summaryrefslogtreecommitdiff
path: root/nn/common/operations/Concatenation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/Concatenation.cpp')
-rw-r--r--nn/common/operations/Concatenation.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/nn/common/operations/Concatenation.cpp b/nn/common/operations/Concatenation.cpp
index 16a08d6b9..6b9007e5e 100644
--- a/nn/common/operations/Concatenation.cpp
+++ b/nn/common/operations/Concatenation.cpp
@@ -29,6 +29,7 @@
#include "CpuOperationUtils.h"
#include "OperationResolver.h"
#include "Tracing.h"
+#include "nnapi/Validation.h"
namespace android {
namespace nn {
@@ -135,7 +136,7 @@ inline bool concatenation<int8_t>(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
uint32_t inputCount = context->getNumInputs();
NN_RET_CHECK_GE(inputCount, 2);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -152,13 +153,13 @@ bool validate(const IOperationValidationContext* context) {
}
std::vector<OperandType> inExpectedTypes(inputCount - 1, inputType);
inExpectedTypes.push_back(OperandType::INT32);
- if (context->getVersion() < Version::ANDROID_Q &&
- inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
const Shape& output = context->getOutputShape(kOutputTensor);
for (uint32_t i = 0; i < inputCount - 1; ++i) {
const Shape& input = context->getInputShape(i);
- NN_RET_CHECK_EQ(input.scale, output.scale);
- NN_RET_CHECK_EQ(input.offset, output.offset);
+ if (input.scale != output.scale || input.offset != output.offset) {
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
+ }
}
}
for (uint32_t i = 0; i < inputCount - 1; ++i) {
@@ -169,7 +170,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {