summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2020-11-02 23:09:34 -0800
committerMichael Butler <butlermichael@google.com>2020-11-06 18:11:02 -0800
commitd6f4f1ed9fea50529006a0aa3436e4bce4decd05 (patch)
tree8bb1d3ea47aab1a33e489ec220311770f3449b0c
parentf1c452cda6533807bcab2337cb13d5184405505c (diff)
downloadml-d6f4f1ed9fea50529006a0aa3436e4bce4decd05.tar.gz
Reorganize operation validation version code
Bug: N/A Test: mma Test: NeuralNetworksTest_static Change-Id: Iae0d6ef34551b1c3ad5ad670ff54733d38c288af Merged-In: Iae0d6ef34551b1c3ad5ad670ff54733d38c288af (cherry picked from commit 109f573c3d0feef5ebe8f86a0d10240dcb43254d)
-rw-r--r--nn/common/operations/Activation.cpp22
-rw-r--r--nn/common/operations/Broadcast.cpp26
-rw-r--r--nn/common/operations/Concatenation.cpp12
-rw-r--r--nn/common/operations/Conv2D.cpp12
-rw-r--r--nn/common/operations/DepthwiseConv2D.cpp12
-rw-r--r--nn/common/operations/Elu.cpp8
-rw-r--r--nn/common/operations/FullyConnected.cpp13
-rw-r--r--nn/common/operations/L2Normalization.cpp16
-rw-r--r--nn/common/operations/LocalResponseNormalization.cpp14
-rw-r--r--nn/common/operations/Pooling.cpp19
-rw-r--r--nn/common/operations/ResizeImageOps.cpp19
-rw-r--r--nn/common/operations/Slice.cpp12
-rw-r--r--nn/common/operations/Softmax.cpp17
-rw-r--r--nn/common/operations/Transpose.cpp12
-rw-r--r--nn/common/operations/TransposeConv2D.cpp6
15 files changed, 126 insertions, 94 deletions
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp
index e3d848799..bcf846a80 100644
--- a/nn/common/operations/Activation.cpp
+++ b/nn/common/operations/Activation.cpp
@@ -357,18 +357,19 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
if (opType == OperationType::TANH) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
}
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType;
}
@@ -376,21 +377,26 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool validateHardSwish(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
inputType == OperandType::TENSOR_QUANT8_ASYMM ||
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
}
- return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(OperationType opType, IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp
index ce1320fb5..e47bd21e8 100644
--- a/nn/common/operations/Broadcast.cpp
+++ b/nn/common/operations/Broadcast.cpp
@@ -33,6 +33,7 @@
#include "OperationResolver.h"
#include "Tracing.h"
#include "nnapi/Types.h"
+#include "nnapi/Validation.h"
namespace android {
namespace nn {
@@ -434,19 +435,19 @@ bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, c
} // namespace
bool validate(OperationType opType, const IOperationValidationContext* context) {
- const Version opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB)
- ? Version::ANDROID_P
- : Version::ANDROID_OC_MR1;
+ auto minSupportedVersion = (opType == OperationType::DIV || opType == OperationType::SUB)
+ ? Version::ANDROID_P
+ : Version::ANDROID_OC_MR1;
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor1);
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt)));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_Q, opIntroducedAt)));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
if (opType == OperationType::SUB) {
- NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_Q, opIntroducedAt)));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
} else if (opType == OperationType::DIV) {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
} else if (opType == OperationType::MUL) {
@@ -454,15 +455,13 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
Shape input1 = context->getInputShape(kInputTensor1);
Shape input2 = context->getInputShape(kInputTensor2);
NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale);
- NN_RET_CHECK(
- validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt)));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
} else {
- NN_RET_CHECK(
- validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt)));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
}
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
inputType == OperandType::TENSOR_INT32) {
- NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_R, opIntroducedAt)));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R);
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType;
}
@@ -472,8 +471,9 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
}
- return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, OperandType::INT32}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Concatenation.cpp b/nn/common/operations/Concatenation.cpp
index cadfd0f65..16a08d6b9 100644
--- a/nn/common/operations/Concatenation.cpp
+++ b/nn/common/operations/Concatenation.cpp
@@ -140,12 +140,13 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_GE(inputCount, 2);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
const OperandType inputType = context->getInputType(0);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
}
@@ -166,8 +167,9 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_LE(inputRank, 4);
}
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp
index 5a5e33764..d00da57a9 100644
--- a/nn/common/operations/Conv2D.cpp
+++ b/nn/common/operations/Conv2D.cpp
@@ -612,17 +612,19 @@ bool validate(const IOperationValidationContext* context) {
}
}
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else if (inputType == OperandType::TENSOR_FLOAT16 ||
filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout ||
withDilation || !meetsQuantizedScaleConstraintBeforeV1_2) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/DepthwiseConv2D.cpp b/nn/common/operations/DepthwiseConv2D.cpp
index 611e38d63..bb158b328 100644
--- a/nn/common/operations/DepthwiseConv2D.cpp
+++ b/nn/common/operations/DepthwiseConv2D.cpp
@@ -495,17 +495,19 @@ bool validate(const IOperationValidationContext* context) {
}
}
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else if (inputType == OperandType::TENSOR_FLOAT16 ||
filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout ||
withDilation || !meetsQuantizedScaleConstraintBeforeV1_2) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Elu.cpp b/nn/common/operations/Elu.cpp
index 0c72cb383..105ef01cf 100644
--- a/nn/common/operations/Elu.cpp
+++ b/nn/common/operations/Elu.cpp
@@ -56,15 +56,17 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
}
auto scalarType =
inputType == OperandType::TENSOR_FLOAT16 ? OperandType::FLOAT16 : OperandType::FLOAT32;
- return validateInputTypes(context, {inputType, scalarType}) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType, scalarType}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp
index 7c8c4e304..873b64abf 100644
--- a/nn/common/operations/FullyConnected.cpp
+++ b/nn/common/operations/FullyConnected.cpp
@@ -223,8 +223,9 @@ bool validate(const IOperationValidationContext* context) {
auto inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
inExpectedTypes = {
OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32,
@@ -232,7 +233,7 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32,
};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
inExpectedTypes = {
OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16,
@@ -249,9 +250,9 @@ bool validate(const IOperationValidationContext* context) {
bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
if (!meetsQuantizedScaleConstraintBeforeV1_2) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
}
inExpectedTypes = {
@@ -261,7 +262,7 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32,
};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
inExpectedTypes = {
OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
@@ -283,7 +284,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateShapes(input, weights, bias));
}
- return true;
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/L2Normalization.cpp b/nn/common/operations/L2Normalization.cpp
index 49cc15dda..22f0cb3d2 100644
--- a/nn/common/operations/L2Normalization.cpp
+++ b/nn/common/operations/L2Normalization.cpp
@@ -203,27 +203,29 @@ bool validate(const IOperationValidationContext* context) {
const OperandType inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes = {inputType};
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
}
if (context->getNumInputs() == kNumInputs) {
inExpectedTypes.push_back(OperandType::INT32);
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (context->getInputShape(kInputTensor).dimensions.size() != 4) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
}
const Shape& input = context->getInputShape(kInputTensor);
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/LocalResponseNormalization.cpp b/nn/common/operations/LocalResponseNormalization.cpp
index 6276168a7..435d602f2 100644
--- a/nn/common/operations/LocalResponseNormalization.cpp
+++ b/nn/common/operations/LocalResponseNormalization.cpp
@@ -138,15 +138,16 @@ bool validate(const IOperationValidationContext* context) {
const OperandType inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
inExpectedTypes = {
OperandType::TENSOR_FLOAT32, OperandType::INT32, OperandType::FLOAT32,
OperandType::FLOAT32, OperandType::FLOAT32,
};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
inExpectedTypes = {
OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::FLOAT16,
OperandType::FLOAT16, OperandType::FLOAT16,
@@ -158,17 +159,18 @@ bool validate(const IOperationValidationContext* context) {
if (context->getNumInputs() == kNumInputs) {
inExpectedTypes.push_back(OperandType::INT32);
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (context->getInputShape(kInputTensor).dimensions.size() != 4) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
}
const Shape& input = context->getInputShape(kInputTensor);
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Pooling.cpp b/nn/common/operations/Pooling.cpp
index bc6571d79..e1c4cdcd5 100644
--- a/nn/common/operations/Pooling.cpp
+++ b/nn/common/operations/Pooling.cpp
@@ -24,6 +24,7 @@
#include "CpuOperationUtils.h"
#include "OperationResolver.h"
#include "Tracing.h"
+#include "nnapi/Validation.h"
namespace android {
namespace nn {
@@ -293,14 +294,15 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7);
auto inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes;
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
inExpectedTypes = {
inputType, OperandType::INT32, OperandType::INT32, OperandType::INT32,
OperandType::INT32, OperandType::INT32, OperandType::INT32,
};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
inExpectedTypes = {
OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32,
OperandType::INT32, OperandType::INT32, OperandType::INT32,
@@ -308,7 +310,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
};
} else if (opType != OperationType::L2_POOL_2D &&
inputType == OperandType::TENSOR_QUANT8_ASYMM) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
inExpectedTypes = {
OperandType::TENSOR_QUANT8_ASYMM,
OperandType::INT32,
@@ -320,7 +322,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
};
} else if (opType != OperationType::L2_POOL_2D &&
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
inExpectedTypes = {
OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
OperandType::INT32,
@@ -341,12 +343,13 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
if (inputCount == 11 || inputCount == 8) {
inExpectedTypes.push_back(OperandType::BOOL);
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/ResizeImageOps.cpp b/nn/common/operations/ResizeImageOps.cpp
index a1acf187d..2c923f852 100644
--- a/nn/common/operations/ResizeImageOps.cpp
+++ b/nn/common/operations/ResizeImageOps.cpp
@@ -25,6 +25,7 @@
#include "CpuOperationUtils.h"
#include "OperationResolver.h"
#include "Tracing.h"
+#include "nnapi/Validation.h"
namespace android {
namespace nn {
@@ -181,19 +182,20 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
auto inputType = context->getInputType(kInputTensor);
auto scalarType = context->getInputType(kOutputHeightParamScalar);
std::vector<OperandType> inExpectedTypes = {inputType, scalarType, scalarType};
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
inputType == OperandType::TENSOR_FLOAT32 ||
inputType == OperandType::TENSOR_QUANT8_ASYMM ||
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
<< "Unsupported tensor type for operation " << opType;
if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
}
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R);
}
if (scalarType != OperandType::INT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
if (inputType == OperandType::TENSOR_FLOAT32) {
NN_RET_CHECK(scalarType == OperandType::FLOAT32);
} else if (inputType == OperandType::TENSOR_FLOAT16) {
@@ -204,18 +206,19 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
}
if (numInputs < kNumInputs) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
} else if (numInputs == kNumInputs) {
inExpectedTypes.push_back(OperandType::BOOL);
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
} else {
while (inExpectedTypes.size() < numInputs) {
inExpectedTypes.push_back(OperandType::BOOL);
}
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R);
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(OperationType opType, IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Slice.cpp b/nn/common/operations/Slice.cpp
index 3cf3c8a33..882b0eb1a 100644
--- a/nn/common/operations/Slice.cpp
+++ b/nn/common/operations/Slice.cpp
@@ -89,14 +89,16 @@ bool validate(const IOperationValidationContext* context) {
inputType == OperandType::TENSOR_QUANT8_ASYMM ||
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
<< "Unsupported tensor type for operation " << kOperationName;
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
}
- return validateInputTypes(context,
- {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(
+ context, {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Softmax.cpp b/nn/common/operations/Softmax.cpp
index a9373957a..e3c362f8a 100644
--- a/nn/common/operations/Softmax.cpp
+++ b/nn/common/operations/Softmax.cpp
@@ -27,6 +27,7 @@
#include "CpuOperationUtils.h"
#include "OperationResolver.h"
#include "Tracing.h"
+#include "nnapi/Validation.h"
namespace android {
namespace nn {
@@ -232,14 +233,15 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes;
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
inExpectedTypes = {inputType, OperandType::FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
inExpectedTypes = {inputType, OperandType::FLOAT16};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
inExpectedTypes = {inputType, OperandType::FLOAT32};
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
@@ -249,15 +251,16 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_LE(inputRank, 4);
}
if (context->getNumInputs() == kNumInputs) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
inExpectedTypes.push_back(OperandType::INT32);
} else {
if (inputRank != 2 && inputRank != 4 && inputRank != 0) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
}
}
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Transpose.cpp b/nn/common/operations/Transpose.cpp
index 3bc76f03e..b964c39a8 100644
--- a/nn/common/operations/Transpose.cpp
+++ b/nn/common/operations/Transpose.cpp
@@ -74,12 +74,13 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
const OperandType inputType = context->getInputType(kInputTensor);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_P));
+ minSupportedVersion = Version::ANDROID_P;
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
}
@@ -87,8 +88,9 @@ bool validate(const IOperationValidationContext* context) {
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/TransposeConv2D.cpp b/nn/common/operations/TransposeConv2D.cpp
index 78d857a35..9d6dbbbfc 100644
--- a/nn/common/operations/TransposeConv2D.cpp
+++ b/nn/common/operations/TransposeConv2D.cpp
@@ -474,9 +474,9 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32, OperandType::INT32, OperandType::BOOL};
}
inExpectedTypes.insert(inExpectedTypes.end(), argExpectedTypes.begin(), argExpectedTypes.end());
- NN_RET_CHECK(validateVersion(context, minSupportedVersion));
- return validateInputTypes(context, inExpectedTypes) &&
- validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return validateVersion(context, minSupportedVersion);
}
bool prepare(IOperationExecutionContext* context) {