summaryrefslogtreecommitdiff
path: root/nn/common/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r--nn/common/Utils.cpp21
1 files changed, 10 insertions, 11 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index da4dbc87f..7417ed8bf 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -215,18 +215,15 @@ class OperationValidationContext : public IOperationValidationContext {
public:
OperationValidationContext(const char* operationName, uint32_t inputCount,
const uint32_t* inputIndexes, uint32_t outputCount,
- const uint32_t* outputIndexes, const Operand* operands,
- HalVersion halVersion)
+ const uint32_t* outputIndexes, const Operand* operands)
: operationName(operationName),
inputCount(inputCount),
inputIndexes(inputIndexes),
outputCount(outputCount),
outputIndexes(outputIndexes),
- operands(operands),
- version(convert(halVersion)) {}
+ operands(operands) {}
const char* getOperationName() const override;
- Version getVersion() const override;
uint32_t getNumInputs() const override;
OperandType getInputType(uint32_t index) const override;
@@ -254,10 +251,6 @@ const char* OperationValidationContext::getOperationName() const {
return operationName;
}
-Version OperationValidationContext::getVersion() const {
- return version;
-}
-
const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
CHECK(index < static_cast<uint32_t>(inputCount));
return &operands[inputIndexes[index]];
@@ -1883,8 +1876,14 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
}
OperationValidationContext context(operationRegistration->name, inputCount,
inputIndexes, outputCount, outputIndexes,
- operands.data(), halVersion);
- if (!operationRegistration->validate(&context)) {
+ operands.data());
+ const auto maybeVersion = operationRegistration->validate(&context);
+ if (!maybeVersion.has_value()) {
+ LOG(ERROR) << "Validation failed for operation " << opType << ": "
+ << maybeVersion.error();
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ if (!validateVersion(&context, convert(halVersion), maybeVersion.value())) {
LOG(ERROR) << "Validation failed for operation " << opType;
return ANEURALNETWORKS_BAD_DATA;
}