diff options
Diffstat (limited to 'nn/common/operations/FullyConnected.cpp')
-rw-r--r-- | nn/common/operations/FullyConnected.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp index 7c8c4e304..ab50d31c1 100644 --- a/nn/common/operations/FullyConnected.cpp +++ b/nn/common/operations/FullyConnected.cpp @@ -217,14 +217,15 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias, } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; + auto minSupportedVersion = Version::ANDROID_OC_MR1; if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, @@ -232,7 +233,7 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, @@ -249,9 +250,9 @@ bool validate(const IOperationValidationContext* context) { bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale); if (!meetsQuantizedScaleConstraintBeforeV1_2) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q)); + minSupportedVersion = Version::ANDROID_Q; } else { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1)); + minSupportedVersion = Version::ANDROID_OC_MR1; } inExpectedTypes = { @@ -261,7 +262,7 @@ bool validate(const IOperationValidationContext* context) { OperandType::INT32, }; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { - NN_RET_CHECK(validateVersion(context, Version::ANDROID_R)); + minSupportedVersion = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, @@ -271,7 +272,6 @@ bool validate(const IOperationValidationContext* context) { }; } else { NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName; - return false; } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); @@ -283,7 +283,7 @@ bool validate(const IOperationValidationContext* context) { NN_RET_CHECK(validateShapes(input, weights, bias)); } - return true; + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { |