summaryrefslogtreecommitdiff
path: root/nn/common/operations/Broadcast.cpp
diff options
context:
space:
mode:
authorDavid Gross <dgross@google.com>2019-04-22 12:27:19 -0700
committerDavid Gross <dgross@google.com>2019-04-23 12:12:03 -0700
commitad3e6388f0c4b18550356c1675e99cd20723f7b4 (patch)
tree3d48b31dc3ade033f2799713d2d88a2b07e4b1c6 /nn/common/operations/Broadcast.cpp
parent99c70afc22e051ab0737fc2f2ec1dc51dd5af5a9 (diff)
downloadml-ad3e6388f0c4b18550356c1675e99cd20723f7b4.tar.gz
Compliance/Validation: DIV and SUB are only available at V1_1 and later.
Also make NN_RET_CHECK_FAIL() more verbose. Bug: 130917878 Test: NeuralNetworksTest_static Merged-In: Id70c4e6579da825e69dc0e735735cfa8bc254229 Change-Id: Id70c4e6579da825e69dc0e735735cfa8bc254229 (cherry picked from commit 4d4f8e7a54d0b9f5cff4912826eaa2daff41470f)
Diffstat (limited to 'nn/common/operations/Broadcast.cpp')
-rw-r--r--nn/common/operations/Broadcast.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp
index 76b1c4413..a575a7e5a 100644
--- a/nn/common/operations/Broadcast.cpp
+++ b/nn/common/operations/Broadcast.cpp
@@ -26,6 +26,8 @@
#include "Tracing.h"
+#include <algorithm>
+
namespace android {
namespace nn {
namespace broadcast {
@@ -346,20 +348,23 @@ bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, c
} // namespace
bool validate(OperationType opType, const IOperationValidationContext* context) {
+ const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB)
+ ? HalVersion::V1_1
+ : HalVersion::V1_0;
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
auto inputType = context->getInputType(kInputTensor1);
if (inputType == OperandType::TENSOR_FLOAT32) {
- NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
+ NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
} else if (inputType == OperandType::TENSOR_FLOAT16) {
- NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
+ NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
if (opType == OperationType::SUB) {
- NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
+ NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
} else if (opType == OperationType::DIV) {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
} else {
- NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
+ NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
}
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);