aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfban <fban@google.com>2022-09-07 19:00:51 +0000
committerFrank Ban <fban@google.com>2022-09-08 18:38:22 +0000
commit589eb63fba6c01c7c7cf3cbf601f60d7b87334df (patch)
tree7ac824f9d3359e90590c4b4bf75032854af51fb5
parent6d4f359d96b0efb8626d5d7cdf8f49ab7aa0274d (diff)
downloadtflite-support-589eb63fba6c01c7c7cf3cbf601f60d7b87334df.tar.gz
Adds a dynamic input tensor model to the BertNLClassifier Java test.
The model is provided by the Android Rubidium team (G3 Path: google3/wireless/android/adservices/mdd/topics_classifier/2/model.tflite). The hope is that future changes to external/tflite-support can be sanity checked by this unit test before impacting the Rubidium team's repo. Bug: 241507692 Test: atest TfliteSupportClassifierTests Change-Id: I371c4ce822a4650db37083e16e7d6e2fba44c17c
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java33
-rw-r--r--tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflitebin0 -> 5123808 bytes
2 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
index 0a609fd2..8c71f705 100644
--- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -29,6 +29,8 @@ import org.tensorflow.lite.task.core.TestUtils;
/** Test for {@link BertNLClassifier}. */
public class BertNLClassifierTest {
private static final String MODEL_FILE = "bert_nl_classifier.tflite";
+ // A classifier model with dynamic input tensors. Provided by the Android Rubidium team.
+ private static final String DYNAMIC_INPUT_MODEL_FILE = "rb_model.tflite";
Category findCategoryWithLabel(List<Category> list, String label) {
return list.stream()
@@ -68,6 +70,15 @@ public class BertNLClassifierTest {
}
@Test
+ public void classify_succeedsWithDynamicInputModelBuffer() throws IOException {
+ verifyDynamicInputResults(
+ BertNLClassifier.createFromBuffer(
+ TestUtils.loadToDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(),
+ DYNAMIC_INPUT_MODEL_FILE)));
+ }
+
+ @Test
public void getModelVersion_succeedsWithVersionInMetadata() throws IOException {
BertNLClassifier classifier = BertNLClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
@@ -76,6 +87,14 @@ public class BertNLClassifierTest {
}
@Test
+ public void getModelVersion_succeedsWithDynamicInputModelVersion() throws IOException {
+ BertNLClassifier classifier = BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE);
+
+ assertThat(classifier.getModelVersion()).isEqualTo("2");
+ }
+
+ @Test
public void getLabelsVersion_succeedsWithNoVersionInMetadata() throws IOException {
BertNLClassifier classifier = BertNLClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
@@ -83,6 +102,14 @@ public class BertNLClassifierTest {
assertThat(classifier.getLabelsVersion()).isEqualTo("NO_VERSION_INFO");
}
+ @Test
+ public void getLabelsVersion_succeedsWithDynamicInputLabelsVersion() throws IOException {
+ BertNLClassifier classifier = BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE);
+
+ assertThat(classifier.getLabelsVersion()).isEqualTo("2");
+ }
+
private void verifyResults(BertNLClassifier classifier) {
List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate");
assertThat(findCategoryWithLabel(negativeResults, "negative").getScore())
@@ -93,4 +120,10 @@ public class BertNLClassifierTest {
assertThat(findCategoryWithLabel(positiveResults, "positive").getScore())
.isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore());
}
+
+ private void verifyDynamicInputResults(BertNLClassifier classifier) {
+ List<Category> topics = classifier.classify("FooBarBaz");
+ assertThat(topics.size()).isEqualTo(446);
+ // TODO(ag/19888344): Add a test for a long text input.
+ }
}
diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite
new file mode 100644
index 00000000..56fe4703
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite
Binary files differ