diff options
author | Michael Butler <butlermichael@google.com> | 2020-11-10 23:06:13 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2020-11-10 23:06:13 +0000 |
commit | 979afbaee92a97d0ec9f0d13436ef7670a7e9538 (patch) | |
tree | e96b673d1baeb8853e54e1812cfb57a0036a4231 | |
parent | f1c452cda6533807bcab2337cb13d5184405505c (diff) | |
parent | ca7a45a1234977c93fdfb578b64114d13ee27b7f (diff) | |
download | ml-979afbaee92a97d0ec9f0d13436ef7670a7e9538.tar.gz |
Merge changes I47c12e13,Iae0d6ef3
* changes:
Make operation validation return Result<Version>
Reorganize operation validation version code
48 files changed, 265 insertions, 273 deletions
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp index d65566f96..c5a71e981 100644 --- a/nn/common/OperationsUtils.cpp +++ b/nn/common/OperationsUtils.cpp @@ -86,8 +86,9 @@ bool validateOutputTypes(const IOperationValidationContext* context, [context](uint32_t index) { return context->getOutputType(index); }); } -bool validateVersion(const IOperationValidationContext* context, Version minSupportedVersion) { - if (context->getVersion() < minSupportedVersion) { +bool validateVersion(const IOperationValidationContext* context, Version contextVersion, + Version minSupportedVersion) { + if (contextVersion < minSupportedVersion) { std::ostringstream message; message << "Operation " << context->getOperationName() << " with inputs {"; for (uint32_t i = 0, n = context->getNumInputs(); i < n; ++i) { @@ -104,7 +105,7 @@ bool validateVersion(const IOperationValidationContext* context, Version minSupp message << context->getOutputType(i); } message << "} is only supported since " << minSupportedVersion << " (validating using " - << context->getVersion() << ")"; + << contextVersion << ")"; NN_RET_CHECK_FAIL() << message.str(); } return true; diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index da4dbc87f..7417ed8bf 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -215,18 +215,15 @@ class OperationValidationContext : public IOperationValidationContext { public: OperationValidationContext(const char* operationName, uint32_t inputCount, const uint32_t* inputIndexes, uint32_t outputCount, - const uint32_t* outputIndexes, const Operand* operands, - HalVersion halVersion) + const uint32_t* outputIndexes, const Operand* operands) : operationName(operationName), inputCount(inputCount), inputIndexes(inputIndexes), outputCount(outputCount), outputIndexes(outputIndexes), - operands(operands), - version(convert(halVersion)) {} + operands(operands) {} const char* getOperationName() const override; - Version getVersion() const override; uint32_t getNumInputs() const override; OperandType getInputType(uint32_t index) const override; @@ -254,10 +251,6 @@ const char* OperationValidationContext::getOperationName() const { return operationName; } -Version OperationValidationContext::getVersion() const { - return version; -} - const Operand* OperationValidationContext::getInputOperand(uint32_t index) const { CHECK(index < static_cast<uint32_t>(inputCount)); return &operands[inputIndexes[index]]; @@ -1883,8 +1876,14 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } OperationValidationContext context(operationRegistration->name, inputCount, inputIndexes, outputCount, outputIndexes, - operands.data(), halVersion); - if (!operationRegistration->validate(&context)) { + operands.data()); + const auto maybeVersion = operationRegistration->validate(&context); + if (!maybeVersion.has_value()) { + LOG(ERROR) << "Validation failed for operation " << opType << ": " + << maybeVersion.error(); + return ANEURALNETWORKS_BAD_DATA; + } + if (!validateVersion(&context, convert(halVersion), maybeVersion.value())) { LOG(ERROR) << "Validation failed for operation " << opType; return ANEURALNETWORKS_BAD_DATA; } diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp index e49f073f2..d37c447c0 100644 --- a/nn/common/Validation.cpp +++ b/nn/common/Validation.cpp @@ -1181,15 +1181,13 @@ class OperationValidationContext : public IOperationValidationContext { public: OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes, const std::vector<uint32_t>& outputIndexes, - const std::vector<Operand>& operands, Version version) + const std::vector<Operand>& operands) : operationName(operationName), inputIndexes(inputIndexes), outputIndexes(outputIndexes), - operands(operands), - version(version) {} + operands(operands) {} const char* getOperationName() const override; - Version getVersion() const override; uint32_t getNumInputs() const override; OperandType getInputType(uint32_t index) const override; @@ -1208,17 +1206,12 @@ class OperationValidationContext : public IOperationValidationContext { const std::vector<uint32_t>& inputIndexes; const std::vector<uint32_t>& outputIndexes; const std::vector<Operand>& operands; - Version version; }; const char* OperationValidationContext::getOperationName() const { return operationName; } -Version OperationValidationContext::getVersion() const { - return version; -} - const Operand* OperationValidationContext::getInputOperand(uint32_t index) const { return &operands.at(inputIndexes.at(index)); } @@ -2521,20 +2514,9 @@ Result<Version> validateOperationImpl(const Operation& operation, NN_VALIDATE(operationRegistration->validate != nullptr) << "Incomplete operation registration: " << opType; - constexpr Version kVersions[] = {Version::ANDROID_OC_MR1, Version::ANDROID_P, - Version::ANDROID_Q, Version::ANDROID_R, - Version::CURRENT_RUNTIME}; - - for (const auto version : kVersions) { - OperationValidationContext context(operationRegistration->name, inputIndexes, - outputIndexes, operands, version); - auto valid = operationRegistration->validate(&context); - if (valid) { - return version; - } - } - - return NN_ERROR() << "Validation failed for operation " << opType; + OperationValidationContext context(operationRegistration->name, inputIndexes, + outputIndexes, operands); + return operationRegistration->validate(&context); } } } diff --git a/nn/common/include/OperationResolver.h b/nn/common/include/OperationResolver.h index d2c066cd3..155341a1a 100644 --- a/nn/common/include/OperationResolver.h +++ b/nn/common/include/OperationResolver.h @@ -32,7 +32,7 @@ struct OperationRegistration { const char* name; // Validates operand types, shapes, and any values known during graph creation. - std::function<bool(const IOperationValidationContext*)> validate; + std::function<Result<Version>(const IOperationValidationContext*)> validate; // prepare is called when the inputs this operation depends on have been // computed. Typically, prepare does any remaining validation and sets @@ -50,10 +50,11 @@ struct OperationRegistration { bool allowZeroSizedInput = false; } flags; - OperationRegistration(OperationType type, const char* name, - std::function<bool(const IOperationValidationContext*)> validate, - std::function<bool(IOperationExecutionContext*)> prepare, - std::function<bool(IOperationExecutionContext*)> execute, Flag flags) + OperationRegistration( + OperationType type, const char* name, + std::function<Result<Version>(const IOperationValidationContext*)> validate, + std::function<bool(IOperationExecutionContext*)> prepare, + std::function<bool(IOperationExecutionContext*)> execute, Flag flags) : type(type), name(name), validate(std::move(validate)), diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h index 676bbb34a..9123139b1 100644 --- a/nn/common/include/OperationsUtils.h +++ b/nn/common/include/OperationsUtils.h @@ -59,19 +59,6 @@ class IOperationValidationContext { virtual const char* getOperationName() const = 0; - // The version of the environment in which the operation is to be executed. - // - // Operation validation logic needs to handle all versions to support the following use cases - // (assume in these examples that the latest version is Version::ANDROID_Q): - // 1. Our runtime wants to distribute work to a driver implementing an older version and calls, - // for example, compliantWithV1_0(const V1_2::Model&). - // 2. A driver implements an older version and delegates model validation to, for example, - // validateModel(const V1_0::Model&). - // - // If getVersion() returns Version::ANDROID_OC_MR1 and the operation is only supported since - // Version::ANDROID_P, validation will fail. - virtual Version getVersion() const = 0; - virtual uint32_t getNumInputs() const = 0; virtual OperandType getInputType(uint32_t index) const = 0; virtual Shape getInputShape(uint32_t index) const = 0; @@ -130,7 +117,8 @@ bool validateOutputTypes(const IOperationValidationContext* context, // Verifies that the HAL version specified in the context is greater or equal // than the minimal supported HAL version. -bool validateVersion(const IOperationValidationContext* context, Version minSupportedVersion); +bool validateVersion(const IOperationValidationContext* context, Version contextVersion, + Version minSupportedVersion); // Verifies that the two shapes are the same. bool SameShape(const Shape& in1, const Shape& in2); diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h index 9dc67cf46..56b62f9b2 100644 --- a/nn/common/include/nnapi/TypeUtils.h +++ b/nn/common/include/nnapi/TypeUtils.h @@ -198,6 +198,8 @@ class FalseyErrorStream { operator bool() const { return false; } + operator Result<Version>() const { return error() << mBuffer.str(); } + private: std::ostringstream mBuffer; }; diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp index e3d848799..651cd020e 100644 --- a/nn/common/operations/Activation.cpp +++ b/nn/common/operations/Activation.cpp @@ -353,22 +353,23 @@ bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData, } // namespace -bool validate(OperationType opType, const IOperationValidationContext* context) { +Result<Version> validate(OperationType opType, const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::TANH) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType; } @@ -376,21 +377,26 @@ bool validate(OperationType opType, const IOperationValidationContext* context) if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } -bool validateHardSwish(const IOperationValidationContext* context) { +Result<Version> validateHardSwish(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU"; } - return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(OperationType opType, IOperationExecutionContext* context) { diff --git a/nn/common/operations/BidirectionalSequenceRNN.cpp b/nn/common/operations/BidirectionalSequenceRNN.cpp index f6b4c301c..5a020d1f7 100644 --- a/nn/common/operations/BidirectionalSequenceRNN.cpp +++ b/nn/common/operations/BidirectionalSequenceRNN.cpp @@ -313,7 +313,7 @@ bool executeTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); // Exact number is dependent on the mergeOutputs parameter and checked // during preparation. @@ -323,9 +323,8 @@ bool validate(const IOperationValidationContext* context) { OperandType inputType = context->getInputType(kInputTensor); if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) { - LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " - << inputType; - return false; + return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " + << inputType; } NN_RET_CHECK(validateInputTypes( context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType, @@ -339,7 +338,7 @@ bool validate(const IOperationValidationContext* context) { if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) { minSupportedVersion = Version::ANDROID_R; } - return validateVersion(context, minSupportedVersion); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp index ce1320fb5..a2d5b8a39 100644 --- a/nn/common/operations/Broadcast.cpp +++ b/nn/common/operations/Broadcast.cpp @@ -33,6 +33,7 @@ #include "OperationResolver.h" #include "Tracing.h" #include "nnapi/Types.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -433,20 +434,20 @@ bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, c } // namespace -bool validate(OperationType opType, const IOperationValidationContext* context) { - const Version opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB) - ? Version::ANDROID_P - : Version::ANDROID_OC_MR1; +Result<Version> validate(OperationType opType, const IOperationValidationContext* context) { + auto minSupportedVersion = (opType == OperationType::DIV || opType == OperationType::SUB) + ? Version::ANDROID_P + : Version::ANDROID_OC_MR1; NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor1); if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_Q, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::SUB) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_Q, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else if (opType == OperationType::DIV) { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV"; } else if (opType == OperationType::MUL) { @@ -454,15 +455,13 @@ bool validate(OperationType opType, const IOperationValidationContext* context) Shape input1 = context->getInputShape(kInputTensor1); Shape input2 = context->getInputShape(kInputTensor2); NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale); - NN_RET_CHECK( - validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } else { - NN_RET_CHECK( - validateVersion(context, std::max(Version::ANDROID_OC_MR1, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED || inputType == OperandType::TENSOR_INT32) { - NN_RET_CHECK(validateVersion(context, std::max(Version::ANDROID_R, opIntroducedAt))); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R); } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType; } @@ -472,8 +471,9 @@ bool validate(OperationType opType, const IOperationValidationContext* context) NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4); NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4); } - return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, OperandType::INT32})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/ChannelShuffle.cpp b/nn/common/operations/ChannelShuffle.cpp index 59726fac2..efa08737b 100644 --- a/nn/common/operations/ChannelShuffle.cpp +++ b/nn/common/operations/ChannelShuffle.cpp @@ -57,7 +57,7 @@ inline bool eval(const T* inputData, const Shape& inputShape, int32_t numGroups, return true; } -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); @@ -73,9 +73,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::INT32, OperandType::INT32})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/Comparisons.cpp b/nn/common/operations/Comparisons.cpp index 8fdf72c59..b490c9218 100644 --- a/nn/common/operations/Comparisons.cpp +++ b/nn/common/operations/Comparisons.cpp @@ -123,7 +123,7 @@ bool executeGreaterTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor1); @@ -136,9 +136,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateInputTypes(context, {inputType, inputType})); NN_RET_CHECK(validateOutputTypes(context, {OperandType::TENSOR_BOOL8})); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/Concatenation.cpp b/nn/common/operations/Concatenation.cpp index cadfd0f65..6b9007e5e 100644 --- a/nn/common/operations/Concatenation.cpp +++ b/nn/common/operations/Concatenation.cpp @@ -29,6 +29,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -135,29 +136,30 @@ inline bool concatenation<int8_t>(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { uint32_t inputCount = context->getNumInputs(); NN_RET_CHECK_GE(inputCount, 2); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); const OperandType inputType = context->getInputType(0); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; } std::vector<OperandType> inExpectedTypes(inputCount - 1, inputType); inExpectedTypes.push_back(OperandType::INT32); - if (context->getVersion() < Version::ANDROID_Q && - inputType == OperandType::TENSOR_QUANT8_ASYMM) { + if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { const Shape& output = context->getOutputShape(kOutputTensor); for (uint32_t i = 0; i < inputCount - 1; ++i) { const Shape& input = context->getInputShape(i); - NN_RET_CHECK_EQ(input.scale, output.scale); - NN_RET_CHECK_EQ(input.offset, output.offset); + if (input.scale != output.scale || input.offset != output.offset) { + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); + } } } for (uint32_t i = 0; i < inputCount - 1; ++i) { @@ -166,8 +168,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_LE(inputRank, 4); } } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp index 5a5e33764..6d989827e 100644 --- a/nn/common/operations/Conv2D.cpp +++ b/nn/common/operations/Conv2D.cpp @@ -526,7 +526,7 @@ bool convQuant8PerChannel(const T* inputData, const Shape& inputShape, const int } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { const uint32_t numInputs = context->getNumInputs(); NN_RET_CHECK( std::binary_search(std::begin(kNumInputsArray), std::end(kNumInputsArray), numInputs)); @@ -612,17 +612,19 @@ bool validate(const IOperationValidationContext* context) { } } + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else if (inputType == OperandType::TENSOR_FLOAT16 || filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout || withDilation || !meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/DepthwiseConv2D.cpp b/nn/common/operations/DepthwiseConv2D.cpp index 611e38d63..64bd7dd4a 100644 --- a/nn/common/operations/DepthwiseConv2D.cpp +++ b/nn/common/operations/DepthwiseConv2D.cpp @@ -413,7 +413,7 @@ bool depthwiseConvQuant8PerChannel(const T* inputData, const Shape& inputShape, } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { const uint32_t numInputs = context->getNumInputs(); NN_RET_CHECK( std::binary_search(std::begin(kNumInputsArray), std::end(kNumInputsArray), numInputs)); @@ -495,17 +495,19 @@ bool validate(const IOperationValidationContext* context) { } } + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else if (inputType == OperandType::TENSOR_FLOAT16 || filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout || withDilation || !meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Dequantize.cpp b/nn/common/operations/Dequantize.cpp index f155eb286..b648ff135 100644 --- a/nn/common/operations/Dequantize.cpp +++ b/nn/common/operations/Dequantize.cpp @@ -75,7 +75,7 @@ bool computePerChannel(const int8_t* inputData, const Shape& inputShape, OutputT } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -89,7 +89,7 @@ bool validate(const IOperationValidationContext* context) { if (inputType == OperandType::TENSOR_QUANT8_ASYMM && outputType == OperandType::TENSOR_FLOAT32) { - return validateVersion(context, Version::ANDROID_OC_MR1); + return Version::ANDROID_OC_MR1; } NN_RET_CHECK(inputType == OperandType::TENSOR_QUANT8_ASYMM || @@ -100,7 +100,7 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(outputType == OperandType::TENSOR_FLOAT16 || outputType == OperandType::TENSOR_FLOAT32) << "Unsupported output operand type for DEQUANTIZE op: " << outputType; - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Elementwise.cpp b/nn/common/operations/Elementwise.cpp index a0cd78ffe..851000392 100644 --- a/nn/common/operations/Elementwise.cpp +++ b/nn/common/operations/Elementwise.cpp @@ -82,7 +82,7 @@ bool executeAbs(IOperationExecutionContext* context) { } } -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -91,10 +91,10 @@ bool validate(const IOperationValidationContext* context) { << "Unsupported tensor type for elementwise operation"; NN_RET_CHECK(validateInputTypes(context, {inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } -bool validateAbs(const IOperationValidationContext* context) { +Result<Version> validateAbs(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -103,11 +103,10 @@ bool validateAbs(const IOperationValidationContext* context) { << "Unsupported tensor type for operation ABS"; NN_RET_CHECK(validateInputTypes(context, {inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, (inputType == OperandType::TENSOR_INT32 ? Version::ANDROID_R - : Version::ANDROID_Q)); + return inputType == OperandType::TENSOR_INT32 ? Version::ANDROID_R : Version::ANDROID_Q; } -bool validateFloor(const IOperationValidationContext* context) { +Result<Version> validateFloor(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -123,9 +122,7 @@ bool validateFloor(const IOperationValidationContext* context) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateVersion(context, - (inputType == OperandType::TENSOR_FLOAT16 ? Version::ANDROID_Q - : Version::ANDROID_OC_MR1)); + return inputType == OperandType::TENSOR_FLOAT16 ? Version::ANDROID_Q : Version::ANDROID_OC_MR1; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Elu.cpp b/nn/common/operations/Elu.cpp index 0c72cb383..98e066210 100644 --- a/nn/common/operations/Elu.cpp +++ b/nn/common/operations/Elu.cpp @@ -52,19 +52,21 @@ bool eluFloat(const T* inputData, const Shape& inputShape, const T alpha, T* out } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU"; } auto scalarType = inputType == OperandType::TENSOR_FLOAT16 ? OperandType::FLOAT16 : OperandType::FLOAT32; - return validateInputTypes(context, {inputType, scalarType}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType, scalarType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Fill.cpp b/nn/common/operations/Fill.cpp index f3b470ed5..9af64f7de 100644 --- a/nn/common/operations/Fill.cpp +++ b/nn/common/operations/Fill.cpp @@ -61,7 +61,7 @@ bool getValueType(OperandType outputType, OperandType* valueType) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); // Check output type first because input value type is dependent on the @@ -77,7 +77,7 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(getValueType(outputType, &valueType)); NN_RET_CHECK(validateInputTypes(context, {OperandType::TENSOR_INT32, valueType})); - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp index 7c8c4e304..ab50d31c1 100644 --- a/nn/common/operations/FullyConnected.cpp +++ b/nn/common/operations/FullyConnected.cpp @@ -217,14 +217,15 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias, } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, @@ -232,7 +233,7 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, @@ -249,9 +250,9 @@ bool validate(const IOperationValidationContext* context) { bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale); if (!meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } inExpectedTypes = { @@ -261,7 +262,7 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, @@ -271,7 +272,6 @@ bool validate(const IOperationValidationContext* context) { }; } else { NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName; - return false; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); @@ -283,7 +283,7 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateShapes(input, weights, bias)); } - return true; + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Gather.cpp b/nn/common/operations/Gather.cpp index 6707b6d94..5571a6501 100644 --- a/nn/common/operations/Gather.cpp +++ b/nn/common/operations/Gather.cpp @@ -59,7 +59,7 @@ inline bool eval(const T* inputData, const Shape& inputShape, int32_t axis, } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -73,9 +73,9 @@ bool validate(const IOperationValidationContext* context) { {inputType, OperandType::INT32, OperandType::TENSOR_INT32})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/GenerateProposals.cpp b/nn/common/operations/GenerateProposals.cpp index edd7cb0db..95e3676e0 100644 --- a/nn/common/operations/GenerateProposals.cpp +++ b/nn/common/operations/GenerateProposals.cpp @@ -197,7 +197,7 @@ constexpr uint32_t kImageInfoTensor = 3; constexpr uint32_t kNumOutputs = 1; constexpr uint32_t kOutputTensor = 0; -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -211,16 +211,14 @@ bool validate(const IOperationValidationContext* context) { inExpectedTypes = {OperandType::TENSOR_QUANT16_ASYMM, deltaInputType, OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_ASYMM}; } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { @@ -703,7 +701,7 @@ bool boxWithNmsLimitQuant(const int8_t* scoresData, const Shape& scoresShape, } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -742,9 +740,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } @@ -1213,7 +1211,7 @@ bool generateProposalsQuant(const T_8QInput* scoresData, const Shape& scoresShap } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -1268,9 +1266,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } @@ -1569,7 +1567,7 @@ bool detectionPostprocessFloat16( } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -1597,7 +1595,7 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes( context, {inputType, inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/HeatmapMaxKeypoint.cpp b/nn/common/operations/HeatmapMaxKeypoint.cpp index 1da7ed07f..63fc5973b 100644 --- a/nn/common/operations/HeatmapMaxKeypoint.cpp +++ b/nn/common/operations/HeatmapMaxKeypoint.cpp @@ -224,7 +224,7 @@ inline bool heatmapMaxKeypointQuant(const int8_t* heatmap, const Shape& heatmapS } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -245,12 +245,11 @@ bool validate(const IOperationValidationContext* context) { OperandType::TENSOR_QUANT16_ASYMM}; minSupportedVersion = Version::ANDROID_R; } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); - return validateVersion(context, minSupportedVersion); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/InstanceNormalization.cpp b/nn/common/operations/InstanceNormalization.cpp index 62b7728f8..1a0e488e9 100644 --- a/nn/common/operations/InstanceNormalization.cpp +++ b/nn/common/operations/InstanceNormalization.cpp @@ -99,7 +99,7 @@ inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -111,12 +111,11 @@ bool validate(const IOperationValidationContext* context) { inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16, OperandType::BOOL}; } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/L2Normalization.cpp b/nn/common/operations/L2Normalization.cpp index 49cc15dda..05682ea3a 100644 --- a/nn/common/operations/L2Normalization.cpp +++ b/nn/common/operations/L2Normalization.cpp @@ -196,34 +196,36 @@ bool l2normQuant8Signed(const int8_t* inputData, const Shape& inputShape, int32_ } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK(context->getNumInputs() == kNumInputs || context->getNumInputs() == kNumInputs - 1); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); const OperandType inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes = {inputType}; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; } if (context->getNumInputs() == kNumInputs) { inExpectedTypes.push_back(OperandType::INT32); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (context->getInputShape(kInputTensor).dimensions.size() != 4) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } const Shape& input = context->getInputShape(kInputTensor); if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/LocalResponseNormalization.cpp b/nn/common/operations/LocalResponseNormalization.cpp index 6276168a7..ed16dec6b 100644 --- a/nn/common/operations/LocalResponseNormalization.cpp +++ b/nn/common/operations/LocalResponseNormalization.cpp @@ -130,7 +130,7 @@ bool executeTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK(context->getNumInputs() == kNumInputs || context->getNumInputs() == kNumInputs - 1); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -138,15 +138,16 @@ bool validate(const IOperationValidationContext* context) { const OperandType inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::INT32, OperandType::FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16, @@ -158,17 +159,18 @@ bool validate(const IOperationValidationContext* context) { if (context->getNumInputs() == kNumInputs) { inExpectedTypes.push_back(OperandType::INT32); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (context->getInputShape(kInputTensor).dimensions.size() != 4) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } const Shape& input = context->getInputShape(kInputTensor); if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/LogSoftmax.cpp b/nn/common/operations/LogSoftmax.cpp index 86a882fd5..6fe934a3d 100644 --- a/nn/common/operations/LogSoftmax.cpp +++ b/nn/common/operations/LogSoftmax.cpp @@ -70,7 +70,7 @@ inline bool compute(const T* input, const Shape& shape, T beta, uint32_t axis, T return true; } -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -83,12 +83,11 @@ bool validate(const IOperationValidationContext* context) { inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/LogicalAndOr.cpp b/nn/common/operations/LogicalAndOr.cpp index 163aa5432..e1927a59f 100644 --- a/nn/common/operations/LogicalAndOr.cpp +++ b/nn/common/operations/LogicalAndOr.cpp @@ -60,7 +60,7 @@ bool compute(const std::function<bool(bool, bool)>& func, const bool8* aData, co } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor1); @@ -68,7 +68,7 @@ bool validate(const IOperationValidationContext* context) { << "Unsupported tensor type for a logical operation"; NN_RET_CHECK(validateInputTypes(context, {inputType, inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/LogicalNot.cpp b/nn/common/operations/LogicalNot.cpp index 2f6bb637b..b93e71b90 100644 --- a/nn/common/operations/LogicalNot.cpp +++ b/nn/common/operations/LogicalNot.cpp @@ -41,7 +41,7 @@ bool compute(const bool8* input, const Shape& shape, bool8* output) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -49,7 +49,7 @@ bool validate(const IOperationValidationContext* context) { << "Unsupported tensor type for LOGICAL_NOT"; NN_RET_CHECK(validateInputTypes(context, {inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Neg.cpp b/nn/common/operations/Neg.cpp index 1d042fcbe..39b58b9b2 100644 --- a/nn/common/operations/Neg.cpp +++ b/nn/common/operations/Neg.cpp @@ -47,7 +47,7 @@ inline bool compute(const T* input, const Shape& shape, T* output) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -56,7 +56,7 @@ bool validate(const IOperationValidationContext* context) { << "Unsupported tensor type for operation " << kOperationName; NN_RET_CHECK(validateInputTypes(context, {inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/PRelu.cpp b/nn/common/operations/PRelu.cpp index db2f6d416..88e38fcf4 100644 --- a/nn/common/operations/PRelu.cpp +++ b/nn/common/operations/PRelu.cpp @@ -95,7 +95,7 @@ bool evalQuant8(const T* aData, const Shape& aShape, const T* bData, const Shape aData, aShape, bData, bShape, outputData, outputShape); } -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); @@ -107,9 +107,9 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateInputTypes(context, {inputType, inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/Pooling.cpp b/nn/common/operations/Pooling.cpp index bc6571d79..6cd286439 100644 --- a/nn/common/operations/Pooling.cpp +++ b/nn/common/operations/Pooling.cpp @@ -24,6 +24,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -287,20 +288,21 @@ bool maxPool(const T* inputData, const Shape& inputShape, const PoolingParam& pa } // namespace -bool validate(OperationType opType, const IOperationValidationContext* context) { +Result<Version> validate(OperationType opType, const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputCount = context->getNumInputs(); NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7); auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { inputType, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, @@ -308,7 +310,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context) }; } else if (opType != OperationType::L2_POOL_2D && inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32, @@ -320,7 +322,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context) }; } else if (opType != OperationType::L2_POOL_2D && inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32, @@ -341,12 +343,13 @@ bool validate(OperationType opType, const IOperationValidationContext* context) } if (inputCount == 11 || inputCount == 8) { inExpectedTypes.push_back(OperandType::BOOL); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/QLSTM.cpp b/nn/common/operations/QLSTM.cpp index 0812e6661..e8c4f90f6 100644 --- a/nn/common/operations/QLSTM.cpp +++ b/nn/common/operations/QLSTM.cpp @@ -101,7 +101,7 @@ inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -149,7 +149,7 @@ bool validate(const IOperationValidationContext* context) { outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Quantize.cpp b/nn/common/operations/Quantize.cpp index c3f4812c2..b9d37b8da 100644 --- a/nn/common/operations/Quantize.cpp +++ b/nn/common/operations/Quantize.cpp @@ -63,7 +63,7 @@ bool quantizeToQuant8Signed(const T* inputData, int8_t* outputData, const Shape& } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -77,9 +77,9 @@ bool validate(const IOperationValidationContext* context) { outputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported output operand type for QUANTIZE op: " << outputType; if (outputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/Rank.cpp b/nn/common/operations/Rank.cpp index 71951d703..f6363417a 100644 --- a/nn/common/operations/Rank.cpp +++ b/nn/common/operations/Rank.cpp @@ -30,7 +30,7 @@ constexpr uint32_t kInputTensor = 0; constexpr uint32_t kNumOutputs = 1; constexpr uint32_t kOutputScalar = 0; -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -46,7 +46,7 @@ bool validate(const IOperationValidationContext* context) { inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Incorrect input type for a RANK op: " << inputType; NN_RET_CHECK(validateOutputTypes(context, {OperandType::INT32})); - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Reduce.cpp b/nn/common/operations/Reduce.cpp index 0563a3536..9eb195648 100644 --- a/nn/common/operations/Reduce.cpp +++ b/nn/common/operations/Reduce.cpp @@ -66,7 +66,7 @@ inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) { } // namespace -bool validateProdSum(const IOperationValidationContext* context) { +Result<Version> validateProdSum(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -80,10 +80,10 @@ bool validateProdSum(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } -bool validateMaxMin(const IOperationValidationContext* context) { +Result<Version> validateMaxMin(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -103,10 +103,10 @@ bool validateMaxMin(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateVersion(context, minVersion); + return minVersion; } -bool validateLogical(const IOperationValidationContext* context) { +Result<Version> validateLogical(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -119,7 +119,7 @@ bool validateLogical(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/ResizeImageOps.cpp b/nn/common/operations/ResizeImageOps.cpp index a1acf187d..733bedb69 100644 --- a/nn/common/operations/ResizeImageOps.cpp +++ b/nn/common/operations/ResizeImageOps.cpp @@ -25,6 +25,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -168,7 +169,7 @@ inline bool getOptionalScalar(const IOperationExecutionContext* context, uint32_ } // namespace -bool validate(OperationType opType, const IOperationValidationContext* context) { +Result<Version> validate(OperationType opType, const IOperationValidationContext* context) { const auto numInputs = context->getNumInputs(); if (opType == OperationType::RESIZE_BILINEAR) { NN_RET_CHECK(numInputs >= kNumInputs - 1 && numInputs <= kNumInputs + kNumOptionalInputs); @@ -181,19 +182,20 @@ bool validate(OperationType opType, const IOperationValidationContext* context) auto inputType = context->getInputType(kInputTensor); auto scalarType = context->getInputType(kOutputHeightParamScalar); std::vector<OperandType> inExpectedTypes = {inputType, scalarType, scalarType}; + auto minSupportedVersion = Version::ANDROID_OC_MR1; NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported tensor type for operation " << opType; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R); } if (scalarType != OperandType::INT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); if (inputType == OperandType::TENSOR_FLOAT32) { NN_RET_CHECK(scalarType == OperandType::FLOAT32); } else if (inputType == OperandType::TENSOR_FLOAT16) { @@ -204,18 +206,19 @@ bool validate(OperationType opType, const IOperationValidationContext* context) } } if (numInputs < kNumInputs) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1); } else if (numInputs == kNumInputs) { inExpectedTypes.push_back(OperandType::BOOL); - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } else { while (inExpectedTypes.size() < numInputs) { inExpectedTypes.push_back(OperandType::BOOL); } - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R); } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(OperationType opType, IOperationExecutionContext* context) { diff --git a/nn/common/operations/RoiAlign.cpp b/nn/common/operations/RoiAlign.cpp index 78049b8bb..3ca64f56a 100644 --- a/nn/common/operations/RoiAlign.cpp +++ b/nn/common/operations/RoiAlign.cpp @@ -337,7 +337,7 @@ inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_ } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -367,15 +367,14 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, OperandType::BOOL}; } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/RoiPooling.cpp b/nn/common/operations/RoiPooling.cpp index a011b4ae2..26e2213a3 100644 --- a/nn/common/operations/RoiPooling.cpp +++ b/nn/common/operations/RoiPooling.cpp @@ -184,7 +184,7 @@ inline bool roiPooling<int8_t, uint16_t>(const int8_t* inputData, const Shape& i } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); std::vector<OperandType> inExpectedTypes; @@ -210,16 +210,14 @@ bool validate(const IOperationValidationContext* context) { OperandType::FLOAT32, OperandType::BOOL}; } else { - LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; - return false; + return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - return validateVersion(context, Version::ANDROID_R); - ; + return Version::ANDROID_R; } else { - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } } diff --git a/nn/common/operations/Select.cpp b/nn/common/operations/Select.cpp index 0b7728ab9..f037b4810 100644 --- a/nn/common/operations/Select.cpp +++ b/nn/common/operations/Select.cpp @@ -66,7 +66,7 @@ bool executeTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor1); @@ -78,7 +78,7 @@ bool validate(const IOperationValidationContext* context) { << "Unsupported input operand type for select op: " << inputType; NN_RET_CHECK(validateInputTypes(context, {OperandType::TENSOR_BOOL8, inputType, inputType})); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, Version::ANDROID_Q); + return Version::ANDROID_Q; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Slice.cpp b/nn/common/operations/Slice.cpp index 3cf3c8a33..db47419f7 100644 --- a/nn/common/operations/Slice.cpp +++ b/nn/common/operations/Slice.cpp @@ -78,7 +78,7 @@ bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t* beg } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -89,14 +89,16 @@ bool validate(const IOperationValidationContext* context) { inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported tensor type for operation " << kOperationName; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } - return validateInputTypes(context, - {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes( + context, {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Softmax.cpp b/nn/common/operations/Softmax.cpp index a9373957a..3e65d85bf 100644 --- a/nn/common/operations/Softmax.cpp +++ b/nn/common/operations/Softmax.cpp @@ -27,6 +27,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -226,20 +227,21 @@ bool softmaxQuant8(const T* inputData, const Shape& inputShape, const float beta } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK(context->getNumInputs() == kNumInputs || context->getNumInputs() == kNumInputs - 1); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = {inputType, OperandType::FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = {inputType, OperandType::FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = {inputType, OperandType::FLOAT32}; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; @@ -249,15 +251,16 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_LE(inputRank, 4); } if (context->getNumInputs() == kNumInputs) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); inExpectedTypes.push_back(OperandType::INT32); } else { if (inputRank != 2 && inputRank != 4 && inputRank != 0) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); } } - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Squeeze.cpp b/nn/common/operations/Squeeze.cpp index e9640b964..2fe8eb8aa 100644 --- a/nn/common/operations/Squeeze.cpp +++ b/nn/common/operations/Squeeze.cpp @@ -35,7 +35,7 @@ constexpr uint32_t kSqueezeDims = 1; constexpr uint32_t kNumOutputs = 1; constexpr uint32_t kOutputTensor = 0; -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -63,7 +63,7 @@ bool validate(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateVersion(context, minSupportedVersion); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/StridedSlice.cpp b/nn/common/operations/StridedSlice.cpp index 654659ac1..fd66ca7c4 100644 --- a/nn/common/operations/StridedSlice.cpp +++ b/nn/common/operations/StridedSlice.cpp @@ -96,7 +96,7 @@ bool executeTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -129,7 +129,7 @@ bool validate(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateVersion(context, minSupportedVersion); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/TopK_V2.cpp b/nn/common/operations/TopK_V2.cpp index d91c8131e..d19a309b3 100644 --- a/nn/common/operations/TopK_V2.cpp +++ b/nn/common/operations/TopK_V2.cpp @@ -73,7 +73,7 @@ bool executeTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); OperandType inputType = context->getInputType(kInputTensor); @@ -89,7 +89,7 @@ bool validate(const IOperationValidationContext* context) { if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { minSupportedVersion = Version::ANDROID_R; } - return validateVersion(context, minSupportedVersion); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/Transpose.cpp b/nn/common/operations/Transpose.cpp index 3bc76f03e..0e61575eb 100644 --- a/nn/common/operations/Transpose.cpp +++ b/nn/common/operations/Transpose.cpp @@ -69,17 +69,18 @@ bool transposeGeneric(const T* inputData, const Shape& inputShape, const int32_t } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); const OperandType inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_P)); + minSupportedVersion = Version::ANDROID_P; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; } @@ -87,8 +88,9 @@ bool validate(const IOperationValidationContext* context) { if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::TENSOR_INT32})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/TransposeConv2D.cpp b/nn/common/operations/TransposeConv2D.cpp index 78d857a35..002df2780 100644 --- a/nn/common/operations/TransposeConv2D.cpp +++ b/nn/common/operations/TransposeConv2D.cpp @@ -433,7 +433,7 @@ bool transposeConvQuant8PerChannel(const T* inputData, const Shape& inputShape, } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { const uint32_t inputCount = context->getNumInputs(); NN_RET_CHECK(inputCount == kNumInputs1 || inputCount == kNumInputs2); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -474,9 +474,9 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, OperandType::INT32, OperandType::BOOL}; } inExpectedTypes.insert(inExpectedTypes.end(), argExpectedTypes.begin(), argExpectedTypes.end()); - NN_RET_CHECK(validateVersion(context, minSupportedVersion)); - return validateInputTypes(context, inExpectedTypes) && - validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/UnidirectionalSequenceLSTM.cpp b/nn/common/operations/UnidirectionalSequenceLSTM.cpp index 02da1581f..dc734e8a4 100644 --- a/nn/common/operations/UnidirectionalSequenceLSTM.cpp +++ b/nn/common/operations/UnidirectionalSequenceLSTM.cpp @@ -112,7 +112,7 @@ inline LSTMParams getLSTMParams(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); const uint32_t numOutputs = context->getNumOutputs(); NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState); @@ -163,7 +163,7 @@ bool validate(const IOperationValidationContext* context) { } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); - return validateVersion(context, minVersionSupported); + return minVersionSupported; } bool prepare(IOperationExecutionContext* context) { diff --git a/nn/common/operations/UnidirectionalSequenceRNN.cpp b/nn/common/operations/UnidirectionalSequenceRNN.cpp index 382aa58e3..eaf60edd3 100644 --- a/nn/common/operations/UnidirectionalSequenceRNN.cpp +++ b/nn/common/operations/UnidirectionalSequenceRNN.cpp @@ -126,15 +126,14 @@ bool executeTyped(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); const int numOutputs = context->getNumOutputs(); NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState); OperandType inputType = context->getInputType(kInputTensor); if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) { - LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " - << inputType; - return false; + return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " + << inputType; } NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType, OperandType::INT32, OperandType::INT32})); @@ -145,7 +144,7 @@ bool validate(const IOperationValidationContext* context) { outputTypes.push_back(inputType); } NN_RET_CHECK(validateOutputTypes(context, outputTypes)); - return validateVersion(context, minVersionSupported); + return minVersionSupported; } bool prepare(IOperationExecutionContext* context) { |