aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-07-07 05:16:22 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-07-07 05:16:22 +0000
commita04077a9ddb51e6bf49f5426608940ea8826c5f0 (patch)
treebc2d1560d1cf9dfb9d8c113d6578dec84a456656
parentebd8ac5829946386185531c3ddac1f87a9f3909a (diff)
parent2644c2c0081562f5ee815b06e2bda4ec5f280a7d (diff)
downloadtflite-support-android14-mainline-sdkext-release.tar.gz
Snap for 10453563 from 2644c2c0081562f5ee815b06e2bda4ec5f280a7d to mainline-sdkext-releaseaml_sdk_341510000aml_sdk_341410000aml_sdk_341110080aml_sdk_341110000aml_sdk_341010000aml_sdk_340912010android14-mainline-sdkext-release
Change-Id: Idc71f7cca16ef3ee451f4657810ec1427b4ca099
-rw-r--r--Android.bp3
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/BUILD80
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc295
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc137
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc137
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc122
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc145
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc130
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h97
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java4
-rw-r--r--tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflitebin5123808 -> 0 bytes
-rw-r--r--tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_v4_model.tflitebin0 -> 5107316 bytes
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/BUILD12
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/rbtml_op_resolver.cc40
-rw-r--r--third_party/zlib/Android.bp32
16 files changed, 1261 insertions, 4 deletions
diff --git a/Android.bp b/Android.bp
index f3d5bb4e..2545e4a2 100644
--- a/Android.bp
+++ b/Android.bp
@@ -222,7 +222,8 @@ cc_library_shared {
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc",
- "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc",
+ "tensorflow_lite_support/java/src/native/task/core/rbtml_op_resolver.cc",
+ "tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc",
"tensorflow_lite_support/cc/utils/jni_utils.cc",
],
version_script: "tensorflow_lite_support/java/tflite_version_script.lds",
diff --git a/tensorflow_lite_support/custom_ops/kernel/BUILD b/tensorflow_lite_support/custom_ops/kernel/BUILD
index b9b11de9..a55dcb95 100644
--- a/tensorflow_lite_support/custom_ops/kernel/BUILD
+++ b/tensorflow_lite_support/custom_ops/kernel/BUILD
@@ -144,3 +144,83 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
+
+cc_library(
+ name = "unsorted_segment",
+ srcs = ["unsorted_segment.cc"],
+ hdrs = ["unsorted_segment.h"],
+ deps = [
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:reference_base",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_test(
+ name = "unsorted_segment_prod_test",
+ size = "small",
+ srcs = [
+ "unsorted_segment_prod_test.cc",
+ "unsorted_segment_test.cc",
+ "unsorted_segment_test.h",
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:test_main",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "unsorted_segment_max_test",
+ size = "small",
+ srcs = [
+ "unsorted_segment_max_test.cc",
+ "unsorted_segment_test.cc",
+ "unsorted_segment_test.h",
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:test_main",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "unsorted_segment_sum_test",
+ size = "small",
+ srcs = [
+ "unsorted_segment_sum_test.cc",
+ "unsorted_segment_test.cc",
+ "unsorted_segment_test.h",
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:test_main",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "unsorted_segment_min_test",
+ size = "small",
+ srcs = [
+ "unsorted_segment_sum_test.cc",
+ "unsorted_segment_test.cc",
+ "unsorted_segment_test.h",
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:test_main",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc
new file mode 100644
index 00000000..15972ed6
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc
@@ -0,0 +1,295 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdint.h>
+
+#include <algorithm>
+#include <functional>
+
+#include "tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h"
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace unsorted_segment {
+
+enum SegmentType {
+ kSegmentMax,
+ kSegmentMin,
+ kSegmentProd,
+ kSegmentSum,
+};
+
+static const int kInputDataTensor = 0;
+static const int kInputSegmentIdsTensor = 1;
+static const int kInputNumSegmentsTensor = 2;
+static const int kOutputTensor = 0;
+
+inline bool IsConstantOrPersistentTensor(const TfLiteTensor* tensor) {
+ return tflite::IsConstantTensor(tensor) ||
+ (tensor->allocation_type == kTfLitePersistentRo);
+}
+
+template <typename T, template <typename T2> typename Op>
+void UnsortedSegmentRef(const tflite::RuntimeShape& input_shape,
+ const T* input_data,
+ const tflite::RuntimeShape& segment_ids_shape,
+ const int32_t* segment_ids_data,
+ const tflite::RuntimeShape& output_shape,
+ T* output_data) {
+ for (int i = 0; i < output_shape.FlatSize(); ++i) {
+ output_data[i] = Op<T>::kInitialValue;
+ }
+ Op<T> op;
+ int segment_flat_size = 1;
+ for (int i = 1; i < output_shape.DimensionsCount(); ++i) {
+ segment_flat_size *= output_shape.Dims(i);
+ }
+ for (int i = 0; i < segment_ids_shape.FlatSize(); i++) {
+ int output_index = segment_ids_data[i];
+ if (output_index < 0) continue;
+ for (int j = 0; j < segment_flat_size; ++j) {
+ output_data[output_index * segment_flat_size + j] =
+ op(output_data[output_index * segment_flat_size + j],
+ input_data[i * segment_flat_size + j]);
+ }
+ }
+}
+
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const TfLiteTensor* data,
+ const TfLiteTensor* segment_ids,
+ const TfLiteTensor* num_segments,
+ TfLiteTensor* output) {
+ // The shape of segment_ids is permitted to be any non-empty prefix of
+ // the input data's shape. The shape of output's first dimension is always
+ // equal to num_segments. The remaining dimensions of output's shape are then
+ // taken to be the suffix of input shape after rank(segment_ids)th position.
+ // Public facing tensorflow erroneously describe unsorted_segment ops as only
+ // supporting segment_ids of rank 1, however tensorflow implementation
+ // supports higher dimensional segment_ids as described.
+ const int segment_ids_rank = tflite::NumDimensions(segment_ids);
+ const int data_rank = tflite::NumDimensions(data);
+ TF_LITE_ENSURE(context, segment_ids_rank <= data_rank);
+ for (int i = 0; i < segment_ids_rank; ++i) {
+ // segment_ids shape must be prefix of data shape.
+ TF_LITE_ENSURE_EQ(context, segment_ids->dims->data[i], data->dims->data[i]);
+ }
+ TF_LITE_ENSURE(context, (num_segments->dims->size == 1 &&
+ num_segments->dims->data[0] == 1) ||
+ num_segments->dims->size == 0);
+ // num_segments can be thought of as number of buckets (segments) in output,
+ // where each segment is the reduction of all elements mapped to that
+ // segment_ids. The shape of said elements is the respective
+ // suffix of the data shape.
+ int32_t num_segments_ = tflite::GetTensorData<int32_t>(num_segments)[0];
+ const int num_segment_ids = tflite::NumElements(segment_ids);
+ int max_index = -1;
+ for (int i = 0; i < num_segment_ids; i++) {
+ max_index = std::max(tflite::GetTensorData<int32_t>(segment_ids)[i], max_index);
+ }
+ // num_segments_ must be at greater than max_index else would map elements
+ // to non existent output segments.
+ TF_LITE_ENSURE(context, max_index < num_segments_);
+ const int output_rank = data_rank - segment_ids_rank + 1;
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
+ output_shape->data[0] = num_segments_;
+ // output_shape[1:] should be data_shape[Rank(segment_ids):]
+ for (int i = segment_ids_rank; i < data_rank; ++i) {
+ output_shape->data[i - segment_ids_rank + 1] = data->dims->data[i];
+ }
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
+ const TfLiteTensor* data;
+ TF_LITE_ENSURE_OK(context,
+ tflite::GetInputSafe(context, node, kInputDataTensor, &data));
+ const TfLiteTensor* segment_ids;
+ TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor,
+ &segment_ids));
+ const TfLiteTensor* num_segments;
+ TF_LITE_ENSURE_OK(
+ context,
+ tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context,
+ tflite::GetOutputSafe(context, node, kOutputTensor, &output));
+ TF_LITE_ENSURE(context,
+ data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, num_segments->type, kTfLiteInt32);
+
+ if (tflite::IsDynamicTensor(data) || !IsConstantOrPersistentTensor(segment_ids) ||
+ !IsConstantOrPersistentTensor(num_segments)) {
+ tflite::SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, data, segment_ids, num_segments, output);
+}
+
+template <typename T>
+struct SegmenMax {
+ inline T operator()(const T& a, const T& b) const { return std::max(a, b); }
+ static constexpr T kInitialValue = std::numeric_limits<T>::lowest();
+};
+
+template <typename T>
+struct SegmenMin {
+ inline T operator()(const T& a, const T& b) const { return std::min(a, b); }
+ static constexpr T kInitialValue = std::numeric_limits<T>::max();
+};
+
+template <typename T>
+struct SegmenProd {
+ inline T operator()(const T& a, const T& b) const { return a * b; }
+ static constexpr T kInitialValue = T(1);
+};
+
+template <typename T>
+struct SegmenSum {
+ inline T operator()(const T& a, const T& b) const { return a + b; }
+ static constexpr T kInitialValue = T(0);
+};
+
+template <typename T>
+TfLiteStatus EvalType(TfLiteContext* context, const tflite::RuntimeShape& input_shape,
+ const T* input_data,
+ const tflite::RuntimeShape& segment_ids_shape,
+ const int32_t* segment_ids_data,
+ const tflite::RuntimeShape& output_shape, T* output_data,
+ SegmentType segment_type) {
+ switch (segment_type) {
+ case kSegmentProd:
+ unsorted_segment::UnsortedSegmentRef<T, SegmenProd>(
+ input_shape, input_data, segment_ids_shape, segment_ids_data,
+ output_shape, output_data);
+ break;
+ case kSegmentMax:
+ unsorted_segment::UnsortedSegmentRef<T, SegmenMax>(
+ input_shape, input_data, segment_ids_shape, segment_ids_data,
+ output_shape, output_data);
+ break;
+ case kSegmentSum:
+ unsorted_segment::UnsortedSegmentRef<T, SegmenSum>(
+ input_shape, input_data, segment_ids_shape, segment_ids_data,
+ output_shape, output_data);
+ break;
+ case kSegmentMin:
+ unsorted_segment::UnsortedSegmentRef<T, SegmenMin>(
+ input_shape, input_data, segment_ids_shape, segment_ids_data,
+ output_shape, output_data);
+ break;
+ default:
+ TF_LITE_KERNEL_LOG(context, "Not recognized segment type: %d",
+ segment_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node,
+ SegmentType segment_type) {
+ const TfLiteTensor* data;
+ TF_LITE_ENSURE_OK(context,
+ tflite::GetInputSafe(context, node, kInputDataTensor, &data));
+ const TfLiteTensor* segment_ids;
+ TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor,
+ &segment_ids));
+ const TfLiteTensor* num_segments;
+ TF_LITE_ENSURE_OK(
+ context,
+ tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context,
+ tflite::GetOutputSafe(context, node, kOutputTensor, &output));
+
+ if (tflite::IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids,
+ num_segments, output));
+ }
+ TF_LITE_ENSURE_EQ(context, tflite::GetTensorShape(data).Dims(0),
+ tflite::GetTensorShape(segment_ids).Dims(0));
+
+#define TF_LITE_UNSORTED_SEGMENT(dtype) \
+ EvalType<dtype>(context, tflite::GetTensorShape(data), tflite::GetTensorData<dtype>(data), \
+ tflite::GetTensorShape(segment_ids), \
+ tflite::GetTensorData<int32_t>(segment_ids), tflite::GetTensorShape(output), \
+ tflite::GetTensorData<dtype>(output), segment_type);
+ switch (data->type) {
+ case kTfLiteInt32:
+ TF_LITE_UNSORTED_SEGMENT(int32_t);
+ break;
+ case kTfLiteFloat32:
+ TF_LITE_UNSORTED_SEGMENT(float);
+ break;
+ default:
+ TF_LITE_KERNEL_LOG(
+ context, "Currently UnsortedSegment doesn't support data type: %s",
+ TfLiteTypeGetName(data->type));
+ return kTfLiteError;
+ }
+#undef TF_LITE_UNSORTED_SEGMENT
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
+ return EvalGeneric(context, node, kSegmentProd);
+}
+TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
+ return EvalGeneric(context, node, kSegmentMax);
+}
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ return EvalGeneric(context, node, kSegmentSum);
+}
+TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
+ return EvalGeneric(context, node, kSegmentMin);
+}
+
+} // namespace unsorted_segment
+
+TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD() {
+ static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
+ unsorted_segment::EvalProd};
+ return &r;
+}
+
+TfLiteRegistration* Register_UNSORTED_SEGMENT_MAX() {
+ static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
+ unsorted_segment::EvalMax};
+ return &r;
+}
+
+TfLiteRegistration* Register_UNSORTED_SEGMENT_SUM() {
+ static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
+ unsorted_segment::EvalSum};
+ return &r;
+}
+
+TfLiteRegistration* Register_UNSORTED_SEGMENT_MIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
+ unsorted_segment::EvalMin};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h
new file mode 100644
index 00000000..fb8fd798
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h
@@ -0,0 +1,31 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_UNSORTED_SEGMENT_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_UNSORTED_SEGMENT_H_
+
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif //TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_UNSORTED_SEGMENT_H_ \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc
new file mode 100644
index 00000000..10baefc0
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <limits.h>
+#include <stdint.h>
+
+#include <initializer_list>
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/kernels/unsorted_segment_test.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class UnsortedSegmentMaxModel : public UnsortedSegmentModel<T> {
+ public:
+ UnsortedSegmentMaxModel(const TensorData& data, const TensorData& segment_ids,
+ const TensorData& num_segments)
+ : UnsortedSegmentModel<T>(data, segment_ids, num_segments,
+ BuiltinOperator_UNSORTED_SEGMENT_MAX,
+ BuiltinOptions_NONE) {}
+
+ explicit UnsortedSegmentMaxModel(
+ const TensorData& data, const std::initializer_list<int>& segment_id_data,
+ const std::initializer_list<int>& segment_id_shape,
+ const std::initializer_list<int>& num_segments_data,
+ const std::initializer_list<int>& num_segments_shape)
+ : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape,
+ num_segments_data, num_segments_shape,
+ BuiltinOperator_UNSORTED_SEGMENT_MAX,
+ BuiltinOptions_NONE) {}
+};
+
+INSTANTIATE_TEST_SUITE_P(UnsortedSegmentMaxTestP, UnsortedSegmentTest,
+ testing::Values(BuiltinOperator_UNSORTED_SEGMENT_MAX));
+
+TEST(UnsortedSegmentMaxModelTest, Int32Test_Simple) {
+ UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {6}},
+ {TensorType_INT32, {6}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {5, 1, 7, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1, 1, 0, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 7}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+TEST(UnsortedSegmentMaxModelTest, Int32Test_Simple2D) {
+ UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(),
+ {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentMaxModelTest, FloatTest_Simple) {
+ UnsortedSegmentMaxModel<float> model({TensorType_FLOAT32, {8}},
+ {TensorType_INT32, {8}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(),
+ {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7});
+ model.PopulateTensor<int32_t>(model.num_segments(), {8});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {2.0, 3.0, std::numeric_limits<float>::lowest(),
+ std::numeric_limits<float>::lowest(),
+ std::numeric_limits<float>::lowest(),
+ std::numeric_limits<float>::lowest(),
+ std::numeric_limits<float>::lowest(), 4.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8}));
+}
+
+TEST(UnsortedSegmentMaxModelTest, FloatTest_Simple2D) {
+ UnsortedSegmentMaxModel<float> model({TensorType_FLOAT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({4.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentMaxModelTest, SegmentsAreNegative) {
+ UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {2, 2}},
+ {TensorType_INT32, {2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {-1, -1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {1});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({std::numeric_limits<int32_t>::lowest(),
+ std::numeric_limits<int32_t>::lowest()}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2}));
+}
+
+TEST(UnsortedSegmentMaxModelTest, ConstParamenterTest) {
+ UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {3, 2}}, {0, 1, 0},
+ {3}, {2}, {1});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2}));
+}
+
+} // namespace
+} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc
new file mode 100644
index 00000000..bb4d4141
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <limits.h>
+#include <stdint.h>
+
+#include <initializer_list>
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/kernels/unsorted_segment_test.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class UnsortedSegmentMinModel : public UnsortedSegmentModel<T> {
+ public:
+ UnsortedSegmentMinModel(const TensorData& data, const TensorData& segment_ids,
+ const TensorData& num_segments)
+ : UnsortedSegmentModel<T>(data, segment_ids, num_segments,
+ BuiltinOperator_UNSORTED_SEGMENT_MIN,
+ BuiltinOptions_NONE) {}
+
+ explicit UnsortedSegmentMinModel(
+ const TensorData& data, const std::initializer_list<int>& segment_id_data,
+ const std::initializer_list<int>& segment_id_shape,
+ const std::initializer_list<int>& num_segments_data,
+ const std::initializer_list<int>& num_segments_shape)
+ : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape,
+ num_segments_data, num_segments_shape,
+ BuiltinOperator_UNSORTED_SEGMENT_MIN,
+ BuiltinOptions_NONE) {}
+};
+
+INSTANTIATE_TEST_SUITE_P(UnsortedSegmentMinTestP, UnsortedSegmentTest,
+ testing::Values(BuiltinOperator_UNSORTED_SEGMENT_MIN));
+
+TEST(UnsortedSegmentMinModelTest, Int32Test_Simple) {
+ UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {6}},
+ {TensorType_INT32, {6}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {5, 3, 7, 8, 6, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1, 1, 0, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 4}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+TEST(UnsortedSegmentMinModelTest, Int32Test_Simple2D) {
+ UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(),
+ {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 2, 1, 5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentMinModelTest, FloatTest_Simple) {
+ UnsortedSegmentMinModel<float> model({TensorType_FLOAT32, {8}},
+ {TensorType_INT32, {8}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(),
+ {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7});
+ model.PopulateTensor<int32_t>(model.num_segments(), {8});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(
+ model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {2.0, 1.0, std::numeric_limits<float>::max(),
+ std::numeric_limits<float>::max(), std::numeric_limits<float>::max(),
+ std::numeric_limits<float>::max(), std::numeric_limits<float>::max(),
+ 1.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8}));
+}
+
+TEST(UnsortedSegmentMinModelTest, FloatTest_Simple2D) {
+ UnsortedSegmentMinModel<float> model({TensorType_FLOAT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({1.0, 2.0, 2.0, 1.0, 5.0, 6.0, 7.0, 8.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentMinModelTest, SegmentsAreNegative) {
+ UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {2, 2}},
+ {TensorType_INT32, {2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {-1, -1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {1});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({std::numeric_limits<int32_t>::max(),
+ std::numeric_limits<int32_t>::max()}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2}));
+}
+
+TEST(UnsortedSegmentMinModelTest, ConstParamenterTest) {
+ UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {3, 2}}, {0, 1, 0},
+ {3}, {2}, {1});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2}));
+}
+
+} // namespace
+} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc
new file mode 100644
index 00000000..524d45a0
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc
@@ -0,0 +1,122 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdint.h>
+
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/kernels/unsorted_segment_test.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class UnsortedSegmentProdModel : public UnsortedSegmentModel<T> {
+ public:
+ UnsortedSegmentProdModel(const TensorData& data,
+ const TensorData& segment_ids,
+ const TensorData& num_segments)
+ : UnsortedSegmentModel<T>(data, segment_ids, num_segments,
+ BuiltinOperator_UNSORTED_SEGMENT_PROD,
+ BuiltinOptions_UnsortedSegmentProdOptions) {}
+
+ explicit UnsortedSegmentProdModel(
+ const TensorData& data, const std::initializer_list<int>& segment_id_data,
+ const std::initializer_list<int>& segment_id_shape,
+ const std::initializer_list<int>& num_segments_data,
+ const std::initializer_list<int>& num_segments_shape)
+ : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape,
+ num_segments_data, num_segments_shape,
+ BuiltinOperator_UNSORTED_SEGMENT_PROD,
+ BuiltinOptions_UnsortedSegmentProdOptions) {}
+};
+
+INSTANTIATE_TEST_SUITE_P(
+ UnsortedSegmentProdTestP, UnsortedSegmentTest,
+ testing::Values(BuiltinOperator_UNSORTED_SEGMENT_PROD));
+
+TEST(UnsortedSegmentProdModelTest, Int32Test_Simple) {
+ UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {8}},
+ {TensorType_INT32, {8}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 4, 3, 2, 1});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7});
+ model.PopulateTensor<int32_t>(model.num_segments(), {8});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8}));
+}
+
+TEST(UnsortedSegmentProdModelTest, TestSkipNegSegmentId) {
+ UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {8}},
+ {TensorType_INT32, {8}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 4, 3, 2, 1});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, -1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {8});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8}));
+}
+
+TEST(UnsortedSegmentProdModelTest, Int32Test_Simple2D) {
+ UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(),
+ {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 6, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentProdModelTest, FloatTest_Simple) {
+ UnsortedSegmentProdModel<float> model({TensorType_FLOAT32, {8}},
+ {TensorType_INT32, {8}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(),
+ {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7});
+ model.PopulateTensor<int32_t>(model.num_segments(), {8});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 96.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8}));
+}
+
+TEST(UnsortedSegmentProdModelTest, FloatTest_Simple2D) {
+ UnsortedSegmentProdModel<float> model({TensorType_FLOAT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({4.0, 6.0, 6.0, 4.0, 5.0, 6.0, 7.0, 8.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+} // namespace
+} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc
new file mode 100644
index 00000000..2451c169
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc
@@ -0,0 +1,145 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <limits.h>
+#include <stdint.h>
+
+#include <initializer_list>
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/kernels/unsorted_segment_test.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class UnsortedSegmentSumModel : public UnsortedSegmentModel<T> {
+ public:
+ UnsortedSegmentSumModel(const TensorData& data, const TensorData& segment_ids,
+ const TensorData& num_segments)
+ : UnsortedSegmentModel<T>(data, segment_ids, num_segments,
+ BuiltinOperator_UNSORTED_SEGMENT_SUM,
+ BuiltinOptions_NONE) {}
+
+ explicit UnsortedSegmentSumModel(
+ const TensorData& data, const std::initializer_list<int>& segment_id_data,
+ const std::initializer_list<int>& segment_id_shape,
+ const std::initializer_list<int>& num_segments_data,
+ const std::initializer_list<int>& num_segments_shape)
+ : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape,
+ num_segments_data, num_segments_shape,
+ BuiltinOperator_UNSORTED_SEGMENT_SUM,
+ BuiltinOptions_NONE) {}
+};
+
+INSTANTIATE_TEST_SUITE_P(UnsortedSegmentSumTestP, UnsortedSegmentTest,
+ testing::Values(BuiltinOperator_UNSORTED_SEGMENT_SUM));
+
+TEST(UnsortedSegmentSumModelTest, Int32Test_Simple) {
+ UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {7}},
+ {TensorType_INT32, {7}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {5, 1, 7, 2, 3, 4, 10});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1, 1, 0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({19, 13}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+TEST(UnsortedSegmentSumModelTest, Int32Test_Simple2D) {
+ UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(),
+ {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentSumModelTest, FloatTest_Simple) {
+ UnsortedSegmentSumModel<float> model({TensorType_FLOAT32, {6}},
+ {TensorType_INT32, {6}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 4.0, 3.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7});
+ model.PopulateTensor<int32_t>(model.num_segments(), {8});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({2.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8}));
+}
+
+TEST(UnsortedSegmentSumModelTest, FloatTest_Simple2D) {
+ UnsortedSegmentSumModel<float> model({TensorType_FLOAT32, {3, 4}},
+ {TensorType_INT32, {3}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 4.0, 3.0, 2.0, 1.0});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 7.0, 8.0})));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+}
+
+TEST(UnsortedSegmentSumModelTest, AllNegativeSegmentIdsZeroTensor) {
+ UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {2, 2}},
+ {TensorType_INT32, {2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {-1, -1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {1});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2}));
+}
+
+TEST(UnsortedSegmentSumModelTest, SomeNegativeSegmentIdsIgnored) {
+ UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {4}},
+ {TensorType_INT32, {4}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {-1, 0, -1, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 4}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+TEST(UnsortedSegmentSumModelTest,
+ NumSegmentsGreaterThanNumIdsPadsWithZeroTensors) {
+ UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {2, 2}},
+ {TensorType_INT32, {2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {3});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 2}));
+}
+} // namespace
+} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc
new file mode 100644
index 00000000..6e398374
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc
@@ -0,0 +1,130 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "third_party/tensorflow/lite/kernels/unsorted_segment_test.h"
+
+#include <limits.h>
+#include <stdint.h>
+
+#include <initializer_list>
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+TEST_P(UnsortedSegmentTest, SegmentIdsSizeNotEqualToDataFirstDimensionFails) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {3, 2}}, {TensorType_INT32, {2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2});
+ ASSERT_EQ(model.Invoke(), kTfLiteError);
+}
+TEST_P(UnsortedSegmentTest,
+ LargestSegmentIdPlusOneGreaterThanNumSegmentsFails) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {2, 2}}, {TensorType_INT32, {2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {1});
+ ASSERT_EQ(model.Invoke(), kTfLiteError);
+}
+TEST_P(UnsortedSegmentTest, NumSegmentsNotScalarShapeFails) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {3, 2}}, {TensorType_INT32, {3}},
+ {TensorType_INT32, {2}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0});
+ model.PopulateTensor<int32_t>(model.num_segments(), {2, 1});
+ ASSERT_EQ(model.Invoke(), kTfLiteError);
+}
+TEST_P(UnsortedSegmentTest, Rank2SegIdsNotPrefixFails) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {2, 2, 2}}, {TensorType_INT32, {2, 1}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 1});
+ model.PopulateTensor<int32_t>(model.num_segments(), {3});
+ ASSERT_EQ(model.Invoke(), kTfLiteError);
+}
+TEST_P(UnsortedSegmentTest, Rank2SegIdsHasShapeNumSegDataShapeSuffix) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {2, 2, 2}}, {TensorType_INT32, {2, 2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 2, 0, 8});
+ model.PopulateTensor<int32_t>(model.num_segments(), {10});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10, 2}));
+}
+TEST_P(UnsortedSegmentTest, Rank2SegIdsHasShapeNumSegDataShapeSuffixConst) {
+ UnsortedSegmentModel<int32_t> model = getConstModel(
+ {TensorType_INT32, {2, 2, 2}}, {1, 2, -1, -1}, {2, 2}, {3}, {1});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 2}));
+}
+TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData2d) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {2, 2}}, {TensorType_INT32, {2, 2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 5, 2, 4});
+ model.PopulateTensor<int32_t>(model.num_segments(), {10});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10}));
+}
+TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData2dConst) {
+ UnsortedSegmentModel<int32_t> model =
+ getConstModel({TensorType_INT32, {2, 2}}, {1, 1, 1, 1}, {2, 2}, {3}, {1});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3}));
+}
+TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData3d) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {2, 2, 2}}, {TensorType_INT32, {2, 2, 2}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.segment_ids(), {1, 2, 3, 4, 5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.num_segments(), {10});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10}));
+}
+TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData3dConst) {
+ UnsortedSegmentModel<int32_t> model =
+ getConstModel({TensorType_INT32, {2, 2, 2}}, {0, 1, 2, -1, 3, -1, 4, -1},
+ {2, 2, 2}, {8}, {1});
+ model.PopulateTensor<int32_t>(model.data(), {1, 1, 1, 1, 1, 1});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({8}));
+}
+TEST_P(UnsortedSegmentTest, Data5dHasShapeNumSegDataShapeSuffix) {
+ UnsortedSegmentModel<int32_t> model =
+ getModel({TensorType_INT32, {2, 1, 2, 1, 2}}, {TensorType_INT32, {2, 1}},
+ {TensorType_INT32, {1}});
+ model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6, 7, 8});
+ model.PopulateTensor(model.segment_ids(), {0, 1});
+ model.PopulateTensor(model.num_segments(), {10});
+ ASSERT_EQ(model.Invoke(), kTfLiteOk);
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10, 2, 1, 2}));
+}
+} // namespace
+} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h
new file mode 100644
index 00000000..6a5b0f64
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h
@@ -0,0 +1,97 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_UNSORTED_SEGMENT_TEST_H_
+#define TENSORFLOW_LITE_KERNELS_UNSORTED_SEGMENT_TEST_H_
+
+#include <limits.h>
+#include <stdint.h>
+
+#include <initializer_list>
+#include <iostream>
+#include <ostream>
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+template <typename T>
+class UnsortedSegmentModel : public SingleOpModel {
+ public:
+ UnsortedSegmentModel(const TensorData& data, const TensorData& segment_ids,
+ const TensorData& num_segments, const BuiltinOperator op,
+ const BuiltinOptions options) {
+ data_id_ = AddInput(data);
+ segment_ids_id_ = AddInput(segment_ids);
+ num_segments_id_ = AddInput(num_segments);
+ output_id_ = AddOutput(data.type);
+ SetBuiltinOp(op, options, 0);
+ BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_),
+ GetShape(num_segments_id_)});
+ }
+
+ explicit UnsortedSegmentModel(
+ const TensorData& data, const std::initializer_list<int>& segment_id_data,
+ const std::initializer_list<int>& segment_id_shape,
+ const std::initializer_list<int>& num_segments_data,
+ const std::initializer_list<int>& num_segments_shape,
+ const BuiltinOperator op, const BuiltinOptions options) {
+ data_id_ = AddInput(data);
+ segment_ids_id_ =
+ AddConstInput(TensorType_INT32, segment_id_data, segment_id_shape);
+ num_segments_id_ =
+ AddConstInput(TensorType_INT32, num_segments_data, num_segments_shape);
+ output_id_ = AddOutput(data.type);
+ SetBuiltinOp(op, options, 0);
+ BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_),
+ GetShape(num_segments_id_)});
+ }
+
+ int data() const { return data_id_; }
+ int segment_ids() const { return segment_ids_id_; }
+ int num_segments() const { return num_segments_id_; }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
+ std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); }
+
+ protected:
+ int data_id_;
+ int segment_ids_id_;
+ int num_segments_id_;
+ int output_id_;
+};
+
+class UnsortedSegmentTest : public ::testing::TestWithParam<BuiltinOperator> {
+ public:
+ UnsortedSegmentModel<int32_t> getModel(const TensorData& data,
+ const TensorData& segment_ids,
+ const TensorData& num_segments) {
+ return UnsortedSegmentModel<int32_t>(data, segment_ids, num_segments,
+ GetParam(), BuiltinOptions_NONE);
+ }
+ UnsortedSegmentModel<int32_t> getConstModel(
+ const TensorData& data, const std::initializer_list<int>& segment_id_data,
+ const std::initializer_list<int>& segment_id_shape,
+ const std::initializer_list<int>& num_segments_data,
+ const std::initializer_list<int>& num_segments_shape) {
+ return UnsortedSegmentModel<int32_t>(
+ data, segment_id_data, segment_id_shape, num_segments_data,
+ num_segments_shape, GetParam(), BuiltinOptions_NONE);
+ }
+};
+
+} // namespace tflite
+#endif // TENSORFLOW_LITE_KERNELS_UNSORTED_SEGMENT_TEST_H_ \ No newline at end of file
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
index efaa9d99..71d8fee2 100644
--- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -30,7 +30,7 @@ import org.tensorflow.lite.task.core.TestUtils;
public class BertNLClassifierTest {
private static final String MODEL_FILE = "bert_nl_classifier.tflite";
// A classifier model with dynamic input tensors. Provided by the Android Rubidium team.
- private static final String DYNAMIC_INPUT_MODEL_FILE = "rb_model.tflite";
+ private static final String DYNAMIC_INPUT_MODEL_FILE = "rb_v4_model.tflite";
Category findCategoryWithLabel(List<Category> list, String label) {
return list.stream()
@@ -91,7 +91,7 @@ public class BertNLClassifierTest {
BertNLClassifier classifier = BertNLClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE);
- assertThat(classifier.getModelVersion()).isEqualTo("2");
+ assertThat(classifier.getModelVersion()).isEqualTo("4");
}
@Test
diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite
deleted file mode 100644
index 56fe4703..00000000
--- a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite
+++ /dev/null
Binary files differ
diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_v4_model.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_v4_model.tflite
new file mode 100644
index 00000000..7265ba27
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_v4_model.tflite
Binary files differ
diff --git a/tensorflow_lite_support/java/src/native/task/core/BUILD b/tensorflow_lite_support/java/src/native/task/core/BUILD
index d4dd7ab3..65e9e30c 100644
--- a/tensorflow_lite_support/java/src/native/task/core/BUILD
+++ b/tensorflow_lite_support/java/src/native/task/core/BUILD
@@ -3,7 +3,7 @@ package(
licenses = ["notice"], # Apache 2.0
)
-# Default provider for BuiltInOpResover. Create your own target, overwrite the
+# Default provider for BuiltInOpResolver. Create your own target, overwrite the
# function to provide a MutableOpResolver for customized OPs and/or a subset of
# builtin OPs.
cc_library(
@@ -14,3 +14,13 @@ cc_library(
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
+
+cc_library(
+ name = "rbtml_op_resolver",
+ srcs = ["rbtml_op_resolver.cc"],
+ deps = [
+ "//tensorflow_lite_support/custom_ops/kernel:unsorted_segment",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/core/rbtml_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/rbtml_op_resolver.cc
new file mode 100644
index 00000000..c35501cb
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/core/rbtml_op_resolver.cc
@@ -0,0 +1,40 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+
+#include "tensorflow/lite/kernels/register.h"
+
+#include "tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h"
+
+namespace tflite {
+namespace task {
+// Create a custom op resolver to provide the unsorted_segment_prod op
+// required by the bert_nl_classifier and rb_model for BertNLClassifier.
+std::unique_ptr<tflite::OpResolver> CreateOpResolver() { // NOLINT
+ std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
+ new tflite::ops::builtin::BuiltinOpResolver);
+ // "UnsortedSegmentProd" is the name used by unsorted_segment_prod op when
+ // when converting SavedModel to tflite using the size optimization approach.
+ resolver->AddCustom("UnsortedSegmentProd",
+ tflite::ops::custom::Register_UNSORTED_SEGMENT_PROD());
+ // "FlexUnsortedSegmentProd" is the name used by unsorted_segment_prod op when
+ // when converting SavedModel to tflite using the the other approaches.
+ resolver->AddCustom("FlexUnsortedSegmentProd",
+ tflite::ops::custom::Register_UNSORTED_SEGMENT_PROD());
+ return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
+}
+
+} // namespace task
+} // namespace tflite \ No newline at end of file
diff --git a/third_party/zlib/Android.bp b/third_party/zlib/Android.bp
index a0c7570a..2da5f53c 100644
--- a/third_party/zlib/Android.bp
+++ b/third_party/zlib/Android.bp
@@ -2,6 +2,38 @@
// TODO(b/233151429): Clean up external/tflite-support/third_party/zlib so that it contains the
// minimum set of files needed to support `BertNLClassifier`.
+package {
+ default_applicable_licenses: [
+ "external_tflite-support_third_party_zlib_license",
+ ],
+}
+
+// Added automatically by a large-scale-change that took the approach of
+// 'apply every license found to every target'. While this makes sure we respect
+// every license restriction, it may not be entirely correct.
+//
+// e.g. GPL in an MIT project might only apply to the contrib/ directory.
+//
+// Please consider splitting the single license below into multiple licenses,
+// taking care not to lose any license_kind information, and overriding the
+// default license using the 'licenses: [...]' property on targets as needed.
+//
+// For unused files, consider creating a 'fileGroup' with "//visibility:private"
+// to attach the license to, and including a comment whether the files may be
+// used in the current project.
+// See: http://go/android-license-faq
+license {
+ name: "external_tflite-support_third_party_zlib_license",
+ visibility: [":__subpackages__"],
+ license_kinds: [
+ "SPDX-license-identifier-BSD",
+ "SPDX-license-identifier-Zlib",
+ ],
+ license_text: [
+ "LICENSE",
+ ],
+}
+
srcs_opt = [
"adler32_simd.c",
// See https://chromium-review.googlesource.com/749732.