diff options
author | Ben Nissan <bennissan@google.com> | 2022-08-24 02:53:11 +0000 |
---|---|---|
committer | Ben Nissan <bennissan@google.com> | 2022-08-24 02:53:11 +0000 |
commit | 510ea86c2952e784f46b81356a0c18a2bc7f0e43 (patch) | |
tree | 845121243bce2106a5d816226999d6c87d6748bf | |
parent | 3fccb92b5eb9831e71380ab6fd5dcdcda13ef32d (diff) | |
download | tflite-support-510ea86c2952e784f46b81356a0c18a2bc7f0e43.tar.gz |
Revert "[automerge] Return StatusOr from Classify() methods 2p: ..."
Revert "Return StatusOr from Classify() methods"
Revert submission 19678847-presubmit-am-86cd332fad154163bed53094e7663a44
Reason for revert: Potentially causing system failures within AdServices (e.g.: https://android-build.googleplex.com/builds/tests/view?invocationId=I42200010083746944&testResultId=TR21628220208011705), needs further testing.
Reverted Changes:
Ib77b4dc3d:[automerge] Return StatusOr from Classify() method...
Ib93385061:Return StatusOr from Classify() methods
Change-Id: Ifabfdfff0092607179d7f70a589fa8dccfc5e77e
5 files changed, 24 insertions, 55 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc index 4c8b3fd7..a246066b 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -245,17 +245,13 @@ absl::Status BertNLClassifier::InitializeFromMetadata() { segment_ids_tensor.dims->data[1]), TfLiteSupportStatus::kInvalidInputTensorSizeError); } - - bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 && - mask_tensor.dims_signature->size == 2 && - segment_ids_tensor.dims_signature->size == 2; - int num_dynamic_tensors = ids_tensor.dims_signature->data[1] == -1 + - mask_tensor.dims_signature->data[1] == -1 + - segment_ids_tensor.dims_signature->data[1] == -1; - - if (has_valid_dims_signature && num_dynamic_tensors == 3) { + if (ids_tensor.dims_signature->data[1] == -1 && + mask_tensor.dims_signature->data[1] == -1 && + segment_ids_tensor.dims_signature->data[1] == -1) { input_tensors_are_dynamic_ = true; - } else if (has_valid_dims_signature && num_dynamic_tensors > 0 && num_dynamic_tensors < 3) { + } else if (ids_tensor.dims_signature->data[1] == -1 || + mask_tensor.dims_signature->data[1] == -1 || + segment_ids_tensor.dims_signature->data[1] == -1) { return CreateStatusWithPayload( absl::StatusCode::kInternal, "Input tensors contain a mix of static and dynamic tensors", diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc index d322a151..a75fe0ff 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -185,16 +185,9 @@ absl::Status NLClassifier::TrySetLabelFromMetadata( } std::vector<Category> NLClassifier::Classify(const std::string& text) { - StatusOr<std::vector<Category>> infer_result = ClassifyText(text); - if (!infer_result.ok()) { - return {}; - } - return infer_result.value(); -} - -StatusOr<std::vector<Category>> NLClassifier::ClassifyText( - const std::string& text) { - return Infer(text); + // The NLClassifier implementation for Preprocess() and Postprocess() never + // returns errors: just call value(). + return Infer(text).value(); } std::string NLClassifier::GetModelVersion() const { diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h index 8e70ba4c..013b7d53 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -109,14 +109,9 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, std::unique_ptr<tflite::OpResolver> resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); - // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`. + // Performs classification on a string input, returns classified results. std::vector<core::Category> Classify(const std::string& text); - // Performs classification on a string input, returns classified results or an - // error. - tflite::support::StatusOr<std::vector<core::Category>> ClassifyText( - const std::string& text); - // Gets the model version, or "NO_VERSION_INFO" in case there is no version. std::string GetModelVersion() const; diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc index 23a3ddca..85ebc505 100644 --- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc +++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc @@ -85,19 +85,15 @@ Category* GetCategoryWithClassName(const std::string& class_name, void verify_classifier(std::unique_ptr<BertNLClassifier> classifier, bool verify_positive) { if (verify_positive) { - tflite::support::StatusOr<std::vector<core::Category>> results = - classifier->ClassifyText("unflinchingly bleak and desperate"); - - EXPECT_TRUE(results.ok()); - EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score, - GetCategoryWithClassName("positive", results.value())->score); + std::vector<core::Category> results = + classifier->Classify("unflinchingly bleak and desperate"); + EXPECT_GT(GetCategoryWithClassName("negative", results)->score, + GetCategoryWithClassName("positive", results)->score); } else { - tflite::support::StatusOr<std::vector<core::Category>> results = - classifier->ClassifyText("it's a charming and often affecting journey"); - - EXPECT_TRUE(results.ok()); - EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, - GetCategoryWithClassName("negative", results.value())->score); + std::vector<Category> results = + classifier->Classify("it's a charming and often affecting journey"); + EXPECT_GT(GetCategoryWithClassName("positive", results)->score, + GetCategoryWithClassName("negative", results)->score); } } @@ -155,12 +151,11 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) { BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size()); EXPECT_TRUE(classifier.ok()); - tflite::support::StatusOr<std::vector<core::Category>> results = - classifier.value()->ClassifyText(ss_for_positive_review.str()); + std::vector<core::Category> results = + classifier.value()->Classify(ss_for_positive_review.str()); - EXPECT_TRUE(results.ok()); - EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, - GetCategoryWithClassName("negative", results.value())->score); + EXPECT_GT(GetCategoryWithClassName("positive", results)->score, + GetCategoryWithClassName("negative", results)->score); } } // namespace diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc index d2f6e7ca..6b002f71 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc @@ -23,9 +23,6 @@ namespace task { namespace text { namespace nlclassifier { -using ::tflite::support::utils::kAssertionError; -using ::tflite::support::utils::kInvalidPointer; -using ::tflite::support::utils::ThrowException; using ::tflite::support::utils::ConvertVectorToArrayList; using ::tflite::support::utils::JStringToString; using ::tflite::task::core::Category; @@ -34,21 +31,14 @@ using ::tflite::task::text::nlclassifier::NLClassifier; jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) { auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); - auto results = nl_classifier->ClassifyText(JStringToString(env, text)); - if (!results.ok()) { - ThrowException(env, kAssertionError, - "Error occurred when running classifier: %s", - results.status().message().data()); - return env->ExceptionOccurred(); - } - + auto results = nl_classifier->Classify(JStringToString(env, text)); jclass category_class = env->FindClass("org/tensorflow/lite/support/label/Category"); jmethodID category_init = env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V"); return ConvertVectorToArrayList<Category>( - env, results.value(), + env, results, [env, category_class, category_init](const Category& category) { jstring class_name = env->NewStringUTF(category.class_name.data()); // Convert double to float as Java interface exposes float as scores. |