diff options
author | Lev Proleev <levp@google.com> | 2020-05-20 21:16:27 +0100 |
---|---|---|
committer | Lev Proleev <levp@google.com> | 2020-05-26 15:47:43 +0100 |
commit | d517e96dba4d6528382a770d38882e4610e2ff1b (patch) | |
tree | a2866f5ad2c41b0405cc02aa5f776160c45dbf16 /nn/common/Utils.cpp | |
parent | 3594be18553db032e3af6dae5c6b9e310eb2f3d8 (diff) | |
download | ml-d517e96dba4d6528382a770d38882e4610e2ff1b.tar.gz |
Add shape check to CAST validation
The generated test is added only to CTS since VTS would fail on some 1.2
drivers.
Fix: 156284111
Test: NeuralNetworksTest_static
Change-Id: I4f3c6cdb9f546501e4ca375d7900431c384d6885
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r-- | nn/common/Utils.cpp | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index 81e5cf1e1..7a66b68ef 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -26,7 +26,10 @@ #include <sys/system_properties.h> #include <algorithm> +#include <functional> +#include <iostream> #include <limits> +#include <numeric> #include <set> #include <string> #include <tuple> @@ -1508,8 +1511,10 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, logInvalidInOutNumber(1, 1); return ANEURALNETWORKS_BAD_DATA; } - auto inputType = operands[inputIndexes[0]].type; - auto outputType = operands[outputIndexes[0]].type; + auto inputOperand = operands[inputIndexes[0]]; + auto outputOperand = operands[outputIndexes[0]]; + auto inputType = inputOperand.type; + auto outputType = outputOperand.type; std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; if ((inputType == OperandType::TENSOR_FLOAT16 || @@ -1535,6 +1540,19 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, LOG(ERROR) << "Unsupported data type for operation " << getOperationName(opType); return ANEURALNETWORKS_BAD_DATA; } + // Validate that output shape is equal to input shape if dimensions + // are already known. + auto getNumberOfElements = [](const hardware::hidl_vec<uint32_t>& dims) { + if (dims.size() == 0) { + return 0; + } + return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + }; + if (inputOperand.dimensions.size() != 0 && outputOperand.dimensions.size() != 0 && + getNumberOfElements(outputOperand.dimensions) != 0 && + inputOperand.dimensions != outputOperand.dimensions) { + return ANEURALNETWORKS_BAD_DATA; + } return validateOperationOperandTypes(operands, inputCount, inputIndexes, inExpectedTypes, outputCount, outputIndexes, outExpectedTypes); |