summaryrefslogtreecommitdiff
path: root/nn/common/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r--nn/common/Utils.cpp22
1 files changed, 20 insertions, 2 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 81e5cf1e1..7a66b68ef 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -26,7 +26,10 @@
#include <sys/system_properties.h>
#include <algorithm>
+#include <functional>
+#include <iostream>
#include <limits>
+#include <numeric>
#include <set>
#include <string>
#include <tuple>
@@ -1508,8 +1511,10 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
logInvalidInOutNumber(1, 1);
return ANEURALNETWORKS_BAD_DATA;
}
- auto inputType = operands[inputIndexes[0]].type;
- auto outputType = operands[outputIndexes[0]].type;
+ auto inputOperand = operands[inputIndexes[0]];
+ auto outputOperand = operands[outputIndexes[0]];
+ auto inputType = inputOperand.type;
+ auto outputType = outputOperand.type;
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
if ((inputType == OperandType::TENSOR_FLOAT16 ||
@@ -1535,6 +1540,19 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
LOG(ERROR) << "Unsupported data type for operation " << getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
+ // Validate that output shape is equal to input shape if dimensions
+ // are already known.
+ auto getNumberOfElements = [](const hardware::hidl_vec<uint32_t>& dims) {
+ if (dims.size() == 0) {
+ return 0;
+ }
+ return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
+ };
+ if (inputOperand.dimensions.size() != 0 && outputOperand.dimensions.size() != 0 &&
+ getNumberOfElements(outputOperand.dimensions) != 0 &&
+ inputOperand.dimensions != outputOperand.dimensions) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);