diff options
author | fban <fban@google.com> | 2022-09-07 19:00:51 +0000 |
---|---|---|
committer | Frank Ban <fban@google.com> | 2022-09-08 18:38:22 +0000 |
commit | 589eb63fba6c01c7c7cf3cbf601f60d7b87334df (patch) | |
tree | 7ac824f9d3359e90590c4b4bf75032854af51fb5 | |
parent | 6d4f359d96b0efb8626d5d7cdf8f49ab7aa0274d (diff) | |
download | tflite-support-589eb63fba6c01c7c7cf3cbf601f60d7b87334df.tar.gz |
Adds a dynamic input tensor model to the BertNLClassifier Java test.
The model is provided by the Android Rubidium team (G3 Path: google3/wireless/android/adservices/mdd/topics_classifier/2/model.tflite). The hope is that
future changes to external/tflite-support can be sanity checked by this
unit test before impacting the Rubidium team's repo.
Bug: 241507692
Test: atest TfliteSupportClassifierTests
Change-Id: I371c4ce822a4650db37083e16e7d6e2fba44c17c
2 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java index 0a609fd2..8c71f705 100644 --- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java @@ -29,6 +29,8 @@ import org.tensorflow.lite.task.core.TestUtils; /** Test for {@link BertNLClassifier}. */ public class BertNLClassifierTest { private static final String MODEL_FILE = "bert_nl_classifier.tflite"; + // A classifier model with dynamic input tensors. Provided by the Android Rubidium team. + private static final String DYNAMIC_INPUT_MODEL_FILE = "rb_model.tflite"; Category findCategoryWithLabel(List<Category> list, String label) { return list.stream() @@ -68,6 +70,15 @@ public class BertNLClassifierTest { } @Test + public void classify_succeedsWithDynamicInputModelBuffer() throws IOException { + verifyDynamicInputResults( + BertNLClassifier.createFromBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), + DYNAMIC_INPUT_MODEL_FILE))); + } + + @Test public void getModelVersion_succeedsWithVersionInMetadata() throws IOException { BertNLClassifier classifier = BertNLClassifier.createFromFile( ApplicationProvider.getApplicationContext(), MODEL_FILE); @@ -76,6 +87,14 @@ public class BertNLClassifierTest { } @Test + public void getModelVersion_succeedsWithDynamicInputModelVersion() throws IOException { + BertNLClassifier classifier = BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE); + + assertThat(classifier.getModelVersion()).isEqualTo("2"); + } + + @Test public void getLabelsVersion_succeedsWithNoVersionInMetadata() throws IOException { BertNLClassifier classifier = BertNLClassifier.createFromFile( ApplicationProvider.getApplicationContext(), MODEL_FILE); @@ -83,6 +102,14 @@ public class BertNLClassifierTest { assertThat(classifier.getLabelsVersion()).isEqualTo("NO_VERSION_INFO"); } + @Test + public void getLabelsVersion_succeedsWithDynamicInputLabelsVersion() throws IOException { + BertNLClassifier classifier = BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE); + + assertThat(classifier.getLabelsVersion()).isEqualTo("2"); + } + private void verifyResults(BertNLClassifier classifier) { List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate"); assertThat(findCategoryWithLabel(negativeResults, "negative").getScore()) @@ -93,4 +120,10 @@ public class BertNLClassifierTest { assertThat(findCategoryWithLabel(positiveResults, "positive").getScore()) .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore()); } + + private void verifyDynamicInputResults(BertNLClassifier classifier) { + List<Category> topics = classifier.classify("FooBarBaz"); + assertThat(topics.size()).isEqualTo(446); + // TODO(ag/19888344): Add a test for a long text input. + } } diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite Binary files differnew file mode 100644 index 00000000..56fe4703 --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite |