summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2020-03-18 15:20:46 +0000
committerSlava Shklyaev <slavash@google.com>2020-04-07 12:16:02 +0100
commit48b0eaf247c3bbb67869f819dec39886c8ba3117 (patch)
treedd3d1112b70e59b99bd92ae253b99ee919a58777
parent77a56993aee1e1e4322541495500aca4dd0b8071 (diff)
downloadml-48b0eaf247c3bbb67869f819dec39886c8ba3117.tar.gz
Add rank checks to validation functions
The change adds rank checks to validation of operations that only support tensors of rank 4 or less. This requirement comes from legacy TF Lite code and is likely to be relaxed in the future to be on par with TF Lite. Adding the checks to validation is benefitial for the TF Lite delegate since in case of a validation error NNAPI node will be fully rejected by the delegation but execution error will cause TF Lite to run NNAPI node during every invocation only to receive an error and do the calculation using CPU implementation. Bug: 139957496 Test: NNTest_static Change-Id: I5cc4c48e775826a237d5ac54c3d2078254bd17a2 Merged-In: I5cc4c48e775826a237d5ac54c3d2078254bd17a2 (cherry picked from commit 9fa29bcc64cb26f3b0c4438d01f3bbe875181ae0)
-rw-r--r--nn/common/OperationsUtils.cpp4
-rw-r--r--nn/common/Utils.cpp24
-rw-r--r--nn/common/include/OperationsUtils.h2
-rw-r--r--nn/common/operations/Activation.cpp4
-rw-r--r--nn/common/operations/Broadcast.cpp6
-rw-r--r--nn/common/operations/ChannelShuffle.cpp4
-rw-r--r--nn/common/operations/Dequantize.cpp6
-rw-r--r--nn/common/operations/FullyConnected.cpp8
-rw-r--r--nn/common/operations/L2Normalization.cpp5
-rw-r--r--nn/common/operations/Reduce.cpp13
-rw-r--r--nn/common/operations/Softmax.cpp7
-rw-r--r--nn/common/operations/Squeeze.cpp6
-rw-r--r--nn/common/operations/StridedSlice.cpp4
-rw-r--r--nn/common/operations/Transpose.cpp4
-rw-r--r--nn/runtime/test/TestValidateOperations.cpp54
15 files changed, 132 insertions, 19 deletions
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp
index c8e594dce..8591afd7e 100644
--- a/nn/common/OperationsUtils.cpp
+++ b/nn/common/OperationsUtils.cpp
@@ -162,6 +162,10 @@ uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx) {
return shape.dimensions[dimensionIdx];
}
+uint32_t hasKnownRank(const Shape& shape) {
+ return !shape.dimensions.empty();
+}
+
bool handleNegativeAxis(int32_t numberOfDimensions, int32_t* axis) {
NN_CHECK(-numberOfDimensions <= *axis && *axis < numberOfDimensions);
if (*axis < 0) {
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 7d5f9a30e..85788ffb2 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -961,6 +961,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ if (inputRank > 4) {
+ LOG(ERROR) << "Unsupported input tensor rank for operation "
+ << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
@@ -1500,6 +1506,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ if (inputRank > 4) {
+ LOG(ERROR) << "Unsupported input tensor rank for operation "
+ << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
@@ -1546,6 +1558,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ if (inputRank > 4) {
+ LOG(ERROR) << "Unsupported input tensor rank for operation "
+ << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
@@ -1591,6 +1609,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
logInvalidInOutNumber(3, 1);
return ANEURALNETWORKS_BAD_DATA;
}
+ const auto inputRank = operands[inputIndexes[0]].dimensions.size();
+ if (inputRank > 4) {
+ LOG(ERROR) << "Unsupported input tensor rank for operation "
+ << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
auto inputType = operands[inputIndexes[0]].type;
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h
index a50b52252..dab1aef35 100644
--- a/nn/common/include/OperationsUtils.h
+++ b/nn/common/include/OperationsUtils.h
@@ -152,6 +152,8 @@ uint32_t getNumberOfDimensions(const Shape& shape);
uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx);
+uint32_t hasKnownRank(const Shape& shape);
+
// Converts an axis index from the range [-dims, dims) into the range [0, dims).
bool handleNegativeAxis(int32_t numberOfDimensions, int32_t* axis);
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp
index a6a3e82c7..f12ed7a61 100644
--- a/nn/common/operations/Activation.cpp
+++ b/nn/common/operations/Activation.cpp
@@ -375,6 +375,10 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
}
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
}
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp
index e19ce742e..17094afa3 100644
--- a/nn/common/operations/Broadcast.cpp
+++ b/nn/common/operations/Broadcast.cpp
@@ -466,6 +466,12 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
}
+ const Shape& input1 = context->getInputShape(kInputTensor1);
+ const Shape& input2 = context->getInputShape(kInputTensor2);
+ if (hasKnownRank(input1) && hasKnownRank(input2)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
+ NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
+ }
return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
validateOutputTypes(context, {inputType});
}
diff --git a/nn/common/operations/ChannelShuffle.cpp b/nn/common/operations/ChannelShuffle.cpp
index c78e496d4..7abf224c8 100644
--- a/nn/common/operations/ChannelShuffle.cpp
+++ b/nn/common/operations/ChannelShuffle.cpp
@@ -69,6 +69,10 @@ bool validate(const IOperationValidationContext* context) {
inputType == OperandType::TENSOR_QUANT8_ASYMM ||
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
<< "Unsupported tensor type for operation " << kOperationName;
+ const Shape& inputShape = context->getInputShape(kInputTensor);
+ if (hasKnownRank(inputShape)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(inputShape), 4);
+ }
NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::INT32, OperandType::INT32}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
diff --git a/nn/common/operations/Dequantize.cpp b/nn/common/operations/Dequantize.cpp
index 3505540bf..2fb2d5cb0 100644
--- a/nn/common/operations/Dequantize.cpp
+++ b/nn/common/operations/Dequantize.cpp
@@ -83,6 +83,11 @@ bool validate(const IOperationValidationContext* context) {
const OperandType inputType = context->getInputType(kInputTensor);
const OperandType outputType = context->getOutputType(kOutputTensor);
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
+
if (inputType == OperandType::TENSOR_QUANT8_ASYMM &&
outputType == OperandType::TENSOR_FLOAT32) {
return validateHalVersion(context, HalVersion::V1_0);
@@ -101,6 +106,7 @@ bool validate(const IOperationValidationContext* context) {
bool prepare(IOperationExecutionContext* context) {
const Shape& input = context->getInputShape(kInputTensor);
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
Shape output = context->getOutputShape(kOutputTensor);
output.dimensions = input.dimensions;
return context->setOutputShape(kOutputTensor, output);
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp
index 29fbec77b..2afbee026 100644
--- a/nn/common/operations/FullyConnected.cpp
+++ b/nn/common/operations/FullyConnected.cpp
@@ -240,6 +240,13 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
+
return true;
}
@@ -260,6 +267,7 @@ bool prepare(IOperationExecutionContext* context) {
// The Tensorflow fully connected layer specification says that input should
// be of at least rank 2, so we check. Tflite doesn't check.
NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
uint32_t input_n_elements = getNumberOfElements(input);
uint32_t num_units = getSizeOfDimension(weights, 0);
diff --git a/nn/common/operations/L2Normalization.cpp b/nn/common/operations/L2Normalization.cpp
index 1925d5471..1f0c9d051 100644
--- a/nn/common/operations/L2Normalization.cpp
+++ b/nn/common/operations/L2Normalization.cpp
@@ -221,6 +221,10 @@ bool validate(const IOperationValidationContext* context) {
} else if (context->getInputShape(kInputTensor).dimensions.size() != 4) {
NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
}
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateInputTypes(context, inExpectedTypes) &&
validateOutputTypes(context, {inputType});
}
@@ -231,6 +235,7 @@ bool prepare(IOperationExecutionContext* context) {
int32_t axis = context->getNumInputs() == kNumInputs
? context->getInputValue<int32_t>(kAxisScalar)
: -1;
+ NN_RET_CHECK_LE(numDimensions, 4);
NN_RET_CHECK_GE(axis, -numDimensions);
NN_RET_CHECK_LT(axis, numDimensions);
Shape output = context->getOutputShape(kOutputTensor);
diff --git a/nn/common/operations/Reduce.cpp b/nn/common/operations/Reduce.cpp
index b3327c9b7..8b2155238 100644
--- a/nn/common/operations/Reduce.cpp
+++ b/nn/common/operations/Reduce.cpp
@@ -79,6 +79,10 @@ bool validateProdSum(const IOperationValidationContext* context) {
NN_RET_CHECK(
validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, HalVersion::V1_2);
}
@@ -98,6 +102,10 @@ bool validateMaxMin(const IOperationValidationContext* context) {
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
minHalVersion = HalVersion::V1_3;
}
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, minHalVersion);
}
@@ -110,12 +118,17 @@ bool validateLogical(const IOperationValidationContext* context) {
NN_RET_CHECK(
validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, HalVersion::V1_2);
}
bool prepare(IOperationExecutionContext* context) {
Shape inputShape = context->getInputShape(kInputTensor);
const uint32_t inputRank = getNumberOfDimensions(inputShape);
+ NN_RET_CHECK_LE(inputRank, 4);
std::vector<bool> shouldReduce(inputRank);
const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
diff --git a/nn/common/operations/Softmax.cpp b/nn/common/operations/Softmax.cpp
index f9b8ed2d5..8c0562880 100644
--- a/nn/common/operations/Softmax.cpp
+++ b/nn/common/operations/Softmax.cpp
@@ -246,12 +246,15 @@ bool validate(const IOperationValidationContext* context) {
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
}
+ const auto inputRank = getNumberOfDimensions(context->getInputShape(kInputTensor));
+ if (inputRank != 0) {
+ NN_RET_CHECK_LE(inputRank, 4);
+ }
if (context->getNumInputs() == kNumInputs) {
NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
inExpectedTypes.push_back(OperandType::INT32);
} else {
- const size_t ndim = context->getInputShape(kInputTensor).dimensions.size();
- if (ndim != 2 && ndim != 4 && ndim != 0) {
+ if (inputRank != 2 && inputRank != 4 && inputRank != 0) {
NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
}
}
diff --git a/nn/common/operations/Squeeze.cpp b/nn/common/operations/Squeeze.cpp
index ca09703fd..977856d2d 100644
--- a/nn/common/operations/Squeeze.cpp
+++ b/nn/common/operations/Squeeze.cpp
@@ -62,6 +62,10 @@ bool validate(const IOperationValidationContext* context) {
OperandType::TENSOR_INT32,
}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, minSupportedHalVersion);
}
@@ -75,6 +79,8 @@ bool prepare(IOperationExecutionContext* context) {
const Shape squeezeDimsShape = context->getInputShape(kSqueezeDims);
int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
+ NN_RET_CHECK_LE(getNumberOfDimensions(inputShape), 4);
+
// squeezeDims need to be provided as a 1-D int32 tensor.
NN_OPS_CHECK(squeezeDimsShape.type == OperandType::TENSOR_INT32);
NN_OPS_CHECK(getNumberOfDimensions(squeezeDimsShape) == 1);
diff --git a/nn/common/operations/StridedSlice.cpp b/nn/common/operations/StridedSlice.cpp
index bcc95f66d..88993837a 100644
--- a/nn/common/operations/StridedSlice.cpp
+++ b/nn/common/operations/StridedSlice.cpp
@@ -128,6 +128,10 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32,
}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, minSupportedHalVersion);
}
diff --git a/nn/common/operations/Transpose.cpp b/nn/common/operations/Transpose.cpp
index e0320c6f7..ff70f9e8b 100644
--- a/nn/common/operations/Transpose.cpp
+++ b/nn/common/operations/Transpose.cpp
@@ -87,6 +87,10 @@ bool validate(const IOperationValidationContext* context) {
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
}
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) &&
validateOutputTypes(context, {inputType});
}
diff --git a/nn/runtime/test/TestValidateOperations.cpp b/nn/runtime/test/TestValidateOperations.cpp
index 9e69df0d0..974ce523b 100644
--- a/nn/runtime/test/TestValidateOperations.cpp
+++ b/nn/runtime/test/TestValidateOperations.cpp
@@ -1006,7 +1006,8 @@ void dequantizeOpTest(int32_t inputOperandType, int32_t outputOperandType) {
uint32_t inputDimensions[4] = {2, 2, 2, 2};
ANeuralNetworksOperandType input = getOpType(inputOperandType, 4, inputDimensions);
ANeuralNetworksOperandType output = getOpType(outputOperandType, 4, inputDimensions);
- OperationTestBase dequantizeTest(ANEURALNETWORKS_DEQUANTIZE, {input}, {output});
+ OperationTestBase dequantizeTest(ANEURALNETWORKS_DEQUANTIZE, {input}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
dequantizeTest.testOpsValidations();
}
@@ -1238,7 +1239,9 @@ void simpleMathOpTest(ANeuralNetworksOperationType operationCode, int32_t operan
.scale = 0.0f,
.zeroPoint = 0};
- OperationTestBase simpleMathTest(operationCode, {input1, input2, activation}, {output});
+ OperationTestBase simpleMathTest(
+ operationCode, {input1, input2, activation}, {output},
+ {{TensorRankConstraint::UpTo(4), {0}}, {TensorRankConstraint::UpTo(4), {1}}});
simpleMathTest.testOpsValidations();
}
@@ -1389,7 +1392,13 @@ void activationOpTest(ANeuralNetworksOperationType operationCode, int32_t operan
ANeuralNetworksOperandType input = getOpType(operandCode, 4, inputDimensions);
ANeuralNetworksOperandType output = input;
- OperationTestBase test(operationCode, {input}, {output});
+ std::vector<TensorRankMutator> inputRankMutators;
+ if (operationCode == ANEURALNETWORKS_LOGISTIC || operationCode == ANEURALNETWORKS_RELU ||
+ operationCode == ANEURALNETWORKS_RELU1 || operationCode == ANEURALNETWORKS_RELU6 ||
+ operationCode == ANEURALNETWORKS_TANH) {
+ inputRankMutators.push_back({TensorRankConstraint::UpTo(4)});
+ }
+ OperationTestBase test(operationCode, {input}, {output}, inputRankMutators);
test.testOpsValidations();
}
@@ -1593,7 +1602,8 @@ void reshapeOpTest(int32_t inputOperandCode) {
ANeuralNetworksOperandType shape = getOpType(ANEURALNETWORKS_TENSOR_INT32, 1, shapeDims);
uint32_t outputDimensions[2] = {4, 6};
ANeuralNetworksOperandType output = getOpType(inputOperandCode, 2, outputDimensions);
- OperationTestBase test(ANEURALNETWORKS_RESHAPE, {input, shape}, {output});
+ OperationTestBase test(ANEURALNETWORKS_RESHAPE, {input, shape}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
test.testOpsValidations();
}
@@ -1649,7 +1659,8 @@ void meanOpTest(int32_t inputOperandCode) {
ANeuralNetworksOperandType keepDims = getOpType(ANEURALNETWORKS_INT32);
ANeuralNetworksOperandType output = getOpType(inputOperandCode, 3, inputDimensions);
- OperationTestBase test(ANEURALNETWORKS_MEAN, {input, dims, keepDims}, {output});
+ OperationTestBase test(ANEURALNETWORKS_MEAN, {input, dims, keepDims}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
test.testOpsValidations();
}
@@ -1678,7 +1689,8 @@ void padOpTest(int32_t inputOperandCode) {
getOpType(ANEURALNETWORKS_TENSOR_INT32, 1, padSizeDimensions);
uint32_t outputDimensions[4] = {4, 3, 4, 3};
ANeuralNetworksOperandType output = getOpType(inputOperandCode, 4, outputDimensions);
- OperationTestBase test(ANEURALNETWORKS_PAD, {input, padSize}, {output});
+ OperationTestBase test(ANEURALNETWORKS_PAD, {input, padSize}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
test.testOpsValidations();
}
@@ -1705,7 +1717,8 @@ void padV2OpTest(int32_t inputOperandCode) {
}
uint32_t outputDimensions[4] = {4, 3, 4, 3};
ANeuralNetworksOperandType output = getOpType(inputOperandCode, 4, outputDimensions);
- OperationTestBase test(ANEURALNETWORKS_PAD_V2, {input, padSize, padValue}, {output});
+ OperationTestBase test(ANEURALNETWORKS_PAD_V2, {input, padSize, padValue}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
test.testOpsValidations();
}
@@ -1726,11 +1739,13 @@ void softmaxOpTest(int32_t operandCode) {
beta = getOpType(ANEURALNETWORKS_FLOAT16);
}
- OperationTestBase softmaxTest(ANEURALNETWORKS_SOFTMAX, {input, beta}, {output});
+ OperationTestBase softmaxTest(ANEURALNETWORKS_SOFTMAX, {input, beta}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
softmaxTest.testOpsValidations();
ANeuralNetworksOperandType axis = getOpType(ANEURALNETWORKS_INT32);
- OperationTestBase softmaxAxisTest(ANEURALNETWORKS_SOFTMAX, {input, beta, axis}, {output});
+ OperationTestBase softmaxAxisTest(ANEURALNETWORKS_SOFTMAX, {input, beta, axis}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
softmaxAxisTest.testOpsValidations();
}
@@ -1976,7 +1991,8 @@ void transposeAndSqueezeOpTest(ANeuralNetworksOperationType operationCode, int32
.zeroPoint = 0};
ANeuralNetworksOperandType output = input;
- OperationTestBase transposeAndSqueezeTest(operationCode, {input, dims}, {output});
+ OperationTestBase transposeAndSqueezeTest(operationCode, {input, dims}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
transposeAndSqueezeTest.testOpsValidations();
}
@@ -2320,7 +2336,8 @@ void fullyConnectedOpTest(int32_t operandCode) {
.zeroPoint = 0};
OperationTestBase fullyConnectedTest(ANEURALNETWORKS_FULLY_CONNECTED,
- {input, weights, bias, activation}, {output});
+ {input, weights, bias, activation}, {output},
+ {{TensorRankConstraint::Between(2, 4)}});
fullyConnectedTest.testOpsValidations();
}
@@ -2353,10 +2370,11 @@ void concatenationTest(int32_t operandCode) {
.zeroPoint = 0};
OperationTestBase concat2Test(ANEURALNETWORKS_CONCATENATION, {input1, input2, activation},
- {output});
+ {output}, {{TensorRankConstraint::UpTo(4), {0, 1}}});
concat2Test.testOpsValidations();
- OperationTestBase concat1Test(ANEURALNETWORKS_CONCATENATION, {input1, activation}, {output});
+ OperationTestBase concat1Test(ANEURALNETWORKS_CONCATENATION, {input1, activation}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
concat1Test.testOpsValidations();
}
@@ -3085,7 +3103,8 @@ void stridedSliceOpTest(int32_t operandCode) {
OperationTestBase stridedSliceTest(
ANEURALNETWORKS_STRIDED_SLICE,
- {input, begins, ends, strides, beginMask, endMask, shrinkAxisMask}, {output});
+ {input, begins, ends, strides, beginMask, endMask, shrinkAxisMask}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
stridedSliceTest.testOpsValidations();
}
@@ -3384,7 +3403,7 @@ void channelShuffleOpTest(int32_t operandCode) {
ANEURALNETWORKS_CHANNEL_SHUFFLE,
{getOpType(operandCode, 2, inoutDim), getOpType(ANEURALNETWORKS_INT32),
getOpType(ANEURALNETWORKS_INT32)},
- {getOpType(operandCode, 2, inoutDim)});
+ {getOpType(operandCode, 2, inoutDim)}, {{TensorRankConstraint::UpTo(4)}});
channelShuffleTest.testOpsValidations();
}
@@ -3485,7 +3504,7 @@ void normalizationOpTest(ANeuralNetworksOperationType operationCode, int32_t ope
OperationTestBase normalizationAxisTest(
operationCode, {getOpType(operandCode, 4, inputDim), getOpType(ANEURALNETWORKS_INT32)},
- {getOpType(operandCode, 4, inputDim)});
+ {getOpType(operandCode, 4, inputDim)}, {{TensorRankConstraint::UpTo(4)}});
normalizationAxisTest.testOpsValidations();
}
@@ -3691,7 +3710,8 @@ void reduceOpTest(ANeuralNetworksOperationType operationCode, int32_t inputOpera
ANeuralNetworksOperandType input2 = getOpType(ANEURALNETWORKS_TENSOR_INT32, 1, axesDimensions);
ANeuralNetworksOperandType input3 = getOpType(ANEURALNETWORKS_BOOL, 0);
ANeuralNetworksOperandType output = getOpType(inputOperandType, 4, inputDimensions);
- OperationTestBase test(operationCode, {input1, input2, input3}, {output});
+ OperationTestBase test(operationCode, {input1, input2, input3}, {output},
+ {{TensorRankConstraint::UpTo(4)}});
test.testOpsValidations();
}