diff options
Diffstat (limited to 'nn/common/operations/Concatenation.cpp')
-rw-r--r-- | nn/common/operations/Concatenation.cpp | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/nn/common/operations/Concatenation.cpp b/nn/common/operations/Concatenation.cpp index 16a08d6b9..6b9007e5e 100644 --- a/nn/common/operations/Concatenation.cpp +++ b/nn/common/operations/Concatenation.cpp @@ -29,6 +29,7 @@ #include "CpuOperationUtils.h" #include "OperationResolver.h" #include "Tracing.h" +#include "nnapi/Validation.h" namespace android { namespace nn { @@ -135,7 +136,7 @@ inline bool concatenation<int8_t>(IOperationExecutionContext* context) { } // namespace -bool validate(const IOperationValidationContext* context) { +Result<Version> validate(const IOperationValidationContext* context) { uint32_t inputCount = context->getNumInputs(); NN_RET_CHECK_GE(inputCount, 2); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); @@ -152,13 +153,13 @@ bool validate(const IOperationValidationContext* context) { } std::vector<OperandType> inExpectedTypes(inputCount - 1, inputType); inExpectedTypes.push_back(OperandType::INT32); - if (context->getVersion() < Version::ANDROID_Q && - inputType == OperandType::TENSOR_QUANT8_ASYMM) { + if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { const Shape& output = context->getOutputShape(kOutputTensor); for (uint32_t i = 0; i < inputCount - 1; ++i) { const Shape& input = context->getInputShape(i); - NN_RET_CHECK_EQ(input.scale, output.scale); - NN_RET_CHECK_EQ(input.offset, output.offset); + if (input.scale != output.scale || input.offset != output.offset) { + minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q); + } } } for (uint32_t i = 0; i < inputCount - 1; ++i) { @@ -169,7 +170,7 @@ bool validate(const IOperationValidationContext* context) { } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, {inputType})); - return validateVersion(context, minSupportedVersion); + return minSupportedVersion; } bool prepare(IOperationExecutionContext* context) { |