diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-04-25 01:10:58 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-04-25 01:10:58 +0000 |
commit | 306515176b8b0b18615e33ef3418cdab620655ea (patch) | |
tree | bc2d1560d1cf9dfb9d8c113d6578dec84a456656 /tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc | |
parent | e74873ae528c6bae607b41bc4d0f22d6a3392a71 (diff) | |
parent | cfed19f460a47f84949afa32b37965658588091d (diff) | |
download | tflite-support-306515176b8b0b18615e33ef3418cdab620655ea.tar.gz |
Snap for 9997652 from 70c43f65ce1a3ded41663501daaf6361bb644ae1 to udc-release am: cfed19f460android14-gsi
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/tflite-support/+/22825883
Change-Id: Ic933f00e1ca2b5b1b929829e7fa807e3dbc33667
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
Diffstat (limited to 'tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc')
-rw-r--r-- | tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc new file mode 100644 index 00000000..524d45a0 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/unsorted_segment_prod_test.cc @@ -0,0 +1,122 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <stdint.h> + +#include <vector> + +#include "testing/base/public/gunit.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/unsorted_segment_test.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template <typename T> +class UnsortedSegmentProdModel : public UnsortedSegmentModel<T> { + public: + UnsortedSegmentProdModel(const TensorData& data, + const TensorData& segment_ids, + const TensorData& num_segments) + : UnsortedSegmentModel<T>(data, segment_ids, num_segments, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOptions_UnsortedSegmentProdOptions) {} + + explicit UnsortedSegmentProdModel( + const TensorData& data, const std::initializer_list<int>& segment_id_data, + const std::initializer_list<int>& segment_id_shape, + const std::initializer_list<int>& num_segments_data, + const std::initializer_list<int>& num_segments_shape) + : UnsortedSegmentModel<T>(data, segment_id_data, segment_id_shape, + num_segments_data, num_segments_shape, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOptions_UnsortedSegmentProdOptions) {} +}; + +INSTANTIATE_TEST_SUITE_P( + UnsortedSegmentProdTestP, UnsortedSegmentTest, + testing::Values(BuiltinOperator_UNSORTED_SEGMENT_PROD)); + +TEST(UnsortedSegmentProdModelTest, Int32Test_Simple) { + UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdModelTest, TestSkipNegSegmentId) { + UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, -1}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdModelTest, Int32Test_Simple2D) { + UnsortedSegmentProdModel<int32_t> model({TensorType_INT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<int32_t>(model.data(), + {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 6, 4, 5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(UnsortedSegmentProdModelTest, FloatTest_Simple) { + UnsortedSegmentProdModel<float> model({TensorType_FLOAT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), + {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor<int32_t>(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 96.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdModelTest, FloatTest_Simple2D) { + UnsortedSegmentProdModel<float> model({TensorType_FLOAT32, {3, 4}}, + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); + model.PopulateTensor<float>(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 4.0, 3.0, 2.0, 1.0}); + model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor<int32_t>(model.num_segments(), {2}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({4.0, 6.0, 6.0, 4.0, 5.0, 6.0, 7.0, 8.0}))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +} // namespace +} // namespace tflite
\ No newline at end of file |