diff options
Diffstat (limited to 'annotator')
44 files changed, 12091 insertions, 0 deletions
diff --git a/annotator/annotator.cc b/annotator/annotator.cc new file mode 100644 index 0000000..2be9d3c --- /dev/null +++ b/annotator/annotator.cc @@ -0,0 +1,1695 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/annotator.h" + +#include <algorithm> +#include <cctype> +#include <cmath> +#include <iterator> +#include <numeric> + +#include "utils/base/logging.h" +#include "utils/checksum.h" +#include "utils/math/softmax.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { +const std::string& Annotator::kOtherCollection = + *[]() { return new std::string("other"); }(); +const std::string& Annotator::kPhoneCollection = + *[]() { return new std::string("phone"); }(); +const std::string& Annotator::kAddressCollection = + *[]() { return new std::string("address"); }(); +const std::string& Annotator::kDateCollection = + *[]() { return new std::string("date"); }(); +const std::string& Annotator::kUrlCollection = + *[]() { return new std::string("url"); }(); +const std::string& Annotator::kFlightCollection = + *[]() { return new std::string("flight"); }(); +const std::string& Annotator::kEmailCollection = + *[]() { return new std::string("email"); }(); +const std::string& Annotator::kIbanCollection = + *[]() { return new std::string("iban"); }(); +const std::string& Annotator::kPaymentCardCollection = + *[]() { return new std::string("payment_card"); }(); +const std::string& Annotator::kIsbnCollection = + *[]() { return new std::string("isbn"); }(); +const std::string& Annotator::kTrackingNumberCollection = + *[]() { return new std::string("tracking_number"); }(); + +namespace { +const Model* LoadAndVerifyModel(const void* addr, int size) { + flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); + if (VerifyModelBuffer(verifier)) { + return GetModel(addr); + } else { + return nullptr; + } +} + +// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will +// create a new instance, assign ownership to owned_lib, and return it. +const UniLib* MaybeCreateUnilib(const UniLib* lib, + std::unique_ptr<UniLib>* owned_lib) { + if (lib) { + return lib; + } else { + owned_lib->reset(new UniLib); + return owned_lib->get(); + } +} + +// As above, but for CalendarLib. +const CalendarLib* MaybeCreateCalendarlib( + const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) { + if (lib) { + return lib; + } else { + owned_lib->reset(new CalendarLib); + return owned_lib->get(); + } +} + +} // namespace + +tflite::Interpreter* InterpreterManager::SelectionInterpreter() { + if (!selection_interpreter_) { + TC3_CHECK(selection_executor_); + selection_interpreter_ = selection_executor_->CreateInterpreter(); + if (!selection_interpreter_) { + TC3_LOG(ERROR) << "Could not build TFLite interpreter."; + } + } + return selection_interpreter_.get(); +} + +tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { + if (!classification_interpreter_) { + TC3_CHECK(classification_executor_); + classification_interpreter_ = classification_executor_->CreateInterpreter(); + if (!classification_interpreter_) { + TC3_LOG(ERROR) << "Could not build TFLite interpreter."; + } + } + return classification_interpreter_.get(); +} + +std::unique_ptr<Annotator> Annotator::FromUnownedBuffer( + const char* buffer, int size, const UniLib* unilib, + const CalendarLib* calendarlib) { + const Model* model = LoadAndVerifyModel(buffer, size); + if (model == nullptr) { + return nullptr; + } + + auto classifier = + std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib)); + if (!classifier->IsInitialized()) { + return nullptr; + } + + return classifier; +} + + +std::unique_ptr<Annotator> Annotator::FromScopedMmap( + std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib, + const CalendarLib* calendarlib) { + if (!(*mmap)->handle().ok()) { + TC3_VLOG(1) << "Mmap failed."; + return nullptr; + } + + const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), + (*mmap)->handle().num_bytes()); + if (!model) { + TC3_LOG(ERROR) << "Model verification failed."; + return nullptr; + } + + auto classifier = std::unique_ptr<Annotator>( + new Annotator(mmap, model, unilib, calendarlib)); + if (!classifier->IsInitialized()) { + return nullptr; + } + + return classifier; +} + +std::unique_ptr<Annotator> Annotator::FromFileDescriptor( + int fd, int offset, int size, const UniLib* unilib, + const CalendarLib* calendarlib) { + std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); + return FromScopedMmap(&mmap, unilib, calendarlib); +} + +std::unique_ptr<Annotator> Annotator::FromFileDescriptor( + int fd, const UniLib* unilib, const CalendarLib* calendarlib) { + std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); + return FromScopedMmap(&mmap, unilib, calendarlib); +} + +std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path, + const UniLib* unilib, + const CalendarLib* calendarlib) { + std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); + return FromScopedMmap(&mmap, unilib, calendarlib); +} + +Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model, + const UniLib* unilib, const CalendarLib* calendarlib) + : model_(model), + mmap_(std::move(*mmap)), + owned_unilib_(nullptr), + unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)), + owned_calendarlib_(nullptr), + calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) { + ValidateAndInitialize(); +} + +Annotator::Annotator(const Model* model, const UniLib* unilib, + const CalendarLib* calendarlib) + : model_(model), + owned_unilib_(nullptr), + unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)), + owned_calendarlib_(nullptr), + calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) { + ValidateAndInitialize(); +} + +void Annotator::ValidateAndInitialize() { + initialized_ = false; + + if (model_ == nullptr) { + TC3_LOG(ERROR) << "No model specified."; + return; + } + + const bool model_enabled_for_annotation = + (model_->triggering_options() != nullptr && + (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)); + const bool model_enabled_for_classification = + (model_->triggering_options() != nullptr && + (model_->triggering_options()->enabled_modes() & + ModeFlag_CLASSIFICATION)); + const bool model_enabled_for_selection = + (model_->triggering_options() != nullptr && + (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)); + + // Annotation requires the selection model. + if (model_enabled_for_annotation || model_enabled_for_selection) { + if (!model_->selection_options()) { + TC3_LOG(ERROR) << "No selection options."; + return; + } + if (!model_->selection_feature_options()) { + TC3_LOG(ERROR) << "No selection feature options."; + return; + } + if (!model_->selection_feature_options()->bounds_sensitive_features()) { + TC3_LOG(ERROR) << "No selection bounds sensitive feature options."; + return; + } + if (!model_->selection_model()) { + TC3_LOG(ERROR) << "No selection model."; + return; + } + selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model()); + if (!selection_executor_) { + TC3_LOG(ERROR) << "Could not initialize selection executor."; + return; + } + selection_feature_processor_.reset( + new FeatureProcessor(model_->selection_feature_options(), unilib_)); + } + + // Annotation requires the classification model for conflict resolution and + // scoring. + // Selection requires the classification model for conflict resolution. + if (model_enabled_for_annotation || model_enabled_for_classification || + model_enabled_for_selection) { + if (!model_->classification_options()) { + TC3_LOG(ERROR) << "No classification options."; + return; + } + + if (!model_->classification_feature_options()) { + TC3_LOG(ERROR) << "No classification feature options."; + return; + } + + if (!model_->classification_feature_options() + ->bounds_sensitive_features()) { + TC3_LOG(ERROR) << "No classification bounds sensitive feature options."; + return; + } + if (!model_->classification_model()) { + TC3_LOG(ERROR) << "No clf model."; + return; + } + + classification_executor_ = + ModelExecutor::FromBuffer(model_->classification_model()); + if (!classification_executor_) { + TC3_LOG(ERROR) << "Could not initialize classification executor."; + return; + } + + classification_feature_processor_.reset(new FeatureProcessor( + model_->classification_feature_options(), unilib_)); + } + + // The embeddings need to be specified if the model is to be used for + // classification or selection. + if (model_enabled_for_annotation || model_enabled_for_classification || + model_enabled_for_selection) { + if (!model_->embedding_model()) { + TC3_LOG(ERROR) << "No embedding model."; + return; + } + + // Check that the embedding size of the selection and classification model + // matches, as they are using the same embeddings. + if (model_enabled_for_selection && + (model_->selection_feature_options()->embedding_size() != + model_->classification_feature_options()->embedding_size() || + model_->selection_feature_options()->embedding_quantization_bits() != + model_->classification_feature_options() + ->embedding_quantization_bits())) { + TC3_LOG(ERROR) << "Mismatching embedding size/quantization."; + return; + } + + embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer( + model_->embedding_model(), + model_->classification_feature_options()->embedding_size(), + model_->classification_feature_options() + ->embedding_quantization_bits()); + if (!embedding_executor_) { + TC3_LOG(ERROR) << "Could not initialize embedding executor."; + return; + } + } + + std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); + if (model_->regex_model()) { + if (!InitializeRegexModel(decompressor.get())) { + TC3_LOG(ERROR) << "Could not initialize regex model."; + return; + } + } + + if (model_->datetime_model()) { + datetime_parser_ = DatetimeParser::Instance( + model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get()); + if (!datetime_parser_) { + TC3_LOG(ERROR) << "Could not initialize datetime parser."; + return; + } + } + + if (model_->output_options()) { + if (model_->output_options()->filtered_collections_annotation()) { + for (const auto collection : + *model_->output_options()->filtered_collections_annotation()) { + filtered_collections_annotation_.insert(collection->str()); + } + } + if (model_->output_options()->filtered_collections_classification()) { + for (const auto collection : + *model_->output_options()->filtered_collections_classification()) { + filtered_collections_classification_.insert(collection->str()); + } + } + if (model_->output_options()->filtered_collections_selection()) { + for (const auto collection : + *model_->output_options()->filtered_collections_selection()) { + filtered_collections_selection_.insert(collection->str()); + } + } + } + + initialized_ = true; +} + +bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) { + if (!model_->regex_model()->patterns()) { + return true; + } + + // Initialize pattern recognizers. + int regex_pattern_id = 0; + for (const auto& regex_pattern : *model_->regex_model()->patterns()) { + std::unique_ptr<UniLib::RegexPattern> compiled_pattern = + UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(), + regex_pattern->compressed_pattern(), + decompressor); + if (!compiled_pattern) { + TC3_LOG(INFO) << "Failed to load regex pattern"; + return false; + } + + if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) { + annotation_regex_patterns_.push_back(regex_pattern_id); + } + if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) { + classification_regex_patterns_.push_back(regex_pattern_id); + } + if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { + selection_regex_patterns_.push_back(regex_pattern_id); + } + regex_patterns_.push_back({ + regex_pattern->collection_name()->str(), + regex_pattern->target_classification_score(), + regex_pattern->priority_score(), + std::move(compiled_pattern), + regex_pattern->verification_options(), + }); + if (regex_pattern->use_approximate_matching()) { + regex_approximate_match_pattern_ids_.insert(regex_pattern_id); + } + ++regex_pattern_id; + } + + return true; +} + +bool Annotator::InitializeKnowledgeEngine( + const std::string& serialized_config) { + std::unique_ptr<KnowledgeEngine> knowledge_engine( + new KnowledgeEngine(unilib_)); + if (!knowledge_engine->Initialize(serialized_config)) { + TC3_LOG(ERROR) << "Failed to initialize the knowledge engine."; + return false; + } + knowledge_engine_ = std::move(knowledge_engine); + return true; +} + +namespace { + +int CountDigits(const std::string& str, CodepointSpan selection_indices) { + int count = 0; + int i = 0; + const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); + for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { + if (i >= selection_indices.first && i < selection_indices.second && + isdigit(*it)) { + ++count; + } + } + return count; +} + +std::string ExtractSelection(const std::string& context, + CodepointSpan selection_indices) { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + auto selection_begin = context_unicode.begin(); + std::advance(selection_begin, selection_indices.first); + auto selection_end = context_unicode.begin(); + std::advance(selection_end, selection_indices.second); + return UnicodeText::UTF8Substring(selection_begin, selection_end); +} + +bool VerifyCandidate(const VerificationOptions* verification_options, + const std::string& match) { + if (!verification_options) { + return true; + } + if (verification_options->verify_luhn_checksum() && + !VerifyLuhnChecksum(match)) { + return false; + } + return true; +} + +} // namespace + +namespace internal { +// Helper function, which if the initial 'span' contains only white-spaces, +// moves the selection to a single-codepoint selection on a left or right side +// of this space. +CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, + const UnicodeText& context_unicode, + const UniLib& unilib) { + TC3_CHECK(ValidNonEmptySpan(span)); + + UnicodeText::const_iterator it; + + // Check that the current selection is all whitespaces. + it = context_unicode.begin(); + std::advance(it, span.first); + for (int i = 0; i < (span.second - span.first); ++i, ++it) { + if (!unilib.IsWhitespace(*it)) { + return span; + } + } + + CodepointSpan result; + + // Try moving left. + result = span; + it = context_unicode.begin(); + std::advance(it, span.first); + while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) { + --result.first; + --it; + } + result.second = result.first + 1; + if (!unilib.IsWhitespace(*it)) { + return result; + } + + // If moving left didn't find a non-whitespace character, just return the + // original span. + return span; +} +} // namespace internal + +bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const { + return !span.classification.empty() && + filtered_collections_annotation_.find( + span.classification[0].collection) != + filtered_collections_annotation_.end(); +} + +bool Annotator::FilteredForClassification( + const ClassificationResult& classification) const { + return filtered_collections_classification_.find(classification.collection) != + filtered_collections_classification_.end(); +} + +bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const { + return !span.classification.empty() && + filtered_collections_selection_.find( + span.classification[0].collection) != + filtered_collections_selection_.end(); +} + +CodepointSpan Annotator::SuggestSelection( + const std::string& context, CodepointSpan click_indices, + const SelectionOptions& options) const { + CodepointSpan original_click_indices = click_indices; + if (!initialized_) { + TC3_LOG(ERROR) << "Not initialized"; + return original_click_indices; + } + if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { + return original_click_indices; + } + + const UnicodeText context_unicode = UTF8ToUnicodeText(context, + /*do_copy=*/false); + + if (!context_unicode.is_valid()) { + return original_click_indices; + } + + const int context_codepoint_size = context_unicode.size_codepoints(); + + if (click_indices.first < 0 || click_indices.second < 0 || + click_indices.first >= context_codepoint_size || + click_indices.second > context_codepoint_size || + click_indices.first >= click_indices.second) { + TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " + << click_indices.first << " " << click_indices.second; + return original_click_indices; + } + + if (model_->snap_whitespace_selections()) { + // We want to expand a purely white-space selection to a multi-selection it + // would've been part of. But with this feature disabled we would do a no- + // op, because no token is found. Therefore, we need to modify the + // 'click_indices' a bit to include a part of the token, so that the click- + // finding logic finds the clicked token correctly. This modification is + // done by the following function. Note, that it's enough to check the left + // side of the current selection, because if the white-space is a part of a + // multi-selection, necessarily both tokens - on the left and the right + // sides need to be selected. Thus snapping only to the left is sufficient + // (there's a check at the bottom that makes sure that if we snap to the + // left token but the result does not contain the initial white-space, + // returns the original indices). + click_indices = internal::SnapLeftIfWhitespaceSelection( + click_indices, context_unicode, *unilib_); + } + + std::vector<AnnotatedSpan> candidates; + InterpreterManager interpreter_manager(selection_executor_.get(), + classification_executor_.get()); + std::vector<Token> tokens; + if (!ModelSuggestSelection(context_unicode, click_indices, + &interpreter_manager, &tokens, &candidates)) { + TC3_LOG(ERROR) << "Model suggest selection failed."; + return original_click_indices; + } + if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) { + TC3_LOG(ERROR) << "Regex suggest selection failed."; + return original_click_indices; + } + if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), + /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", + options.locales, ModeFlag_SELECTION, &candidates)) { + TC3_LOG(ERROR) << "Datetime suggest selection failed."; + return original_click_indices; + } + if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) { + TC3_LOG(ERROR) << "Knowledge suggest selection failed."; + return original_click_indices; + } + + // Sort candidates according to their position in the input, so that the next + // code can assume that any connected component of overlapping spans forms a + // contiguous block. + std::sort(candidates.begin(), candidates.end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + return a.span.first < b.span.first; + }); + + std::vector<int> candidate_indices; + if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, + &candidate_indices)) { + TC3_LOG(ERROR) << "Couldn't resolve conflicts."; + return original_click_indices; + } + + for (const int i : candidate_indices) { + if (SpansOverlap(candidates[i].span, click_indices) && + SpansOverlap(candidates[i].span, original_click_indices)) { + // Run model classification if not present but requested and there's a + // classification collection filter specified. + if (candidates[i].classification.empty() && + model_->selection_options()->always_classify_suggested_selection() && + !filtered_collections_selection_.empty()) { + if (!ModelClassifyText( + context, candidates[i].span, &interpreter_manager, + /*embedding_cache=*/nullptr, &candidates[i].classification)) { + return original_click_indices; + } + } + + // Ignore if span classification is filtered. + if (FilteredForSelection(candidates[i])) { + return original_click_indices; + } + + return candidates[i].span; + } + } + + return original_click_indices; +} + +namespace { +// Helper function that returns the index of the first candidate that +// transitively does not overlap with the candidate on 'start_index'. If the end +// of 'candidates' is reached, it returns the index that points right behind the +// array. +int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, + int start_index) { + int first_non_overlapping = start_index + 1; + CodepointSpan conflicting_span = candidates[start_index].span; + while ( + first_non_overlapping < candidates.size() && + SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) { + // Grow the span to include the current one. + conflicting_span.second = std::max( + conflicting_span.second, candidates[first_non_overlapping].span.second); + + ++first_non_overlapping; + } + return first_non_overlapping; +} +} // namespace + +bool Annotator::ResolveConflicts(const std::vector<AnnotatedSpan>& candidates, + const std::string& context, + const std::vector<Token>& cached_tokens, + InterpreterManager* interpreter_manager, + std::vector<int>* result) const { + result->clear(); + result->reserve(candidates.size()); + for (int i = 0; i < candidates.size();) { + int first_non_overlapping = + FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i); + + const bool conflict_found = first_non_overlapping != (i + 1); + if (conflict_found) { + std::vector<int> candidate_indices; + if (!ResolveConflict(context, cached_tokens, candidates, i, + first_non_overlapping, interpreter_manager, + &candidate_indices)) { + return false; + } + result->insert(result->end(), candidate_indices.begin(), + candidate_indices.end()); + } else { + result->push_back(i); + } + + // Skip over the whole conflicting group/go to next candidate. + i = first_non_overlapping; + } + return true; +} + +namespace { +inline bool ClassifiedAsOther( + const std::vector<ClassificationResult>& classification) { + return !classification.empty() && + classification[0].collection == Annotator::kOtherCollection; +} + +float GetPriorityScore( + const std::vector<ClassificationResult>& classification) { + if (!ClassifiedAsOther(classification)) { + return classification[0].priority_score; + } else { + return -1.0; + } +} +} // namespace + +bool Annotator::ResolveConflict(const std::string& context, + const std::vector<Token>& cached_tokens, + const std::vector<AnnotatedSpan>& candidates, + int start_index, int end_index, + InterpreterManager* interpreter_manager, + std::vector<int>* chosen_indices) const { + std::vector<int> conflicting_indices; + std::unordered_map<int, float> scores; + for (int i = start_index; i < end_index; ++i) { + conflicting_indices.push_back(i); + if (!candidates[i].classification.empty()) { + scores[i] = GetPriorityScore(candidates[i].classification); + continue; + } + + // OPTIMIZATION: So that we don't have to classify all the ML model + // spans apriori, we wait until we get here, when they conflict with + // something and we need the actual classification scores. So if the + // candidate conflicts and comes from the model, we need to run a + // classification to determine its priority: + std::vector<ClassificationResult> classification; + if (!ModelClassifyText(context, cached_tokens, candidates[i].span, + interpreter_manager, + /*embedding_cache=*/nullptr, &classification)) { + return false; + } + + if (!classification.empty()) { + scores[i] = GetPriorityScore(classification); + } + } + + std::sort(conflicting_indices.begin(), conflicting_indices.end(), + [&scores](int i, int j) { return scores[i] > scores[j]; }); + + // Keeps the candidates sorted by their position in the text (their left span + // index) for fast retrieval down. + std::set<int, std::function<bool(int, int)>> chosen_indices_set( + [&candidates](int a, int b) { + return candidates[a].span.first < candidates[b].span.first; + }); + + // Greedily place the candidates if they don't conflict with the already + // placed ones. + for (int i = 0; i < conflicting_indices.size(); ++i) { + const int considered_candidate = conflicting_indices[i]; + if (!DoesCandidateConflict(considered_candidate, candidates, + chosen_indices_set)) { + chosen_indices_set.insert(considered_candidate); + } + } + + *chosen_indices = + std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end()); + + return true; +} + +bool Annotator::ModelSuggestSelection( + const UnicodeText& context_unicode, CodepointSpan click_indices, + InterpreterManager* interpreter_manager, std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const { + if (model_->triggering_options() == nullptr || + !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) { + return true; + } + + int click_pos; + *tokens = selection_feature_processor_->Tokenize(context_unicode); + selection_feature_processor_->RetokenizeAndFindClick( + context_unicode, click_indices, + selection_feature_processor_->GetOptions()->only_use_line_with_click(), + tokens, &click_pos); + if (click_pos == kInvalidIndex) { + TC3_VLOG(1) << "Could not calculate the click position."; + return false; + } + + const int symmetry_context_size = + model_->selection_options()->symmetry_context_size(); + const FeatureProcessorOptions_::BoundsSensitiveFeatures* + bounds_sensitive_features = selection_feature_processor_->GetOptions() + ->bounds_sensitive_features(); + + // The symmetry context span is the clicked token with symmetry_context_size + // tokens on either side. + const TokenSpan symmetry_context_span = IntersectTokenSpans( + ExpandTokenSpan(SingleTokenSpan(click_pos), + /*num_tokens_left=*/symmetry_context_size, + /*num_tokens_right=*/symmetry_context_size), + {0, tokens->size()}); + + // Compute the extraction span based on the model type. + TokenSpan extraction_span; + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + // The extraction span is the symmetry context span expanded to include + // max_selection_span tokens on either side, which is how far a selection + // can stretch from the click, plus a relevant number of tokens outside of + // the bounds of the selection. + const int max_selection_span = + selection_feature_processor_->GetOptions()->max_selection_span(); + extraction_span = + ExpandTokenSpan(symmetry_context_span, + /*num_tokens_left=*/max_selection_span + + bounds_sensitive_features->num_tokens_before(), + /*num_tokens_right=*/max_selection_span + + bounds_sensitive_features->num_tokens_after()); + } else { + // The extraction span is the symmetry context span expanded to include + // context_size tokens on either side. + const int context_size = + selection_feature_processor_->GetOptions()->context_size(); + extraction_span = ExpandTokenSpan(symmetry_context_span, + /*num_tokens_left=*/context_size, + /*num_tokens_right=*/context_size); + } + extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); + + if (!selection_feature_processor_->HasEnoughSupportedCodepoints( + *tokens, extraction_span)) { + return true; + } + + std::unique_ptr<CachedFeatures> cached_features; + if (!selection_feature_processor_->ExtractFeatures( + *tokens, extraction_span, + /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, + embedding_executor_.get(), + /*embedding_cache=*/nullptr, + selection_feature_processor_->EmbeddingSize() + + selection_feature_processor_->DenseFeaturesCount(), + &cached_features)) { + TC3_LOG(ERROR) << "Could not extract features."; + return false; + } + + // Produce selection model candidates. + std::vector<TokenSpan> chunks; + if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, + interpreter_manager->SelectionInterpreter(), *cached_features, + &chunks)) { + TC3_LOG(ERROR) << "Could not chunk."; + return false; + } + + for (const TokenSpan& chunk : chunks) { + AnnotatedSpan candidate; + candidate.span = selection_feature_processor_->StripBoundaryCodepoints( + context_unicode, TokenSpanToCodepointSpan(*tokens, chunk)); + if (model_->selection_options()->strip_unpaired_brackets()) { + candidate.span = + StripUnpairedBrackets(context_unicode, candidate.span, *unilib_); + } + + // Only output non-empty spans. + if (candidate.span.first != candidate.span.second) { + result->push_back(candidate); + } + } + return true; +} + +bool Annotator::ModelClassifyText( + const std::string& context, CodepointSpan selection_indices, + InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const { + if (model_->triggering_options() == nullptr || + !(model_->triggering_options()->enabled_modes() & + ModeFlag_CLASSIFICATION)) { + return true; + } + return ModelClassifyText(context, {}, selection_indices, interpreter_manager, + embedding_cache, classification_results); +} + +namespace internal { +std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, + TokenSpan tokens_around_selection_to_copy) { + const auto first_selection_token = std::upper_bound( + cached_tokens.begin(), cached_tokens.end(), selection_indices.first, + [](int selection_start, const Token& token) { + return selection_start < token.end; + }); + const auto last_selection_token = std::lower_bound( + cached_tokens.begin(), cached_tokens.end(), selection_indices.second, + [](const Token& token, int selection_end) { + return token.start < selection_end; + }); + + const int64 first_token = std::max( + static_cast<int64>(0), + static_cast<int64>((first_selection_token - cached_tokens.begin()) - + tokens_around_selection_to_copy.first)); + const int64 last_token = std::min( + static_cast<int64>(cached_tokens.size()), + static_cast<int64>((last_selection_token - cached_tokens.begin()) + + tokens_around_selection_to_copy.second)); + + std::vector<Token> tokens; + tokens.reserve(last_token - first_token); + for (int i = first_token; i < last_token; ++i) { + tokens.push_back(cached_tokens[i]); + } + return tokens; +} +} // namespace internal + +TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const { + const FeatureProcessorOptions_::BoundsSensitiveFeatures* + bounds_sensitive_features = + classification_feature_processor_->GetOptions() + ->bounds_sensitive_features(); + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + // The extraction span is the selection span expanded to include a relevant + // number of tokens outside of the bounds of the selection. + return {bounds_sensitive_features->num_tokens_before(), + bounds_sensitive_features->num_tokens_after()}; + } else { + // The extraction span is the clicked token with context_size tokens on + // either side. + const int context_size = + selection_feature_processor_->GetOptions()->context_size(); + return {context_size, context_size}; + } +} + +bool Annotator::ModelClassifyText( + const std::string& context, const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const { + std::vector<Token> tokens; + if (cached_tokens.empty()) { + tokens = classification_feature_processor_->Tokenize(context); + } else { + tokens = internal::CopyCachedTokens(cached_tokens, selection_indices, + ClassifyTextUpperBoundNeededTokens()); + } + + int click_pos; + classification_feature_processor_->RetokenizeAndFindClick( + context, selection_indices, + classification_feature_processor_->GetOptions() + ->only_use_line_with_click(), + &tokens, &click_pos); + const TokenSpan selection_token_span = + CodepointSpanToTokenSpan(tokens, selection_indices); + const int selection_num_tokens = TokenSpanSize(selection_token_span); + if (model_->classification_options()->max_num_tokens() > 0 && + model_->classification_options()->max_num_tokens() < + selection_num_tokens) { + *classification_results = {{kOtherCollection, 1.0}}; + return true; + } + + const FeatureProcessorOptions_::BoundsSensitiveFeatures* + bounds_sensitive_features = + classification_feature_processor_->GetOptions() + ->bounds_sensitive_features(); + if (selection_token_span.first == kInvalidIndex || + selection_token_span.second == kInvalidIndex) { + TC3_LOG(ERROR) << "Could not determine span."; + return false; + } + + // Compute the extraction span based on the model type. + TokenSpan extraction_span; + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + // The extraction span is the selection span expanded to include a relevant + // number of tokens outside of the bounds of the selection. + extraction_span = ExpandTokenSpan( + selection_token_span, + /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(), + /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); + } else { + if (click_pos == kInvalidIndex) { + TC3_LOG(ERROR) << "Couldn't choose a click position."; + return false; + } + // The extraction span is the clicked token with context_size tokens on + // either side. + const int context_size = + classification_feature_processor_->GetOptions()->context_size(); + extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos), + /*num_tokens_left=*/context_size, + /*num_tokens_right=*/context_size); + } + extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()}); + + if (!classification_feature_processor_->HasEnoughSupportedCodepoints( + tokens, extraction_span)) { + *classification_results = {{kOtherCollection, 1.0}}; + return true; + } + + std::unique_ptr<CachedFeatures> cached_features; + if (!classification_feature_processor_->ExtractFeatures( + tokens, extraction_span, selection_indices, embedding_executor_.get(), + embedding_cache, + classification_feature_processor_->EmbeddingSize() + + classification_feature_processor_->DenseFeaturesCount(), + &cached_features)) { + TC3_LOG(ERROR) << "Could not extract features."; + return false; + } + + std::vector<float> features; + features.reserve(cached_features->OutputFeaturesSize()); + if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { + cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span, + &features); + } else { + cached_features->AppendClickContextFeaturesForClick(click_pos, &features); + } + + TensorView<float> logits = classification_executor_->ComputeLogits( + TensorView<float>(features.data(), + {1, static_cast<int>(features.size())}), + interpreter_manager->ClassificationInterpreter()); + if (!logits.is_valid()) { + TC3_LOG(ERROR) << "Couldn't compute logits."; + return false; + } + + if (logits.dims() != 2 || logits.dim(0) != 1 || + logits.dim(1) != classification_feature_processor_->NumCollections()) { + TC3_LOG(ERROR) << "Mismatching output"; + return false; + } + + const std::vector<float> scores = + ComputeSoftmax(logits.data(), logits.dim(1)); + + classification_results->resize(scores.size()); + for (int i = 0; i < scores.size(); i++) { + (*classification_results)[i] = { + classification_feature_processor_->LabelToCollection(i), scores[i]}; + } + std::sort(classification_results->begin(), classification_results->end(), + [](const ClassificationResult& a, const ClassificationResult& b) { + return a.score > b.score; + }); + + // Phone class sanity check. + if (!classification_results->empty() && + classification_results->begin()->collection == kPhoneCollection) { + const int digit_count = CountDigits(context, selection_indices); + if (digit_count < + model_->classification_options()->phone_min_num_digits() || + digit_count > + model_->classification_options()->phone_max_num_digits()) { + *classification_results = {{kOtherCollection, 1.0}}; + } + } + + // Address class sanity check. + if (!classification_results->empty() && + classification_results->begin()->collection == kAddressCollection) { + if (selection_num_tokens < + model_->classification_options()->address_min_num_tokens()) { + *classification_results = {{kOtherCollection, 1.0}}; + } + } + + return true; +} + +bool Annotator::RegexClassifyText( + const std::string& context, CodepointSpan selection_indices, + ClassificationResult* classification_result) const { + const std::string selection_text = + ExtractSelection(context, selection_indices); + const UnicodeText selection_text_unicode( + UTF8ToUnicodeText(selection_text, /*do_copy=*/false)); + + // Check whether any of the regular expressions match. + for (const int pattern_id : classification_regex_patterns_) { + const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; + const std::unique_ptr<UniLib::RegexMatcher> matcher = + regex_pattern.pattern->Matcher(selection_text_unicode); + int status = UniLib::RegexMatcher::kNoError; + bool matches; + if (regex_approximate_match_pattern_ids_.find(pattern_id) != + regex_approximate_match_pattern_ids_.end()) { + matches = matcher->ApproximatelyMatches(&status); + } else { + matches = matcher->Matches(&status); + } + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + if (matches && + VerifyCandidate(regex_pattern.verification_options, selection_text)) { + *classification_result = {regex_pattern.collection_name, + regex_pattern.target_classification_score, + regex_pattern.priority_score}; + return true; + } + if (status != UniLib::RegexMatcher::kNoError) { + TC3_LOG(ERROR) << "Cound't match regex: " << pattern_id; + } + } + + return false; +} + +bool Annotator::DatetimeClassifyText( + const std::string& context, CodepointSpan selection_indices, + const ClassificationOptions& options, + ClassificationResult* classification_result) const { + if (!datetime_parser_) { + return false; + } + + const std::string selection_text = + ExtractSelection(context, selection_indices); + + std::vector<DatetimeParseResultSpan> datetime_spans; + if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, + options.reference_timezone, options.locales, + ModeFlag_CLASSIFICATION, + /*anchor_start_end=*/true, &datetime_spans)) { + TC3_LOG(ERROR) << "Error during parsing datetime."; + return false; + } + for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { + // Only consider the result valid if the selection and extracted datetime + // spans exactly match. + if (std::make_pair(datetime_span.span.first + selection_indices.first, + datetime_span.span.second + selection_indices.first) == + selection_indices) { + *classification_result = {kDateCollection, + datetime_span.target_classification_score}; + classification_result->datetime_parse_result = datetime_span.data; + return true; + } + } + return false; +} + +std::vector<ClassificationResult> Annotator::ClassifyText( + const std::string& context, CodepointSpan selection_indices, + const ClassificationOptions& options) const { + if (!initialized_) { + TC3_LOG(ERROR) << "Not initialized"; + return {}; + } + + if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return {}; + } + + if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { + return {}; + } + + if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { + TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: " + << std::get<0>(selection_indices) << " " + << std::get<1>(selection_indices); + return {}; + } + + // Try the knowledge engine. + ClassificationResult knowledge_result; + if (knowledge_engine_ && knowledge_engine_->ClassifyText( + context, selection_indices, &knowledge_result)) { + if (!FilteredForClassification(knowledge_result)) { + return {knowledge_result}; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // Try the regular expression models. + ClassificationResult regex_result; + if (RegexClassifyText(context, selection_indices, ®ex_result)) { + if (!FilteredForClassification(regex_result)) { + return {regex_result}; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // Try the date model. + ClassificationResult datetime_result; + if (DatetimeClassifyText(context, selection_indices, options, + &datetime_result)) { + if (!FilteredForClassification(datetime_result)) { + return {datetime_result}; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // Fallback to the model. + std::vector<ClassificationResult> model_result; + + InterpreterManager interpreter_manager(selection_executor_.get(), + classification_executor_.get()); + if (ModelClassifyText(context, selection_indices, &interpreter_manager, + /*embedding_cache=*/nullptr, &model_result) && + !model_result.empty()) { + if (!FilteredForClassification(model_result[0])) { + return model_result; + } else { + return {{kOtherCollection, 1.0}}; + } + } + + // No classifications. + return {}; +} + +bool Annotator::ModelAnnotate(const std::string& context, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const { + if (model_->triggering_options() == nullptr || + !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { + return true; + } + + const UnicodeText context_unicode = UTF8ToUnicodeText(context, + /*do_copy=*/false); + std::vector<UnicodeTextRange> lines; + if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) { + lines.push_back({context_unicode.begin(), context_unicode.end()}); + } else { + lines = selection_feature_processor_->SplitContext(context_unicode); + } + + const float min_annotate_confidence = + (model_->triggering_options() != nullptr + ? model_->triggering_options()->min_annotate_confidence() + : 0.f); + + FeatureProcessor::EmbeddingCache embedding_cache; + for (const UnicodeTextRange& line : lines) { + const std::string line_str = + UnicodeText::UTF8Substring(line.first, line.second); + + *tokens = selection_feature_processor_->Tokenize(line_str); + selection_feature_processor_->RetokenizeAndFindClick( + line_str, {0, std::distance(line.first, line.second)}, + selection_feature_processor_->GetOptions()->only_use_line_with_click(), + tokens, + /*click_pos=*/nullptr); + const TokenSpan full_line_span = {0, tokens->size()}; + + // TODO(zilka): Add support for greater granularity of this check. + if (!selection_feature_processor_->HasEnoughSupportedCodepoints( + *tokens, full_line_span)) { + continue; + } + + std::unique_ptr<CachedFeatures> cached_features; + if (!selection_feature_processor_->ExtractFeatures( + *tokens, full_line_span, + /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, + embedding_executor_.get(), + /*embedding_cache=*/nullptr, + selection_feature_processor_->EmbeddingSize() + + selection_feature_processor_->DenseFeaturesCount(), + &cached_features)) { + TC3_LOG(ERROR) << "Could not extract features."; + return false; + } + + std::vector<TokenSpan> local_chunks; + if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, + interpreter_manager->SelectionInterpreter(), + *cached_features, &local_chunks)) { + TC3_LOG(ERROR) << "Could not chunk."; + return false; + } + + const int offset = std::distance(context_unicode.begin(), line.first); + for (const TokenSpan& chunk : local_chunks) { + const CodepointSpan codepoint_span = + selection_feature_processor_->StripBoundaryCodepoints( + line_str, TokenSpanToCodepointSpan(*tokens, chunk)); + + // Skip empty spans. + if (codepoint_span.first != codepoint_span.second) { + std::vector<ClassificationResult> classification; + if (!ModelClassifyText(line_str, *tokens, codepoint_span, + interpreter_manager, &embedding_cache, + &classification)) { + TC3_LOG(ERROR) << "Could not classify text: " + << (codepoint_span.first + offset) << " " + << (codepoint_span.second + offset); + return false; + } + + // Do not include the span if it's classified as "other". + if (!classification.empty() && !ClassifiedAsOther(classification) && + classification[0].score >= min_annotate_confidence) { + AnnotatedSpan result_span; + result_span.span = {codepoint_span.first + offset, + codepoint_span.second + offset}; + result_span.classification = std::move(classification); + result->push_back(std::move(result_span)); + } + } + } + } + return true; +} + +const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const { + return selection_feature_processor_.get(); +} + +const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests() + const { + return classification_feature_processor_.get(); +} + +const DatetimeParser* Annotator::DatetimeParserForTests() const { + return datetime_parser_.get(); +} + +std::vector<AnnotatedSpan> Annotator::Annotate( + const std::string& context, const AnnotationOptions& options) const { + std::vector<AnnotatedSpan> candidates; + + if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { + return {}; + } + + if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { + return {}; + } + + InterpreterManager interpreter_manager(selection_executor_.get(), + classification_executor_.get()); + // Annotate with the selection model. + std::vector<Token> tokens; + if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) { + TC3_LOG(ERROR) << "Couldn't run ModelAnnotate."; + return {}; + } + + // Annotate with the regular expression models. + if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), + annotation_regex_patterns_, &candidates)) { + TC3_LOG(ERROR) << "Couldn't run RegexChunk."; + return {}; + } + + // Annotate with the datetime model. + if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), + options.reference_time_ms_utc, options.reference_timezone, + options.locales, ModeFlag_ANNOTATION, &candidates)) { + TC3_LOG(ERROR) << "Couldn't run RegexChunk."; + return {}; + } + + // Annotate with the knowledge engine. + if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) { + TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk."; + return {}; + } + + // Sort candidates according to their position in the input, so that the next + // code can assume that any connected component of overlapping spans forms a + // contiguous block. + std::sort(candidates.begin(), candidates.end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + return a.span.first < b.span.first; + }); + + std::vector<int> candidate_indices; + if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, + &candidate_indices)) { + TC3_LOG(ERROR) << "Couldn't resolve conflicts."; + return {}; + } + + std::vector<AnnotatedSpan> result; + result.reserve(candidate_indices.size()); + for (const int i : candidate_indices) { + if (!candidates[i].classification.empty() && + !ClassifiedAsOther(candidates[i].classification) && + !FilteredForAnnotation(candidates[i])) { + result.push_back(std::move(candidates[i])); + } + } + + return result; +} + +bool Annotator::RegexChunk(const UnicodeText& context_unicode, + const std::vector<int>& rules, + std::vector<AnnotatedSpan>* result) const { + for (int pattern_id : rules) { + const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; + const auto matcher = regex_pattern.pattern->Matcher(context_unicode); + if (!matcher) { + TC3_LOG(ERROR) << "Could not get regex matcher for pattern: " + << pattern_id; + return false; + } + + int status = UniLib::RegexMatcher::kNoError; + while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + if (regex_pattern.verification_options) { + if (!VerifyCandidate(regex_pattern.verification_options, + matcher->Group(1, &status).ToUTF8String())) { + continue; + } + } + result->emplace_back(); + // Selection/annotation regular expressions need to specify a capturing + // group specifying the selection. + result->back().span = {matcher->Start(1, &status), + matcher->End(1, &status)}; + result->back().classification = { + {regex_pattern.collection_name, + regex_pattern.target_classification_score, + regex_pattern.priority_score}}; + } + } + return true; +} + +bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, + tflite::Interpreter* selection_interpreter, + const CachedFeatures& cached_features, + std::vector<TokenSpan>* chunks) const { + const int max_selection_span = + selection_feature_processor_->GetOptions()->max_selection_span(); + // The inference span is the span of interest expanded to include + // max_selection_span tokens on either side, which is how far a selection can + // stretch from the click. + const TokenSpan inference_span = IntersectTokenSpans( + ExpandTokenSpan(span_of_interest, + /*num_tokens_left=*/max_selection_span, + /*num_tokens_right=*/max_selection_span), + {0, num_tokens}); + + std::vector<ScoredChunk> scored_chunks; + if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() && + selection_feature_processor_->GetOptions() + ->bounds_sensitive_features() + ->enabled()) { + if (!ModelBoundsSensitiveScoreChunks( + num_tokens, span_of_interest, inference_span, cached_features, + selection_interpreter, &scored_chunks)) { + return false; + } + } else { + if (!ModelClickContextScoreChunks(num_tokens, span_of_interest, + cached_features, selection_interpreter, + &scored_chunks)) { + return false; + } + } + std::sort(scored_chunks.rbegin(), scored_chunks.rend(), + [](const ScoredChunk& lhs, const ScoredChunk& rhs) { + return lhs.score < rhs.score; + }); + + // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick + // them greedily as long as they do not overlap with any previously picked + // chunks. + std::vector<bool> token_used(TokenSpanSize(inference_span)); + chunks->clear(); + for (const ScoredChunk& scored_chunk : scored_chunks) { + bool feasible = true; + for (int i = scored_chunk.token_span.first; + i < scored_chunk.token_span.second; ++i) { + if (token_used[i - inference_span.first]) { + feasible = false; + break; + } + } + + if (!feasible) { + continue; + } + + for (int i = scored_chunk.token_span.first; + i < scored_chunk.token_span.second; ++i) { + token_used[i - inference_span.first] = true; + } + + chunks->push_back(scored_chunk.token_span); + } + + std::sort(chunks->begin(), chunks->end()); + + return true; +} + +namespace { +// Updates the value at the given key in the map to maximum of the current value +// and the given value, or simply inserts the value if the key is not yet there. +template <typename Map> +void UpdateMax(Map* map, typename Map::key_type key, + typename Map::mapped_type value) { + const auto it = map->find(key); + if (it != map->end()) { + it->second = std::max(it->second, value); + } else { + (*map)[key] = value; + } +} +} // namespace + +bool Annotator::ModelClickContextScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const { + const int max_batch_size = model_->selection_options()->batch_size(); + + std::vector<float> all_features; + std::map<TokenSpan, float> chunk_scores; + for (int batch_start = span_of_interest.first; + batch_start < span_of_interest.second; batch_start += max_batch_size) { + const int batch_end = + std::min(batch_start + max_batch_size, span_of_interest.second); + + // Prepare features for the whole batch. + all_features.clear(); + all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); + for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { + cached_features.AppendClickContextFeaturesForClick(click_pos, + &all_features); + } + + // Run batched inference. + const int batch_size = batch_end - batch_start; + const int features_size = cached_features.OutputFeaturesSize(); + TensorView<float> logits = selection_executor_->ComputeLogits( + TensorView<float>(all_features.data(), {batch_size, features_size}), + selection_interpreter); + if (!logits.is_valid()) { + TC3_LOG(ERROR) << "Couldn't compute logits."; + return false; + } + if (logits.dims() != 2 || logits.dim(0) != batch_size || + logits.dim(1) != + selection_feature_processor_->GetSelectionLabelCount()) { + TC3_LOG(ERROR) << "Mismatching output."; + return false; + } + + // Save results. + for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { + const std::vector<float> scores = ComputeSoftmax( + logits.data() + logits.dim(1) * (click_pos - batch_start), + logits.dim(1)); + for (int j = 0; + j < selection_feature_processor_->GetSelectionLabelCount(); ++j) { + TokenSpan relative_token_span; + if (!selection_feature_processor_->LabelToTokenSpan( + j, &relative_token_span)) { + TC3_LOG(ERROR) << "Couldn't map the label to a token span."; + return false; + } + const TokenSpan candidate_span = ExpandTokenSpan( + SingleTokenSpan(click_pos), relative_token_span.first, + relative_token_span.second); + if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) { + UpdateMax(&chunk_scores, candidate_span, scores[j]); + } + } + } + } + + scored_chunks->clear(); + scored_chunks->reserve(chunk_scores.size()); + for (const auto& entry : chunk_scores) { + scored_chunks->push_back(ScoredChunk{entry.first, entry.second}); + } + + return true; +} + +bool Annotator::ModelBoundsSensitiveScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const TokenSpan& inference_span, const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const { + const int max_selection_span = + selection_feature_processor_->GetOptions()->max_selection_span(); + const int max_chunk_length = selection_feature_processor_->GetOptions() + ->selection_reduced_output_space() + ? max_selection_span + 1 + : 2 * max_selection_span + 1; + const bool score_single_token_spans_as_zero = + selection_feature_processor_->GetOptions() + ->bounds_sensitive_features() + ->score_single_token_spans_as_zero(); + + scored_chunks->clear(); + if (score_single_token_spans_as_zero) { + scored_chunks->reserve(TokenSpanSize(span_of_interest)); + } + + // Prepare all chunk candidates into one batch: + // - Are contained in the inference span + // - Have a non-empty intersection with the span of interest + // - Are at least one token long + // - Are not longer than the maximum chunk length + std::vector<TokenSpan> candidate_spans; + for (int start = inference_span.first; start < span_of_interest.second; + ++start) { + const int leftmost_end_index = std::max(start, span_of_interest.first) + 1; + for (int end = leftmost_end_index; + end <= inference_span.second && end - start <= max_chunk_length; + ++end) { + const TokenSpan candidate_span = {start, end}; + if (score_single_token_spans_as_zero && + TokenSpanSize(candidate_span) == 1) { + // Do not include the single token span in the batch, add a zero score + // for it directly to the output. + scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f}); + } else { + candidate_spans.push_back(candidate_span); + } + } + } + + const int max_batch_size = model_->selection_options()->batch_size(); + + std::vector<float> all_features; + scored_chunks->reserve(scored_chunks->size() + candidate_spans.size()); + for (int batch_start = 0; batch_start < candidate_spans.size(); + batch_start += max_batch_size) { + const int batch_end = std::min(batch_start + max_batch_size, + static_cast<int>(candidate_spans.size())); + + // Prepare features for the whole batch. + all_features.clear(); + all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); + for (int i = batch_start; i < batch_end; ++i) { + cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i], + &all_features); + } + + // Run batched inference. + const int batch_size = batch_end - batch_start; + const int features_size = cached_features.OutputFeaturesSize(); + TensorView<float> logits = selection_executor_->ComputeLogits( + TensorView<float>(all_features.data(), {batch_size, features_size}), + selection_interpreter); + if (!logits.is_valid()) { + TC3_LOG(ERROR) << "Couldn't compute logits."; + return false; + } + if (logits.dims() != 2 || logits.dim(0) != batch_size || + logits.dim(1) != 1) { + TC3_LOG(ERROR) << "Mismatching output."; + return false; + } + + // Save results. + for (int i = batch_start; i < batch_end; ++i) { + scored_chunks->push_back( + ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]}); + } + } + + return true; +} + +bool Annotator::DatetimeChunk(const UnicodeText& context_unicode, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& locales, ModeFlag mode, + std::vector<AnnotatedSpan>* result) const { + if (!datetime_parser_) { + return true; + } + + std::vector<DatetimeParseResultSpan> datetime_spans; + if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, + reference_timezone, locales, mode, + /*anchor_start_end=*/false, &datetime_spans)) { + return false; + } + for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { + AnnotatedSpan annotated_span; + annotated_span.span = datetime_span.span; + annotated_span.classification = {{kDateCollection, + datetime_span.target_classification_score, + datetime_span.priority_score}}; + annotated_span.classification[0].datetime_parse_result = datetime_span.data; + + result->push_back(std::move(annotated_span)); + } + return true; +} + +const Model* ViewModel(const void* buffer, int size) { + if (!buffer) { + return nullptr; + } + + return LoadAndVerifyModel(buffer, size); +} + +} // namespace libtextclassifier3 diff --git a/annotator/annotator.h b/annotator/annotator.h new file mode 100644 index 0000000..c58c03d --- /dev/null +++ b/annotator/annotator.h @@ -0,0 +1,393 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Inference code for the text classification model. + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_ + +#include <memory> +#include <set> +#include <string> +#include <vector> + +#include "annotator/datetime/parser.h" +#include "annotator/feature-processor.h" +#include "annotator/knowledge/knowledge-engine.h" +#include "annotator/model-executor.h" +#include "annotator/model_generated.h" +#include "annotator/strip-unpaired-brackets.h" +#include "annotator/types.h" +#include "annotator/zlib-utils.h" +#include "utils/memory/mmap.h" +#include "utils/utf8/unilib.h" +#include "utils/zlib/zlib.h" + +namespace libtextclassifier3 { + +struct SelectionOptions { + // Comma-separated list of locale specification for the input text (BCP 47 + // tags). + std::string locales; + + static SelectionOptions Default() { return SelectionOptions(); } +}; + +struct ClassificationOptions { + // For parsing relative datetimes, the reference now time against which the + // relative datetimes get resolved. + // UTC milliseconds since epoch. + int64 reference_time_ms_utc = 0; + + // Timezone in which the input text was written (format as accepted by ICU). + std::string reference_timezone; + + // Comma-separated list of locale specification for the input text (BCP 47 + // tags). + std::string locales; + + static ClassificationOptions Default() { return ClassificationOptions(); } +}; + +struct AnnotationOptions { + // For parsing relative datetimes, the reference now time against which the + // relative datetimes get resolved. + // UTC milliseconds since epoch. + int64 reference_time_ms_utc = 0; + + // Timezone in which the input text was written (format as accepted by ICU). + std::string reference_timezone; + + // Comma-separated list of locale specification for the input text (BCP 47 + // tags). + std::string locales; + + static AnnotationOptions Default() { return AnnotationOptions(); } +}; + +// Holds TFLite interpreters for selection and classification models. +// NOTE: his class is not thread-safe, thus should NOT be re-used across +// threads. +class InterpreterManager { + public: + // The constructor can be called with nullptr for any of the executors, and is + // a defined behavior, as long as the corresponding *Interpreter() method is + // not called when the executor is null. + InterpreterManager(const ModelExecutor* selection_executor, + const ModelExecutor* classification_executor) + : selection_executor_(selection_executor), + classification_executor_(classification_executor) {} + + // Gets or creates and caches an interpreter for the selection model. + tflite::Interpreter* SelectionInterpreter(); + + // Gets or creates and caches an interpreter for the classification model. + tflite::Interpreter* ClassificationInterpreter(); + + private: + const ModelExecutor* selection_executor_; + const ModelExecutor* classification_executor_; + + std::unique_ptr<tflite::Interpreter> selection_interpreter_; + std::unique_ptr<tflite::Interpreter> classification_interpreter_; +}; + +// A text processing model that provides text classification, annotation, +// selection suggestion for various types. +// NOTE: This class is not thread-safe. +class Annotator { + public: + static std::unique_ptr<Annotator> FromUnownedBuffer( + const char* buffer, int size, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + // Takes ownership of the mmap. + static std::unique_ptr<Annotator> FromScopedMmap( + std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + static std::unique_ptr<Annotator> FromFileDescriptor( + int fd, int offset, int size, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + static std::unique_ptr<Annotator> FromFileDescriptor( + int fd, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + static std::unique_ptr<Annotator> FromPath( + const std::string& path, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + + // Returns true if the model is ready for use. + bool IsInitialized() { return initialized_; } + + // Initializes the knowledge engine with the given config. + bool InitializeKnowledgeEngine(const std::string& serialized_config); + + // Runs inference for given a context and current selection (i.e. index + // of the first and one past last selected characters (utf8 codepoint + // offsets)). Returns the indices (utf8 codepoint offsets) of the selection + // beginning character and one past selection end character. + // Returns the original click_indices if an error occurs. + // NOTE: The selection indices are passed in and returned in terms of + // UTF8 codepoints (not bytes). + // Requires that the model is a smart selection model. + CodepointSpan SuggestSelection( + const std::string& context, CodepointSpan click_indices, + const SelectionOptions& options = SelectionOptions::Default()) const; + + // Classifies the selected text given the context string. + // Returns an empty result if an error occurs. + std::vector<ClassificationResult> ClassifyText( + const std::string& context, CodepointSpan selection_indices, + const ClassificationOptions& options = + ClassificationOptions::Default()) const; + + // Annotates given input text. The annotations are sorted by their position + // in the context string and exclude spans classified as 'other'. + std::vector<AnnotatedSpan> Annotate( + const std::string& context, + const AnnotationOptions& options = AnnotationOptions::Default()) const; + + // Exposes the feature processor for tests and evaluations. + const FeatureProcessor* SelectionFeatureProcessorForTests() const; + const FeatureProcessor* ClassificationFeatureProcessorForTests() const; + + // Exposes the date time parser for tests and evaluations. + const DatetimeParser* DatetimeParserForTests() const; + + // String collection names for various classes. + static const std::string& kOtherCollection; + static const std::string& kPhoneCollection; + static const std::string& kAddressCollection; + static const std::string& kDateCollection; + static const std::string& kUrlCollection; + static const std::string& kFlightCollection; + static const std::string& kEmailCollection; + static const std::string& kIbanCollection; + static const std::string& kPaymentCardCollection; + static const std::string& kIsbnCollection; + static const std::string& kTrackingNumberCollection; + + protected: + struct ScoredChunk { + TokenSpan token_span; + float score; + }; + + // Constructs and initializes text classifier from given model. + // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'. + Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model, + const UniLib* unilib, const CalendarLib* calendarlib); + + // Constructs, validates and initializes text classifier from given model. + // Does not own the buffer that backs 'model'. + explicit Annotator(const Model* model, const UniLib* unilib, + const CalendarLib* calendarlib); + + // Checks that model contains all required fields, and initializes internal + // datastructures. + void ValidateAndInitialize(); + + // Initializes regular expressions for the regex model. + bool InitializeRegexModel(ZlibDecompressor* decompressor); + + // Resolves conflicts in the list of candidates by removing some overlapping + // ones. Returns indices of the surviving ones. + // NOTE: Assumes that the candidates are sorted according to their position in + // the span. + bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates, + const std::string& context, + const std::vector<Token>& cached_tokens, + InterpreterManager* interpreter_manager, + std::vector<int>* result) const; + + // Resolves one conflict between candidates on indices 'start_index' + // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate + // indices to 'chosen_indices'. Returns false if a problem arises. + bool ResolveConflict(const std::string& context, + const std::vector<Token>& cached_tokens, + const std::vector<AnnotatedSpan>& candidates, + int start_index, int end_index, + InterpreterManager* interpreter_manager, + std::vector<int>* chosen_indices) const; + + // Gets selection candidates from the ML model. + // Provides the tokens produced during tokenization of the context string for + // reuse. + bool ModelSuggestSelection(const UnicodeText& context_unicode, + CodepointSpan click_indices, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const; + + // Classifies the selected text given the context string with the + // classification model. + // Returns true if no error occurred. + bool ModelClassifyText( + const std::string& context, const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const; + + bool ModelClassifyText( + const std::string& context, CodepointSpan selection_indices, + InterpreterManager* interpreter_manager, + FeatureProcessor::EmbeddingCache* embedding_cache, + std::vector<ClassificationResult>* classification_results) const; + + // Returns a relative token span that represents how many tokens on the left + // from the selection and right from the selection are needed for the + // classifier input. + TokenSpan ClassifyTextUpperBoundNeededTokens() const; + + // Classifies the selected text with the regular expressions models. + // Returns true if any regular expression matched and the result was set. + bool RegexClassifyText(const std::string& context, + CodepointSpan selection_indices, + ClassificationResult* classification_result) const; + + // Classifies the selected text with the date time model. + // Returns true if there was a match and the result was set. + bool DatetimeClassifyText(const std::string& context, + CodepointSpan selection_indices, + const ClassificationOptions& options, + ClassificationResult* classification_result) const; + + // Chunks given input text with the selection model and classifies the spans + // with the classification model. + // The annotations are sorted by their position in the context string and + // exclude spans classified as 'other'. + // Provides the tokens produced during tokenization of the context string for + // reuse. + bool ModelAnnotate(const std::string& context, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const; + + // Groups the tokens into chunks. A chunk is a token span that should be the + // suggested selection when any of its contained tokens is clicked. The chunks + // are non-overlapping and are sorted by their position in the context string. + // "num_tokens" is the total number of tokens available (as this method does + // not need the actual vector of tokens). + // "span_of_interest" is a span of all the tokens that could be clicked. + // The resulting chunks all have to overlap with it and they cover this span + // completely. The first and last chunk might extend beyond it. + // The chunks vector is cleared before filling. + bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest, + tflite::Interpreter* selection_interpreter, + const CachedFeatures& cached_features, + std::vector<TokenSpan>* chunks) const; + + // A helper method for ModelChunk(). It generates scored chunk candidates for + // a click context model. + // NOTE: The returned chunks can (and most likely do) overlap. + bool ModelClickContextScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const; + + // A helper method for ModelChunk(). It generates scored chunk candidates for + // a bounds-sensitive model. + // NOTE: The returned chunks can (and most likely do) overlap. + bool ModelBoundsSensitiveScoreChunks( + int num_tokens, const TokenSpan& span_of_interest, + const TokenSpan& inference_span, const CachedFeatures& cached_features, + tflite::Interpreter* selection_interpreter, + std::vector<ScoredChunk>* scored_chunks) const; + + // Produces chunks isolated by a set of regular expressions. + bool RegexChunk(const UnicodeText& context_unicode, + const std::vector<int>& rules, + std::vector<AnnotatedSpan>* result) const; + + // Produces chunks from the datetime parser. + bool DatetimeChunk(const UnicodeText& context_unicode, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& locales, ModeFlag mode, + std::vector<AnnotatedSpan>* result) const; + + // Returns whether a classification should be filtered. + bool FilteredForAnnotation(const AnnotatedSpan& span) const; + bool FilteredForClassification( + const ClassificationResult& classification) const; + bool FilteredForSelection(const AnnotatedSpan& span) const; + + const Model* model_; + + std::unique_ptr<const ModelExecutor> selection_executor_; + std::unique_ptr<const ModelExecutor> classification_executor_; + std::unique_ptr<const EmbeddingExecutor> embedding_executor_; + + std::unique_ptr<const FeatureProcessor> selection_feature_processor_; + std::unique_ptr<const FeatureProcessor> classification_feature_processor_; + + std::unique_ptr<const DatetimeParser> datetime_parser_; + + private: + struct CompiledRegexPattern { + std::string collection_name; + float target_classification_score; + float priority_score; + std::unique_ptr<UniLib::RegexPattern> pattern; + const VerificationOptions* verification_options; + }; + + std::unique_ptr<ScopedMmap> mmap_; + bool initialized_ = false; + bool enabled_for_annotation_ = false; + bool enabled_for_classification_ = false; + bool enabled_for_selection_ = false; + std::unordered_set<std::string> filtered_collections_annotation_; + std::unordered_set<std::string> filtered_collections_classification_; + std::unordered_set<std::string> filtered_collections_selection_; + + std::vector<CompiledRegexPattern> regex_patterns_; + std::unordered_set<int> regex_approximate_match_pattern_ids_; + + // Indices into regex_patterns_ for the different modes. + std::vector<int> annotation_regex_patterns_, classification_regex_patterns_, + selection_regex_patterns_; + + std::unique_ptr<UniLib> owned_unilib_; + const UniLib* unilib_; + std::unique_ptr<CalendarLib> owned_calendarlib_; + const CalendarLib* calendarlib_; + + std::unique_ptr<const KnowledgeEngine> knowledge_engine_; +}; + +namespace internal { + +// Helper function, which if the initial 'span' contains only white-spaces, +// moves the selection to a single-codepoint selection on the left side +// of this block of white-space. +CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, + const UnicodeText& context_unicode, + const UniLib& unilib); + +// Copies tokens from 'cached_tokens' that are +// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant +// from the tokens that correspond to 'selection_indices'. +std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, + CodepointSpan selection_indices, + TokenSpan tokens_around_selection_to_copy); +} // namespace internal + +// Interprets the buffer as a Model flatbuffer and returns it for reading. +const Model* ViewModel(const void* buffer, int size); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_ diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc new file mode 100644 index 0000000..9bda35a --- /dev/null +++ b/annotator/annotator_jni.cc @@ -0,0 +1,434 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// JNI wrapper for the Annotator. + +#include "annotator/annotator_jni.h" + +#include <jni.h> +#include <type_traits> +#include <vector> + +#include "annotator/annotator.h" +#include "annotator/annotator_jni_common.h" +#include "utils/base/integral_types.h" +#include "utils/calendar/calendar.h" +#include "utils/java/scoped_local_ref.h" +#include "utils/java/string_utils.h" +#include "utils/memory/mmap.h" +#include "utils/utf8/unilib.h" + +#ifdef TC3_UNILIB_JAVAICU +#ifndef TC3_CALENDAR_JAVAICU +#error Inconsistent usage of Java ICU components +#else +#define TC3_USE_JAVAICU +#endif +#endif + +using libtextclassifier3::AnnotatedSpan; +using libtextclassifier3::Annotator; +using libtextclassifier3::ClassificationResult; +using libtextclassifier3::CodepointSpan; +using libtextclassifier3::Model; +using libtextclassifier3::ScopedLocalRef; +// When using the Java's ICU, CalendarLib and UniLib need to be instantiated +// with a JavaVM pointer from JNI. When using a standard ICU the pointer is +// not needed and the objects are instantiated implicitly. +#ifdef TC3_USE_JAVAICU +using libtextclassifier3::CalendarLib; +using libtextclassifier3::UniLib; +#endif + +namespace libtextclassifier3 { + +using libtextclassifier3::CodepointSpan; + +namespace { + +jobjectArray ClassificationResultsToJObjectArray( + JNIEnv* env, + const std::vector<ClassificationResult>& classification_result) { + const ScopedLocalRef<jclass> result_class( + env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$ClassificationResult"), + env); + if (!result_class) { + TC3_LOG(ERROR) << "Couldn't find ClassificationResult class."; + return nullptr; + } + const ScopedLocalRef<jclass> datetime_parse_class( + env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$DatetimeResult"), + env); + if (!datetime_parse_class) { + TC3_LOG(ERROR) << "Couldn't find DatetimeResult class."; + return nullptr; + } + + const jmethodID result_class_constructor = env->GetMethodID( + result_class.get(), "<init>", + "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$DatetimeResult;[B)V"); + const jmethodID datetime_parse_class_constructor = + env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); + + const jobjectArray results = env->NewObjectArray(classification_result.size(), + result_class.get(), nullptr); + for (int i = 0; i < classification_result.size(); i++) { + jstring row_string = + env->NewStringUTF(classification_result[i].collection.c_str()); + + jobject row_datetime_parse = nullptr; + if (classification_result[i].datetime_parse_result.IsSet()) { + row_datetime_parse = env->NewObject( + datetime_parse_class.get(), datetime_parse_class_constructor, + classification_result[i].datetime_parse_result.time_ms_utc, + classification_result[i].datetime_parse_result.granularity); + } + + jbyteArray serialized_knowledge_result = nullptr; + const std::string& serialized_knowledge_result_string = + classification_result[i].serialized_knowledge_result; + if (!serialized_knowledge_result_string.empty()) { + serialized_knowledge_result = + env->NewByteArray(serialized_knowledge_result_string.size()); + env->SetByteArrayRegion(serialized_knowledge_result, 0, + serialized_knowledge_result_string.size(), + reinterpret_cast<const jbyte*>( + serialized_knowledge_result_string.data())); + } + + jobject result = + env->NewObject(result_class.get(), result_class_constructor, row_string, + static_cast<jfloat>(classification_result[i].score), + row_datetime_parse, serialized_knowledge_result); + env->SetObjectArrayElement(results, i, result); + env->DeleteLocalRef(result); + } + return results; +} + +CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, + CodepointSpan orig_indices, + bool from_utf8) { + const libtextclassifier3::UnicodeText unicode_str = + libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); + + int unicode_index = 0; + int bmp_index = 0; + + const int* source_index; + const int* target_index; + if (from_utf8) { + source_index = &unicode_index; + target_index = &bmp_index; + } else { + source_index = &bmp_index; + target_index = &unicode_index; + } + + CodepointSpan result{-1, -1}; + std::function<void()> assign_indices_fn = [&result, &orig_indices, + &source_index, &target_index]() { + if (orig_indices.first == *source_index) { + result.first = *target_index; + } + + if (orig_indices.second == *source_index) { + result.second = *target_index; + } + }; + + for (auto it = unicode_str.begin(); it != unicode_str.end(); + ++it, ++unicode_index, ++bmp_index) { + assign_indices_fn(); + + // There is 1 extra character in the input for each UTF8 character > 0xFFFF. + if (*it > 0xFFFF) { + ++bmp_index; + } + } + assign_indices_fn(); + + return result; +} + +} // namespace + +CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, + CodepointSpan bmp_indices) { + return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); +} + +CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, + CodepointSpan utf8_indices) { + return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); +} + +jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return env->NewStringUTF(""); + } + const Model* model = libtextclassifier3::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model || !model->locales()) { + return env->NewStringUTF(""); + } + return env->NewStringUTF(model->locales()->c_str()); +} + +jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return 0; + } + const Model* model = libtextclassifier3::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model) { + return 0; + } + return model->version(); +} + +jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return env->NewStringUTF(""); + } + const Model* model = libtextclassifier3::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model || !model->name()) { + return env->NewStringUTF(""); + } + return env->NewStringUTF(model->name()->c_str()); +} + +} // namespace libtextclassifier3 + +using libtextclassifier3::ClassificationResultsToJObjectArray; +using libtextclassifier3::ConvertIndicesBMPToUTF8; +using libtextclassifier3::ConvertIndicesUTF8ToBMP; +using libtextclassifier3::FromJavaAnnotationOptions; +using libtextclassifier3::FromJavaClassificationOptions; +using libtextclassifier3::FromJavaSelectionOptions; +using libtextclassifier3::ToStlString; + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator) +(JNIEnv* env, jobject thiz, jint fd) { +#ifdef TC3_USE_JAVAICU + std::shared_ptr<libtextclassifier3::JniCache> jni_cache( + libtextclassifier3::JniCache::Create(env)); + return reinterpret_cast<jlong>( + Annotator::FromFileDescriptor(fd, new UniLib(jni_cache), + new CalendarLib(jni_cache)) + .release()); +#else + return reinterpret_cast<jlong>(Annotator::FromFileDescriptor(fd).release()); +#endif +} + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath) +(JNIEnv* env, jobject thiz, jstring path) { + const std::string path_str = ToStlString(env, path); +#ifdef TC3_USE_JAVAICU + std::shared_ptr<libtextclassifier3::JniCache> jni_cache( + libtextclassifier3::JniCache::Create(env)); + return reinterpret_cast<jlong>(Annotator::FromPath(path_str, + new UniLib(jni_cache), + new CalendarLib(jni_cache)) + .release()); +#else + return reinterpret_cast<jlong>(Annotator::FromPath(path_str).release()); +#endif +} + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, + nativeNewAnnotatorFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); +#ifdef TC3_USE_JAVAICU + std::shared_ptr<libtextclassifier3::JniCache> jni_cache( + libtextclassifier3::JniCache::Create(env)); + return reinterpret_cast<jlong>( + Annotator::FromFileDescriptor(fd, offset, size, new UniLib(jni_cache), + new CalendarLib(jni_cache)) + .release()); +#else + return reinterpret_cast<jlong>( + Annotator::FromFileDescriptor(fd, offset, size).release()); +#endif +} + +TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME, + nativeInitializeKnowledgeEngine) +(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) { + if (!ptr) { + return false; + } + + Annotator* model = reinterpret_cast<Annotator*>(ptr); + + std::string serialized_config_string; + const int length = env->GetArrayLength(serialized_config); + serialized_config_string.resize(length); + env->GetByteArrayRegion(serialized_config, 0, length, + reinterpret_cast<jbyte*>(const_cast<char*>( + serialized_config_string.data()))); + + return model->InitializeKnowledgeEngine(serialized_config_string); +} + +TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options) { + if (!ptr) { + return nullptr; + } + + Annotator* model = reinterpret_cast<Annotator*>(ptr); + + const std::string context_utf8 = ToStlString(env, context); + CodepointSpan input_indices = + ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); + CodepointSpan selection = model->SuggestSelection( + context_utf8, input_indices, FromJavaSelectionOptions(env, options)); + selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); + + jintArray result = env->NewIntArray(2); + env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); + env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); + return result; +} + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options) { + if (!ptr) { + return nullptr; + } + Annotator* ff_model = reinterpret_cast<Annotator*>(ptr); + + const std::string context_utf8 = ToStlString(env, context); + const CodepointSpan input_indices = + ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); + const std::vector<ClassificationResult> classification_result = + ff_model->ClassifyText(context_utf8, input_indices, + FromJavaClassificationOptions(env, options)); + + return ClassificationResultsToJObjectArray(env, classification_result); +} + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { + if (!ptr) { + return nullptr; + } + Annotator* model = reinterpret_cast<Annotator*>(ptr); + std::string context_utf8 = ToStlString(env, context); + std::vector<AnnotatedSpan> annotations = + model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); + + jclass result_class = env->FindClass( + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"); + if (!result_class) { + TC3_LOG(ERROR) << "Couldn't find result class: " + << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$AnnotatedSpan"; + return nullptr; + } + + jmethodID result_class_constructor = + env->GetMethodID(result_class, "<init>", + "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$ClassificationResult;)V"); + + jobjectArray results = + env->NewObjectArray(annotations.size(), result_class, nullptr); + + for (int i = 0; i < annotations.size(); ++i) { + CodepointSpan span_bmp = + ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); + jobject result = env->NewObject(result_class, result_class_constructor, + static_cast<jint>(span_bmp.first), + static_cast<jint>(span_bmp.second), + ClassificationResultsToJObjectArray( + env, annotations[i].classification)); + env->SetObjectArrayElement(results, i, result); + env->DeleteLocalRef(result); + } + env->DeleteLocalRef(result_class); + return results; +} + +TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) +(JNIEnv* env, jobject thiz, jlong ptr) { + Annotator* model = reinterpret_cast<Annotator*>(ptr); + delete model; +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage) +(JNIEnv* env, jobject clazz, jint fd) { + TC3_LOG(WARNING) << "Using deprecated getLanguage()."; + return TC3_JNI_METHOD_NAME(TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)( + env, clazz, fd); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd)); + return GetLocalesFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetLocalesFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd, offset, size)); + return GetLocalesFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd)); + return GetVersionFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, + nativeGetVersionFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd, offset, size)); + return GetVersionFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd)); + return GetNameFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetNameFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd, offset, size)); + return GetNameFromMmap(env, mmap.get()); +} diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h new file mode 100644 index 0000000..47715b4 --- /dev/null +++ b/annotator/annotator_jni.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_ + +#include <jni.h> +#include <string> +#include "annotator/annotator_jni_common.h" +#include "annotator/types.h" +#include "utils/java/jni-base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// SmartSelection. +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator) +(JNIEnv* env, jobject thiz, jint fd); + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath) +(JNIEnv* env, jobject thiz, jstring path); + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, + nativeNewAnnotatorFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME, + nativeInitializeKnowledgeEngine) +(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config); + +TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options); + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options); + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options); + +TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) +(JNIEnv* env, jobject thiz, jlong ptr); + +// DEPRECATED. Use nativeGetLocales instead. +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetLocalesFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, + nativeGetVersionFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetNameFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +#ifdef __cplusplus +} +#endif + +namespace libtextclassifier3 { + +// Given a utf8 string and a span expressed in Java BMP (basic multilingual +// plane) codepoints, converts it to a span expressed in utf8 codepoints. +libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8( + const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices); + +// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a +// span expressed in Java BMP (basic multilingual plane) codepoints. +libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP( + const std::string& utf8_str, + libtextclassifier3::CodepointSpan utf8_indices); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_ diff --git a/annotator/annotator_jni_common.cc b/annotator/annotator_jni_common.cc new file mode 100644 index 0000000..0fdb87b --- /dev/null +++ b/annotator/annotator_jni_common.cc @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/annotator_jni_common.h" + +#include "utils/java/jni-base.h" +#include "utils/java/scoped_local_ref.h" + +namespace libtextclassifier3 { +namespace { +template <typename T> +T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, + const std::string& class_name) { + if (!joptions) { + return {}; + } + + const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), + env); + if (!options_class) { + return {}; + } + + const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( + env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, + "getLocale", "Ljava/lang/String;"); + const std::pair<bool, jobject> status_or_reference_timezone = + CallJniMethod0<jobject>(env, joptions, options_class.get(), + &JNIEnv::CallObjectMethod, "getReferenceTimezone", + "Ljava/lang/String;"); + const std::pair<bool, int64> status_or_reference_time_ms_utc = + CallJniMethod0<int64>(env, joptions, options_class.get(), + &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", + "J"); + + if (!status_or_locales.first || !status_or_reference_timezone.first || + !status_or_reference_time_ms_utc.first) { + return {}; + } + + T options; + options.locales = + ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + options.reference_timezone = ToStlString( + env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); + options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; + return options; +} +} // namespace + +SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { + if (!joptions) { + return {}; + } + + const ScopedLocalRef<jclass> options_class( + env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$SelectionOptions"), + env); + const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( + env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, + "getLocales", "Ljava/lang/String;"); + if (!status_or_locales.first) { + return {}; + } + + SelectionOptions options; + options.locales = + ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + + return options; +} + +ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, + jobject joptions) { + return FromJavaOptionsInternal<ClassificationOptions>( + env, joptions, + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions"); +} + +AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { + return FromJavaOptionsInternal<AnnotationOptions>( + env, joptions, + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions"); +} + +} // namespace libtextclassifier3 diff --git a/annotator/annotator_jni_common.h b/annotator/annotator_jni_common.h new file mode 100644 index 0000000..b62bb21 --- /dev/null +++ b/annotator/annotator_jni_common.h @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_ + +#include <jni.h> + +#include "annotator/annotator.h" + +#ifndef TC3_ANNOTATOR_CLASS_NAME +#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel +#endif + +#define TC3_ANNOTATOR_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ANNOTATOR_CLASS_NAME) + +namespace libtextclassifier3 { + +SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions); + +ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, + jobject joptions); + +AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_ diff --git a/annotator/annotator_jni_test.cc b/annotator/annotator_jni_test.cc new file mode 100644 index 0000000..929fb59 --- /dev/null +++ b/annotator/annotator_jni_test.cc @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/annotator_jni.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +TEST(Annotator, ConvertIndicesBMPUTF8) { + // Test boundary cases. + EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), std::make_pair(0, 5)); + EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello", {0, 5}), std::make_pair(0, 5)); + + EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {0, 5}), + std::make_pair(0, 5)); + EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {0, 5}), + std::make_pair(0, 5)); + EXPECT_EQ(ConvertIndicesBMPToUTF8("😁ello world", {0, 6}), + std::make_pair(0, 5)); + EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁ello world", {0, 5}), + std::make_pair(0, 6)); + + EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {6, 11}), + std::make_pair(6, 11)); + EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {6, 11}), + std::make_pair(6, 11)); + EXPECT_EQ(ConvertIndicesBMPToUTF8("hello worl😁", {6, 12}), + std::make_pair(6, 11)); + EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello worl😁", {6, 11}), + std::make_pair(6, 12)); + + // Simple example where the longer character is before the selection. + // character 😁 is 0x1f601 + EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello World.", {3, 8}), + std::make_pair(2, 7)); + + EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello World.", {2, 7}), + std::make_pair(3, 8)); + + // Longer character is before and in selection. + EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁 World.", {3, 9}), + std::make_pair(2, 7)); + + EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁 World.", {2, 7}), + std::make_pair(3, 9)); + + // Longer character is before and after selection. + EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello😁World.", {3, 8}), + std::make_pair(2, 7)); + + EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello😁World.", {2, 7}), + std::make_pair(3, 8)); + + // Longer character is before in after selection. + EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁😁World.", {3, 9}), + std::make_pair(2, 7)); + + EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁😁World.", {2, 7}), + std::make_pair(3, 9)); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc new file mode 100644 index 0000000..fbaf039 --- /dev/null +++ b/annotator/annotator_test.cc @@ -0,0 +1,1254 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/annotator.h" + +#include <fstream> +#include <iostream> +#include <memory> +#include <string> + +#include "annotator/model_generated.h" +#include "annotator/types-test-util.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +using testing::ElementsAreArray; +using testing::IsEmpty; +using testing::Pair; +using testing::Values; + +std::string FirstResult(const std::vector<ClassificationResult>& results) { + if (results.empty()) { + return "<INVALID RESULTS>"; + } + return results[0].collection; +} + +MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") { + return testing::Value(arg.span, Pair(start, end)) && + testing::Value(FirstResult(arg.classification), best_class); +} + +std::string ReadFile(const std::string& file_name) { + std::ifstream file_stream(file_name); + return std::string(std::istreambuf_iterator<char>(file_stream), {}); +} + +std::string GetModelPath() { + return TC3_TEST_DATA_DIR; +} + +class AnnotatorTest : public ::testing::TestWithParam<const char*> { + protected: + AnnotatorTest() + : INIT_UNILIB_FOR_TESTING(unilib_), + INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {} + UniLib unilib_; + CalendarLib calendarlib_; +}; + +TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) { + std::unique_ptr<Annotator> classifier = Annotator::FromPath( + GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_); + EXPECT_FALSE(classifier); +} + +INSTANTIATE_TEST_CASE_P(ClickContext, AnnotatorTest, + Values("test_model_cc.fb")); +INSTANTIATE_TEST_CASE_P(BoundsSensitive, AnnotatorTest, + Values("test_model.fb")); + +TEST_P(AnnotatorTest, ClassifyText) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ("other", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at", {15, 27}))); + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "Call me at (800) 123-456 today", {11, 24}))); + + // More lines. + EXPECT_EQ("other", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at|Visit " + "www.google.com every today!|Call me at (800) 123-456 today.", + {15, 27}))); + EXPECT_EQ("phone", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at|Visit " + "www.google.com every today!|Call me at (800) 123-456 today.", + {90, 103}))); + + // Single word. + EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5}))); + EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4}))); + EXPECT_EQ("<INVALID RESULTS>", + FirstResult(classifier->ClassifyText("asdf", {0, 0}))); + + // Junk. + EXPECT_EQ("<INVALID RESULTS>", + FirstResult(classifier->ClassifyText("", {0, 0}))); + EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText( + "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5}))); + // Test invalid utf8 input. + EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText( + "\xf0\x9f\x98\x8b\x8b", {0, 0}))); +} + +TEST_P(AnnotatorTest, ClassifyTextDisabledFail) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + unpacked_model->classification_model.clear(); + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + + // The classification model is still needed for selection scores. + ASSERT_FALSE(classifier); +} + +TEST_P(AnnotatorTest, ClassifyTextDisabled) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = + ModeFlag_ANNOTATION_AND_SELECTION; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_THAT( + classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}), + IsEmpty()); +} + +TEST_P(AnnotatorTest, ClassifyTextFilteredCollections) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "Call me at (800) 123-456 today", {11, 24}))); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // Disable phone classification + unpacked_model->output_options->filtered_collections_classification.push_back( + "phone"); + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ("other", FirstResult(classifier->ClassifyText( + "Call me at (800) 123-456 today", {11, 24}))); + + // Check that the address classification still passes. + EXPECT_EQ("address", FirstResult(classifier->ClassifyText( + "350 Third Street, Cambridge", {0, 27}))); +} + +std::unique_ptr<RegexModel_::PatternT> MakePattern( + const std::string& collection_name, const std::string& pattern, + const bool enabled_for_classification, const bool enabled_for_selection, + const bool enabled_for_annotation, const float score) { + std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT); + result->collection_name = collection_name; + result->pattern = pattern; + // We cannot directly operate with |= on the flag, so use an int here. + int enabled_modes = ModeFlag_NONE; + if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION; + if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION; + if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION; + result->enabled_modes = static_cast<ModeFlag>(enabled_modes); + result->target_classification_score = score; + result->priority_score = score; + return result; +} + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, ClassifyTextRegularExpression) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", "Barack Obama", /*enabled_for_classification=*/true, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5)); + std::unique_ptr<RegexModel_::PatternT> verified_pattern = + MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}", + /*enabled_for_classification=*/true, + /*enabled_for_selection=*/false, + /*enabled_for_annotation=*/false, 1.0); + verified_pattern->verification_options.reset(new VerificationOptionsT); + verified_pattern->verification_options->verify_luhn_checksum = true; + unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ("flight", + FirstResult(classifier->ClassifyText( + "Your flight LX373 is delayed by 3 hours.", {12, 17}))); + EXPECT_EQ("person", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at", {15, 27}))); + EXPECT_EQ("email", + FirstResult(classifier->ClassifyText("you@android.com", {0, 15}))); + EXPECT_EQ("email", FirstResult(classifier->ClassifyText( + "Contact me at you@android.com", {14, 29}))); + + EXPECT_EQ("url", FirstResult(classifier->ClassifyText( + "Visit www.google.com every today!", {6, 20}))); + + EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5}))); + EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd", + {7, 12}))); + EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText( + "cc: 4012 8888 8888 1881", {4, 23}))); + EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText( + "2221 0067 4735 6281", {0, 19}))); + // Luhn check fails. + EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282", + {0, 19}))); + + // More lines. + EXPECT_EQ("url", + FirstResult(classifier->ClassifyText( + "this afternoon Barack Obama gave a speech at|Visit " + "www.google.com every today!|Call me at (800) 123-456 today.", + {51, 65}))); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.back()->priority_score = 1.1; + std::unique_ptr<RegexModel_::PatternT> verified_pattern = + MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})", + /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, + /*enabled_for_annotation=*/false, 1.0); + verified_pattern->verification_options.reset(new VerificationOptionsT); + verified_pattern->verification_options->verify_luhn_checksum = true; + unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + // Check regular expression selection. + EXPECT_EQ(classifier->SuggestSelection( + "Your flight MA 0123 is delayed by 3 hours.", {12, 14}), + std::make_pair(12, 19)); + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon Barack Obama gave a speech at", {15, 21}), + std::make_pair(15, 27)); + EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}), + std::make_pair(4, 23)); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.back()->priority_score = 0.5; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); + ASSERT_TRUE(classifier); + + // Check conflict resolution. + EXPECT_EQ( + classifier->SuggestSelection( + "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123", + {55, 57}), + std::make_pair(26, 62)); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); + unpacked_model->regex_model->patterns.back()->priority_score = 1.1; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); + ASSERT_TRUE(classifier); + + // Check conflict resolution. + EXPECT_EQ( + classifier->SuggestSelection( + "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123", + {55, 57}), + std::make_pair(55, 62)); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, AnnotateRegex) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test regex models. + unpacked_model->regex_model->patterns.push_back(MakePattern( + "person", " (Barack Obama) ", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0)); + unpacked_model->regex_model->patterns.push_back(MakePattern( + "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5)); + std::unique_ptr<RegexModel_::PatternT> verified_pattern = + MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})", + /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, + /*enabled_for_annotation=*/true, 1.0); + verified_pattern->verification_options.reset(new VerificationOptionsT); + verified_pattern->verification_options->verify_luhn_checksum = true; + unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n"; + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({IsAnnotatedSpan(6, 18, "person"), + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + IsAnnotatedSpan(107, 126, "payment_card")})); +} +#endif // TC3_UNILIB_ICU + +TEST_P(AnnotatorTest, PhoneFiltering) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "phone: (123) 456 789", {7, 20}))); + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "phone: (123) 456 789,0001112", {7, 25}))); + EXPECT_EQ("other", FirstResult(classifier->ClassifyText( + "phone: (123) 456 789,0001112", {7, 28}))); +} + +TEST_P(AnnotatorTest, SuggestSelection) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon Barack Obama gave a speech at", {15, 21}), + std::make_pair(15, 21)); + + // Try passing whole string. + // If more than 1 token is specified, we should return back what entered. + EXPECT_EQ( + classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}), + std::make_pair(0, 27)); + + // Single letter. + EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1)); + + // Single word. + EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4)); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 23)); + + // Unpaired bracket stripping. + EXPECT_EQ( + classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}), + std::make_pair(11, 25)); + EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}), + std::make_pair(12, 15)); + EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}), + std::make_pair(11, 15)); + EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}), + std::make_pair(12, 15)); + + // If the resulting selection would be empty, the original span is returned. + EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}), + std::make_pair(11, 13)); + EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}), + std::make_pair(11, 12)); + EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}), + std::make_pair(11, 12)); +} + +TEST_P(AnnotatorTest, SuggestSelectionDisabledFail) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the selection model. + unpacked_model->selection_model.clear(); + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + // Selection model needs to be present for annotation. + ASSERT_FALSE(classifier); +} + +TEST_P(AnnotatorTest, SuggestSelectionDisabled) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the selection model. + unpacked_model->selection_model.clear(); + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION; + unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 14)); + + EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( + "call me at (800) 123-456 today", {11, 24}))); + + EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"), + IsEmpty()); +} + +TEST_P(AnnotatorTest, SuggestSelectionFilteredCollections) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 23)); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // Disable phone selection + unpacked_model->output_options->filtered_collections_selection.push_back( + "phone"); + // We need to force this for filtering. + unpacked_model->selection_options->always_classify_suggested_selection = true; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}), + std::make_pair(11, 14)); + + // Address selection should still work. + EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), + std::make_pair(0, 27)); +} + +TEST_P(AnnotatorTest, SuggestSelectionsAreSymmetric) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}), + std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}), + std::make_pair(0, 27)); + EXPECT_EQ( + classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}), + std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge", + {16, 22}), + std::make_pair(6, 33)); +} + +TEST_P(AnnotatorTest, SuggestSelectionWithNewLine) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}), + std::make_pair(4, 16)); + EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}), + std::make_pair(0, 12)); + + SelectionOptions options; + EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options), + std::make_pair(0, 7)); +} + +TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + // From the right. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon BarackObama, gave a speech at", {15, 26}), + std::make_pair(15, 26)); + + // From the right multiple. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}), + std::make_pair(15, 26)); + + // From the left multiple. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}), + std::make_pair(21, 32)); + + // From both sides. + EXPECT_EQ(classifier->SuggestSelection( + "this afternoon !BarackObama,- gave a speech at", {16, 27}), + std::make_pair(16, 27)); +} + +TEST_P(AnnotatorTest, SuggestSelectionNoCrashWithJunk) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + // Try passing in bunch of invalid selections. + EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}), + std::make_pair(-10, 27)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}), + std::make_pair(0, 27)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}), + std::make_pair(-30, 300)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}), + std::make_pair(-10, -1)); + EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}), + std::make_pair(100, 17)); + + // Try passing invalid utf8. + EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}), + std::make_pair(-1, -1)); +} + +TEST_P(AnnotatorTest, SuggestSelectionSelectSpace) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}), + std::make_pair(11, 23)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}), + std::make_pair(10, 11)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}), + std::make_pair(23, 24)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}), + std::make_pair(23, 24)); + EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today", + {14, 17}), + std::make_pair(11, 25)); + EXPECT_EQ( + classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}), + std::make_pair(11, 23)); + EXPECT_EQ( + classifier->SuggestSelection( + "let's meet at 350 Third Street Cambridge and go there", {30, 31}), + std::make_pair(14, 40)); + EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}), + std::make_pair(4, 5)); + EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}), + std::make_pair(7, 8)); + + // With a punctuation around the selected whitespace. + EXPECT_EQ( + classifier->SuggestSelection( + "let's meet at 350 Third Street, Cambridge and go there", {31, 32}), + std::make_pair(14, 41)); + + // When all's whitespace, should return the original indices. + EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}), + std::make_pair(0, 1)); + EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}), + std::make_pair(0, 3)); + EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}), + std::make_pair(2, 3)); + EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}), + std::make_pair(5, 6)); +} + +TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) { + UnicodeText text; + + text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), + std::make_pair(3, 4)); + text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), + std::make_pair(3, 4)); + + // Nothing on the left. + text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), + std::make_pair(4, 5)); + text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_), + std::make_pair(0, 1)); + + // Whitespace only. + text = UTF8ToUnicodeText(" ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_), + std::make_pair(2, 3)); + text = UTF8ToUnicodeText(" ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), + std::make_pair(4, 5)); + text = UTF8ToUnicodeText(" ", /*do_copy=*/false); + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_), + std::make_pair(0, 1)); +} + +TEST_P(AnnotatorTest, Annotate) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + AnnotationOptions options; + EXPECT_THAT(classifier->Annotate("853 225 3556", options), + ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); + EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty()); + + // Try passing invalid utf8. + EXPECT_TRUE( + classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options) + .empty()); +} + + +TEST_P(AnnotatorTest, AnnotateSmallBatches) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Set the batch size. + unpacked_model->selection_options->batch_size = 4; + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + AnnotationOptions options; + EXPECT_THAT(classifier->Annotate("853 225 3556", options), + ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")})); + EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty()); +} + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + // Add test threshold. + unpacked_model->triggering_options->min_annotate_confidence = + 2.f; // Discards all results. + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + + EXPECT_EQ(classifier->Annotate(test_string).size(), 0); +} +#endif // TC3_UNILIB_ICU + +TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Add test thresholds. + unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT); + unpacked_model->triggering_options->min_annotate_confidence = + 0.f; // Keeps all results. + unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL; + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_EQ(classifier->Annotate(test_string).size(), 2); +} + +TEST_P(AnnotatorTest, AnnotateDisabled) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the model for annotation. + unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION; + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + EXPECT_THAT(classifier->Annotate(test_string), IsEmpty()); +} + +TEST_P(AnnotatorTest, AnnotateFilteredCollections) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // Disable phone annotation + unpacked_model->output_options->filtered_collections_annotation.push_back( + "phone"); + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(28, 55, "address"), + })); +} + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + const std::string test_string = + "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " + "number is 853 225 3556"; + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + })); + + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + unpacked_model->output_options.reset(new OutputOptionsT); + + // We add a custom annotator that wins against the phone classification + // below and that we subsequently suppress. + unpacked_model->output_options->filtered_collections_annotation.push_back( + "suppress"); + + unpacked_model->regex_model->patterns.push_back(MakePattern( + "suppress", "(\\d{3} ?\\d{4})", + /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0)); + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_THAT(classifier->Annotate(test_string), + ElementsAreArray({ + IsAnnotatedSpan(28, 55, "address"), + })); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_CALENDAR_ICU +TEST_P(AnnotatorTest, ClassifyTextDate) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam()); + EXPECT_TRUE(classifier); + + std::vector<ClassificationResult> result; + ClassificationOptions options; + + options.reference_timezone = "Europe/Zurich"; + result = classifier->ClassifyText("january 1, 2017", {0, 15}, options); + + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); + result.clear(); + + options.reference_timezone = "America/Los_Angeles"; + result = classifier->ClassifyText("march 1, 2017", {0, 13}, options); + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); + result.clear(); + + options.reference_timezone = "America/Los_Angeles"; + result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options); + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_SECOND); + result.clear(); + + // Date on another line. + options.reference_timezone = "Europe/Zurich"; + result = classifier->ClassifyText( + "hello world this is the first line\n" + "january 1, 2017", + {35, 50}, options); + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_CALENDAR_ICU +TEST_P(AnnotatorTest, ClassifyTextDatePriorities) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam()); + EXPECT_TRUE(classifier); + + std::vector<ClassificationResult> result; + ClassificationOptions options; + + result.clear(); + options.reference_timezone = "Europe/Zurich"; + options.locales = "en-US"; + result = classifier->ClassifyText("03.05.1970", {0, 10}, options); + + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); + + result.clear(); + options.reference_timezone = "Europe/Zurich"; + options.locales = "de"; + result = classifier->ClassifyText("03.05.1970", {0, 10}, options); + + ASSERT_EQ(result.size(), 1); + EXPECT_THAT(result[0].collection, "date"); + EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000); + EXPECT_EQ(result[0].datetime_parse_result.granularity, + DatetimeGranularity::GRANULARITY_DAY); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_CALENDAR_ICU +TEST_P(AnnotatorTest, SuggestTextDateDisabled) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + // Disable the patterns for selection. + for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) { + unpacked_model->datetime_model->patterns[i]->enabled_modes = + ModeFlag_ANNOTATION_AND_CLASSIFICATION; + } + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + EXPECT_EQ("date", + FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15}))); + EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}), + std::make_pair(0, 7)); + EXPECT_THAT(classifier->Annotate("january 1, 2017"), + ElementsAreArray({IsAnnotatedSpan(0, 15, "date")})); +} +#endif // TC3_UNILIB_ICU + +class TestingAnnotator : public Annotator { + public: + TestingAnnotator(const std::string& model, const UniLib* unilib, + const CalendarLib* calendarlib) + : Annotator(ViewModel(model.data(), model.size()), unilib, calendarlib) {} + + using Annotator::ResolveConflicts; +}; + +AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span, + const std::string& collection, + const float score) { + AnnotatedSpan result; + result.span = span; + result.classification.push_back({collection, score}); + return result; +} + +TEST_F(AnnotatorTest, ResolveConflictsTrivial) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); + + std::vector<AnnotatedSpan> candidates{ + {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0})); +} + +TEST_F(AnnotatorTest, ResolveConflictsSequence) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 1}, "phone", 1.0), + MakeAnnotatedSpan({1, 2}, "phone", 1.0), + MakeAnnotatedSpan({2, 3}, "phone", 1.0), + MakeAnnotatedSpan({3, 4}, "phone", 1.0), + MakeAnnotatedSpan({4, 5}, "phone", 1.0), + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4})); +} + +TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 3}, "phone", 1.0), + MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser! + MakeAnnotatedSpan({3, 7}, "phone", 1.0), + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0, 2})); +} + +TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser! + MakeAnnotatedSpan({1, 5}, "phone", 1.0), + MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser! + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({1})); +} + +TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); + + std::vector<AnnotatedSpan> candidates{{ + MakeAnnotatedSpan({0, 3}, "phone", 0.5), + MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser! + MakeAnnotatedSpan({3, 7}, "phone", 0.6), + MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser! + MakeAnnotatedSpan({11, 15}, "phone", 0.9), + }}; + + std::vector<int> chosen; + classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{}, + /*interpreter_manager=*/nullptr, &chosen); + EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4})); +} + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, LongInput) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + for (const auto& type_value_pair : + std::vector<std::pair<std::string, std::string>>{ + {"address", "350 Third Street, Cambridge"}, + {"phone", "123 456-7890"}, + {"url", "www.google.com"}, + {"email", "someone@gmail.com"}, + {"flight", "LX 38"}, + {"date", "September 1, 2018"}}) { + const std::string input_100k = std::string(50000, ' ') + + type_value_pair.second + + std::string(50000, ' '); + const int value_length = type_value_pair.second.size(); + + EXPECT_THAT(classifier->Annotate(input_100k), + ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length, + type_value_pair.first)})); + EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}), + std::make_pair(50000, 50000 + value_length)); + EXPECT_EQ(type_value_pair.first, + FirstResult(classifier->ClassifyText( + input_100k, {50000, 50000 + value_length}))); + } +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +// These coarse tests are there only to make sure the execution happens in +// reasonable amount of time. +TEST_P(AnnotatorTest, LongInputNoResultCheck) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + for (const std::string& value : + std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) { + const std::string input_100k = + std::string(50000, ' ') + value + std::string(50000, ' '); + const int value_length = value.size(); + + classifier->Annotate(input_100k); + classifier->SuggestSelection(input_100k, {50000, 50001}); + classifier->ClassifyText(input_100k, {50000, 50000 + value_length}); + } +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, MaxTokenLength) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + std::unique_ptr<Annotator> classifier; + + // With unrestricted number of tokens should behave normally. + unpacked_model->classification_options->max_num_tokens = -1; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "address"); + + // Raise the maximum number of tokens to suppress the classification. + unpacked_model->classification_options->max_num_tokens = 3; + + flatbuffers::FlatBufferBuilder builder2; + FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder2.GetBufferPointer()), + builder2.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "other"); +} +#endif // TC3_UNILIB_ICU + +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, MinAddressTokenLength) { + const std::string test_model = ReadFile(GetModelPath() + GetParam()); + std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); + + std::unique_ptr<Annotator> classifier; + + // With unrestricted number of address tokens should behave normally. + unpacked_model->classification_options->address_min_num_tokens = 0; + + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "address"); + + // Raise number of address tokens to suppress the address classification. + unpacked_model->classification_options->address_min_num_tokens = 5; + + flatbuffers::FlatBufferBuilder builder2; + FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder2.GetBufferPointer()), + builder2.GetSize(), &unilib_, &calendarlib_); + ASSERT_TRUE(classifier); + + EXPECT_EQ(FirstResult(classifier->ClassifyText( + "I live at 350 Third Street, Cambridge.", {10, 37})), + "other"); +} +#endif // TC3_UNILIB_ICU + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/cached-features.cc b/annotator/cached-features.cc new file mode 100644 index 0000000..480c044 --- /dev/null +++ b/annotator/cached-features.cc @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/cached-features.h" + +#include "utils/base/logging.h" +#include "utils/tensor-view.h" + +namespace libtextclassifier3 { + +namespace { + +int CalculateOutputFeaturesSize(const FeatureProcessorOptions* options, + int feature_vector_size) { + const bool bounds_sensitive_enabled = + options->bounds_sensitive_features() && + options->bounds_sensitive_features()->enabled(); + + int num_extracted_tokens = 0; + if (bounds_sensitive_enabled) { + const FeatureProcessorOptions_::BoundsSensitiveFeatures* config = + options->bounds_sensitive_features(); + num_extracted_tokens += config->num_tokens_before(); + num_extracted_tokens += config->num_tokens_inside_left(); + num_extracted_tokens += config->num_tokens_inside_right(); + num_extracted_tokens += config->num_tokens_after(); + if (config->include_inside_bag()) { + ++num_extracted_tokens; + } + } else { + num_extracted_tokens = 2 * options->context_size() + 1; + } + + int output_features_size = num_extracted_tokens * feature_vector_size; + + if (bounds_sensitive_enabled && + options->bounds_sensitive_features()->include_inside_length()) { + ++output_features_size; + } + + return output_features_size; +} + +} // namespace + +std::unique_ptr<CachedFeatures> CachedFeatures::Create( + const TokenSpan& extraction_span, + std::unique_ptr<std::vector<float>> features, + std::unique_ptr<std::vector<float>> padding_features, + const FeatureProcessorOptions* options, int feature_vector_size) { + const int min_feature_version = + options->bounds_sensitive_features() && + options->bounds_sensitive_features()->enabled() + ? 2 + : 1; + if (options->feature_version() < min_feature_version) { + TC3_LOG(ERROR) << "Unsupported feature version."; + return nullptr; + } + + std::unique_ptr<CachedFeatures> cached_features(new CachedFeatures()); + cached_features->extraction_span_ = extraction_span; + cached_features->features_ = std::move(features); + cached_features->padding_features_ = std::move(padding_features); + cached_features->options_ = options; + + cached_features->output_features_size_ = + CalculateOutputFeaturesSize(options, feature_vector_size); + + return cached_features; +} + +void CachedFeatures::AppendClickContextFeaturesForClick( + int click_pos, std::vector<float>* output_features) const { + click_pos -= extraction_span_.first; + + AppendFeaturesInternal( + /*intended_span=*/ExpandTokenSpan(SingleTokenSpan(click_pos), + options_->context_size(), + options_->context_size()), + /*read_mask_span=*/{0, TokenSpanSize(extraction_span_)}, output_features); +} + +void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan( + TokenSpan selected_span, std::vector<float>* output_features) const { + const FeatureProcessorOptions_::BoundsSensitiveFeatures* config = + options_->bounds_sensitive_features(); + + selected_span.first -= extraction_span_.first; + selected_span.second -= extraction_span_.first; + + // Append the features for tokens around the left bound. Masks out tokens + // after the right bound, so that if num_tokens_inside_left goes past it, + // padding tokens will be used. + AppendFeaturesInternal( + /*intended_span=*/{selected_span.first - config->num_tokens_before(), + selected_span.first + + config->num_tokens_inside_left()}, + /*read_mask_span=*/{0, selected_span.second}, output_features); + + // Append the features for tokens around the right bound. Masks out tokens + // before the left bound, so that if num_tokens_inside_right goes past it, + // padding tokens will be used. + AppendFeaturesInternal( + /*intended_span=*/{selected_span.second - + config->num_tokens_inside_right(), + selected_span.second + config->num_tokens_after()}, + /*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)}, + output_features); + + if (config->include_inside_bag()) { + AppendBagFeatures(selected_span, output_features); + } + + if (config->include_inside_length()) { + output_features->push_back( + static_cast<float>(TokenSpanSize(selected_span))); + } +} + +void CachedFeatures::AppendFeaturesInternal( + const TokenSpan& intended_span, const TokenSpan& read_mask_span, + std::vector<float>* output_features) const { + const TokenSpan copy_span = + IntersectTokenSpans(intended_span, read_mask_span); + for (int i = intended_span.first; i < copy_span.first; ++i) { + AppendPaddingFeatures(output_features); + } + output_features->insert( + output_features->end(), + features_->begin() + copy_span.first * NumFeaturesPerToken(), + features_->begin() + copy_span.second * NumFeaturesPerToken()); + for (int i = copy_span.second; i < intended_span.second; ++i) { + AppendPaddingFeatures(output_features); + } +} + +void CachedFeatures::AppendPaddingFeatures( + std::vector<float>* output_features) const { + output_features->insert(output_features->end(), padding_features_->begin(), + padding_features_->end()); +} + +void CachedFeatures::AppendBagFeatures( + const TokenSpan& bag_span, std::vector<float>* output_features) const { + const int offset = output_features->size(); + output_features->resize(output_features->size() + NumFeaturesPerToken()); + for (int i = bag_span.first; i < bag_span.second; ++i) { + for (int j = 0; j < NumFeaturesPerToken(); ++j) { + (*output_features)[offset + j] += + (*features_)[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span); + } + } +} + +int CachedFeatures::NumFeaturesPerToken() const { + return padding_features_->size(); +} + +} // namespace libtextclassifier3 diff --git a/annotator/cached-features.h b/annotator/cached-features.h new file mode 100644 index 0000000..e03f79c --- /dev/null +++ b/annotator/cached-features.h @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ + +#include <memory> +#include <vector> + +#include "annotator/model-executor.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" + +namespace libtextclassifier3 { + +// Holds state for extracting features across multiple calls and reusing them. +// Assumes that features for each Token are independent. +class CachedFeatures { + public: + static std::unique_ptr<CachedFeatures> Create( + const TokenSpan& extraction_span, + std::unique_ptr<std::vector<float>> features, + std::unique_ptr<std::vector<float>> padding_features, + const FeatureProcessorOptions* options, int feature_vector_size); + + // Appends the click context features for the given click position to + // 'output_features'. + void AppendClickContextFeaturesForClick( + int click_pos, std::vector<float>* output_features) const; + + // Appends the bounds-sensitive features for the given token span to + // 'output_features'. + void AppendBoundsSensitiveFeaturesForSpan( + TokenSpan selected_span, std::vector<float>* output_features) const; + + // Returns number of features that 'AppendFeaturesForSpan' appends. + int OutputFeaturesSize() const { return output_features_size_; } + + private: + CachedFeatures() {} + + // Appends token features to the output. The intended_span specifies which + // tokens' features should be used in principle. The read_mask_span restricts + // which tokens are actually read. For tokens outside of the read_mask_span, + // padding tokens are used instead. + void AppendFeaturesInternal(const TokenSpan& intended_span, + const TokenSpan& read_mask_span, + std::vector<float>* output_features) const; + + // Appends features of one padding token to the output. + void AppendPaddingFeatures(std::vector<float>* output_features) const; + + // Appends the features of tokens from the given span to the output. The + // features are averaged so that the appended features have the size + // corresponding to one token. + void AppendBagFeatures(const TokenSpan& bag_span, + std::vector<float>* output_features) const; + + int NumFeaturesPerToken() const; + + TokenSpan extraction_span_; + const FeatureProcessorOptions* options_; + int output_features_size_; + std::unique_ptr<std::vector<float>> features_; + std::unique_ptr<std::vector<float>> padding_features_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ diff --git a/annotator/cached-features_test.cc b/annotator/cached-features_test.cc new file mode 100644 index 0000000..702f3ca --- /dev/null +++ b/annotator/cached-features_test.cc @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/cached-features.h" + +#include "annotator/model-executor.h" +#include "utils/tensor-view.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ElementsAreArray; +using testing::FloatEq; +using testing::Matcher; + +namespace libtextclassifier3 { +namespace { + +Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { + std::vector<Matcher<float>> matchers; + for (const float value : values) { + matchers.push_back(FloatEq(value)); + } + return ElementsAreArray(matchers); +} + +std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) { + std::unique_ptr<std::vector<float>> features(new std::vector<float>()); + for (int i = 1; i <= num_tokens; ++i) { + features->push_back(i * 11.0f); + features->push_back(-i * 11.0f); + features->push_back(i * 0.1f); + } + return features; +} + +std::vector<float> GetCachedClickContextFeatures( + const CachedFeatures& cached_features, int click_pos) { + std::vector<float> output_features; + cached_features.AppendClickContextFeaturesForClick(click_pos, + &output_features); + return output_features; +} + +std::vector<float> GetCachedBoundsSensitiveFeatures( + const CachedFeatures& cached_features, TokenSpan selected_span) { + std::vector<float> output_features; + cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span, + &output_features); + return output_features; +} + +TEST(CachedFeaturesTest, ClickContext) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.feature_version = 1; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateFeatureProcessorOptions(builder, &options)); + flatbuffers::DetachedBuffer options_fb = builder.Release(); + + std::unique_ptr<std::vector<float>> features = MakeFeatures(9); + std::unique_ptr<std::vector<float>> padding_features( + new std::vector<float>{112233.0, -112233.0, 321.0}); + + const std::unique_ptr<CachedFeatures> cached_features = + CachedFeatures::Create( + {3, 10}, std::move(features), std::move(padding_features), + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + /*feature_vector_size=*/3); + ASSERT_TRUE(cached_features); + + EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5), + ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0, + 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5})); + + EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6), + ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0, + 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6})); + + EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7), + ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, + 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7})); +} + +TEST(CachedFeaturesTest, BoundsSensitive) { + std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config( + new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); + config->enabled = true; + config->num_tokens_before = 2; + config->num_tokens_inside_left = 2; + config->num_tokens_inside_right = 2; + config->num_tokens_after = 2; + config->include_inside_bag = true; + config->include_inside_length = true; + FeatureProcessorOptionsT options; + options.bounds_sensitive_features = std::move(config); + options.feature_version = 2; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateFeatureProcessorOptions(builder, &options)); + flatbuffers::DetachedBuffer options_fb = builder.Release(); + + std::unique_ptr<std::vector<float>> features = MakeFeatures(9); + std::unique_ptr<std::vector<float>> padding_features( + new std::vector<float>{112233.0, -112233.0, 321.0}); + + const std::unique_ptr<CachedFeatures> cached_features = + CachedFeatures::Create( + {3, 9}, std::move(features), std::move(padding_features), + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + /*feature_vector_size=*/3); + ASSERT_TRUE(cached_features); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}), + ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, + -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0, + 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, + 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0})); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}), + ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, + -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0, + 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5, + 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0})); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}), + ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, + -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0, + 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, + 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0})); + + EXPECT_THAT( + GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}), + ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, + 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0, + 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, + 55.0, -55.0, 0.5, 66.0, -66.0, 0.6, + 44.0, -44.0, 0.4, 1.0})); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/datetime/extractor.cc b/annotator/datetime/extractor.cc new file mode 100644 index 0000000..31229dd --- /dev/null +++ b/annotator/datetime/extractor.cc @@ -0,0 +1,469 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/datetime/extractor.h" + +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +bool DatetimeExtractor::Extract(DateParseData* result, + CodepointSpan* result_span) const { + result->field_set_mask = 0; + *result_span = {kInvalidIndex, kInvalidIndex}; + + if (rule_.regex->groups() == nullptr) { + return false; + } + + for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) { + UnicodeText group_text; + const int group_type = rule_.regex->groups()->Get(group_id); + if (group_type == DatetimeGroupType_GROUP_UNUSED) { + continue; + } + if (!GroupTextFromMatch(group_id, &group_text)) { + TC3_LOG(ERROR) << "Couldn't retrieve group."; + return false; + } + // The pattern can have a group defined in a part that was not matched, + // e.g. an optional part. In this case we'll get an empty content here. + if (group_text.empty()) { + continue; + } + switch (group_type) { + case DatetimeGroupType_GROUP_YEAR: { + if (!ParseYear(group_text, &(result->year))) { + TC3_LOG(ERROR) << "Couldn't extract YEAR."; + return false; + } + result->field_set_mask |= DateParseData::YEAR_FIELD; + break; + } + case DatetimeGroupType_GROUP_MONTH: { + if (!ParseMonth(group_text, &(result->month))) { + TC3_LOG(ERROR) << "Couldn't extract MONTH."; + return false; + } + result->field_set_mask |= DateParseData::MONTH_FIELD; + break; + } + case DatetimeGroupType_GROUP_DAY: { + if (!ParseDigits(group_text, &(result->day_of_month))) { + TC3_LOG(ERROR) << "Couldn't extract DAY."; + return false; + } + result->field_set_mask |= DateParseData::DAY_FIELD; + break; + } + case DatetimeGroupType_GROUP_HOUR: { + if (!ParseDigits(group_text, &(result->hour))) { + TC3_LOG(ERROR) << "Couldn't extract HOUR."; + return false; + } + result->field_set_mask |= DateParseData::HOUR_FIELD; + break; + } + case DatetimeGroupType_GROUP_MINUTE: { + if (!ParseDigits(group_text, &(result->minute))) { + TC3_LOG(ERROR) << "Couldn't extract MINUTE."; + return false; + } + result->field_set_mask |= DateParseData::MINUTE_FIELD; + break; + } + case DatetimeGroupType_GROUP_SECOND: { + if (!ParseDigits(group_text, &(result->second))) { + TC3_LOG(ERROR) << "Couldn't extract SECOND."; + return false; + } + result->field_set_mask |= DateParseData::SECOND_FIELD; + break; + } + case DatetimeGroupType_GROUP_AMPM: { + if (!ParseAMPM(group_text, &(result->ampm))) { + TC3_LOG(ERROR) << "Couldn't extract AMPM."; + return false; + } + result->field_set_mask |= DateParseData::AMPM_FIELD; + break; + } + case DatetimeGroupType_GROUP_RELATIONDISTANCE: { + if (!ParseRelationDistance(group_text, &(result->relation_distance))) { + TC3_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD."; + return false; + } + result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD; + break; + } + case DatetimeGroupType_GROUP_RELATION: { + if (!ParseRelation(group_text, &(result->relation))) { + TC3_LOG(ERROR) << "Couldn't extract RELATION_FIELD."; + return false; + } + result->field_set_mask |= DateParseData::RELATION_FIELD; + break; + } + case DatetimeGroupType_GROUP_RELATIONTYPE: { + if (!ParseRelationType(group_text, &(result->relation_type))) { + TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD."; + return false; + } + result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD; + break; + } + case DatetimeGroupType_GROUP_DUMMY1: + case DatetimeGroupType_GROUP_DUMMY2: + break; + default: + TC3_LOG(INFO) << "Unknown group type."; + continue; + } + if (!UpdateMatchSpan(group_id, result_span)) { + TC3_LOG(ERROR) << "Couldn't update span."; + return false; + } + } + + if (result_span->first == kInvalidIndex || + result_span->second == kInvalidIndex) { + *result_span = {kInvalidIndex, kInvalidIndex}; + } + + return true; +} + +bool DatetimeExtractor::RuleIdForType(DatetimeExtractorType type, + int* rule_id) const { + auto type_it = type_and_locale_to_rule_.find(type); + if (type_it == type_and_locale_to_rule_.end()) { + return false; + } + + auto locale_it = type_it->second.find(locale_id_); + if (locale_it == type_it->second.end()) { + return false; + } + *rule_id = locale_it->second; + return true; +} + +bool DatetimeExtractor::ExtractType(const UnicodeText& input, + DatetimeExtractorType extractor_type, + UnicodeText* match_result) const { + int rule_id; + if (!RuleIdForType(extractor_type, &rule_id)) { + return false; + } + + std::unique_ptr<UniLib::RegexMatcher> matcher = + rules_[rule_id]->Matcher(input); + if (!matcher) { + return false; + } + + int status; + if (!matcher->Find(&status)) { + return false; + } + + if (match_result != nullptr) { + *match_result = matcher->Group(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + } + return true; +} + +bool DatetimeExtractor::GroupTextFromMatch(int group_id, + UnicodeText* result) const { + int status; + *result = matcher_.Group(group_id, &status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + return true; +} + +bool DatetimeExtractor::UpdateMatchSpan(int group_id, + CodepointSpan* span) const { + int status; + const int match_start = matcher_.Start(group_id, &status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + const int match_end = matcher_.End(group_id, &status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + if (span->first == kInvalidIndex || span->first > match_start) { + span->first = match_start; + } + if (span->second == kInvalidIndex || span->second < match_end) { + span->second = match_end; + } + + return true; +} + +template <typename T> +bool DatetimeExtractor::MapInput( + const UnicodeText& input, + const std::vector<std::pair<DatetimeExtractorType, T>>& mapping, + T* result) const { + for (const auto& type_value_pair : mapping) { + if (ExtractType(input, type_value_pair.first)) { + *result = type_value_pair.second; + return true; + } + } + return false; +} + +bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input, + int* parsed_number) const { + std::vector<std::pair<int, int>> found_numbers; + for (const auto& type_value_pair : + std::vector<std::pair<DatetimeExtractorType, int>>{ + {DatetimeExtractorType_ZERO, 0}, + {DatetimeExtractorType_ONE, 1}, + {DatetimeExtractorType_TWO, 2}, + {DatetimeExtractorType_THREE, 3}, + {DatetimeExtractorType_FOUR, 4}, + {DatetimeExtractorType_FIVE, 5}, + {DatetimeExtractorType_SIX, 6}, + {DatetimeExtractorType_SEVEN, 7}, + {DatetimeExtractorType_EIGHT, 8}, + {DatetimeExtractorType_NINE, 9}, + {DatetimeExtractorType_TEN, 10}, + {DatetimeExtractorType_ELEVEN, 11}, + {DatetimeExtractorType_TWELVE, 12}, + {DatetimeExtractorType_THIRTEEN, 13}, + {DatetimeExtractorType_FOURTEEN, 14}, + {DatetimeExtractorType_FIFTEEN, 15}, + {DatetimeExtractorType_SIXTEEN, 16}, + {DatetimeExtractorType_SEVENTEEN, 17}, + {DatetimeExtractorType_EIGHTEEN, 18}, + {DatetimeExtractorType_NINETEEN, 19}, + {DatetimeExtractorType_TWENTY, 20}, + {DatetimeExtractorType_THIRTY, 30}, + {DatetimeExtractorType_FORTY, 40}, + {DatetimeExtractorType_FIFTY, 50}, + {DatetimeExtractorType_SIXTY, 60}, + {DatetimeExtractorType_SEVENTY, 70}, + {DatetimeExtractorType_EIGHTY, 80}, + {DatetimeExtractorType_NINETY, 90}, + {DatetimeExtractorType_HUNDRED, 100}, + {DatetimeExtractorType_THOUSAND, 1000}, + }) { + int rule_id; + if (!RuleIdForType(type_value_pair.first, &rule_id)) { + return false; + } + + std::unique_ptr<UniLib::RegexMatcher> matcher = + rules_[rule_id]->Matcher(input); + if (!matcher) { + return false; + } + + int status; + while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + int span_start = matcher->Start(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + found_numbers.push_back({span_start, type_value_pair.second}); + } + } + + std::sort(found_numbers.begin(), found_numbers.end(), + [](const std::pair<int, int>& a, const std::pair<int, int>& b) { + return a.first < b.first; + }); + + int sum = 0; + int running_value = -1; + // Simple math to make sure we handle written numerical modifiers correctly + // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1. + for (const std::pair<int, int> position_number_pair : found_numbers) { + if (running_value >= 0) { + if (running_value > position_number_pair.second) { + sum += running_value; + running_value = position_number_pair.second; + } else { + running_value *= position_number_pair.second; + } + } else { + running_value = position_number_pair.second; + } + } + sum += running_value; + *parsed_number = sum; + return true; +} + +bool DatetimeExtractor::ParseDigits(const UnicodeText& input, + int* parsed_digits) const { + UnicodeText digit; + if (!ExtractType(input, DatetimeExtractorType_DIGITS, &digit)) { + return false; + } + + if (!unilib_.ParseInt32(digit, parsed_digits)) { + return false; + } + return true; +} + +bool DatetimeExtractor::ParseYear(const UnicodeText& input, + int* parsed_year) const { + if (!ParseDigits(input, parsed_year)) { + return false; + } + + if (*parsed_year < 100) { + if (*parsed_year < 50) { + *parsed_year += 2000; + } else { + *parsed_year += 1900; + } + } + + return true; +} + +bool DatetimeExtractor::ParseMonth(const UnicodeText& input, + int* parsed_month) const { + if (ParseDigits(input, parsed_month)) { + return true; + } + + if (MapInput(input, + { + {DatetimeExtractorType_JANUARY, 1}, + {DatetimeExtractorType_FEBRUARY, 2}, + {DatetimeExtractorType_MARCH, 3}, + {DatetimeExtractorType_APRIL, 4}, + {DatetimeExtractorType_MAY, 5}, + {DatetimeExtractorType_JUNE, 6}, + {DatetimeExtractorType_JULY, 7}, + {DatetimeExtractorType_AUGUST, 8}, + {DatetimeExtractorType_SEPTEMBER, 9}, + {DatetimeExtractorType_OCTOBER, 10}, + {DatetimeExtractorType_NOVEMBER, 11}, + {DatetimeExtractorType_DECEMBER, 12}, + }, + parsed_month)) { + return true; + } + + return false; +} + +bool DatetimeExtractor::ParseAMPM(const UnicodeText& input, + int* parsed_ampm) const { + return MapInput(input, + { + {DatetimeExtractorType_AM, DateParseData::AMPM::AM}, + {DatetimeExtractorType_PM, DateParseData::AMPM::PM}, + }, + parsed_ampm); +} + +bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input, + int* parsed_distance) const { + if (ParseDigits(input, parsed_distance)) { + return true; + } + if (ParseWrittenNumber(input, parsed_distance)) { + return true; + } + return false; +} + +bool DatetimeExtractor::ParseRelation( + const UnicodeText& input, DateParseData::Relation* parsed_relation) const { + return MapInput( + input, + { + {DatetimeExtractorType_NOW, DateParseData::Relation::NOW}, + {DatetimeExtractorType_YESTERDAY, DateParseData::Relation::YESTERDAY}, + {DatetimeExtractorType_TOMORROW, DateParseData::Relation::TOMORROW}, + {DatetimeExtractorType_NEXT, DateParseData::Relation::NEXT}, + {DatetimeExtractorType_NEXT_OR_SAME, + DateParseData::Relation::NEXT_OR_SAME}, + {DatetimeExtractorType_LAST, DateParseData::Relation::LAST}, + {DatetimeExtractorType_PAST, DateParseData::Relation::PAST}, + {DatetimeExtractorType_FUTURE, DateParseData::Relation::FUTURE}, + }, + parsed_relation); +} + +bool DatetimeExtractor::ParseRelationType( + const UnicodeText& input, + DateParseData::RelationType* parsed_relation_type) const { + return MapInput( + input, + { + {DatetimeExtractorType_MONDAY, DateParseData::MONDAY}, + {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY}, + {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY}, + {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY}, + {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY}, + {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY}, + {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY}, + {DatetimeExtractorType_DAY, DateParseData::DAY}, + {DatetimeExtractorType_WEEK, DateParseData::WEEK}, + {DatetimeExtractorType_MONTH, DateParseData::MONTH}, + {DatetimeExtractorType_YEAR, DateParseData::YEAR}, + }, + parsed_relation_type); +} + +bool DatetimeExtractor::ParseTimeUnit(const UnicodeText& input, + int* parsed_time_unit) const { + return MapInput(input, + { + {DatetimeExtractorType_DAYS, DateParseData::DAYS}, + {DatetimeExtractorType_WEEKS, DateParseData::WEEKS}, + {DatetimeExtractorType_MONTHS, DateParseData::MONTHS}, + {DatetimeExtractorType_HOURS, DateParseData::HOURS}, + {DatetimeExtractorType_MINUTES, DateParseData::MINUTES}, + {DatetimeExtractorType_SECONDS, DateParseData::SECONDS}, + {DatetimeExtractorType_YEARS, DateParseData::YEARS}, + }, + parsed_time_unit); +} + +bool DatetimeExtractor::ParseWeekday(const UnicodeText& input, + int* parsed_weekday) const { + return MapInput( + input, + { + {DatetimeExtractorType_MONDAY, DateParseData::MONDAY}, + {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY}, + {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY}, + {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY}, + {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY}, + {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY}, + {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY}, + }, + parsed_weekday); +} + +} // namespace libtextclassifier3 diff --git a/annotator/datetime/extractor.h b/annotator/datetime/extractor.h new file mode 100644 index 0000000..4c17aa7 --- /dev/null +++ b/annotator/datetime/extractor.h @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3 { + +struct CompiledRule { + // The compiled regular expression. + std::unique_ptr<const UniLib::RegexPattern> compiled_regex; + + // The uncompiled pattern and information about the pattern groups. + const DatetimeModelPattern_::Regex* regex; + + // DatetimeModelPattern which 'regex' is part of and comes from. + const DatetimeModelPattern* pattern; +}; + +// A helper class for DatetimeParser that extracts structured data +// (DateParseDate) from the current match of the passed RegexMatcher. +class DatetimeExtractor { + public: + DatetimeExtractor( + const CompiledRule& rule, const UniLib::RegexMatcher& matcher, + int locale_id, const UniLib& unilib, + const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& + extractor_rules, + const std::unordered_map<DatetimeExtractorType, + std::unordered_map<int, int>>& + type_and_locale_to_extractor_rule) + : rule_(rule), + matcher_(matcher), + locale_id_(locale_id), + unilib_(unilib), + rules_(extractor_rules), + type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {} + bool Extract(DateParseData* result, CodepointSpan* result_span) const; + + private: + bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const; + + // Returns true if the rule for given extractor matched. If it matched, + // match_result will contain the first group of the rule (if match_result not + // nullptr). + bool ExtractType(const UnicodeText& input, + DatetimeExtractorType extractor_type, + UnicodeText* match_result = nullptr) const; + + bool GroupTextFromMatch(int group_id, UnicodeText* result) const; + + // Updates the span to include the current match for the given group. + bool UpdateMatchSpan(int group_id, CodepointSpan* span) const; + + // Returns true if any of the extractors from 'mapping' matched. If it did, + // will fill 'result' with the associated value from 'mapping'. + template <typename T> + bool MapInput(const UnicodeText& input, + const std::vector<std::pair<DatetimeExtractorType, T>>& mapping, + T* result) const; + + bool ParseDigits(const UnicodeText& input, int* parsed_digits) const; + bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const; + bool ParseYear(const UnicodeText& input, int* parsed_year) const; + bool ParseMonth(const UnicodeText& input, int* parsed_month) const; + bool ParseAMPM(const UnicodeText& input, int* parsed_ampm) const; + bool ParseRelation(const UnicodeText& input, + DateParseData::Relation* parsed_relation) const; + bool ParseRelationDistance(const UnicodeText& input, + int* parsed_distance) const; + bool ParseTimeUnit(const UnicodeText& input, int* parsed_time_unit) const; + bool ParseRelationType( + const UnicodeText& input, + DateParseData::RelationType* parsed_relation_type) const; + bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const; + + const CompiledRule& rule_; + const UniLib::RegexMatcher& matcher_; + int locale_id_; + const UniLib& unilib_; + const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& rules_; + const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>& + type_and_locale_to_rule_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_ diff --git a/annotator/datetime/parser.cc b/annotator/datetime/parser.cc new file mode 100644 index 0000000..ac3a62d --- /dev/null +++ b/annotator/datetime/parser.cc @@ -0,0 +1,406 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/datetime/parser.h" + +#include <set> +#include <unordered_set> + +#include "annotator/datetime/extractor.h" +#include "utils/calendar/calendar.h" +#include "utils/i18n/locale.h" +#include "utils/strings/split.h" + +namespace libtextclassifier3 { +std::unique_ptr<DatetimeParser> DatetimeParser::Instance( + const DatetimeModel* model, const UniLib& unilib, + const CalendarLib& calendarlib, ZlibDecompressor* decompressor) { + std::unique_ptr<DatetimeParser> result( + new DatetimeParser(model, unilib, calendarlib, decompressor)); + if (!result->initialized_) { + result.reset(); + } + return result; +} + +DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, + const CalendarLib& calendarlib, + ZlibDecompressor* decompressor) + : unilib_(unilib), calendarlib_(calendarlib) { + initialized_ = false; + + if (model == nullptr) { + return; + } + + if (model->patterns() != nullptr) { + for (const DatetimeModelPattern* pattern : *model->patterns()) { + if (pattern->regexes()) { + for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) { + std::unique_ptr<UniLib::RegexPattern> regex_pattern = + UncompressMakeRegexPattern(unilib, regex->pattern(), + regex->compressed_pattern(), + decompressor); + if (!regex_pattern) { + TC3_LOG(ERROR) << "Couldn't create rule pattern."; + return; + } + rules_.push_back({std::move(regex_pattern), regex, pattern}); + if (pattern->locales()) { + for (int locale : *pattern->locales()) { + locale_to_rules_[locale].push_back(rules_.size() - 1); + } + } + } + } + } + } + + if (model->extractors() != nullptr) { + for (const DatetimeModelExtractor* extractor : *model->extractors()) { + std::unique_ptr<UniLib::RegexPattern> regex_pattern = + UncompressMakeRegexPattern(unilib, extractor->pattern(), + extractor->compressed_pattern(), + decompressor); + if (!regex_pattern) { + TC3_LOG(ERROR) << "Couldn't create extractor pattern"; + return; + } + extractor_rules_.push_back(std::move(regex_pattern)); + + if (extractor->locales()) { + for (int locale : *extractor->locales()) { + type_and_locale_to_extractor_rule_[extractor->extractor()][locale] = + extractor_rules_.size() - 1; + } + } + } + } + + if (model->locales() != nullptr) { + for (int i = 0; i < model->locales()->Length(); ++i) { + locale_string_to_id_[model->locales()->Get(i)->str()] = i; + } + } + + if (model->default_locales() != nullptr) { + for (const int locale : *model->default_locales()) { + default_locale_ids_.push_back(locale); + } + } + + use_extractors_for_locating_ = model->use_extractors_for_locating(); + + initialized_ = true; +} + +bool DatetimeParser::Parse( + const std::string& input, const int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const { + return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), + reference_time_ms_utc, reference_timezone, locales, mode, + anchor_start_end, results); +} + +bool DatetimeParser::FindSpansUsingLocales( + const std::vector<int>& locale_ids, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + ModeFlag mode, bool anchor_start_end, const std::string& reference_locale, + std::unordered_set<int>* executed_rules, + std::vector<DatetimeParseResultSpan>* found_spans) const { + for (const int locale_id : locale_ids) { + auto rules_it = locale_to_rules_.find(locale_id); + if (rules_it == locale_to_rules_.end()) { + continue; + } + + for (const int rule_id : rules_it->second) { + // Skip rules that were already executed in previous locales. + if (executed_rules->find(rule_id) != executed_rules->end()) { + continue; + } + + if (!(rules_[rule_id].pattern->enabled_modes() & mode)) { + continue; + } + + executed_rules->insert(rule_id); + + if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + anchor_start_end, found_spans)) { + return false; + } + } + } + return true; +} + +bool DatetimeParser::Parse( + const UnicodeText& input, const int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const { + std::vector<DatetimeParseResultSpan> found_spans; + std::unordered_set<int> executed_rules; + std::string reference_locale; + const std::vector<int> requested_locales = + ParseAndExpandLocales(locales, &reference_locale); + if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc, + reference_timezone, mode, anchor_start_end, + reference_locale, &executed_rules, &found_spans)) { + return false; + } + + std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans; + int counter = 0; + for (const auto& found_span : found_spans) { + indexed_found_spans.push_back({found_span, counter}); + counter++; + } + + // Resolve conflicts by always picking the longer span and breaking ties by + // selecting the earlier entry in the list for a given locale. + std::sort(indexed_found_spans.begin(), indexed_found_spans.end(), + [](const std::pair<DatetimeParseResultSpan, int>& a, + const std::pair<DatetimeParseResultSpan, int>& b) { + if ((a.first.span.second - a.first.span.first) != + (b.first.span.second - b.first.span.first)) { + return (a.first.span.second - a.first.span.first) > + (b.first.span.second - b.first.span.first); + } else { + return a.second < b.second; + } + }); + + found_spans.clear(); + for (auto& span_index_pair : indexed_found_spans) { + found_spans.push_back(span_index_pair.first); + } + + std::set<int, std::function<bool(int, int)>> chosen_indices_set( + [&found_spans](int a, int b) { + return found_spans[a].span.first < found_spans[b].span.first; + }); + for (int i = 0; i < found_spans.size(); ++i) { + if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) { + chosen_indices_set.insert(i); + results->push_back(found_spans[i]); + } + } + + return true; +} + +bool DatetimeParser::HandleParseMatch( + const CompiledRule& rule, const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + std::vector<DatetimeParseResultSpan>* result) const { + int status = UniLib::RegexMatcher::kNoError; + const int start = matcher.Start(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + + const int end = matcher.End(&status); + if (status != UniLib::RegexMatcher::kNoError) { + return false; + } + + DatetimeParseResultSpan parse_result; + if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone, + reference_locale, locale_id, &(parse_result.data), + &parse_result.span)) { + return false; + } + if (!use_extractors_for_locating_) { + parse_result.span = {start, end}; + } + if (parse_result.span.first != kInvalidIndex && + parse_result.span.second != kInvalidIndex) { + parse_result.target_classification_score = + rule.pattern->target_classification_score(); + parse_result.priority_score = rule.pattern->priority_score(); + result->push_back(parse_result); + } + return true; +} + +bool DatetimeParser::ParseWithRule( + const CompiledRule& rule, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, const int locale_id, + bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const { + std::unique_ptr<UniLib::RegexMatcher> matcher = + rule.compiled_regex->Matcher(input); + int status = UniLib::RegexMatcher::kNoError; + if (anchor_start_end) { + if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) { + if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + result)) { + return false; + } + } + } else { + while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + result)) { + return false; + } + } + } + return true; +} + +std::vector<int> DatetimeParser::ParseAndExpandLocales( + const std::string& locales, std::string* reference_locale) const { + std::vector<StringPiece> split_locales = strings::Split(locales, ','); + if (!split_locales.empty()) { + *reference_locale = split_locales[0].ToString(); + } else { + *reference_locale = ""; + } + + std::vector<int> result; + for (const StringPiece& locale_str : split_locales) { + auto locale_it = locale_string_to_id_.find(locale_str.ToString()); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + + const Locale locale = Locale::FromBCP47(locale_str.ToString()); + if (!locale.IsValid()) { + continue; + } + + const std::string language = locale.Language(); + const std::string script = locale.Script(); + const std::string region = locale.Region(); + + // First, try adding *-region locale. + if (!region.empty()) { + locale_it = locale_string_to_id_.find("*-" + region); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + } + // Second, try adding language-script-* locale. + if (!script.empty()) { + locale_it = locale_string_to_id_.find(language + "-" + script + "-*"); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + } + // Third, try adding language-* locale. + if (!language.empty()) { + locale_it = locale_string_to_id_.find(language + "-*"); + if (locale_it != locale_string_to_id_.end()) { + result.push_back(locale_it->second); + } + } + } + + // Add the default locales if they haven't been added already. + const std::unordered_set<int> result_set(result.begin(), result.end()); + for (const int default_locale_id : default_locale_ids_) { + if (result_set.find(default_locale_id) == result_set.end()) { + result.push_back(default_locale_id); + } + } + + return result; +} + +namespace { + +DatetimeGranularity GetGranularity(const DateParseData& data) { + DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR; + if ((data.field_set_mask & DateParseData::YEAR_FIELD) || + (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::YEAR))) { + granularity = DatetimeGranularity::GRANULARITY_YEAR; + } + if ((data.field_set_mask & DateParseData::MONTH_FIELD) || + (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::MONTH))) { + granularity = DatetimeGranularity::GRANULARITY_MONTH; + } + if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::WEEK)) { + granularity = DatetimeGranularity::GRANULARITY_WEEK; + } + if (data.field_set_mask & DateParseData::DAY_FIELD || + (data.field_set_mask & DateParseData::RELATION_FIELD && + (data.relation == DateParseData::Relation::NOW || + data.relation == DateParseData::Relation::TOMORROW || + data.relation == DateParseData::Relation::YESTERDAY)) || + (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && + (data.relation_type == DateParseData::RelationType::MONDAY || + data.relation_type == DateParseData::RelationType::TUESDAY || + data.relation_type == DateParseData::RelationType::WEDNESDAY || + data.relation_type == DateParseData::RelationType::THURSDAY || + data.relation_type == DateParseData::RelationType::FRIDAY || + data.relation_type == DateParseData::RelationType::SATURDAY || + data.relation_type == DateParseData::RelationType::SUNDAY || + data.relation_type == DateParseData::RelationType::DAY))) { + granularity = DatetimeGranularity::GRANULARITY_DAY; + } + if (data.field_set_mask & DateParseData::HOUR_FIELD) { + granularity = DatetimeGranularity::GRANULARITY_HOUR; + } + if (data.field_set_mask & DateParseData::MINUTE_FIELD) { + granularity = DatetimeGranularity::GRANULARITY_MINUTE; + } + if (data.field_set_mask & DateParseData::SECOND_FIELD) { + granularity = DatetimeGranularity::GRANULARITY_SECOND; + } + return granularity; +} + +} // namespace + +bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + const int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + int locale_id, DatetimeParseResult* result, + CodepointSpan* result_span) const { + DateParseData parse; + DatetimeExtractor extractor(rule, matcher, locale_id, unilib_, + extractor_rules_, + type_and_locale_to_extractor_rule_); + if (!extractor.Extract(&parse, result_span)) { + return false; + } + + result->granularity = GetGranularity(parse); + + if (!calendarlib_.InterpretParseData( + parse, reference_time_ms_utc, reference_timezone, reference_locale, + result->granularity, &(result->time_ms_utc))) { + return false; + } + + return true; +} + +} // namespace libtextclassifier3 diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h new file mode 100644 index 0000000..c7eaf1f --- /dev/null +++ b/annotator/datetime/parser.h @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "annotator/datetime/extractor.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/calendar/calendar.h" +#include "utils/utf8/unilib.h" +#include "utils/zlib/zlib.h" + +namespace libtextclassifier3 { + +// Parses datetime expressions in the input and resolves them to actual absolute +// time. +class DatetimeParser { + public: + static std::unique_ptr<DatetimeParser> Instance( + const DatetimeModel* model, const UniLib& unilib, + const CalendarLib& calendarlib, ZlibDecompressor* decompressor); + + // Parses the dates in 'input' and fills result. Makes sure that the results + // do not overlap. + // If 'anchor_start_end' is true the extracted results need to start at the + // beginning of 'input' and end at the end of it. + bool Parse(const std::string& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const; + + // Same as above but takes UnicodeText. + bool Parse(const UnicodeText& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& locales, + ModeFlag mode, bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* results) const; + + protected: + DatetimeParser(const DatetimeModel* model, const UniLib& unilib, + const CalendarLib& calendarlib, + ZlibDecompressor* decompressor); + + // Returns a list of locale ids for given locale spec string (comma-separated + // locale names). Assigns the first parsed locale to reference_locale. + std::vector<int> ParseAndExpandLocales(const std::string& locales, + std::string* reference_locale) const; + + // Helper function that finds datetime spans, only using the rules associated + // with the given locales. + bool FindSpansUsingLocales( + const std::vector<int>& locale_ids, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + ModeFlag mode, bool anchor_start_end, const std::string& reference_locale, + std::unordered_set<int>* executed_rules, + std::vector<DatetimeParseResultSpan>* found_spans) const; + + bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, const int locale_id, + bool anchor_start_end, + std::vector<DatetimeParseResultSpan>* result) const; + + // Converts the current match in 'matcher' into DatetimeParseResult. + bool ExtractDatetime(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + DatetimeParseResult* result, + CodepointSpan* result_span) const; + + // Parse and extract information from current match in 'matcher'. + bool HandleParseMatch(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + std::vector<DatetimeParseResultSpan>* result) const; + + private: + bool initialized_; + const UniLib& unilib_; + const CalendarLib& calendarlib_; + std::vector<CompiledRule> rules_; + std::unordered_map<int, std::vector<int>> locale_to_rules_; + std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_; + std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>> + type_and_locale_to_extractor_rule_; + std::unordered_map<std::string, int> locale_string_to_id_; + std::vector<int> default_locale_ids_; + bool use_extractors_for_locating_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc new file mode 100644 index 0000000..d46accf --- /dev/null +++ b/annotator/datetime/parser_test.cc @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <time.h> +#include <fstream> +#include <iostream> +#include <memory> +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "annotator/annotator.h" +#include "annotator/datetime/parser.h" +#include "annotator/model_generated.h" +#include "annotator/types-test-util.h" + +using testing::ElementsAreArray; + +namespace libtextclassifier3 { +namespace { + +std::string GetModelPath() { + return TC3_TEST_DATA_DIR; +} + +std::string ReadFile(const std::string& file_name) { + std::ifstream file_stream(file_name); + return std::string(std::istreambuf_iterator<char>(file_stream), {}); +} + +std::string FormatMillis(int64 time_ms_utc) { + long time_seconds = time_ms_utc / 1000; // NOLINT + // Format time, "ddd yyyy-mm-dd hh:mm:ss zzz" + char buffer[512]; + strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z", + localtime(&time_seconds)); + return std::string(buffer); +} + +class ParserTest : public testing::Test { + public: + void SetUp() override { + model_buffer_ = ReadFile(GetModelPath() + "test_model.fb"); + classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(), + model_buffer_.size(), &unilib_); + TC3_CHECK(classifier_); + parser_ = classifier_->DatetimeParserForTests(); + } + + bool HasNoResult(const std::string& text, bool anchor_start_end = false, + const std::string& timezone = "Europe/Zurich") { + std::vector<DatetimeParseResultSpan> results; + if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION, + anchor_start_end, &results)) { + TC3_LOG(ERROR) << text; + TC3_CHECK(false); + } + return results.empty(); + } + + bool ParsesCorrectly(const std::string& marked_text, + const int64 expected_ms_utc, + DatetimeGranularity expected_granularity, + bool anchor_start_end = false, + const std::string& timezone = "Europe/Zurich", + const std::string& locales = "en-US") { + const UnicodeText marked_text_unicode = + UTF8ToUnicodeText(marked_text, /*do_copy=*/false); + auto brace_open_it = + std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{'); + auto brace_end_it = + std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}'); + TC3_CHECK(brace_open_it != marked_text_unicode.end()); + TC3_CHECK(brace_end_it != marked_text_unicode.end()); + + std::string text; + text += + UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it); + text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it); + text += UnicodeText::UTF8Substring(std::next(brace_end_it), + marked_text_unicode.end()); + + std::vector<DatetimeParseResultSpan> results; + + if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION, + anchor_start_end, &results)) { + TC3_LOG(ERROR) << text; + TC3_CHECK(false); + } + if (results.empty()) { + TC3_LOG(ERROR) << "No results."; + return false; + } + + const int expected_start_index = + std::distance(marked_text_unicode.begin(), brace_open_it); + // The -1 bellow is to account for the opening bracket character. + const int expected_end_index = + std::distance(marked_text_unicode.begin(), brace_end_it) - 1; + + std::vector<DatetimeParseResultSpan> filtered_results; + for (const DatetimeParseResultSpan& result : results) { + if (SpansOverlap(result.span, + {expected_start_index, expected_end_index})) { + filtered_results.push_back(result); + } + } + + const std::vector<DatetimeParseResultSpan> expected{ + {{expected_start_index, expected_end_index}, + {expected_ms_utc, expected_granularity}, + /*target_classification_score=*/1.0, + /*priority_score=*/0.1}}; + const bool matches = + testing::Matches(ElementsAreArray(expected))(filtered_results); + if (!matches) { + TC3_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: " + << FormatMillis(expected[0].data.time_ms_utc); + for (int i = 0; i < filtered_results.size(); ++i) { + TC3_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i] + << " which corresponds to: " + << FormatMillis(filtered_results[i].data.time_ms_utc); + } + } + return matches; + } + + bool ParsesCorrectlyGerman(const std::string& marked_text, + const int64 expected_ms_utc, + DatetimeGranularity expected_granularity) { + return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", /*locales=*/"de"); + } + + protected: + std::string model_buffer_; + std::unique_ptr<Annotator> classifier_; + const DatetimeParser* parser_; + UniLib unilib_; +}; + +// Test with just a few cases to make debugging of general failures easier. +TEST_F(ParserTest, ParseShort) { + EXPECT_TRUE( + ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY)); +} + +TEST_F(ParserTest, Parse) { + EXPECT_TRUE( + ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000, + GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000, + GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}", 1277512289000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectly("{23/Apr 11:42:35}", 9715355000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly( + "Are sentiments apartments decisively the especially alteration. " + "Thrown shy denote ten ladies though ask saw. Or by to he going " + "think order event music. Incommode so intention defective at " + "convinced. Led income months itself and houses you. After nor " + "you leave might share court balls. {19/apr/2010 06:36:15} Are " + "sentiments apartments decisively the especially alteration. " + "Thrown shy denote ten ladies though ask saw. Or by to he going " + "think order event music. Incommode so intention defective at " + "convinced. Led income months itself and houses you. After nor " + "you leave might share court balls. ", + 1271651775000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000, + GRANULARITY_HOUR)); + + EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -3600000, GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -57600000, GRANULARITY_MINUTE, + /*anchor_start_end=*/false, + "America/Los_Angeles")); + EXPECT_TRUE( + ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR)); + EXPECT_TRUE( + ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR)); + EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000, + GRANULARITY_MINUTE)); +} + +TEST_F(ParserTest, ParseWithAnchor) { + EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000, + GRANULARITY_DAY, /*anchor_start_end=*/false)); + EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000, + GRANULARITY_DAY, /*anchor_start_end=*/true)); + EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000, + GRANULARITY_DAY, /*anchor_start_end=*/false)); + EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum", + /*anchor_start_end=*/true)); +} + +TEST_F(ParserTest, ParseGerman) { + EXPECT_TRUE( + ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum", + 1514761200000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}", 1271651775000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}", 1291253998000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}", 1277512289000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000, + GRANULARITY_SECOND)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35}", 9715355000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000, + GRANULARITY_SECOND)); + EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000, + GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}", + 1514820600000, GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4 nachm}", 1514818800000, + GRANULARITY_HOUR)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{morgen 0:00}", 82800000, GRANULARITY_MINUTE)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE)); + EXPECT_TRUE( + ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR)); +} + +TEST_F(ParserTest, ParseNonUs) { + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", + /*locales=*/"en-GB")); + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", /*locales=*/"en")); +} + +TEST_F(ParserTest, ParseUs) { + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", + /*locales=*/"en-US")); + EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", + /*locales=*/"es-US")); +} + +TEST_F(ParserTest, ParseUnknownLanguage) { + EXPECT_TRUE(ParsesCorrectly("bylo to {31. 12. 2015} v 6 hodin", 1451516400000, + GRANULARITY_DAY, + /*anchor_start_end=*/false, + /*timezone=*/"Europe/Zurich", /*locales=*/"xx")); +} + +class ParserLocaleTest : public testing::Test { + public: + void SetUp() override; + bool HasResult(const std::string& input, const std::string& locales); + + protected: + UniLib unilib_; + CalendarLib calendarlib_; + flatbuffers::FlatBufferBuilder builder_; + std::unique_ptr<DatetimeParser> parser_; +}; + +void AddPattern(const std::string& regex, int locale, + std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) { + patterns->emplace_back(new DatetimeModelPatternT); + patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT); + patterns->back()->regexes.back()->pattern = regex; + patterns->back()->regexes.back()->groups.push_back( + DatetimeGroupType_GROUP_UNUSED); + patterns->back()->locales.push_back(locale); +} + +void ParserLocaleTest::SetUp() { + DatetimeModelT model; + model.use_extractors_for_locating = false; + model.locales.clear(); + model.locales.push_back("en-US"); + model.locales.push_back("en-CH"); + model.locales.push_back("zh-Hant"); + model.locales.push_back("en-*"); + model.locales.push_back("zh-Hant-*"); + model.locales.push_back("*-CH"); + model.locales.push_back("default"); + model.default_locales.push_back(6); + + AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns); + AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns); + AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns); + AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns); + AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns); + AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns); + AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns); + + builder_.Finish(DatetimeModel::Pack(builder_, &model)); + const DatetimeModel* model_fb = + flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer()); + ASSERT_TRUE(model_fb); + + parser_ = DatetimeParser::Instance(model_fb, unilib_, calendarlib_, + /*decompressor=*/nullptr); + ASSERT_TRUE(parser_); +} + +bool ParserLocaleTest::HasResult(const std::string& input, + const std::string& locales) { + std::vector<DatetimeParseResultSpan> results; + EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"", locales, + ModeFlag_ANNOTATION, false, &results)); + return results.size() == 1; +} + +TEST_F(ParserLocaleTest, English) { + EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US")); + EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US")); + EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH")); + EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH")); + EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH")); +} + +TEST_F(ParserLocaleTest, TraditionalChinese) { + EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant")); + EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW")); + EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG")); + EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG")); + EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh")); + EXPECT_TRUE(HasResult("default", /*locales=*/"zh")); + EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG")); +} + +TEST_F(ParserLocaleTest, SwissEnglish) { + EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH")); + EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH")); + EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH")); + EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE")); + EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH")); + EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH")); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/feature-processor.cc b/annotator/feature-processor.cc new file mode 100644 index 0000000..a18393b --- /dev/null +++ b/annotator/feature-processor.cc @@ -0,0 +1,988 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/feature-processor.h" + +#include <iterator> +#include <set> +#include <vector> + +#include "utils/base/logging.h" +#include "utils/strings/utf8.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { + +namespace internal { + +TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( + const FeatureProcessorOptions* const options) { + TokenFeatureExtractorOptions extractor_options; + + extractor_options.num_buckets = options->num_buckets(); + if (options->chargram_orders() != nullptr) { + for (int order : *options->chargram_orders()) { + extractor_options.chargram_orders.push_back(order); + } + } + extractor_options.max_word_length = options->max_word_length(); + extractor_options.extract_case_feature = options->extract_case_feature(); + extractor_options.unicode_aware_features = options->unicode_aware_features(); + extractor_options.extract_selection_mask_feature = + options->extract_selection_mask_feature(); + if (options->regexp_feature() != nullptr) { + for (const auto& regexp_feauture : *options->regexp_feature()) { + extractor_options.regexp_features.push_back(regexp_feauture->str()); + } + } + extractor_options.remap_digits = options->remap_digits(); + extractor_options.lowercase_tokens = options->lowercase_tokens(); + + if (options->allowed_chargrams() != nullptr) { + for (const auto& chargram : *options->allowed_chargrams()) { + extractor_options.allowed_chargrams.insert(chargram->str()); + } + } + return extractor_options; +} + +void SplitTokensOnSelectionBoundaries(CodepointSpan selection, + std::vector<Token>* tokens) { + for (auto it = tokens->begin(); it != tokens->end(); ++it) { + const UnicodeText token_word = + UTF8ToUnicodeText(it->value, /*do_copy=*/false); + + auto last_start = token_word.begin(); + int last_start_index = it->start; + std::vector<UnicodeText::const_iterator> split_points; + + // Selection start split point. + if (selection.first > it->start && selection.first < it->end) { + std::advance(last_start, selection.first - last_start_index); + split_points.push_back(last_start); + last_start_index = selection.first; + } + + // Selection end split point. + if (selection.second > it->start && selection.second < it->end) { + std::advance(last_start, selection.second - last_start_index); + split_points.push_back(last_start); + } + + if (!split_points.empty()) { + // Add a final split for the rest of the token unless it's been all + // consumed already. + if (split_points.back() != token_word.end()) { + split_points.push_back(token_word.end()); + } + + std::vector<Token> replacement_tokens; + last_start = token_word.begin(); + int current_pos = it->start; + for (const auto& split_point : split_points) { + Token new_token(token_word.UTF8Substring(last_start, split_point), + current_pos, + current_pos + std::distance(last_start, split_point)); + + last_start = split_point; + current_pos = new_token.end; + + replacement_tokens.push_back(new_token); + } + + it = tokens->erase(it); + it = tokens->insert(it, replacement_tokens.begin(), + replacement_tokens.end()); + std::advance(it, replacement_tokens.size() - 1); + } + } +} + +} // namespace internal + +void FeatureProcessor::StripTokensFromOtherLines( + const std::string& context, CodepointSpan span, + std::vector<Token>* tokens) const { + const UnicodeText context_unicode = UTF8ToUnicodeText(context, + /*do_copy=*/false); + StripTokensFromOtherLines(context_unicode, span, tokens); +} + +void FeatureProcessor::StripTokensFromOtherLines( + const UnicodeText& context_unicode, CodepointSpan span, + std::vector<Token>* tokens) const { + std::vector<UnicodeTextRange> lines = SplitContext(context_unicode); + + auto span_start = context_unicode.begin(); + if (span.first > 0) { + std::advance(span_start, span.first); + } + auto span_end = context_unicode.begin(); + if (span.second > 0) { + std::advance(span_end, span.second); + } + for (const UnicodeTextRange& line : lines) { + // Find the line that completely contains the span. + if (line.first <= span_start && line.second >= span_end) { + const CodepointIndex last_line_begin_index = + std::distance(context_unicode.begin(), line.first); + const CodepointIndex last_line_end_index = + last_line_begin_index + std::distance(line.first, line.second); + + for (auto token = tokens->begin(); token != tokens->end();) { + if (token->start >= last_line_begin_index && + token->end <= last_line_end_index) { + ++token; + } else { + token = tokens->erase(token); + } + } + } + } +} + +std::string FeatureProcessor::GetDefaultCollection() const { + if (options_->default_collection() < 0 || + options_->collections() == nullptr || + options_->default_collection() >= options_->collections()->size()) { + TC3_LOG(ERROR) + << "Invalid or missing default collection. Returning empty string."; + return ""; + } + return (*options_->collections())[options_->default_collection()]->str(); +} + +std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const { + const UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false); + return Tokenize(text_unicode); +} + +std::vector<Token> FeatureProcessor::Tokenize( + const UnicodeText& text_unicode) const { + if (options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) { + return tokenizer_.Tokenize(text_unicode); + } else if (options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_ICU || + options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_MIXED) { + std::vector<Token> result; + if (!ICUTokenize(text_unicode, &result)) { + return {}; + } + if (options_->tokenization_type() == + FeatureProcessorOptions_::TokenizationType_MIXED) { + InternalRetokenize(text_unicode, &result); + } + return result; + } else { + TC3_LOG(ERROR) << "Unknown tokenization type specified. Using " + "internal."; + return tokenizer_.Tokenize(text_unicode); + } +} + +bool FeatureProcessor::LabelToSpan( + const int label, const VectorSpan<Token>& tokens, + std::pair<CodepointIndex, CodepointIndex>* span) const { + if (tokens.size() != GetNumContextTokens()) { + return false; + } + + TokenSpan token_span; + if (!LabelToTokenSpan(label, &token_span)) { + return false; + } + + const int result_begin_token_index = token_span.first; + const Token& result_begin_token = + tokens[options_->context_size() - result_begin_token_index]; + const int result_begin_codepoint = result_begin_token.start; + const int result_end_token_index = token_span.second; + const Token& result_end_token = + tokens[options_->context_size() + result_end_token_index]; + const int result_end_codepoint = result_end_token.end; + + if (result_begin_codepoint == kInvalidIndex || + result_end_codepoint == kInvalidIndex) { + *span = CodepointSpan({kInvalidIndex, kInvalidIndex}); + } else { + const UnicodeText token_begin_unicode = + UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false); + UnicodeText::const_iterator token_begin = token_begin_unicode.begin(); + const UnicodeText token_end_unicode = + UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false); + UnicodeText::const_iterator token_end = token_end_unicode.end(); + + const int begin_ignored = CountIgnoredSpanBoundaryCodepoints( + token_begin, token_begin_unicode.end(), + /*count_from_beginning=*/true); + const int end_ignored = + CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end, + /*count_from_beginning=*/false); + // In case everything would be stripped, set the span to the original + // beginning and zero length. + if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) { + *span = {result_begin_codepoint, result_begin_codepoint}; + } else { + *span = CodepointSpan({result_begin_codepoint + begin_ignored, + result_end_codepoint - end_ignored}); + } + } + return true; +} + +bool FeatureProcessor::LabelToTokenSpan(const int label, + TokenSpan* token_span) const { + if (label >= 0 && label < label_to_selection_.size()) { + *token_span = label_to_selection_[label]; + return true; + } else { + return false; + } +} + +bool FeatureProcessor::SpanToLabel( + const std::pair<CodepointIndex, CodepointIndex>& span, + const std::vector<Token>& tokens, int* label) const { + if (tokens.size() != GetNumContextTokens()) { + return false; + } + + const int click_position = + options_->context_size(); // Click is always in the middle. + const int padding = options_->context_size() - options_->max_selection_span(); + + int span_left = 0; + for (int i = click_position - 1; i >= padding; i--) { + if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) { + ++span_left; + } else { + break; + } + } + + int span_right = 0; + for (int i = click_position + 1; i < tokens.size() - padding; ++i) { + if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) { + ++span_right; + } else { + break; + } + } + + // Check that the spanned tokens cover the whole span. + bool tokens_match_span; + const CodepointIndex tokens_start = tokens[click_position - span_left].start; + const CodepointIndex tokens_end = tokens[click_position + span_right].end; + if (options_->snap_label_span_boundaries_to_containing_tokens()) { + tokens_match_span = tokens_start <= span.first && tokens_end >= span.second; + } else { + const UnicodeText token_left_unicode = UTF8ToUnicodeText( + tokens[click_position - span_left].value, /*do_copy=*/false); + const UnicodeText token_right_unicode = UTF8ToUnicodeText( + tokens[click_position + span_right].value, /*do_copy=*/false); + + UnicodeText::const_iterator span_begin = token_left_unicode.begin(); + UnicodeText::const_iterator span_end = token_right_unicode.end(); + + const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints( + span_begin, token_left_unicode.end(), /*count_from_beginning=*/true); + const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints( + token_right_unicode.begin(), span_end, + /*count_from_beginning=*/false); + + tokens_match_span = tokens_start <= span.first && + tokens_start + num_punctuation_start >= span.first && + tokens_end >= span.second && + tokens_end - num_punctuation_end <= span.second; + } + + if (tokens_match_span) { + *label = TokenSpanToLabel({span_left, span_right}); + } else { + *label = kInvalidLabel; + } + + return true; +} + +int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const { + auto it = selection_to_label_.find(span); + if (it != selection_to_label_.end()) { + return it->second; + } else { + return kInvalidLabel; + } +} + +TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens, + CodepointSpan codepoint_span, + bool snap_boundaries_to_containing_tokens) { + const int codepoint_start = std::get<0>(codepoint_span); + const int codepoint_end = std::get<1>(codepoint_span); + + TokenIndex start_token = kInvalidIndex; + TokenIndex end_token = kInvalidIndex; + for (int i = 0; i < selectable_tokens.size(); ++i) { + bool is_token_in_span; + if (snap_boundaries_to_containing_tokens) { + is_token_in_span = codepoint_start < selectable_tokens[i].end && + codepoint_end > selectable_tokens[i].start; + } else { + is_token_in_span = codepoint_start <= selectable_tokens[i].start && + codepoint_end >= selectable_tokens[i].end; + } + if (is_token_in_span && !selectable_tokens[i].is_padding) { + if (start_token == kInvalidIndex) { + start_token = i; + } + end_token = i + 1; + } + } + return {start_token, end_token}; +} + +CodepointSpan TokenSpanToCodepointSpan( + const std::vector<Token>& selectable_tokens, TokenSpan token_span) { + return {selectable_tokens[token_span.first].start, + selectable_tokens[token_span.second - 1].end}; +} + +namespace { + +// Finds a single token that completely contains the given span. +int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens, + CodepointSpan codepoint_span) { + const int codepoint_start = std::get<0>(codepoint_span); + const int codepoint_end = std::get<1>(codepoint_span); + + for (int i = 0; i < selectable_tokens.size(); ++i) { + if (codepoint_start >= selectable_tokens[i].start && + codepoint_end <= selectable_tokens[i].end) { + return i; + } + } + return kInvalidIndex; +} + +} // namespace + +namespace internal { + +int CenterTokenFromClick(CodepointSpan span, + const std::vector<Token>& selectable_tokens) { + int range_begin; + int range_end; + std::tie(range_begin, range_end) = + CodepointSpanToTokenSpan(selectable_tokens, span); + + // If no exact match was found, try finding a token that completely contains + // the click span. This is useful e.g. when Android builds the selection + // using ICU tokenization, and ends up with only a portion of our space- + // separated token. E.g. for "(857)" Android would select "857". + if (range_begin == kInvalidIndex || range_end == kInvalidIndex) { + int token_index = FindTokenThatContainsSpan(selectable_tokens, span); + if (token_index != kInvalidIndex) { + range_begin = token_index; + range_end = token_index + 1; + } + } + + // We only allow clicks that are exactly 1 selectable token. + if (range_end - range_begin == 1) { + return range_begin; + } else { + return kInvalidIndex; + } +} + +int CenterTokenFromMiddleOfSelection( + CodepointSpan span, const std::vector<Token>& selectable_tokens) { + int range_begin; + int range_end; + std::tie(range_begin, range_end) = + CodepointSpanToTokenSpan(selectable_tokens, span); + + // Center the clicked token in the selection range. + if (range_begin != kInvalidIndex && range_end != kInvalidIndex) { + return (range_begin + range_end - 1) / 2; + } else { + return kInvalidIndex; + } +} + +} // namespace internal + +int FeatureProcessor::FindCenterToken(CodepointSpan span, + const std::vector<Token>& tokens) const { + if (options_->center_token_selection_method() == + FeatureProcessorOptions_:: + CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) { + return internal::CenterTokenFromClick(span, tokens); + } else if (options_->center_token_selection_method() == + FeatureProcessorOptions_:: + CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) { + return internal::CenterTokenFromMiddleOfSelection(span, tokens); + } else if (options_->center_token_selection_method() == + FeatureProcessorOptions_:: + CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) { + // TODO(zilka): Remove once we have new models on the device. + // It uses the fact that sharing model use + // split_tokens_on_selection_boundaries and selection not. So depending on + // this we select the right way of finding the click location. + if (!options_->split_tokens_on_selection_boundaries()) { + // SmartSelection model. + return internal::CenterTokenFromClick(span, tokens); + } else { + // SmartSharing model. + return internal::CenterTokenFromMiddleOfSelection(span, tokens); + } + } else { + TC3_LOG(ERROR) << "Invalid center token selection method."; + return kInvalidIndex; + } +} + +bool FeatureProcessor::SelectionLabelSpans( + const VectorSpan<Token> tokens, + std::vector<CodepointSpan>* selection_label_spans) const { + for (int i = 0; i < label_to_selection_.size(); ++i) { + CodepointSpan span; + if (!LabelToSpan(i, tokens, &span)) { + TC3_LOG(ERROR) << "Could not convert label to span: " << i; + return false; + } + selection_label_spans->push_back(span); + } + return true; +} + +void FeatureProcessor::PrepareCodepointRanges( + const std::vector<const FeatureProcessorOptions_::CodepointRange*>& + codepoint_ranges, + std::vector<CodepointRange>* prepared_codepoint_ranges) { + prepared_codepoint_ranges->clear(); + prepared_codepoint_ranges->reserve(codepoint_ranges.size()); + for (const FeatureProcessorOptions_::CodepointRange* range : + codepoint_ranges) { + prepared_codepoint_ranges->push_back( + CodepointRange(range->start(), range->end())); + } + + std::sort(prepared_codepoint_ranges->begin(), + prepared_codepoint_ranges->end(), + [](const CodepointRange& a, const CodepointRange& b) { + return a.start < b.start; + }); +} + +void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() { + if (options_->ignored_span_boundary_codepoints() != nullptr) { + for (const int codepoint : *options_->ignored_span_boundary_codepoints()) { + ignored_span_boundary_codepoints_.insert(codepoint); + } + } +} + +int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints( + const UnicodeText::const_iterator& span_start, + const UnicodeText::const_iterator& span_end, + bool count_from_beginning) const { + if (span_start == span_end) { + return 0; + } + + UnicodeText::const_iterator it; + UnicodeText::const_iterator it_last; + if (count_from_beginning) { + it = span_start; + it_last = span_end; + // We can assume that the string is non-zero length because of the check + // above, thus the decrement is always valid here. + --it_last; + } else { + it = span_end; + it_last = span_start; + // We can assume that the string is non-zero length because of the check + // above, thus the decrement is always valid here. + --it; + } + + // Move until we encounter a non-ignored character. + int num_ignored = 0; + while (ignored_span_boundary_codepoints_.find(*it) != + ignored_span_boundary_codepoints_.end()) { + ++num_ignored; + + if (it == it_last) { + break; + } + + if (count_from_beginning) { + ++it; + } else { + --it; + } + } + + return num_ignored; +} + +namespace { + +void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints, + std::vector<UnicodeTextRange>* ranges) { + UnicodeText::const_iterator start = t.begin(); + UnicodeText::const_iterator curr = start; + UnicodeText::const_iterator end = t.end(); + for (; curr != end; ++curr) { + if (codepoints.find(*curr) != codepoints.end()) { + if (start != curr) { + ranges->push_back(std::make_pair(start, curr)); + } + start = curr; + ++start; + } + } + if (start != end) { + ranges->push_back(std::make_pair(start, end)); + } +} + +} // namespace + +std::vector<UnicodeTextRange> FeatureProcessor::SplitContext( + const UnicodeText& context_unicode) const { + std::vector<UnicodeTextRange> lines; + const std::set<char32> codepoints{{'\n', '|'}}; + FindSubstrings(context_unicode, codepoints, &lines); + return lines; +} + +CodepointSpan FeatureProcessor::StripBoundaryCodepoints( + const std::string& context, CodepointSpan span) const { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + return StripBoundaryCodepoints(context_unicode, span); +} + +CodepointSpan FeatureProcessor::StripBoundaryCodepoints( + const UnicodeText& context_unicode, CodepointSpan span) const { + if (context_unicode.empty() || !ValidNonEmptySpan(span)) { + return span; + } + + UnicodeText::const_iterator span_begin = context_unicode.begin(); + std::advance(span_begin, span.first); + UnicodeText::const_iterator span_end = context_unicode.begin(); + std::advance(span_end, span.second); + + const int start_offset = CountIgnoredSpanBoundaryCodepoints( + span_begin, span_end, /*count_from_beginning=*/true); + const int end_offset = CountIgnoredSpanBoundaryCodepoints( + span_begin, span_end, /*count_from_beginning=*/false); + + if (span.first + start_offset < span.second - end_offset) { + return {span.first + start_offset, span.second - end_offset}; + } else { + return {span.first, span.first}; + } +} + +float FeatureProcessor::SupportedCodepointsRatio( + const TokenSpan& token_span, const std::vector<Token>& tokens) const { + int num_supported = 0; + int num_total = 0; + for (int i = token_span.first; i < token_span.second; ++i) { + const UnicodeText value = + UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false); + for (auto codepoint : value) { + if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) { + ++num_supported; + } + ++num_total; + } + } + return static_cast<float>(num_supported) / static_cast<float>(num_total); +} + +bool FeatureProcessor::IsCodepointInRanges( + int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const { + auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(), + codepoint, + [](const CodepointRange& range, int codepoint) { + // This function compares range with the + // codepoint for the purpose of finding the first + // greater or equal range. Because of the use of + // std::lower_bound it needs to return true when + // range < codepoint; the first time it will + // return false the lower bound is found and + // returned. + // + // It might seem weird that the condition is + // range.end <= codepoint here but when codepoint + // == range.end it means it's actually just + // outside of the range, thus the range is less + // than the codepoint. + return range.end <= codepoint; + }); + if (it != codepoint_ranges.end() && it->start <= codepoint && + it->end > codepoint) { + return true; + } else { + return false; + } +} + +int FeatureProcessor::CollectionToLabel(const std::string& collection) const { + const auto it = collection_to_label_.find(collection); + if (it == collection_to_label_.end()) { + return options_->default_collection(); + } else { + return it->second; + } +} + +std::string FeatureProcessor::LabelToCollection(int label) const { + if (label >= 0 && label < collection_to_label_.size()) { + return (*options_->collections())[label]->str(); + } else { + return GetDefaultCollection(); + } +} + +void FeatureProcessor::MakeLabelMaps() { + if (options_->collections() != nullptr) { + for (int i = 0; i < options_->collections()->size(); ++i) { + collection_to_label_[(*options_->collections())[i]->str()] = i; + } + } + + int selection_label_id = 0; + for (int l = 0; l < (options_->max_selection_span() + 1); ++l) { + for (int r = 0; r < (options_->max_selection_span() + 1); ++r) { + if (!options_->selection_reduced_output_space() || + r + l <= options_->max_selection_span()) { + TokenSpan token_span{l, r}; + selection_to_label_[token_span] = selection_label_id; + label_to_selection_.push_back(token_span); + ++selection_label_id; + } + } + } +} + +void FeatureProcessor::RetokenizeAndFindClick(const std::string& context, + CodepointSpan input_span, + bool only_use_line_with_click, + std::vector<Token>* tokens, + int* click_pos) const { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click, + tokens, click_pos); +} + +void FeatureProcessor::RetokenizeAndFindClick( + const UnicodeText& context_unicode, CodepointSpan input_span, + bool only_use_line_with_click, std::vector<Token>* tokens, + int* click_pos) const { + TC3_CHECK(tokens != nullptr); + + if (options_->split_tokens_on_selection_boundaries()) { + internal::SplitTokensOnSelectionBoundaries(input_span, tokens); + } + + if (only_use_line_with_click) { + StripTokensFromOtherLines(context_unicode, input_span, tokens); + } + + int local_click_pos; + if (click_pos == nullptr) { + click_pos = &local_click_pos; + } + *click_pos = FindCenterToken(input_span, *tokens); + if (*click_pos == kInvalidIndex) { + // If the default click method failed, let's try to do sub-token matching + // before we fail. + *click_pos = internal::CenterTokenFromClick(input_span, *tokens); + } +} + +namespace internal { + +void StripOrPadTokens(TokenSpan relative_click_span, int context_size, + std::vector<Token>* tokens, int* click_pos) { + int right_context_needed = relative_click_span.second + context_size; + if (*click_pos + right_context_needed + 1 >= tokens->size()) { + // Pad max the context size. + const int num_pad_tokens = std::min( + context_size, static_cast<int>(*click_pos + right_context_needed + 1 - + tokens->size())); + std::vector<Token> pad_tokens(num_pad_tokens); + tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end()); + } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) { + // Strip unused tokens. + auto it = tokens->begin(); + std::advance(it, *click_pos + right_context_needed + 1); + tokens->erase(it, tokens->end()); + } + + int left_context_needed = relative_click_span.first + context_size; + if (*click_pos < left_context_needed) { + // Pad max the context size. + const int num_pad_tokens = + std::min(context_size, left_context_needed - *click_pos); + std::vector<Token> pad_tokens(num_pad_tokens); + tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end()); + *click_pos += num_pad_tokens; + } else if (*click_pos > left_context_needed) { + // Strip unused tokens. + auto it = tokens->begin(); + std::advance(it, *click_pos - left_context_needed); + *click_pos -= it - tokens->begin(); + tokens->erase(tokens->begin(), it); + } +} + +} // namespace internal + +bool FeatureProcessor::HasEnoughSupportedCodepoints( + const std::vector<Token>& tokens, TokenSpan token_span) const { + if (options_->min_supported_codepoint_ratio() > 0) { + const float supported_codepoint_ratio = + SupportedCodepointsRatio(token_span, tokens); + if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) { + TC3_VLOG(1) << "Not enough supported codepoints in the context: " + << supported_codepoint_ratio; + return false; + } + } + return true; +} + +bool FeatureProcessor::ExtractFeatures( + const std::vector<Token>& tokens, TokenSpan token_span, + CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, int feature_vector_size, + std::unique_ptr<CachedFeatures>* cached_features) const { + std::unique_ptr<std::vector<float>> features(new std::vector<float>()); + features->reserve(feature_vector_size * TokenSpanSize(token_span)); + for (int i = token_span.first; i < token_span.second; ++i) { + if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature, + embedding_executor, embedding_cache, + features.get())) { + TC3_LOG(ERROR) << "Could not get token features."; + return false; + } + } + + std::unique_ptr<std::vector<float>> padding_features( + new std::vector<float>()); + padding_features->reserve(feature_vector_size); + if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature, + embedding_executor, embedding_cache, + padding_features.get())) { + TC3_LOG(ERROR) << "Count not get padding token features."; + return false; + } + + *cached_features = CachedFeatures::Create(token_span, std::move(features), + std::move(padding_features), + options_, feature_vector_size); + if (!*cached_features) { + TC3_LOG(ERROR) << "Cound not create cached features."; + return false; + } + + return true; +} + +bool FeatureProcessor::ICUTokenize(const UnicodeText& context_unicode, + std::vector<Token>* result) const { + std::unique_ptr<UniLib::BreakIterator> break_iterator = + unilib_->CreateBreakIterator(context_unicode); + if (!break_iterator) { + return false; + } + int last_break_index = 0; + int break_index = 0; + int last_unicode_index = 0; + int unicode_index = 0; + auto token_begin_it = context_unicode.begin(); + while ((break_index = break_iterator->Next()) != + UniLib::BreakIterator::kDone) { + const int token_length = break_index - last_break_index; + unicode_index = last_unicode_index + token_length; + + auto token_end_it = token_begin_it; + std::advance(token_end_it, token_length); + + // Determine if the whole token is whitespace. + bool is_whitespace = true; + for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) { + if (!unilib_->IsWhitespace(*char_it)) { + is_whitespace = false; + break; + } + } + + const std::string token = + context_unicode.UTF8Substring(token_begin_it, token_end_it); + + if (!is_whitespace || options_->icu_preserve_whitespace_tokens()) { + result->push_back(Token(token, last_unicode_index, unicode_index)); + } + + last_break_index = break_index; + last_unicode_index = unicode_index; + token_begin_it = token_end_it; + } + + return true; +} + +void FeatureProcessor::InternalRetokenize(const UnicodeText& unicode_text, + std::vector<Token>* tokens) const { + std::vector<Token> result; + CodepointSpan span(-1, -1); + for (Token& token : *tokens) { + const UnicodeText unicode_token_value = + UTF8ToUnicodeText(token.value, /*do_copy=*/false); + bool should_retokenize = true; + for (const int codepoint : unicode_token_value) { + if (!IsCodepointInRanges(codepoint, + internal_tokenizer_codepoint_ranges_)) { + should_retokenize = false; + break; + } + } + + if (should_retokenize) { + if (span.first < 0) { + span.first = token.start; + } + span.second = token.end; + } else { + TokenizeSubstring(unicode_text, span, &result); + span.first = -1; + result.emplace_back(std::move(token)); + } + } + TokenizeSubstring(unicode_text, span, &result); + + *tokens = std::move(result); +} + +void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text, + CodepointSpan span, + std::vector<Token>* result) const { + if (span.first < 0) { + // There is no span to tokenize. + return; + } + + // Extract the substring. + UnicodeText::const_iterator it_begin = unicode_text.begin(); + for (int i = 0; i < span.first; ++i) { + ++it_begin; + } + UnicodeText::const_iterator it_end = unicode_text.begin(); + for (int i = 0; i < span.second; ++i) { + ++it_end; + } + const std::string text = unicode_text.UTF8Substring(it_begin, it_end); + + // Run the tokenizer and update the token bounds to reflect the offset of the + // substring. + std::vector<Token> tokens = tokenizer_.Tokenize(text); + // Avoids progressive capacity increases in the for loop. + result->reserve(result->size() + tokens.size()); + for (Token& token : tokens) { + token.start += span.first; + token.end += span.first; + result->emplace_back(std::move(token)); + } +} + +bool FeatureProcessor::AppendTokenFeaturesWithCache( + const Token& token, CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, + std::vector<float>* output_features) const { + // Look for the embedded features for the token in the cache, if there is one. + if (embedding_cache) { + const auto it = embedding_cache->find({token.start, token.end}); + if (it != embedding_cache->end()) { + // The embedded features were found in the cache, extract only the dense + // features. + std::vector<float> dense_features; + if (!feature_extractor_.Extract( + token, token.IsContainedInSpan(selection_span_for_feature), + /*sparse_features=*/nullptr, &dense_features)) { + TC3_LOG(ERROR) << "Could not extract token's dense features."; + return false; + } + + // Append both embedded and dense features to the output and return. + output_features->insert(output_features->end(), it->second.begin(), + it->second.end()); + output_features->insert(output_features->end(), dense_features.begin(), + dense_features.end()); + return true; + } + } + + // Extract the sparse and dense features. + std::vector<int> sparse_features; + std::vector<float> dense_features; + if (!feature_extractor_.Extract( + token, token.IsContainedInSpan(selection_span_for_feature), + &sparse_features, &dense_features)) { + TC3_LOG(ERROR) << "Could not extract token's features."; + return false; + } + + // Embed the sparse features, appending them directly to the output. + const int embedding_size = GetOptions()->embedding_size(); + output_features->resize(output_features->size() + embedding_size); + float* output_features_end = + output_features->data() + output_features->size(); + if (!embedding_executor->AddEmbedding( + TensorView<int>(sparse_features.data(), + {static_cast<int>(sparse_features.size())}), + /*dest=*/output_features_end - embedding_size, + /*dest_size=*/embedding_size)) { + TC3_LOG(ERROR) << "Cound not embed token's sparse features."; + return false; + } + + // If there is a cache, the embedded features for the token were not in it, + // so insert them. + if (embedding_cache) { + (*embedding_cache)[{token.start, token.end}] = std::vector<float>( + output_features_end - embedding_size, output_features_end); + } + + // Append the dense features to the output. + output_features->insert(output_features->end(), dense_features.begin(), + dense_features.end()); + return true; +} + +} // namespace libtextclassifier3 diff --git a/annotator/feature-processor.h b/annotator/feature-processor.h new file mode 100644 index 0000000..2d04253 --- /dev/null +++ b/annotator/feature-processor.h @@ -0,0 +1,331 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Feature processing for FFModel (feed-forward SmartSelection model). + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ + +#include <map> +#include <memory> +#include <set> +#include <string> +#include <vector> + +#include "annotator/cached-features.h" +#include "annotator/model_generated.h" +#include "annotator/token-feature-extractor.h" +#include "annotator/tokenizer.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/base/logging.h" +#include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3 { + +constexpr int kInvalidLabel = -1; + +namespace internal { + +TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( + const FeatureProcessorOptions* options); + +// Splits tokens that contain the selection boundary inside them. +// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" +void SplitTokensOnSelectionBoundaries(CodepointSpan selection, + std::vector<Token>* tokens); + +// Returns the index of token that corresponds to the codepoint span. +int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens); + +// Returns the index of token that corresponds to the middle of the codepoint +// span. +int CenterTokenFromMiddleOfSelection( + CodepointSpan span, const std::vector<Token>& selectable_tokens); + +// Strips the tokens from the tokens vector that are not used for feature +// extraction because they are out of scope, or pads them so that there is +// enough tokens in the required context_size for all inferences with a click +// in relative_click_span. +void StripOrPadTokens(TokenSpan relative_click_span, int context_size, + std::vector<Token>* tokens, int* click_pos); + +} // namespace internal + +// Converts a codepoint span to a token span in the given list of tokens. +// If snap_boundaries_to_containing_tokens is set to true, it is enough for a +// token to overlap with the codepoint range to be considered part of it. +// Otherwise it must be fully included in the range. +TokenSpan CodepointSpanToTokenSpan( + const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span, + bool snap_boundaries_to_containing_tokens = false); + +// Converts a token span to a codepoint span in the given list of tokens. +CodepointSpan TokenSpanToCodepointSpan( + const std::vector<Token>& selectable_tokens, TokenSpan token_span); + +// Takes care of preparing features for the span prediction model. +class FeatureProcessor { + public: + // A cache mapping codepoint spans to embedded tokens features. An instance + // can be provided to multiple calls to ExtractFeatures() operating on the + // same context (the same codepoint spans corresponding to the same tokens), + // as an optimization. Note that the tokenizations do not have to be + // identical. + typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache; + + FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib) + : unilib_(unilib), + feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options), + *unilib_), + options_(options), + tokenizer_( + options->tokenization_codepoint_config() != nullptr + ? Tokenizer({options->tokenization_codepoint_config()->begin(), + options->tokenization_codepoint_config()->end()}, + options->tokenize_on_script_change()) + : Tokenizer({}, /*split_on_script_change=*/false)) { + MakeLabelMaps(); + if (options->supported_codepoint_ranges() != nullptr) { + PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(), + options->supported_codepoint_ranges()->end()}, + &supported_codepoint_ranges_); + } + if (options->internal_tokenizer_codepoint_ranges() != nullptr) { + PrepareCodepointRanges( + {options->internal_tokenizer_codepoint_ranges()->begin(), + options->internal_tokenizer_codepoint_ranges()->end()}, + &internal_tokenizer_codepoint_ranges_); + } + PrepareIgnoredSpanBoundaryCodepoints(); + } + + // Tokenizes the input string using the selected tokenization method. + std::vector<Token> Tokenize(const std::string& text) const; + + // Same as above but takes UnicodeText. + std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; + + // Converts a label into a token span. + bool LabelToTokenSpan(int label, TokenSpan* token_span) const; + + // Gets the total number of selection labels. + int GetSelectionLabelCount() const { return label_to_selection_.size(); } + + // Gets the string value for given collection label. + std::string LabelToCollection(int label) const; + + // Gets the total number of collections of the model. + int NumCollections() const { return collection_to_label_.size(); } + + // Gets the name of the default collection. + std::string GetDefaultCollection() const; + + const FeatureProcessorOptions* GetOptions() const { return options_; } + + // Retokenizes the context and input span, and finds the click position. + // Depending on the options, might modify tokens (split them or remove them). + void RetokenizeAndFindClick(const std::string& context, + CodepointSpan input_span, + bool only_use_line_with_click, + std::vector<Token>* tokens, int* click_pos) const; + + // Same as above but takes UnicodeText. + void RetokenizeAndFindClick(const UnicodeText& context_unicode, + CodepointSpan input_span, + bool only_use_line_with_click, + std::vector<Token>* tokens, int* click_pos) const; + + // Returns true if the token span has enough supported codepoints (as defined + // in the model config) or not and model should not run. + bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens, + TokenSpan token_span) const; + + // Extracts features as a CachedFeatures object that can be used for repeated + // inference over token spans in the given context. + bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span, + CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, int feature_vector_size, + std::unique_ptr<CachedFeatures>* cached_features) const; + + // Fills selection_label_spans with CodepointSpans that correspond to the + // selection labels. The CodepointSpans are based on the codepoint ranges of + // given tokens. + bool SelectionLabelSpans( + VectorSpan<Token> tokens, + std::vector<CodepointSpan>* selection_label_spans) const; + + int DenseFeaturesCount() const { + return feature_extractor_.DenseFeaturesCount(); + } + + int EmbeddingSize() const { return options_->embedding_size(); } + + // Splits context to several segments. + std::vector<UnicodeTextRange> SplitContext( + const UnicodeText& context_unicode) const; + + // Strips boundary codepoints from the span in context and returns the new + // start and end indices. If the span comprises entirely of boundary + // codepoints, the first index of span is returned for both indices. + CodepointSpan StripBoundaryCodepoints(const std::string& context, + CodepointSpan span) const; + + // Same as above but takes UnicodeText. + CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode, + CodepointSpan span) const; + + protected: + // Represents a codepoint range [start, end). + struct CodepointRange { + int32 start; + int32 end; + + CodepointRange(int32 arg_start, int32 arg_end) + : start(arg_start), end(arg_end) {} + }; + + // Returns the class id corresponding to the given string collection + // identifier. There is a catch-all class id that the function returns for + // unknown collections. + int CollectionToLabel(const std::string& collection) const; + + // Prepares mapping from collection names to labels. + void MakeLabelMaps(); + + // Gets the number of spannable tokens for the model. + // + // Spannable tokens are those tokens of context, which the model predicts + // selection spans over (i.e., there is 1:1 correspondence between the output + // classes of the model and each of the spannable tokens). + int GetNumContextTokens() const { return options_->context_size() * 2 + 1; } + + // Converts a label into a span of codepoint indices corresponding to it + // given output_tokens. + bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens, + CodepointSpan* span) const; + + // Converts a span to the corresponding label given output_tokens. + bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span, + const std::vector<Token>& output_tokens, int* label) const; + + // Converts a token span to the corresponding label. + int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const; + + void PrepareCodepointRanges( + const std::vector<const FeatureProcessorOptions_::CodepointRange*>& + codepoint_ranges, + std::vector<CodepointRange>* prepared_codepoint_ranges); + + // Returns the ratio of supported codepoints to total number of codepoints in + // the given token span. + float SupportedCodepointsRatio(const TokenSpan& token_span, + const std::vector<Token>& tokens) const; + + // Returns true if given codepoint is covered by the given sorted vector of + // codepoint ranges. + bool IsCodepointInRanges( + int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const; + + void PrepareIgnoredSpanBoundaryCodepoints(); + + // Counts the number of span boundary codepoints. If count_from_beginning is + // True, the counting will start at the span_start iterator (inclusive) and at + // maximum end at span_end (exclusive). If count_from_beginning is True, the + // counting will start from span_end (exclusive) and end at span_start + // (inclusive). + int CountIgnoredSpanBoundaryCodepoints( + const UnicodeText::const_iterator& span_start, + const UnicodeText::const_iterator& span_end, + bool count_from_beginning) const; + + // Finds the center token index in tokens vector, using the method defined + // in options_. + int FindCenterToken(CodepointSpan span, + const std::vector<Token>& tokens) const; + + // Tokenizes the input text using ICU tokenizer. + bool ICUTokenize(const UnicodeText& context_unicode, + std::vector<Token>* result) const; + + // Takes the result of ICU tokenization and retokenizes stretches of tokens + // made of a specific subset of characters using the internal tokenizer. + void InternalRetokenize(const UnicodeText& unicode_text, + std::vector<Token>* tokens) const; + + // Tokenizes a substring of the unicode string, appending the resulting tokens + // to the output vector. The resulting tokens have bounds relative to the full + // string. Does nothing if the start of the span is negative. + void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span, + std::vector<Token>* result) const; + + // Removes all tokens from tokens that are not on a line (defined by calling + // SplitContext on the context) to which span points. + void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, + std::vector<Token>* tokens) const; + + // Same as above but takes UnicodeText. + void StripTokensFromOtherLines(const UnicodeText& context_unicode, + CodepointSpan span, + std::vector<Token>* tokens) const; + + // Extracts the features of a token and appends them to the output vector. + // Uses the embedding cache to to avoid re-extracting the re-embedding the + // sparse features for the same token. + bool AppendTokenFeaturesWithCache(const Token& token, + CodepointSpan selection_span_for_feature, + const EmbeddingExecutor* embedding_executor, + EmbeddingCache* embedding_cache, + std::vector<float>* output_features) const; + + private: + const UniLib* unilib_; + + protected: + const TokenFeatureExtractor feature_extractor_; + + // Codepoint ranges that define what codepoints are supported by the model. + // NOTE: Must be sorted. + std::vector<CodepointRange> supported_codepoint_ranges_; + + // Codepoint ranges that define which tokens (consisting of which codepoints) + // should be re-tokenized with the internal tokenizer in the mixed + // tokenization mode. + // NOTE: Must be sorted. + std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_; + + private: + // Set of codepoints that will be stripped from beginning and end of + // predicted spans. + std::set<int32> ignored_span_boundary_codepoints_; + + const FeatureProcessorOptions* const options_; + + // Mapping between token selection spans and labels ids. + std::map<TokenSpan, int> selection_to_label_; + std::vector<TokenSpan> label_to_selection_; + + // Mapping between collections and labels. + std::map<std::string, int> collection_to_label_; + + Tokenizer tokenizer_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ diff --git a/annotator/feature-processor_test.cc b/annotator/feature-processor_test.cc new file mode 100644 index 0000000..c9f0e0d --- /dev/null +++ b/annotator/feature-processor_test.cc @@ -0,0 +1,1125 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/feature-processor.h" + +#include "annotator/model-executor.h" +#include "utils/tensor-view.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +using testing::ElementsAreArray; +using testing::FloatEq; +using testing::Matcher; + +flatbuffers::DetachedBuffer PackFeatureProcessorOptions( + const FeatureProcessorOptionsT& options) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateFeatureProcessorOptions(builder, &options)); + return builder.Release(); +} + +template <typename T> +std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) { + return std::vector<T>(vector.begin() + start, vector.begin() + end); +} + +Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { + std::vector<Matcher<float>> matchers; + for (const float value : values) { + matchers.push_back(FloatEq(value)); + } + return ElementsAreArray(matchers); +} + +class TestingFeatureProcessor : public FeatureProcessor { + public: + using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints; + using FeatureProcessor::FeatureProcessor; + using FeatureProcessor::ICUTokenize; + using FeatureProcessor::IsCodepointInRanges; + using FeatureProcessor::SpanToLabel; + using FeatureProcessor::StripTokensFromOtherLines; + using FeatureProcessor::supported_codepoint_ranges_; + using FeatureProcessor::SupportedCodepointsRatio; +}; + +// EmbeddingExecutor that always returns features based on +class FakeEmbeddingExecutor : public EmbeddingExecutor { + public: + bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, + int dest_size) const override { + TC3_CHECK_GE(dest_size, 4); + EXPECT_EQ(sparse_features.size(), 1); + dest[0] = sparse_features.data()[0]; + dest[1] = sparse_features.data()[0]; + dest[2] = -sparse_features.data()[0]; + dest[3] = -sparse_features.data()[0]; + return true; + } + + private: + std::vector<float> storage_; +}; + +class FeatureProcessorTest : public ::testing::Test { + protected: + FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) { + std::vector<Token> tokens{Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)}; + + internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens); + + // clang-format off + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Hělló", 0, 5), + Token("fěě", 6, 9), + Token("bař", 9, 12), + Token("@google.com", 12, 23), + Token("heře!", 24, 29)})); + // clang-format on +} + +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) { + std::vector<Token> tokens{Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)}; + + internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens); + + // clang-format off + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Hělló", 0, 5), + Token("fěěbař", 6, 12), + Token("@google.com", 12, 23), + Token("heře!", 24, 29)})); + // clang-format on +} + +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) { + std::vector<Token> tokens{Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)}; + + internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens); + + // clang-format off + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Hělló", 0, 5), + Token("fěě", 6, 9), + Token("bař@google.com", 9, 23), + Token("heře!", 24, 29)})); + // clang-format on +} + +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) { + std::vector<Token> tokens{Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)}; + + internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens); + + // clang-format off + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)})); + // clang-format on +} + +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) { + std::vector<Token> tokens{Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)}; + + internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens); + + // clang-format off + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Hě", 0, 2), + Token("lló", 2, 5), + Token("fěě", 6, 9), + Token("bař@google.com", 9, 23), + Token("heře!", 24, 29)})); + // clang-format on +} + +TEST_F(FeatureProcessorTest, KeepLineWithClickFirst) { + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; + const CodepointSpan span = {0, 5}; + // clang-format off + std::vector<Token> tokens = {Token("Fiřst", 0, 5), + Token("Lině", 6, 10), + Token("Sěcond", 11, 17), + Token("Lině", 18, 22), + Token("Thiřd", 23, 28), + Token("Lině", 29, 33)}; + // clang-format on + + // Keeps the first line. + feature_processor.StripTokensFromOtherLines(context, span, &tokens); + EXPECT_THAT(tokens, + ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)})); +} + +TEST_F(FeatureProcessorTest, KeepLineWithClickSecond) { + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; + const CodepointSpan span = {18, 22}; + // clang-format off + std::vector<Token> tokens = {Token("Fiřst", 0, 5), + Token("Lině", 6, 10), + Token("Sěcond", 11, 17), + Token("Lině", 18, 22), + Token("Thiřd", 23, 28), + Token("Lině", 29, 33)}; + // clang-format on + + // Keeps the first line. + feature_processor.StripTokensFromOtherLines(context, span, &tokens); + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Sěcond", 11, 17), Token("Lině", 18, 22)})); +} + +TEST_F(FeatureProcessorTest, KeepLineWithClickThird) { + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; + const CodepointSpan span = {24, 33}; + // clang-format off + std::vector<Token> tokens = {Token("Fiřst", 0, 5), + Token("Lině", 6, 10), + Token("Sěcond", 11, 17), + Token("Lině", 18, 22), + Token("Thiřd", 23, 28), + Token("Lině", 29, 33)}; + // clang-format on + + // Keeps the first line. + feature_processor.StripTokensFromOtherLines(context, span, &tokens); + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Thiřd", 23, 28), Token("Lině", 29, 33)})); +} + +TEST_F(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině"; + const CodepointSpan span = {18, 22}; + // clang-format off + std::vector<Token> tokens = {Token("Fiřst", 0, 5), + Token("Lině", 6, 10), + Token("Sěcond", 11, 17), + Token("Lině", 18, 22), + Token("Thiřd", 23, 28), + Token("Lině", 29, 33)}; + // clang-format on + + // Keeps the first line. + feature_processor.StripTokensFromOtherLines(context, span, &tokens); + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Sěcond", 11, 17), Token("Lině", 18, 22)})); +} + +TEST_F(FeatureProcessorTest, KeepLineWithCrosslineClick) { + FeatureProcessorOptionsT options; + options.only_use_line_with_click = true; + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; + const CodepointSpan span = {5, 23}; + // clang-format off + std::vector<Token> tokens = {Token("Fiřst", 0, 5), + Token("Lině", 6, 10), + Token("Sěcond", 18, 23), + Token("Lině", 19, 23), + Token("Thiřd", 23, 28), + Token("Lině", 29, 33)}; + // clang-format on + + // Keeps the first line. + feature_processor.StripTokensFromOtherLines(context, span, &tokens); + EXPECT_THAT(tokens, ElementsAreArray( + {Token("Fiřst", 0, 5), Token("Lině", 6, 10), + Token("Sěcond", 18, 23), Token("Lině", 19, 23), + Token("Thiřd", 23, 28), Token("Lině", 29, 33)})); +} + +TEST_F(FeatureProcessorTest, SpanToLabel) { + FeatureProcessorOptionsT options; + options.context_size = 1; + options.max_selection_span = 1; + options.snap_label_span_boundaries_to_containing_tokens = false; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); + ASSERT_EQ(3, tokens.size()); + int label; + ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label)); + EXPECT_EQ(kInvalidLabel, label); + ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label)); + EXPECT_NE(kInvalidLabel, label); + TokenSpan token_span; + feature_processor.LabelToTokenSpan(label, &token_span); + EXPECT_EQ(0, token_span.first); + EXPECT_EQ(0, token_span.second); + + // Reconfigure with snapping enabled. + options.snap_label_span_boundaries_to_containing_tokens = true; + flatbuffers::DetachedBuffer options2_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), + &unilib_); + int label2; + ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); + EXPECT_EQ(label, label2); + ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2)); + EXPECT_EQ(label, label2); + ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2)); + EXPECT_EQ(label, label2); + + // Cross a token boundary. + ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2)); + EXPECT_EQ(kInvalidLabel, label2); + ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2)); + EXPECT_EQ(kInvalidLabel, label2); + + // Multiple tokens. + options.context_size = 2; + options.max_selection_span = 2; + flatbuffers::DetachedBuffer options3_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor3( + flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), + &unilib_); + tokens = feature_processor3.Tokenize("zero, one, two, three, four"); + ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); + EXPECT_NE(kInvalidLabel, label2); + feature_processor3.LabelToTokenSpan(label2, &token_span); + EXPECT_EQ(1, token_span.first); + EXPECT_EQ(0, token_span.second); + + int label3; + ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3)); + EXPECT_EQ(label2, label3); + ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3)); + EXPECT_EQ(label2, label3); + ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3)); + EXPECT_EQ(label2, label3); +} + +TEST_F(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { + FeatureProcessorOptionsT options; + options.context_size = 1; + options.max_selection_span = 1; + options.snap_label_span_boundaries_to_containing_tokens = false; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); + ASSERT_EQ(3, tokens.size()); + int label; + ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label)); + EXPECT_EQ(kInvalidLabel, label); + ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label)); + EXPECT_NE(kInvalidLabel, label); + TokenSpan token_span; + feature_processor.LabelToTokenSpan(label, &token_span); + EXPECT_EQ(0, token_span.first); + EXPECT_EQ(0, token_span.second); + + // Reconfigure with snapping enabled. + options.snap_label_span_boundaries_to_containing_tokens = true; + flatbuffers::DetachedBuffer options2_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), + &unilib_); + int label2; + ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); + EXPECT_EQ(label, label2); + ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2)); + EXPECT_EQ(label, label2); + ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2)); + EXPECT_EQ(label, label2); + + // Cross a token boundary. + ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2)); + EXPECT_EQ(kInvalidLabel, label2); + ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2)); + EXPECT_EQ(kInvalidLabel, label2); + + // Multiple tokens. + options.context_size = 2; + options.max_selection_span = 2; + flatbuffers::DetachedBuffer options3_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor3( + flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), + &unilib_); + tokens = feature_processor3.Tokenize("zero, one, two, three, four"); + ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); + EXPECT_NE(kInvalidLabel, label2); + feature_processor3.LabelToTokenSpan(label2, &token_span); + EXPECT_EQ(1, token_span.first); + EXPECT_EQ(0, token_span.second); + + int label3; + ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3)); + EXPECT_EQ(label2, label3); + ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3)); + EXPECT_EQ(label2, label3); + ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3)); + EXPECT_EQ(label2, label3); +} + +TEST_F(FeatureProcessorTest, CenterTokenFromClick) { + int token_index; + + // Exactly aligned indices. + token_index = internal::CenterTokenFromClick( + {6, 11}, + {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)}); + EXPECT_EQ(token_index, 1); + + // Click is contained in a token. + token_index = internal::CenterTokenFromClick( + {13, 17}, + {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)}); + EXPECT_EQ(token_index, 2); + + // Click spans two tokens. + token_index = internal::CenterTokenFromClick( + {6, 17}, + {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)}); + EXPECT_EQ(token_index, kInvalidIndex); +} + +TEST_F(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) { + int token_index; + + // Selection of length 3. Exactly aligned indices. + token_index = internal::CenterTokenFromMiddleOfSelection( + {7, 27}, + {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), + Token("Token4", 21, 27), Token("Token5", 28, 34)}); + EXPECT_EQ(token_index, 2); + + // Selection of length 1 token. Exactly aligned indices. + token_index = internal::CenterTokenFromMiddleOfSelection( + {21, 27}, + {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), + Token("Token4", 21, 27), Token("Token5", 28, 34)}); + EXPECT_EQ(token_index, 3); + + // Selection marks sub-token range, with no tokens in it. + token_index = internal::CenterTokenFromMiddleOfSelection( + {29, 33}, + {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), + Token("Token4", 21, 27), Token("Token5", 28, 34)}); + EXPECT_EQ(token_index, kInvalidIndex); + + // Selection of length 2. Sub-token indices. + token_index = internal::CenterTokenFromMiddleOfSelection( + {3, 25}, + {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), + Token("Token4", 21, 27), Token("Token5", 28, 34)}); + EXPECT_EQ(token_index, 1); + + // Selection of length 1. Sub-token indices. + token_index = internal::CenterTokenFromMiddleOfSelection( + {22, 34}, + {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), + Token("Token4", 21, 27), Token("Token5", 28, 34)}); + EXPECT_EQ(token_index, 4); + + // Some invalid ones. + token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {}); + EXPECT_EQ(token_index, -1); +} + +TEST_F(FeatureProcessorTest, SupportedCodepointsRatio) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.max_selection_span = 2; + options.snap_label_span_boundaries_to_containing_tokens = false; + options.feature_version = 2; + options.embedding_size = 4; + options.bounds_sensitive_features.reset( + new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); + options.bounds_sensitive_features->enabled = true; + options.bounds_sensitive_features->num_tokens_before = 5; + options.bounds_sensitive_features->num_tokens_inside_left = 3; + options.bounds_sensitive_features->num_tokens_inside_right = 3; + options.bounds_sensitive_features->num_tokens_after = 5; + options.bounds_sensitive_features->include_inside_bag = true; + options.bounds_sensitive_features->include_inside_length = true; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + { + options.supported_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.supported_codepoint_ranges.back(); + range->start = 0; + range->end = 128; + } + + { + options.supported_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.supported_codepoint_ranges.back(); + range->start = 10000; + range->end = 10001; + } + + { + options.supported_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.supported_codepoint_ranges.back(); + range->start = 20000; + range->end = 30000; + } + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + EXPECT_THAT(feature_processor.SupportedCodepointsRatio( + {0, 3}, feature_processor.Tokenize("aaa bbb ccc")), + FloatEq(1.0)); + EXPECT_THAT(feature_processor.SupportedCodepointsRatio( + {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")), + FloatEq(2.0 / 3)); + EXPECT_THAT(feature_processor.SupportedCodepointsRatio( + {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")), + FloatEq(0.0)); + EXPECT_FALSE(feature_processor.IsCodepointInRanges( + -1, feature_processor.supported_codepoint_ranges_)); + EXPECT_TRUE(feature_processor.IsCodepointInRanges( + 0, feature_processor.supported_codepoint_ranges_)); + EXPECT_TRUE(feature_processor.IsCodepointInRanges( + 10, feature_processor.supported_codepoint_ranges_)); + EXPECT_TRUE(feature_processor.IsCodepointInRanges( + 127, feature_processor.supported_codepoint_ranges_)); + EXPECT_FALSE(feature_processor.IsCodepointInRanges( + 128, feature_processor.supported_codepoint_ranges_)); + EXPECT_FALSE(feature_processor.IsCodepointInRanges( + 9999, feature_processor.supported_codepoint_ranges_)); + EXPECT_TRUE(feature_processor.IsCodepointInRanges( + 10000, feature_processor.supported_codepoint_ranges_)); + EXPECT_FALSE(feature_processor.IsCodepointInRanges( + 10001, feature_processor.supported_codepoint_ranges_)); + EXPECT_TRUE(feature_processor.IsCodepointInRanges( + 25000, feature_processor.supported_codepoint_ranges_)); + + const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7), + Token("eee", 8, 11)}; + + options.min_supported_codepoint_ratio = 0.0; + flatbuffers::DetachedBuffer options2_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), + &unilib_); + EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints( + tokens, /*token_span=*/{0, 3})); + + options.min_supported_codepoint_ratio = 0.2; + flatbuffers::DetachedBuffer options3_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor3( + flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), + &unilib_); + EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints( + tokens, /*token_span=*/{0, 3})); + + options.min_supported_codepoint_ratio = 0.5; + flatbuffers::DetachedBuffer options4_fb = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor4( + flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()), + &unilib_); + EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints( + tokens, /*token_span=*/{0, 3})); +} + +TEST_F(FeatureProcessorTest, InSpanFeature) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.max_selection_span = 2; + options.snap_label_span_boundaries_to_containing_tokens = false; + options.feature_version = 2; + options.embedding_size = 4; + options.extract_selection_mask_feature = true; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + std::unique_ptr<CachedFeatures> cached_features; + + FakeEmbeddingExecutor embedding_executor; + + const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7), + Token("ccc", 8, 11), Token("ddd", 12, 15)}; + + EXPECT_TRUE(feature_processor.ExtractFeatures( + tokens, /*token_span=*/{0, 4}, + /*selection_span_for_feature=*/{4, 11}, &embedding_executor, + /*embedding_cache=*/nullptr, /*feature_vector_size=*/5, + &cached_features)); + std::vector<float> features; + cached_features->AppendClickContextFeaturesForClick(1, &features); + ASSERT_EQ(features.size(), 25); + EXPECT_THAT(features[4], FloatEq(0.0)); + EXPECT_THAT(features[9], FloatEq(0.0)); + EXPECT_THAT(features[14], FloatEq(1.0)); + EXPECT_THAT(features[19], FloatEq(1.0)); + EXPECT_THAT(features[24], FloatEq(0.0)); +} + +TEST_F(FeatureProcessorTest, EmbeddingCache) { + FeatureProcessorOptionsT options; + options.context_size = 2; + options.max_selection_span = 2; + options.snap_label_span_boundaries_to_containing_tokens = false; + options.feature_version = 2; + options.embedding_size = 4; + options.bounds_sensitive_features.reset( + new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); + options.bounds_sensitive_features->enabled = true; + options.bounds_sensitive_features->num_tokens_before = 3; + options.bounds_sensitive_features->num_tokens_inside_left = 2; + options.bounds_sensitive_features->num_tokens_inside_right = 2; + options.bounds_sensitive_features->num_tokens_after = 3; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + std::unique_ptr<CachedFeatures> cached_features; + + FakeEmbeddingExecutor embedding_executor; + + const std::vector<Token> tokens = { + Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11), + Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)}; + + // We pre-populate the cache with dummy embeddings, to make sure they are + // used when populating the features vector. + const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0}; + const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0}; + const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0}; + FeatureProcessor::EmbeddingCache embedding_cache = { + {{kInvalidIndex, kInvalidIndex}, cached_padding_features}, + {{4, 7}, cached_features1}, + {{12, 15}, cached_features2}, + }; + + EXPECT_TRUE(feature_processor.ExtractFeatures( + tokens, /*token_span=*/{0, 6}, + /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, + &embedding_executor, &embedding_cache, /*feature_vector_size=*/4, + &cached_features)); + std::vector<float> features; + cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features); + ASSERT_EQ(features.size(), 40); + // Check that the dummy embeddings were used. + EXPECT_THAT(Subvector(features, 0, 4), + ElementsAreFloat(cached_padding_features)); + EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1)); + EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2)); + EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2)); + EXPECT_THAT(Subvector(features, 36, 40), + ElementsAreFloat(cached_padding_features)); + // Check that the real embeddings were cached. + EXPECT_EQ(embedding_cache.size(), 7); + EXPECT_THAT(Subvector(features, 4, 8), + ElementsAreFloat(embedding_cache.at({0, 3}))); + EXPECT_THAT(Subvector(features, 12, 16), + ElementsAreFloat(embedding_cache.at({8, 11}))); + EXPECT_THAT(Subvector(features, 20, 24), + ElementsAreFloat(embedding_cache.at({8, 11}))); + EXPECT_THAT(Subvector(features, 28, 32), + ElementsAreFloat(embedding_cache.at({16, 19}))); + EXPECT_THAT(Subvector(features, 32, 36), + ElementsAreFloat(embedding_cache.at({20, 23}))); +} + +TEST_F(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) { + std::vector<Token> tokens_orig{ + Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0), + Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0), + Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0), + Token("12", 0, 0)}; + + std::vector<Token> tokens; + int click_index; + + // Try to click first token and see if it gets padded from left. + tokens = tokens_orig; + click_index = 0; + internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index); + // clang-format off + EXPECT_EQ(tokens, std::vector<Token>({Token(), + Token(), + Token("0", 0, 0), + Token("1", 0, 0), + Token("2", 0, 0)})); + // clang-format on + EXPECT_EQ(click_index, 2); + + // When we click the second token nothing should get padded. + tokens = tokens_orig; + click_index = 2; + internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index); + // clang-format off + EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0), + Token("1", 0, 0), + Token("2", 0, 0), + Token("3", 0, 0), + Token("4", 0, 0)})); + // clang-format on + EXPECT_EQ(click_index, 2); + + // When we click the last token tokens should get padded from the right. + tokens = tokens_orig; + click_index = 12; + internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index); + // clang-format off + EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0), + Token("11", 0, 0), + Token("12", 0, 0), + Token(), + Token()})); + // clang-format on + EXPECT_EQ(click_index, 2); +} + +TEST_F(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) { + std::vector<Token> tokens_orig{ + Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0), + Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0), + Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0), + Token("12", 0, 0)}; + + std::vector<Token> tokens; + int click_index; + + // Try to click first token and see if it gets padded from left to maximum + // context_size. + tokens = tokens_orig; + click_index = 0; + internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index); + // clang-format off + EXPECT_EQ(tokens, std::vector<Token>({Token(), + Token(), + Token("0", 0, 0), + Token("1", 0, 0), + Token("2", 0, 0), + Token("3", 0, 0), + Token("4", 0, 0), + Token("5", 0, 0)})); + // clang-format on + EXPECT_EQ(click_index, 2); + + // Clicking to the middle with enough context should not produce any padding. + tokens = tokens_orig; + click_index = 6; + internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index); + // clang-format off + EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0), + Token("2", 0, 0), + Token("3", 0, 0), + Token("4", 0, 0), + Token("5", 0, 0), + Token("6", 0, 0), + Token("7", 0, 0), + Token("8", 0, 0), + Token("9", 0, 0)})); + // clang-format on + EXPECT_EQ(click_index, 5); + + // Clicking at the end should pad right to maximum context_size. + tokens = tokens_orig; + click_index = 11; + internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index); + // clang-format off + EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0), + Token("7", 0, 0), + Token("8", 0, 0), + Token("9", 0, 0), + Token("10", 0, 0), + Token("11", 0, 0), + Token("12", 0, 0), + Token(), + Token()})); + // clang-format on + EXPECT_EQ(click_index, 5); +} + +TEST_F(FeatureProcessorTest, InternalTokenizeOnScriptChange) { + FeatureProcessorOptionsT options; + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + { + auto& config = options.tokenization_codepoint_config.back(); + config->start = 0; + config->end = 256; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + config->script_id = 1; + } + options.tokenize_on_script_change = false; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"), + std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)})); + + options.tokenize_on_script_change = true; + flatbuffers::DetachedBuffer options_fb2 = + PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor2( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()), + &unilib_); + + EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"), + std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7), + Token("웹사이트", 7, 11)})); +} + +#ifdef TC3_TEST_ICU +TEST_F(FeatureProcessorTest, ICUTokenize) { + FeatureProcessorOptionsT options; + options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + UniLib unilib; + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); + std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ"); + ASSERT_EQ(tokens, + // clang-format off + std::vector<Token>({Token("พระบาท", 0, 6), + Token("สมเด็จ", 6, 12), + Token("พระ", 12, 15), + Token("ปร", 15, 17), + Token("มิ", 17, 19)})); + // clang-format on +} +#endif + +#ifdef TC3_TEST_ICU +TEST_F(FeatureProcessorTest, ICUTokenizeWithWhitespaces) { + FeatureProcessorOptionsT options; + options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU; + options.icu_preserve_whitespace_tokens = true; + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + UniLib unilib; + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); + std::vector<Token> tokens = + feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ"); + ASSERT_EQ(tokens, + // clang-format off + std::vector<Token>({Token("พระบาท", 0, 6), + Token(" ", 6, 7), + Token("สมเด็จ", 7, 13), + Token(" ", 13, 14), + Token("พระ", 14, 17), + Token(" ", 17, 18), + Token("ปร", 18, 20), + Token(" ", 20, 21), + Token("มิ", 21, 23)})); + // clang-format on +} +#endif + +#ifdef TC3_TEST_ICU +TEST_F(FeatureProcessorTest, MixedTokenize) { + FeatureProcessorOptionsT options; + options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED; + + options.tokenization_codepoint_config.emplace_back( + new TokenizationCodepointRangeT()); + auto& config = options.tokenization_codepoint_config.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 0; + range->end = 128; + } + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 128; + range->end = 256; + } + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 256; + range->end = 384; + } + + { + options.internal_tokenizer_codepoint_ranges.emplace_back( + new FeatureProcessorOptions_::CodepointRangeT()); + auto& range = options.internal_tokenizer_codepoint_ranges.back(); + range->start = 384; + range->end = 592; + } + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + UniLib unilib; + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); + std::vector<Token> tokens = feature_processor.Tokenize( + "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/"); + ASSERT_EQ(tokens, + // clang-format off + std::vector<Token>({Token("こんにちは", 0, 5), + Token("Japanese-ląnguagę", 5, 22), + Token("text", 23, 27), + Token("世界", 28, 30), + Token("http://www.google.com/", 31, 53)})); + // clang-format on +} +#endif + +TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { + FeatureProcessorOptionsT options; + options.ignored_span_boundary_codepoints.push_back('.'); + options.ignored_span_boundary_codepoints.push_back(','); + options.ignored_span_boundary_codepoints.push_back('['); + options.ignored_span_boundary_codepoints.push_back(']'); + + flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + TestingFeatureProcessor feature_processor( + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib_); + + const std::string text1_utf8 = "ěščř"; + const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text1.begin(), text1.end(), + /*count_from_beginning=*/true), + 0); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text1.begin(), text1.end(), + /*count_from_beginning=*/false), + 0); + + const std::string text2_utf8 = ".,abčd"; + const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text2.begin(), text2.end(), + /*count_from_beginning=*/true), + 2); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text2.begin(), text2.end(), + /*count_from_beginning=*/false), + 0); + + const std::string text3_utf8 = ".,abčd[]"; + const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text3.begin(), text3.end(), + /*count_from_beginning=*/true), + 2); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text3.begin(), text3.end(), + /*count_from_beginning=*/false), + 2); + + const std::string text4_utf8 = "[abčd]"; + const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text4.begin(), text4.end(), + /*count_from_beginning=*/true), + 1); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text4.begin(), text4.end(), + /*count_from_beginning=*/false), + 1); + + const std::string text5_utf8 = ""; + const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text5.begin(), text5.end(), + /*count_from_beginning=*/true), + 0); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text5.begin(), text5.end(), + /*count_from_beginning=*/false), + 0); + + const std::string text6_utf8 = "012345ěščř"; + const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false); + UnicodeText::const_iterator text6_begin = text6.begin(); + std::advance(text6_begin, 6); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text6_begin, text6.end(), + /*count_from_beginning=*/true), + 0); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text6_begin, text6.end(), + /*count_from_beginning=*/false), + 0); + + const std::string text7_utf8 = "012345.,ěščř"; + const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false); + UnicodeText::const_iterator text7_begin = text7.begin(); + std::advance(text7_begin, 6); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text7_begin, text7.end(), + /*count_from_beginning=*/true), + 2); + UnicodeText::const_iterator text7_end = text7.begin(); + std::advance(text7_end, 8); + EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( + text7.begin(), text7_end, + /*count_from_beginning=*/false), + 2); + + // Test not stripping. + EXPECT_EQ(feature_processor.StripBoundaryCodepoints( + "Hello [[[Wořld]] or not?", {0, 24}), + std::make_pair(0, 24)); + // Test basic stripping. + EXPECT_EQ(feature_processor.StripBoundaryCodepoints( + "Hello [[[Wořld]] or not?", {6, 16}), + std::make_pair(9, 14)); + // Test stripping when everything is stripped. + EXPECT_EQ( + feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}), + std::make_pair(6, 6)); + // Test stripping empty string. + EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}), + std::make_pair(0, 0)); +} + +TEST_F(FeatureProcessorTest, CodepointSpanToTokenSpan) { + const std::vector<Token> tokens{Token("Hělló", 0, 5), + Token("fěěbař@google.com", 6, 23), + Token("heře!", 24, 29)}; + + // Spans matching the tokens exactly. + EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5})); + EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23})); + EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29})); + EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23})); + EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29})); + EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29})); + + // Snapping to containing tokens has no effect. + EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true)); + EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true)); + EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true)); + EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true)); + EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true)); + EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true)); + + // Span boundaries inside tokens. + EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28})); + EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true)); + + // Tokens adjacent to the span, but not overlapping. + EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24})); + EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true)); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h new file mode 100644 index 0000000..a6285dc --- /dev/null +++ b/annotator/knowledge/knowledge-engine-dummy.h @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_ + +#include <string> + +#include "annotator/types.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3 { + +// A dummy implementation of the knowledge engine. +class KnowledgeEngine { + public: + explicit KnowledgeEngine(const UniLib* unilib) {} + + bool Initialize(const std::string& serialized_config) { return true; } + + bool ClassifyText(const std::string& context, CodepointSpan selection_indices, + ClassificationResult* classification_result) const { + return false; + } + + bool Chunk(const std::string& context, + std::vector<AnnotatedSpan>* result) const { + return true; + } +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_ diff --git a/annotator/knowledge/knowledge-engine.h b/annotator/knowledge/knowledge-engine.h new file mode 100644 index 0000000..4776b26 --- /dev/null +++ b/annotator/knowledge/knowledge-engine.h @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_ + +#include "annotator/knowledge/knowledge-engine-dummy.h" + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_ diff --git a/annotator/model-executor.cc b/annotator/model-executor.cc new file mode 100644 index 0000000..7c57e8f --- /dev/null +++ b/annotator/model-executor.cc @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/model-executor.h" + +#include "annotator/quantization.h" +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +TensorView<float> ModelExecutor::ComputeLogits( + const TensorView<float>& features, tflite::Interpreter* interpreter) const { + if (!interpreter) { + return TensorView<float>::Invalid(); + } + interpreter->ResizeInputTensor(kInputIndexFeatures, features.shape()); + if (interpreter->AllocateTensors() != kTfLiteOk) { + TC3_VLOG(1) << "Allocation failed."; + return TensorView<float>::Invalid(); + } + + SetInput<float>(kInputIndexFeatures, features, interpreter); + + if (interpreter->Invoke() != kTfLiteOk) { + TC3_VLOG(1) << "Interpreter failed."; + return TensorView<float>::Invalid(); + } + + return OutputView<float>(kOutputIndexLogits, interpreter); +} + +std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, + int quantization_bits) { + std::unique_ptr<TfLiteModelExecutor> executor = + TfLiteModelExecutor::FromBuffer(model_spec_buffer); + if (!executor) { + TC3_LOG(ERROR) << "Could not load TFLite model for embeddings."; + return nullptr; + } + + std::unique_ptr<tflite::Interpreter> interpreter = + executor->CreateInterpreter(); + if (!interpreter) { + TC3_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; + return nullptr; + } + + if (interpreter->tensors_size() != 2) { + return nullptr; + } + const TfLiteTensor* embeddings = interpreter->tensor(0); + if (embeddings->dims->size != 2) { + return nullptr; + } + int num_buckets = embeddings->dims->data[0]; + const TfLiteTensor* scales = interpreter->tensor(1); + if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets || + scales->dims->data[1] != 1) { + return nullptr; + } + int bytes_per_embedding = embeddings->dims->data[1]; + if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits, + embedding_size)) { + TC3_LOG(ERROR) << "Mismatch in quantization parameters."; + return nullptr; + } + + return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor( + std::move(executor), quantization_bits, num_buckets, bytes_per_embedding, + embedding_size, scales, embeddings, std::move(interpreter))); +} + +TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( + std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, + int num_buckets, int bytes_per_embedding, int output_embedding_size, + const TfLiteTensor* scales, const TfLiteTensor* embeddings, + std::unique_ptr<tflite::Interpreter> interpreter) + : executor_(std::move(executor)), + quantization_bits_(quantization_bits), + num_buckets_(num_buckets), + bytes_per_embedding_(bytes_per_embedding), + output_embedding_size_(output_embedding_size), + scales_(scales), + embeddings_(embeddings), + interpreter_(std::move(interpreter)) {} + +bool TFLiteEmbeddingExecutor::AddEmbedding( + const TensorView<int>& sparse_features, float* dest, int dest_size) const { + if (dest_size != output_embedding_size_) { + TC3_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " + << dest_size << " " << output_embedding_size_; + return false; + } + const int num_sparse_features = sparse_features.size(); + for (int i = 0; i < num_sparse_features; ++i) { + const int bucket_id = sparse_features.data()[i]; + if (bucket_id >= num_buckets_) { + return false; + } + + if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8, + bytes_per_embedding_, num_sparse_features, + quantization_bits_, bucket_id, dest, dest_size)) { + return false; + } + } + return true; +} + +} // namespace libtextclassifier3 diff --git a/annotator/model-executor.h b/annotator/model-executor.h new file mode 100644 index 0000000..5ad3a7f --- /dev/null +++ b/annotator/model-executor.h @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Contains classes that can execute different models/parts of a model. + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ + +#include <memory> + +#include "annotator/types.h" +#include "utils/base/logging.h" +#include "utils/tensor-view.h" +#include "utils/tflite-model-executor.h" + +namespace libtextclassifier3 { + +// Executor for the text selection prediction and classification models. +class ModelExecutor : public TfLiteModelExecutor { + public: + static std::unique_ptr<ModelExecutor> FromModelSpec( + const tflite::Model* model_spec) { + auto model = TfLiteModelFromModelSpec(model_spec); + if (!model) { + return nullptr; + } + return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); + } + + static std::unique_ptr<ModelExecutor> FromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer) { + auto model = TfLiteModelFromBuffer(model_spec_buffer); + if (!model) { + return nullptr; + } + return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); + } + + TensorView<float> ComputeLogits(const TensorView<float>& features, + tflite::Interpreter* interpreter) const; + + protected: + explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) + : TfLiteModelExecutor(std::move(model)) {} + + static const int kInputIndexFeatures = 0; + static const int kOutputIndexLogits = 0; +}; + +// Executor for embedding sparse features into a dense vector. +class EmbeddingExecutor { + public: + virtual ~EmbeddingExecutor() {} + + // Embeds the sparse_features into a dense embedding and adds (+) it + // element-wise to the dest vector. + virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, + int dest_size) const = 0; + + // Returns true when the model is ready to be used, false otherwise. + virtual bool IsReady() const { return true; } +}; + +class TFLiteEmbeddingExecutor : public EmbeddingExecutor { + public: + static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, + int quantization_bits); + + // Embeds the sparse_features into a dense embedding and adds (+) it + // element-wise to the dest vector. + bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, + int dest_size) const; + + protected: + explicit TFLiteEmbeddingExecutor( + std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, + int num_buckets, int bytes_per_embedding, int output_embedding_size, + const TfLiteTensor* scales, const TfLiteTensor* embeddings, + std::unique_ptr<tflite::Interpreter> interpreter); + + std::unique_ptr<TfLiteModelExecutor> executor_; + + int quantization_bits_; + int num_buckets_ = -1; + int bytes_per_embedding_ = -1; + int output_embedding_size_ = -1; + const TfLiteTensor* scales_ = nullptr; + const TfLiteTensor* embeddings_ = nullptr; + + // NOTE: This interpreter is used in a read-only way (as a storage for the + // model params), thus is still thread-safe. + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ diff --git a/annotator/model.fbs b/annotator/model.fbs new file mode 100755 index 0000000..3682994 --- /dev/null +++ b/annotator/model.fbs @@ -0,0 +1,583 @@ +// +// Copyright (C) 2018 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +include "utils/intents/intent-config.fbs"; +include "utils/zlib/buffer.fbs"; + +file_identifier "TC2 "; + +// The possible model modes, represents a bit field. +namespace libtextclassifier3; +enum ModeFlag : int { + NONE = 0, + ANNOTATION = 1, + CLASSIFICATION = 2, + ANNOTATION_AND_CLASSIFICATION = 3, + SELECTION = 4, + ANNOTATION_AND_SELECTION = 5, + CLASSIFICATION_AND_SELECTION = 6, + ALL = 7, +} + +namespace libtextclassifier3; +enum DatetimeExtractorType : int { + UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0, + AM = 1, + PM = 2, + JANUARY = 3, + FEBRUARY = 4, + MARCH = 5, + APRIL = 6, + MAY = 7, + JUNE = 8, + JULY = 9, + AUGUST = 10, + SEPTEMBER = 11, + OCTOBER = 12, + NOVEMBER = 13, + DECEMBER = 14, + NEXT = 15, + NEXT_OR_SAME = 16, + LAST = 17, + NOW = 18, + TOMORROW = 19, + YESTERDAY = 20, + PAST = 21, + FUTURE = 22, + DAY = 23, + WEEK = 24, + MONTH = 25, + YEAR = 26, + MONDAY = 27, + TUESDAY = 28, + WEDNESDAY = 29, + THURSDAY = 30, + FRIDAY = 31, + SATURDAY = 32, + SUNDAY = 33, + DAYS = 34, + WEEKS = 35, + MONTHS = 36, + HOURS = 37, + MINUTES = 38, + SECONDS = 39, + YEARS = 40, + DIGITS = 41, + SIGNEDDIGITS = 42, + ZERO = 43, + ONE = 44, + TWO = 45, + THREE = 46, + FOUR = 47, + FIVE = 48, + SIX = 49, + SEVEN = 50, + EIGHT = 51, + NINE = 52, + TEN = 53, + ELEVEN = 54, + TWELVE = 55, + THIRTEEN = 56, + FOURTEEN = 57, + FIFTEEN = 58, + SIXTEEN = 59, + SEVENTEEN = 60, + EIGHTEEN = 61, + NINETEEN = 62, + TWENTY = 63, + THIRTY = 64, + FORTY = 65, + FIFTY = 66, + SIXTY = 67, + SEVENTY = 68, + EIGHTY = 69, + NINETY = 70, + HUNDRED = 71, + THOUSAND = 72, +} + +namespace libtextclassifier3; +enum DatetimeGroupType : int { + GROUP_UNKNOWN = 0, + GROUP_UNUSED = 1, + GROUP_YEAR = 2, + GROUP_MONTH = 3, + GROUP_DAY = 4, + GROUP_HOUR = 5, + GROUP_MINUTE = 6, + GROUP_SECOND = 7, + GROUP_AMPM = 8, + GROUP_RELATIONDISTANCE = 9, + GROUP_RELATION = 10, + GROUP_RELATIONTYPE = 11, + + // Dummy groups serve just as an inflator of the selection. E.g. we might want + // to select more text than was contained in an envelope of all extractor + // spans. + GROUP_DUMMY1 = 12, + + GROUP_DUMMY2 = 13, +} + +// Options for the model that predicts text selection. +namespace libtextclassifier3; +table SelectionModelOptions { + // If true, before the selection is returned, the unpaired brackets contained + // in the predicted selection are stripped from the both selection ends. + // The bracket codepoints are defined in the Unicode standard: + // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt + strip_unpaired_brackets:bool = true; + + // Number of hypothetical click positions on either side of the actual click + // to consider in order to enforce symmetry. + symmetry_context_size:int; + + // Number of examples to bundle in one batch for inference. + batch_size:int = 1024; + + // Whether to always classify a suggested selection or only on demand. + always_classify_suggested_selection:bool = false; +} + +// Options for the model that classifies a text selection. +namespace libtextclassifier3; +table ClassificationModelOptions { + // Limits for phone numbers. + phone_min_num_digits:int = 7; + + phone_max_num_digits:int = 15; + + // Limits for addresses. + address_min_num_tokens:int; + + // Maximum number of tokens to attempt a classification (-1 is unlimited). + max_num_tokens:int = -1; +} + +// Options for post-checks, checksums and verification to apply on a match. +namespace libtextclassifier3; +table VerificationOptions { + verify_luhn_checksum:bool = false; +} + +// List of regular expression matchers to check. +namespace libtextclassifier3.RegexModel_; +table Pattern { + // The name of the collection of a match. + collection_name:string; + + // The pattern to check. + // Can specify a single capturing group used as match boundaries. + pattern:string; + + // The modes for which to apply the patterns. + enabled_modes:libtextclassifier3.ModeFlag = ALL; + + // The final score to assign to the results of this pattern. + target_classification_score:float = 1; + + // Priority score used for conflict resolution with the other models. + priority_score:float = 0; + + // If true, will use an approximate matching implementation implemented + // using Find() instead of the true Match(). This approximate matching will + // use the first Find() result and then check that it spans the whole input. + use_approximate_matching:bool = false; + + compressed_pattern:libtextclassifier3.CompressedBuffer; + + // Verification to apply on a match. + verification_options:libtextclassifier3.VerificationOptions; +} + +namespace libtextclassifier3; +table RegexModel { + patterns:[libtextclassifier3.RegexModel_.Pattern]; +} + +// List of regex patterns. +namespace libtextclassifier3.DatetimeModelPattern_; +table Regex { + pattern:string; + + // The ith entry specifies the type of the ith capturing group. + // This is used to decide how the matched content has to be parsed. + groups:[libtextclassifier3.DatetimeGroupType]; + + compressed_pattern:libtextclassifier3.CompressedBuffer; +} + +namespace libtextclassifier3; +table DatetimeModelPattern { + regexes:[libtextclassifier3.DatetimeModelPattern_.Regex]; + + // List of locale indices in DatetimeModel that represent the locales that + // these patterns should be used for. If empty, can be used for all locales. + locales:[int]; + + // The final score to assign to the results of this pattern. + target_classification_score:float = 1; + + // Priority score used for conflict resolution with the other models. + priority_score:float = 0; + + // The modes for which to apply the patterns. + enabled_modes:libtextclassifier3.ModeFlag = ALL; +} + +namespace libtextclassifier3; +table DatetimeModelExtractor { + extractor:libtextclassifier3.DatetimeExtractorType; + pattern:string; + locales:[int]; + compressed_pattern:libtextclassifier3.CompressedBuffer; +} + +namespace libtextclassifier3; +table DatetimeModel { + // List of BCP 47 locale strings representing all locales supported by the + // model. The individual patterns refer back to them using an index. + locales:[string]; + + patterns:[libtextclassifier3.DatetimeModelPattern]; + extractors:[libtextclassifier3.DatetimeModelExtractor]; + + // If true, will use the extractors for determining the match location as + // opposed to using the location where the global pattern matched. + use_extractors_for_locating:bool = true; + + // List of locale ids, rules of whose are always run, after the requested + // ones. + default_locales:[int]; +} + +namespace libtextclassifier3.DatetimeModelLibrary_; +table Item { + key:string; + value:libtextclassifier3.DatetimeModel; +} + +// A set of named DateTime models. +namespace libtextclassifier3; +table DatetimeModelLibrary { + models:[libtextclassifier3.DatetimeModelLibrary_.Item]; +} + +// Options controlling the output of the Tensorflow Lite models. +namespace libtextclassifier3; +table ModelTriggeringOptions { + // Lower bound threshold for filtering annotation model outputs. + min_annotate_confidence:float = 0; + + // The modes for which to enable the models. + enabled_modes:libtextclassifier3.ModeFlag = ALL; +} + +// Options controlling the output of the classifier. +namespace libtextclassifier3; +table OutputOptions { + // Lists of collection names that will be filtered out at the output: + // - For annotation, the spans of given collection are simply dropped. + // - For classification, the result is mapped to the class "other". + // - For selection, the spans of given class are returned as + // single-selection. + filtered_collections_annotation:[string]; + + filtered_collections_classification:[string]; + filtered_collections_selection:[string]; +} + +namespace libtextclassifier3; +table Model { + // Comma-separated list of locales supported by the model as BCP 47 tags. + locales:string; + + version:int; + + // A name for the model that can be used for e.g. logging. + name:string; + + selection_feature_options:libtextclassifier3.FeatureProcessorOptions; + classification_feature_options:libtextclassifier3.FeatureProcessorOptions; + + // Tensorflow Lite models. + selection_model:[ubyte] (force_align: 16); + + classification_model:[ubyte] (force_align: 16); + embedding_model:[ubyte] (force_align: 16); + + // Options for the different models. + selection_options:libtextclassifier3.SelectionModelOptions; + + classification_options:libtextclassifier3.ClassificationModelOptions; + regex_model:libtextclassifier3.RegexModel; + datetime_model:libtextclassifier3.DatetimeModel; + + // Options controlling the output of the models. + triggering_options:libtextclassifier3.ModelTriggeringOptions; + + // Global switch that controls if SuggestSelection(), ClassifyText() and + // Annotate() will run. If a mode is disabled it returns empty/no-op results. + enabled_modes:libtextclassifier3.ModeFlag = ALL; + + // If true, will snap the selections that consist only of whitespaces to the + // containing suggested span. Otherwise, no suggestion is proposed, since the + // selections are not part of any token. + snap_whitespace_selections:bool = true; + + // Global configuration for the output of SuggestSelection(), ClassifyText() + // and Annotate(). + output_options:libtextclassifier3.OutputOptions; + + // Configures how Intents should be generated on Android. + // TODO(smillius): Remove deprecated factory options. + android_intent_options:libtextclassifier3.AndroidIntentFactoryOptions; + + intent_options:libtextclassifier3.IntentFactoryModel; +} + +// Role of the codepoints in the range. +namespace libtextclassifier3.TokenizationCodepointRange_; +enum Role : int { + // Concatenates the codepoint to the current run of codepoints. + DEFAULT_ROLE = 0, + + // Splits a run of codepoints before the current codepoint. + SPLIT_BEFORE = 1, + + // Splits a run of codepoints after the current codepoint. + SPLIT_AFTER = 2, + + // Each codepoint will be a separate token. Good e.g. for Chinese + // characters. + TOKEN_SEPARATOR = 3, + + // Discards the codepoint. + DISCARD_CODEPOINT = 4, + + // Common values: + // Splits on the characters and discards them. Good e.g. for the space + // character. + WHITESPACE_SEPARATOR = 7, +} + +// Represents a codepoint range [start, end) with its role for tokenization. +namespace libtextclassifier3; +table TokenizationCodepointRange { + start:int; + end:int; + role:libtextclassifier3.TokenizationCodepointRange_.Role; + + // Integer identifier of the script this range denotes. Negative values are + // reserved for Tokenizer's internal use. + script_id:int; +} + +// Method for selecting the center token. +namespace libtextclassifier3.FeatureProcessorOptions_; +enum CenterTokenSelectionMethod : int { + DEFAULT_CENTER_TOKEN_METHOD = 0, + + // Use click indices to determine the center token. + CENTER_TOKEN_FROM_CLICK = 1, + + // Use selection indices to get a token range, and select the middle of it + // as the center token. + CENTER_TOKEN_MIDDLE_OF_SELECTION = 2, +} + +// Controls the type of tokenization the model will use for the input text. +namespace libtextclassifier3.FeatureProcessorOptions_; +enum TokenizationType : int { + INVALID_TOKENIZATION_TYPE = 0, + + // Use the internal tokenizer for tokenization. + INTERNAL_TOKENIZER = 1, + + // Use ICU for tokenization. + ICU = 2, + + // First apply ICU tokenization. Then identify stretches of tokens + // consisting only of codepoints in internal_tokenizer_codepoint_ranges + // and re-tokenize them using the internal tokenizer. + MIXED = 3, +} + +// Range of codepoints start - end, where end is exclusive. +namespace libtextclassifier3.FeatureProcessorOptions_; +table CodepointRange { + start:int; + end:int; +} + +// Bounds-sensitive feature extraction configuration. +namespace libtextclassifier3.FeatureProcessorOptions_; +table BoundsSensitiveFeatures { + // Enables the extraction of bounds-sensitive features, instead of the click + // context features. + enabled:bool; + + // The numbers of tokens to extract in specific locations relative to the + // bounds. + // Immediately before the span. + num_tokens_before:int; + + // Inside the span, aligned with the beginning. + num_tokens_inside_left:int; + + // Inside the span, aligned with the end. + num_tokens_inside_right:int; + + // Immediately after the span. + num_tokens_after:int; + + // If true, also extracts the tokens of the entire span and adds up their + // features forming one "token" to include in the extracted features. + include_inside_bag:bool; + + // If true, includes the selection length (in the number of tokens) as a + // feature. + include_inside_length:bool; + + // If true, for selection, single token spans are not run through the model + // and their score is assumed to be zero. + score_single_token_spans_as_zero:bool; +} + +namespace libtextclassifier3; +table FeatureProcessorOptions { + // Number of buckets used for hashing charactergrams. + num_buckets:int = -1; + + // Size of the embedding. + embedding_size:int = -1; + + // Number of bits for quantization for embeddings. + embedding_quantization_bits:int = 8; + + // Context size defines the number of words to the left and to the right of + // the selected word to be used as context. For example, if context size is + // N, then we take N words to the left and N words to the right of the + // selected word as its context. + context_size:int = -1; + + // Maximum number of words of the context to select in total. + max_selection_span:int = -1; + + // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3 + // character trigrams etc. + chargram_orders:[int]; + + // Maximum length of a word, in codepoints. + max_word_length:int = 20; + + // If true, will use the unicode-aware functionality for extracting features. + unicode_aware_features:bool = false; + + // Whether to extract the token case feature. + extract_case_feature:bool = false; + + // Whether to extract the selection mask feature. + extract_selection_mask_feature:bool = false; + + // List of regexps to run over each token. For each regexp, if there is a + // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used. + regexp_feature:[string]; + + // Whether to remap all digits to a single number. + remap_digits:bool = false; + + // Whether to lower-case each token before generating hashgrams. + lowercase_tokens:bool; + + // If true, the selection classifier output will contain only the selections + // that are feasible (e.g., those that are shorter than max_selection_span), + // if false, the output will be a complete cross-product of possible + // selections to the left and possible selections to the right, including the + // infeasible ones. + // NOTE: Exists mainly for compatibility with older models that were trained + // with the non-reduced output space. + selection_reduced_output_space:bool = true; + + // Collection names. + collections:[string]; + + // An index of collection in collections to be used if a collection name can't + // be mapped to an id. + default_collection:int = -1; + + // If true, will split the input by lines, and only use the line that contains + // the clicked token. + only_use_line_with_click:bool = false; + + // If true, will split tokens that contain the selection boundary, at the + // position of the boundary. + // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" + split_tokens_on_selection_boundaries:bool = false; + + // Codepoint ranges that determine how different codepoints are tokenized. + // The ranges must not overlap. + tokenization_codepoint_config:[libtextclassifier3.TokenizationCodepointRange]; + + center_token_selection_method:libtextclassifier3.FeatureProcessorOptions_.CenterTokenSelectionMethod; + + // If true, span boundaries will be snapped to containing tokens and not + // required to exactly match token boundaries. + snap_label_span_boundaries_to_containing_tokens:bool; + + // A set of codepoint ranges supported by the model. + supported_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange]; + + // A set of codepoint ranges to use in the mixed tokenization mode to identify + // stretches of tokens to re-tokenize using the internal tokenizer. + internal_tokenizer_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange]; + + // Minimum ratio of supported codepoints in the input context. If the ratio + // is lower than this, the feature computation will fail. + min_supported_codepoint_ratio:float = 0; + + // Used for versioning the format of features the model expects. + // - feature_version == 0: + // For each token the features consist of: + // - chargram embeddings + // - dense features + // Chargram embeddings for tokens are concatenated first together, + // and at the end, the dense features for the tokens are concatenated + // to it. So the resulting feature vector has two regions. + feature_version:int = 0; + + tokenization_type:libtextclassifier3.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER; + icu_preserve_whitespace_tokens:bool = false; + + // List of codepoints that will be stripped from beginning and end of + // predicted spans. + ignored_span_boundary_codepoints:[int]; + + bounds_sensitive_features:libtextclassifier3.FeatureProcessorOptions_.BoundsSensitiveFeatures; + + // List of allowed charactergrams. The extracted charactergrams are filtered + // using this list, and charactergrams that are not present are interpreted as + // out-of-vocabulary. + // If no allowed_chargrams are specified, all charactergrams are allowed. + // The field is typed as bytes type to allow non-UTF8 chargrams. + allowed_chargrams:[string]; + + // If true, tokens will be also split when the codepoint's script_id changes + // as defined in TokenizationCodepointRange. + tokenize_on_script_change:bool = false; +} + +root_type libtextclassifier3.Model; diff --git a/annotator/quantization.cc b/annotator/quantization.cc new file mode 100644 index 0000000..2cf11c5 --- /dev/null +++ b/annotator/quantization.cc @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/quantization.h" + +#include "utils/base/logging.h" + +namespace libtextclassifier3 { +namespace { +float DequantizeValue(int num_sparse_features, int quantization_bias, + float multiplier, int value) { + return 1.0 / num_sparse_features * (value - quantization_bias) * multiplier; +} + +void DequantizeAdd8bit(const float* scales, const uint8* embeddings, + int bytes_per_embedding, const int num_sparse_features, + const int bucket_id, float* dest, int dest_size) { + static const int kQuantizationBias8bit = 128; + const float multiplier = scales[bucket_id]; + for (int k = 0; k < dest_size; ++k) { + dest[k] += + DequantizeValue(num_sparse_features, kQuantizationBias8bit, multiplier, + embeddings[bucket_id * bytes_per_embedding + k]); + } +} + +void DequantizeAddNBit(const float* scales, const uint8* embeddings, + int bytes_per_embedding, int num_sparse_features, + int quantization_bits, int bucket_id, float* dest, + int dest_size) { + const int quantization_bias = 1 << (quantization_bits - 1); + const float multiplier = scales[bucket_id]; + for (int i = 0; i < dest_size; ++i) { + const int bit_offset = i * quantization_bits; + const int read16_offset = bit_offset / 8; + + uint16 data = embeddings[bucket_id * bytes_per_embedding + read16_offset]; + // If we are not at the end of the embedding row, we can read 2-byte uint16, + // but if we are, we need to only read uint8. + if (read16_offset < bytes_per_embedding - 1) { + data |= embeddings[bucket_id * bytes_per_embedding + read16_offset + 1] + << 8; + } + int value = (data >> (bit_offset % 8)) & ((1 << quantization_bits) - 1); + dest[i] += DequantizeValue(num_sparse_features, quantization_bias, + multiplier, value); + } +} +} // namespace + +bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits, + int output_embedding_size) { + if (bytes_per_embedding * 8 / quantization_bits < output_embedding_size) { + return false; + } + + return true; +} + +bool DequantizeAdd(const float* scales, const uint8* embeddings, + int bytes_per_embedding, int num_sparse_features, + int quantization_bits, int bucket_id, float* dest, + int dest_size) { + if (quantization_bits == 8) { + DequantizeAdd8bit(scales, embeddings, bytes_per_embedding, + num_sparse_features, bucket_id, dest, dest_size); + } else if (quantization_bits != 8) { + DequantizeAddNBit(scales, embeddings, bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest, + dest_size); + } else { + TC3_LOG(ERROR) << "Unsupported quantization_bits: " << quantization_bits; + return false; + } + + return true; +} + +} // namespace libtextclassifier3 diff --git a/annotator/quantization.h b/annotator/quantization.h new file mode 100644 index 0000000..d294f37 --- /dev/null +++ b/annotator/quantization.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_ + +#include "utils/base/integral_types.h" + +namespace libtextclassifier3 { + +// Returns true if the quantization parameters are valid. +bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits, + int output_embedding_size); + +// Dequantizes embeddings (quantized to 1 to 8 bits) into the floats they +// represent. The algorithm proceeds by reading 2-byte words from the embedding +// storage to handle well the cases when the quantized value crosses the byte- +// boundary. +bool DequantizeAdd(const float* scales, const uint8* embeddings, + int bytes_per_embedding, int num_sparse_features, + int quantization_bits, int bucket_id, float* dest, + int dest_size); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_ diff --git a/annotator/quantization_test.cc b/annotator/quantization_test.cc new file mode 100644 index 0000000..b995096 --- /dev/null +++ b/annotator/quantization_test.cc @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/quantization.h" + +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ElementsAreArray; +using testing::FloatEq; +using testing::Matcher; + +namespace libtextclassifier3 { +namespace { + +Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { + std::vector<Matcher<float>> matchers; + for (const float value : values) { + matchers.push_back(FloatEq(value)); + } + return ElementsAreArray(matchers); +} + +TEST(QuantizationTest, DequantizeAdd8bit) { + std::vector<float> scales{{0.1, 9.0, -7.0}}; + std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00, + /*1: */ 0xFF, 0x09, 0x00, 0xFF, + /*2: */ 0x09, 0x00, 0xFF, 0x09}}; + + const int quantization_bits = 8; + const int bytes_per_embedding = 4; + const int num_sparse_features = 7; + { + const int bucket_id = 0; + std::vector<float> dest(4, 0.0); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, + dest.data(), dest.size()); + + EXPECT_THAT(dest, + ElementsAreFloat(std::vector<float>{ + // clang-format off + {1.0 / 7 * 0.1 * (0x00 - 128), + 1.0 / 7 * 0.1 * (0xFF - 128), + 1.0 / 7 * 0.1 * (0x09 - 128), + 1.0 / 7 * 0.1 * (0x00 - 128)} + // clang-format on + })); + } + + { + const int bucket_id = 1; + std::vector<float> dest(4, 0.0); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, + dest.data(), dest.size()); + + EXPECT_THAT(dest, + ElementsAreFloat(std::vector<float>{ + // clang-format off + {1.0 / 7 * 9.0 * (0xFF - 128), + 1.0 / 7 * 9.0 * (0x09 - 128), + 1.0 / 7 * 9.0 * (0x00 - 128), + 1.0 / 7 * 9.0 * (0xFF - 128)} + // clang-format on + })); + } +} + +TEST(QuantizationTest, DequantizeAdd1bitZeros) { + const int bytes_per_embedding = 4; + const int num_buckets = 3; + const int num_sparse_features = 7; + const int quantization_bits = 1; + const int bucket_id = 1; + + std::vector<float> scales(num_buckets); + std::vector<uint8> embeddings(bytes_per_embedding * num_buckets); + std::fill(scales.begin(), scales.end(), 1); + std::fill(embeddings.begin(), embeddings.end(), 0); + + std::vector<float> dest(32); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest.data(), + dest.size()); + + std::vector<float> expected(32); + std::fill(expected.begin(), expected.end(), + 1.0 / num_sparse_features * (0 - 1)); + EXPECT_THAT(dest, ElementsAreFloat(expected)); +} + +TEST(QuantizationTest, DequantizeAdd1bitOnes) { + const int bytes_per_embedding = 4; + const int num_buckets = 3; + const int num_sparse_features = 7; + const int quantization_bits = 1; + const int bucket_id = 1; + + std::vector<float> scales(num_buckets, 1.0); + std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF); + + std::vector<float> dest(32); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest.data(), + dest.size()); + std::vector<float> expected(32); + std::fill(expected.begin(), expected.end(), + 1.0 / num_sparse_features * (1 - 1)); + EXPECT_THAT(dest, ElementsAreFloat(expected)); +} + +TEST(QuantizationTest, DequantizeAdd3bit) { + const int bytes_per_embedding = 4; + const int num_buckets = 3; + const int num_sparse_features = 7; + const int quantization_bits = 3; + const int bucket_id = 1; + + std::vector<float> scales(num_buckets, 1.0); + scales[1] = 9.0; + std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0); + // For bucket_id=1, the embedding has values 0..9 for indices 0..9: + embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1; + embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3); + embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1; + + std::vector<float> dest(10); + DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding, + num_sparse_features, quantization_bits, bucket_id, dest.data(), + dest.size()); + + std::vector<float> expected; + expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]); + expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]); + EXPECT_THAT(dest, ElementsAreFloat(expected)); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/strip-unpaired-brackets.cc b/annotator/strip-unpaired-brackets.cc new file mode 100644 index 0000000..b1067ad --- /dev/null +++ b/annotator/strip-unpaired-brackets.cc @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/strip-unpaired-brackets.h" + +#include <iterator> + +#include "utils/base/logging.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { +namespace { + +// Returns true if given codepoint is contained in the given span in context. +bool IsCodepointInSpan(const char32 codepoint, + const UnicodeText& context_unicode, + const CodepointSpan span) { + auto begin_it = context_unicode.begin(); + std::advance(begin_it, span.first); + auto end_it = context_unicode.begin(); + std::advance(end_it, span.second); + + return std::find(begin_it, end_it, codepoint) != end_it; +} + +// Returns the first codepoint of the span. +char32 FirstSpanCodepoint(const UnicodeText& context_unicode, + const CodepointSpan span) { + auto it = context_unicode.begin(); + std::advance(it, span.first); + return *it; +} + +// Returns the last codepoint of the span. +char32 LastSpanCodepoint(const UnicodeText& context_unicode, + const CodepointSpan span) { + auto it = context_unicode.begin(); + std::advance(it, span.second - 1); + return *it; +} + +} // namespace + +CodepointSpan StripUnpairedBrackets(const std::string& context, + CodepointSpan span, const UniLib& unilib) { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + return StripUnpairedBrackets(context_unicode, span, unilib); +} + +// If the first or the last codepoint of the given span is a bracket, the +// bracket is stripped if the span does not contain its corresponding paired +// version. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, + CodepointSpan span, const UniLib& unilib) { + if (context_unicode.empty() || !ValidNonEmptySpan(span)) { + return span; + } + + const char32 begin_char = FirstSpanCodepoint(context_unicode, span); + const char32 paired_begin_char = unilib.GetPairedBracket(begin_char); + if (paired_begin_char != begin_char) { + if (!unilib.IsOpeningBracket(begin_char) || + !IsCodepointInSpan(paired_begin_char, context_unicode, span)) { + ++span.first; + } + } + + if (span.first == span.second) { + return span; + } + + const char32 end_char = LastSpanCodepoint(context_unicode, span); + const char32 paired_end_char = unilib.GetPairedBracket(end_char); + if (paired_end_char != end_char) { + if (!unilib.IsClosingBracket(end_char) || + !IsCodepointInSpan(paired_end_char, context_unicode, span)) { + --span.second; + } + } + + // Should not happen, but let's make sure. + if (span.first > span.second) { + TC3_LOG(WARNING) << "Inverse indices result: " << span.first << ", " + << span.second; + span.second = span.first; + } + + return span; +} + +} // namespace libtextclassifier3 diff --git a/annotator/strip-unpaired-brackets.h b/annotator/strip-unpaired-brackets.h new file mode 100644 index 0000000..ceb8d60 --- /dev/null +++ b/annotator/strip-unpaired-brackets.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_ + +#include <string> + +#include "annotator/types.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3 { +// If the first or the last codepoint of the given span is a bracket, the +// bracket is stripped if the span does not contain its corresponding paired +// version. +CodepointSpan StripUnpairedBrackets(const std::string& context, + CodepointSpan span, const UniLib& unilib); + +// Same as above but takes UnicodeText instance directly. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, + CodepointSpan span, const UniLib& unilib); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_ diff --git a/annotator/strip-unpaired-brackets_test.cc b/annotator/strip-unpaired-brackets_test.cc new file mode 100644 index 0000000..32585ce --- /dev/null +++ b/annotator/strip-unpaired-brackets_test.cc @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/strip-unpaired-brackets.h" + +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +class StripUnpairedBracketsTest : public ::testing::Test { + protected: + StripUnpairedBracketsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + +TEST_F(StripUnpairedBracketsTest, StripUnpairedBrackets) { + // If the brackets match, nothing gets stripped. + EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib_), + std::make_pair(8, 17)); + EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib_), + std::make_pair(8, 17)); + + // If the brackets don't match, they get stripped. + EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib_), + std::make_pair(9, 16)); + EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib_), + std::make_pair(9, 16)); + EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib_), + std::make_pair(8, 15)); + EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib_), + std::make_pair(8, 15)); + + // Strips brackets correctly from length-1 selections that consist of + // a bracket only. + EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib_), + std::make_pair(12, 12)); + EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib_), + std::make_pair(12, 12)); + + // Handles invalid spans gracefully. + EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib_), + std::make_pair(11, 11)); + EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib_), + std::make_pair(0, 0)); + EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib_), + std::make_pair(11, 11)); + EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib_), + std::make_pair(-1, -1)); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb Binary files differnew file mode 100644 index 0000000..fa9cec5 --- /dev/null +++ b/annotator/test_data/test_model.fb diff --git a/annotator/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb Binary files differnew file mode 100644 index 0000000..b73d84f --- /dev/null +++ b/annotator/test_data/test_model_cc.fb diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb Binary files differnew file mode 100644 index 0000000..ba71cdd --- /dev/null +++ b/annotator/test_data/wrong_embeddings.fb diff --git a/annotator/token-feature-extractor.cc b/annotator/token-feature-extractor.cc new file mode 100644 index 0000000..77ad7a4 --- /dev/null +++ b/annotator/token-feature-extractor.cc @@ -0,0 +1,311 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/token-feature-extractor.h" + +#include <cctype> +#include <string> + +#include "utils/base/logging.h" +#include "utils/hash/farmhash.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { + +namespace { + +std::string RemapTokenAscii(const std::string& token, + const TokenFeatureExtractorOptions& options) { + if (!options.remap_digits && !options.lowercase_tokens) { + return token; + } + + std::string copy = token; + for (int i = 0; i < token.size(); ++i) { + if (options.remap_digits && isdigit(copy[i])) { + copy[i] = '0'; + } + if (options.lowercase_tokens) { + copy[i] = tolower(copy[i]); + } + } + return copy; +} + +void RemapTokenUnicode(const std::string& token, + const TokenFeatureExtractorOptions& options, + const UniLib& unilib, UnicodeText* remapped) { + if (!options.remap_digits && !options.lowercase_tokens) { + // Leave remapped untouched. + return; + } + + UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false); + remapped->clear(); + for (auto it = word.begin(); it != word.end(); ++it) { + if (options.remap_digits && unilib.IsDigit(*it)) { + remapped->push_back('0'); + } else if (options.lowercase_tokens) { + remapped->push_back(unilib.ToLower(*it)); + } else { + remapped->push_back(*it); + } + } +} + +} // namespace + +TokenFeatureExtractor::TokenFeatureExtractor( + const TokenFeatureExtractorOptions& options, const UniLib& unilib) + : options_(options), unilib_(unilib) { + for (const std::string& pattern : options.regexp_features) { + regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>( + unilib_.CreateRegexPattern(UTF8ToUnicodeText( + pattern.c_str(), pattern.size(), /*do_copy=*/false)))); + } +} + +bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span, + std::vector<int>* sparse_features, + std::vector<float>* dense_features) const { + if (!dense_features) { + return false; + } + if (sparse_features) { + *sparse_features = ExtractCharactergramFeatures(token); + } + *dense_features = ExtractDenseFeatures(token, is_in_span); + return true; +} + +std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures( + const Token& token) const { + if (options_.unicode_aware_features) { + return ExtractCharactergramFeaturesUnicode(token); + } else { + return ExtractCharactergramFeaturesAscii(token); + } +} + +std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures( + const Token& token, bool is_in_span) const { + std::vector<float> dense_features; + + if (options_.extract_case_feature) { + if (options_.unicode_aware_features) { + UnicodeText token_unicode = + UTF8ToUnicodeText(token.value, /*do_copy=*/false); + const bool is_upper = unilib_.IsUpper(*token_unicode.begin()); + if (!token.value.empty() && is_upper) { + dense_features.push_back(1.0); + } else { + dense_features.push_back(-1.0); + } + } else { + if (!token.value.empty() && isupper(*token.value.begin())) { + dense_features.push_back(1.0); + } else { + dense_features.push_back(-1.0); + } + } + } + + if (options_.extract_selection_mask_feature) { + if (is_in_span) { + dense_features.push_back(1.0); + } else { + if (options_.unicode_aware_features) { + dense_features.push_back(-1.0); + } else { + dense_features.push_back(0.0); + } + } + } + + // Add regexp features. + if (!regex_patterns_.empty()) { + UnicodeText token_unicode = + UTF8ToUnicodeText(token.value, /*do_copy=*/false); + for (int i = 0; i < regex_patterns_.size(); ++i) { + if (!regex_patterns_[i].get()) { + dense_features.push_back(-1.0); + continue; + } + auto matcher = regex_patterns_[i]->Matcher(token_unicode); + int status; + if (matcher->Matches(&status)) { + dense_features.push_back(1.0); + } else { + dense_features.push_back(-1.0); + } + } + } + + return dense_features; +} + +int TokenFeatureExtractor::HashToken(StringPiece token) const { + if (options_.allowed_chargrams.empty()) { + return tc3farmhash::Fingerprint64(token) % options_.num_buckets; + } else { + // Padding and out-of-vocabulary tokens have extra buckets reserved because + // they are special and important tokens, and we don't want them to share + // embedding with other charactergrams. + // TODO(zilka): Experimentally verify. + const int kNumExtraBuckets = 2; + const std::string token_string = token.ToString(); + if (token_string == "<PAD>") { + return 1; + } else if (options_.allowed_chargrams.find(token_string) == + options_.allowed_chargrams.end()) { + return 0; // Out-of-vocabulary. + } else { + return (tc3farmhash::Fingerprint64(token) % + (options_.num_buckets - kNumExtraBuckets)) + + kNumExtraBuckets; + } + } +} + +std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii( + const Token& token) const { + std::vector<int> result; + if (token.is_padding || token.value.empty()) { + result.push_back(HashToken("<PAD>")); + } else { + const std::string word = RemapTokenAscii(token.value, options_); + + // Trim words that are over max_word_length characters. + const int max_word_length = options_.max_word_length; + std::string feature_word; + if (word.size() > max_word_length) { + feature_word = + "^" + word.substr(0, max_word_length / 2) + "\1" + + word.substr(word.size() - max_word_length / 2, max_word_length / 2) + + "$"; + } else { + // Add a prefix and suffix to the word. + feature_word = "^" + word + "$"; + } + + // Upper-bound the number of charactergram extracted to avoid resizing. + result.reserve(options_.chargram_orders.size() * feature_word.size()); + + if (options_.chargram_orders.empty()) { + result.push_back(HashToken(feature_word)); + } else { + // Generate the character-grams. + for (int chargram_order : options_.chargram_orders) { + if (chargram_order == 1) { + for (int i = 1; i < feature_word.size() - 1; ++i) { + result.push_back( + HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1))); + } + } else { + for (int i = 0; + i < static_cast<int>(feature_word.size()) - chargram_order + 1; + ++i) { + result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i, + /*len=*/chargram_order))); + } + } + } + } + } + return result; +} + +std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode( + const Token& token) const { + std::vector<int> result; + if (token.is_padding || token.value.empty()) { + result.push_back(HashToken("<PAD>")); + } else { + UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false); + RemapTokenUnicode(token.value, options_, unilib_, &word); + + // Trim the word if needed by finding a left-cut point and right-cut point. + auto left_cut = word.begin(); + auto right_cut = word.end(); + for (int i = 0; i < options_.max_word_length / 2; i++) { + if (left_cut < right_cut) { + ++left_cut; + } + if (left_cut < right_cut) { + --right_cut; + } + } + + std::string feature_word; + if (left_cut == right_cut) { + feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$"; + } else { + // clang-format off + feature_word = "^" + + word.UTF8Substring(word.begin(), left_cut) + + "\1" + + word.UTF8Substring(right_cut, word.end()) + + "$"; + // clang-format on + } + + const UnicodeText feature_word_unicode = + UTF8ToUnicodeText(feature_word, /*do_copy=*/false); + + // Upper-bound the number of charactergram extracted to avoid resizing. + result.reserve(options_.chargram_orders.size() * feature_word.size()); + + if (options_.chargram_orders.empty()) { + result.push_back(HashToken(feature_word)); + } else { + // Generate the character-grams. + for (int chargram_order : options_.chargram_orders) { + UnicodeText::const_iterator it_start = feature_word_unicode.begin(); + UnicodeText::const_iterator it_end = feature_word_unicode.end(); + if (chargram_order == 1) { + ++it_start; + --it_end; + } + + UnicodeText::const_iterator it_chargram_start = it_start; + UnicodeText::const_iterator it_chargram_end = it_start; + bool chargram_is_complete = true; + for (int i = 0; i < chargram_order; ++i) { + if (it_chargram_end == it_end) { + chargram_is_complete = false; + break; + } + ++it_chargram_end; + } + if (!chargram_is_complete) { + continue; + } + + for (; it_chargram_end <= it_end; + ++it_chargram_start, ++it_chargram_end) { + const int length_bytes = + it_chargram_end.utf8_data() - it_chargram_start.utf8_data(); + result.push_back(HashToken( + StringPiece(it_chargram_start.utf8_data(), length_bytes))); + } + } + } + } + return result; +} + +} // namespace libtextclassifier3 diff --git a/annotator/token-feature-extractor.h b/annotator/token-feature-extractor.h new file mode 100644 index 0000000..7dc19fe --- /dev/null +++ b/annotator/token-feature-extractor.h @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_set> +#include <vector> + +#include "annotator/types.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3 { + +struct TokenFeatureExtractorOptions { + // Number of buckets used for hashing charactergrams. + int num_buckets = 0; + + // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3 + // character trigrams etc. + std::vector<int> chargram_orders; + + // Whether to extract the token case feature. + bool extract_case_feature = false; + + // If true, will use the unicode-aware functionality for extracting features. + bool unicode_aware_features = false; + + // Whether to extract the selection mask feature. + bool extract_selection_mask_feature = false; + + // Regexp features to extract. + std::vector<std::string> regexp_features; + + // Whether to remap digits to a single number. + bool remap_digits = false; + + // Whether to lowercase all tokens. + bool lowercase_tokens = false; + + // Maximum length of a word. + int max_word_length = 20; + + // List of allowed charactergrams. The extracted charactergrams are filtered + // using this list, and charactergrams that are not present are interpreted as + // out-of-vocabulary. + // If no allowed_chargrams are specified, all charactergrams are allowed. + std::unordered_set<std::string> allowed_chargrams; +}; + +class TokenFeatureExtractor { + public: + TokenFeatureExtractor(const TokenFeatureExtractorOptions& options, + const UniLib& unilib); + + // Extracts both the sparse (charactergram) and the dense features from a + // token. is_in_span is a bool indicator whether the token is a part of the + // selection span (true) or not (false). + // The sparse_features output is optional. Fails and returns false if + // dense_fatures in a nullptr. + bool Extract(const Token& token, bool is_in_span, + std::vector<int>* sparse_features, + std::vector<float>* dense_features) const; + + // Extracts the sparse (charactergram) features from the token. + std::vector<int> ExtractCharactergramFeatures(const Token& token) const; + + // Extracts the dense features from the token. is_in_span is a bool indicator + // whether the token is a part of the selection span (true) or not (false). + std::vector<float> ExtractDenseFeatures(const Token& token, + bool is_in_span) const; + + int DenseFeaturesCount() const { + int feature_count = + options_.extract_case_feature + options_.extract_selection_mask_feature; + feature_count += regex_patterns_.size(); + return feature_count; + } + + protected: + // Hashes given token to given number of buckets. + int HashToken(StringPiece token) const; + + // Extracts the charactergram features from the token in a non-unicode-aware + // way. + std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const; + + // Extracts the charactergram features from the token in a unicode-aware way. + std::vector<int> ExtractCharactergramFeaturesUnicode( + const Token& token) const; + + private: + TokenFeatureExtractorOptions options_; + std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_; + const UniLib& unilib_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_ diff --git a/annotator/token-feature-extractor_test.cc b/annotator/token-feature-extractor_test.cc new file mode 100644 index 0000000..32383a9 --- /dev/null +++ b/annotator/token-feature-extractor_test.cc @@ -0,0 +1,556 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/token-feature-extractor.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +class TokenFeatureExtractorTest : public ::testing::Test { + protected: + TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + +class TestingTokenFeatureExtractor : public TokenFeatureExtractor { + public: + using TokenFeatureExtractor::HashToken; + using TokenFeatureExtractor::TokenFeatureExtractor; +}; + +TEST_F(TokenFeatureExtractorTest, ExtractAscii) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2, 3}; + options.extract_case_feature = true; + options.unicode_aware_features = false; + options.extract_selection_mask_feature = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + + extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({ + // clang-format off + extractor.HashToken("H"), + extractor.HashToken("e"), + extractor.HashToken("l"), + extractor.HashToken("l"), + extractor.HashToken("o"), + extractor.HashToken("^H"), + extractor.HashToken("He"), + extractor.HashToken("el"), + extractor.HashToken("ll"), + extractor.HashToken("lo"), + extractor.HashToken("o$"), + extractor.HashToken("^He"), + extractor.HashToken("Hel"), + extractor.HashToken("ell"), + extractor.HashToken("llo"), + extractor.HashToken("lo$") + // clang-format on + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({ + // clang-format off + extractor.HashToken("w"), + extractor.HashToken("o"), + extractor.HashToken("r"), + extractor.HashToken("l"), + extractor.HashToken("d"), + extractor.HashToken("!"), + extractor.HashToken("^w"), + extractor.HashToken("wo"), + extractor.HashToken("or"), + extractor.HashToken("rl"), + extractor.HashToken("ld"), + extractor.HashToken("d!"), + extractor.HashToken("!$"), + extractor.HashToken("^wo"), + extractor.HashToken("wor"), + extractor.HashToken("orl"), + extractor.HashToken("rld"), + extractor.HashToken("ld!"), + extractor.HashToken("d!$"), + // clang-format on + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); +} + +TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{}; + options.extract_case_feature = true; + options.unicode_aware_features = false; + options.extract_selection_mask_feature = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + + extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({extractor.HashToken("^Hello$")})); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({extractor.HashToken("^world!$")})); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); +} + +TEST_F(TokenFeatureExtractorTest, ExtractUnicode) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2, 3}; + options.extract_case_feature = true; + options.unicode_aware_features = true; + options.extract_selection_mask_feature = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + + extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({ + // clang-format off + extractor.HashToken("H"), + extractor.HashToken("ě"), + extractor.HashToken("l"), + extractor.HashToken("l"), + extractor.HashToken("ó"), + extractor.HashToken("^H"), + extractor.HashToken("Hě"), + extractor.HashToken("ěl"), + extractor.HashToken("ll"), + extractor.HashToken("ló"), + extractor.HashToken("ó$"), + extractor.HashToken("^Hě"), + extractor.HashToken("Hěl"), + extractor.HashToken("ěll"), + extractor.HashToken("lló"), + extractor.HashToken("ló$") + // clang-format on + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({ + // clang-format off + extractor.HashToken("w"), + extractor.HashToken("o"), + extractor.HashToken("r"), + extractor.HashToken("l"), + extractor.HashToken("d"), + extractor.HashToken("!"), + extractor.HashToken("^w"), + extractor.HashToken("wo"), + extractor.HashToken("or"), + extractor.HashToken("rl"), + extractor.HashToken("ld"), + extractor.HashToken("d!"), + extractor.HashToken("!$"), + extractor.HashToken("^wo"), + extractor.HashToken("wor"), + extractor.HashToken("orl"), + extractor.HashToken("rld"), + extractor.HashToken("ld!"), + extractor.HashToken("d!$"), + // clang-format on + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); +} + +TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{}; + options.extract_case_feature = true; + options.unicode_aware_features = true; + options.extract_selection_mask_feature = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + + extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({extractor.HashToken("^Hělló$")})); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, testing::ElementsAreArray({ + extractor.HashToken("^world!$"), + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); +} + +#ifdef TC3_TEST_ICU +TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.extract_case_feature = true; + options.unicode_aware_features = true; + options.extract_selection_mask_feature = false; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); +} +#endif + +TEST_F(TokenFeatureExtractorTest, DigitRemapping) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.remap_digits = true; + options.unicode_aware_features = false; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, + &dense_features); + + std::vector<int> sparse_features2; + extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); + + extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, + testing::Not(testing::ElementsAreArray(sparse_features2))); +} + +TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.remap_digits = true; + options.unicode_aware_features = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, + &dense_features); + + std::vector<int> sparse_features2; + extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); + + extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, + testing::Not(testing::ElementsAreArray(sparse_features2))); +} + +TEST_F(TokenFeatureExtractorTest, LowercaseAscii) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.lowercase_tokens = true; + options.unicode_aware_features = false; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features, + &dense_features); + + std::vector<int> sparse_features2; + extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); + + extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); +} + +#ifdef TC3_TEST_ICU +TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.lowercase_tokens = true; + options.unicode_aware_features = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features); + + std::vector<int> sparse_features2; + extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2, + &dense_features); + EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); +} +#endif + +#ifdef TC3_TEST_ICU +TEST_F(TokenFeatureExtractorTest, RegexFeatures) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.remap_digits = false; + options.unicode_aware_features = false; + options.regexp_features.push_back("^[a-z]+$"); // all lower case. + options.regexp_features.push_back("^[0-9]+$"); // all digits. + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); + + dense_features.clear(); + extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0})); + + dense_features.clear(); + extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); + + dense_features.clear(); + extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features, + &dense_features); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0})); +} +#endif + +TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{22}; + options.extract_case_feature = true; + options.unicode_aware_features = true; + options.extract_selection_mask_feature = true; + TestingTokenFeatureExtractor extractor(options, unilib_); + + // Test that this runs. ASAN should catch problems. + std::vector<int> sparse_features; + std::vector<float> dense_features; + extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true, + &sparse_features, &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({ + // clang-format off + extractor.HashToken("^abcdefghij\1qřstuvwxyz"), + extractor.HashToken("abcdefghij\1qřstuvwxyz$"), + // clang-format on + })); +} + +TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5}; + options.extract_case_feature = true; + options.unicode_aware_features = true; + options.extract_selection_mask_feature = true; + + TestingTokenFeatureExtractor extractor_unicode(options, unilib_); + + options.unicode_aware_features = false; + TestingTokenFeatureExtractor extractor_ascii(options, unilib_); + + for (const std::string& input : + {"https://www.abcdefgh.com/in/xxxkkkvayio", + "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil", + "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd", + "x", "Hello", "Hey,", "Hi", ""}) { + std::vector<int> sparse_features_unicode; + std::vector<float> dense_features_unicode; + extractor_unicode.Extract(Token{input, 0, 0}, true, + &sparse_features_unicode, + &dense_features_unicode); + + std::vector<int> sparse_features_ascii; + std::vector<float> dense_features_ascii; + extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii, + &dense_features_ascii); + + EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input; + EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input; + } +} + +TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2}; + options.extract_case_feature = true; + options.unicode_aware_features = false; + options.extract_selection_mask_feature = true; + + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + + extractor.Extract(Token(), false, &sparse_features, &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({extractor.HashToken("<PAD>")})); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); +} + +TEST_F(TokenFeatureExtractorTest, ExtractFiltered) { + TokenFeatureExtractorOptions options; + options.num_buckets = 1000; + options.chargram_orders = std::vector<int>{1, 2, 3}; + options.extract_case_feature = true; + options.unicode_aware_features = false; + options.extract_selection_mask_feature = true; + options.allowed_chargrams.insert("^H"); + options.allowed_chargrams.insert("ll"); + options.allowed_chargrams.insert("llo"); + options.allowed_chargrams.insert("w"); + options.allowed_chargrams.insert("!"); + options.allowed_chargrams.insert("\xc4"); // UTF8 control character. + + TestingTokenFeatureExtractor extractor(options, unilib_); + + std::vector<int> sparse_features; + std::vector<float> dense_features; + + extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, + testing::ElementsAreArray({ + // clang-format off + 0, + extractor.HashToken("\xc4"), + 0, + 0, + 0, + 0, + extractor.HashToken("^H"), + 0, + 0, + 0, + extractor.HashToken("ll"), + 0, + 0, + 0, + 0, + 0, + 0, + extractor.HashToken("llo"), + 0 + // clang-format on + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); + + sparse_features.clear(); + dense_features.clear(); + extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, + &dense_features); + + EXPECT_THAT(sparse_features, testing::ElementsAreArray({ + // clang-format off + extractor.HashToken("w"), + 0, + 0, + 0, + 0, + extractor.HashToken("!"), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + // clang-format on + })); + EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); + EXPECT_EQ(extractor.HashToken("<PAD>"), 1); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/tokenizer.cc b/annotator/tokenizer.cc new file mode 100644 index 0000000..099dccc --- /dev/null +++ b/annotator/tokenizer.cc @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/tokenizer.h" + +#include <algorithm> + +#include "utils/base/logging.h" +#include "utils/strings/utf8.h" + +namespace libtextclassifier3 { + +Tokenizer::Tokenizer( + const std::vector<const TokenizationCodepointRange*>& codepoint_ranges, + bool split_on_script_change) + : split_on_script_change_(split_on_script_change) { + for (const TokenizationCodepointRange* range : codepoint_ranges) { + codepoint_ranges_.emplace_back(range->UnPack()); + } + + std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(), + [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, + const std::unique_ptr<const TokenizationCodepointRangeT>& b) { + return a->start < b->start; + }); +} + +const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange( + int codepoint) const { + auto it = std::lower_bound( + codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint, + [](const std::unique_ptr<const TokenizationCodepointRangeT>& range, + int codepoint) { + // This function compares range with the codepoint for the purpose of + // finding the first greater or equal range. Because of the use of + // std::lower_bound it needs to return true when range < codepoint; + // the first time it will return false the lower bound is found and + // returned. + // + // It might seem weird that the condition is range.end <= codepoint + // here but when codepoint == range.end it means it's actually just + // outside of the range, thus the range is less than the codepoint. + return range->end <= codepoint; + }); + if (it != codepoint_ranges_.end() && (*it)->start <= codepoint && + (*it)->end > codepoint) { + return it->get(); + } else { + return nullptr; + } +} + +void Tokenizer::GetScriptAndRole(char32 codepoint, + TokenizationCodepointRange_::Role* role, + int* script) const { + const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint); + if (range) { + *role = range->role; + *script = range->script_id; + } else { + *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + *script = kUnknownScript; + } +} + +std::vector<Token> Tokenizer::Tokenize(const std::string& text) const { + UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false); + return Tokenize(text_unicode); +} + +std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const { + std::vector<Token> result; + Token new_token("", 0, 0); + int codepoint_index = 0; + + int last_script = kInvalidScript; + for (auto it = text_unicode.begin(); it != text_unicode.end(); + ++it, ++codepoint_index) { + TokenizationCodepointRange_::Role role; + int script; + GetScriptAndRole(*it, &role, &script); + + if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE || + (split_on_script_change_ && last_script != kInvalidScript && + last_script != script)) { + if (!new_token.value.empty()) { + result.push_back(new_token); + } + new_token = Token("", codepoint_index, codepoint_index); + } + if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) { + new_token.value += std::string( + it.utf8_data(), + it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data())); + ++new_token.end; + } + if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) { + if (!new_token.value.empty()) { + result.push_back(new_token); + } + new_token = Token("", codepoint_index + 1, codepoint_index + 1); + } + + last_script = script; + } + if (!new_token.value.empty()) { + result.push_back(new_token); + } + + return result; +} + +} // namespace libtextclassifier3 diff --git a/annotator/tokenizer.h b/annotator/tokenizer.h new file mode 100644 index 0000000..ec33f2d --- /dev/null +++ b/annotator/tokenizer.h @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_ + +#include <string> +#include <vector> + +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { + +const int kInvalidScript = -1; +const int kUnknownScript = -2; + +// Tokenizer splits the input string into a sequence of tokens, according to the +// configuration. +class Tokenizer { + public: + explicit Tokenizer( + const std::vector<const TokenizationCodepointRange*>& codepoint_ranges, + bool split_on_script_change); + + // Tokenizes the input string using the selected tokenization method. + std::vector<Token> Tokenize(const std::string& text) const; + + // Same as above but takes UnicodeText. + std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; + + protected: + // Finds the tokenization codepoint range config for given codepoint. + // Internally uses binary search so should be O(log(# of codepoint_ranges)). + const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const; + + // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE + // and kUnknownScript are assigned. + void GetScriptAndRole(char32 codepoint, + TokenizationCodepointRange_::Role* role, + int* script) const; + + private: + // Codepoint ranges that determine how different codepoints are tokenized. + // The ranges must not overlap. + std::vector<std::unique_ptr<const TokenizationCodepointRangeT>> + codepoint_ranges_; + + // If true, tokens will be additionally split when the codepoint's script_id + // changes. + bool split_on_script_change_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_ diff --git a/annotator/tokenizer_test.cc b/annotator/tokenizer_test.cc new file mode 100644 index 0000000..a3ab9da --- /dev/null +++ b/annotator/tokenizer_test.cc @@ -0,0 +1,334 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/tokenizer.h" + +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +using testing::ElementsAreArray; + +class TestingTokenizer : public Tokenizer { + public: + explicit TestingTokenizer( + const std::vector<const TokenizationCodepointRange*>& + codepoint_range_configs, + bool split_on_script_change) + : Tokenizer(codepoint_range_configs, split_on_script_change) {} + + using Tokenizer::FindTokenizationRange; +}; + +class TestingTokenizerProxy { + public: + explicit TestingTokenizerProxy( + const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs, + bool split_on_script_change) { + int num_configs = codepoint_range_configs.size(); + std::vector<const TokenizationCodepointRange*> configs_fb; + buffers_.reserve(num_configs); + for (int i = 0; i < num_configs; i++) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateTokenizationCodepointRange( + builder, &codepoint_range_configs[i])); + buffers_.push_back(builder.Release()); + configs_fb.push_back( + flatbuffers::GetRoot<TokenizationCodepointRange>(buffers_[i].data())); + } + tokenizer_ = std::unique_ptr<TestingTokenizer>( + new TestingTokenizer(configs_fb, split_on_script_change)); + } + + TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const { + const TokenizationCodepointRangeT* range = + tokenizer_->FindTokenizationRange(c); + if (range != nullptr) { + return range->role; + } else { + return TokenizationCodepointRange_::Role_DEFAULT_ROLE; + } + } + + std::vector<Token> Tokenize(const std::string& utf8_text) const { + return tokenizer_->Tokenize(utf8_text); + } + + private: + std::vector<flatbuffers::DetachedBuffer> buffers_; + std::unique_ptr<TestingTokenizer> tokenizer_; +}; + +TEST(TokenizerTest, FindTokenizationRange) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + configs.emplace_back(); + config = &configs.back(); + config->start = 0; + config->end = 10; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + configs.emplace_back(); + config = &configs.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + configs.emplace_back(); + config = &configs.back(); + config->start = 1234; + config->end = 12345; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false); + + // Test hits to the first group. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(0), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(5), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(10), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + + // Test a hit to the second group. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(31), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(32), + TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(33), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + + // Test hits to the third group. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344), + TokenizationCodepointRange_::Role_TOKEN_SEPARATOR); + EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); + + // Test a hit outside. + EXPECT_EQ(tokenizer.TestFindTokenizationRole(99), + TokenizationCodepointRange_::Role_DEFAULT_ROLE); +} + +TEST(TokenizerTest, TokenizeOnSpace) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + configs.emplace_back(); + config = &configs.back(); + // Space character. + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false); + std::vector<Token> tokens = tokenizer.Tokenize("Hello world!"); + + EXPECT_THAT(tokens, + ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)})); +} + +TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + // Latin. + configs.emplace_back(); + config = &configs.back(); + config->start = 0; + config->end = 32; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + config->script_id = 1; + configs.emplace_back(); + config = &configs.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + config->script_id = 1; + configs.emplace_back(); + config = &configs.back(); + config->start = 33; + config->end = 0x77F + 1; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + config->script_id = 1; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/true); + EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"), + std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6), + Token("전화", 7, 10), Token("(123)", 10, 15), + Token("456-789", 16, 23), + Token("웹사이트", 23, 28)})); +} // namespace + +TEST(TokenizerTest, TokenizeComplex) { + std::vector<TokenizationCodepointRangeT> configs; + TokenizationCodepointRangeT* config; + + // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt + // Latin - cyrilic. + // 0000..007F; Basic Latin + // 0080..00FF; Latin-1 Supplement + // 0100..017F; Latin Extended-A + // 0180..024F; Latin Extended-B + // 0250..02AF; IPA Extensions + // 02B0..02FF; Spacing Modifier Letters + // 0300..036F; Combining Diacritical Marks + // 0370..03FF; Greek and Coptic + // 0400..04FF; Cyrillic + // 0500..052F; Cyrillic Supplement + // 0530..058F; Armenian + // 0590..05FF; Hebrew + // 0600..06FF; Arabic + // 0700..074F; Syriac + // 0750..077F; Arabic Supplement + configs.emplace_back(); + config = &configs.back(); + config->start = 0; + config->end = 32; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + configs.emplace_back(); + config = &configs.back(); + config->start = 32; + config->end = 33; + config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 33; + config->end = 0x77F + 1; + config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; + + // CJK + // 2E80..2EFF; CJK Radicals Supplement + // 3000..303F; CJK Symbols and Punctuation + // 3040..309F; Hiragana + // 30A0..30FF; Katakana + // 3100..312F; Bopomofo + // 3130..318F; Hangul Compatibility Jamo + // 3190..319F; Kanbun + // 31A0..31BF; Bopomofo Extended + // 31C0..31EF; CJK Strokes + // 31F0..31FF; Katakana Phonetic Extensions + // 3200..32FF; Enclosed CJK Letters and Months + // 3300..33FF; CJK Compatibility + // 3400..4DBF; CJK Unified Ideographs Extension A + // 4DC0..4DFF; Yijing Hexagram Symbols + // 4E00..9FFF; CJK Unified Ideographs + // A000..A48F; Yi Syllables + // A490..A4CF; Yi Radicals + // A4D0..A4FF; Lisu + // A500..A63F; Vai + // F900..FAFF; CJK Compatibility Ideographs + // FE30..FE4F; CJK Compatibility Forms + // 20000..2A6DF; CJK Unified Ideographs Extension B + // 2A700..2B73F; CJK Unified Ideographs Extension C + // 2B740..2B81F; CJK Unified Ideographs Extension D + // 2B820..2CEAF; CJK Unified Ideographs Extension E + // 2CEB0..2EBEF; CJK Unified Ideographs Extension F + // 2F800..2FA1F; CJK Compatibility Ideographs Supplement + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2E80; + config->end = 0x2EFF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x3000; + config->end = 0xA63F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0xF900; + config->end = 0xFAFF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0xFE30; + config->end = 0xFE4F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x20000; + config->end = 0x2A6DF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2A700; + config->end = 0x2B73F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2B740; + config->end = 0x2B81F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2B820; + config->end = 0x2CEAF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2CEB0; + config->end = 0x2EBEF + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + configs.emplace_back(); + config = &configs.back(); + config->start = 0x2F800; + config->end = 0x2FA1F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + // Thai. + // 0E00..0E7F; Thai + configs.emplace_back(); + config = &configs.back(); + config->start = 0x0E00; + config->end = 0x0E7F + 1; + config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR; + + TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false); + std::vector<Token> tokens; + + tokens = tokenizer.Tokenize( + "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。"); + EXPECT_EQ(tokens.size(), 30); + + tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ"); + // clang-format off + EXPECT_THAT( + tokens, + ElementsAreArray({Token("問", 0, 1), + Token("少", 1, 2), + Token("目", 2, 3), + Token("hello", 4, 9), + Token("木", 10, 11), + Token("輸", 11, 12), + Token("ย", 12, 13), + Token("า", 13, 14), + Token("ม", 14, 15), + Token("き", 15, 16), + Token("ゃ", 16, 17)})); + // clang-format on +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/annotator/types-test-util.h b/annotator/types-test-util.h new file mode 100644 index 0000000..fbbdd63 --- /dev/null +++ b/annotator/types-test-util.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_ + +#include <ostream> + +#include "annotator/types.h" +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +inline std::ostream& operator<<(std::ostream& stream, const Token& value) { + logging::LoggingStringStream tmp_stream; + tmp_stream << value; + return stream << tmp_stream.message; +} + +inline std::ostream& operator<<(std::ostream& stream, + const AnnotatedSpan& value) { + logging::LoggingStringStream tmp_stream; + tmp_stream << value; + return stream << tmp_stream.message; +} + +inline std::ostream& operator<<(std::ostream& stream, + const DatetimeParseResultSpan& value) { + logging::LoggingStringStream tmp_stream; + tmp_stream << value; + return stream << tmp_stream.message; +} + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_ diff --git a/annotator/types.h b/annotator/types.h new file mode 100644 index 0000000..38bce41 --- /dev/null +++ b/annotator/types.h @@ -0,0 +1,402 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ + +#include <algorithm> +#include <cmath> +#include <functional> +#include <map> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "utils/base/integral_types.h" +#include "utils/base/logging.h" +#include "utils/variant.h" + +namespace libtextclassifier3 { + +constexpr int kInvalidIndex = -1; + +// Index for a 0-based array of tokens. +using TokenIndex = int; + +// Index for a 0-based array of codepoints. +using CodepointIndex = int; + +// Marks a span in a sequence of codepoints. The first element is the index of +// the first codepoint of the span, and the second element is the index of the +// codepoint one past the end of the span. +// TODO(b/71982294): Make it a struct. +using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; + +inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { + return a.first < b.second && b.first < a.second; +} + +inline bool ValidNonEmptySpan(const CodepointSpan& span) { + return span.first < span.second && span.first >= 0 && span.second >= 0; +} + +template <typename T> +bool DoesCandidateConflict( + const int considered_candidate, const std::vector<T>& candidates, + const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) { + if (chosen_indices_set.empty()) { + return false; + } + + auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate); + // Check conflict on the right. + if (conflicting_it != chosen_indices_set.end() && + SpansOverlap(candidates[considered_candidate].span, + candidates[*conflicting_it].span)) { + return true; + } + + // Check conflict on the left. + // If we can't go more left, there can't be a conflict: + if (conflicting_it == chosen_indices_set.begin()) { + return false; + } + // Otherwise move one span left and insert if it doesn't overlap with the + // candidate. + --conflicting_it; + if (!SpansOverlap(candidates[considered_candidate].span, + candidates[*conflicting_it].span)) { + return false; + } + + return true; +} + +// Marks a span in a sequence of tokens. The first element is the index of the +// first token in the span, and the second element is the index of the token one +// past the end of the span. +// TODO(b/71982294): Make it a struct. +using TokenSpan = std::pair<TokenIndex, TokenIndex>; + +// Returns the size of the token span. Assumes that the span is valid. +inline int TokenSpanSize(const TokenSpan& token_span) { + return token_span.second - token_span.first; +} + +// Returns a token span consisting of one token. +inline TokenSpan SingleTokenSpan(int token_index) { + return {token_index, token_index + 1}; +} + +// Returns an intersection of two token spans. Assumes that both spans are valid +// and overlapping. +inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1, + const TokenSpan& token_span2) { + return {std::max(token_span1.first, token_span2.first), + std::min(token_span1.second, token_span2.second)}; +} + +// Returns and expanded token span by adding a certain number of tokens on its +// left and on its right. +inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span, + int num_tokens_left, int num_tokens_right) { + return {token_span.first - num_tokens_left, + token_span.second + num_tokens_right}; +} + +// Token holds a token, its position in the original string and whether it was +// part of the input span. +struct Token { + std::string value; + CodepointIndex start; + CodepointIndex end; + + // Whether the token is a padding token. + bool is_padding; + + // Default constructor constructs the padding-token. + Token() + : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} + + Token(const std::string& arg_value, CodepointIndex arg_start, + CodepointIndex arg_end) + : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} + + bool operator==(const Token& other) const { + return value == other.value && start == other.start && end == other.end && + is_padding == other.is_padding; + } + + bool IsContainedInSpan(CodepointSpan span) const { + return start >= span.first && end <= span.second; + } +}; + +// Pretty-printing function for Token. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, const Token& token) { + if (!token.is_padding) { + return stream << "Token(\"" << token.value << "\", " << token.start << ", " + << token.end << ")"; + } else { + return stream << "Token()"; + } +} + +enum DatetimeGranularity { + GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this + // structure being uninitialized. + GRANULARITY_YEAR = 0, + GRANULARITY_MONTH = 1, + GRANULARITY_WEEK = 2, + GRANULARITY_DAY = 3, + GRANULARITY_HOUR = 4, + GRANULARITY_MINUTE = 5, + GRANULARITY_SECOND = 6 +}; + +struct DatetimeParseResult { + // The absolute time in milliseconds since the epoch in UTC. This is derived + // from the reference time and the fields specified in the text - so it may + // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm) + int64 time_ms_utc; + + // The precision of the estimate then in to calculating the milliseconds + DatetimeGranularity granularity; + + DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {} + + DatetimeParseResult(int64 arg_time_ms_utc, + DatetimeGranularity arg_granularity) + : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {} + + bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; } + + bool operator==(const DatetimeParseResult& other) const { + return granularity == other.granularity && time_ms_utc == other.time_ms_utc; + } +}; + +const float kFloatCompareEpsilon = 1e-5; + +struct DatetimeParseResultSpan { + CodepointSpan span; + DatetimeParseResult data; + float target_classification_score; + float priority_score; + + bool operator==(const DatetimeParseResultSpan& other) const { + return span == other.span && data.granularity == other.data.granularity && + data.time_ms_utc == other.data.time_ms_utc && + std::abs(target_classification_score - + other.target_classification_score) < kFloatCompareEpsilon && + std::abs(priority_score - other.priority_score) < + kFloatCompareEpsilon; + } +}; + +// Pretty-printing function for DatetimeParseResultSpan. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, + const DatetimeParseResultSpan& value) { + return stream << "DatetimeParseResultSpan({" << value.span.first << ", " + << value.span.second << "}, {/*time_ms_utc=*/ " + << value.data.time_ms_utc << ", /*granularity=*/ " + << value.data.granularity << "})"; +} + +struct ClassificationResult { + std::string collection; + float score; + DatetimeParseResult datetime_parse_result; + std::string serialized_knowledge_result; + + // Internal score used for conflict resolution. + float priority_score; + + // Extra information. + std::map<std::string, Variant> extra; + + explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} + + ClassificationResult(const std::string& arg_collection, float arg_score) + : collection(arg_collection), + score(arg_score), + priority_score(arg_score) {} + + ClassificationResult(const std::string& arg_collection, float arg_score, + float arg_priority_score) + : collection(arg_collection), + score(arg_score), + priority_score(arg_priority_score) {} +}; + +// Pretty-printing function for ClassificationResult. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, const ClassificationResult& result) { + return stream << "ClassificationResult(" << result.collection << ", " + << result.score << ")"; +} + +// Pretty-printing function for std::vector<ClassificationResult>. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, + const std::vector<ClassificationResult>& results) { + stream = stream << "{\n"; + for (const ClassificationResult& result : results) { + stream = stream << " " << result << "\n"; + } + stream = stream << "}"; + return stream; +} + +// Represents a result of Annotate call. +struct AnnotatedSpan { + // Unicode codepoint indices in the input string. + CodepointSpan span = {kInvalidIndex, kInvalidIndex}; + + // Classification result for the span. + std::vector<ClassificationResult> classification; +}; + +// Pretty-printing function for AnnotatedSpan. +inline logging::LoggingStringStream& operator<<( + logging::LoggingStringStream& stream, const AnnotatedSpan& span) { + std::string best_class; + float best_score = -1; + if (!span.classification.empty()) { + best_class = span.classification[0].collection; + best_score = span.classification[0].score; + } + return stream << "Span(" << span.span.first << ", " << span.span.second + << ", " << best_class << ", " << best_score << ")"; +} + +// StringPiece analogue for std::vector<T>. +template <class T> +class VectorSpan { + public: + VectorSpan() : begin_(), end_() {} + VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) + : begin_(v.begin()), end_(v.end()) {} + VectorSpan(typename std::vector<T>::const_iterator begin, + typename std::vector<T>::const_iterator end) + : begin_(begin), end_(end) {} + + const T& operator[](typename std::vector<T>::size_type i) const { + return *(begin_ + i); + } + + int size() const { return end_ - begin_; } + typename std::vector<T>::const_iterator begin() const { return begin_; } + typename std::vector<T>::const_iterator end() const { return end_; } + const float* data() const { return &(*begin_); } + + private: + typename std::vector<T>::const_iterator begin_; + typename std::vector<T>::const_iterator end_; +}; + +struct DateParseData { + enum Relation { + NEXT = 1, + NEXT_OR_SAME = 2, + LAST = 3, + NOW = 4, + TOMORROW = 5, + YESTERDAY = 6, + PAST = 7, + FUTURE = 8 + }; + + enum RelationType { + SUNDAY = 1, + MONDAY = 2, + TUESDAY = 3, + WEDNESDAY = 4, + THURSDAY = 5, + FRIDAY = 6, + SATURDAY = 7, + DAY = 8, + WEEK = 9, + MONTH = 10, + YEAR = 11 + }; + + enum Fields { + YEAR_FIELD = 1 << 0, + MONTH_FIELD = 1 << 1, + DAY_FIELD = 1 << 2, + HOUR_FIELD = 1 << 3, + MINUTE_FIELD = 1 << 4, + SECOND_FIELD = 1 << 5, + AMPM_FIELD = 1 << 6, + ZONE_OFFSET_FIELD = 1 << 7, + DST_OFFSET_FIELD = 1 << 8, + RELATION_FIELD = 1 << 9, + RELATION_TYPE_FIELD = 1 << 10, + RELATION_DISTANCE_FIELD = 1 << 11 + }; + + enum AMPM { AM = 0, PM = 1 }; + + enum TimeUnit { + DAYS = 1, + WEEKS = 2, + MONTHS = 3, + HOURS = 4, + MINUTES = 5, + SECONDS = 6, + YEARS = 7 + }; + + // Bit mask of fields which have been set on the struct + int field_set_mask; + + // Fields describing absolute date fields. + // Year of the date seen in the text match. + int year; + // Month of the year starting with January = 1. + int month; + // Day of the month starting with 1. + int day_of_month; + // Hour of the day with a range of 0-23, + // values less than 12 need the AMPM field below or heuristics + // to definitively determine the time. + int hour; + // Hour of the day with a range of 0-59. + int minute; + // Hour of the day with a range of 0-59. + int second; + // 0 == AM, 1 == PM + int ampm; + // Number of hours offset from UTC this date time is in. + int zone_offset; + // Number of hours offest for DST + int dst_offset; + + // The permutation from now that was made to find the date time. + Relation relation; + // The unit of measure of the change to the date time. + RelationType relation_type; + // The number of units of change that were made. + int relation_distance; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ diff --git a/annotator/zlib-utils.cc b/annotator/zlib-utils.cc new file mode 100644 index 0000000..6efe025 --- /dev/null +++ b/annotator/zlib-utils.cc @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/zlib-utils.h" + +#include <memory> + +#include "utils/base/logging.h" +#include "utils/zlib/zlib.h" + +namespace libtextclassifier3 { + +// Compress rule fields in the model. +bool CompressModel(ModelT* model) { + std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance(); + if (!zlib_compressor) { + TC3_LOG(ERROR) << "Cannot compress model."; + return false; + } + + // Compress regex rules. + if (model->regex_model != nullptr) { + for (int i = 0; i < model->regex_model->patterns.size(); i++) { + RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); + pattern->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(pattern->pattern, + pattern->compressed_pattern.get()); + pattern->pattern.clear(); + } + } + + // Compress date-time rules. + if (model->datetime_model != nullptr) { + for (int i = 0; i < model->datetime_model->patterns.size(); i++) { + DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); + for (int j = 0; j < pattern->regexes.size(); j++) { + DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); + regex->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(regex->pattern, + regex->compressed_pattern.get()); + regex->pattern.clear(); + } + } + for (int i = 0; i < model->datetime_model->extractors.size(); i++) { + DatetimeModelExtractorT* extractor = + model->datetime_model->extractors[i].get(); + extractor->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(extractor->pattern, + extractor->compressed_pattern.get()); + extractor->pattern.clear(); + } + } + return true; +} + +bool DecompressModel(ModelT* model) { + std::unique_ptr<ZlibDecompressor> zlib_decompressor = + ZlibDecompressor::Instance(); + if (!zlib_decompressor) { + TC3_LOG(ERROR) << "Cannot initialize decompressor."; + return false; + } + + // Decompress regex rules. + if (model->regex_model != nullptr) { + for (int i = 0; i < model->regex_model->patterns.size(); i++) { + RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); + if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(), + &pattern->pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; + return false; + } + pattern->compressed_pattern.reset(nullptr); + } + } + + // Decompress date-time rules. + if (model->datetime_model != nullptr) { + for (int i = 0; i < model->datetime_model->patterns.size(); i++) { + DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); + for (int j = 0; j < pattern->regexes.size(); j++) { + DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); + if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(), + ®ex->pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j; + return false; + } + regex->compressed_pattern.reset(nullptr); + } + } + for (int i = 0; i < model->datetime_model->extractors.size(); i++) { + DatetimeModelExtractorT* extractor = + model->datetime_model->extractors[i].get(); + if (!zlib_decompressor->MaybeDecompress( + extractor->compressed_pattern.get(), &extractor->pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; + return false; + } + extractor->compressed_pattern.reset(nullptr); + } + } + return true; +} + +std::string CompressSerializedModel(const std::string& model) { + std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str()); + TC3_CHECK(unpacked_model != nullptr); + TC3_CHECK(CompressModel(unpacked_model.get())); + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +} // namespace libtextclassifier3 diff --git a/annotator/zlib-utils.h b/annotator/zlib-utils.h new file mode 100644 index 0000000..462a02b --- /dev/null +++ b/annotator/zlib-utils.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Functions to compress and decompress low entropy entries in the model. + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_ + +#include "annotator/model_generated.h" + +namespace libtextclassifier3 { + +// Compresses regex and datetime rules in the model in place. +bool CompressModel(ModelT* model); + +// Decompresses regex and datetime rules in the model in place. +bool DecompressModel(ModelT* model); + +// Compresses regex and datetime rules in the model. +std::string CompressSerializedModel(const std::string& model); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_ diff --git a/annotator/zlib-utils_test.cc b/annotator/zlib-utils_test.cc new file mode 100644 index 0000000..7a8d775 --- /dev/null +++ b/annotator/zlib-utils_test.cc @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/zlib-utils.h" + +#include <memory> + +#include "annotator/model_generated.h" +#include "utils/zlib/zlib.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { + +TEST(ZlibUtilsTest, CompressModel) { + ModelT model; + model.regex_model.reset(new RegexModelT); + model.regex_model->patterns.emplace_back(new RegexModel_::PatternT); + model.regex_model->patterns.back()->pattern = "this is a test pattern"; + model.regex_model->patterns.emplace_back(new RegexModel_::PatternT); + model.regex_model->patterns.back()->pattern = "this is a second test pattern"; + + model.datetime_model.reset(new DatetimeModelT); + model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT); + model.datetime_model->patterns.back()->regexes.emplace_back( + new DatetimeModelPattern_::RegexT); + model.datetime_model->patterns.back()->regexes.back()->pattern = + "an example datetime pattern"; + model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT); + model.datetime_model->extractors.back()->pattern = + "an example datetime extractor"; + + // Compress the model. + EXPECT_TRUE(CompressModel(&model)); + + // Sanity check that uncompressed field is removed. + EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty()); + EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty()); + EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty()); + EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty()); + + // Pack and load the model. + flatbuffers::FlatBufferBuilder builder; + builder.Finish(Model::Pack(builder, &model)); + const Model* compressed_model = + GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer())); + ASSERT_TRUE(compressed_model != nullptr); + + // Decompress the fields again and check that they match the original. + std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); + ASSERT_TRUE(decompressor != nullptr); + std::string uncompressed_pattern; + EXPECT_TRUE(decompressor->MaybeDecompress( + compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "this is a test pattern"); + EXPECT_TRUE(decompressor->MaybeDecompress( + compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "this is a second test pattern"); + EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model() + ->patterns() + ->Get(0) + ->regexes() + ->Get(0) + ->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "an example datetime pattern"); + EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model() + ->extractors() + ->Get(0) + ->compressed_pattern(), + &uncompressed_pattern)); + EXPECT_EQ(uncompressed_pattern, "an example datetime extractor"); + + EXPECT_TRUE(DecompressModel(&model)); + EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern"); + EXPECT_EQ(model.regex_model->patterns[1]->pattern, + "this is a second test pattern"); + EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern, + "an example datetime pattern"); + EXPECT_EQ(model.datetime_model->extractors[0]->pattern, + "an example datetime extractor"); +} + +} // namespace libtextclassifier3 |