summaryrefslogtreecommitdiff
path: root/native
diff options
context:
space:
mode:
authorTony Mak <tonymak@google.com>2021-02-24 20:08:27 +0000
committerTony Mak <tonymak@google.com>2021-02-25 14:26:51 +0000
commit8a501057fd9d5a2c4c194bcd22de93691bc1c452 (patch)
tree0ffb1f53246bc6cfd075d4d23ca2578d6d1122c1 /native
parent2587b43b53b9643da23c118f53199132ab28b414 (diff)
downloadlibtextclassifier-8a501057fd9d5a2c4c194bcd22de93691bc1c452.tar.gz
Export libtextclassifier
Export libtextclassifier without the model downloader code to mainline-prod for a few changes we want for the upcoming mainline release. Test: atest -p external/libtextclassifier Test: Smart selection + smart reply on a R device Fixes: 179890518 Change-Id: I0da4487432920e3f95cb00d4f44c8ec257f4b81b
Diffstat (limited to 'native')
-rw-r--r--native/FlatBufferHeaders.bp9
-rw-r--r--native/actions/actions-suggestions.cc39
-rw-r--r--native/actions/actions-suggestions_test.cc9
-rwxr-xr-xnative/actions/actions_model.fbs16
-rw-r--r--native/actions/grammar-actions_test.cc34
-rw-r--r--native/actions/test_data/actions_suggestions_grammar_test.modelbin145160 -> 145160 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.modelbin3387344 -> 3387328 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_9heads.modelbin3874672 -> 3874464 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.modelbin3853552 -> 3853376 bytes
-rw-r--r--native/annotator/annotator.cc57
-rw-r--r--native/annotator/annotator.h6
-rw-r--r--native/annotator/annotator_test-include.cc47
-rwxr-xr-xnative/annotator/datetime/datetime.fbs145
-rw-r--r--native/annotator/duration/duration_test.cc2
-rw-r--r--native/annotator/feature-processor.cc7
-rw-r--r--native/annotator/feature-processor.h5
-rw-r--r--native/annotator/grammar/grammar-annotator_test.cc41
-rw-r--r--native/annotator/knowledge/knowledge-engine-dummy.h5
-rwxr-xr-xnative/annotator/model.fbs12
-rw-r--r--native/annotator/number/number_test-include.cc2
-rw-r--r--native/annotator/strip-unpaired-brackets.cc77
-rw-r--r--native/annotator/strip-unpaired-brackets.h13
-rw-r--r--native/annotator/types.h10
-rw-r--r--native/lang_id/common/lite_base/endian.h9
-rw-r--r--native/utils/base/endian.h18
-rw-r--r--native/utils/grammar/analyzer_test.cc4
-rw-r--r--native/utils/grammar/parsing/matcher_test.cc25
-rw-r--r--native/utils/grammar/parsing/parser_test.cc48
-rwxr-xr-xnative/utils/grammar/rules.fbs2
-rw-r--r--native/utils/grammar/semantics/composer_test.cc8
-rw-r--r--native/utils/grammar/utils/ir.cc24
-rw-r--r--native/utils/grammar/utils/ir.h9
-rw-r--r--native/utils/grammar/utils/ir_test.cc25
-rw-r--r--native/utils/grammar/utils/locale-shard-map.cc86
-rw-r--r--native/utils/grammar/utils/locale-shard-map.h55
-rw-r--r--native/utils/grammar/utils/locale-shard-map_test.cc76
-rw-r--r--native/utils/grammar/utils/rules.cc2
-rw-r--r--native/utils/grammar/utils/rules.h5
-rw-r--r--native/utils/grammar/utils/rules_test.cc41
-rw-r--r--native/utils/i18n/locale.cc16
-rw-r--r--native/utils/i18n/locale.h4
-rw-r--r--native/utils/tflite-model-executor.cc46
-rw-r--r--native/utils/tokenizer-utils.cc (renamed from native/utils/test-utils.cc)2
-rw-r--r--native/utils/tokenizer-utils.h (renamed from native/utils/test-utils.h)6
-rw-r--r--native/utils/tokenizer-utils_test.cc (renamed from native/utils/test-utils_test.cc)8
-rw-r--r--native/utils/utf8/unicodetext.cc8
-rw-r--r--native/utils/utf8/unicodetext.h3
47 files changed, 733 insertions, 333 deletions
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
index 6248d2a..4212bbd 100644
--- a/native/FlatBufferHeaders.bp
+++ b/native/FlatBufferHeaders.bp
@@ -64,13 +64,6 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_annotator_datetime_datetime",
- srcs: ["annotator/datetime/datetime.fbs"],
- out: ["annotator/datetime/datetime_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
name: "libtextclassifier_fbgen_annotator_entity-data",
srcs: ["annotator/entity-data.fbs"],
out: ["annotator/entity-data_generated.h"],
@@ -185,7 +178,6 @@ cc_library_headers {
"libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
"libtextclassifier_fbgen_annotator_experimental_experimental",
- "libtextclassifier_fbgen_annotator_datetime_datetime",
"libtextclassifier_fbgen_annotator_entity-data",
"libtextclassifier_fbgen_utils_grammar_testing_value",
"libtextclassifier_fbgen_utils_grammar_semantics_expression",
@@ -209,7 +201,6 @@ cc_library_headers {
"libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
"libtextclassifier_fbgen_annotator_experimental_experimental",
- "libtextclassifier_fbgen_annotator_datetime_datetime",
"libtextclassifier_fbgen_annotator_entity-data",
"libtextclassifier_fbgen_utils_grammar_testing_value",
"libtextclassifier_fbgen_utils_grammar_semantics_expression",
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index a9edde9..69235d7 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -72,6 +72,20 @@ int NumMessagesToConsider(const Conversation& conversation,
: max_conversation_history_length);
}
+template <typename T>
+std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
+ const int max_length,
+ const T pad_value) {
+ if (inputs.size() >= max_length) {
+ return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
+ } else {
+ std::vector<T> result;
+ result.reserve(max_length);
+ result.insert(result.begin(), inputs.begin(), inputs.end());
+ result.insert(result.end(), max_length - inputs.size(), pad_value);
+ return result;
+ }
+}
} // namespace
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
@@ -639,8 +653,17 @@ bool ActionsSuggestions::SetupModelInput(
return false;
}
if (model_->tflite_model_spec()->input_context() >= 0) {
- model_executor_->SetInput<std::string>(
- model_->tflite_model_spec()->input_context(), context, interpreter);
+ if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(),
+ PadOrTruncateToTargetLength(
+ context, model_->tflite_model_spec()->input_length_to_pad(),
+ std::string("")),
+ interpreter);
+ } else {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(), context, interpreter);
+ }
}
if (model_->tflite_model_spec()->input_context_length() >= 0) {
model_executor_->SetInput<int>(
@@ -648,8 +671,16 @@ bool ActionsSuggestions::SetupModelInput(
interpreter);
}
if (model_->tflite_model_spec()->input_user_id() >= 0) {
- model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
- user_ids, interpreter);
+ if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_user_id(),
+ PadOrTruncateToTargetLength(
+ user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
+ interpreter);
+ } else {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
+ }
}
if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
model_executor_->SetInput<int>(
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 55aa852..ddaa604 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -29,6 +29,7 @@
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/flatbuffers_generated.h"
#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/utils/locale-shard-map.h"
#include "utils/grammar/utils/rules.h"
#include "utils/hash/farmhash.h"
#include "utils/jvm-test-utils.h"
@@ -1042,7 +1043,9 @@ TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
// Setup test rules.
action_grammar_rules->rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add(
"<knock>", {"<^>", "ventura", "!?", "<$>"},
/*callback=*/
@@ -1102,7 +1105,9 @@ TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
// Setup test rules.
action_grammar_rules->rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add(
"<event>", {"it", "is", "at", "<time>"},
/*callback=*/
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 1548816..0db43f4 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -15,15 +15,15 @@
//
include "actions/actions-entity-data.fbs";
-include "utils/grammar/rules.fbs";
-include "utils/tokenizer.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
-include "utils/codepoint-range.fbs";
-include "utils/zlib/buffer.fbs";
-include "utils/normalization.fbs";
include "annotator/model.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/grammar/rules.fbs";
include "utils/intents/intent-config.fbs";
+include "utils/normalization.fbs";
include "utils/resources.fbs";
+include "utils/tokenizer.fbs";
+include "utils/zlib/buffer.fbs";
file_identifier "TC3A";
@@ -116,6 +116,10 @@ table TensorflowLiteModelSpec {
// Map of additional input tensor name to its index.
input_name_index:[TensorflowLiteModelSpec_.InputNameIndexEntry];
+
+ // If greater than 0, pad or truncate the input_user_id and input_context
+ // tensor to length of input_length_to_pad.
+ input_length_to_pad:int = 0;
}
// Configuration for the tokenizer.
diff --git a/native/actions/grammar-actions_test.cc b/native/actions/grammar-actions_test.cc
index 9fe73d4..02deea9 100644
--- a/native/actions/grammar-actions_test.cc
+++ b/native/actions/grammar-actions_test.cc
@@ -37,6 +37,8 @@ namespace {
using ::testing::ElementsAre;
using ::testing::IsEmpty;
+using ::libtextclassifier3::grammar::LocaleShardMap;
+
class TestGrammarActions : public GrammarActions {
public:
explicit TestGrammarActions(
@@ -140,12 +142,14 @@ class GrammarActionsTest : public testing::Test {
};
TEST_F(GrammarActionsTest, ProducesSmartReplies) {
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
// Create test rules.
// Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
rules.Add(
"<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
/*callback=*/
@@ -174,7 +178,8 @@ TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add(
"<scripted_reply>",
@@ -231,7 +236,8 @@ TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add(
"<call_phone>", {"please", "dial", "<phone>"},
@@ -270,7 +276,9 @@ TEST_F(GrammarActionsTest, HandlesLocales) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules(/*num_shards=*/2);
+ LocaleShardMap locale_shard_map =
+ LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"});
+ grammar::Rules rules(locale_shard_map);
rules.Add(
"<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
/*callback=*/
@@ -333,7 +341,8 @@ TEST_F(GrammarActionsTest, HandlesAssertions) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -387,7 +396,8 @@ TEST_F(GrammarActionsTest, SetsFixedEntityData) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
// Create smart reply and static entity data.
const int spec_id =
@@ -440,7 +450,8 @@ TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
// Create smart reply and static entity data.
const int spec_id =
@@ -522,7 +533,8 @@ TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
// Create smart reply.
const int spec_id =
@@ -572,7 +584,8 @@ TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
RulesModel_::GrammarRulesT action_grammar_rules;
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add(
"<call_phone>", {"please", "dial", "<phone>"},
/*callback=*/
@@ -632,7 +645,8 @@ TEST_F(GrammarActionsTest, HandlesExclusions) {
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<excluded>", {"be", "safe"});
rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
/*excluded_nonterminal=*/"<excluded>");
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index ae6dc60..1af22aa 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index 52f932e..1361475 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 6145540..1396a46 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index de8520a..660d97f 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 2635820..53176e6 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -1017,7 +1017,13 @@ CodepointSpan Annotator::SuggestSelection(
return original_click_indices;
}
- return candidates.annotated_spans[0][i].span;
+ // We return a suggested span contains the original span.
+ // This compensates for "select all" selection that may come from
+ // other apps. See http://b/179890518.
+ if (SpanContains(candidates.annotated_spans[0][i].span,
+ original_click_indices)) {
+ return candidates.annotated_spans[0][i].span;
+ }
}
}
@@ -1935,7 +1941,7 @@ std::vector<ClassificationResult> Annotator::ClassifyText(
bool Annotator::ModelAnnotate(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- const BaseOptions& options, InterpreterManager* interpreter_manager,
+ const AnnotationOptions& options, InterpreterManager* interpreter_manager,
std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
@@ -2008,13 +2014,41 @@ bool Annotator::ModelAnnotate(
}
const int offset = std::distance(context_unicode.begin(), line.first);
+ UnicodeText line_unicode;
+ std::vector<UnicodeText::const_iterator> line_codepoints;
+ if (options.enable_optimization) {
+ if (local_chunks.empty()) {
+ continue;
+ }
+ line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false);
+ line_codepoints = line_unicode.Codepoints();
+ line_codepoints.push_back(line_unicode.end());
+ }
for (const TokenSpan& chunk : local_chunks) {
CodepointSpan codepoint_span =
- selection_feature_processor_->StripBoundaryCodepoints(
- line_str, TokenSpanToCodepointSpan(line_tokens, chunk));
- if (model_->selection_options()->strip_unpaired_brackets()) {
- codepoint_span =
- StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
+ TokenSpanToCodepointSpan(line_tokens, chunk);
+ if (options.enable_optimization) {
+ if (!codepoint_span.IsValid() ||
+ codepoint_span.second > line_codepoints.size()) {
+ continue;
+ }
+ codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second],
+ codepoint_span);
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ codepoint_span = StripUnpairedBrackets(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second],
+ codepoint_span, *unilib_);
+ }
+ } else {
+ codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ line_str, codepoint_span);
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ codepoint_span =
+ StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
+ }
}
// Skip empty spans.
@@ -3136,4 +3170,13 @@ bool Annotator::LookUpKnowledgeEntity(
knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
}
+StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
+ const std::string& mid_str, const std::string& property) const {
+ if (!knowledge_engine_) {
+ return Status(StatusCode::FAILED_PRECONDITION,
+ "knowledge_engine_ is nullptr");
+ }
+ return knowledge_engine_->LookUpEntityProperty(mid_str, property);
+}
+
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index 5397f56..a570a83 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -219,6 +219,10 @@ class Annotator {
bool LookUpKnowledgeEntity(const std::string& id,
std::string* serialized_knowledge_result) const;
+ // Looks up an entity's property.
+ StatusOr<std::string> LookUpKnowledgeEntityProperty(
+ const std::string& mid_str, const std::string& property) const;
+
const Model* model() const;
const reflection::Schema* entity_data_schema() const;
@@ -342,7 +346,7 @@ class Annotator {
// reuse.
bool ModelAnnotate(const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- const BaseOptions& options,
+ const AnnotationOptions& options,
InterpreterManager* interpreter_manager,
std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc
index 3ed91e1..b852827 100644
--- a/native/annotator/annotator_test-include.cc
+++ b/native/annotator/annotator_test-include.cc
@@ -27,6 +27,7 @@
#include "annotator/test-utils.h"
#include "annotator/types-test-util.h"
#include "annotator/types.h"
+#include "utils/grammar/utils/locale-shard-map.h"
#include "utils/grammar/utils/rules.h"
#include "utils/testing/annotator.h"
#include "lang_id/fb_model/lang-id-from-fb.h"
@@ -921,13 +922,13 @@ TEST_F(AnnotatorTest, SuggestSelection) {
// Unpaired bracket stripping.
EXPECT_EQ(
- classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
+ classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}),
CodepointSpan(11, 25));
- EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
+ EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}),
CodepointSpan(12, 15));
- EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
+ EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}),
CodepointSpan(11, 15));
- EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
+ EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}),
CodepointSpan(12, 15));
// If the resulting selection would be empty, the original span is returned.
@@ -937,6 +938,12 @@ TEST_F(AnnotatorTest, SuggestSelection) {
CodepointSpan(11, 12));
EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
CodepointSpan(11, 12));
+
+ // If the original span is larger than the found selection, the original span
+ // is returned.
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}),
+ CodepointSpan(5, 24));
}
TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
@@ -1238,6 +1245,34 @@ TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
}));
}
+TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.enable_optimization = true;
+
+ EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 26, "phone"),
+ }));
+
+ // Unpaired bracket stripping.
+ EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 22, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+}
+
TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
@@ -1743,7 +1778,9 @@ TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
// Add test rules.
grammar_model->rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<tv_detective>", {"jessica", "fletcher"});
rules.Add("<tv_detective>", {"columbo"});
rules.Add("<tv_detective>", {"magnum"});
diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs
deleted file mode 100755
index 8012cdc..0000000
--- a/native/annotator/datetime/datetime.fbs
+++ /dev/null
@@ -1,145 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-// Meridiem field.
-namespace libtextclassifier3.grammar.datetime;
-enum Meridiem : int {
- UNKNOWN = 0,
-
- // Ante meridiem: Before noon
- AM = 1,
-
- // Post meridiem: After noon
- PM = 2,
-}
-
-// Enum represents a unit of date and time in the expression.
-// Next field: 10
-namespace libtextclassifier3.grammar.datetime;
-enum ComponentType : int {
- UNSPECIFIED = 0,
-
- // Year of the date seen in the text match.
- YEAR = 1,
-
- // Month of the year starting with January = 1.
- MONTH = 2,
-
- // Week (7 days).
- WEEK = 3,
-
- // Day of week, start of the week is Sunday & its value is 1.
- DAY_OF_WEEK = 4,
-
- // Day of the month starting with 1.
- DAY_OF_MONTH = 5,
-
- // Hour of the day.
- HOUR = 6,
-
- // Minute of the hour with a range of 0-59.
- MINUTE = 7,
-
- // Seconds of the minute with a range of 0-59.
- SECOND = 8,
-
- // Meridiem field i.e. AM/PM.
- MERIDIEM = 9,
-}
-
-namespace libtextclassifier3.grammar.datetime;
-table TimeZone {
- // Offset from UTC/GTM in minutes.
- utc_offset_mins:int;
-}
-
-namespace libtextclassifier3.grammar.datetime.RelativeDatetimeComponent_;
-enum Modifier : int {
- UNSPECIFIED = 0,
- NEXT = 1,
- THIS = 2,
- LAST = 3,
- NOW = 4,
- TOMORROW = 5,
- YESTERDAY = 6,
-}
-
-// Message for representing the relative date-time component in date-time
-// expressions.
-// Next field: 4
-namespace libtextclassifier3.grammar.datetime;
-table RelativeDatetimeComponent {
- component_type:ComponentType = UNSPECIFIED;
- modifier:RelativeDatetimeComponent_.Modifier = UNSPECIFIED;
- value:int;
-}
-
-// AbsoluteDateTime represents date-time expressions that is not ambiguous.
-// Next field: 11
-namespace libtextclassifier3.grammar.datetime;
-table AbsoluteDateTime {
- // Year value of the date seen in the text match.
- year:int = -1;
-
- // Month value of the year starting with January = 1.
- month:int = -1;
-
- // Day value of the month starting with 1.
- day:int = -1;
-
- // Day of week, start of the week is Sunday and its value is 1.
- week_day:int = -1;
-
- // Hour value of the day.
- hour:int = -1;
-
- // Minute value of the hour with a range of 0-59.
- minute:int = -1;
-
- // Seconds value of the minute with a range of 0-59.
- second:int = -1;
-
- partial_second:double = -1;
-
- // Meridiem field i.e. AM/PM.
- meridiem:Meridiem;
-
- time_zone:TimeZone;
-}
-
-// Message to represent relative datetime expressions.
-// It encode expressions
-// - Where modifier such as before/after shift the date e.g.[three days ago],
-// [2 days after March 1st].
-// - When prefix make the expression relative e.g. [next weekend],
-// [last Monday].
-// Next field: 3
-namespace libtextclassifier3.grammar.datetime;
-table RelativeDateTime {
- relative_datetime_component:[RelativeDatetimeComponent];
-
- // The base could be an absolute datetime point for example: "March 1", a
- // relative datetime point, for example: "2 days before March 1"
- base:AbsoluteDateTime;
-}
-
-// Datetime result.
-namespace libtextclassifier3.grammar.datetime;
-table UngroundedDatetime {
- absolute_datetime:AbsoluteDateTime;
- relative_datetime:RelativeDateTime;
-}
-
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index f5e0510..7c07a72 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -23,7 +23,7 @@
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
#include "annotator/types.h"
-#include "utils/test-utils.h"
+#include "utils/tokenizer-utils.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
diff --git a/native/annotator/feature-processor.cc b/native/annotator/feature-processor.cc
index 3831c5f..99e25e1 100644
--- a/native/annotator/feature-processor.cc
+++ b/native/annotator/feature-processor.cc
@@ -474,6 +474,13 @@ bool FeatureProcessor::SelectionLabelSpans(
return true;
}
+bool FeatureProcessor::SelectionLabelRelativeTokenSpans(
+ std::vector<TokenSpan>* selection_label_relative_token_spans) const {
+ selection_label_relative_token_spans->assign(label_to_selection_.begin(),
+ label_to_selection_.end());
+ return true;
+}
+
void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
if (options_->ignored_span_boundary_codepoints() != nullptr) {
for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
diff --git a/native/annotator/feature-processor.h b/native/annotator/feature-processor.h
index 3b865b0..482d274 100644
--- a/native/annotator/feature-processor.h
+++ b/native/annotator/feature-processor.h
@@ -165,6 +165,11 @@ class FeatureProcessor {
VectorSpan<Token> tokens,
std::vector<CodepointSpan>* selection_label_spans) const;
+ // Fills selection_label_relative_token_spans with number of tokens left and
+ // right from the click.
+ bool SelectionLabelRelativeTokenSpans(
+ std::vector<TokenSpan>* selection_label_relative_token_spans) const;
+
int DenseFeaturesCount() const {
return feature_extractor_.DenseFeaturesCount();
}
diff --git a/native/annotator/grammar/grammar-annotator_test.cc b/native/annotator/grammar/grammar-annotator_test.cc
index b2084cb..6fcd1f5 100644
--- a/native/annotator/grammar/grammar-annotator_test.cc
+++ b/native/annotator/grammar/grammar-annotator_test.cc
@@ -23,6 +23,7 @@
#include "annotator/model_generated.h"
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/utils/locale-shard-map.h"
#include "utils/grammar/utils/rules.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unicodetext.h"
@@ -45,7 +46,9 @@ TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -79,7 +82,9 @@ TEST_F(GrammarAnnotatorTest, HandlesAssertions) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -120,7 +125,9 @@ TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.AddValueMapping("<low_confidence_phone>", {"<digits>"},
/*value=*/0);
@@ -157,7 +164,9 @@ TEST_F(GrammarAnnotatorTest, ClassifiesTextWithGrammarRules) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -194,7 +203,9 @@ TEST_F(GrammarAnnotatorTest, ClassifiesTextWithAssertions) {
grammar_model.context_left_num_tokens = -1;
grammar_model.context_right_num_tokens = -1;
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -254,7 +265,9 @@ TEST_F(GrammarAnnotatorTest, ClassifiesTextWithContext) {
grammar_model.context_left_num_tokens = 3;
grammar_model.context_right_num_tokens = 0;
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<tracking_number>", {"<5_digits>"});
rules.Add("<tracking_number>", {"<6_digits>"});
rules.Add("<tracking_number>", {"<7_digits>"});
@@ -306,7 +319,9 @@ TEST_F(GrammarAnnotatorTest, SuggestsTextSelection) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -338,7 +353,9 @@ TEST_F(GrammarAnnotatorTest, SetsFixedEntityData) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
const int person_result =
AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
rules.Add(
@@ -382,7 +399,9 @@ TEST_F(GrammarAnnotatorTest, SetsEntityDataFromCapturingMatches) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
const int person_result =
AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
@@ -438,7 +457,9 @@ TEST_F(GrammarAnnotatorTest, RespectsRuleModes) {
GrammarModelT grammar_model;
SetTestTokenizerOptions(&grammar_model);
grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
rules.Add("<classification_carrier>", {"ei"});
rules.Add("<classification_carrier>", {"en"});
rules.Add("<selection_carrier>", {"ai"});
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 615ad06..b6c4f42 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -66,6 +66,11 @@ class KnowledgeEngine {
std::string* serialized_knowledge_result) const {
return false;
}
+
+ StatusOr<std::string> LookUpEntityProperty(
+ const std::string& mid_str, const std::string& property) const {
+ return Status(StatusCode::UNIMPLEMENTED, "Not implemented");
+ }
};
} // namespace libtextclassifier3
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index dbbb422..f639f06 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -14,17 +14,17 @@
// limitations under the License.
//
-include "utils/intents/intent-config.fbs";
-include "annotator/experimental/experimental.fbs";
include "annotator/entity-data.fbs";
+include "annotator/experimental/experimental.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/container/bit-vector.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
include "utils/grammar/rules.fbs";
+include "utils/intents/intent-config.fbs";
include "utils/normalization.fbs";
-include "utils/tokenizer.fbs";
include "utils/resources.fbs";
-include "utils/codepoint-range.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/tokenizer.fbs";
include "utils/zlib/buffer.fbs";
-include "utils/container/bit-vector.fbs";
file_identifier "TC2 ";
diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc
index d95d388..f47933f 100644
--- a/native/annotator/number/number_test-include.cc
+++ b/native/annotator/number/number_test-include.cc
@@ -23,7 +23,7 @@
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
#include "annotator/types.h"
-#include "utils/test-utils.h"
+#include "utils/tokenizer-utils.h"
#include "utils/utf8/unicodetext.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/native/annotator/strip-unpaired-brackets.cc b/native/annotator/strip-unpaired-brackets.cc
index df1fcce..8bf93d9 100644
--- a/native/annotator/strip-unpaired-brackets.cc
+++ b/native/annotator/strip-unpaired-brackets.cc
@@ -22,59 +22,23 @@
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-namespace {
-// Returns true if given codepoint is contained in the given span in context.
-bool IsCodepointInSpan(const char32 codepoint,
- const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto begin_it = context_unicode.begin();
- std::advance(begin_it, span.first);
- auto end_it = context_unicode.begin();
- std::advance(end_it, span.second);
-
- return std::find(begin_it, end_it, codepoint) != end_it;
-}
-
-// Returns the first codepoint of the span.
-char32 FirstSpanCodepoint(const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto it = context_unicode.begin();
- std::advance(it, span.first);
- return *it;
-}
-
-// Returns the last codepoint of the span.
-char32 LastSpanCodepoint(const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto it = context_unicode.begin();
- std::advance(it, span.second - 1);
- return *it;
-}
-
-} // namespace
-
-CodepointSpan StripUnpairedBrackets(const std::string& context,
- CodepointSpan span, const UniLib& unilib) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripUnpairedBrackets(context_unicode, span, unilib);
-}
-
-// If the first or the last codepoint of the given span is a bracket, the
-// bracket is stripped if the span does not contain its corresponding paired
-// version.
-CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
- CodepointSpan span, const UniLib& unilib) {
- if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
+CodepointSpan StripUnpairedBrackets(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const UniLib& unilib) {
+ if (span_begin == span_end || !span.IsValid() || span.IsEmpty()) {
return span;
}
- const char32 begin_char = FirstSpanCodepoint(context_unicode, span);
+ UnicodeText::const_iterator begin = span_begin;
+ const UnicodeText::const_iterator end = span_end;
+ const char32 begin_char = *begin;
const char32 paired_begin_char = unilib.GetPairedBracket(begin_char);
if (paired_begin_char != begin_char) {
if (!unilib.IsOpeningBracket(begin_char) ||
- !IsCodepointInSpan(paired_begin_char, context_unicode, span)) {
+ std::find(begin, end, paired_begin_char) == end) {
+ ++begin;
++span.first;
}
}
@@ -83,11 +47,11 @@ CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
return span;
}
- const char32 end_char = LastSpanCodepoint(context_unicode, span);
+ const char32 end_char = *std::prev(end);
const char32 paired_end_char = unilib.GetPairedBracket(end_char);
if (paired_end_char != end_char) {
if (!unilib.IsClosingBracket(end_char) ||
- !IsCodepointInSpan(paired_end_char, context_unicode, span)) {
+ std::find(begin, end, paired_end_char) == end) {
--span.second;
}
}
@@ -102,4 +66,21 @@ CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
return span;
}
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context,
+ CodepointSpan span, const UniLib& unilib) {
+ if (!span.IsValid() || span.IsEmpty()) {
+ return span;
+ }
+ const UnicodeText span_text = UnicodeText::Substring(
+ context, span.first, span.second, /*do_copy=*/false);
+ return StripUnpairedBrackets(span_text.begin(), span_text.end(), span,
+ unilib);
+}
+
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib) {
+ return StripUnpairedBrackets(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ span, unilib);
+}
+
} // namespace libtextclassifier3
diff --git a/native/annotator/strip-unpaired-brackets.h b/native/annotator/strip-unpaired-brackets.h
index ceb8d60..c6cdc1a 100644
--- a/native/annotator/strip-unpaired-brackets.h
+++ b/native/annotator/strip-unpaired-brackets.h
@@ -23,14 +23,21 @@
#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
+
// If the first or the last codepoint of the given span is a bracket, the
// bracket is stripped if the span does not contain its corresponding paired
// version.
-CodepointSpan StripUnpairedBrackets(const std::string& context,
+CodepointSpan StripUnpairedBrackets(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const UniLib& unilib);
+
+// Same as above but takes a UnicodeText instance for the span.
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context,
CodepointSpan span, const UniLib& unilib);
-// Same as above but takes UnicodeText instance directly.
-CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
+// Same as above but takes a string instance.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
CodepointSpan span, const UniLib& unilib);
} // namespace libtextclassifier3
diff --git a/native/annotator/types.h b/native/annotator/types.h
index 3063838..45999cd 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -101,6 +101,11 @@ inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
return a.first < b.second && b.first < a.second;
}
+inline bool SpanContains(const CodepointSpan& span,
+ const CodepointSpan& sub_span) {
+ return span.first <= sub_span.first && span.second >= sub_span.second;
+}
+
template <typename T>
bool DoesCandidateConflict(
const int considered_candidate, const std::vector<T>& candidates,
@@ -610,6 +615,11 @@ struct AnnotationOptions : public BaseOptions, public DatetimeOptions {
// If true, trigger dictionary on words that are of beginner level.
bool trigger_dictionary_on_beginner_words = false;
+ // If true, enables an optimized code path for annotation.
+ // The optimization caused crashes previously, which is why we are rolling it
+ // out using this temporary flag. See: b/178503899
+ bool enable_optimization = false;
+
bool operator==(const AnnotationOptions& other) const {
return this->is_serialized_entity_data_enabled ==
other.is_serialized_entity_data_enabled &&
diff --git a/native/lang_id/common/lite_base/endian.h b/native/lang_id/common/lite_base/endian.h
index 16c2dca..2e3ee26 100644
--- a/native/lang_id/common/lite_base/endian.h
+++ b/native/lang_id/common/lite_base/endian.h
@@ -93,15 +93,6 @@ class LittleEndian {
// Conversion functions.
#ifdef SAFTM_IS_LITTLE_ENDIAN
- static uint16 FromHost16(uint16 x) { return x; }
- static uint16 ToHost16(uint16 x) { return x; }
-
- static uint32 FromHost32(uint32 x) { return x; }
- static uint32 ToHost32(uint32 x) { return x; }
-
- static uint64 FromHost64(uint64 x) { return x; }
- static uint64 ToHost64(uint64 x) { return x; }
-
static bool IsLittleEndian() { return true; }
#elif defined SAFTM_IS_BIG_ENDIAN
diff --git a/native/utils/base/endian.h b/native/utils/base/endian.h
index 9312704..810bc46 100644
--- a/native/utils/base/endian.h
+++ b/native/utils/base/endian.h
@@ -53,8 +53,8 @@ namespace libtextclassifier3 {
#define bswap_64(x) OSSwapInt64(x)
#endif // !defined(bswap_16)
#else
-#define GG_LONGLONG(x) x##LL
-#define GG_ULONGLONG(x) x##ULL
+#define int64_t {x} x##LL
+#define uint64_t {x} x##ULL
static inline uint16 bswap_16(uint16 x) {
return (uint16)(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); // NOLINT
}
@@ -65,14 +65,12 @@ static inline uint32 bswap_32(uint32 x) {
}
#define bswap_32(x) bswap_32(x)
static inline uint64 bswap_64(uint64 x) {
- return (((x & GG_ULONGLONG(0xFF)) << 56) |
- ((x & GG_ULONGLONG(0xFF00)) << 40) |
- ((x & GG_ULONGLONG(0xFF0000)) << 24) |
- ((x & GG_ULONGLONG(0xFF000000)) << 8) |
- ((x & GG_ULONGLONG(0xFF00000000)) >> 8) |
- ((x & GG_ULONGLONG(0xFF0000000000)) >> 24) |
- ((x & GG_ULONGLONG(0xFF000000000000)) >> 40) |
- ((x & GG_ULONGLONG(0xFF00000000000000)) >> 56));
+ return (((x & uint64_t{0xFF}) << 56) | ((x & uint64_t{0xFF00}) << 40) |
+ ((x & uint64_t{0xFF0000}) << 24) | ((x & uint64_t{0xFF000000}) << 8) |
+ ((x & uint64_t{0xFF00000000}) >> 8) |
+ ((x & uint64_t{0xFF0000000000}) >> 24) |
+ ((x & uint64_t{0xFF000000000000}) >> 40) |
+ ((x & uint64_t{0xFF00000000000000}) >> 56));
}
#define bswap_64(x) bswap_64(x)
#endif
diff --git a/native/utils/grammar/analyzer_test.cc b/native/utils/grammar/analyzer_test.cc
index 9f71efe..4950fb4 100644
--- a/native/utils/grammar/analyzer_test.cc
+++ b/native/utils/grammar/analyzer_test.cc
@@ -38,7 +38,9 @@ TEST_F(AnalyzerTest, ParsesTextWithGrammar) {
semantic_values_schema_.buffer().end());
// Define rules and semantics.
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<month>", {"january"},
static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
/*callback_param=*/model.semantic_expression.size());
diff --git a/native/utils/grammar/parsing/matcher_test.cc b/native/utils/grammar/parsing/matcher_test.cc
index 8528009..7c9a14d 100644
--- a/native/utils/grammar/parsing/matcher_test.cc
+++ b/native/utils/grammar/parsing/matcher_test.cc
@@ -123,7 +123,9 @@ class MatcherTest : public testing::Test {
TEST_F(MatcherTest, HandlesBasicOperations) {
// Create an example grammar.
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<test>", {"the", "quick", "brown", "fox"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
rules.Add("<action>", {"<test>"},
@@ -146,7 +148,9 @@ TEST_F(MatcherTest, HandlesBasicOperations) {
std::string CreateTestGrammar() {
// Create an example grammar.
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
// Callbacks on terminal rules.
rules.Add("<output_5>", {"quick"},
@@ -260,7 +264,9 @@ TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) {
}
TEST_F(MatcherTest, HandlesOptionalRuleElements) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"},
static_cast<CallbackId>(DefaultCallback::kRootRule));
rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"},
@@ -293,7 +299,9 @@ TEST_F(MatcherTest, HandlesOptionalRuleElements) {
}
TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
// Require no whitespace between code and flight number.
@@ -331,7 +339,9 @@ TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
}
TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
/*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
@@ -383,7 +393,10 @@ TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
}
TEST_F(MatcherTest, HandlesExclusions) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+
rules.Add("<all_zeros>", {"0000"});
rules.AddWithExclusion("<flight_code>", {"<4_digits>"},
/*excluded_nonterminal=*/"<all_zeros>");
diff --git a/native/utils/grammar/parsing/parser_test.cc b/native/utils/grammar/parsing/parser_test.cc
index cf8310b..183be0e 100644
--- a/native/utils/grammar/parsing/parser_test.cc
+++ b/native/utils/grammar/parsing/parser_test.cc
@@ -41,7 +41,9 @@ using ::testing::IsEmpty;
class ParserTest : public GrammarTest {};
TEST_F(ParserTest, ParsesSimpleRules) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<day>", {"<2_digits>"});
rules.Add("<month>", {"<2_digits>"});
rules.Add("<year>", {"<4_digits>"});
@@ -58,7 +60,9 @@ TEST_F(ParserTest, ParsesSimpleRules) {
}
TEST_F(ParserTest, HandlesEmptyInput) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
constexpr int kTest = 0;
rules.Add("<test>", {"test"},
static_cast<CallbackId>(DefaultCallback::kRootRule), kTest);
@@ -80,7 +84,9 @@ TEST_F(ParserTest, HandlesEmptyInput) {
}
TEST_F(ParserTest, HandlesUppercaseTokens) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
constexpr int kScriptedReply = 0;
rules.Add("<test>", {"please?", "reply", "<uppercase_token>"},
static_cast<CallbackId>(DefaultCallback::kRootRule),
@@ -99,7 +105,9 @@ TEST_F(ParserTest, HandlesUppercaseTokens) {
}
TEST_F(ParserTest, HandlesAnchors) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
constexpr int kScriptedReply = 0;
rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"},
static_cast<CallbackId>(DefaultCallback::kRootRule),
@@ -118,7 +126,9 @@ TEST_F(ParserTest, HandlesAnchors) {
}
TEST_F(ParserTest, HandlesWordBreaks) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
constexpr int kFlight = 0;
@@ -141,7 +151,9 @@ TEST_F(ParserTest, HandlesWordBreaks) {
}
TEST_F(ParserTest, HandlesAnnotations) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
constexpr int kCallPhone = 0;
rules.Add("<flight>", {"dial", "<phone>"},
static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone);
@@ -167,7 +179,9 @@ TEST_F(ParserTest, HandlesAnnotations) {
}
TEST_F(ParserTest, HandlesRegexAnnotators) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.AddRegex("<code>",
"(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)");
constexpr int kScriptedReply = 0;
@@ -188,7 +202,9 @@ TEST_F(ParserTest, HandlesRegexAnnotators) {
}
TEST_F(ParserTest, HandlesExclusions) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<excluded>", {"be", "safe"});
rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
/*excluded_nonterminal=*/"<excluded>");
@@ -210,7 +226,9 @@ TEST_F(ParserTest, HandlesExclusions) {
}
TEST_F(ParserTest, HandlesFillers) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
constexpr int kSetReminder = 0;
rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"},
static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
@@ -224,7 +242,9 @@ TEST_F(ParserTest, HandlesFillers) {
}
TEST_F(ParserTest, HandlesAssertions) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -249,7 +269,9 @@ TEST_F(ParserTest, HandlesAssertions) {
}
TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<carrier>", {"lx"});
rules.Add("<carrier>", {"aa"});
rules.Add("<flight_code>", {"<2_digits>"});
@@ -270,7 +292,9 @@ TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
}
TEST_F(ParserTest, HandlesCaseSensitiveMatching) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0,
/*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0,
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
index 3225892..bc0136c 100755
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
@@ -14,9 +14,9 @@
// limitations under the License.
//
+include "utils/grammar/semantics/expression.fbs";
include "utils/i18n/language-tag.fbs";
include "utils/zlib/buffer.fbs";
-include "utils/grammar/semantics/expression.fbs";
// The terminal rules map as sorted strings table.
// The sorted terminal strings table is represented as offsets into the
diff --git a/native/utils/grammar/semantics/composer_test.cc b/native/utils/grammar/semantics/composer_test.cc
index 95b0759..e768e18 100644
--- a/native/utils/grammar/semantics/composer_test.cc
+++ b/native/utils/grammar/semantics/composer_test.cc
@@ -38,7 +38,9 @@ class SemanticComposerTest : public GrammarTest {};
TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) {
RulesSetT model;
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
const int test_value_type =
TypeIdForName(semantic_values_schema_.get(),
"libtextclassifier3.grammar.TestValue")
@@ -109,7 +111,9 @@ TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) {
TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) {
RulesSetT model;
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
const int test_value_type =
TypeIdForName(semantic_values_schema_.get(),
"libtextclassifier3.grammar.TestValue")
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index 49135bf..9477dd0 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -16,6 +16,7 @@
#include "utils/grammar/utils/ir.h"
+#include "utils/i18n/locale.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
#include "utils/zlib/zlib.h"
@@ -445,16 +446,29 @@ void Ir::Serialize(const bool include_debug_information,
}
// Serialize the unary and binary rules.
- for (const RulesShard& shard : shards_) {
+ for (int i = 0; i < shards_.size(); i++) {
output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
RulesSet_::RulesT* rules = output->rules.back().get();
- // Serialize the unary rules.
- SerializeUnaryRulesShard(shard.unary_rules, output, rules);
+ for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) {
+ if (shard_locale.IsValid()) {
+ // Check if the language is set to all i.e. '*' which is a special, to
+ // make it consistent with device side parser here instead of filling
+ // the all locale leave the language tag list empty
+ rules->locale.emplace_back(
+ std::make_unique<libtextclassifier3::LanguageTagT>());
+ libtextclassifier3::LanguageTagT* language_tag =
+ rules->locale.back().get();
+ language_tag->language = shard_locale.Language();
+ language_tag->region = shard_locale.Region();
+ language_tag->script = shard_locale.Script();
+ }
+ }
+ // Serialize the unary rules.
+ SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules);
// Serialize the binary rules.
- SerializeBinaryRulesShard(shard.binary_rules, output, rules);
+ SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules);
}
-
// Serialize the terminal rules.
// We keep the rules separate by shard but merge the actual terminals into
// one shared string pool to most effectively exploit reuse.
diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h
index adafa66..f056d7a 100644
--- a/native/utils/grammar/utils/ir.h
+++ b/native/utils/grammar/utils/ir.h
@@ -25,6 +25,7 @@
#include "utils/base/integral_types.h"
#include "utils/grammar/rules_generated.h"
#include "utils/grammar/types.h"
+#include "utils/grammar/utils/locale-shard-map.h"
namespace libtextclassifier3::grammar {
@@ -96,8 +97,10 @@ class Ir {
std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules;
};
- explicit Ir(const int num_shards = 1)
- : num_nonterminals_(0), shards_(num_shards) {}
+ explicit Ir(const LocaleShardMap& locale_shard_map)
+ : num_nonterminals_(0),
+ locale_shard_map_(locale_shard_map),
+ shards_(locale_shard_map_.GetNumberOfShards()) {}
// Adds a new non-terminal.
Nonterm AddNonterminal(const std::string& name = "") {
@@ -224,6 +227,8 @@ class Ir {
Nonterm num_nonterminals_;
std::unordered_set<Nonterm> nonshareable_;
+ // Locale information for Rules
+ const LocaleShardMap& locale_shard_map_;
// The sharded rules.
std::vector<RulesShard> shards_;
diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc
index 279d99a..7a386df 100644
--- a/native/utils/grammar/utils/ir_test.cc
+++ b/native/utils/grammar/utils/ir_test.cc
@@ -31,7 +31,9 @@ using ::testing::Ne;
using ::testing::SizeIs;
TEST(IrTest, HandlesSharingWithTerminalRules) {
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
// <t1> ::= the
const Nonterm t1 = ir.Add(kUnassignedNonterm, "the");
@@ -72,7 +74,9 @@ TEST(IrTest, HandlesSharingWithTerminalRules) {
}
TEST(IrTest, HandlesSharingWithNonterminalRules) {
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
// Setup a few terminal rules.
const std::vector<Nonterm> rhs = {
@@ -97,7 +101,9 @@ TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) {
// Test sharing in the presence of callbacks.
constexpr CallbackId kOutput1 = 1;
constexpr CallbackId kOutput2 = 2;
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello");
const Nonterm x2 =
@@ -116,7 +122,10 @@ TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) {
TEST(IrTest, SerializesRulesToFlatbufferFormat) {
constexpr CallbackId kOutput = 1;
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
+
const Nonterm verb = ir.AddUnshareableNonterminal();
ir.Add(verb, "buy");
ir.Add(Ir::Lhs{verb, {kOutput}}, "bring");
@@ -155,7 +164,9 @@ TEST(IrTest, SerializesRulesToFlatbufferFormat) {
}
TEST(IrTest, HandlesRulesSharding) {
- Ir ir(/*num_shards=*/2);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"});
+ Ir ir(locale_shard_map);
const Nonterm verb = ir.AddUnshareableNonterminal();
const Nonterm set_reminder = ir.AddUnshareableNonterminal();
@@ -210,7 +221,9 @@ TEST(IrTest, HandlesRulesSharding) {
}
TEST(IrTest, DeduplicatesLhsSets) {
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
const Nonterm test = ir.AddUnshareableNonterminal();
ir.Add(test, "test");
diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc
new file mode 100644
index 0000000..4f7dc5e
--- /dev/null
+++ b/native/utils/grammar/utils/locale-shard-map.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/locale-shard-map.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
+#include "utils/strings/append.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) {
+ std::vector<Locale> locale_list;
+ for (const Locale& locale : LocaleList::ParseFrom(locale_tags).GetLocales()) {
+ if (locale.IsValid()) {
+ locale_list.emplace_back(locale);
+ }
+ }
+ std::sort(locale_list.begin(), locale_list.end(),
+ [](const Locale& a, const Locale& b) { return a < b; });
+ return locale_list;
+}
+
+} // namespace
+
+LocaleShardMap LocaleShardMap::CreateLocaleShardMap(
+ const std::vector<std::string>& locale_tags) {
+ LocaleShardMap locale_shard_map;
+ for (const std::string& locale_tag : locale_tags) {
+ locale_shard_map.AddLocalTags(locale_tag);
+ }
+ return locale_shard_map;
+}
+
+std::vector<Locale> LocaleShardMap::GetLocales(const int shard) const {
+ auto locale_it = shard_to_locale_data_.find(shard);
+ if (locale_it != shard_to_locale_data_.end()) {
+ return locale_it->second;
+ }
+ return std::vector<Locale>();
+}
+
+int LocaleShardMap::GetNumberOfShards() const {
+ return shard_to_locale_data_.size();
+}
+
+int LocaleShardMap::GetShard(const std::vector<Locale> locales) const {
+ for (const auto& [shard, locale_list] : shard_to_locale_data_) {
+ if (std::equal(locales.begin(), locales.end(), locale_list.begin())) {
+ return shard;
+ }
+ }
+ return 0;
+}
+
+int LocaleShardMap::GetShard(const std::string& locale_tags) const {
+ std::vector<Locale> locale_list = LocaleTagsToLocaleList(locale_tags);
+ return GetShard(locale_list);
+}
+
+void LocaleShardMap::AddLocalTags(const std::string& locale_tags) {
+ std::vector<Locale> locale_list = LocaleTagsToLocaleList(locale_tags);
+ int shard_id = shard_to_locale_data_.size();
+ shard_to_locale_data_.insert({shard_id, locale_list});
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/locale-shard-map.h b/native/utils/grammar/utils/locale-shard-map.h
new file mode 100644
index 0000000..5e0f5cb
--- /dev/null
+++ b/native/utils/grammar/utils/locale-shard-map.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "utils/grammar/types.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
+#include "utils/optional.h"
+
+namespace libtextclassifier3::grammar {
+
+// Grammar rules are associated with Locale which serve as a filter during rule
+// application. The class holds shard’s information for Locale which is used
+// when the Aqua rules are compiled into internal rules.proto flatbuffer.
+class LocaleShardMap {
+ public:
+ static LocaleShardMap CreateLocaleShardMap(
+ const std::vector<std::string>& locale_tags);
+
+ std::vector<Locale> GetLocales(const int shard) const;
+
+ int GetShard(const std::vector<Locale> locales) const;
+ int GetShard(const std::string& locale_tags) const;
+
+ int GetNumberOfShards() const;
+
+ private:
+ explicit LocaleShardMap() {}
+ void AddLocalTags(const std::string& locale_tag);
+
+ std::unordered_map<int, std::vector<Locale>> shard_to_locale_data_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_
diff --git a/native/utils/grammar/utils/locale-shard-map_test.cc b/native/utils/grammar/utils/locale-shard-map_test.cc
new file mode 100644
index 0000000..14c9081
--- /dev/null
+++ b/native/utils/grammar/utils/locale-shard-map_test.cc
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/locale-shard-map.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::SizeIs;
+
+TEST(LocaleShardMapTest, HandlesSimpleShard) {
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap(
+ {"ar-EG", "bn-BD", "cs-CZ", "da-DK", "de-DE", "en-US", "es-ES", "fi-FI",
+ "fr-FR", "gu-IN", "id-ID", "it-IT", "ja-JP", "kn-IN", "ko-KR", "ml-IN",
+ "mr-IN", "nl-NL", "no-NO", "pl-PL", "pt-BR", "ru-RU", "sv-SE", "ta-IN",
+ "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "vi-VN", "zh-TW"});
+
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 31);
+ for (int i = 0; i < 31; i++) {
+ EXPECT_THAT(locale_shard_map.GetLocales(i), SizeIs(1));
+ }
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[0], Locale::FromBCP47("ar-EG"));
+ EXPECT_EQ(locale_shard_map.GetLocales(8)[0], Locale::FromBCP47("fr-FR"));
+ EXPECT_EQ(locale_shard_map.GetLocales(16)[0], Locale::FromBCP47("mr-IN"));
+ EXPECT_EQ(locale_shard_map.GetLocales(24)[0], Locale::FromBCP47("te-IN"));
+ EXPECT_EQ(locale_shard_map.GetLocales(30)[0], Locale::FromBCP47("zh-TW"));
+}
+
+TEST(LocaleTagShardTest, HandlesWildCard) {
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({"*"});
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 1);
+ EXPECT_THAT(locale_shard_map.GetLocales(0), SizeIs(1));
+}
+
+TEST(LocaleTagShardTest, HandlesMultipleLocalePerShard) {
+ LocaleShardMap locale_shard_map =
+ LocaleShardMap::CreateLocaleShardMap({"ar-EG,bn-BD,cs-CZ", "en-*"});
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 2);
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[0], Locale::FromBCP47("ar-EG"));
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[1], Locale::FromBCP47("bn-BD"));
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[2], Locale::FromBCP47("cs-CZ"));
+ EXPECT_EQ(locale_shard_map.GetLocales(1)[0], Locale::FromBCP47("en"));
+
+ EXPECT_EQ(locale_shard_map.GetShard("ar-EG,bn-BD,cs-CZ"), 0);
+ EXPECT_EQ(locale_shard_map.GetShard("bn-BD,cs-CZ,ar-EG"), 0);
+ EXPECT_EQ(locale_shard_map.GetShard("bn-BD,ar-EG,cs-CZ"), 0);
+ EXPECT_EQ(locale_shard_map.GetShard("ar-EG,cs-CZ,bn-BD"), 0);
+}
+
+TEST(LocaleTagShardTest, HandlesEmptyLocaleTag) {
+ LocaleShardMap locale_shard_map =
+ LocaleShardMap::CreateLocaleShardMap({"", "en-US"});
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 2);
+ EXPECT_THAT(locale_shard_map.GetLocales(0), SizeIs(0));
+ EXPECT_THAT(locale_shard_map.GetLocales(1), SizeIs(1));
+ EXPECT_EQ(locale_shard_map.GetLocales(1)[0], Locale::FromBCP47("en-US"));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index 623124a..661514a 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -414,7 +414,7 @@ bool Rules::UsesFillers() const {
}
Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
- Ir rules(num_shards_);
+ Ir rules(locale_shard_map_);
std::unordered_map<int, Nonterm> nonterminal_ids;
// Pending rules to process.
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index a6851f3..4931e2f 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -55,7 +55,8 @@ constexpr const char* kFiller = "<filler>";
// internal representation.
class Rules {
public:
- explicit Rules(const int num_shards = 1) : num_shards_(num_shards) {}
+ explicit Rules(const LocaleShardMap& locale_shard_map)
+ : locale_shard_map_(locale_shard_map) {}
// Represents one item in a right-hand side, a single terminal or nonterminal.
struct RhsElement {
@@ -214,7 +215,7 @@ class Rules {
// Checks whether the fillers are used in any active rule.
bool UsesFillers() const;
- const int num_shards_;
+ const LocaleShardMap& locale_shard_map_;
// Non-terminal to id map.
std::unordered_map<std::string, int> nonterminal_names_;
diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc
index 8db88ab..c71f2b4 100644
--- a/native/utils/grammar/utils/rules_test.cc
+++ b/native/utils/grammar/utils/rules_test.cc
@@ -28,7 +28,9 @@ using ::testing::IsEmpty;
using ::testing::SizeIs;
TEST(SerializeRulesTest, HandlesSimpleRuleSet) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<verb>", {"buy"});
rules.Add("<verb>", {"bring"});
@@ -49,7 +51,9 @@ TEST(SerializeRulesTest, HandlesSimpleRuleSet) {
}
TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
const CallbackId output = 1;
rules.Add("<verb>", {"buy"});
@@ -73,7 +77,9 @@ TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) {
}
TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0,
@@ -89,7 +95,9 @@ TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) {
}
TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
/*case_sensitive=*/true);
rules.Add("<iata>", {"AA"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
@@ -109,7 +117,9 @@ TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) {
}
TEST(SerializeRulesTest, HandlesMultipleShards) {
- Rules rules(/*num_shards=*/2);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
/*case_sensitive=*/true, /*shard=*/0);
rules.Add("<iata>", {"aa"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
@@ -124,7 +134,10 @@ TEST(SerializeRulesTest, HandlesMultipleShards) {
}
TEST(SerializeRulesTest, HandlesRegexRules) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ // Rules rules;
rules.AddRegex("<code>", "[A-Z]+");
rules.AddRegex("<numbers>", "\\d+");
RulesSetT frozen_rules;
@@ -134,7 +147,9 @@ TEST(SerializeRulesTest, HandlesRegexRules) {
}
TEST(SerializeRulesTest, HandlesAlias) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
rules.Add("<flight>", {"<iata>", "<4_digits>"});
@@ -155,7 +170,9 @@ TEST(SerializeRulesTest, HandlesAlias) {
}
TEST(SerializeRulesTest, ResolvesAnchorsAndFillers) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<code>",
{"<^>", "<filler>", "this", "is", "a", "test", "<filler>", "<$>"});
const Ir ir = rules.Finalize();
@@ -177,7 +194,9 @@ TEST(SerializeRulesTest, ResolvesAnchorsAndFillers) {
}
TEST(SerializeRulesTest, HandlesFillers) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<test>", {"<filler>?", "a", "test"});
const Ir ir = rules.Finalize();
RulesSetT frozen_rules;
@@ -198,7 +217,9 @@ TEST(SerializeRulesTest, HandlesFillers) {
}
TEST(SerializeRulesTest, HandlesAnnotations) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.AddAnnotation("phone");
rules.AddAnnotation("url");
rules.AddAnnotation("tracking_number");
diff --git a/native/utils/i18n/locale.cc b/native/utils/i18n/locale.cc
index d5a1109..3719079 100644
--- a/native/utils/i18n/locale.cc
+++ b/native/utils/i18n/locale.cc
@@ -16,6 +16,8 @@
#include "utils/i18n/locale.h"
+#include <string>
+
#include "utils/strings/split.h"
namespace libtextclassifier3 {
@@ -196,6 +198,20 @@ bool Locale::IsAnyLocaleSupported(const std::vector<Locale>& locales,
return false;
}
+bool Locale::operator==(const Locale& locale) const {
+ return language_ == locale.language_ && region_ == locale.region_ &&
+ script_ == locale.script_;
+}
+
+bool Locale::operator<(const Locale& locale) const {
+ return std::tie(language_, region_, script_) <
+ std::tie(locale.language_, locale.region_, locale.script_);
+}
+
+bool Locale::operator!=(const Locale& locale) const {
+ return !(*this == locale);
+}
+
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Locale& locale) {
return stream << "Locale(language=" << locale.Language()
diff --git a/native/utils/i18n/locale.h b/native/utils/i18n/locale.h
index 308846d..036bacd 100644
--- a/native/utils/i18n/locale.h
+++ b/native/utils/i18n/locale.h
@@ -60,6 +60,10 @@ class Locale {
const std::vector<Locale>& supported_locales,
bool default_value);
+ bool operator==(const Locale& locale) const;
+ bool operator!=(const Locale& locale) const;
+ bool operator<(const Locale& locale) const;
+
private:
Locale(const std::string& language, const std::string& script,
const std::string& region)
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 2dbd786..e491130 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -62,6 +62,15 @@ TfLiteRegistration* Register_WHERE();
TfLiteRegistration* Register_ONE_HOT();
TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_TANH();
+#ifndef TC3_AOSP
+TfLiteRegistration* Register_REDUCE_PROD();
+TfLiteRegistration* Register_SHAPE();
+TfLiteRegistration* Register_NOT_EQUAL();
+TfLiteRegistration* Register_CUMSUM();
+TfLiteRegistration* Register_EXPAND_DIMS();
+TfLiteRegistration* Register_FILL();
+TfLiteRegistration* Register_PADV2();
+#endif // TC3_AOSP
} // namespace builtin
} // namespace ops
} // namespace tflite
@@ -70,6 +79,17 @@ TfLiteRegistration* Register_TANH();
#include "utils/tflite/dist_diversification.h"
#include "utils/tflite/text_encoder.h"
#include "utils/tflite/token_encoder.h"
+#ifndef TC3_AOSP
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
+TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
+TfLiteRegistration* Register_RAGGED_RANGE();
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+#endif // TC3_AOSP
void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
@@ -191,6 +211,22 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
tflite::ops::builtin::Register_TANH(),
/*min_version=*/1,
/*max_version=*/1);
+#ifndef TC3_AOSP
+ resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD,
+ ::tflite::ops::builtin::Register_REDUCE_PROD());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
+ ::tflite::ops::builtin::Register_SHAPE());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL,
+ ::tflite::ops::builtin::Register_NOT_EQUAL());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM,
+ ::tflite::ops::builtin::Register_CUMSUM());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
+ ::tflite::ops::builtin::Register_EXPAND_DIMS());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_FILL,
+ ::tflite::ops::builtin::Register_FILL());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2,
+ ::tflite::ops::builtin::Register_PADV2());
+#endif // TC3_AOSP
}
#else
void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
@@ -222,6 +258,16 @@ std::unique_ptr<tflite::OpResolver> BuildOpResolver(
tflite::ops::custom::Register_TEXT_ENCODER());
resolver->AddCustom("TokenEncoder",
tflite::ops::custom::Register_TOKEN_ENCODER());
+#ifndef TC3_AOSP
+ resolver->AddCustom(
+ "TFSentencepieceTokenizeOp",
+ ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
+ resolver->AddCustom("RaggedRange",
+ ::tflite::ops::custom::Register_RAGGED_RANGE());
+ resolver->AddCustom(
+ "RaggedTensorToTensor",
+ ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
+#endif // TC3_AOSP
#endif // TC3_WITH_ACTIONS_OPS
customize_fn(resolver.get());
return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
diff --git a/native/utils/test-utils.cc b/native/utils/tokenizer-utils.cc
index 9e3216a..7d07b0c 100644
--- a/native/utils/test-utils.cc
+++ b/native/utils/tokenizer-utils.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "utils/test-utils.h"
+#include "utils/tokenizer-utils.h"
#include <iterator>
diff --git a/native/utils/test-utils.h b/native/utils/tokenizer-utils.h
index 184e60a..553791b 100644
--- a/native/utils/test-utils.h
+++ b/native/utils/tokenizer-utils.h
@@ -16,8 +16,8 @@
// Utilities for tests.
-#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_
#include <string>
@@ -40,4 +40,4 @@ std::vector<Token> TokenizeOnDelimiters(
} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_
diff --git a/native/utils/test-utils_test.cc b/native/utils/tokenizer-utils_test.cc
index 88a3ec1..9c632bd 100644
--- a/native/utils/test-utils_test.cc
+++ b/native/utils/tokenizer-utils_test.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "utils/test-utils.h"
+#include "utils/tokenizer-utils.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -22,7 +22,7 @@
namespace libtextclassifier3 {
namespace {
-TEST(TestUtilTest, TokenizeOnSpace) {
+TEST(TokenizerUtilTest, TokenizeOnSpace) {
std::vector<Token> tokens =
TokenizeOnSpace("Where is Jörg Borg located? Maybe in Zürich ...");
@@ -65,7 +65,7 @@ TEST(TestUtilTest, TokenizeOnSpace) {
EXPECT_EQ(tokens[8].end, 47);
}
-TEST(TestUtilTest, TokenizeOnDelimiters) {
+TEST(TokenizerUtilTest, TokenizeOnDelimiters) {
std::vector<Token> tokens = TokenizeOnDelimiters(
"This might be čomplíčateď?!: Oder?", {' ', '?', '!'});
@@ -96,7 +96,7 @@ TEST(TestUtilTest, TokenizeOnDelimiters) {
EXPECT_EQ(tokens[5].end, 35);
}
-TEST(TestUtilTest, TokenizeOnDelimitersKeepNoSpace) {
+TEST(TokenizerUtilTest, TokenizeOnDelimitersKeepNoSpace) {
std::vector<Token> tokens = TokenizeOnDelimiters(
"This might be čomplíčateď?!: Oder?", {' ', '?', '!'},
/* create_tokens_for_non_space_delimiters =*/true);
diff --git a/native/utils/utf8/unicodetext.cc b/native/utils/utf8/unicodetext.cc
index 2ddd38c..d05e377 100644
--- a/native/utils/utf8/unicodetext.cc
+++ b/native/utils/utf8/unicodetext.cc
@@ -210,6 +210,14 @@ std::vector<UnicodeText::const_iterator> UnicodeText::Codepoints() const {
return codepoints;
}
+std::vector<char32> UnicodeText::CodepointsChar32() const {
+ std::vector<char32> codepoints;
+ for (auto it = begin(); it != end(); it++) {
+ codepoints.push_back(*it);
+ }
+ return codepoints;
+}
+
bool UnicodeText::operator==(const UnicodeText& other) const {
if (repr_.size_ != other.repr_.size_) {
return false;
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index 4ca0dd2..4c1c3ce 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -178,6 +178,9 @@ class UnicodeText {
// Returns an iterator for each codepoint.
std::vector<const_iterator> Codepoints() const;
+ // Returns the list of codepoints of the UnicodeText.
+ std::vector<char32> CodepointsChar32() const;
+
std::string ToUTF8String() const;
std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
static std::string UTF8Substring(const const_iterator& it_begin,