aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Nissan <bennissan@google.com>2022-08-24 02:53:11 +0000
committerBen Nissan <bennissan@google.com>2022-08-24 02:53:11 +0000
commit510ea86c2952e784f46b81356a0c18a2bc7f0e43 (patch)
tree845121243bce2106a5d816226999d6c87d6748bf
parent3fccb92b5eb9831e71380ab6fd5dcdcda13ef32d (diff)
downloadtflite-support-510ea86c2952e784f46b81356a0c18a2bc7f0e43.tar.gz
Revert "[automerge] Return StatusOr from Classify() methods 2p: ..."
Revert "Return StatusOr from Classify() methods" Revert submission 19678847-presubmit-am-86cd332fad154163bed53094e7663a44 Reason for revert: Potentially causing system failures within AdServices (e.g.: https://android-build.googleplex.com/builds/tests/view?invocationId=I42200010083746944&testResultId=TR21628220208011705), needs further testing. Reverted Changes: Ib77b4dc3d:[automerge] Return StatusOr from Classify() method... Ib93385061:Return StatusOr from Classify() methods Change-Id: Ifabfdfff0092607179d7f70a589fa8dccfc5e77e
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc16
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc13
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h7
-rw-r--r--tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc29
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc14
5 files changed, 24 insertions, 55 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index 4c8b3fd7..a246066b 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -245,17 +245,13 @@ absl::Status BertNLClassifier::InitializeFromMetadata() {
segment_ids_tensor.dims->data[1]),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
-
- bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 &&
- mask_tensor.dims_signature->size == 2 &&
- segment_ids_tensor.dims_signature->size == 2;
- int num_dynamic_tensors = ids_tensor.dims_signature->data[1] == -1 +
- mask_tensor.dims_signature->data[1] == -1 +
- segment_ids_tensor.dims_signature->data[1] == -1;
-
- if (has_valid_dims_signature && num_dynamic_tensors == 3) {
+ if (ids_tensor.dims_signature->data[1] == -1 &&
+ mask_tensor.dims_signature->data[1] == -1 &&
+ segment_ids_tensor.dims_signature->data[1] == -1) {
input_tensors_are_dynamic_ = true;
- } else if (has_valid_dims_signature && num_dynamic_tensors > 0 && num_dynamic_tensors < 3) {
+ } else if (ids_tensor.dims_signature->data[1] == -1 ||
+ mask_tensor.dims_signature->data[1] == -1 ||
+ segment_ids_tensor.dims_signature->data[1] == -1) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Input tensors contain a mix of static and dynamic tensors",
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
index d322a151..a75fe0ff 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -185,16 +185,9 @@ absl::Status NLClassifier::TrySetLabelFromMetadata(
}
std::vector<Category> NLClassifier::Classify(const std::string& text) {
- StatusOr<std::vector<Category>> infer_result = ClassifyText(text);
- if (!infer_result.ok()) {
- return {};
- }
- return infer_result.value();
-}
-
-StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
- const std::string& text) {
- return Infer(text);
+ // The NLClassifier implementation for Preprocess() and Postprocess() never
+ // returns errors: just call value().
+ return Infer(text).value();
}
std::string NLClassifier::GetModelVersion() const {
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
index 8e70ba4c..013b7d53 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -109,14 +109,9 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
- // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`.
+ // Performs classification on a string input, returns classified results.
std::vector<core::Category> Classify(const std::string& text);
- // Performs classification on a string input, returns classified results or an
- // error.
- tflite::support::StatusOr<std::vector<core::Category>> ClassifyText(
- const std::string& text);
-
// Gets the model version, or "NO_VERSION_INFO" in case there is no version.
std::string GetModelVersion() const;
diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
index 23a3ddca..85ebc505 100644
--- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
+++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
@@ -85,19 +85,15 @@ Category* GetCategoryWithClassName(const std::string& class_name,
void verify_classifier(std::unique_ptr<BertNLClassifier> classifier,
bool verify_positive) {
if (verify_positive) {
- tflite::support::StatusOr<std::vector<core::Category>> results =
- classifier->ClassifyText("unflinchingly bleak and desperate");
-
- EXPECT_TRUE(results.ok());
- EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score,
- GetCategoryWithClassName("positive", results.value())->score);
+ std::vector<core::Category> results =
+ classifier->Classify("unflinchingly bleak and desperate");
+ EXPECT_GT(GetCategoryWithClassName("negative", results)->score,
+ GetCategoryWithClassName("positive", results)->score);
} else {
- tflite::support::StatusOr<std::vector<core::Category>> results =
- classifier->ClassifyText("it's a charming and often affecting journey");
-
- EXPECT_TRUE(results.ok());
- EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
- GetCategoryWithClassName("negative", results.value())->score);
+ std::vector<Category> results =
+ classifier->Classify("it's a charming and often affecting journey");
+ EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
+ GetCategoryWithClassName("negative", results)->score);
}
}
@@ -155,12 +151,11 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
EXPECT_TRUE(classifier.ok());
- tflite::support::StatusOr<std::vector<core::Category>> results =
- classifier.value()->ClassifyText(ss_for_positive_review.str());
+ std::vector<core::Category> results =
+ classifier.value()->Classify(ss_for_positive_review.str());
- EXPECT_TRUE(results.ok());
- EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
- GetCategoryWithClassName("negative", results.value())->score);
+ EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
+ GetCategoryWithClassName("negative", results)->score);
}
} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
index d2f6e7ca..6b002f71 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -23,9 +23,6 @@ namespace task {
namespace text {
namespace nlclassifier {
-using ::tflite::support::utils::kAssertionError;
-using ::tflite::support::utils::kInvalidPointer;
-using ::tflite::support::utils::ThrowException;
using ::tflite::support::utils::ConvertVectorToArrayList;
using ::tflite::support::utils::JStringToString;
using ::tflite::task::core::Category;
@@ -34,21 +31,14 @@ using ::tflite::task::text::nlclassifier::NLClassifier;
jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) {
auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
- auto results = nl_classifier->ClassifyText(JStringToString(env, text));
- if (!results.ok()) {
- ThrowException(env, kAssertionError,
- "Error occurred when running classifier: %s",
- results.status().message().data());
- return env->ExceptionOccurred();
- }
-
+ auto results = nl_classifier->Classify(JStringToString(env, text));
jclass category_class =
env->FindClass("org/tensorflow/lite/support/label/Category");
jmethodID category_init =
env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V");
return ConvertVectorToArrayList<Category>(
- env, results.value(),
+ env, results,
[env, category_class, category_init](const Category& category) {
jstring class_name = env->NewStringUTF(category.class_name.data());
// Convert double to float as Java interface exposes float as scores.