diff options
author | David Gross <dgross@google.com> | 2019-04-22 12:27:19 -0700 |
---|---|---|
committer | David Gross <dgross@google.com> | 2019-04-23 12:12:03 -0700 |
commit | ad3e6388f0c4b18550356c1675e99cd20723f7b4 (patch) | |
tree | 3d48b31dc3ade033f2799713d2d88a2b07e4b1c6 /nn/common/operations/Broadcast.cpp | |
parent | 99c70afc22e051ab0737fc2f2ec1dc51dd5af5a9 (diff) | |
download | ml-ad3e6388f0c4b18550356c1675e99cd20723f7b4.tar.gz |
Compliance/Validation: DIV and SUB are only available at V1_1 and later.
Also make NN_RET_CHECK_FAIL() more verbose.
Bug: 130917878
Test: NeuralNetworksTest_static
Merged-In: Id70c4e6579da825e69dc0e735735cfa8bc254229
Change-Id: Id70c4e6579da825e69dc0e735735cfa8bc254229
(cherry picked from commit 4d4f8e7a54d0b9f5cff4912826eaa2daff41470f)
Diffstat (limited to 'nn/common/operations/Broadcast.cpp')
-rw-r--r-- | nn/common/operations/Broadcast.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp index 76b1c4413..a575a7e5a 100644 --- a/nn/common/operations/Broadcast.cpp +++ b/nn/common/operations/Broadcast.cpp @@ -26,6 +26,8 @@ #include "Tracing.h" +#include <algorithm> + namespace android { namespace nn { namespace broadcast { @@ -346,20 +348,23 @@ bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, c } // namespace bool validate(OperationType opType, const IOperationValidationContext* context) { + const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB) + ? HalVersion::V1_1 + : HalVersion::V1_0; NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); auto inputType = context->getInputType(kInputTensor1); if (inputType == OperandType::TENSOR_FLOAT32) { - NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); + NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt))); } else if (inputType == OperandType::TENSOR_FLOAT16) { - NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); + NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt))); } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (opType == OperationType::SUB) { - NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); + NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt))); } else if (opType == OperationType::DIV) { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV"; } else { - NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); + NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt))); } } else { NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType); |