diff options
Diffstat (limited to 'native/annotator')
27 files changed, 563 insertions, 169 deletions
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index e0d4241..a8483f1 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -432,7 +432,8 @@ void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib, datetime_parser_ = std::make_unique<GrammarDatetimeParser>( *analyzer_, *datetime_grounder_, /*target_classification_score=*/1.0, - /*priority_score=*/1.0); + /*priority_score=*/1.0, + model_->datetime_grammar_model()->enabled_modes()); } } else if (model_->datetime_model()) { datetime_parser_ = RegexDatetimeParser::Instance( @@ -604,6 +605,8 @@ bool Annotator::InitializeKnowledgeEngine( if (model_->triggering_options() != nullptr) { knowledge_engine->SetPriorityScore( model_->triggering_options()->knowledge_priority_score()); + knowledge_engine->SetEnabledModes( + model_->triggering_options()->knowledge_enabled_modes()); } knowledge_engine_ = std::move(knowledge_engine); return true; @@ -621,10 +624,21 @@ bool Annotator::InitializeContactEngine(const std::string& serialized_config) { return true; } +void Annotator::CleanUpContactEngine() { + if (contact_engine_ == nullptr) { + TC3_LOG(INFO) + << "Attempting to clean up contact engine that does not exist."; + return; + } + contact_engine_->CleanUp(); +} + bool Annotator::InitializeInstalledAppEngine( const std::string& serialized_config) { std::unique_ptr<InstalledAppEngine> installed_app_engine( - new InstalledAppEngine(selection_feature_processor_.get(), unilib_)); + new InstalledAppEngine( + selection_feature_processor_.get(), unilib_, + model_->triggering_options()->installed_app_enabled_modes())); if (!installed_app_engine->Initialize(serialized_config)) { TC3_LOG(ERROR) << "Failed to initialize the installed app engine."; return false; @@ -912,38 +926,40 @@ CodepointSpan Annotator::SuggestSelection( !knowledge_engine_ ->Chunk(context, options.annotation_usecase, options.location_context, Permissions(), - AnnotateMode::kEntityAnnotation, &candidates) + AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION, + &candidates) .ok()) { TC3_LOG(ERROR) << "Knowledge suggest selection failed."; return original_click_indices; } if (contact_engine_ != nullptr && - !contact_engine_->Chunk(context_unicode, tokens, + !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Contact suggest selection failed."; return original_click_indices; } if (installed_app_engine_ != nullptr && - !installed_app_engine_->Chunk(context_unicode, tokens, + !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Installed app suggest selection failed."; return original_click_indices; } if (number_annotator_ != nullptr && !number_annotator_->FindAll(context_unicode, options.annotation_usecase, + ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Number annotator failed in suggest selection."; return original_click_indices; } if (duration_annotator_ != nullptr && - !duration_annotator_->FindAll(context_unicode, tokens, - options.annotation_usecase, - &candidates.annotated_spans[0])) { + !duration_annotator_->FindAll( + context_unicode, tokens, options.annotation_usecase, + ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Duration annotator failed in suggest selection."; return original_click_indices; } if (person_name_engine_ != nullptr && - !person_name_engine_->Chunk(context_unicode, tokens, + !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Person name suggest selection failed."; return original_click_indices; @@ -964,7 +980,9 @@ CodepointSpan Annotator::SuggestSelection( candidates.annotated_spans[0].push_back(pod_ner_suggested_span); } - if (experimental_annotator_ != nullptr) { + if (experimental_annotator_ != nullptr && + (model_->triggering_options()->experimental_enabled_modes() & + ModeFlag_SELECTION)) { candidates.annotated_spans[0].push_back( experimental_annotator_->SuggestSelection(context_unicode, click_indices)); @@ -1896,7 +1914,9 @@ std::vector<ClassificationResult> Annotator::ClassifyText( candidates.push_back({selection_indices, {vocab_annotator_result}}); } - if (experimental_annotator_) { + if (experimental_annotator_ && + (model_->triggering_options()->experimental_enabled_modes() & + ModeFlag_CLASSIFICATION)) { experimental_annotator_->ClassifyText(context_unicode, selection_indices, candidates); } @@ -2218,7 +2238,8 @@ Status Annotator::AnnotateSingleInput( const bool contact_annotations_enabled = !is_raw_usecase || is_entity_type_enabled(Collections::Contact()); if (contact_annotations_enabled && contact_engine_ && - !contact_engine_->Chunk(context_unicode, tokens, candidates)) { + !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION, + candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk."); } @@ -2226,7 +2247,8 @@ Status Annotator::AnnotateSingleInput( const bool app_annotations_enabled = !is_raw_usecase || is_entity_type_enabled(Collections::App()); if (app_annotations_enabled && installed_app_engine_ && - !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) { + !installed_app_engine_->Chunk(context_unicode, tokens, + ModeFlag_ANNOTATION, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run installed app engine Chunk."); } @@ -2237,7 +2259,7 @@ Status Annotator::AnnotateSingleInput( is_entity_type_enabled(Collections::Percentage())); if (number_annotations_enabled && number_annotator_ != nullptr && !number_annotator_->FindAll(context_unicode, options.annotation_usecase, - candidates)) { + ModeFlag_ANNOTATION, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run number annotator FindAll."); } @@ -2247,7 +2269,8 @@ Status Annotator::AnnotateSingleInput( !is_raw_usecase || is_entity_type_enabled(Collections::Duration()); if (duration_annotations_enabled && duration_annotator_ != nullptr && !duration_annotator_->FindAll(context_unicode, tokens, - options.annotation_usecase, candidates)) { + options.annotation_usecase, + ModeFlag_ANNOTATION, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run duration annotator FindAll."); } @@ -2256,7 +2279,8 @@ Status Annotator::AnnotateSingleInput( const bool person_annotations_enabled = !is_raw_usecase || is_entity_type_enabled(Collections::PersonName()); if (person_annotations_enabled && person_name_engine_ && - !person_name_engine_->Chunk(context_unicode, tokens, candidates)) { + !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION, + candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run person name engine Chunk."); } @@ -2290,6 +2314,8 @@ Status Annotator::AnnotateSingleInput( // Annotate with the experimental annotator. if (experimental_annotator_ != nullptr && + (model_->triggering_options()->experimental_enabled_modes() & + ModeFlag_ANNOTATION) && !experimental_annotator_->Annotate(context_unicode, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator."); } @@ -2376,14 +2402,21 @@ StatusOr<Annotations> Annotator::AnnotateStructuredInput( .relative_bounding_box_height = string_fragment.bounding_box_height}); } + const EnabledEntityTypes is_entity_type_enabled(options.entity_types); + const bool is_raw_usecase = + options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW; + + const bool knowledge_engine_annotations_enabled = + !is_raw_usecase || is_entity_type_enabled(Collections::Entity()); // KnowledgeEngine is special, because it supports annotation of multiple // fragments at once. - if (knowledge_engine_ && + if (knowledge_engine_annotations_enabled && knowledge_engine_ && !knowledge_engine_ ->ChunkMultipleSpans(text_to_annotate, fragment_metadata, options.annotation_usecase, options.location_context, options.permissions, - options.annotate_mode, &annotation_candidates) + options.annotate_mode, ModeFlag_ANNOTATION, + &annotation_candidates) .ok()) { return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk."); } diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h index d69fe32..5df8129 100644 --- a/native/annotator/annotator.h +++ b/native/annotator/annotator.h @@ -149,6 +149,9 @@ class Annotator { // Initializes the contact engine with the given config. bool InitializeContactEngine(const std::string& serialized_config); + // Cleans up the resources associated with the contact engine. + void CleanUpContactEngine(); + // Initializes the installed app engine with the given config. bool InitializeInstalledAppEngine(const std::string& serialized_config); diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc index 6e7eeab..3d352c6 100644 --- a/native/annotator/annotator_jni.cc +++ b/native/annotator/annotator_jni.cc @@ -275,6 +275,7 @@ StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject( device_locales, classification_result, options->reference_time_ms_utc, context, selection_indices, app_context, model_context->model()->entity_data_schema(), + options->enable_add_contact_intent, options->enable_search_intent, &remote_action_templates)) { return {Status::UNKNOWN}; } @@ -896,6 +897,9 @@ TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) (JNIEnv* env, jobject thiz, jlong ptr) { const AnnotatorJniContext* context = reinterpret_cast<AnnotatorJniContext*>(ptr); + if (context != nullptr && context->model()) { + context->model()->CleanUpContactEngine(); + } delete context; } diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc index a6f636f..6ee4977 100644 --- a/native/annotator/annotator_jni_common.cc +++ b/native/annotator/annotator_jni_common.cc @@ -279,6 +279,23 @@ StatusOr<ClassificationOptions> FromJavaClassificationOptions( JniHelper::CallBooleanMethod(env, joptions, get_trigger_dictionary_on_beginner_words)); + // .getEnableAddContactIntent() + TC3_ASSIGN_OR_RETURN( + jmethodID get_enable_add_contact_intent, + JniHelper::GetMethodID(env, options_class.get(), + "getEnableAddContactIntent", "()Z")); + TC3_ASSIGN_OR_RETURN(classifier_options.enable_add_contact_intent, + JniHelper::CallBooleanMethod( + env, joptions, get_enable_add_contact_intent)); + + // .getEnableSearchIntent() + TC3_ASSIGN_OR_RETURN(jmethodID get_enable_search_intent, + JniHelper::GetMethodID(env, options_class.get(), + "getEnableSearchIntent", "()Z")); + TC3_ASSIGN_OR_RETURN( + classifier_options.enable_search_intent, + JniHelper::CallBooleanMethod(env, joptions, get_enable_search_intent)); + return classifier_options; } diff --git a/native/annotator/collections.h b/native/annotator/collections.h index 417b447..becdcdb 100644 --- a/native/annotator/collections.h +++ b/native/annotator/collections.h @@ -144,6 +144,36 @@ class Collections { *[]() { return new std::string("otp_code"); }(); return value; } + static const std::string& Art() { + static const std::string& value = + *[]() { return new std::string("art"); }(); + return value; + } + static const std::string& ConsumerGood() { + static const std::string& value = + *[]() { return new std::string("consumer_good"); }(); + return value; + } + static const std::string& Event() { + static const std::string& value = + *[]() { return new std::string("event"); }(); + return value; + } + static const std::string& Location() { + static const std::string& value = + *[]() { return new std::string("location"); }(); + return value; + } + static const std::string& Organization() { + static const std::string& value = + *[]() { return new std::string("organization"); }(); + return value; + } + static const std::string& Person() { + static const std::string& value = + *[]() { return new std::string("person"); }(); + return value; + } }; } // namespace libtextclassifier3 diff --git a/native/annotator/contact/contact-engine-dummy.h b/native/annotator/contact/contact-engine-dummy.h index fe60203..211553c 100644 --- a/native/annotator/contact/contact-engine-dummy.h +++ b/native/annotator/contact/contact-engine-dummy.h @@ -47,13 +47,15 @@ class ContactEngine { } bool Chunk(const UnicodeText& context_unicode, - const std::vector<Token>& tokens, + const std::vector<Token>& tokens, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { return true; } void AddContactMetadataToKnowledgeClassificationResult( ClassificationResult* classification_result) const {} + + void CleanUp() const {} }; } // namespace libtextclassifier3 diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc index 6d51c19..c49a01a 100644 --- a/native/annotator/datetime/grammar-parser.cc +++ b/native/annotator/datetime/grammar-parser.cc @@ -20,6 +20,7 @@ #include <unordered_set> #include "annotator/datetime/datetime-grounder.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/grammar/analyzer.h" #include "utils/grammar/evaluated-derivation.h" @@ -33,11 +34,13 @@ namespace libtextclassifier3 { GrammarDatetimeParser::GrammarDatetimeParser( const grammar::Analyzer& analyzer, const DatetimeGrounder& datetime_grounder, - const float target_classification_score, const float priority_score) + const float target_classification_score, const float priority_score, + ModeFlag enabled_modes) : analyzer_(analyzer), datetime_grounder_(datetime_grounder), target_classification_score_(target_classification_score), - priority_score_(priority_score) {} + priority_score_(priority_score), + enabled_modes_(enabled_modes) {} StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse( const std::string& input, const int64 reference_time_ms_utc, @@ -54,6 +57,10 @@ StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse( const std::string& reference_timezone, const LocaleList& locale_list, ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end) const { + if (!(enabled_modes_ & mode)) { + return std::vector<DatetimeParseResultSpan>(); + } + std::vector<DatetimeParseResultSpan> results; UnsafeArena arena(/*block_size=*/16 << 10); std::vector<Locale> locales = locale_list.GetLocales(); diff --git a/native/annotator/datetime/grammar-parser.h b/native/annotator/datetime/grammar-parser.h index 6ff4b46..35da843 100644 --- a/native/annotator/datetime/grammar-parser.h +++ b/native/annotator/datetime/grammar-parser.h @@ -22,6 +22,7 @@ #include "annotator/datetime/datetime-grounder.h" #include "annotator/datetime/parser.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/statusor.h" #include "utils/grammar/analyzer.h" @@ -37,7 +38,8 @@ class GrammarDatetimeParser : public DatetimeParser { explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer, const DatetimeGrounder& datetime_grounder, const float target_classification_score, - const float priority_score); + const float priority_score, + ModeFlag enabled_modes); // Parses the dates in 'input' and fills result. Makes sure that the results // do not overlap. @@ -61,6 +63,7 @@ class GrammarDatetimeParser : public DatetimeParser { const DatetimeGrounder& datetime_grounder_; const float target_classification_score_; const float priority_score_; + const ModeFlag enabled_modes_; }; } // namespace libtextclassifier3 diff --git a/native/annotator/datetime/grammar-parser_test.cc b/native/annotator/datetime/grammar-parser_test.cc index cf2dffd..b8a270d 100644 --- a/native/annotator/datetime/grammar-parser_test.cc +++ b/native/annotator/datetime/grammar-parser_test.cc @@ -22,6 +22,7 @@ #include "annotator/datetime/datetime-grounder.h" #include "annotator/datetime/testing/base-parser-test.h" #include "annotator/datetime/testing/datetime-component-builder.h" +#include "annotator/model_generated.h" #include "utils/grammar/analyzer.h" #include "utils/jvm-test-utils.h" #include "utils/test-data-test-utils.h" @@ -42,7 +43,15 @@ std::string ReadFile(const std::string& file_name) { class GrammarDatetimeParserTest : public DateTimeParserTest { public: - void SetUp() override { + void SetUp() override { ResetParser(ModeFlag_ALL); } + + // Exposes the date time parser for tests and evaluations. + const DatetimeParser* DatetimeParserForTests() const override { + return parser_.get(); + } + + protected: + void ResetParser(ModeFlag enabled_modes) { grammar_buffer_ = ReadFile(GetModelPath() + "datetime.fb"); unilib_ = CreateUniLibForTesting(); calendarlib_ = CreateCalendarLibForTesting(); @@ -51,12 +60,8 @@ class GrammarDatetimeParserTest : public DateTimeParserTest { datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_.get()); parser_.reset(new GrammarDatetimeParser(*analyzer_, *datetime_grounder_, /*target_classification_score=*/1.0, - /*priority_score=*/1.0)); - } - - // Exposes the date time parser for tests and evaluations. - const DatetimeParser* DatetimeParserForTests() const override { - return parser_.get(); + /*priority_score=*/1.0, + enabled_modes)); } private: @@ -486,6 +491,13 @@ TEST_F(GrammarDatetimeParserTest, Parse) { .Build()})); } +TEST_F(GrammarDatetimeParserTest, NotEnabledModeHasNoResult) { + ResetParser(ModeFlag_SELECTION); + // `DateTimeParserTest` implementation parses the input under the ANNOTATION + // mode. + EXPECT_TRUE(HasNoResult("{January 1, 1988}")); +} + TEST_F(GrammarDatetimeParserTest, DateValidation) { EXPECT_TRUE(ParsesCorrectly( "{01/02/2020}", 1577919600000, GRANULARITY_DAY, diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc index c59b8e0..df4c60d 100644 --- a/native/annotator/duration/duration.cc +++ b/native/annotator/duration/duration.cc @@ -20,6 +20,7 @@ #include <cstdlib> #include "annotator/collections.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/base/macros.h" @@ -125,8 +126,10 @@ bool DurationAnnotator::ClassifyText( const UnicodeText& context, CodepointSpan selection_indices, AnnotationUsecase annotation_usecase, ClassificationResult* classification_result) const { - if (!options_->enabled() || ((options_->enabled_annotation_usecases() & - (1 << annotation_usecase))) == 0) { + if (!options_->enabled() || + ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) == + 0 || + !(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) { return false; } @@ -151,9 +154,12 @@ bool DurationAnnotator::ClassifyText( bool DurationAnnotator::FindAll(const UnicodeText& context, const std::vector<Token>& tokens, AnnotationUsecase annotation_usecase, + ModeFlag mode, std::vector<AnnotatedSpan>* results) const { - if (!options_->enabled() || ((options_->enabled_annotation_usecases() & - (1 << annotation_usecase))) == 0) { + if (!options_->enabled() || + ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) == + 0 || + !(options_->enabled_modes() & mode)) { return true; } diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h index 1a42ac3..e99542c 100644 --- a/native/annotator/duration/duration.h +++ b/native/annotator/duration/duration.h @@ -87,7 +87,7 @@ class DurationAnnotator { // Finds all duration instances in the input text. bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens, - AnnotationUsecase annotation_usecase, + AnnotationUsecase annotation_usecase, ModeFlag mode, std::vector<AnnotatedSpan>* results) const; private: diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc index 7c07a72..f726058 100644 --- a/native/annotator/duration/duration_test.cc +++ b/native/annotator/duration/duration_test.cc @@ -16,6 +16,7 @@ #include "annotator/duration/duration.h" +#include <cstddef> #include <string> #include <vector> @@ -37,41 +38,61 @@ using testing::ElementsAre; using testing::Field; using testing::IsEmpty; -const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - DurationAnnotatorOptionsT options; - options.enabled = true; +namespace { +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + DurationAnnotatorOptionsT options; + options.enabled = true; + options.enabled_modes = enabled_modes; - options.week_expressions.push_back("week"); - options.week_expressions.push_back("weeks"); + options.week_expressions.push_back("week"); + options.week_expressions.push_back("weeks"); - options.day_expressions.push_back("day"); - options.day_expressions.push_back("days"); + options.day_expressions.push_back("day"); + options.day_expressions.push_back("days"); - options.hour_expressions.push_back("hour"); - options.hour_expressions.push_back("hours"); + options.hour_expressions.push_back("hour"); + options.hour_expressions.push_back("hours"); - options.minute_expressions.push_back("minute"); - options.minute_expressions.push_back("minutes"); + options.minute_expressions.push_back("minute"); + options.minute_expressions.push_back("minutes"); - options.second_expressions.push_back("second"); - options.second_expressions.push_back("seconds"); + options.second_expressions.push_back("second"); + options.second_expressions.push_back("seconds"); - options.filler_expressions.push_back("and"); - options.filler_expressions.push_back("a"); - options.filler_expressions.push_back("an"); - options.filler_expressions.push_back("one"); + options.filler_expressions.push_back("and"); + options.filler_expressions.push_back("a"); + options.filler_expressions.push_back("an"); + options.filler_expressions.push_back("one"); - options.half_expressions.push_back("half"); + options.half_expressions.push_back("half"); - options.sub_token_separator_codepoints.push_back('-'); + options.sub_token_separator_codepoints.push_back('-'); - flatbuffers::FlatBufferBuilder builder; - builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); + flatbuffers::FlatBufferBuilder builder; + builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} +} // namespace - return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data()); +const DurationAnnotatorOptions* TestingDurationAnnotatorOptions( + ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_all = + CreateOptionsData(ModeFlag_ALL); + static const flatbuffers::DetachedBuffer* options_data_selection = + CreateOptionsData(ModeFlag_SELECTION); + static const flatbuffers::DetachedBuffer* options_data_no_selection = + CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION); + + if (enabled_modes == ModeFlag_SELECTION) { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_selection->data()); + } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_no_selection->data()); + } else { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_all->data()); + } } std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) { @@ -103,10 +124,10 @@ std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) { class DurationAnnotatorTest : public ::testing::Test { protected: - DurationAnnotatorTest() + explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL) : INIT_UNILIB_FOR_TESTING(unilib_), feature_processor_(BuildFeatureProcessor(&unilib_)), - duration_annotator_(TestingDurationAnnotatorOptions(), + duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes), feature_processor_.get(), &unilib_) {} std::vector<Token> Tokenize(const UnicodeText& text) { @@ -118,6 +139,19 @@ class DurationAnnotatorTest : public ::testing::Test { DurationAnnotator duration_annotator_; }; +class DurationAnnotatorForAnnotationAndClassificationTest + : public DurationAnnotatorTest { + protected: + DurationAnnotatorForAnnotationAndClassificationTest() + : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest { + protected: + DurationAnnotatorForSelectionTest() + : DurationAnnotatorTest(ModeFlag_SELECTION) {} +}; + TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( @@ -129,6 +163,14 @@ TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { Field(&ClassificationResult::duration_ms, 15 * 60 * 1000))); } +TEST_F(DurationAnnotatorForSelectionTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification; + EXPECT_FALSE(duration_annotator_.ClassifyText( + UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24}, + AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification)); +} + TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( @@ -152,7 +194,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -165,13 +208,26 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { 15 * 60 * 1000))))))); } +TEST_F(DurationAnnotatorForAnnotationAndClassificationTest, + FindsAllDisabledModeReturnsNoResults) { + const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); + + EXPECT_THAT(result, IsEmpty()); +} + TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?"); std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -190,7 +246,8 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -209,7 +266,8 @@ TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -228,7 +286,8 @@ TEST_F(DurationAnnotatorTest, FindsHalfAnHour) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -247,7 +306,8 @@ TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -266,7 +326,8 @@ TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -286,7 +347,8 @@ TEST_F(DurationAnnotatorTest, std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -305,7 +367,8 @@ TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -323,7 +386,8 @@ TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -334,7 +398,8 @@ TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -352,7 +417,8 @@ TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -383,7 +449,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -402,7 +469,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -422,7 +490,8 @@ TEST_F(DurationAnnotatorTest, std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -440,7 +509,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -458,7 +528,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -475,7 +546,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -540,7 +612,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -558,7 +631,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -576,7 +650,8 @@ TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, IsEmpty()); } @@ -586,7 +661,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, diff --git a/native/annotator/installed_app/installed-app-engine-dummy.h b/native/annotator/installed_app/installed-app-engine-dummy.h index 2f2b62f..80f5a5c 100644 --- a/native/annotator/installed_app/installed-app-engine-dummy.h +++ b/native/annotator/installed_app/installed-app-engine-dummy.h @@ -32,7 +32,7 @@ namespace libtextclassifier3 { class InstalledAppEngine { public: explicit InstalledAppEngine(const FeatureProcessor* feature_processor, - const UniLib* unilib) {} + const UniLib* unilib, ModeFlag enabled_modes) {} bool Initialize(const std::string& serialized_config) { TC3_LOG(ERROR) << "No installed app engine to initialize."; @@ -45,7 +45,7 @@ class InstalledAppEngine { } bool Chunk(const UnicodeText& context_unicode, - const std::vector<Token>& tokens, + const std::vector<Token>& tokens, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { return true; } diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h index 34fa490..949018c 100644 --- a/native/annotator/knowledge/knowledge-engine-dummy.h +++ b/native/annotator/knowledge/knowledge-engine-dummy.h @@ -37,6 +37,8 @@ class KnowledgeEngine { void SetPriorityScore(float priority_score) {} + void SetEnabledModes(ModeFlag enabled_modes) {} + Status ClassifyText(const std::string& text, CodepointSpan selection_indices, AnnotationUsecase annotation_usecase, const Optional<LocationContext>& location_context, @@ -48,7 +50,7 @@ class KnowledgeEngine { Status Chunk(const std::string& text, AnnotationUsecase annotation_usecase, const Optional<LocationContext>& location_context, const Permissions& permissions, const AnnotateMode annotate_mode, - Annotations* result) const { + ModeFlag mode, Annotations* result) const { return Status::OK; } @@ -58,7 +60,7 @@ class KnowledgeEngine { AnnotationUsecase annotation_usecase, const Optional<LocationContext>& location_context, const Permissions& permissions, const AnnotateMode annotate_mode, - Annotations* results) const { + ModeFlag mode, Annotations* results) const { return Status::OK; } diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs index 57187f5..eeb4101 100644 --- a/native/annotator/model.fbs +++ b/native/annotator/model.fbs @@ -415,6 +415,7 @@ table GrammarModel { // The grammar rules. rules:grammar.RulesSet; + // Deprecated. Used only for the old implementation of the grammar model. rule_classification_result:[GrammarModel_.RuleClassificationResult]; // Number of tokens in the context to use for classification and text @@ -432,6 +433,10 @@ table GrammarModel { // The priority score used for conflict resolution with the other models. priority_score:float = 1; + + // Global enabled modes. Use this instead of + // `rule_classification_result.enabled_modes`. + enabled_modes:ModeFlag = ALL; } namespace libtextclassifier3.MoneyParsingOptions_; @@ -486,6 +491,15 @@ table ModelTriggeringOptions { // map. Key: collection type e.g. "address", "phone"..., Value: float number. // NOTE: The entries here need to be sorted since we use LookupByKey. collection_to_priority:[ModelTriggeringOptions_.CollectionToPriorityEntry]; + + // Enabled modes for the knowledge engine model. + knowledge_enabled_modes:ModeFlag = ALL; + + // Enabled modes for the experimental model. + experimental_enabled_modes:ModeFlag = ALL; + + // Enabled modes for the installed app model. + installed_app_enabled_modes:ModeFlag = ALL; } // Options controlling the output of the classifier. @@ -894,6 +908,9 @@ table ContactAnnotatorOptions { // For each language there is a customized list of supported declensions. language:string (shared); + + // Enabled modes. + enabled_modes:ModeFlag = ALL; } namespace libtextclassifier3.TranslateAnnotatorOptions_; @@ -927,6 +944,9 @@ table TranslateAnnotatorOptions { algorithm:TranslateAnnotatorOptions_.Algorithm; backoff_options:TranslateAnnotatorOptions_.BackoffOptions; + + // Enabled modes. + enabled_modes:ModeFlag = CLASSIFICATION; } namespace libtextclassifier3.PodNerModel_; @@ -1012,6 +1032,9 @@ table PodNerModel { min_number_of_tokens:int = 1; min_number_of_wordpieces:int = 1; + + // Enabled modes. + enabled_modes:ModeFlag = ALL; } namespace libtextclassifier3; @@ -1043,6 +1066,9 @@ table VocabModel { // Priority score used for conflict resolution with the other models. priority_score:float = 0; + + // Enabled modes. + enabled_modes:ModeFlag = ANNOTATION_AND_CLASSIFICATION; } root_type libtextclassifier3.Model; diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc index 3be6ad8..14fc24e 100644 --- a/native/annotator/number/number.cc +++ b/native/annotator/number/number.cc @@ -21,6 +21,7 @@ #include <string> #include "annotator/collections.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/strings/split.h" @@ -38,7 +39,8 @@ bool NumberAnnotator::ClassifyText( context, selection_indices.first, selection_indices.second); std::vector<AnnotatedSpan> results; - if (!FindAll(substring_selected, annotation_usecase, &results)) { + if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION, + &results)) { return false; } @@ -216,8 +218,9 @@ bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text, bool NumberAnnotator::FindAll(const UnicodeText& context, AnnotationUsecase annotation_usecase, + ModeFlag mode, std::vector<AnnotatedSpan>* result) const { - if (!options_->enabled()) { + if (!options_->enabled() || !(options_->enabled_modes() & mode)) { return true; } diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h index d83bea0..dcc2d48 100644 --- a/native/annotator/number/number.h +++ b/native/annotator/number/number.h @@ -58,7 +58,7 @@ class NumberAnnotator { // Finds all number instances in the input text. Returns true in any case. bool FindAll(const UnicodeText& context_unicode, - AnnotationUsecase annotation_usecase, + AnnotationUsecase annotation_usecase, ModeFlag mode, std::vector<AnnotatedSpan>* result) const; private: diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc index f47933f..98140f4 100644 --- a/native/annotator/number/number_test-include.cc +++ b/native/annotator/number/number_test-include.cc @@ -16,6 +16,7 @@ #include "annotator/number/number_test-include.h" +#include <set> #include <string> #include <vector> @@ -34,37 +35,57 @@ namespace test_internal { using ::testing::AllOf; using ::testing::ElementsAre; using ::testing::Field; +using ::testing::IsEmpty; using ::testing::Matcher; using ::testing::UnorderedElementsAre; +namespace { +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + NumberAnnotatorOptionsT options; + options.enabled = true; + options.priority_score = -10.0; + options.float_number_priority_score = 1.0; + options.enabled_annotation_usecases = + 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW; + options.max_number_of_digits = 20; + options.enabled_modes = enabled_modes; + + options.percentage_priority_score = 1.0; + options.percentage_annotation_usecases = + (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) + + (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART); + std::set<std::string> percent_suffixes( + {"パーセント", "percent", "pércént", "pc", "pct", "%", "٪", "﹪", "%"}); + for (const std::string& string_value : percent_suffixes) { + options.percentage_pieces_string.append(string_value); + options.percentage_pieces_string.push_back('\0'); + } + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(NumberAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} +} // namespace + const NumberAnnotatorOptions* -NumberAnnotatorTest::TestingNumberAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - NumberAnnotatorOptionsT options; - options.enabled = true; - options.priority_score = -10.0; - options.float_number_priority_score = 1.0; - options.enabled_annotation_usecases = - 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW; - options.max_number_of_digits = 20; - - options.percentage_priority_score = 1.0; - options.percentage_annotation_usecases = - (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) + - (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART); - std::set<std::string> percent_suffixes({"パーセント", "percent", "pércént", - "pc", "pct", "%", "٪", "﹪", "%"}); - for (const std::string& string_value : percent_suffixes) { - options.percentage_pieces_string.append(string_value); - options.percentage_pieces_string.push_back('\0'); - } - - flatbuffers::FlatBufferBuilder builder; - builder.Finish(NumberAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); - - return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data()); +NumberAnnotatorTest::TestingNumberAnnotatorOptions(ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_selection = + CreateOptionsData(ModeFlag_SELECTION); + static const flatbuffers::DetachedBuffer* options_data_no_selection = + CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION); + static const flatbuffers::DetachedBuffer* options_data_all = + CreateOptionsData(ModeFlag_ALL); + + if (enabled_modes == ModeFlag_SELECTION) { + return flatbuffers::GetRoot<NumberAnnotatorOptions>( + options_data_selection->data()); + } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) { + return flatbuffers::GetRoot<NumberAnnotatorOptions>( + options_data_no_selection->data()); + } else { + return flatbuffers::GetRoot<NumberAnnotatorOptions>( + options_data_all->data()); + } } MATCHER_P(IsCorrectCollection, collection, "collection is " + collection) { @@ -124,6 +145,14 @@ TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) { EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345); } +TEST_F(NumberAnnotatorForSelectionTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification_result; + EXPECT_FALSE(number_annotator_.ClassifyText( + UTF8ToUnicodeText("... 12345 ..."), {4, 9}, + AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result)); +} + TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberAsFloatCorrectly) { ClassificationResult classification_result; EXPECT_TRUE(number_annotator_.ClassifyText( @@ -167,7 +196,7 @@ TEST_F(NumberAnnotatorTest, FindsAllIntegerAndFloatNumbersInText) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("how much is 2 plus 5 divided by 7% minus 3.14 " "what about 68.9# or 68.9#?"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -268,7 +297,8 @@ TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiJaPercentageCorrectSuffix) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("明日の降水確率は10パーセント 音量を12にセット"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION, + &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(8, 10), "number", @@ -285,7 +315,7 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 " "but not $99."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -307,12 +337,23 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) { /*int_value=*/39, /*double_value=*/39.0))); } +TEST_F(NumberAnnotatorForAnnotationAndClassificationTest, + FindsAllDisabledModeReturnsNoResults) { + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(number_annotator_.FindAll( + UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 " + "but not $99."), + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result)); + + EXPECT_THAT(result, IsEmpty()); +} + TEST_F(NumberAnnotatorTest, FindsNoNumberInText) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... 12345a ... 12345..12345 and 123a45 are not valid. " "And -#5% is also bad."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result)); ASSERT_EQ(result.size(), 0); } @@ -323,7 +364,8 @@ TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "It's 12, 13, 14! Or 15??? For sure 16: 17; 18. and -19"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION, + &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -348,7 +390,7 @@ TEST_F(NumberAnnotatorTest, FindsFloatNumberWithPunctuation) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("It's 12.123, 13.45, 14.54321! Or 15.1? Maybe 16.33: " "17.21; but for sure 18.90."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -379,7 +421,7 @@ TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_SELECTION, &result)); EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan( CodepointSpan(0, 2), "number", @@ -390,7 +432,7 @@ TEST_F(NumberAnnotatorTest, HandlesNegativeNumbers) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Number -5 and -5% and not number --5%"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -408,7 +450,7 @@ TEST_F(NumberAnnotatorTest, FindGoodPercentageContexts) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "5 percent, 10 pct, 25 pc and 17%, -5 percent, 10% are percentages"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -448,7 +490,7 @@ TEST_F(NumberAnnotatorTest, FindSinglePercentageInContext) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("5%"), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(0, 1), "number", @@ -463,7 +505,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentageContexts) { // A valid number is followed by only one punctuation element. EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("10, pct, 25 prc, 5#: percentage are not percentages"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -478,7 +520,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentagePunctuationContexts) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "#!24% or :?33 percent are not valid percentages, nor numbers."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_TRUE(result.empty()); } @@ -488,7 +530,7 @@ TEST_F(NumberAnnotatorTest, FindPercentageInNonAsciiContext) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "At the café 10% or 25 percent of people are nice. Only 10%!"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -748,7 +790,7 @@ TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -757,7 +799,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -766,7 +808,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -786,7 +828,7 @@ TEST_F(NumberAnnotatorTest, ForNumberAnnotationsSetsScoreAndPriorityScore) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Come at 9 or 10 ok?"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -811,7 +853,7 @@ TEST_F(NumberAnnotatorTest, std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Results are between 12.5 and 13.5, right?"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(20, 24), "number", @@ -845,7 +887,7 @@ TEST_F(NumberAnnotatorTest, ForPercentageAnnotationsSetsScoreAndPriorityScore) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Results are between 9% and 10 percent."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(20, 21), "number", @@ -887,7 +929,8 @@ TEST_F(NumberAnnotatorTest, NumberDisabledPercentageEnabledForSmartUsecase) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Accuracy for experiment 3 is 9%."), - AnnotationUsecase_ANNOTATION_USECASE_SMART, &result)); + AnnotationUsecase_ANNOTATION_USECASE_SMART, ModeFlag_ANNOTATION, + &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(29, 31), "percentage", /*int_value=*/9, /*double_value=*/9.0, @@ -898,7 +941,7 @@ TEST_F(NumberAnnotatorTest, MathOperatorsNotAnnotatedAsNumbersFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("how much is 2 + 2 or 5 - 96 * 89"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -928,7 +971,7 @@ TEST_F(NumberAnnotatorTest, SlashSeparatesTwoNumbersFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("what's 1 + 2/3 * 4/5 * 6 / 7"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -972,7 +1015,7 @@ TEST_F(NumberAnnotatorTest, SlashDoesNotSeparatesTwoNumbersFindAll) { // 2 in the "2/" context is a number because / is punctuation EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("what's 2a2/3 or 2/s4 or 2/ or /3 or //3 or 2//"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan( CodepointSpan(24, 25), "number", @@ -983,7 +1026,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextAnnotatedFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("The interval is: (12, 13) or [-12, -4.5)"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -1002,7 +1045,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextNotAnnotatedFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("The interval is: -(12, 138*)"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_TRUE(result.empty()); } @@ -1012,7 +1055,7 @@ TEST_F(NumberAnnotatorTest, FractionalNumberDotsFindAll) { // Dots source: https://unicode-search.net/unicode-namesearch.pl?term=period EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("3.1 3﹒2 3.3"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(0, 3), "number", @@ -1032,7 +1075,7 @@ TEST_F(NumberAnnotatorTest, NonAsciiDigitsFindAll) { // Digits source: https://unicode-search.net/unicode-namesearch.pl?term=digit EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("3 3﹒2 3.3%"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(0, 1), "number", @@ -1052,7 +1095,7 @@ TEST_F(NumberAnnotatorTest, AnnotatedZeroPrecededNumbersFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Numbers: 0.9 or 09 or 09.9 or 032310"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(9, 12), "number", @@ -1072,7 +1115,7 @@ TEST_F(NumberAnnotatorTest, ZeroAfterDotFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("15.0 16.00"), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -1086,7 +1129,7 @@ TEST_F(NumberAnnotatorTest, NineDotNineFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("9.9 9.99 99.99 99.999 99.9999"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( diff --git a/native/annotator/number/number_test-include.h b/native/annotator/number/number_test-include.h index 9de7c86..14fc6f2 100644 --- a/native/annotator/number/number_test-include.h +++ b/native/annotator/number/number_test-include.h @@ -17,6 +17,7 @@ #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_ #define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_ +#include "annotator/model_generated.h" #include "annotator/number/number.h" #include "utils/jvm-test-utils.h" #include "gtest/gtest.h" @@ -25,17 +26,32 @@ namespace libtextclassifier3 { namespace test_internal { class NumberAnnotatorTest : public ::testing::Test { + private: protected: - NumberAnnotatorTest() + explicit NumberAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL) : unilib_(CreateUniLibForTesting()), - number_annotator_(TestingNumberAnnotatorOptions(), unilib_.get()) {} + number_annotator_(TestingNumberAnnotatorOptions(enabled_modes), + unilib_.get()) {} - const NumberAnnotatorOptions* TestingNumberAnnotatorOptions(); + const NumberAnnotatorOptions* TestingNumberAnnotatorOptions( + ModeFlag enabled_modes); std::unique_ptr<UniLib> unilib_; NumberAnnotator number_annotator_; }; +class NumberAnnotatorForAnnotationAndClassificationTest + : public NumberAnnotatorTest { + protected: + NumberAnnotatorForAnnotationAndClassificationTest() + : NumberAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class NumberAnnotatorForSelectionTest : public NumberAnnotatorTest { + protected: + NumberAnnotatorForSelectionTest() : NumberAnnotatorTest(ModeFlag_SELECTION) {} +}; + } // namespace test_internal } // namespace libtextclassifier3 diff --git a/native/annotator/person_name/person-name-engine-dummy.h b/native/annotator/person_name/person-name-engine-dummy.h index 9c83241..44d2821 100644 --- a/native/annotator/person_name/person-name-engine-dummy.h +++ b/native/annotator/person_name/person-name-engine-dummy.h @@ -46,7 +46,7 @@ class PersonNameEngine { } bool Chunk(const UnicodeText& context_unicode, - const std::vector<Token>& tokens, + const std::vector<Token>& tokens, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { return true; } diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs index b15543f..6ef4a72 100644 --- a/native/annotator/person_name/person_name_model.fbs +++ b/native/annotator/person_name/person_name_model.fbs @@ -14,6 +14,8 @@ // limitations under the License. // +include "annotator/model.fbs"; + file_identifier "TC2 "; // Next ID: 2 @@ -26,7 +28,7 @@ table PersonName { person_name:string (shared); } -// Next ID: 6 +// Next ID: 7 namespace libtextclassifier3; table PersonNameModel { // Decides if the person name annotator is enabled. @@ -52,6 +54,9 @@ table PersonNameModel { // upper case character and have at least one lower case character. // required annotate_capitalized_names_only:bool; + + // Enabled modes. + enabled_modes:ModeFlag = ALL; } root_type libtextclassifier3.PersonNameModel; diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc index 666b7c7..0cb86ee 100644 --- a/native/annotator/pod_ner/pod-ner-impl.cc +++ b/native/annotator/pod_ner/pod-ner-impl.cc @@ -398,6 +398,10 @@ bool PodNerAnnotator::AnnotateAroundSpanOfInterest( std::vector<AnnotatedSpan> *results) const { TC3_CHECK(results != nullptr); + if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { + return true; + } + std::vector<int32_t> wordpiece_indices; std::vector<int32_t> token_starts; std::vector<Token> tokens; @@ -470,6 +474,11 @@ bool PodNerAnnotator::SuggestSelection(const UnicodeText &context, return false; } + if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { + *result = {}; + return false; + } + for (const AnnotatedSpan &annotation : annotations) { TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation; if (annotation.span.first <= click.first && @@ -491,6 +500,10 @@ bool PodNerAnnotator::ClassifyText(const UnicodeText &context, CodepointSpan click, ClassificationResult *result) const { TC3_VLOG(INFO) << "POD NER ClassifyText " << click; + if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return false; + } + std::vector<AnnotatedSpan> annotations; if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) { return false; diff --git a/native/annotator/pod_ner/pod-ner-impl_test.cc b/native/annotator/pod_ner/pod-ner-impl_test.cc index c7d0bee..5accebd 100644 --- a/native/annotator/pod_ner/pod-ner-impl_test.cc +++ b/native/annotator/pod_ner/pod-ner-impl_test.cc @@ -53,7 +53,7 @@ constexpr float kDefaultPriorityScore = 0.5; class PodNerTest : public testing::Test { protected: - PodNerTest() { + explicit PodNerTest(ModeFlag enabled_modes = ModeFlag_ALL) { PodNerModelT model; model.min_number_of_tokens = kMinNumberOfTokens; @@ -68,6 +68,7 @@ class PodNerTest : public testing::Test { GetTestFileContent("annotator/pod_ner/test_data/vocab.txt"); model.word_piece_vocab = std::vector<uint8_t>( word_piece_vocab_buffer.begin(), word_piece_vocab_buffer.end()); + model.enabled_modes = enabled_modes; flatbuffers::FlatBufferBuilder builder; builder.Finish(PodNerModel::Pack(builder, &model)); @@ -101,6 +102,17 @@ class PodNerTest : public testing::Test { std::unique_ptr<UniLib> unilib_; }; +class PodNerForAnnotationAndClassificationTest : public PodNerTest { + protected: + PodNerForAnnotationAndClassificationTest() + : PodNerTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class PodNerForSelectionTest : public PodNerTest { + protected: + PodNerForSelectionTest() : PodNerTest(ModeFlag_SELECTION) {} +}; + TEST_F(PodNerTest, AnnotateSmokeTest) { std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(model_, *unilib_); @@ -209,6 +221,18 @@ TEST_F(PodNerTest, AnnotateDefaultCollections) { } } +TEST_F(PodNerForSelectionTest, AnnotateWithDisabledAnnotationReturnsNoResults) { + std::unique_ptr<PodNerAnnotator> annotator = + PodNerAnnotator::Create(model_, *unilib_); + ASSERT_TRUE(annotator != nullptr); + + std::string multi_word_location = "I live in New York"; + std::vector<AnnotatedSpan> annotations; + ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(multi_word_location), + &annotations)); + EXPECT_THAT(annotations, IsEmpty()); +} + TEST_F(PodNerTest, AnnotateConfigurableCollections) { std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack()); ASSERT_TRUE(unpacked_model != nullptr); @@ -525,6 +549,18 @@ TEST_F(PodNerTest, SuggestSelectionTest) { EXPECT_EQ(suggested_span.span, CodepointSpan(kInvalidIndex, kInvalidIndex)); } +TEST_F(PodNerForAnnotationAndClassificationTest, + SuggestSelectionWithDisabledSelectionReturnsNoResults) { + std::unique_ptr<PodNerAnnotator> annotator = + PodNerAnnotator::Create(model_, *unilib_); + ASSERT_TRUE(annotator != nullptr); + + AnnotatedSpan suggested_span; + EXPECT_FALSE(annotator->SuggestSelection( + UTF8ToUnicodeText("Google New York, in New York"), {7, 10}, + &suggested_span)); +} + TEST_F(PodNerTest, ClassifyTextTest) { std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(model_, *unilib_); @@ -536,6 +572,17 @@ TEST_F(PodNerTest, ClassifyTextTest) { EXPECT_EQ(result.collection, "location"); } +TEST_F(PodNerForSelectionTest, + ClassifyTextWithDisabledClassificationReturnsFalse) { + std::unique_ptr<PodNerAnnotator> annotator = + PodNerAnnotator::Create(model_, *unilib_); + ASSERT_TRUE(annotator != nullptr); + + ClassificationResult result; + ASSERT_FALSE(annotator->ClassifyText(UTF8ToUnicodeText("We met in New York"), + {10, 18}, &result)); +} + TEST_F(PodNerTest, ThreadSafety) { std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(model_, *unilib_); diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc index 2c5a43c..e38109c 100644 --- a/native/annotator/translate/translate.cc +++ b/native/annotator/translate/translate.cc @@ -21,6 +21,7 @@ #include "annotator/collections.h" #include "annotator/entity-data_generated.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "lang_id/lang-id-wrapper.h" #include "utils/base/logging.h" @@ -34,6 +35,10 @@ bool TranslateAnnotator::ClassifyText( const UnicodeText& context, CodepointSpan selection_indices, const std::string& user_familiar_language_tags, ClassificationResult* classification_result) const { + if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return false; + } + std::vector<TranslateAnnotator::LanguageConfidence> confidences; if (options_->algorithm() == TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) { diff --git a/native/annotator/translate/translate_test.cc b/native/annotator/translate/translate_test.cc index 5c4a63f..90227ec 100644 --- a/native/annotator/translate/translate_test.cc +++ b/native/annotator/translate/translate_test.cc @@ -31,20 +31,33 @@ namespace { using testing::AllOf; using testing::Field; -const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - TranslateAnnotatorOptionsT options; - options.enabled = true; - options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF; - options.backoff_options.reset( - new TranslateAnnotatorOptions_::BackoffOptionsT()); - - flatbuffers::FlatBufferBuilder builder; - builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); - - return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data()); +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + TranslateAnnotatorOptionsT options; + options.enabled = true; + options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF; + options.backoff_options.reset( + new TranslateAnnotatorOptions_::BackoffOptionsT()); + options.enabled_modes = enabled_modes; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} + +const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions( + ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_classification = + CreateOptionsData(ModeFlag_CLASSIFICATION); + static const flatbuffers::DetachedBuffer* options_data_none = + CreateOptionsData(ModeFlag_NONE); + + if (enabled_modes == ModeFlag_CLASSIFICATION) { + return flatbuffers::GetRoot<TranslateAnnotatorOptions>( + options_data_classification->data()); + } else { + return flatbuffers::GetRoot<TranslateAnnotatorOptions>( + options_data_none->data()); + } } class TestingTranslateAnnotator : public TranslateAnnotator { @@ -60,11 +73,12 @@ std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); } class TranslateAnnotatorTest : public ::testing::Test { protected: - TranslateAnnotatorTest() + explicit TranslateAnnotatorTest( + ModeFlag enabled_modes = ModeFlag_CLASSIFICATION) : INIT_UNILIB_FOR_TESTING(unilib_), langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile( GetModelPath() + "lang_id.smfb")), - translate_annotator_(TestingTranslateAnnotatorOptions(), + translate_annotator_(TestingTranslateAnnotatorOptions(enabled_modes), langid_model_.get(), &unilib_) {} UniLib unilib_; @@ -72,6 +86,11 @@ class TranslateAnnotatorTest : public ::testing::Test { TestingTranslateAnnotator translate_annotator_; }; +class TranslateAnnotatorForNoneTest : public TranslateAnnotatorTest { + protected: + TranslateAnnotatorForNoneTest() : TranslateAnnotatorTest(ModeFlag_NONE) {} +}; + TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) { ClassificationResult classification; EXPECT_TRUE(translate_annotator_.ClassifyText( @@ -110,6 +129,13 @@ TEST_F(TranslateAnnotatorTest, EntityDataIsSet) { predictions->Get(1)->confidence_score()); } +TEST_F(TranslateAnnotatorForNoneTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification; + EXPECT_FALSE(translate_annotator_.ClassifyText( + UTF8ToUnicodeText("学校"), {0, 2}, "en", &classification)); +} + TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) { ClassificationResult classification; diff --git a/native/annotator/types.h b/native/annotator/types.h index ada301c..8485d44 100644 --- a/native/annotator/types.h +++ b/native/annotator/types.h @@ -65,6 +65,7 @@ struct CodepointSpan { CodepointSpan(CodepointIndex start, CodepointIndex end) : first(start), second(end) {} + CodepointSpan(const CodepointSpan& other) = default; CodepointSpan& operator=(const CodepointSpan& other) = default; bool operator==(const CodepointSpan& other) const { @@ -439,6 +440,8 @@ struct ClassificationResult { contact_nickname, contact_email_address, contact_phone_number, contact_account_type, contact_account_name, contact_id, contact_alternate_name; + int64 contact_recognition_source; + float contact_neural_match_score; std::string app_name, app_package_name; int64 numeric_value; double numeric_double_value; @@ -577,12 +580,18 @@ struct ClassificationOptions : public BaseOptions, public DatetimeOptions { std::string user_familiar_language_tags; // If true, trigger dictionary on words that are of beginner level. bool trigger_dictionary_on_beginner_words = false; + // If true, generate *Add* contact intent for email/phone entity. + bool enable_add_contact_intent; + // If true, generate *Search* intent for named entities. + bool enable_search_intent; bool operator==(const ClassificationOptions& other) const { return this->user_familiar_language_tags == other.user_familiar_language_tags && this->trigger_dictionary_on_beginner_words == other.trigger_dictionary_on_beginner_words && + this->enable_add_contact_intent == other.enable_add_contact_intent && + this->enable_search_intent == other.enable_search_intent && BaseOptions::operator==(other) && DatetimeOptions::operator==(other); } }; diff --git a/native/annotator/vocab/vocab-annotator-impl.cc b/native/annotator/vocab/vocab-annotator-impl.cc index 4b5cc73..b464f54 100644 --- a/native/annotator/vocab/vocab-annotator-impl.cc +++ b/native/annotator/vocab/vocab-annotator-impl.cc @@ -61,6 +61,9 @@ bool VocabAnnotator::Annotate( const UnicodeText& context, const std::vector<Locale> detected_text_language_tags, bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const { + if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { + return true; + } std::vector<Token> tokens = feature_processor_.Tokenize(context); for (const Token& token : tokens) { ClassificationResult classification_result; @@ -90,6 +93,9 @@ bool VocabAnnotator::ClassifyTextInternal( const std::vector<Locale> detected_text_language_tags, bool trigger_on_beginner_words, ClassificationResult* classification_result, CodepointSpan* classified_span) const { + if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return false; + } if (vocab_level_table_ == nullptr) { return false; } |