summaryrefslogtreecommitdiff
path: root/nn
diff options
context:
space:
mode:
authorSlava Shklyaev <slavash@google.com>2020-01-10 13:07:02 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2020-01-10 13:07:02 +0000
commit004c7f4a66d4968046ff1b27ef55ba35378ba10b (patch)
tree9cba3323e351db9424357bb33552c9288332db26 /nn
parenta329665d1ad6bb609e1bd1b3c3e1a3efe0cc5b4e (diff)
parent5d9f619775fa6b2badde414aac0ffb2f48614e45 (diff)
downloadml-004c7f4a66d4968046ff1b27ef55ba35378ba10b.tar.gz
Merge "Add TENSOR_INT32 support for SUB"
Diffstat (limited to 'nn')
-rw-r--r--nn/common/operations/Broadcast.cpp34
-rw-r--r--nn/runtime/include/NeuralNetworks.h3
-rw-r--r--nn/runtime/test/TestValidateOperations.cpp4
-rw-r--r--nn/runtime/test/generated/spec_V1_3/sub_int32.example.cpp70
-rw-r--r--nn/runtime/test/specs/V1_3/sub_int32.mod.py23
-rw-r--r--nn/tools/api/types.spec5
6 files changed, 128 insertions, 11 deletions
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp
index 99bb18495..56069a852 100644
--- a/nn/common/operations/Broadcast.cpp
+++ b/nn/common/operations/Broadcast.cpp
@@ -205,8 +205,9 @@ bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& sha
return true;
}
-bool addInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData, const Shape& bShape,
- int32_t activation, int32_t* outputData, const Shape& outputShape) {
+bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
+ const Shape& bShape, int32_t activation, int32_t* outputData,
+ const Shape& outputShape, int32_t func(int32_t, int32_t)) {
NN_RET_CHECK_EQ(activation, ANEURALNETWORKS_FUSED_NONE);
IndexedShapeWrapper aShapeIndexed(aShape);
IndexedShapeWrapper bShapeIndexed(bShape);
@@ -221,7 +222,7 @@ bool addInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData, c
uint32_t bFlatIndex;
NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
- outputData[outputFlatIndex] = aData[aFlatIndex] + bData[bFlatIndex];
+ outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
} while (!lastIndex);
@@ -461,7 +462,8 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
}
} else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_3, opIntroducedAt)));
- } else if (inputType == OperandType::TENSOR_INT32 && opType == OperationType::ADD) {
+ } else if (inputType == OperandType::TENSOR_INT32 &&
+ (opType == OperationType::ADD || opType == OperationType::SUB)) {
NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3));
} else {
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
@@ -517,13 +519,14 @@ bool executeAdd(IOperationExecutionContext* context) {
context->getOutputBuffer<int8_t>(kOutputTensor),
context->getOutputShape(kOutputTensor));
case OperandType::TENSOR_INT32:
- return addInt32(context->getInputBuffer<int32_t>(kInputTensor1),
- context->getInputShape(kInputTensor1),
- context->getInputBuffer<int32_t>(kInputTensor2),
- context->getInputShape(kInputTensor2),
- context->getInputValue<int32_t>(kActivationScalar),
- context->getOutputBuffer<int32_t>(kOutputTensor),
- context->getOutputShape(kOutputTensor));
+ return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
+ context->getInputShape(kInputTensor1),
+ context->getInputBuffer<int32_t>(kInputTensor2),
+ context->getInputShape(kInputTensor2),
+ context->getInputValue<int32_t>(kActivationScalar),
+ context->getOutputBuffer<int32_t>(kOutputTensor),
+ context->getOutputShape(kOutputTensor),
+ [](int32_t a, int32_t b) { return a + b; });
default:
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
}
@@ -606,6 +609,15 @@ bool executeSub(IOperationExecutionContext* context) {
context->getInputValue<int32_t>(kActivationScalar),
context->getOutputBuffer<int8_t>(kOutputTensor),
context->getOutputShape(kOutputTensor));
+ case OperandType::TENSOR_INT32:
+ return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
+ context->getInputShape(kInputTensor1),
+ context->getInputBuffer<int32_t>(kInputTensor2),
+ context->getInputShape(kInputTensor2),
+ context->getInputValue<int32_t>(kActivationScalar),
+ context->getOutputBuffer<int32_t>(kOutputTensor),
+ context->getOutputShape(kOutputTensor),
+ [](int32_t a, int32_t b) { return a - b; });
default:
NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
}
diff --git a/nn/runtime/include/NeuralNetworks.h b/nn/runtime/include/NeuralNetworks.h
index b2f72c47d..505a19bfb 100644
--- a/nn/runtime/include/NeuralNetworks.h
+++ b/nn/runtime/include/NeuralNetworks.h
@@ -2379,6 +2379,7 @@ typedef enum {
* * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
* * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} (since API level 29)
* * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED} (since API level 30)
+ * * {@link ANEURALNETWORKS_TENSOR_INT32} (since API level 30)
*
* Supported tensor rank: up to 4
*
@@ -2389,6 +2390,8 @@ typedef enum {
* * 2: An {@link ANEURALNETWORKS_INT32} scalar, and has to be one of the
* {@link FuseCode} values. Specifies the activation to
* invoke on the result.
+ * For a {@link ANEURALNETWORKS_TENSOR_INT32} tensor,
+ * the {@link FuseCode} must be "NONE".
*
* Outputs:
* * 0: A tensor of the same {@link OperandCode} as input0.
diff --git a/nn/runtime/test/TestValidateOperations.cpp b/nn/runtime/test/TestValidateOperations.cpp
index 4f17c3d31..d42323c94 100644
--- a/nn/runtime/test/TestValidateOperations.cpp
+++ b/nn/runtime/test/TestValidateOperations.cpp
@@ -682,6 +682,10 @@ TEST(OperationValidationTest, SUB_quant8_signed) {
simpleMathOpTest(ANEURALNETWORKS_SUB, ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED);
}
+TEST(OperationValidationTest, SUB_int32) {
+ simpleMathOpTest(ANEURALNETWORKS_SUB, ANEURALNETWORKS_TENSOR_INT32);
+}
+
TEST(OperationValidationTest, DIV_float16) {
simpleMathOpTest(ANEURALNETWORKS_DIV, ANEURALNETWORKS_TENSOR_FLOAT16);
}
diff --git a/nn/runtime/test/generated/spec_V1_3/sub_int32.example.cpp b/nn/runtime/test/generated/spec_V1_3/sub_int32.example.cpp
new file mode 100644
index 000000000..baa8741be
--- /dev/null
+++ b/nn/runtime/test/generated/spec_V1_3/sub_int32.example.cpp
@@ -0,0 +1,70 @@
+// Generated from sub_int32.mod.py
+// DO NOT EDIT
+// clang-format off
+#include "TestHarness.h"
+using namespace test_helper;
+
+namespace generated_tests::sub_int32 {
+
+const TestModel& get_test_model() {
+ static TestModel model = {
+ .expectFailure = false,
+ .expectedMultinomialDistributionTolerance = 0,
+ .inputIndexes = {0, 1},
+ .isRelaxed = false,
+ .minSupportedVersion = TestHalVersion::V1_3,
+ .operands = {{
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({2, -4, 8, -16}),
+ .dimensions = {1, 2, 2},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({2, -2, -4, 4}),
+ .dimensions = {1, 2, 2},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_INPUT,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({0}),
+ .dimensions = {},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::CONSTANT_COPY,
+ .numberOfConsumers = 1,
+ .scale = 0.0f,
+ .type = TestOperandType::INT32,
+ .zeroPoint = 0
+ }, {
+ .channelQuant = {},
+ .data = TestBuffer::createFromVector<int32_t>({0, -2, 12, -20}),
+ .dimensions = {1, 2, 2},
+ .isIgnored = false,
+ .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
+ .numberOfConsumers = 0,
+ .scale = 0.0f,
+ .type = TestOperandType::TENSOR_INT32,
+ .zeroPoint = 0
+ }},
+ .operations = {{
+ .inputs = {0, 1, 2},
+ .outputs = {3},
+ .type = TestOperationType::SUB
+ }},
+ .outputIndexes = {3}
+ };
+ return model;
+}
+
+const auto dummy_test_model = TestModelManager::get().add("sub_int32", get_test_model());
+
+} // namespace generated_tests::sub_int32
+
diff --git a/nn/runtime/test/specs/V1_3/sub_int32.mod.py b/nn/runtime/test/specs/V1_3/sub_int32.mod.py
new file mode 100644
index 000000000..82fe66f81
--- /dev/null
+++ b/nn/runtime/test/specs/V1_3/sub_int32.mod.py
@@ -0,0 +1,23 @@
+#
+# Copyright (C) 2020 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+model = Model()
+input0 = Input("input0", "TENSOR_INT32", "{1, 2, 2}")
+input1 = Input("input1", "TENSOR_INT32", "{1, 2, 2}")
+output = Output("output", "TENSOR_INT32", "{1, 2, 2}")
+model = model.Operation("SUB", input0, input1, 0).To(output)
+Example({input0: [2, -4, 8, -16],
+ input1: [2, -2, -4, 4],
+ output: [0, -2, 12, -20]})
diff --git a/nn/tools/api/types.spec b/nn/tools/api/types.spec
index c6650d959..095f491ed 100644
--- a/nn/tools/api/types.spec
+++ b/nn/tools/api/types.spec
@@ -2728,6 +2728,7 @@
%/kind
%kind ndk hal_1.3+
* * {@link %{OperandTypeLinkPfx}TENSOR_QUANT8_ASYMM_SIGNED} (since %{APILevel30})
+ * * {@link %{OperandTypeLinkPfx}TENSOR_INT32} (since %{APILevel30})
%/kind
*
* Supported tensor rank: up to 4
@@ -2739,6 +2740,10 @@
* * 2: An {@link %{OperandTypeLinkPfx}INT32} scalar, and has to be one of the
* {@link %{FusedActivationFunc}} values. Specifies the activation to
* invoke on the result.
+%kind ndk hal_1.3+
+ * For a {@link %{OperandTypeLinkPfx}TENSOR_INT32} tensor,
+ * the {@link %{FusedActivationFunc}} must be "NONE".
+%/kind
*
* Outputs:
* * 0: A tensor of the same {@link %{OperandType}} as input0.