aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-10-28 08:54:00 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-10-28 08:54:00 +0000
commitbef7369682efa988264ef60a4c30555d822e3a62 (patch)
treef17cc8ccec5e44b6ad8a19967d14c4b5299182f2
parent1eb3362a578b834284836c47d73eb466b813efaa (diff)
parent0320ce8084721a759e7d2e6965d17540eeec8c36 (diff)
downloadtflite-support-bef7369682efa988264ef60a4c30555d822e3a62.tar.gz
Snap for 9229821 from 0320ce8084721a759e7d2e6965d17540eeec8c36 to mainline-networking-releaseaml_net_331313030aml_net_331313010
Change-Id: Idefb4b9fdec32028ade7bd505487185b461c4788
-rw-r--r--Android.bp2
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc9
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc13
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h7
-rw-r--r--tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc29
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java33
-rw-r--r--tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflitebin0 -> 5123808 bytes
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc14
-rw-r--r--tensorflow_lite_support/java/tflite_version_script.lds10
9 files changed, 99 insertions, 18 deletions
diff --git a/Android.bp b/Android.bp
index de4f3ddd..598623f1 100644
--- a/Android.bp
+++ b/Android.bp
@@ -221,6 +221,8 @@ cc_library_shared {
"tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc",
"tensorflow_lite_support/cc/utils/jni_utils.cc",
],
+ // TODO(b/247088924): Use linker_scripts here.
+ version_script: "tensorflow_lite_support/java/tflite_version_script.lds",
shared_libs: ["liblog"],
static_libs: [
"libprotobuf-cpp-lite-ndk",
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index a246066b..88a8a0af 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -245,6 +245,15 @@ absl::Status BertNLClassifier::InitializeFromMetadata() {
segment_ids_tensor.dims->data[1]),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
+
+ // If some tensor does not have a size 2 dims_signature, then we
+ // assume the input is not dynamic.
+ if (ids_tensor.dims_signature->size != 2 ||
+ mask_tensor.dims_signature->size != 2 ||
+ segment_ids_tensor.dims_signature->size != 2) {
+ return absl::OkStatus();
+ }
+
if (ids_tensor.dims_signature->data[1] == -1 &&
mask_tensor.dims_signature->data[1] == -1 &&
segment_ids_tensor.dims_signature->data[1] == -1) {
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
index a75fe0ff..d322a151 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -185,9 +185,16 @@ absl::Status NLClassifier::TrySetLabelFromMetadata(
}
std::vector<Category> NLClassifier::Classify(const std::string& text) {
- // The NLClassifier implementation for Preprocess() and Postprocess() never
- // returns errors: just call value().
- return Infer(text).value();
+ StatusOr<std::vector<Category>> infer_result = ClassifyText(text);
+ if (!infer_result.ok()) {
+ return {};
+ }
+ return infer_result.value();
+}
+
+StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
+ const std::string& text) {
+ return Infer(text);
}
std::string NLClassifier::GetModelVersion() const {
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
index 013b7d53..8e70ba4c 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -109,9 +109,14 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
- // Performs classification on a string input, returns classified results.
+ // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`.
std::vector<core::Category> Classify(const std::string& text);
+ // Performs classification on a string input, returns classified results or an
+ // error.
+ tflite::support::StatusOr<std::vector<core::Category>> ClassifyText(
+ const std::string& text);
+
// Gets the model version, or "NO_VERSION_INFO" in case there is no version.
std::string GetModelVersion() const;
diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
index 85ebc505..23a3ddca 100644
--- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
+++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
@@ -85,15 +85,19 @@ Category* GetCategoryWithClassName(const std::string& class_name,
void verify_classifier(std::unique_ptr<BertNLClassifier> classifier,
bool verify_positive) {
if (verify_positive) {
- std::vector<core::Category> results =
- classifier->Classify("unflinchingly bleak and desperate");
- EXPECT_GT(GetCategoryWithClassName("negative", results)->score,
- GetCategoryWithClassName("positive", results)->score);
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier->ClassifyText("unflinchingly bleak and desperate");
+
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score,
+ GetCategoryWithClassName("positive", results.value())->score);
} else {
- std::vector<Category> results =
- classifier->Classify("it's a charming and often affecting journey");
- EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
- GetCategoryWithClassName("negative", results)->score);
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier->ClassifyText("it's a charming and often affecting journey");
+
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
+ GetCategoryWithClassName("negative", results.value())->score);
}
}
@@ -151,11 +155,12 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
EXPECT_TRUE(classifier.ok());
- std::vector<core::Category> results =
- classifier.value()->Classify(ss_for_positive_review.str());
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier.value()->ClassifyText(ss_for_positive_review.str());
- EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
- GetCategoryWithClassName("negative", results)->score);
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
+ GetCategoryWithClassName("negative", results.value())->score);
}
} // namespace
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
index 0a609fd2..8c71f705 100644
--- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -29,6 +29,8 @@ import org.tensorflow.lite.task.core.TestUtils;
/** Test for {@link BertNLClassifier}. */
public class BertNLClassifierTest {
private static final String MODEL_FILE = "bert_nl_classifier.tflite";
+ // A classifier model with dynamic input tensors. Provided by the Android Rubidium team.
+ private static final String DYNAMIC_INPUT_MODEL_FILE = "rb_model.tflite";
Category findCategoryWithLabel(List<Category> list, String label) {
return list.stream()
@@ -68,6 +70,15 @@ public class BertNLClassifierTest {
}
@Test
+ public void classify_succeedsWithDynamicInputModelBuffer() throws IOException {
+ verifyDynamicInputResults(
+ BertNLClassifier.createFromBuffer(
+ TestUtils.loadToDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(),
+ DYNAMIC_INPUT_MODEL_FILE)));
+ }
+
+ @Test
public void getModelVersion_succeedsWithVersionInMetadata() throws IOException {
BertNLClassifier classifier = BertNLClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
@@ -76,6 +87,14 @@ public class BertNLClassifierTest {
}
@Test
+ public void getModelVersion_succeedsWithDynamicInputModelVersion() throws IOException {
+ BertNLClassifier classifier = BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE);
+
+ assertThat(classifier.getModelVersion()).isEqualTo("2");
+ }
+
+ @Test
public void getLabelsVersion_succeedsWithNoVersionInMetadata() throws IOException {
BertNLClassifier classifier = BertNLClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
@@ -83,6 +102,14 @@ public class BertNLClassifierTest {
assertThat(classifier.getLabelsVersion()).isEqualTo("NO_VERSION_INFO");
}
+ @Test
+ public void getLabelsVersion_succeedsWithDynamicInputLabelsVersion() throws IOException {
+ BertNLClassifier classifier = BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), DYNAMIC_INPUT_MODEL_FILE);
+
+ assertThat(classifier.getLabelsVersion()).isEqualTo("2");
+ }
+
private void verifyResults(BertNLClassifier classifier) {
List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate");
assertThat(findCategoryWithLabel(negativeResults, "negative").getScore())
@@ -93,4 +120,10 @@ public class BertNLClassifierTest {
assertThat(findCategoryWithLabel(positiveResults, "positive").getScore())
.isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore());
}
+
+ private void verifyDynamicInputResults(BertNLClassifier classifier) {
+ List<Category> topics = classifier.classify("FooBarBaz");
+ assertThat(topics.size()).isEqualTo(446);
+ // TODO(ag/19888344): Add a test for a long text input.
+ }
}
diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite
new file mode 100644
index 00000000..56fe4703
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/rb_model.tflite
Binary files differ
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
index 6b002f71..d2f6e7ca 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -23,6 +23,9 @@ namespace task {
namespace text {
namespace nlclassifier {
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::ThrowException;
using ::tflite::support::utils::ConvertVectorToArrayList;
using ::tflite::support::utils::JStringToString;
using ::tflite::task::core::Category;
@@ -31,14 +34,21 @@ using ::tflite::task::text::nlclassifier::NLClassifier;
jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) {
auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
- auto results = nl_classifier->Classify(JStringToString(env, text));
+ auto results = nl_classifier->ClassifyText(JStringToString(env, text));
+ if (!results.ok()) {
+ ThrowException(env, kAssertionError,
+ "Error occurred when running classifier: %s",
+ results.status().message().data());
+ return env->ExceptionOccurred();
+ }
+
jclass category_class =
env->FindClass("org/tensorflow/lite/support/label/Category");
jmethodID category_init =
env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V");
return ConvertVectorToArrayList<Category>(
- env, results,
+ env, results.value(),
[env, category_class, category_init](const Category& category) {
jstring class_name = env->NewStringUTF(category.class_name.data());
// Convert double to float as Java interface exposes float as scores.
diff --git a/tensorflow_lite_support/java/tflite_version_script.lds b/tensorflow_lite_support/java/tflite_version_script.lds
new file mode 100644
index 00000000..604c923a
--- /dev/null
+++ b/tensorflow_lite_support/java/tflite_version_script.lds
@@ -0,0 +1,10 @@
+VERS_1.0 {
+ # Export JNI and native C symbols.
+ global:
+ Java_*;
+
+ # Hide everything else.
+ local:
+ *;
+};
+