aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Nissan <bennissan@google.com>2022-08-19 16:44:36 +0000
committerBen Nissan <bennissan@google.com>2022-08-23 17:30:58 +0000
commit9fc1fcbd2239b6c8f6fa8399c6580d4f350e2150 (patch)
treefd7b7d30537d59b8f0a635b20c511543623f87aa
parente4266a8951b0f53a80bc54392276c355583c0a53 (diff)
downloadtflite-support-9fc1fcbd2239b6c8f6fa8399c6580d4f350e2150.tar.gz
Return StatusOr from Classify() methods
This CL updates the Classify() methods in NLClassifier and BertNLClassifier to return StatusOr wrappers for their contents, allowing errors to be propagated across the JNI boundary. To enable this, it fixes a bug when checking for dynamic vs static input tensors in bert_nl_classifier.cc and throws exceptions from RunClassifier(). Bug: 242926638, Bug: 242926783 Test: atest tflite_support_classifier_tests Test: atest TfliteSupportClassifierTests Change-Id: Ib93385061ace97cf96a168e3d76d2d9162c8055f
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc16
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc13
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h7
-rw-r--r--tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc29
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc14
5 files changed, 55 insertions, 24 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index a246066b..4c8b3fd7 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -245,13 +245,17 @@ absl::Status BertNLClassifier::InitializeFromMetadata() {
segment_ids_tensor.dims->data[1]),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
- if (ids_tensor.dims_signature->data[1] == -1 &&
- mask_tensor.dims_signature->data[1] == -1 &&
- segment_ids_tensor.dims_signature->data[1] == -1) {
+
+ bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 &&
+ mask_tensor.dims_signature->size == 2 &&
+ segment_ids_tensor.dims_signature->size == 2;
+ int num_dynamic_tensors = ids_tensor.dims_signature->data[1] == -1 +
+ mask_tensor.dims_signature->data[1] == -1 +
+ segment_ids_tensor.dims_signature->data[1] == -1;
+
+ if (has_valid_dims_signature && num_dynamic_tensors == 3) {
input_tensors_are_dynamic_ = true;
- } else if (ids_tensor.dims_signature->data[1] == -1 ||
- mask_tensor.dims_signature->data[1] == -1 ||
- segment_ids_tensor.dims_signature->data[1] == -1) {
+ } else if (has_valid_dims_signature && num_dynamic_tensors > 0 && num_dynamic_tensors < 3) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Input tensors contain a mix of static and dynamic tensors",
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
index a75fe0ff..d322a151 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -185,9 +185,16 @@ absl::Status NLClassifier::TrySetLabelFromMetadata(
}
std::vector<Category> NLClassifier::Classify(const std::string& text) {
- // The NLClassifier implementation for Preprocess() and Postprocess() never
- // returns errors: just call value().
- return Infer(text).value();
+ StatusOr<std::vector<Category>> infer_result = ClassifyText(text);
+ if (!infer_result.ok()) {
+ return {};
+ }
+ return infer_result.value();
+}
+
+StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
+ const std::string& text) {
+ return Infer(text);
}
std::string NLClassifier::GetModelVersion() const {
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
index 013b7d53..8e70ba4c 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -109,9 +109,14 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
- // Performs classification on a string input, returns classified results.
+ // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`.
std::vector<core::Category> Classify(const std::string& text);
+ // Performs classification on a string input, returns classified results or an
+ // error.
+ tflite::support::StatusOr<std::vector<core::Category>> ClassifyText(
+ const std::string& text);
+
// Gets the model version, or "NO_VERSION_INFO" in case there is no version.
std::string GetModelVersion() const;
diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
index 85ebc505..23a3ddca 100644
--- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
+++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
@@ -85,15 +85,19 @@ Category* GetCategoryWithClassName(const std::string& class_name,
void verify_classifier(std::unique_ptr<BertNLClassifier> classifier,
bool verify_positive) {
if (verify_positive) {
- std::vector<core::Category> results =
- classifier->Classify("unflinchingly bleak and desperate");
- EXPECT_GT(GetCategoryWithClassName("negative", results)->score,
- GetCategoryWithClassName("positive", results)->score);
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier->ClassifyText("unflinchingly bleak and desperate");
+
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score,
+ GetCategoryWithClassName("positive", results.value())->score);
} else {
- std::vector<Category> results =
- classifier->Classify("it's a charming and often affecting journey");
- EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
- GetCategoryWithClassName("negative", results)->score);
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier->ClassifyText("it's a charming and often affecting journey");
+
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
+ GetCategoryWithClassName("negative", results.value())->score);
}
}
@@ -151,11 +155,12 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
EXPECT_TRUE(classifier.ok());
- std::vector<core::Category> results =
- classifier.value()->Classify(ss_for_positive_review.str());
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier.value()->ClassifyText(ss_for_positive_review.str());
- EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
- GetCategoryWithClassName("negative", results)->score);
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
+ GetCategoryWithClassName("negative", results.value())->score);
}
} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
index 6b002f71..d2f6e7ca 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -23,6 +23,9 @@ namespace task {
namespace text {
namespace nlclassifier {
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::ThrowException;
using ::tflite::support::utils::ConvertVectorToArrayList;
using ::tflite::support::utils::JStringToString;
using ::tflite::task::core::Category;
@@ -31,14 +34,21 @@ using ::tflite::task::text::nlclassifier::NLClassifier;
jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) {
auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
- auto results = nl_classifier->Classify(JStringToString(env, text));
+ auto results = nl_classifier->ClassifyText(JStringToString(env, text));
+ if (!results.ok()) {
+ ThrowException(env, kAssertionError,
+ "Error occurred when running classifier: %s",
+ results.status().message().data());
+ return env->ExceptionOccurred();
+ }
+
jclass category_class =
env->FindClass("org/tensorflow/lite/support/label/Category");
jmethodID category_init =
env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V");
return ConvertVectorToArrayList<Category>(
- env, results,
+ env, results.value(),
[env, category_class, category_init](const Category& category) {
jstring class_name = env->NewStringUTF(category.class_name.data());
// Convert double to float as Java interface exposes float as scores.