summaryrefslogtreecommitdiff
path: root/nn/common/operations/Activation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/Activation.cpp')
-rw-r--r--nn/common/operations/Activation.cpp26
1 files changed, 16 insertions, 10 deletions
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp
index e3d848799..651cd020e 100644
--- a/nn/common/operations/Activation.cpp
+++ b/nn/common/operations/Activation.cpp
@@ -353,22 +353,23 @@ bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
} // namespace
-bool validate(OperationType opType, const IOperationValidationContext* context) {
+Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
if (opType == OperationType::TANH) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
}
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType;
}
@@ -376,21 +377,26 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return minSupportedVersion;
}
-bool validateHardSwish(const IOperationValidationContext* context) {
+Result<Version> validateHardSwish(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
inputType == OperandType::TENSOR_QUANT8_ASYMM ||
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
}
- return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
+ NN_RET_CHECK(validateInputTypes(context, {inputType}));
+ NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ return minSupportedVersion;
}
bool prepare(OperationType opType, IOperationExecutionContext* context) {