summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2020-11-02 23:17:11 -0800
committerMichael Butler <butlermichael@google.com>2020-11-06 18:15:52 -0800
commitca7a45a1234977c93fdfb578b64114d13ee27b7f (patch)
treee96b673d1baeb8853e54e1812cfb57a0036a4231
parentd6f4f1ed9fea50529006a0aa3436e4bce4decd05 (diff)
downloadml-ca7a45a1234977c93fdfb578b64114d13ee27b7f.tar.gz
Make operation validation return Result<Version>
Bug: N/A Test: mma Test: NeuralNetworksTest_static Change-Id: I47c12e13fcb41f832e31043b3f14e7b93472b0f8 Merged-In: I47c12e13fcb41f832e31043b3f14e7b93472b0f8 (cherry picked from commit 72146e08363add234a71d19accd96325cb77ce7c)
-rw-r--r--nn/common/OperationsUtils.cpp7
-rw-r--r--nn/common/Utils.cpp21
-rw-r--r--nn/common/Validation.cpp28
-rw-r--r--nn/common/include/OperationResolver.h11
-rw-r--r--nn/common/include/OperationsUtils.h16
-rw-r--r--nn/common/include/nnapi/TypeUtils.h2
-rw-r--r--nn/common/operations/Activation.cpp8
-rw-r--r--nn/common/operations/BidirectionalSequenceRNN.cpp9
-rw-r--r--nn/common/operations/Broadcast.cpp4
-rw-r--r--nn/common/operations/ChannelShuffle.cpp6
-rw-r--r--nn/common/operations/Comparisons.cpp6
-rw-r--r--nn/common/operations/Concatenation.cpp13
-rw-r--r--nn/common/operations/Conv2D.cpp4
-rw-r--r--nn/common/operations/DepthwiseConv2D.cpp4
-rw-r--r--nn/common/operations/Dequantize.cpp6
-rw-r--r--nn/common/operations/Elementwise.cpp15
-rw-r--r--nn/common/operations/Elu.cpp4
-rw-r--r--nn/common/operations/Fill.cpp4
-rw-r--r--nn/common/operations/FullyConnected.cpp5
-rw-r--r--nn/common/operations/Gather.cpp6
-rw-r--r--nn/common/operations/GenerateProposals.cpp26
-rw-r--r--nn/common/operations/HeatmapMaxKeypoint.cpp7
-rw-r--r--nn/common/operations/InstanceNormalization.cpp7
-rw-r--r--nn/common/operations/L2Normalization.cpp4
-rw-r--r--nn/common/operations/LocalResponseNormalization.cpp4
-rw-r--r--nn/common/operations/LogSoftmax.cpp7
-rw-r--r--nn/common/operations/LogicalAndOr.cpp4
-rw-r--r--nn/common/operations/LogicalNot.cpp4
-rw-r--r--nn/common/operations/Neg.cpp4
-rw-r--r--nn/common/operations/PRelu.cpp6
-rw-r--r--nn/common/operations/Pooling.cpp4
-rw-r--r--nn/common/operations/QLSTM.cpp4
-rw-r--r--nn/common/operations/Quantize.cpp6
-rw-r--r--nn/common/operations/Rank.cpp4
-rw-r--r--nn/common/operations/Reduce.cpp12
-rw-r--r--nn/common/operations/ResizeImageOps.cpp4
-rw-r--r--nn/common/operations/RoiAlign.cpp9
-rw-r--r--nn/common/operations/RoiPooling.cpp10
-rw-r--r--nn/common/operations/Select.cpp4
-rw-r--r--nn/common/operations/Slice.cpp4
-rw-r--r--nn/common/operations/Softmax.cpp4
-rw-r--r--nn/common/operations/Squeeze.cpp4
-rw-r--r--nn/common/operations/StridedSlice.cpp4
-rw-r--r--nn/common/operations/TopK_V2.cpp4
-rw-r--r--nn/common/operations/Transpose.cpp4
-rw-r--r--nn/common/operations/TransposeConv2D.cpp4
-rw-r--r--nn/common/operations/UnidirectionalSequenceLSTM.cpp4
-rw-r--r--nn/common/operations/UnidirectionalSequenceRNN.cpp9
48 files changed, 155 insertions, 195 deletions
diff --git a/nn/common/OperationsUtils.cpp b/nn/common/OperationsUtils.cpp
index d65566f96..c5a71e981 100644
--- a/nn/common/OperationsUtils.cpp
+++ b/nn/common/OperationsUtils.cpp
@@ -86,8 +86,9 @@ bool validateOutputTypes(const IOperationValidationContext* context,
[context](uint32_t index) { return context->getOutputType(index); });
}
-bool validateVersion(const IOperationValidationContext* context, Version minSupportedVersion) {
- if (context->getVersion() < minSupportedVersion) {
+bool validateVersion(const IOperationValidationContext* context, Version contextVersion,
+ Version minSupportedVersion) {
+ if (contextVersion < minSupportedVersion) {
std::ostringstream message;
message << "Operation " << context->getOperationName() << " with inputs {";
for (uint32_t i = 0, n = context->getNumInputs(); i < n; ++i) {
@@ -104,7 +105,7 @@ bool validateVersion(const IOperationValidationContext* context, Version minSupp
message << context->getOutputType(i);
}
message << "} is only supported since " << minSupportedVersion << " (validating using "
- << context->getVersion() << ")";
+ << contextVersion << ")";
NN_RET_CHECK_FAIL() << message.str();
}
return true;
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index da4dbc87f..7417ed8bf 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -215,18 +215,15 @@ class OperationValidationContext : public IOperationValidationContext {
public:
OperationValidationContext(const char* operationName, uint32_t inputCount,
const uint32_t* inputIndexes, uint32_t outputCount,
- const uint32_t* outputIndexes, const Operand* operands,
- HalVersion halVersion)
+ const uint32_t* outputIndexes, const Operand* operands)
: operationName(operationName),
inputCount(inputCount),
inputIndexes(inputIndexes),
outputCount(outputCount),
outputIndexes(outputIndexes),
- operands(operands),
- version(convert(halVersion)) {}
+ operands(operands) {}
const char* getOperationName() const override;
- Version getVersion() const override;
uint32_t getNumInputs() const override;
OperandType getInputType(uint32_t index) const override;
@@ -254,10 +251,6 @@ const char* OperationValidationContext::getOperationName() const {
return operationName;
}
-Version OperationValidationContext::getVersion() const {
- return version;
-}
-
const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
CHECK(index < static_cast<uint32_t>(inputCount));
return &operands[inputIndexes[index]];
@@ -1883,8 +1876,14 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
}
OperationValidationContext context(operationRegistration->name, inputCount,
inputIndexes, outputCount, outputIndexes,
- operands.data(), halVersion);
- if (!operationRegistration->validate(&context)) {
+ operands.data());
+ const auto maybeVersion = operationRegistration->validate(&context);
+ if (!maybeVersion.has_value()) {
+ LOG(ERROR) << "Validation failed for operation " << opType << ": "
+ << maybeVersion.error();
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ if (!validateVersion(&context, convert(halVersion), maybeVersion.value())) {
LOG(ERROR) << "Validation failed for operation " << opType;
return ANEURALNETWORKS_BAD_DATA;
}
diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp
index e49f073f2..d37c447c0 100644
--- a/nn/common/Validation.cpp
+++ b/nn/common/Validation.cpp
@@ -1181,15 +1181,13 @@ class OperationValidationContext : public IOperationValidationContext {
public:
OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes,
const std::vector<uint32_t>& outputIndexes,
- const std::vector<Operand>& operands, Version version)
+ const std::vector<Operand>& operands)
: operationName(operationName),
inputIndexes(inputIndexes),
outputIndexes(outputIndexes),
- operands(operands),
- version(version) {}
+ operands(operands) {}
const char* getOperationName() const override;
- Version getVersion() const override;
uint32_t getNumInputs() const override;
OperandType getInputType(uint32_t index) const override;
@@ -1208,17 +1206,12 @@ class OperationValidationContext : public IOperationValidationContext {
const std::vector<uint32_t>& inputIndexes;
const std::vector<uint32_t>& outputIndexes;
const std::vector<Operand>& operands;
- Version version;
};
const char* OperationValidationContext::getOperationName() const {
return operationName;
}
-Version OperationValidationContext::getVersion() const {
- return version;
-}
-
const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
return &operands.at(inputIndexes.at(index));
}
@@ -2521,20 +2514,9 @@ Result<Version> validateOperationImpl(const Operation& operation,
NN_VALIDATE(operationRegistration->validate != nullptr)
<< "Incomplete operation registration: " << opType;
- constexpr Version kVersions[] = {Version::ANDROID_OC_MR1, Version::ANDROID_P,
- Version::ANDROID_Q, Version::ANDROID_R,
- Version::CURRENT_RUNTIME};
-
- for (const auto version : kVersions) {
- OperationValidationContext context(operationRegistration->name, inputIndexes,
- outputIndexes, operands, version);
- auto valid = operationRegistration->validate(&context);
- if (valid) {
- return version;
- }
- }
-
- return NN_ERROR() << "Validation failed for operation " << opType;
+ OperationValidationContext context(operationRegistration->name, inputIndexes,
+ outputIndexes, operands);
+ return operationRegistration->validate(&context);
}
}
}
diff --git a/nn/common/include/OperationResolver.h b/nn/common/include/OperationResolver.h
index d2c066cd3..155341a1a 100644
--- a/nn/common/include/OperationResolver.h
+++ b/nn/common/include/OperationResolver.h
@@ -32,7 +32,7 @@ struct OperationRegistration {
const char* name;
// Validates operand types, shapes, and any values known during graph creation.
- std::function<bool(const IOperationValidationContext*)> validate;
+ std::function<Result<Version>(const IOperationValidationContext*)> validate;
// prepare is called when the inputs this operation depends on have been
// computed. Typically, prepare does any remaining validation and sets
@@ -50,10 +50,11 @@ struct OperationRegistration {
bool allowZeroSizedInput = false;
} flags;
- OperationRegistration(OperationType type, const char* name,
- std::function<bool(const IOperationValidationContext*)> validate,
- std::function<bool(IOperationExecutionContext*)> prepare,
- std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
+ OperationRegistration(
+ OperationType type, const char* name,
+ std::function<Result<Version>(const IOperationValidationContext*)> validate,
+ std::function<bool(IOperationExecutionContext*)> prepare,
+ std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
: type(type),
name(name),
validate(std::move(validate)),
diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h
index 676bbb34a..9123139b1 100644
--- a/nn/common/include/OperationsUtils.h
+++ b/nn/common/include/OperationsUtils.h
@@ -59,19 +59,6 @@ class IOperationValidationContext {
virtual const char* getOperationName() const = 0;
- // The version of the environment in which the operation is to be executed.
- //
- // Operation validation logic needs to handle all versions to support the following use cases
- // (assume in these examples that the latest version is Version::ANDROID_Q):
- // 1. Our runtime wants to distribute work to a driver implementing an older version and calls,
- // for example, compliantWithV1_0(const V1_2::Model&).
- // 2. A driver implements an older version and delegates model validation to, for example,
- // validateModel(const V1_0::Model&).
- //
- // If getVersion() returns Version::ANDROID_OC_MR1 and the operation is only supported since
- // Version::ANDROID_P, validation will fail.
- virtual Version getVersion() const = 0;
-
virtual uint32_t getNumInputs() const = 0;
virtual OperandType getInputType(uint32_t index) const = 0;
virtual Shape getInputShape(uint32_t index) const = 0;
@@ -130,7 +117,8 @@ bool validateOutputTypes(const IOperationValidationContext* context,
// Verifies that the HAL version specified in the context is greater or equal
// than the minimal supported HAL version.
-bool validateVersion(const IOperationValidationContext* context, Version minSupportedVersion);
+bool validateVersion(const IOperationValidationContext* context, Version contextVersion,
+ Version minSupportedVersion);
// Verifies that the two shapes are the same.
bool SameShape(const Shape& in1, const Shape& in2);
diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h
index 9dc67cf46..56b62f9b2 100644
--- a/nn/common/include/nnapi/TypeUtils.h
+++ b/nn/common/include/nnapi/TypeUtils.h
@@ -198,6 +198,8 @@ class FalseyErrorStream {
operator bool() const { return false; }
+ operator Result<Version>() const { return error() << mBuffer.str(); }
+
private:
std::ostringstream mBuffer;
};
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp
index bcf846a80..651cd020e 100644
--- a/nn/common/operations/Activation.cpp
+++ b/nn/common/operations/Activation.cpp
@@ -353,7 +353,7 @@ bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
} // namespace
-bool validate(OperationType opType, const IOperationValidationContext* context) {
+Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
@@ -379,10 +379,10 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
NN_RET_CHECK(validateInputTypes(context, {inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
-bool validateHardSwish(const IOperationValidationContext* context) {
+Result<Version> validateHardSwish(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
@@ -396,7 +396,7 @@ bool validateHardSwish(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, {inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(OperationType opType, IOperationExecutionContext* context) {
diff --git a/nn/common/operations/BidirectionalSequenceRNN.cpp b/nn/common/operations/BidirectionalSequenceRNN.cpp
index f6b4c301c..5a020d1f7 100644
--- a/nn/common/operations/BidirectionalSequenceRNN.cpp
+++ b/nn/common/operations/BidirectionalSequenceRNN.cpp
@@ -313,7 +313,7 @@ bool executeTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
// Exact number is dependent on the mergeOutputs parameter and checked
// during preparation.
@@ -323,9 +323,8 @@ bool validate(const IOperationValidationContext* context) {
OperandType inputType = context->getInputType(kInputTensor);
if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
- LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
- << inputType;
- return false;
+ return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
+ << inputType;
}
NN_RET_CHECK(validateInputTypes(
context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
@@ -339,7 +338,7 @@ bool validate(const IOperationValidationContext* context) {
if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) {
minSupportedVersion = Version::ANDROID_R;
}
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp
index e47bd21e8..a2d5b8a39 100644
--- a/nn/common/operations/Broadcast.cpp
+++ b/nn/common/operations/Broadcast.cpp
@@ -434,7 +434,7 @@ bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, c
} // namespace
-bool validate(OperationType opType, const IOperationValidationContext* context) {
+Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
auto minSupportedVersion = (opType == OperationType::DIV || opType == OperationType::SUB)
? Version::ANDROID_P
: Version::ANDROID_OC_MR1;
@@ -473,7 +473,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, OperandType::INT32}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/ChannelShuffle.cpp b/nn/common/operations/ChannelShuffle.cpp
index 59726fac2..efa08737b 100644
--- a/nn/common/operations/ChannelShuffle.cpp
+++ b/nn/common/operations/ChannelShuffle.cpp
@@ -57,7 +57,7 @@ inline bool eval(const T* inputData, const Shape& inputShape, int32_t numGroups,
return true;
}
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
@@ -73,9 +73,9 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::INT32, OperandType::INT32}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/Comparisons.cpp b/nn/common/operations/Comparisons.cpp
index 8fdf72c59..b490c9218 100644
--- a/nn/common/operations/Comparisons.cpp
+++ b/nn/common/operations/Comparisons.cpp
@@ -123,7 +123,7 @@ bool executeGreaterTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor1);
@@ -136,9 +136,9 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
NN_RET_CHECK(validateOutputTypes(context, {OperandType::TENSOR_BOOL8}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/Concatenation.cpp b/nn/common/operations/Concatenation.cpp
index 16a08d6b9..6b9007e5e 100644
--- a/nn/common/operations/Concatenation.cpp
+++ b/nn/common/operations/Concatenation.cpp
@@ -29,6 +29,7 @@
#include "CpuOperationUtils.h"
#include "OperationResolver.h"
#include "Tracing.h"
+#include "nnapi/Validation.h"
namespace android {
namespace nn {
@@ -135,7 +136,7 @@ inline bool concatenation<int8_t>(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
uint32_t inputCount = context->getNumInputs();
NN_RET_CHECK_GE(inputCount, 2);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -152,13 +153,13 @@ bool validate(const IOperationValidationContext* context) {
}
std::vector<OperandType> inExpectedTypes(inputCount - 1, inputType);
inExpectedTypes.push_back(OperandType::INT32);
- if (context->getVersion() < Version::ANDROID_Q &&
- inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
const Shape& output = context->getOutputShape(kOutputTensor);
for (uint32_t i = 0; i < inputCount - 1; ++i) {
const Shape& input = context->getInputShape(i);
- NN_RET_CHECK_EQ(input.scale, output.scale);
- NN_RET_CHECK_EQ(input.offset, output.offset);
+ if (input.scale != output.scale || input.offset != output.offset) {
+ minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
+ }
}
}
for (uint32_t i = 0; i < inputCount - 1; ++i) {
@@ -169,7 +170,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp
index d00da57a9..6d989827e 100644
--- a/nn/common/operations/Conv2D.cpp
+++ b/nn/common/operations/Conv2D.cpp
@@ -526,7 +526,7 @@ bool convQuant8PerChannel(const T* inputData, const Shape& inputShape, const int
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
const uint32_t numInputs = context->getNumInputs();
NN_RET_CHECK(
std::binary_search(std::begin(kNumInputsArray), std::end(kNumInputsArray), numInputs));
@@ -624,7 +624,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/DepthwiseConv2D.cpp b/nn/common/operations/DepthwiseConv2D.cpp
index bb158b328..64bd7dd4a 100644
--- a/nn/common/operations/DepthwiseConv2D.cpp
+++ b/nn/common/operations/DepthwiseConv2D.cpp
@@ -413,7 +413,7 @@ bool depthwiseConvQuant8PerChannel(const T* inputData, const Shape& inputShape,
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
const uint32_t numInputs = context->getNumInputs();
NN_RET_CHECK(
std::binary_search(std::begin(kNumInputsArray), std::end(kNumInputsArray), numInputs));
@@ -507,7 +507,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Dequantize.cpp b/nn/common/operations/Dequantize.cpp
index f155eb286..b648ff135 100644
--- a/nn/common/operations/Dequantize.cpp
+++ b/nn/common/operations/Dequantize.cpp
@@ -75,7 +75,7 @@ bool computePerChannel(const int8_t* inputData, const Shape& inputShape, OutputT
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -89,7 +89,7 @@ bool validate(const IOperationValidationContext* context) {
if (inputType == OperandType::TENSOR_QUANT8_ASYMM &&
outputType == OperandType::TENSOR_FLOAT32) {
- return validateVersion(context, Version::ANDROID_OC_MR1);
+ return Version::ANDROID_OC_MR1;
}
NN_RET_CHECK(inputType == OperandType::TENSOR_QUANT8_ASYMM ||
@@ -100,7 +100,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(outputType == OperandType::TENSOR_FLOAT16 ||
outputType == OperandType::TENSOR_FLOAT32)
<< "Unsupported output operand type for DEQUANTIZE op: " << outputType;
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Elementwise.cpp b/nn/common/operations/Elementwise.cpp
index a0cd78ffe..851000392 100644
--- a/nn/common/operations/Elementwise.cpp
+++ b/nn/common/operations/Elementwise.cpp
@@ -82,7 +82,7 @@ bool executeAbs(IOperationExecutionContext* context) {
}
}
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -91,10 +91,10 @@ bool validate(const IOperationValidationContext* context) {
<< "Unsupported tensor type for elementwise operation";
NN_RET_CHECK(validateInputTypes(context, {inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
-bool validateAbs(const IOperationValidationContext* context) {
+Result<Version> validateAbs(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -103,11 +103,10 @@ bool validateAbs(const IOperationValidationContext* context) {
<< "Unsupported tensor type for operation ABS";
NN_RET_CHECK(validateInputTypes(context, {inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, (inputType == OperandType::TENSOR_INT32 ? Version::ANDROID_R
- : Version::ANDROID_Q));
+ return inputType == OperandType::TENSOR_INT32 ? Version::ANDROID_R : Version::ANDROID_Q;
}
-bool validateFloor(const IOperationValidationContext* context) {
+Result<Version> validateFloor(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -123,9 +122,7 @@ bool validateFloor(const IOperationValidationContext* context) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateVersion(context,
- (inputType == OperandType::TENSOR_FLOAT16 ? Version::ANDROID_Q
- : Version::ANDROID_OC_MR1));
+ return inputType == OperandType::TENSOR_FLOAT16 ? Version::ANDROID_Q : Version::ANDROID_OC_MR1;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Elu.cpp b/nn/common/operations/Elu.cpp
index 105ef01cf..98e066210 100644
--- a/nn/common/operations/Elu.cpp
+++ b/nn/common/operations/Elu.cpp
@@ -52,7 +52,7 @@ bool eluFloat(const T* inputData, const Shape& inputShape, const T alpha, T* out
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
@@ -66,7 +66,7 @@ bool validate(const IOperationValidationContext* context) {
inputType == OperandType::TENSOR_FLOAT16 ? OperandType::FLOAT16 : OperandType::FLOAT32;
NN_RET_CHECK(validateInputTypes(context, {inputType, scalarType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Fill.cpp b/nn/common/operations/Fill.cpp
index f3b470ed5..9af64f7de 100644
--- a/nn/common/operations/Fill.cpp
+++ b/nn/common/operations/Fill.cpp
@@ -61,7 +61,7 @@ bool getValueType(OperandType outputType, OperandType* valueType) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
// Check output type first because input value type is dependent on the
@@ -77,7 +77,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(getValueType(outputType, &valueType));
NN_RET_CHECK(validateInputTypes(context, {OperandType::TENSOR_INT32, valueType}));
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp
index 873b64abf..ab50d31c1 100644
--- a/nn/common/operations/FullyConnected.cpp
+++ b/nn/common/operations/FullyConnected.cpp
@@ -217,7 +217,7 @@ bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias,
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
@@ -272,7 +272,6 @@ bool validate(const IOperationValidationContext* context) {
};
} else {
NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
- return false;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
@@ -284,7 +283,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateShapes(input, weights, bias));
}
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Gather.cpp b/nn/common/operations/Gather.cpp
index 6707b6d94..5571a6501 100644
--- a/nn/common/operations/Gather.cpp
+++ b/nn/common/operations/Gather.cpp
@@ -59,7 +59,7 @@ inline bool eval(const T* inputData, const Shape& inputShape, int32_t axis,
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -73,9 +73,9 @@ bool validate(const IOperationValidationContext* context) {
{inputType, OperandType::INT32, OperandType::TENSOR_INT32}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/GenerateProposals.cpp b/nn/common/operations/GenerateProposals.cpp
index edd7cb0db..95e3676e0 100644
--- a/nn/common/operations/GenerateProposals.cpp
+++ b/nn/common/operations/GenerateProposals.cpp
@@ -197,7 +197,7 @@ constexpr uint32_t kImageInfoTensor = 3;
constexpr uint32_t kNumOutputs = 1;
constexpr uint32_t kOutputTensor = 0;
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -211,16 +211,14 @@ bool validate(const IOperationValidationContext* context) {
inExpectedTypes = {OperandType::TENSOR_QUANT16_ASYMM, deltaInputType,
OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_ASYMM};
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
@@ -703,7 +701,7 @@ bool boxWithNmsLimitQuant(const int8_t* scoresData, const Shape& scoresShape,
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -742,9 +740,9 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
@@ -1213,7 +1211,7 @@ bool generateProposalsQuant(const T_8QInput* scoresData, const Shape& scoresShap
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -1268,9 +1266,9 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
@@ -1569,7 +1567,7 @@ bool detectionPostprocessFloat16(
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -1597,7 +1595,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(
context, {inputType, inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/HeatmapMaxKeypoint.cpp b/nn/common/operations/HeatmapMaxKeypoint.cpp
index 1da7ed07f..63fc5973b 100644
--- a/nn/common/operations/HeatmapMaxKeypoint.cpp
+++ b/nn/common/operations/HeatmapMaxKeypoint.cpp
@@ -224,7 +224,7 @@ inline bool heatmapMaxKeypointQuant(const int8_t* heatmap, const Shape& heatmapS
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -245,12 +245,11 @@ bool validate(const IOperationValidationContext* context) {
OperandType::TENSOR_QUANT16_ASYMM};
minSupportedVersion = Version::ANDROID_R;
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/InstanceNormalization.cpp b/nn/common/operations/InstanceNormalization.cpp
index 62b7728f8..1a0e488e9 100644
--- a/nn/common/operations/InstanceNormalization.cpp
+++ b/nn/common/operations/InstanceNormalization.cpp
@@ -99,7 +99,7 @@ inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -111,12 +111,11 @@ bool validate(const IOperationValidationContext* context) {
inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16,
OperandType::FLOAT16, OperandType::BOOL};
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/L2Normalization.cpp b/nn/common/operations/L2Normalization.cpp
index 22f0cb3d2..05682ea3a 100644
--- a/nn/common/operations/L2Normalization.cpp
+++ b/nn/common/operations/L2Normalization.cpp
@@ -196,7 +196,7 @@ bool l2normQuant8Signed(const int8_t* inputData, const Shape& inputShape, int32_
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK(context->getNumInputs() == kNumInputs ||
context->getNumInputs() == kNumInputs - 1);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -225,7 +225,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/LocalResponseNormalization.cpp b/nn/common/operations/LocalResponseNormalization.cpp
index 435d602f2..ed16dec6b 100644
--- a/nn/common/operations/LocalResponseNormalization.cpp
+++ b/nn/common/operations/LocalResponseNormalization.cpp
@@ -130,7 +130,7 @@ bool executeTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK(context->getNumInputs() == kNumInputs ||
context->getNumInputs() == kNumInputs - 1);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -170,7 +170,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/LogSoftmax.cpp b/nn/common/operations/LogSoftmax.cpp
index 86a882fd5..6fe934a3d 100644
--- a/nn/common/operations/LogSoftmax.cpp
+++ b/nn/common/operations/LogSoftmax.cpp
@@ -70,7 +70,7 @@ inline bool compute(const T* input, const Shape& shape, T beta, uint32_t axis, T
return true;
}
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -83,12 +83,11 @@ bool validate(const IOperationValidationContext* context) {
inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::INT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/LogicalAndOr.cpp b/nn/common/operations/LogicalAndOr.cpp
index 163aa5432..e1927a59f 100644
--- a/nn/common/operations/LogicalAndOr.cpp
+++ b/nn/common/operations/LogicalAndOr.cpp
@@ -60,7 +60,7 @@ bool compute(const std::function<bool(bool, bool)>& func, const bool8* aData, co
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor1);
@@ -68,7 +68,7 @@ bool validate(const IOperationValidationContext* context) {
<< "Unsupported tensor type for a logical operation";
NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/LogicalNot.cpp b/nn/common/operations/LogicalNot.cpp
index 2f6bb637b..b93e71b90 100644
--- a/nn/common/operations/LogicalNot.cpp
+++ b/nn/common/operations/LogicalNot.cpp
@@ -41,7 +41,7 @@ bool compute(const bool8* input, const Shape& shape, bool8* output) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -49,7 +49,7 @@ bool validate(const IOperationValidationContext* context) {
<< "Unsupported tensor type for LOGICAL_NOT";
NN_RET_CHECK(validateInputTypes(context, {inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Neg.cpp b/nn/common/operations/Neg.cpp
index 1d042fcbe..39b58b9b2 100644
--- a/nn/common/operations/Neg.cpp
+++ b/nn/common/operations/Neg.cpp
@@ -47,7 +47,7 @@ inline bool compute(const T* input, const Shape& shape, T* output) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -56,7 +56,7 @@ bool validate(const IOperationValidationContext* context) {
<< "Unsupported tensor type for operation " << kOperationName;
NN_RET_CHECK(validateInputTypes(context, {inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/PRelu.cpp b/nn/common/operations/PRelu.cpp
index db2f6d416..88e38fcf4 100644
--- a/nn/common/operations/PRelu.cpp
+++ b/nn/common/operations/PRelu.cpp
@@ -95,7 +95,7 @@ bool evalQuant8(const T* aData, const Shape& aShape, const T* bData, const Shape
aData, aShape, bData, bShape, outputData, outputShape);
}
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor);
@@ -107,9 +107,9 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/Pooling.cpp b/nn/common/operations/Pooling.cpp
index e1c4cdcd5..6cd286439 100644
--- a/nn/common/operations/Pooling.cpp
+++ b/nn/common/operations/Pooling.cpp
@@ -288,7 +288,7 @@ bool maxPool(const T* inputData, const Shape& inputShape, const PoolingParam& pa
} // namespace
-bool validate(OperationType opType, const IOperationValidationContext* context) {
+Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputCount = context->getNumInputs();
NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7);
@@ -349,7 +349,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/QLSTM.cpp b/nn/common/operations/QLSTM.cpp
index 0812e6661..e8c4f90f6 100644
--- a/nn/common/operations/QLSTM.cpp
+++ b/nn/common/operations/QLSTM.cpp
@@ -101,7 +101,7 @@ inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -149,7 +149,7 @@ bool validate(const IOperationValidationContext* context) {
outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Quantize.cpp b/nn/common/operations/Quantize.cpp
index c3f4812c2..b9d37b8da 100644
--- a/nn/common/operations/Quantize.cpp
+++ b/nn/common/operations/Quantize.cpp
@@ -63,7 +63,7 @@ bool quantizeToQuant8Signed(const T* inputData, int8_t* outputData, const Shape&
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -77,9 +77,9 @@ bool validate(const IOperationValidationContext* context) {
outputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
<< "Unsupported output operand type for QUANTIZE op: " << outputType;
if (outputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/Rank.cpp b/nn/common/operations/Rank.cpp
index 71951d703..f6363417a 100644
--- a/nn/common/operations/Rank.cpp
+++ b/nn/common/operations/Rank.cpp
@@ -30,7 +30,7 @@ constexpr uint32_t kInputTensor = 0;
constexpr uint32_t kNumOutputs = 1;
constexpr uint32_t kOutputScalar = 0;
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -46,7 +46,7 @@ bool validate(const IOperationValidationContext* context) {
inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
<< "Incorrect input type for a RANK op: " << inputType;
NN_RET_CHECK(validateOutputTypes(context, {OperandType::INT32}));
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Reduce.cpp b/nn/common/operations/Reduce.cpp
index 0563a3536..9eb195648 100644
--- a/nn/common/operations/Reduce.cpp
+++ b/nn/common/operations/Reduce.cpp
@@ -66,7 +66,7 @@ inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) {
} // namespace
-bool validateProdSum(const IOperationValidationContext* context) {
+Result<Version> validateProdSum(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -80,10 +80,10 @@ bool validateProdSum(const IOperationValidationContext* context) {
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
-bool validateMaxMin(const IOperationValidationContext* context) {
+Result<Version> validateMaxMin(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -103,10 +103,10 @@ bool validateMaxMin(const IOperationValidationContext* context) {
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateVersion(context, minVersion);
+ return minVersion;
}
-bool validateLogical(const IOperationValidationContext* context) {
+Result<Version> validateLogical(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -119,7 +119,7 @@ bool validateLogical(const IOperationValidationContext* context) {
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/ResizeImageOps.cpp b/nn/common/operations/ResizeImageOps.cpp
index 2c923f852..733bedb69 100644
--- a/nn/common/operations/ResizeImageOps.cpp
+++ b/nn/common/operations/ResizeImageOps.cpp
@@ -169,7 +169,7 @@ inline bool getOptionalScalar(const IOperationExecutionContext* context, uint32_
} // namespace
-bool validate(OperationType opType, const IOperationValidationContext* context) {
+Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
const auto numInputs = context->getNumInputs();
if (opType == OperationType::RESIZE_BILINEAR) {
NN_RET_CHECK(numInputs >= kNumInputs - 1 && numInputs <= kNumInputs + kNumOptionalInputs);
@@ -218,7 +218,7 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(OperationType opType, IOperationExecutionContext* context) {
diff --git a/nn/common/operations/RoiAlign.cpp b/nn/common/operations/RoiAlign.cpp
index 78049b8bb..3ca64f56a 100644
--- a/nn/common/operations/RoiAlign.cpp
+++ b/nn/common/operations/RoiAlign.cpp
@@ -337,7 +337,7 @@ inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -367,15 +367,14 @@ bool validate(const IOperationValidationContext* context) {
OperandType::INT32,
OperandType::BOOL};
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/RoiPooling.cpp b/nn/common/operations/RoiPooling.cpp
index a011b4ae2..26e2213a3 100644
--- a/nn/common/operations/RoiPooling.cpp
+++ b/nn/common/operations/RoiPooling.cpp
@@ -184,7 +184,7 @@ inline bool roiPooling<int8_t, uint16_t>(const int8_t* inputData, const Shape& i
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
std::vector<OperandType> inExpectedTypes;
@@ -210,16 +210,14 @@ bool validate(const IOperationValidationContext* context) {
OperandType::FLOAT32,
OperandType::BOOL};
} else {
- LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
- return false;
+ return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
- return validateVersion(context, Version::ANDROID_R);
- ;
+ return Version::ANDROID_R;
} else {
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
}
diff --git a/nn/common/operations/Select.cpp b/nn/common/operations/Select.cpp
index 0b7728ab9..f037b4810 100644
--- a/nn/common/operations/Select.cpp
+++ b/nn/common/operations/Select.cpp
@@ -66,7 +66,7 @@ bool executeTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor1);
@@ -78,7 +78,7 @@ bool validate(const IOperationValidationContext* context) {
<< "Unsupported input operand type for select op: " << inputType;
NN_RET_CHECK(validateInputTypes(context, {OperandType::TENSOR_BOOL8, inputType, inputType}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, Version::ANDROID_Q);
+ return Version::ANDROID_Q;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Slice.cpp b/nn/common/operations/Slice.cpp
index 882b0eb1a..db47419f7 100644
--- a/nn/common/operations/Slice.cpp
+++ b/nn/common/operations/Slice.cpp
@@ -78,7 +78,7 @@ bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t* beg
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -98,7 +98,7 @@ bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK(validateInputTypes(
context, {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Softmax.cpp b/nn/common/operations/Softmax.cpp
index e3c362f8a..3e65d85bf 100644
--- a/nn/common/operations/Softmax.cpp
+++ b/nn/common/operations/Softmax.cpp
@@ -227,7 +227,7 @@ bool softmaxQuant8(const T* inputData, const Shape& inputShape, const float beta
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK(context->getNumInputs() == kNumInputs ||
context->getNumInputs() == kNumInputs - 1);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -260,7 +260,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Squeeze.cpp b/nn/common/operations/Squeeze.cpp
index e9640b964..2fe8eb8aa 100644
--- a/nn/common/operations/Squeeze.cpp
+++ b/nn/common/operations/Squeeze.cpp
@@ -35,7 +35,7 @@ constexpr uint32_t kSqueezeDims = 1;
constexpr uint32_t kNumOutputs = 1;
constexpr uint32_t kOutputTensor = 0;
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -63,7 +63,7 @@ bool validate(const IOperationValidationContext* context) {
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/StridedSlice.cpp b/nn/common/operations/StridedSlice.cpp
index 654659ac1..fd66ca7c4 100644
--- a/nn/common/operations/StridedSlice.cpp
+++ b/nn/common/operations/StridedSlice.cpp
@@ -96,7 +96,7 @@ bool executeTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -129,7 +129,7 @@ bool validate(const IOperationValidationContext* context) {
if (hasKnownRank(input)) {
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
}
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/TopK_V2.cpp b/nn/common/operations/TopK_V2.cpp
index d91c8131e..d19a309b3 100644
--- a/nn/common/operations/TopK_V2.cpp
+++ b/nn/common/operations/TopK_V2.cpp
@@ -73,7 +73,7 @@ bool executeTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
OperandType inputType = context->getInputType(kInputTensor);
@@ -89,7 +89,7 @@ bool validate(const IOperationValidationContext* context) {
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
minSupportedVersion = Version::ANDROID_R;
}
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/Transpose.cpp b/nn/common/operations/Transpose.cpp
index b964c39a8..0e61575eb 100644
--- a/nn/common/operations/Transpose.cpp
+++ b/nn/common/operations/Transpose.cpp
@@ -69,7 +69,7 @@ bool transposeGeneric(const T* inputData, const Shape& inputShape, const int32_t
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -90,7 +90,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/TransposeConv2D.cpp b/nn/common/operations/TransposeConv2D.cpp
index 9d6dbbbfc..002df2780 100644
--- a/nn/common/operations/TransposeConv2D.cpp
+++ b/nn/common/operations/TransposeConv2D.cpp
@@ -433,7 +433,7 @@ bool transposeConvQuant8PerChannel(const T* inputData, const Shape& inputShape,
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
const uint32_t inputCount = context->getNumInputs();
NN_RET_CHECK(inputCount == kNumInputs1 || inputCount == kNumInputs2);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
@@ -476,7 +476,7 @@ bool validate(const IOperationValidationContext* context) {
inExpectedTypes.insert(inExpectedTypes.end(), argExpectedTypes.begin(), argExpectedTypes.end());
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
- return validateVersion(context, minSupportedVersion);
+ return minSupportedVersion;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/UnidirectionalSequenceLSTM.cpp b/nn/common/operations/UnidirectionalSequenceLSTM.cpp
index 02da1581f..dc734e8a4 100644
--- a/nn/common/operations/UnidirectionalSequenceLSTM.cpp
+++ b/nn/common/operations/UnidirectionalSequenceLSTM.cpp
@@ -112,7 +112,7 @@ inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
const uint32_t numOutputs = context->getNumOutputs();
NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
@@ -163,7 +163,7 @@ bool validate(const IOperationValidationContext* context) {
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
- return validateVersion(context, minVersionSupported);
+ return minVersionSupported;
}
bool prepare(IOperationExecutionContext* context) {
diff --git a/nn/common/operations/UnidirectionalSequenceRNN.cpp b/nn/common/operations/UnidirectionalSequenceRNN.cpp
index 382aa58e3..eaf60edd3 100644
--- a/nn/common/operations/UnidirectionalSequenceRNN.cpp
+++ b/nn/common/operations/UnidirectionalSequenceRNN.cpp
@@ -126,15 +126,14 @@ bool executeTyped(IOperationExecutionContext* context) {
} // namespace
-bool validate(const IOperationValidationContext* context) {
+Result<Version> validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
const int numOutputs = context->getNumOutputs();
NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
OperandType inputType = context->getInputType(kInputTensor);
if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
- LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
- << inputType;
- return false;
+ return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
+ << inputType;
}
NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
OperandType::INT32, OperandType::INT32}));
@@ -145,7 +144,7 @@ bool validate(const IOperationValidationContext* context) {
outputTypes.push_back(inputType);
}
NN_RET_CHECK(validateOutputTypes(context, outputTypes));
- return validateVersion(context, minVersionSupported);
+ return minVersionSupported;
}
bool prepare(IOperationExecutionContext* context) {