summaryrefslogtreecommitdiff
path: root/native/annotator
diff options
context:
space:
mode:
Diffstat (limited to 'native/annotator')
-rw-r--r--native/annotator/annotator.cc69
-rw-r--r--native/annotator/annotator.h3
-rw-r--r--native/annotator/annotator_jni.cc4
-rw-r--r--native/annotator/annotator_jni_common.cc17
-rw-r--r--native/annotator/collections.h30
-rw-r--r--native/annotator/contact/contact-engine-dummy.h4
-rw-r--r--native/annotator/datetime/grammar-parser.cc11
-rw-r--r--native/annotator/datetime/grammar-parser.h5
-rw-r--r--native/annotator/datetime/grammar-parser_test.cc26
-rw-r--r--native/annotator/duration/duration.cc14
-rw-r--r--native/annotator/duration/duration.h2
-rw-r--r--native/annotator/duration/duration_test.cc174
-rw-r--r--native/annotator/installed_app/installed-app-engine-dummy.h4
-rw-r--r--native/annotator/knowledge/knowledge-engine-dummy.h6
-rw-r--r--native/annotator/model.fbs26
-rw-r--r--native/annotator/number/number.cc7
-rw-r--r--native/annotator/number/number.h2
-rw-r--r--native/annotator/number/number_test-include.cc157
-rw-r--r--native/annotator/number/number_test-include.h22
-rw-r--r--native/annotator/person_name/person-name-engine-dummy.h2
-rw-r--r--native/annotator/person_name/person_name_model.fbs7
-rw-r--r--native/annotator/pod_ner/pod-ner-impl.cc13
-rw-r--r--native/annotator/pod_ner/pod-ner-impl_test.cc49
-rw-r--r--native/annotator/translate/translate.cc5
-rw-r--r--native/annotator/translate/translate_test.cc58
-rw-r--r--native/annotator/types.h9
-rw-r--r--native/annotator/vocab/vocab-annotator-impl.cc6
27 files changed, 563 insertions, 169 deletions
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index e0d4241..a8483f1 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -432,7 +432,8 @@ void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
*analyzer_, *datetime_grounder_,
/*target_classification_score=*/1.0,
- /*priority_score=*/1.0);
+ /*priority_score=*/1.0,
+ model_->datetime_grammar_model()->enabled_modes());
}
} else if (model_->datetime_model()) {
datetime_parser_ = RegexDatetimeParser::Instance(
@@ -604,6 +605,8 @@ bool Annotator::InitializeKnowledgeEngine(
if (model_->triggering_options() != nullptr) {
knowledge_engine->SetPriorityScore(
model_->triggering_options()->knowledge_priority_score());
+ knowledge_engine->SetEnabledModes(
+ model_->triggering_options()->knowledge_enabled_modes());
}
knowledge_engine_ = std::move(knowledge_engine);
return true;
@@ -621,10 +624,21 @@ bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
return true;
}
+void Annotator::CleanUpContactEngine() {
+ if (contact_engine_ == nullptr) {
+ TC3_LOG(INFO)
+ << "Attempting to clean up contact engine that does not exist.";
+ return;
+ }
+ contact_engine_->CleanUp();
+}
+
bool Annotator::InitializeInstalledAppEngine(
const std::string& serialized_config) {
std::unique_ptr<InstalledAppEngine> installed_app_engine(
- new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
+ new InstalledAppEngine(
+ selection_feature_processor_.get(), unilib_,
+ model_->triggering_options()->installed_app_enabled_modes()));
if (!installed_app_engine->Initialize(serialized_config)) {
TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
return false;
@@ -912,38 +926,40 @@ CodepointSpan Annotator::SuggestSelection(
!knowledge_engine_
->Chunk(context, options.annotation_usecase,
options.location_context, Permissions(),
- AnnotateMode::kEntityAnnotation, &candidates)
+ AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION,
+ &candidates)
.ok()) {
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
if (contact_engine_ != nullptr &&
- !contact_engine_->Chunk(context_unicode, tokens,
+ !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Contact suggest selection failed.";
return original_click_indices;
}
if (installed_app_engine_ != nullptr &&
- !installed_app_engine_->Chunk(context_unicode, tokens,
+ !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Installed app suggest selection failed.";
return original_click_indices;
}
if (number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
+ ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
return original_click_indices;
}
if (duration_annotator_ != nullptr &&
- !duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase,
- &candidates.annotated_spans[0])) {
+ !duration_annotator_->FindAll(
+ context_unicode, tokens, options.annotation_usecase,
+ ModeFlag_SELECTION, &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
return original_click_indices;
}
if (person_name_engine_ != nullptr &&
- !person_name_engine_->Chunk(context_unicode, tokens,
+ !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Person name suggest selection failed.";
return original_click_indices;
@@ -964,7 +980,9 @@ CodepointSpan Annotator::SuggestSelection(
candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
}
- if (experimental_annotator_ != nullptr) {
+ if (experimental_annotator_ != nullptr &&
+ (model_->triggering_options()->experimental_enabled_modes() &
+ ModeFlag_SELECTION)) {
candidates.annotated_spans[0].push_back(
experimental_annotator_->SuggestSelection(context_unicode,
click_indices));
@@ -1896,7 +1914,9 @@ std::vector<ClassificationResult> Annotator::ClassifyText(
candidates.push_back({selection_indices, {vocab_annotator_result}});
}
- if (experimental_annotator_) {
+ if (experimental_annotator_ &&
+ (model_->triggering_options()->experimental_enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
experimental_annotator_->ClassifyText(context_unicode, selection_indices,
candidates);
}
@@ -2218,7 +2238,8 @@ Status Annotator::AnnotateSingleInput(
const bool contact_annotations_enabled =
!is_raw_usecase || is_entity_type_enabled(Collections::Contact());
if (contact_annotations_enabled && contact_engine_ &&
- !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
+ !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
+ candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
}
@@ -2226,7 +2247,8 @@ Status Annotator::AnnotateSingleInput(
const bool app_annotations_enabled =
!is_raw_usecase || is_entity_type_enabled(Collections::App());
if (app_annotations_enabled && installed_app_engine_ &&
- !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
+ !installed_app_engine_->Chunk(context_unicode, tokens,
+ ModeFlag_ANNOTATION, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run installed app engine Chunk.");
}
@@ -2237,7 +2259,7 @@ Status Annotator::AnnotateSingleInput(
is_entity_type_enabled(Collections::Percentage()));
if (number_annotations_enabled && number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
- candidates)) {
+ ModeFlag_ANNOTATION, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run number annotator FindAll.");
}
@@ -2247,7 +2269,8 @@ Status Annotator::AnnotateSingleInput(
!is_raw_usecase || is_entity_type_enabled(Collections::Duration());
if (duration_annotations_enabled && duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase, candidates)) {
+ options.annotation_usecase,
+ ModeFlag_ANNOTATION, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run duration annotator FindAll.");
}
@@ -2256,7 +2279,8 @@ Status Annotator::AnnotateSingleInput(
const bool person_annotations_enabled =
!is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
if (person_annotations_enabled && person_name_engine_ &&
- !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
+ !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
+ candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run person name engine Chunk.");
}
@@ -2290,6 +2314,8 @@ Status Annotator::AnnotateSingleInput(
// Annotate with the experimental annotator.
if (experimental_annotator_ != nullptr &&
+ (model_->triggering_options()->experimental_enabled_modes() &
+ ModeFlag_ANNOTATION) &&
!experimental_annotator_->Annotate(context_unicode, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
}
@@ -2376,14 +2402,21 @@ StatusOr<Annotations> Annotator::AnnotateStructuredInput(
.relative_bounding_box_height = string_fragment.bounding_box_height});
}
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ const bool is_raw_usecase =
+ options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ const bool knowledge_engine_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Entity());
// KnowledgeEngine is special, because it supports annotation of multiple
// fragments at once.
- if (knowledge_engine_ &&
+ if (knowledge_engine_annotations_enabled && knowledge_engine_ &&
!knowledge_engine_
->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
options.annotation_usecase,
options.location_context, options.permissions,
- options.annotate_mode, &annotation_candidates)
+ options.annotate_mode, ModeFlag_ANNOTATION,
+ &annotation_candidates)
.ok()) {
return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
}
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index d69fe32..5df8129 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -149,6 +149,9 @@ class Annotator {
// Initializes the contact engine with the given config.
bool InitializeContactEngine(const std::string& serialized_config);
+ // Cleans up the resources associated with the contact engine.
+ void CleanUpContactEngine();
+
// Initializes the installed app engine with the given config.
bool InitializeInstalledAppEngine(const std::string& serialized_config);
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 6e7eeab..3d352c6 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -275,6 +275,7 @@ StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject(
device_locales, classification_result,
options->reference_time_ms_utc, context, selection_indices,
app_context, model_context->model()->entity_data_schema(),
+ options->enable_add_contact_intent, options->enable_search_intent,
&remote_action_templates)) {
return {Status::UNKNOWN};
}
@@ -896,6 +897,9 @@ TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
(JNIEnv* env, jobject thiz, jlong ptr) {
const AnnotatorJniContext* context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
+ if (context != nullptr && context->model()) {
+ context->model()->CleanUpContactEngine();
+ }
delete context;
}
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index a6f636f..6ee4977 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -279,6 +279,23 @@ StatusOr<ClassificationOptions> FromJavaClassificationOptions(
JniHelper::CallBooleanMethod(env, joptions,
get_trigger_dictionary_on_beginner_words));
+ // .getEnableAddContactIntent()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_enable_add_contact_intent,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getEnableAddContactIntent", "()Z"));
+ TC3_ASSIGN_OR_RETURN(classifier_options.enable_add_contact_intent,
+ JniHelper::CallBooleanMethod(
+ env, joptions, get_enable_add_contact_intent));
+
+ // .getEnableSearchIntent()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_enable_search_intent,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getEnableSearchIntent", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ classifier_options.enable_search_intent,
+ JniHelper::CallBooleanMethod(env, joptions, get_enable_search_intent));
+
return classifier_options;
}
diff --git a/native/annotator/collections.h b/native/annotator/collections.h
index 417b447..becdcdb 100644
--- a/native/annotator/collections.h
+++ b/native/annotator/collections.h
@@ -144,6 +144,36 @@ class Collections {
*[]() { return new std::string("otp_code"); }();
return value;
}
+ static const std::string& Art() {
+ static const std::string& value =
+ *[]() { return new std::string("art"); }();
+ return value;
+ }
+ static const std::string& ConsumerGood() {
+ static const std::string& value =
+ *[]() { return new std::string("consumer_good"); }();
+ return value;
+ }
+ static const std::string& Event() {
+ static const std::string& value =
+ *[]() { return new std::string("event"); }();
+ return value;
+ }
+ static const std::string& Location() {
+ static const std::string& value =
+ *[]() { return new std::string("location"); }();
+ return value;
+ }
+ static const std::string& Organization() {
+ static const std::string& value =
+ *[]() { return new std::string("organization"); }();
+ return value;
+ }
+ static const std::string& Person() {
+ static const std::string& value =
+ *[]() { return new std::string("person"); }();
+ return value;
+ }
};
} // namespace libtextclassifier3
diff --git a/native/annotator/contact/contact-engine-dummy.h b/native/annotator/contact/contact-engine-dummy.h
index fe60203..211553c 100644
--- a/native/annotator/contact/contact-engine-dummy.h
+++ b/native/annotator/contact/contact-engine-dummy.h
@@ -47,13 +47,15 @@ class ContactEngine {
}
bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
+ const std::vector<Token>& tokens, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
return true;
}
void AddContactMetadataToKnowledgeClassificationResult(
ClassificationResult* classification_result) const {}
+
+ void CleanUp() const {}
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc
index 6d51c19..c49a01a 100644
--- a/native/annotator/datetime/grammar-parser.cc
+++ b/native/annotator/datetime/grammar-parser.cc
@@ -20,6 +20,7 @@
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
@@ -33,11 +34,13 @@ namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
- const float target_classification_score, const float priority_score)
+ const float target_classification_score, const float priority_score,
+ ModeFlag enabled_modes)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
- priority_score_(priority_score) {}
+ priority_score_(priority_score),
+ enabled_modes_(enabled_modes) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
@@ -54,6 +57,10 @@ StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
+ if (!(enabled_modes_ & mode)) {
+ return std::vector<DatetimeParseResultSpan>();
+ }
+
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
std::vector<Locale> locales = locale_list.GetLocales();
diff --git a/native/annotator/datetime/grammar-parser.h b/native/annotator/datetime/grammar-parser.h
index 6ff4b46..35da843 100644
--- a/native/annotator/datetime/grammar-parser.h
+++ b/native/annotator/datetime/grammar-parser.h
@@ -22,6 +22,7 @@
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/statusor.h"
#include "utils/grammar/analyzer.h"
@@ -37,7 +38,8 @@ class GrammarDatetimeParser : public DatetimeParser {
explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
const float target_classification_score,
- const float priority_score);
+ const float priority_score,
+ ModeFlag enabled_modes);
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
@@ -61,6 +63,7 @@ class GrammarDatetimeParser : public DatetimeParser {
const DatetimeGrounder& datetime_grounder_;
const float target_classification_score_;
const float priority_score_;
+ const ModeFlag enabled_modes_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/grammar-parser_test.cc b/native/annotator/datetime/grammar-parser_test.cc
index cf2dffd..b8a270d 100644
--- a/native/annotator/datetime/grammar-parser_test.cc
+++ b/native/annotator/datetime/grammar-parser_test.cc
@@ -22,6 +22,7 @@
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/testing/base-parser-test.h"
#include "annotator/datetime/testing/datetime-component-builder.h"
+#include "annotator/model_generated.h"
#include "utils/grammar/analyzer.h"
#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
@@ -42,7 +43,15 @@ std::string ReadFile(const std::string& file_name) {
class GrammarDatetimeParserTest : public DateTimeParserTest {
public:
- void SetUp() override {
+ void SetUp() override { ResetParser(ModeFlag_ALL); }
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return parser_.get();
+ }
+
+ protected:
+ void ResetParser(ModeFlag enabled_modes) {
grammar_buffer_ = ReadFile(GetModelPath() + "datetime.fb");
unilib_ = CreateUniLibForTesting();
calendarlib_ = CreateCalendarLibForTesting();
@@ -51,12 +60,8 @@ class GrammarDatetimeParserTest : public DateTimeParserTest {
datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_.get());
parser_.reset(new GrammarDatetimeParser(*analyzer_, *datetime_grounder_,
/*target_classification_score=*/1.0,
- /*priority_score=*/1.0));
- }
-
- // Exposes the date time parser for tests and evaluations.
- const DatetimeParser* DatetimeParserForTests() const override {
- return parser_.get();
+ /*priority_score=*/1.0,
+ enabled_modes));
}
private:
@@ -486,6 +491,13 @@ TEST_F(GrammarDatetimeParserTest, Parse) {
.Build()}));
}
+TEST_F(GrammarDatetimeParserTest, NotEnabledModeHasNoResult) {
+ ResetParser(ModeFlag_SELECTION);
+ // `DateTimeParserTest` implementation parses the input under the ANNOTATION
+ // mode.
+ EXPECT_TRUE(HasNoResult("{January 1, 1988}"));
+}
+
TEST_F(GrammarDatetimeParserTest, DateValidation) {
EXPECT_TRUE(ParsesCorrectly(
"{01/02/2020}", 1577919600000, GRANULARITY_DAY,
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index c59b8e0..df4c60d 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -20,6 +20,7 @@
#include <cstdlib>
#include "annotator/collections.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/base/macros.h"
@@ -125,8 +126,10 @@ bool DurationAnnotator::ClassifyText(
const UnicodeText& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
- if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
- (1 << annotation_usecase))) == 0) {
+ if (!options_->enabled() ||
+ ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
+ 0 ||
+ !(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
return false;
}
@@ -151,9 +154,12 @@ bool DurationAnnotator::ClassifyText(
bool DurationAnnotator::FindAll(const UnicodeText& context,
const std::vector<Token>& tokens,
AnnotationUsecase annotation_usecase,
+ ModeFlag mode,
std::vector<AnnotatedSpan>* results) const {
- if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
- (1 << annotation_usecase))) == 0) {
+ if (!options_->enabled() ||
+ ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
+ 0 ||
+ !(options_->enabled_modes() & mode)) {
return true;
}
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index 1a42ac3..e99542c 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -87,7 +87,7 @@ class DurationAnnotator {
// Finds all duration instances in the input text.
bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens,
- AnnotationUsecase annotation_usecase,
+ AnnotationUsecase annotation_usecase, ModeFlag mode,
std::vector<AnnotatedSpan>* results) const;
private:
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index 7c07a72..f726058 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -16,6 +16,7 @@
#include "annotator/duration/duration.h"
+#include <cstddef>
#include <string>
#include <vector>
@@ -37,41 +38,61 @@ using testing::ElementsAre;
using testing::Field;
using testing::IsEmpty;
-const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- DurationAnnotatorOptionsT options;
- options.enabled = true;
+namespace {
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ DurationAnnotatorOptionsT options;
+ options.enabled = true;
+ options.enabled_modes = enabled_modes;
- options.week_expressions.push_back("week");
- options.week_expressions.push_back("weeks");
+ options.week_expressions.push_back("week");
+ options.week_expressions.push_back("weeks");
- options.day_expressions.push_back("day");
- options.day_expressions.push_back("days");
+ options.day_expressions.push_back("day");
+ options.day_expressions.push_back("days");
- options.hour_expressions.push_back("hour");
- options.hour_expressions.push_back("hours");
+ options.hour_expressions.push_back("hour");
+ options.hour_expressions.push_back("hours");
- options.minute_expressions.push_back("minute");
- options.minute_expressions.push_back("minutes");
+ options.minute_expressions.push_back("minute");
+ options.minute_expressions.push_back("minutes");
- options.second_expressions.push_back("second");
- options.second_expressions.push_back("seconds");
+ options.second_expressions.push_back("second");
+ options.second_expressions.push_back("seconds");
- options.filler_expressions.push_back("and");
- options.filler_expressions.push_back("a");
- options.filler_expressions.push_back("an");
- options.filler_expressions.push_back("one");
+ options.filler_expressions.push_back("and");
+ options.filler_expressions.push_back("a");
+ options.filler_expressions.push_back("an");
+ options.filler_expressions.push_back("one");
- options.half_expressions.push_back("half");
+ options.half_expressions.push_back("half");
- options.sub_token_separator_codepoints.push_back('-');
+ options.sub_token_separator_codepoints.push_back('-');
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+} // namespace
- return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
+const DurationAnnotatorOptions* TestingDurationAnnotatorOptions(
+ ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_all =
+ CreateOptionsData(ModeFlag_ALL);
+ static const flatbuffers::DetachedBuffer* options_data_selection =
+ CreateOptionsData(ModeFlag_SELECTION);
+ static const flatbuffers::DetachedBuffer* options_data_no_selection =
+ CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
+
+ if (enabled_modes == ModeFlag_SELECTION) {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_selection->data());
+ } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_no_selection->data());
+ } else {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_all->data());
+ }
}
std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
@@ -103,10 +124,10 @@ std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
class DurationAnnotatorTest : public ::testing::Test {
protected:
- DurationAnnotatorTest()
+ explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
- duration_annotator_(TestingDurationAnnotatorOptions(),
+ duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes),
feature_processor_.get(), &unilib_) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
@@ -118,6 +139,19 @@ class DurationAnnotatorTest : public ::testing::Test {
DurationAnnotator duration_annotator_;
};
+class DurationAnnotatorForAnnotationAndClassificationTest
+ : public DurationAnnotatorTest {
+ protected:
+ DurationAnnotatorForAnnotationAndClassificationTest()
+ : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest {
+ protected:
+ DurationAnnotatorForSelectionTest()
+ : DurationAnnotatorTest(ModeFlag_SELECTION) {}
+};
+
TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
@@ -129,6 +163,14 @@ TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
}
+TEST_F(DurationAnnotatorForSelectionTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification;
+ EXPECT_FALSE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+}
+
TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
@@ -152,7 +194,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -165,13 +208,26 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
15 * 60 * 1000)))))));
}
+TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,
+ FindsAllDisabledModeReturnsNoResults) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
+
+ EXPECT_THAT(result, IsEmpty());
+}
+
TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -190,7 +246,8 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -209,7 +266,8 @@ TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -228,7 +286,8 @@ TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -247,7 +306,8 @@ TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -266,7 +326,8 @@ TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -286,7 +347,8 @@ TEST_F(DurationAnnotatorTest,
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -305,7 +367,8 @@ TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -323,7 +386,8 @@ TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -334,7 +398,8 @@ TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -352,7 +417,8 @@ TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -383,7 +449,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -402,7 +469,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -422,7 +490,8 @@ TEST_F(DurationAnnotatorTest,
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -440,7 +509,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -458,7 +528,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -475,7 +546,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -540,7 +612,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -558,7 +631,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -576,7 +650,8 @@ TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, IsEmpty());
}
@@ -586,7 +661,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
diff --git a/native/annotator/installed_app/installed-app-engine-dummy.h b/native/annotator/installed_app/installed-app-engine-dummy.h
index 2f2b62f..80f5a5c 100644
--- a/native/annotator/installed_app/installed-app-engine-dummy.h
+++ b/native/annotator/installed_app/installed-app-engine-dummy.h
@@ -32,7 +32,7 @@ namespace libtextclassifier3 {
class InstalledAppEngine {
public:
explicit InstalledAppEngine(const FeatureProcessor* feature_processor,
- const UniLib* unilib) {}
+ const UniLib* unilib, ModeFlag enabled_modes) {}
bool Initialize(const std::string& serialized_config) {
TC3_LOG(ERROR) << "No installed app engine to initialize.";
@@ -45,7 +45,7 @@ class InstalledAppEngine {
}
bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
+ const std::vector<Token>& tokens, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
return true;
}
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 34fa490..949018c 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -37,6 +37,8 @@ class KnowledgeEngine {
void SetPriorityScore(float priority_score) {}
+ void SetEnabledModes(ModeFlag enabled_modes) {}
+
Status ClassifyText(const std::string& text, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
@@ -48,7 +50,7 @@ class KnowledgeEngine {
Status Chunk(const std::string& text, AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
const Permissions& permissions, const AnnotateMode annotate_mode,
- Annotations* result) const {
+ ModeFlag mode, Annotations* result) const {
return Status::OK;
}
@@ -58,7 +60,7 @@ class KnowledgeEngine {
AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
const Permissions& permissions, const AnnotateMode annotate_mode,
- Annotations* results) const {
+ ModeFlag mode, Annotations* results) const {
return Status::OK;
}
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 57187f5..eeb4101 100644
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -415,6 +415,7 @@ table GrammarModel {
// The grammar rules.
rules:grammar.RulesSet;
+ // Deprecated. Used only for the old implementation of the grammar model.
rule_classification_result:[GrammarModel_.RuleClassificationResult];
// Number of tokens in the context to use for classification and text
@@ -432,6 +433,10 @@ table GrammarModel {
// The priority score used for conflict resolution with the other models.
priority_score:float = 1;
+
+ // Global enabled modes. Use this instead of
+ // `rule_classification_result.enabled_modes`.
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3.MoneyParsingOptions_;
@@ -486,6 +491,15 @@ table ModelTriggeringOptions {
// map. Key: collection type e.g. "address", "phone"..., Value: float number.
// NOTE: The entries here need to be sorted since we use LookupByKey.
collection_to_priority:[ModelTriggeringOptions_.CollectionToPriorityEntry];
+
+ // Enabled modes for the knowledge engine model.
+ knowledge_enabled_modes:ModeFlag = ALL;
+
+ // Enabled modes for the experimental model.
+ experimental_enabled_modes:ModeFlag = ALL;
+
+ // Enabled modes for the installed app model.
+ installed_app_enabled_modes:ModeFlag = ALL;
}
// Options controlling the output of the classifier.
@@ -894,6 +908,9 @@ table ContactAnnotatorOptions {
// For each language there is a customized list of supported declensions.
language:string (shared);
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3.TranslateAnnotatorOptions_;
@@ -927,6 +944,9 @@ table TranslateAnnotatorOptions {
algorithm:TranslateAnnotatorOptions_.Algorithm;
backoff_options:TranslateAnnotatorOptions_.BackoffOptions;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = CLASSIFICATION;
}
namespace libtextclassifier3.PodNerModel_;
@@ -1012,6 +1032,9 @@ table PodNerModel {
min_number_of_tokens:int = 1;
min_number_of_wordpieces:int = 1;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3;
@@ -1043,6 +1066,9 @@ table VocabModel {
// Priority score used for conflict resolution with the other models.
priority_score:float = 0;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ANNOTATION_AND_CLASSIFICATION;
}
root_type libtextclassifier3.Model;
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
index 3be6ad8..14fc24e 100644
--- a/native/annotator/number/number.cc
+++ b/native/annotator/number/number.cc
@@ -21,6 +21,7 @@
#include <string>
#include "annotator/collections.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/strings/split.h"
@@ -38,7 +39,8 @@ bool NumberAnnotator::ClassifyText(
context, selection_indices.first, selection_indices.second);
std::vector<AnnotatedSpan> results;
- if (!FindAll(substring_selected, annotation_usecase, &results)) {
+ if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION,
+ &results)) {
return false;
}
@@ -216,8 +218,9 @@ bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
bool NumberAnnotator::FindAll(const UnicodeText& context,
AnnotationUsecase annotation_usecase,
+ ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
- if (!options_->enabled()) {
+ if (!options_->enabled() || !(options_->enabled_modes() & mode)) {
return true;
}
diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h
index d83bea0..dcc2d48 100644
--- a/native/annotator/number/number.h
+++ b/native/annotator/number/number.h
@@ -58,7 +58,7 @@ class NumberAnnotator {
// Finds all number instances in the input text. Returns true in any case.
bool FindAll(const UnicodeText& context_unicode,
- AnnotationUsecase annotation_usecase,
+ AnnotationUsecase annotation_usecase, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const;
private:
diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc
index f47933f..98140f4 100644
--- a/native/annotator/number/number_test-include.cc
+++ b/native/annotator/number/number_test-include.cc
@@ -16,6 +16,7 @@
#include "annotator/number/number_test-include.h"
+#include <set>
#include <string>
#include <vector>
@@ -34,37 +35,57 @@ namespace test_internal {
using ::testing::AllOf;
using ::testing::ElementsAre;
using ::testing::Field;
+using ::testing::IsEmpty;
using ::testing::Matcher;
using ::testing::UnorderedElementsAre;
+namespace {
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ NumberAnnotatorOptionsT options;
+ options.enabled = true;
+ options.priority_score = -10.0;
+ options.float_number_priority_score = 1.0;
+ options.enabled_annotation_usecases =
+ 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ options.max_number_of_digits = 20;
+ options.enabled_modes = enabled_modes;
+
+ options.percentage_priority_score = 1.0;
+ options.percentage_annotation_usecases =
+ (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) +
+ (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART);
+ std::set<std::string> percent_suffixes(
+ {"パーセント", "percent", "pércént", "pc", "pct", "%", "٪", "﹪", "%"});
+ for (const std::string& string_value : percent_suffixes) {
+ options.percentage_pieces_string.append(string_value);
+ options.percentage_pieces_string.push_back('\0');
+ }
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+} // namespace
+
const NumberAnnotatorOptions*
-NumberAnnotatorTest::TestingNumberAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- NumberAnnotatorOptionsT options;
- options.enabled = true;
- options.priority_score = -10.0;
- options.float_number_priority_score = 1.0;
- options.enabled_annotation_usecases =
- 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW;
- options.max_number_of_digits = 20;
-
- options.percentage_priority_score = 1.0;
- options.percentage_annotation_usecases =
- (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) +
- (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART);
- std::set<std::string> percent_suffixes({"パーセント", "percent", "pércént",
- "pc", "pct", "%", "٪", "﹪", "%"});
- for (const std::string& string_value : percent_suffixes) {
- options.percentage_pieces_string.append(string_value);
- options.percentage_pieces_string.push_back('\0');
- }
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data());
+NumberAnnotatorTest::TestingNumberAnnotatorOptions(ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_selection =
+ CreateOptionsData(ModeFlag_SELECTION);
+ static const flatbuffers::DetachedBuffer* options_data_no_selection =
+ CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
+ static const flatbuffers::DetachedBuffer* options_data_all =
+ CreateOptionsData(ModeFlag_ALL);
+
+ if (enabled_modes == ModeFlag_SELECTION) {
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(
+ options_data_selection->data());
+ } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(
+ options_data_no_selection->data());
+ } else {
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(
+ options_data_all->data());
+ }
}
MATCHER_P(IsCorrectCollection, collection, "collection is " + collection) {
@@ -124,6 +145,14 @@ TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) {
EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345);
}
+TEST_F(NumberAnnotatorForSelectionTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberAsFloatCorrectly) {
ClassificationResult classification_result;
EXPECT_TRUE(number_annotator_.ClassifyText(
@@ -167,7 +196,7 @@ TEST_F(NumberAnnotatorTest, FindsAllIntegerAndFloatNumbersInText) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("how much is 2 plus 5 divided by 7% minus 3.14 "
"what about 68.9# or 68.9#?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -268,7 +297,8 @@ TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiJaPercentageCorrectSuffix) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("明日の降水確率は10パーセント 音量を12にセット"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION,
+ &result));
EXPECT_THAT(result,
UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(8, 10), "number",
@@ -285,7 +315,7 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 "
"but not $99."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -307,12 +337,23 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
/*int_value=*/39, /*double_value=*/39.0)));
}
+TEST_F(NumberAnnotatorForAnnotationAndClassificationTest,
+ FindsAllDisabledModeReturnsNoResults) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 "
+ "but not $99."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result));
+
+ EXPECT_THAT(result, IsEmpty());
+}
+
TEST_F(NumberAnnotatorTest, FindsNoNumberInText) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... 12345a ... 12345..12345 and 123a45 are not valid. "
"And -#5% is also bad."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -323,7 +364,8 @@ TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"It's 12, 13, 14! Or 15??? For sure 16: 17; 18. and -19"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION,
+ &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -348,7 +390,7 @@ TEST_F(NumberAnnotatorTest, FindsFloatNumberWithPunctuation) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("It's 12.123, 13.45, 14.54321! Or 15.1? Maybe 16.33: "
"17.21; but for sure 18.90."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -379,7 +421,7 @@ TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan(
CodepointSpan(0, 2), "number",
@@ -390,7 +432,7 @@ TEST_F(NumberAnnotatorTest, HandlesNegativeNumbers) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Number -5 and -5% and not number --5%"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -408,7 +450,7 @@ TEST_F(NumberAnnotatorTest, FindGoodPercentageContexts) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"5 percent, 10 pct, 25 pc and 17%, -5 percent, 10% are percentages"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -448,7 +490,7 @@ TEST_F(NumberAnnotatorTest, FindSinglePercentageInContext) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("5%"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(0, 1), "number",
@@ -463,7 +505,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentageContexts) {
// A valid number is followed by only one punctuation element.
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("10, pct, 25 prc, 5#: percentage are not percentages"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -478,7 +520,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentagePunctuationContexts) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"#!24% or :?33 percent are not valid percentages, nor numbers."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_TRUE(result.empty());
}
@@ -488,7 +530,7 @@ TEST_F(NumberAnnotatorTest, FindPercentageInNonAsciiContext) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"At the café 10% or 25 percent of people are nice. Only 10%!"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -748,7 +790,7 @@ TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -757,7 +799,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -766,7 +808,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -786,7 +828,7 @@ TEST_F(NumberAnnotatorTest, ForNumberAnnotationsSetsScoreAndPriorityScore) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Come at 9 or 10 ok?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -811,7 +853,7 @@ TEST_F(NumberAnnotatorTest,
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Results are between 12.5 and 13.5, right?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(20, 24), "number",
@@ -845,7 +887,7 @@ TEST_F(NumberAnnotatorTest, ForPercentageAnnotationsSetsScoreAndPriorityScore) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Results are between 9% and 10 percent."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(20, 21), "number",
@@ -887,7 +929,8 @@ TEST_F(NumberAnnotatorTest, NumberDisabledPercentageEnabledForSmartUsecase) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Accuracy for experiment 3 is 9%."),
- AnnotationUsecase_ANNOTATION_USECASE_SMART, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, ModeFlag_ANNOTATION,
+ &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(29, 31), "percentage",
/*int_value=*/9, /*double_value=*/9.0,
@@ -898,7 +941,7 @@ TEST_F(NumberAnnotatorTest, MathOperatorsNotAnnotatedAsNumbersFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("how much is 2 + 2 or 5 - 96 * 89"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -928,7 +971,7 @@ TEST_F(NumberAnnotatorTest, SlashSeparatesTwoNumbersFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("what's 1 + 2/3 * 4/5 * 6 / 7"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -972,7 +1015,7 @@ TEST_F(NumberAnnotatorTest, SlashDoesNotSeparatesTwoNumbersFindAll) {
// 2 in the "2/" context is a number because / is punctuation
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("what's 2a2/3 or 2/s4 or 2/ or /3 or //3 or 2//"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan(
CodepointSpan(24, 25), "number",
@@ -983,7 +1026,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextAnnotatedFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("The interval is: (12, 13) or [-12, -4.5)"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -1002,7 +1045,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextNotAnnotatedFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("The interval is: -(12, 138*)"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_TRUE(result.empty());
}
@@ -1012,7 +1055,7 @@ TEST_F(NumberAnnotatorTest, FractionalNumberDotsFindAll) {
// Dots source: https://unicode-search.net/unicode-namesearch.pl?term=period
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("3.1 3﹒2 3.3"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(0, 3), "number",
@@ -1032,7 +1075,7 @@ TEST_F(NumberAnnotatorTest, NonAsciiDigitsFindAll) {
// Digits source: https://unicode-search.net/unicode-namesearch.pl?term=digit
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("3 3﹒2 3.3%"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(0, 1), "number",
@@ -1052,7 +1095,7 @@ TEST_F(NumberAnnotatorTest, AnnotatedZeroPrecededNumbersFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Numbers: 0.9 or 09 or 09.9 or 032310"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(9, 12), "number",
@@ -1072,7 +1115,7 @@ TEST_F(NumberAnnotatorTest, ZeroAfterDotFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("15.0 16.00"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -1086,7 +1129,7 @@ TEST_F(NumberAnnotatorTest, NineDotNineFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("9.9 9.99 99.99 99.999 99.9999"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
diff --git a/native/annotator/number/number_test-include.h b/native/annotator/number/number_test-include.h
index 9de7c86..14fc6f2 100644
--- a/native/annotator/number/number_test-include.h
+++ b/native/annotator/number/number_test-include.h
@@ -17,6 +17,7 @@
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
+#include "annotator/model_generated.h"
#include "annotator/number/number.h"
#include "utils/jvm-test-utils.h"
#include "gtest/gtest.h"
@@ -25,17 +26,32 @@ namespace libtextclassifier3 {
namespace test_internal {
class NumberAnnotatorTest : public ::testing::Test {
+ private:
protected:
- NumberAnnotatorTest()
+ explicit NumberAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
: unilib_(CreateUniLibForTesting()),
- number_annotator_(TestingNumberAnnotatorOptions(), unilib_.get()) {}
+ number_annotator_(TestingNumberAnnotatorOptions(enabled_modes),
+ unilib_.get()) {}
- const NumberAnnotatorOptions* TestingNumberAnnotatorOptions();
+ const NumberAnnotatorOptions* TestingNumberAnnotatorOptions(
+ ModeFlag enabled_modes);
std::unique_ptr<UniLib> unilib_;
NumberAnnotator number_annotator_;
};
+class NumberAnnotatorForAnnotationAndClassificationTest
+ : public NumberAnnotatorTest {
+ protected:
+ NumberAnnotatorForAnnotationAndClassificationTest()
+ : NumberAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class NumberAnnotatorForSelectionTest : public NumberAnnotatorTest {
+ protected:
+ NumberAnnotatorForSelectionTest() : NumberAnnotatorTest(ModeFlag_SELECTION) {}
+};
+
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/native/annotator/person_name/person-name-engine-dummy.h b/native/annotator/person_name/person-name-engine-dummy.h
index 9c83241..44d2821 100644
--- a/native/annotator/person_name/person-name-engine-dummy.h
+++ b/native/annotator/person_name/person-name-engine-dummy.h
@@ -46,7 +46,7 @@ class PersonNameEngine {
}
bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
+ const std::vector<Token>& tokens, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
return true;
}
diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs
index b15543f..6ef4a72 100644
--- a/native/annotator/person_name/person_name_model.fbs
+++ b/native/annotator/person_name/person_name_model.fbs
@@ -14,6 +14,8 @@
// limitations under the License.
//
+include "annotator/model.fbs";
+
file_identifier "TC2 ";
// Next ID: 2
@@ -26,7 +28,7 @@ table PersonName {
person_name:string (shared);
}
-// Next ID: 6
+// Next ID: 7
namespace libtextclassifier3;
table PersonNameModel {
// Decides if the person name annotator is enabled.
@@ -52,6 +54,9 @@ table PersonNameModel {
// upper case character and have at least one lower case character.
// required
annotate_capitalized_names_only:bool;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
}
root_type libtextclassifier3.PersonNameModel;
diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc
index 666b7c7..0cb86ee 100644
--- a/native/annotator/pod_ner/pod-ner-impl.cc
+++ b/native/annotator/pod_ner/pod-ner-impl.cc
@@ -398,6 +398,10 @@ bool PodNerAnnotator::AnnotateAroundSpanOfInterest(
std::vector<AnnotatedSpan> *results) const {
TC3_CHECK(results != nullptr);
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
+
std::vector<int32_t> wordpiece_indices;
std::vector<int32_t> token_starts;
std::vector<Token> tokens;
@@ -470,6 +474,11 @@ bool PodNerAnnotator::SuggestSelection(const UnicodeText &context,
return false;
}
+ if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
+ *result = {};
+ return false;
+ }
+
for (const AnnotatedSpan &annotation : annotations) {
TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation;
if (annotation.span.first <= click.first &&
@@ -491,6 +500,10 @@ bool PodNerAnnotator::ClassifyText(const UnicodeText &context,
CodepointSpan click,
ClassificationResult *result) const {
TC3_VLOG(INFO) << "POD NER ClassifyText " << click;
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return false;
+ }
+
std::vector<AnnotatedSpan> annotations;
if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
return false;
diff --git a/native/annotator/pod_ner/pod-ner-impl_test.cc b/native/annotator/pod_ner/pod-ner-impl_test.cc
index c7d0bee..5accebd 100644
--- a/native/annotator/pod_ner/pod-ner-impl_test.cc
+++ b/native/annotator/pod_ner/pod-ner-impl_test.cc
@@ -53,7 +53,7 @@ constexpr float kDefaultPriorityScore = 0.5;
class PodNerTest : public testing::Test {
protected:
- PodNerTest() {
+ explicit PodNerTest(ModeFlag enabled_modes = ModeFlag_ALL) {
PodNerModelT model;
model.min_number_of_tokens = kMinNumberOfTokens;
@@ -68,6 +68,7 @@ class PodNerTest : public testing::Test {
GetTestFileContent("annotator/pod_ner/test_data/vocab.txt");
model.word_piece_vocab = std::vector<uint8_t>(
word_piece_vocab_buffer.begin(), word_piece_vocab_buffer.end());
+ model.enabled_modes = enabled_modes;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(PodNerModel::Pack(builder, &model));
@@ -101,6 +102,17 @@ class PodNerTest : public testing::Test {
std::unique_ptr<UniLib> unilib_;
};
+class PodNerForAnnotationAndClassificationTest : public PodNerTest {
+ protected:
+ PodNerForAnnotationAndClassificationTest()
+ : PodNerTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class PodNerForSelectionTest : public PodNerTest {
+ protected:
+ PodNerForSelectionTest() : PodNerTest(ModeFlag_SELECTION) {}
+};
+
TEST_F(PodNerTest, AnnotateSmokeTest) {
std::unique_ptr<PodNerAnnotator> annotator =
PodNerAnnotator::Create(model_, *unilib_);
@@ -209,6 +221,18 @@ TEST_F(PodNerTest, AnnotateDefaultCollections) {
}
}
+TEST_F(PodNerForSelectionTest, AnnotateWithDisabledAnnotationReturnsNoResults) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string multi_word_location = "I live in New York";
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(multi_word_location),
+ &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
TEST_F(PodNerTest, AnnotateConfigurableCollections) {
std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack());
ASSERT_TRUE(unpacked_model != nullptr);
@@ -525,6 +549,18 @@ TEST_F(PodNerTest, SuggestSelectionTest) {
EXPECT_EQ(suggested_span.span, CodepointSpan(kInvalidIndex, kInvalidIndex));
}
+TEST_F(PodNerForAnnotationAndClassificationTest,
+ SuggestSelectionWithDisabledSelectionReturnsNoResults) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ AnnotatedSpan suggested_span;
+ EXPECT_FALSE(annotator->SuggestSelection(
+ UTF8ToUnicodeText("Google New York, in New York"), {7, 10},
+ &suggested_span));
+}
+
TEST_F(PodNerTest, ClassifyTextTest) {
std::unique_ptr<PodNerAnnotator> annotator =
PodNerAnnotator::Create(model_, *unilib_);
@@ -536,6 +572,17 @@ TEST_F(PodNerTest, ClassifyTextTest) {
EXPECT_EQ(result.collection, "location");
}
+TEST_F(PodNerForSelectionTest,
+ ClassifyTextWithDisabledClassificationReturnsFalse) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ ClassificationResult result;
+ ASSERT_FALSE(annotator->ClassifyText(UTF8ToUnicodeText("We met in New York"),
+ {10, 18}, &result));
+}
+
TEST_F(PodNerTest, ThreadSafety) {
std::unique_ptr<PodNerAnnotator> annotator =
PodNerAnnotator::Create(model_, *unilib_);
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
index 2c5a43c..e38109c 100644
--- a/native/annotator/translate/translate.cc
+++ b/native/annotator/translate/translate.cc
@@ -21,6 +21,7 @@
#include "annotator/collections.h"
#include "annotator/entity-data_generated.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "lang_id/lang-id-wrapper.h"
#include "utils/base/logging.h"
@@ -34,6 +35,10 @@ bool TranslateAnnotator::ClassifyText(
const UnicodeText& context, CodepointSpan selection_indices,
const std::string& user_familiar_language_tags,
ClassificationResult* classification_result) const {
+ if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return false;
+ }
+
std::vector<TranslateAnnotator::LanguageConfidence> confidences;
if (options_->algorithm() ==
TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
diff --git a/native/annotator/translate/translate_test.cc b/native/annotator/translate/translate_test.cc
index 5c4a63f..90227ec 100644
--- a/native/annotator/translate/translate_test.cc
+++ b/native/annotator/translate/translate_test.cc
@@ -31,20 +31,33 @@ namespace {
using testing::AllOf;
using testing::Field;
-const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- TranslateAnnotatorOptionsT options;
- options.enabled = true;
- options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
- options.backoff_options.reset(
- new TranslateAnnotatorOptions_::BackoffOptionsT());
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ TranslateAnnotatorOptionsT options;
+ options.enabled = true;
+ options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
+ options.backoff_options.reset(
+ new TranslateAnnotatorOptions_::BackoffOptionsT());
+ options.enabled_modes = enabled_modes;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+
+const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions(
+ ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_classification =
+ CreateOptionsData(ModeFlag_CLASSIFICATION);
+ static const flatbuffers::DetachedBuffer* options_data_none =
+ CreateOptionsData(ModeFlag_NONE);
+
+ if (enabled_modes == ModeFlag_CLASSIFICATION) {
+ return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
+ options_data_classification->data());
+ } else {
+ return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
+ options_data_none->data());
+ }
}
class TestingTranslateAnnotator : public TranslateAnnotator {
@@ -60,11 +73,12 @@ std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
class TranslateAnnotatorTest : public ::testing::Test {
protected:
- TranslateAnnotatorTest()
+ explicit TranslateAnnotatorTest(
+ ModeFlag enabled_modes = ModeFlag_CLASSIFICATION)
: INIT_UNILIB_FOR_TESTING(unilib_),
langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
GetModelPath() + "lang_id.smfb")),
- translate_annotator_(TestingTranslateAnnotatorOptions(),
+ translate_annotator_(TestingTranslateAnnotatorOptions(enabled_modes),
langid_model_.get(), &unilib_) {}
UniLib unilib_;
@@ -72,6 +86,11 @@ class TranslateAnnotatorTest : public ::testing::Test {
TestingTranslateAnnotator translate_annotator_;
};
+class TranslateAnnotatorForNoneTest : public TranslateAnnotatorTest {
+ protected:
+ TranslateAnnotatorForNoneTest() : TranslateAnnotatorTest(ModeFlag_NONE) {}
+};
+
TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
ClassificationResult classification;
EXPECT_TRUE(translate_annotator_.ClassifyText(
@@ -110,6 +129,13 @@ TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
predictions->Get(1)->confidence_score());
}
+TEST_F(TranslateAnnotatorForNoneTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification;
+ EXPECT_FALSE(translate_annotator_.ClassifyText(
+ UTF8ToUnicodeText("学校"), {0, 2}, "en", &classification));
+}
+
TEST_F(TranslateAnnotatorTest,
WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
ClassificationResult classification;
diff --git a/native/annotator/types.h b/native/annotator/types.h
index ada301c..8485d44 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -65,6 +65,7 @@ struct CodepointSpan {
CodepointSpan(CodepointIndex start, CodepointIndex end)
: first(start), second(end) {}
+ CodepointSpan(const CodepointSpan& other) = default;
CodepointSpan& operator=(const CodepointSpan& other) = default;
bool operator==(const CodepointSpan& other) const {
@@ -439,6 +440,8 @@ struct ClassificationResult {
contact_nickname, contact_email_address, contact_phone_number,
contact_account_type, contact_account_name, contact_id,
contact_alternate_name;
+ int64 contact_recognition_source;
+ float contact_neural_match_score;
std::string app_name, app_package_name;
int64 numeric_value;
double numeric_double_value;
@@ -577,12 +580,18 @@ struct ClassificationOptions : public BaseOptions, public DatetimeOptions {
std::string user_familiar_language_tags;
// If true, trigger dictionary on words that are of beginner level.
bool trigger_dictionary_on_beginner_words = false;
+ // If true, generate *Add* contact intent for email/phone entity.
+ bool enable_add_contact_intent;
+ // If true, generate *Search* intent for named entities.
+ bool enable_search_intent;
bool operator==(const ClassificationOptions& other) const {
return this->user_familiar_language_tags ==
other.user_familiar_language_tags &&
this->trigger_dictionary_on_beginner_words ==
other.trigger_dictionary_on_beginner_words &&
+ this->enable_add_contact_intent == other.enable_add_contact_intent &&
+ this->enable_search_intent == other.enable_search_intent &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
diff --git a/native/annotator/vocab/vocab-annotator-impl.cc b/native/annotator/vocab/vocab-annotator-impl.cc
index 4b5cc73..b464f54 100644
--- a/native/annotator/vocab/vocab-annotator-impl.cc
+++ b/native/annotator/vocab/vocab-annotator-impl.cc
@@ -61,6 +61,9 @@ bool VocabAnnotator::Annotate(
const UnicodeText& context,
const std::vector<Locale> detected_text_language_tags,
bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const {
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
std::vector<Token> tokens = feature_processor_.Tokenize(context);
for (const Token& token : tokens) {
ClassificationResult classification_result;
@@ -90,6 +93,9 @@ bool VocabAnnotator::ClassifyTextInternal(
const std::vector<Locale> detected_text_language_tags,
bool trigger_on_beginner_words, ClassificationResult* classification_result,
CodepointSpan* classified_span) const {
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return false;
+ }
if (vocab_level_table_ == nullptr) {
return false;
}