diff options
Diffstat (limited to 'nn/common/Utils.cpp')
-rw-r--r-- | nn/common/Utils.cpp | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index da4dbc87f..7417ed8bf 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -215,18 +215,15 @@ class OperationValidationContext : public IOperationValidationContext { public: OperationValidationContext(const char* operationName, uint32_t inputCount, const uint32_t* inputIndexes, uint32_t outputCount, - const uint32_t* outputIndexes, const Operand* operands, - HalVersion halVersion) + const uint32_t* outputIndexes, const Operand* operands) : operationName(operationName), inputCount(inputCount), inputIndexes(inputIndexes), outputCount(outputCount), outputIndexes(outputIndexes), - operands(operands), - version(convert(halVersion)) {} + operands(operands) {} const char* getOperationName() const override; - Version getVersion() const override; uint32_t getNumInputs() const override; OperandType getInputType(uint32_t index) const override; @@ -254,10 +251,6 @@ const char* OperationValidationContext::getOperationName() const { return operationName; } -Version OperationValidationContext::getVersion() const { - return version; -} - const Operand* OperationValidationContext::getInputOperand(uint32_t index) const { CHECK(index < static_cast<uint32_t>(inputCount)); return &operands[inputIndexes[index]]; @@ -1883,8 +1876,14 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } OperationValidationContext context(operationRegistration->name, inputCount, inputIndexes, outputCount, outputIndexes, - operands.data(), halVersion); - if (!operationRegistration->validate(&context)) { + operands.data()); + const auto maybeVersion = operationRegistration->validate(&context); + if (!maybeVersion.has_value()) { + LOG(ERROR) << "Validation failed for operation " << opType << ": " + << maybeVersion.error(); + return ANEURALNETWORKS_BAD_DATA; + } + if (!validateVersion(&context, convert(halVersion), maybeVersion.value())) { LOG(ERROR) << "Validation failed for operation " << opType; return ANEURALNETWORKS_BAD_DATA; } |