summaryrefslogtreecommitdiff
path: root/nn/common/operations/FullyConnected.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/FullyConnected.cpp')
-rw-r--r--nn/common/operations/FullyConnected.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp
index 7c8c4e304..ab50d31c1 100644
--- a/nn/common/operations/FullyConnected.cpp
+++ b/nn/common/operations/FullyConnected.cpp
@@ -217,14 +217,15 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias,
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
+ auto minSupportedVersion = Version::ANDROID_OC_MR1;
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
inExpectedTypes = {
OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32,
@@ -232,7 +233,7 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32,
};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
inExpectedTypes = {
OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16,
@@ -249,9 +250,9 @@ bool validate(const IOperationValidationContext* context) {
bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
if (!meetsQuantizedScaleConstraintBeforeV1_2) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_Q));
+ minSupportedVersion = Version::ANDROID_Q;
} else {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_OC_MR1));
+ minSupportedVersion = Version::ANDROID_OC_MR1;
}
inExpectedTypes = {
@@ -261,7 +262,7 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32,
};
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- NN_RET_CHECK(validateVersion(context, Version::ANDROID_R));
+ minSupportedVersion = Version::ANDROID_R;
inExpectedTypes = {
OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
@@ -271,7 +272,6 @@ bool validate(const IOperationValidationContext* context) {
};
} else {
NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
- return false;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
@@ -283,7 +283,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateShapes(input, weights, bias));
}
- return true;
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {