aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc')
-rw-r--r--tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc29
1 files changed, 17 insertions, 12 deletions
diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
index 85ebc505..23a3ddca 100644
--- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
+++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
@@ -85,15 +85,19 @@ Category* GetCategoryWithClassName(const std::string& class_name,
void verify_classifier(std::unique_ptr<BertNLClassifier> classifier,
bool verify_positive) {
if (verify_positive) {
- std::vector<core::Category> results =
- classifier->Classify("unflinchingly bleak and desperate");
- EXPECT_GT(GetCategoryWithClassName("negative", results)->score,
- GetCategoryWithClassName("positive", results)->score);
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier->ClassifyText("unflinchingly bleak and desperate");
+
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score,
+ GetCategoryWithClassName("positive", results.value())->score);
} else {
- std::vector<Category> results =
- classifier->Classify("it's a charming and often affecting journey");
- EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
- GetCategoryWithClassName("negative", results)->score);
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier->ClassifyText("it's a charming and often affecting journey");
+
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
+ GetCategoryWithClassName("negative", results.value())->score);
}
}
@@ -151,11 +155,12 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
EXPECT_TRUE(classifier.ok());
- std::vector<core::Category> results =
- classifier.value()->Classify(ss_for_positive_review.str());
+ tflite::support::StatusOr<std::vector<core::Category>> results =
+ classifier.value()->ClassifyText(ss_for_positive_review.str());
- EXPECT_GT(GetCategoryWithClassName("positive", results)->score,
- GetCategoryWithClassName("negative", results)->score);
+ EXPECT_TRUE(results.ok());
+ EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
+ GetCategoryWithClassName("negative", results.value())->score);
}
} // namespace