diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-10-28 08:54:00 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-10-28 08:54:00 +0000 |
commit | bef7369682efa988264ef60a4c30555d822e3a62 (patch) | |
tree | f17cc8ccec5e44b6ad8a19967d14c4b5299182f2 | |
parent | 1eb3362a578b834284836c47d73eb466b813efaa (diff) | |
parent | 0320ce8084721a759e7d2e6965d17540eeec8c36 (diff) | |
download | tflite-support-bef7369682efa988264ef60a4c30555d822e3a62.tar.gz |
Snap for 9229821 from 0320ce8084721a759e7d2e6965d17540eeec8c36 to mainline-networking-releaseaml_net_331313030aml_net_331313010
Change-Id: Idefb4b9fdec32028ade7bd505487185b461c4788
9 files changed, 99 insertions, 18 deletions
@@ -221,6 +221,8 @@ cc_library_shared { "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc", "tensorflow_lite_support/cc/utils/jni_utils.cc", ], + // TODO(b/247088924): Use linker_scripts here. + version_script: "tensorflow_lite_support/java/tflite_version_script.lds", shared_libs: ["liblog"], static_libs: [ "libprotobuf-cpp-lite-ndk", diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc index a246066b..88a8a0af 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -245,6 +245,15 @@ absl::Status BertNLClassifier::InitializeFromMetadata() { segment_ids_tensor.dims->data[1]), TfLiteSupportStatus::kInvalidInputTensorSizeError); } + + // If some tensor does not have a size 2 dims_signature, then we + // assume the input is not dynamic. + if (ids_tensor.dims_signature->size != 2 || + mask_tensor.dims_signature->size != 2 || + segment_ids_tensor.dims_signature->size != 2) { + return absl::OkStatus(); + } + if (ids_tensor.dims_signature->data[1] == -1 && mask_tensor.dims_signature->data[1] == -1 && segment_ids_tensor.dims_signature->data[1] == -1) { diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc index a75fe0ff..d322a151 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -185,9 +185,16 @@ absl::Status NLClassifier::TrySetLabelFromMetadata( } std::vector<Category> NLClassifier::Classify(const std::string& text) { - // The NLClassifier implementation for Preprocess() and Postprocess() never - // returns errors: just call value(). - return Infer(text).value(); + StatusOr<std::vector<Category>> infer_result = ClassifyText(text); + if (!infer_result.ok()) { + return {}; + } + return infer_result.value(); +} + +StatusOr<std::vector<Category>> NLClassifier::ClassifyText( + const std::string& text) { + return Infer(text); } std::string NLClassifier::GetModelVersion() const { diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h index 013b7d53..8e70ba4c 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -109,9 +109,14 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, std::unique_ptr<tflite::OpResolver> resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); - // Performs classification on a string input, returns classified results. + // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`. std::vector<core::Category> Classify(const std::string& text); + // Performs classification on a string input, returns classified results or an + // error. + tflite::support::StatusOr<std::vector<core::Category>> ClassifyText( + const std::string& text); + // Gets the model version, or "NO_VERSION_INFO" in case there is no version. std::string GetModelVersion() const; diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc index 85ebc505..23a3ddca 100644 --- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc +++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc @@ -85,15 +85,19 @@ Category* GetCategoryWithClassName(const std::string& class_name, void verify_classifier(std::unique_ptr<BertNLClassifier> classifier, bool verify_positive) { if (verify_positive) { - std::vector<core::Category> results = - classifier->Classify("unflinchingly bleak and desperate"); - EXPECT_GT(GetCategoryWithClassName("negative", results)->score, - GetCategoryWithClassName("positive", results)->score); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier->ClassifyText("unflinchingly bleak and desperate"); + + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score, + GetCategoryWithClassName("positive", results.value())->score); } else { - std::vector<Category> results = - classifier->Classify("it's a charming and often affecting journey"); - EXPECT_GT(GetCategoryWithClassName("positive", results)->score, - GetCategoryWithClassName("negative", results)->score); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier->ClassifyText("it's a charming and often affecting journey"); + + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, + GetCategoryWithClassName("negative", results.value())->score); } } @@ -151,11 +155,12 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) { BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size()); EXPECT_TRUE(classifier.ok()); - std::vector<core::Category> results = - classifier.value()->Classify(ss_for_positive_review.str()); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier.value()->ClassifyText(ss_for_positive_review.str()); - EXPECT_GT(GetCategoryWithClassName("positive", results)->score, - GetCategoryWithClassName("negative", results)->score); + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, + GetCategoryWithClassName("negative", results.value())->score); } } // namespace diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java index 0a609fd2..8c71f705 100644 --- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java @@ -29,6 +29,8 @@ import org.tensorflow.lite.task.core.TestUtils; /** Test for {@link BertNLClassifier}. */ public class BertNLClassifierTest { private static final String MODEL_FILE = "bert_nl_classifier.tflite"; + // A classifier model with dynamic input tensors. Provided by the Android Rubidium team. + private static final String DYNAMIC_INPUT_MODEL_FILE = "rb_model.tflite"; Category findCategoryWithLabel(List<Category> list, String label) { return list.stream() @@ -68,6 +70,15 @@ public class BertNLClassifierTest { } @Test + public void classify_succeedsWithDynamicInputModelBuffer() throws IOException { + verifyDynamicInputResults( + BertNLClassifier.createFromBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), + DYNAMIC_INPUT_MODEL_FILE))); + } + + @Test public void getModelVersion_succeedsWithVersionInMetadata() throws IOException { BertNLClassifier classifier = BertNLClassifier.createFromFile( ApplicationProvider.getApplicationContext(), MODEL_FILE); @@ -76,6 +87,14 @@ public class BertNLClassifierTest { } @Test + public void getModelVersion_succeedsWithDynamicInputModelVersion() throws IOException { + BertNLClassifier classifier = BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE); + + assertThat(classifier.getModelVersion()).isEqualTo("2"); + } + + @Test public void getLabelsVersion_succeedsWithNoVersionInMetadata() throws IOException { BertNLClassifier classifier = BertNLClassifier.createFromFile( ApplicationProvider.getApplicationContext(), MODEL_FILE); @@ -83,6 +102,14 @@ public class BertNLClassifierTest { assertThat(classifier.getLabelsVersion()).isEqualTo("NO_VERSION_INFO"); } + @Test + public void getLabelsVersion_succeedsWithDynamicInputLabelsVersion() throws IOException { + BertNLClassifier classifier = BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE); + + assertThat(classifier.getLabelsVersion()).isEqualTo("2"); + } + private void verifyResults(BertNLClassifier classifier) { List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate"); assertThat(findCategoryWithLabel(negativeResults, "negative").getScore()) @@ -93,4 +120,10 @@ public class BertNLClassifierTest { assertThat(findCategoryWithLabel(positiveResults, "positive").getScore()) .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore()); } + + private void verifyDynamicInputResults(BertNLClassifier classifier) { + List<Category> topics = classifier.classify("FooBarBaz"); + assertThat(topics.size()).isEqualTo(446); + // TODO(ag/19888344): Add a test for a long text input. + } } diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite Binary files differnew file mode 100644 index 00000000..56fe4703 --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc index 6b002f71..d2f6e7ca 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc @@ -23,6 +23,9 @@ namespace task { namespace text { namespace nlclassifier { +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; using ::tflite::support::utils::ConvertVectorToArrayList; using ::tflite::support::utils::JStringToString; using ::tflite::task::core::Category; @@ -31,14 +34,21 @@ using ::tflite::task::text::nlclassifier::NLClassifier; jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) { auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); - auto results = nl_classifier->Classify(JStringToString(env, text)); + auto results = nl_classifier->ClassifyText(JStringToString(env, text)); + if (!results.ok()) { + ThrowException(env, kAssertionError, + "Error occurred when running classifier: %s", + results.status().message().data()); + return env->ExceptionOccurred(); + } + jclass category_class = env->FindClass("org/tensorflow/lite/support/label/Category"); jmethodID category_init = env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V"); return ConvertVectorToArrayList<Category>( - env, results, + env, results.value(), [env, category_class, category_init](const Category& category) { jstring class_name = env->NewStringUTF(category.class_name.data()); // Convert double to float as Java interface exposes float as scores. diff --git a/tensorflow_lite_support/java/tflite_version_script.lds b/tensorflow_lite_support/java/tflite_version_script.lds new file mode 100644 index 00000000..604c923a --- /dev/null +++ b/tensorflow_lite_support/java/tflite_version_script.lds @@ -0,0 +1,10 @@ +VERS_1.0 { + # Export JNI and native C symbols. + global: + Java_*; + + # Hide everything else. + local: + *; +}; + |