summaryrefslogtreecommitdiff
path: root/annotator
diff options
context:
space:
mode:
Diffstat (limited to 'annotator')
-rw-r--r--annotator/annotator.cc1695
-rw-r--r--annotator/annotator.h393
-rw-r--r--annotator/annotator_jni.cc434
-rw-r--r--annotator/annotator_jni.h103
-rw-r--r--annotator/annotator_jni_common.cc100
-rw-r--r--annotator/annotator_jni_common.h41
-rw-r--r--annotator/annotator_jni_test.cc79
-rw-r--r--annotator/annotator_test.cc1254
-rw-r--r--annotator/cached-features.cc173
-rw-r--r--annotator/cached-features.h83
-rw-r--r--annotator/cached-features_test.cc157
-rw-r--r--annotator/datetime/extractor.cc469
-rw-r--r--annotator/datetime/extractor.h111
-rw-r--r--annotator/datetime/parser.cc406
-rw-r--r--annotator/datetime/parser.h118
-rw-r--r--annotator/datetime/parser_test.cc413
-rw-r--r--annotator/feature-processor.cc988
-rw-r--r--annotator/feature-processor.h331
-rw-r--r--annotator/feature-processor_test.cc1125
-rw-r--r--annotator/knowledge/knowledge-engine-dummy.h47
-rw-r--r--annotator/knowledge/knowledge-engine.h22
-rw-r--r--annotator/model-executor.cc124
-rw-r--r--annotator/model-executor.h111
-rwxr-xr-xannotator/model.fbs583
-rw-r--r--annotator/quantization.cc92
-rw-r--r--annotator/quantization.h39
-rw-r--r--annotator/quantization_test.cc163
-rw-r--r--annotator/strip-unpaired-brackets.cc105
-rw-r--r--annotator/strip-unpaired-brackets.h38
-rw-r--r--annotator/strip-unpaired-brackets_test.cc66
-rw-r--r--annotator/test_data/test_model.fbbin0 -> 522688 bytes
-rw-r--r--annotator/test_data/test_model_cc.fbbin0 -> 552160 bytes
-rw-r--r--annotator/test_data/wrong_embeddings.fbbin0 -> 288628 bytes
-rw-r--r--annotator/token-feature-extractor.cc311
-rw-r--r--annotator/token-feature-extractor.h115
-rw-r--r--annotator/token-feature-extractor_test.cc556
-rw-r--r--annotator/tokenizer.cc126
-rw-r--r--annotator/tokenizer.h71
-rw-r--r--annotator/tokenizer_test.cc334
-rw-r--r--annotator/types-test-util.h49
-rw-r--r--annotator/types.h402
-rw-r--r--annotator/zlib-utils.cc128
-rw-r--r--annotator/zlib-utils.h37
-rw-r--r--annotator/zlib-utils_test.cc99
44 files changed, 12091 insertions, 0 deletions
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
new file mode 100644
index 0000000..2be9d3c
--- /dev/null
+++ b/annotator/annotator.cc
@@ -0,0 +1,1695 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator.h"
+
+#include <algorithm>
+#include <cctype>
+#include <cmath>
+#include <iterator>
+#include <numeric>
+
+#include "utils/base/logging.h"
+#include "utils/checksum.h"
+#include "utils/math/softmax.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+const std::string& Annotator::kOtherCollection =
+ *[]() { return new std::string("other"); }();
+const std::string& Annotator::kPhoneCollection =
+ *[]() { return new std::string("phone"); }();
+const std::string& Annotator::kAddressCollection =
+ *[]() { return new std::string("address"); }();
+const std::string& Annotator::kDateCollection =
+ *[]() { return new std::string("date"); }();
+const std::string& Annotator::kUrlCollection =
+ *[]() { return new std::string("url"); }();
+const std::string& Annotator::kFlightCollection =
+ *[]() { return new std::string("flight"); }();
+const std::string& Annotator::kEmailCollection =
+ *[]() { return new std::string("email"); }();
+const std::string& Annotator::kIbanCollection =
+ *[]() { return new std::string("iban"); }();
+const std::string& Annotator::kPaymentCardCollection =
+ *[]() { return new std::string("payment_card"); }();
+const std::string& Annotator::kIsbnCollection =
+ *[]() { return new std::string("isbn"); }();
+const std::string& Annotator::kTrackingNumberCollection =
+ *[]() { return new std::string("tracking_number"); }();
+
+namespace {
+const Model* LoadAndVerifyModel(const void* addr, int size) {
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
+ if (VerifyModelBuffer(verifier)) {
+ return GetModel(addr);
+ } else {
+ return nullptr;
+ }
+}
+
+// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
+// create a new instance, assign ownership to owned_lib, and return it.
+const UniLib* MaybeCreateUnilib(const UniLib* lib,
+ std::unique_ptr<UniLib>* owned_lib) {
+ if (lib) {
+ return lib;
+ } else {
+ owned_lib->reset(new UniLib);
+ return owned_lib->get();
+ }
+}
+
+// As above, but for CalendarLib.
+const CalendarLib* MaybeCreateCalendarlib(
+ const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
+ if (lib) {
+ return lib;
+ } else {
+ owned_lib->reset(new CalendarLib);
+ return owned_lib->get();
+ }
+}
+
+} // namespace
+
+tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
+ if (!selection_interpreter_) {
+ TC3_CHECK(selection_executor_);
+ selection_interpreter_ = selection_executor_->CreateInterpreter();
+ if (!selection_interpreter_) {
+ TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
+ }
+ }
+ return selection_interpreter_.get();
+}
+
+tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
+ if (!classification_interpreter_) {
+ TC3_CHECK(classification_executor_);
+ classification_interpreter_ = classification_executor_->CreateInterpreter();
+ if (!classification_interpreter_) {
+ TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
+ }
+ }
+ return classification_interpreter_.get();
+}
+
+std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ const Model* model = LoadAndVerifyModel(buffer, size);
+ if (model == nullptr) {
+ return nullptr;
+ }
+
+ auto classifier =
+ std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+
+std::unique_ptr<Annotator> Annotator::FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ if (!(*mmap)->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+
+ const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
+ (*mmap)->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+
+ auto classifier = std::unique_ptr<Annotator>(
+ new Annotator(mmap, model, unilib, calendarlib));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, int offset, int size, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return FromScopedMmap(&mmap, unilib, calendarlib);
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
+ return FromScopedMmap(&mmap, unilib, calendarlib);
+}
+
+std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
+ const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return FromScopedMmap(&mmap, unilib, calendarlib);
+}
+
+Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ const UniLib* unilib, const CalendarLib* calendarlib)
+ : model_(model),
+ mmap_(std::move(*mmap)),
+ owned_unilib_(nullptr),
+ unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
+ owned_calendarlib_(nullptr),
+ calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
+ ValidateAndInitialize();
+}
+
+Annotator::Annotator(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib)
+ : model_(model),
+ owned_unilib_(nullptr),
+ unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
+ owned_calendarlib_(nullptr),
+ calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
+ ValidateAndInitialize();
+}
+
+void Annotator::ValidateAndInitialize() {
+ initialized_ = false;
+
+ if (model_ == nullptr) {
+ TC3_LOG(ERROR) << "No model specified.";
+ return;
+ }
+
+ const bool model_enabled_for_annotation =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
+ const bool model_enabled_for_classification =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION));
+ const bool model_enabled_for_selection =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
+
+ // Annotation requires the selection model.
+ if (model_enabled_for_annotation || model_enabled_for_selection) {
+ if (!model_->selection_options()) {
+ TC3_LOG(ERROR) << "No selection options.";
+ return;
+ }
+ if (!model_->selection_feature_options()) {
+ TC3_LOG(ERROR) << "No selection feature options.";
+ return;
+ }
+ if (!model_->selection_feature_options()->bounds_sensitive_features()) {
+ TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
+ return;
+ }
+ if (!model_->selection_model()) {
+ TC3_LOG(ERROR) << "No selection model.";
+ return;
+ }
+ selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
+ if (!selection_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize selection executor.";
+ return;
+ }
+ selection_feature_processor_.reset(
+ new FeatureProcessor(model_->selection_feature_options(), unilib_));
+ }
+
+ // Annotation requires the classification model for conflict resolution and
+ // scoring.
+ // Selection requires the classification model for conflict resolution.
+ if (model_enabled_for_annotation || model_enabled_for_classification ||
+ model_enabled_for_selection) {
+ if (!model_->classification_options()) {
+ TC3_LOG(ERROR) << "No classification options.";
+ return;
+ }
+
+ if (!model_->classification_feature_options()) {
+ TC3_LOG(ERROR) << "No classification feature options.";
+ return;
+ }
+
+ if (!model_->classification_feature_options()
+ ->bounds_sensitive_features()) {
+ TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
+ return;
+ }
+ if (!model_->classification_model()) {
+ TC3_LOG(ERROR) << "No clf model.";
+ return;
+ }
+
+ classification_executor_ =
+ ModelExecutor::FromBuffer(model_->classification_model());
+ if (!classification_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize classification executor.";
+ return;
+ }
+
+ classification_feature_processor_.reset(new FeatureProcessor(
+ model_->classification_feature_options(), unilib_));
+ }
+
+ // The embeddings need to be specified if the model is to be used for
+ // classification or selection.
+ if (model_enabled_for_annotation || model_enabled_for_classification ||
+ model_enabled_for_selection) {
+ if (!model_->embedding_model()) {
+ TC3_LOG(ERROR) << "No embedding model.";
+ return;
+ }
+
+ // Check that the embedding size of the selection and classification model
+ // matches, as they are using the same embeddings.
+ if (model_enabled_for_selection &&
+ (model_->selection_feature_options()->embedding_size() !=
+ model_->classification_feature_options()->embedding_size() ||
+ model_->selection_feature_options()->embedding_quantization_bits() !=
+ model_->classification_feature_options()
+ ->embedding_quantization_bits())) {
+ TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
+ return;
+ }
+
+ embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
+ model_->embedding_model(),
+ model_->classification_feature_options()->embedding_size(),
+ model_->classification_feature_options()
+ ->embedding_quantization_bits());
+ if (!embedding_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize embedding executor.";
+ return;
+ }
+ }
+
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ if (model_->regex_model()) {
+ if (!InitializeRegexModel(decompressor.get())) {
+ TC3_LOG(ERROR) << "Could not initialize regex model.";
+ return;
+ }
+ }
+
+ if (model_->datetime_model()) {
+ datetime_parser_ = DatetimeParser::Instance(
+ model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
+ if (!datetime_parser_) {
+ TC3_LOG(ERROR) << "Could not initialize datetime parser.";
+ return;
+ }
+ }
+
+ if (model_->output_options()) {
+ if (model_->output_options()->filtered_collections_annotation()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_annotation()) {
+ filtered_collections_annotation_.insert(collection->str());
+ }
+ }
+ if (model_->output_options()->filtered_collections_classification()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_classification()) {
+ filtered_collections_classification_.insert(collection->str());
+ }
+ }
+ if (model_->output_options()->filtered_collections_selection()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_selection()) {
+ filtered_collections_selection_.insert(collection->str());
+ }
+ }
+ }
+
+ initialized_ = true;
+}
+
+bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
+ if (!model_->regex_model()->patterns()) {
+ return true;
+ }
+
+ // Initialize pattern recognizers.
+ int regex_pattern_id = 0;
+ for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
+ UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(),
+ regex_pattern->compressed_pattern(),
+ decompressor);
+ if (!compiled_pattern) {
+ TC3_LOG(INFO) << "Failed to load regex pattern";
+ return false;
+ }
+
+ if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
+ annotation_regex_patterns_.push_back(regex_pattern_id);
+ }
+ if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
+ classification_regex_patterns_.push_back(regex_pattern_id);
+ }
+ if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
+ selection_regex_patterns_.push_back(regex_pattern_id);
+ }
+ regex_patterns_.push_back({
+ regex_pattern->collection_name()->str(),
+ regex_pattern->target_classification_score(),
+ regex_pattern->priority_score(),
+ std::move(compiled_pattern),
+ regex_pattern->verification_options(),
+ });
+ if (regex_pattern->use_approximate_matching()) {
+ regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
+ }
+ ++regex_pattern_id;
+ }
+
+ return true;
+}
+
+bool Annotator::InitializeKnowledgeEngine(
+ const std::string& serialized_config) {
+ std::unique_ptr<KnowledgeEngine> knowledge_engine(
+ new KnowledgeEngine(unilib_));
+ if (!knowledge_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
+ return false;
+ }
+ knowledge_engine_ = std::move(knowledge_engine);
+ return true;
+}
+
+namespace {
+
+int CountDigits(const std::string& str, CodepointSpan selection_indices) {
+ int count = 0;
+ int i = 0;
+ const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
+ for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
+ if (i >= selection_indices.first && i < selection_indices.second &&
+ isdigit(*it)) {
+ ++count;
+ }
+ }
+ return count;
+}
+
+std::string ExtractSelection(const std::string& context,
+ CodepointSpan selection_indices) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ auto selection_begin = context_unicode.begin();
+ std::advance(selection_begin, selection_indices.first);
+ auto selection_end = context_unicode.begin();
+ std::advance(selection_end, selection_indices.second);
+ return UnicodeText::UTF8Substring(selection_begin, selection_end);
+}
+
+bool VerifyCandidate(const VerificationOptions* verification_options,
+ const std::string& match) {
+ if (!verification_options) {
+ return true;
+ }
+ if (verification_options->verify_luhn_checksum() &&
+ !VerifyLuhnChecksum(match)) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+namespace internal {
+// Helper function, which if the initial 'span' contains only white-spaces,
+// moves the selection to a single-codepoint selection on a left or right side
+// of this space.
+CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UniLib& unilib) {
+ TC3_CHECK(ValidNonEmptySpan(span));
+
+ UnicodeText::const_iterator it;
+
+ // Check that the current selection is all whitespaces.
+ it = context_unicode.begin();
+ std::advance(it, span.first);
+ for (int i = 0; i < (span.second - span.first); ++i, ++it) {
+ if (!unilib.IsWhitespace(*it)) {
+ return span;
+ }
+ }
+
+ CodepointSpan result;
+
+ // Try moving left.
+ result = span;
+ it = context_unicode.begin();
+ std::advance(it, span.first);
+ while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
+ --result.first;
+ --it;
+ }
+ result.second = result.first + 1;
+ if (!unilib.IsWhitespace(*it)) {
+ return result;
+ }
+
+ // If moving left didn't find a non-whitespace character, just return the
+ // original span.
+ return span;
+}
+} // namespace internal
+
+bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
+ return !span.classification.empty() &&
+ filtered_collections_annotation_.find(
+ span.classification[0].collection) !=
+ filtered_collections_annotation_.end();
+}
+
+bool Annotator::FilteredForClassification(
+ const ClassificationResult& classification) const {
+ return filtered_collections_classification_.find(classification.collection) !=
+ filtered_collections_classification_.end();
+}
+
+bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
+ return !span.classification.empty() &&
+ filtered_collections_selection_.find(
+ span.classification[0].collection) !=
+ filtered_collections_selection_.end();
+}
+
+CodepointSpan Annotator::SuggestSelection(
+ const std::string& context, CodepointSpan click_indices,
+ const SelectionOptions& options) const {
+ CodepointSpan original_click_indices = click_indices;
+ if (!initialized_) {
+ TC3_LOG(ERROR) << "Not initialized";
+ return original_click_indices;
+ }
+ if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
+ return original_click_indices;
+ }
+
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ if (!context_unicode.is_valid()) {
+ return original_click_indices;
+ }
+
+ const int context_codepoint_size = context_unicode.size_codepoints();
+
+ if (click_indices.first < 0 || click_indices.second < 0 ||
+ click_indices.first >= context_codepoint_size ||
+ click_indices.second > context_codepoint_size ||
+ click_indices.first >= click_indices.second) {
+ TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
+ << click_indices.first << " " << click_indices.second;
+ return original_click_indices;
+ }
+
+ if (model_->snap_whitespace_selections()) {
+ // We want to expand a purely white-space selection to a multi-selection it
+ // would've been part of. But with this feature disabled we would do a no-
+ // op, because no token is found. Therefore, we need to modify the
+ // 'click_indices' a bit to include a part of the token, so that the click-
+ // finding logic finds the clicked token correctly. This modification is
+ // done by the following function. Note, that it's enough to check the left
+ // side of the current selection, because if the white-space is a part of a
+ // multi-selection, necessarily both tokens - on the left and the right
+ // sides need to be selected. Thus snapping only to the left is sufficient
+ // (there's a check at the bottom that makes sure that if we snap to the
+ // left token but the result does not contain the initial white-space,
+ // returns the original indices).
+ click_indices = internal::SnapLeftIfWhitespaceSelection(
+ click_indices, context_unicode, *unilib_);
+ }
+
+ std::vector<AnnotatedSpan> candidates;
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ std::vector<Token> tokens;
+ if (!ModelSuggestSelection(context_unicode, click_indices,
+ &interpreter_manager, &tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Model suggest selection failed.";
+ return original_click_indices;
+ }
+ if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
+ TC3_LOG(ERROR) << "Regex suggest selection failed.";
+ return original_click_indices;
+ }
+ if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
+ options.locales, ModeFlag_SELECTION, &candidates)) {
+ TC3_LOG(ERROR) << "Datetime suggest selection failed.";
+ return original_click_indices;
+ }
+ if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
+ TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
+ return original_click_indices;
+ }
+
+ // Sort candidates according to their position in the input, so that the next
+ // code can assume that any connected component of overlapping spans forms a
+ // contiguous block.
+ std::sort(candidates.begin(), candidates.end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
+ &candidate_indices)) {
+ TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return original_click_indices;
+ }
+
+ for (const int i : candidate_indices) {
+ if (SpansOverlap(candidates[i].span, click_indices) &&
+ SpansOverlap(candidates[i].span, original_click_indices)) {
+ // Run model classification if not present but requested and there's a
+ // classification collection filter specified.
+ if (candidates[i].classification.empty() &&
+ model_->selection_options()->always_classify_suggested_selection() &&
+ !filtered_collections_selection_.empty()) {
+ if (!ModelClassifyText(
+ context, candidates[i].span, &interpreter_manager,
+ /*embedding_cache=*/nullptr, &candidates[i].classification)) {
+ return original_click_indices;
+ }
+ }
+
+ // Ignore if span classification is filtered.
+ if (FilteredForSelection(candidates[i])) {
+ return original_click_indices;
+ }
+
+ return candidates[i].span;
+ }
+ }
+
+ return original_click_indices;
+}
+
+namespace {
+// Helper function that returns the index of the first candidate that
+// transitively does not overlap with the candidate on 'start_index'. If the end
+// of 'candidates' is reached, it returns the index that points right behind the
+// array.
+int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
+ int start_index) {
+ int first_non_overlapping = start_index + 1;
+ CodepointSpan conflicting_span = candidates[start_index].span;
+ while (
+ first_non_overlapping < candidates.size() &&
+ SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
+ // Grow the span to include the current one.
+ conflicting_span.second = std::max(
+ conflicting_span.second, candidates[first_non_overlapping].span.second);
+
+ ++first_non_overlapping;
+ }
+ return first_non_overlapping;
+}
+} // namespace
+
+bool Annotator::ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
+ const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* result) const {
+ result->clear();
+ result->reserve(candidates.size());
+ for (int i = 0; i < candidates.size();) {
+ int first_non_overlapping =
+ FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
+
+ const bool conflict_found = first_non_overlapping != (i + 1);
+ if (conflict_found) {
+ std::vector<int> candidate_indices;
+ if (!ResolveConflict(context, cached_tokens, candidates, i,
+ first_non_overlapping, interpreter_manager,
+ &candidate_indices)) {
+ return false;
+ }
+ result->insert(result->end(), candidate_indices.begin(),
+ candidate_indices.end());
+ } else {
+ result->push_back(i);
+ }
+
+ // Skip over the whole conflicting group/go to next candidate.
+ i = first_non_overlapping;
+ }
+ return true;
+}
+
+namespace {
+inline bool ClassifiedAsOther(
+ const std::vector<ClassificationResult>& classification) {
+ return !classification.empty() &&
+ classification[0].collection == Annotator::kOtherCollection;
+}
+
+float GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) {
+ if (!ClassifiedAsOther(classification)) {
+ return classification[0].priority_score;
+ } else {
+ return -1.0;
+ }
+}
+} // namespace
+
+bool Annotator::ResolveConflict(const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ const std::vector<AnnotatedSpan>& candidates,
+ int start_index, int end_index,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* chosen_indices) const {
+ std::vector<int> conflicting_indices;
+ std::unordered_map<int, float> scores;
+ for (int i = start_index; i < end_index; ++i) {
+ conflicting_indices.push_back(i);
+ if (!candidates[i].classification.empty()) {
+ scores[i] = GetPriorityScore(candidates[i].classification);
+ continue;
+ }
+
+ // OPTIMIZATION: So that we don't have to classify all the ML model
+ // spans apriori, we wait until we get here, when they conflict with
+ // something and we need the actual classification scores. So if the
+ // candidate conflicts and comes from the model, we need to run a
+ // classification to determine its priority:
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(context, cached_tokens, candidates[i].span,
+ interpreter_manager,
+ /*embedding_cache=*/nullptr, &classification)) {
+ return false;
+ }
+
+ if (!classification.empty()) {
+ scores[i] = GetPriorityScore(classification);
+ }
+ }
+
+ std::sort(conflicting_indices.begin(), conflicting_indices.end(),
+ [&scores](int i, int j) { return scores[i] > scores[j]; });
+
+ // Keeps the candidates sorted by their position in the text (their left span
+ // index) for fast retrieval down.
+ std::set<int, std::function<bool(int, int)>> chosen_indices_set(
+ [&candidates](int a, int b) {
+ return candidates[a].span.first < candidates[b].span.first;
+ });
+
+ // Greedily place the candidates if they don't conflict with the already
+ // placed ones.
+ for (int i = 0; i < conflicting_indices.size(); ++i) {
+ const int considered_candidate = conflicting_indices[i];
+ if (!DoesCandidateConflict(considered_candidate, candidates,
+ chosen_indices_set)) {
+ chosen_indices_set.insert(considered_candidate);
+ }
+ }
+
+ *chosen_indices =
+ std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end());
+
+ return true;
+}
+
+bool Annotator::ModelSuggestSelection(
+ const UnicodeText& context_unicode, CodepointSpan click_indices,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
+ return true;
+ }
+
+ int click_pos;
+ *tokens = selection_feature_processor_->Tokenize(context_unicode);
+ selection_feature_processor_->RetokenizeAndFindClick(
+ context_unicode, click_indices,
+ selection_feature_processor_->GetOptions()->only_use_line_with_click(),
+ tokens, &click_pos);
+ if (click_pos == kInvalidIndex) {
+ TC3_VLOG(1) << "Could not calculate the click position.";
+ return false;
+ }
+
+ const int symmetry_context_size =
+ model_->selection_options()->symmetry_context_size();
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features = selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+
+ // The symmetry context span is the clicked token with symmetry_context_size
+ // tokens on either side.
+ const TokenSpan symmetry_context_span = IntersectTokenSpans(
+ ExpandTokenSpan(SingleTokenSpan(click_pos),
+ /*num_tokens_left=*/symmetry_context_size,
+ /*num_tokens_right=*/symmetry_context_size),
+ {0, tokens->size()});
+
+ // Compute the extraction span based on the model type.
+ TokenSpan extraction_span;
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the symmetry context span expanded to include
+ // max_selection_span tokens on either side, which is how far a selection
+ // can stretch from the click, plus a relevant number of tokens outside of
+ // the bounds of the selection.
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ extraction_span =
+ ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_after());
+ } else {
+ // The extraction span is the symmetry context span expanded to include
+ // context_size tokens on either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ extraction_span = ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
+ }
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+
+ if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, extraction_span)) {
+ return true;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!selection_feature_processor_->ExtractFeatures(
+ *tokens, extraction_span,
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ embedding_executor_.get(),
+ /*embedding_cache=*/nullptr,
+ selection_feature_processor_->EmbeddingSize() +
+ selection_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC3_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ // Produce selection model candidates.
+ std::vector<TokenSpan> chunks;
+ if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
+ interpreter_manager->SelectionInterpreter(), *cached_features,
+ &chunks)) {
+ TC3_LOG(ERROR) << "Could not chunk.";
+ return false;
+ }
+
+ for (const TokenSpan& chunk : chunks) {
+ AnnotatedSpan candidate;
+ candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
+ context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ candidate.span =
+ StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
+ }
+
+ // Only output non-empty spans.
+ if (candidate.span.first != candidate.span.second) {
+ result->push_back(candidate);
+ }
+ }
+ return true;
+}
+
+bool Annotator::ModelClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
+ return true;
+ }
+ return ModelClassifyText(context, {}, selection_indices, interpreter_manager,
+ embedding_cache, classification_results);
+}
+
+namespace internal {
+std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices,
+ TokenSpan tokens_around_selection_to_copy) {
+ const auto first_selection_token = std::upper_bound(
+ cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
+ [](int selection_start, const Token& token) {
+ return selection_start < token.end;
+ });
+ const auto last_selection_token = std::lower_bound(
+ cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
+ [](const Token& token, int selection_end) {
+ return token.start < selection_end;
+ });
+
+ const int64 first_token = std::max(
+ static_cast<int64>(0),
+ static_cast<int64>((first_selection_token - cached_tokens.begin()) -
+ tokens_around_selection_to_copy.first));
+ const int64 last_token = std::min(
+ static_cast<int64>(cached_tokens.size()),
+ static_cast<int64>((last_selection_token - cached_tokens.begin()) +
+ tokens_around_selection_to_copy.second));
+
+ std::vector<Token> tokens;
+ tokens.reserve(last_token - first_token);
+ for (int i = first_token; i < last_token; ++i) {
+ tokens.push_back(cached_tokens[i]);
+ }
+ return tokens;
+}
+} // namespace internal
+
+TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ classification_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ return {bounds_sensitive_features->num_tokens_before(),
+ bounds_sensitive_features->num_tokens_after()};
+ } else {
+ // The extraction span is the clicked token with context_size tokens on
+ // either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ return {context_size, context_size};
+ }
+}
+
+bool Annotator::ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const {
+ std::vector<Token> tokens;
+ if (cached_tokens.empty()) {
+ tokens = classification_feature_processor_->Tokenize(context);
+ } else {
+ tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
+ ClassifyTextUpperBoundNeededTokens());
+ }
+
+ int click_pos;
+ classification_feature_processor_->RetokenizeAndFindClick(
+ context, selection_indices,
+ classification_feature_processor_->GetOptions()
+ ->only_use_line_with_click(),
+ &tokens, &click_pos);
+ const TokenSpan selection_token_span =
+ CodepointSpanToTokenSpan(tokens, selection_indices);
+ const int selection_num_tokens = TokenSpanSize(selection_token_span);
+ if (model_->classification_options()->max_num_tokens() > 0 &&
+ model_->classification_options()->max_num_tokens() <
+ selection_num_tokens) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ return true;
+ }
+
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ classification_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+ if (selection_token_span.first == kInvalidIndex ||
+ selection_token_span.second == kInvalidIndex) {
+ TC3_LOG(ERROR) << "Could not determine span.";
+ return false;
+ }
+
+ // Compute the extraction span based on the model type.
+ TokenSpan extraction_span;
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ extraction_span = ExpandTokenSpan(
+ selection_token_span,
+ /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
+ } else {
+ if (click_pos == kInvalidIndex) {
+ TC3_LOG(ERROR) << "Couldn't choose a click position.";
+ return false;
+ }
+ // The extraction span is the clicked token with context_size tokens on
+ // either side.
+ const int context_size =
+ classification_feature_processor_->GetOptions()->context_size();
+ extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
+ }
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
+
+ if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
+ tokens, extraction_span)) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ return true;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!classification_feature_processor_->ExtractFeatures(
+ tokens, extraction_span, selection_indices, embedding_executor_.get(),
+ embedding_cache,
+ classification_feature_processor_->EmbeddingSize() +
+ classification_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC3_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ std::vector<float> features;
+ features.reserve(cached_features->OutputFeaturesSize());
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
+ &features);
+ } else {
+ cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
+ }
+
+ TensorView<float> logits = classification_executor_->ComputeLogits(
+ TensorView<float>(features.data(),
+ {1, static_cast<int>(features.size())}),
+ interpreter_manager->ClassificationInterpreter());
+ if (!logits.is_valid()) {
+ TC3_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+
+ if (logits.dims() != 2 || logits.dim(0) != 1 ||
+ logits.dim(1) != classification_feature_processor_->NumCollections()) {
+ TC3_LOG(ERROR) << "Mismatching output";
+ return false;
+ }
+
+ const std::vector<float> scores =
+ ComputeSoftmax(logits.data(), logits.dim(1));
+
+ classification_results->resize(scores.size());
+ for (int i = 0; i < scores.size(); i++) {
+ (*classification_results)[i] = {
+ classification_feature_processor_->LabelToCollection(i), scores[i]};
+ }
+ std::sort(classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
+
+ // Phone class sanity check.
+ if (!classification_results->empty() &&
+ classification_results->begin()->collection == kPhoneCollection) {
+ const int digit_count = CountDigits(context, selection_indices);
+ if (digit_count <
+ model_->classification_options()->phone_min_num_digits() ||
+ digit_count >
+ model_->classification_options()->phone_max_num_digits()) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ }
+ }
+
+ // Address class sanity check.
+ if (!classification_results->empty() &&
+ classification_results->begin()->collection == kAddressCollection) {
+ if (selection_num_tokens <
+ model_->classification_options()->address_min_num_tokens()) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ }
+ }
+
+ return true;
+}
+
+bool Annotator::RegexClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
+ const UnicodeText selection_text_unicode(
+ UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
+
+ // Check whether any of the regular expressions match.
+ for (const int pattern_id : classification_regex_patterns_) {
+ const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern.pattern->Matcher(selection_text_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ bool matches;
+ if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
+ regex_approximate_match_pattern_ids_.end()) {
+ matches = matcher->ApproximatelyMatches(&status);
+ } else {
+ matches = matcher->Matches(&status);
+ }
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ if (matches &&
+ VerifyCandidate(regex_pattern.verification_options, selection_text)) {
+ *classification_result = {regex_pattern.collection_name,
+ regex_pattern.target_classification_score,
+ regex_pattern.priority_score};
+ return true;
+ }
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC3_LOG(ERROR) << "Cound't match regex: " << pattern_id;
+ }
+ }
+
+ return false;
+}
+
+bool Annotator::DatetimeClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ ClassificationResult* classification_result) const {
+ if (!datetime_parser_) {
+ return false;
+ }
+
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
+
+ std::vector<DatetimeParseResultSpan> datetime_spans;
+ if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
+ options.reference_timezone, options.locales,
+ ModeFlag_CLASSIFICATION,
+ /*anchor_start_end=*/true, &datetime_spans)) {
+ TC3_LOG(ERROR) << "Error during parsing datetime.";
+ return false;
+ }
+ for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ // Only consider the result valid if the selection and extracted datetime
+ // spans exactly match.
+ if (std::make_pair(datetime_span.span.first + selection_indices.first,
+ datetime_span.span.second + selection_indices.first) ==
+ selection_indices) {
+ *classification_result = {kDateCollection,
+ datetime_span.target_classification_score};
+ classification_result->datetime_parse_result = datetime_span.data;
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<ClassificationResult> Annotator::ClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options) const {
+ if (!initialized_) {
+ TC3_LOG(ERROR) << "Not initialized";
+ return {};
+ }
+
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return {};
+ }
+
+ if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
+ return {};
+ }
+
+ if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
+ TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
+ return {};
+ }
+
+ // Try the knowledge engine.
+ ClassificationResult knowledge_result;
+ if (knowledge_engine_ && knowledge_engine_->ClassifyText(
+ context, selection_indices, &knowledge_result)) {
+ if (!FilteredForClassification(knowledge_result)) {
+ return {knowledge_result};
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
+ }
+
+ // Try the regular expression models.
+ ClassificationResult regex_result;
+ if (RegexClassifyText(context, selection_indices, &regex_result)) {
+ if (!FilteredForClassification(regex_result)) {
+ return {regex_result};
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
+ }
+
+ // Try the date model.
+ ClassificationResult datetime_result;
+ if (DatetimeClassifyText(context, selection_indices, options,
+ &datetime_result)) {
+ if (!FilteredForClassification(datetime_result)) {
+ return {datetime_result};
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
+ }
+
+ // Fallback to the model.
+ std::vector<ClassificationResult> model_result;
+
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ if (ModelClassifyText(context, selection_indices, &interpreter_manager,
+ /*embedding_cache=*/nullptr, &model_result) &&
+ !model_result.empty()) {
+ if (!FilteredForClassification(model_result[0])) {
+ return model_result;
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
+ }
+
+ // No classifications.
+ return {};
+}
+
+bool Annotator::ModelAnnotate(const std::string& context,
+ InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
+
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ std::vector<UnicodeTextRange> lines;
+ if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
+ lines.push_back({context_unicode.begin(), context_unicode.end()});
+ } else {
+ lines = selection_feature_processor_->SplitContext(context_unicode);
+ }
+
+ const float min_annotate_confidence =
+ (model_->triggering_options() != nullptr
+ ? model_->triggering_options()->min_annotate_confidence()
+ : 0.f);
+
+ FeatureProcessor::EmbeddingCache embedding_cache;
+ for (const UnicodeTextRange& line : lines) {
+ const std::string line_str =
+ UnicodeText::UTF8Substring(line.first, line.second);
+
+ *tokens = selection_feature_processor_->Tokenize(line_str);
+ selection_feature_processor_->RetokenizeAndFindClick(
+ line_str, {0, std::distance(line.first, line.second)},
+ selection_feature_processor_->GetOptions()->only_use_line_with_click(),
+ tokens,
+ /*click_pos=*/nullptr);
+ const TokenSpan full_line_span = {0, tokens->size()};
+
+ // TODO(zilka): Add support for greater granularity of this check.
+ if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, full_line_span)) {
+ continue;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!selection_feature_processor_->ExtractFeatures(
+ *tokens, full_line_span,
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ embedding_executor_.get(),
+ /*embedding_cache=*/nullptr,
+ selection_feature_processor_->EmbeddingSize() +
+ selection_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC3_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ std::vector<TokenSpan> local_chunks;
+ if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
+ interpreter_manager->SelectionInterpreter(),
+ *cached_features, &local_chunks)) {
+ TC3_LOG(ERROR) << "Could not chunk.";
+ return false;
+ }
+
+ const int offset = std::distance(context_unicode.begin(), line.first);
+ for (const TokenSpan& chunk : local_chunks) {
+ const CodepointSpan codepoint_span =
+ selection_feature_processor_->StripBoundaryCodepoints(
+ line_str, TokenSpanToCodepointSpan(*tokens, chunk));
+
+ // Skip empty spans.
+ if (codepoint_span.first != codepoint_span.second) {
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(line_str, *tokens, codepoint_span,
+ interpreter_manager, &embedding_cache,
+ &classification)) {
+ TC3_LOG(ERROR) << "Could not classify text: "
+ << (codepoint_span.first + offset) << " "
+ << (codepoint_span.second + offset);
+ return false;
+ }
+
+ // Do not include the span if it's classified as "other".
+ if (!classification.empty() && !ClassifiedAsOther(classification) &&
+ classification[0].score >= min_annotate_confidence) {
+ AnnotatedSpan result_span;
+ result_span.span = {codepoint_span.first + offset,
+ codepoint_span.second + offset};
+ result_span.classification = std::move(classification);
+ result->push_back(std::move(result_span));
+ }
+ }
+ }
+ }
+ return true;
+}
+
+const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
+ return selection_feature_processor_.get();
+}
+
+const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
+ const {
+ return classification_feature_processor_.get();
+}
+
+const DatetimeParser* Annotator::DatetimeParserForTests() const {
+ return datetime_parser_.get();
+}
+
+std::vector<AnnotatedSpan> Annotator::Annotate(
+ const std::string& context, const AnnotationOptions& options) const {
+ std::vector<AnnotatedSpan> candidates;
+
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return {};
+ }
+
+ if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
+ return {};
+ }
+
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ // Annotate with the selection model.
+ std::vector<Token> tokens;
+ if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
+ return {};
+ }
+
+ // Annotate with the regular expression models.
+ if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ annotation_regex_patterns_, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
+ return {};
+ }
+
+ // Annotate with the datetime model.
+ if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ options.reference_time_ms_utc, options.reference_timezone,
+ options.locales, ModeFlag_ANNOTATION, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
+ return {};
+ }
+
+ // Annotate with the knowledge engine.
+ if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
+ return {};
+ }
+
+ // Sort candidates according to their position in the input, so that the next
+ // code can assume that any connected component of overlapping spans forms a
+ // contiguous block.
+ std::sort(candidates.begin(), candidates.end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
+ &candidate_indices)) {
+ TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return {};
+ }
+
+ std::vector<AnnotatedSpan> result;
+ result.reserve(candidate_indices.size());
+ for (const int i : candidate_indices) {
+ if (!candidates[i].classification.empty() &&
+ !ClassifiedAsOther(candidates[i].classification) &&
+ !FilteredForAnnotation(candidates[i])) {
+ result.push_back(std::move(candidates[i]));
+ }
+ }
+
+ return result;
+}
+
+bool Annotator::RegexChunk(const UnicodeText& context_unicode,
+ const std::vector<int>& rules,
+ std::vector<AnnotatedSpan>* result) const {
+ for (int pattern_id : rules) {
+ const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
+ if (!matcher) {
+ TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
+ << pattern_id;
+ return false;
+ }
+
+ int status = UniLib::RegexMatcher::kNoError;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (regex_pattern.verification_options) {
+ if (!VerifyCandidate(regex_pattern.verification_options,
+ matcher->Group(1, &status).ToUTF8String())) {
+ continue;
+ }
+ }
+ result->emplace_back();
+ // Selection/annotation regular expressions need to specify a capturing
+ // group specifying the selection.
+ result->back().span = {matcher->Start(1, &status),
+ matcher->End(1, &status)};
+ result->back().classification = {
+ {regex_pattern.collection_name,
+ regex_pattern.target_classification_score,
+ regex_pattern.priority_score}};
+ }
+ }
+ return true;
+}
+
+bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
+ tflite::Interpreter* selection_interpreter,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const {
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ // The inference span is the span of interest expanded to include
+ // max_selection_span tokens on either side, which is how far a selection can
+ // stretch from the click.
+ const TokenSpan inference_span = IntersectTokenSpans(
+ ExpandTokenSpan(span_of_interest,
+ /*num_tokens_left=*/max_selection_span,
+ /*num_tokens_right=*/max_selection_span),
+ {0, num_tokens});
+
+ std::vector<ScoredChunk> scored_chunks;
+ if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
+ selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features()
+ ->enabled()) {
+ if (!ModelBoundsSensitiveScoreChunks(
+ num_tokens, span_of_interest, inference_span, cached_features,
+ selection_interpreter, &scored_chunks)) {
+ return false;
+ }
+ } else {
+ if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
+ cached_features, selection_interpreter,
+ &scored_chunks)) {
+ return false;
+ }
+ }
+ std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
+ [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+ return lhs.score < rhs.score;
+ });
+
+ // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
+ // them greedily as long as they do not overlap with any previously picked
+ // chunks.
+ std::vector<bool> token_used(TokenSpanSize(inference_span));
+ chunks->clear();
+ for (const ScoredChunk& scored_chunk : scored_chunks) {
+ bool feasible = true;
+ for (int i = scored_chunk.token_span.first;
+ i < scored_chunk.token_span.second; ++i) {
+ if (token_used[i - inference_span.first]) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (!feasible) {
+ continue;
+ }
+
+ for (int i = scored_chunk.token_span.first;
+ i < scored_chunk.token_span.second; ++i) {
+ token_used[i - inference_span.first] = true;
+ }
+
+ chunks->push_back(scored_chunk.token_span);
+ }
+
+ std::sort(chunks->begin(), chunks->end());
+
+ return true;
+}
+
+namespace {
+// Updates the value at the given key in the map to maximum of the current value
+// and the given value, or simply inserts the value if the key is not yet there.
+template <typename Map>
+void UpdateMax(Map* map, typename Map::key_type key,
+ typename Map::mapped_type value) {
+ const auto it = map->find(key);
+ if (it != map->end()) {
+ it->second = std::max(it->second, value);
+ } else {
+ (*map)[key] = value;
+ }
+}
+} // namespace
+
+bool Annotator::ModelClickContextScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const {
+ const int max_batch_size = model_->selection_options()->batch_size();
+
+ std::vector<float> all_features;
+ std::map<TokenSpan, float> chunk_scores;
+ for (int batch_start = span_of_interest.first;
+ batch_start < span_of_interest.second; batch_start += max_batch_size) {
+ const int batch_end =
+ std::min(batch_start + max_batch_size, span_of_interest.second);
+
+ // Prepare features for the whole batch.
+ all_features.clear();
+ all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
+ for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
+ cached_features.AppendClickContextFeaturesForClick(click_pos,
+ &all_features);
+ }
+
+ // Run batched inference.
+ const int batch_size = batch_end - batch_start;
+ const int features_size = cached_features.OutputFeaturesSize();
+ TensorView<float> logits = selection_executor_->ComputeLogits(
+ TensorView<float>(all_features.data(), {batch_size, features_size}),
+ selection_interpreter);
+ if (!logits.is_valid()) {
+ TC3_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+ if (logits.dims() != 2 || logits.dim(0) != batch_size ||
+ logits.dim(1) !=
+ selection_feature_processor_->GetSelectionLabelCount()) {
+ TC3_LOG(ERROR) << "Mismatching output.";
+ return false;
+ }
+
+ // Save results.
+ for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
+ const std::vector<float> scores = ComputeSoftmax(
+ logits.data() + logits.dim(1) * (click_pos - batch_start),
+ logits.dim(1));
+ for (int j = 0;
+ j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
+ TokenSpan relative_token_span;
+ if (!selection_feature_processor_->LabelToTokenSpan(
+ j, &relative_token_span)) {
+ TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
+ return false;
+ }
+ const TokenSpan candidate_span = ExpandTokenSpan(
+ SingleTokenSpan(click_pos), relative_token_span.first,
+ relative_token_span.second);
+ if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
+ UpdateMax(&chunk_scores, candidate_span, scores[j]);
+ }
+ }
+ }
+ }
+
+ scored_chunks->clear();
+ scored_chunks->reserve(chunk_scores.size());
+ for (const auto& entry : chunk_scores) {
+ scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
+ }
+
+ return true;
+}
+
+bool Annotator::ModelBoundsSensitiveScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const {
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ const int max_chunk_length = selection_feature_processor_->GetOptions()
+ ->selection_reduced_output_space()
+ ? max_selection_span + 1
+ : 2 * max_selection_span + 1;
+ const bool score_single_token_spans_as_zero =
+ selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features()
+ ->score_single_token_spans_as_zero();
+
+ scored_chunks->clear();
+ if (score_single_token_spans_as_zero) {
+ scored_chunks->reserve(TokenSpanSize(span_of_interest));
+ }
+
+ // Prepare all chunk candidates into one batch:
+ // - Are contained in the inference span
+ // - Have a non-empty intersection with the span of interest
+ // - Are at least one token long
+ // - Are not longer than the maximum chunk length
+ std::vector<TokenSpan> candidate_spans;
+ for (int start = inference_span.first; start < span_of_interest.second;
+ ++start) {
+ const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
+ for (int end = leftmost_end_index;
+ end <= inference_span.second && end - start <= max_chunk_length;
+ ++end) {
+ const TokenSpan candidate_span = {start, end};
+ if (score_single_token_spans_as_zero &&
+ TokenSpanSize(candidate_span) == 1) {
+ // Do not include the single token span in the batch, add a zero score
+ // for it directly to the output.
+ scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
+ } else {
+ candidate_spans.push_back(candidate_span);
+ }
+ }
+ }
+
+ const int max_batch_size = model_->selection_options()->batch_size();
+
+ std::vector<float> all_features;
+ scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
+ for (int batch_start = 0; batch_start < candidate_spans.size();
+ batch_start += max_batch_size) {
+ const int batch_end = std::min(batch_start + max_batch_size,
+ static_cast<int>(candidate_spans.size()));
+
+ // Prepare features for the whole batch.
+ all_features.clear();
+ all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
+ for (int i = batch_start; i < batch_end; ++i) {
+ cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
+ &all_features);
+ }
+
+ // Run batched inference.
+ const int batch_size = batch_end - batch_start;
+ const int features_size = cached_features.OutputFeaturesSize();
+ TensorView<float> logits = selection_executor_->ComputeLogits(
+ TensorView<float>(all_features.data(), {batch_size, features_size}),
+ selection_interpreter);
+ if (!logits.is_valid()) {
+ TC3_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+ if (logits.dims() != 2 || logits.dim(0) != batch_size ||
+ logits.dim(1) != 1) {
+ TC3_LOG(ERROR) << "Mismatching output.";
+ return false;
+ }
+
+ // Save results.
+ for (int i = batch_start; i < batch_end; ++i) {
+ scored_chunks->push_back(
+ ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
+ }
+ }
+
+ return true;
+}
+
+bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& locales, ModeFlag mode,
+ std::vector<AnnotatedSpan>* result) const {
+ if (!datetime_parser_) {
+ return true;
+ }
+
+ std::vector<DatetimeParseResultSpan> datetime_spans;
+ if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
+ reference_timezone, locales, mode,
+ /*anchor_start_end=*/false, &datetime_spans)) {
+ return false;
+ }
+ for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ AnnotatedSpan annotated_span;
+ annotated_span.span = datetime_span.span;
+ annotated_span.classification = {{kDateCollection,
+ datetime_span.target_classification_score,
+ datetime_span.priority_score}};
+ annotated_span.classification[0].datetime_parse_result = datetime_span.data;
+
+ result->push_back(std::move(annotated_span));
+ }
+ return true;
+}
+
+const Model* ViewModel(const void* buffer, int size) {
+ if (!buffer) {
+ return nullptr;
+ }
+
+ return LoadAndVerifyModel(buffer, size);
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/annotator.h b/annotator/annotator.h
new file mode 100644
index 0000000..c58c03d
--- /dev/null
+++ b/annotator/annotator.h
@@ -0,0 +1,393 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Inference code for the text classification model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "annotator/datetime/parser.h"
+#include "annotator/feature-processor.h"
+#include "annotator/knowledge/knowledge-engine.h"
+#include "annotator/model-executor.h"
+#include "annotator/model_generated.h"
+#include "annotator/strip-unpaired-brackets.h"
+#include "annotator/types.h"
+#include "annotator/zlib-utils.h"
+#include "utils/memory/mmap.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+struct SelectionOptions {
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ static SelectionOptions Default() { return SelectionOptions(); }
+};
+
+struct ClassificationOptions {
+ // For parsing relative datetimes, the reference now time against which the
+ // relative datetimes get resolved.
+ // UTC milliseconds since epoch.
+ int64 reference_time_ms_utc = 0;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ static ClassificationOptions Default() { return ClassificationOptions(); }
+};
+
+struct AnnotationOptions {
+ // For parsing relative datetimes, the reference now time against which the
+ // relative datetimes get resolved.
+ // UTC milliseconds since epoch.
+ int64 reference_time_ms_utc = 0;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ static AnnotationOptions Default() { return AnnotationOptions(); }
+};
+
+// Holds TFLite interpreters for selection and classification models.
+// NOTE: his class is not thread-safe, thus should NOT be re-used across
+// threads.
+class InterpreterManager {
+ public:
+ // The constructor can be called with nullptr for any of the executors, and is
+ // a defined behavior, as long as the corresponding *Interpreter() method is
+ // not called when the executor is null.
+ InterpreterManager(const ModelExecutor* selection_executor,
+ const ModelExecutor* classification_executor)
+ : selection_executor_(selection_executor),
+ classification_executor_(classification_executor) {}
+
+ // Gets or creates and caches an interpreter for the selection model.
+ tflite::Interpreter* SelectionInterpreter();
+
+ // Gets or creates and caches an interpreter for the classification model.
+ tflite::Interpreter* ClassificationInterpreter();
+
+ private:
+ const ModelExecutor* selection_executor_;
+ const ModelExecutor* classification_executor_;
+
+ std::unique_ptr<tflite::Interpreter> selection_interpreter_;
+ std::unique_ptr<tflite::Interpreter> classification_interpreter_;
+};
+
+// A text processing model that provides text classification, annotation,
+// selection suggestion for various types.
+// NOTE: This class is not thread-safe.
+class Annotator {
+ public:
+ static std::unique_ptr<Annotator> FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ // Takes ownership of the mmap.
+ static std::unique_ptr<Annotator> FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, int offset, int size, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromPath(
+ const std::string& path, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+
+ // Returns true if the model is ready for use.
+ bool IsInitialized() { return initialized_; }
+
+ // Initializes the knowledge engine with the given config.
+ bool InitializeKnowledgeEngine(const std::string& serialized_config);
+
+ // Runs inference for given a context and current selection (i.e. index
+ // of the first and one past last selected characters (utf8 codepoint
+ // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
+ // beginning character and one past selection end character.
+ // Returns the original click_indices if an error occurs.
+ // NOTE: The selection indices are passed in and returned in terms of
+ // UTF8 codepoints (not bytes).
+ // Requires that the model is a smart selection model.
+ CodepointSpan SuggestSelection(
+ const std::string& context, CodepointSpan click_indices,
+ const SelectionOptions& options = SelectionOptions::Default()) const;
+
+ // Classifies the selected text given the context string.
+ // Returns an empty result if an error occurs.
+ std::vector<ClassificationResult> ClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options =
+ ClassificationOptions::Default()) const;
+
+ // Annotates given input text. The annotations are sorted by their position
+ // in the context string and exclude spans classified as 'other'.
+ std::vector<AnnotatedSpan> Annotate(
+ const std::string& context,
+ const AnnotationOptions& options = AnnotationOptions::Default()) const;
+
+ // Exposes the feature processor for tests and evaluations.
+ const FeatureProcessor* SelectionFeatureProcessorForTests() const;
+ const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const;
+
+ // String collection names for various classes.
+ static const std::string& kOtherCollection;
+ static const std::string& kPhoneCollection;
+ static const std::string& kAddressCollection;
+ static const std::string& kDateCollection;
+ static const std::string& kUrlCollection;
+ static const std::string& kFlightCollection;
+ static const std::string& kEmailCollection;
+ static const std::string& kIbanCollection;
+ static const std::string& kPaymentCardCollection;
+ static const std::string& kIsbnCollection;
+ static const std::string& kTrackingNumberCollection;
+
+ protected:
+ struct ScoredChunk {
+ TokenSpan token_span;
+ float score;
+ };
+
+ // Constructs and initializes text classifier from given model.
+ // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
+ Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ const UniLib* unilib, const CalendarLib* calendarlib);
+
+ // Constructs, validates and initializes text classifier from given model.
+ // Does not own the buffer that backs 'model'.
+ explicit Annotator(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib);
+
+ // Checks that model contains all required fields, and initializes internal
+ // datastructures.
+ void ValidateAndInitialize();
+
+ // Initializes regular expressions for the regex model.
+ bool InitializeRegexModel(ZlibDecompressor* decompressor);
+
+ // Resolves conflicts in the list of candidates by removing some overlapping
+ // ones. Returns indices of the surviving ones.
+ // NOTE: Assumes that the candidates are sorted according to their position in
+ // the span.
+ bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
+ const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* result) const;
+
+ // Resolves one conflict between candidates on indices 'start_index'
+ // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
+ // indices to 'chosen_indices'. Returns false if a problem arises.
+ bool ResolveConflict(const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ const std::vector<AnnotatedSpan>& candidates,
+ int start_index, int end_index,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* chosen_indices) const;
+
+ // Gets selection candidates from the ML model.
+ // Provides the tokens produced during tokenization of the context string for
+ // reuse.
+ bool ModelSuggestSelection(const UnicodeText& context_unicode,
+ CodepointSpan click_indices,
+ InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Classifies the selected text given the context string with the
+ // classification model.
+ // Returns true if no error occurred.
+ bool ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ bool ModelClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ // Returns a relative token span that represents how many tokens on the left
+ // from the selection and right from the selection are needed for the
+ // classifier input.
+ TokenSpan ClassifyTextUpperBoundNeededTokens() const;
+
+ // Classifies the selected text with the regular expressions models.
+ // Returns true if any regular expression matched and the result was set.
+ bool RegexClassifyText(const std::string& context,
+ CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const;
+
+ // Classifies the selected text with the date time model.
+ // Returns true if there was a match and the result was set.
+ bool DatetimeClassifyText(const std::string& context,
+ CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ ClassificationResult* classification_result) const;
+
+ // Chunks given input text with the selection model and classifies the spans
+ // with the classification model.
+ // The annotations are sorted by their position in the context string and
+ // exclude spans classified as 'other'.
+ // Provides the tokens produced during tokenization of the context string for
+ // reuse.
+ bool ModelAnnotate(const std::string& context,
+ InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Groups the tokens into chunks. A chunk is a token span that should be the
+ // suggested selection when any of its contained tokens is clicked. The chunks
+ // are non-overlapping and are sorted by their position in the context string.
+ // "num_tokens" is the total number of tokens available (as this method does
+ // not need the actual vector of tokens).
+ // "span_of_interest" is a span of all the tokens that could be clicked.
+ // The resulting chunks all have to overlap with it and they cover this span
+ // completely. The first and last chunk might extend beyond it.
+ // The chunks vector is cleared before filling.
+ bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
+ tflite::Interpreter* selection_interpreter,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const;
+
+ // A helper method for ModelChunk(). It generates scored chunk candidates for
+ // a click context model.
+ // NOTE: The returned chunks can (and most likely do) overlap.
+ bool ModelClickContextScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const;
+
+ // A helper method for ModelChunk(). It generates scored chunk candidates for
+ // a bounds-sensitive model.
+ // NOTE: The returned chunks can (and most likely do) overlap.
+ bool ModelBoundsSensitiveScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const;
+
+ // Produces chunks isolated by a set of regular expressions.
+ bool RegexChunk(const UnicodeText& context_unicode,
+ const std::vector<int>& rules,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Produces chunks from the datetime parser.
+ bool DatetimeChunk(const UnicodeText& context_unicode,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& locales, ModeFlag mode,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Returns whether a classification should be filtered.
+ bool FilteredForAnnotation(const AnnotatedSpan& span) const;
+ bool FilteredForClassification(
+ const ClassificationResult& classification) const;
+ bool FilteredForSelection(const AnnotatedSpan& span) const;
+
+ const Model* model_;
+
+ std::unique_ptr<const ModelExecutor> selection_executor_;
+ std::unique_ptr<const ModelExecutor> classification_executor_;
+ std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
+
+ std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
+ std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
+
+ std::unique_ptr<const DatetimeParser> datetime_parser_;
+
+ private:
+ struct CompiledRegexPattern {
+ std::string collection_name;
+ float target_classification_score;
+ float priority_score;
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ const VerificationOptions* verification_options;
+ };
+
+ std::unique_ptr<ScopedMmap> mmap_;
+ bool initialized_ = false;
+ bool enabled_for_annotation_ = false;
+ bool enabled_for_classification_ = false;
+ bool enabled_for_selection_ = false;
+ std::unordered_set<std::string> filtered_collections_annotation_;
+ std::unordered_set<std::string> filtered_collections_classification_;
+ std::unordered_set<std::string> filtered_collections_selection_;
+
+ std::vector<CompiledRegexPattern> regex_patterns_;
+ std::unordered_set<int> regex_approximate_match_pattern_ids_;
+
+ // Indices into regex_patterns_ for the different modes.
+ std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
+ selection_regex_patterns_;
+
+ std::unique_ptr<UniLib> owned_unilib_;
+ const UniLib* unilib_;
+ std::unique_ptr<CalendarLib> owned_calendarlib_;
+ const CalendarLib* calendarlib_;
+
+ std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
+};
+
+namespace internal {
+
+// Helper function, which if the initial 'span' contains only white-spaces,
+// moves the selection to a single-codepoint selection on the left side
+// of this block of white-space.
+CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UniLib& unilib);
+
+// Copies tokens from 'cached_tokens' that are
+// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
+// from the tokens that correspond to 'selection_indices'.
+std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices,
+ TokenSpan tokens_around_selection_to_copy);
+} // namespace internal
+
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const Model* ViewModel(const void* buffer, int size);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc
new file mode 100644
index 0000000..9bda35a
--- /dev/null
+++ b/annotator/annotator_jni.cc
@@ -0,0 +1,434 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for the Annotator.
+
+#include "annotator/annotator_jni.h"
+
+#include <jni.h>
+#include <type_traits>
+#include <vector>
+
+#include "annotator/annotator.h"
+#include "annotator/annotator_jni_common.h"
+#include "utils/base/integral_types.h"
+#include "utils/calendar/calendar.h"
+#include "utils/java/scoped_local_ref.h"
+#include "utils/java/string_utils.h"
+#include "utils/memory/mmap.h"
+#include "utils/utf8/unilib.h"
+
+#ifdef TC3_UNILIB_JAVAICU
+#ifndef TC3_CALENDAR_JAVAICU
+#error Inconsistent usage of Java ICU components
+#else
+#define TC3_USE_JAVAICU
+#endif
+#endif
+
+using libtextclassifier3::AnnotatedSpan;
+using libtextclassifier3::Annotator;
+using libtextclassifier3::ClassificationResult;
+using libtextclassifier3::CodepointSpan;
+using libtextclassifier3::Model;
+using libtextclassifier3::ScopedLocalRef;
+// When using the Java's ICU, CalendarLib and UniLib need to be instantiated
+// with a JavaVM pointer from JNI. When using a standard ICU the pointer is
+// not needed and the objects are instantiated implicitly.
+#ifdef TC3_USE_JAVAICU
+using libtextclassifier3::CalendarLib;
+using libtextclassifier3::UniLib;
+#endif
+
+namespace libtextclassifier3 {
+
+using libtextclassifier3::CodepointSpan;
+
+namespace {
+
+jobjectArray ClassificationResultsToJObjectArray(
+ JNIEnv* env,
+ const std::vector<ClassificationResult>& classification_result) {
+ const ScopedLocalRef<jclass> result_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult"),
+ env);
+ if (!result_class) {
+ TC3_LOG(ERROR) << "Couldn't find ClassificationResult class.";
+ return nullptr;
+ }
+ const ScopedLocalRef<jclass> datetime_parse_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult"),
+ env);
+ if (!datetime_parse_class) {
+ TC3_LOG(ERROR) << "Couldn't find DatetimeResult class.";
+ return nullptr;
+ }
+
+ const jmethodID result_class_constructor = env->GetMethodID(
+ result_class.get(), "<init>",
+ "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult;[B)V");
+ const jmethodID datetime_parse_class_constructor =
+ env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
+
+ const jobjectArray results = env->NewObjectArray(classification_result.size(),
+ result_class.get(), nullptr);
+ for (int i = 0; i < classification_result.size(); i++) {
+ jstring row_string =
+ env->NewStringUTF(classification_result[i].collection.c_str());
+
+ jobject row_datetime_parse = nullptr;
+ if (classification_result[i].datetime_parse_result.IsSet()) {
+ row_datetime_parse = env->NewObject(
+ datetime_parse_class.get(), datetime_parse_class_constructor,
+ classification_result[i].datetime_parse_result.time_ms_utc,
+ classification_result[i].datetime_parse_result.granularity);
+ }
+
+ jbyteArray serialized_knowledge_result = nullptr;
+ const std::string& serialized_knowledge_result_string =
+ classification_result[i].serialized_knowledge_result;
+ if (!serialized_knowledge_result_string.empty()) {
+ serialized_knowledge_result =
+ env->NewByteArray(serialized_knowledge_result_string.size());
+ env->SetByteArrayRegion(serialized_knowledge_result, 0,
+ serialized_knowledge_result_string.size(),
+ reinterpret_cast<const jbyte*>(
+ serialized_knowledge_result_string.data()));
+ }
+
+ jobject result =
+ env->NewObject(result_class.get(), result_class_constructor, row_string,
+ static_cast<jfloat>(classification_result[i].score),
+ row_datetime_parse, serialized_knowledge_result);
+ env->SetObjectArrayElement(results, i, result);
+ env->DeleteLocalRef(result);
+ }
+ return results;
+}
+
+CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
+ CodepointSpan orig_indices,
+ bool from_utf8) {
+ const libtextclassifier3::UnicodeText unicode_str =
+ libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
+
+ int unicode_index = 0;
+ int bmp_index = 0;
+
+ const int* source_index;
+ const int* target_index;
+ if (from_utf8) {
+ source_index = &unicode_index;
+ target_index = &bmp_index;
+ } else {
+ source_index = &bmp_index;
+ target_index = &unicode_index;
+ }
+
+ CodepointSpan result{-1, -1};
+ std::function<void()> assign_indices_fn = [&result, &orig_indices,
+ &source_index, &target_index]() {
+ if (orig_indices.first == *source_index) {
+ result.first = *target_index;
+ }
+
+ if (orig_indices.second == *source_index) {
+ result.second = *target_index;
+ }
+ };
+
+ for (auto it = unicode_str.begin(); it != unicode_str.end();
+ ++it, ++unicode_index, ++bmp_index) {
+ assign_indices_fn();
+
+ // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
+ if (*it > 0xFFFF) {
+ ++bmp_index;
+ }
+ }
+ assign_indices_fn();
+
+ return result;
+}
+
+} // namespace
+
+CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
+ CodepointSpan bmp_indices) {
+ return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
+}
+
+CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
+ CodepointSpan utf8_indices) {
+ return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
+}
+
+jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const Model* model = libtextclassifier3::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->locales()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->locales()->c_str());
+}
+
+jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const Model* model = libtextclassifier3::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
+}
+
+jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const Model* model = libtextclassifier3::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->name()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->name()->c_str());
+}
+
+} // namespace libtextclassifier3
+
+using libtextclassifier3::ClassificationResultsToJObjectArray;
+using libtextclassifier3::ConvertIndicesBMPToUTF8;
+using libtextclassifier3::ConvertIndicesUTF8ToBMP;
+using libtextclassifier3::FromJavaAnnotationOptions;
+using libtextclassifier3::FromJavaClassificationOptions;
+using libtextclassifier3::FromJavaSelectionOptions;
+using libtextclassifier3::ToStlString;
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
+(JNIEnv* env, jobject thiz, jint fd) {
+#ifdef TC3_USE_JAVAICU
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ return reinterpret_cast<jlong>(
+ Annotator::FromFileDescriptor(fd, new UniLib(jni_cache),
+ new CalendarLib(jni_cache))
+ .release());
+#else
+ return reinterpret_cast<jlong>(Annotator::FromFileDescriptor(fd).release());
+#endif
+}
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ const std::string path_str = ToStlString(env, path);
+#ifdef TC3_USE_JAVAICU
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ return reinterpret_cast<jlong>(Annotator::FromPath(path_str,
+ new UniLib(jni_cache),
+ new CalendarLib(jni_cache))
+ .release());
+#else
+ return reinterpret_cast<jlong>(Annotator::FromPath(path_str).release());
+#endif
+}
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
+ nativeNewAnnotatorFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+#ifdef TC3_USE_JAVAICU
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ return reinterpret_cast<jlong>(
+ Annotator::FromFileDescriptor(fd, offset, size, new UniLib(jni_cache),
+ new CalendarLib(jni_cache))
+ .release());
+#else
+ return reinterpret_cast<jlong>(
+ Annotator::FromFileDescriptor(fd, offset, size).release());
+#endif
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeKnowledgeEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+
+ std::string serialized_config_string;
+ const int length = env->GetArrayLength(serialized_config);
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeKnowledgeEngine(serialized_config_string);
+}
+
+TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options) {
+ if (!ptr) {
+ return nullptr;
+ }
+
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+
+ const std::string context_utf8 = ToStlString(env, context);
+ CodepointSpan input_indices =
+ ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ CodepointSpan selection = model->SuggestSelection(
+ context_utf8, input_indices, FromJavaSelectionOptions(env, options));
+ selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
+
+ jintArray result = env->NewIntArray(2);
+ env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
+ env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
+ return result;
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options) {
+ if (!ptr) {
+ return nullptr;
+ }
+ Annotator* ff_model = reinterpret_cast<Annotator*>(ptr);
+
+ const std::string context_utf8 = ToStlString(env, context);
+ const CodepointSpan input_indices =
+ ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ const std::vector<ClassificationResult> classification_result =
+ ff_model->ClassifyText(context_utf8, input_indices,
+ FromJavaClassificationOptions(env, options));
+
+ return ClassificationResultsToJObjectArray(env, classification_result);
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
+ if (!ptr) {
+ return nullptr;
+ }
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+ std::string context_utf8 = ToStlString(env, context);
+ std::vector<AnnotatedSpan> annotations =
+ model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
+
+ jclass result_class = env->FindClass(
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan");
+ if (!result_class) {
+ TC3_LOG(ERROR) << "Couldn't find result class: "
+ << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotatedSpan";
+ return nullptr;
+ }
+
+ jmethodID result_class_constructor =
+ env->GetMethodID(result_class, "<init>",
+ "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult;)V");
+
+ jobjectArray results =
+ env->NewObjectArray(annotations.size(), result_class, nullptr);
+
+ for (int i = 0; i < annotations.size(); ++i) {
+ CodepointSpan span_bmp =
+ ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
+ jobject result = env->NewObject(result_class, result_class_constructor,
+ static_cast<jint>(span_bmp.first),
+ static_cast<jint>(span_bmp.second),
+ ClassificationResultsToJObjectArray(
+ env, annotations[i].classification));
+ env->SetObjectArrayElement(results, i, result);
+ env->DeleteLocalRef(result);
+ }
+ env->DeleteLocalRef(result_class);
+ return results;
+}
+
+TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
+(JNIEnv* env, jobject thiz, jlong ptr) {
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+ delete model;
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd) {
+ TC3_LOG(WARNING) << "Using deprecated getLanguage().";
+ return TC3_JNI_METHOD_NAME(TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)(
+ env, clazz, fd);
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return GetLocalesFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
+ nativeGetLocalesFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return GetLocalesFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return GetVersionFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
+ nativeGetVersionFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return GetVersionFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return GetNameFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
+ nativeGetNameFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return GetNameFromMmap(env, mmap.get());
+}
diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h
new file mode 100644
index 0000000..47715b4
--- /dev/null
+++ b/annotator/annotator_jni.h
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "annotator/annotator_jni_common.h"
+#include "annotator/types.h"
+#include "utils/java/jni-base.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// SmartSelection.
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
+(JNIEnv* env, jobject thiz, jint fd);
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
+(JNIEnv* env, jobject thiz, jstring path);
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
+ nativeNewAnnotatorFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeKnowledgeEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
+TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
+
+TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+// DEPRECATED. Use nativeGetLocales instead.
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
+ nativeGetLocalesFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
+ nativeGetVersionFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
+ nativeGetNameFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+// Given a utf8 string and a span expressed in Java BMP (basic multilingual
+// plane) codepoints, converts it to a span expressed in utf8 codepoints.
+libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8(
+ const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices);
+
+// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a
+// span expressed in Java BMP (basic multilingual plane) codepoints.
+libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP(
+ const std::string& utf8_str,
+ libtextclassifier3::CodepointSpan utf8_indices);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
diff --git a/annotator/annotator_jni_common.cc b/annotator/annotator_jni_common.cc
new file mode 100644
index 0000000..0fdb87b
--- /dev/null
+++ b/annotator/annotator_jni_common.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator_jni_common.h"
+
+#include "utils/java/jni-base.h"
+#include "utils/java/scoped_local_ref.h"
+
+namespace libtextclassifier3 {
+namespace {
+template <typename T>
+T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
+ const std::string& class_name) {
+ if (!joptions) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
+ env);
+ if (!options_class) {
+ return {};
+ }
+
+ const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
+ env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
+ "getLocale", "Ljava/lang/String;");
+ const std::pair<bool, jobject> status_or_reference_timezone =
+ CallJniMethod0<jobject>(env, joptions, options_class.get(),
+ &JNIEnv::CallObjectMethod, "getReferenceTimezone",
+ "Ljava/lang/String;");
+ const std::pair<bool, int64> status_or_reference_time_ms_utc =
+ CallJniMethod0<int64>(env, joptions, options_class.get(),
+ &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
+ "J");
+
+ if (!status_or_locales.first || !status_or_reference_timezone.first ||
+ !status_or_reference_time_ms_utc.first) {
+ return {};
+ }
+
+ T options;
+ options.locales =
+ ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+ options.reference_timezone = ToStlString(
+ env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
+ options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
+ return options;
+}
+} // namespace
+
+SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
+ if (!joptions) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> options_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$SelectionOptions"),
+ env);
+ const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
+ env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
+ "getLocales", "Ljava/lang/String;");
+ if (!status_or_locales.first) {
+ return {};
+ }
+
+ SelectionOptions options;
+ options.locales =
+ ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+
+ return options;
+}
+
+ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
+ jobject joptions) {
+ return FromJavaOptionsInternal<ClassificationOptions>(
+ env, joptions,
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions");
+}
+
+AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
+ return FromJavaOptionsInternal<AnnotationOptions>(
+ env, joptions,
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/annotator_jni_common.h b/annotator/annotator_jni_common.h
new file mode 100644
index 0000000..b62bb21
--- /dev/null
+++ b/annotator/annotator_jni_common.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
+
+#include <jni.h>
+
+#include "annotator/annotator.h"
+
+#ifndef TC3_ANNOTATOR_CLASS_NAME
+#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel
+#endif
+
+#define TC3_ANNOTATOR_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ANNOTATOR_CLASS_NAME)
+
+namespace libtextclassifier3 {
+
+SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions);
+
+ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
+ jobject joptions);
+
+AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
diff --git a/annotator/annotator_jni_test.cc b/annotator/annotator_jni_test.cc
new file mode 100644
index 0000000..929fb59
--- /dev/null
+++ b/annotator/annotator_jni_test.cc
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator_jni.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(Annotator, ConvertIndicesBMPUTF8) {
+ // Test boundary cases.
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), std::make_pair(0, 5));
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello", {0, 5}), std::make_pair(0, 5));
+
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {0, 5}),
+ std::make_pair(0, 5));
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {0, 5}),
+ std::make_pair(0, 5));
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("😁ello world", {0, 6}),
+ std::make_pair(0, 5));
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁ello world", {0, 5}),
+ std::make_pair(0, 6));
+
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {6, 11}),
+ std::make_pair(6, 11));
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {6, 11}),
+ std::make_pair(6, 11));
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("hello worl😁", {6, 12}),
+ std::make_pair(6, 11));
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello worl😁", {6, 11}),
+ std::make_pair(6, 12));
+
+ // Simple example where the longer character is before the selection.
+ // character 😁 is 0x1f601
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello World.", {3, 8}),
+ std::make_pair(2, 7));
+
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello World.", {2, 7}),
+ std::make_pair(3, 8));
+
+ // Longer character is before and in selection.
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁 World.", {3, 9}),
+ std::make_pair(2, 7));
+
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁 World.", {2, 7}),
+ std::make_pair(3, 9));
+
+ // Longer character is before and after selection.
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello😁World.", {3, 8}),
+ std::make_pair(2, 7));
+
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello😁World.", {2, 7}),
+ std::make_pair(3, 8));
+
+ // Longer character is before in after selection.
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁😁World.", {3, 9}),
+ std::make_pair(2, 7));
+
+ EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁😁World.", {2, 7}),
+ std::make_pair(3, 9));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
new file mode 100644
index 0000000..fbaf039
--- /dev/null
+++ b/annotator/annotator_test.cc
@@ -0,0 +1,1254 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator.h"
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+using testing::IsEmpty;
+using testing::Pair;
+using testing::Values;
+
+std::string FirstResult(const std::vector<ClassificationResult>& results) {
+ if (results.empty()) {
+ return "<INVALID RESULTS>";
+ }
+ return results[0].collection;
+}
+
+MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
+ return testing::Value(arg.span, Pair(start, end)) &&
+ testing::Value(FirstResult(arg.classification), best_class);
+}
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string GetModelPath() {
+ return TC3_TEST_DATA_DIR;
+}
+
+class AnnotatorTest : public ::testing::TestWithParam<const char*> {
+ protected:
+ AnnotatorTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_),
+ INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {}
+ UniLib unilib_;
+ CalendarLib calendarlib_;
+};
+
+TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_);
+ EXPECT_FALSE(classifier);
+}
+
+INSTANTIATE_TEST_CASE_P(ClickContext, AnnotatorTest,
+ Values("test_model_cc.fb"));
+INSTANTIATE_TEST_CASE_P(BoundsSensitive, AnnotatorTest,
+ Values("test_model.fb"));
+
+TEST_P(AnnotatorTest, ClassifyText) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // More lines.
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {15, 27})));
+ EXPECT_EQ("phone",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {90, 103})));
+
+ // Single word.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
+ EXPECT_EQ("<INVALID RESULTS>",
+ FirstResult(classifier->ClassifyText("asdf", {0, 0})));
+
+ // Junk.
+ EXPECT_EQ("<INVALID RESULTS>",
+ FirstResult(classifier->ClassifyText("", {0, 0})));
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
+ // Test invalid utf8 input.
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "\xf0\x9f\x98\x8b\x8b", {0, 0})));
+}
+
+TEST_P(AnnotatorTest, ClassifyTextDisabledFail) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->classification_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+
+ // The classification model is still needed for selection scores.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_P(AnnotatorTest, ClassifyTextDisabled) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes =
+ ModeFlag_ANNOTATION_AND_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(
+ classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
+ IsEmpty());
+}
+
+TEST_P(AnnotatorTest, ClassifyTextFilteredCollections) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone classification
+ unpacked_model->output_options->filtered_collections_classification.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // Check that the address classification still passes.
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+}
+
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score) {
+ std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
+ result->collection_name = collection_name;
+ result->pattern = pattern;
+ // We cannot directly operate with |= on the flag, so use an int here.
+ int enabled_modes = ModeFlag_NONE;
+ if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
+ if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
+ if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
+ result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
+ result->target_classification_score = score;
+ result->priority_score = score;
+ return result;
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, ClassifyTextRegularExpression) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", "Barack Obama", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ EXPECT_EQ("person",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("email",
+ FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
+ EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
+ "Contact me at you@android.com", {14, 29})));
+
+ EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
+ {7, 12})));
+ EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
+ "cc: 4012 8888 8888 1881", {4, 23})));
+ EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
+ "2221 0067 4735 6281", {0, 19})));
+ // Luhn check fails.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
+ {0, 19})));
+
+ // More lines.
+ EXPECT_EQ("url",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {51, 65})));
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
+ std::make_pair(12, 19));
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ std::make_pair(15, 27));
+ EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
+ std::make_pair(4, 23));
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ std::make_pair(26, 62));
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ std::make_pair(55, 62));
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, AnnotateRegex) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/true, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ IsAnnotatedSpan(107, 126, "payment_card")}));
+}
+#endif // TC3_UNILIB_ICU
+
+TEST_P(AnnotatorTest, PhoneFiltering) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789", {7, 20})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 25})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 28})));
+}
+
+TEST_P(AnnotatorTest, SuggestSelection) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ std::make_pair(15, 21));
+
+ // Try passing whole string.
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
+ std::make_pair(0, 27));
+
+ // Single letter.
+ EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
+
+ // Single word.
+ EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ // Unpaired bracket stripping.
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
+ std::make_pair(12, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
+ std::make_pair(11, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
+ std::make_pair(12, 15));
+
+ // If the resulting selection would be empty, the original span is returned.
+ EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
+ std::make_pair(11, 13));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
+ std::make_pair(11, 12));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
+ std::make_pair(11, 12));
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionDisabledFail) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ // Selection model needs to be present for annotation.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionDisabled) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 14));
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "call me at (800) 123-456 today", {11, 24})));
+
+ EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
+ IsEmpty());
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionFilteredCollections) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone selection
+ unpacked_model->output_options->filtered_collections_selection.push_back(
+ "phone");
+ // We need to force this for filtering.
+ unpacked_model->selection_options->always_classify_suggested_selection = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 14));
+
+ // Address selection should still work.
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ std::make_pair(0, 27));
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionsAreSymmetric) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
+ {16, 22}),
+ std::make_pair(6, 33));
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionWithNewLine) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
+ std::make_pair(4, 16));
+ EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
+ std::make_pair(0, 12));
+
+ SelectionOptions options;
+ EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
+ std::make_pair(0, 7));
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ // From the right.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama, gave a speech at", {15, 26}),
+ std::make_pair(15, 26));
+
+ // From the right multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
+ std::make_pair(15, 26));
+
+ // From the left multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
+ std::make_pair(21, 32));
+
+ // From both sides.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon !BarackObama,- gave a speech at", {16, 27}),
+ std::make_pair(16, 27));
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ // Try passing in bunch of invalid selections.
+ EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
+ std::make_pair(-10, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
+ std::make_pair(-30, 300));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
+ std::make_pair(-10, -1));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
+ std::make_pair(100, 17));
+
+ // Try passing invalid utf8.
+ EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
+ std::make_pair(-1, -1));
+}
+
+TEST_P(AnnotatorTest, SuggestSelectionSelectSpace) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
+ std::make_pair(10, 11));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
+ std::make_pair(23, 24));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
+ std::make_pair(23, 24));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
+ {14, 17}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
+ std::make_pair(14, 40));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
+ std::make_pair(4, 5));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
+ std::make_pair(7, 8));
+
+ // With a punctuation around the selected whitespace.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
+ std::make_pair(14, 41));
+
+ // When all's whitespace, should return the original indices.
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
+ std::make_pair(0, 1));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
+ std::make_pair(0, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
+ std::make_pair(2, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
+ std::make_pair(5, 6));
+}
+
+TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
+ UnicodeText text;
+
+ text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
+ std::make_pair(3, 4));
+ text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
+ std::make_pair(3, 4));
+
+ // Nothing on the left.
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
+ std::make_pair(4, 5));
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
+ std::make_pair(0, 1));
+
+ // Whitespace only.
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_),
+ std::make_pair(2, 3));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
+ std::make_pair(4, 5));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
+ std::make_pair(0, 1));
+}
+
+TEST_P(AnnotatorTest, Annotate) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+
+ // Try passing invalid utf8.
+ EXPECT_TRUE(
+ classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
+ .empty());
+}
+
+
+TEST_P(AnnotatorTest, AnnotateSmallBatches) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Set the batch size.
+ unpacked_model->selection_options->batch_size = 4;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ // Add test threshold.
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 2.f; // Discards all results.
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
+}
+#endif // TC3_UNILIB_ICU
+
+TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test thresholds.
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 0.f; // Keeps all results.
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
+}
+
+TEST_P(AnnotatorTest, AnnotateDisabled) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the model for annotation.
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
+}
+
+TEST_P(AnnotatorTest, AnnotateFilteredCollections) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone annotation
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // We add a custom annotator that wins against the phone classification
+ // below and that we subsequently suppress.
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "suppress");
+
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "suppress", "(\\d{3} ?\\d{4})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_CALENDAR_ICU
+TEST_P(AnnotatorTest, ClassifyTextDate) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam());
+ EXPECT_TRUE(classifier);
+
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ options.reference_timezone = "Europe/Zurich";
+ result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+ result.clear();
+
+ options.reference_timezone = "America/Los_Angeles";
+ result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+ result.clear();
+
+ options.reference_timezone = "America/Los_Angeles";
+ result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_SECOND);
+ result.clear();
+
+ // Date on another line.
+ options.reference_timezone = "Europe/Zurich";
+ result = classifier->ClassifyText(
+ "hello world this is the first line\n"
+ "january 1, 2017",
+ {35, 50}, options);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_CALENDAR_ICU
+TEST_P(AnnotatorTest, ClassifyTextDatePriorities) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam());
+ EXPECT_TRUE(classifier);
+
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ result.clear();
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+
+ result.clear();
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "de";
+ result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_CALENDAR_ICU
+TEST_P(AnnotatorTest, SuggestTextDateDisabled) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the patterns for selection.
+ for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
+ unpacked_model->datetime_model->patterns[i]->enabled_modes =
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION;
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+ EXPECT_EQ("date",
+ FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
+ EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
+ std::make_pair(0, 7));
+ EXPECT_THAT(classifier->Annotate("january 1, 2017"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
+}
+#endif // TC3_UNILIB_ICU
+
+class TestingAnnotator : public Annotator {
+ public:
+ TestingAnnotator(const std::string& model, const UniLib* unilib,
+ const CalendarLib* calendarlib)
+ : Annotator(ViewModel(model.data(), model.size()), unilib, calendarlib) {}
+
+ using Annotator::ResolveConflicts;
+};
+
+AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
+ const std::string& collection,
+ const float score) {
+ AnnotatedSpan result;
+ result.span = span;
+ result.classification.push_back({collection, score});
+ return result;
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
+ TestingAnnotator classifier("", &unilib_, &calendarlib_);
+
+ std::vector<AnnotatedSpan> candidates{
+ {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsSequence) {
+ TestingAnnotator classifier("", &unilib_, &calendarlib_);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 1}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 2}, "phone", 1.0),
+ MakeAnnotatedSpan({2, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 4}, "phone", 1.0),
+ MakeAnnotatedSpan({4, 5}, "phone", 1.0),
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
+ TestingAnnotator classifier("", &unilib_, &calendarlib_);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 1.0),
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
+ TestingAnnotator classifier("", &unilib_, &calendarlib_);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({1, 5}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
+ TestingAnnotator classifier("", &unilib_, &calendarlib_);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5),
+ MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6),
+ MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
+ MakeAnnotatedSpan({11, 15}, "phone", 0.9),
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, LongInput) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ for (const auto& type_value_pair :
+ std::vector<std::pair<std::string, std::string>>{
+ {"address", "350 Third Street, Cambridge"},
+ {"phone", "123 456-7890"},
+ {"url", "www.google.com"},
+ {"email", "someone@gmail.com"},
+ {"flight", "LX 38"},
+ {"date", "September 1, 2018"}}) {
+ const std::string input_100k = std::string(50000, ' ') +
+ type_value_pair.second +
+ std::string(50000, ' ');
+ const int value_length = type_value_pair.second.size();
+
+ EXPECT_THAT(classifier->Annotate(input_100k),
+ ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
+ type_value_pair.first)}));
+ EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
+ std::make_pair(50000, 50000 + value_length));
+ EXPECT_EQ(type_value_pair.first,
+ FirstResult(classifier->ClassifyText(
+ input_100k, {50000, 50000 + value_length})));
+ }
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+// These coarse tests are there only to make sure the execution happens in
+// reasonable amount of time.
+TEST_P(AnnotatorTest, LongInputNoResultCheck) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ for (const std::string& value :
+ std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
+ const std::string input_100k =
+ std::string(50000, ' ') + value + std::string(50000, ' ');
+ const int value_length = value.size();
+
+ classifier->Annotate(input_100k);
+ classifier->SuggestSelection(input_100k, {50000, 50001});
+ classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
+ }
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, MaxTokenLength) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of tokens should behave normally.
+ unpacked_model->classification_options->max_num_tokens = -1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise the maximum number of tokens to suppress the classification.
+ unpacked_model->classification_options->max_num_tokens = 3;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+#endif // TC3_UNILIB_ICU
+
+#ifdef TC3_UNILIB_ICU
+TEST_P(AnnotatorTest, MinAddressTokenLength) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of address tokens should behave normally.
+ unpacked_model->classification_options->address_min_num_tokens = 0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise number of address tokens to suppress the address classification.
+ unpacked_model->classification_options->address_min_num_tokens = 5;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+#endif // TC3_UNILIB_ICU
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/cached-features.cc b/annotator/cached-features.cc
new file mode 100644
index 0000000..480c044
--- /dev/null
+++ b/annotator/cached-features.cc
@@ -0,0 +1,173 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/cached-features.h"
+
+#include "utils/base/logging.h"
+#include "utils/tensor-view.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+
+int CalculateOutputFeaturesSize(const FeatureProcessorOptions* options,
+ int feature_vector_size) {
+ const bool bounds_sensitive_enabled =
+ options->bounds_sensitive_features() &&
+ options->bounds_sensitive_features()->enabled();
+
+ int num_extracted_tokens = 0;
+ if (bounds_sensitive_enabled) {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config =
+ options->bounds_sensitive_features();
+ num_extracted_tokens += config->num_tokens_before();
+ num_extracted_tokens += config->num_tokens_inside_left();
+ num_extracted_tokens += config->num_tokens_inside_right();
+ num_extracted_tokens += config->num_tokens_after();
+ if (config->include_inside_bag()) {
+ ++num_extracted_tokens;
+ }
+ } else {
+ num_extracted_tokens = 2 * options->context_size() + 1;
+ }
+
+ int output_features_size = num_extracted_tokens * feature_vector_size;
+
+ if (bounds_sensitive_enabled &&
+ options->bounds_sensitive_features()->include_inside_length()) {
+ ++output_features_size;
+ }
+
+ return output_features_size;
+}
+
+} // namespace
+
+std::unique_ptr<CachedFeatures> CachedFeatures::Create(
+ const TokenSpan& extraction_span,
+ std::unique_ptr<std::vector<float>> features,
+ std::unique_ptr<std::vector<float>> padding_features,
+ const FeatureProcessorOptions* options, int feature_vector_size) {
+ const int min_feature_version =
+ options->bounds_sensitive_features() &&
+ options->bounds_sensitive_features()->enabled()
+ ? 2
+ : 1;
+ if (options->feature_version() < min_feature_version) {
+ TC3_LOG(ERROR) << "Unsupported feature version.";
+ return nullptr;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features(new CachedFeatures());
+ cached_features->extraction_span_ = extraction_span;
+ cached_features->features_ = std::move(features);
+ cached_features->padding_features_ = std::move(padding_features);
+ cached_features->options_ = options;
+
+ cached_features->output_features_size_ =
+ CalculateOutputFeaturesSize(options, feature_vector_size);
+
+ return cached_features;
+}
+
+void CachedFeatures::AppendClickContextFeaturesForClick(
+ int click_pos, std::vector<float>* output_features) const {
+ click_pos -= extraction_span_.first;
+
+ AppendFeaturesInternal(
+ /*intended_span=*/ExpandTokenSpan(SingleTokenSpan(click_pos),
+ options_->context_size(),
+ options_->context_size()),
+ /*read_mask_span=*/{0, TokenSpanSize(extraction_span_)}, output_features);
+}
+
+void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan(
+ TokenSpan selected_span, std::vector<float>* output_features) const {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config =
+ options_->bounds_sensitive_features();
+
+ selected_span.first -= extraction_span_.first;
+ selected_span.second -= extraction_span_.first;
+
+ // Append the features for tokens around the left bound. Masks out tokens
+ // after the right bound, so that if num_tokens_inside_left goes past it,
+ // padding tokens will be used.
+ AppendFeaturesInternal(
+ /*intended_span=*/{selected_span.first - config->num_tokens_before(),
+ selected_span.first +
+ config->num_tokens_inside_left()},
+ /*read_mask_span=*/{0, selected_span.second}, output_features);
+
+ // Append the features for tokens around the right bound. Masks out tokens
+ // before the left bound, so that if num_tokens_inside_right goes past it,
+ // padding tokens will be used.
+ AppendFeaturesInternal(
+ /*intended_span=*/{selected_span.second -
+ config->num_tokens_inside_right(),
+ selected_span.second + config->num_tokens_after()},
+ /*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)},
+ output_features);
+
+ if (config->include_inside_bag()) {
+ AppendBagFeatures(selected_span, output_features);
+ }
+
+ if (config->include_inside_length()) {
+ output_features->push_back(
+ static_cast<float>(TokenSpanSize(selected_span)));
+ }
+}
+
+void CachedFeatures::AppendFeaturesInternal(
+ const TokenSpan& intended_span, const TokenSpan& read_mask_span,
+ std::vector<float>* output_features) const {
+ const TokenSpan copy_span =
+ IntersectTokenSpans(intended_span, read_mask_span);
+ for (int i = intended_span.first; i < copy_span.first; ++i) {
+ AppendPaddingFeatures(output_features);
+ }
+ output_features->insert(
+ output_features->end(),
+ features_->begin() + copy_span.first * NumFeaturesPerToken(),
+ features_->begin() + copy_span.second * NumFeaturesPerToken());
+ for (int i = copy_span.second; i < intended_span.second; ++i) {
+ AppendPaddingFeatures(output_features);
+ }
+}
+
+void CachedFeatures::AppendPaddingFeatures(
+ std::vector<float>* output_features) const {
+ output_features->insert(output_features->end(), padding_features_->begin(),
+ padding_features_->end());
+}
+
+void CachedFeatures::AppendBagFeatures(
+ const TokenSpan& bag_span, std::vector<float>* output_features) const {
+ const int offset = output_features->size();
+ output_features->resize(output_features->size() + NumFeaturesPerToken());
+ for (int i = bag_span.first; i < bag_span.second; ++i) {
+ for (int j = 0; j < NumFeaturesPerToken(); ++j) {
+ (*output_features)[offset + j] +=
+ (*features_)[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span);
+ }
+ }
+}
+
+int CachedFeatures::NumFeaturesPerToken() const {
+ return padding_features_->size();
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/cached-features.h b/annotator/cached-features.h
new file mode 100644
index 0000000..e03f79c
--- /dev/null
+++ b/annotator/cached-features.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_
+
+#include <memory>
+#include <vector>
+
+#include "annotator/model-executor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Holds state for extracting features across multiple calls and reusing them.
+// Assumes that features for each Token are independent.
+class CachedFeatures {
+ public:
+ static std::unique_ptr<CachedFeatures> Create(
+ const TokenSpan& extraction_span,
+ std::unique_ptr<std::vector<float>> features,
+ std::unique_ptr<std::vector<float>> padding_features,
+ const FeatureProcessorOptions* options, int feature_vector_size);
+
+ // Appends the click context features for the given click position to
+ // 'output_features'.
+ void AppendClickContextFeaturesForClick(
+ int click_pos, std::vector<float>* output_features) const;
+
+ // Appends the bounds-sensitive features for the given token span to
+ // 'output_features'.
+ void AppendBoundsSensitiveFeaturesForSpan(
+ TokenSpan selected_span, std::vector<float>* output_features) const;
+
+ // Returns number of features that 'AppendFeaturesForSpan' appends.
+ int OutputFeaturesSize() const { return output_features_size_; }
+
+ private:
+ CachedFeatures() {}
+
+ // Appends token features to the output. The intended_span specifies which
+ // tokens' features should be used in principle. The read_mask_span restricts
+ // which tokens are actually read. For tokens outside of the read_mask_span,
+ // padding tokens are used instead.
+ void AppendFeaturesInternal(const TokenSpan& intended_span,
+ const TokenSpan& read_mask_span,
+ std::vector<float>* output_features) const;
+
+ // Appends features of one padding token to the output.
+ void AppendPaddingFeatures(std::vector<float>* output_features) const;
+
+ // Appends the features of tokens from the given span to the output. The
+ // features are averaged so that the appended features have the size
+ // corresponding to one token.
+ void AppendBagFeatures(const TokenSpan& bag_span,
+ std::vector<float>* output_features) const;
+
+ int NumFeaturesPerToken() const;
+
+ TokenSpan extraction_span_;
+ const FeatureProcessorOptions* options_;
+ int output_features_size_;
+ std::unique_ptr<std::vector<float>> features_;
+ std::unique_ptr<std::vector<float>> padding_features_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_
diff --git a/annotator/cached-features_test.cc b/annotator/cached-features_test.cc
new file mode 100644
index 0000000..702f3ca
--- /dev/null
+++ b/annotator/cached-features_test.cc
@@ -0,0 +1,157 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/cached-features.h"
+
+#include "annotator/model-executor.h"
+#include "utils/tensor-view.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+namespace libtextclassifier3 {
+namespace {
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) {
+ std::unique_ptr<std::vector<float>> features(new std::vector<float>());
+ for (int i = 1; i <= num_tokens; ++i) {
+ features->push_back(i * 11.0f);
+ features->push_back(-i * 11.0f);
+ features->push_back(i * 0.1f);
+ }
+ return features;
+}
+
+std::vector<float> GetCachedClickContextFeatures(
+ const CachedFeatures& cached_features, int click_pos) {
+ std::vector<float> output_features;
+ cached_features.AppendClickContextFeaturesForClick(click_pos,
+ &output_features);
+ return output_features;
+}
+
+std::vector<float> GetCachedBoundsSensitiveFeatures(
+ const CachedFeatures& cached_features, TokenSpan selected_span) {
+ std::vector<float> output_features;
+ cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
+ &output_features);
+ return output_features;
+}
+
+TEST(CachedFeaturesTest, ClickContext) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.feature_version = 1;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ flatbuffers::DetachedBuffer options_fb = builder.Release();
+
+ std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>{112233.0, -112233.0, 321.0});
+
+ const std::unique_ptr<CachedFeatures> cached_features =
+ CachedFeatures::Create(
+ {3, 10}, std::move(features), std::move(padding_features),
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ /*feature_vector_size=*/3);
+ ASSERT_TRUE(cached_features);
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
+ 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
+ ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
+ 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
+}
+
+TEST(CachedFeaturesTest, BoundsSensitive) {
+ std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ config->enabled = true;
+ config->num_tokens_before = 2;
+ config->num_tokens_inside_left = 2;
+ config->num_tokens_inside_right = 2;
+ config->num_tokens_after = 2;
+ config->include_inside_bag = true;
+ config->include_inside_length = true;
+ FeatureProcessorOptionsT options;
+ options.bounds_sensitive_features = std::move(config);
+ options.feature_version = 2;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ flatbuffers::DetachedBuffer options_fb = builder.Release();
+
+ std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>{112233.0, -112233.0, 321.0});
+
+ const std::unique_ptr<CachedFeatures> cached_features =
+ CachedFeatures::Create(
+ {3, 9}, std::move(features), std::move(padding_features),
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ /*feature_vector_size=*/3);
+ ASSERT_TRUE(cached_features);
+
+ EXPECT_THAT(
+ GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
+ -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0}));
+
+ EXPECT_THAT(
+ GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
+ -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
+ 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
+ 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0}));
+
+ EXPECT_THAT(
+ GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
+ -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0}));
+
+ EXPECT_THAT(
+ GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
+ 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
+ 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
+ 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 44.0, -44.0, 0.4, 1.0}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/datetime/extractor.cc b/annotator/datetime/extractor.cc
new file mode 100644
index 0000000..31229dd
--- /dev/null
+++ b/annotator/datetime/extractor.cc
@@ -0,0 +1,469 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/extractor.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+bool DatetimeExtractor::Extract(DateParseData* result,
+ CodepointSpan* result_span) const {
+ result->field_set_mask = 0;
+ *result_span = {kInvalidIndex, kInvalidIndex};
+
+ if (rule_.regex->groups() == nullptr) {
+ return false;
+ }
+
+ for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) {
+ UnicodeText group_text;
+ const int group_type = rule_.regex->groups()->Get(group_id);
+ if (group_type == DatetimeGroupType_GROUP_UNUSED) {
+ continue;
+ }
+ if (!GroupTextFromMatch(group_id, &group_text)) {
+ TC3_LOG(ERROR) << "Couldn't retrieve group.";
+ return false;
+ }
+ // The pattern can have a group defined in a part that was not matched,
+ // e.g. an optional part. In this case we'll get an empty content here.
+ if (group_text.empty()) {
+ continue;
+ }
+ switch (group_type) {
+ case DatetimeGroupType_GROUP_YEAR: {
+ if (!ParseYear(group_text, &(result->year))) {
+ TC3_LOG(ERROR) << "Couldn't extract YEAR.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::YEAR_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_MONTH: {
+ if (!ParseMonth(group_text, &(result->month))) {
+ TC3_LOG(ERROR) << "Couldn't extract MONTH.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::MONTH_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_DAY: {
+ if (!ParseDigits(group_text, &(result->day_of_month))) {
+ TC3_LOG(ERROR) << "Couldn't extract DAY.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::DAY_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_HOUR: {
+ if (!ParseDigits(group_text, &(result->hour))) {
+ TC3_LOG(ERROR) << "Couldn't extract HOUR.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::HOUR_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_MINUTE: {
+ if (!ParseDigits(group_text, &(result->minute))) {
+ TC3_LOG(ERROR) << "Couldn't extract MINUTE.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::MINUTE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_SECOND: {
+ if (!ParseDigits(group_text, &(result->second))) {
+ TC3_LOG(ERROR) << "Couldn't extract SECOND.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::SECOND_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_AMPM: {
+ if (!ParseAMPM(group_text, &(result->ampm))) {
+ TC3_LOG(ERROR) << "Couldn't extract AMPM.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::AMPM_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONDISTANCE: {
+ if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATION: {
+ if (!ParseRelation(group_text, &(result->relation))) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONTYPE: {
+ if (!ParseRelationType(group_text, &(result->relation_type))) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_DUMMY1:
+ case DatetimeGroupType_GROUP_DUMMY2:
+ break;
+ default:
+ TC3_LOG(INFO) << "Unknown group type.";
+ continue;
+ }
+ if (!UpdateMatchSpan(group_id, result_span)) {
+ TC3_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
+ }
+
+ if (result_span->first == kInvalidIndex ||
+ result_span->second == kInvalidIndex) {
+ *result_span = {kInvalidIndex, kInvalidIndex};
+ }
+
+ return true;
+}
+
+bool DatetimeExtractor::RuleIdForType(DatetimeExtractorType type,
+ int* rule_id) const {
+ auto type_it = type_and_locale_to_rule_.find(type);
+ if (type_it == type_and_locale_to_rule_.end()) {
+ return false;
+ }
+
+ auto locale_it = type_it->second.find(locale_id_);
+ if (locale_it == type_it->second.end()) {
+ return false;
+ }
+ *rule_id = locale_it->second;
+ return true;
+}
+
+bool DatetimeExtractor::ExtractType(const UnicodeText& input,
+ DatetimeExtractorType extractor_type,
+ UnicodeText* match_result) const {
+ int rule_id;
+ if (!RuleIdForType(extractor_type, &rule_id)) {
+ return false;
+ }
+
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[rule_id]->Matcher(input);
+ if (!matcher) {
+ return false;
+ }
+
+ int status;
+ if (!matcher->Find(&status)) {
+ return false;
+ }
+
+ if (match_result != nullptr) {
+ *match_result = matcher->Group(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool DatetimeExtractor::GroupTextFromMatch(int group_id,
+ UnicodeText* result) const {
+ int status;
+ *result = matcher_.Group(group_id, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeExtractor::UpdateMatchSpan(int group_id,
+ CodepointSpan* span) const {
+ int status;
+ const int match_start = matcher_.Start(group_id, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ const int match_end = matcher_.End(group_id, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ if (span->first == kInvalidIndex || span->first > match_start) {
+ span->first = match_start;
+ }
+ if (span->second == kInvalidIndex || span->second < match_end) {
+ span->second = match_end;
+ }
+
+ return true;
+}
+
+template <typename T>
+bool DatetimeExtractor::MapInput(
+ const UnicodeText& input,
+ const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
+ T* result) const {
+ for (const auto& type_value_pair : mapping) {
+ if (ExtractType(input, type_value_pair.first)) {
+ *result = type_value_pair.second;
+ return true;
+ }
+ }
+ return false;
+}
+
+bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input,
+ int* parsed_number) const {
+ std::vector<std::pair<int, int>> found_numbers;
+ for (const auto& type_value_pair :
+ std::vector<std::pair<DatetimeExtractorType, int>>{
+ {DatetimeExtractorType_ZERO, 0},
+ {DatetimeExtractorType_ONE, 1},
+ {DatetimeExtractorType_TWO, 2},
+ {DatetimeExtractorType_THREE, 3},
+ {DatetimeExtractorType_FOUR, 4},
+ {DatetimeExtractorType_FIVE, 5},
+ {DatetimeExtractorType_SIX, 6},
+ {DatetimeExtractorType_SEVEN, 7},
+ {DatetimeExtractorType_EIGHT, 8},
+ {DatetimeExtractorType_NINE, 9},
+ {DatetimeExtractorType_TEN, 10},
+ {DatetimeExtractorType_ELEVEN, 11},
+ {DatetimeExtractorType_TWELVE, 12},
+ {DatetimeExtractorType_THIRTEEN, 13},
+ {DatetimeExtractorType_FOURTEEN, 14},
+ {DatetimeExtractorType_FIFTEEN, 15},
+ {DatetimeExtractorType_SIXTEEN, 16},
+ {DatetimeExtractorType_SEVENTEEN, 17},
+ {DatetimeExtractorType_EIGHTEEN, 18},
+ {DatetimeExtractorType_NINETEEN, 19},
+ {DatetimeExtractorType_TWENTY, 20},
+ {DatetimeExtractorType_THIRTY, 30},
+ {DatetimeExtractorType_FORTY, 40},
+ {DatetimeExtractorType_FIFTY, 50},
+ {DatetimeExtractorType_SIXTY, 60},
+ {DatetimeExtractorType_SEVENTY, 70},
+ {DatetimeExtractorType_EIGHTY, 80},
+ {DatetimeExtractorType_NINETY, 90},
+ {DatetimeExtractorType_HUNDRED, 100},
+ {DatetimeExtractorType_THOUSAND, 1000},
+ }) {
+ int rule_id;
+ if (!RuleIdForType(type_value_pair.first, &rule_id)) {
+ return false;
+ }
+
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[rule_id]->Matcher(input);
+ if (!matcher) {
+ return false;
+ }
+
+ int status;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ int span_start = matcher->Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ found_numbers.push_back({span_start, type_value_pair.second});
+ }
+ }
+
+ std::sort(found_numbers.begin(), found_numbers.end(),
+ [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+ return a.first < b.first;
+ });
+
+ int sum = 0;
+ int running_value = -1;
+ // Simple math to make sure we handle written numerical modifiers correctly
+ // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1.
+ for (const std::pair<int, int> position_number_pair : found_numbers) {
+ if (running_value >= 0) {
+ if (running_value > position_number_pair.second) {
+ sum += running_value;
+ running_value = position_number_pair.second;
+ } else {
+ running_value *= position_number_pair.second;
+ }
+ } else {
+ running_value = position_number_pair.second;
+ }
+ }
+ sum += running_value;
+ *parsed_number = sum;
+ return true;
+}
+
+bool DatetimeExtractor::ParseDigits(const UnicodeText& input,
+ int* parsed_digits) const {
+ UnicodeText digit;
+ if (!ExtractType(input, DatetimeExtractorType_DIGITS, &digit)) {
+ return false;
+ }
+
+ if (!unilib_.ParseInt32(digit, parsed_digits)) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeExtractor::ParseYear(const UnicodeText& input,
+ int* parsed_year) const {
+ if (!ParseDigits(input, parsed_year)) {
+ return false;
+ }
+
+ if (*parsed_year < 100) {
+ if (*parsed_year < 50) {
+ *parsed_year += 2000;
+ } else {
+ *parsed_year += 1900;
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeExtractor::ParseMonth(const UnicodeText& input,
+ int* parsed_month) const {
+ if (ParseDigits(input, parsed_month)) {
+ return true;
+ }
+
+ if (MapInput(input,
+ {
+ {DatetimeExtractorType_JANUARY, 1},
+ {DatetimeExtractorType_FEBRUARY, 2},
+ {DatetimeExtractorType_MARCH, 3},
+ {DatetimeExtractorType_APRIL, 4},
+ {DatetimeExtractorType_MAY, 5},
+ {DatetimeExtractorType_JUNE, 6},
+ {DatetimeExtractorType_JULY, 7},
+ {DatetimeExtractorType_AUGUST, 8},
+ {DatetimeExtractorType_SEPTEMBER, 9},
+ {DatetimeExtractorType_OCTOBER, 10},
+ {DatetimeExtractorType_NOVEMBER, 11},
+ {DatetimeExtractorType_DECEMBER, 12},
+ },
+ parsed_month)) {
+ return true;
+ }
+
+ return false;
+}
+
+bool DatetimeExtractor::ParseAMPM(const UnicodeText& input,
+ int* parsed_ampm) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_AM, DateParseData::AMPM::AM},
+ {DatetimeExtractorType_PM, DateParseData::AMPM::PM},
+ },
+ parsed_ampm);
+}
+
+bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input,
+ int* parsed_distance) const {
+ if (ParseDigits(input, parsed_distance)) {
+ return true;
+ }
+ if (ParseWrittenNumber(input, parsed_distance)) {
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeExtractor::ParseRelation(
+ const UnicodeText& input, DateParseData::Relation* parsed_relation) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_NOW, DateParseData::Relation::NOW},
+ {DatetimeExtractorType_YESTERDAY, DateParseData::Relation::YESTERDAY},
+ {DatetimeExtractorType_TOMORROW, DateParseData::Relation::TOMORROW},
+ {DatetimeExtractorType_NEXT, DateParseData::Relation::NEXT},
+ {DatetimeExtractorType_NEXT_OR_SAME,
+ DateParseData::Relation::NEXT_OR_SAME},
+ {DatetimeExtractorType_LAST, DateParseData::Relation::LAST},
+ {DatetimeExtractorType_PAST, DateParseData::Relation::PAST},
+ {DatetimeExtractorType_FUTURE, DateParseData::Relation::FUTURE},
+ },
+ parsed_relation);
+}
+
+bool DatetimeExtractor::ParseRelationType(
+ const UnicodeText& input,
+ DateParseData::RelationType* parsed_relation_type) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
+ {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
+ {DatetimeExtractorType_DAY, DateParseData::DAY},
+ {DatetimeExtractorType_WEEK, DateParseData::WEEK},
+ {DatetimeExtractorType_MONTH, DateParseData::MONTH},
+ {DatetimeExtractorType_YEAR, DateParseData::YEAR},
+ },
+ parsed_relation_type);
+}
+
+bool DatetimeExtractor::ParseTimeUnit(const UnicodeText& input,
+ int* parsed_time_unit) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_DAYS, DateParseData::DAYS},
+ {DatetimeExtractorType_WEEKS, DateParseData::WEEKS},
+ {DatetimeExtractorType_MONTHS, DateParseData::MONTHS},
+ {DatetimeExtractorType_HOURS, DateParseData::HOURS},
+ {DatetimeExtractorType_MINUTES, DateParseData::MINUTES},
+ {DatetimeExtractorType_SECONDS, DateParseData::SECONDS},
+ {DatetimeExtractorType_YEARS, DateParseData::YEARS},
+ },
+ parsed_time_unit);
+}
+
+bool DatetimeExtractor::ParseWeekday(const UnicodeText& input,
+ int* parsed_weekday) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
+ {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
+ },
+ parsed_weekday);
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/datetime/extractor.h b/annotator/datetime/extractor.h
new file mode 100644
index 0000000..4c17aa7
--- /dev/null
+++ b/annotator/datetime/extractor.h
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+struct CompiledRule {
+ // The compiled regular expression.
+ std::unique_ptr<const UniLib::RegexPattern> compiled_regex;
+
+ // The uncompiled pattern and information about the pattern groups.
+ const DatetimeModelPattern_::Regex* regex;
+
+ // DatetimeModelPattern which 'regex' is part of and comes from.
+ const DatetimeModelPattern* pattern;
+};
+
+// A helper class for DatetimeParser that extracts structured data
+// (DateParseDate) from the current match of the passed RegexMatcher.
+class DatetimeExtractor {
+ public:
+ DatetimeExtractor(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int locale_id, const UniLib& unilib,
+ const std::vector<std::unique_ptr<const UniLib::RegexPattern>>&
+ extractor_rules,
+ const std::unordered_map<DatetimeExtractorType,
+ std::unordered_map<int, int>>&
+ type_and_locale_to_extractor_rule)
+ : rule_(rule),
+ matcher_(matcher),
+ locale_id_(locale_id),
+ unilib_(unilib),
+ rules_(extractor_rules),
+ type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {}
+ bool Extract(DateParseData* result, CodepointSpan* result_span) const;
+
+ private:
+ bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const;
+
+ // Returns true if the rule for given extractor matched. If it matched,
+ // match_result will contain the first group of the rule (if match_result not
+ // nullptr).
+ bool ExtractType(const UnicodeText& input,
+ DatetimeExtractorType extractor_type,
+ UnicodeText* match_result = nullptr) const;
+
+ bool GroupTextFromMatch(int group_id, UnicodeText* result) const;
+
+ // Updates the span to include the current match for the given group.
+ bool UpdateMatchSpan(int group_id, CodepointSpan* span) const;
+
+ // Returns true if any of the extractors from 'mapping' matched. If it did,
+ // will fill 'result' with the associated value from 'mapping'.
+ template <typename T>
+ bool MapInput(const UnicodeText& input,
+ const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
+ T* result) const;
+
+ bool ParseDigits(const UnicodeText& input, int* parsed_digits) const;
+ bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
+ bool ParseYear(const UnicodeText& input, int* parsed_year) const;
+ bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
+ bool ParseAMPM(const UnicodeText& input, int* parsed_ampm) const;
+ bool ParseRelation(const UnicodeText& input,
+ DateParseData::Relation* parsed_relation) const;
+ bool ParseRelationDistance(const UnicodeText& input,
+ int* parsed_distance) const;
+ bool ParseTimeUnit(const UnicodeText& input, int* parsed_time_unit) const;
+ bool ParseRelationType(
+ const UnicodeText& input,
+ DateParseData::RelationType* parsed_relation_type) const;
+ bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const;
+
+ const CompiledRule& rule_;
+ const UniLib::RegexMatcher& matcher_;
+ int locale_id_;
+ const UniLib& unilib_;
+ const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& rules_;
+ const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>&
+ type_and_locale_to_rule_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
diff --git a/annotator/datetime/parser.cc b/annotator/datetime/parser.cc
new file mode 100644
index 0000000..ac3a62d
--- /dev/null
+++ b/annotator/datetime/parser.cc
@@ -0,0 +1,406 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/parser.h"
+
+#include <set>
+#include <unordered_set>
+
+#include "annotator/datetime/extractor.h"
+#include "utils/calendar/calendar.h"
+#include "utils/i18n/locale.h"
+#include "utils/strings/split.h"
+
+namespace libtextclassifier3 {
+std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
+ const DatetimeModel* model, const UniLib& unilib,
+ const CalendarLib& calendarlib, ZlibDecompressor* decompressor) {
+ std::unique_ptr<DatetimeParser> result(
+ new DatetimeParser(model, unilib, calendarlib, decompressor));
+ if (!result->initialized_) {
+ result.reset();
+ }
+ return result;
+}
+
+DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
+ const CalendarLib& calendarlib,
+ ZlibDecompressor* decompressor)
+ : unilib_(unilib), calendarlib_(calendarlib) {
+ initialized_ = false;
+
+ if (model == nullptr) {
+ return;
+ }
+
+ if (model->patterns() != nullptr) {
+ for (const DatetimeModelPattern* pattern : *model->patterns()) {
+ if (pattern->regexes()) {
+ for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ UncompressMakeRegexPattern(unilib, regex->pattern(),
+ regex->compressed_pattern(),
+ decompressor);
+ if (!regex_pattern) {
+ TC3_LOG(ERROR) << "Couldn't create rule pattern.";
+ return;
+ }
+ rules_.push_back({std::move(regex_pattern), regex, pattern});
+ if (pattern->locales()) {
+ for (int locale : *pattern->locales()) {
+ locale_to_rules_[locale].push_back(rules_.size() - 1);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (model->extractors() != nullptr) {
+ for (const DatetimeModelExtractor* extractor : *model->extractors()) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ UncompressMakeRegexPattern(unilib, extractor->pattern(),
+ extractor->compressed_pattern(),
+ decompressor);
+ if (!regex_pattern) {
+ TC3_LOG(ERROR) << "Couldn't create extractor pattern";
+ return;
+ }
+ extractor_rules_.push_back(std::move(regex_pattern));
+
+ if (extractor->locales()) {
+ for (int locale : *extractor->locales()) {
+ type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
+ extractor_rules_.size() - 1;
+ }
+ }
+ }
+ }
+
+ if (model->locales() != nullptr) {
+ for (int i = 0; i < model->locales()->Length(); ++i) {
+ locale_string_to_id_[model->locales()->Get(i)->str()] = i;
+ }
+ }
+
+ if (model->default_locales() != nullptr) {
+ for (const int locale : *model->default_locales()) {
+ default_locale_ids_.push_back(locale);
+ }
+ }
+
+ use_extractors_for_locating_ = model->use_extractors_for_locating();
+
+ initialized_ = true;
+}
+
+bool DatetimeParser::Parse(
+ const std::string& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
+ reference_time_ms_utc, reference_timezone, locales, mode,
+ anchor_start_end, results);
+}
+
+bool DatetimeParser::FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules,
+ std::vector<DatetimeParseResultSpan>* found_spans) const {
+ for (const int locale_id : locale_ids) {
+ auto rules_it = locale_to_rules_.find(locale_id);
+ if (rules_it == locale_to_rules_.end()) {
+ continue;
+ }
+
+ for (const int rule_id : rules_it->second) {
+ // Skip rules that were already executed in previous locales.
+ if (executed_rules->find(rule_id) != executed_rules->end()) {
+ continue;
+ }
+
+ if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
+ continue;
+ }
+
+ executed_rules->insert(rule_id);
+
+ if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ anchor_start_end, found_spans)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool DatetimeParser::Parse(
+ const UnicodeText& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
+ std::unordered_set<int> executed_rules;
+ std::string reference_locale;
+ const std::vector<int> requested_locales =
+ ParseAndExpandLocales(locales, &reference_locale);
+ if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
+ reference_timezone, mode, anchor_start_end,
+ reference_locale, &executed_rules, &found_spans)) {
+ return false;
+ }
+
+ std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
+ int counter = 0;
+ for (const auto& found_span : found_spans) {
+ indexed_found_spans.push_back({found_span, counter});
+ counter++;
+ }
+
+ // Resolve conflicts by always picking the longer span and breaking ties by
+ // selecting the earlier entry in the list for a given locale.
+ std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
+
+ found_spans.clear();
+ for (auto& span_index_pair : indexed_found_spans) {
+ found_spans.push_back(span_index_pair.first);
+ }
+
+ std::set<int, std::function<bool(int, int)>> chosen_indices_set(
+ [&found_spans](int a, int b) {
+ return found_spans[a].span.first < found_spans[b].span.first;
+ });
+ for (int i = 0; i < found_spans.size(); ++i) {
+ if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
+ chosen_indices_set.insert(i);
+ results->push_back(found_spans[i]);
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeParser::HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const {
+ int status = UniLib::RegexMatcher::kNoError;
+ const int start = matcher.Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+
+ const int end = matcher.End(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+
+ DatetimeParseResultSpan parse_result;
+ if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
+ reference_locale, locale_id, &(parse_result.data),
+ &parse_result.span)) {
+ return false;
+ }
+ if (!use_extractors_for_locating_) {
+ parse_result.span = {start, end};
+ }
+ if (parse_result.span.first != kInvalidIndex &&
+ parse_result.span.second != kInvalidIndex) {
+ parse_result.target_classification_score =
+ rule.pattern->target_classification_score();
+ parse_result.priority_score = rule.pattern->priority_score();
+ result->push_back(parse_result);
+ }
+ return true;
+}
+
+bool DatetimeParser::ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.compiled_regex->Matcher(input);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (anchor_start_end) {
+ if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ result)) {
+ return false;
+ }
+ }
+ } else {
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ result)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+std::vector<int> DatetimeParser::ParseAndExpandLocales(
+ const std::string& locales, std::string* reference_locale) const {
+ std::vector<StringPiece> split_locales = strings::Split(locales, ',');
+ if (!split_locales.empty()) {
+ *reference_locale = split_locales[0].ToString();
+ } else {
+ *reference_locale = "";
+ }
+
+ std::vector<int> result;
+ for (const StringPiece& locale_str : split_locales) {
+ auto locale_it = locale_string_to_id_.find(locale_str.ToString());
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ continue;
+ }
+
+ const std::string language = locale.Language();
+ const std::string script = locale.Script();
+ const std::string region = locale.Region();
+
+ // First, try adding *-region locale.
+ if (!region.empty()) {
+ locale_it = locale_string_to_id_.find("*-" + region);
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Second, try adding language-script-* locale.
+ if (!script.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Third, try adding language-* locale.
+ if (!language.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ }
+
+ // Add the default locales if they haven't been added already.
+ const std::unordered_set<int> result_set(result.begin(), result.end());
+ for (const int default_locale_id : default_locale_ids_) {
+ if (result_set.find(default_locale_id) == result_set.end()) {
+ result.push_back(default_locale_id);
+ }
+ }
+
+ return result;
+}
+
+namespace {
+
+DatetimeGranularity GetGranularity(const DateParseData& data) {
+ DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::YEAR))) {
+ granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ }
+ if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::MONTH))) {
+ granularity = DatetimeGranularity::GRANULARITY_MONTH;
+ }
+ if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::WEEK)) {
+ granularity = DatetimeGranularity::GRANULARITY_WEEK;
+ }
+ if (data.field_set_mask & DateParseData::DAY_FIELD ||
+ (data.field_set_mask & DateParseData::RELATION_FIELD &&
+ (data.relation == DateParseData::Relation::NOW ||
+ data.relation == DateParseData::Relation::TOMORROW ||
+ data.relation == DateParseData::Relation::YESTERDAY)) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::MONDAY ||
+ data.relation_type == DateParseData::RelationType::TUESDAY ||
+ data.relation_type == DateParseData::RelationType::WEDNESDAY ||
+ data.relation_type == DateParseData::RelationType::THURSDAY ||
+ data.relation_type == DateParseData::RelationType::FRIDAY ||
+ data.relation_type == DateParseData::RelationType::SATURDAY ||
+ data.relation_type == DateParseData::RelationType::SUNDAY ||
+ data.relation_type == DateParseData::RelationType::DAY))) {
+ granularity = DatetimeGranularity::GRANULARITY_DAY;
+ }
+ if (data.field_set_mask & DateParseData::HOUR_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_HOUR;
+ }
+ if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_MINUTE;
+ }
+ if (data.field_set_mask & DateParseData::SECOND_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_SECOND;
+ }
+ return granularity;
+}
+
+} // namespace
+
+bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ const int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ int locale_id, DatetimeParseResult* result,
+ CodepointSpan* result_span) const {
+ DateParseData parse;
+ DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
+ extractor_rules_,
+ type_and_locale_to_extractor_rule_);
+ if (!extractor.Extract(&parse, result_span)) {
+ return false;
+ }
+
+ result->granularity = GetGranularity(parse);
+
+ if (!calendarlib_.InterpretParseData(
+ parse, reference_time_ms_utc, reference_timezone, reference_locale,
+ result->granularity, &(result->time_ms_utc))) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h
new file mode 100644
index 0000000..c7eaf1f
--- /dev/null
+++ b/annotator/datetime/parser.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/datetime/extractor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/calendar/calendar.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class DatetimeParser {
+ public:
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib& unilib,
+ const CalendarLib& calendarlib, ZlibDecompressor* decompressor);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
+ bool Parse(const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ // Same as above but takes UnicodeText.
+ bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ protected:
+ DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
+ const CalendarLib& calendarlib,
+ ZlibDecompressor* decompressor);
+
+ // Returns a list of locale ids for given locale spec string (comma-separated
+ // locale names). Assigns the first parsed locale to reference_locale.
+ std::vector<int> ParseAndExpandLocales(const std::string& locales,
+ std::string* reference_locale) const;
+
+ // Helper function that finds datetime spans, only using the rules associated
+ // with the given locales.
+ bool FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules,
+ std::vector<DatetimeParseResultSpan>* found_spans) const;
+
+ bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
+ // Converts the current match in 'matcher' into DatetimeParseResult.
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ DatetimeParseResult* result,
+ CodepointSpan* result_span) const;
+
+ // Parse and extract information from current match in 'matcher'.
+ bool HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
+ private:
+ bool initialized_;
+ const UniLib& unilib_;
+ const CalendarLib& calendarlib_;
+ std::vector<CompiledRule> rules_;
+ std::unordered_map<int, std::vector<int>> locale_to_rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
+ std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
+ type_and_locale_to_extractor_rule_;
+ std::unordered_map<std::string, int> locale_string_to_id_;
+ std::vector<int> default_locale_ids_;
+ bool use_extractors_for_locating_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
new file mode 100644
index 0000000..d46accf
--- /dev/null
+++ b/annotator/datetime/parser_test.cc
@@ -0,0 +1,413 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <time.h>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "annotator/annotator.h"
+#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+
+using testing::ElementsAreArray;
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetModelPath() {
+ return TC3_TEST_DATA_DIR;
+}
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string FormatMillis(int64 time_ms_utc) {
+ long time_seconds = time_ms_utc / 1000; // NOLINT
+ // Format time, "ddd yyyy-mm-dd hh:mm:ss zzz"
+ char buffer[512];
+ strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
+ localtime(&time_seconds));
+ return std::string(buffer);
+}
+
+class ParserTest : public testing::Test {
+ public:
+ void SetUp() override {
+ model_buffer_ = ReadFile(GetModelPath() + "test_model.fb");
+ classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(),
+ model_buffer_.size(), &unilib_);
+ TC3_CHECK(classifier_);
+ parser_ = classifier_->DatetimeParserForTests();
+ }
+
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich") {
+ std::vector<DatetimeParseResultSpan> results;
+ if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
+ anchor_start_end, &results)) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ return results.empty();
+ }
+
+ bool ParsesCorrectly(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US") {
+ const UnicodeText marked_text_unicode =
+ UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
+ auto brace_open_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
+ auto brace_end_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
+ TC3_CHECK(brace_open_it != marked_text_unicode.end());
+ TC3_CHECK(brace_end_it != marked_text_unicode.end());
+
+ std::string text;
+ text +=
+ UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_end_it),
+ marked_text_unicode.end());
+
+ std::vector<DatetimeParseResultSpan> results;
+
+ if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
+ anchor_start_end, &results)) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ if (results.empty()) {
+ TC3_LOG(ERROR) << "No results.";
+ return false;
+ }
+
+ const int expected_start_index =
+ std::distance(marked_text_unicode.begin(), brace_open_it);
+ // The -1 bellow is to account for the opening bracket character.
+ const int expected_end_index =
+ std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
+
+ std::vector<DatetimeParseResultSpan> filtered_results;
+ for (const DatetimeParseResultSpan& result : results) {
+ if (SpansOverlap(result.span,
+ {expected_start_index, expected_end_index})) {
+ filtered_results.push_back(result);
+ }
+ }
+
+ const std::vector<DatetimeParseResultSpan> expected{
+ {{expected_start_index, expected_end_index},
+ {expected_ms_utc, expected_granularity},
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/0.1}};
+ const bool matches =
+ testing::Matches(ElementsAreArray(expected))(filtered_results);
+ if (!matches) {
+ TC3_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: "
+ << FormatMillis(expected[0].data.time_ms_utc);
+ for (int i = 0; i < filtered_results.size(); ++i) {
+ TC3_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i]
+ << " which corresponds to: "
+ << FormatMillis(filtered_results[i].data.time_ms_utc);
+ }
+ }
+ return matches;
+ }
+
+ bool ParsesCorrectlyGerman(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+ }
+
+ protected:
+ std::string model_buffer_;
+ std::unique_ptr<Annotator> classifier_;
+ const DatetimeParser* parser_;
+ UniLib unilib_;
+};
+
+// Test with just a few cases to make debugging of general failures easier.
+TEST_F(ParserTest, ParseShort) {
+ EXPECT_TRUE(
+ ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
+}
+
+TEST_F(ParserTest, Parse) {
+ EXPECT_TRUE(
+ ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
+ GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}", 1277512289000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectly("{23/Apr 11:42:35}", 9715355000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly(
+ "Are sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. {19/apr/2010 06:36:15} Are "
+ "sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. ",
+ 1271651775000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
+ GRANULARITY_HOUR));
+
+ EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -3600000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -57600000, GRANULARITY_MINUTE,
+ /*anchor_start_end=*/false,
+ "America/Los_Angeles"));
+ EXPECT_TRUE(
+ ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
+ GRANULARITY_MINUTE));
+}
+
+TEST_F(ParserTest, ParseWithAnchor) {
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
+ GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
+ GRANULARITY_DAY, /*anchor_start_end=*/true));
+ EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
+ GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
+ /*anchor_start_end=*/true));
+}
+
+TEST_F(ParserTest, ParseGerman) {
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum",
+ 1514761200000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}", 1271651775000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}", 1291253998000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}", 1277512289000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35}", 9715355000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}",
+ 1514820600000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4 nachm}", 1514818800000,
+ GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{morgen 0:00}", 82800000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR));
+}
+
+TEST_F(ParserTest, ParseNonUs) {
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-GB"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en"));
+}
+
+TEST_F(ParserTest, ParseUs) {
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-US"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"es-US"));
+}
+
+TEST_F(ParserTest, ParseUnknownLanguage) {
+ EXPECT_TRUE(ParsesCorrectly("bylo to {31. 12. 2015} v 6 hodin", 1451516400000,
+ GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
+}
+
+class ParserLocaleTest : public testing::Test {
+ public:
+ void SetUp() override;
+ bool HasResult(const std::string& input, const std::string& locales);
+
+ protected:
+ UniLib unilib_;
+ CalendarLib calendarlib_;
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<DatetimeParser> parser_;
+};
+
+void AddPattern(const std::string& regex, int locale,
+ std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
+ patterns->emplace_back(new DatetimeModelPatternT);
+ patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
+ patterns->back()->regexes.back()->pattern = regex;
+ patterns->back()->regexes.back()->groups.push_back(
+ DatetimeGroupType_GROUP_UNUSED);
+ patterns->back()->locales.push_back(locale);
+}
+
+void ParserLocaleTest::SetUp() {
+ DatetimeModelT model;
+ model.use_extractors_for_locating = false;
+ model.locales.clear();
+ model.locales.push_back("en-US");
+ model.locales.push_back("en-CH");
+ model.locales.push_back("zh-Hant");
+ model.locales.push_back("en-*");
+ model.locales.push_back("zh-Hant-*");
+ model.locales.push_back("*-CH");
+ model.locales.push_back("default");
+ model.default_locales.push_back(6);
+
+ AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
+ AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
+ AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
+ AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
+ AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
+
+ builder_.Finish(DatetimeModel::Pack(builder_, &model));
+ const DatetimeModel* model_fb =
+ flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
+ ASSERT_TRUE(model_fb);
+
+ parser_ = DatetimeParser::Instance(model_fb, unilib_, calendarlib_,
+ /*decompressor=*/nullptr);
+ ASSERT_TRUE(parser_);
+}
+
+bool ParserLocaleTest::HasResult(const std::string& input,
+ const std::string& locales) {
+ std::vector<DatetimeParseResultSpan> results;
+ EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", locales,
+ ModeFlag_ANNOTATION, false, &results));
+ return results.size() == 1;
+}
+
+TEST_F(ParserLocaleTest, English) {
+ EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+TEST_F(ParserLocaleTest, TraditionalChinese) {
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
+}
+
+TEST_F(ParserLocaleTest, SwissEnglish) {
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
+ EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/feature-processor.cc b/annotator/feature-processor.cc
new file mode 100644
index 0000000..a18393b
--- /dev/null
+++ b/annotator/feature-processor.cc
@@ -0,0 +1,988 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/feature-processor.h"
+
+#include <iterator>
+#include <set>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+namespace internal {
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions* const options) {
+ TokenFeatureExtractorOptions extractor_options;
+
+ extractor_options.num_buckets = options->num_buckets();
+ if (options->chargram_orders() != nullptr) {
+ for (int order : *options->chargram_orders()) {
+ extractor_options.chargram_orders.push_back(order);
+ }
+ }
+ extractor_options.max_word_length = options->max_word_length();
+ extractor_options.extract_case_feature = options->extract_case_feature();
+ extractor_options.unicode_aware_features = options->unicode_aware_features();
+ extractor_options.extract_selection_mask_feature =
+ options->extract_selection_mask_feature();
+ if (options->regexp_feature() != nullptr) {
+ for (const auto& regexp_feauture : *options->regexp_feature()) {
+ extractor_options.regexp_features.push_back(regexp_feauture->str());
+ }
+ }
+ extractor_options.remap_digits = options->remap_digits();
+ extractor_options.lowercase_tokens = options->lowercase_tokens();
+
+ if (options->allowed_chargrams() != nullptr) {
+ for (const auto& chargram : *options->allowed_chargrams()) {
+ extractor_options.allowed_chargrams.insert(chargram->str());
+ }
+ }
+ return extractor_options;
+}
+
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens) {
+ for (auto it = tokens->begin(); it != tokens->end(); ++it) {
+ const UnicodeText token_word =
+ UTF8ToUnicodeText(it->value, /*do_copy=*/false);
+
+ auto last_start = token_word.begin();
+ int last_start_index = it->start;
+ std::vector<UnicodeText::const_iterator> split_points;
+
+ // Selection start split point.
+ if (selection.first > it->start && selection.first < it->end) {
+ std::advance(last_start, selection.first - last_start_index);
+ split_points.push_back(last_start);
+ last_start_index = selection.first;
+ }
+
+ // Selection end split point.
+ if (selection.second > it->start && selection.second < it->end) {
+ std::advance(last_start, selection.second - last_start_index);
+ split_points.push_back(last_start);
+ }
+
+ if (!split_points.empty()) {
+ // Add a final split for the rest of the token unless it's been all
+ // consumed already.
+ if (split_points.back() != token_word.end()) {
+ split_points.push_back(token_word.end());
+ }
+
+ std::vector<Token> replacement_tokens;
+ last_start = token_word.begin();
+ int current_pos = it->start;
+ for (const auto& split_point : split_points) {
+ Token new_token(token_word.UTF8Substring(last_start, split_point),
+ current_pos,
+ current_pos + std::distance(last_start, split_point));
+
+ last_start = split_point;
+ current_pos = new_token.end;
+
+ replacement_tokens.push_back(new_token);
+ }
+
+ it = tokens->erase(it);
+ it = tokens->insert(it, replacement_tokens.begin(),
+ replacement_tokens.end());
+ std::advance(it, replacement_tokens.size() - 1);
+ }
+ }
+}
+
+} // namespace internal
+
+void FeatureProcessor::StripTokensFromOtherLines(
+ const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const {
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ StripTokensFromOtherLines(context_unicode, span, tokens);
+}
+
+void FeatureProcessor::StripTokensFromOtherLines(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ std::vector<Token>* tokens) const {
+ std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
+
+ auto span_start = context_unicode.begin();
+ if (span.first > 0) {
+ std::advance(span_start, span.first);
+ }
+ auto span_end = context_unicode.begin();
+ if (span.second > 0) {
+ std::advance(span_end, span.second);
+ }
+ for (const UnicodeTextRange& line : lines) {
+ // Find the line that completely contains the span.
+ if (line.first <= span_start && line.second >= span_end) {
+ const CodepointIndex last_line_begin_index =
+ std::distance(context_unicode.begin(), line.first);
+ const CodepointIndex last_line_end_index =
+ last_line_begin_index + std::distance(line.first, line.second);
+
+ for (auto token = tokens->begin(); token != tokens->end();) {
+ if (token->start >= last_line_begin_index &&
+ token->end <= last_line_end_index) {
+ ++token;
+ } else {
+ token = tokens->erase(token);
+ }
+ }
+ }
+ }
+}
+
+std::string FeatureProcessor::GetDefaultCollection() const {
+ if (options_->default_collection() < 0 ||
+ options_->collections() == nullptr ||
+ options_->default_collection() >= options_->collections()->size()) {
+ TC3_LOG(ERROR)
+ << "Invalid or missing default collection. Returning empty string.";
+ return "";
+ }
+ return (*options_->collections())[options_->default_collection()]->str();
+}
+
+std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
+ const UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
+ return Tokenize(text_unicode);
+}
+
+std::vector<Token> FeatureProcessor::Tokenize(
+ const UnicodeText& text_unicode) const {
+ if (options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) {
+ return tokenizer_.Tokenize(text_unicode);
+ } else if (options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_ICU ||
+ options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_MIXED) {
+ std::vector<Token> result;
+ if (!ICUTokenize(text_unicode, &result)) {
+ return {};
+ }
+ if (options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_MIXED) {
+ InternalRetokenize(text_unicode, &result);
+ }
+ return result;
+ } else {
+ TC3_LOG(ERROR) << "Unknown tokenization type specified. Using "
+ "internal.";
+ return tokenizer_.Tokenize(text_unicode);
+ }
+}
+
+bool FeatureProcessor::LabelToSpan(
+ const int label, const VectorSpan<Token>& tokens,
+ std::pair<CodepointIndex, CodepointIndex>* span) const {
+ if (tokens.size() != GetNumContextTokens()) {
+ return false;
+ }
+
+ TokenSpan token_span;
+ if (!LabelToTokenSpan(label, &token_span)) {
+ return false;
+ }
+
+ const int result_begin_token_index = token_span.first;
+ const Token& result_begin_token =
+ tokens[options_->context_size() - result_begin_token_index];
+ const int result_begin_codepoint = result_begin_token.start;
+ const int result_end_token_index = token_span.second;
+ const Token& result_end_token =
+ tokens[options_->context_size() + result_end_token_index];
+ const int result_end_codepoint = result_end_token.end;
+
+ if (result_begin_codepoint == kInvalidIndex ||
+ result_end_codepoint == kInvalidIndex) {
+ *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
+ } else {
+ const UnicodeText token_begin_unicode =
+ UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
+ const UnicodeText token_end_unicode =
+ UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_end = token_end_unicode.end();
+
+ const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_begin, token_begin_unicode.end(),
+ /*count_from_beginning=*/true);
+ const int end_ignored =
+ CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
+ /*count_from_beginning=*/false);
+ // In case everything would be stripped, set the span to the original
+ // beginning and zero length.
+ if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
+ *span = {result_begin_codepoint, result_begin_codepoint};
+ } else {
+ *span = CodepointSpan({result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored});
+ }
+ }
+ return true;
+}
+
+bool FeatureProcessor::LabelToTokenSpan(const int label,
+ TokenSpan* token_span) const {
+ if (label >= 0 && label < label_to_selection_.size()) {
+ *token_span = label_to_selection_[label];
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool FeatureProcessor::SpanToLabel(
+ const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& tokens, int* label) const {
+ if (tokens.size() != GetNumContextTokens()) {
+ return false;
+ }
+
+ const int click_position =
+ options_->context_size(); // Click is always in the middle.
+ const int padding = options_->context_size() - options_->max_selection_span();
+
+ int span_left = 0;
+ for (int i = click_position - 1; i >= padding; i--) {
+ if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
+ ++span_left;
+ } else {
+ break;
+ }
+ }
+
+ int span_right = 0;
+ for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
+ if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
+ ++span_right;
+ } else {
+ break;
+ }
+ }
+
+ // Check that the spanned tokens cover the whole span.
+ bool tokens_match_span;
+ const CodepointIndex tokens_start = tokens[click_position - span_left].start;
+ const CodepointIndex tokens_end = tokens[click_position + span_right].end;
+ if (options_->snap_label_span_boundaries_to_containing_tokens()) {
+ tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
+ } else {
+ const UnicodeText token_left_unicode = UTF8ToUnicodeText(
+ tokens[click_position - span_left].value, /*do_copy=*/false);
+ const UnicodeText token_right_unicode = UTF8ToUnicodeText(
+ tokens[click_position + span_right].value, /*do_copy=*/false);
+
+ UnicodeText::const_iterator span_begin = token_left_unicode.begin();
+ UnicodeText::const_iterator span_end = token_right_unicode.end();
+
+ const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
+ const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
+ token_right_unicode.begin(), span_end,
+ /*count_from_beginning=*/false);
+
+ tokens_match_span = tokens_start <= span.first &&
+ tokens_start + num_punctuation_start >= span.first &&
+ tokens_end >= span.second &&
+ tokens_end - num_punctuation_end <= span.second;
+ }
+
+ if (tokens_match_span) {
+ *label = TokenSpanToLabel({span_left, span_right});
+ } else {
+ *label = kInvalidLabel;
+ }
+
+ return true;
+}
+
+int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
+ auto it = selection_to_label_.find(span);
+ if (it != selection_to_label_.end()) {
+ return it->second;
+ } else {
+ return kInvalidLabel;
+ }
+}
+
+TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens) {
+ const int codepoint_start = std::get<0>(codepoint_span);
+ const int codepoint_end = std::get<1>(codepoint_span);
+
+ TokenIndex start_token = kInvalidIndex;
+ TokenIndex end_token = kInvalidIndex;
+ for (int i = 0; i < selectable_tokens.size(); ++i) {
+ bool is_token_in_span;
+ if (snap_boundaries_to_containing_tokens) {
+ is_token_in_span = codepoint_start < selectable_tokens[i].end &&
+ codepoint_end > selectable_tokens[i].start;
+ } else {
+ is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
+ codepoint_end >= selectable_tokens[i].end;
+ }
+ if (is_token_in_span && !selectable_tokens[i].is_padding) {
+ if (start_token == kInvalidIndex) {
+ start_token = i;
+ }
+ end_token = i + 1;
+ }
+ }
+ return {start_token, end_token};
+}
+
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
+ return {selectable_tokens[token_span.first].start,
+ selectable_tokens[token_span.second - 1].end};
+}
+
+namespace {
+
+// Finds a single token that completely contains the given span.
+int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span) {
+ const int codepoint_start = std::get<0>(codepoint_span);
+ const int codepoint_end = std::get<1>(codepoint_span);
+
+ for (int i = 0; i < selectable_tokens.size(); ++i) {
+ if (codepoint_start >= selectable_tokens[i].start &&
+ codepoint_end <= selectable_tokens[i].end) {
+ return i;
+ }
+ }
+ return kInvalidIndex;
+}
+
+} // namespace
+
+namespace internal {
+
+int CenterTokenFromClick(CodepointSpan span,
+ const std::vector<Token>& selectable_tokens) {
+ int range_begin;
+ int range_end;
+ std::tie(range_begin, range_end) =
+ CodepointSpanToTokenSpan(selectable_tokens, span);
+
+ // If no exact match was found, try finding a token that completely contains
+ // the click span. This is useful e.g. when Android builds the selection
+ // using ICU tokenization, and ends up with only a portion of our space-
+ // separated token. E.g. for "(857)" Android would select "857".
+ if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
+ int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
+ if (token_index != kInvalidIndex) {
+ range_begin = token_index;
+ range_end = token_index + 1;
+ }
+ }
+
+ // We only allow clicks that are exactly 1 selectable token.
+ if (range_end - range_begin == 1) {
+ return range_begin;
+ } else {
+ return kInvalidIndex;
+ }
+}
+
+int CenterTokenFromMiddleOfSelection(
+ CodepointSpan span, const std::vector<Token>& selectable_tokens) {
+ int range_begin;
+ int range_end;
+ std::tie(range_begin, range_end) =
+ CodepointSpanToTokenSpan(selectable_tokens, span);
+
+ // Center the clicked token in the selection range.
+ if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
+ return (range_begin + range_end - 1) / 2;
+ } else {
+ return kInvalidIndex;
+ }
+}
+
+} // namespace internal
+
+int FeatureProcessor::FindCenterToken(CodepointSpan span,
+ const std::vector<Token>& tokens) const {
+ if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
+ return internal::CenterTokenFromClick(span, tokens);
+ } else if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
+ return internal::CenterTokenFromMiddleOfSelection(span, tokens);
+ } else if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
+ // TODO(zilka): Remove once we have new models on the device.
+ // It uses the fact that sharing model use
+ // split_tokens_on_selection_boundaries and selection not. So depending on
+ // this we select the right way of finding the click location.
+ if (!options_->split_tokens_on_selection_boundaries()) {
+ // SmartSelection model.
+ return internal::CenterTokenFromClick(span, tokens);
+ } else {
+ // SmartSharing model.
+ return internal::CenterTokenFromMiddleOfSelection(span, tokens);
+ }
+ } else {
+ TC3_LOG(ERROR) << "Invalid center token selection method.";
+ return kInvalidIndex;
+ }
+}
+
+bool FeatureProcessor::SelectionLabelSpans(
+ const VectorSpan<Token> tokens,
+ std::vector<CodepointSpan>* selection_label_spans) const {
+ for (int i = 0; i < label_to_selection_.size(); ++i) {
+ CodepointSpan span;
+ if (!LabelToSpan(i, tokens, &span)) {
+ TC3_LOG(ERROR) << "Could not convert label to span: " << i;
+ return false;
+ }
+ selection_label_spans->push_back(span);
+ }
+ return true;
+}
+
+void FeatureProcessor::PrepareCodepointRanges(
+ const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
+ codepoint_ranges,
+ std::vector<CodepointRange>* prepared_codepoint_ranges) {
+ prepared_codepoint_ranges->clear();
+ prepared_codepoint_ranges->reserve(codepoint_ranges.size());
+ for (const FeatureProcessorOptions_::CodepointRange* range :
+ codepoint_ranges) {
+ prepared_codepoint_ranges->push_back(
+ CodepointRange(range->start(), range->end()));
+ }
+
+ std::sort(prepared_codepoint_ranges->begin(),
+ prepared_codepoint_ranges->end(),
+ [](const CodepointRange& a, const CodepointRange& b) {
+ return a.start < b.start;
+ });
+}
+
+void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
+ if (options_->ignored_span_boundary_codepoints() != nullptr) {
+ for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
+ ignored_span_boundary_codepoints_.insert(codepoint);
+ }
+ }
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const {
+ if (span_start == span_end) {
+ return 0;
+ }
+
+ UnicodeText::const_iterator it;
+ UnicodeText::const_iterator it_last;
+ if (count_from_beginning) {
+ it = span_start;
+ it_last = span_end;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it_last;
+ } else {
+ it = span_end;
+ it_last = span_start;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it;
+ }
+
+ // Move until we encounter a non-ignored character.
+ int num_ignored = 0;
+ while (ignored_span_boundary_codepoints_.find(*it) !=
+ ignored_span_boundary_codepoints_.end()) {
+ ++num_ignored;
+
+ if (it == it_last) {
+ break;
+ }
+
+ if (count_from_beginning) {
+ ++it;
+ } else {
+ --it;
+ }
+ }
+
+ return num_ignored;
+}
+
+namespace {
+
+void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
+ std::vector<UnicodeTextRange>* ranges) {
+ UnicodeText::const_iterator start = t.begin();
+ UnicodeText::const_iterator curr = start;
+ UnicodeText::const_iterator end = t.end();
+ for (; curr != end; ++curr) {
+ if (codepoints.find(*curr) != codepoints.end()) {
+ if (start != curr) {
+ ranges->push_back(std::make_pair(start, curr));
+ }
+ start = curr;
+ ++start;
+ }
+ }
+ if (start != end) {
+ ranges->push_back(std::make_pair(start, end));
+ }
+}
+
+} // namespace
+
+std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
+ const UnicodeText& context_unicode) const {
+ std::vector<UnicodeTextRange> lines;
+ const std::set<char32> codepoints{{'\n', '|'}};
+ FindSubstrings(context_unicode, codepoints, &lines);
+ return lines;
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ return StripBoundaryCodepoints(context_unicode, span);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span) const {
+ if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+ return span;
+ }
+
+ UnicodeText::const_iterator span_begin = context_unicode.begin();
+ std::advance(span_begin, span.first);
+ UnicodeText::const_iterator span_end = context_unicode.begin();
+ std::advance(span_end, span.second);
+
+ const int start_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/true);
+ const int end_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/false);
+
+ if (span.first + start_offset < span.second - end_offset) {
+ return {span.first + start_offset, span.second - end_offset};
+ } else {
+ return {span.first, span.first};
+ }
+}
+
+float FeatureProcessor::SupportedCodepointsRatio(
+ const TokenSpan& token_span, const std::vector<Token>& tokens) const {
+ int num_supported = 0;
+ int num_total = 0;
+ for (int i = token_span.first; i < token_span.second; ++i) {
+ const UnicodeText value =
+ UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
+ for (auto codepoint : value) {
+ if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
+ ++num_supported;
+ }
+ ++num_total;
+ }
+ }
+ return static_cast<float>(num_supported) / static_cast<float>(num_total);
+}
+
+bool FeatureProcessor::IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
+ auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
+ codepoint,
+ [](const CodepointRange& range, int codepoint) {
+ // This function compares range with the
+ // codepoint for the purpose of finding the first
+ // greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when
+ // range < codepoint; the first time it will
+ // return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is
+ // range.end <= codepoint here but when codepoint
+ // == range.end it means it's actually just
+ // outside of the range, thus the range is less
+ // than the codepoint.
+ return range.end <= codepoint;
+ });
+ if (it != codepoint_ranges.end() && it->start <= codepoint &&
+ it->end > codepoint) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
+ const auto it = collection_to_label_.find(collection);
+ if (it == collection_to_label_.end()) {
+ return options_->default_collection();
+ } else {
+ return it->second;
+ }
+}
+
+std::string FeatureProcessor::LabelToCollection(int label) const {
+ if (label >= 0 && label < collection_to_label_.size()) {
+ return (*options_->collections())[label]->str();
+ } else {
+ return GetDefaultCollection();
+ }
+}
+
+void FeatureProcessor::MakeLabelMaps() {
+ if (options_->collections() != nullptr) {
+ for (int i = 0; i < options_->collections()->size(); ++i) {
+ collection_to_label_[(*options_->collections())[i]->str()] = i;
+ }
+ }
+
+ int selection_label_id = 0;
+ for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
+ for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
+ if (!options_->selection_reduced_output_space() ||
+ r + l <= options_->max_selection_span()) {
+ TokenSpan token_span{l, r};
+ selection_to_label_[token_span] = selection_label_id;
+ label_to_selection_.push_back(token_span);
+ ++selection_label_id;
+ }
+ }
+ }
+}
+
+void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens,
+ int* click_pos) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
+ tokens, click_pos);
+}
+
+void FeatureProcessor::RetokenizeAndFindClick(
+ const UnicodeText& context_unicode, CodepointSpan input_span,
+ bool only_use_line_with_click, std::vector<Token>* tokens,
+ int* click_pos) const {
+ TC3_CHECK(tokens != nullptr);
+
+ if (options_->split_tokens_on_selection_boundaries()) {
+ internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
+ }
+
+ if (only_use_line_with_click) {
+ StripTokensFromOtherLines(context_unicode, input_span, tokens);
+ }
+
+ int local_click_pos;
+ if (click_pos == nullptr) {
+ click_pos = &local_click_pos;
+ }
+ *click_pos = FindCenterToken(input_span, *tokens);
+ if (*click_pos == kInvalidIndex) {
+ // If the default click method failed, let's try to do sub-token matching
+ // before we fail.
+ *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ }
+}
+
+namespace internal {
+
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos) {
+ int right_context_needed = relative_click_span.second + context_size;
+ if (*click_pos + right_context_needed + 1 >= tokens->size()) {
+ // Pad max the context size.
+ const int num_pad_tokens = std::min(
+ context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
+ tokens->size()));
+ std::vector<Token> pad_tokens(num_pad_tokens);
+ tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
+ } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
+ // Strip unused tokens.
+ auto it = tokens->begin();
+ std::advance(it, *click_pos + right_context_needed + 1);
+ tokens->erase(it, tokens->end());
+ }
+
+ int left_context_needed = relative_click_span.first + context_size;
+ if (*click_pos < left_context_needed) {
+ // Pad max the context size.
+ const int num_pad_tokens =
+ std::min(context_size, left_context_needed - *click_pos);
+ std::vector<Token> pad_tokens(num_pad_tokens);
+ tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
+ *click_pos += num_pad_tokens;
+ } else if (*click_pos > left_context_needed) {
+ // Strip unused tokens.
+ auto it = tokens->begin();
+ std::advance(it, *click_pos - left_context_needed);
+ *click_pos -= it - tokens->begin();
+ tokens->erase(tokens->begin(), it);
+ }
+}
+
+} // namespace internal
+
+bool FeatureProcessor::HasEnoughSupportedCodepoints(
+ const std::vector<Token>& tokens, TokenSpan token_span) const {
+ if (options_->min_supported_codepoint_ratio() > 0) {
+ const float supported_codepoint_ratio =
+ SupportedCodepointsRatio(token_span, tokens);
+ if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
+ TC3_VLOG(1) << "Not enough supported codepoints in the context: "
+ << supported_codepoint_ratio;
+ return false;
+ }
+ }
+ return true;
+}
+
+bool FeatureProcessor::ExtractFeatures(
+ const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
+ std::unique_ptr<CachedFeatures>* cached_features) const {
+ std::unique_ptr<std::vector<float>> features(new std::vector<float>());
+ features->reserve(feature_vector_size * TokenSpanSize(token_span));
+ for (int i = token_span.first; i < token_span.second; ++i) {
+ if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
+ embedding_executor, embedding_cache,
+ features.get())) {
+ TC3_LOG(ERROR) << "Could not get token features.";
+ return false;
+ }
+ }
+
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>());
+ padding_features->reserve(feature_vector_size);
+ if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
+ embedding_executor, embedding_cache,
+ padding_features.get())) {
+ TC3_LOG(ERROR) << "Count not get padding token features.";
+ return false;
+ }
+
+ *cached_features = CachedFeatures::Create(token_span, std::move(features),
+ std::move(padding_features),
+ options_, feature_vector_size);
+ if (!*cached_features) {
+ TC3_LOG(ERROR) << "Cound not create cached features.";
+ return false;
+ }
+
+ return true;
+}
+
+bool FeatureProcessor::ICUTokenize(const UnicodeText& context_unicode,
+ std::vector<Token>* result) const {
+ std::unique_ptr<UniLib::BreakIterator> break_iterator =
+ unilib_->CreateBreakIterator(context_unicode);
+ if (!break_iterator) {
+ return false;
+ }
+ int last_break_index = 0;
+ int break_index = 0;
+ int last_unicode_index = 0;
+ int unicode_index = 0;
+ auto token_begin_it = context_unicode.begin();
+ while ((break_index = break_iterator->Next()) !=
+ UniLib::BreakIterator::kDone) {
+ const int token_length = break_index - last_break_index;
+ unicode_index = last_unicode_index + token_length;
+
+ auto token_end_it = token_begin_it;
+ std::advance(token_end_it, token_length);
+
+ // Determine if the whole token is whitespace.
+ bool is_whitespace = true;
+ for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
+ if (!unilib_->IsWhitespace(*char_it)) {
+ is_whitespace = false;
+ break;
+ }
+ }
+
+ const std::string token =
+ context_unicode.UTF8Substring(token_begin_it, token_end_it);
+
+ if (!is_whitespace || options_->icu_preserve_whitespace_tokens()) {
+ result->push_back(Token(token, last_unicode_index, unicode_index));
+ }
+
+ last_break_index = break_index;
+ last_unicode_index = unicode_index;
+ token_begin_it = token_end_it;
+ }
+
+ return true;
+}
+
+void FeatureProcessor::InternalRetokenize(const UnicodeText& unicode_text,
+ std::vector<Token>* tokens) const {
+ std::vector<Token> result;
+ CodepointSpan span(-1, -1);
+ for (Token& token : *tokens) {
+ const UnicodeText unicode_token_value =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ bool should_retokenize = true;
+ for (const int codepoint : unicode_token_value) {
+ if (!IsCodepointInRanges(codepoint,
+ internal_tokenizer_codepoint_ranges_)) {
+ should_retokenize = false;
+ break;
+ }
+ }
+
+ if (should_retokenize) {
+ if (span.first < 0) {
+ span.first = token.start;
+ }
+ span.second = token.end;
+ } else {
+ TokenizeSubstring(unicode_text, span, &result);
+ span.first = -1;
+ result.emplace_back(std::move(token));
+ }
+ }
+ TokenizeSubstring(unicode_text, span, &result);
+
+ *tokens = std::move(result);
+}
+
+void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
+ CodepointSpan span,
+ std::vector<Token>* result) const {
+ if (span.first < 0) {
+ // There is no span to tokenize.
+ return;
+ }
+
+ // Extract the substring.
+ UnicodeText::const_iterator it_begin = unicode_text.begin();
+ for (int i = 0; i < span.first; ++i) {
+ ++it_begin;
+ }
+ UnicodeText::const_iterator it_end = unicode_text.begin();
+ for (int i = 0; i < span.second; ++i) {
+ ++it_end;
+ }
+ const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
+
+ // Run the tokenizer and update the token bounds to reflect the offset of the
+ // substring.
+ std::vector<Token> tokens = tokenizer_.Tokenize(text);
+ // Avoids progressive capacity increases in the for loop.
+ result->reserve(result->size() + tokens.size());
+ for (Token& token : tokens) {
+ token.start += span.first;
+ token.end += span.first;
+ result->emplace_back(std::move(token));
+ }
+}
+
+bool FeatureProcessor::AppendTokenFeaturesWithCache(
+ const Token& token, CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const {
+ // Look for the embedded features for the token in the cache, if there is one.
+ if (embedding_cache) {
+ const auto it = embedding_cache->find({token.start, token.end});
+ if (it != embedding_cache->end()) {
+ // The embedded features were found in the cache, extract only the dense
+ // features.
+ std::vector<float> dense_features;
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ /*sparse_features=*/nullptr, &dense_features)) {
+ TC3_LOG(ERROR) << "Could not extract token's dense features.";
+ return false;
+ }
+
+ // Append both embedded and dense features to the output and return.
+ output_features->insert(output_features->end(), it->second.begin(),
+ it->second.end());
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+ }
+ }
+
+ // Extract the sparse and dense features.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ &sparse_features, &dense_features)) {
+ TC3_LOG(ERROR) << "Could not extract token's features.";
+ return false;
+ }
+
+ // Embed the sparse features, appending them directly to the output.
+ const int embedding_size = GetOptions()->embedding_size();
+ output_features->resize(output_features->size() + embedding_size);
+ float* output_features_end =
+ output_features->data() + output_features->size();
+ if (!embedding_executor->AddEmbedding(
+ TensorView<int>(sparse_features.data(),
+ {static_cast<int>(sparse_features.size())}),
+ /*dest=*/output_features_end - embedding_size,
+ /*dest_size=*/embedding_size)) {
+ TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
+ return false;
+ }
+
+ // If there is a cache, the embedded features for the token were not in it,
+ // so insert them.
+ if (embedding_cache) {
+ (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
+ output_features_end - embedding_size, output_features_end);
+ }
+
+ // Append the dense features to the output.
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/feature-processor.h b/annotator/feature-processor.h
new file mode 100644
index 0000000..2d04253
--- /dev/null
+++ b/annotator/feature-processor.h
@@ -0,0 +1,331 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature processing for FFModel (feed-forward SmartSelection model).
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
+
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "annotator/cached-features.h"
+#include "annotator/model_generated.h"
+#include "annotator/token-feature-extractor.h"
+#include "annotator/tokenizer.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+constexpr int kInvalidLabel = -1;
+
+namespace internal {
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions* options);
+
+// Splits tokens that contain the selection boundary inside them.
+// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens);
+
+// Returns the index of token that corresponds to the codepoint span.
+int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
+
+// Returns the index of token that corresponds to the middle of the codepoint
+// span.
+int CenterTokenFromMiddleOfSelection(
+ CodepointSpan span, const std::vector<Token>& selectable_tokens);
+
+// Strips the tokens from the tokens vector that are not used for feature
+// extraction because they are out of scope, or pads them so that there is
+// enough tokens in the required context_size for all inferences with a click
+// in relative_click_span.
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos);
+
+} // namespace internal
+
+// Converts a codepoint span to a token span in the given list of tokens.
+// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
+// token to overlap with the codepoint range to be considered part of it.
+// Otherwise it must be fully included in the range.
+TokenSpan CodepointSpanToTokenSpan(
+ const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens = false);
+
+// Converts a token span to a codepoint span in the given list of tokens.
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+
+// Takes care of preparing features for the span prediction model.
+class FeatureProcessor {
+ public:
+ // A cache mapping codepoint spans to embedded tokens features. An instance
+ // can be provided to multiple calls to ExtractFeatures() operating on the
+ // same context (the same codepoint spans corresponding to the same tokens),
+ // as an optimization. Note that the tokenizations do not have to be
+ // identical.
+ typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
+
+ FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
+ : unilib_(unilib),
+ feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
+ *unilib_),
+ options_(options),
+ tokenizer_(
+ options->tokenization_codepoint_config() != nullptr
+ ? Tokenizer({options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end()},
+ options->tokenize_on_script_change())
+ : Tokenizer({}, /*split_on_script_change=*/false)) {
+ MakeLabelMaps();
+ if (options->supported_codepoint_ranges() != nullptr) {
+ PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
+ options->supported_codepoint_ranges()->end()},
+ &supported_codepoint_ranges_);
+ }
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ PrepareCodepointRanges(
+ {options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end()},
+ &internal_tokenizer_codepoint_ranges_);
+ }
+ PrepareIgnoredSpanBoundaryCodepoints();
+ }
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
+
+ // Converts a label into a token span.
+ bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
+
+ // Gets the total number of selection labels.
+ int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
+ // Gets the string value for given collection label.
+ std::string LabelToCollection(int label) const;
+
+ // Gets the total number of collections of the model.
+ int NumCollections() const { return collection_to_label_.size(); }
+
+ // Gets the name of the default collection.
+ std::string GetDefaultCollection() const;
+
+ const FeatureProcessorOptions* GetOptions() const { return options_; }
+
+ // Retokenizes the context and input span, and finds the click position.
+ // Depending on the options, might modify tokens (split them or remove them).
+ void RetokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Same as above but takes UnicodeText.
+ void RetokenizeAndFindClick(const UnicodeText& context_unicode,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Returns true if the token span has enough supported codepoints (as defined
+ // in the model config) or not and model should not run.
+ bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
+ TokenSpan token_span) const;
+
+ // Extracts features as a CachedFeatures object that can be used for repeated
+ // inference over token spans in the given context.
+ bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
+ std::unique_ptr<CachedFeatures>* cached_features) const;
+
+ // Fills selection_label_spans with CodepointSpans that correspond to the
+ // selection labels. The CodepointSpans are based on the codepoint ranges of
+ // given tokens.
+ bool SelectionLabelSpans(
+ VectorSpan<Token> tokens,
+ std::vector<CodepointSpan>* selection_label_spans) const;
+
+ int DenseFeaturesCount() const {
+ return feature_extractor_.DenseFeaturesCount();
+ }
+
+ int EmbeddingSize() const { return options_->embedding_size(); }
+
+ // Splits context to several segments.
+ std::vector<UnicodeTextRange> SplitContext(
+ const UnicodeText& context_unicode) const;
+
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
+ // Same as above but takes UnicodeText.
+ CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
+ CodepointSpan span) const;
+
+ protected:
+ // Represents a codepoint range [start, end).
+ struct CodepointRange {
+ int32 start;
+ int32 end;
+
+ CodepointRange(int32 arg_start, int32 arg_end)
+ : start(arg_start), end(arg_end) {}
+ };
+
+ // Returns the class id corresponding to the given string collection
+ // identifier. There is a catch-all class id that the function returns for
+ // unknown collections.
+ int CollectionToLabel(const std::string& collection) const;
+
+ // Prepares mapping from collection names to labels.
+ void MakeLabelMaps();
+
+ // Gets the number of spannable tokens for the model.
+ //
+ // Spannable tokens are those tokens of context, which the model predicts
+ // selection spans over (i.e., there is 1:1 correspondence between the output
+ // classes of the model and each of the spannable tokens).
+ int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
+
+ // Converts a label into a span of codepoint indices corresponding to it
+ // given output_tokens.
+ bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
+ CodepointSpan* span) const;
+
+ // Converts a span to the corresponding label given output_tokens.
+ bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& output_tokens, int* label) const;
+
+ // Converts a token span to the corresponding label.
+ int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+
+ void PrepareCodepointRanges(
+ const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
+ codepoint_ranges,
+ std::vector<CodepointRange>* prepared_codepoint_ranges);
+
+ // Returns the ratio of supported codepoints to total number of codepoints in
+ // the given token span.
+ float SupportedCodepointsRatio(const TokenSpan& token_span,
+ const std::vector<Token>& tokens) const;
+
+ // Returns true if given codepoint is covered by the given sorted vector of
+ // codepoint ranges.
+ bool IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
+
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
+ // Finds the center token index in tokens vector, using the method defined
+ // in options_.
+ int FindCenterToken(CodepointSpan span,
+ const std::vector<Token>& tokens) const;
+
+ // Tokenizes the input text using ICU tokenizer.
+ bool ICUTokenize(const UnicodeText& context_unicode,
+ std::vector<Token>* result) const;
+
+ // Takes the result of ICU tokenization and retokenizes stretches of tokens
+ // made of a specific subset of characters using the internal tokenizer.
+ void InternalRetokenize(const UnicodeText& unicode_text,
+ std::vector<Token>* tokens) const;
+
+ // Tokenizes a substring of the unicode string, appending the resulting tokens
+ // to the output vector. The resulting tokens have bounds relative to the full
+ // string. Does nothing if the start of the span is negative.
+ void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
+ std::vector<Token>* result) const;
+
+ // Removes all tokens from tokens that are not on a line (defined by calling
+ // SplitContext on the context) to which span points.
+ void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
+ // Same as above but takes UnicodeText.
+ void StripTokensFromOtherLines(const UnicodeText& context_unicode,
+ CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
+ // Extracts the features of a token and appends them to the output vector.
+ // Uses the embedding cache to to avoid re-extracting the re-embedding the
+ // sparse features for the same token.
+ bool AppendTokenFeaturesWithCache(const Token& token,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const;
+
+ private:
+ const UniLib* unilib_;
+
+ protected:
+ const TokenFeatureExtractor feature_extractor_;
+
+ // Codepoint ranges that define what codepoints are supported by the model.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRange> supported_codepoint_ranges_;
+
+ // Codepoint ranges that define which tokens (consisting of which codepoints)
+ // should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
+
+ private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::set<int32> ignored_span_boundary_codepoints_;
+
+ const FeatureProcessorOptions* const options_;
+
+ // Mapping between token selection spans and labels ids.
+ std::map<TokenSpan, int> selection_to_label_;
+ std::vector<TokenSpan> label_to_selection_;
+
+ // Mapping between collections and labels.
+ std::map<std::string, int> collection_to_label_;
+
+ Tokenizer tokenizer_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
diff --git a/annotator/feature-processor_test.cc b/annotator/feature-processor_test.cc
new file mode 100644
index 0000000..c9f0e0d
--- /dev/null
+++ b/annotator/feature-processor_test.cc
@@ -0,0 +1,1125 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/feature-processor.h"
+
+#include "annotator/model-executor.h"
+#include "utils/tensor-view.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
+ const FeatureProcessorOptionsT& options) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ return builder.Release();
+}
+
+template <typename T>
+std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
+ return std::vector<T>(vector.begin() + start, vector.begin() + end);
+}
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+class TestingFeatureProcessor : public FeatureProcessor {
+ public:
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
+ using FeatureProcessor::FeatureProcessor;
+ using FeatureProcessor::ICUTokenize;
+ using FeatureProcessor::IsCodepointInRanges;
+ using FeatureProcessor::SpanToLabel;
+ using FeatureProcessor::StripTokensFromOtherLines;
+ using FeatureProcessor::supported_codepoint_ranges_;
+ using FeatureProcessor::SupportedCodepointsRatio;
+};
+
+// EmbeddingExecutor that always returns features based on
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 4);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ dest[1] = sparse_features.data()[0];
+ dest[2] = -sparse_features.data()[0];
+ dest[3] = -sparse_features.data()[0];
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
+class FeatureProcessorTest : public ::testing::Test {
+ protected:
+ FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař", 9, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař", 6, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hě", 0, 2),
+ Token("lló", 2, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(FeatureProcessorTest, KeepLineWithClickFirst) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {0, 5};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
+}
+
+TEST_F(FeatureProcessorTest, KeepLineWithClickSecond) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST_F(FeatureProcessorTest, KeepLineWithClickThird) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {24, 33};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST_F(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST_F(FeatureProcessorTest, KeepLineWithCrosslineClick) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {5, 23};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23),
+ Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23), Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST_F(FeatureProcessorTest, SpanToLabel) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST_F(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST_F(FeatureProcessorTest, CenterTokenFromClick) {
+ int token_index;
+
+ // Exactly aligned indices.
+ token_index = internal::CenterTokenFromClick(
+ {6, 11},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 1);
+
+ // Click is contained in a token.
+ token_index = internal::CenterTokenFromClick(
+ {13, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 2);
+
+ // Click spans two tokens.
+ token_index = internal::CenterTokenFromClick(
+ {6, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+}
+
+TEST_F(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
+ int token_index;
+
+ // Selection of length 3. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {7, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 2);
+
+ // Selection of length 1 token. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {21, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 3);
+
+ // Selection marks sub-token range, with no tokens in it.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {29, 33},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+
+ // Selection of length 2. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {3, 25},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 1);
+
+ // Selection of length 1. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {22, 34},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 4);
+
+ // Some invalid ones.
+ token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
+ EXPECT_EQ(token_index, -1);
+}
+
+TEST_F(FeatureProcessorTest, SupportedCodepointsRatio) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 5;
+ options.bounds_sensitive_features->num_tokens_inside_left = 3;
+ options.bounds_sensitive_features->num_tokens_inside_right = 3;
+ options.bounds_sensitive_features->num_tokens_after = 5;
+ options.bounds_sensitive_features->include_inside_bag = true;
+ options.bounds_sensitive_features->include_inside_length = true;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ {
+ options.supported_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 0;
+ range->end = 128;
+ }
+
+ {
+ options.supported_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 10000;
+ range->end = 10001;
+ }
+
+ {
+ options.supported_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 20000;
+ range->end = 30000;
+ }
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
+ FloatEq(1.0));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
+ FloatEq(2.0 / 3));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
+ FloatEq(0.0));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ -1, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 0, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 10, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 127, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 128, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 9999, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 10000, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 10001, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 25000, feature_processor.supported_codepoint_ranges_));
+
+ const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
+ Token("eee", 8, 11)};
+
+ options.min_supported_codepoint_ratio = 0.0;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+
+ options.min_supported_codepoint_ratio = 0.2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+
+ options.min_supported_codepoint_ratio = 0.5;
+ flatbuffers::DetachedBuffer options4_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor4(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
+ &unilib_);
+ EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+}
+
+TEST_F(FeatureProcessorTest, InSpanFeature) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.extract_selection_mask_feature = true;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
+ Token("ccc", 8, 11), Token("ddd", 12, 15)};
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 4},
+ /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
+ /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendClickContextFeaturesForClick(1, &features);
+ ASSERT_EQ(features.size(), 25);
+ EXPECT_THAT(features[4], FloatEq(0.0));
+ EXPECT_THAT(features[9], FloatEq(0.0));
+ EXPECT_THAT(features[14], FloatEq(1.0));
+ EXPECT_THAT(features[19], FloatEq(1.0));
+ EXPECT_THAT(features[24], FloatEq(0.0));
+}
+
+TEST_F(FeatureProcessorTest, EmbeddingCache) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 3;
+ options.bounds_sensitive_features->num_tokens_inside_left = 2;
+ options.bounds_sensitive_features->num_tokens_inside_right = 2;
+ options.bounds_sensitive_features->num_tokens_after = 3;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {
+ Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
+ Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
+
+ // We pre-populate the cache with dummy embeddings, to make sure they are
+ // used when populating the features vector.
+ const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
+ const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
+ const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
+ FeatureProcessor::EmbeddingCache embedding_cache = {
+ {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
+ {{4, 7}, cached_features1},
+ {{12, 15}, cached_features2},
+ };
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 6},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
+ ASSERT_EQ(features.size(), 40);
+ // Check that the dummy embeddings were used.
+ EXPECT_THAT(Subvector(features, 0, 4),
+ ElementsAreFloat(cached_padding_features));
+ EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
+ EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 36, 40),
+ ElementsAreFloat(cached_padding_features));
+ // Check that the real embeddings were cached.
+ EXPECT_EQ(embedding_cache.size(), 7);
+ EXPECT_THAT(Subvector(features, 4, 8),
+ ElementsAreFloat(embedding_cache.at({0, 3})));
+ EXPECT_THAT(Subvector(features, 12, 16),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 20, 24),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 28, 32),
+ ElementsAreFloat(embedding_cache.at({16, 19})));
+ EXPECT_THAT(Subvector(features, 32, 36),
+ ElementsAreFloat(embedding_cache.at({20, 23})));
+}
+
+TEST_F(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the second token nothing should get padded.
+ tokens = tokens_orig;
+ click_index = 2;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the last token tokens should get padded from the right.
+ tokens = tokens_orig;
+ click_index = 12;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+}
+
+TEST_F(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left to maximum
+ // context_size.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // Clicking to the middle with enough context should not produce any padding.
+ tokens = tokens_orig;
+ click_index = 6;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0),
+ Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+
+ // Clicking at the end should pad right to maximum context_size.
+ tokens = tokens_orig;
+ click_index = 11;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0),
+ Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+}
+
+TEST_F(FeatureProcessorTest, InternalTokenizeOnScriptChange) {
+ FeatureProcessorOptionsT options;
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ {
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 0;
+ config->end = 256;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+ }
+ options.tokenize_on_script_change = false;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
+
+ options.tokenize_on_script_change = true;
+ flatbuffers::DetachedBuffer options_fb2 =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()),
+ &unilib_);
+
+ EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
+ Token("웹사이트", 7, 11)}));
+}
+
+#ifdef TC3_TEST_ICU
+TEST_F(FeatureProcessorTest, ICUTokenize) {
+ FeatureProcessorOptionsT options;
+ options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ UniLib unilib;
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
+ std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token("สมเด็จ", 6, 12),
+ Token("พระ", 12, 15),
+ Token("ปร", 15, 17),
+ Token("มิ", 17, 19)}));
+ // clang-format on
+}
+#endif
+
+#ifdef TC3_TEST_ICU
+TEST_F(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
+ FeatureProcessorOptionsT options;
+ options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
+ options.icu_preserve_whitespace_tokens = true;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ UniLib unilib;
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
+ std::vector<Token> tokens =
+ feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token(" ", 6, 7),
+ Token("สมเด็จ", 7, 13),
+ Token(" ", 13, 14),
+ Token("พระ", 14, 17),
+ Token(" ", 17, 18),
+ Token("ปร", 18, 20),
+ Token(" ", 20, 21),
+ Token("มิ", 21, 23)}));
+ // clang-format on
+}
+#endif
+
+#ifdef TC3_TEST_ICU
+TEST_F(FeatureProcessorTest, MixedTokenize) {
+ FeatureProcessorOptionsT options;
+ options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 0;
+ range->end = 128;
+ }
+
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 128;
+ range->end = 256;
+ }
+
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 256;
+ range->end = 384;
+ }
+
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 384;
+ range->end = 592;
+ }
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ UniLib unilib;
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
+ std::vector<Token> tokens = feature_processor.Tokenize(
+ "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("こんにちは", 0, 5),
+ Token("Japanese-ląnguagę", 5, 22),
+ Token("text", 23, 27),
+ Token("世界", 28, 30),
+ Token("http://www.google.com/", 31, 53)}));
+ // clang-format on
+}
+#endif
+
+TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ FeatureProcessorOptionsT options;
+ options.ignored_span_boundary_codepoints.push_back('.');
+ options.ignored_span_boundary_codepoints.push_back(',');
+ options.ignored_span_boundary_codepoints.push_back('[');
+ options.ignored_span_boundary_codepoints.push_back(']');
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string text1_utf8 = "ěščř";
+ const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text2_utf8 = ".,abčd";
+ const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text3_utf8 = ".,abčd[]";
+ const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/false),
+ 2);
+
+ const std::string text4_utf8 = "[abčd]";
+ const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/true),
+ 1);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/false),
+ 1);
+
+ const std::string text5_utf8 = "";
+ const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text6_utf8 = "012345ěščř";
+ const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text6_begin = text6.begin();
+ std::advance(text6_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text7_utf8 = "012345.,ěščř";
+ const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text7_begin = text7.begin();
+ std::advance(text7_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7_begin, text7.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ UnicodeText::const_iterator text7_end = text7.begin();
+ std::advance(text7_end, 8);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7.begin(), text7_end,
+ /*count_from_beginning=*/false),
+ 2);
+
+ // Test not stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {0, 24}),
+ std::make_pair(0, 24));
+ // Test basic stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {6, 16}),
+ std::make_pair(9, 14));
+ // Test stripping when everything is stripped.
+ EXPECT_EQ(
+ feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
+ std::make_pair(6, 6));
+ // Test stripping empty string.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
+ std::make_pair(0, 0));
+}
+
+TEST_F(FeatureProcessorTest, CodepointSpanToTokenSpan) {
+ const std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ // Spans matching the tokens exactly.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
+
+ // Snapping to containing tokens has no effect.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
+
+ // Span boundaries inside tokens.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
+
+ // Tokens adjacent to the span, but not overlapping.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h
new file mode 100644
index 0000000..a6285dc
--- /dev/null
+++ b/annotator/knowledge/knowledge-engine-dummy.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
+
+#include <string>
+
+#include "annotator/types.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the knowledge engine.
+class KnowledgeEngine {
+ public:
+ explicit KnowledgeEngine(const UniLib* unilib) {}
+
+ bool Initialize(const std::string& serialized_config) { return true; }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const std::string& context,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
diff --git a/annotator/knowledge/knowledge-engine.h b/annotator/knowledge/knowledge-engine.h
new file mode 100644
index 0000000..4776b26
--- /dev/null
+++ b/annotator/knowledge/knowledge-engine.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_
+
+#include "annotator/knowledge/knowledge-engine-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_
diff --git a/annotator/model-executor.cc b/annotator/model-executor.cc
new file mode 100644
index 0000000..7c57e8f
--- /dev/null
+++ b/annotator/model-executor.cc
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/model-executor.h"
+
+#include "annotator/quantization.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+TensorView<float> ModelExecutor::ComputeLogits(
+ const TensorView<float>& features, tflite::Interpreter* interpreter) const {
+ if (!interpreter) {
+ return TensorView<float>::Invalid();
+ }
+ interpreter->ResizeInputTensor(kInputIndexFeatures, features.shape());
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ TC3_VLOG(1) << "Allocation failed.";
+ return TensorView<float>::Invalid();
+ }
+
+ SetInput<float>(kInputIndexFeatures, features, interpreter);
+
+ if (interpreter->Invoke() != kTfLiteOk) {
+ TC3_VLOG(1) << "Interpreter failed.";
+ return TensorView<float>::Invalid();
+ }
+
+ return OutputView<float>(kOutputIndexLogits, interpreter);
+}
+
+std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
+ int quantization_bits) {
+ std::unique_ptr<TfLiteModelExecutor> executor =
+ TfLiteModelExecutor::FromBuffer(model_spec_buffer);
+ if (!executor) {
+ TC3_LOG(ERROR) << "Could not load TFLite model for embeddings.";
+ return nullptr;
+ }
+
+ std::unique_ptr<tflite::Interpreter> interpreter =
+ executor->CreateInterpreter();
+ if (!interpreter) {
+ TC3_LOG(ERROR) << "Could not build TFLite interpreter for embeddings.";
+ return nullptr;
+ }
+
+ if (interpreter->tensors_size() != 2) {
+ return nullptr;
+ }
+ const TfLiteTensor* embeddings = interpreter->tensor(0);
+ if (embeddings->dims->size != 2) {
+ return nullptr;
+ }
+ int num_buckets = embeddings->dims->data[0];
+ const TfLiteTensor* scales = interpreter->tensor(1);
+ if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets ||
+ scales->dims->data[1] != 1) {
+ return nullptr;
+ }
+ int bytes_per_embedding = embeddings->dims->data[1];
+ if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits,
+ embedding_size)) {
+ TC3_LOG(ERROR) << "Mismatch in quantization parameters.";
+ return nullptr;
+ }
+
+ return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor(
+ std::move(executor), quantization_bits, num_buckets, bytes_per_embedding,
+ embedding_size, scales, embeddings, std::move(interpreter)));
+}
+
+TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
+ std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
+ int num_buckets, int bytes_per_embedding, int output_embedding_size,
+ const TfLiteTensor* scales, const TfLiteTensor* embeddings,
+ std::unique_ptr<tflite::Interpreter> interpreter)
+ : executor_(std::move(executor)),
+ quantization_bits_(quantization_bits),
+ num_buckets_(num_buckets),
+ bytes_per_embedding_(bytes_per_embedding),
+ output_embedding_size_(output_embedding_size),
+ scales_(scales),
+ embeddings_(embeddings),
+ interpreter_(std::move(interpreter)) {}
+
+bool TFLiteEmbeddingExecutor::AddEmbedding(
+ const TensorView<int>& sparse_features, float* dest, int dest_size) const {
+ if (dest_size != output_embedding_size_) {
+ TC3_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: "
+ << dest_size << " " << output_embedding_size_;
+ return false;
+ }
+ const int num_sparse_features = sparse_features.size();
+ for (int i = 0; i < num_sparse_features; ++i) {
+ const int bucket_id = sparse_features.data()[i];
+ if (bucket_id >= num_buckets_) {
+ return false;
+ }
+
+ if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8,
+ bytes_per_embedding_, num_sparse_features,
+ quantization_bits_, bucket_id, dest, dest_size)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/model-executor.h b/annotator/model-executor.h
new file mode 100644
index 0000000..5ad3a7f
--- /dev/null
+++ b/annotator/model-executor.h
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Contains classes that can execute different models/parts of a model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
+
+#include <memory>
+
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/tensor-view.h"
+#include "utils/tflite-model-executor.h"
+
+namespace libtextclassifier3 {
+
+// Executor for the text selection prediction and classification models.
+class ModelExecutor : public TfLiteModelExecutor {
+ public:
+ static std::unique_ptr<ModelExecutor> FromModelSpec(
+ const tflite::Model* model_spec) {
+ auto model = TfLiteModelFromModelSpec(model_spec);
+ if (!model) {
+ return nullptr;
+ }
+ return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
+ }
+
+ static std::unique_ptr<ModelExecutor> FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
+ auto model = TfLiteModelFromBuffer(model_spec_buffer);
+ if (!model) {
+ return nullptr;
+ }
+ return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
+ }
+
+ TensorView<float> ComputeLogits(const TensorView<float>& features,
+ tflite::Interpreter* interpreter) const;
+
+ protected:
+ explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
+ : TfLiteModelExecutor(std::move(model)) {}
+
+ static const int kInputIndexFeatures = 0;
+ static const int kOutputIndexLogits = 0;
+};
+
+// Executor for embedding sparse features into a dense vector.
+class EmbeddingExecutor {
+ public:
+ virtual ~EmbeddingExecutor() {}
+
+ // Embeds the sparse_features into a dense embedding and adds (+) it
+ // element-wise to the dest vector.
+ virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const = 0;
+
+ // Returns true when the model is ready to be used, false otherwise.
+ virtual bool IsReady() const { return true; }
+};
+
+class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
+ int quantization_bits);
+
+ // Embeds the sparse_features into a dense embedding and adds (+) it
+ // element-wise to the dest vector.
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const;
+
+ protected:
+ explicit TFLiteEmbeddingExecutor(
+ std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
+ int num_buckets, int bytes_per_embedding, int output_embedding_size,
+ const TfLiteTensor* scales, const TfLiteTensor* embeddings,
+ std::unique_ptr<tflite::Interpreter> interpreter);
+
+ std::unique_ptr<TfLiteModelExecutor> executor_;
+
+ int quantization_bits_;
+ int num_buckets_ = -1;
+ int bytes_per_embedding_ = -1;
+ int output_embedding_size_ = -1;
+ const TfLiteTensor* scales_ = nullptr;
+ const TfLiteTensor* embeddings_ = nullptr;
+
+ // NOTE: This interpreter is used in a read-only way (as a storage for the
+ // model params), thus is still thread-safe.
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
diff --git a/annotator/model.fbs b/annotator/model.fbs
new file mode 100755
index 0000000..3682994
--- /dev/null
+++ b/annotator/model.fbs
@@ -0,0 +1,583 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/intents/intent-config.fbs";
+include "utils/zlib/buffer.fbs";
+
+file_identifier "TC2 ";
+
+// The possible model modes, represents a bit field.
+namespace libtextclassifier3;
+enum ModeFlag : int {
+ NONE = 0,
+ ANNOTATION = 1,
+ CLASSIFICATION = 2,
+ ANNOTATION_AND_CLASSIFICATION = 3,
+ SELECTION = 4,
+ ANNOTATION_AND_SELECTION = 5,
+ CLASSIFICATION_AND_SELECTION = 6,
+ ALL = 7,
+}
+
+namespace libtextclassifier3;
+enum DatetimeExtractorType : int {
+ UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
+ AM = 1,
+ PM = 2,
+ JANUARY = 3,
+ FEBRUARY = 4,
+ MARCH = 5,
+ APRIL = 6,
+ MAY = 7,
+ JUNE = 8,
+ JULY = 9,
+ AUGUST = 10,
+ SEPTEMBER = 11,
+ OCTOBER = 12,
+ NOVEMBER = 13,
+ DECEMBER = 14,
+ NEXT = 15,
+ NEXT_OR_SAME = 16,
+ LAST = 17,
+ NOW = 18,
+ TOMORROW = 19,
+ YESTERDAY = 20,
+ PAST = 21,
+ FUTURE = 22,
+ DAY = 23,
+ WEEK = 24,
+ MONTH = 25,
+ YEAR = 26,
+ MONDAY = 27,
+ TUESDAY = 28,
+ WEDNESDAY = 29,
+ THURSDAY = 30,
+ FRIDAY = 31,
+ SATURDAY = 32,
+ SUNDAY = 33,
+ DAYS = 34,
+ WEEKS = 35,
+ MONTHS = 36,
+ HOURS = 37,
+ MINUTES = 38,
+ SECONDS = 39,
+ YEARS = 40,
+ DIGITS = 41,
+ SIGNEDDIGITS = 42,
+ ZERO = 43,
+ ONE = 44,
+ TWO = 45,
+ THREE = 46,
+ FOUR = 47,
+ FIVE = 48,
+ SIX = 49,
+ SEVEN = 50,
+ EIGHT = 51,
+ NINE = 52,
+ TEN = 53,
+ ELEVEN = 54,
+ TWELVE = 55,
+ THIRTEEN = 56,
+ FOURTEEN = 57,
+ FIFTEEN = 58,
+ SIXTEEN = 59,
+ SEVENTEEN = 60,
+ EIGHTEEN = 61,
+ NINETEEN = 62,
+ TWENTY = 63,
+ THIRTY = 64,
+ FORTY = 65,
+ FIFTY = 66,
+ SIXTY = 67,
+ SEVENTY = 68,
+ EIGHTY = 69,
+ NINETY = 70,
+ HUNDRED = 71,
+ THOUSAND = 72,
+}
+
+namespace libtextclassifier3;
+enum DatetimeGroupType : int {
+ GROUP_UNKNOWN = 0,
+ GROUP_UNUSED = 1,
+ GROUP_YEAR = 2,
+ GROUP_MONTH = 3,
+ GROUP_DAY = 4,
+ GROUP_HOUR = 5,
+ GROUP_MINUTE = 6,
+ GROUP_SECOND = 7,
+ GROUP_AMPM = 8,
+ GROUP_RELATIONDISTANCE = 9,
+ GROUP_RELATION = 10,
+ GROUP_RELATIONTYPE = 11,
+
+ // Dummy groups serve just as an inflator of the selection. E.g. we might want
+ // to select more text than was contained in an envelope of all extractor
+ // spans.
+ GROUP_DUMMY1 = 12,
+
+ GROUP_DUMMY2 = 13,
+}
+
+// Options for the model that predicts text selection.
+namespace libtextclassifier3;
+table SelectionModelOptions {
+ // If true, before the selection is returned, the unpaired brackets contained
+ // in the predicted selection are stripped from the both selection ends.
+ // The bracket codepoints are defined in the Unicode standard:
+ // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
+ strip_unpaired_brackets:bool = true;
+
+ // Number of hypothetical click positions on either side of the actual click
+ // to consider in order to enforce symmetry.
+ symmetry_context_size:int;
+
+ // Number of examples to bundle in one batch for inference.
+ batch_size:int = 1024;
+
+ // Whether to always classify a suggested selection or only on demand.
+ always_classify_suggested_selection:bool = false;
+}
+
+// Options for the model that classifies a text selection.
+namespace libtextclassifier3;
+table ClassificationModelOptions {
+ // Limits for phone numbers.
+ phone_min_num_digits:int = 7;
+
+ phone_max_num_digits:int = 15;
+
+ // Limits for addresses.
+ address_min_num_tokens:int;
+
+ // Maximum number of tokens to attempt a classification (-1 is unlimited).
+ max_num_tokens:int = -1;
+}
+
+// Options for post-checks, checksums and verification to apply on a match.
+namespace libtextclassifier3;
+table VerificationOptions {
+ verify_luhn_checksum:bool = false;
+}
+
+// List of regular expression matchers to check.
+namespace libtextclassifier3.RegexModel_;
+table Pattern {
+ // The name of the collection of a match.
+ collection_name:string;
+
+ // The pattern to check.
+ // Can specify a single capturing group used as match boundaries.
+ pattern:string;
+
+ // The modes for which to apply the patterns.
+ enabled_modes:libtextclassifier3.ModeFlag = ALL;
+
+ // The final score to assign to the results of this pattern.
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // If true, will use an approximate matching implementation implemented
+ // using Find() instead of the true Match(). This approximate matching will
+ // use the first Find() result and then check that it spans the whole input.
+ use_approximate_matching:bool = false;
+
+ compressed_pattern:libtextclassifier3.CompressedBuffer;
+
+ // Verification to apply on a match.
+ verification_options:libtextclassifier3.VerificationOptions;
+}
+
+namespace libtextclassifier3;
+table RegexModel {
+ patterns:[libtextclassifier3.RegexModel_.Pattern];
+}
+
+// List of regex patterns.
+namespace libtextclassifier3.DatetimeModelPattern_;
+table Regex {
+ pattern:string;
+
+ // The ith entry specifies the type of the ith capturing group.
+ // This is used to decide how the matched content has to be parsed.
+ groups:[libtextclassifier3.DatetimeGroupType];
+
+ compressed_pattern:libtextclassifier3.CompressedBuffer;
+}
+
+namespace libtextclassifier3;
+table DatetimeModelPattern {
+ regexes:[libtextclassifier3.DatetimeModelPattern_.Regex];
+
+ // List of locale indices in DatetimeModel that represent the locales that
+ // these patterns should be used for. If empty, can be used for all locales.
+ locales:[int];
+
+ // The final score to assign to the results of this pattern.
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // The modes for which to apply the patterns.
+ enabled_modes:libtextclassifier3.ModeFlag = ALL;
+}
+
+namespace libtextclassifier3;
+table DatetimeModelExtractor {
+ extractor:libtextclassifier3.DatetimeExtractorType;
+ pattern:string;
+ locales:[int];
+ compressed_pattern:libtextclassifier3.CompressedBuffer;
+}
+
+namespace libtextclassifier3;
+table DatetimeModel {
+ // List of BCP 47 locale strings representing all locales supported by the
+ // model. The individual patterns refer back to them using an index.
+ locales:[string];
+
+ patterns:[libtextclassifier3.DatetimeModelPattern];
+ extractors:[libtextclassifier3.DatetimeModelExtractor];
+
+ // If true, will use the extractors for determining the match location as
+ // opposed to using the location where the global pattern matched.
+ use_extractors_for_locating:bool = true;
+
+ // List of locale ids, rules of whose are always run, after the requested
+ // ones.
+ default_locales:[int];
+}
+
+namespace libtextclassifier3.DatetimeModelLibrary_;
+table Item {
+ key:string;
+ value:libtextclassifier3.DatetimeModel;
+}
+
+// A set of named DateTime models.
+namespace libtextclassifier3;
+table DatetimeModelLibrary {
+ models:[libtextclassifier3.DatetimeModelLibrary_.Item];
+}
+
+// Options controlling the output of the Tensorflow Lite models.
+namespace libtextclassifier3;
+table ModelTriggeringOptions {
+ // Lower bound threshold for filtering annotation model outputs.
+ min_annotate_confidence:float = 0;
+
+ // The modes for which to enable the models.
+ enabled_modes:libtextclassifier3.ModeFlag = ALL;
+}
+
+// Options controlling the output of the classifier.
+namespace libtextclassifier3;
+table OutputOptions {
+ // Lists of collection names that will be filtered out at the output:
+ // - For annotation, the spans of given collection are simply dropped.
+ // - For classification, the result is mapped to the class "other".
+ // - For selection, the spans of given class are returned as
+ // single-selection.
+ filtered_collections_annotation:[string];
+
+ filtered_collections_classification:[string];
+ filtered_collections_selection:[string];
+}
+
+namespace libtextclassifier3;
+table Model {
+ // Comma-separated list of locales supported by the model as BCP 47 tags.
+ locales:string;
+
+ version:int;
+
+ // A name for the model that can be used for e.g. logging.
+ name:string;
+
+ selection_feature_options:libtextclassifier3.FeatureProcessorOptions;
+ classification_feature_options:libtextclassifier3.FeatureProcessorOptions;
+
+ // Tensorflow Lite models.
+ selection_model:[ubyte] (force_align: 16);
+
+ classification_model:[ubyte] (force_align: 16);
+ embedding_model:[ubyte] (force_align: 16);
+
+ // Options for the different models.
+ selection_options:libtextclassifier3.SelectionModelOptions;
+
+ classification_options:libtextclassifier3.ClassificationModelOptions;
+ regex_model:libtextclassifier3.RegexModel;
+ datetime_model:libtextclassifier3.DatetimeModel;
+
+ // Options controlling the output of the models.
+ triggering_options:libtextclassifier3.ModelTriggeringOptions;
+
+ // Global switch that controls if SuggestSelection(), ClassifyText() and
+ // Annotate() will run. If a mode is disabled it returns empty/no-op results.
+ enabled_modes:libtextclassifier3.ModeFlag = ALL;
+
+ // If true, will snap the selections that consist only of whitespaces to the
+ // containing suggested span. Otherwise, no suggestion is proposed, since the
+ // selections are not part of any token.
+ snap_whitespace_selections:bool = true;
+
+ // Global configuration for the output of SuggestSelection(), ClassifyText()
+ // and Annotate().
+ output_options:libtextclassifier3.OutputOptions;
+
+ // Configures how Intents should be generated on Android.
+ // TODO(smillius): Remove deprecated factory options.
+ android_intent_options:libtextclassifier3.AndroidIntentFactoryOptions;
+
+ intent_options:libtextclassifier3.IntentFactoryModel;
+}
+
+// Role of the codepoints in the range.
+namespace libtextclassifier3.TokenizationCodepointRange_;
+enum Role : int {
+ // Concatenates the codepoint to the current run of codepoints.
+ DEFAULT_ROLE = 0,
+
+ // Splits a run of codepoints before the current codepoint.
+ SPLIT_BEFORE = 1,
+
+ // Splits a run of codepoints after the current codepoint.
+ SPLIT_AFTER = 2,
+
+ // Each codepoint will be a separate token. Good e.g. for Chinese
+ // characters.
+ TOKEN_SEPARATOR = 3,
+
+ // Discards the codepoint.
+ DISCARD_CODEPOINT = 4,
+
+ // Common values:
+ // Splits on the characters and discards them. Good e.g. for the space
+ // character.
+ WHITESPACE_SEPARATOR = 7,
+}
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+namespace libtextclassifier3;
+table TokenizationCodepointRange {
+ start:int;
+ end:int;
+ role:libtextclassifier3.TokenizationCodepointRange_.Role;
+
+ // Integer identifier of the script this range denotes. Negative values are
+ // reserved for Tokenizer's internal use.
+ script_id:int;
+}
+
+// Method for selecting the center token.
+namespace libtextclassifier3.FeatureProcessorOptions_;
+enum CenterTokenSelectionMethod : int {
+ DEFAULT_CENTER_TOKEN_METHOD = 0,
+
+ // Use click indices to determine the center token.
+ CENTER_TOKEN_FROM_CLICK = 1,
+
+ // Use selection indices to get a token range, and select the middle of it
+ // as the center token.
+ CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
+}
+
+// Controls the type of tokenization the model will use for the input text.
+namespace libtextclassifier3.FeatureProcessorOptions_;
+enum TokenizationType : int {
+ INVALID_TOKENIZATION_TYPE = 0,
+
+ // Use the internal tokenizer for tokenization.
+ INTERNAL_TOKENIZER = 1,
+
+ // Use ICU for tokenization.
+ ICU = 2,
+
+ // First apply ICU tokenization. Then identify stretches of tokens
+ // consisting only of codepoints in internal_tokenizer_codepoint_ranges
+ // and re-tokenize them using the internal tokenizer.
+ MIXED = 3,
+}
+
+// Range of codepoints start - end, where end is exclusive.
+namespace libtextclassifier3.FeatureProcessorOptions_;
+table CodepointRange {
+ start:int;
+ end:int;
+}
+
+// Bounds-sensitive feature extraction configuration.
+namespace libtextclassifier3.FeatureProcessorOptions_;
+table BoundsSensitiveFeatures {
+ // Enables the extraction of bounds-sensitive features, instead of the click
+ // context features.
+ enabled:bool;
+
+ // The numbers of tokens to extract in specific locations relative to the
+ // bounds.
+ // Immediately before the span.
+ num_tokens_before:int;
+
+ // Inside the span, aligned with the beginning.
+ num_tokens_inside_left:int;
+
+ // Inside the span, aligned with the end.
+ num_tokens_inside_right:int;
+
+ // Immediately after the span.
+ num_tokens_after:int;
+
+ // If true, also extracts the tokens of the entire span and adds up their
+ // features forming one "token" to include in the extracted features.
+ include_inside_bag:bool;
+
+ // If true, includes the selection length (in the number of tokens) as a
+ // feature.
+ include_inside_length:bool;
+
+ // If true, for selection, single token spans are not run through the model
+ // and their score is assumed to be zero.
+ score_single_token_spans_as_zero:bool;
+}
+
+namespace libtextclassifier3;
+table FeatureProcessorOptions {
+ // Number of buckets used for hashing charactergrams.
+ num_buckets:int = -1;
+
+ // Size of the embedding.
+ embedding_size:int = -1;
+
+ // Number of bits for quantization for embeddings.
+ embedding_quantization_bits:int = 8;
+
+ // Context size defines the number of words to the left and to the right of
+ // the selected word to be used as context. For example, if context size is
+ // N, then we take N words to the left and N words to the right of the
+ // selected word as its context.
+ context_size:int = -1;
+
+ // Maximum number of words of the context to select in total.
+ max_selection_span:int = -1;
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ chargram_orders:[int];
+
+ // Maximum length of a word, in codepoints.
+ max_word_length:int = 20;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ unicode_aware_features:bool = false;
+
+ // Whether to extract the token case feature.
+ extract_case_feature:bool = false;
+
+ // Whether to extract the selection mask feature.
+ extract_selection_mask_feature:bool = false;
+
+ // List of regexps to run over each token. For each regexp, if there is a
+ // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used.
+ regexp_feature:[string];
+
+ // Whether to remap all digits to a single number.
+ remap_digits:bool = false;
+
+ // Whether to lower-case each token before generating hashgrams.
+ lowercase_tokens:bool;
+
+ // If true, the selection classifier output will contain only the selections
+ // that are feasible (e.g., those that are shorter than max_selection_span),
+ // if false, the output will be a complete cross-product of possible
+ // selections to the left and possible selections to the right, including the
+ // infeasible ones.
+ // NOTE: Exists mainly for compatibility with older models that were trained
+ // with the non-reduced output space.
+ selection_reduced_output_space:bool = true;
+
+ // Collection names.
+ collections:[string];
+
+ // An index of collection in collections to be used if a collection name can't
+ // be mapped to an id.
+ default_collection:int = -1;
+
+ // If true, will split the input by lines, and only use the line that contains
+ // the clicked token.
+ only_use_line_with_click:bool = false;
+
+ // If true, will split tokens that contain the selection boundary, at the
+ // position of the boundary.
+ // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+ split_tokens_on_selection_boundaries:bool = false;
+
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ tokenization_codepoint_config:[libtextclassifier3.TokenizationCodepointRange];
+
+ center_token_selection_method:libtextclassifier3.FeatureProcessorOptions_.CenterTokenSelectionMethod;
+
+ // If true, span boundaries will be snapped to containing tokens and not
+ // required to exactly match token boundaries.
+ snap_label_span_boundaries_to_containing_tokens:bool;
+
+ // A set of codepoint ranges supported by the model.
+ supported_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange];
+
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ internal_tokenizer_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange];
+
+ // Minimum ratio of supported codepoints in the input context. If the ratio
+ // is lower than this, the feature computation will fail.
+ min_supported_codepoint_ratio:float = 0;
+
+ // Used for versioning the format of features the model expects.
+ // - feature_version == 0:
+ // For each token the features consist of:
+ // - chargram embeddings
+ // - dense features
+ // Chargram embeddings for tokens are concatenated first together,
+ // and at the end, the dense features for the tokens are concatenated
+ // to it. So the resulting feature vector has two regions.
+ feature_version:int = 0;
+
+ tokenization_type:libtextclassifier3.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER;
+ icu_preserve_whitespace_tokens:bool = false;
+
+ // List of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ ignored_span_boundary_codepoints:[int];
+
+ bounds_sensitive_features:libtextclassifier3.FeatureProcessorOptions_.BoundsSensitiveFeatures;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ // The field is typed as bytes type to allow non-UTF8 chargrams.
+ allowed_chargrams:[string];
+
+ // If true, tokens will be also split when the codepoint's script_id changes
+ // as defined in TokenizationCodepointRange.
+ tokenize_on_script_change:bool = false;
+}
+
+root_type libtextclassifier3.Model;
diff --git a/annotator/quantization.cc b/annotator/quantization.cc
new file mode 100644
index 0000000..2cf11c5
--- /dev/null
+++ b/annotator/quantization.cc
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/quantization.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+namespace {
+float DequantizeValue(int num_sparse_features, int quantization_bias,
+ float multiplier, int value) {
+ return 1.0 / num_sparse_features * (value - quantization_bias) * multiplier;
+}
+
+void DequantizeAdd8bit(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, const int num_sparse_features,
+ const int bucket_id, float* dest, int dest_size) {
+ static const int kQuantizationBias8bit = 128;
+ const float multiplier = scales[bucket_id];
+ for (int k = 0; k < dest_size; ++k) {
+ dest[k] +=
+ DequantizeValue(num_sparse_features, kQuantizationBias8bit, multiplier,
+ embeddings[bucket_id * bytes_per_embedding + k]);
+ }
+}
+
+void DequantizeAddNBit(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, int num_sparse_features,
+ int quantization_bits, int bucket_id, float* dest,
+ int dest_size) {
+ const int quantization_bias = 1 << (quantization_bits - 1);
+ const float multiplier = scales[bucket_id];
+ for (int i = 0; i < dest_size; ++i) {
+ const int bit_offset = i * quantization_bits;
+ const int read16_offset = bit_offset / 8;
+
+ uint16 data = embeddings[bucket_id * bytes_per_embedding + read16_offset];
+ // If we are not at the end of the embedding row, we can read 2-byte uint16,
+ // but if we are, we need to only read uint8.
+ if (read16_offset < bytes_per_embedding - 1) {
+ data |= embeddings[bucket_id * bytes_per_embedding + read16_offset + 1]
+ << 8;
+ }
+ int value = (data >> (bit_offset % 8)) & ((1 << quantization_bits) - 1);
+ dest[i] += DequantizeValue(num_sparse_features, quantization_bias,
+ multiplier, value);
+ }
+}
+} // namespace
+
+bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits,
+ int output_embedding_size) {
+ if (bytes_per_embedding * 8 / quantization_bits < output_embedding_size) {
+ return false;
+ }
+
+ return true;
+}
+
+bool DequantizeAdd(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, int num_sparse_features,
+ int quantization_bits, int bucket_id, float* dest,
+ int dest_size) {
+ if (quantization_bits == 8) {
+ DequantizeAdd8bit(scales, embeddings, bytes_per_embedding,
+ num_sparse_features, bucket_id, dest, dest_size);
+ } else if (quantization_bits != 8) {
+ DequantizeAddNBit(scales, embeddings, bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest,
+ dest_size);
+ } else {
+ TC3_LOG(ERROR) << "Unsupported quantization_bits: " << quantization_bits;
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/quantization.h b/annotator/quantization.h
new file mode 100644
index 0000000..d294f37
--- /dev/null
+++ b/annotator/quantization.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_
+
+#include "utils/base/integral_types.h"
+
+namespace libtextclassifier3 {
+
+// Returns true if the quantization parameters are valid.
+bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits,
+ int output_embedding_size);
+
+// Dequantizes embeddings (quantized to 1 to 8 bits) into the floats they
+// represent. The algorithm proceeds by reading 2-byte words from the embedding
+// storage to handle well the cases when the quantized value crosses the byte-
+// boundary.
+bool DequantizeAdd(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, int num_sparse_features,
+ int quantization_bits, int bucket_id, float* dest,
+ int dest_size);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_
diff --git a/annotator/quantization_test.cc b/annotator/quantization_test.cc
new file mode 100644
index 0000000..b995096
--- /dev/null
+++ b/annotator/quantization_test.cc
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/quantization.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+namespace libtextclassifier3 {
+namespace {
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+TEST(QuantizationTest, DequantizeAdd8bit) {
+ std::vector<float> scales{{0.1, 9.0, -7.0}};
+ std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00,
+ /*1: */ 0xFF, 0x09, 0x00, 0xFF,
+ /*2: */ 0x09, 0x00, 0xFF, 0x09}};
+
+ const int quantization_bits = 8;
+ const int bytes_per_embedding = 4;
+ const int num_sparse_features = 7;
+ {
+ const int bucket_id = 0;
+ std::vector<float> dest(4, 0.0);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id,
+ dest.data(), dest.size());
+
+ EXPECT_THAT(dest,
+ ElementsAreFloat(std::vector<float>{
+ // clang-format off
+ {1.0 / 7 * 0.1 * (0x00 - 128),
+ 1.0 / 7 * 0.1 * (0xFF - 128),
+ 1.0 / 7 * 0.1 * (0x09 - 128),
+ 1.0 / 7 * 0.1 * (0x00 - 128)}
+ // clang-format on
+ }));
+ }
+
+ {
+ const int bucket_id = 1;
+ std::vector<float> dest(4, 0.0);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id,
+ dest.data(), dest.size());
+
+ EXPECT_THAT(dest,
+ ElementsAreFloat(std::vector<float>{
+ // clang-format off
+ {1.0 / 7 * 9.0 * (0xFF - 128),
+ 1.0 / 7 * 9.0 * (0x09 - 128),
+ 1.0 / 7 * 9.0 * (0x00 - 128),
+ 1.0 / 7 * 9.0 * (0xFF - 128)}
+ // clang-format on
+ }));
+ }
+}
+
+TEST(QuantizationTest, DequantizeAdd1bitZeros) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 1;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets);
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets);
+ std::fill(scales.begin(), scales.end(), 1);
+ std::fill(embeddings.begin(), embeddings.end(), 0);
+
+ std::vector<float> dest(32);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+
+ std::vector<float> expected(32);
+ std::fill(expected.begin(), expected.end(),
+ 1.0 / num_sparse_features * (0 - 1));
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+TEST(QuantizationTest, DequantizeAdd1bitOnes) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 1;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets, 1.0);
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF);
+
+ std::vector<float> dest(32);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+ std::vector<float> expected(32);
+ std::fill(expected.begin(), expected.end(),
+ 1.0 / num_sparse_features * (1 - 1));
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+TEST(QuantizationTest, DequantizeAdd3bit) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 3;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets, 1.0);
+ scales[1] = 9.0;
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0);
+ // For bucket_id=1, the embedding has values 0..9 for indices 0..9:
+ embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1;
+ embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3);
+ embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1;
+
+ std::vector<float> dest(10);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+
+ std::vector<float> expected;
+ expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/strip-unpaired-brackets.cc b/annotator/strip-unpaired-brackets.cc
new file mode 100644
index 0000000..b1067ad
--- /dev/null
+++ b/annotator/strip-unpaired-brackets.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/strip-unpaired-brackets.h"
+
+#include <iterator>
+
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Returns true if given codepoint is contained in the given span in context.
+bool IsCodepointInSpan(const char32 codepoint,
+ const UnicodeText& context_unicode,
+ const CodepointSpan span) {
+ auto begin_it = context_unicode.begin();
+ std::advance(begin_it, span.first);
+ auto end_it = context_unicode.begin();
+ std::advance(end_it, span.second);
+
+ return std::find(begin_it, end_it, codepoint) != end_it;
+}
+
+// Returns the first codepoint of the span.
+char32 FirstSpanCodepoint(const UnicodeText& context_unicode,
+ const CodepointSpan span) {
+ auto it = context_unicode.begin();
+ std::advance(it, span.first);
+ return *it;
+}
+
+// Returns the last codepoint of the span.
+char32 LastSpanCodepoint(const UnicodeText& context_unicode,
+ const CodepointSpan span) {
+ auto it = context_unicode.begin();
+ std::advance(it, span.second - 1);
+ return *it;
+}
+
+} // namespace
+
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ return StripUnpairedBrackets(context_unicode, span, unilib);
+}
+
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
+ CodepointSpan span, const UniLib& unilib) {
+ if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+ return span;
+ }
+
+ const char32 begin_char = FirstSpanCodepoint(context_unicode, span);
+ const char32 paired_begin_char = unilib.GetPairedBracket(begin_char);
+ if (paired_begin_char != begin_char) {
+ if (!unilib.IsOpeningBracket(begin_char) ||
+ !IsCodepointInSpan(paired_begin_char, context_unicode, span)) {
+ ++span.first;
+ }
+ }
+
+ if (span.first == span.second) {
+ return span;
+ }
+
+ const char32 end_char = LastSpanCodepoint(context_unicode, span);
+ const char32 paired_end_char = unilib.GetPairedBracket(end_char);
+ if (paired_end_char != end_char) {
+ if (!unilib.IsClosingBracket(end_char) ||
+ !IsCodepointInSpan(paired_end_char, context_unicode, span)) {
+ --span.second;
+ }
+ }
+
+ // Should not happen, but let's make sure.
+ if (span.first > span.second) {
+ TC3_LOG(WARNING) << "Inverse indices result: " << span.first << ", "
+ << span.second;
+ span.second = span.first;
+ }
+
+ return span;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/strip-unpaired-brackets.h b/annotator/strip-unpaired-brackets.h
new file mode 100644
index 0000000..ceb8d60
--- /dev/null
+++ b/annotator/strip-unpaired-brackets.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_
+
+#include <string>
+
+#include "annotator/types.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib);
+
+// Same as above but takes UnicodeText instance directly.
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
+ CodepointSpan span, const UniLib& unilib);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_
diff --git a/annotator/strip-unpaired-brackets_test.cc b/annotator/strip-unpaired-brackets_test.cc
new file mode 100644
index 0000000..32585ce
--- /dev/null
+++ b/annotator/strip-unpaired-brackets_test.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/strip-unpaired-brackets.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class StripUnpairedBracketsTest : public ::testing::Test {
+ protected:
+ StripUnpairedBracketsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+TEST_F(StripUnpairedBracketsTest, StripUnpairedBrackets) {
+ // If the brackets match, nothing gets stripped.
+ EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib_),
+ std::make_pair(8, 17));
+ EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib_),
+ std::make_pair(8, 17));
+
+ // If the brackets don't match, they get stripped.
+ EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib_),
+ std::make_pair(9, 16));
+ EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib_),
+ std::make_pair(9, 16));
+ EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib_),
+ std::make_pair(8, 15));
+ EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib_),
+ std::make_pair(8, 15));
+
+ // Strips brackets correctly from length-1 selections that consist of
+ // a bracket only.
+ EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib_),
+ std::make_pair(12, 12));
+ EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib_),
+ std::make_pair(12, 12));
+
+ // Handles invalid spans gracefully.
+ EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib_),
+ std::make_pair(11, 11));
+ EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib_),
+ std::make_pair(0, 0));
+ EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib_),
+ std::make_pair(11, 11));
+ EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib_),
+ std::make_pair(-1, -1));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb
new file mode 100644
index 0000000..fa9cec5
--- /dev/null
+++ b/annotator/test_data/test_model.fb
Binary files differ
diff --git a/annotator/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb
new file mode 100644
index 0000000..b73d84f
--- /dev/null
+++ b/annotator/test_data/test_model_cc.fb
Binary files differ
diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb
new file mode 100644
index 0000000..ba71cdd
--- /dev/null
+++ b/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/annotator/token-feature-extractor.cc b/annotator/token-feature-extractor.cc
new file mode 100644
index 0000000..77ad7a4
--- /dev/null
+++ b/annotator/token-feature-extractor.cc
@@ -0,0 +1,311 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/token-feature-extractor.h"
+
+#include <cctype>
+#include <string>
+
+#include "utils/base/logging.h"
+#include "utils/hash/farmhash.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+
+std::string RemapTokenAscii(const std::string& token,
+ const TokenFeatureExtractorOptions& options) {
+ if (!options.remap_digits && !options.lowercase_tokens) {
+ return token;
+ }
+
+ std::string copy = token;
+ for (int i = 0; i < token.size(); ++i) {
+ if (options.remap_digits && isdigit(copy[i])) {
+ copy[i] = '0';
+ }
+ if (options.lowercase_tokens) {
+ copy[i] = tolower(copy[i]);
+ }
+ }
+ return copy;
+}
+
+void RemapTokenUnicode(const std::string& token,
+ const TokenFeatureExtractorOptions& options,
+ const UniLib& unilib, UnicodeText* remapped) {
+ if (!options.remap_digits && !options.lowercase_tokens) {
+ // Leave remapped untouched.
+ return;
+ }
+
+ UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
+ remapped->clear();
+ for (auto it = word.begin(); it != word.end(); ++it) {
+ if (options.remap_digits && unilib.IsDigit(*it)) {
+ remapped->push_back('0');
+ } else if (options.lowercase_tokens) {
+ remapped->push_back(unilib.ToLower(*it));
+ } else {
+ remapped->push_back(*it);
+ }
+ }
+}
+
+} // namespace
+
+TokenFeatureExtractor::TokenFeatureExtractor(
+ const TokenFeatureExtractorOptions& options, const UniLib& unilib)
+ : options_(options), unilib_(unilib) {
+ for (const std::string& pattern : options.regexp_features) {
+ regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
+ unilib_.CreateRegexPattern(UTF8ToUnicodeText(
+ pattern.c_str(), pattern.size(), /*do_copy=*/false))));
+ }
+}
+
+bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
+ std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const {
+ if (!dense_features) {
+ return false;
+ }
+ if (sparse_features) {
+ *sparse_features = ExtractCharactergramFeatures(token);
+ }
+ *dense_features = ExtractDenseFeatures(token, is_in_span);
+ return true;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
+ const Token& token) const {
+ if (options_.unicode_aware_features) {
+ return ExtractCharactergramFeaturesUnicode(token);
+ } else {
+ return ExtractCharactergramFeaturesAscii(token);
+ }
+}
+
+std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
+ const Token& token, bool is_in_span) const {
+ std::vector<float> dense_features;
+
+ if (options_.extract_case_feature) {
+ if (options_.unicode_aware_features) {
+ UnicodeText token_unicode =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
+ if (!token.value.empty() && is_upper) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ } else {
+ if (!token.value.empty() && isupper(*token.value.begin())) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ }
+ }
+
+ if (options_.extract_selection_mask_feature) {
+ if (is_in_span) {
+ dense_features.push_back(1.0);
+ } else {
+ if (options_.unicode_aware_features) {
+ dense_features.push_back(-1.0);
+ } else {
+ dense_features.push_back(0.0);
+ }
+ }
+ }
+
+ // Add regexp features.
+ if (!regex_patterns_.empty()) {
+ UnicodeText token_unicode =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ for (int i = 0; i < regex_patterns_.size(); ++i) {
+ if (!regex_patterns_[i].get()) {
+ dense_features.push_back(-1.0);
+ continue;
+ }
+ auto matcher = regex_patterns_[i]->Matcher(token_unicode);
+ int status;
+ if (matcher->Matches(&status)) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ }
+ }
+
+ return dense_features;
+}
+
+int TokenFeatureExtractor::HashToken(StringPiece token) const {
+ if (options_.allowed_chargrams.empty()) {
+ return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
+ } else {
+ // Padding and out-of-vocabulary tokens have extra buckets reserved because
+ // they are special and important tokens, and we don't want them to share
+ // embedding with other charactergrams.
+ // TODO(zilka): Experimentally verify.
+ const int kNumExtraBuckets = 2;
+ const std::string token_string = token.ToString();
+ if (token_string == "<PAD>") {
+ return 1;
+ } else if (options_.allowed_chargrams.find(token_string) ==
+ options_.allowed_chargrams.end()) {
+ return 0; // Out-of-vocabulary.
+ } else {
+ return (tc3farmhash::Fingerprint64(token) %
+ (options_.num_buckets - kNumExtraBuckets)) +
+ kNumExtraBuckets;
+ }
+ }
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
+ const Token& token) const {
+ std::vector<int> result;
+ if (token.is_padding || token.value.empty()) {
+ result.push_back(HashToken("<PAD>"));
+ } else {
+ const std::string word = RemapTokenAscii(token.value, options_);
+
+ // Trim words that are over max_word_length characters.
+ const int max_word_length = options_.max_word_length;
+ std::string feature_word;
+ if (word.size() > max_word_length) {
+ feature_word =
+ "^" + word.substr(0, max_word_length / 2) + "\1" +
+ word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
+ "$";
+ } else {
+ // Add a prefix and suffix to the word.
+ feature_word = "^" + word + "$";
+ }
+
+ // Upper-bound the number of charactergram extracted to avoid resizing.
+ result.reserve(options_.chargram_orders.size() * feature_word.size());
+
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ if (chargram_order == 1) {
+ for (int i = 1; i < feature_word.size() - 1; ++i) {
+ result.push_back(
+ HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
+ }
+ } else {
+ for (int i = 0;
+ i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+ ++i) {
+ result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
+ /*len=*/chargram_order)));
+ }
+ }
+ }
+ }
+ }
+ return result;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
+ const Token& token) const {
+ std::vector<int> result;
+ if (token.is_padding || token.value.empty()) {
+ result.push_back(HashToken("<PAD>"));
+ } else {
+ UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ RemapTokenUnicode(token.value, options_, unilib_, &word);
+
+ // Trim the word if needed by finding a left-cut point and right-cut point.
+ auto left_cut = word.begin();
+ auto right_cut = word.end();
+ for (int i = 0; i < options_.max_word_length / 2; i++) {
+ if (left_cut < right_cut) {
+ ++left_cut;
+ }
+ if (left_cut < right_cut) {
+ --right_cut;
+ }
+ }
+
+ std::string feature_word;
+ if (left_cut == right_cut) {
+ feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
+ } else {
+ // clang-format off
+ feature_word = "^" +
+ word.UTF8Substring(word.begin(), left_cut) +
+ "\1" +
+ word.UTF8Substring(right_cut, word.end()) +
+ "$";
+ // clang-format on
+ }
+
+ const UnicodeText feature_word_unicode =
+ UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
+
+ // Upper-bound the number of charactergram extracted to avoid resizing.
+ result.reserve(options_.chargram_orders.size() * feature_word.size());
+
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ UnicodeText::const_iterator it_start = feature_word_unicode.begin();
+ UnicodeText::const_iterator it_end = feature_word_unicode.end();
+ if (chargram_order == 1) {
+ ++it_start;
+ --it_end;
+ }
+
+ UnicodeText::const_iterator it_chargram_start = it_start;
+ UnicodeText::const_iterator it_chargram_end = it_start;
+ bool chargram_is_complete = true;
+ for (int i = 0; i < chargram_order; ++i) {
+ if (it_chargram_end == it_end) {
+ chargram_is_complete = false;
+ break;
+ }
+ ++it_chargram_end;
+ }
+ if (!chargram_is_complete) {
+ continue;
+ }
+
+ for (; it_chargram_end <= it_end;
+ ++it_chargram_start, ++it_chargram_end) {
+ const int length_bytes =
+ it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
+ result.push_back(HashToken(
+ StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ }
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/token-feature-extractor.h b/annotator/token-feature-extractor.h
new file mode 100644
index 0000000..7dc19fe
--- /dev/null
+++ b/annotator/token-feature-extractor.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_
+
+#include <memory>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+struct TokenFeatureExtractorOptions {
+ // Number of buckets used for hashing charactergrams.
+ int num_buckets = 0;
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ std::vector<int> chargram_orders;
+
+ // Whether to extract the token case feature.
+ bool extract_case_feature = false;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ bool unicode_aware_features = false;
+
+ // Whether to extract the selection mask feature.
+ bool extract_selection_mask_feature = false;
+
+ // Regexp features to extract.
+ std::vector<std::string> regexp_features;
+
+ // Whether to remap digits to a single number.
+ bool remap_digits = false;
+
+ // Whether to lowercase all tokens.
+ bool lowercase_tokens = false;
+
+ // Maximum length of a word.
+ int max_word_length = 20;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ std::unordered_set<std::string> allowed_chargrams;
+};
+
+class TokenFeatureExtractor {
+ public:
+ TokenFeatureExtractor(const TokenFeatureExtractorOptions& options,
+ const UniLib& unilib);
+
+ // Extracts both the sparse (charactergram) and the dense features from a
+ // token. is_in_span is a bool indicator whether the token is a part of the
+ // selection span (true) or not (false).
+ // The sparse_features output is optional. Fails and returns false if
+ // dense_fatures in a nullptr.
+ bool Extract(const Token& token, bool is_in_span,
+ std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const;
+
+ // Extracts the sparse (charactergram) features from the token.
+ std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
+
+ // Extracts the dense features from the token. is_in_span is a bool indicator
+ // whether the token is a part of the selection span (true) or not (false).
+ std::vector<float> ExtractDenseFeatures(const Token& token,
+ bool is_in_span) const;
+
+ int DenseFeaturesCount() const {
+ int feature_count =
+ options_.extract_case_feature + options_.extract_selection_mask_feature;
+ feature_count += regex_patterns_.size();
+ return feature_count;
+ }
+
+ protected:
+ // Hashes given token to given number of buckets.
+ int HashToken(StringPiece token) const;
+
+ // Extracts the charactergram features from the token in a non-unicode-aware
+ // way.
+ std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const;
+
+ // Extracts the charactergram features from the token in a unicode-aware way.
+ std::vector<int> ExtractCharactergramFeaturesUnicode(
+ const Token& token) const;
+
+ private:
+ TokenFeatureExtractorOptions options_;
+ std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_;
+ const UniLib& unilib_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/annotator/token-feature-extractor_test.cc b/annotator/token-feature-extractor_test.cc
new file mode 100644
index 0000000..32383a9
--- /dev/null
+++ b/annotator/token-feature-extractor_test.cc
@@ -0,0 +1,556 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/token-feature-extractor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class TokenFeatureExtractorTest : public ::testing::Test {
+ protected:
+ TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
+ public:
+ using TokenFeatureExtractor::HashToken;
+ using TokenFeatureExtractor::TokenFeatureExtractor;
+};
+
+TEST_F(TokenFeatureExtractorTest, ExtractAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("H"),
+ extractor.HashToken("e"),
+ extractor.HashToken("l"),
+ extractor.HashToken("l"),
+ extractor.HashToken("o"),
+ extractor.HashToken("^H"),
+ extractor.HashToken("He"),
+ extractor.HashToken("el"),
+ extractor.HashToken("ll"),
+ extractor.HashToken("lo"),
+ extractor.HashToken("o$"),
+ extractor.HashToken("^He"),
+ extractor.HashToken("Hel"),
+ extractor.HashToken("ell"),
+ extractor.HashToken("llo"),
+ extractor.HashToken("lo$")
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ extractor.HashToken("o"),
+ extractor.HashToken("r"),
+ extractor.HashToken("l"),
+ extractor.HashToken("d"),
+ extractor.HashToken("!"),
+ extractor.HashToken("^w"),
+ extractor.HashToken("wo"),
+ extractor.HashToken("or"),
+ extractor.HashToken("rl"),
+ extractor.HashToken("ld"),
+ extractor.HashToken("d!"),
+ extractor.HashToken("!$"),
+ extractor.HashToken("^wo"),
+ extractor.HashToken("wor"),
+ extractor.HashToken("orl"),
+ extractor.HashToken("rld"),
+ extractor.HashToken("ld!"),
+ extractor.HashToken("d!$"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^world!$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("H"),
+ extractor.HashToken("ě"),
+ extractor.HashToken("l"),
+ extractor.HashToken("l"),
+ extractor.HashToken("ó"),
+ extractor.HashToken("^H"),
+ extractor.HashToken("Hě"),
+ extractor.HashToken("ěl"),
+ extractor.HashToken("ll"),
+ extractor.HashToken("ló"),
+ extractor.HashToken("ó$"),
+ extractor.HashToken("^Hě"),
+ extractor.HashToken("Hěl"),
+ extractor.HashToken("ěll"),
+ extractor.HashToken("lló"),
+ extractor.HashToken("ló$")
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ extractor.HashToken("o"),
+ extractor.HashToken("r"),
+ extractor.HashToken("l"),
+ extractor.HashToken("d"),
+ extractor.HashToken("!"),
+ extractor.HashToken("^w"),
+ extractor.HashToken("wo"),
+ extractor.HashToken("or"),
+ extractor.HashToken("rl"),
+ extractor.HashToken("ld"),
+ extractor.HashToken("d!"),
+ extractor.HashToken("!$"),
+ extractor.HashToken("^wo"),
+ extractor.HashToken("wor"),
+ extractor.HashToken("orl"),
+ extractor.HashToken("rld"),
+ extractor.HashToken("ld!"),
+ extractor.HashToken("d!$"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ extractor.HashToken("^world!$"),
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+#ifdef TC3_TEST_ICU
+TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = false;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+}
+#endif
+
+TEST_F(TokenFeatureExtractorTest, DigitRemapping) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = true;
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features,
+ testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = true;
+ options.unicode_aware_features = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features,
+ testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST_F(TokenFeatureExtractorTest, LowercaseAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+
+#ifdef TC3_TEST_ICU
+TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+#endif
+
+#ifdef TC3_TEST_ICU
+TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = false;
+ options.unicode_aware_features = false;
+ options.regexp_features.push_back("^[a-z]+$"); // all lower case.
+ options.regexp_features.push_back("^[0-9]+$"); // all digits.
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
+}
+#endif
+
+TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{22};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ // Test that this runs. ASAN should catch problems.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
+ &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
+ extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
+ // clang-format on
+ }));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+
+ TestingTokenFeatureExtractor extractor_unicode(options, unilib_);
+
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor_ascii(options, unilib_);
+
+ for (const std::string& input :
+ {"https://www.abcdefgh.com/in/xxxkkkvayio",
+ "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
+ "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
+ "x", "Hello", "Hey,", "Hi", ""}) {
+ std::vector<int> sparse_features_unicode;
+ std::vector<float> dense_features_unicode;
+ extractor_unicode.Extract(Token{input, 0, 0}, true,
+ &sparse_features_unicode,
+ &dense_features_unicode);
+
+ std::vector<int> sparse_features_ascii;
+ std::vector<float> dense_features_ascii;
+ extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
+ &dense_features_ascii);
+
+ EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
+ EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
+ }
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token(), false, &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractFiltered) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ options.allowed_chargrams.insert("^H");
+ options.allowed_chargrams.insert("ll");
+ options.allowed_chargrams.insert("llo");
+ options.allowed_chargrams.insert("w");
+ options.allowed_chargrams.insert("!");
+ options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
+
+ TestingTokenFeatureExtractor extractor(options, unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ 0,
+ extractor.HashToken("\xc4"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("^H"),
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("ll"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("llo"),
+ 0
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("!"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+ EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/tokenizer.cc b/annotator/tokenizer.cc
new file mode 100644
index 0000000..099dccc
--- /dev/null
+++ b/annotator/tokenizer.cc
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/tokenizer.h"
+
+#include <algorithm>
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3 {
+
+Tokenizer::Tokenizer(
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ bool split_on_script_change)
+ : split_on_script_change_(split_on_script_change) {
+ for (const TokenizationCodepointRange* range : codepoint_ranges) {
+ codepoint_ranges_.emplace_back(range->UnPack());
+ }
+
+ std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+ const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+ return a->start < b->start;
+ });
+}
+
+const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
+ int codepoint) const {
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
+ int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range->end <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
+ (*it)->end > codepoint) {
+ return it->get();
+ } else {
+ return nullptr;
+ }
+}
+
+void Tokenizer::GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const {
+ const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
+ if (range) {
+ *role = range->role;
+ *script = range->script_id;
+ } else {
+ *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ *script = kUnknownScript;
+ }
+}
+
+std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
+ UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
+ return Tokenize(text_unicode);
+}
+
+std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
+ std::vector<Token> result;
+ Token new_token("", 0, 0);
+ int codepoint_index = 0;
+
+ int last_script = kInvalidScript;
+ for (auto it = text_unicode.begin(); it != text_unicode.end();
+ ++it, ++codepoint_index) {
+ TokenizationCodepointRange_::Role role;
+ int script;
+ GetScriptAndRole(*it, &role, &script);
+
+ if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
+ (split_on_script_change_ && last_script != kInvalidScript &&
+ last_script != script)) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index, codepoint_index);
+ }
+ if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
+ new_token.value += std::string(
+ it.utf8_data(),
+ it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
+ ++new_token.end;
+ }
+ if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+ }
+
+ last_script = script;
+ }
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/tokenizer.h b/annotator/tokenizer.h
new file mode 100644
index 0000000..ec33f2d
--- /dev/null
+++ b/annotator/tokenizer.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+const int kInvalidScript = -1;
+const int kUnknownScript = -2;
+
+// Tokenizer splits the input string into a sequence of tokens, according to the
+// configuration.
+class Tokenizer {
+ public:
+ explicit Tokenizer(
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ bool split_on_script_change);
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
+
+ protected:
+ // Finds the tokenization codepoint range config for given codepoint.
+ // Internally uses binary search so should be O(log(# of codepoint_ranges)).
+ const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const;
+
+ // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
+ // and kUnknownScript are assigned.
+ void GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const;
+
+ private:
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ std::vector<std::unique_ptr<const TokenizationCodepointRangeT>>
+ codepoint_ranges_;
+
+ // If true, tokens will be additionally split when the codepoint's script_id
+ // changes.
+ bool split_on_script_change_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_
diff --git a/annotator/tokenizer_test.cc b/annotator/tokenizer_test.cc
new file mode 100644
index 0000000..a3ab9da
--- /dev/null
+++ b/annotator/tokenizer_test.cc
@@ -0,0 +1,334 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/tokenizer.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingTokenizer : public Tokenizer {
+ public:
+ explicit TestingTokenizer(
+ const std::vector<const TokenizationCodepointRange*>&
+ codepoint_range_configs,
+ bool split_on_script_change)
+ : Tokenizer(codepoint_range_configs, split_on_script_change) {}
+
+ using Tokenizer::FindTokenizationRange;
+};
+
+class TestingTokenizerProxy {
+ public:
+ explicit TestingTokenizerProxy(
+ const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
+ bool split_on_script_change) {
+ int num_configs = codepoint_range_configs.size();
+ std::vector<const TokenizationCodepointRange*> configs_fb;
+ buffers_.reserve(num_configs);
+ for (int i = 0; i < num_configs; i++) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateTokenizationCodepointRange(
+ builder, &codepoint_range_configs[i]));
+ buffers_.push_back(builder.Release());
+ configs_fb.push_back(
+ flatbuffers::GetRoot<TokenizationCodepointRange>(buffers_[i].data()));
+ }
+ tokenizer_ = std::unique_ptr<TestingTokenizer>(
+ new TestingTokenizer(configs_fb, split_on_script_change));
+ }
+
+ TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
+ const TokenizationCodepointRangeT* range =
+ tokenizer_->FindTokenizationRange(c);
+ if (range != nullptr) {
+ return range->role;
+ } else {
+ return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ }
+ }
+
+ std::vector<Token> Tokenize(const std::string& utf8_text) const {
+ return tokenizer_->Tokenize(utf8_text);
+ }
+
+ private:
+ std::vector<flatbuffers::DetachedBuffer> buffers_;
+ std::unique_ptr<TestingTokenizer> tokenizer_;
+};
+
+TEST(TokenizerTest, FindTokenizationRange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 10;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 1234;
+ config->end = 12345;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
+
+ // Test hits to the first group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit to the second group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
+ TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test hits to the third group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit outside.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+}
+
+TEST(TokenizerTest, TokenizeOnSpace) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ // Space character.
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
+
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
+}
+
+TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Latin.
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/true);
+ EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
+ Token("전화", 7, 10), Token("(123)", 10, 15),
+ Token("456-789", 16, 23),
+ Token("웹사이트", 23, 28)}));
+} // namespace
+
+TEST(TokenizerTest, TokenizeComplex) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
+ // Latin - cyrilic.
+ // 0000..007F; Basic Latin
+ // 0080..00FF; Latin-1 Supplement
+ // 0100..017F; Latin Extended-A
+ // 0180..024F; Latin Extended-B
+ // 0250..02AF; IPA Extensions
+ // 02B0..02FF; Spacing Modifier Letters
+ // 0300..036F; Combining Diacritical Marks
+ // 0370..03FF; Greek and Coptic
+ // 0400..04FF; Cyrillic
+ // 0500..052F; Cyrillic Supplement
+ // 0530..058F; Armenian
+ // 0590..05FF; Hebrew
+ // 0600..06FF; Arabic
+ // 0700..074F; Syriac
+ // 0750..077F; Arabic Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+
+ // CJK
+ // 2E80..2EFF; CJK Radicals Supplement
+ // 3000..303F; CJK Symbols and Punctuation
+ // 3040..309F; Hiragana
+ // 30A0..30FF; Katakana
+ // 3100..312F; Bopomofo
+ // 3130..318F; Hangul Compatibility Jamo
+ // 3190..319F; Kanbun
+ // 31A0..31BF; Bopomofo Extended
+ // 31C0..31EF; CJK Strokes
+ // 31F0..31FF; Katakana Phonetic Extensions
+ // 3200..32FF; Enclosed CJK Letters and Months
+ // 3300..33FF; CJK Compatibility
+ // 3400..4DBF; CJK Unified Ideographs Extension A
+ // 4DC0..4DFF; Yijing Hexagram Symbols
+ // 4E00..9FFF; CJK Unified Ideographs
+ // A000..A48F; Yi Syllables
+ // A490..A4CF; Yi Radicals
+ // A4D0..A4FF; Lisu
+ // A500..A63F; Vai
+ // F900..FAFF; CJK Compatibility Ideographs
+ // FE30..FE4F; CJK Compatibility Forms
+ // 20000..2A6DF; CJK Unified Ideographs Extension B
+ // 2A700..2B73F; CJK Unified Ideographs Extension C
+ // 2B740..2B81F; CJK Unified Ideographs Extension D
+ // 2B820..2CEAF; CJK Unified Ideographs Extension E
+ // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
+ // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2E80;
+ config->end = 0x2EFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x3000;
+ config->end = 0xA63F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xF900;
+ config->end = 0xFAFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xFE30;
+ config->end = 0xFE4F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x20000;
+ config->end = 0x2A6DF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2A700;
+ config->end = 0x2B73F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B740;
+ config->end = 0x2B81F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B820;
+ config->end = 0x2CEAF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2CEB0;
+ config->end = 0x2EBEF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2F800;
+ config->end = 0x2FA1F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ // Thai.
+ // 0E00..0E7F; Thai
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x0E00;
+ config->end = 0x0E7F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
+ std::vector<Token> tokens;
+
+ tokens = tokenizer.Tokenize(
+ "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
+ EXPECT_EQ(tokens.size(), 30);
+
+ tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
+ // clang-format off
+ EXPECT_THAT(
+ tokens,
+ ElementsAreArray({Token("問", 0, 1),
+ Token("少", 1, 2),
+ Token("目", 2, 3),
+ Token("hello", 4, 9),
+ Token("木", 10, 11),
+ Token("輸", 11, 12),
+ Token("ย", 12, 13),
+ Token("า", 13, 14),
+ Token("ม", 14, 15),
+ Token("き", 15, 16),
+ Token("ゃ", 16, 17)}));
+ // clang-format on
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/types-test-util.h b/annotator/types-test-util.h
new file mode 100644
index 0000000..fbbdd63
--- /dev/null
+++ b/annotator/types-test-util.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
+
+#include <ostream>
+
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+inline std::ostream& operator<<(std::ostream& stream, const Token& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const AnnotatedSpan& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const DatetimeParseResultSpan& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
diff --git a/annotator/types.h b/annotator/types.h
new file mode 100644
index 0000000..38bce41
--- /dev/null
+++ b/annotator/types.h
@@ -0,0 +1,402 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
+
+#include <algorithm>
+#include <cmath>
+#include <functional>
+#include <map>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+constexpr int kInvalidIndex = -1;
+
+// Index for a 0-based array of tokens.
+using TokenIndex = int;
+
+// Index for a 0-based array of codepoints.
+using CodepointIndex = int;
+
+// Marks a span in a sequence of codepoints. The first element is the index of
+// the first codepoint of the span, and the second element is the index of the
+// codepoint one past the end of the span.
+// TODO(b/71982294): Make it a struct.
+using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+
+inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
+ return a.first < b.second && b.first < a.second;
+}
+
+inline bool ValidNonEmptySpan(const CodepointSpan& span) {
+ return span.first < span.second && span.first >= 0 && span.second >= 0;
+}
+
+template <typename T>
+bool DoesCandidateConflict(
+ const int considered_candidate, const std::vector<T>& candidates,
+ const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
+ if (chosen_indices_set.empty()) {
+ return false;
+ }
+
+ auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
+ // Check conflict on the right.
+ if (conflicting_it != chosen_indices_set.end() &&
+ SpansOverlap(candidates[considered_candidate].span,
+ candidates[*conflicting_it].span)) {
+ return true;
+ }
+
+ // Check conflict on the left.
+ // If we can't go more left, there can't be a conflict:
+ if (conflicting_it == chosen_indices_set.begin()) {
+ return false;
+ }
+ // Otherwise move one span left and insert if it doesn't overlap with the
+ // candidate.
+ --conflicting_it;
+ if (!SpansOverlap(candidates[considered_candidate].span,
+ candidates[*conflicting_it].span)) {
+ return false;
+ }
+
+ return true;
+}
+
+// Marks a span in a sequence of tokens. The first element is the index of the
+// first token in the span, and the second element is the index of the token one
+// past the end of the span.
+// TODO(b/71982294): Make it a struct.
+using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+
+// Returns the size of the token span. Assumes that the span is valid.
+inline int TokenSpanSize(const TokenSpan& token_span) {
+ return token_span.second - token_span.first;
+}
+
+// Returns a token span consisting of one token.
+inline TokenSpan SingleTokenSpan(int token_index) {
+ return {token_index, token_index + 1};
+}
+
+// Returns an intersection of two token spans. Assumes that both spans are valid
+// and overlapping.
+inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
+ const TokenSpan& token_span2) {
+ return {std::max(token_span1.first, token_span2.first),
+ std::min(token_span1.second, token_span2.second)};
+}
+
+// Returns and expanded token span by adding a certain number of tokens on its
+// left and on its right.
+inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
+ int num_tokens_left, int num_tokens_right) {
+ return {token_span.first - num_tokens_left,
+ token_span.second + num_tokens_right};
+}
+
+// Token holds a token, its position in the original string and whether it was
+// part of the input span.
+struct Token {
+ std::string value;
+ CodepointIndex start;
+ CodepointIndex end;
+
+ // Whether the token is a padding token.
+ bool is_padding;
+
+ // Default constructor constructs the padding-token.
+ Token()
+ : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
+
+ Token(const std::string& arg_value, CodepointIndex arg_start,
+ CodepointIndex arg_end)
+ : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
+
+ bool operator==(const Token& other) const {
+ return value == other.value && start == other.start && end == other.end &&
+ is_padding == other.is_padding;
+ }
+
+ bool IsContainedInSpan(CodepointSpan span) const {
+ return start >= span.first && end <= span.second;
+ }
+};
+
+// Pretty-printing function for Token.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const Token& token) {
+ if (!token.is_padding) {
+ return stream << "Token(\"" << token.value << "\", " << token.start << ", "
+ << token.end << ")";
+ } else {
+ return stream << "Token()";
+ }
+}
+
+enum DatetimeGranularity {
+ GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
+ // structure being uninitialized.
+ GRANULARITY_YEAR = 0,
+ GRANULARITY_MONTH = 1,
+ GRANULARITY_WEEK = 2,
+ GRANULARITY_DAY = 3,
+ GRANULARITY_HOUR = 4,
+ GRANULARITY_MINUTE = 5,
+ GRANULARITY_SECOND = 6
+};
+
+struct DatetimeParseResult {
+ // The absolute time in milliseconds since the epoch in UTC. This is derived
+ // from the reference time and the fields specified in the text - so it may
+ // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
+ int64 time_ms_utc;
+
+ // The precision of the estimate then in to calculating the milliseconds
+ DatetimeGranularity granularity;
+
+ DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
+
+ DatetimeParseResult(int64 arg_time_ms_utc,
+ DatetimeGranularity arg_granularity)
+ : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
+
+ bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
+
+ bool operator==(const DatetimeParseResult& other) const {
+ return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
+ }
+};
+
+const float kFloatCompareEpsilon = 1e-5;
+
+struct DatetimeParseResultSpan {
+ CodepointSpan span;
+ DatetimeParseResult data;
+ float target_classification_score;
+ float priority_score;
+
+ bool operator==(const DatetimeParseResultSpan& other) const {
+ return span == other.span && data.granularity == other.data.granularity &&
+ data.time_ms_utc == other.data.time_ms_utc &&
+ std::abs(target_classification_score -
+ other.target_classification_score) < kFloatCompareEpsilon &&
+ std::abs(priority_score - other.priority_score) <
+ kFloatCompareEpsilon;
+ }
+};
+
+// Pretty-printing function for DatetimeParseResultSpan.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value) {
+ return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
+ << value.span.second << "}, {/*time_ms_utc=*/ "
+ << value.data.time_ms_utc << ", /*granularity=*/ "
+ << value.data.granularity << "})";
+}
+
+struct ClassificationResult {
+ std::string collection;
+ float score;
+ DatetimeParseResult datetime_parse_result;
+ std::string serialized_knowledge_result;
+
+ // Internal score used for conflict resolution.
+ float priority_score;
+
+ // Extra information.
+ std::map<std::string, Variant> extra;
+
+ explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
+
+ ClassificationResult(const std::string& arg_collection, float arg_score)
+ : collection(arg_collection),
+ score(arg_score),
+ priority_score(arg_score) {}
+
+ ClassificationResult(const std::string& arg_collection, float arg_score,
+ float arg_priority_score)
+ : collection(arg_collection),
+ score(arg_score),
+ priority_score(arg_priority_score) {}
+};
+
+// Pretty-printing function for ClassificationResult.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const ClassificationResult& result) {
+ return stream << "ClassificationResult(" << result.collection << ", "
+ << result.score << ")";
+}
+
+// Pretty-printing function for std::vector<ClassificationResult>.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const std::vector<ClassificationResult>& results) {
+ stream = stream << "{\n";
+ for (const ClassificationResult& result : results) {
+ stream = stream << " " << result << "\n";
+ }
+ stream = stream << "}";
+ return stream;
+}
+
+// Represents a result of Annotate call.
+struct AnnotatedSpan {
+ // Unicode codepoint indices in the input string.
+ CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+
+ // Classification result for the span.
+ std::vector<ClassificationResult> classification;
+};
+
+// Pretty-printing function for AnnotatedSpan.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].collection;
+ best_score = span.classification[0].score;
+ }
+ return stream << "Span(" << span.span.first << ", " << span.span.second
+ << ", " << best_class << ", " << best_score << ")";
+}
+
+// StringPiece analogue for std::vector<T>.
+template <class T>
+class VectorSpan {
+ public:
+ VectorSpan() : begin_(), end_() {}
+ VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
+ : begin_(v.begin()), end_(v.end()) {}
+ VectorSpan(typename std::vector<T>::const_iterator begin,
+ typename std::vector<T>::const_iterator end)
+ : begin_(begin), end_(end) {}
+
+ const T& operator[](typename std::vector<T>::size_type i) const {
+ return *(begin_ + i);
+ }
+
+ int size() const { return end_ - begin_; }
+ typename std::vector<T>::const_iterator begin() const { return begin_; }
+ typename std::vector<T>::const_iterator end() const { return end_; }
+ const float* data() const { return &(*begin_); }
+
+ private:
+ typename std::vector<T>::const_iterator begin_;
+ typename std::vector<T>::const_iterator end_;
+};
+
+struct DateParseData {
+ enum Relation {
+ NEXT = 1,
+ NEXT_OR_SAME = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+ PAST = 7,
+ FUTURE = 8
+ };
+
+ enum RelationType {
+ SUNDAY = 1,
+ MONDAY = 2,
+ TUESDAY = 3,
+ WEDNESDAY = 4,
+ THURSDAY = 5,
+ FRIDAY = 6,
+ SATURDAY = 7,
+ DAY = 8,
+ WEEK = 9,
+ MONTH = 10,
+ YEAR = 11
+ };
+
+ enum Fields {
+ YEAR_FIELD = 1 << 0,
+ MONTH_FIELD = 1 << 1,
+ DAY_FIELD = 1 << 2,
+ HOUR_FIELD = 1 << 3,
+ MINUTE_FIELD = 1 << 4,
+ SECOND_FIELD = 1 << 5,
+ AMPM_FIELD = 1 << 6,
+ ZONE_OFFSET_FIELD = 1 << 7,
+ DST_OFFSET_FIELD = 1 << 8,
+ RELATION_FIELD = 1 << 9,
+ RELATION_TYPE_FIELD = 1 << 10,
+ RELATION_DISTANCE_FIELD = 1 << 11
+ };
+
+ enum AMPM { AM = 0, PM = 1 };
+
+ enum TimeUnit {
+ DAYS = 1,
+ WEEKS = 2,
+ MONTHS = 3,
+ HOURS = 4,
+ MINUTES = 5,
+ SECONDS = 6,
+ YEARS = 7
+ };
+
+ // Bit mask of fields which have been set on the struct
+ int field_set_mask;
+
+ // Fields describing absolute date fields.
+ // Year of the date seen in the text match.
+ int year;
+ // Month of the year starting with January = 1.
+ int month;
+ // Day of the month starting with 1.
+ int day_of_month;
+ // Hour of the day with a range of 0-23,
+ // values less than 12 need the AMPM field below or heuristics
+ // to definitively determine the time.
+ int hour;
+ // Hour of the day with a range of 0-59.
+ int minute;
+ // Hour of the day with a range of 0-59.
+ int second;
+ // 0 == AM, 1 == PM
+ int ampm;
+ // Number of hours offset from UTC this date time is in.
+ int zone_offset;
+ // Number of hours offest for DST
+ int dst_offset;
+
+ // The permutation from now that was made to find the date time.
+ Relation relation;
+ // The unit of measure of the change to the date time.
+ RelationType relation_type;
+ // The number of units of change that were made.
+ int relation_distance;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
diff --git a/annotator/zlib-utils.cc b/annotator/zlib-utils.cc
new file mode 100644
index 0000000..6efe025
--- /dev/null
+++ b/annotator/zlib-utils.cc
@@ -0,0 +1,128 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Compress rule fields in the model.
+bool CompressModel(ModelT* model) {
+ std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
+ if (!zlib_compressor) {
+ TC3_LOG(ERROR) << "Cannot compress model.";
+ return false;
+ }
+
+ // Compress regex rules.
+ if (model->regex_model != nullptr) {
+ for (int i = 0; i < model->regex_model->patterns.size(); i++) {
+ RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
+ pattern->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(pattern->pattern,
+ pattern->compressed_pattern.get());
+ pattern->pattern.clear();
+ }
+ }
+
+ // Compress date-time rules.
+ if (model->datetime_model != nullptr) {
+ for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
+ DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
+ for (int j = 0; j < pattern->regexes.size(); j++) {
+ DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
+ regex->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(regex->pattern,
+ regex->compressed_pattern.get());
+ regex->pattern.clear();
+ }
+ }
+ for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
+ DatetimeModelExtractorT* extractor =
+ model->datetime_model->extractors[i].get();
+ extractor->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(extractor->pattern,
+ extractor->compressed_pattern.get());
+ extractor->pattern.clear();
+ }
+ }
+ return true;
+}
+
+bool DecompressModel(ModelT* model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ // Decompress regex rules.
+ if (model->regex_model != nullptr) {
+ for (int i = 0; i < model->regex_model->patterns.size(); i++) {
+ RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
+ if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(),
+ &pattern->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ pattern->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ // Decompress date-time rules.
+ if (model->datetime_model != nullptr) {
+ for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
+ DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
+ for (int j = 0; j < pattern->regexes.size(); j++) {
+ DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
+ if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(),
+ &regex->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j;
+ return false;
+ }
+ regex->compressed_pattern.reset(nullptr);
+ }
+ }
+ for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
+ DatetimeModelExtractorT* extractor =
+ model->datetime_model->extractors[i].get();
+ if (!zlib_decompressor->MaybeDecompress(
+ extractor->compressed_pattern.get(), &extractor->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ extractor->compressed_pattern.reset(nullptr);
+ }
+ }
+ return true;
+}
+
+std::string CompressSerializedModel(const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(CompressModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/zlib-utils.h b/annotator/zlib-utils.h
new file mode 100644
index 0000000..462a02b
--- /dev/null
+++ b/annotator/zlib-utils.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Functions to compress and decompress low entropy entries in the model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_
+
+#include "annotator/model_generated.h"
+
+namespace libtextclassifier3 {
+
+// Compresses regex and datetime rules in the model in place.
+bool CompressModel(ModelT* model);
+
+// Decompresses regex and datetime rules in the model in place.
+bool DecompressModel(ModelT* model);
+
+// Compresses regex and datetime rules in the model.
+std::string CompressSerializedModel(const std::string& model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_
diff --git a/annotator/zlib-utils_test.cc b/annotator/zlib-utils_test.cc
new file mode 100644
index 0000000..7a8d775
--- /dev/null
+++ b/annotator/zlib-utils_test.cc
@@ -0,0 +1,99 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/zlib-utils.h"
+
+#include <memory>
+
+#include "annotator/model_generated.h"
+#include "utils/zlib/zlib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+TEST(ZlibUtilsTest, CompressModel) {
+ ModelT model;
+ model.regex_model.reset(new RegexModelT);
+ model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
+ model.regex_model->patterns.back()->pattern = "this is a test pattern";
+ model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
+ model.regex_model->patterns.back()->pattern = "this is a second test pattern";
+
+ model.datetime_model.reset(new DatetimeModelT);
+ model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT);
+ model.datetime_model->patterns.back()->regexes.emplace_back(
+ new DatetimeModelPattern_::RegexT);
+ model.datetime_model->patterns.back()->regexes.back()->pattern =
+ "an example datetime pattern";
+ model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT);
+ model.datetime_model->extractors.back()->pattern =
+ "an example datetime extractor";
+
+ // Compress the model.
+ EXPECT_TRUE(CompressModel(&model));
+
+ // Sanity check that uncompressed field is removed.
+ EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
+ EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
+ EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
+ EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
+
+ // Pack and load the model.
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, &model));
+ const Model* compressed_model =
+ GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer()));
+ ASSERT_TRUE(compressed_model != nullptr);
+
+ // Decompress the fields again and check that they match the original.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ ASSERT_TRUE(decompressor != nullptr);
+ std::string uncompressed_pattern;
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "this is a test pattern");
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "this is a second test pattern");
+ EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
+ ->patterns()
+ ->Get(0)
+ ->regexes()
+ ->Get(0)
+ ->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "an example datetime pattern");
+ EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
+ ->extractors()
+ ->Get(0)
+ ->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "an example datetime extractor");
+
+ EXPECT_TRUE(DecompressModel(&model));
+ EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern");
+ EXPECT_EQ(model.regex_model->patterns[1]->pattern,
+ "this is a second test pattern");
+ EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern,
+ "an example datetime pattern");
+ EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
+ "an example datetime extractor");
+}
+
+} // namespace libtextclassifier3