diff options
Diffstat (limited to 'tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc')
-rw-r--r-- | tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc index 85ebc505..23a3ddca 100644 --- a/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc +++ b/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc @@ -85,15 +85,19 @@ Category* GetCategoryWithClassName(const std::string& class_name, void verify_classifier(std::unique_ptr<BertNLClassifier> classifier, bool verify_positive) { if (verify_positive) { - std::vector<core::Category> results = - classifier->Classify("unflinchingly bleak and desperate"); - EXPECT_GT(GetCategoryWithClassName("negative", results)->score, - GetCategoryWithClassName("positive", results)->score); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier->ClassifyText("unflinchingly bleak and desperate"); + + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score, + GetCategoryWithClassName("positive", results.value())->score); } else { - std::vector<Category> results = - classifier->Classify("it's a charming and often affecting journey"); - EXPECT_GT(GetCategoryWithClassName("positive", results)->score, - GetCategoryWithClassName("negative", results)->score); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier->ClassifyText("it's a charming and often affecting journey"); + + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, + GetCategoryWithClassName("negative", results.value())->score); } } @@ -151,11 +155,12 @@ TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) { BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size()); EXPECT_TRUE(classifier.ok()); - std::vector<core::Category> results = - classifier.value()->Classify(ss_for_positive_review.str()); + tflite::support::StatusOr<std::vector<core::Category>> results = + classifier.value()->ClassifyText(ss_for_positive_review.str()); - EXPECT_GT(GetCategoryWithClassName("positive", results)->score, - GetCategoryWithClassName("negative", results)->score); + EXPECT_TRUE(results.ok()); + EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score, + GetCategoryWithClassName("negative", results.value())->score); } } // namespace |