diff options
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r-- | nn/common/Utils.cpp | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index 81e5cf1e1..7a66b68ef 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -26,7 +26,10 @@ #include <sys/system_properties.h> #include <algorithm> +#include <functional> +#include <iostream> #include <limits> +#include <numeric> #include <set> #include <string> #include <tuple> @@ -1508,8 +1511,10 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, logInvalidInOutNumber(1, 1); return ANEURALNETWORKS_BAD_DATA; } - auto inputType = operands[inputIndexes[0]].type; - auto outputType = operands[outputIndexes[0]].type; + auto inputOperand = operands[inputIndexes[0]]; + auto outputOperand = operands[outputIndexes[0]]; + auto inputType = inputOperand.type; + auto outputType = outputOperand.type; std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; if ((inputType == OperandType::TENSOR_FLOAT16 || @@ -1535,6 +1540,19 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, LOG(ERROR) << "Unsupported data type for operation " << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } + // Validate that output shape is equal to input shape if dimensions + // are already known. + auto getNumberOfElements = [](const hardware::hidl_vec<uint32_t>& dims) { + if (dims.size() == 0) { + return 0; + } + return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + }; + if (inputOperand.dimensions.size() != 0 && outputOperand.dimensions.size() != 0 && + getNumberOfElements(outputOperand.dimensions) != 0 && + inputOperand.dimensions != outputOperand.dimensions) { + return ANEURALNETWORKS_BAD_DATA; + } return validateOperationOperandTypes(operands, inputCount, inputIndexes, inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); |