diff options
author | Michael Butler <butlermichael@google.com> | 2020-11-02 23:09:34 -0800 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2020-11-06 18:11:02 -0800 |
commit | d6f4f1ed9fea50529006a0aa3436e4bce4decd05 (patch) | |
tree | 8bb1d3ea47aab1a33e489ec220311770f3449b0c | |
parent | f1c452cda6533807bcab2337cb13d5184405505c (diff) | |
download | ml-d6f4f1ed9fea50529006a0aa3436e4bce4decd05.tar.gz |
Reorganize operation validation version code
Bug: N/A
Test: mma
Test: NeuralNetworksTest_static
Change-Id: Iae0d6ef34551b1c3ad5ad670ff54733d38c288af
Merged-In: Iae0d6ef34551b1c3ad5ad670ff54733d38c288af
(cherry picked from commit 109f573c3d0feef5ebe8f86a0d10240dcb43254d)
-rw-r--r-- | nn/common/operations/Activation.cpp | 22 | ||||
-rw-r--r-- | nn/common/operations/Broadcast.cpp | 26 | ||||
-rw-r--r-- | nn/common/operations/Concatenation.cpp | 12 | ||||
-rw-r--r-- | nn/common/operations/Conv2D.cpp | 12 | ||||
-rw-r--r-- | nn/common/operations/DepthwiseConv2D.cpp | 12 | ||||
-rw-r--r-- | nn/common/operations/Elu.cpp | 8 | ||||
-rw-r--r-- | nn/common/operations/FullyConnected.cpp | 13 | ||||
-rw-r--r-- | nn/common/operations/L2Normalization.cpp | 16 | ||||
-rw-r--r-- | nn/common/operations/LocalResponseNormalization.cpp | 14 | ||||
-rw-r--r-- | nn/common/operations/Pooling.cpp | 19 | ||||
-rw-r--r-- | nn/common/operations/ResizeImageOps.cpp | 19 | ||||
-rw-r--r-- | nn/common/operations/Slice.cpp | 12 | ||||
-rw-r--r-- | nn/common/operations/Softmax.cpp | 17 | ||||
-rw-r--r-- | nn/common/operations/Transpose.cpp | 12 | ||||
-rw-r--r-- | nn/common/operations/TransposeConv2D.cpp | 6 |
15 files changed, 126 insertions, 94 deletions
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp index e3d848799..bcf846a80 100644 --- a/nn/common/operations/Activation.cpp +++ b/nn/common/operations/Activation.cpp @@ -357,18 +357,19 @@ bool validate(OperationType opType, const IOperationValidationContext* context) NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::TANH) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType; } @@ -376,21 +377,26 @@ bool validate(OperationType opType, const IOperationValidationContext* context) if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool validateHardSwish(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU"; } - return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(OperationType opType, IOperationExecutionContext* context) { diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp index ce1320fb5..e47bd21e8 100644 --- a/nn/common/operations/Broadcast.cpp +++ b/nn/common/operations/Broadcast.cpp @@ -33,6 +33,7 @@ #include "OperationResolver.h" #include "Tracing.h" #include "nnapi/Types.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -434,19 +435,19 @@ bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, c } // namespace bool validate(OperationType opType, const IOperationValidationContext* context) { - const Version opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB) - ? Version::ANDROID_P - : Version::ANDROID_OC_MR1; + auto minSupportedVersion = (opType == OperationType::DIV || opType == OperationType::SUB) + ? Version::ANDROID_P + : Version::ANDROID_OC_MR1; NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor1); if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_Q, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::SUB) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_Q, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else if (opType == OperationType::DIV) { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV"; } else if (opType == OperationType::MUL) { @@ -454,15 +455,13 @@ bool validate(OperationType opType, const IOperationValidationContext* context) Shape input1 = context->getInputShape(kInputTensor1); Shape input2 = context->getInputShape(kInputTensor2); NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale); - NN_RET_CHECK( - validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } else { - NN_RET_CHECK( - validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED || inputType == OperandType::TENSOR_INT32) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_R, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R); } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType; } @@ -472,8 +471,9 @@ bool validate(OperationType opType, const IOperationValidationContext* context) NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4); NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4); } - return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, OperandType::INT32})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Concatenation.cpp b/nn/common/operations/Concatenation.cpp index cadfd0f65..16a08d6b9 100644 --- a/nn/common/operations/Concatenation.cpp +++ b/nn/common/operations/Concatenation.cpp @@ -140,12 +140,13 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_GE(inputCount, 2); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); const OperandType inputType = context->getInputType(0); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; } @@ -166,8 +167,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_LE(inputRank, 4); } } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp index 5a5e33764..d00da57a9 100644 --- a/nn/common/operations/Conv2D.cpp +++ b/nn/common/operations/Conv2D.cpp @@ -612,17 +612,19 @@ bool validate(const IOperationValidationContext* context) { } } + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else if (inputType == OperandType::TENSOR_FLOAT16 || filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout || withDilation || !meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/DepthwiseConv2D.cpp b/nn/common/operations/DepthwiseConv2D.cpp index 611e38d63..bb158b328 100644 --- a/nn/common/operations/DepthwiseConv2D.cpp +++ b/nn/common/operations/DepthwiseConv2D.cpp @@ -495,17 +495,19 @@ bool validate(const IOperationValidationContext* context) { } } + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else if (inputType == OperandType::TENSOR_FLOAT16 || filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout || withDilation || !meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Elu.cpp b/nn/common/operations/Elu.cpp index 0c72cb383..105ef01cf 100644 --- a/nn/common/operations/Elu.cpp +++ b/nn/common/operations/Elu.cpp @@ -56,15 +56,17 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU"; } auto scalarType = inputType == OperandType::TENSOR_FLOAT16 ? OperandType::FLOAT16 : OperandType::FLOAT32; - return validateInputTypes(context, {inputType, scalarType}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType, scalarType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp index 7c8c4e304..873b64abf 100644 --- a/nn/common/operations/FullyConnected.cpp +++ b/nn/common/operations/FullyConnected.cpp @@ -223,8 +223,9 @@ bool validate(const IOperationValidationContext* context) { auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, @@ -232,7 +233,7 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, @@ -249,9 +250,9 @@ bool validate(const IOperationValidationContext* context) { bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale); if (!meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } inExpectedTypes = { @@ -261,7 +262,7 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, @@ -283,7 +284,7 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateShapes(input, weights, bias)); } - return true; + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/L2Normalization.cpp b/nn/common/operations/L2Normalization.cpp index 49cc15dda..22f0cb3d2 100644 --- a/nn/common/operations/L2Normalization.cpp +++ b/nn/common/operations/L2Normalization.cpp @@ -203,27 +203,29 @@ bool validate(const IOperationValidationContext* context) { const OperandType inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes = {inputType}; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; } if (context->getNumInputs() == kNumInputs) { inExpectedTypes.push_back(OperandType::INT32); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (context->getInputShape(kInputTensor).dimensions.size() != 4) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } const Shape& input = context->getInputShape(kInputTensor); if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/LocalResponseNormalization.cpp b/nn/common/operations/LocalResponseNormalization.cpp index 6276168a7..435d602f2 100644 --- a/nn/common/operations/LocalResponseNormalization.cpp +++ b/nn/common/operations/LocalResponseNormalization.cpp @@ -138,15 +138,16 @@ bool validate(const IOperationValidationContext* context) { const OperandType inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::INT32, OperandType::FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16, @@ -158,17 +159,18 @@ bool validate(const IOperationValidationContext* context) { if (context->getNumInputs() == kNumInputs) { inExpectedTypes.push_back(OperandType::INT32); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (context->getInputShape(kInputTensor).dimensions.size() != 4) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } const Shape& input = context->getInputShape(kInputTensor); if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Pooling.cpp b/nn/common/operations/Pooling.cpp index bc6571d79..e1c4cdcd5 100644 --- a/nn/common/operations/Pooling.cpp +++ b/nn/common/operations/Pooling.cpp @@ -24,6 +24,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -293,14 +294,15 @@ bool validate(OperationType opType, const IOperationValidationContext* context) NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7); auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { inputType, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, @@ -308,7 +310,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context) }; } else if (opType != OperationType::L2_POOL_2D && inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32, @@ -320,7 +322,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context) }; } else if (opType != OperationType::L2_POOL_2D && inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32, @@ -341,12 +343,13 @@ bool validate(OperationType opType, const IOperationValidationContext* context) } if (inputCount == 11 || inputCount == 8) { inExpectedTypes.push_back(OperandType::BOOL); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/ResizeImageOps.cpp b/nn/common/operations/ResizeImageOps.cpp index a1acf187d..2c923f852 100644 --- a/nn/common/operations/ResizeImageOps.cpp +++ b/nn/common/operations/ResizeImageOps.cpp @@ -25,6 +25,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -181,19 +182,20 @@ bool validate(OperationType opType, const IOperationValidationContext* context) auto inputType = context->getInputType(kInputTensor); auto scalarType = context->getInputType(kOutputHeightParamScalar); std::vector<OperandType> inExpectedTypes = {inputType, scalarType, scalarType}; + auto minSupportedVersion = Version::ANDROID_OC_MR1; NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported tensor type for operation " << opType; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R); } if (scalarType != OperandType::INT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); if (inputType == OperandType::TENSOR_FLOAT32) { NN_RET_CHECK(scalarType == OperandType::FLOAT32); } else if (inputType == OperandType::TENSOR_FLOAT16) { @@ -204,18 +206,19 @@ bool validate(OperationType opType, const IOperationValidationContext* context) } } if (numInputs < kNumInputs) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } else if (numInputs == kNumInputs) { inExpectedTypes.push_back(OperandType::BOOL); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else { while (inExpectedTypes.size() < numInputs) { inExpectedTypes.push_back(OperandType::BOOL); } - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(OperationType opType, IOperationExecutionContext* context) { diff --git a/nn/common/operations/Slice.cpp b/nn/common/operations/Slice.cpp index 3cf3c8a33..882b0eb1a 100644 --- a/nn/common/operations/Slice.cpp +++ b/nn/common/operations/Slice.cpp @@ -89,14 +89,16 @@ bool validate(const IOperationValidationContext* context) { inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported tensor type for operation " << kOperationName; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } - return validateInputTypes(context, - {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes( + context, {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Softmax.cpp b/nn/common/operations/Softmax.cpp index a9373957a..e3c362f8a 100644 --- a/nn/common/operations/Softmax.cpp +++ b/nn/common/operations/Softmax.cpp @@ -27,6 +27,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -232,14 +233,15 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = {inputType, OperandType::FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = {inputType, OperandType::FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = {inputType, OperandType::FLOAT32}; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; @@ -249,15 +251,16 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_LE(inputRank, 4); } if (context->getNumInputs() == kNumInputs) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); inExpectedTypes.push_back(OperandType::INT32); } else { if (inputRank != 2 && inputRank != 4 && inputRank != 0) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Transpose.cpp b/nn/common/operations/Transpose.cpp index 3bc76f03e..b964c39a8 100644 --- a/nn/common/operations/Transpose.cpp +++ b/nn/common/operations/Transpose.cpp @@ -74,12 +74,13 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); const OperandType inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_P)); + minSupportedVersion = Version::ANDROID_P; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; } @@ -87,8 +88,9 @@ bool validate(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::TENSOR_INT32})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/TransposeConv2D.cpp b/nn/common/operations/TransposeConv2D.cpp index 78d857a35..9d6dbbbfc 100644 --- a/nn/common/operations/TransposeConv2D.cpp +++ b/nn/common/operations/TransposeConv2D.cpp @@ -474,9 +474,9 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, OperandType::INT32, OperandType::BOOL}; } inExpectedTypes.insert(inExpectedTypes.end(), argExpectedTypes.begin(), argExpectedTypes.end()); - NN_RET_CHECK(validateVersion(context, minSupportedVersion)); - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return validateVersion(context, minSupportedVersion); } bool prepare(IOperationExecutionContext* context) { |