diff options
Diffstat (limited to 'tensorflow_lite_support/custom_ops/kernel')
9 files changed, 1174 insertions, 0 deletions
diff --git a/tensorflow_lite_support/custom_ops/kernel/BUILD b/tensorflow_lite_support/custom_ops/kernel/BUILD index b9b11de9..a55dcb95 100644 --- a/tensorflow_lite_support/custom_ops/kernel/BUILD +++ b/tensorflow_lite_support/custom_ops/kernel/BUILD @@ -144,3 +144,83 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) + +cc_library( + name = "unsorted_segment", + srcs = ["unsorted_segment.cc"], + hdrs = ["unsorted_segment.h"], + deps = [ + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:reference_base", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_test( + name = "unsorted_segment_prod_test", + size = "small", + srcs = [ + "unsorted_segment_prod_test.cc", + "unsorted_segment_test.cc", + "unsorted_segment_test.h", + ], + deps = [ + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_main", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "unsorted_segment_max_test", + size = "small", + srcs = [ + "unsorted_segment_max_test.cc", + "unsorted_segment_test.cc", + "unsorted_segment_test.h", + ], + deps = [ + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_main", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "unsorted_segment_sum_test", + size = "small", + srcs = [ + "unsorted_segment_sum_test.cc", + "unsorted_segment_test.cc", + "unsorted_segment_test.h", + ], + deps = [ + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_main", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "unsorted_segment_min_test", + size = "small", + srcs = [ + "unsorted_segment_sum_test.cc", + "unsorted_segment_test.cc", + "unsorted_segment_test.h", + ], + deps = [ + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_main", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc new file mode 100644 index 00000000..15972ed6 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.cc @@ -0,0 +1,295 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stdint.h> + +#include <algorithm> +#include <functional> + +#include "tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h" + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace unsorted_segment { + +enum SegmentType { + kSegmentMax, + kSegmentMin, + kSegmentProd, + kSegmentSum, +}; + +static const int kInputDataTensor = 0; +static const int kInputSegmentIdsTensor = 1; +static const int kInputNumSegmentsTensor = 2; +static const int kOutputTensor = 0; + +inline bool IsConstantOrPersistentTensor(const TfLiteTensor* tensor) { + return tflite::IsConstantTensor(tensor) || + (tensor->allocation_type == kTfLitePersistentRo); +} + +template <typename T, template <typename T2> typename Op> +void UnsortedSegmentRef(const tflite::RuntimeShape& input_shape, + const T* input_data, + const tflite::RuntimeShape& segment_ids_shape, + const int32_t* segment_ids_data, + const tflite::RuntimeShape& output_shape, + T* output_data) { + for (int i = 0; i < output_shape.FlatSize(); ++i) { + output_data[i] = Op<T>::kInitialValue; + } + Op<T> op; + int segment_flat_size = 1; + for (int i = 1; i < output_shape.DimensionsCount(); ++i) { + segment_flat_size *= output_shape.Dims(i); + } + for (int i = 0; i < segment_ids_shape.FlatSize(); i++) { + int output_index = segment_ids_data[i]; + if (output_index < 0) continue; + for (int j = 0; j < segment_flat_size; ++j) { + output_data[output_index * segment_flat_size + j] = + op(output_data[output_index * segment_flat_size + j], + input_data[i * segment_flat_size + j]); + } + } +} + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const TfLiteTensor* data, + const TfLiteTensor* segment_ids, + const TfLiteTensor* num_segments, + TfLiteTensor* output) { + // The shape of segment_ids is permitted to be any non-empty prefix of + // the input data's shape. The shape of output's first dimension is always + // equal to num_segments. The remaining dimensions of output's shape are then + // taken to be the suffix of input shape after rank(segment_ids)th position. + // Public facing tensorflow erroneously describe unsorted_segment ops as only + // supporting segment_ids of rank 1, however tensorflow implementation + // supports higher dimensional segment_ids as described. + const int segment_ids_rank = tflite::NumDimensions(segment_ids); + const int data_rank = tflite::NumDimensions(data); + TF_LITE_ENSURE(context, segment_ids_rank <= data_rank); + for (int i = 0; i < segment_ids_rank; ++i) { + // segment_ids shape must be prefix of data shape. + TF_LITE_ENSURE_EQ(context, segment_ids->dims->data[i], data->dims->data[i]); + } + TF_LITE_ENSURE(context, (num_segments->dims->size == 1 && + num_segments->dims->data[0] == 1) || + num_segments->dims->size == 0); + // num_segments can be thought of as number of buckets (segments) in output, + // where each segment is the reduction of all elements mapped to that + // segment_ids. The shape of said elements is the respective + // suffix of the data shape. + int32_t num_segments_ = tflite::GetTensorData<int32_t>(num_segments)[0]; + const int num_segment_ids = tflite::NumElements(segment_ids); + int max_index = -1; + for (int i = 0; i < num_segment_ids; i++) { + max_index = std::max(tflite::GetTensorData<int32_t>(segment_ids)[i], max_index); + } + // num_segments_ must be at greater than max_index else would map elements + // to non existent output segments. + TF_LITE_ENSURE(context, max_index < num_segments_); + const int output_rank = data_rank - segment_ids_rank + 1; + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + output_shape->data[0] = num_segments_; + // output_shape[1:] should be data_shape[Rank(segment_ids):] + for (int i = segment_ids_rank; i < data_rank; ++i) { + output_shape->data[i - segment_ids_rank + 1] = data->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1); + const TfLiteTensor* data; + TF_LITE_ENSURE_OK(context, + tflite::GetInputSafe(context, node, kInputDataTensor, &data)); + const TfLiteTensor* segment_ids; + TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor, + &segment_ids)); + const TfLiteTensor* num_segments; + TF_LITE_ENSURE_OK( + context, + tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + tflite::GetOutputSafe(context, node, kOutputTensor, &output)); + TF_LITE_ENSURE(context, + data->type == kTfLiteInt32 || data->type == kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, num_segments->type, kTfLiteInt32); + + if (tflite::IsDynamicTensor(data) || !IsConstantOrPersistentTensor(segment_ids) || + !IsConstantOrPersistentTensor(num_segments)) { + tflite::SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, data, segment_ids, num_segments, output); +} + +template <typename T> +struct SegmenMax { + inline T operator()(const T& a, const T& b) const { return std::max(a, b); } + static constexpr T kInitialValue = std::numeric_limits<T>::lowest(); +}; + +template <typename T> +struct SegmenMin { + inline T operator()(const T& a, const T& b) const { return std::min(a, b); } + static constexpr T kInitialValue = std::numeric_limits<T>::max(); +}; + +template <typename T> +struct SegmenProd { + inline T operator()(const T& a, const T& b) const { return a * b; } + static constexpr T kInitialValue = T(1); +}; + +template <typename T> +struct SegmenSum { + inline T operator()(const T& a, const T& b) const { return a + b; } + static constexpr T kInitialValue = T(0); +}; + +template <typename T> +TfLiteStatus EvalType(TfLiteContext* context, const tflite::RuntimeShape& input_shape, + const T* input_data, + const tflite::RuntimeShape& segment_ids_shape, + const int32_t* segment_ids_data, + const tflite::RuntimeShape& output_shape, T* output_data, + SegmentType segment_type) { + switch (segment_type) { + case kSegmentProd: + unsorted_segment::UnsortedSegmentRef<T, SegmenProd>( + input_shape, input_data, segment_ids_shape, segment_ids_data, + output_shape, output_data); + break; + case kSegmentMax: + unsorted_segment::UnsortedSegmentRef<T, SegmenMax>( + input_shape, input_data, segment_ids_shape, segment_ids_data, + output_shape, output_data); + break; + case kSegmentSum: + unsorted_segment::UnsortedSegmentRef<T, SegmenSum>( + input_shape, input_data, segment_ids_shape, segment_ids_data, + output_shape, output_data); + break; + case kSegmentMin: + unsorted_segment::UnsortedSegmentRef<T, SegmenMin>( + input_shape, input_data, segment_ids_shape, segment_ids_data, + output_shape, output_data); + break; + default: + TF_LITE_KERNEL_LOG(context, "Not recognized segment type: %d", + segment_type); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node, + SegmentType segment_type) { + const TfLiteTensor* data; + TF_LITE_ENSURE_OK(context, + tflite::GetInputSafe(context, node, kInputDataTensor, &data)); + const TfLiteTensor* segment_ids; + TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor, + &segment_ids)); + const TfLiteTensor* num_segments; + TF_LITE_ENSURE_OK( + context, + tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + tflite::GetOutputSafe(context, node, kOutputTensor, &output)); + + if (tflite::IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids, + num_segments, output)); + } + TF_LITE_ENSURE_EQ(context, tflite::GetTensorShape(data).Dims(0), + tflite::GetTensorShape(segment_ids).Dims(0)); + +#define TF_LITE_UNSORTED_SEGMENT(dtype) \ + EvalType<dtype>(context, tflite::GetTensorShape(data), tflite::GetTensorData<dtype>(data), \ + tflite::GetTensorShape(segment_ids), \ + tflite::GetTensorData<int32_t>(segment_ids), tflite::GetTensorShape(output), \ + tflite::GetTensorData<dtype>(output), segment_type); + switch (data->type) { + case kTfLiteInt32: + TF_LITE_UNSORTED_SEGMENT(int32_t); + break; + case kTfLiteFloat32: + TF_LITE_UNSORTED_SEGMENT(float); + break; + default: + TF_LITE_KERNEL_LOG( + context, "Currently UnsortedSegment doesn't support data type: %s", + TfLiteTypeGetName(data->type)); + return kTfLiteError; + } +#undef TF_LITE_UNSORTED_SEGMENT + return kTfLiteOk; +} + +TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { + return EvalGeneric(context, node, kSegmentProd); +} +TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) { + return EvalGeneric(context, node, kSegmentMax); +} +TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { + return EvalGeneric(context, node, kSegmentSum); +} +TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) { + return EvalGeneric(context, node, kSegmentMin); +} + +} // namespace unsorted_segment + +TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD() { + static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, + unsorted_segment::EvalProd}; + return &r; +} + +TfLiteRegistration* Register_UNSORTED_SEGMENT_MAX() { + static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, + unsorted_segment::EvalMax}; + return &r; +} + +TfLiteRegistration* Register_UNSORTED_SEGMENT_SUM() { + static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, + unsorted_segment::EvalSum}; + return &r; +} + +TfLiteRegistration* Register_UNSORTED_SEGMENT_MIN() { + static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, + unsorted_segment::EvalMin}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h new file mode 100644 index 00000000..fb8fd798 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_UNSORTED_SEGMENT_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_UNSORTED_SEGMENT_H_ + +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD(); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif //TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_UNSORTED_SEGMENT_H_
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc new file mode 100644 index 00000000..10baefc0 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_max_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <limits.h> +#include <stdint.h> + +#include <initializer_list> +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/unsorted_segment_test.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template <typename T> +class UnsortedSegmentMaxModel : public UnsortedSegmentModel<T> { + public: + UnsortedSegmentMaxModel(const TensorData& data, const TensorData& segment_ids, + const TensorData& num_segments) + : UnsortedSegmentModel<T>(data, segment_ids, num_segments, + BuiltinOperator_UNSORTED_SEGMENT_MAX, + BuiltinOptions_NONE) {} + + explicit UnsortedSegmentMaxModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape) + : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape, + num_segments_data, num_segments_shape, + BuiltinOperator_UNSORTED_SEGMENT_MAX, + BuiltinOptions_NONE) {} +}; + +INSTANTIATE_TEST_SUITE_P(UnsortedSegmentMaxTestP, UnsortedSegmentTest, + testing::Values(BuiltinOperator_UNSORTED_SEGMENT_MAX)); + +TEST(UnsortedSegmentMaxModelTest, Int32Test_Simple) { + UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {6}}, + {TensorType_INT32, {6}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {5, 1, 7, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1, 1, 0, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 7})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(UnsortedSegmentMaxModelTest, Int32Test_Simple2D) { + UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), + {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentMaxModelTest, FloatTest_Simple) { + UnsortedSegmentMaxModel<float> model({TensorType_FLOAT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), + {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {2.0, 3.0, std::numeric_limits<float>::lowest(), + std::numeric_limits<float>::lowest(), + std::numeric_limits<float>::lowest(), + std::numeric_limits<float>::lowest(), + std::numeric_limits<float>::lowest(), 4.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentMaxModelTest, FloatTest_Simple2D) { + UnsortedSegmentMaxModel<float> model({TensorType_FLOAT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({4.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentMaxModelTest, SegmentsAreNegative) { + UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {2, 2}}, + {TensorType_INT32, {2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {-1, -1}); + model.PopulateTensor<int32_t>(model.num_segments(), {1}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({std::numeric_limits<int32_t>::lowest(), + std::numeric_limits<int32_t>::lowest()})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2})); +} + +TEST(UnsortedSegmentMaxModelTest, ConstParamenterTest) { + UnsortedSegmentMaxModel<int32_t> model({TensorType_INT32, {3, 2}}, {0, 1, 0}, + {3}, {2}, {1}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2})); +} + +} // namespace +} // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc new file mode 100644 index 00000000..bb4d4141 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_min_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <limits.h> +#include <stdint.h> + +#include <initializer_list> +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/unsorted_segment_test.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template <typename T> +class UnsortedSegmentMinModel : public UnsortedSegmentModel<T> { + public: + UnsortedSegmentMinModel(const TensorData& data, const TensorData& segment_ids, + const TensorData& num_segments) + : UnsortedSegmentModel<T>(data, segment_ids, num_segments, + BuiltinOperator_UNSORTED_SEGMENT_MIN, + BuiltinOptions_NONE) {} + + explicit UnsortedSegmentMinModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape) + : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape, + num_segments_data, num_segments_shape, + BuiltinOperator_UNSORTED_SEGMENT_MIN, + BuiltinOptions_NONE) {} +}; + +INSTANTIATE_TEST_SUITE_P(UnsortedSegmentMinTestP, UnsortedSegmentTest, + testing::Values(BuiltinOperator_UNSORTED_SEGMENT_MIN)); + +TEST(UnsortedSegmentMinModelTest, Int32Test_Simple) { + UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {6}}, + {TensorType_INT32, {6}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {5, 3, 7, 8, 6, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1, 1, 0, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(UnsortedSegmentMinModelTest, Int32Test_Simple2D) { + UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), + {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 2, 1, 5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentMinModelTest, FloatTest_Simple) { + UnsortedSegmentMinModel<float> model({TensorType_FLOAT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), + {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {2.0, 1.0, std::numeric_limits<float>::max(), + std::numeric_limits<float>::max(), std::numeric_limits<float>::max(), + std::numeric_limits<float>::max(), std::numeric_limits<float>::max(), + 1.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentMinModelTest, FloatTest_Simple2D) { + UnsortedSegmentMinModel<float> model({TensorType_FLOAT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({1.0, 2.0, 2.0, 1.0, 5.0, 6.0, 7.0, 8.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentMinModelTest, SegmentsAreNegative) { + UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {2, 2}}, + {TensorType_INT32, {2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {-1, -1}); + model.PopulateTensor<int32_t>(model.num_segments(), {1}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({std::numeric_limits<int32_t>::max(), + std::numeric_limits<int32_t>::max()})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2})); +} + +TEST(UnsortedSegmentMinModelTest, ConstParamenterTest) { + UnsortedSegmentMinModel<int32_t> model({TensorType_INT32, {3, 2}}, {0, 1, 0}, + {3}, {2}, {1}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2})); +} + +} // namespace +} // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc new file mode 100644 index 00000000..524d45a0 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc @@ -0,0 +1,122 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <stdint.h> + +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/unsorted_segment_test.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template <typename T> +class UnsortedSegmentProdModel : public UnsortedSegmentModel<T> { + public: + UnsortedSegmentProdModel(const TensorData& data, + const TensorData& segment_ids, + const TensorData& num_segments) + : UnsortedSegmentModel<T>(data, segment_ids, num_segments, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOptions_UnsortedSegmentProdOptions) {} + + explicit UnsortedSegmentProdModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape) + : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape, + num_segments_data, num_segments_shape, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOptions_UnsortedSegmentProdOptions) {} +}; + +INSTANTIATE_TEST_SUITE_P( + UnsortedSegmentProdTestP, UnsortedSegmentTest, + testing::Values(BuiltinOperator_UNSORTED_SEGMENT_PROD)); + +TEST(UnsortedSegmentProdModelTest, Int32Test_Simple) { + UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdModelTest, TestSkipNegSegmentId) { + UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, -1}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdModelTest, Int32Test_Simple2D) { + UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), + {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 6, 4, 5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentProdModelTest, FloatTest_Simple) { + UnsortedSegmentProdModel<float> model({TensorType_FLOAT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), + {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 96.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdModelTest, FloatTest_Simple2D) { + UnsortedSegmentProdModel<float> model({TensorType_FLOAT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({4.0, 6.0, 6.0, 4.0, 5.0, 6.0, 7.0, 8.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +} // namespace +} // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc new file mode 100644 index 00000000..2451c169 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_sum_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <limits.h> +#include <stdint.h> + +#include <initializer_list> +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/unsorted_segment_test.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template <typename T> +class UnsortedSegmentSumModel : public UnsortedSegmentModel<T> { + public: + UnsortedSegmentSumModel(const TensorData& data, const TensorData& segment_ids, + const TensorData& num_segments) + : UnsortedSegmentModel<T>(data, segment_ids, num_segments, + BuiltinOperator_UNSORTED_SEGMENT_SUM, + BuiltinOptions_NONE) {} + + explicit UnsortedSegmentSumModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape) + : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape, + num_segments_data, num_segments_shape, + BuiltinOperator_UNSORTED_SEGMENT_SUM, + BuiltinOptions_NONE) {} +}; + +INSTANTIATE_TEST_SUITE_P(UnsortedSegmentSumTestP, UnsortedSegmentTest, + testing::Values(BuiltinOperator_UNSORTED_SEGMENT_SUM)); + +TEST(UnsortedSegmentSumModelTest, Int32Test_Simple) { + UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {7}}, + {TensorType_INT32, {7}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {5, 1, 7, 2, 3, 4, 10}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1, 1, 0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({19, 13})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(UnsortedSegmentSumModelTest, Int32Test_Simple2D) { + UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), + {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentSumModelTest, FloatTest_Simple) { + UnsortedSegmentSumModel<float> model({TensorType_FLOAT32, {6}}, + {TensorType_INT32, {6}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 4.0, 3.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({2.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentSumModelTest, FloatTest_Simple2D) { + UnsortedSegmentSumModel<float> model({TensorType_FLOAT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 7.0, 8.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentSumModelTest, AllNegativeSegmentIdsZeroTensor) { + UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {2, 2}}, + {TensorType_INT32, {2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {-1, -1}); + model.PopulateTensor<int32_t>(model.num_segments(), {1}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2})); +} + +TEST(UnsortedSegmentSumModelTest, SomeNegativeSegmentIdsIgnored) { + UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {4}}, + {TensorType_INT32, {4}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {-1, 0, -1, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(UnsortedSegmentSumModelTest, + NumSegmentsGreaterThanNumIdsPadsWithZeroTensors) { + UnsortedSegmentSumModel<int32_t> model({TensorType_INT32, {2, 2}}, + {TensorType_INT32, {2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {3}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 0, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 2})); +} +} // namespace +} // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc new file mode 100644 index 00000000..6e398374 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "third_party/tensorflow/lite/kernels/unsorted_segment_test.h" + +#include <limits.h> +#include <stdint.h> + +#include <initializer_list> +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +TEST_P(UnsortedSegmentTest, SegmentIdsSizeNotEqualToDataFirstDimensionFails) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {3, 2}}, {TensorType_INT32, {2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteError); +} +TEST_P(UnsortedSegmentTest, + LargestSegmentIdPlusOneGreaterThanNumSegmentsFails) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {2, 2}}, {TensorType_INT32, {2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {1}); + ASSERT_EQ(model.Invoke(), kTfLiteError); +} +TEST_P(UnsortedSegmentTest, NumSegmentsNotScalarShapeFails) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {3, 2}}, {TensorType_INT32, {3}}, + {TensorType_INT32, {2}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2, 1}); + ASSERT_EQ(model.Invoke(), kTfLiteError); +} +TEST_P(UnsortedSegmentTest, Rank2SegIdsNotPrefixFails) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {2, 2, 2}}, {TensorType_INT32, {2, 1}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 1}); + model.PopulateTensor<int32_t>(model.num_segments(), {3}); + ASSERT_EQ(model.Invoke(), kTfLiteError); +} +TEST_P(UnsortedSegmentTest, Rank2SegIdsHasShapeNumSegDataShapeSuffix) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {2, 2, 2}}, {TensorType_INT32, {2, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 2, 0, 8}); + model.PopulateTensor<int32_t>(model.num_segments(), {10}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10, 2})); +} +TEST_P(UnsortedSegmentTest, Rank2SegIdsHasShapeNumSegDataShapeSuffixConst) { + UnsortedSegmentModel<int32_t> model = getConstModel( + {TensorType_INT32, {2, 2, 2}}, {1, 2, -1, -1}, {2, 2}, {3}, {1}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 2})); +} +TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData2d) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {2, 2}}, {TensorType_INT32, {2, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 5, 2, 4}); + model.PopulateTensor<int32_t>(model.num_segments(), {10}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10})); +} +TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData2dConst) { + UnsortedSegmentModel<int32_t> model = + getConstModel({TensorType_INT32, {2, 2}}, {1, 1, 1, 1}, {2, 2}, {3}, {1}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3})); +} +TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData3d) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {2, 2, 2}}, {TensorType_INT32, {2, 2, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6, 7, 8}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 2, 3, 4, 5, 6, 7, 8}); + model.PopulateTensor<int32_t>(model.num_segments(), {10}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10})); +} +TEST_P(UnsortedSegmentTest, SegIdsHasSameShapeAsData3dConst) { + UnsortedSegmentModel<int32_t> model = + getConstModel({TensorType_INT32, {2, 2, 2}}, {0, 1, 2, -1, 3, -1, 4, -1}, + {2, 2, 2}, {8}, {1}); + model.PopulateTensor<int32_t>(model.data(), {1, 1, 1, 1, 1, 1}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({8})); +} +TEST_P(UnsortedSegmentTest, Data5dHasShapeNumSegDataShapeSuffix) { + UnsortedSegmentModel<int32_t> model = + getModel({TensorType_INT32, {2, 1, 2, 1, 2}}, {TensorType_INT32, {2, 1}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6, 7, 8}); + model.PopulateTensor(model.segment_ids(), {0, 1}); + model.PopulateTensor(model.num_segments(), {10}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({10, 2, 1, 2})); +} +} // namespace +} // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h new file mode 100644 index 00000000..6a5b0f64 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_test.h @@ -0,0 +1,97 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_UNSORTED_SEGMENT_TEST_H_ +#define TENSORFLOW_LITE_KERNELS_UNSORTED_SEGMENT_TEST_H_ + +#include <limits.h> +#include <stdint.h> + +#include <initializer_list> +#include <iostream> +#include <ostream> +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +template <typename T> +class UnsortedSegmentModel : public SingleOpModel { + public: + UnsortedSegmentModel(const TensorData& data, const TensorData& segment_ids, + const TensorData& num_segments, const BuiltinOperator op, + const BuiltinOptions options) { + data_id_ = AddInput(data); + segment_ids_id_ = AddInput(segment_ids); + num_segments_id_ = AddInput(num_segments); + output_id_ = AddOutput(data.type); + SetBuiltinOp(op, options, 0); + BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_), + GetShape(num_segments_id_)}); + } + + explicit UnsortedSegmentModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape, + const BuiltinOperator op, const BuiltinOptions options) { + data_id_ = AddInput(data); + segment_ids_id_ = + AddConstInput(TensorType_INT32, segment_id_data, segment_id_shape); + num_segments_id_ = + AddConstInput(TensorType_INT32, num_segments_data, num_segments_shape); + output_id_ = AddOutput(data.type); + SetBuiltinOp(op, options, 0); + BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_), + GetShape(num_segments_id_)}); + } + + int data() const { return data_id_; } + int segment_ids() const { return segment_ids_id_; } + int num_segments() const { return num_segments_id_; } + std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); } + std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); } + + protected: + int data_id_; + int segment_ids_id_; + int num_segments_id_; + int output_id_; +}; + +class UnsortedSegmentTest : public ::testing::TestWithParam<BuiltinOperator> { + public: + UnsortedSegmentModel<int32_t> getModel(const TensorData& data, + const TensorData& segment_ids, + const TensorData& num_segments) { + return UnsortedSegmentModel<int32_t>(data, segment_ids, num_segments, + GetParam(), BuiltinOptions_NONE); + } + UnsortedSegmentModel<int32_t> getConstModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape) { + return UnsortedSegmentModel<int32_t>( + data, segment_id_data, segment_id_shape, num_segments_data, + num_segments_shape, GetParam(), BuiltinOptions_NONE); + } +}; + +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_UNSORTED_SEGMENT_TEST_H_
\ No newline at end of file |