diff options
author | Tony Mak <tonymak@google.com> | 2021-02-24 20:08:27 +0000 |
---|---|---|
committer | Tony Mak <tonymak@google.com> | 2021-02-25 14:26:51 +0000 |
commit | 8a501057fd9d5a2c4c194bcd22de93691bc1c452 (patch) | |
tree | 0ffb1f53246bc6cfd075d4d23ca2578d6d1122c1 /native | |
parent | 2587b43b53b9643da23c118f53199132ab28b414 (diff) | |
download | libtextclassifier-8a501057fd9d5a2c4c194bcd22de93691bc1c452.tar.gz |
Export libtextclassifier
Export libtextclassifier without the model downloader code to
mainline-prod for a few changes we want for the upcoming mainline
release.
Test: atest -p external/libtextclassifier
Test: Smart selection + smart reply on a R device
Fixes: 179890518
Change-Id: I0da4487432920e3f95cb00d4f44c8ec257f4b81b
Diffstat (limited to 'native')
47 files changed, 733 insertions, 333 deletions
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp index 6248d2a..4212bbd 100644 --- a/native/FlatBufferHeaders.bp +++ b/native/FlatBufferHeaders.bp @@ -64,13 +64,6 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_annotator_datetime_datetime", - srcs: ["annotator/datetime/datetime.fbs"], - out: ["annotator/datetime/datetime_generated.h"], - defaults: ["fbgen"], -} - -genrule { name: "libtextclassifier_fbgen_annotator_entity-data", srcs: ["annotator/entity-data.fbs"], out: ["annotator/entity-data_generated.h"], @@ -185,7 +178,6 @@ cc_library_headers { "libtextclassifier_fbgen_annotator_model", "libtextclassifier_fbgen_annotator_person_name_person_name_model", "libtextclassifier_fbgen_annotator_experimental_experimental", - "libtextclassifier_fbgen_annotator_datetime_datetime", "libtextclassifier_fbgen_annotator_entity-data", "libtextclassifier_fbgen_utils_grammar_testing_value", "libtextclassifier_fbgen_utils_grammar_semantics_expression", @@ -209,7 +201,6 @@ cc_library_headers { "libtextclassifier_fbgen_annotator_model", "libtextclassifier_fbgen_annotator_person_name_person_name_model", "libtextclassifier_fbgen_annotator_experimental_experimental", - "libtextclassifier_fbgen_annotator_datetime_datetime", "libtextclassifier_fbgen_annotator_entity-data", "libtextclassifier_fbgen_utils_grammar_testing_value", "libtextclassifier_fbgen_utils_grammar_semantics_expression", diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index a9edde9..69235d7 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -72,6 +72,20 @@ int NumMessagesToConsider(const Conversation& conversation, : max_conversation_history_length); } +template <typename T> +std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs, + const int max_length, + const T pad_value) { + if (inputs.size() >= max_length) { + return std::vector<T>(inputs.begin(), inputs.begin() + max_length); + } else { + std::vector<T> result; + result.reserve(max_length); + result.insert(result.begin(), inputs.begin(), inputs.end()); + result.insert(result.end(), max_length - inputs.size(), pad_value); + return result; + } +} } // namespace std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer( @@ -639,8 +653,17 @@ bool ActionsSuggestions::SetupModelInput( return false; } if (model_->tflite_model_spec()->input_context() >= 0) { - model_executor_->SetInput<std::string>( - model_->tflite_model_spec()->input_context(), context, interpreter); + if (model_->tflite_model_spec()->input_length_to_pad() > 0) { + model_executor_->SetInput<std::string>( + model_->tflite_model_spec()->input_context(), + PadOrTruncateToTargetLength( + context, model_->tflite_model_spec()->input_length_to_pad(), + std::string("")), + interpreter); + } else { + model_executor_->SetInput<std::string>( + model_->tflite_model_spec()->input_context(), context, interpreter); + } } if (model_->tflite_model_spec()->input_context_length() >= 0) { model_executor_->SetInput<int>( @@ -648,8 +671,16 @@ bool ActionsSuggestions::SetupModelInput( interpreter); } if (model_->tflite_model_spec()->input_user_id() >= 0) { - model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(), - user_ids, interpreter); + if (model_->tflite_model_spec()->input_length_to_pad() > 0) { + model_executor_->SetInput<int>( + model_->tflite_model_spec()->input_user_id(), + PadOrTruncateToTargetLength( + user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0), + interpreter); + } else { + model_executor_->SetInput<int>( + model_->tflite_model_spec()->input_user_id(), user_ids, interpreter); + } } if (model_->tflite_model_spec()->input_num_suggestions() >= 0) { model_executor_->SetInput<int>( diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc index 55aa852..ddaa604 100644 --- a/native/actions/actions-suggestions_test.cc +++ b/native/actions/actions-suggestions_test.cc @@ -29,6 +29,7 @@ #include "utils/flatbuffers/flatbuffers.h" #include "utils/flatbuffers/flatbuffers_generated.h" #include "utils/flatbuffers/mutable.h" +#include "utils/grammar/utils/locale-shard-map.h" #include "utils/grammar/utils/rules.h" #include "utils/hash/farmhash.h" #include "utils/jvm-test-utils.h" @@ -1042,7 +1043,9 @@ TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) { // Setup test rules. action_grammar_rules->rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add( "<knock>", {"<^>", "ventura", "!?", "<$>"}, /*callback=*/ @@ -1102,7 +1105,9 @@ TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) { // Setup test rules. action_grammar_rules->rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add( "<event>", {"it", "is", "at", "<time>"}, /*callback=*/ diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs index 1548816..0db43f4 100755 --- a/native/actions/actions_model.fbs +++ b/native/actions/actions_model.fbs @@ -15,15 +15,15 @@ // include "actions/actions-entity-data.fbs"; -include "utils/grammar/rules.fbs"; -include "utils/tokenizer.fbs"; -include "utils/flatbuffers/flatbuffers.fbs"; -include "utils/codepoint-range.fbs"; -include "utils/zlib/buffer.fbs"; -include "utils/normalization.fbs"; include "annotator/model.fbs"; +include "utils/codepoint-range.fbs"; +include "utils/flatbuffers/flatbuffers.fbs"; +include "utils/grammar/rules.fbs"; include "utils/intents/intent-config.fbs"; +include "utils/normalization.fbs"; include "utils/resources.fbs"; +include "utils/tokenizer.fbs"; +include "utils/zlib/buffer.fbs"; file_identifier "TC3A"; @@ -116,6 +116,10 @@ table TensorflowLiteModelSpec { // Map of additional input tensor name to its index. input_name_index:[TensorflowLiteModelSpec_.InputNameIndexEntry]; + + // If greater than 0, pad or truncate the input_user_id and input_context + // tensor to length of input_length_to_pad. + input_length_to_pad:int = 0; } // Configuration for the tokenizer. diff --git a/native/actions/grammar-actions_test.cc b/native/actions/grammar-actions_test.cc index 9fe73d4..02deea9 100644 --- a/native/actions/grammar-actions_test.cc +++ b/native/actions/grammar-actions_test.cc @@ -37,6 +37,8 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::libtextclassifier3::grammar::LocaleShardMap; + class TestGrammarActions : public GrammarActions { public: explicit TestGrammarActions( @@ -140,12 +142,14 @@ class GrammarActionsTest : public testing::Test { }; TEST_F(GrammarActionsTest, ProducesSmartReplies) { + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); + // Create test rules. // Rule: ^knock knock.?$ -> "Who's there?", "Yes?" RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; rules.Add( "<knock>", {"<^>", "knock", "knock", ".?", "<$>"}, /*callback=*/ @@ -174,7 +178,8 @@ TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add( "<scripted_reply>", @@ -231,7 +236,8 @@ TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add( "<call_phone>", {"please", "dial", "<phone>"}, @@ -270,7 +276,9 @@ TEST_F(GrammarActionsTest, HandlesLocales) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules(/*num_shards=*/2); + LocaleShardMap locale_shard_map = + LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"}); + grammar::Rules rules(locale_shard_map); rules.Add( "<knock>", {"<^>", "knock", "knock", ".?", "<$>"}, /*callback=*/ @@ -333,7 +341,8 @@ TEST_F(GrammarActionsTest, HandlesAssertions) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -387,7 +396,8 @@ TEST_F(GrammarActionsTest, SetsFixedEntityData) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); // Create smart reply and static entity data. const int spec_id = @@ -440,7 +450,8 @@ TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); // Create smart reply and static entity data. const int spec_id = @@ -522,7 +533,8 @@ TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); // Create smart reply. const int spec_id = @@ -572,7 +584,8 @@ TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) { RulesModel_::GrammarRulesT action_grammar_rules; SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add( "<call_phone>", {"please", "dial", "<phone>"}, /*callback=*/ @@ -632,7 +645,8 @@ TEST_F(GrammarActionsTest, HandlesExclusions) { SetTokenizerOptions(&action_grammar_rules); action_grammar_rules.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<excluded>", {"be", "safe"}); rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"}, /*excluded_nonterminal=*/"<excluded>"); diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model Binary files differindex ae6dc60..1af22aa 100644 --- a/native/actions/test_data/actions_suggestions_grammar_test.model +++ b/native/actions/test_data/actions_suggestions_grammar_test.model diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model Binary files differindex 52f932e..1361475 100644 --- a/native/actions/test_data/actions_suggestions_test.model +++ b/native/actions/test_data/actions_suggestions_test.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model Binary files differindex 6145540..1396a46 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model Binary files differindex de8520a..660d97f 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index 2635820..53176e6 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -1017,7 +1017,13 @@ CodepointSpan Annotator::SuggestSelection( return original_click_indices; } - return candidates.annotated_spans[0][i].span; + // We return a suggested span contains the original span. + // This compensates for "select all" selection that may come from + // other apps. See http://b/179890518. + if (SpanContains(candidates.annotated_spans[0][i].span, + original_click_indices)) { + return candidates.annotated_spans[0][i].span; + } } } @@ -1935,7 +1941,7 @@ std::vector<ClassificationResult> Annotator::ClassifyText( bool Annotator::ModelAnnotate( const std::string& context, const std::vector<Locale>& detected_text_language_tags, - const BaseOptions& options, InterpreterManager* interpreter_manager, + const AnnotationOptions& options, InterpreterManager* interpreter_manager, std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const { if (model_->triggering_options() == nullptr || !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { @@ -2008,13 +2014,41 @@ bool Annotator::ModelAnnotate( } const int offset = std::distance(context_unicode.begin(), line.first); + UnicodeText line_unicode; + std::vector<UnicodeText::const_iterator> line_codepoints; + if (options.enable_optimization) { + if (local_chunks.empty()) { + continue; + } + line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false); + line_codepoints = line_unicode.Codepoints(); + line_codepoints.push_back(line_unicode.end()); + } for (const TokenSpan& chunk : local_chunks) { CodepointSpan codepoint_span = - selection_feature_processor_->StripBoundaryCodepoints( - line_str, TokenSpanToCodepointSpan(line_tokens, chunk)); - if (model_->selection_options()->strip_unpaired_brackets()) { - codepoint_span = - StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_); + TokenSpanToCodepointSpan(line_tokens, chunk); + if (options.enable_optimization) { + if (!codepoint_span.IsValid() || + codepoint_span.second > line_codepoints.size()) { + continue; + } + codepoint_span = selection_feature_processor_->StripBoundaryCodepoints( + /*span_begin=*/line_codepoints[codepoint_span.first], + /*span_end=*/line_codepoints[codepoint_span.second], + codepoint_span); + if (model_->selection_options()->strip_unpaired_brackets()) { + codepoint_span = StripUnpairedBrackets( + /*span_begin=*/line_codepoints[codepoint_span.first], + /*span_end=*/line_codepoints[codepoint_span.second], + codepoint_span, *unilib_); + } + } else { + codepoint_span = selection_feature_processor_->StripBoundaryCodepoints( + line_str, codepoint_span); + if (model_->selection_options()->strip_unpaired_brackets()) { + codepoint_span = + StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_); + } } // Skip empty spans. @@ -3136,4 +3170,13 @@ bool Annotator::LookUpKnowledgeEntity( knowledge_engine_->LookUpEntity(id, serialized_knowledge_result); } +StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty( + const std::string& mid_str, const std::string& property) const { + if (!knowledge_engine_) { + return Status(StatusCode::FAILED_PRECONDITION, + "knowledge_engine_ is nullptr"); + } + return knowledge_engine_->LookUpEntityProperty(mid_str, property); +} + } // namespace libtextclassifier3 diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h index 5397f56..a570a83 100644 --- a/native/annotator/annotator.h +++ b/native/annotator/annotator.h @@ -219,6 +219,10 @@ class Annotator { bool LookUpKnowledgeEntity(const std::string& id, std::string* serialized_knowledge_result) const; + // Looks up an entity's property. + StatusOr<std::string> LookUpKnowledgeEntityProperty( + const std::string& mid_str, const std::string& property) const; + const Model* model() const; const reflection::Schema* entity_data_schema() const; @@ -342,7 +346,7 @@ class Annotator { // reuse. bool ModelAnnotate(const std::string& context, const std::vector<Locale>& detected_text_language_tags, - const BaseOptions& options, + const AnnotationOptions& options, InterpreterManager* interpreter_manager, std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const; diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc index 3ed91e1..b852827 100644 --- a/native/annotator/annotator_test-include.cc +++ b/native/annotator/annotator_test-include.cc @@ -27,6 +27,7 @@ #include "annotator/test-utils.h" #include "annotator/types-test-util.h" #include "annotator/types.h" +#include "utils/grammar/utils/locale-shard-map.h" #include "utils/grammar/utils/rules.h" #include "utils/testing/annotator.h" #include "lang_id/fb_model/lang-id-from-fb.h" @@ -921,13 +922,13 @@ TEST_F(AnnotatorTest, SuggestSelection) { // Unpaired bracket stripping. EXPECT_EQ( - classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}), + classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}), CodepointSpan(11, 25)); - EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}), + EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}), CodepointSpan(12, 15)); - EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}), + EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}), CodepointSpan(11, 15)); - EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}), + EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}), CodepointSpan(12, 15)); // If the resulting selection would be empty, the original span is returned. @@ -937,6 +938,12 @@ TEST_F(AnnotatorTest, SuggestSelection) { CodepointSpan(11, 12)); EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}), CodepointSpan(11, 12)); + + // If the original span is larger than the found selection, the original span + // is returned. + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}), + CodepointSpan(5, 24)); } TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) { @@ -1238,6 +1245,34 @@ TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) { })); } +TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) { + std::unique_ptr<Annotator> classifier = Annotator::FromPath( + GetTestModelPath(), unilib_.get(), calendarlib_.get()); + ASSERT_TRUE(classifier); + + AnnotationOptions options; + options.enable_optimization = true; + + EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options), + ElementsAreArray({ + IsAnnotatedSpan(11, 26, "phone"), + })); + + // Unpaired bracket stripping. + EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options), + ElementsAreArray({ + IsAnnotatedSpan(12, 23, "phone"), + })); + EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options), + ElementsAreArray({ + IsAnnotatedSpan(11, 22, "phone"), + })); + EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options), + ElementsAreArray({ + IsAnnotatedSpan(12, 23, "phone"), + })); +} + TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) { std::unique_ptr<Annotator> classifier = Annotator::FromPath( GetTestModelPath(), unilib_.get(), calendarlib_.get()); @@ -1743,7 +1778,9 @@ TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) { // Add test rules. grammar_model->rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<tv_detective>", {"jessica", "fletcher"}); rules.Add("<tv_detective>", {"columbo"}); rules.Add("<tv_detective>", {"magnum"}); diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs deleted file mode 100755 index 8012cdc..0000000 --- a/native/annotator/datetime/datetime.fbs +++ /dev/null @@ -1,145 +0,0 @@ -// -// Copyright (C) 2018 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -// Meridiem field. -namespace libtextclassifier3.grammar.datetime; -enum Meridiem : int { - UNKNOWN = 0, - - // Ante meridiem: Before noon - AM = 1, - - // Post meridiem: After noon - PM = 2, -} - -// Enum represents a unit of date and time in the expression. -// Next field: 10 -namespace libtextclassifier3.grammar.datetime; -enum ComponentType : int { - UNSPECIFIED = 0, - - // Year of the date seen in the text match. - YEAR = 1, - - // Month of the year starting with January = 1. - MONTH = 2, - - // Week (7 days). - WEEK = 3, - - // Day of week, start of the week is Sunday & its value is 1. - DAY_OF_WEEK = 4, - - // Day of the month starting with 1. - DAY_OF_MONTH = 5, - - // Hour of the day. - HOUR = 6, - - // Minute of the hour with a range of 0-59. - MINUTE = 7, - - // Seconds of the minute with a range of 0-59. - SECOND = 8, - - // Meridiem field i.e. AM/PM. - MERIDIEM = 9, -} - -namespace libtextclassifier3.grammar.datetime; -table TimeZone { - // Offset from UTC/GTM in minutes. - utc_offset_mins:int; -} - -namespace libtextclassifier3.grammar.datetime.RelativeDatetimeComponent_; -enum Modifier : int { - UNSPECIFIED = 0, - NEXT = 1, - THIS = 2, - LAST = 3, - NOW = 4, - TOMORROW = 5, - YESTERDAY = 6, -} - -// Message for representing the relative date-time component in date-time -// expressions. -// Next field: 4 -namespace libtextclassifier3.grammar.datetime; -table RelativeDatetimeComponent { - component_type:ComponentType = UNSPECIFIED; - modifier:RelativeDatetimeComponent_.Modifier = UNSPECIFIED; - value:int; -} - -// AbsoluteDateTime represents date-time expressions that is not ambiguous. -// Next field: 11 -namespace libtextclassifier3.grammar.datetime; -table AbsoluteDateTime { - // Year value of the date seen in the text match. - year:int = -1; - - // Month value of the year starting with January = 1. - month:int = -1; - - // Day value of the month starting with 1. - day:int = -1; - - // Day of week, start of the week is Sunday and its value is 1. - week_day:int = -1; - - // Hour value of the day. - hour:int = -1; - - // Minute value of the hour with a range of 0-59. - minute:int = -1; - - // Seconds value of the minute with a range of 0-59. - second:int = -1; - - partial_second:double = -1; - - // Meridiem field i.e. AM/PM. - meridiem:Meridiem; - - time_zone:TimeZone; -} - -// Message to represent relative datetime expressions. -// It encode expressions -// - Where modifier such as before/after shift the date e.g.[three days ago], -// [2 days after March 1st]. -// - When prefix make the expression relative e.g. [next weekend], -// [last Monday]. -// Next field: 3 -namespace libtextclassifier3.grammar.datetime; -table RelativeDateTime { - relative_datetime_component:[RelativeDatetimeComponent]; - - // The base could be an absolute datetime point for example: "March 1", a - // relative datetime point, for example: "2 days before March 1" - base:AbsoluteDateTime; -} - -// Datetime result. -namespace libtextclassifier3.grammar.datetime; -table UngroundedDatetime { - absolute_datetime:AbsoluteDateTime; - relative_datetime:RelativeDateTime; -} - diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc index f5e0510..7c07a72 100644 --- a/native/annotator/duration/duration_test.cc +++ b/native/annotator/duration/duration_test.cc @@ -23,7 +23,7 @@ #include "annotator/model_generated.h" #include "annotator/types-test-util.h" #include "annotator/types.h" -#include "utils/test-utils.h" +#include "utils/tokenizer-utils.h" #include "utils/utf8/unicodetext.h" #include "utils/utf8/unilib.h" #include "gmock/gmock.h" diff --git a/native/annotator/feature-processor.cc b/native/annotator/feature-processor.cc index 3831c5f..99e25e1 100644 --- a/native/annotator/feature-processor.cc +++ b/native/annotator/feature-processor.cc @@ -474,6 +474,13 @@ bool FeatureProcessor::SelectionLabelSpans( return true; } +bool FeatureProcessor::SelectionLabelRelativeTokenSpans( + std::vector<TokenSpan>* selection_label_relative_token_spans) const { + selection_label_relative_token_spans->assign(label_to_selection_.begin(), + label_to_selection_.end()); + return true; +} + void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() { if (options_->ignored_span_boundary_codepoints() != nullptr) { for (const int codepoint : *options_->ignored_span_boundary_codepoints()) { diff --git a/native/annotator/feature-processor.h b/native/annotator/feature-processor.h index 3b865b0..482d274 100644 --- a/native/annotator/feature-processor.h +++ b/native/annotator/feature-processor.h @@ -165,6 +165,11 @@ class FeatureProcessor { VectorSpan<Token> tokens, std::vector<CodepointSpan>* selection_label_spans) const; + // Fills selection_label_relative_token_spans with number of tokens left and + // right from the click. + bool SelectionLabelRelativeTokenSpans( + std::vector<TokenSpan>* selection_label_relative_token_spans) const; + int DenseFeaturesCount() const { return feature_extractor_.DenseFeaturesCount(); } diff --git a/native/annotator/grammar/grammar-annotator_test.cc b/native/annotator/grammar/grammar-annotator_test.cc index b2084cb..6fcd1f5 100644 --- a/native/annotator/grammar/grammar-annotator_test.cc +++ b/native/annotator/grammar/grammar-annotator_test.cc @@ -23,6 +23,7 @@ #include "annotator/model_generated.h" #include "utils/flatbuffers/flatbuffers.h" #include "utils/flatbuffers/mutable.h" +#include "utils/grammar/utils/locale-shard-map.h" #include "utils/grammar/utils/rules.h" #include "utils/tokenizer.h" #include "utils/utf8/unicodetext.h" @@ -45,7 +46,9 @@ TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -79,7 +82,9 @@ TEST_F(GrammarAnnotatorTest, HandlesAssertions) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -120,7 +125,9 @@ TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.AddValueMapping("<low_confidence_phone>", {"<digits>"}, /*value=*/0); @@ -157,7 +164,9 @@ TEST_F(GrammarAnnotatorTest, ClassifiesTextWithGrammarRules) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -194,7 +203,9 @@ TEST_F(GrammarAnnotatorTest, ClassifiesTextWithAssertions) { grammar_model.context_left_num_tokens = -1; grammar_model.context_right_num_tokens = -1; - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -254,7 +265,9 @@ TEST_F(GrammarAnnotatorTest, ClassifiesTextWithContext) { grammar_model.context_left_num_tokens = 3; grammar_model.context_right_num_tokens = 0; - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<tracking_number>", {"<5_digits>"}); rules.Add("<tracking_number>", {"<6_digits>"}); rules.Add("<tracking_number>", {"<7_digits>"}); @@ -306,7 +319,9 @@ TEST_F(GrammarAnnotatorTest, SuggestsTextSelection) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -338,7 +353,9 @@ TEST_F(GrammarAnnotatorTest, SetsFixedEntityData) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); const int person_result = AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model); rules.Add( @@ -382,7 +399,9 @@ TEST_F(GrammarAnnotatorTest, SetsEntityDataFromCapturingMatches) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); const int person_result = AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model); @@ -438,7 +457,9 @@ TEST_F(GrammarAnnotatorTest, RespectsRuleModes) { GrammarModelT grammar_model; SetTestTokenizerOptions(&grammar_model); grammar_model.rules.reset(new grammar::RulesSetT); - grammar::Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + grammar::Rules rules(locale_shard_map); rules.Add("<classification_carrier>", {"ei"}); rules.Add("<classification_carrier>", {"en"}); rules.Add("<selection_carrier>", {"ai"}); diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h index 615ad06..b6c4f42 100644 --- a/native/annotator/knowledge/knowledge-engine-dummy.h +++ b/native/annotator/knowledge/knowledge-engine-dummy.h @@ -66,6 +66,11 @@ class KnowledgeEngine { std::string* serialized_knowledge_result) const { return false; } + + StatusOr<std::string> LookUpEntityProperty( + const std::string& mid_str, const std::string& property) const { + return Status(StatusCode::UNIMPLEMENTED, "Not implemented"); + } }; } // namespace libtextclassifier3 diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs index dbbb422..f639f06 100755 --- a/native/annotator/model.fbs +++ b/native/annotator/model.fbs @@ -14,17 +14,17 @@ // limitations under the License. // -include "utils/intents/intent-config.fbs"; -include "annotator/experimental/experimental.fbs"; include "annotator/entity-data.fbs"; +include "annotator/experimental/experimental.fbs"; +include "utils/codepoint-range.fbs"; +include "utils/container/bit-vector.fbs"; +include "utils/flatbuffers/flatbuffers.fbs"; include "utils/grammar/rules.fbs"; +include "utils/intents/intent-config.fbs"; include "utils/normalization.fbs"; -include "utils/tokenizer.fbs"; include "utils/resources.fbs"; -include "utils/codepoint-range.fbs"; -include "utils/flatbuffers/flatbuffers.fbs"; +include "utils/tokenizer.fbs"; include "utils/zlib/buffer.fbs"; -include "utils/container/bit-vector.fbs"; file_identifier "TC2 "; diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc index d95d388..f47933f 100644 --- a/native/annotator/number/number_test-include.cc +++ b/native/annotator/number/number_test-include.cc @@ -23,7 +23,7 @@ #include "annotator/model_generated.h" #include "annotator/types-test-util.h" #include "annotator/types.h" -#include "utils/test-utils.h" +#include "utils/tokenizer-utils.h" #include "utils/utf8/unicodetext.h" #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/native/annotator/strip-unpaired-brackets.cc b/native/annotator/strip-unpaired-brackets.cc index df1fcce..8bf93d9 100644 --- a/native/annotator/strip-unpaired-brackets.cc +++ b/native/annotator/strip-unpaired-brackets.cc @@ -22,59 +22,23 @@ #include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { -namespace { -// Returns true if given codepoint is contained in the given span in context. -bool IsCodepointInSpan(const char32 codepoint, - const UnicodeText& context_unicode, - const CodepointSpan span) { - auto begin_it = context_unicode.begin(); - std::advance(begin_it, span.first); - auto end_it = context_unicode.begin(); - std::advance(end_it, span.second); - - return std::find(begin_it, end_it, codepoint) != end_it; -} - -// Returns the first codepoint of the span. -char32 FirstSpanCodepoint(const UnicodeText& context_unicode, - const CodepointSpan span) { - auto it = context_unicode.begin(); - std::advance(it, span.first); - return *it; -} - -// Returns the last codepoint of the span. -char32 LastSpanCodepoint(const UnicodeText& context_unicode, - const CodepointSpan span) { - auto it = context_unicode.begin(); - std::advance(it, span.second - 1); - return *it; -} - -} // namespace - -CodepointSpan StripUnpairedBrackets(const std::string& context, - CodepointSpan span, const UniLib& unilib) { - const UnicodeText context_unicode = - UTF8ToUnicodeText(context, /*do_copy=*/false); - return StripUnpairedBrackets(context_unicode, span, unilib); -} - -// If the first or the last codepoint of the given span is a bracket, the -// bracket is stripped if the span does not contain its corresponding paired -// version. -CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, - CodepointSpan span, const UniLib& unilib) { - if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) { +CodepointSpan StripUnpairedBrackets( + const UnicodeText::const_iterator& span_begin, + const UnicodeText::const_iterator& span_end, CodepointSpan span, + const UniLib& unilib) { + if (span_begin == span_end || !span.IsValid() || span.IsEmpty()) { return span; } - const char32 begin_char = FirstSpanCodepoint(context_unicode, span); + UnicodeText::const_iterator begin = span_begin; + const UnicodeText::const_iterator end = span_end; + const char32 begin_char = *begin; const char32 paired_begin_char = unilib.GetPairedBracket(begin_char); if (paired_begin_char != begin_char) { if (!unilib.IsOpeningBracket(begin_char) || - !IsCodepointInSpan(paired_begin_char, context_unicode, span)) { + std::find(begin, end, paired_begin_char) == end) { + ++begin; ++span.first; } } @@ -83,11 +47,11 @@ CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, return span; } - const char32 end_char = LastSpanCodepoint(context_unicode, span); + const char32 end_char = *std::prev(end); const char32 paired_end_char = unilib.GetPairedBracket(end_char); if (paired_end_char != end_char) { if (!unilib.IsClosingBracket(end_char) || - !IsCodepointInSpan(paired_end_char, context_unicode, span)) { + std::find(begin, end, paired_end_char) == end) { --span.second; } } @@ -102,4 +66,21 @@ CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, return span; } +CodepointSpan StripUnpairedBrackets(const UnicodeText& context, + CodepointSpan span, const UniLib& unilib) { + if (!span.IsValid() || span.IsEmpty()) { + return span; + } + const UnicodeText span_text = UnicodeText::Substring( + context, span.first, span.second, /*do_copy=*/false); + return StripUnpairedBrackets(span_text.begin(), span_text.end(), span, + unilib); +} + +CodepointSpan StripUnpairedBrackets(const std::string& context, + CodepointSpan span, const UniLib& unilib) { + return StripUnpairedBrackets(UTF8ToUnicodeText(context, /*do_copy=*/false), + span, unilib); +} + } // namespace libtextclassifier3 diff --git a/native/annotator/strip-unpaired-brackets.h b/native/annotator/strip-unpaired-brackets.h index ceb8d60..c6cdc1a 100644 --- a/native/annotator/strip-unpaired-brackets.h +++ b/native/annotator/strip-unpaired-brackets.h @@ -23,14 +23,21 @@ #include "utils/utf8/unilib.h" namespace libtextclassifier3 { + // If the first or the last codepoint of the given span is a bracket, the // bracket is stripped if the span does not contain its corresponding paired // version. -CodepointSpan StripUnpairedBrackets(const std::string& context, +CodepointSpan StripUnpairedBrackets( + const UnicodeText::const_iterator& span_begin, + const UnicodeText::const_iterator& span_end, CodepointSpan span, + const UniLib& unilib); + +// Same as above but takes a UnicodeText instance for the span. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context, CodepointSpan span, const UniLib& unilib); -// Same as above but takes UnicodeText instance directly. -CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, +// Same as above but takes a string instance. +CodepointSpan StripUnpairedBrackets(const std::string& context, CodepointSpan span, const UniLib& unilib); } // namespace libtextclassifier3 diff --git a/native/annotator/types.h b/native/annotator/types.h index 3063838..45999cd 100644 --- a/native/annotator/types.h +++ b/native/annotator/types.h @@ -101,6 +101,11 @@ inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { return a.first < b.second && b.first < a.second; } +inline bool SpanContains(const CodepointSpan& span, + const CodepointSpan& sub_span) { + return span.first <= sub_span.first && span.second >= sub_span.second; +} + template <typename T> bool DoesCandidateConflict( const int considered_candidate, const std::vector<T>& candidates, @@ -610,6 +615,11 @@ struct AnnotationOptions : public BaseOptions, public DatetimeOptions { // If true, trigger dictionary on words that are of beginner level. bool trigger_dictionary_on_beginner_words = false; + // If true, enables an optimized code path for annotation. + // The optimization caused crashes previously, which is why we are rolling it + // out using this temporary flag. See: b/178503899 + bool enable_optimization = false; + bool operator==(const AnnotationOptions& other) const { return this->is_serialized_entity_data_enabled == other.is_serialized_entity_data_enabled && diff --git a/native/lang_id/common/lite_base/endian.h b/native/lang_id/common/lite_base/endian.h index 16c2dca..2e3ee26 100644 --- a/native/lang_id/common/lite_base/endian.h +++ b/native/lang_id/common/lite_base/endian.h @@ -93,15 +93,6 @@ class LittleEndian { // Conversion functions. #ifdef SAFTM_IS_LITTLE_ENDIAN - static uint16 FromHost16(uint16 x) { return x; } - static uint16 ToHost16(uint16 x) { return x; } - - static uint32 FromHost32(uint32 x) { return x; } - static uint32 ToHost32(uint32 x) { return x; } - - static uint64 FromHost64(uint64 x) { return x; } - static uint64 ToHost64(uint64 x) { return x; } - static bool IsLittleEndian() { return true; } #elif defined SAFTM_IS_BIG_ENDIAN diff --git a/native/utils/base/endian.h b/native/utils/base/endian.h index 9312704..810bc46 100644 --- a/native/utils/base/endian.h +++ b/native/utils/base/endian.h @@ -53,8 +53,8 @@ namespace libtextclassifier3 { #define bswap_64(x) OSSwapInt64(x) #endif // !defined(bswap_16) #else -#define GG_LONGLONG(x) x##LL -#define GG_ULONGLONG(x) x##ULL +#define int64_t {x} x##LL +#define uint64_t {x} x##ULL static inline uint16 bswap_16(uint16 x) { return (uint16)(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); // NOLINT } @@ -65,14 +65,12 @@ static inline uint32 bswap_32(uint32 x) { } #define bswap_32(x) bswap_32(x) static inline uint64 bswap_64(uint64 x) { - return (((x & GG_ULONGLONG(0xFF)) << 56) | - ((x & GG_ULONGLONG(0xFF00)) << 40) | - ((x & GG_ULONGLONG(0xFF0000)) << 24) | - ((x & GG_ULONGLONG(0xFF000000)) << 8) | - ((x & GG_ULONGLONG(0xFF00000000)) >> 8) | - ((x & GG_ULONGLONG(0xFF0000000000)) >> 24) | - ((x & GG_ULONGLONG(0xFF000000000000)) >> 40) | - ((x & GG_ULONGLONG(0xFF00000000000000)) >> 56)); + return (((x & uint64_t{0xFF}) << 56) | ((x & uint64_t{0xFF00}) << 40) | + ((x & uint64_t{0xFF0000}) << 24) | ((x & uint64_t{0xFF000000}) << 8) | + ((x & uint64_t{0xFF00000000}) >> 8) | + ((x & uint64_t{0xFF0000000000}) >> 24) | + ((x & uint64_t{0xFF000000000000}) >> 40) | + ((x & uint64_t{0xFF00000000000000}) >> 56)); } #define bswap_64(x) bswap_64(x) #endif diff --git a/native/utils/grammar/analyzer_test.cc b/native/utils/grammar/analyzer_test.cc index 9f71efe..4950fb4 100644 --- a/native/utils/grammar/analyzer_test.cc +++ b/native/utils/grammar/analyzer_test.cc @@ -38,7 +38,9 @@ TEST_F(AnalyzerTest, ParsesTextWithGrammar) { semantic_values_schema_.buffer().end()); // Define rules and semantics. - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<month>", {"january"}, static_cast<CallbackId>(DefaultCallback::kSemanticExpression), /*callback_param=*/model.semantic_expression.size()); diff --git a/native/utils/grammar/parsing/matcher_test.cc b/native/utils/grammar/parsing/matcher_test.cc index 8528009..7c9a14d 100644 --- a/native/utils/grammar/parsing/matcher_test.cc +++ b/native/utils/grammar/parsing/matcher_test.cc @@ -123,7 +123,9 @@ class MatcherTest : public testing::Test { TEST_F(MatcherTest, HandlesBasicOperations) { // Create an example grammar. - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<test>", {"the", "quick", "brown", "fox"}, static_cast<CallbackId>(DefaultCallback::kRootRule)); rules.Add("<action>", {"<test>"}, @@ -146,7 +148,9 @@ TEST_F(MatcherTest, HandlesBasicOperations) { std::string CreateTestGrammar() { // Create an example grammar. - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); // Callbacks on terminal rules. rules.Add("<output_5>", {"quick"}, @@ -260,7 +264,9 @@ TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) { } TEST_F(MatcherTest, HandlesOptionalRuleElements) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"}, static_cast<CallbackId>(DefaultCallback::kRootRule)); rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"}, @@ -293,7 +299,9 @@ TEST_F(MatcherTest, HandlesOptionalRuleElements) { } TEST_F(MatcherTest, HandlesWhitespaceGapLimits) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<iata>", {"lx"}); rules.Add("<iata>", {"aa"}); // Require no whitespace between code and flight number. @@ -331,7 +339,9 @@ TEST_F(MatcherTest, HandlesWhitespaceGapLimits) { } TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0, /*max_whitespace_gap*/ -1, /*case_sensitive=*/true); rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0, @@ -383,7 +393,10 @@ TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) { } TEST_F(MatcherTest, HandlesExclusions) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); + rules.Add("<all_zeros>", {"0000"}); rules.AddWithExclusion("<flight_code>", {"<4_digits>"}, /*excluded_nonterminal=*/"<all_zeros>"); diff --git a/native/utils/grammar/parsing/parser_test.cc b/native/utils/grammar/parsing/parser_test.cc index cf8310b..183be0e 100644 --- a/native/utils/grammar/parsing/parser_test.cc +++ b/native/utils/grammar/parsing/parser_test.cc @@ -41,7 +41,9 @@ using ::testing::IsEmpty; class ParserTest : public GrammarTest {}; TEST_F(ParserTest, ParsesSimpleRules) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<day>", {"<2_digits>"}); rules.Add("<month>", {"<2_digits>"}); rules.Add("<year>", {"<4_digits>"}); @@ -58,7 +60,9 @@ TEST_F(ParserTest, ParsesSimpleRules) { } TEST_F(ParserTest, HandlesEmptyInput) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); constexpr int kTest = 0; rules.Add("<test>", {"test"}, static_cast<CallbackId>(DefaultCallback::kRootRule), kTest); @@ -80,7 +84,9 @@ TEST_F(ParserTest, HandlesEmptyInput) { } TEST_F(ParserTest, HandlesUppercaseTokens) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); constexpr int kScriptedReply = 0; rules.Add("<test>", {"please?", "reply", "<uppercase_token>"}, static_cast<CallbackId>(DefaultCallback::kRootRule), @@ -99,7 +105,9 @@ TEST_F(ParserTest, HandlesUppercaseTokens) { } TEST_F(ParserTest, HandlesAnchors) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); constexpr int kScriptedReply = 0; rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"}, static_cast<CallbackId>(DefaultCallback::kRootRule), @@ -118,7 +126,9 @@ TEST_F(ParserTest, HandlesAnchors) { } TEST_F(ParserTest, HandlesWordBreaks) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); constexpr int kFlight = 0; @@ -141,7 +151,9 @@ TEST_F(ParserTest, HandlesWordBreaks) { } TEST_F(ParserTest, HandlesAnnotations) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); constexpr int kCallPhone = 0; rules.Add("<flight>", {"dial", "<phone>"}, static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone); @@ -167,7 +179,9 @@ TEST_F(ParserTest, HandlesAnnotations) { } TEST_F(ParserTest, HandlesRegexAnnotators) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.AddRegex("<code>", "(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)"); constexpr int kScriptedReply = 0; @@ -188,7 +202,9 @@ TEST_F(ParserTest, HandlesRegexAnnotators) { } TEST_F(ParserTest, HandlesExclusions) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<excluded>", {"be", "safe"}); rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"}, /*excluded_nonterminal=*/"<excluded>"); @@ -210,7 +226,9 @@ TEST_F(ParserTest, HandlesExclusions) { } TEST_F(ParserTest, HandlesFillers) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); constexpr int kSetReminder = 0; rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"}, static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder); @@ -224,7 +242,9 @@ TEST_F(ParserTest, HandlesFillers) { } TEST_F(ParserTest, HandlesAssertions) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -249,7 +269,9 @@ TEST_F(ParserTest, HandlesAssertions) { } TEST_F(ParserTest, HandlesWhitespaceGapLimit) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<carrier>", {"lx"}); rules.Add("<carrier>", {"aa"}); rules.Add("<flight_code>", {"<2_digits>"}); @@ -270,7 +292,9 @@ TEST_F(ParserTest, HandlesWhitespaceGapLimit) { } TEST_F(ParserTest, HandlesCaseSensitiveMatching) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0, /*max_whitespace_gap=*/-1, /*case_sensitive=*/true); rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0, diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs index 3225892..bc0136c 100755 --- a/native/utils/grammar/rules.fbs +++ b/native/utils/grammar/rules.fbs @@ -14,9 +14,9 @@ // limitations under the License. // +include "utils/grammar/semantics/expression.fbs"; include "utils/i18n/language-tag.fbs"; include "utils/zlib/buffer.fbs"; -include "utils/grammar/semantics/expression.fbs"; // The terminal rules map as sorted strings table. // The sorted terminal strings table is represented as offsets into the diff --git a/native/utils/grammar/semantics/composer_test.cc b/native/utils/grammar/semantics/composer_test.cc index 95b0759..e768e18 100644 --- a/native/utils/grammar/semantics/composer_test.cc +++ b/native/utils/grammar/semantics/composer_test.cc @@ -38,7 +38,9 @@ class SemanticComposerTest : public GrammarTest {}; TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) { RulesSetT model; - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); const int test_value_type = TypeIdForName(semantic_values_schema_.get(), "libtextclassifier3.grammar.TestValue") @@ -109,7 +111,9 @@ TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) { TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) { RulesSetT model; - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); const int test_value_type = TypeIdForName(semantic_values_schema_.get(), "libtextclassifier3.grammar.TestValue") diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc index 49135bf..9477dd0 100644 --- a/native/utils/grammar/utils/ir.cc +++ b/native/utils/grammar/utils/ir.cc @@ -16,6 +16,7 @@ #include "utils/grammar/utils/ir.h" +#include "utils/i18n/locale.h" #include "utils/strings/append.h" #include "utils/strings/stringpiece.h" #include "utils/zlib/zlib.h" @@ -445,16 +446,29 @@ void Ir::Serialize(const bool include_debug_information, } // Serialize the unary and binary rules. - for (const RulesShard& shard : shards_) { + for (int i = 0; i < shards_.size(); i++) { output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>()); RulesSet_::RulesT* rules = output->rules.back().get(); - // Serialize the unary rules. - SerializeUnaryRulesShard(shard.unary_rules, output, rules); + for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) { + if (shard_locale.IsValid()) { + // Check if the language is set to all i.e. '*' which is a special, to + // make it consistent with device side parser here instead of filling + // the all locale leave the language tag list empty + rules->locale.emplace_back( + std::make_unique<libtextclassifier3::LanguageTagT>()); + libtextclassifier3::LanguageTagT* language_tag = + rules->locale.back().get(); + language_tag->language = shard_locale.Language(); + language_tag->region = shard_locale.Region(); + language_tag->script = shard_locale.Script(); + } + } + // Serialize the unary rules. + SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules); // Serialize the binary rules. - SerializeBinaryRulesShard(shard.binary_rules, output, rules); + SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules); } - // Serialize the terminal rules. // We keep the rules separate by shard but merge the actual terminals into // one shared string pool to most effectively exploit reuse. diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h index adafa66..f056d7a 100644 --- a/native/utils/grammar/utils/ir.h +++ b/native/utils/grammar/utils/ir.h @@ -25,6 +25,7 @@ #include "utils/base/integral_types.h" #include "utils/grammar/rules_generated.h" #include "utils/grammar/types.h" +#include "utils/grammar/utils/locale-shard-map.h" namespace libtextclassifier3::grammar { @@ -96,8 +97,10 @@ class Ir { std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules; }; - explicit Ir(const int num_shards = 1) - : num_nonterminals_(0), shards_(num_shards) {} + explicit Ir(const LocaleShardMap& locale_shard_map) + : num_nonterminals_(0), + locale_shard_map_(locale_shard_map), + shards_(locale_shard_map_.GetNumberOfShards()) {} // Adds a new non-terminal. Nonterm AddNonterminal(const std::string& name = "") { @@ -224,6 +227,8 @@ class Ir { Nonterm num_nonterminals_; std::unordered_set<Nonterm> nonshareable_; + // Locale information for Rules + const LocaleShardMap& locale_shard_map_; // The sharded rules. std::vector<RulesShard> shards_; diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc index 279d99a..7a386df 100644 --- a/native/utils/grammar/utils/ir_test.cc +++ b/native/utils/grammar/utils/ir_test.cc @@ -31,7 +31,9 @@ using ::testing::Ne; using ::testing::SizeIs; TEST(IrTest, HandlesSharingWithTerminalRules) { - Ir ir; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Ir ir(locale_shard_map); // <t1> ::= the const Nonterm t1 = ir.Add(kUnassignedNonterm, "the"); @@ -72,7 +74,9 @@ TEST(IrTest, HandlesSharingWithTerminalRules) { } TEST(IrTest, HandlesSharingWithNonterminalRules) { - Ir ir; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Ir ir(locale_shard_map); // Setup a few terminal rules. const std::vector<Nonterm> rhs = { @@ -97,7 +101,9 @@ TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) { // Test sharing in the presence of callbacks. constexpr CallbackId kOutput1 = 1; constexpr CallbackId kOutput2 = 2; - Ir ir; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Ir ir(locale_shard_map); const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello"); const Nonterm x2 = @@ -116,7 +122,10 @@ TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) { TEST(IrTest, SerializesRulesToFlatbufferFormat) { constexpr CallbackId kOutput = 1; - Ir ir; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Ir ir(locale_shard_map); + const Nonterm verb = ir.AddUnshareableNonterminal(); ir.Add(verb, "buy"); ir.Add(Ir::Lhs{verb, {kOutput}}, "bring"); @@ -155,7 +164,9 @@ TEST(IrTest, SerializesRulesToFlatbufferFormat) { } TEST(IrTest, HandlesRulesSharding) { - Ir ir(/*num_shards=*/2); + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"}); + Ir ir(locale_shard_map); const Nonterm verb = ir.AddUnshareableNonterminal(); const Nonterm set_reminder = ir.AddUnshareableNonterminal(); @@ -210,7 +221,9 @@ TEST(IrTest, HandlesRulesSharding) { } TEST(IrTest, DeduplicatesLhsSets) { - Ir ir; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Ir ir(locale_shard_map); const Nonterm test = ir.AddUnshareableNonterminal(); ir.Add(test, "test"); diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc new file mode 100644 index 0000000..4f7dc5e --- /dev/null +++ b/native/utils/grammar/utils/locale-shard-map.cc @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/grammar/utils/locale-shard-map.h" + +#include <algorithm> +#include <string> +#include <utility> +#include <vector> + +#include "utils/i18n/locale-list.h" +#include "utils/i18n/locale.h" +#include "utils/strings/append.h" + +namespace libtextclassifier3::grammar { +namespace { + +std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) { + std::vector<Locale> locale_list; + for (const Locale& locale : LocaleList::ParseFrom(locale_tags).GetLocales()) { + if (locale.IsValid()) { + locale_list.emplace_back(locale); + } + } + std::sort(locale_list.begin(), locale_list.end(), + [](const Locale& a, const Locale& b) { return a < b; }); + return locale_list; +} + +} // namespace + +LocaleShardMap LocaleShardMap::CreateLocaleShardMap( + const std::vector<std::string>& locale_tags) { + LocaleShardMap locale_shard_map; + for (const std::string& locale_tag : locale_tags) { + locale_shard_map.AddLocalTags(locale_tag); + } + return locale_shard_map; +} + +std::vector<Locale> LocaleShardMap::GetLocales(const int shard) const { + auto locale_it = shard_to_locale_data_.find(shard); + if (locale_it != shard_to_locale_data_.end()) { + return locale_it->second; + } + return std::vector<Locale>(); +} + +int LocaleShardMap::GetNumberOfShards() const { + return shard_to_locale_data_.size(); +} + +int LocaleShardMap::GetShard(const std::vector<Locale> locales) const { + for (const auto& [shard, locale_list] : shard_to_locale_data_) { + if (std::equal(locales.begin(), locales.end(), locale_list.begin())) { + return shard; + } + } + return 0; +} + +int LocaleShardMap::GetShard(const std::string& locale_tags) const { + std::vector<Locale> locale_list = LocaleTagsToLocaleList(locale_tags); + return GetShard(locale_list); +} + +void LocaleShardMap::AddLocalTags(const std::string& locale_tags) { + std::vector<Locale> locale_list = LocaleTagsToLocaleList(locale_tags); + int shard_id = shard_to_locale_data_.size(); + shard_to_locale_data_.insert({shard_id, locale_list}); +} + +} // namespace libtextclassifier3::grammar diff --git a/native/utils/grammar/utils/locale-shard-map.h b/native/utils/grammar/utils/locale-shard-map.h new file mode 100644 index 0000000..5e0f5cb --- /dev/null +++ b/native/utils/grammar/utils/locale-shard-map.h @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include "utils/grammar/types.h" +#include "utils/i18n/locale-list.h" +#include "utils/i18n/locale.h" +#include "utils/optional.h" + +namespace libtextclassifier3::grammar { + +// Grammar rules are associated with Locale which serve as a filter during rule +// application. The class holds shard’s information for Locale which is used +// when the Aqua rules are compiled into internal rules.proto flatbuffer. +class LocaleShardMap { + public: + static LocaleShardMap CreateLocaleShardMap( + const std::vector<std::string>& locale_tags); + + std::vector<Locale> GetLocales(const int shard) const; + + int GetShard(const std::vector<Locale> locales) const; + int GetShard(const std::string& locale_tags) const; + + int GetNumberOfShards() const; + + private: + explicit LocaleShardMap() {} + void AddLocalTags(const std::string& locale_tag); + + std::unordered_map<int, std::vector<Locale>> shard_to_locale_data_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_ diff --git a/native/utils/grammar/utils/locale-shard-map_test.cc b/native/utils/grammar/utils/locale-shard-map_test.cc new file mode 100644 index 0000000..14c9081 --- /dev/null +++ b/native/utils/grammar/utils/locale-shard-map_test.cc @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/grammar/utils/locale-shard-map.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3::grammar { +namespace { + +using ::testing::SizeIs; + +TEST(LocaleShardMapTest, HandlesSimpleShard) { + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap( + {"ar-EG", "bn-BD", "cs-CZ", "da-DK", "de-DE", "en-US", "es-ES", "fi-FI", + "fr-FR", "gu-IN", "id-ID", "it-IT", "ja-JP", "kn-IN", "ko-KR", "ml-IN", + "mr-IN", "nl-NL", "no-NO", "pl-PL", "pt-BR", "ru-RU", "sv-SE", "ta-IN", + "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "vi-VN", "zh-TW"}); + + EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 31); + for (int i = 0; i < 31; i++) { + EXPECT_THAT(locale_shard_map.GetLocales(i), SizeIs(1)); + } + EXPECT_EQ(locale_shard_map.GetLocales(0)[0], Locale::FromBCP47("ar-EG")); + EXPECT_EQ(locale_shard_map.GetLocales(8)[0], Locale::FromBCP47("fr-FR")); + EXPECT_EQ(locale_shard_map.GetLocales(16)[0], Locale::FromBCP47("mr-IN")); + EXPECT_EQ(locale_shard_map.GetLocales(24)[0], Locale::FromBCP47("te-IN")); + EXPECT_EQ(locale_shard_map.GetLocales(30)[0], Locale::FromBCP47("zh-TW")); +} + +TEST(LocaleTagShardTest, HandlesWildCard) { + LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({"*"}); + EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 1); + EXPECT_THAT(locale_shard_map.GetLocales(0), SizeIs(1)); +} + +TEST(LocaleTagShardTest, HandlesMultipleLocalePerShard) { + LocaleShardMap locale_shard_map = + LocaleShardMap::CreateLocaleShardMap({"ar-EG,bn-BD,cs-CZ", "en-*"}); + EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 2); + EXPECT_EQ(locale_shard_map.GetLocales(0)[0], Locale::FromBCP47("ar-EG")); + EXPECT_EQ(locale_shard_map.GetLocales(0)[1], Locale::FromBCP47("bn-BD")); + EXPECT_EQ(locale_shard_map.GetLocales(0)[2], Locale::FromBCP47("cs-CZ")); + EXPECT_EQ(locale_shard_map.GetLocales(1)[0], Locale::FromBCP47("en")); + + EXPECT_EQ(locale_shard_map.GetShard("ar-EG,bn-BD,cs-CZ"), 0); + EXPECT_EQ(locale_shard_map.GetShard("bn-BD,cs-CZ,ar-EG"), 0); + EXPECT_EQ(locale_shard_map.GetShard("bn-BD,ar-EG,cs-CZ"), 0); + EXPECT_EQ(locale_shard_map.GetShard("ar-EG,cs-CZ,bn-BD"), 0); +} + +TEST(LocaleTagShardTest, HandlesEmptyLocaleTag) { + LocaleShardMap locale_shard_map = + LocaleShardMap::CreateLocaleShardMap({"", "en-US"}); + EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 2); + EXPECT_THAT(locale_shard_map.GetLocales(0), SizeIs(0)); + EXPECT_THAT(locale_shard_map.GetLocales(1), SizeIs(1)); + EXPECT_EQ(locale_shard_map.GetLocales(1)[0], Locale::FromBCP47("en-US")); +} + +} // namespace +} // namespace libtextclassifier3::grammar diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc index 623124a..661514a 100644 --- a/native/utils/grammar/utils/rules.cc +++ b/native/utils/grammar/utils/rules.cc @@ -414,7 +414,7 @@ bool Rules::UsesFillers() const { } Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const { - Ir rules(num_shards_); + Ir rules(locale_shard_map_); std::unordered_map<int, Nonterm> nonterminal_ids; // Pending rules to process. diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h index a6851f3..4931e2f 100644 --- a/native/utils/grammar/utils/rules.h +++ b/native/utils/grammar/utils/rules.h @@ -55,7 +55,8 @@ constexpr const char* kFiller = "<filler>"; // internal representation. class Rules { public: - explicit Rules(const int num_shards = 1) : num_shards_(num_shards) {} + explicit Rules(const LocaleShardMap& locale_shard_map) + : locale_shard_map_(locale_shard_map) {} // Represents one item in a right-hand side, a single terminal or nonterminal. struct RhsElement { @@ -214,7 +215,7 @@ class Rules { // Checks whether the fillers are used in any active rule. bool UsesFillers() const; - const int num_shards_; + const LocaleShardMap& locale_shard_map_; // Non-terminal to id map. std::unordered_map<std::string, int> nonterminal_names_; diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc index 8db88ab..c71f2b4 100644 --- a/native/utils/grammar/utils/rules_test.cc +++ b/native/utils/grammar/utils/rules_test.cc @@ -28,7 +28,9 @@ using ::testing::IsEmpty; using ::testing::SizeIs; TEST(SerializeRulesTest, HandlesSimpleRuleSet) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<verb>", {"buy"}); rules.Add("<verb>", {"bring"}); @@ -49,7 +51,9 @@ TEST(SerializeRulesTest, HandlesSimpleRuleSet) { } TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); const CallbackId output = 1; rules.Add("<verb>", {"buy"}); @@ -73,7 +77,9 @@ TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) { } TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<iata>", {"lx"}); rules.Add("<iata>", {"aa"}); rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0, @@ -89,7 +95,9 @@ TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) { } TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1, /*case_sensitive=*/true); rules.Add("<iata>", {"AA"}, kNoCallback, 0, /*max_whitespace_gap=*/-1, @@ -109,7 +117,9 @@ TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) { } TEST(SerializeRulesTest, HandlesMultipleShards) { - Rules rules(/*num_shards=*/2); + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"}); + Rules rules(locale_shard_map); rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1, /*case_sensitive=*/true, /*shard=*/0); rules.Add("<iata>", {"aa"}, kNoCallback, 0, /*max_whitespace_gap=*/-1, @@ -124,7 +134,10 @@ TEST(SerializeRulesTest, HandlesMultipleShards) { } TEST(SerializeRulesTest, HandlesRegexRules) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); + // Rules rules; rules.AddRegex("<code>", "[A-Z]+"); rules.AddRegex("<numbers>", "\\d+"); RulesSetT frozen_rules; @@ -134,7 +147,9 @@ TEST(SerializeRulesTest, HandlesRegexRules) { } TEST(SerializeRulesTest, HandlesAlias) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<iata>", {"lx"}); rules.Add("<iata>", {"aa"}); rules.Add("<flight>", {"<iata>", "<4_digits>"}); @@ -155,7 +170,9 @@ TEST(SerializeRulesTest, HandlesAlias) { } TEST(SerializeRulesTest, ResolvesAnchorsAndFillers) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<code>", {"<^>", "<filler>", "this", "is", "a", "test", "<filler>", "<$>"}); const Ir ir = rules.Finalize(); @@ -177,7 +194,9 @@ TEST(SerializeRulesTest, ResolvesAnchorsAndFillers) { } TEST(SerializeRulesTest, HandlesFillers) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.Add("<test>", {"<filler>?", "a", "test"}); const Ir ir = rules.Finalize(); RulesSetT frozen_rules; @@ -198,7 +217,9 @@ TEST(SerializeRulesTest, HandlesFillers) { } TEST(SerializeRulesTest, HandlesAnnotations) { - Rules rules; + grammar::LocaleShardMap locale_shard_map = + grammar::LocaleShardMap::CreateLocaleShardMap({""}); + Rules rules(locale_shard_map); rules.AddAnnotation("phone"); rules.AddAnnotation("url"); rules.AddAnnotation("tracking_number"); diff --git a/native/utils/i18n/locale.cc b/native/utils/i18n/locale.cc index d5a1109..3719079 100644 --- a/native/utils/i18n/locale.cc +++ b/native/utils/i18n/locale.cc @@ -16,6 +16,8 @@ #include "utils/i18n/locale.h" +#include <string> + #include "utils/strings/split.h" namespace libtextclassifier3 { @@ -196,6 +198,20 @@ bool Locale::IsAnyLocaleSupported(const std::vector<Locale>& locales, return false; } +bool Locale::operator==(const Locale& locale) const { + return language_ == locale.language_ && region_ == locale.region_ && + script_ == locale.script_; +} + +bool Locale::operator<(const Locale& locale) const { + return std::tie(language_, region_, script_) < + std::tie(locale.language_, locale.region_, locale.script_); +} + +bool Locale::operator!=(const Locale& locale) const { + return !(*this == locale); +} + logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, const Locale& locale) { return stream << "Locale(language=" << locale.Language() diff --git a/native/utils/i18n/locale.h b/native/utils/i18n/locale.h index 308846d..036bacd 100644 --- a/native/utils/i18n/locale.h +++ b/native/utils/i18n/locale.h @@ -60,6 +60,10 @@ class Locale { const std::vector<Locale>& supported_locales, bool default_value); + bool operator==(const Locale& locale) const; + bool operator!=(const Locale& locale) const; + bool operator<(const Locale& locale) const; + private: Locale(const std::string& language, const std::string& script, const std::string& region) diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc index 2dbd786..e491130 100644 --- a/native/utils/tflite-model-executor.cc +++ b/native/utils/tflite-model-executor.cc @@ -62,6 +62,15 @@ TfLiteRegistration* Register_WHERE(); TfLiteRegistration* Register_ONE_HOT(); TfLiteRegistration* Register_POW(); TfLiteRegistration* Register_TANH(); +#ifndef TC3_AOSP +TfLiteRegistration* Register_REDUCE_PROD(); +TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_NOT_EQUAL(); +TfLiteRegistration* Register_CUMSUM(); +TfLiteRegistration* Register_EXPAND_DIMS(); +TfLiteRegistration* Register_FILL(); +TfLiteRegistration* Register_PADV2(); +#endif // TC3_AOSP } // namespace builtin } // namespace ops } // namespace tflite @@ -70,6 +79,17 @@ TfLiteRegistration* Register_TANH(); #include "utils/tflite/dist_diversification.h" #include "utils/tflite/text_encoder.h" #include "utils/tflite/token_encoder.h" +#ifndef TC3_AOSP +namespace tflite { +namespace ops { +namespace custom { +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); +TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR(); +TfLiteRegistration* Register_RAGGED_RANGE(); +} // namespace custom +} // namespace ops +} // namespace tflite +#endif // TC3_AOSP void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { resolver->AddBuiltin(tflite::BuiltinOperator_ADD, @@ -191,6 +211,22 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { tflite::ops::builtin::Register_TANH(), /*min_version=*/1, /*max_version=*/1); +#ifndef TC3_AOSP + resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD, + ::tflite::ops::builtin::Register_REDUCE_PROD()); + resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE, + ::tflite::ops::builtin::Register_SHAPE()); + resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL, + ::tflite::ops::builtin::Register_NOT_EQUAL()); + resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM, + ::tflite::ops::builtin::Register_CUMSUM()); + resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS, + ::tflite::ops::builtin::Register_EXPAND_DIMS()); + resolver->AddBuiltin(::tflite::BuiltinOperator_FILL, + ::tflite::ops::builtin::Register_FILL()); + resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2, + ::tflite::ops::builtin::Register_PADV2()); +#endif // TC3_AOSP } #else void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { @@ -222,6 +258,16 @@ std::unique_ptr<tflite::OpResolver> BuildOpResolver( tflite::ops::custom::Register_TEXT_ENCODER()); resolver->AddCustom("TokenEncoder", tflite::ops::custom::Register_TOKEN_ENCODER()); +#ifndef TC3_AOSP + resolver->AddCustom( + "TFSentencepieceTokenizeOp", + ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER()); + resolver->AddCustom("RaggedRange", + ::tflite::ops::custom::Register_RAGGED_RANGE()); + resolver->AddCustom( + "RaggedTensorToTensor", + ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR()); +#endif // TC3_AOSP #endif // TC3_WITH_ACTIONS_OPS customize_fn(resolver.get()); return std::unique_ptr<tflite::OpResolver>(std::move(resolver)); diff --git a/native/utils/test-utils.cc b/native/utils/tokenizer-utils.cc index 9e3216a..7d07b0c 100644 --- a/native/utils/test-utils.cc +++ b/native/utils/tokenizer-utils.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "utils/test-utils.h" +#include "utils/tokenizer-utils.h" #include <iterator> diff --git a/native/utils/test-utils.h b/native/utils/tokenizer-utils.h index 184e60a..553791b 100644 --- a/native/utils/test-utils.h +++ b/native/utils/tokenizer-utils.h @@ -16,8 +16,8 @@ // Utilities for tests. -#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_ -#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_ +#define LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_ #include <string> @@ -40,4 +40,4 @@ std::vector<Token> TokenizeOnDelimiters( } // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_ diff --git a/native/utils/test-utils_test.cc b/native/utils/tokenizer-utils_test.cc index 88a3ec1..9c632bd 100644 --- a/native/utils/test-utils_test.cc +++ b/native/utils/tokenizer-utils_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "utils/test-utils.h" +#include "utils/tokenizer-utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -22,7 +22,7 @@ namespace libtextclassifier3 { namespace { -TEST(TestUtilTest, TokenizeOnSpace) { +TEST(TokenizerUtilTest, TokenizeOnSpace) { std::vector<Token> tokens = TokenizeOnSpace("Where is Jörg Borg located? Maybe in Zürich ..."); @@ -65,7 +65,7 @@ TEST(TestUtilTest, TokenizeOnSpace) { EXPECT_EQ(tokens[8].end, 47); } -TEST(TestUtilTest, TokenizeOnDelimiters) { +TEST(TokenizerUtilTest, TokenizeOnDelimiters) { std::vector<Token> tokens = TokenizeOnDelimiters( "This might be čomplíčateď?!: Oder?", {' ', '?', '!'}); @@ -96,7 +96,7 @@ TEST(TestUtilTest, TokenizeOnDelimiters) { EXPECT_EQ(tokens[5].end, 35); } -TEST(TestUtilTest, TokenizeOnDelimitersKeepNoSpace) { +TEST(TokenizerUtilTest, TokenizeOnDelimitersKeepNoSpace) { std::vector<Token> tokens = TokenizeOnDelimiters( "This might be čomplíčateď?!: Oder?", {' ', '?', '!'}, /* create_tokens_for_non_space_delimiters =*/true); diff --git a/native/utils/utf8/unicodetext.cc b/native/utils/utf8/unicodetext.cc index 2ddd38c..d05e377 100644 --- a/native/utils/utf8/unicodetext.cc +++ b/native/utils/utf8/unicodetext.cc @@ -210,6 +210,14 @@ std::vector<UnicodeText::const_iterator> UnicodeText::Codepoints() const { return codepoints; } +std::vector<char32> UnicodeText::CodepointsChar32() const { + std::vector<char32> codepoints; + for (auto it = begin(); it != end(); it++) { + codepoints.push_back(*it); + } + return codepoints; +} + bool UnicodeText::operator==(const UnicodeText& other) const { if (repr_.size_ != other.repr_.size_) { return false; diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h index 4ca0dd2..4c1c3ce 100644 --- a/native/utils/utf8/unicodetext.h +++ b/native/utils/utf8/unicodetext.h @@ -178,6 +178,9 @@ class UnicodeText { // Returns an iterator for each codepoint. std::vector<const_iterator> Codepoints() const; + // Returns the list of codepoints of the UnicodeText. + std::vector<char32> CodepointsChar32() const; + std::string ToUTF8String() const; std::string UTF8Substring(int begin_codepoint, int end_codepoint) const; static std::string UTF8Substring(const const_iterator& it_begin, |