diff options
author | Ben Nissan <bennissan@google.com> | 2022-08-19 16:44:36 +0000 |
---|---|---|
committer | Ben Nissan <bennissan@google.com> | 2022-08-23 17:30:58 +0000 |
commit | 9fc1fcbd2239b6c8f6fa8399c6580d4f350e2150 (patch) | |
tree | fd7b7d30537d59b8f0a635b20c511543623f87aa | |
parent | e4266a8951b0f53a80bc54392276c355583c0a53 (diff) | |
download | tflite-support-9fc1fcbd2239b6c8f6fa8399c6580d4f350e2150.tar.gz |
Return StatusOr from Classify() methods
This CL updates the Classify() methods in NLClassifier and
BertNLClassifier to return StatusOr wrappers for their contents,
allowing errors to be propagated across the JNI boundary.
To enable this, it fixes a bug when checking for dynamic vs static input
tensors in bert_nl_classifier.cc and throws exceptions from
RunClassifier().
Bug: 242926638, Bug: 242926783
Test: atest tflite_support_classifier_tests
Test: atest TfliteSupportClassifierTests
Change-Id: Ib93385061ace97cf96a168e3d76d2d9162c8055f
5 files changed, 55 insertions, 24 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc index a246066b..4c8b3fd7 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -245,13 +245,17 @@ absl::Status BertNLClassifier::InitializeFromMetadata() { segment_ids_tensor.dims->data[1]), TfLiteSupportStatus::kInvalidInputTensorSizeError); } - if (ids_tensor.dims_signature->data[1] == -1 && - mask_tensor.dims_signature->data[1] == -1 && - segment_ids_tensor.dims_signature->data[1] == -1) { + + bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 && + mask_tensor.dims_signature->size == 2 && + segment_ids_tensor.dims_signature->size == 2; + int num_dynamic_tensors = ids_tensor.dims_signature->data[1] == -1 + + mask_tensor.dims_signature->data[1] == -1 + + segment_ids_tensor.dims_signature->data[1] == -1; + + if (has_valid_dims_signature && num_dynamic_tensors == 3) { input_tensors_are_dynamic_ = true; - } else if (ids_tensor.dims_signature->data[1] == -1 || - mask_tensor.dims_signature->data[1] == -1 || - segment_ids_tensor.dims_signature->data[1] == -1) { + } else if (has_valid_dims_signature && num_dynamic_tensors > 0 && num_dynamic_tensors < 3) { return CreateStatusWithPayload( absl::StatusCode::kInternal, "Input tensors contain a mix of static and dynamic tensors", diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc index a75fe0ff..d322a151 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -185,9 +185,16 @@ absl::Status NLClassifier::TrySetLabelFromMetadata( } std::vector<Category> NLClassifier::Classify(const std::string& text) { - // The NLClassifier implementation for Preprocess() and Postprocess() never - // returns errors: just call value(). - return Infer(text).value(); + StatusOr<std::vector<Category>> infer_result = ClassifyText(text); + if (!infer_result.ok()) { + return {}; + } + return infer_result.value(); +} + +StatusOr<std::vector<Category>> NLClassifier::ClassifyText( + const std::string& text) { + return Infer(text); } std::string NLClassifier::GetModelVersion() const { diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h index 013b7d53..8e70ba4c 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -109,9 +109,14 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, std::unique_ptr<tflite::OpResolver> resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); - // Performs classification on a string input, returns classified results. + // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`. std::vector<core::Category> Classify(const std::string& text); + // Performs classification on a string input, returns classified results or an + // error. + tflite::support::StatusOr<std::vector<core::Category>> ClassifyText( + const std::string& text); + // Gets the model version, or "NO_VERSION_INFO" in case there is no version. std::string GetModelVersion() const; diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc index 85ebc505..23a3ddca 100644 --- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc +++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc @@ -85,15 +85,19 @@ Category* GetCategoryWithClassName(const std::string& class_name, void verify_classifier(std::unique_ptr<BertNLClassifier> classifier, bool verify_positive) { if (verify_positive) { - std::vector<core::Category> results = - classifier->Classify("unflinchingly bleak and desperate"); - EXPECT_GT(GetCategoryWithClassName("negative", results)->score, - GetCategoryWithClassName("positive", results)->score); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier->ClassifyText("unflinchingly bleak and desperate"); + + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score, + GetCategoryWithClassName("positive", results.value())->score); } else { - std::vector<Category> results = - classifier->Classify("it's a charming and often affecting journey"); - EXPECT_GT(GetCategoryWithClassName("positive", results)->score, - GetCategoryWithClassName("negative", results)->score); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier->ClassifyText("it's a charming and often affecting journey"); + + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, + GetCategoryWithClassName("negative", results.value())->score); } } @@ -151,11 +155,12 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) { BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size()); EXPECT_TRUE(classifier.ok()); - std::vector<core::Category> results = - classifier.value()->Classify(ss_for_positive_review.str()); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier.value()->ClassifyText(ss_for_positive_review.str()); - EXPECT_GT(GetCategoryWithClassName("positive", results)->score, - GetCategoryWithClassName("negative", results)->score); + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, + GetCategoryWithClassName("negative", results.value())->score); } } // namespace diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc index 6b002f71..d2f6e7ca 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc @@ -23,6 +23,9 @@ namespace task { namespace text { namespace nlclassifier { +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; using ::tflite::support::utils::ConvertVectorToArrayList; using ::tflite::support::utils::JStringToString; using ::tflite::task::core::Category; @@ -31,14 +34,21 @@ using ::tflite::task::text::nlclassifier::NLClassifier; jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) { auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); - auto results = nl_classifier->Classify(JStringToString(env, text)); + auto results = nl_classifier->ClassifyText(JStringToString(env, text)); + if (!results.ok()) { + ThrowException(env, kAssertionError, + "Error occurred when running classifier: %s", + results.status().message().data()); + return env->ExceptionOccurred(); + } + jclass category_class = env->FindClass("org/tensorflow/lite/support/label/Category"); jmethodID category_init = env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V"); return ConvertVectorToArrayList<Category>( - env, results, + env, results.value(), [env, category_class, category_init](const Category& category) { jstring class_name = env->NewStringUTF(category.class_name.data()); // Convert double to float as Java interface exposes float as scores. |