summaryrefslogtreecommitdiff
path: root/nn/common/Utils.cpp
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2020-05-20 21:16:27 +0100
committerLev Proleev <levp@google.com>2020-05-26 15:47:43 +0100
commitd517e96dba4d6528382a770d38882e4610e2ff1b (patch)
treea2866f5ad2c41b0405cc02aa5f776160c45dbf16 /nn/common/Utils.cpp
parent3594be18553db032e3af6dae5c6b9e310eb2f3d8 (diff)
downloadml-d517e96dba4d6528382a770d38882e4610e2ff1b.tar.gz
Add shape check to CAST validation
The generated test is added only to CTS since VTS would fail on some 1.2 drivers. Fix: 156284111 Test: NeuralNetworksTest_static Change-Id: I4f3c6cdb9f546501e4ca375d7900431c384d6885
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r--nn/common/Utils.cpp22
1 files changed, 20 insertions, 2 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 81e5cf1e1..7a66b68ef 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -26,7 +26,10 @@
#include <sys/system_properties.h>
#include <algorithm>
+#include <functional>
+#include <iostream>
#include <limits>
+#include <numeric>
#include <set>
#include <string>
#include <tuple>
@@ -1508,8 +1511,10 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
logInvalidInOutNumber(1, 1);
return ANEURALNETWORKS_BAD_DATA;
}
- auto inputType = operands[inputIndexes[0]].type;
- auto outputType = operands[outputIndexes[0]].type;
+ auto inputOperand = operands[inputIndexes[0]];
+ auto outputOperand = operands[outputIndexes[0]];
+ auto inputType = inputOperand.type;
+ auto outputType = outputOperand.type;
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
if ((inputType == OperandType::TENSOR_FLOAT16 ||
@@ -1535,6 +1540,19 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
LOG(ERROR) << "Unsupported data type for operation " << getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
+ // Validate that output shape is equal to input shape if dimensions
+ // are already known.
+ auto getNumberOfElements = [](const hardware::hidl_vec<uint32_t>& dims) {
+ if (dims.size() == 0) {
+ return 0;
+ }
+ return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
+ };
+ if (inputOperand.dimensions.size() != 0 && outputOperand.dimensions.size() != 0 &&
+ getNumberOfElements(outputOperand.dimensions) != 0 &&
+ inputOperand.dimensions != outputOperand.dimensions) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);