diff options
Diffstat (limited to 'nn/common/operations/Activation.cpp')
-rw-r--r-- | nn/common/operations/Activation.cpp | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp index e3d848799..651cd020e 100644 --- a/nn/common/operations/Activation.cpp +++ b/nn/common/operations/Activation.cpp @@ -353,22 +353,23 @@ bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData, } // namespace -bool validate(OperationType opType, const IOperationValidationContext* context) { +Result<Version> validate(OperationType opType, const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::TANH) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType; } @@ -376,21 +377,26 @@ bool validate(OperationType opType, const IOperationValidationContext* context) if (hasKnownRank(input)) { NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); } - return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } -bool validateHardSwish(const IOperationValidationContext* context) { +Result<Version> validateHardSwish(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU"; } - return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); + NN_RET_CHECK(validateInputTypes(context, {inputType})); + NN_RET_CHECK(validateOutputTypes(context, {inputType})); + return minSupportedVersion; } bool prepare(OperationType opType, IOperationExecutionContext* context) { |