diff options
author | Nikita Iashchenko <nikitai@google.com> | 2019-01-11 15:34:43 +0000 |
---|---|---|
committer | Nikita Iashchenko <nikitai@google.com> | 2019-01-15 14:50:44 +0000 |
commit | 27845a3f1d089ea5dfcf180d65d51730c586ae53 (patch) | |
tree | 079de816a511f24dfc068c5d7cddae1124670097 | |
parent | 95b129a22e75c8f4350e6bf6a850df723eb271d7 (diff) | |
download | libtextclassifier-27845a3f1d089ea5dfcf180d65d51730c586ae53.tar.gz |
Update external/libtextclassifier
Bug: 119788152
Test: atest frameworks/base/core/tests/coretests/src/android/view/textclassifier plus manual testing
Change-Id: Ifeb3f248a7e96092e5662aed2daf591ffe020d1d
Merged-In: Icb2458967ef51efa2952b3eaddefbf1f7b359930
-rw-r--r-- | Android.bp | 9 | ||||
-rw-r--r-- | Android.mk | 87 | ||||
-rw-r--r-- | annotator/annotator.cc (renamed from text-classifier.cc) | 419 | ||||
-rw-r--r-- | annotator/annotator.h (renamed from text-classifier.h) | 92 | ||||
-rw-r--r-- | annotator/annotator_jni.cc | 434 | ||||
-rw-r--r-- | annotator/annotator_jni.h | 103 | ||||
-rw-r--r-- | annotator/annotator_jni_common.cc | 100 | ||||
-rw-r--r-- | annotator/annotator_jni_common.h | 41 | ||||
-rw-r--r-- | annotator/annotator_jni_test.cc (renamed from textclassifier_jni_test.cc) | 10 | ||||
-rw-r--r-- | annotator/annotator_test.cc (renamed from text-classifier_test.cc) | 559 | ||||
-rw-r--r-- | annotator/cached-features.cc (renamed from cached-features.cc) | 14 | ||||
-rw-r--r-- | annotator/cached-features.h (renamed from cached-features.h) | 18 | ||||
-rw-r--r-- | annotator/cached-features_test.cc (renamed from cached-features_test.cc) | 12 | ||||
-rw-r--r-- | annotator/datetime/extractor.cc (renamed from datetime/extractor.cc) | 36 | ||||
-rw-r--r-- | annotator/datetime/extractor.h (renamed from datetime/extractor.h) | 22 | ||||
-rw-r--r-- | annotator/datetime/parser.cc (renamed from datetime/parser.cc) | 29 | ||||
-rw-r--r-- | annotator/datetime/parser.h (renamed from datetime/parser.h) | 31 | ||||
-rw-r--r-- | annotator/datetime/parser_test.cc (renamed from datetime/parser_test.cc) | 118 | ||||
-rw-r--r-- | annotator/feature-processor.cc (renamed from feature-processor.cc) | 52 | ||||
-rw-r--r-- | annotator/feature-processor.h (renamed from feature-processor.h) | 44 | ||||
-rw-r--r-- | annotator/feature-processor_test.cc (renamed from feature-processor_test.cc) | 140 | ||||
-rw-r--r-- | annotator/knowledge/knowledge-engine-dummy.h | 47 | ||||
-rw-r--r-- | annotator/knowledge/knowledge-engine.h (renamed from util/calendar/calendar.h) | 10 | ||||
-rw-r--r-- | annotator/model-executor.cc (renamed from model-executor.cc) | 114 | ||||
-rw-r--r-- | annotator/model-executor.h (renamed from model-executor.h) | 86 | ||||
-rwxr-xr-x | annotator/model.fbs (renamed from model.fbs) | 170 | ||||
-rw-r--r-- | annotator/quantization.cc (renamed from quantization.cc) | 12 | ||||
-rw-r--r-- | annotator/quantization.h (renamed from quantization.h) | 14 | ||||
-rw-r--r-- | annotator/quantization_test.cc (renamed from quantization_test.cc) | 8 | ||||
-rw-r--r-- | annotator/strip-unpaired-brackets.cc (renamed from strip-unpaired-brackets.cc) | 16 | ||||
-rw-r--r-- | annotator/strip-unpaired-brackets.h (renamed from strip-unpaired-brackets.h) | 16 | ||||
-rw-r--r-- | annotator/strip-unpaired-brackets_test.cc (renamed from strip-unpaired-brackets_test.cc) | 41 | ||||
-rw-r--r-- | annotator/test_data/test_model.fb (renamed from test_data/test_model.fb) | bin | 578732 -> 522688 bytes | |||
-rw-r--r-- | annotator/test_data/test_model_cc.fb (renamed from test_data/test_model_cc.fb) | bin | 608192 -> 552160 bytes | |||
-rw-r--r-- | annotator/test_data/wrong_embeddings.fb | bin | 0 -> 288628 bytes | |||
-rw-r--r-- | annotator/token-feature-extractor.cc (renamed from token-feature-extractor.cc) | 26 | ||||
-rw-r--r-- | annotator/token-feature-extractor.h (renamed from token-feature-extractor.h) | 18 | ||||
-rw-r--r-- | annotator/token-feature-extractor_test.cc (renamed from token-feature-extractor_test.cc) | 92 | ||||
-rw-r--r-- | annotator/tokenizer.cc (renamed from tokenizer.cc) | 12 | ||||
-rw-r--r-- | annotator/tokenizer.h (renamed from tokenizer.h) | 20 | ||||
-rw-r--r-- | annotator/tokenizer_test.cc (renamed from tokenizer_test.cc) | 8 | ||||
-rw-r--r-- | annotator/types-test-util.h (renamed from types-test-util.h) | 16 | ||||
-rw-r--r-- | annotator/types.h (renamed from types.h) | 36 | ||||
-rw-r--r-- | annotator/zlib-utils.cc | 128 | ||||
-rw-r--r-- | annotator/zlib-utils.h | 37 | ||||
-rw-r--r-- | annotator/zlib-utils_test.cc (renamed from zlib-utils_test.cc) | 39 | ||||
-rw-r--r-- | generate_flatbuffers.mk | 46 | ||||
-rw-r--r-- | java/com/google/android/textclassifier/AnnotatorModel.java | 342 | ||||
-rwxr-xr-x | model_generated.h | 3718 | ||||
-rw-r--r-- | models/textclassifier.ar.model | bin | 599384 -> 539336 bytes | |||
-rw-r--r-- | models/textclassifier.en.model | bin | 599704 -> 542752 bytes | |||
-rw-r--r-- | models/textclassifier.es.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.fr.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.it.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.ja.model | bin | 533832 -> 473064 bytes | |||
-rw-r--r-- | models/textclassifier.ko.model | bin | 599448 -> 539368 bytes | |||
-rw-r--r-- | models/textclassifier.nl.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.pl.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.pt.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.ru.model | bin | 599448 -> 539368 bytes | |||
-rw-r--r-- | models/textclassifier.th.model | bin | 533352 -> 472808 bytes | |||
-rw-r--r-- | models/textclassifier.tr.model | bin | 599704 -> 538800 bytes | |||
-rw-r--r-- | models/textclassifier.universal.model | bin | 218748 -> 39344 bytes | |||
-rw-r--r-- | models/textclassifier.zh-Hant.model | bin | 533740 -> 473036 bytes | |||
-rw-r--r-- | models/textclassifier.zh.model | bin | 533736 -> 473032 bytes | |||
-rw-r--r-- | test_data/wrong_embeddings.fb | bin | 287020 -> 0 bytes | |||
-rw-r--r-- | textclassifier_jni.cc | 496 | ||||
-rw-r--r-- | textclassifier_jni.h | 134 | ||||
-rw-r--r-- | util/calendar/calendar-icu.cc | 436 | ||||
-rw-r--r-- | util/calendar/calendar-icu.h | 41 | ||||
-rw-r--r-- | util/calendar/calendar_test.cc | 129 | ||||
-rw-r--r-- | util/gtl/map_util.h | 65 | ||||
-rw-r--r-- | util/gtl/stl_util.h | 55 | ||||
-rw-r--r-- | util/hash/hash.cc | 6 | ||||
-rw-r--r-- | util/hash/hash.h | 4 | ||||
-rw-r--r-- | util/java/string_utils.h | 29 | ||||
-rw-r--r-- | util/strings/utf8.cc | 42 | ||||
-rw-r--r-- | util/utf8/unilib-icu.cc | 293 | ||||
-rw-r--r-- | utils/base/casts.h (renamed from util/base/casts.h) | 12 | ||||
-rw-r--r-- | utils/base/config.h (renamed from util/base/config.h) | 12 | ||||
-rw-r--r-- | utils/base/endian.h (renamed from util/base/endian.h) | 14 | ||||
-rw-r--r-- | utils/base/integral_types.h (renamed from util/base/integral_types.h) | 14 | ||||
-rw-r--r-- | utils/base/logging.cc (renamed from util/base/logging.cc) | 14 | ||||
-rw-r--r-- | utils/base/logging.h (renamed from util/base/logging.h) | 96 | ||||
-rw-r--r-- | utils/base/logging_levels.h (renamed from util/base/logging_levels.h) | 12 | ||||
-rw-r--r-- | utils/base/logging_raw.cc (renamed from util/base/logging_raw.cc) | 14 | ||||
-rw-r--r-- | utils/base/logging_raw.h (renamed from util/base/logging_raw.h) | 14 | ||||
-rw-r--r-- | utils/base/macros.h (renamed from util/base/macros.h) | 44 | ||||
-rw-r--r-- | utils/base/port.h (renamed from util/base/port.h) | 22 | ||||
-rw-r--r-- | utils/calendar/calendar-common.h | 278 | ||||
-rw-r--r-- | utils/calendar/calendar-javaicu.cc | 190 | ||||
-rw-r--r-- | utils/calendar/calendar-javaicu.h | 89 | ||||
-rw-r--r-- | utils/calendar/calendar.h | 23 | ||||
-rw-r--r-- | utils/calendar/calendar_test.cc | 244 | ||||
-rw-r--r-- | utils/checksum.cc | 50 | ||||
-rw-r--r-- | utils/checksum.h | 34 | ||||
-rw-r--r-- | utils/checksum_test.cc | 57 | ||||
-rw-r--r-- | utils/flatbuffers.cc (renamed from util/flatbuffers.cc) | 8 | ||||
-rw-r--r-- | utils/flatbuffers.h (renamed from util/flatbuffers.h) | 14 | ||||
-rw-r--r-- | utils/hash/farmhash.cc (renamed from util/hash/farmhash.cc) | 18 | ||||
-rw-r--r-- | utils/hash/farmhash.h (renamed from util/hash/farmhash.h) | 10 | ||||
-rw-r--r-- | utils/i18n/locale.cc (renamed from util/i18n/locale.cc) | 10 | ||||
-rw-r--r-- | utils/i18n/locale.h (renamed from util/i18n/locale.h) | 14 | ||||
-rw-r--r-- | utils/i18n/locale_test.cc (renamed from util/i18n/locale_test.cc) | 8 | ||||
-rwxr-xr-x | utils/intents/intent-config.fbs | 192 | ||||
-rw-r--r-- | utils/java/jni-base.cc | 72 | ||||
-rw-r--r-- | utils/java/jni-base.h | 83 | ||||
-rw-r--r-- | utils/java/jni-cache.cc | 284 | ||||
-rw-r--r-- | utils/java/jni-cache.h | 141 | ||||
-rw-r--r-- | utils/java/scoped_global_ref.h (renamed from util/java/scoped_global_ref.h) | 24 | ||||
-rw-r--r-- | utils/java/scoped_local_ref.h (renamed from util/java/scoped_local_ref.h) | 16 | ||||
-rw-r--r-- | utils/java/string_utils.cc (renamed from util/java/string_utils.cc) | 40 | ||||
-rw-r--r-- | utils/java/string_utils.h | 76 | ||||
-rw-r--r-- | utils/math/fastexp.cc (renamed from util/math/fastexp.cc) | 8 | ||||
-rw-r--r-- | utils/math/fastexp.h (renamed from util/math/fastexp.h) | 20 | ||||
-rw-r--r-- | utils/math/softmax.cc (renamed from util/math/softmax.cc) | 16 | ||||
-rw-r--r-- | utils/math/softmax.h (renamed from util/math/softmax.h) | 12 | ||||
-rw-r--r-- | utils/memory/mmap.cc (renamed from util/memory/mmap.cc) | 24 | ||||
-rw-r--r-- | utils/memory/mmap.h (renamed from util/memory/mmap.h) | 18 | ||||
-rw-r--r-- | utils/optional.h | 81 | ||||
-rw-r--r-- | utils/strings/numbers.cc (renamed from util/strings/numbers.cc) | 8 | ||||
-rw-r--r-- | utils/strings/numbers.h (renamed from util/strings/numbers.h) | 14 | ||||
-rw-r--r-- | utils/strings/numbers_test.cc (renamed from util/strings/numbers_test.cc) | 10 | ||||
-rw-r--r-- | utils/strings/split.cc (renamed from util/strings/split.cc) | 8 | ||||
-rw-r--r-- | utils/strings/split.h (renamed from util/strings/split.h) | 14 | ||||
-rw-r--r-- | utils/strings/stringpiece.h (renamed from util/strings/stringpiece.h) | 58 | ||||
-rw-r--r-- | utils/strings/stringpiece_test.cc | 58 | ||||
-rw-r--r-- | utils/strings/utf8.cc | 52 | ||||
-rw-r--r-- | utils/strings/utf8.h (renamed from util/strings/utf8.h) | 16 | ||||
-rw-r--r-- | utils/strings/utf8_test.cc | 59 | ||||
-rw-r--r-- | utils/tensor-view.cc (renamed from tensor-view.cc) | 8 | ||||
-rw-r--r-- | utils/tensor-view.h (renamed from tensor-view.h) | 12 | ||||
-rw-r--r-- | utils/tensor-view_test.cc (renamed from tensor-view_test.cc) | 8 | ||||
-rw-r--r-- | utils/testing/logging_event_listener.h | 62 | ||||
-rw-r--r-- | utils/tflite-model-executor.cc | 129 | ||||
-rw-r--r-- | utils/tflite-model-executor.h | 123 | ||||
-rw-r--r-- | utils/utf8/unicodetext.cc (renamed from util/utf8/unicodetext.cc) | 12 | ||||
-rw-r--r-- | utils/utf8/unicodetext.h (renamed from util/utf8/unicodetext.h) | 16 | ||||
-rw-r--r-- | utils/utf8/unicodetext_test.cc (renamed from util/utf8/unicodetext_test.cc) | 18 | ||||
-rw-r--r-- | utils/utf8/unilib-javaicu.cc | 694 | ||||
-rw-r--r-- | utils/utf8/unilib-javaicu.h (renamed from util/utf8/unilib-icu.h) | 66 | ||||
-rw-r--r-- | utils/utf8/unilib.h (renamed from util/utf8/unilib.h) | 12 | ||||
-rw-r--r-- | utils/utf8/unilib_test.cc (renamed from util/utf8/unilib_test.cc) | 180 | ||||
-rw-r--r-- | utils/variant.h | 64 | ||||
-rwxr-xr-x | utils/zlib/buffer.fbs | 22 | ||||
-rw-r--r-- | utils/zlib/zlib.cc | 174 | ||||
-rw-r--r-- | utils/zlib/zlib.h (renamed from zlib-utils.h) | 34 | ||||
-rw-r--r-- | zlib-utils.cc | 269 |
148 files changed, 6365 insertions, 7429 deletions
@@ -21,7 +21,7 @@ cc_library_headers { cc_defaults { name: "libtextclassifier_hash_defaults", srcs: [ - "util/hash/farmhash.cc", + "utils/hash/farmhash.cc", "util/hash/hash.cc" ], cflags: [ @@ -44,3 +44,10 @@ cc_library_static { sdk_version: "current", stl: "libc++_static", } + +java_library_static { + name: "libtextclassifier-java", + sdk_version: "core_current", + no_framework_libs: true, + srcs: ["java/**/*.java"], +} @@ -33,12 +33,16 @@ MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS := \ -Wno-undefined-var-template \ -Wno-unused-function \ -Wno-unused-parameter \ + -Wno-extern-c-compat MY_LIBTEXTCLASSIFIER_CFLAGS := \ $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) \ -fvisibility=hidden \ -DLIBTEXTCLASSIFIER_UNILIB_ICU \ - -DZLIB_CONST + -DZLIB_CONST \ + -DSAFTM_COMPACT_LOGGING \ + -DTC3_UNILIB_JAVAICU \ + -DTC3_CALENDAR_JAVAICU # Only enable debug logging in userdebug/eng builds. ifneq (,$(filter userdebug eng, $(TARGET_BUILD_VARIANT))) @@ -46,27 +50,16 @@ ifneq (,$(filter userdebug eng, $(TARGET_BUILD_VARIANT))) endif # ----------------- -# flatbuffers -# ----------------- - -# Empty static library so that other projects can include just the basic -# FlatBuffers headers as a module. -include $(CLEAR_VARS) -LOCAL_MODULE := flatbuffers -LOCAL_EXPORT_C_INCLUDES := $(LOCAL_PATH)/include -LOCAL_EXPORT_CPPFLAGS := -std=c++11 -fexceptions -Wall \ - -DFLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE - -include $(BUILD_STATIC_LIBRARY) - -# ----------------- # libtextclassifier # ----------------- include $(CLEAR_VARS) LOCAL_MODULE := libtextclassifier - +LOCAL_MODULE_CLASS := SHARED_LIBRARIES LOCAL_CPP_EXTENSION := .cc + +include $(LOCAL_PATH)/generate_flatbuffers.mk + LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS) LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS) @@ -75,24 +68,22 @@ LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc test-util.%,$(call all-subdir- LOCAL_C_INCLUDES := $(TOP)/external/zlib LOCAL_C_INCLUDES += $(TOP)/external/tensorflow LOCAL_C_INCLUDES += $(TOP)/external/flatbuffers/include +LOCAL_C_INCLUDES += $(TOP)/external/libutf +LOCAL_C_INCLUDES += $(intermediates) LOCAL_SHARED_LIBRARIES += liblog -LOCAL_SHARED_LIBRARIES += libicuuc -LOCAL_SHARED_LIBRARIES += libicui18n LOCAL_SHARED_LIBRARIES += libtflite LOCAL_SHARED_LIBRARIES += libz -LOCAL_STATIC_LIBRARIES += flatbuffers +LOCAL_STATIC_LIBRARIES += libutf -LOCAL_REQUIRED_MODULES := textclassifier.en.model -LOCAL_REQUIRED_MODULES += textclassifier.universal.model +LOCAL_REQUIRED_MODULES := libtextclassifier_annotator_en_model +LOCAL_REQUIRED_MODULES += libtextclassifier_annotator_universal_model LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds -# TODO(b/119788152): Remove this when the bug is fixed -LOCAL_CFLAGS += -DANDROID_LINK_SHARED_ICU4C -LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\"" -LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\"" +LOCAL_CPPFLAGS_32 += -DTC3_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\"" +LOCAL_CPPFLAGS_64 += -DTC3_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\"" include $(BUILD_SHARED_LIBRARY) @@ -103,43 +94,46 @@ include $(BUILD_SHARED_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := libtextclassifier_tests +LOCAL_MODULE_CLASS := NATIVE_TESTS LOCAL_COMPATIBILITY_SUITE := device-tests LOCAL_MODULE_TAGS := tests - LOCAL_CPP_EXTENSION := .cc + +include $(LOCAL_PATH)/generate_flatbuffers.mk + LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS) LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS) -LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, test_data) +LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, annotator/test_data) -# TODO(b/119788152): Remove this when the bug is fixed -LOCAL_CFLAGS += -DANDROID_LINK_SHARED_ICU4C -LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\"" -LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\"" +LOCAL_CPPFLAGS_32 += -DTC3_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\"" +LOCAL_CPPFLAGS_64 += -DTC3_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\"" -LOCAL_SRC_FILES := $(call all-subdir-cpp-files) +# TODO: Do not filter out tflite test once the dependency issue is resolved. +LOCAL_SRC_FILES := $(filter-out utils/tflite/%_test.cc,$(call all-subdir-cpp-files)) LOCAL_C_INCLUDES := $(TOP)/external/zlib LOCAL_C_INCLUDES += $(TOP)/external/tensorflow LOCAL_C_INCLUDES += $(TOP)/external/flatbuffers/include +LOCAL_C_INCLUDES += $(TOP)/external/libutf +LOCAL_C_INCLUDES += $(intermediates) -LOCAL_STATIC_LIBRARIES += libgmock LOCAL_SHARED_LIBRARIES += liblog -LOCAL_SHARED_LIBRARIES += libicuuc -LOCAL_SHARED_LIBRARIES += libicui18n LOCAL_SHARED_LIBRARIES += libtflite LOCAL_SHARED_LIBRARIES += libz -LOCAL_STATIC_LIBRARIES += flatbuffers +LOCAL_STATIC_LIBRARIES += libgmock +LOCAL_STATIC_LIBRARIES += libutf include $(BUILD_NATIVE_TEST) -# ---------------------- -# Smart Selection models -# ---------------------- +# ---------------- +# Annotator models +# ---------------- include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.en.model +LOCAL_MODULE := libtextclassifier_annotator_en_model +LOCAL_MODULE_STEM := textclassifier.en.model LOCAL_MODULE_CLASS := ETC LOCAL_MODULE_OWNER := google LOCAL_SRC_FILES := ./models/textclassifier.en.model @@ -147,19 +141,10 @@ LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier include $(BUILD_PREBUILT) include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.universal.model +LOCAL_MODULE := libtextclassifier_annotator_universal_model +LOCAL_MODULE_STEM := textclassifier.universal.model LOCAL_MODULE_CLASS := ETC LOCAL_MODULE_OWNER := google LOCAL_SRC_FILES := ./models/textclassifier.universal.model LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier include $(BUILD_PREBUILT) - -# ----------------------- -# Smart Selection bundles -# ----------------------- - -include $(CLEAR_VARS) -LOCAL_MODULE := textclassifier.bundle1 -LOCAL_REQUIRED_MODULES := textclassifier.en.model -LOCAL_CFLAGS := $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) -include $(BUILD_STATIC_LIBRARY) diff --git a/text-classifier.cc b/annotator/annotator.cc index e20813a..2be9d3c 100644 --- a/text-classifier.cc +++ b/annotator/annotator.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "text-classifier.h" +#include "annotator/annotator.h" #include <algorithm> #include <cctype> @@ -22,39 +22,76 @@ #include <iterator> #include <numeric> -#include "util/base/logging.h" -#include "util/math/softmax.h" -#include "util/utf8/unicodetext.h" +#include "utils/base/logging.h" +#include "utils/checksum.h" +#include "utils/math/softmax.h" +#include "utils/utf8/unicodetext.h" -namespace libtextclassifier2 { -const std::string& TextClassifier::kOtherCollection = +namespace libtextclassifier3 { +const std::string& Annotator::kOtherCollection = *[]() { return new std::string("other"); }(); -const std::string& TextClassifier::kPhoneCollection = +const std::string& Annotator::kPhoneCollection = *[]() { return new std::string("phone"); }(); -const std::string& TextClassifier::kAddressCollection = +const std::string& Annotator::kAddressCollection = *[]() { return new std::string("address"); }(); -const std::string& TextClassifier::kDateCollection = +const std::string& Annotator::kDateCollection = *[]() { return new std::string("date"); }(); +const std::string& Annotator::kUrlCollection = + *[]() { return new std::string("url"); }(); +const std::string& Annotator::kFlightCollection = + *[]() { return new std::string("flight"); }(); +const std::string& Annotator::kEmailCollection = + *[]() { return new std::string("email"); }(); +const std::string& Annotator::kIbanCollection = + *[]() { return new std::string("iban"); }(); +const std::string& Annotator::kPaymentCardCollection = + *[]() { return new std::string("payment_card"); }(); +const std::string& Annotator::kIsbnCollection = + *[]() { return new std::string("isbn"); }(); +const std::string& Annotator::kTrackingNumberCollection = + *[]() { return new std::string("tracking_number"); }(); namespace { const Model* LoadAndVerifyModel(const void* addr, int size) { - const Model* model = GetModel(addr); - flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); - if (model->Verify(verifier)) { - return model; + if (VerifyModelBuffer(verifier)) { + return GetModel(addr); } else { return nullptr; } } + +// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will +// create a new instance, assign ownership to owned_lib, and return it. +const UniLib* MaybeCreateUnilib(const UniLib* lib, + std::unique_ptr<UniLib>* owned_lib) { + if (lib) { + return lib; + } else { + owned_lib->reset(new UniLib); + return owned_lib->get(); + } +} + +// As above, but for CalendarLib. +const CalendarLib* MaybeCreateCalendarlib( + const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) { + if (lib) { + return lib; + } else { + owned_lib->reset(new CalendarLib); + return owned_lib->get(); + } +} + } // namespace tflite::Interpreter* InterpreterManager::SelectionInterpreter() { if (!selection_interpreter_) { - TC_CHECK(selection_executor_); + TC3_CHECK(selection_executor_); selection_interpreter_ = selection_executor_->CreateInterpreter(); if (!selection_interpreter_) { - TC_LOG(ERROR) << "Could not build TFLite interpreter."; + TC3_LOG(ERROR) << "Could not build TFLite interpreter."; } } return selection_interpreter_.get(); @@ -62,24 +99,25 @@ tflite::Interpreter* InterpreterManager::SelectionInterpreter() { tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { if (!classification_interpreter_) { - TC_CHECK(classification_executor_); + TC3_CHECK(classification_executor_); classification_interpreter_ = classification_executor_->CreateInterpreter(); if (!classification_interpreter_) { - TC_LOG(ERROR) << "Could not build TFLite interpreter."; + TC3_LOG(ERROR) << "Could not build TFLite interpreter."; } } return classification_interpreter_.get(); } -std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer( - const char* buffer, int size, const UniLib* unilib) { +std::unique_ptr<Annotator> Annotator::FromUnownedBuffer( + const char* buffer, int size, const UniLib* unilib, + const CalendarLib* calendarlib) { const Model* model = LoadAndVerifyModel(buffer, size); if (model == nullptr) { return nullptr; } auto classifier = - std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib)); + std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib)); if (!classifier->IsInitialized()) { return nullptr; } @@ -87,22 +125,24 @@ std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer( return classifier; } -std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap( - std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) { + +std::unique_ptr<Annotator> Annotator::FromScopedMmap( + std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib, + const CalendarLib* calendarlib) { if (!(*mmap)->handle().ok()) { - TC_VLOG(1) << "Mmap failed."; + TC3_VLOG(1) << "Mmap failed."; return nullptr; } const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), (*mmap)->handle().num_bytes()); if (!model) { - TC_LOG(ERROR) << "Model verification failed."; + TC3_LOG(ERROR) << "Model verification failed."; return nullptr; } - auto classifier = - std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib)); + auto classifier = std::unique_ptr<Annotator>( + new Annotator(mmap, model, unilib, calendarlib)); if (!classifier->IsInitialized()) { return nullptr; } @@ -110,29 +150,52 @@ std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap( return classifier; } -std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( - int fd, int offset, int size, const UniLib* unilib) { +std::unique_ptr<Annotator> Annotator::FromFileDescriptor( + int fd, int offset, int size, const UniLib* unilib, + const CalendarLib* calendarlib) { std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); - return FromScopedMmap(&mmap, unilib); + return FromScopedMmap(&mmap, unilib, calendarlib); } -std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( - int fd, const UniLib* unilib) { +std::unique_ptr<Annotator> Annotator::FromFileDescriptor( + int fd, const UniLib* unilib, const CalendarLib* calendarlib) { std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); - return FromScopedMmap(&mmap, unilib); + return FromScopedMmap(&mmap, unilib, calendarlib); } -std::unique_ptr<TextClassifier> TextClassifier::FromPath( - const std::string& path, const UniLib* unilib) { +std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path, + const UniLib* unilib, + const CalendarLib* calendarlib) { std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); - return FromScopedMmap(&mmap, unilib); + return FromScopedMmap(&mmap, unilib, calendarlib); +} + +Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model, + const UniLib* unilib, const CalendarLib* calendarlib) + : model_(model), + mmap_(std::move(*mmap)), + owned_unilib_(nullptr), + unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)), + owned_calendarlib_(nullptr), + calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) { + ValidateAndInitialize(); } -void TextClassifier::ValidateAndInitialize() { +Annotator::Annotator(const Model* model, const UniLib* unilib, + const CalendarLib* calendarlib) + : model_(model), + owned_unilib_(nullptr), + unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)), + owned_calendarlib_(nullptr), + calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) { + ValidateAndInitialize(); +} + +void Annotator::ValidateAndInitialize() { initialized_ = false; if (model_ == nullptr) { - TC_LOG(ERROR) << "No model specified."; + TC3_LOG(ERROR) << "No model specified."; return; } @@ -150,24 +213,24 @@ void TextClassifier::ValidateAndInitialize() { // Annotation requires the selection model. if (model_enabled_for_annotation || model_enabled_for_selection) { if (!model_->selection_options()) { - TC_LOG(ERROR) << "No selection options."; + TC3_LOG(ERROR) << "No selection options."; return; } if (!model_->selection_feature_options()) { - TC_LOG(ERROR) << "No selection feature options."; + TC3_LOG(ERROR) << "No selection feature options."; return; } if (!model_->selection_feature_options()->bounds_sensitive_features()) { - TC_LOG(ERROR) << "No selection bounds sensitive feature options."; + TC3_LOG(ERROR) << "No selection bounds sensitive feature options."; return; } if (!model_->selection_model()) { - TC_LOG(ERROR) << "No selection model."; + TC3_LOG(ERROR) << "No selection model."; return; } - selection_executor_ = ModelExecutor::Instance(model_->selection_model()); + selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model()); if (!selection_executor_) { - TC_LOG(ERROR) << "Could not initialize selection executor."; + TC3_LOG(ERROR) << "Could not initialize selection executor."; return; } selection_feature_processor_.reset( @@ -180,29 +243,29 @@ void TextClassifier::ValidateAndInitialize() { if (model_enabled_for_annotation || model_enabled_for_classification || model_enabled_for_selection) { if (!model_->classification_options()) { - TC_LOG(ERROR) << "No classification options."; + TC3_LOG(ERROR) << "No classification options."; return; } if (!model_->classification_feature_options()) { - TC_LOG(ERROR) << "No classification feature options."; + TC3_LOG(ERROR) << "No classification feature options."; return; } if (!model_->classification_feature_options() ->bounds_sensitive_features()) { - TC_LOG(ERROR) << "No classification bounds sensitive feature options."; + TC3_LOG(ERROR) << "No classification bounds sensitive feature options."; return; } if (!model_->classification_model()) { - TC_LOG(ERROR) << "No clf model."; + TC3_LOG(ERROR) << "No clf model."; return; } classification_executor_ = - ModelExecutor::Instance(model_->classification_model()); + ModelExecutor::FromBuffer(model_->classification_model()); if (!classification_executor_) { - TC_LOG(ERROR) << "Could not initialize classification executor."; + TC3_LOG(ERROR) << "Could not initialize classification executor."; return; } @@ -215,7 +278,7 @@ void TextClassifier::ValidateAndInitialize() { if (model_enabled_for_annotation || model_enabled_for_classification || model_enabled_for_selection) { if (!model_->embedding_model()) { - TC_LOG(ERROR) << "No embedding model."; + TC3_LOG(ERROR) << "No embedding model."; return; } @@ -227,17 +290,17 @@ void TextClassifier::ValidateAndInitialize() { model_->selection_feature_options()->embedding_quantization_bits() != model_->classification_feature_options() ->embedding_quantization_bits())) { - TC_LOG(ERROR) << "Mismatching embedding size/quantization."; + TC3_LOG(ERROR) << "Mismatching embedding size/quantization."; return; } - embedding_executor_ = TFLiteEmbeddingExecutor::Instance( + embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer( model_->embedding_model(), model_->classification_feature_options()->embedding_size(), model_->classification_feature_options() ->embedding_quantization_bits()); if (!embedding_executor_) { - TC_LOG(ERROR) << "Could not initialize embedding executor."; + TC3_LOG(ERROR) << "Could not initialize embedding executor."; return; } } @@ -245,16 +308,16 @@ void TextClassifier::ValidateAndInitialize() { std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); if (model_->regex_model()) { if (!InitializeRegexModel(decompressor.get())) { - TC_LOG(ERROR) << "Could not initialize regex model."; + TC3_LOG(ERROR) << "Could not initialize regex model."; return; } } if (model_->datetime_model()) { - datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(), - *unilib_, decompressor.get()); + datetime_parser_ = DatetimeParser::Instance( + model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get()); if (!datetime_parser_) { - TC_LOG(ERROR) << "Could not initialize datetime parser."; + TC3_LOG(ERROR) << "Could not initialize datetime parser."; return; } } @@ -283,7 +346,7 @@ void TextClassifier::ValidateAndInitialize() { initialized_ = true; } -bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { +bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) { if (!model_->regex_model()->patterns()) { return true; } @@ -296,7 +359,7 @@ bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { regex_pattern->compressed_pattern(), decompressor); if (!compiled_pattern) { - TC_LOG(INFO) << "Failed to load regex pattern"; + TC3_LOG(INFO) << "Failed to load regex pattern"; return false; } @@ -309,10 +372,13 @@ bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { selection_regex_patterns_.push_back(regex_pattern_id); } - regex_patterns_.push_back({regex_pattern->collection_name()->str(), - regex_pattern->target_classification_score(), - regex_pattern->priority_score(), - std::move(compiled_pattern)}); + regex_patterns_.push_back({ + regex_pattern->collection_name()->str(), + regex_pattern->target_classification_score(), + regex_pattern->priority_score(), + std::move(compiled_pattern), + regex_pattern->verification_options(), + }); if (regex_pattern->use_approximate_matching()) { regex_approximate_match_pattern_ids_.insert(regex_pattern_id); } @@ -322,6 +388,18 @@ bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { return true; } +bool Annotator::InitializeKnowledgeEngine( + const std::string& serialized_config) { + std::unique_ptr<KnowledgeEngine> knowledge_engine( + new KnowledgeEngine(unilib_)); + if (!knowledge_engine->Initialize(serialized_config)) { + TC3_LOG(ERROR) << "Failed to initialize the knowledge engine."; + return false; + } + knowledge_engine_ = std::move(knowledge_engine); + return true; +} + namespace { int CountDigits(const std::string& str, CodepointSpan selection_indices) { @@ -347,6 +425,19 @@ std::string ExtractSelection(const std::string& context, std::advance(selection_end, selection_indices.second); return UnicodeText::UTF8Substring(selection_begin, selection_end); } + +bool VerifyCandidate(const VerificationOptions* verification_options, + const std::string& match) { + if (!verification_options) { + return true; + } + if (verification_options->verify_luhn_checksum() && + !VerifyLuhnChecksum(match)) { + return false; + } + return true; +} + } // namespace namespace internal { @@ -356,7 +447,7 @@ namespace internal { CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, const UnicodeText& context_unicode, const UniLib& unilib) { - TC_CHECK(ValidNonEmptySpan(span)); + TC3_CHECK(ValidNonEmptySpan(span)); UnicodeText::const_iterator it; @@ -390,32 +481,32 @@ CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, } } // namespace internal -bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const { +bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const { return !span.classification.empty() && filtered_collections_annotation_.find( span.classification[0].collection) != filtered_collections_annotation_.end(); } -bool TextClassifier::FilteredForClassification( +bool Annotator::FilteredForClassification( const ClassificationResult& classification) const { return filtered_collections_classification_.find(classification.collection) != filtered_collections_classification_.end(); } -bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const { +bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const { return !span.classification.empty() && filtered_collections_selection_.find( span.classification[0].collection) != filtered_collections_selection_.end(); } -CodepointSpan TextClassifier::SuggestSelection( +CodepointSpan Annotator::SuggestSelection( const std::string& context, CodepointSpan click_indices, const SelectionOptions& options) const { CodepointSpan original_click_indices = click_indices; if (!initialized_) { - TC_LOG(ERROR) << "Not initialized"; + TC3_LOG(ERROR) << "Not initialized"; return original_click_indices; } if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { @@ -435,8 +526,8 @@ CodepointSpan TextClassifier::SuggestSelection( click_indices.first >= context_codepoint_size || click_indices.second > context_codepoint_size || click_indices.first >= click_indices.second) { - TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " - << click_indices.first << " " << click_indices.second; + TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " + << click_indices.first << " " << click_indices.second; return original_click_indices; } @@ -448,7 +539,7 @@ CodepointSpan TextClassifier::SuggestSelection( // finding logic finds the clicked token correctly. This modification is // done by the following function. Note, that it's enough to check the left // side of the current selection, because if the white-space is a part of a - // multi-selection, neccessarily both tokens - on the left and the right + // multi-selection, necessarily both tokens - on the left and the right // sides need to be selected. Thus snapping only to the left is sufficient // (there's a check at the bottom that makes sure that if we snap to the // left token but the result does not contain the initial white-space, @@ -463,17 +554,21 @@ CodepointSpan TextClassifier::SuggestSelection( std::vector<Token> tokens; if (!ModelSuggestSelection(context_unicode, click_indices, &interpreter_manager, &tokens, &candidates)) { - TC_LOG(ERROR) << "Model suggest selection failed."; + TC3_LOG(ERROR) << "Model suggest selection failed."; return original_click_indices; } if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) { - TC_LOG(ERROR) << "Regex suggest selection failed."; + TC3_LOG(ERROR) << "Regex suggest selection failed."; return original_click_indices; } if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", options.locales, ModeFlag_SELECTION, &candidates)) { - TC_LOG(ERROR) << "Datetime suggest selection failed."; + TC3_LOG(ERROR) << "Datetime suggest selection failed."; + return original_click_indices; + } + if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) { + TC3_LOG(ERROR) << "Knowledge suggest selection failed."; return original_click_indices; } @@ -488,7 +583,7 @@ CodepointSpan TextClassifier::SuggestSelection( std::vector<int> candidate_indices; if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, &candidate_indices)) { - TC_LOG(ERROR) << "Couldn't resolve conflicts."; + TC3_LOG(ERROR) << "Couldn't resolve conflicts."; return original_click_indices; } @@ -541,10 +636,11 @@ int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, } } // namespace -bool TextClassifier::ResolveConflicts( - const std::vector<AnnotatedSpan>& candidates, const std::string& context, - const std::vector<Token>& cached_tokens, - InterpreterManager* interpreter_manager, std::vector<int>* result) const { +bool Annotator::ResolveConflicts(const std::vector<AnnotatedSpan>& candidates, + const std::string& context, + const std::vector<Token>& cached_tokens, + InterpreterManager* interpreter_manager, + std::vector<int>* result) const { result->clear(); result->reserve(candidates.size()); for (int i = 0; i < candidates.size();) { @@ -575,7 +671,7 @@ namespace { inline bool ClassifiedAsOther( const std::vector<ClassificationResult>& classification) { return !classification.empty() && - classification[0].collection == TextClassifier::kOtherCollection; + classification[0].collection == Annotator::kOtherCollection; } float GetPriorityScore( @@ -588,11 +684,12 @@ float GetPriorityScore( } } // namespace -bool TextClassifier::ResolveConflict( - const std::string& context, const std::vector<Token>& cached_tokens, - const std::vector<AnnotatedSpan>& candidates, int start_index, - int end_index, InterpreterManager* interpreter_manager, - std::vector<int>* chosen_indices) const { +bool Annotator::ResolveConflict(const std::string& context, + const std::vector<Token>& cached_tokens, + const std::vector<AnnotatedSpan>& candidates, + int start_index, int end_index, + InterpreterManager* interpreter_manager, + std::vector<int>* chosen_indices) const { std::vector<int> conflicting_indices; std::unordered_map<int, float> scores; for (int i = start_index; i < end_index; ++i) { @@ -645,7 +742,7 @@ bool TextClassifier::ResolveConflict( return true; } -bool TextClassifier::ModelSuggestSelection( +bool Annotator::ModelSuggestSelection( const UnicodeText& context_unicode, CodepointSpan click_indices, InterpreterManager* interpreter_manager, std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const { @@ -661,7 +758,7 @@ bool TextClassifier::ModelSuggestSelection( selection_feature_processor_->GetOptions()->only_use_line_with_click(), tokens, &click_pos); if (click_pos == kInvalidIndex) { - TC_VLOG(1) << "Could not calculate the click position."; + TC3_VLOG(1) << "Could not calculate the click position."; return false; } @@ -719,7 +816,7 @@ bool TextClassifier::ModelSuggestSelection( selection_feature_processor_->EmbeddingSize() + selection_feature_processor_->DenseFeaturesCount(), &cached_features)) { - TC_LOG(ERROR) << "Could not extract features."; + TC3_LOG(ERROR) << "Could not extract features."; return false; } @@ -728,7 +825,7 @@ bool TextClassifier::ModelSuggestSelection( if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, interpreter_manager->SelectionInterpreter(), *cached_features, &chunks)) { - TC_LOG(ERROR) << "Could not chunk."; + TC3_LOG(ERROR) << "Could not chunk."; return false; } @@ -749,7 +846,7 @@ bool TextClassifier::ModelSuggestSelection( return true; } -bool TextClassifier::ModelClassifyText( +bool Annotator::ModelClassifyText( const std::string& context, CodepointSpan selection_indices, InterpreterManager* interpreter_manager, FeatureProcessor::EmbeddingCache* embedding_cache, @@ -796,7 +893,7 @@ std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, } } // namespace internal -TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const { +TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const { const FeatureProcessorOptions_::BoundsSensitiveFeatures* bounds_sensitive_features = classification_feature_processor_->GetOptions() @@ -815,7 +912,7 @@ TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const { } } -bool TextClassifier::ModelClassifyText( +bool Annotator::ModelClassifyText( const std::string& context, const std::vector<Token>& cached_tokens, CodepointSpan selection_indices, InterpreterManager* interpreter_manager, FeatureProcessor::EmbeddingCache* embedding_cache, @@ -850,7 +947,7 @@ bool TextClassifier::ModelClassifyText( ->bounds_sensitive_features(); if (selection_token_span.first == kInvalidIndex || selection_token_span.second == kInvalidIndex) { - TC_LOG(ERROR) << "Could not determine span."; + TC3_LOG(ERROR) << "Could not determine span."; return false; } @@ -865,7 +962,7 @@ bool TextClassifier::ModelClassifyText( /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); } else { if (click_pos == kInvalidIndex) { - TC_LOG(ERROR) << "Couldn't choose a click position."; + TC3_LOG(ERROR) << "Couldn't choose a click position."; return false; } // The extraction span is the clicked token with context_size tokens on @@ -891,7 +988,7 @@ bool TextClassifier::ModelClassifyText( classification_feature_processor_->EmbeddingSize() + classification_feature_processor_->DenseFeaturesCount(), &cached_features)) { - TC_LOG(ERROR) << "Could not extract features."; + TC3_LOG(ERROR) << "Could not extract features."; return false; } @@ -909,13 +1006,13 @@ bool TextClassifier::ModelClassifyText( {1, static_cast<int>(features.size())}), interpreter_manager->ClassificationInterpreter()); if (!logits.is_valid()) { - TC_LOG(ERROR) << "Couldn't compute logits."; + TC3_LOG(ERROR) << "Couldn't compute logits."; return false; } if (logits.dims() != 2 || logits.dim(0) != 1 || logits.dim(1) != classification_feature_processor_->NumCollections()) { - TC_LOG(ERROR) << "Mismatching output"; + TC3_LOG(ERROR) << "Mismatching output"; return false; } @@ -956,7 +1053,7 @@ bool TextClassifier::ModelClassifyText( return true; } -bool TextClassifier::RegexClassifyText( +bool Annotator::RegexClassifyText( const std::string& context, CodepointSpan selection_indices, ClassificationResult* classification_result) const { const std::string selection_text = @@ -980,21 +1077,22 @@ bool TextClassifier::RegexClassifyText( if (status != UniLib::RegexMatcher::kNoError) { return false; } - if (matches) { + if (matches && + VerifyCandidate(regex_pattern.verification_options, selection_text)) { *classification_result = {regex_pattern.collection_name, regex_pattern.target_classification_score, regex_pattern.priority_score}; return true; } if (status != UniLib::RegexMatcher::kNoError) { - TC_LOG(ERROR) << "Cound't match regex: " << pattern_id; + TC3_LOG(ERROR) << "Cound't match regex: " << pattern_id; } } return false; } -bool TextClassifier::DatetimeClassifyText( +bool Annotator::DatetimeClassifyText( const std::string& context, CodepointSpan selection_indices, const ClassificationOptions& options, ClassificationResult* classification_result) const { @@ -1010,7 +1108,7 @@ bool TextClassifier::DatetimeClassifyText( options.reference_timezone, options.locales, ModeFlag_CLASSIFICATION, /*anchor_start_end=*/true, &datetime_spans)) { - TC_LOG(ERROR) << "Error during parsing datetime."; + TC3_LOG(ERROR) << "Error during parsing datetime."; return false; } for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { @@ -1028,11 +1126,11 @@ bool TextClassifier::DatetimeClassifyText( return false; } -std::vector<ClassificationResult> TextClassifier::ClassifyText( +std::vector<ClassificationResult> Annotator::ClassifyText( const std::string& context, CodepointSpan selection_indices, const ClassificationOptions& options) const { if (!initialized_) { - TC_LOG(ERROR) << "Not initialized"; + TC3_LOG(ERROR) << "Not initialized"; return {}; } @@ -1045,12 +1143,23 @@ std::vector<ClassificationResult> TextClassifier::ClassifyText( } if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { - TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: " - << std::get<0>(selection_indices) << " " - << std::get<1>(selection_indices); + TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: " + << std::get<0>(selection_indices) << " " + << std::get<1>(selection_indices); return {}; } + // Try the knowledge engine. + ClassificationResult knowledge_result; + if (knowledge_engine_ && knowledge_engine_->ClassifyText( + context, selection_indices, &knowledge_result)) { + if (!FilteredForClassification(knowledge_result)) { + return {knowledge_result}; + } else { + return {{kOtherCollection, 1.0}}; + } + } + // Try the regular expression models. ClassificationResult regex_result; if (RegexClassifyText(context, selection_indices, ®ex_result)) { @@ -1091,10 +1200,10 @@ std::vector<ClassificationResult> TextClassifier::ClassifyText( return {}; } -bool TextClassifier::ModelAnnotate(const std::string& context, - InterpreterManager* interpreter_manager, - std::vector<Token>* tokens, - std::vector<AnnotatedSpan>* result) const { +bool Annotator::ModelAnnotate(const std::string& context, + InterpreterManager* interpreter_manager, + std::vector<Token>* tokens, + std::vector<AnnotatedSpan>* result) const { if (model_->triggering_options() == nullptr || !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { return true; @@ -1142,7 +1251,7 @@ bool TextClassifier::ModelAnnotate(const std::string& context, selection_feature_processor_->EmbeddingSize() + selection_feature_processor_->DenseFeaturesCount(), &cached_features)) { - TC_LOG(ERROR) << "Could not extract features."; + TC3_LOG(ERROR) << "Could not extract features."; return false; } @@ -1150,7 +1259,7 @@ bool TextClassifier::ModelAnnotate(const std::string& context, if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, interpreter_manager->SelectionInterpreter(), *cached_features, &local_chunks)) { - TC_LOG(ERROR) << "Could not chunk."; + TC3_LOG(ERROR) << "Could not chunk."; return false; } @@ -1166,9 +1275,9 @@ bool TextClassifier::ModelAnnotate(const std::string& context, if (!ModelClassifyText(line_str, *tokens, codepoint_span, interpreter_manager, &embedding_cache, &classification)) { - TC_LOG(ERROR) << "Could not classify text: " - << (codepoint_span.first + offset) << " " - << (codepoint_span.second + offset); + TC3_LOG(ERROR) << "Could not classify text: " + << (codepoint_span.first + offset) << " " + << (codepoint_span.second + offset); return false; } @@ -1187,21 +1296,20 @@ bool TextClassifier::ModelAnnotate(const std::string& context, return true; } -const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests() - const { +const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const { return selection_feature_processor_.get(); } -const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests() +const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests() const { return classification_feature_processor_.get(); } -const DatetimeParser* TextClassifier::DatetimeParserForTests() const { +const DatetimeParser* Annotator::DatetimeParserForTests() const { return datetime_parser_.get(); } -std::vector<AnnotatedSpan> TextClassifier::Annotate( +std::vector<AnnotatedSpan> Annotator::Annotate( const std::string& context, const AnnotationOptions& options) const { std::vector<AnnotatedSpan> candidates; @@ -1218,14 +1326,14 @@ std::vector<AnnotatedSpan> TextClassifier::Annotate( // Annotate with the selection model. std::vector<Token> tokens; if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) { - TC_LOG(ERROR) << "Couldn't run ModelAnnotate."; + TC3_LOG(ERROR) << "Couldn't run ModelAnnotate."; return {}; } // Annotate with the regular expression models. if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), annotation_regex_patterns_, &candidates)) { - TC_LOG(ERROR) << "Couldn't run RegexChunk."; + TC3_LOG(ERROR) << "Couldn't run RegexChunk."; return {}; } @@ -1233,7 +1341,13 @@ std::vector<AnnotatedSpan> TextClassifier::Annotate( if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), options.reference_time_ms_utc, options.reference_timezone, options.locales, ModeFlag_ANNOTATION, &candidates)) { - TC_LOG(ERROR) << "Couldn't run RegexChunk."; + TC3_LOG(ERROR) << "Couldn't run RegexChunk."; + return {}; + } + + // Annotate with the knowledge engine. + if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) { + TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk."; return {}; } @@ -1248,7 +1362,7 @@ std::vector<AnnotatedSpan> TextClassifier::Annotate( std::vector<int> candidate_indices; if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, &candidate_indices)) { - TC_LOG(ERROR) << "Couldn't resolve conflicts."; + TC3_LOG(ERROR) << "Couldn't resolve conflicts."; return {}; } @@ -1265,20 +1379,26 @@ std::vector<AnnotatedSpan> TextClassifier::Annotate( return result; } -bool TextClassifier::RegexChunk(const UnicodeText& context_unicode, - const std::vector<int>& rules, - std::vector<AnnotatedSpan>* result) const { +bool Annotator::RegexChunk(const UnicodeText& context_unicode, + const std::vector<int>& rules, + std::vector<AnnotatedSpan>* result) const { for (int pattern_id : rules) { const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; const auto matcher = regex_pattern.pattern->Matcher(context_unicode); if (!matcher) { - TC_LOG(ERROR) << "Could not get regex matcher for pattern: " - << pattern_id; + TC3_LOG(ERROR) << "Could not get regex matcher for pattern: " + << pattern_id; return false; } int status = UniLib::RegexMatcher::kNoError; while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { + if (regex_pattern.verification_options) { + if (!VerifyCandidate(regex_pattern.verification_options, + matcher->Group(1, &status).ToUTF8String())) { + continue; + } + } result->emplace_back(); // Selection/annotation regular expressions need to specify a capturing // group specifying the selection. @@ -1293,11 +1413,10 @@ bool TextClassifier::RegexChunk(const UnicodeText& context_unicode, return true; } -bool TextClassifier::ModelChunk(int num_tokens, - const TokenSpan& span_of_interest, - tflite::Interpreter* selection_interpreter, - const CachedFeatures& cached_features, - std::vector<TokenSpan>* chunks) const { +bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, + tflite::Interpreter* selection_interpreter, + const CachedFeatures& cached_features, + std::vector<TokenSpan>* chunks) const { const int max_selection_span = selection_feature_processor_->GetOptions()->max_selection_span(); // The inference span is the span of interest expanded to include @@ -1378,7 +1497,7 @@ void UpdateMax(Map* map, typename Map::key_type key, } } // namespace -bool TextClassifier::ModelClickContextScoreChunks( +bool Annotator::ModelClickContextScoreChunks( int num_tokens, const TokenSpan& span_of_interest, const CachedFeatures& cached_features, tflite::Interpreter* selection_interpreter, @@ -1407,13 +1526,13 @@ bool TextClassifier::ModelClickContextScoreChunks( TensorView<float>(all_features.data(), {batch_size, features_size}), selection_interpreter); if (!logits.is_valid()) { - TC_LOG(ERROR) << "Couldn't compute logits."; + TC3_LOG(ERROR) << "Couldn't compute logits."; return false; } if (logits.dims() != 2 || logits.dim(0) != batch_size || logits.dim(1) != selection_feature_processor_->GetSelectionLabelCount()) { - TC_LOG(ERROR) << "Mismatching output."; + TC3_LOG(ERROR) << "Mismatching output."; return false; } @@ -1427,7 +1546,7 @@ bool TextClassifier::ModelClickContextScoreChunks( TokenSpan relative_token_span; if (!selection_feature_processor_->LabelToTokenSpan( j, &relative_token_span)) { - TC_LOG(ERROR) << "Couldn't map the label to a token span."; + TC3_LOG(ERROR) << "Couldn't map the label to a token span."; return false; } const TokenSpan candidate_span = ExpandTokenSpan( @@ -1449,7 +1568,7 @@ bool TextClassifier::ModelClickContextScoreChunks( return true; } -bool TextClassifier::ModelBoundsSensitiveScoreChunks( +bool Annotator::ModelBoundsSensitiveScoreChunks( int num_tokens, const TokenSpan& span_of_interest, const TokenSpan& inference_span, const CachedFeatures& cached_features, tflite::Interpreter* selection_interpreter, @@ -1518,12 +1637,12 @@ bool TextClassifier::ModelBoundsSensitiveScoreChunks( TensorView<float>(all_features.data(), {batch_size, features_size}), selection_interpreter); if (!logits.is_valid()) { - TC_LOG(ERROR) << "Couldn't compute logits."; + TC3_LOG(ERROR) << "Couldn't compute logits."; return false; } if (logits.dims() != 2 || logits.dim(0) != batch_size || logits.dim(1) != 1) { - TC_LOG(ERROR) << "Mismatching output."; + TC3_LOG(ERROR) << "Mismatching output."; return false; } @@ -1537,11 +1656,11 @@ bool TextClassifier::ModelBoundsSensitiveScoreChunks( return true; } -bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode, - int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& locales, ModeFlag mode, - std::vector<AnnotatedSpan>* result) const { +bool Annotator::DatetimeChunk(const UnicodeText& context_unicode, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& locales, ModeFlag mode, + std::vector<AnnotatedSpan>* result) const { if (!datetime_parser_) { return true; } @@ -1573,4 +1692,4 @@ const Model* ViewModel(const void* buffer, int size) { return LoadAndVerifyModel(buffer, size); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/text-classifier.h b/annotator/annotator.h index 0692ecd..c58c03d 100644 --- a/text-classifier.h +++ b/annotator/annotator.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,25 +16,27 @@ // Inference code for the text classification model. -#ifndef LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ -#define LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_ #include <memory> #include <set> #include <string> #include <vector> -#include "datetime/parser.h" -#include "feature-processor.h" -#include "model-executor.h" -#include "model_generated.h" -#include "strip-unpaired-brackets.h" -#include "types.h" -#include "util/memory/mmap.h" -#include "util/utf8/unilib.h" -#include "zlib-utils.h" +#include "annotator/datetime/parser.h" +#include "annotator/feature-processor.h" +#include "annotator/knowledge/knowledge-engine.h" +#include "annotator/model-executor.h" +#include "annotator/model_generated.h" +#include "annotator/strip-unpaired-brackets.h" +#include "annotator/types.h" +#include "annotator/zlib-utils.h" +#include "utils/memory/mmap.h" +#include "utils/utf8/unilib.h" +#include "utils/zlib/zlib.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { struct SelectionOptions { // Comma-separated list of locale specification for the input text (BCP 47 @@ -106,23 +108,31 @@ class InterpreterManager { // A text processing model that provides text classification, annotation, // selection suggestion for various types. // NOTE: This class is not thread-safe. -class TextClassifier { +class Annotator { public: - static std::unique_ptr<TextClassifier> FromUnownedBuffer( - const char* buffer, int size, const UniLib* unilib = nullptr); + static std::unique_ptr<Annotator> FromUnownedBuffer( + const char* buffer, int size, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); // Takes ownership of the mmap. - static std::unique_ptr<TextClassifier> FromScopedMmap( - std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr); - static std::unique_ptr<TextClassifier> FromFileDescriptor( - int fd, int offset, int size, const UniLib* unilib = nullptr); - static std::unique_ptr<TextClassifier> FromFileDescriptor( - int fd, const UniLib* unilib = nullptr); - static std::unique_ptr<TextClassifier> FromPath( - const std::string& path, const UniLib* unilib = nullptr); + static std::unique_ptr<Annotator> FromScopedMmap( + std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + static std::unique_ptr<Annotator> FromFileDescriptor( + int fd, int offset, int size, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + static std::unique_ptr<Annotator> FromFileDescriptor( + int fd, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); + static std::unique_ptr<Annotator> FromPath( + const std::string& path, const UniLib* unilib = nullptr, + const CalendarLib* calendarlib = nullptr); // Returns true if the model is ready for use. bool IsInitialized() { return initialized_; } + // Initializes the knowledge engine with the given config. + bool InitializeKnowledgeEngine(const std::string& serialized_config); + // Runs inference for given a context and current selection (i.e. index // of the first and one past last selected characters (utf8 codepoint // offsets)). Returns the indices (utf8 codepoint offsets) of the selection @@ -160,6 +170,13 @@ class TextClassifier { static const std::string& kPhoneCollection; static const std::string& kAddressCollection; static const std::string& kDateCollection; + static const std::string& kUrlCollection; + static const std::string& kFlightCollection; + static const std::string& kEmailCollection; + static const std::string& kIbanCollection; + static const std::string& kPaymentCardCollection; + static const std::string& kIsbnCollection; + static const std::string& kTrackingNumberCollection; protected: struct ScoredChunk { @@ -169,23 +186,13 @@ class TextClassifier { // Constructs and initializes text classifier from given model. // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'. - TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model, - const UniLib* unilib) - : model_(model), - mmap_(std::move(*mmap)), - owned_unilib_(nullptr), - unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) { - ValidateAndInitialize(); - } + Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model, + const UniLib* unilib, const CalendarLib* calendarlib); // Constructs, validates and initializes text classifier from given model. // Does not own the buffer that backs 'model'. - explicit TextClassifier(const Model* model, const UniLib* unilib) - : model_(model), - owned_unilib_(nullptr), - unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) { - ValidateAndInitialize(); - } + explicit Annotator(const Model* model, const UniLib* unilib, + const CalendarLib* calendarlib); // Checks that model contains all required fields, and initializes internal // datastructures. @@ -334,6 +341,7 @@ class TextClassifier { float target_classification_score; float priority_score; std::unique_ptr<UniLib::RegexPattern> pattern; + const VerificationOptions* verification_options; }; std::unique_ptr<ScopedMmap> mmap_; @@ -354,6 +362,10 @@ class TextClassifier { std::unique_ptr<UniLib> owned_unilib_; const UniLib* unilib_; + std::unique_ptr<CalendarLib> owned_calendarlib_; + const CalendarLib* calendarlib_; + + std::unique_ptr<const KnowledgeEngine> knowledge_engine_; }; namespace internal { @@ -376,6 +388,6 @@ std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, // Interprets the buffer as a Model flatbuffer and returns it for reading. const Model* ViewModel(const void* buffer, int size); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_ diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc new file mode 100644 index 0000000..9bda35a --- /dev/null +++ b/annotator/annotator_jni.cc @@ -0,0 +1,434 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// JNI wrapper for the Annotator. + +#include "annotator/annotator_jni.h" + +#include <jni.h> +#include <type_traits> +#include <vector> + +#include "annotator/annotator.h" +#include "annotator/annotator_jni_common.h" +#include "utils/base/integral_types.h" +#include "utils/calendar/calendar.h" +#include "utils/java/scoped_local_ref.h" +#include "utils/java/string_utils.h" +#include "utils/memory/mmap.h" +#include "utils/utf8/unilib.h" + +#ifdef TC3_UNILIB_JAVAICU +#ifndef TC3_CALENDAR_JAVAICU +#error Inconsistent usage of Java ICU components +#else +#define TC3_USE_JAVAICU +#endif +#endif + +using libtextclassifier3::AnnotatedSpan; +using libtextclassifier3::Annotator; +using libtextclassifier3::ClassificationResult; +using libtextclassifier3::CodepointSpan; +using libtextclassifier3::Model; +using libtextclassifier3::ScopedLocalRef; +// When using the Java's ICU, CalendarLib and UniLib need to be instantiated +// with a JavaVM pointer from JNI. When using a standard ICU the pointer is +// not needed and the objects are instantiated implicitly. +#ifdef TC3_USE_JAVAICU +using libtextclassifier3::CalendarLib; +using libtextclassifier3::UniLib; +#endif + +namespace libtextclassifier3 { + +using libtextclassifier3::CodepointSpan; + +namespace { + +jobjectArray ClassificationResultsToJObjectArray( + JNIEnv* env, + const std::vector<ClassificationResult>& classification_result) { + const ScopedLocalRef<jclass> result_class( + env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$ClassificationResult"), + env); + if (!result_class) { + TC3_LOG(ERROR) << "Couldn't find ClassificationResult class."; + return nullptr; + } + const ScopedLocalRef<jclass> datetime_parse_class( + env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$DatetimeResult"), + env); + if (!datetime_parse_class) { + TC3_LOG(ERROR) << "Couldn't find DatetimeResult class."; + return nullptr; + } + + const jmethodID result_class_constructor = env->GetMethodID( + result_class.get(), "<init>", + "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$DatetimeResult;[B)V"); + const jmethodID datetime_parse_class_constructor = + env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); + + const jobjectArray results = env->NewObjectArray(classification_result.size(), + result_class.get(), nullptr); + for (int i = 0; i < classification_result.size(); i++) { + jstring row_string = + env->NewStringUTF(classification_result[i].collection.c_str()); + + jobject row_datetime_parse = nullptr; + if (classification_result[i].datetime_parse_result.IsSet()) { + row_datetime_parse = env->NewObject( + datetime_parse_class.get(), datetime_parse_class_constructor, + classification_result[i].datetime_parse_result.time_ms_utc, + classification_result[i].datetime_parse_result.granularity); + } + + jbyteArray serialized_knowledge_result = nullptr; + const std::string& serialized_knowledge_result_string = + classification_result[i].serialized_knowledge_result; + if (!serialized_knowledge_result_string.empty()) { + serialized_knowledge_result = + env->NewByteArray(serialized_knowledge_result_string.size()); + env->SetByteArrayRegion(serialized_knowledge_result, 0, + serialized_knowledge_result_string.size(), + reinterpret_cast<const jbyte*>( + serialized_knowledge_result_string.data())); + } + + jobject result = + env->NewObject(result_class.get(), result_class_constructor, row_string, + static_cast<jfloat>(classification_result[i].score), + row_datetime_parse, serialized_knowledge_result); + env->SetObjectArrayElement(results, i, result); + env->DeleteLocalRef(result); + } + return results; +} + +CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, + CodepointSpan orig_indices, + bool from_utf8) { + const libtextclassifier3::UnicodeText unicode_str = + libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); + + int unicode_index = 0; + int bmp_index = 0; + + const int* source_index; + const int* target_index; + if (from_utf8) { + source_index = &unicode_index; + target_index = &bmp_index; + } else { + source_index = &bmp_index; + target_index = &unicode_index; + } + + CodepointSpan result{-1, -1}; + std::function<void()> assign_indices_fn = [&result, &orig_indices, + &source_index, &target_index]() { + if (orig_indices.first == *source_index) { + result.first = *target_index; + } + + if (orig_indices.second == *source_index) { + result.second = *target_index; + } + }; + + for (auto it = unicode_str.begin(); it != unicode_str.end(); + ++it, ++unicode_index, ++bmp_index) { + assign_indices_fn(); + + // There is 1 extra character in the input for each UTF8 character > 0xFFFF. + if (*it > 0xFFFF) { + ++bmp_index; + } + } + assign_indices_fn(); + + return result; +} + +} // namespace + +CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, + CodepointSpan bmp_indices) { + return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); +} + +CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, + CodepointSpan utf8_indices) { + return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); +} + +jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return env->NewStringUTF(""); + } + const Model* model = libtextclassifier3::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model || !model->locales()) { + return env->NewStringUTF(""); + } + return env->NewStringUTF(model->locales()->c_str()); +} + +jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return 0; + } + const Model* model = libtextclassifier3::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model) { + return 0; + } + return model->version(); +} + +jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { + if (!mmap->handle().ok()) { + return env->NewStringUTF(""); + } + const Model* model = libtextclassifier3::ViewModel( + mmap->handle().start(), mmap->handle().num_bytes()); + if (!model || !model->name()) { + return env->NewStringUTF(""); + } + return env->NewStringUTF(model->name()->c_str()); +} + +} // namespace libtextclassifier3 + +using libtextclassifier3::ClassificationResultsToJObjectArray; +using libtextclassifier3::ConvertIndicesBMPToUTF8; +using libtextclassifier3::ConvertIndicesUTF8ToBMP; +using libtextclassifier3::FromJavaAnnotationOptions; +using libtextclassifier3::FromJavaClassificationOptions; +using libtextclassifier3::FromJavaSelectionOptions; +using libtextclassifier3::ToStlString; + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator) +(JNIEnv* env, jobject thiz, jint fd) { +#ifdef TC3_USE_JAVAICU + std::shared_ptr<libtextclassifier3::JniCache> jni_cache( + libtextclassifier3::JniCache::Create(env)); + return reinterpret_cast<jlong>( + Annotator::FromFileDescriptor(fd, new UniLib(jni_cache), + new CalendarLib(jni_cache)) + .release()); +#else + return reinterpret_cast<jlong>(Annotator::FromFileDescriptor(fd).release()); +#endif +} + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath) +(JNIEnv* env, jobject thiz, jstring path) { + const std::string path_str = ToStlString(env, path); +#ifdef TC3_USE_JAVAICU + std::shared_ptr<libtextclassifier3::JniCache> jni_cache( + libtextclassifier3::JniCache::Create(env)); + return reinterpret_cast<jlong>(Annotator::FromPath(path_str, + new UniLib(jni_cache), + new CalendarLib(jni_cache)) + .release()); +#else + return reinterpret_cast<jlong>(Annotator::FromPath(path_str).release()); +#endif +} + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, + nativeNewAnnotatorFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); +#ifdef TC3_USE_JAVAICU + std::shared_ptr<libtextclassifier3::JniCache> jni_cache( + libtextclassifier3::JniCache::Create(env)); + return reinterpret_cast<jlong>( + Annotator::FromFileDescriptor(fd, offset, size, new UniLib(jni_cache), + new CalendarLib(jni_cache)) + .release()); +#else + return reinterpret_cast<jlong>( + Annotator::FromFileDescriptor(fd, offset, size).release()); +#endif +} + +TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME, + nativeInitializeKnowledgeEngine) +(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) { + if (!ptr) { + return false; + } + + Annotator* model = reinterpret_cast<Annotator*>(ptr); + + std::string serialized_config_string; + const int length = env->GetArrayLength(serialized_config); + serialized_config_string.resize(length); + env->GetByteArrayRegion(serialized_config, 0, length, + reinterpret_cast<jbyte*>(const_cast<char*>( + serialized_config_string.data()))); + + return model->InitializeKnowledgeEngine(serialized_config_string); +} + +TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options) { + if (!ptr) { + return nullptr; + } + + Annotator* model = reinterpret_cast<Annotator*>(ptr); + + const std::string context_utf8 = ToStlString(env, context); + CodepointSpan input_indices = + ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); + CodepointSpan selection = model->SuggestSelection( + context_utf8, input_indices, FromJavaSelectionOptions(env, options)); + selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); + + jintArray result = env->NewIntArray(2); + env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); + env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); + return result; +} + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options) { + if (!ptr) { + return nullptr; + } + Annotator* ff_model = reinterpret_cast<Annotator*>(ptr); + + const std::string context_utf8 = ToStlString(env, context); + const CodepointSpan input_indices = + ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); + const std::vector<ClassificationResult> classification_result = + ff_model->ClassifyText(context_utf8, input_indices, + FromJavaClassificationOptions(env, options)); + + return ClassificationResultsToJObjectArray(env, classification_result); +} + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { + if (!ptr) { + return nullptr; + } + Annotator* model = reinterpret_cast<Annotator*>(ptr); + std::string context_utf8 = ToStlString(env, context); + std::vector<AnnotatedSpan> annotations = + model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); + + jclass result_class = env->FindClass( + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"); + if (!result_class) { + TC3_LOG(ERROR) << "Couldn't find result class: " + << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$AnnotatedSpan"; + return nullptr; + } + + jmethodID result_class_constructor = + env->GetMethodID(result_class, "<init>", + "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$ClassificationResult;)V"); + + jobjectArray results = + env->NewObjectArray(annotations.size(), result_class, nullptr); + + for (int i = 0; i < annotations.size(); ++i) { + CodepointSpan span_bmp = + ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); + jobject result = env->NewObject(result_class, result_class_constructor, + static_cast<jint>(span_bmp.first), + static_cast<jint>(span_bmp.second), + ClassificationResultsToJObjectArray( + env, annotations[i].classification)); + env->SetObjectArrayElement(results, i, result); + env->DeleteLocalRef(result); + } + env->DeleteLocalRef(result_class); + return results; +} + +TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) +(JNIEnv* env, jobject thiz, jlong ptr) { + Annotator* model = reinterpret_cast<Annotator*>(ptr); + delete model; +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage) +(JNIEnv* env, jobject clazz, jint fd) { + TC3_LOG(WARNING) << "Using deprecated getLanguage()."; + return TC3_JNI_METHOD_NAME(TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)( + env, clazz, fd); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd)); + return GetLocalesFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetLocalesFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd, offset, size)); + return GetLocalesFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd)); + return GetVersionFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, + nativeGetVersionFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd, offset, size)); + return GetVersionFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName) +(JNIEnv* env, jobject clazz, jint fd) { + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd)); + return GetNameFromMmap(env, mmap.get()); +} + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetNameFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { + const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd); + const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( + new libtextclassifier3::ScopedMmap(fd, offset, size)); + return GetNameFromMmap(env, mmap.get()); +} diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h new file mode 100644 index 0000000..47715b4 --- /dev/null +++ b/annotator/annotator_jni.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_ + +#include <jni.h> +#include <string> +#include "annotator/annotator_jni_common.h" +#include "annotator/types.h" +#include "utils/java/jni-base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// SmartSelection. +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator) +(JNIEnv* env, jobject thiz, jint fd); + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath) +(JNIEnv* env, jobject thiz, jstring path); + +TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, + nativeNewAnnotatorFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME, + nativeInitializeKnowledgeEngine) +(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config); + +TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options); + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, + jint selection_end, jobject options); + +TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate) +(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options); + +TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) +(JNIEnv* env, jobject thiz, jlong ptr); + +// DEPRECATED. Use nativeGetLocales instead. +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetLocalesFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, + nativeGetVersionFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName) +(JNIEnv* env, jobject clazz, jint fd); + +TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, + nativeGetNameFromAssetFileDescriptor) +(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); + +#ifdef __cplusplus +} +#endif + +namespace libtextclassifier3 { + +// Given a utf8 string and a span expressed in Java BMP (basic multilingual +// plane) codepoints, converts it to a span expressed in utf8 codepoints. +libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8( + const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices); + +// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a +// span expressed in Java BMP (basic multilingual plane) codepoints. +libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP( + const std::string& utf8_str, + libtextclassifier3::CodepointSpan utf8_indices); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_ diff --git a/annotator/annotator_jni_common.cc b/annotator/annotator_jni_common.cc new file mode 100644 index 0000000..0fdb87b --- /dev/null +++ b/annotator/annotator_jni_common.cc @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/annotator_jni_common.h" + +#include "utils/java/jni-base.h" +#include "utils/java/scoped_local_ref.h" + +namespace libtextclassifier3 { +namespace { +template <typename T> +T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, + const std::string& class_name) { + if (!joptions) { + return {}; + } + + const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), + env); + if (!options_class) { + return {}; + } + + const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( + env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, + "getLocale", "Ljava/lang/String;"); + const std::pair<bool, jobject> status_or_reference_timezone = + CallJniMethod0<jobject>(env, joptions, options_class.get(), + &JNIEnv::CallObjectMethod, "getReferenceTimezone", + "Ljava/lang/String;"); + const std::pair<bool, int64> status_or_reference_time_ms_utc = + CallJniMethod0<int64>(env, joptions, options_class.get(), + &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", + "J"); + + if (!status_or_locales.first || !status_or_reference_timezone.first || + !status_or_reference_time_ms_utc.first) { + return {}; + } + + T options; + options.locales = + ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + options.reference_timezone = ToStlString( + env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); + options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; + return options; +} +} // namespace + +SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { + if (!joptions) { + return {}; + } + + const ScopedLocalRef<jclass> options_class( + env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$SelectionOptions"), + env); + const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( + env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, + "getLocales", "Ljava/lang/String;"); + if (!status_or_locales.first) { + return {}; + } + + SelectionOptions options; + options.locales = + ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + + return options; +} + +ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, + jobject joptions) { + return FromJavaOptionsInternal<ClassificationOptions>( + env, joptions, + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions"); +} + +AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { + return FromJavaOptionsInternal<AnnotationOptions>( + env, joptions, + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions"); +} + +} // namespace libtextclassifier3 diff --git a/annotator/annotator_jni_common.h b/annotator/annotator_jni_common.h new file mode 100644 index 0000000..b62bb21 --- /dev/null +++ b/annotator/annotator_jni_common.h @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_ + +#include <jni.h> + +#include "annotator/annotator.h" + +#ifndef TC3_ANNOTATOR_CLASS_NAME +#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel +#endif + +#define TC3_ANNOTATOR_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ANNOTATOR_CLASS_NAME) + +namespace libtextclassifier3 { + +SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions); + +ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, + jobject joptions); + +AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_ diff --git a/textclassifier_jni_test.cc b/annotator/annotator_jni_test.cc index 87b96fa..929fb59 100644 --- a/textclassifier_jni_test.cc +++ b/annotator/annotator_jni_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "textclassifier_jni.h" +#include "annotator/annotator_jni.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { -TEST(TextClassifier, ConvertIndicesBMPUTF8) { +TEST(Annotator, ConvertIndicesBMPUTF8) { // Test boundary cases. EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), std::make_pair(0, 5)); EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello", {0, 5}), std::make_pair(0, 5)); @@ -76,4 +76,4 @@ TEST(TextClassifier, ConvertIndicesBMPUTF8) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/text-classifier_test.cc b/annotator/annotator_test.cc index c8ced76..fbaf039 100644 --- a/text-classifier_test.cc +++ b/annotator/annotator_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,19 @@ * limitations under the License. */ -#include "text-classifier.h" +#include "annotator/annotator.h" #include <fstream> #include <iostream> #include <memory> #include <string> -#include "model_generated.h" -#include "types-test-util.h" +#include "annotator/model_generated.h" +#include "annotator/types-test-util.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { using testing::ElementsAreArray; @@ -52,27 +52,32 @@ std::string ReadFile(const std::string& file_name) { } std::string GetModelPath() { - return LIBTEXTCLASSIFIER_TEST_DATA_DIR; + return TC3_TEST_DATA_DIR; } -TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib); +class AnnotatorTest : public ::testing::TestWithParam<const char*> { + protected: + AnnotatorTest() + : INIT_UNILIB_FOR_TESTING(unilib_), + INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {} + UniLib unilib_; + CalendarLib calendarlib_; +}; + +TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) { + std::unique_ptr<Annotator> classifier = Annotator::FromPath( + GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_); EXPECT_FALSE(classifier); } -class TextClassifierTest : public ::testing::TestWithParam<const char*> {}; - -INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest, +INSTANTIATE_TEST_CASE_P(ClickContext, AnnotatorTest, Values("test_model_cc.fb")); -INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest, +INSTANTIATE_TEST_CASE_P(BoundsSensitive, AnnotatorTest, Values("test_model.fb")); -TEST_P(TextClassifierTest, ClassifyText) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, ClassifyText) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ("other", @@ -109,8 +114,7 @@ TEST_P(TextClassifierTest, ClassifyText) { "\xf0\x9f\x98\x8b\x8b", {0, 0}))); } -TEST_P(TextClassifierTest, ClassifyTextDisabledFail) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, ClassifyTextDisabledFail) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -119,19 +123,17 @@ TEST_P(TextClassifierTest, ClassifyTextDisabledFail) { unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); // The classification model is still needed for selection scores. ASSERT_FALSE(classifier); } -TEST_P(TextClassifierTest, ClassifyTextDisabled) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, ClassifyTextDisabled) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -140,12 +142,11 @@ TEST_P(TextClassifierTest, ClassifyTextDisabled) { ModeFlag_ANNOTATION_AND_SELECTION; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_THAT( @@ -153,13 +154,11 @@ TEST_P(TextClassifierTest, ClassifyTextDisabled) { IsEmpty()); } -TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, ClassifyTextFilteredCollections) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), - &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( @@ -173,11 +172,11 @@ TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) { "phone"); flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ("other", FirstResult(classifier->ClassifyText( @@ -206,9 +205,8 @@ std::unique_ptr<RegexModel_::PatternT> MakePattern( return result; } -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, ClassifyTextRegularExpression) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, ClassifyTextRegularExpression) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -219,14 +217,21 @@ TEST_P(TextClassifierTest, ClassifyTextRegularExpression) { unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5)); + std::unique_ptr<RegexModel_::PatternT> verified_pattern = + MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}", + /*enabled_for_classification=*/true, + /*enabled_for_selection=*/false, + /*enabled_for_annotation=*/false, 1.0); + verified_pattern->verification_options.reset(new VerificationOptionsT); + verified_pattern->verification_options->verify_luhn_checksum = true; + unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ("flight", @@ -246,6 +251,13 @@ TEST_P(TextClassifierTest, ClassifyTextRegularExpression) { EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5}))); EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd", {7, 12}))); + EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText( + "cc: 4012 8888 8888 1881", {4, 23}))); + EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText( + "2221 0067 4735 6281", {0, 19}))); + // Luhn check fails. + EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282", + {0, 19}))); // More lines. EXPECT_EQ("url", @@ -254,11 +266,10 @@ TEST_P(TextClassifierTest, ClassifyTextRegularExpression) { "www.google.com every today!|Call me at (800) 123-456 today.", {51, 65}))); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -270,14 +281,21 @@ TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) { "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0)); unpacked_model->regex_model->patterns.back()->priority_score = 1.1; + std::unique_ptr<RegexModel_::PatternT> verified_pattern = + MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})", + /*enabled_for_classification=*/false, + /*enabled_for_selection=*/true, + /*enabled_for_annotation=*/false, 1.0); + verified_pattern->verification_options.reset(new VerificationOptionsT); + verified_pattern->verification_options->verify_luhn_checksum = true; + unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); // Check regular expression selection. @@ -287,12 +305,13 @@ TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) { EXPECT_EQ(classifier->SuggestSelection( "this afternoon Barack Obama gave a speech at", {15, 21}), std::make_pair(15, 27)); + EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}), + std::make_pair(4, 23)); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, - SuggestSelectionRegularExpressionConflictsModelWins) { +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -306,12 +325,11 @@ TEST_P(TextClassifierTest, unpacked_model->regex_model->patterns.back()->priority_score = 0.5; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize()); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); ASSERT_TRUE(classifier); // Check conflict resolution. @@ -321,11 +339,10 @@ TEST_P(TextClassifierTest, {55, 57}), std::make_pair(26, 62)); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, - SuggestSelectionRegularExpressionConflictsRegexWins) { +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -339,12 +356,11 @@ TEST_P(TextClassifierTest, unpacked_model->regex_model->patterns.back()->priority_score = 1.1; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize()); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); ASSERT_TRUE(classifier); // Check conflict resolution. @@ -354,11 +370,10 @@ TEST_P(TextClassifierTest, {55, 57}), std::make_pair(55, 62)); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, AnnotateRegex) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, AnnotateRegex) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -369,32 +384,36 @@ TEST_P(TextClassifierTest, AnnotateRegex) { unpacked_model->regex_model->patterns.push_back(MakePattern( "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false, /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5)); + std::unique_ptr<RegexModel_::PatternT> verified_pattern = + MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})", + /*enabled_for_classification=*/false, + /*enabled_for_selection=*/false, + /*enabled_for_annotation=*/true, 1.0); + verified_pattern->verification_options.reset(new VerificationOptionsT); + verified_pattern->verification_options->verify_luhn_checksum = true; + unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern)); flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " - "number is 853 225 3556"; + "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n"; EXPECT_THAT(classifier->Annotate(test_string), - ElementsAreArray({ - IsAnnotatedSpan(6, 18, "person"), - IsAnnotatedSpan(19, 24, "date"), - IsAnnotatedSpan(28, 55, "address"), - IsAnnotatedSpan(79, 91, "phone"), - })); + ElementsAreArray({IsAnnotatedSpan(6, 18, "person"), + IsAnnotatedSpan(28, 55, "address"), + IsAnnotatedSpan(79, 91, "phone"), + IsAnnotatedSpan(107, 126, "payment_card")})); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -TEST_P(TextClassifierTest, PhoneFiltering) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, PhoneFiltering) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ("phone", FirstResult(classifier->ClassifyText( @@ -405,10 +424,9 @@ TEST_P(TextClassifierTest, PhoneFiltering) { "phone: (123) 456 789,0001112", {7, 28}))); } -TEST_P(TextClassifierTest, SuggestSelection) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, SuggestSelection) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(classifier->SuggestSelection( @@ -451,8 +469,7 @@ TEST_P(TextClassifierTest, SuggestSelection) { std::make_pair(11, 12)); } -TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, SuggestSelectionDisabledFail) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -462,18 +479,16 @@ TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) { unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); // Selection model needs to be present for annotation. ASSERT_FALSE(classifier); } -TEST_P(TextClassifierTest, SuggestSelectionDisabled) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, SuggestSelectionDisabled) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -484,12 +499,11 @@ TEST_P(TextClassifierTest, SuggestSelectionDisabled) { unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ( @@ -503,13 +517,11 @@ TEST_P(TextClassifierTest, SuggestSelectionDisabled) { IsEmpty()); } -TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, SuggestSelectionFilteredCollections) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), - &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ( @@ -526,11 +538,11 @@ TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) { unpacked_model->selection_options->always_classify_suggested_selection = true; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ( @@ -542,10 +554,9 @@ TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) { std::make_pair(0, 27)); } -TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, SuggestSelectionsAreSymmetric) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}), @@ -560,10 +571,9 @@ TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) { std::make_pair(6, 33)); } -TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, SuggestSelectionWithNewLine) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}), @@ -576,10 +586,9 @@ TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) { std::make_pair(0, 7)); } -TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); // From the right. @@ -603,10 +612,9 @@ TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) { std::make_pair(16, 27)); } -TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, SuggestSelectionNoCrashWithJunk) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); // Try passing in bunch of invalid selections. @@ -627,10 +635,9 @@ TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) { std::make_pair(-1, -1)); } -TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, SuggestSelectionSelectSpace) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ( @@ -677,41 +684,39 @@ TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) { std::make_pair(5, 6)); } -TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) { UnicodeText text; text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), std::make_pair(3, 4)); text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), std::make_pair(3, 4)); // Nothing on the left. text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), std::make_pair(4, 5)); text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_), std::make_pair(0, 1)); // Whitespace only. text = UTF8ToUnicodeText(" ", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_), std::make_pair(2, 3)); text = UTF8ToUnicodeText(" ", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_), std::make_pair(4, 5)); text = UTF8ToUnicodeText(" ", /*do_copy=*/false); - EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib), + EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_), std::make_pair(0, 1)); } -TEST_P(TextClassifierTest, Annotate) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, Annotate) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = @@ -719,9 +724,6 @@ TEST_P(TextClassifierTest, Annotate) { "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU - IsAnnotatedSpan(19, 24, "date"), -#endif IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); @@ -737,20 +739,19 @@ TEST_P(TextClassifierTest, Annotate) { .empty()); } -TEST_P(TextClassifierTest, AnnotateSmallBatches) { - CREATE_UNILIB_FOR_TESTING; + +TEST_P(AnnotatorTest, AnnotateSmallBatches) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); // Set the batch size. unpacked_model->selection_options->batch_size = 4; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = @@ -758,9 +759,6 @@ TEST_P(TextClassifierTest, AnnotateSmallBatches) { "number is 853 225 3556"; EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU - IsAnnotatedSpan(19, 24, "date"), -#endif IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); @@ -771,9 +769,8 @@ TEST_P(TextClassifierTest, AnnotateSmallBatches) { EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty()); } -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -782,24 +779,22 @@ TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) { unpacked_model->triggering_options->min_annotate_confidence = 2.f; // Discards all results. flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; - EXPECT_EQ(classifier->Annotate(test_string).size(), 1); + EXPECT_EQ(classifier->Annotate(test_string).size(), 0); } -#endif +#endif // TC3_UNILIB_ICU -TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -809,39 +804,31 @@ TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) { 0.f; // Keeps all results. unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " "number is 853 225 3556"; -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU - EXPECT_EQ(classifier->Annotate(test_string).size(), 3); -#else - // In non-ICU mode there is no "date" result. EXPECT_EQ(classifier->Annotate(test_string).size(), 2); -#endif } -TEST_P(TextClassifierTest, AnnotateDisabled) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, AnnotateDisabled) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); // Disable the model for annotation. unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone " @@ -849,13 +836,11 @@ TEST_P(TextClassifierTest, AnnotateDisabled) { EXPECT_THAT(classifier->Annotate(test_string), IsEmpty()); } -TEST_P(TextClassifierTest, AnnotateFilteredCollections) { - CREATE_UNILIB_FOR_TESTING; +TEST_P(AnnotatorTest, AnnotateFilteredCollections) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), - &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = @@ -864,9 +849,6 @@ TEST_P(TextClassifierTest, AnnotateFilteredCollections) { EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU - IsAnnotatedSpan(19, 24, "date"), -#endif IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); @@ -879,30 +861,25 @@ TEST_P(TextClassifierTest, AnnotateFilteredCollections) { "phone"); flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU - IsAnnotatedSpan(19, 24, "date"), -#endif IsAnnotatedSpan(28, 55, "address"), })); } -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(), - &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + test_model.c_str(), test_model.size(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); const std::string test_string = @@ -911,9 +888,6 @@ TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) { EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU - IsAnnotatedSpan(19, 24, "date"), -#endif IsAnnotatedSpan(28, 55, "address"), IsAnnotatedSpan(79, 91, "phone"), })); @@ -932,25 +906,24 @@ TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) { /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0)); flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_THAT(classifier->Annotate(test_string), ElementsAreArray({ - IsAnnotatedSpan(19, 24, "date"), IsAnnotatedSpan(28, 55, "address"), })); } -#endif +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU -TEST_P(TextClassifierTest, ClassifyTextDate) { - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam()); +#ifdef TC3_CALENDAR_ICU +TEST_P(AnnotatorTest, ClassifyTextDate) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam()); EXPECT_TRUE(classifier); std::vector<ClassificationResult> result; @@ -996,12 +969,12 @@ TEST_P(TextClassifierTest, ClassifyTextDate) { EXPECT_EQ(result[0].datetime_parse_result.granularity, DatetimeGranularity::GRANULARITY_DAY); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU -TEST_P(TextClassifierTest, ClassifyTextDatePriorities) { - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam()); +#ifdef TC3_CALENDAR_ICU +TEST_P(AnnotatorTest, ClassifyTextDatePriorities) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam()); EXPECT_TRUE(classifier); std::vector<ClassificationResult> result; @@ -1029,11 +1002,10 @@ TEST_P(TextClassifierTest, ClassifyTextDatePriorities) { EXPECT_EQ(result[0].datetime_parse_result.granularity, DatetimeGranularity::GRANULARITY_DAY); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU -TEST_P(TextClassifierTest, SuggestTextDateDisabled) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_CALENDAR_ICU +TEST_P(AnnotatorTest, SuggestTextDateDisabled) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); @@ -1043,12 +1015,11 @@ TEST_P(TextClassifierTest, SuggestTextDateDisabled) { ModeFlag_ANNOTATION_AND_CLASSIFICATION; } flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromUnownedBuffer( - reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer( + reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ("date", FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15}))); @@ -1057,14 +1028,15 @@ TEST_P(TextClassifierTest, SuggestTextDateDisabled) { EXPECT_THAT(classifier->Annotate("january 1, 2017"), ElementsAreArray({IsAnnotatedSpan(0, 15, "date")})); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -class TestingTextClassifier : public TextClassifier { +class TestingAnnotator : public Annotator { public: - TestingTextClassifier(const std::string& model, const UniLib* unilib) - : TextClassifier(ViewModel(model.data(), model.size()), unilib) {} + TestingAnnotator(const std::string& model, const UniLib* unilib, + const CalendarLib* calendarlib) + : Annotator(ViewModel(model.data(), model.size()), unilib, calendarlib) {} - using TextClassifier::ResolveConflicts; + using Annotator::ResolveConflicts; }; AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span, @@ -1076,9 +1048,8 @@ AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span, return result; } -TEST(TextClassifierTest, ResolveConflictsTrivial) { - CREATE_UNILIB_FOR_TESTING; - TestingTextClassifier classifier("", &unilib); +TEST_F(AnnotatorTest, ResolveConflictsTrivial) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); std::vector<AnnotatedSpan> candidates{ {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}}; @@ -1089,9 +1060,8 @@ TEST(TextClassifierTest, ResolveConflictsTrivial) { EXPECT_THAT(chosen, ElementsAreArray({0})); } -TEST(TextClassifierTest, ResolveConflictsSequence) { - CREATE_UNILIB_FOR_TESTING; - TestingTextClassifier classifier("", &unilib); +TEST_F(AnnotatorTest, ResolveConflictsSequence) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); std::vector<AnnotatedSpan> candidates{{ MakeAnnotatedSpan({0, 1}, "phone", 1.0), @@ -1107,9 +1077,8 @@ TEST(TextClassifierTest, ResolveConflictsSequence) { EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4})); } -TEST(TextClassifierTest, ResolveConflictsThreeSpans) { - CREATE_UNILIB_FOR_TESTING; - TestingTextClassifier classifier("", &unilib); +TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); std::vector<AnnotatedSpan> candidates{{ MakeAnnotatedSpan({0, 3}, "phone", 1.0), @@ -1123,9 +1092,8 @@ TEST(TextClassifierTest, ResolveConflictsThreeSpans) { EXPECT_THAT(chosen, ElementsAreArray({0, 2})); } -TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) { - CREATE_UNILIB_FOR_TESTING; - TestingTextClassifier classifier("", &unilib); +TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); std::vector<AnnotatedSpan> candidates{{ MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser! @@ -1139,9 +1107,8 @@ TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) { EXPECT_THAT(chosen, ElementsAreArray({1})); } -TEST(TextClassifierTest, ResolveConflictsFiveSpans) { - CREATE_UNILIB_FOR_TESTING; - TestingTextClassifier classifier("", &unilib); +TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) { + TestingAnnotator classifier("", &unilib_, &calendarlib_); std::vector<AnnotatedSpan> candidates{{ MakeAnnotatedSpan({0, 3}, "phone", 0.5), @@ -1157,11 +1124,10 @@ TEST(TextClassifierTest, ResolveConflictsFiveSpans) { EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4})); } -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, LongInput) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, LongInput) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); for (const auto& type_value_pair : @@ -1187,15 +1153,14 @@ TEST_P(TextClassifierTest, LongInput) { input_100k, {50000, 50000 + value_length}))); } } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +#ifdef TC3_UNILIB_ICU // These coarse tests are there only to make sure the execution happens in // reasonable amount of time. -TEST_P(TextClassifierTest, LongInputNoResultCheck) { - CREATE_UNILIB_FOR_TESTING; - std::unique_ptr<TextClassifier> classifier = - TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib); +TEST_P(AnnotatorTest, LongInputNoResultCheck) { + std::unique_ptr<Annotator> classifier = + Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); for (const std::string& value : @@ -1209,24 +1174,23 @@ TEST_P(TextClassifierTest, LongInputNoResultCheck) { classifier->ClassifyText(input_100k, {50000, 50000 + value_length}); } } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, MaxTokenLength) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, MaxTokenLength) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); - std::unique_ptr<TextClassifier> classifier; + std::unique_ptr<Annotator> classifier; // With unrestricted number of tokens should behave normally. unpacked_model->classification_options->max_num_tokens = -1; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( @@ -1237,34 +1201,33 @@ TEST_P(TextClassifierTest, MaxTokenLength) { unpacked_model->classification_options->max_num_tokens = 3; flatbuffers::FlatBufferBuilder builder2; - builder2.Finish(Model::Pack(builder2, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder2.GetBufferPointer()), - builder2.GetSize(), &unilib); + builder2.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( "I live at 350 Third Street, Cambridge.", {10, 37})), "other"); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST_P(TextClassifierTest, MinAddressTokenLength) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_P(AnnotatorTest, MinAddressTokenLength) { const std::string test_model = ReadFile(GetModelPath() + GetParam()); std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str()); - std::unique_ptr<TextClassifier> classifier; + std::unique_ptr<Annotator> classifier; // With unrestricted number of address tokens should behave normally. unpacked_model->classification_options->address_min_num_tokens = 0; flatbuffers::FlatBufferBuilder builder; - builder.Finish(Model::Pack(builder, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize(), &unilib); + builder.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( @@ -1275,17 +1238,17 @@ TEST_P(TextClassifierTest, MinAddressTokenLength) { unpacked_model->classification_options->address_min_num_tokens = 5; flatbuffers::FlatBufferBuilder builder2; - builder2.Finish(Model::Pack(builder2, unpacked_model.get())); - classifier = TextClassifier::FromUnownedBuffer( + FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get())); + classifier = Annotator::FromUnownedBuffer( reinterpret_cast<const char*>(builder2.GetBufferPointer()), - builder2.GetSize(), &unilib); + builder2.GetSize(), &unilib_, &calendarlib_); ASSERT_TRUE(classifier); EXPECT_EQ(FirstResult(classifier->ClassifyText( "I live at 350 Third Street, Cambridge.", {10, 37})), "other"); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/cached-features.cc b/annotator/cached-features.cc index 2a46780..480c044 100644 --- a/cached-features.cc +++ b/annotator/cached-features.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "cached-features.h" +#include "annotator/cached-features.h" -#include "tensor-view.h" -#include "util/base/logging.h" +#include "utils/base/logging.h" +#include "utils/tensor-view.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { @@ -67,7 +67,7 @@ std::unique_ptr<CachedFeatures> CachedFeatures::Create( ? 2 : 1; if (options->feature_version() < min_feature_version) { - TC_LOG(ERROR) << "Unsupported feature version."; + TC3_LOG(ERROR) << "Unsupported feature version."; return nullptr; } @@ -170,4 +170,4 @@ int CachedFeatures::NumFeaturesPerToken() const { return padding_features_->size(); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/cached-features.h b/annotator/cached-features.h index 0224d86..e03f79c 100644 --- a/cached-features.h +++ b/annotator/cached-features.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,17 +14,17 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ -#define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ #include <memory> #include <vector> -#include "model-executor.h" -#include "model_generated.h" -#include "types.h" +#include "annotator/model-executor.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Holds state for extracting features across multiple calls and reusing them. // Assumes that features for each Token are independent. @@ -78,6 +78,6 @@ class CachedFeatures { std::unique_ptr<std::vector<float>> padding_features_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ diff --git a/cached-features_test.cc b/annotator/cached-features_test.cc index f064a63..702f3ca 100644 --- a/cached-features_test.cc +++ b/annotator/cached-features_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "cached-features.h" +#include "annotator/cached-features.h" -#include "model-executor.h" -#include "tensor-view.h" +#include "annotator/model-executor.h" +#include "utils/tensor-view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -26,7 +26,7 @@ using testing::ElementsAreArray; using testing::FloatEq; using testing::Matcher; -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { @@ -154,4 +154,4 @@ TEST(CachedFeaturesTest, BoundsSensitive) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/datetime/extractor.cc b/annotator/datetime/extractor.cc index f4ab8f4..31229dd 100644 --- a/datetime/extractor.cc +++ b/annotator/datetime/extractor.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "datetime/extractor.h" +#include "annotator/datetime/extractor.h" -#include "util/base/logging.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { bool DatetimeExtractor::Extract(DateParseData* result, CodepointSpan* result_span) const { @@ -36,7 +36,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, continue; } if (!GroupTextFromMatch(group_id, &group_text)) { - TC_LOG(ERROR) << "Couldn't retrieve group."; + TC3_LOG(ERROR) << "Couldn't retrieve group."; return false; } // The pattern can have a group defined in a part that was not matched, @@ -47,7 +47,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, switch (group_type) { case DatetimeGroupType_GROUP_YEAR: { if (!ParseYear(group_text, &(result->year))) { - TC_LOG(ERROR) << "Couldn't extract YEAR."; + TC3_LOG(ERROR) << "Couldn't extract YEAR."; return false; } result->field_set_mask |= DateParseData::YEAR_FIELD; @@ -55,7 +55,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_MONTH: { if (!ParseMonth(group_text, &(result->month))) { - TC_LOG(ERROR) << "Couldn't extract MONTH."; + TC3_LOG(ERROR) << "Couldn't extract MONTH."; return false; } result->field_set_mask |= DateParseData::MONTH_FIELD; @@ -63,7 +63,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_DAY: { if (!ParseDigits(group_text, &(result->day_of_month))) { - TC_LOG(ERROR) << "Couldn't extract DAY."; + TC3_LOG(ERROR) << "Couldn't extract DAY."; return false; } result->field_set_mask |= DateParseData::DAY_FIELD; @@ -71,7 +71,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_HOUR: { if (!ParseDigits(group_text, &(result->hour))) { - TC_LOG(ERROR) << "Couldn't extract HOUR."; + TC3_LOG(ERROR) << "Couldn't extract HOUR."; return false; } result->field_set_mask |= DateParseData::HOUR_FIELD; @@ -79,7 +79,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_MINUTE: { if (!ParseDigits(group_text, &(result->minute))) { - TC_LOG(ERROR) << "Couldn't extract MINUTE."; + TC3_LOG(ERROR) << "Couldn't extract MINUTE."; return false; } result->field_set_mask |= DateParseData::MINUTE_FIELD; @@ -87,7 +87,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_SECOND: { if (!ParseDigits(group_text, &(result->second))) { - TC_LOG(ERROR) << "Couldn't extract SECOND."; + TC3_LOG(ERROR) << "Couldn't extract SECOND."; return false; } result->field_set_mask |= DateParseData::SECOND_FIELD; @@ -95,7 +95,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_AMPM: { if (!ParseAMPM(group_text, &(result->ampm))) { - TC_LOG(ERROR) << "Couldn't extract AMPM."; + TC3_LOG(ERROR) << "Couldn't extract AMPM."; return false; } result->field_set_mask |= DateParseData::AMPM_FIELD; @@ -103,7 +103,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_RELATIONDISTANCE: { if (!ParseRelationDistance(group_text, &(result->relation_distance))) { - TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD."; + TC3_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD."; return false; } result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD; @@ -111,7 +111,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_RELATION: { if (!ParseRelation(group_text, &(result->relation))) { - TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD."; + TC3_LOG(ERROR) << "Couldn't extract RELATION_FIELD."; return false; } result->field_set_mask |= DateParseData::RELATION_FIELD; @@ -119,7 +119,7 @@ bool DatetimeExtractor::Extract(DateParseData* result, } case DatetimeGroupType_GROUP_RELATIONTYPE: { if (!ParseRelationType(group_text, &(result->relation_type))) { - TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD."; + TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD."; return false; } result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD; @@ -129,11 +129,11 @@ bool DatetimeExtractor::Extract(DateParseData* result, case DatetimeGroupType_GROUP_DUMMY2: break; default: - TC_LOG(INFO) << "Unknown group type."; + TC3_LOG(INFO) << "Unknown group type."; continue; } if (!UpdateMatchSpan(group_id, result_span)) { - TC_LOG(ERROR) << "Couldn't update span."; + TC3_LOG(ERROR) << "Couldn't update span."; return false; } } @@ -466,4 +466,4 @@ bool DatetimeExtractor::ParseWeekday(const UnicodeText& input, parsed_weekday); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/datetime/extractor.h b/annotator/datetime/extractor.h index 5c36ec4..4c17aa7 100644 --- a/datetime/extractor.h +++ b/annotator/datetime/extractor.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,20 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_ -#define LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_ #include <string> #include <unordered_map> #include <vector> -#include "model_generated.h" -#include "types.h" -#include "util/strings/stringpiece.h" -#include "util/utf8/unicodetext.h" -#include "util/utf8/unilib.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { struct CompiledRule { // The compiled regular expression. @@ -106,6 +106,6 @@ class DatetimeExtractor { type_and_locale_to_rule_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_ diff --git a/datetime/parser.cc b/annotator/datetime/parser.cc index 4bc5dff..ac3a62d 100644 --- a/datetime/parser.cc +++ b/annotator/datetime/parser.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,22 +14,22 @@ * limitations under the License. */ -#include "datetime/parser.h" +#include "annotator/datetime/parser.h" #include <set> #include <unordered_set> -#include "datetime/extractor.h" -#include "util/calendar/calendar.h" -#include "util/i18n/locale.h" -#include "util/strings/split.h" +#include "annotator/datetime/extractor.h" +#include "utils/calendar/calendar.h" +#include "utils/i18n/locale.h" +#include "utils/strings/split.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { std::unique_ptr<DatetimeParser> DatetimeParser::Instance( const DatetimeModel* model, const UniLib& unilib, - ZlibDecompressor* decompressor) { + const CalendarLib& calendarlib, ZlibDecompressor* decompressor) { std::unique_ptr<DatetimeParser> result( - new DatetimeParser(model, unilib, decompressor)); + new DatetimeParser(model, unilib, calendarlib, decompressor)); if (!result->initialized_) { result.reset(); } @@ -37,8 +37,9 @@ std::unique_ptr<DatetimeParser> DatetimeParser::Instance( } DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, + const CalendarLib& calendarlib, ZlibDecompressor* decompressor) - : unilib_(unilib) { + : unilib_(unilib), calendarlib_(calendarlib) { initialized_ = false; if (model == nullptr) { @@ -54,7 +55,7 @@ DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, regex->compressed_pattern(), decompressor); if (!regex_pattern) { - TC_LOG(ERROR) << "Couldn't create rule pattern."; + TC3_LOG(ERROR) << "Couldn't create rule pattern."; return; } rules_.push_back({std::move(regex_pattern), regex, pattern}); @@ -75,7 +76,7 @@ DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, extractor->compressed_pattern(), decompressor); if (!regex_pattern) { - TC_LOG(ERROR) << "Couldn't create extractor pattern"; + TC3_LOG(ERROR) << "Couldn't create extractor pattern"; return; } extractor_rules_.push_back(std::move(regex_pattern)); @@ -393,7 +394,7 @@ bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, result->granularity = GetGranularity(parse); - if (!calendar_lib_.InterpretParseData( + if (!calendarlib_.InterpretParseData( parse, reference_time_ms_utc, reference_timezone, reference_locale, result->granularity, &(result->time_ms_utc))) { return false; @@ -402,4 +403,4 @@ bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, return true; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/datetime/parser.h b/annotator/datetime/parser.h index 0666607..c7eaf1f 100644 --- a/datetime/parser.h +++ b/annotator/datetime/parser.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_DATETIME_PARSER_H_ -#define LIBTEXTCLASSIFIER_DATETIME_PARSER_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ #include <memory> #include <string> @@ -23,15 +23,15 @@ #include <unordered_set> #include <vector> -#include "datetime/extractor.h" -#include "model_generated.h" -#include "types.h" -#include "util/base/integral_types.h" -#include "util/calendar/calendar.h" -#include "util/utf8/unilib.h" -#include "zlib-utils.h" +#include "annotator/datetime/extractor.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/calendar/calendar.h" +#include "utils/utf8/unilib.h" +#include "utils/zlib/zlib.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Parses datetime expressions in the input and resolves them to actual absolute // time. @@ -39,7 +39,7 @@ class DatetimeParser { public: static std::unique_ptr<DatetimeParser> Instance( const DatetimeModel* model, const UniLib& unilib, - ZlibDecompressor* decompressor); + const CalendarLib& calendarlib, ZlibDecompressor* decompressor); // Parses the dates in 'input' and fills result. Makes sure that the results // do not overlap. @@ -58,6 +58,7 @@ class DatetimeParser { protected: DatetimeParser(const DatetimeModel* model, const UniLib& unilib, + const CalendarLib& calendarlib, ZlibDecompressor* decompressor); // Returns a list of locale ids for given locale spec string (comma-separated @@ -101,6 +102,7 @@ class DatetimeParser { private: bool initialized_; const UniLib& unilib_; + const CalendarLib& calendarlib_; std::vector<CompiledRule> rules_; std::unordered_map<int, std::vector<int>> locale_to_rules_; std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_; @@ -108,10 +110,9 @@ class DatetimeParser { type_and_locale_to_extractor_rule_; std::unordered_map<std::string, int> locale_string_to_id_; std::vector<int> default_locale_ids_; - CalendarLib calendar_lib_; bool use_extractors_for_locating_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_DATETIME_PARSER_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ diff --git a/datetime/parser_test.cc b/annotator/datetime/parser_test.cc index e61ed12..d46accf 100644 --- a/datetime/parser_test.cc +++ b/annotator/datetime/parser_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,18 +23,18 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "datetime/parser.h" -#include "model_generated.h" -#include "text-classifier.h" -#include "types-test-util.h" +#include "annotator/annotator.h" +#include "annotator/datetime/parser.h" +#include "annotator/model_generated.h" +#include "annotator/types-test-util.h" using testing::ElementsAreArray; -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { std::string GetModelPath() { - return LIBTEXTCLASSIFIER_TEST_DATA_DIR; + return TC3_TEST_DATA_DIR; } std::string ReadFile(const std::string& file_name) { @@ -55,9 +55,9 @@ class ParserTest : public testing::Test { public: void SetUp() override { model_buffer_ = ReadFile(GetModelPath() + "test_model.fb"); - classifier_ = TextClassifier::FromUnownedBuffer( - model_buffer_.data(), model_buffer_.size(), &unilib_); - TC_CHECK(classifier_); + classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(), + model_buffer_.size(), &unilib_); + TC3_CHECK(classifier_); parser_ = classifier_->DatetimeParserForTests(); } @@ -66,8 +66,8 @@ class ParserTest : public testing::Test { std::vector<DatetimeParseResultSpan> results; if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION, anchor_start_end, &results)) { - TC_LOG(ERROR) << text; - TC_CHECK(false); + TC3_LOG(ERROR) << text; + TC3_CHECK(false); } return results.empty(); } @@ -84,8 +84,8 @@ class ParserTest : public testing::Test { std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{'); auto brace_end_it = std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}'); - TC_CHECK(brace_open_it != marked_text_unicode.end()); - TC_CHECK(brace_end_it != marked_text_unicode.end()); + TC3_CHECK(brace_open_it != marked_text_unicode.end()); + TC3_CHECK(brace_end_it != marked_text_unicode.end()); std::string text; text += @@ -98,11 +98,11 @@ class ParserTest : public testing::Test { if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION, anchor_start_end, &results)) { - TC_LOG(ERROR) << text; - TC_CHECK(false); + TC3_LOG(ERROR) << text; + TC3_CHECK(false); } if (results.empty()) { - TC_LOG(ERROR) << "No results."; + TC3_LOG(ERROR) << "No results."; return false; } @@ -124,16 +124,16 @@ class ParserTest : public testing::Test { {{expected_start_index, expected_end_index}, {expected_ms_utc, expected_granularity}, /*target_classification_score=*/1.0, - /*priority_score=*/0.0}}; + /*priority_score=*/0.1}}; const bool matches = testing::Matches(ElementsAreArray(expected))(filtered_results); if (!matches) { - TC_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: " - << FormatMillis(expected[0].data.time_ms_utc); + TC3_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: " + << FormatMillis(expected[0].data.time_ms_utc); for (int i = 0; i < filtered_results.size(); ++i) { - TC_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i] - << " which corresponds to: " - << FormatMillis(filtered_results[i].data.time_ms_utc); + TC3_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i] + << " which corresponds to: " + << FormatMillis(filtered_results[i].data.time_ms_utc); } } return matches; @@ -149,7 +149,7 @@ class ParserTest : public testing::Test { protected: std::string model_buffer_; - std::unique_ptr<TextClassifier> classifier_; + std::unique_ptr<Annotator> classifier_; const DatetimeParser* parser_; UniLib unilib_; }; @@ -158,7 +158,6 @@ class ParserTest : public testing::Test { TEST_F(ParserTest, ParseShort) { EXPECT_TRUE( ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY)); } TEST_F(ParserTest, Parse) { @@ -176,30 +175,23 @@ TEST_F(ParserTest, Parse) { GRANULARITY_SECOND)); EXPECT_TRUE( ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29},573", 1277512289000, + EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}", 1277512289000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectly("{11:42:35}.173", 38555000, GRANULARITY_SECOND)); EXPECT_TRUE( - ParsesCorrectly("{23/Apr 11:42:35},173", 9715355000, GRANULARITY_SECOND)); + ParsesCorrectly("{23/Apr 11:42:35}", 9715355000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}.883", 1429782155000, - GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}.883", 1429782155000, - GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}.883", 1429782155000, - GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectly( @@ -221,26 +213,17 @@ TEST_F(ParserTest, Parse) { EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000, GRANULARITY_HOUR)); - EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_DAY, + EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -3600000, GRANULARITY_MINUTE)); + EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -57600000, GRANULARITY_MINUTE, /*anchor_start_end=*/false, "America/Los_Angeles")); - EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_WEEK)); - EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectly("{in three days}", 255600000, GRANULARITY_DAY)); - EXPECT_TRUE( - ParsesCorrectly("{in three weeks}", 1465200000, GRANULARITY_WEEK)); - EXPECT_TRUE(ParsesCorrectly("{tomorrow}", 82800000, GRANULARITY_DAY)); EXPECT_TRUE( ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE)); - EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4}", 97200000, GRANULARITY_HOUR)); - EXPECT_TRUE(ParsesCorrectly("{next wednesday}", 514800000, GRANULARITY_DAY)); + EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR)); EXPECT_TRUE( - ParsesCorrectly("{next wednesday at 4}", 529200000, GRANULARITY_HOUR)); + ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR)); EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000, GRANULARITY_MINUTE)); - EXPECT_TRUE(ParsesCorrectly("{Three days ago}", -262800000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY)); } TEST_F(ParserTest, ParseWithAnchor) { @@ -271,15 +254,13 @@ TEST_F(ParserTest, ParseGerman) { GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29},573", 1277512289000, + EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}", 1277512289000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000, GRANULARITY_SECOND)); EXPECT_TRUE( ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{11:42:35}.173", 38555000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35},173", 9715355000, + EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35}", 9715355000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000, GRANULARITY_SECOND)); @@ -287,18 +268,12 @@ TEST_F(ParserTest, ParseGerman) { GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}.883", 1429782155000, - GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}.883", 1429782155000, - GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000, GRANULARITY_SECOND)); - EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}.883", 1429782155000, - GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000, GRANULARITY_SECOND)); EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000, @@ -309,32 +284,12 @@ TEST_F(ParserTest, ParseGerman) { GRANULARITY_HOUR)); EXPECT_TRUE( ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectlyGerman("{heute}", -3600000, GRANULARITY_DAY)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{nächste Woche}", 342000000, GRANULARITY_WEEK)); EXPECT_TRUE( - ParsesCorrectlyGerman("{nächsten Tag}", 82800000, GRANULARITY_DAY)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{in drei Tagen}", 255600000, GRANULARITY_DAY)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{in drei Wochen}", 1551600000, GRANULARITY_WEEK)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{vor drei Tagen}", -262800000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectlyGerman("{morgen}", 82800000, GRANULARITY_DAY)); + ParsesCorrectlyGerman("{morgen 0:00}", 82800000, GRANULARITY_MINUTE)); EXPECT_TRUE( ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE)); EXPECT_TRUE( - ParsesCorrectlyGerman("{morgen um 4}", 97200000, GRANULARITY_HOUR)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{nächsten Mittwoch}", 514800000, GRANULARITY_DAY)); - EXPECT_TRUE(ParsesCorrectlyGerman("{nächsten Mittwoch um 4}", 529200000, - GRANULARITY_HOUR)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{Vor drei Tagen}", -262800000, GRANULARITY_DAY)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{in einer woche}", 342000000, GRANULARITY_WEEK)); - EXPECT_TRUE( - ParsesCorrectlyGerman("{in einer tag}", 82800000, GRANULARITY_DAY)); + ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR)); } TEST_F(ParserTest, ParseNonUs) { @@ -372,6 +327,7 @@ class ParserLocaleTest : public testing::Test { protected: UniLib unilib_; + CalendarLib calendarlib_; flatbuffers::FlatBufferBuilder builder_; std::unique_ptr<DatetimeParser> parser_; }; @@ -412,7 +368,7 @@ void ParserLocaleTest::SetUp() { flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer()); ASSERT_TRUE(model_fb); - parser_ = DatetimeParser::Instance(model_fb, unilib_, + parser_ = DatetimeParser::Instance(model_fb, unilib_, calendarlib_, /*decompressor=*/nullptr); ASSERT_TRUE(parser_); } @@ -454,4 +410,4 @@ TEST_F(ParserLocaleTest, SwissEnglish) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/feature-processor.cc b/annotator/feature-processor.cc index 551e649..a18393b 100644 --- a/feature-processor.cc +++ b/annotator/feature-processor.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,17 +14,17 @@ * limitations under the License. */ -#include "feature-processor.h" +#include "annotator/feature-processor.h" #include <iterator> #include <set> #include <vector> -#include "util/base/logging.h" -#include "util/strings/utf8.h" -#include "util/utf8/unicodetext.h" +#include "utils/base/logging.h" +#include "utils/strings/utf8.h" +#include "utils/utf8/unicodetext.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace internal { @@ -111,16 +111,6 @@ void SplitTokensOnSelectionBoundaries(CodepointSpan selection, } } -const UniLib* MaybeCreateUnilib(const UniLib* unilib, - std::unique_ptr<UniLib>* owned_unilib) { - if (unilib) { - return unilib; - } else { - owned_unilib->reset(new UniLib); - return owned_unilib->get(); - } -} - } // namespace internal void FeatureProcessor::StripTokensFromOtherLines( @@ -168,7 +158,7 @@ std::string FeatureProcessor::GetDefaultCollection() const { if (options_->default_collection() < 0 || options_->collections() == nullptr || options_->default_collection() >= options_->collections()->size()) { - TC_LOG(ERROR) + TC3_LOG(ERROR) << "Invalid or missing default collection. Returning empty string."; return ""; } @@ -199,8 +189,8 @@ std::vector<Token> FeatureProcessor::Tokenize( } return result; } else { - TC_LOG(ERROR) << "Unknown tokenization type specified. Using " - "internal."; + TC3_LOG(ERROR) << "Unknown tokenization type specified. Using " + "internal."; return tokenizer_.Tokenize(text_unicode); } } @@ -462,7 +452,7 @@ int FeatureProcessor::FindCenterToken(CodepointSpan span, return internal::CenterTokenFromMiddleOfSelection(span, tokens); } } else { - TC_LOG(ERROR) << "Invalid center token selection method."; + TC3_LOG(ERROR) << "Invalid center token selection method."; return kInvalidIndex; } } @@ -473,7 +463,7 @@ bool FeatureProcessor::SelectionLabelSpans( for (int i = 0; i < label_to_selection_.size(); ++i) { CodepointSpan span; if (!LabelToSpan(i, tokens, &span)) { - TC_LOG(ERROR) << "Could not convert label to span: " << i; + TC3_LOG(ERROR) << "Could not convert label to span: " << i; return false; } selection_label_spans->push_back(span); @@ -711,7 +701,7 @@ void FeatureProcessor::RetokenizeAndFindClick( const UnicodeText& context_unicode, CodepointSpan input_span, bool only_use_line_with_click, std::vector<Token>* tokens, int* click_pos) const { - TC_CHECK(tokens != nullptr); + TC3_CHECK(tokens != nullptr); if (options_->split_tokens_on_selection_boundaries()) { internal::SplitTokensOnSelectionBoundaries(input_span, tokens); @@ -777,8 +767,8 @@ bool FeatureProcessor::HasEnoughSupportedCodepoints( const float supported_codepoint_ratio = SupportedCodepointsRatio(token_span, tokens); if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) { - TC_VLOG(1) << "Not enough supported codepoints in the context: " - << supported_codepoint_ratio; + TC3_VLOG(1) << "Not enough supported codepoints in the context: " + << supported_codepoint_ratio; return false; } } @@ -797,7 +787,7 @@ bool FeatureProcessor::ExtractFeatures( if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature, embedding_executor, embedding_cache, features.get())) { - TC_LOG(ERROR) << "Could not get token features."; + TC3_LOG(ERROR) << "Could not get token features."; return false; } } @@ -808,7 +798,7 @@ bool FeatureProcessor::ExtractFeatures( if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature, embedding_executor, embedding_cache, padding_features.get())) { - TC_LOG(ERROR) << "Count not get padding token features."; + TC3_LOG(ERROR) << "Count not get padding token features."; return false; } @@ -816,7 +806,7 @@ bool FeatureProcessor::ExtractFeatures( std::move(padding_features), options_, feature_vector_size); if (!*cached_features) { - TC_LOG(ERROR) << "Cound not create cached features."; + TC3_LOG(ERROR) << "Cound not create cached features."; return false; } @@ -945,7 +935,7 @@ bool FeatureProcessor::AppendTokenFeaturesWithCache( if (!feature_extractor_.Extract( token, token.IsContainedInSpan(selection_span_for_feature), /*sparse_features=*/nullptr, &dense_features)) { - TC_LOG(ERROR) << "Could not extract token's dense features."; + TC3_LOG(ERROR) << "Could not extract token's dense features."; return false; } @@ -964,7 +954,7 @@ bool FeatureProcessor::AppendTokenFeaturesWithCache( if (!feature_extractor_.Extract( token, token.IsContainedInSpan(selection_span_for_feature), &sparse_features, &dense_features)) { - TC_LOG(ERROR) << "Could not extract token's features."; + TC3_LOG(ERROR) << "Could not extract token's features."; return false; } @@ -978,7 +968,7 @@ bool FeatureProcessor::AppendTokenFeaturesWithCache( {static_cast<int>(sparse_features.size())}), /*dest=*/output_features_end - embedding_size, /*dest_size=*/embedding_size)) { - TC_LOG(ERROR) << "Cound not embed token's sparse features."; + TC3_LOG(ERROR) << "Cound not embed token's sparse features."; return false; } @@ -995,4 +985,4 @@ bool FeatureProcessor::AppendTokenFeaturesWithCache( return true; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/feature-processor.h b/annotator/feature-processor.h index 98d3449..2d04253 100644 --- a/feature-processor.h +++ b/annotator/feature-processor.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ // Feature processing for FFModel (feed-forward SmartSelection model). -#ifndef LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ -#define LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ #include <map> #include <memory> @@ -25,17 +25,17 @@ #include <string> #include <vector> -#include "cached-features.h" -#include "model_generated.h" -#include "token-feature-extractor.h" -#include "tokenizer.h" -#include "types.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" -#include "util/utf8/unicodetext.h" -#include "util/utf8/unilib.h" +#include "annotator/cached-features.h" +#include "annotator/model_generated.h" +#include "annotator/token-feature-extractor.h" +#include "annotator/tokenizer.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/base/logging.h" +#include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { constexpr int kInvalidLabel = -1; @@ -64,11 +64,6 @@ int CenterTokenFromMiddleOfSelection( void StripOrPadTokens(TokenSpan relative_click_span, int context_size, std::vector<Token>* tokens, int* click_pos); -// If unilib is not nullptr, just returns unilib. Otherwise, if unilib is -// nullptr, will create UniLib, assign ownership to owned_unilib, and return it. -const UniLib* MaybeCreateUnilib(const UniLib* unilib, - std::unique_ptr<UniLib>* owned_unilib); - } // namespace internal // Converts a codepoint span to a token span in the given list of tokens. @@ -93,12 +88,8 @@ class FeatureProcessor { // identical. typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache; - // If unilib is nullptr, will create and own an instance of a UniLib, - // otherwise will use what's passed in. - explicit FeatureProcessor(const FeatureProcessorOptions* options, - const UniLib* unilib = nullptr) - : owned_unilib_(nullptr), - unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)), + FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib) + : unilib_(unilib), feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options), *unilib_), options_(options), @@ -303,7 +294,6 @@ class FeatureProcessor { std::vector<float>* output_features) const; private: - std::unique_ptr<UniLib> owned_unilib_; const UniLib* unilib_; protected: @@ -336,6 +326,6 @@ class FeatureProcessor { Tokenizer tokenizer_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ diff --git a/feature-processor_test.cc b/annotator/feature-processor_test.cc index 58b3033..c9f0e0d 100644 --- a/feature-processor_test.cc +++ b/annotator/feature-processor_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "feature-processor.h" +#include "annotator/feature-processor.h" -#include "model-executor.h" -#include "tensor-view.h" +#include "annotator/model-executor.h" +#include "utils/tensor-view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { using testing::ElementsAreArray; @@ -66,7 +66,7 @@ class FakeEmbeddingExecutor : public EmbeddingExecutor { public: bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, int dest_size) const override { - TC_CHECK_GE(dest_size, 4); + TC3_CHECK_GE(dest_size, 4); EXPECT_EQ(sparse_features.size(), 1); dest[0] = sparse_features.data()[0]; dest[1] = sparse_features.data()[0]; @@ -79,7 +79,13 @@ class FakeEmbeddingExecutor : public EmbeddingExecutor { std::vector<float> storage_; }; -TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) { +class FeatureProcessorTest : public ::testing::Test { + protected: + FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) { std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), Token("heře!", 24, 29)}; @@ -96,7 +102,7 @@ TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) { // clang-format on } -TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) { +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) { std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), Token("heře!", 24, 29)}; @@ -112,7 +118,7 @@ TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) { // clang-format on } -TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) { +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) { std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), Token("heře!", 24, 29)}; @@ -128,7 +134,7 @@ TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) { // clang-format on } -TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) { +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) { std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), Token("heře!", 24, 29)}; @@ -143,7 +149,7 @@ TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) { // clang-format on } -TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) { +TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) { std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), Token("heře!", 24, 29)}; @@ -160,14 +166,13 @@ TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) { // clang-format on } -TEST(FeatureProcessorTest, KeepLineWithClickFirst) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, KeepLineWithClickFirst) { FeatureProcessorOptionsT options; options.only_use_line_with_click = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {0, 5}; @@ -186,14 +191,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickFirst) { ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)})); } -TEST(FeatureProcessorTest, KeepLineWithClickSecond) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, KeepLineWithClickSecond) { FeatureProcessorOptionsT options; options.only_use_line_with_click = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {18, 22}; @@ -212,14 +216,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickSecond) { {Token("Sěcond", 11, 17), Token("Lině", 18, 22)})); } -TEST(FeatureProcessorTest, KeepLineWithClickThird) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, KeepLineWithClickThird) { FeatureProcessorOptionsT options; options.only_use_line_with_click = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {24, 33}; @@ -238,14 +241,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickThird) { {Token("Thiřd", 23, 28), Token("Lině", 29, 33)})); } -TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { FeatureProcessorOptionsT options; options.only_use_line_with_click = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině"; const CodepointSpan span = {18, 22}; @@ -264,14 +266,13 @@ TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { {Token("Sěcond", 11, 17), Token("Lině", 18, 22)})); } -TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, KeepLineWithCrosslineClick) { FeatureProcessorOptionsT options; options.only_use_line_with_click = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině"; const CodepointSpan span = {5, 23}; @@ -292,8 +293,7 @@ TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) { Token("Thiřd", 23, 28), Token("Lině", 29, 33)})); } -TEST(FeatureProcessorTest, SpanToLabel) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, SpanToLabel) { FeatureProcessorOptionsT options; options.context_size = 1; options.max_selection_span = 1; @@ -309,7 +309,7 @@ TEST(FeatureProcessorTest, SpanToLabel) { flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); ASSERT_EQ(3, tokens.size()); int label; @@ -328,7 +328,7 @@ TEST(FeatureProcessorTest, SpanToLabel) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor2( flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), - &unilib); + &unilib_); int label2; ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); EXPECT_EQ(label, label2); @@ -350,7 +350,7 @@ TEST(FeatureProcessorTest, SpanToLabel) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor3( flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), - &unilib); + &unilib_); tokens = feature_processor3.Tokenize("zero, one, two, three, four"); ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); EXPECT_NE(kInvalidLabel, label2); @@ -367,8 +367,7 @@ TEST(FeatureProcessorTest, SpanToLabel) { EXPECT_EQ(label2, label3); } -TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { FeatureProcessorOptionsT options; options.context_size = 1; options.max_selection_span = 1; @@ -384,7 +383,7 @@ TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); ASSERT_EQ(3, tokens.size()); int label; @@ -403,7 +402,7 @@ TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor2( flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), - &unilib); + &unilib_); int label2; ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); EXPECT_EQ(label, label2); @@ -425,7 +424,7 @@ TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor3( flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), - &unilib); + &unilib_); tokens = feature_processor3.Tokenize("zero, one, two, three, four"); ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); EXPECT_NE(kInvalidLabel, label2); @@ -442,7 +441,7 @@ TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { EXPECT_EQ(label2, label3); } -TEST(FeatureProcessorTest, CenterTokenFromClick) { +TEST_F(FeatureProcessorTest, CenterTokenFromClick) { int token_index; // Exactly aligned indices. @@ -464,7 +463,7 @@ TEST(FeatureProcessorTest, CenterTokenFromClick) { EXPECT_EQ(token_index, kInvalidIndex); } -TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) { +TEST_F(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) { int token_index; // Selection of length 3. Exactly aligned indices. @@ -507,7 +506,7 @@ TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) { EXPECT_EQ(token_index, -1); } -TEST(FeatureProcessorTest, SupportedCodepointsRatio) { +TEST_F(FeatureProcessorTest, SupportedCodepointsRatio) { FeatureProcessorOptionsT options; options.context_size = 2; options.max_selection_span = 2; @@ -556,10 +555,9 @@ TEST(FeatureProcessorTest, SupportedCodepointsRatio) { } flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); - CREATE_UNILIB_FOR_TESTING; TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); EXPECT_THAT(feature_processor.SupportedCodepointsRatio( {0, 3}, feature_processor.Tokenize("aaa bbb ccc")), FloatEq(1.0)); @@ -596,7 +594,7 @@ TEST(FeatureProcessorTest, SupportedCodepointsRatio) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor2( flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), - &unilib); + &unilib_); EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints( tokens, /*token_span=*/{0, 3})); @@ -605,7 +603,7 @@ TEST(FeatureProcessorTest, SupportedCodepointsRatio) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor3( flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), - &unilib); + &unilib_); EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints( tokens, /*token_span=*/{0, 3})); @@ -614,12 +612,12 @@ TEST(FeatureProcessorTest, SupportedCodepointsRatio) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor4( flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()), - &unilib); + &unilib_); EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints( tokens, /*token_span=*/{0, 3})); } -TEST(FeatureProcessorTest, InSpanFeature) { +TEST_F(FeatureProcessorTest, InSpanFeature) { FeatureProcessorOptionsT options; options.context_size = 2; options.max_selection_span = 2; @@ -629,10 +627,9 @@ TEST(FeatureProcessorTest, InSpanFeature) { options.extract_selection_mask_feature = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); - CREATE_UNILIB_FOR_TESTING; TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); std::unique_ptr<CachedFeatures> cached_features; @@ -656,7 +653,7 @@ TEST(FeatureProcessorTest, InSpanFeature) { EXPECT_THAT(features[24], FloatEq(0.0)); } -TEST(FeatureProcessorTest, EmbeddingCache) { +TEST_F(FeatureProcessorTest, EmbeddingCache) { FeatureProcessorOptionsT options; options.context_size = 2; options.max_selection_span = 2; @@ -672,10 +669,9 @@ TEST(FeatureProcessorTest, EmbeddingCache) { options.bounds_sensitive_features->num_tokens_after = 3; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); - CREATE_UNILIB_FOR_TESTING; TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); std::unique_ptr<CachedFeatures> cached_features; @@ -726,7 +722,7 @@ TEST(FeatureProcessorTest, EmbeddingCache) { ElementsAreFloat(embedding_cache.at({20, 23}))); } -TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) { +TEST_F(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) { std::vector<Token> tokens_orig{ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0), Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0), @@ -776,7 +772,7 @@ TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) { EXPECT_EQ(click_index, 2); } -TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) { +TEST_F(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) { std::vector<Token> tokens_orig{ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0), Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0), @@ -838,8 +834,7 @@ TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) { EXPECT_EQ(click_index, 5); } -TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, InternalTokenizeOnScriptChange) { FeatureProcessorOptionsT options; options.tokenization_codepoint_config.emplace_back( new TokenizationCodepointRangeT()); @@ -855,7 +850,7 @@ TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) { flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"), std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)})); @@ -865,21 +860,23 @@ TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) { PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor2( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()), - &unilib); + &unilib_); EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"), std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7), Token("웹사이트", 7, 11)})); } -#ifdef LIBTEXTCLASSIFIER_TEST_ICU -TEST(FeatureProcessorTest, ICUTokenize) { +#ifdef TC3_TEST_ICU +TEST_F(FeatureProcessorTest, ICUTokenize) { FeatureProcessorOptionsT options; options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + UniLib unilib; TestingFeatureProcessor feature_processor( - flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data())); + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ"); ASSERT_EQ(tokens, // clang-format off @@ -892,15 +889,17 @@ TEST(FeatureProcessorTest, ICUTokenize) { } #endif -#ifdef LIBTEXTCLASSIFIER_TEST_ICU -TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) { +#ifdef TC3_TEST_ICU +TEST_F(FeatureProcessorTest, ICUTokenizeWithWhitespaces) { FeatureProcessorOptionsT options; options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU; options.icu_preserve_whitespace_tokens = true; flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + UniLib unilib; TestingFeatureProcessor feature_processor( - flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data())); + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); std::vector<Token> tokens = feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ"); ASSERT_EQ(tokens, @@ -918,8 +917,8 @@ TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) { } #endif -#ifdef LIBTEXTCLASSIFIER_TEST_ICU -TEST(FeatureProcessorTest, MixedTokenize) { +#ifdef TC3_TEST_ICU +TEST_F(FeatureProcessorTest, MixedTokenize) { FeatureProcessorOptionsT options; options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED; @@ -963,8 +962,10 @@ TEST(FeatureProcessorTest, MixedTokenize) { } flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); + UniLib unilib; TestingFeatureProcessor feature_processor( - flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data())); + flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), + &unilib); std::vector<Token> tokens = feature_processor.Tokenize( "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/"); ASSERT_EQ(tokens, @@ -978,8 +979,7 @@ TEST(FeatureProcessorTest, MixedTokenize) { } #endif -TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { FeatureProcessorOptionsT options; options.ignored_span_boundary_codepoints.push_back('.'); options.ignored_span_boundary_codepoints.push_back(','); @@ -989,7 +989,7 @@ TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); TestingFeatureProcessor feature_processor( flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), - &unilib); + &unilib_); const std::string text1_utf8 = "ěščř"; const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false); @@ -1091,7 +1091,7 @@ TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { std::make_pair(0, 0)); } -TEST(FeatureProcessorTest, CodepointSpanToTokenSpan) { +TEST_F(FeatureProcessorTest, CodepointSpanToTokenSpan) { const std::vector<Token> tokens{Token("Hělló", 0, 5), Token("fěěbař@google.com", 6, 23), Token("heře!", 24, 29)}; @@ -1122,4 +1122,4 @@ TEST(FeatureProcessorTest, CodepointSpanToTokenSpan) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h new file mode 100644 index 0000000..a6285dc --- /dev/null +++ b/annotator/knowledge/knowledge-engine-dummy.h @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_ + +#include <string> + +#include "annotator/types.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3 { + +// A dummy implementation of the knowledge engine. +class KnowledgeEngine { + public: + explicit KnowledgeEngine(const UniLib* unilib) {} + + bool Initialize(const std::string& serialized_config) { return true; } + + bool ClassifyText(const std::string& context, CodepointSpan selection_indices, + ClassificationResult* classification_result) const { + return false; + } + + bool Chunk(const std::string& context, + std::vector<AnnotatedSpan>* result) const { + return true; + } +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_ diff --git a/util/calendar/calendar.h b/annotator/knowledge/knowledge-engine.h index b0cf2e6..4776b26 100644 --- a/util/calendar/calendar.h +++ b/annotator/knowledge/knowledge-engine.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_ -#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_ -#include "util/calendar/calendar-icu.h" +#include "annotator/knowledge/knowledge-engine-dummy.h" -#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_H_ diff --git a/model-executor.cc b/annotator/model-executor.cc index 69931cb..7c57e8f 100644 --- a/model-executor.cc +++ b/annotator/model-executor.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,49 +14,48 @@ * limitations under the License. */ -#include "model-executor.h" +#include "annotator/model-executor.h" -#include "quantization.h" -#include "util/base/logging.h" +#include "annotator/quantization.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { -namespace internal { -bool FromModelSpec(const tflite::Model* model_spec, - std::unique_ptr<const tflite::FlatBufferModel>* model) { - *model = tflite::FlatBufferModel::BuildFromModel(model_spec); - if (!(*model) || !(*model)->initialized()) { - TC_LOG(ERROR) << "Could not build TFLite model from a model spec. "; - return false; +namespace libtextclassifier3 { + +TensorView<float> ModelExecutor::ComputeLogits( + const TensorView<float>& features, tflite::Interpreter* interpreter) const { + if (!interpreter) { + return TensorView<float>::Invalid(); + } + interpreter->ResizeInputTensor(kInputIndexFeatures, features.shape()); + if (interpreter->AllocateTensors() != kTfLiteOk) { + TC3_VLOG(1) << "Allocation failed."; + return TensorView<float>::Invalid(); + } + + SetInput<float>(kInputIndexFeatures, features, interpreter); + + if (interpreter->Invoke() != kTfLiteOk) { + TC3_VLOG(1) << "Interpreter failed."; + return TensorView<float>::Invalid(); } - return true; -} -} // namespace internal -std::unique_ptr<tflite::Interpreter> ModelExecutor::CreateInterpreter() const { - std::unique_ptr<tflite::Interpreter> interpreter; - tflite::InterpreterBuilder(*model_, builtins_)(&interpreter); - return interpreter; + return OutputView<float>(kOutputIndexLogits, interpreter); } -std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::Instance( +std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer( const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, int quantization_bits) { - const tflite::Model* model_spec = - flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); - flatbuffers::Verifier verifier(model_spec_buffer->data(), - model_spec_buffer->Length()); - std::unique_ptr<const tflite::FlatBufferModel> model; - if (!model_spec->Verify(verifier) || - !internal::FromModelSpec(model_spec, &model)) { - TC_LOG(ERROR) << "Could not load TFLite model."; + std::unique_ptr<TfLiteModelExecutor> executor = + TfLiteModelExecutor::FromBuffer(model_spec_buffer); + if (!executor) { + TC3_LOG(ERROR) << "Could not load TFLite model for embeddings."; return nullptr; } - std::unique_ptr<tflite::Interpreter> interpreter; - tflite::ops::builtin::BuiltinOpResolver builtins; - tflite::InterpreterBuilder(*model, builtins)(&interpreter); + std::unique_ptr<tflite::Interpreter> interpreter = + executor->CreateInterpreter(); if (!interpreter) { - TC_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; + TC3_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; return nullptr; } @@ -76,21 +75,21 @@ std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::Instance( int bytes_per_embedding = embeddings->dims->data[1]; if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits, embedding_size)) { - TC_LOG(ERROR) << "Mismatch in quantization parameters."; + TC3_LOG(ERROR) << "Mismatch in quantization parameters."; return nullptr; } return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor( - std::move(model), quantization_bits, num_buckets, bytes_per_embedding, + std::move(executor), quantization_bits, num_buckets, bytes_per_embedding, embedding_size, scales, embeddings, std::move(interpreter))); } TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( - std::unique_ptr<const tflite::FlatBufferModel> model, int quantization_bits, + std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, int num_buckets, int bytes_per_embedding, int output_embedding_size, const TfLiteTensor* scales, const TfLiteTensor* embeddings, std::unique_ptr<tflite::Interpreter> interpreter) - : model_(std::move(model)), + : executor_(std::move(executor)), quantization_bits_(quantization_bits), num_buckets_(num_buckets), bytes_per_embedding_(bytes_per_embedding), @@ -102,8 +101,8 @@ TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( bool TFLiteEmbeddingExecutor::AddEmbedding( const TensorView<int>& sparse_features, float* dest, int dest_size) const { if (dest_size != output_embedding_size_) { - TC_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " - << dest_size << " " << output_embedding_size_; + TC3_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " + << dest_size << " " << output_embedding_size_; return false; } const int num_sparse_features = sparse_features.size(); @@ -122,41 +121,4 @@ bool TFLiteEmbeddingExecutor::AddEmbedding( return true; } -TensorView<float> ComputeLogitsHelper(const int input_index_features, - const int output_index_logits, - const TensorView<float>& features, - tflite::Interpreter* interpreter) { - if (!interpreter) { - return TensorView<float>::Invalid(); - } - interpreter->ResizeInputTensor(input_index_features, features.shape()); - if (interpreter->AllocateTensors() != kTfLiteOk) { - TC_VLOG(1) << "Allocation failed."; - return TensorView<float>::Invalid(); - } - - TfLiteTensor* features_tensor = - interpreter->tensor(interpreter->inputs()[input_index_features]); - int size = 1; - for (int i = 0; i < features_tensor->dims->size; ++i) { - size *= features_tensor->dims->data[i]; - } - features.copy_to(features_tensor->data.f, size); - - if (interpreter->Invoke() != kTfLiteOk) { - TC_VLOG(1) << "Interpreter failed."; - return TensorView<float>::Invalid(); - } - - TfLiteTensor* logits_tensor = - interpreter->tensor(interpreter->outputs()[output_index_logits]); - - std::vector<int> output_shape(logits_tensor->dims->size); - for (int i = 0; i < logits_tensor->dims->size; ++i) { - output_shape[i] = logits_tensor->dims->data[i]; - } - - return TensorView<float>(logits_tensor->data.f, output_shape); -} - -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/model-executor.h b/annotator/model-executor.h index ef6d36f..5ad3a7f 100644 --- a/model-executor.h +++ b/annotator/model-executor.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,75 +16,48 @@ // Contains classes that can execute different models/parts of a model. -#ifndef LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ -#define LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ #include <memory> -#include "tensor-view.h" -#include "types.h" -#include "util/base/logging.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" +#include "annotator/types.h" +#include "utils/base/logging.h" +#include "utils/tensor-view.h" +#include "utils/tflite-model-executor.h" -namespace libtextclassifier2 { - -namespace internal { -bool FromModelSpec(const tflite::Model* model_spec, - std::unique_ptr<const tflite::FlatBufferModel>* model); -} // namespace internal - -// A helper function that given indices of feature and logits tensor, feature -// values computes the logits using given interpreter. -TensorView<float> ComputeLogitsHelper(const int input_index_features, - const int output_index_logits, - const TensorView<float>& features, - tflite::Interpreter* interpreter); +namespace libtextclassifier3 { // Executor for the text selection prediction and classification models. -class ModelExecutor { +class ModelExecutor : public TfLiteModelExecutor { public: - static std::unique_ptr<const ModelExecutor> Instance( - const flatbuffers::Vector<uint8_t>* model_spec_buffer) { - const tflite::Model* model = - flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); - flatbuffers::Verifier verifier(model_spec_buffer->data(), - model_spec_buffer->Length()); - if (!model->Verify(verifier)) { + static std::unique_ptr<ModelExecutor> FromModelSpec( + const tflite::Model* model_spec) { + auto model = TfLiteModelFromModelSpec(model_spec); + if (!model) { return nullptr; } - return Instance(model); + return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); } - static std::unique_ptr<const ModelExecutor> Instance( - const tflite::Model* model_spec) { - std::unique_ptr<const tflite::FlatBufferModel> model; - if (!internal::FromModelSpec(model_spec, &model)) { + static std::unique_ptr<ModelExecutor> FromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer) { + auto model = TfLiteModelFromBuffer(model_spec_buffer); + if (!model) { return nullptr; } return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); } - // Creates an Interpreter for the model that serves as a scratch-pad for the - // inference. The Interpreter is NOT thread-safe. - std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; - TensorView<float> ComputeLogits(const TensorView<float>& features, - tflite::Interpreter* interpreter) const { - return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits, - features, interpreter); - } + tflite::Interpreter* interpreter) const; protected: explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) - : model_(std::move(model)) {} + : TfLiteModelExecutor(std::move(model)) {} static const int kInputIndexFeatures = 0; static const int kOutputIndexLogits = 0; - - std::unique_ptr<const tflite::FlatBufferModel> model_; - tflite::ops::builtin::BuiltinOpResolver builtins_; }; // Executor for embedding sparse features into a dense vector. @@ -103,22 +76,23 @@ class EmbeddingExecutor { class TFLiteEmbeddingExecutor : public EmbeddingExecutor { public: - static std::unique_ptr<TFLiteEmbeddingExecutor> Instance( + static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer( const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, int quantization_bits); + // Embeds the sparse_features into a dense embedding and adds (+) it + // element-wise to the dest vector. bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, - int dest_size) const override; + int dest_size) const; protected: explicit TFLiteEmbeddingExecutor( - std::unique_ptr<const tflite::FlatBufferModel> model, - int quantization_bits, int num_buckets, int bytes_per_embedding, - int output_embedding_size, const TfLiteTensor* scales, - const TfLiteTensor* embeddings, + std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, + int num_buckets, int bytes_per_embedding, int output_embedding_size, + const TfLiteTensor* scales, const TfLiteTensor* embeddings, std::unique_ptr<tflite::Interpreter> interpreter); - std::unique_ptr<const tflite::FlatBufferModel> model_; + std::unique_ptr<TfLiteModelExecutor> executor_; int quantization_bits_; int num_buckets_ = -1; @@ -132,6 +106,6 @@ class TFLiteEmbeddingExecutor : public EmbeddingExecutor { std::unique_ptr<tflite::Interpreter> interpreter_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ diff --git a/model.fbs b/annotator/model.fbs index fb9778b..3682994 100755 --- a/model.fbs +++ b/annotator/model.fbs @@ -1,5 +1,5 @@ // -// Copyright (C) 2017 The Android Open Source Project +// Copyright (C) 2018 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,10 +14,13 @@ // limitations under the License. // +include "utils/intents/intent-config.fbs"; +include "utils/zlib/buffer.fbs"; + file_identifier "TC2 "; // The possible model modes, represents a bit field. -namespace libtextclassifier2; +namespace libtextclassifier3; enum ModeFlag : int { NONE = 0, ANNOTATION = 1, @@ -29,7 +32,7 @@ enum ModeFlag : int { ALL = 7, } -namespace libtextclassifier2; +namespace libtextclassifier3; enum DatetimeExtractorType : int { UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0, AM = 1, @@ -106,7 +109,7 @@ enum DatetimeExtractorType : int { THOUSAND = 72, } -namespace libtextclassifier2; +namespace libtextclassifier3; enum DatetimeGroupType : int { GROUP_UNKNOWN = 0, GROUP_UNUSED = 1, @@ -129,20 +132,14 @@ enum DatetimeGroupType : int { GROUP_DUMMY2 = 13, } -namespace libtextclassifier2; -table CompressedBuffer { - buffer:[ubyte]; - uncompressed_size:int; -} - // Options for the model that predicts text selection. -namespace libtextclassifier2; +namespace libtextclassifier3; table SelectionModelOptions { // If true, before the selection is returned, the unpaired brackets contained // in the predicted selection are stripped from the both selection ends. // The bracket codepoints are defined in the Unicode standard: // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt - strip_unpaired_brackets:bool = 1; + strip_unpaired_brackets:bool = true; // Number of hypothetical click positions on either side of the actual click // to consider in order to enforce symmetry. @@ -152,11 +149,11 @@ table SelectionModelOptions { batch_size:int = 1024; // Whether to always classify a suggested selection or only on demand. - always_classify_suggested_selection:bool = 0; + always_classify_suggested_selection:bool = false; } // Options for the model that classifies a text selection. -namespace libtextclassifier2; +namespace libtextclassifier3; table ClassificationModelOptions { // Limits for phone numbers. phone_min_num_digits:int = 7; @@ -170,8 +167,14 @@ table ClassificationModelOptions { max_num_tokens:int = -1; } +// Options for post-checks, checksums and verification to apply on a match. +namespace libtextclassifier3; +table VerificationOptions { + verify_luhn_checksum:bool = false; +} + // List of regular expression matchers to check. -namespace libtextclassifier2.RegexModel_; +namespace libtextclassifier3.RegexModel_; table Pattern { // The name of the collection of a match. collection_name:string; @@ -181,7 +184,7 @@ table Pattern { pattern:string; // The modes for which to apply the patterns. - enabled_modes:libtextclassifier2.ModeFlag = ALL; + enabled_modes:libtextclassifier3.ModeFlag = ALL; // The final score to assign to the results of this pattern. target_classification_score:float = 1; @@ -192,31 +195,34 @@ table Pattern { // If true, will use an approximate matching implementation implemented // using Find() instead of the true Match(). This approximate matching will // use the first Find() result and then check that it spans the whole input. - use_approximate_matching:bool = 0; + use_approximate_matching:bool = false; + + compressed_pattern:libtextclassifier3.CompressedBuffer; - compressed_pattern:libtextclassifier2.CompressedBuffer; + // Verification to apply on a match. + verification_options:libtextclassifier3.VerificationOptions; } -namespace libtextclassifier2; +namespace libtextclassifier3; table RegexModel { - patterns:[libtextclassifier2.RegexModel_.Pattern]; + patterns:[libtextclassifier3.RegexModel_.Pattern]; } // List of regex patterns. -namespace libtextclassifier2.DatetimeModelPattern_; +namespace libtextclassifier3.DatetimeModelPattern_; table Regex { pattern:string; // The ith entry specifies the type of the ith capturing group. // This is used to decide how the matched content has to be parsed. - groups:[libtextclassifier2.DatetimeGroupType]; + groups:[libtextclassifier3.DatetimeGroupType]; - compressed_pattern:libtextclassifier2.CompressedBuffer; + compressed_pattern:libtextclassifier3.CompressedBuffer; } -namespace libtextclassifier2; +namespace libtextclassifier3; table DatetimeModelPattern { - regexes:[libtextclassifier2.DatetimeModelPattern_.Regex]; + regexes:[libtextclassifier3.DatetimeModelPattern_.Regex]; // List of locale indices in DatetimeModel that represent the locales that // these patterns should be used for. If empty, can be used for all locales. @@ -225,63 +231,63 @@ table DatetimeModelPattern { // The final score to assign to the results of this pattern. target_classification_score:float = 1; - // Priority score used for conflict resulution with the other models. + // Priority score used for conflict resolution with the other models. priority_score:float = 0; // The modes for which to apply the patterns. - enabled_modes:libtextclassifier2.ModeFlag = ALL; + enabled_modes:libtextclassifier3.ModeFlag = ALL; } -namespace libtextclassifier2; +namespace libtextclassifier3; table DatetimeModelExtractor { - extractor:libtextclassifier2.DatetimeExtractorType; + extractor:libtextclassifier3.DatetimeExtractorType; pattern:string; locales:[int]; - compressed_pattern:libtextclassifier2.CompressedBuffer; + compressed_pattern:libtextclassifier3.CompressedBuffer; } -namespace libtextclassifier2; +namespace libtextclassifier3; table DatetimeModel { // List of BCP 47 locale strings representing all locales supported by the // model. The individual patterns refer back to them using an index. locales:[string]; - patterns:[libtextclassifier2.DatetimeModelPattern]; - extractors:[libtextclassifier2.DatetimeModelExtractor]; + patterns:[libtextclassifier3.DatetimeModelPattern]; + extractors:[libtextclassifier3.DatetimeModelExtractor]; // If true, will use the extractors for determining the match location as // opposed to using the location where the global pattern matched. - use_extractors_for_locating:bool = 1; + use_extractors_for_locating:bool = true; // List of locale ids, rules of whose are always run, after the requested // ones. default_locales:[int]; } -namespace libtextclassifier2.DatetimeModelLibrary_; +namespace libtextclassifier3.DatetimeModelLibrary_; table Item { key:string; - value:libtextclassifier2.DatetimeModel; + value:libtextclassifier3.DatetimeModel; } // A set of named DateTime models. -namespace libtextclassifier2; +namespace libtextclassifier3; table DatetimeModelLibrary { - models:[libtextclassifier2.DatetimeModelLibrary_.Item]; + models:[libtextclassifier3.DatetimeModelLibrary_.Item]; } // Options controlling the output of the Tensorflow Lite models. -namespace libtextclassifier2; +namespace libtextclassifier3; table ModelTriggeringOptions { // Lower bound threshold for filtering annotation model outputs. min_annotate_confidence:float = 0; // The modes for which to enable the models. - enabled_modes:libtextclassifier2.ModeFlag = ALL; + enabled_modes:libtextclassifier3.ModeFlag = ALL; } // Options controlling the output of the classifier. -namespace libtextclassifier2; +namespace libtextclassifier3; table OutputOptions { // Lists of collection names that will be filtered out at the output: // - For annotation, the spans of given collection are simply dropped. @@ -294,7 +300,7 @@ table OutputOptions { filtered_collections_selection:[string]; } -namespace libtextclassifier2; +namespace libtextclassifier3; table Model { // Comma-separated list of locales supported by the model as BCP 47 tags. locales:string; @@ -304,8 +310,8 @@ table Model { // A name for the model that can be used for e.g. logging. name:string; - selection_feature_options:libtextclassifier2.FeatureProcessorOptions; - classification_feature_options:libtextclassifier2.FeatureProcessorOptions; + selection_feature_options:libtextclassifier3.FeatureProcessorOptions; + classification_feature_options:libtextclassifier3.FeatureProcessorOptions; // Tensorflow Lite models. selection_model:[ubyte] (force_align: 16); @@ -314,31 +320,37 @@ table Model { embedding_model:[ubyte] (force_align: 16); // Options for the different models. - selection_options:libtextclassifier2.SelectionModelOptions; + selection_options:libtextclassifier3.SelectionModelOptions; - classification_options:libtextclassifier2.ClassificationModelOptions; - regex_model:libtextclassifier2.RegexModel; - datetime_model:libtextclassifier2.DatetimeModel; + classification_options:libtextclassifier3.ClassificationModelOptions; + regex_model:libtextclassifier3.RegexModel; + datetime_model:libtextclassifier3.DatetimeModel; // Options controlling the output of the models. - triggering_options:libtextclassifier2.ModelTriggeringOptions; + triggering_options:libtextclassifier3.ModelTriggeringOptions; // Global switch that controls if SuggestSelection(), ClassifyText() and // Annotate() will run. If a mode is disabled it returns empty/no-op results. - enabled_modes:libtextclassifier2.ModeFlag = ALL; + enabled_modes:libtextclassifier3.ModeFlag = ALL; // If true, will snap the selections that consist only of whitespaces to the // containing suggested span. Otherwise, no suggestion is proposed, since the // selections are not part of any token. - snap_whitespace_selections:bool = 1; + snap_whitespace_selections:bool = true; // Global configuration for the output of SuggestSelection(), ClassifyText() // and Annotate(). - output_options:libtextclassifier2.OutputOptions; + output_options:libtextclassifier3.OutputOptions; + + // Configures how Intents should be generated on Android. + // TODO(smillius): Remove deprecated factory options. + android_intent_options:libtextclassifier3.AndroidIntentFactoryOptions; + + intent_options:libtextclassifier3.IntentFactoryModel; } // Role of the codepoints in the range. -namespace libtextclassifier2.TokenizationCodepointRange_; +namespace libtextclassifier3.TokenizationCodepointRange_; enum Role : int { // Concatenates the codepoint to the current run of codepoints. DEFAULT_ROLE = 0, @@ -363,11 +375,11 @@ enum Role : int { } // Represents a codepoint range [start, end) with its role for tokenization. -namespace libtextclassifier2; +namespace libtextclassifier3; table TokenizationCodepointRange { start:int; end:int; - role:libtextclassifier2.TokenizationCodepointRange_.Role; + role:libtextclassifier3.TokenizationCodepointRange_.Role; // Integer identifier of the script this range denotes. Negative values are // reserved for Tokenizer's internal use. @@ -375,7 +387,7 @@ table TokenizationCodepointRange { } // Method for selecting the center token. -namespace libtextclassifier2.FeatureProcessorOptions_; +namespace libtextclassifier3.FeatureProcessorOptions_; enum CenterTokenSelectionMethod : int { DEFAULT_CENTER_TOKEN_METHOD = 0, @@ -388,7 +400,7 @@ enum CenterTokenSelectionMethod : int { } // Controls the type of tokenization the model will use for the input text. -namespace libtextclassifier2.FeatureProcessorOptions_; +namespace libtextclassifier3.FeatureProcessorOptions_; enum TokenizationType : int { INVALID_TOKENIZATION_TYPE = 0, @@ -405,14 +417,14 @@ enum TokenizationType : int { } // Range of codepoints start - end, where end is exclusive. -namespace libtextclassifier2.FeatureProcessorOptions_; +namespace libtextclassifier3.FeatureProcessorOptions_; table CodepointRange { start:int; end:int; } // Bounds-sensitive feature extraction configuration. -namespace libtextclassifier2.FeatureProcessorOptions_; +namespace libtextclassifier3.FeatureProcessorOptions_; table BoundsSensitiveFeatures { // Enables the extraction of bounds-sensitive features, instead of the click // context features. @@ -445,13 +457,7 @@ table BoundsSensitiveFeatures { score_single_token_spans_as_zero:bool; } -namespace libtextclassifier2.FeatureProcessorOptions_; -table AlternativeCollectionMapEntry { - key:string; - value:string; -} - -namespace libtextclassifier2; +namespace libtextclassifier3; table FeatureProcessorOptions { // Number of buckets used for hashing charactergrams. num_buckets:int = -1; @@ -479,20 +485,20 @@ table FeatureProcessorOptions { max_word_length:int = 20; // If true, will use the unicode-aware functionality for extracting features. - unicode_aware_features:bool = 0; + unicode_aware_features:bool = false; // Whether to extract the token case feature. - extract_case_feature:bool = 0; + extract_case_feature:bool = false; // Whether to extract the selection mask feature. - extract_selection_mask_feature:bool = 0; + extract_selection_mask_feature:bool = false; // List of regexps to run over each token. For each regexp, if there is a // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used. regexp_feature:[string]; // Whether to remap all digits to a single number. - remap_digits:bool = 0; + remap_digits:bool = false; // Whether to lower-case each token before generating hashgrams. lowercase_tokens:bool; @@ -504,7 +510,7 @@ table FeatureProcessorOptions { // infeasible ones. // NOTE: Exists mainly for compatibility with older models that were trained // with the non-reduced output space. - selection_reduced_output_space:bool = 1; + selection_reduced_output_space:bool = true; // Collection names. collections:[string]; @@ -515,29 +521,29 @@ table FeatureProcessorOptions { // If true, will split the input by lines, and only use the line that contains // the clicked token. - only_use_line_with_click:bool = 0; + only_use_line_with_click:bool = false; // If true, will split tokens that contain the selection boundary, at the // position of the boundary. // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" - split_tokens_on_selection_boundaries:bool = 0; + split_tokens_on_selection_boundaries:bool = false; // Codepoint ranges that determine how different codepoints are tokenized. // The ranges must not overlap. - tokenization_codepoint_config:[libtextclassifier2.TokenizationCodepointRange]; + tokenization_codepoint_config:[libtextclassifier3.TokenizationCodepointRange]; - center_token_selection_method:libtextclassifier2.FeatureProcessorOptions_.CenterTokenSelectionMethod; + center_token_selection_method:libtextclassifier3.FeatureProcessorOptions_.CenterTokenSelectionMethod; // If true, span boundaries will be snapped to containing tokens and not // required to exactly match token boundaries. snap_label_span_boundaries_to_containing_tokens:bool; // A set of codepoint ranges supported by the model. - supported_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange]; + supported_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange]; // A set of codepoint ranges to use in the mixed tokenization mode to identify // stretches of tokens to re-tokenize using the internal tokenizer. - internal_tokenizer_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange]; + internal_tokenizer_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange]; // Minimum ratio of supported codepoints in the input context. If the ratio // is lower than this, the feature computation will fail. @@ -553,14 +559,14 @@ table FeatureProcessorOptions { // to it. So the resulting feature vector has two regions. feature_version:int = 0; - tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER; - icu_preserve_whitespace_tokens:bool = 0; + tokenization_type:libtextclassifier3.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER; + icu_preserve_whitespace_tokens:bool = false; // List of codepoints that will be stripped from beginning and end of // predicted spans. ignored_span_boundary_codepoints:[int]; - bounds_sensitive_features:libtextclassifier2.FeatureProcessorOptions_.BoundsSensitiveFeatures; + bounds_sensitive_features:libtextclassifier3.FeatureProcessorOptions_.BoundsSensitiveFeatures; // List of allowed charactergrams. The extracted charactergrams are filtered // using this list, and charactergrams that are not present are interpreted as @@ -571,7 +577,7 @@ table FeatureProcessorOptions { // If true, tokens will be also split when the codepoint's script_id changes // as defined in TokenizationCodepointRange. - tokenize_on_script_change:bool = 0; + tokenize_on_script_change:bool = false; } -root_type libtextclassifier2.Model; +root_type libtextclassifier3.Model; diff --git a/quantization.cc b/annotator/quantization.cc index 1a34565..2cf11c5 100644 --- a/quantization.cc +++ b/annotator/quantization.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "quantization.h" +#include "annotator/quantization.h" -#include "util/base/logging.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { float DequantizeValue(int num_sparse_features, int quantization_bias, float multiplier, int value) { @@ -82,11 +82,11 @@ bool DequantizeAdd(const float* scales, const uint8* embeddings, num_sparse_features, quantization_bits, bucket_id, dest, dest_size); } else { - TC_LOG(ERROR) << "Unsupported quantization_bits: " << quantization_bits; + TC3_LOG(ERROR) << "Unsupported quantization_bits: " << quantization_bits; return false; } return true; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/quantization.h b/annotator/quantization.h index c486640..d294f37 100644 --- a/quantization.h +++ b/annotator/quantization.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_QUANTIZATION_H_ -#define LIBTEXTCLASSIFIER_QUANTIZATION_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_ -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Returns true if the quantization parameters are valid. bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits, @@ -34,6 +34,6 @@ bool DequantizeAdd(const float* scales, const uint8* embeddings, int quantization_bits, int bucket_id, float* dest, int dest_size); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_QUANTIZATION_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_QUANTIZATION_H_ diff --git a/quantization_test.cc b/annotator/quantization_test.cc index 088daaf..b995096 100644 --- a/quantization_test.cc +++ b/annotator/quantization_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "quantization.h" +#include "annotator/quantization.h" #include <vector> @@ -25,7 +25,7 @@ using testing::ElementsAreArray; using testing::FloatEq; using testing::Matcher; -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { @@ -160,4 +160,4 @@ TEST(QuantizationTest, DequantizeAdd3bit) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/strip-unpaired-brackets.cc b/annotator/strip-unpaired-brackets.cc index ddf3322..b1067ad 100644 --- a/strip-unpaired-brackets.cc +++ b/annotator/strip-unpaired-brackets.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "strip-unpaired-brackets.h" +#include "annotator/strip-unpaired-brackets.h" #include <iterator> -#include "util/base/logging.h" -#include "util/utf8/unicodetext.h" +#include "utils/base/logging.h" +#include "utils/utf8/unicodetext.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { // Returns true if given codepoint is contained in the given span in context. @@ -94,12 +94,12 @@ CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, // Should not happen, but let's make sure. if (span.first > span.second) { - TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", " - << span.second; + TC3_LOG(WARNING) << "Inverse indices result: " << span.first << ", " + << span.second; span.second = span.first; } return span; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/strip-unpaired-brackets.h b/annotator/strip-unpaired-brackets.h index 4e82c3e..ceb8d60 100644 --- a/strip-unpaired-brackets.h +++ b/annotator/strip-unpaired-brackets.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_ -#define LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_ #include <string> -#include "types.h" -#include "util/utf8/unilib.h" +#include "annotator/types.h" +#include "utils/utf8/unilib.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // If the first or the last codepoint of the given span is a bracket, the // bracket is stripped if the span does not contain its corresponding paired // version. @@ -33,6 +33,6 @@ CodepointSpan StripUnpairedBrackets(const std::string& context, CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, CodepointSpan span, const UniLib& unilib); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_STRIP_UNPAIRED_BRACKETS_H_ diff --git a/strip-unpaired-brackets_test.cc b/annotator/strip-unpaired-brackets_test.cc index 5362500..32585ce 100644 --- a/strip-unpaired-brackets_test.cc +++ b/annotator/strip-unpaired-brackets_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,48 +14,53 @@ * limitations under the License. */ -#include "strip-unpaired-brackets.h" +#include "annotator/strip-unpaired-brackets.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { -TEST(StripUnpairedBracketsTest, StripUnpairedBrackets) { - CREATE_UNILIB_FOR_TESTING +class StripUnpairedBracketsTest : public ::testing::Test { + protected: + StripUnpairedBracketsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + +TEST_F(StripUnpairedBracketsTest, StripUnpairedBrackets) { // If the brackets match, nothing gets stripped. - EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib_), std::make_pair(8, 17)); - EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib_), std::make_pair(8, 17)); // If the brackets don't match, they get stripped. - EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib_), std::make_pair(9, 16)); - EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib_), std::make_pair(9, 16)); - EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib_), std::make_pair(8, 15)); - EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib_), std::make_pair(8, 15)); // Strips brackets correctly from length-1 selections that consist of // a bracket only. - EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib_), std::make_pair(12, 12)); - EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib_), std::make_pair(12, 12)); // Handles invalid spans gracefully. - EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib), + EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib_), std::make_pair(11, 11)); - EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib), + EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib_), std::make_pair(0, 0)); - EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib), + EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib_), std::make_pair(11, 11)); - EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib), + EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib_), std::make_pair(-1, -1)); } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/test_data/test_model.fb b/annotator/test_data/test_model.fb Binary files differindex c651bdb..fa9cec5 100644 --- a/test_data/test_model.fb +++ b/annotator/test_data/test_model.fb diff --git a/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb Binary files differindex 53af6bf..b73d84f 100644 --- a/test_data/test_model_cc.fb +++ b/annotator/test_data/test_model_cc.fb diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb Binary files differnew file mode 100644 index 0000000..ba71cdd --- /dev/null +++ b/annotator/test_data/wrong_embeddings.fb diff --git a/token-feature-extractor.cc b/annotator/token-feature-extractor.cc index 13fba30..77ad7a4 100644 --- a/token-feature-extractor.cc +++ b/annotator/token-feature-extractor.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,17 +14,17 @@ * limitations under the License. */ -#include "token-feature-extractor.h" +#include "annotator/token-feature-extractor.h" #include <cctype> #include <string> -#include "util/base/logging.h" -#include "util/hash/farmhash.h" -#include "util/strings/stringpiece.h" -#include "util/utf8/unicodetext.h" +#include "utils/base/logging.h" +#include "utils/hash/farmhash.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { @@ -58,11 +58,11 @@ void RemapTokenUnicode(const std::string& token, remapped->clear(); for (auto it = word.begin(); it != word.end(); ++it) { if (options.remap_digits && unilib.IsDigit(*it)) { - remapped->AppendCodepoint('0'); + remapped->push_back('0'); } else if (options.lowercase_tokens) { - remapped->AppendCodepoint(unilib.ToLower(*it)); + remapped->push_back(unilib.ToLower(*it)); } else { - remapped->AppendCodepoint(*it); + remapped->push_back(*it); } } } @@ -160,7 +160,7 @@ std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures( int TokenFeatureExtractor::HashToken(StringPiece token) const { if (options_.allowed_chargrams.empty()) { - return tc2farmhash::Fingerprint64(token) % options_.num_buckets; + return tc3farmhash::Fingerprint64(token) % options_.num_buckets; } else { // Padding and out-of-vocabulary tokens have extra buckets reserved because // they are special and important tokens, and we don't want them to share @@ -174,7 +174,7 @@ int TokenFeatureExtractor::HashToken(StringPiece token) const { options_.allowed_chargrams.end()) { return 0; // Out-of-vocabulary. } else { - return (tc2farmhash::Fingerprint64(token) % + return (tc3farmhash::Fingerprint64(token) % (options_.num_buckets - kNumExtraBuckets)) + kNumExtraBuckets; } @@ -308,4 +308,4 @@ std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode( return result; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/token-feature-extractor.h b/annotator/token-feature-extractor.h index fee1355..7dc19fe 100644 --- a/token-feature-extractor.h +++ b/annotator/token-feature-extractor.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,18 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_ -#define LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_ #include <memory> #include <unordered_set> #include <vector> -#include "types.h" -#include "util/strings/stringpiece.h" -#include "util/utf8/unilib.h" +#include "annotator/types.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unilib.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { struct TokenFeatureExtractorOptions { // Number of buckets used for hashing charactergrams. @@ -110,6 +110,6 @@ class TokenFeatureExtractor { const UniLib& unilib_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_ diff --git a/token-feature-extractor_test.cc b/annotator/token-feature-extractor_test.cc index 4b7e011..32383a9 100644 --- a/token-feature-extractor_test.cc +++ b/annotator/token-feature-extractor_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,29 +14,34 @@ * limitations under the License. */ -#include "token-feature-extractor.h" +#include "annotator/token-feature-extractor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { +class TokenFeatureExtractorTest : public ::testing::Test { + protected: + TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + class TestingTokenFeatureExtractor : public TokenFeatureExtractor { public: using TokenFeatureExtractor::HashToken; using TokenFeatureExtractor::TokenFeatureExtractor; }; -TEST(TokenFeatureExtractorTest, ExtractAscii) { +TEST_F(TokenFeatureExtractorTest, ExtractAscii) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2, 3}; options.extract_case_feature = true; options.unicode_aware_features = false; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -99,15 +104,14 @@ TEST(TokenFeatureExtractorTest, ExtractAscii) { EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); } -TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) { +TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{}; options.extract_case_feature = true; options.unicode_aware_features = false; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -129,15 +133,14 @@ TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) { EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); } -TEST(TokenFeatureExtractorTest, ExtractUnicode) { +TEST_F(TokenFeatureExtractorTest, ExtractUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2, 3}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -200,15 +203,14 @@ TEST(TokenFeatureExtractorTest, ExtractUnicode) { EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); } -TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { +TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -231,16 +233,15 @@ TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); } -#ifdef LIBTEXTCLASSIFIER_TEST_ICU -TEST(TokenFeatureExtractorTest, ICUCaseFeature) { +#ifdef TC3_TEST_ICU +TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = false; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -268,14 +269,13 @@ TEST(TokenFeatureExtractorTest, ICUCaseFeature) { } #endif -TEST(TokenFeatureExtractorTest, DigitRemapping) { +TEST_F(TokenFeatureExtractorTest, DigitRemapping) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; options.remap_digits = true; options.unicode_aware_features = false; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -293,14 +293,13 @@ TEST(TokenFeatureExtractorTest, DigitRemapping) { testing::Not(testing::ElementsAreArray(sparse_features2))); } -TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) { +TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; options.remap_digits = true; options.unicode_aware_features = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -318,14 +317,13 @@ TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) { testing::Not(testing::ElementsAreArray(sparse_features2))); } -TEST(TokenFeatureExtractorTest, LowercaseAscii) { +TEST_F(TokenFeatureExtractorTest, LowercaseAscii) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; options.lowercase_tokens = true; options.unicode_aware_features = false; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -342,15 +340,14 @@ TEST(TokenFeatureExtractorTest, LowercaseAscii) { EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); } -#ifdef LIBTEXTCLASSIFIER_TEST_ICU -TEST(TokenFeatureExtractorTest, LowercaseUnicode) { +#ifdef TC3_TEST_ICU +TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; options.lowercase_tokens = true; options.unicode_aware_features = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -363,8 +360,8 @@ TEST(TokenFeatureExtractorTest, LowercaseUnicode) { } #endif -#ifdef LIBTEXTCLASSIFIER_TEST_ICU -TEST(TokenFeatureExtractorTest, RegexFeatures) { +#ifdef TC3_TEST_ICU +TEST_F(TokenFeatureExtractorTest, RegexFeatures) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; @@ -372,8 +369,7 @@ TEST(TokenFeatureExtractorTest, RegexFeatures) { options.unicode_aware_features = false; options.regexp_features.push_back("^[a-z]+$"); // all lower case. options.regexp_features.push_back("^[0-9]+$"); // all digits. - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -398,15 +394,14 @@ TEST(TokenFeatureExtractorTest, RegexFeatures) { } #endif -TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { +TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{22}; options.extract_case_feature = true; options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); // Test that this runs. ASAN should catch problems. std::vector<int> sparse_features; @@ -423,7 +418,7 @@ TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { })); } -TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { +TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5}; @@ -431,11 +426,10 @@ TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { options.unicode_aware_features = true; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor_unicode(options, unilib); + TestingTokenFeatureExtractor extractor_unicode(options, unilib_); options.unicode_aware_features = false; - TestingTokenFeatureExtractor extractor_ascii(options, unilib); + TestingTokenFeatureExtractor extractor_ascii(options, unilib_); for (const std::string& input : {"https://www.abcdefgh.com/in/xxxkkkvayio", @@ -458,7 +452,7 @@ TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { } } -TEST(TokenFeatureExtractorTest, ExtractForPadToken) { +TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2}; @@ -466,8 +460,7 @@ TEST(TokenFeatureExtractorTest, ExtractForPadToken) { options.unicode_aware_features = false; options.extract_selection_mask_feature = true; - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -479,7 +472,7 @@ TEST(TokenFeatureExtractorTest, ExtractForPadToken) { EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); } -TEST(TokenFeatureExtractorTest, ExtractFiltered) { +TEST_F(TokenFeatureExtractorTest, ExtractFiltered) { TokenFeatureExtractorOptions options; options.num_buckets = 1000; options.chargram_orders = std::vector<int>{1, 2, 3}; @@ -493,8 +486,7 @@ TEST(TokenFeatureExtractorTest, ExtractFiltered) { options.allowed_chargrams.insert("!"); options.allowed_chargrams.insert("\xc4"); // UTF8 control character. - CREATE_UNILIB_FOR_TESTING - TestingTokenFeatureExtractor extractor(options, unilib); + TestingTokenFeatureExtractor extractor(options, unilib_); std::vector<int> sparse_features; std::vector<float> dense_features; @@ -561,4 +553,4 @@ TEST(TokenFeatureExtractorTest, ExtractFiltered) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/tokenizer.cc b/annotator/tokenizer.cc index 722a67b..099dccc 100644 --- a/tokenizer.cc +++ b/annotator/tokenizer.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "tokenizer.h" +#include "annotator/tokenizer.h" #include <algorithm> -#include "util/base/logging.h" -#include "util/strings/utf8.h" +#include "utils/base/logging.h" +#include "utils/strings/utf8.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { Tokenizer::Tokenizer( const std::vector<const TokenizationCodepointRange*>& codepoint_ranges, @@ -123,4 +123,4 @@ std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const { return result; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/tokenizer.h b/annotator/tokenizer.h index 2524e12..ec33f2d 100644 --- a/tokenizer.h +++ b/annotator/tokenizer.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,18 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_TOKENIZER_H_ -#define LIBTEXTCLASSIFIER_TOKENIZER_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_ #include <string> #include <vector> -#include "model_generated.h" -#include "types.h" -#include "util/base/integral_types.h" -#include "util/utf8/unicodetext.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/utf8/unicodetext.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { const int kInvalidScript = -1; const int kUnknownScript = -2; @@ -66,6 +66,6 @@ class Tokenizer { bool split_on_script_change_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_TOKENIZER_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_ diff --git a/tokenizer_test.cc b/annotator/tokenizer_test.cc index 65072f3..a3ab9da 100644 --- a/tokenizer_test.cc +++ b/annotator/tokenizer_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "tokenizer.h" +#include "annotator/tokenizer.h" #include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { using testing::ElementsAreArray; @@ -331,4 +331,4 @@ TEST(TokenizerTest, TokenizeComplex) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/types-test-util.h b/annotator/types-test-util.h index 1679e7c..fbbdd63 100644 --- a/types-test-util.h +++ b/annotator/types-test-util.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_ -#define LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_ #include <ostream> -#include "types.h" -#include "util/base/logging.h" +#include "annotator/types.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { inline std::ostream& operator<<(std::ostream& stream, const Token& value) { logging::LoggingStringStream tmp_stream; @@ -44,6 +44,6 @@ inline std::ostream& operator<<(std::ostream& stream, return stream << tmp_stream.message; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_ diff --git a/types.h b/annotator/types.h index b2f624d..38bce41 100644 --- a/types.h +++ b/annotator/types.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,23 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_TYPES_H_ -#define LIBTEXTCLASSIFIER_TYPES_H_ +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ #include <algorithm> #include <cmath> #include <functional> +#include <map> #include <set> #include <string> #include <utility> #include <vector> -#include "util/base/integral_types.h" -#include "util/base/logging.h" +#include "utils/base/integral_types.h" +#include "utils/base/logging.h" +#include "utils/variant.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { constexpr int kInvalidIndex = -1; @@ -221,10 +223,14 @@ struct ClassificationResult { std::string collection; float score; DatetimeParseResult datetime_parse_result; + std::string serialized_knowledge_result; // Internal score used for conflict resolution. float priority_score; + // Extra information. + std::map<std::string, Variant> extra; + explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} ClassificationResult(const std::string& arg_collection, float arg_score) @@ -318,13 +324,13 @@ struct DateParseData { }; enum RelationType { - MONDAY = 1, - TUESDAY = 2, - WEDNESDAY = 3, - THURSDAY = 4, - FRIDAY = 5, - SATURDAY = 6, - SUNDAY = 7, + SUNDAY = 1, + MONDAY = 2, + TUESDAY = 3, + WEDNESDAY = 4, + THURSDAY = 5, + FRIDAY = 6, + SATURDAY = 7, DAY = 8, WEEK = 9, MONTH = 10, @@ -391,6 +397,6 @@ struct DateParseData { int relation_distance; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_TYPES_H_ +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ diff --git a/annotator/zlib-utils.cc b/annotator/zlib-utils.cc new file mode 100644 index 0000000..6efe025 --- /dev/null +++ b/annotator/zlib-utils.cc @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/zlib-utils.h" + +#include <memory> + +#include "utils/base/logging.h" +#include "utils/zlib/zlib.h" + +namespace libtextclassifier3 { + +// Compress rule fields in the model. +bool CompressModel(ModelT* model) { + std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance(); + if (!zlib_compressor) { + TC3_LOG(ERROR) << "Cannot compress model."; + return false; + } + + // Compress regex rules. + if (model->regex_model != nullptr) { + for (int i = 0; i < model->regex_model->patterns.size(); i++) { + RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); + pattern->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(pattern->pattern, + pattern->compressed_pattern.get()); + pattern->pattern.clear(); + } + } + + // Compress date-time rules. + if (model->datetime_model != nullptr) { + for (int i = 0; i < model->datetime_model->patterns.size(); i++) { + DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); + for (int j = 0; j < pattern->regexes.size(); j++) { + DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); + regex->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(regex->pattern, + regex->compressed_pattern.get()); + regex->pattern.clear(); + } + } + for (int i = 0; i < model->datetime_model->extractors.size(); i++) { + DatetimeModelExtractorT* extractor = + model->datetime_model->extractors[i].get(); + extractor->compressed_pattern.reset(new CompressedBufferT); + zlib_compressor->Compress(extractor->pattern, + extractor->compressed_pattern.get()); + extractor->pattern.clear(); + } + } + return true; +} + +bool DecompressModel(ModelT* model) { + std::unique_ptr<ZlibDecompressor> zlib_decompressor = + ZlibDecompressor::Instance(); + if (!zlib_decompressor) { + TC3_LOG(ERROR) << "Cannot initialize decompressor."; + return false; + } + + // Decompress regex rules. + if (model->regex_model != nullptr) { + for (int i = 0; i < model->regex_model->patterns.size(); i++) { + RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); + if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(), + &pattern->pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; + return false; + } + pattern->compressed_pattern.reset(nullptr); + } + } + + // Decompress date-time rules. + if (model->datetime_model != nullptr) { + for (int i = 0; i < model->datetime_model->patterns.size(); i++) { + DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); + for (int j = 0; j < pattern->regexes.size(); j++) { + DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); + if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(), + ®ex->pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j; + return false; + } + regex->compressed_pattern.reset(nullptr); + } + } + for (int i = 0; i < model->datetime_model->extractors.size(); i++) { + DatetimeModelExtractorT* extractor = + model->datetime_model->extractors[i].get(); + if (!zlib_decompressor->MaybeDecompress( + extractor->compressed_pattern.get(), &extractor->pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; + return false; + } + extractor->compressed_pattern.reset(nullptr); + } + } + return true; +} + +std::string CompressSerializedModel(const std::string& model) { + std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str()); + TC3_CHECK(unpacked_model != nullptr); + TC3_CHECK(CompressModel(unpacked_model.get())); + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +} // namespace libtextclassifier3 diff --git a/annotator/zlib-utils.h b/annotator/zlib-utils.h new file mode 100644 index 0000000..462a02b --- /dev/null +++ b/annotator/zlib-utils.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Functions to compress and decompress low entropy entries in the model. + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_ + +#include "annotator/model_generated.h" + +namespace libtextclassifier3 { + +// Compresses regex and datetime rules in the model in place. +bool CompressModel(ModelT* model); + +// Decompresses regex and datetime rules in the model in place. +bool DecompressModel(ModelT* model); + +// Compresses regex and datetime rules in the model. +std::string CompressSerializedModel(const std::string& model); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ZLIB_UTILS_H_ diff --git a/zlib-utils_test.cc b/annotator/zlib-utils_test.cc index 155f14f..7a8d775 100644 --- a/zlib-utils_test.cc +++ b/annotator/zlib-utils_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,16 @@ * limitations under the License. */ -#include "zlib-utils.h" +#include "annotator/zlib-utils.h" #include <memory> -#include "model_generated.h" +#include "annotator/model_generated.h" +#include "utils/zlib/zlib.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { TEST(ZlibUtilsTest, CompressModel) { ModelT model; @@ -62,27 +63,27 @@ TEST(ZlibUtilsTest, CompressModel) { std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); ASSERT_TRUE(decompressor != nullptr); std::string uncompressed_pattern; - EXPECT_TRUE(decompressor->Decompress( + EXPECT_TRUE(decompressor->MaybeDecompress( compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(), &uncompressed_pattern)); EXPECT_EQ(uncompressed_pattern, "this is a test pattern"); - EXPECT_TRUE(decompressor->Decompress( + EXPECT_TRUE(decompressor->MaybeDecompress( compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(), &uncompressed_pattern)); EXPECT_EQ(uncompressed_pattern, "this is a second test pattern"); - EXPECT_TRUE(decompressor->Decompress(compressed_model->datetime_model() - ->patterns() - ->Get(0) - ->regexes() - ->Get(0) - ->compressed_pattern(), - &uncompressed_pattern)); + EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model() + ->patterns() + ->Get(0) + ->regexes() + ->Get(0) + ->compressed_pattern(), + &uncompressed_pattern)); EXPECT_EQ(uncompressed_pattern, "an example datetime pattern"); - EXPECT_TRUE(decompressor->Decompress(compressed_model->datetime_model() - ->extractors() - ->Get(0) - ->compressed_pattern(), - &uncompressed_pattern)); + EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model() + ->extractors() + ->Get(0) + ->compressed_pattern(), + &uncompressed_pattern)); EXPECT_EQ(uncompressed_pattern, "an example datetime extractor"); EXPECT_TRUE(DecompressModel(&model)); @@ -95,4 +96,4 @@ TEST(ZlibUtilsTest, CompressModel) { "an example datetime extractor"); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/generate_flatbuffers.mk b/generate_flatbuffers.mk new file mode 100644 index 0000000..1522463 --- /dev/null +++ b/generate_flatbuffers.mk @@ -0,0 +1,46 @@ +FLATC := $(HOST_OUT_EXECUTABLES)/flatc$(HOST_EXECUTABLE_SUFFIX) + +define transform-fbs-to-cpp +@echo "Flatc: $@ <= $(PRIVATE_INPUT_FBS)" +@rm -f $@ +@mkdir -p $(dir $@) +$(FLATC) \ + --cpp \ + --no-union-value-namespacing \ + --gen-object-api \ + --keep-prefix \ + -I $(INPUT_DIR) \ + -o $(dir $@) \ + $(PRIVATE_INPUT_FBS) \ + || exit 33 +[ -f $@ ] || exit 33 +endef + +intermediates := $(call local-generated-sources-dir) + +# Generate utils/zlib/buffer_generated.h using FlatBuffer schema compiler. +UTILS_ZLIB_BUFFER_FBS := $(LOCAL_PATH)/utils/zlib/buffer.fbs +UTILS_ZLIB_BUFFER_H := $(intermediates)/utils/zlib/buffer_generated.h +$(UTILS_ZLIB_BUFFER_H): PRIVATE_INPUT_FBS := $(UTILS_ZLIB_BUFFER_FBS) +$(UTILS_ZLIB_BUFFER_H): INPUT_DIR := $(LOCAL_PATH) +$(UTILS_ZLIB_BUFFER_H): $(FLATC) $(UTILS_ZLIB_BUFFER_FBS) + $(transform-fbs-to-cpp) +LOCAL_GENERATED_SOURCES += $(UTILS_ZLIB_BUFFER_H) + +# Generate utils/intent/intent-config_generated.h using FlatBuffer schema compiler. +INTENT_CONFIG_FBS := $(LOCAL_PATH)/utils/intents/intent-config.fbs +INTENT_CONFIG_H := $(intermediates)/utils/intents/intent-config_generated.h +$(INTENT_CONFIG_H): PRIVATE_INPUT_FBS := $(INTENT_CONFIG_FBS) +$(INTENT_CONFIG_H): INPUT_DIR := $(LOCAL_PATH) +$(INTENT_CONFIG_H): $(FLATC) $(INTENT_CONFIG_FBS) + $(transform-fbs-to-cpp) +LOCAL_GENERATED_SOURCES += $(INTENT_CONFIG_H) + +# Generate annotator/model_generated.h using FlatBuffer schema compiler. +ANNOTATOR_MODEL_FBS := $(LOCAL_PATH)/annotator/model.fbs +ANNOTATOR_MODEL_H := $(intermediates)/annotator/model_generated.h +$(ANNOTATOR_MODEL_H): PRIVATE_INPUT_FBS := $(ANNOTATOR_MODEL_FBS) +$(ANNOTATOR_MODEL_H): INPUT_DIR := $(LOCAL_PATH) +$(ANNOTATOR_MODEL_H): $(FLATC) $(ANNOTATOR_MODEL_FBS) $(INTENT_CONFIG_H) + $(transform-fbs-to-cpp) +LOCAL_GENERATED_SOURCES += $(ANNOTATOR_MODEL_H) diff --git a/java/com/google/android/textclassifier/AnnotatorModel.java b/java/com/google/android/textclassifier/AnnotatorModel.java new file mode 100644 index 0000000..08a4455 --- /dev/null +++ b/java/com/google/android/textclassifier/AnnotatorModel.java @@ -0,0 +1,342 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.android.textclassifier; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Java wrapper for Annotator native library interface. This library is used for detecting entities + * in text. + * + * @hide + */ +public final class AnnotatorModel implements AutoCloseable { + private final AtomicBoolean isClosed = new AtomicBoolean(false); + + static { + System.loadLibrary("textclassifier"); + } + + // Keep these in sync with the constants defined in AOSP. + static final String TYPE_UNKNOWN = ""; + static final String TYPE_OTHER = "other"; + static final String TYPE_EMAIL = "email"; + static final String TYPE_PHONE = "phone"; + static final String TYPE_ADDRESS = "address"; + static final String TYPE_URL = "url"; + static final String TYPE_DATE = "date"; + static final String TYPE_DATE_TIME = "datetime"; + static final String TYPE_FLIGHT_NUMBER = "flight"; + + private long annotatorPtr; + + /** + * Creates a new instance of SmartSelect predictor, using the provided model image, given as a + * file descriptor. + */ + public AnnotatorModel(int fileDescriptor) { + annotatorPtr = nativeNewAnnotator(fileDescriptor); + if (annotatorPtr == 0L) { + throw new IllegalArgumentException("Couldn't initialize TC from file descriptor."); + } + } + + /** + * Creates a new instance of SmartSelect predictor, using the provided model image, given as a + * file path. + */ + public AnnotatorModel(String path) { + annotatorPtr = nativeNewAnnotatorFromPath(path); + if (annotatorPtr == 0L) { + throw new IllegalArgumentException("Couldn't initialize TC from given file."); + } + } + + /** Initializes the knowledge engine, passing the given serialized config to it. */ + public void initializeKnowledgeEngine(byte[] serializedConfig) { + if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) { + throw new IllegalArgumentException("Couldn't initialize the KG engine"); + } + } + + /** + * Given a string context and current selection, computes the selection suggestion. + * + * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the + * character index where the selection begins, and selectionEnd is the index of one character past + * the selection span. + * + * <p>The return value is an array of two ints: suggested selection beginning and end, with the + * same semantics as the input selectionBeginning and selectionEnd. + */ + public int[] suggestSelection( + String context, int selectionBegin, int selectionEnd, SelectionOptions options) { + return nativeSuggestSelection(annotatorPtr, context, selectionBegin, selectionEnd, options); + } + + /** + * Given a string context and current selection, classifies the type of the selected text. + * + * <p>The begin and end params are character indices in the context string. + * + * <p>Returns an array of ClassificationResult objects with the probability scores for different + * collections. + */ + public ClassificationResult[] classifyText( + String context, int selectionBegin, int selectionEnd, ClassificationOptions options) { + return nativeClassifyText(annotatorPtr, context, selectionBegin, selectionEnd, options); + } + + /** + * Annotates given input text. The annotations should cover the whole input context except for + * whitespaces, and are sorted by their position in the context string. + */ + public AnnotatedSpan[] annotate(String text, AnnotationOptions options) { + return nativeAnnotate(annotatorPtr, text, options); + } + + /** Frees up the allocated memory. */ + @Override + public void close() { + if (isClosed.compareAndSet(false, true)) { + nativeCloseAnnotator(annotatorPtr); + annotatorPtr = 0L; + } + } + + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } + } + + /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */ + public static String getLocales(int fd) { + return nativeGetLocales(fd); + } + + /** Returns the version of the model. */ + public static int getVersion(int fd) { + return nativeGetVersion(fd); + } + + /** Returns the name of the model. */ + public static String getName(int fd) { + return nativeGetName(fd); + } + + /** Information about a parsed time/date. */ + public static final class DatetimeResult { + + static final int GRANULARITY_YEAR = 0; + static final int GRANULARITY_MONTH = 1; + static final int GRANULARITY_WEEK = 2; + static final int GRANULARITY_DAY = 3; + static final int GRANULARITY_HOUR = 4; + static final int GRANULARITY_MINUTE = 5; + static final int GRANULARITY_SECOND = 6; + + private final long timeMsUtc; + private final int granularity; + + DatetimeResult(long timeMsUtc, int granularity) { + this.timeMsUtc = timeMsUtc; + this.granularity = granularity; + } + + public long getTimeMsUtc() { + return timeMsUtc; + } + + public int getGranularity() { + return granularity; + } + } + + /** Classification result for classifyText method. */ + public static final class ClassificationResult { + private final String collection; + private final float score; + private final DatetimeResult datetimeResult; + private final byte[] serializedKnowledgeResult; + + public ClassificationResult( + String collection, + float score, + DatetimeResult datetimeResult, + byte[] serializedKnowledgeResult) { + this.collection = collection; + this.score = score; + this.datetimeResult = datetimeResult; + this.serializedKnowledgeResult = serializedKnowledgeResult; + } + + /** Returns the classified entity type. */ + public String getCollection() { + if (TYPE_DATE.equals(collection) && datetimeResult != null) { + switch (datetimeResult.getGranularity()) { + case DatetimeResult.GRANULARITY_HOUR: + case DatetimeResult.GRANULARITY_MINUTE: + case DatetimeResult.GRANULARITY_SECOND: + return TYPE_DATE_TIME; + default: + return TYPE_DATE; + } + } + return collection; + } + + /** Confidence score between 0 and 1. */ + public float getScore() { + return score; + } + + public DatetimeResult getDatetimeResult() { + return datetimeResult; + } + + byte[] getSerializedKnowledgeResult() { + return serializedKnowledgeResult; + } + } + + /** Represents a result of Annotate call. */ + public static final class AnnotatedSpan { + private final int startIndex; + private final int endIndex; + private final ClassificationResult[] classification; + + AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification) { + this.startIndex = startIndex; + this.endIndex = endIndex; + this.classification = classification; + } + + public int getStartIndex() { + return startIndex; + } + + public int getEndIndex() { + return endIndex; + } + + public ClassificationResult[] getClassification() { + return classification; + } + } + + /** Represents options for the suggestSelection call. */ + public static final class SelectionOptions { + private final String locales; + + public SelectionOptions(String locales) { + this.locales = locales; + } + + public String getLocales() { + return locales; + } + } + + /** Represents options for the classifyText call. */ + public static final class ClassificationOptions { + private final long referenceTimeMsUtc; + private final String referenceTimezone; + private final String locales; + + public ClassificationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) { + this.referenceTimeMsUtc = referenceTimeMsUtc; + this.referenceTimezone = referenceTimezone; + this.locales = locale; + } + + public long getReferenceTimeMsUtc() { + return referenceTimeMsUtc; + } + + public String getReferenceTimezone() { + return referenceTimezone; + } + + public String getLocale() { + return locales; + } + } + + /** Represents options for the annotate call. */ + public static final class AnnotationOptions { + private final long referenceTimeMsUtc; + private final String referenceTimezone; + private final String locales; + + public AnnotationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) { + this.referenceTimeMsUtc = referenceTimeMsUtc; + this.referenceTimezone = referenceTimezone; + this.locales = locale; + } + + public long getReferenceTimeMsUtc() { + return referenceTimeMsUtc; + } + + public String getReferenceTimezone() { + return referenceTimezone; + } + + public String getLocale() { + return locales; + } + } + + /** + * Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long + * as the pointer is used. + */ + long getNativeAnnotator() { + return annotatorPtr; + } + + private static native long nativeNewAnnotator(int fd); + + private static native long nativeNewAnnotatorFromPath(String path); + + private static native String nativeGetLocales(int fd); + + private static native int nativeGetVersion(int fd); + + private static native String nativeGetName(int fd); + + private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig); + + private native int[] nativeSuggestSelection( + long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options); + + private native ClassificationResult[] nativeClassifyText( + long context, + String text, + int selectionBegin, + int selectionEnd, + ClassificationOptions options); + + private native AnnotatedSpan[] nativeAnnotate( + long context, String text, AnnotationOptions options); + + private native void nativeCloseAnnotator(long context); +} diff --git a/model_generated.h b/model_generated.h deleted file mode 100755 index 6ef75f6..0000000 --- a/model_generated.h +++ /dev/null @@ -1,3718 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// automatically generated by the FlatBuffers compiler, do not modify - - -#ifndef FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_ -#define FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_ - -#include "flatbuffers/flatbuffers.h" - -namespace libtextclassifier2 { - -struct CompressedBuffer; -struct CompressedBufferT; - -struct SelectionModelOptions; -struct SelectionModelOptionsT; - -struct ClassificationModelOptions; -struct ClassificationModelOptionsT; - -namespace RegexModel_ { - -struct Pattern; -struct PatternT; - -} // namespace RegexModel_ - -struct RegexModel; -struct RegexModelT; - -namespace DatetimeModelPattern_ { - -struct Regex; -struct RegexT; - -} // namespace DatetimeModelPattern_ - -struct DatetimeModelPattern; -struct DatetimeModelPatternT; - -struct DatetimeModelExtractor; -struct DatetimeModelExtractorT; - -struct DatetimeModel; -struct DatetimeModelT; - -namespace DatetimeModelLibrary_ { - -struct Item; -struct ItemT; - -} // namespace DatetimeModelLibrary_ - -struct DatetimeModelLibrary; -struct DatetimeModelLibraryT; - -struct ModelTriggeringOptions; -struct ModelTriggeringOptionsT; - -struct OutputOptions; -struct OutputOptionsT; - -struct Model; -struct ModelT; - -struct TokenizationCodepointRange; -struct TokenizationCodepointRangeT; - -namespace FeatureProcessorOptions_ { - -struct CodepointRange; -struct CodepointRangeT; - -struct BoundsSensitiveFeatures; -struct BoundsSensitiveFeaturesT; - -struct AlternativeCollectionMapEntry; -struct AlternativeCollectionMapEntryT; - -} // namespace FeatureProcessorOptions_ - -struct FeatureProcessorOptions; -struct FeatureProcessorOptionsT; - -enum ModeFlag { - ModeFlag_NONE = 0, - ModeFlag_ANNOTATION = 1, - ModeFlag_CLASSIFICATION = 2, - ModeFlag_ANNOTATION_AND_CLASSIFICATION = 3, - ModeFlag_SELECTION = 4, - ModeFlag_ANNOTATION_AND_SELECTION = 5, - ModeFlag_CLASSIFICATION_AND_SELECTION = 6, - ModeFlag_ALL = 7, - ModeFlag_MIN = ModeFlag_NONE, - ModeFlag_MAX = ModeFlag_ALL -}; - -inline ModeFlag (&EnumValuesModeFlag())[8] { - static ModeFlag values[] = { - ModeFlag_NONE, - ModeFlag_ANNOTATION, - ModeFlag_CLASSIFICATION, - ModeFlag_ANNOTATION_AND_CLASSIFICATION, - ModeFlag_SELECTION, - ModeFlag_ANNOTATION_AND_SELECTION, - ModeFlag_CLASSIFICATION_AND_SELECTION, - ModeFlag_ALL - }; - return values; -} - -inline const char **EnumNamesModeFlag() { - static const char *names[] = { - "NONE", - "ANNOTATION", - "CLASSIFICATION", - "ANNOTATION_AND_CLASSIFICATION", - "SELECTION", - "ANNOTATION_AND_SELECTION", - "CLASSIFICATION_AND_SELECTION", - "ALL", - nullptr - }; - return names; -} - -inline const char *EnumNameModeFlag(ModeFlag e) { - const size_t index = static_cast<int>(e); - return EnumNamesModeFlag()[index]; -} - -enum DatetimeExtractorType { - DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0, - DatetimeExtractorType_AM = 1, - DatetimeExtractorType_PM = 2, - DatetimeExtractorType_JANUARY = 3, - DatetimeExtractorType_FEBRUARY = 4, - DatetimeExtractorType_MARCH = 5, - DatetimeExtractorType_APRIL = 6, - DatetimeExtractorType_MAY = 7, - DatetimeExtractorType_JUNE = 8, - DatetimeExtractorType_JULY = 9, - DatetimeExtractorType_AUGUST = 10, - DatetimeExtractorType_SEPTEMBER = 11, - DatetimeExtractorType_OCTOBER = 12, - DatetimeExtractorType_NOVEMBER = 13, - DatetimeExtractorType_DECEMBER = 14, - DatetimeExtractorType_NEXT = 15, - DatetimeExtractorType_NEXT_OR_SAME = 16, - DatetimeExtractorType_LAST = 17, - DatetimeExtractorType_NOW = 18, - DatetimeExtractorType_TOMORROW = 19, - DatetimeExtractorType_YESTERDAY = 20, - DatetimeExtractorType_PAST = 21, - DatetimeExtractorType_FUTURE = 22, - DatetimeExtractorType_DAY = 23, - DatetimeExtractorType_WEEK = 24, - DatetimeExtractorType_MONTH = 25, - DatetimeExtractorType_YEAR = 26, - DatetimeExtractorType_MONDAY = 27, - DatetimeExtractorType_TUESDAY = 28, - DatetimeExtractorType_WEDNESDAY = 29, - DatetimeExtractorType_THURSDAY = 30, - DatetimeExtractorType_FRIDAY = 31, - DatetimeExtractorType_SATURDAY = 32, - DatetimeExtractorType_SUNDAY = 33, - DatetimeExtractorType_DAYS = 34, - DatetimeExtractorType_WEEKS = 35, - DatetimeExtractorType_MONTHS = 36, - DatetimeExtractorType_HOURS = 37, - DatetimeExtractorType_MINUTES = 38, - DatetimeExtractorType_SECONDS = 39, - DatetimeExtractorType_YEARS = 40, - DatetimeExtractorType_DIGITS = 41, - DatetimeExtractorType_SIGNEDDIGITS = 42, - DatetimeExtractorType_ZERO = 43, - DatetimeExtractorType_ONE = 44, - DatetimeExtractorType_TWO = 45, - DatetimeExtractorType_THREE = 46, - DatetimeExtractorType_FOUR = 47, - DatetimeExtractorType_FIVE = 48, - DatetimeExtractorType_SIX = 49, - DatetimeExtractorType_SEVEN = 50, - DatetimeExtractorType_EIGHT = 51, - DatetimeExtractorType_NINE = 52, - DatetimeExtractorType_TEN = 53, - DatetimeExtractorType_ELEVEN = 54, - DatetimeExtractorType_TWELVE = 55, - DatetimeExtractorType_THIRTEEN = 56, - DatetimeExtractorType_FOURTEEN = 57, - DatetimeExtractorType_FIFTEEN = 58, - DatetimeExtractorType_SIXTEEN = 59, - DatetimeExtractorType_SEVENTEEN = 60, - DatetimeExtractorType_EIGHTEEN = 61, - DatetimeExtractorType_NINETEEN = 62, - DatetimeExtractorType_TWENTY = 63, - DatetimeExtractorType_THIRTY = 64, - DatetimeExtractorType_FORTY = 65, - DatetimeExtractorType_FIFTY = 66, - DatetimeExtractorType_SIXTY = 67, - DatetimeExtractorType_SEVENTY = 68, - DatetimeExtractorType_EIGHTY = 69, - DatetimeExtractorType_NINETY = 70, - DatetimeExtractorType_HUNDRED = 71, - DatetimeExtractorType_THOUSAND = 72, - DatetimeExtractorType_MIN = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, - DatetimeExtractorType_MAX = DatetimeExtractorType_THOUSAND -}; - -inline DatetimeExtractorType (&EnumValuesDatetimeExtractorType())[73] { - static DatetimeExtractorType values[] = { - DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, - DatetimeExtractorType_AM, - DatetimeExtractorType_PM, - DatetimeExtractorType_JANUARY, - DatetimeExtractorType_FEBRUARY, - DatetimeExtractorType_MARCH, - DatetimeExtractorType_APRIL, - DatetimeExtractorType_MAY, - DatetimeExtractorType_JUNE, - DatetimeExtractorType_JULY, - DatetimeExtractorType_AUGUST, - DatetimeExtractorType_SEPTEMBER, - DatetimeExtractorType_OCTOBER, - DatetimeExtractorType_NOVEMBER, - DatetimeExtractorType_DECEMBER, - DatetimeExtractorType_NEXT, - DatetimeExtractorType_NEXT_OR_SAME, - DatetimeExtractorType_LAST, - DatetimeExtractorType_NOW, - DatetimeExtractorType_TOMORROW, - DatetimeExtractorType_YESTERDAY, - DatetimeExtractorType_PAST, - DatetimeExtractorType_FUTURE, - DatetimeExtractorType_DAY, - DatetimeExtractorType_WEEK, - DatetimeExtractorType_MONTH, - DatetimeExtractorType_YEAR, - DatetimeExtractorType_MONDAY, - DatetimeExtractorType_TUESDAY, - DatetimeExtractorType_WEDNESDAY, - DatetimeExtractorType_THURSDAY, - DatetimeExtractorType_FRIDAY, - DatetimeExtractorType_SATURDAY, - DatetimeExtractorType_SUNDAY, - DatetimeExtractorType_DAYS, - DatetimeExtractorType_WEEKS, - DatetimeExtractorType_MONTHS, - DatetimeExtractorType_HOURS, - DatetimeExtractorType_MINUTES, - DatetimeExtractorType_SECONDS, - DatetimeExtractorType_YEARS, - DatetimeExtractorType_DIGITS, - DatetimeExtractorType_SIGNEDDIGITS, - DatetimeExtractorType_ZERO, - DatetimeExtractorType_ONE, - DatetimeExtractorType_TWO, - DatetimeExtractorType_THREE, - DatetimeExtractorType_FOUR, - DatetimeExtractorType_FIVE, - DatetimeExtractorType_SIX, - DatetimeExtractorType_SEVEN, - DatetimeExtractorType_EIGHT, - DatetimeExtractorType_NINE, - DatetimeExtractorType_TEN, - DatetimeExtractorType_ELEVEN, - DatetimeExtractorType_TWELVE, - DatetimeExtractorType_THIRTEEN, - DatetimeExtractorType_FOURTEEN, - DatetimeExtractorType_FIFTEEN, - DatetimeExtractorType_SIXTEEN, - DatetimeExtractorType_SEVENTEEN, - DatetimeExtractorType_EIGHTEEN, - DatetimeExtractorType_NINETEEN, - DatetimeExtractorType_TWENTY, - DatetimeExtractorType_THIRTY, - DatetimeExtractorType_FORTY, - DatetimeExtractorType_FIFTY, - DatetimeExtractorType_SIXTY, - DatetimeExtractorType_SEVENTY, - DatetimeExtractorType_EIGHTY, - DatetimeExtractorType_NINETY, - DatetimeExtractorType_HUNDRED, - DatetimeExtractorType_THOUSAND - }; - return values; -} - -inline const char **EnumNamesDatetimeExtractorType() { - static const char *names[] = { - "UNKNOWN_DATETIME_EXTRACTOR_TYPE", - "AM", - "PM", - "JANUARY", - "FEBRUARY", - "MARCH", - "APRIL", - "MAY", - "JUNE", - "JULY", - "AUGUST", - "SEPTEMBER", - "OCTOBER", - "NOVEMBER", - "DECEMBER", - "NEXT", - "NEXT_OR_SAME", - "LAST", - "NOW", - "TOMORROW", - "YESTERDAY", - "PAST", - "FUTURE", - "DAY", - "WEEK", - "MONTH", - "YEAR", - "MONDAY", - "TUESDAY", - "WEDNESDAY", - "THURSDAY", - "FRIDAY", - "SATURDAY", - "SUNDAY", - "DAYS", - "WEEKS", - "MONTHS", - "HOURS", - "MINUTES", - "SECONDS", - "YEARS", - "DIGITS", - "SIGNEDDIGITS", - "ZERO", - "ONE", - "TWO", - "THREE", - "FOUR", - "FIVE", - "SIX", - "SEVEN", - "EIGHT", - "NINE", - "TEN", - "ELEVEN", - "TWELVE", - "THIRTEEN", - "FOURTEEN", - "FIFTEEN", - "SIXTEEN", - "SEVENTEEN", - "EIGHTEEN", - "NINETEEN", - "TWENTY", - "THIRTY", - "FORTY", - "FIFTY", - "SIXTY", - "SEVENTY", - "EIGHTY", - "NINETY", - "HUNDRED", - "THOUSAND", - nullptr - }; - return names; -} - -inline const char *EnumNameDatetimeExtractorType(DatetimeExtractorType e) { - const size_t index = static_cast<int>(e); - return EnumNamesDatetimeExtractorType()[index]; -} - -enum DatetimeGroupType { - DatetimeGroupType_GROUP_UNKNOWN = 0, - DatetimeGroupType_GROUP_UNUSED = 1, - DatetimeGroupType_GROUP_YEAR = 2, - DatetimeGroupType_GROUP_MONTH = 3, - DatetimeGroupType_GROUP_DAY = 4, - DatetimeGroupType_GROUP_HOUR = 5, - DatetimeGroupType_GROUP_MINUTE = 6, - DatetimeGroupType_GROUP_SECOND = 7, - DatetimeGroupType_GROUP_AMPM = 8, - DatetimeGroupType_GROUP_RELATIONDISTANCE = 9, - DatetimeGroupType_GROUP_RELATION = 10, - DatetimeGroupType_GROUP_RELATIONTYPE = 11, - DatetimeGroupType_GROUP_DUMMY1 = 12, - DatetimeGroupType_GROUP_DUMMY2 = 13, - DatetimeGroupType_MIN = DatetimeGroupType_GROUP_UNKNOWN, - DatetimeGroupType_MAX = DatetimeGroupType_GROUP_DUMMY2 -}; - -inline DatetimeGroupType (&EnumValuesDatetimeGroupType())[14] { - static DatetimeGroupType values[] = { - DatetimeGroupType_GROUP_UNKNOWN, - DatetimeGroupType_GROUP_UNUSED, - DatetimeGroupType_GROUP_YEAR, - DatetimeGroupType_GROUP_MONTH, - DatetimeGroupType_GROUP_DAY, - DatetimeGroupType_GROUP_HOUR, - DatetimeGroupType_GROUP_MINUTE, - DatetimeGroupType_GROUP_SECOND, - DatetimeGroupType_GROUP_AMPM, - DatetimeGroupType_GROUP_RELATIONDISTANCE, - DatetimeGroupType_GROUP_RELATION, - DatetimeGroupType_GROUP_RELATIONTYPE, - DatetimeGroupType_GROUP_DUMMY1, - DatetimeGroupType_GROUP_DUMMY2 - }; - return values; -} - -inline const char **EnumNamesDatetimeGroupType() { - static const char *names[] = { - "GROUP_UNKNOWN", - "GROUP_UNUSED", - "GROUP_YEAR", - "GROUP_MONTH", - "GROUP_DAY", - "GROUP_HOUR", - "GROUP_MINUTE", - "GROUP_SECOND", - "GROUP_AMPM", - "GROUP_RELATIONDISTANCE", - "GROUP_RELATION", - "GROUP_RELATIONTYPE", - "GROUP_DUMMY1", - "GROUP_DUMMY2", - nullptr - }; - return names; -} - -inline const char *EnumNameDatetimeGroupType(DatetimeGroupType e) { - const size_t index = static_cast<int>(e); - return EnumNamesDatetimeGroupType()[index]; -} - -namespace TokenizationCodepointRange_ { - -enum Role { - Role_DEFAULT_ROLE = 0, - Role_SPLIT_BEFORE = 1, - Role_SPLIT_AFTER = 2, - Role_TOKEN_SEPARATOR = 3, - Role_DISCARD_CODEPOINT = 4, - Role_WHITESPACE_SEPARATOR = 7, - Role_MIN = Role_DEFAULT_ROLE, - Role_MAX = Role_WHITESPACE_SEPARATOR -}; - -inline Role (&EnumValuesRole())[6] { - static Role values[] = { - Role_DEFAULT_ROLE, - Role_SPLIT_BEFORE, - Role_SPLIT_AFTER, - Role_TOKEN_SEPARATOR, - Role_DISCARD_CODEPOINT, - Role_WHITESPACE_SEPARATOR - }; - return values; -} - -inline const char **EnumNamesRole() { - static const char *names[] = { - "DEFAULT_ROLE", - "SPLIT_BEFORE", - "SPLIT_AFTER", - "TOKEN_SEPARATOR", - "DISCARD_CODEPOINT", - "", - "", - "WHITESPACE_SEPARATOR", - nullptr - }; - return names; -} - -inline const char *EnumNameRole(Role e) { - const size_t index = static_cast<int>(e); - return EnumNamesRole()[index]; -} - -} // namespace TokenizationCodepointRange_ - -namespace FeatureProcessorOptions_ { - -enum CenterTokenSelectionMethod { - CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD = 0, - CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK = 1, - CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION = 2, - CenterTokenSelectionMethod_MIN = CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, - CenterTokenSelectionMethod_MAX = CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION -}; - -inline CenterTokenSelectionMethod (&EnumValuesCenterTokenSelectionMethod())[3] { - static CenterTokenSelectionMethod values[] = { - CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, - CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK, - CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION - }; - return values; -} - -inline const char **EnumNamesCenterTokenSelectionMethod() { - static const char *names[] = { - "DEFAULT_CENTER_TOKEN_METHOD", - "CENTER_TOKEN_FROM_CLICK", - "CENTER_TOKEN_MIDDLE_OF_SELECTION", - nullptr - }; - return names; -} - -inline const char *EnumNameCenterTokenSelectionMethod(CenterTokenSelectionMethod e) { - const size_t index = static_cast<int>(e); - return EnumNamesCenterTokenSelectionMethod()[index]; -} - -enum TokenizationType { - TokenizationType_INVALID_TOKENIZATION_TYPE = 0, - TokenizationType_INTERNAL_TOKENIZER = 1, - TokenizationType_ICU = 2, - TokenizationType_MIXED = 3, - TokenizationType_MIN = TokenizationType_INVALID_TOKENIZATION_TYPE, - TokenizationType_MAX = TokenizationType_MIXED -}; - -inline TokenizationType (&EnumValuesTokenizationType())[4] { - static TokenizationType values[] = { - TokenizationType_INVALID_TOKENIZATION_TYPE, - TokenizationType_INTERNAL_TOKENIZER, - TokenizationType_ICU, - TokenizationType_MIXED - }; - return values; -} - -inline const char **EnumNamesTokenizationType() { - static const char *names[] = { - "INVALID_TOKENIZATION_TYPE", - "INTERNAL_TOKENIZER", - "ICU", - "MIXED", - nullptr - }; - return names; -} - -inline const char *EnumNameTokenizationType(TokenizationType e) { - const size_t index = static_cast<int>(e); - return EnumNamesTokenizationType()[index]; -} - -} // namespace FeatureProcessorOptions_ - -struct CompressedBufferT : public flatbuffers::NativeTable { - typedef CompressedBuffer TableType; - std::vector<uint8_t> buffer; - int32_t uncompressed_size; - CompressedBufferT() - : uncompressed_size(0) { - } -}; - -struct CompressedBuffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef CompressedBufferT NativeTableType; - enum { - VT_BUFFER = 4, - VT_UNCOMPRESSED_SIZE = 6 - }; - const flatbuffers::Vector<uint8_t> *buffer() const { - return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_BUFFER); - } - int32_t uncompressed_size() const { - return GetField<int32_t>(VT_UNCOMPRESSED_SIZE, 0); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BUFFER) && - verifier.Verify(buffer()) && - VerifyField<int32_t>(verifier, VT_UNCOMPRESSED_SIZE) && - verifier.EndTable(); - } - CompressedBufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(CompressedBufferT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<CompressedBuffer> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct CompressedBufferBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_buffer(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer) { - fbb_.AddOffset(CompressedBuffer::VT_BUFFER, buffer); - } - void add_uncompressed_size(int32_t uncompressed_size) { - fbb_.AddElement<int32_t>(CompressedBuffer::VT_UNCOMPRESSED_SIZE, uncompressed_size, 0); - } - explicit CompressedBufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - CompressedBufferBuilder &operator=(const CompressedBufferBuilder &); - flatbuffers::Offset<CompressedBuffer> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<CompressedBuffer>(end); - return o; - } -}; - -inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer = 0, - int32_t uncompressed_size = 0) { - CompressedBufferBuilder builder_(_fbb); - builder_.add_uncompressed_size(uncompressed_size); - builder_.add_buffer(buffer); - return builder_.Finish(); -} - -inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBufferDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<uint8_t> *buffer = nullptr, - int32_t uncompressed_size = 0) { - return libtextclassifier2::CreateCompressedBuffer( - _fbb, - buffer ? _fbb.CreateVector<uint8_t>(*buffer) : 0, - uncompressed_size); -} - -flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct SelectionModelOptionsT : public flatbuffers::NativeTable { - typedef SelectionModelOptions TableType; - bool strip_unpaired_brackets; - int32_t symmetry_context_size; - int32_t batch_size; - bool always_classify_suggested_selection; - SelectionModelOptionsT() - : strip_unpaired_brackets(true), - symmetry_context_size(0), - batch_size(1024), - always_classify_suggested_selection(false) { - } -}; - -struct SelectionModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef SelectionModelOptionsT NativeTableType; - enum { - VT_STRIP_UNPAIRED_BRACKETS = 4, - VT_SYMMETRY_CONTEXT_SIZE = 6, - VT_BATCH_SIZE = 8, - VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION = 10 - }; - bool strip_unpaired_brackets() const { - return GetField<uint8_t>(VT_STRIP_UNPAIRED_BRACKETS, 1) != 0; - } - int32_t symmetry_context_size() const { - return GetField<int32_t>(VT_SYMMETRY_CONTEXT_SIZE, 0); - } - int32_t batch_size() const { - return GetField<int32_t>(VT_BATCH_SIZE, 1024); - } - bool always_classify_suggested_selection() const { - return GetField<uint8_t>(VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION, 0) != 0; - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<uint8_t>(verifier, VT_STRIP_UNPAIRED_BRACKETS) && - VerifyField<int32_t>(verifier, VT_SYMMETRY_CONTEXT_SIZE) && - VerifyField<int32_t>(verifier, VT_BATCH_SIZE) && - VerifyField<uint8_t>(verifier, VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION) && - verifier.EndTable(); - } - SelectionModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(SelectionModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<SelectionModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct SelectionModelOptionsBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_strip_unpaired_brackets(bool strip_unpaired_brackets) { - fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_STRIP_UNPAIRED_BRACKETS, static_cast<uint8_t>(strip_unpaired_brackets), 1); - } - void add_symmetry_context_size(int32_t symmetry_context_size) { - fbb_.AddElement<int32_t>(SelectionModelOptions::VT_SYMMETRY_CONTEXT_SIZE, symmetry_context_size, 0); - } - void add_batch_size(int32_t batch_size) { - fbb_.AddElement<int32_t>(SelectionModelOptions::VT_BATCH_SIZE, batch_size, 1024); - } - void add_always_classify_suggested_selection(bool always_classify_suggested_selection) { - fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION, static_cast<uint8_t>(always_classify_suggested_selection), 0); - } - explicit SelectionModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - SelectionModelOptionsBuilder &operator=(const SelectionModelOptionsBuilder &); - flatbuffers::Offset<SelectionModelOptions> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<SelectionModelOptions>(end); - return o; - } -}; - -inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions( - flatbuffers::FlatBufferBuilder &_fbb, - bool strip_unpaired_brackets = true, - int32_t symmetry_context_size = 0, - int32_t batch_size = 1024, - bool always_classify_suggested_selection = false) { - SelectionModelOptionsBuilder builder_(_fbb); - builder_.add_batch_size(batch_size); - builder_.add_symmetry_context_size(symmetry_context_size); - builder_.add_always_classify_suggested_selection(always_classify_suggested_selection); - builder_.add_strip_unpaired_brackets(strip_unpaired_brackets); - return builder_.Finish(); -} - -flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct ClassificationModelOptionsT : public flatbuffers::NativeTable { - typedef ClassificationModelOptions TableType; - int32_t phone_min_num_digits; - int32_t phone_max_num_digits; - int32_t address_min_num_tokens; - int32_t max_num_tokens; - ClassificationModelOptionsT() - : phone_min_num_digits(7), - phone_max_num_digits(15), - address_min_num_tokens(0), - max_num_tokens(-1) { - } -}; - -struct ClassificationModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef ClassificationModelOptionsT NativeTableType; - enum { - VT_PHONE_MIN_NUM_DIGITS = 4, - VT_PHONE_MAX_NUM_DIGITS = 6, - VT_ADDRESS_MIN_NUM_TOKENS = 8, - VT_MAX_NUM_TOKENS = 10 - }; - int32_t phone_min_num_digits() const { - return GetField<int32_t>(VT_PHONE_MIN_NUM_DIGITS, 7); - } - int32_t phone_max_num_digits() const { - return GetField<int32_t>(VT_PHONE_MAX_NUM_DIGITS, 15); - } - int32_t address_min_num_tokens() const { - return GetField<int32_t>(VT_ADDRESS_MIN_NUM_TOKENS, 0); - } - int32_t max_num_tokens() const { - return GetField<int32_t>(VT_MAX_NUM_TOKENS, -1); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_PHONE_MIN_NUM_DIGITS) && - VerifyField<int32_t>(verifier, VT_PHONE_MAX_NUM_DIGITS) && - VerifyField<int32_t>(verifier, VT_ADDRESS_MIN_NUM_TOKENS) && - VerifyField<int32_t>(verifier, VT_MAX_NUM_TOKENS) && - verifier.EndTable(); - } - ClassificationModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(ClassificationModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<ClassificationModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct ClassificationModelOptionsBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_phone_min_num_digits(int32_t phone_min_num_digits) { - fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MIN_NUM_DIGITS, phone_min_num_digits, 7); - } - void add_phone_max_num_digits(int32_t phone_max_num_digits) { - fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MAX_NUM_DIGITS, phone_max_num_digits, 15); - } - void add_address_min_num_tokens(int32_t address_min_num_tokens) { - fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_ADDRESS_MIN_NUM_TOKENS, address_min_num_tokens, 0); - } - void add_max_num_tokens(int32_t max_num_tokens) { - fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_MAX_NUM_TOKENS, max_num_tokens, -1); - } - explicit ClassificationModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - ClassificationModelOptionsBuilder &operator=(const ClassificationModelOptionsBuilder &); - flatbuffers::Offset<ClassificationModelOptions> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<ClassificationModelOptions>(end); - return o; - } -}; - -inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t phone_min_num_digits = 7, - int32_t phone_max_num_digits = 15, - int32_t address_min_num_tokens = 0, - int32_t max_num_tokens = -1) { - ClassificationModelOptionsBuilder builder_(_fbb); - builder_.add_max_num_tokens(max_num_tokens); - builder_.add_address_min_num_tokens(address_min_num_tokens); - builder_.add_phone_max_num_digits(phone_max_num_digits); - builder_.add_phone_min_num_digits(phone_min_num_digits); - return builder_.Finish(); -} - -flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -namespace RegexModel_ { - -struct PatternT : public flatbuffers::NativeTable { - typedef Pattern TableType; - std::string collection_name; - std::string pattern; - libtextclassifier2::ModeFlag enabled_modes; - float target_classification_score; - float priority_score; - bool use_approximate_matching; - std::unique_ptr<libtextclassifier2::CompressedBufferT> compressed_pattern; - PatternT() - : enabled_modes(libtextclassifier2::ModeFlag_ALL), - target_classification_score(1.0f), - priority_score(0.0f), - use_approximate_matching(false) { - } -}; - -struct Pattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef PatternT NativeTableType; - enum { - VT_COLLECTION_NAME = 4, - VT_PATTERN = 6, - VT_ENABLED_MODES = 8, - VT_TARGET_CLASSIFICATION_SCORE = 10, - VT_PRIORITY_SCORE = 12, - VT_USE_APPROXIMATE_MATCHING = 14, - VT_COMPRESSED_PATTERN = 16 - }; - const flatbuffers::String *collection_name() const { - return GetPointer<const flatbuffers::String *>(VT_COLLECTION_NAME); - } - const flatbuffers::String *pattern() const { - return GetPointer<const flatbuffers::String *>(VT_PATTERN); - } - libtextclassifier2::ModeFlag enabled_modes() const { - return static_cast<libtextclassifier2::ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); - } - float target_classification_score() const { - return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f); - } - float priority_score() const { - return GetField<float>(VT_PRIORITY_SCORE, 0.0f); - } - bool use_approximate_matching() const { - return GetField<uint8_t>(VT_USE_APPROXIMATE_MATCHING, 0) != 0; - } - const libtextclassifier2::CompressedBuffer *compressed_pattern() const { - return GetPointer<const libtextclassifier2::CompressedBuffer *>(VT_COMPRESSED_PATTERN); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_COLLECTION_NAME) && - verifier.Verify(collection_name()) && - VerifyOffset(verifier, VT_PATTERN) && - verifier.Verify(pattern()) && - VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && - VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) && - VerifyField<float>(verifier, VT_PRIORITY_SCORE) && - VerifyField<uint8_t>(verifier, VT_USE_APPROXIMATE_MATCHING) && - VerifyOffset(verifier, VT_COMPRESSED_PATTERN) && - verifier.VerifyTable(compressed_pattern()) && - verifier.EndTable(); - } - PatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(PatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<Pattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct PatternBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_collection_name(flatbuffers::Offset<flatbuffers::String> collection_name) { - fbb_.AddOffset(Pattern::VT_COLLECTION_NAME, collection_name); - } - void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) { - fbb_.AddOffset(Pattern::VT_PATTERN, pattern); - } - void add_enabled_modes(libtextclassifier2::ModeFlag enabled_modes) { - fbb_.AddElement<int32_t>(Pattern::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); - } - void add_target_classification_score(float target_classification_score) { - fbb_.AddElement<float>(Pattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f); - } - void add_priority_score(float priority_score) { - fbb_.AddElement<float>(Pattern::VT_PRIORITY_SCORE, priority_score, 0.0f); - } - void add_use_approximate_matching(bool use_approximate_matching) { - fbb_.AddElement<uint8_t>(Pattern::VT_USE_APPROXIMATE_MATCHING, static_cast<uint8_t>(use_approximate_matching), 0); - } - void add_compressed_pattern(flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern) { - fbb_.AddOffset(Pattern::VT_COMPRESSED_PATTERN, compressed_pattern); - } - explicit PatternBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - PatternBuilder &operator=(const PatternBuilder &); - flatbuffers::Offset<Pattern> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<Pattern>(end); - return o; - } -}; - -inline flatbuffers::Offset<Pattern> CreatePattern( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::String> collection_name = 0, - flatbuffers::Offset<flatbuffers::String> pattern = 0, - libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL, - float target_classification_score = 1.0f, - float priority_score = 0.0f, - bool use_approximate_matching = false, - flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { - PatternBuilder builder_(_fbb); - builder_.add_compressed_pattern(compressed_pattern); - builder_.add_priority_score(priority_score); - builder_.add_target_classification_score(target_classification_score); - builder_.add_enabled_modes(enabled_modes); - builder_.add_pattern(pattern); - builder_.add_collection_name(collection_name); - builder_.add_use_approximate_matching(use_approximate_matching); - return builder_.Finish(); -} - -inline flatbuffers::Offset<Pattern> CreatePatternDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *collection_name = nullptr, - const char *pattern = nullptr, - libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL, - float target_classification_score = 1.0f, - float priority_score = 0.0f, - bool use_approximate_matching = false, - flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { - return libtextclassifier2::RegexModel_::CreatePattern( - _fbb, - collection_name ? _fbb.CreateString(collection_name) : 0, - pattern ? _fbb.CreateString(pattern) : 0, - enabled_modes, - target_classification_score, - priority_score, - use_approximate_matching, - compressed_pattern); -} - -flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -} // namespace RegexModel_ - -struct RegexModelT : public flatbuffers::NativeTable { - typedef RegexModel TableType; - std::vector<std::unique_ptr<libtextclassifier2::RegexModel_::PatternT>> patterns; - RegexModelT() { - } -}; - -struct RegexModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef RegexModelT NativeTableType; - enum { - VT_PATTERNS = 4 - }; - const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *patterns() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *>(VT_PATTERNS); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_PATTERNS) && - verifier.Verify(patterns()) && - verifier.VerifyVectorOfTables(patterns()) && - verifier.EndTable(); - } - RegexModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(RegexModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<RegexModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct RegexModelBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>> patterns) { - fbb_.AddOffset(RegexModel::VT_PATTERNS, patterns); - } - explicit RegexModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - RegexModelBuilder &operator=(const RegexModelBuilder &); - flatbuffers::Offset<RegexModel> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<RegexModel>(end); - return o; - } -}; - -inline flatbuffers::Offset<RegexModel> CreateRegexModel( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>> patterns = 0) { - RegexModelBuilder builder_(_fbb); - builder_.add_patterns(patterns); - return builder_.Finish(); -} - -inline flatbuffers::Offset<RegexModel> CreateRegexModelDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *patterns = nullptr) { - return libtextclassifier2::CreateRegexModel( - _fbb, - patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>(*patterns) : 0); -} - -flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -namespace DatetimeModelPattern_ { - -struct RegexT : public flatbuffers::NativeTable { - typedef Regex TableType; - std::string pattern; - std::vector<libtextclassifier2::DatetimeGroupType> groups; - std::unique_ptr<libtextclassifier2::CompressedBufferT> compressed_pattern; - RegexT() { - } -}; - -struct Regex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef RegexT NativeTableType; - enum { - VT_PATTERN = 4, - VT_GROUPS = 6, - VT_COMPRESSED_PATTERN = 8 - }; - const flatbuffers::String *pattern() const { - return GetPointer<const flatbuffers::String *>(VT_PATTERN); - } - const flatbuffers::Vector<int32_t> *groups() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_GROUPS); - } - const libtextclassifier2::CompressedBuffer *compressed_pattern() const { - return GetPointer<const libtextclassifier2::CompressedBuffer *>(VT_COMPRESSED_PATTERN); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_PATTERN) && - verifier.Verify(pattern()) && - VerifyOffset(verifier, VT_GROUPS) && - verifier.Verify(groups()) && - VerifyOffset(verifier, VT_COMPRESSED_PATTERN) && - verifier.VerifyTable(compressed_pattern()) && - verifier.EndTable(); - } - RegexT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(RegexT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<Regex> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct RegexBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) { - fbb_.AddOffset(Regex::VT_PATTERN, pattern); - } - void add_groups(flatbuffers::Offset<flatbuffers::Vector<int32_t>> groups) { - fbb_.AddOffset(Regex::VT_GROUPS, groups); - } - void add_compressed_pattern(flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern) { - fbb_.AddOffset(Regex::VT_COMPRESSED_PATTERN, compressed_pattern); - } - explicit RegexBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - RegexBuilder &operator=(const RegexBuilder &); - flatbuffers::Offset<Regex> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<Regex>(end); - return o; - } -}; - -inline flatbuffers::Offset<Regex> CreateRegex( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::String> pattern = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> groups = 0, - flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { - RegexBuilder builder_(_fbb); - builder_.add_compressed_pattern(compressed_pattern); - builder_.add_groups(groups); - builder_.add_pattern(pattern); - return builder_.Finish(); -} - -inline flatbuffers::Offset<Regex> CreateRegexDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *pattern = nullptr, - const std::vector<int32_t> *groups = nullptr, - flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) { - return libtextclassifier2::DatetimeModelPattern_::CreateRegex( - _fbb, - pattern ? _fbb.CreateString(pattern) : 0, - groups ? _fbb.CreateVector<int32_t>(*groups) : 0, - compressed_pattern); -} - -flatbuffers::Offset<Regex> CreateRegex(flatbuffers::FlatBufferBuilder &_fbb, const RegexT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -} // namespace DatetimeModelPattern_ - -struct DatetimeModelPatternT : public flatbuffers::NativeTable { - typedef DatetimeModelPattern TableType; - std::vector<std::unique_ptr<libtextclassifier2::DatetimeModelPattern_::RegexT>> regexes; - std::vector<int32_t> locales; - float target_classification_score; - float priority_score; - ModeFlag enabled_modes; - DatetimeModelPatternT() - : target_classification_score(1.0f), - priority_score(0.0f), - enabled_modes(ModeFlag_ALL) { - } -}; - -struct DatetimeModelPattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef DatetimeModelPatternT NativeTableType; - enum { - VT_REGEXES = 4, - VT_LOCALES = 6, - VT_TARGET_CLASSIFICATION_SCORE = 8, - VT_PRIORITY_SCORE = 10, - VT_ENABLED_MODES = 12 - }; - const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *regexes() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *>(VT_REGEXES); - } - const flatbuffers::Vector<int32_t> *locales() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES); - } - float target_classification_score() const { - return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f); - } - float priority_score() const { - return GetField<float>(VT_PRIORITY_SCORE, 0.0f); - } - ModeFlag enabled_modes() const { - return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_REGEXES) && - verifier.Verify(regexes()) && - verifier.VerifyVectorOfTables(regexes()) && - VerifyOffset(verifier, VT_LOCALES) && - verifier.Verify(locales()) && - VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) && - VerifyField<float>(verifier, VT_PRIORITY_SCORE) && - VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && - verifier.EndTable(); - } - DatetimeModelPatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<DatetimeModelPattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct DatetimeModelPatternBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_regexes(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>> regexes) { - fbb_.AddOffset(DatetimeModelPattern::VT_REGEXES, regexes); - } - void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) { - fbb_.AddOffset(DatetimeModelPattern::VT_LOCALES, locales); - } - void add_target_classification_score(float target_classification_score) { - fbb_.AddElement<float>(DatetimeModelPattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f); - } - void add_priority_score(float priority_score) { - fbb_.AddElement<float>(DatetimeModelPattern::VT_PRIORITY_SCORE, priority_score, 0.0f); - } - void add_enabled_modes(ModeFlag enabled_modes) { - fbb_.AddElement<int32_t>(DatetimeModelPattern::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); - } - explicit DatetimeModelPatternBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - DatetimeModelPatternBuilder &operator=(const DatetimeModelPatternBuilder &); - flatbuffers::Offset<DatetimeModelPattern> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<DatetimeModelPattern>(end); - return o; - } -}; - -inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>> regexes = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0, - float target_classification_score = 1.0f, - float priority_score = 0.0f, - ModeFlag enabled_modes = ModeFlag_ALL) { - DatetimeModelPatternBuilder builder_(_fbb); - builder_.add_enabled_modes(enabled_modes); - builder_.add_priority_score(priority_score); - builder_.add_target_classification_score(target_classification_score); - builder_.add_locales(locales); - builder_.add_regexes(regexes); - return builder_.Finish(); -} - -inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPatternDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *regexes = nullptr, - const std::vector<int32_t> *locales = nullptr, - float target_classification_score = 1.0f, - float priority_score = 0.0f, - ModeFlag enabled_modes = ModeFlag_ALL) { - return libtextclassifier2::CreateDatetimeModelPattern( - _fbb, - regexes ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>(*regexes) : 0, - locales ? _fbb.CreateVector<int32_t>(*locales) : 0, - target_classification_score, - priority_score, - enabled_modes); -} - -flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct DatetimeModelExtractorT : public flatbuffers::NativeTable { - typedef DatetimeModelExtractor TableType; - DatetimeExtractorType extractor; - std::string pattern; - std::vector<int32_t> locales; - std::unique_ptr<CompressedBufferT> compressed_pattern; - DatetimeModelExtractorT() - : extractor(DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE) { - } -}; - -struct DatetimeModelExtractor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef DatetimeModelExtractorT NativeTableType; - enum { - VT_EXTRACTOR = 4, - VT_PATTERN = 6, - VT_LOCALES = 8, - VT_COMPRESSED_PATTERN = 10 - }; - DatetimeExtractorType extractor() const { - return static_cast<DatetimeExtractorType>(GetField<int32_t>(VT_EXTRACTOR, 0)); - } - const flatbuffers::String *pattern() const { - return GetPointer<const flatbuffers::String *>(VT_PATTERN); - } - const flatbuffers::Vector<int32_t> *locales() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES); - } - const CompressedBuffer *compressed_pattern() const { - return GetPointer<const CompressedBuffer *>(VT_COMPRESSED_PATTERN); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_EXTRACTOR) && - VerifyOffset(verifier, VT_PATTERN) && - verifier.Verify(pattern()) && - VerifyOffset(verifier, VT_LOCALES) && - verifier.Verify(locales()) && - VerifyOffset(verifier, VT_COMPRESSED_PATTERN) && - verifier.VerifyTable(compressed_pattern()) && - verifier.EndTable(); - } - DatetimeModelExtractorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(DatetimeModelExtractorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<DatetimeModelExtractor> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct DatetimeModelExtractorBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_extractor(DatetimeExtractorType extractor) { - fbb_.AddElement<int32_t>(DatetimeModelExtractor::VT_EXTRACTOR, static_cast<int32_t>(extractor), 0); - } - void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) { - fbb_.AddOffset(DatetimeModelExtractor::VT_PATTERN, pattern); - } - void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) { - fbb_.AddOffset(DatetimeModelExtractor::VT_LOCALES, locales); - } - void add_compressed_pattern(flatbuffers::Offset<CompressedBuffer> compressed_pattern) { - fbb_.AddOffset(DatetimeModelExtractor::VT_COMPRESSED_PATTERN, compressed_pattern); - } - explicit DatetimeModelExtractorBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - DatetimeModelExtractorBuilder &operator=(const DatetimeModelExtractorBuilder &); - flatbuffers::Offset<DatetimeModelExtractor> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<DatetimeModelExtractor>(end); - return o; - } -}; - -inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor( - flatbuffers::FlatBufferBuilder &_fbb, - DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, - flatbuffers::Offset<flatbuffers::String> pattern = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0, - flatbuffers::Offset<CompressedBuffer> compressed_pattern = 0) { - DatetimeModelExtractorBuilder builder_(_fbb); - builder_.add_compressed_pattern(compressed_pattern); - builder_.add_locales(locales); - builder_.add_pattern(pattern); - builder_.add_extractor(extractor); - return builder_.Finish(); -} - -inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractorDirect( - flatbuffers::FlatBufferBuilder &_fbb, - DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE, - const char *pattern = nullptr, - const std::vector<int32_t> *locales = nullptr, - flatbuffers::Offset<CompressedBuffer> compressed_pattern = 0) { - return libtextclassifier2::CreateDatetimeModelExtractor( - _fbb, - extractor, - pattern ? _fbb.CreateString(pattern) : 0, - locales ? _fbb.CreateVector<int32_t>(*locales) : 0, - compressed_pattern); -} - -flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct DatetimeModelT : public flatbuffers::NativeTable { - typedef DatetimeModel TableType; - std::vector<std::string> locales; - std::vector<std::unique_ptr<DatetimeModelPatternT>> patterns; - std::vector<std::unique_ptr<DatetimeModelExtractorT>> extractors; - bool use_extractors_for_locating; - std::vector<int32_t> default_locales; - DatetimeModelT() - : use_extractors_for_locating(true) { - } -}; - -struct DatetimeModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef DatetimeModelT NativeTableType; - enum { - VT_LOCALES = 4, - VT_PATTERNS = 6, - VT_EXTRACTORS = 8, - VT_USE_EXTRACTORS_FOR_LOCATING = 10, - VT_DEFAULT_LOCALES = 12 - }; - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *locales() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_LOCALES); - } - const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>> *>(VT_PATTERNS); - } - const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *>(VT_EXTRACTORS); - } - bool use_extractors_for_locating() const { - return GetField<uint8_t>(VT_USE_EXTRACTORS_FOR_LOCATING, 1) != 0; - } - const flatbuffers::Vector<int32_t> *default_locales() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DEFAULT_LOCALES); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_LOCALES) && - verifier.Verify(locales()) && - verifier.VerifyVectorOfStrings(locales()) && - VerifyOffset(verifier, VT_PATTERNS) && - verifier.Verify(patterns()) && - verifier.VerifyVectorOfTables(patterns()) && - VerifyOffset(verifier, VT_EXTRACTORS) && - verifier.Verify(extractors()) && - verifier.VerifyVectorOfTables(extractors()) && - VerifyField<uint8_t>(verifier, VT_USE_EXTRACTORS_FOR_LOCATING) && - VerifyOffset(verifier, VT_DEFAULT_LOCALES) && - verifier.Verify(default_locales()) && - verifier.EndTable(); - } - DatetimeModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(DatetimeModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<DatetimeModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct DatetimeModelBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_locales(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales) { - fbb_.AddOffset(DatetimeModel::VT_LOCALES, locales); - } - void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns) { - fbb_.AddOffset(DatetimeModel::VT_PATTERNS, patterns); - } - void add_extractors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors) { - fbb_.AddOffset(DatetimeModel::VT_EXTRACTORS, extractors); - } - void add_use_extractors_for_locating(bool use_extractors_for_locating) { - fbb_.AddElement<uint8_t>(DatetimeModel::VT_USE_EXTRACTORS_FOR_LOCATING, static_cast<uint8_t>(use_extractors_for_locating), 1); - } - void add_default_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> default_locales) { - fbb_.AddOffset(DatetimeModel::VT_DEFAULT_LOCALES, default_locales); - } - explicit DatetimeModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - DatetimeModelBuilder &operator=(const DatetimeModelBuilder &); - flatbuffers::Offset<DatetimeModel> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<DatetimeModel>(end); - return o; - } -}; - -inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModel( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors = 0, - bool use_extractors_for_locating = true, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> default_locales = 0) { - DatetimeModelBuilder builder_(_fbb); - builder_.add_default_locales(default_locales); - builder_.add_extractors(extractors); - builder_.add_patterns(patterns); - builder_.add_locales(locales); - builder_.add_use_extractors_for_locating(use_extractors_for_locating); - return builder_.Finish(); -} - -inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModelDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *locales = nullptr, - const std::vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns = nullptr, - const std::vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors = nullptr, - bool use_extractors_for_locating = true, - const std::vector<int32_t> *default_locales = nullptr) { - return libtextclassifier2::CreateDatetimeModel( - _fbb, - locales ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*locales) : 0, - patterns ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>>(*patterns) : 0, - extractors ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>>(*extractors) : 0, - use_extractors_for_locating, - default_locales ? _fbb.CreateVector<int32_t>(*default_locales) : 0); -} - -flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -namespace DatetimeModelLibrary_ { - -struct ItemT : public flatbuffers::NativeTable { - typedef Item TableType; - std::string key; - std::unique_ptr<libtextclassifier2::DatetimeModelT> value; - ItemT() { - } -}; - -struct Item FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef ItemT NativeTableType; - enum { - VT_KEY = 4, - VT_VALUE = 6 - }; - const flatbuffers::String *key() const { - return GetPointer<const flatbuffers::String *>(VT_KEY); - } - const libtextclassifier2::DatetimeModel *value() const { - return GetPointer<const libtextclassifier2::DatetimeModel *>(VT_VALUE); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_KEY) && - verifier.Verify(key()) && - VerifyOffset(verifier, VT_VALUE) && - verifier.VerifyTable(value()) && - verifier.EndTable(); - } - ItemT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(ItemT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<Item> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ItemT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct ItemBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_key(flatbuffers::Offset<flatbuffers::String> key) { - fbb_.AddOffset(Item::VT_KEY, key); - } - void add_value(flatbuffers::Offset<libtextclassifier2::DatetimeModel> value) { - fbb_.AddOffset(Item::VT_VALUE, value); - } - explicit ItemBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - ItemBuilder &operator=(const ItemBuilder &); - flatbuffers::Offset<Item> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<Item>(end); - return o; - } -}; - -inline flatbuffers::Offset<Item> CreateItem( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::String> key = 0, - flatbuffers::Offset<libtextclassifier2::DatetimeModel> value = 0) { - ItemBuilder builder_(_fbb); - builder_.add_value(value); - builder_.add_key(key); - return builder_.Finish(); -} - -inline flatbuffers::Offset<Item> CreateItemDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *key = nullptr, - flatbuffers::Offset<libtextclassifier2::DatetimeModel> value = 0) { - return libtextclassifier2::DatetimeModelLibrary_::CreateItem( - _fbb, - key ? _fbb.CreateString(key) : 0, - value); -} - -flatbuffers::Offset<Item> CreateItem(flatbuffers::FlatBufferBuilder &_fbb, const ItemT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -} // namespace DatetimeModelLibrary_ - -struct DatetimeModelLibraryT : public flatbuffers::NativeTable { - typedef DatetimeModelLibrary TableType; - std::vector<std::unique_ptr<libtextclassifier2::DatetimeModelLibrary_::ItemT>> models; - DatetimeModelLibraryT() { - } -}; - -struct DatetimeModelLibrary FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef DatetimeModelLibraryT NativeTableType; - enum { - VT_MODELS = 4 - }; - const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *models() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *>(VT_MODELS); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_MODELS) && - verifier.Verify(models()) && - verifier.VerifyVectorOfTables(models()) && - verifier.EndTable(); - } - DatetimeModelLibraryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(DatetimeModelLibraryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<DatetimeModelLibrary> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct DatetimeModelLibraryBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_models(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>> models) { - fbb_.AddOffset(DatetimeModelLibrary::VT_MODELS, models); - } - explicit DatetimeModelLibraryBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - DatetimeModelLibraryBuilder &operator=(const DatetimeModelLibraryBuilder &); - flatbuffers::Offset<DatetimeModelLibrary> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<DatetimeModelLibrary>(end); - return o; - } -}; - -inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>> models = 0) { - DatetimeModelLibraryBuilder builder_(_fbb); - builder_.add_models(models); - return builder_.Finish(); -} - -inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibraryDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *models = nullptr) { - return libtextclassifier2::CreateDatetimeModelLibrary( - _fbb, - models ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>(*models) : 0); -} - -flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct ModelTriggeringOptionsT : public flatbuffers::NativeTable { - typedef ModelTriggeringOptions TableType; - float min_annotate_confidence; - ModeFlag enabled_modes; - ModelTriggeringOptionsT() - : min_annotate_confidence(0.0f), - enabled_modes(ModeFlag_ALL) { - } -}; - -struct ModelTriggeringOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef ModelTriggeringOptionsT NativeTableType; - enum { - VT_MIN_ANNOTATE_CONFIDENCE = 4, - VT_ENABLED_MODES = 6 - }; - float min_annotate_confidence() const { - return GetField<float>(VT_MIN_ANNOTATE_CONFIDENCE, 0.0f); - } - ModeFlag enabled_modes() const { - return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<float>(verifier, VT_MIN_ANNOTATE_CONFIDENCE) && - VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && - verifier.EndTable(); - } - ModelTriggeringOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(ModelTriggeringOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<ModelTriggeringOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct ModelTriggeringOptionsBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_min_annotate_confidence(float min_annotate_confidence) { - fbb_.AddElement<float>(ModelTriggeringOptions::VT_MIN_ANNOTATE_CONFIDENCE, min_annotate_confidence, 0.0f); - } - void add_enabled_modes(ModeFlag enabled_modes) { - fbb_.AddElement<int32_t>(ModelTriggeringOptions::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); - } - explicit ModelTriggeringOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - ModelTriggeringOptionsBuilder &operator=(const ModelTriggeringOptionsBuilder &); - flatbuffers::Offset<ModelTriggeringOptions> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<ModelTriggeringOptions>(end); - return o; - } -}; - -inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions( - flatbuffers::FlatBufferBuilder &_fbb, - float min_annotate_confidence = 0.0f, - ModeFlag enabled_modes = ModeFlag_ALL) { - ModelTriggeringOptionsBuilder builder_(_fbb); - builder_.add_enabled_modes(enabled_modes); - builder_.add_min_annotate_confidence(min_annotate_confidence); - return builder_.Finish(); -} - -flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct OutputOptionsT : public flatbuffers::NativeTable { - typedef OutputOptions TableType; - std::vector<std::string> filtered_collections_annotation; - std::vector<std::string> filtered_collections_classification; - std::vector<std::string> filtered_collections_selection; - OutputOptionsT() { - } -}; - -struct OutputOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef OutputOptionsT NativeTableType; - enum { - VT_FILTERED_COLLECTIONS_ANNOTATION = 4, - VT_FILTERED_COLLECTIONS_CLASSIFICATION = 6, - VT_FILTERED_COLLECTIONS_SELECTION = 8 - }; - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_annotation() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_ANNOTATION); - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_classification() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_CLASSIFICATION); - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_selection() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_SELECTION); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_ANNOTATION) && - verifier.Verify(filtered_collections_annotation()) && - verifier.VerifyVectorOfStrings(filtered_collections_annotation()) && - VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_CLASSIFICATION) && - verifier.Verify(filtered_collections_classification()) && - verifier.VerifyVectorOfStrings(filtered_collections_classification()) && - VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_SELECTION) && - verifier.Verify(filtered_collections_selection()) && - verifier.VerifyVectorOfStrings(filtered_collections_selection()) && - verifier.EndTable(); - } - OutputOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(OutputOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<OutputOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct OutputOptionsBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_filtered_collections_annotation(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_annotation) { - fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_ANNOTATION, filtered_collections_annotation); - } - void add_filtered_collections_classification(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_classification) { - fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_CLASSIFICATION, filtered_collections_classification); - } - void add_filtered_collections_selection(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_selection) { - fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_SELECTION, filtered_collections_selection); - } - explicit OutputOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - OutputOptionsBuilder &operator=(const OutputOptionsBuilder &); - flatbuffers::Offset<OutputOptions> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<OutputOptions>(end); - return o; - } -}; - -inline flatbuffers::Offset<OutputOptions> CreateOutputOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_annotation = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_classification = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_selection = 0) { - OutputOptionsBuilder builder_(_fbb); - builder_.add_filtered_collections_selection(filtered_collections_selection); - builder_.add_filtered_collections_classification(filtered_collections_classification); - builder_.add_filtered_collections_annotation(filtered_collections_annotation); - return builder_.Finish(); -} - -inline flatbuffers::Offset<OutputOptions> CreateOutputOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_annotation = nullptr, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_classification = nullptr, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_selection = nullptr) { - return libtextclassifier2::CreateOutputOptions( - _fbb, - filtered_collections_annotation ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_annotation) : 0, - filtered_collections_classification ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_classification) : 0, - filtered_collections_selection ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_selection) : 0); -} - -flatbuffers::Offset<OutputOptions> CreateOutputOptions(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct ModelT : public flatbuffers::NativeTable { - typedef Model TableType; - std::string locales; - int32_t version; - std::string name; - std::unique_ptr<FeatureProcessorOptionsT> selection_feature_options; - std::unique_ptr<FeatureProcessorOptionsT> classification_feature_options; - std::vector<uint8_t> selection_model; - std::vector<uint8_t> classification_model; - std::vector<uint8_t> embedding_model; - std::unique_ptr<SelectionModelOptionsT> selection_options; - std::unique_ptr<ClassificationModelOptionsT> classification_options; - std::unique_ptr<RegexModelT> regex_model; - std::unique_ptr<DatetimeModelT> datetime_model; - std::unique_ptr<ModelTriggeringOptionsT> triggering_options; - ModeFlag enabled_modes; - bool snap_whitespace_selections; - std::unique_ptr<OutputOptionsT> output_options; - ModelT() - : version(0), - enabled_modes(ModeFlag_ALL), - snap_whitespace_selections(true) { - } -}; - -struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef ModelT NativeTableType; - enum { - VT_LOCALES = 4, - VT_VERSION = 6, - VT_NAME = 8, - VT_SELECTION_FEATURE_OPTIONS = 10, - VT_CLASSIFICATION_FEATURE_OPTIONS = 12, - VT_SELECTION_MODEL = 14, - VT_CLASSIFICATION_MODEL = 16, - VT_EMBEDDING_MODEL = 18, - VT_SELECTION_OPTIONS = 20, - VT_CLASSIFICATION_OPTIONS = 22, - VT_REGEX_MODEL = 24, - VT_DATETIME_MODEL = 26, - VT_TRIGGERING_OPTIONS = 28, - VT_ENABLED_MODES = 30, - VT_SNAP_WHITESPACE_SELECTIONS = 32, - VT_OUTPUT_OPTIONS = 34 - }; - const flatbuffers::String *locales() const { - return GetPointer<const flatbuffers::String *>(VT_LOCALES); - } - int32_t version() const { - return GetField<int32_t>(VT_VERSION, 0); - } - const flatbuffers::String *name() const { - return GetPointer<const flatbuffers::String *>(VT_NAME); - } - const FeatureProcessorOptions *selection_feature_options() const { - return GetPointer<const FeatureProcessorOptions *>(VT_SELECTION_FEATURE_OPTIONS); - } - const FeatureProcessorOptions *classification_feature_options() const { - return GetPointer<const FeatureProcessorOptions *>(VT_CLASSIFICATION_FEATURE_OPTIONS); - } - const flatbuffers::Vector<uint8_t> *selection_model() const { - return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_SELECTION_MODEL); - } - const flatbuffers::Vector<uint8_t> *classification_model() const { - return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CLASSIFICATION_MODEL); - } - const flatbuffers::Vector<uint8_t> *embedding_model() const { - return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_EMBEDDING_MODEL); - } - const SelectionModelOptions *selection_options() const { - return GetPointer<const SelectionModelOptions *>(VT_SELECTION_OPTIONS); - } - const ClassificationModelOptions *classification_options() const { - return GetPointer<const ClassificationModelOptions *>(VT_CLASSIFICATION_OPTIONS); - } - const RegexModel *regex_model() const { - return GetPointer<const RegexModel *>(VT_REGEX_MODEL); - } - const DatetimeModel *datetime_model() const { - return GetPointer<const DatetimeModel *>(VT_DATETIME_MODEL); - } - const ModelTriggeringOptions *triggering_options() const { - return GetPointer<const ModelTriggeringOptions *>(VT_TRIGGERING_OPTIONS); - } - ModeFlag enabled_modes() const { - return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7)); - } - bool snap_whitespace_selections() const { - return GetField<uint8_t>(VT_SNAP_WHITESPACE_SELECTIONS, 1) != 0; - } - const OutputOptions *output_options() const { - return GetPointer<const OutputOptions *>(VT_OUTPUT_OPTIONS); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_LOCALES) && - verifier.Verify(locales()) && - VerifyField<int32_t>(verifier, VT_VERSION) && - VerifyOffset(verifier, VT_NAME) && - verifier.Verify(name()) && - VerifyOffset(verifier, VT_SELECTION_FEATURE_OPTIONS) && - verifier.VerifyTable(selection_feature_options()) && - VerifyOffset(verifier, VT_CLASSIFICATION_FEATURE_OPTIONS) && - verifier.VerifyTable(classification_feature_options()) && - VerifyOffset(verifier, VT_SELECTION_MODEL) && - verifier.Verify(selection_model()) && - VerifyOffset(verifier, VT_CLASSIFICATION_MODEL) && - verifier.Verify(classification_model()) && - VerifyOffset(verifier, VT_EMBEDDING_MODEL) && - verifier.Verify(embedding_model()) && - VerifyOffset(verifier, VT_SELECTION_OPTIONS) && - verifier.VerifyTable(selection_options()) && - VerifyOffset(verifier, VT_CLASSIFICATION_OPTIONS) && - verifier.VerifyTable(classification_options()) && - VerifyOffset(verifier, VT_REGEX_MODEL) && - verifier.VerifyTable(regex_model()) && - VerifyOffset(verifier, VT_DATETIME_MODEL) && - verifier.VerifyTable(datetime_model()) && - VerifyOffset(verifier, VT_TRIGGERING_OPTIONS) && - verifier.VerifyTable(triggering_options()) && - VerifyField<int32_t>(verifier, VT_ENABLED_MODES) && - VerifyField<uint8_t>(verifier, VT_SNAP_WHITESPACE_SELECTIONS) && - VerifyOffset(verifier, VT_OUTPUT_OPTIONS) && - verifier.VerifyTable(output_options()) && - verifier.EndTable(); - } - ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<Model> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct ModelBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_locales(flatbuffers::Offset<flatbuffers::String> locales) { - fbb_.AddOffset(Model::VT_LOCALES, locales); - } - void add_version(int32_t version) { - fbb_.AddElement<int32_t>(Model::VT_VERSION, version, 0); - } - void add_name(flatbuffers::Offset<flatbuffers::String> name) { - fbb_.AddOffset(Model::VT_NAME, name); - } - void add_selection_feature_options(flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options) { - fbb_.AddOffset(Model::VT_SELECTION_FEATURE_OPTIONS, selection_feature_options); - } - void add_classification_feature_options(flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options) { - fbb_.AddOffset(Model::VT_CLASSIFICATION_FEATURE_OPTIONS, classification_feature_options); - } - void add_selection_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model) { - fbb_.AddOffset(Model::VT_SELECTION_MODEL, selection_model); - } - void add_classification_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model) { - fbb_.AddOffset(Model::VT_CLASSIFICATION_MODEL, classification_model); - } - void add_embedding_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model) { - fbb_.AddOffset(Model::VT_EMBEDDING_MODEL, embedding_model); - } - void add_selection_options(flatbuffers::Offset<SelectionModelOptions> selection_options) { - fbb_.AddOffset(Model::VT_SELECTION_OPTIONS, selection_options); - } - void add_classification_options(flatbuffers::Offset<ClassificationModelOptions> classification_options) { - fbb_.AddOffset(Model::VT_CLASSIFICATION_OPTIONS, classification_options); - } - void add_regex_model(flatbuffers::Offset<RegexModel> regex_model) { - fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model); - } - void add_datetime_model(flatbuffers::Offset<DatetimeModel> datetime_model) { - fbb_.AddOffset(Model::VT_DATETIME_MODEL, datetime_model); - } - void add_triggering_options(flatbuffers::Offset<ModelTriggeringOptions> triggering_options) { - fbb_.AddOffset(Model::VT_TRIGGERING_OPTIONS, triggering_options); - } - void add_enabled_modes(ModeFlag enabled_modes) { - fbb_.AddElement<int32_t>(Model::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7); - } - void add_snap_whitespace_selections(bool snap_whitespace_selections) { - fbb_.AddElement<uint8_t>(Model::VT_SNAP_WHITESPACE_SELECTIONS, static_cast<uint8_t>(snap_whitespace_selections), 1); - } - void add_output_options(flatbuffers::Offset<OutputOptions> output_options) { - fbb_.AddOffset(Model::VT_OUTPUT_OPTIONS, output_options); - } - explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - ModelBuilder &operator=(const ModelBuilder &); - flatbuffers::Offset<Model> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<Model>(end); - return o; - } -}; - -inline flatbuffers::Offset<Model> CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::String> locales = 0, - int32_t version = 0, - flatbuffers::Offset<flatbuffers::String> name = 0, - flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0, - flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0, - flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model = 0, - flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model = 0, - flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model = 0, - flatbuffers::Offset<SelectionModelOptions> selection_options = 0, - flatbuffers::Offset<ClassificationModelOptions> classification_options = 0, - flatbuffers::Offset<RegexModel> regex_model = 0, - flatbuffers::Offset<DatetimeModel> datetime_model = 0, - flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0, - ModeFlag enabled_modes = ModeFlag_ALL, - bool snap_whitespace_selections = true, - flatbuffers::Offset<OutputOptions> output_options = 0) { - ModelBuilder builder_(_fbb); - builder_.add_output_options(output_options); - builder_.add_enabled_modes(enabled_modes); - builder_.add_triggering_options(triggering_options); - builder_.add_datetime_model(datetime_model); - builder_.add_regex_model(regex_model); - builder_.add_classification_options(classification_options); - builder_.add_selection_options(selection_options); - builder_.add_embedding_model(embedding_model); - builder_.add_classification_model(classification_model); - builder_.add_selection_model(selection_model); - builder_.add_classification_feature_options(classification_feature_options); - builder_.add_selection_feature_options(selection_feature_options); - builder_.add_name(name); - builder_.add_version(version); - builder_.add_locales(locales); - builder_.add_snap_whitespace_selections(snap_whitespace_selections); - return builder_.Finish(); -} - -inline flatbuffers::Offset<Model> CreateModelDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *locales = nullptr, - int32_t version = 0, - const char *name = nullptr, - flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0, - flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0, - const std::vector<uint8_t> *selection_model = nullptr, - const std::vector<uint8_t> *classification_model = nullptr, - const std::vector<uint8_t> *embedding_model = nullptr, - flatbuffers::Offset<SelectionModelOptions> selection_options = 0, - flatbuffers::Offset<ClassificationModelOptions> classification_options = 0, - flatbuffers::Offset<RegexModel> regex_model = 0, - flatbuffers::Offset<DatetimeModel> datetime_model = 0, - flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0, - ModeFlag enabled_modes = ModeFlag_ALL, - bool snap_whitespace_selections = true, - flatbuffers::Offset<OutputOptions> output_options = 0) { - return libtextclassifier2::CreateModel( - _fbb, - locales ? _fbb.CreateString(locales) : 0, - version, - name ? _fbb.CreateString(name) : 0, - selection_feature_options, - classification_feature_options, - selection_model ? _fbb.CreateVector<uint8_t>(*selection_model) : 0, - classification_model ? _fbb.CreateVector<uint8_t>(*classification_model) : 0, - embedding_model ? _fbb.CreateVector<uint8_t>(*embedding_model) : 0, - selection_options, - classification_options, - regex_model, - datetime_model, - triggering_options, - enabled_modes, - snap_whitespace_selections, - output_options); -} - -flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct TokenizationCodepointRangeT : public flatbuffers::NativeTable { - typedef TokenizationCodepointRange TableType; - int32_t start; - int32_t end; - libtextclassifier2::TokenizationCodepointRange_::Role role; - int32_t script_id; - TokenizationCodepointRangeT() - : start(0), - end(0), - role(libtextclassifier2::TokenizationCodepointRange_::Role_DEFAULT_ROLE), - script_id(0) { - } -}; - -struct TokenizationCodepointRange FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef TokenizationCodepointRangeT NativeTableType; - enum { - VT_START = 4, - VT_END = 6, - VT_ROLE = 8, - VT_SCRIPT_ID = 10 - }; - int32_t start() const { - return GetField<int32_t>(VT_START, 0); - } - int32_t end() const { - return GetField<int32_t>(VT_END, 0); - } - libtextclassifier2::TokenizationCodepointRange_::Role role() const { - return static_cast<libtextclassifier2::TokenizationCodepointRange_::Role>(GetField<int32_t>(VT_ROLE, 0)); - } - int32_t script_id() const { - return GetField<int32_t>(VT_SCRIPT_ID, 0); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_START) && - VerifyField<int32_t>(verifier, VT_END) && - VerifyField<int32_t>(verifier, VT_ROLE) && - VerifyField<int32_t>(verifier, VT_SCRIPT_ID) && - verifier.EndTable(); - } - TokenizationCodepointRangeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(TokenizationCodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<TokenizationCodepointRange> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct TokenizationCodepointRangeBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_start(int32_t start) { - fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_START, start, 0); - } - void add_end(int32_t end) { - fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_END, end, 0); - } - void add_role(libtextclassifier2::TokenizationCodepointRange_::Role role) { - fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_ROLE, static_cast<int32_t>(role), 0); - } - void add_script_id(int32_t script_id) { - fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_SCRIPT_ID, script_id, 0); - } - explicit TokenizationCodepointRangeBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - TokenizationCodepointRangeBuilder &operator=(const TokenizationCodepointRangeBuilder &); - flatbuffers::Offset<TokenizationCodepointRange> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<TokenizationCodepointRange>(end); - return o; - } -}; - -inline flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t start = 0, - int32_t end = 0, - libtextclassifier2::TokenizationCodepointRange_::Role role = libtextclassifier2::TokenizationCodepointRange_::Role_DEFAULT_ROLE, - int32_t script_id = 0) { - TokenizationCodepointRangeBuilder builder_(_fbb); - builder_.add_script_id(script_id); - builder_.add_role(role); - builder_.add_end(end); - builder_.add_start(start); - return builder_.Finish(); -} - -flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -namespace FeatureProcessorOptions_ { - -struct CodepointRangeT : public flatbuffers::NativeTable { - typedef CodepointRange TableType; - int32_t start; - int32_t end; - CodepointRangeT() - : start(0), - end(0) { - } -}; - -struct CodepointRange FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef CodepointRangeT NativeTableType; - enum { - VT_START = 4, - VT_END = 6 - }; - int32_t start() const { - return GetField<int32_t>(VT_START, 0); - } - int32_t end() const { - return GetField<int32_t>(VT_END, 0); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_START) && - VerifyField<int32_t>(verifier, VT_END) && - verifier.EndTable(); - } - CodepointRangeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(CodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<CodepointRange> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct CodepointRangeBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_start(int32_t start) { - fbb_.AddElement<int32_t>(CodepointRange::VT_START, start, 0); - } - void add_end(int32_t end) { - fbb_.AddElement<int32_t>(CodepointRange::VT_END, end, 0); - } - explicit CodepointRangeBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - CodepointRangeBuilder &operator=(const CodepointRangeBuilder &); - flatbuffers::Offset<CodepointRange> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<CodepointRange>(end); - return o; - } -}; - -inline flatbuffers::Offset<CodepointRange> CreateCodepointRange( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t start = 0, - int32_t end = 0) { - CodepointRangeBuilder builder_(_fbb); - builder_.add_end(end); - builder_.add_start(start); - return builder_.Finish(); -} - -flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct BoundsSensitiveFeaturesT : public flatbuffers::NativeTable { - typedef BoundsSensitiveFeatures TableType; - bool enabled; - int32_t num_tokens_before; - int32_t num_tokens_inside_left; - int32_t num_tokens_inside_right; - int32_t num_tokens_after; - bool include_inside_bag; - bool include_inside_length; - bool score_single_token_spans_as_zero; - BoundsSensitiveFeaturesT() - : enabled(false), - num_tokens_before(0), - num_tokens_inside_left(0), - num_tokens_inside_right(0), - num_tokens_after(0), - include_inside_bag(false), - include_inside_length(false), - score_single_token_spans_as_zero(false) { - } -}; - -struct BoundsSensitiveFeatures FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef BoundsSensitiveFeaturesT NativeTableType; - enum { - VT_ENABLED = 4, - VT_NUM_TOKENS_BEFORE = 6, - VT_NUM_TOKENS_INSIDE_LEFT = 8, - VT_NUM_TOKENS_INSIDE_RIGHT = 10, - VT_NUM_TOKENS_AFTER = 12, - VT_INCLUDE_INSIDE_BAG = 14, - VT_INCLUDE_INSIDE_LENGTH = 16, - VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO = 18 - }; - bool enabled() const { - return GetField<uint8_t>(VT_ENABLED, 0) != 0; - } - int32_t num_tokens_before() const { - return GetField<int32_t>(VT_NUM_TOKENS_BEFORE, 0); - } - int32_t num_tokens_inside_left() const { - return GetField<int32_t>(VT_NUM_TOKENS_INSIDE_LEFT, 0); - } - int32_t num_tokens_inside_right() const { - return GetField<int32_t>(VT_NUM_TOKENS_INSIDE_RIGHT, 0); - } - int32_t num_tokens_after() const { - return GetField<int32_t>(VT_NUM_TOKENS_AFTER, 0); - } - bool include_inside_bag() const { - return GetField<uint8_t>(VT_INCLUDE_INSIDE_BAG, 0) != 0; - } - bool include_inside_length() const { - return GetField<uint8_t>(VT_INCLUDE_INSIDE_LENGTH, 0) != 0; - } - bool score_single_token_spans_as_zero() const { - return GetField<uint8_t>(VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO, 0) != 0; - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<uint8_t>(verifier, VT_ENABLED) && - VerifyField<int32_t>(verifier, VT_NUM_TOKENS_BEFORE) && - VerifyField<int32_t>(verifier, VT_NUM_TOKENS_INSIDE_LEFT) && - VerifyField<int32_t>(verifier, VT_NUM_TOKENS_INSIDE_RIGHT) && - VerifyField<int32_t>(verifier, VT_NUM_TOKENS_AFTER) && - VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_BAG) && - VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_LENGTH) && - VerifyField<uint8_t>(verifier, VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO) && - verifier.EndTable(); - } - BoundsSensitiveFeaturesT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(BoundsSensitiveFeaturesT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<BoundsSensitiveFeatures> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct BoundsSensitiveFeaturesBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_enabled(bool enabled) { - fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_ENABLED, static_cast<uint8_t>(enabled), 0); - } - void add_num_tokens_before(int32_t num_tokens_before) { - fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_BEFORE, num_tokens_before, 0); - } - void add_num_tokens_inside_left(int32_t num_tokens_inside_left) { - fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_INSIDE_LEFT, num_tokens_inside_left, 0); - } - void add_num_tokens_inside_right(int32_t num_tokens_inside_right) { - fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_INSIDE_RIGHT, num_tokens_inside_right, 0); - } - void add_num_tokens_after(int32_t num_tokens_after) { - fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_AFTER, num_tokens_after, 0); - } - void add_include_inside_bag(bool include_inside_bag) { - fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_BAG, static_cast<uint8_t>(include_inside_bag), 0); - } - void add_include_inside_length(bool include_inside_length) { - fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_LENGTH, static_cast<uint8_t>(include_inside_length), 0); - } - void add_score_single_token_spans_as_zero(bool score_single_token_spans_as_zero) { - fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO, static_cast<uint8_t>(score_single_token_spans_as_zero), 0); - } - explicit BoundsSensitiveFeaturesBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - BoundsSensitiveFeaturesBuilder &operator=(const BoundsSensitiveFeaturesBuilder &); - flatbuffers::Offset<BoundsSensitiveFeatures> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<BoundsSensitiveFeatures>(end); - return o; - } -}; - -inline flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures( - flatbuffers::FlatBufferBuilder &_fbb, - bool enabled = false, - int32_t num_tokens_before = 0, - int32_t num_tokens_inside_left = 0, - int32_t num_tokens_inside_right = 0, - int32_t num_tokens_after = 0, - bool include_inside_bag = false, - bool include_inside_length = false, - bool score_single_token_spans_as_zero = false) { - BoundsSensitiveFeaturesBuilder builder_(_fbb); - builder_.add_num_tokens_after(num_tokens_after); - builder_.add_num_tokens_inside_right(num_tokens_inside_right); - builder_.add_num_tokens_inside_left(num_tokens_inside_left); - builder_.add_num_tokens_before(num_tokens_before); - builder_.add_score_single_token_spans_as_zero(score_single_token_spans_as_zero); - builder_.add_include_inside_length(include_inside_length); - builder_.add_include_inside_bag(include_inside_bag); - builder_.add_enabled(enabled); - return builder_.Finish(); -} - -flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -struct AlternativeCollectionMapEntryT : public flatbuffers::NativeTable { - typedef AlternativeCollectionMapEntry TableType; - std::string key; - std::string value; - AlternativeCollectionMapEntryT() { - } -}; - -struct AlternativeCollectionMapEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef AlternativeCollectionMapEntryT NativeTableType; - enum { - VT_KEY = 4, - VT_VALUE = 6 - }; - const flatbuffers::String *key() const { - return GetPointer<const flatbuffers::String *>(VT_KEY); - } - const flatbuffers::String *value() const { - return GetPointer<const flatbuffers::String *>(VT_VALUE); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_KEY) && - verifier.Verify(key()) && - VerifyOffset(verifier, VT_VALUE) && - verifier.Verify(value()) && - verifier.EndTable(); - } - AlternativeCollectionMapEntryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(AlternativeCollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<AlternativeCollectionMapEntry> Pack(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct AlternativeCollectionMapEntryBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_key(flatbuffers::Offset<flatbuffers::String> key) { - fbb_.AddOffset(AlternativeCollectionMapEntry::VT_KEY, key); - } - void add_value(flatbuffers::Offset<flatbuffers::String> value) { - fbb_.AddOffset(AlternativeCollectionMapEntry::VT_VALUE, value); - } - explicit AlternativeCollectionMapEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - AlternativeCollectionMapEntryBuilder &operator=(const AlternativeCollectionMapEntryBuilder &); - flatbuffers::Offset<AlternativeCollectionMapEntry> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<AlternativeCollectionMapEntry>(end); - return o; - } -}; - -inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::String> key = 0, - flatbuffers::Offset<flatbuffers::String> value = 0) { - AlternativeCollectionMapEntryBuilder builder_(_fbb); - builder_.add_value(value); - builder_.add_key(key); - return builder_.Finish(); -} - -inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntryDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *key = nullptr, - const char *value = nullptr) { - return libtextclassifier2::FeatureProcessorOptions_::CreateAlternativeCollectionMapEntry( - _fbb, - key ? _fbb.CreateString(key) : 0, - value ? _fbb.CreateString(value) : 0); -} - -flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -} // namespace FeatureProcessorOptions_ - -struct FeatureProcessorOptionsT : public flatbuffers::NativeTable { - typedef FeatureProcessorOptions TableType; - int32_t num_buckets; - int32_t embedding_size; - int32_t embedding_quantization_bits; - int32_t context_size; - int32_t max_selection_span; - std::vector<int32_t> chargram_orders; - int32_t max_word_length; - bool unicode_aware_features; - bool extract_case_feature; - bool extract_selection_mask_feature; - std::vector<std::string> regexp_feature; - bool remap_digits; - bool lowercase_tokens; - bool selection_reduced_output_space; - std::vector<std::string> collections; - int32_t default_collection; - bool only_use_line_with_click; - bool split_tokens_on_selection_boundaries; - std::vector<std::unique_ptr<TokenizationCodepointRangeT>> tokenization_codepoint_config; - libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method; - bool snap_label_span_boundaries_to_containing_tokens; - std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> supported_codepoint_ranges; - std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> internal_tokenizer_codepoint_ranges; - float min_supported_codepoint_ratio; - int32_t feature_version; - libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type; - bool icu_preserve_whitespace_tokens; - std::vector<int32_t> ignored_span_boundary_codepoints; - std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT> bounds_sensitive_features; - std::vector<std::string> allowed_chargrams; - bool tokenize_on_script_change; - FeatureProcessorOptionsT() - : num_buckets(-1), - embedding_size(-1), - embedding_quantization_bits(8), - context_size(-1), - max_selection_span(-1), - max_word_length(20), - unicode_aware_features(false), - extract_case_feature(false), - extract_selection_mask_feature(false), - remap_digits(false), - lowercase_tokens(false), - selection_reduced_output_space(true), - default_collection(-1), - only_use_line_with_click(false), - split_tokens_on_selection_boundaries(false), - center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD), - snap_label_span_boundaries_to_containing_tokens(false), - min_supported_codepoint_ratio(0.0f), - feature_version(0), - tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER), - icu_preserve_whitespace_tokens(false), - tokenize_on_script_change(false) { - } -}; - -struct FeatureProcessorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef FeatureProcessorOptionsT NativeTableType; - enum { - VT_NUM_BUCKETS = 4, - VT_EMBEDDING_SIZE = 6, - VT_EMBEDDING_QUANTIZATION_BITS = 8, - VT_CONTEXT_SIZE = 10, - VT_MAX_SELECTION_SPAN = 12, - VT_CHARGRAM_ORDERS = 14, - VT_MAX_WORD_LENGTH = 16, - VT_UNICODE_AWARE_FEATURES = 18, - VT_EXTRACT_CASE_FEATURE = 20, - VT_EXTRACT_SELECTION_MASK_FEATURE = 22, - VT_REGEXP_FEATURE = 24, - VT_REMAP_DIGITS = 26, - VT_LOWERCASE_TOKENS = 28, - VT_SELECTION_REDUCED_OUTPUT_SPACE = 30, - VT_COLLECTIONS = 32, - VT_DEFAULT_COLLECTION = 34, - VT_ONLY_USE_LINE_WITH_CLICK = 36, - VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 38, - VT_TOKENIZATION_CODEPOINT_CONFIG = 40, - VT_CENTER_TOKEN_SELECTION_METHOD = 42, - VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 44, - VT_SUPPORTED_CODEPOINT_RANGES = 46, - VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 48, - VT_MIN_SUPPORTED_CODEPOINT_RATIO = 50, - VT_FEATURE_VERSION = 52, - VT_TOKENIZATION_TYPE = 54, - VT_ICU_PRESERVE_WHITESPACE_TOKENS = 56, - VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 58, - VT_BOUNDS_SENSITIVE_FEATURES = 60, - VT_ALLOWED_CHARGRAMS = 62, - VT_TOKENIZE_ON_SCRIPT_CHANGE = 64 - }; - int32_t num_buckets() const { - return GetField<int32_t>(VT_NUM_BUCKETS, -1); - } - int32_t embedding_size() const { - return GetField<int32_t>(VT_EMBEDDING_SIZE, -1); - } - int32_t embedding_quantization_bits() const { - return GetField<int32_t>(VT_EMBEDDING_QUANTIZATION_BITS, 8); - } - int32_t context_size() const { - return GetField<int32_t>(VT_CONTEXT_SIZE, -1); - } - int32_t max_selection_span() const { - return GetField<int32_t>(VT_MAX_SELECTION_SPAN, -1); - } - const flatbuffers::Vector<int32_t> *chargram_orders() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_CHARGRAM_ORDERS); - } - int32_t max_word_length() const { - return GetField<int32_t>(VT_MAX_WORD_LENGTH, 20); - } - bool unicode_aware_features() const { - return GetField<uint8_t>(VT_UNICODE_AWARE_FEATURES, 0) != 0; - } - bool extract_case_feature() const { - return GetField<uint8_t>(VT_EXTRACT_CASE_FEATURE, 0) != 0; - } - bool extract_selection_mask_feature() const { - return GetField<uint8_t>(VT_EXTRACT_SELECTION_MASK_FEATURE, 0) != 0; - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXP_FEATURE); - } - bool remap_digits() const { - return GetField<uint8_t>(VT_REMAP_DIGITS, 0) != 0; - } - bool lowercase_tokens() const { - return GetField<uint8_t>(VT_LOWERCASE_TOKENS, 0) != 0; - } - bool selection_reduced_output_space() const { - return GetField<uint8_t>(VT_SELECTION_REDUCED_OUTPUT_SPACE, 1) != 0; - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *collections() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_COLLECTIONS); - } - int32_t default_collection() const { - return GetField<int32_t>(VT_DEFAULT_COLLECTION, -1); - } - bool only_use_line_with_click() const { - return GetField<uint8_t>(VT_ONLY_USE_LINE_WITH_CLICK, 0) != 0; - } - bool split_tokens_on_selection_boundaries() const { - return GetField<uint8_t>(VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, 0) != 0; - } - const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *>(VT_TOKENIZATION_CODEPOINT_CONFIG); - } - libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method() const { - return static_cast<libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod>(GetField<int32_t>(VT_CENTER_TOKEN_SELECTION_METHOD, 0)); - } - bool snap_label_span_boundaries_to_containing_tokens() const { - return GetField<uint8_t>(VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, 0) != 0; - } - const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_SUPPORTED_CODEPOINT_RANGES); - } - const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES); - } - float min_supported_codepoint_ratio() const { - return GetField<float>(VT_MIN_SUPPORTED_CODEPOINT_RATIO, 0.0f); - } - int32_t feature_version() const { - return GetField<int32_t>(VT_FEATURE_VERSION, 0); - } - libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type() const { - return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 1)); - } - bool icu_preserve_whitespace_tokens() const { - return GetField<uint8_t>(VT_ICU_PRESERVE_WHITESPACE_TOKENS, 0) != 0; - } - const flatbuffers::Vector<int32_t> *ignored_span_boundary_codepoints() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS); - } - const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *bounds_sensitive_features() const { - return GetPointer<const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *>(VT_BOUNDS_SENSITIVE_FEATURES); - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_ALLOWED_CHARGRAMS); - } - bool tokenize_on_script_change() const { - return GetField<uint8_t>(VT_TOKENIZE_ON_SCRIPT_CHANGE, 0) != 0; - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_NUM_BUCKETS) && - VerifyField<int32_t>(verifier, VT_EMBEDDING_SIZE) && - VerifyField<int32_t>(verifier, VT_EMBEDDING_QUANTIZATION_BITS) && - VerifyField<int32_t>(verifier, VT_CONTEXT_SIZE) && - VerifyField<int32_t>(verifier, VT_MAX_SELECTION_SPAN) && - VerifyOffset(verifier, VT_CHARGRAM_ORDERS) && - verifier.Verify(chargram_orders()) && - VerifyField<int32_t>(verifier, VT_MAX_WORD_LENGTH) && - VerifyField<uint8_t>(verifier, VT_UNICODE_AWARE_FEATURES) && - VerifyField<uint8_t>(verifier, VT_EXTRACT_CASE_FEATURE) && - VerifyField<uint8_t>(verifier, VT_EXTRACT_SELECTION_MASK_FEATURE) && - VerifyOffset(verifier, VT_REGEXP_FEATURE) && - verifier.Verify(regexp_feature()) && - verifier.VerifyVectorOfStrings(regexp_feature()) && - VerifyField<uint8_t>(verifier, VT_REMAP_DIGITS) && - VerifyField<uint8_t>(verifier, VT_LOWERCASE_TOKENS) && - VerifyField<uint8_t>(verifier, VT_SELECTION_REDUCED_OUTPUT_SPACE) && - VerifyOffset(verifier, VT_COLLECTIONS) && - verifier.Verify(collections()) && - verifier.VerifyVectorOfStrings(collections()) && - VerifyField<int32_t>(verifier, VT_DEFAULT_COLLECTION) && - VerifyField<uint8_t>(verifier, VT_ONLY_USE_LINE_WITH_CLICK) && - VerifyField<uint8_t>(verifier, VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES) && - VerifyOffset(verifier, VT_TOKENIZATION_CODEPOINT_CONFIG) && - verifier.Verify(tokenization_codepoint_config()) && - verifier.VerifyVectorOfTables(tokenization_codepoint_config()) && - VerifyField<int32_t>(verifier, VT_CENTER_TOKEN_SELECTION_METHOD) && - VerifyField<uint8_t>(verifier, VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS) && - VerifyOffset(verifier, VT_SUPPORTED_CODEPOINT_RANGES) && - verifier.Verify(supported_codepoint_ranges()) && - verifier.VerifyVectorOfTables(supported_codepoint_ranges()) && - VerifyOffset(verifier, VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES) && - verifier.Verify(internal_tokenizer_codepoint_ranges()) && - verifier.VerifyVectorOfTables(internal_tokenizer_codepoint_ranges()) && - VerifyField<float>(verifier, VT_MIN_SUPPORTED_CODEPOINT_RATIO) && - VerifyField<int32_t>(verifier, VT_FEATURE_VERSION) && - VerifyField<int32_t>(verifier, VT_TOKENIZATION_TYPE) && - VerifyField<uint8_t>(verifier, VT_ICU_PRESERVE_WHITESPACE_TOKENS) && - VerifyOffset(verifier, VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS) && - verifier.Verify(ignored_span_boundary_codepoints()) && - VerifyOffset(verifier, VT_BOUNDS_SENSITIVE_FEATURES) && - verifier.VerifyTable(bounds_sensitive_features()) && - VerifyOffset(verifier, VT_ALLOWED_CHARGRAMS) && - verifier.Verify(allowed_chargrams()) && - verifier.VerifyVectorOfStrings(allowed_chargrams()) && - VerifyField<uint8_t>(verifier, VT_TOKENIZE_ON_SCRIPT_CHANGE) && - verifier.EndTable(); - } - FeatureProcessorOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<FeatureProcessorOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct FeatureProcessorOptionsBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_num_buckets(int32_t num_buckets) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_NUM_BUCKETS, num_buckets, -1); - } - void add_embedding_size(int32_t embedding_size) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_SIZE, embedding_size, -1); - } - void add_embedding_quantization_bits(int32_t embedding_quantization_bits) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_QUANTIZATION_BITS, embedding_quantization_bits, 8); - } - void add_context_size(int32_t context_size) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CONTEXT_SIZE, context_size, -1); - } - void add_max_selection_span(int32_t max_selection_span) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_SELECTION_SPAN, max_selection_span, -1); - } - void add_chargram_orders(flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders) { - fbb_.AddOffset(FeatureProcessorOptions::VT_CHARGRAM_ORDERS, chargram_orders); - } - void add_max_word_length(int32_t max_word_length) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_WORD_LENGTH, max_word_length, 20); - } - void add_unicode_aware_features(bool unicode_aware_features) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_UNICODE_AWARE_FEATURES, static_cast<uint8_t>(unicode_aware_features), 0); - } - void add_extract_case_feature(bool extract_case_feature) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_CASE_FEATURE, static_cast<uint8_t>(extract_case_feature), 0); - } - void add_extract_selection_mask_feature(bool extract_selection_mask_feature) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_SELECTION_MASK_FEATURE, static_cast<uint8_t>(extract_selection_mask_feature), 0); - } - void add_regexp_feature(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature) { - fbb_.AddOffset(FeatureProcessorOptions::VT_REGEXP_FEATURE, regexp_feature); - } - void add_remap_digits(bool remap_digits) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_REMAP_DIGITS, static_cast<uint8_t>(remap_digits), 0); - } - void add_lowercase_tokens(bool lowercase_tokens) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_LOWERCASE_TOKENS, static_cast<uint8_t>(lowercase_tokens), 0); - } - void add_selection_reduced_output_space(bool selection_reduced_output_space) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SELECTION_REDUCED_OUTPUT_SPACE, static_cast<uint8_t>(selection_reduced_output_space), 1); - } - void add_collections(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections) { - fbb_.AddOffset(FeatureProcessorOptions::VT_COLLECTIONS, collections); - } - void add_default_collection(int32_t default_collection) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_DEFAULT_COLLECTION, default_collection, -1); - } - void add_only_use_line_with_click(bool only_use_line_with_click) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ONLY_USE_LINE_WITH_CLICK, static_cast<uint8_t>(only_use_line_with_click), 0); - } - void add_split_tokens_on_selection_boundaries(bool split_tokens_on_selection_boundaries) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, static_cast<uint8_t>(split_tokens_on_selection_boundaries), 0); - } - void add_tokenization_codepoint_config(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config) { - fbb_.AddOffset(FeatureProcessorOptions::VT_TOKENIZATION_CODEPOINT_CONFIG, tokenization_codepoint_config); - } - void add_center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CENTER_TOKEN_SELECTION_METHOD, static_cast<int32_t>(center_token_selection_method), 0); - } - void add_snap_label_span_boundaries_to_containing_tokens(bool snap_label_span_boundaries_to_containing_tokens) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, static_cast<uint8_t>(snap_label_span_boundaries_to_containing_tokens), 0); - } - void add_supported_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges) { - fbb_.AddOffset(FeatureProcessorOptions::VT_SUPPORTED_CODEPOINT_RANGES, supported_codepoint_ranges); - } - void add_internal_tokenizer_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges) { - fbb_.AddOffset(FeatureProcessorOptions::VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES, internal_tokenizer_codepoint_ranges); - } - void add_min_supported_codepoint_ratio(float min_supported_codepoint_ratio) { - fbb_.AddElement<float>(FeatureProcessorOptions::VT_MIN_SUPPORTED_CODEPOINT_RATIO, min_supported_codepoint_ratio, 0.0f); - } - void add_feature_version(int32_t feature_version) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_FEATURE_VERSION, feature_version, 0); - } - void add_tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type) { - fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 1); - } - void add_icu_preserve_whitespace_tokens(bool icu_preserve_whitespace_tokens) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ICU_PRESERVE_WHITESPACE_TOKENS, static_cast<uint8_t>(icu_preserve_whitespace_tokens), 0); - } - void add_ignored_span_boundary_codepoints(flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints) { - fbb_.AddOffset(FeatureProcessorOptions::VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS, ignored_span_boundary_codepoints); - } - void add_bounds_sensitive_features(flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features) { - fbb_.AddOffset(FeatureProcessorOptions::VT_BOUNDS_SENSITIVE_FEATURES, bounds_sensitive_features); - } - void add_allowed_chargrams(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams) { - fbb_.AddOffset(FeatureProcessorOptions::VT_ALLOWED_CHARGRAMS, allowed_chargrams); - } - void add_tokenize_on_script_change(bool tokenize_on_script_change) { - fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_TOKENIZE_ON_SCRIPT_CHANGE, static_cast<uint8_t>(tokenize_on_script_change), 0); - } - explicit FeatureProcessorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - FeatureProcessorOptionsBuilder &operator=(const FeatureProcessorOptionsBuilder &); - flatbuffers::Offset<FeatureProcessorOptions> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<FeatureProcessorOptions>(end); - return o; - } -}; - -inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t num_buckets = -1, - int32_t embedding_size = -1, - int32_t embedding_quantization_bits = 8, - int32_t context_size = -1, - int32_t max_selection_span = -1, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders = 0, - int32_t max_word_length = 20, - bool unicode_aware_features = false, - bool extract_case_feature = false, - bool extract_selection_mask_feature = false, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature = 0, - bool remap_digits = false, - bool lowercase_tokens = false, - bool selection_reduced_output_space = true, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections = 0, - int32_t default_collection = -1, - bool only_use_line_with_click = false, - bool split_tokens_on_selection_boundaries = false, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config = 0, - libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, - bool snap_label_span_boundaries_to_containing_tokens = false, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges = 0, - float min_supported_codepoint_ratio = 0.0f, - int32_t feature_version = 0, - libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER, - bool icu_preserve_whitespace_tokens = false, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints = 0, - flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams = 0, - bool tokenize_on_script_change = false) { - FeatureProcessorOptionsBuilder builder_(_fbb); - builder_.add_allowed_chargrams(allowed_chargrams); - builder_.add_bounds_sensitive_features(bounds_sensitive_features); - builder_.add_ignored_span_boundary_codepoints(ignored_span_boundary_codepoints); - builder_.add_tokenization_type(tokenization_type); - builder_.add_feature_version(feature_version); - builder_.add_min_supported_codepoint_ratio(min_supported_codepoint_ratio); - builder_.add_internal_tokenizer_codepoint_ranges(internal_tokenizer_codepoint_ranges); - builder_.add_supported_codepoint_ranges(supported_codepoint_ranges); - builder_.add_center_token_selection_method(center_token_selection_method); - builder_.add_tokenization_codepoint_config(tokenization_codepoint_config); - builder_.add_default_collection(default_collection); - builder_.add_collections(collections); - builder_.add_regexp_feature(regexp_feature); - builder_.add_max_word_length(max_word_length); - builder_.add_chargram_orders(chargram_orders); - builder_.add_max_selection_span(max_selection_span); - builder_.add_context_size(context_size); - builder_.add_embedding_quantization_bits(embedding_quantization_bits); - builder_.add_embedding_size(embedding_size); - builder_.add_num_buckets(num_buckets); - builder_.add_tokenize_on_script_change(tokenize_on_script_change); - builder_.add_icu_preserve_whitespace_tokens(icu_preserve_whitespace_tokens); - builder_.add_snap_label_span_boundaries_to_containing_tokens(snap_label_span_boundaries_to_containing_tokens); - builder_.add_split_tokens_on_selection_boundaries(split_tokens_on_selection_boundaries); - builder_.add_only_use_line_with_click(only_use_line_with_click); - builder_.add_selection_reduced_output_space(selection_reduced_output_space); - builder_.add_lowercase_tokens(lowercase_tokens); - builder_.add_remap_digits(remap_digits); - builder_.add_extract_selection_mask_feature(extract_selection_mask_feature); - builder_.add_extract_case_feature(extract_case_feature); - builder_.add_unicode_aware_features(unicode_aware_features); - return builder_.Finish(); -} - -inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t num_buckets = -1, - int32_t embedding_size = -1, - int32_t embedding_quantization_bits = 8, - int32_t context_size = -1, - int32_t max_selection_span = -1, - const std::vector<int32_t> *chargram_orders = nullptr, - int32_t max_word_length = 20, - bool unicode_aware_features = false, - bool extract_case_feature = false, - bool extract_selection_mask_feature = false, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature = nullptr, - bool remap_digits = false, - bool lowercase_tokens = false, - bool selection_reduced_output_space = true, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *collections = nullptr, - int32_t default_collection = -1, - bool only_use_line_with_click = false, - bool split_tokens_on_selection_boundaries = false, - const std::vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config = nullptr, - libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD, - bool snap_label_span_boundaries_to_containing_tokens = false, - const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges = nullptr, - const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges = nullptr, - float min_supported_codepoint_ratio = 0.0f, - int32_t feature_version = 0, - libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER, - bool icu_preserve_whitespace_tokens = false, - const std::vector<int32_t> *ignored_span_boundary_codepoints = nullptr, - flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams = nullptr, - bool tokenize_on_script_change = false) { - return libtextclassifier2::CreateFeatureProcessorOptions( - _fbb, - num_buckets, - embedding_size, - embedding_quantization_bits, - context_size, - max_selection_span, - chargram_orders ? _fbb.CreateVector<int32_t>(*chargram_orders) : 0, - max_word_length, - unicode_aware_features, - extract_case_feature, - extract_selection_mask_feature, - regexp_feature ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexp_feature) : 0, - remap_digits, - lowercase_tokens, - selection_reduced_output_space, - collections ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*collections) : 0, - default_collection, - only_use_line_with_click, - split_tokens_on_selection_boundaries, - tokenization_codepoint_config ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>>(*tokenization_codepoint_config) : 0, - center_token_selection_method, - snap_label_span_boundaries_to_containing_tokens, - supported_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*supported_codepoint_ranges) : 0, - internal_tokenizer_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*internal_tokenizer_codepoint_ranges) : 0, - min_supported_codepoint_ratio, - feature_version, - tokenization_type, - icu_preserve_whitespace_tokens, - ignored_span_boundary_codepoints ? _fbb.CreateVector<int32_t>(*ignored_span_boundary_codepoints) : 0, - bounds_sensitive_features, - allowed_chargrams ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*allowed_chargrams) : 0, - tokenize_on_script_change); -} - -flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - -inline CompressedBufferT *CompressedBuffer::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new CompressedBufferT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void CompressedBuffer::UnPackTo(CompressedBufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = buffer(); if (_e) { _o->buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffer[_i] = _e->Get(_i); } } }; - { auto _e = uncompressed_size(); _o->uncompressed_size = _e; }; -} - -inline flatbuffers::Offset<CompressedBuffer> CompressedBuffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateCompressedBuffer(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CompressedBufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _buffer = _o->buffer.size() ? _fbb.CreateVector(_o->buffer) : 0; - auto _uncompressed_size = _o->uncompressed_size; - return libtextclassifier2::CreateCompressedBuffer( - _fbb, - _buffer, - _uncompressed_size); -} - -inline SelectionModelOptionsT *SelectionModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new SelectionModelOptionsT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void SelectionModelOptions::UnPackTo(SelectionModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = strip_unpaired_brackets(); _o->strip_unpaired_brackets = _e; }; - { auto _e = symmetry_context_size(); _o->symmetry_context_size = _e; }; - { auto _e = batch_size(); _o->batch_size = _e; }; - { auto _e = always_classify_suggested_selection(); _o->always_classify_suggested_selection = _e; }; -} - -inline flatbuffers::Offset<SelectionModelOptions> SelectionModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateSelectionModelOptions(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SelectionModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _strip_unpaired_brackets = _o->strip_unpaired_brackets; - auto _symmetry_context_size = _o->symmetry_context_size; - auto _batch_size = _o->batch_size; - auto _always_classify_suggested_selection = _o->always_classify_suggested_selection; - return libtextclassifier2::CreateSelectionModelOptions( - _fbb, - _strip_unpaired_brackets, - _symmetry_context_size, - _batch_size, - _always_classify_suggested_selection); -} - -inline ClassificationModelOptionsT *ClassificationModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new ClassificationModelOptionsT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void ClassificationModelOptions::UnPackTo(ClassificationModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = phone_min_num_digits(); _o->phone_min_num_digits = _e; }; - { auto _e = phone_max_num_digits(); _o->phone_max_num_digits = _e; }; - { auto _e = address_min_num_tokens(); _o->address_min_num_tokens = _e; }; - { auto _e = max_num_tokens(); _o->max_num_tokens = _e; }; -} - -inline flatbuffers::Offset<ClassificationModelOptions> ClassificationModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateClassificationModelOptions(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ClassificationModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _phone_min_num_digits = _o->phone_min_num_digits; - auto _phone_max_num_digits = _o->phone_max_num_digits; - auto _address_min_num_tokens = _o->address_min_num_tokens; - auto _max_num_tokens = _o->max_num_tokens; - return libtextclassifier2::CreateClassificationModelOptions( - _fbb, - _phone_min_num_digits, - _phone_max_num_digits, - _address_min_num_tokens, - _max_num_tokens); -} - -namespace RegexModel_ { - -inline PatternT *Pattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new PatternT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void Pattern::UnPackTo(PatternT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = collection_name(); if (_e) _o->collection_name = _e->str(); }; - { auto _e = pattern(); if (_e) _o->pattern = _e->str(); }; - { auto _e = enabled_modes(); _o->enabled_modes = _e; }; - { auto _e = target_classification_score(); _o->target_classification_score = _e; }; - { auto _e = priority_score(); _o->priority_score = _e; }; - { auto _e = use_approximate_matching(); _o->use_approximate_matching = _e; }; - { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<libtextclassifier2::CompressedBufferT>(_e->UnPack(_resolver)); }; -} - -inline flatbuffers::Offset<Pattern> Pattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreatePattern(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _collection_name = _o->collection_name.empty() ? 0 : _fbb.CreateString(_o->collection_name); - auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern); - auto _enabled_modes = _o->enabled_modes; - auto _target_classification_score = _o->target_classification_score; - auto _priority_score = _o->priority_score; - auto _use_approximate_matching = _o->use_approximate_matching; - auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0; - return libtextclassifier2::RegexModel_::CreatePattern( - _fbb, - _collection_name, - _pattern, - _enabled_modes, - _target_classification_score, - _priority_score, - _use_approximate_matching, - _compressed_pattern); -} - -} // namespace RegexModel_ - -inline RegexModelT *RegexModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new RegexModelT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void RegexModel::UnPackTo(RegexModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::RegexModel_::PatternT>(_e->Get(_i)->UnPack(_resolver)); } } }; -} - -inline flatbuffers::Offset<RegexModel> RegexModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateRegexModel(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreatePattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0; - return libtextclassifier2::CreateRegexModel( - _fbb, - _patterns); -} - -namespace DatetimeModelPattern_ { - -inline RegexT *Regex::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new RegexT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void Regex::UnPackTo(RegexT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = pattern(); if (_e) _o->pattern = _e->str(); }; - { auto _e = groups(); if (_e) { _o->groups.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->groups[_i] = (DatetimeGroupType)_e->Get(_i); } } }; - { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<libtextclassifier2::CompressedBufferT>(_e->UnPack(_resolver)); }; -} - -inline flatbuffers::Offset<Regex> Regex::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateRegex(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<Regex> CreateRegex(flatbuffers::FlatBufferBuilder &_fbb, const RegexT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern); - auto _groups = _o->groups.size() ? _fbb.CreateVector((const int32_t*)_o->groups.data(), _o->groups.size()) : 0; - auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0; - return libtextclassifier2::DatetimeModelPattern_::CreateRegex( - _fbb, - _pattern, - _groups, - _compressed_pattern); -} - -} // namespace DatetimeModelPattern_ - -inline DatetimeModelPatternT *DatetimeModelPattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new DatetimeModelPatternT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void DatetimeModelPattern::UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = regexes(); if (_e) { _o->regexes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexes[_i] = std::unique_ptr<libtextclassifier2::DatetimeModelPattern_::RegexT>(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } }; - { auto _e = target_classification_score(); _o->target_classification_score = _e; }; - { auto _e = priority_score(); _o->priority_score = _e; }; - { auto _e = enabled_modes(); _o->enabled_modes = _e; }; -} - -inline flatbuffers::Offset<DatetimeModelPattern> DatetimeModelPattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateDatetimeModelPattern(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelPatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _regexes = _o->regexes.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> (_o->regexes.size(), [](size_t i, _VectorArgs *__va) { return CreateRegex(*__va->__fbb, __va->__o->regexes[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0; - auto _target_classification_score = _o->target_classification_score; - auto _priority_score = _o->priority_score; - auto _enabled_modes = _o->enabled_modes; - return libtextclassifier2::CreateDatetimeModelPattern( - _fbb, - _regexes, - _locales, - _target_classification_score, - _priority_score, - _enabled_modes); -} - -inline DatetimeModelExtractorT *DatetimeModelExtractor::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new DatetimeModelExtractorT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void DatetimeModelExtractor::UnPackTo(DatetimeModelExtractorT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = extractor(); _o->extractor = _e; }; - { auto _e = pattern(); if (_e) _o->pattern = _e->str(); }; - { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } }; - { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<CompressedBufferT>(_e->UnPack(_resolver)); }; -} - -inline flatbuffers::Offset<DatetimeModelExtractor> DatetimeModelExtractor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateDatetimeModelExtractor(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelExtractorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _extractor = _o->extractor; - auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern); - auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0; - auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0; - return libtextclassifier2::CreateDatetimeModelExtractor( - _fbb, - _extractor, - _pattern, - _locales, - _compressed_pattern); -} - -inline DatetimeModelT *DatetimeModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new DatetimeModelT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void DatetimeModel::UnPackTo(DatetimeModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i)->str(); } } }; - { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<DatetimeModelPatternT>(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = extractors(); if (_e) { _o->extractors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->extractors[_i] = std::unique_ptr<DatetimeModelExtractorT>(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = use_extractors_for_locating(); _o->use_extractors_for_locating = _e; }; - { auto _e = default_locales(); if (_e) { _o->default_locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->default_locales[_i] = _e->Get(_i); } } }; -} - -inline flatbuffers::Offset<DatetimeModel> DatetimeModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateDatetimeModel(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _locales = _o->locales.size() ? _fbb.CreateVectorOfStrings(_o->locales) : 0; - auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _extractors = _o->extractors.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>> (_o->extractors.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelExtractor(*__va->__fbb, __va->__o->extractors[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _use_extractors_for_locating = _o->use_extractors_for_locating; - auto _default_locales = _o->default_locales.size() ? _fbb.CreateVector(_o->default_locales) : 0; - return libtextclassifier2::CreateDatetimeModel( - _fbb, - _locales, - _patterns, - _extractors, - _use_extractors_for_locating, - _default_locales); -} - -namespace DatetimeModelLibrary_ { - -inline ItemT *Item::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new ItemT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void Item::UnPackTo(ItemT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = key(); if (_e) _o->key = _e->str(); }; - { auto _e = value(); if (_e) _o->value = std::unique_ptr<libtextclassifier2::DatetimeModelT>(_e->UnPack(_resolver)); }; -} - -inline flatbuffers::Offset<Item> Item::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ItemT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateItem(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<Item> CreateItem(flatbuffers::FlatBufferBuilder &_fbb, const ItemT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ItemT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key); - auto _value = _o->value ? CreateDatetimeModel(_fbb, _o->value.get(), _rehasher) : 0; - return libtextclassifier2::DatetimeModelLibrary_::CreateItem( - _fbb, - _key, - _value); -} - -} // namespace DatetimeModelLibrary_ - -inline DatetimeModelLibraryT *DatetimeModelLibrary::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new DatetimeModelLibraryT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void DatetimeModelLibrary::UnPackTo(DatetimeModelLibraryT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = models(); if (_e) { _o->models.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->models[_i] = std::unique_ptr<libtextclassifier2::DatetimeModelLibrary_::ItemT>(_e->Get(_i)->UnPack(_resolver)); } } }; -} - -inline flatbuffers::Offset<DatetimeModelLibrary> DatetimeModelLibrary::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateDatetimeModelLibrary(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelLibraryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _models = _o->models.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> (_o->models.size(), [](size_t i, _VectorArgs *__va) { return CreateItem(*__va->__fbb, __va->__o->models[i].get(), __va->__rehasher); }, &_va ) : 0; - return libtextclassifier2::CreateDatetimeModelLibrary( - _fbb, - _models); -} - -inline ModelTriggeringOptionsT *ModelTriggeringOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new ModelTriggeringOptionsT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void ModelTriggeringOptions::UnPackTo(ModelTriggeringOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = min_annotate_confidence(); _o->min_annotate_confidence = _e; }; - { auto _e = enabled_modes(); _o->enabled_modes = _e; }; -} - -inline flatbuffers::Offset<ModelTriggeringOptions> ModelTriggeringOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateModelTriggeringOptions(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelTriggeringOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _min_annotate_confidence = _o->min_annotate_confidence; - auto _enabled_modes = _o->enabled_modes; - return libtextclassifier2::CreateModelTriggeringOptions( - _fbb, - _min_annotate_confidence, - _enabled_modes); -} - -inline OutputOptionsT *OutputOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new OutputOptionsT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void OutputOptions::UnPackTo(OutputOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = filtered_collections_annotation(); if (_e) { _o->filtered_collections_annotation.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_annotation[_i] = _e->Get(_i)->str(); } } }; - { auto _e = filtered_collections_classification(); if (_e) { _o->filtered_collections_classification.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_classification[_i] = _e->Get(_i)->str(); } } }; - { auto _e = filtered_collections_selection(); if (_e) { _o->filtered_collections_selection.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_selection[_i] = _e->Get(_i)->str(); } } }; -} - -inline flatbuffers::Offset<OutputOptions> OutputOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateOutputOptions(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<OutputOptions> CreateOutputOptions(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OutputOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _filtered_collections_annotation = _o->filtered_collections_annotation.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_annotation) : 0; - auto _filtered_collections_classification = _o->filtered_collections_classification.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_classification) : 0; - auto _filtered_collections_selection = _o->filtered_collections_selection.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_selection) : 0; - return libtextclassifier2::CreateOutputOptions( - _fbb, - _filtered_collections_annotation, - _filtered_collections_classification, - _filtered_collections_selection); -} - -inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new ModelT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = locales(); if (_e) _o->locales = _e->str(); }; - { auto _e = version(); _o->version = _e; }; - { auto _e = name(); if (_e) _o->name = _e->str(); }; - { auto _e = selection_feature_options(); if (_e) _o->selection_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); }; - { auto _e = classification_feature_options(); if (_e) _o->classification_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); }; - { auto _e = selection_model(); if (_e) { _o->selection_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->selection_model[_i] = _e->Get(_i); } } }; - { auto _e = classification_model(); if (_e) { _o->classification_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->classification_model[_i] = _e->Get(_i); } } }; - { auto _e = embedding_model(); if (_e) { _o->embedding_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_model[_i] = _e->Get(_i); } } }; - { auto _e = selection_options(); if (_e) _o->selection_options = std::unique_ptr<SelectionModelOptionsT>(_e->UnPack(_resolver)); }; - { auto _e = classification_options(); if (_e) _o->classification_options = std::unique_ptr<ClassificationModelOptionsT>(_e->UnPack(_resolver)); }; - { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<RegexModelT>(_e->UnPack(_resolver)); }; - { auto _e = datetime_model(); if (_e) _o->datetime_model = std::unique_ptr<DatetimeModelT>(_e->UnPack(_resolver)); }; - { auto _e = triggering_options(); if (_e) _o->triggering_options = std::unique_ptr<ModelTriggeringOptionsT>(_e->UnPack(_resolver)); }; - { auto _e = enabled_modes(); _o->enabled_modes = _e; }; - { auto _e = snap_whitespace_selections(); _o->snap_whitespace_selections = _e; }; - { auto _e = output_options(); if (_e) _o->output_options = std::unique_ptr<OutputOptionsT>(_e->UnPack(_resolver)); }; -} - -inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateModel(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _locales = _o->locales.empty() ? 0 : _fbb.CreateString(_o->locales); - auto _version = _o->version; - auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); - auto _selection_feature_options = _o->selection_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->selection_feature_options.get(), _rehasher) : 0; - auto _classification_feature_options = _o->classification_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->classification_feature_options.get(), _rehasher) : 0; - auto _selection_model = _o->selection_model.size() ? _fbb.CreateVector(_o->selection_model) : 0; - auto _classification_model = _o->classification_model.size() ? _fbb.CreateVector(_o->classification_model) : 0; - auto _embedding_model = _o->embedding_model.size() ? _fbb.CreateVector(_o->embedding_model) : 0; - auto _selection_options = _o->selection_options ? CreateSelectionModelOptions(_fbb, _o->selection_options.get(), _rehasher) : 0; - auto _classification_options = _o->classification_options ? CreateClassificationModelOptions(_fbb, _o->classification_options.get(), _rehasher) : 0; - auto _regex_model = _o->regex_model ? CreateRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0; - auto _datetime_model = _o->datetime_model ? CreateDatetimeModel(_fbb, _o->datetime_model.get(), _rehasher) : 0; - auto _triggering_options = _o->triggering_options ? CreateModelTriggeringOptions(_fbb, _o->triggering_options.get(), _rehasher) : 0; - auto _enabled_modes = _o->enabled_modes; - auto _snap_whitespace_selections = _o->snap_whitespace_selections; - auto _output_options = _o->output_options ? CreateOutputOptions(_fbb, _o->output_options.get(), _rehasher) : 0; - return libtextclassifier2::CreateModel( - _fbb, - _locales, - _version, - _name, - _selection_feature_options, - _classification_feature_options, - _selection_model, - _classification_model, - _embedding_model, - _selection_options, - _classification_options, - _regex_model, - _datetime_model, - _triggering_options, - _enabled_modes, - _snap_whitespace_selections, - _output_options); -} - -inline TokenizationCodepointRangeT *TokenizationCodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new TokenizationCodepointRangeT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void TokenizationCodepointRange::UnPackTo(TokenizationCodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = start(); _o->start = _e; }; - { auto _e = end(); _o->end = _e; }; - { auto _e = role(); _o->role = _e; }; - { auto _e = script_id(); _o->script_id = _e; }; -} - -inline flatbuffers::Offset<TokenizationCodepointRange> TokenizationCodepointRange::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateTokenizationCodepointRange(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TokenizationCodepointRangeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _start = _o->start; - auto _end = _o->end; - auto _role = _o->role; - auto _script_id = _o->script_id; - return libtextclassifier2::CreateTokenizationCodepointRange( - _fbb, - _start, - _end, - _role, - _script_id); -} - -namespace FeatureProcessorOptions_ { - -inline CodepointRangeT *CodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new CodepointRangeT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void CodepointRange::UnPackTo(CodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = start(); _o->start = _e; }; - { auto _e = end(); _o->end = _e; }; -} - -inline flatbuffers::Offset<CodepointRange> CodepointRange::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateCodepointRange(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CodepointRangeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _start = _o->start; - auto _end = _o->end; - return libtextclassifier2::FeatureProcessorOptions_::CreateCodepointRange( - _fbb, - _start, - _end); -} - -inline BoundsSensitiveFeaturesT *BoundsSensitiveFeatures::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new BoundsSensitiveFeaturesT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void BoundsSensitiveFeatures::UnPackTo(BoundsSensitiveFeaturesT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = enabled(); _o->enabled = _e; }; - { auto _e = num_tokens_before(); _o->num_tokens_before = _e; }; - { auto _e = num_tokens_inside_left(); _o->num_tokens_inside_left = _e; }; - { auto _e = num_tokens_inside_right(); _o->num_tokens_inside_right = _e; }; - { auto _e = num_tokens_after(); _o->num_tokens_after = _e; }; - { auto _e = include_inside_bag(); _o->include_inside_bag = _e; }; - { auto _e = include_inside_length(); _o->include_inside_length = _e; }; - { auto _e = score_single_token_spans_as_zero(); _o->score_single_token_spans_as_zero = _e; }; -} - -inline flatbuffers::Offset<BoundsSensitiveFeatures> BoundsSensitiveFeatures::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateBoundsSensitiveFeatures(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BoundsSensitiveFeaturesT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _enabled = _o->enabled; - auto _num_tokens_before = _o->num_tokens_before; - auto _num_tokens_inside_left = _o->num_tokens_inside_left; - auto _num_tokens_inside_right = _o->num_tokens_inside_right; - auto _num_tokens_after = _o->num_tokens_after; - auto _include_inside_bag = _o->include_inside_bag; - auto _include_inside_length = _o->include_inside_length; - auto _score_single_token_spans_as_zero = _o->score_single_token_spans_as_zero; - return libtextclassifier2::FeatureProcessorOptions_::CreateBoundsSensitiveFeatures( - _fbb, - _enabled, - _num_tokens_before, - _num_tokens_inside_left, - _num_tokens_inside_right, - _num_tokens_after, - _include_inside_bag, - _include_inside_length, - _score_single_token_spans_as_zero); -} - -inline AlternativeCollectionMapEntryT *AlternativeCollectionMapEntry::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new AlternativeCollectionMapEntryT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void AlternativeCollectionMapEntry::UnPackTo(AlternativeCollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = key(); if (_e) _o->key = _e->str(); }; - { auto _e = value(); if (_e) _o->value = _e->str(); }; -} - -inline flatbuffers::Offset<AlternativeCollectionMapEntry> AlternativeCollectionMapEntry::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateAlternativeCollectionMapEntry(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AlternativeCollectionMapEntryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key); - auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value); - return libtextclassifier2::FeatureProcessorOptions_::CreateAlternativeCollectionMapEntry( - _fbb, - _key, - _value); -} - -} // namespace FeatureProcessorOptions_ - -inline FeatureProcessorOptionsT *FeatureProcessorOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new FeatureProcessorOptionsT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void FeatureProcessorOptions::UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = num_buckets(); _o->num_buckets = _e; }; - { auto _e = embedding_size(); _o->embedding_size = _e; }; - { auto _e = embedding_quantization_bits(); _o->embedding_quantization_bits = _e; }; - { auto _e = context_size(); _o->context_size = _e; }; - { auto _e = max_selection_span(); _o->max_selection_span = _e; }; - { auto _e = chargram_orders(); if (_e) { _o->chargram_orders.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->chargram_orders[_i] = _e->Get(_i); } } }; - { auto _e = max_word_length(); _o->max_word_length = _e; }; - { auto _e = unicode_aware_features(); _o->unicode_aware_features = _e; }; - { auto _e = extract_case_feature(); _o->extract_case_feature = _e; }; - { auto _e = extract_selection_mask_feature(); _o->extract_selection_mask_feature = _e; }; - { auto _e = regexp_feature(); if (_e) { _o->regexp_feature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexp_feature[_i] = _e->Get(_i)->str(); } } }; - { auto _e = remap_digits(); _o->remap_digits = _e; }; - { auto _e = lowercase_tokens(); _o->lowercase_tokens = _e; }; - { auto _e = selection_reduced_output_space(); _o->selection_reduced_output_space = _e; }; - { auto _e = collections(); if (_e) { _o->collections.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->collections[_i] = _e->Get(_i)->str(); } } }; - { auto _e = default_collection(); _o->default_collection = _e; }; - { auto _e = only_use_line_with_click(); _o->only_use_line_with_click = _e; }; - { auto _e = split_tokens_on_selection_boundaries(); _o->split_tokens_on_selection_boundaries = _e; }; - { auto _e = tokenization_codepoint_config(); if (_e) { _o->tokenization_codepoint_config.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tokenization_codepoint_config[_i] = std::unique_ptr<TokenizationCodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = center_token_selection_method(); _o->center_token_selection_method = _e; }; - { auto _e = snap_label_span_boundaries_to_containing_tokens(); _o->snap_label_span_boundaries_to_containing_tokens = _e; }; - { auto _e = supported_codepoint_ranges(); if (_e) { _o->supported_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->supported_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = internal_tokenizer_codepoint_ranges(); if (_e) { _o->internal_tokenizer_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->internal_tokenizer_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = min_supported_codepoint_ratio(); _o->min_supported_codepoint_ratio = _e; }; - { auto _e = feature_version(); _o->feature_version = _e; }; - { auto _e = tokenization_type(); _o->tokenization_type = _e; }; - { auto _e = icu_preserve_whitespace_tokens(); _o->icu_preserve_whitespace_tokens = _e; }; - { auto _e = ignored_span_boundary_codepoints(); if (_e) { _o->ignored_span_boundary_codepoints.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->ignored_span_boundary_codepoints[_i] = _e->Get(_i); } } }; - { auto _e = bounds_sensitive_features(); if (_e) _o->bounds_sensitive_features = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT>(_e->UnPack(_resolver)); }; - { auto _e = allowed_chargrams(); if (_e) { _o->allowed_chargrams.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->allowed_chargrams[_i] = _e->Get(_i)->str(); } } }; - { auto _e = tokenize_on_script_change(); _o->tokenize_on_script_change = _e; }; -} - -inline flatbuffers::Offset<FeatureProcessorOptions> FeatureProcessorOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateFeatureProcessorOptions(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FeatureProcessorOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _num_buckets = _o->num_buckets; - auto _embedding_size = _o->embedding_size; - auto _embedding_quantization_bits = _o->embedding_quantization_bits; - auto _context_size = _o->context_size; - auto _max_selection_span = _o->max_selection_span; - auto _chargram_orders = _o->chargram_orders.size() ? _fbb.CreateVector(_o->chargram_orders) : 0; - auto _max_word_length = _o->max_word_length; - auto _unicode_aware_features = _o->unicode_aware_features; - auto _extract_case_feature = _o->extract_case_feature; - auto _extract_selection_mask_feature = _o->extract_selection_mask_feature; - auto _regexp_feature = _o->regexp_feature.size() ? _fbb.CreateVectorOfStrings(_o->regexp_feature) : 0; - auto _remap_digits = _o->remap_digits; - auto _lowercase_tokens = _o->lowercase_tokens; - auto _selection_reduced_output_space = _o->selection_reduced_output_space; - auto _collections = _o->collections.size() ? _fbb.CreateVectorOfStrings(_o->collections) : 0; - auto _default_collection = _o->default_collection; - auto _only_use_line_with_click = _o->only_use_line_with_click; - auto _split_tokens_on_selection_boundaries = _o->split_tokens_on_selection_boundaries; - auto _tokenization_codepoint_config = _o->tokenization_codepoint_config.size() ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>> (_o->tokenization_codepoint_config.size(), [](size_t i, _VectorArgs *__va) { return CreateTokenizationCodepointRange(*__va->__fbb, __va->__o->tokenization_codepoint_config[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _center_token_selection_method = _o->center_token_selection_method; - auto _snap_label_span_boundaries_to_containing_tokens = _o->snap_label_span_boundaries_to_containing_tokens; - auto _supported_codepoint_ranges = _o->supported_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->supported_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->supported_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _internal_tokenizer_codepoint_ranges = _o->internal_tokenizer_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->internal_tokenizer_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->internal_tokenizer_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _min_supported_codepoint_ratio = _o->min_supported_codepoint_ratio; - auto _feature_version = _o->feature_version; - auto _tokenization_type = _o->tokenization_type; - auto _icu_preserve_whitespace_tokens = _o->icu_preserve_whitespace_tokens; - auto _ignored_span_boundary_codepoints = _o->ignored_span_boundary_codepoints.size() ? _fbb.CreateVector(_o->ignored_span_boundary_codepoints) : 0; - auto _bounds_sensitive_features = _o->bounds_sensitive_features ? CreateBoundsSensitiveFeatures(_fbb, _o->bounds_sensitive_features.get(), _rehasher) : 0; - auto _allowed_chargrams = _o->allowed_chargrams.size() ? _fbb.CreateVectorOfStrings(_o->allowed_chargrams) : 0; - auto _tokenize_on_script_change = _o->tokenize_on_script_change; - return libtextclassifier2::CreateFeatureProcessorOptions( - _fbb, - _num_buckets, - _embedding_size, - _embedding_quantization_bits, - _context_size, - _max_selection_span, - _chargram_orders, - _max_word_length, - _unicode_aware_features, - _extract_case_feature, - _extract_selection_mask_feature, - _regexp_feature, - _remap_digits, - _lowercase_tokens, - _selection_reduced_output_space, - _collections, - _default_collection, - _only_use_line_with_click, - _split_tokens_on_selection_boundaries, - _tokenization_codepoint_config, - _center_token_selection_method, - _snap_label_span_boundaries_to_containing_tokens, - _supported_codepoint_ranges, - _internal_tokenizer_codepoint_ranges, - _min_supported_codepoint_ratio, - _feature_version, - _tokenization_type, - _icu_preserve_whitespace_tokens, - _ignored_span_boundary_codepoints, - _bounds_sensitive_features, - _allowed_chargrams, - _tokenize_on_script_change); -} - -inline const libtextclassifier2::Model *GetModel(const void *buf) { - return flatbuffers::GetRoot<libtextclassifier2::Model>(buf); -} - -inline const char *ModelIdentifier() { - return "TC2 "; -} - -inline bool ModelBufferHasIdentifier(const void *buf) { - return flatbuffers::BufferHasIdentifier( - buf, ModelIdentifier()); -} - -inline bool VerifyModelBuffer( - flatbuffers::Verifier &verifier) { - return verifier.VerifyBuffer<libtextclassifier2::Model>(ModelIdentifier()); -} - -inline void FinishModelBuffer( - flatbuffers::FlatBufferBuilder &fbb, - flatbuffers::Offset<libtextclassifier2::Model> root) { - fbb.Finish(root, ModelIdentifier()); -} - -inline std::unique_ptr<ModelT> UnPackModel( - const void *buf, - const flatbuffers::resolver_function_t *res = nullptr) { - return std::unique_ptr<ModelT>(GetModel(buf)->UnPack(res)); -} - -} // namespace libtextclassifier2 - -#endif // FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_ diff --git a/models/textclassifier.ar.model b/models/textclassifier.ar.model Binary files differindex 2342daa..4153026 100644 --- a/models/textclassifier.ar.model +++ b/models/textclassifier.ar.model diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model Binary files differindex a40f940..887d1df 100644 --- a/models/textclassifier.en.model +++ b/models/textclassifier.en.model diff --git a/models/textclassifier.es.model b/models/textclassifier.es.model Binary files differindex 7de4e5d..2093b41 100644 --- a/models/textclassifier.es.model +++ b/models/textclassifier.es.model diff --git a/models/textclassifier.fr.model b/models/textclassifier.fr.model Binary files differindex 1072041..b54345b 100644 --- a/models/textclassifier.fr.model +++ b/models/textclassifier.fr.model diff --git a/models/textclassifier.it.model b/models/textclassifier.it.model Binary files differindex 5bc98ae..e05d2db 100644 --- a/models/textclassifier.it.model +++ b/models/textclassifier.it.model diff --git a/models/textclassifier.ja.model b/models/textclassifier.ja.model Binary files differindex 9f60b8a..de10271 100644 --- a/models/textclassifier.ja.model +++ b/models/textclassifier.ja.model diff --git a/models/textclassifier.ko.model b/models/textclassifier.ko.model Binary files differindex 451df45..00d1bf3 100644 --- a/models/textclassifier.ko.model +++ b/models/textclassifier.ko.model diff --git a/models/textclassifier.nl.model b/models/textclassifier.nl.model Binary files differindex 07ea076..a733938 100644 --- a/models/textclassifier.nl.model +++ b/models/textclassifier.nl.model diff --git a/models/textclassifier.pl.model b/models/textclassifier.pl.model Binary files differindex 6cf62a5..3947dc2 100644 --- a/models/textclassifier.pl.model +++ b/models/textclassifier.pl.model diff --git a/models/textclassifier.pt.model b/models/textclassifier.pt.model Binary files differindex a745d58..b7bb298 100644 --- a/models/textclassifier.pt.model +++ b/models/textclassifier.pt.model diff --git a/models/textclassifier.ru.model b/models/textclassifier.ru.model Binary files differindex aa97ebc..377f73f 100644 --- a/models/textclassifier.ru.model +++ b/models/textclassifier.ru.model diff --git a/models/textclassifier.th.model b/models/textclassifier.th.model Binary files differindex 37339b7..41a3a3b 100644 --- a/models/textclassifier.th.model +++ b/models/textclassifier.th.model diff --git a/models/textclassifier.tr.model b/models/textclassifier.tr.model Binary files differindex 2405d9e..e284388 100644 --- a/models/textclassifier.tr.model +++ b/models/textclassifier.tr.model diff --git a/models/textclassifier.universal.model b/models/textclassifier.universal.model Binary files differindex 5c4220f..7856747 100644 --- a/models/textclassifier.universal.model +++ b/models/textclassifier.universal.model diff --git a/models/textclassifier.zh-Hant.model b/models/textclassifier.zh-Hant.model Binary files differindex 32edfe4..dd04f09 100644 --- a/models/textclassifier.zh-Hant.model +++ b/models/textclassifier.zh-Hant.model diff --git a/models/textclassifier.zh.model b/models/textclassifier.zh.model Binary files differindex eb1ff61..4e5f525 100644 --- a/models/textclassifier.zh.model +++ b/models/textclassifier.zh.model diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb Binary files differdeleted file mode 100644 index e1aa3ea..0000000 --- a/test_data/wrong_embeddings.fb +++ /dev/null diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc deleted file mode 100644 index 29cf745..0000000 --- a/textclassifier_jni.cc +++ /dev/null @@ -1,496 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// JNI wrapper for the TextClassifier. - -#include "textclassifier_jni.h" - -#include <jni.h> -#include <type_traits> -#include <vector> - -#include "text-classifier.h" -#include "util/base/integral_types.h" -#include "util/java/scoped_local_ref.h" -#include "util/java/string_utils.h" -#include "util/memory/mmap.h" -#include "util/utf8/unilib.h" - -using libtextclassifier2::AnnotatedSpan; -using libtextclassifier2::AnnotationOptions; -using libtextclassifier2::ClassificationOptions; -using libtextclassifier2::ClassificationResult; -using libtextclassifier2::CodepointSpan; -using libtextclassifier2::JStringToUtf8String; -using libtextclassifier2::Model; -using libtextclassifier2::ScopedLocalRef; -using libtextclassifier2::SelectionOptions; -using libtextclassifier2::TextClassifier; -#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU -using libtextclassifier2::UniLib; -#endif - -namespace libtextclassifier2 { - -using libtextclassifier2::CodepointSpan; - -namespace { - -std::string ToStlString(JNIEnv* env, const jstring& str) { - std::string result; - JStringToUtf8String(env, str, &result); - return result; -} - -jobjectArray ClassificationResultsToJObjectArray( - JNIEnv* env, - const std::vector<ClassificationResult>& classification_result) { - const ScopedLocalRef<jclass> result_class( - env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"), - env); - if (!result_class) { - TC_LOG(ERROR) << "Couldn't find ClassificationResult class."; - return nullptr; - } - const ScopedLocalRef<jclass> datetime_parse_class( - env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env); - if (!datetime_parse_class) { - TC_LOG(ERROR) << "Couldn't find DatetimeResult class."; - return nullptr; - } - - const jmethodID result_class_constructor = - env->GetMethodID(result_class.get(), "<init>", - "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR - "$DatetimeResult;)V"); - const jmethodID datetime_parse_class_constructor = - env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); - - const jobjectArray results = env->NewObjectArray(classification_result.size(), - result_class.get(), nullptr); - for (int i = 0; i < classification_result.size(); i++) { - jstring row_string = - env->NewStringUTF(classification_result[i].collection.c_str()); - jobject row_datetime_parse = nullptr; - if (classification_result[i].datetime_parse_result.IsSet()) { - row_datetime_parse = env->NewObject( - datetime_parse_class.get(), datetime_parse_class_constructor, - classification_result[i].datetime_parse_result.time_ms_utc, - classification_result[i].datetime_parse_result.granularity); - } - jobject result = - env->NewObject(result_class.get(), result_class_constructor, row_string, - static_cast<jfloat>(classification_result[i].score), - row_datetime_parse); - env->SetObjectArrayElement(results, i, result); - env->DeleteLocalRef(result); - } - return results; -} - -template <typename T, typename F> -std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object, - jclass class_object, F function, - const std::string& method_name, - const std::string& return_java_type) { - const jmethodID method = env->GetMethodID(class_object, method_name.c_str(), - ("()" + return_java_type).c_str()); - if (!method) { - return std::make_pair(false, T()); - } - return std::make_pair(true, (env->*function)(object, method)); -} - -SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { - if (!joptions) { - return {}; - } - - const ScopedLocalRef<jclass> options_class( - env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"), - env); - const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( - env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, - "getLocales", "Ljava/lang/String;"); - if (!status_or_locales.first) { - return {}; - } - - SelectionOptions options; - options.locales = - ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); - - return options; -} - -template <typename T> -T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, - const std::string& class_name) { - if (!joptions) { - return {}; - } - - const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), - env); - if (!options_class) { - return {}; - } - - const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( - env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, - "getLocale", "Ljava/lang/String;"); - const std::pair<bool, jobject> status_or_reference_timezone = - CallJniMethod0<jobject>(env, joptions, options_class.get(), - &JNIEnv::CallObjectMethod, "getReferenceTimezone", - "Ljava/lang/String;"); - const std::pair<bool, int64> status_or_reference_time_ms_utc = - CallJniMethod0<int64>(env, joptions, options_class.get(), - &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", - "J"); - - if (!status_or_locales.first || !status_or_reference_timezone.first || - !status_or_reference_time_ms_utc.first) { - return {}; - } - - T options; - options.locales = - ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); - options.reference_timezone = ToStlString( - env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); - options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; - return options; -} - -ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, - jobject joptions) { - return FromJavaOptionsInternal<ClassificationOptions>( - env, joptions, - TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions"); -} - -AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { - return FromJavaOptionsInternal<AnnotationOptions>( - env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions"); -} - -CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, - CodepointSpan orig_indices, - bool from_utf8) { - const libtextclassifier2::UnicodeText unicode_str = - libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); - - int unicode_index = 0; - int bmp_index = 0; - - const int* source_index; - const int* target_index; - if (from_utf8) { - source_index = &unicode_index; - target_index = &bmp_index; - } else { - source_index = &bmp_index; - target_index = &unicode_index; - } - - CodepointSpan result{-1, -1}; - std::function<void()> assign_indices_fn = [&result, &orig_indices, - &source_index, &target_index]() { - if (orig_indices.first == *source_index) { - result.first = *target_index; - } - - if (orig_indices.second == *source_index) { - result.second = *target_index; - } - }; - - for (auto it = unicode_str.begin(); it != unicode_str.end(); - ++it, ++unicode_index, ++bmp_index) { - assign_indices_fn(); - - // There is 1 extra character in the input for each UTF8 character > 0xFFFF. - if (*it > 0xFFFF) { - ++bmp_index; - } - } - assign_indices_fn(); - - return result; -} - -} // namespace - -CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, - CodepointSpan bmp_indices) { - return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); -} - -CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, - CodepointSpan utf8_indices) { - return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); -} - -jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) { - // Get system-level file descriptor from AssetFileDescriptor. - ScopedLocalRef<jclass> afd_class( - env->FindClass("android/content/res/AssetFileDescriptor"), env); - if (afd_class == nullptr) { - TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor."; - return reinterpret_cast<jlong>(nullptr); - } - jmethodID afd_class_getFileDescriptor = env->GetMethodID( - afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;"); - if (afd_class_getFileDescriptor == nullptr) { - TC_LOG(ERROR) << "Couldn't find getFileDescriptor."; - return reinterpret_cast<jlong>(nullptr); - } - - ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"), - env); - if (fd_class == nullptr) { - TC_LOG(ERROR) << "Couldn't find FileDescriptor."; - return reinterpret_cast<jlong>(nullptr); - } - jfieldID fd_class_descriptor = - env->GetFieldID(fd_class.get(), "descriptor", "I"); - if (fd_class_descriptor == nullptr) { - TC_LOG(ERROR) << "Couldn't find descriptor."; - return reinterpret_cast<jlong>(nullptr); - } - - jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor); - return env->GetIntField(bundle_jfd, fd_class_descriptor); -} - -jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { - if (!mmap->handle().ok()) { - return env->NewStringUTF(""); - } - const Model* model = libtextclassifier2::ViewModel( - mmap->handle().start(), mmap->handle().num_bytes()); - if (!model || !model->locales()) { - return env->NewStringUTF(""); - } - return env->NewStringUTF(model->locales()->c_str()); -} - -jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { - if (!mmap->handle().ok()) { - return 0; - } - const Model* model = libtextclassifier2::ViewModel( - mmap->handle().start(), mmap->handle().num_bytes()); - if (!model) { - return 0; - } - return model->version(); -} - -jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { - if (!mmap->handle().ok()) { - return env->NewStringUTF(""); - } - const Model* model = libtextclassifier2::ViewModel( - mmap->handle().start(), mmap->handle().num_bytes()); - if (!model || !model->name()) { - return env->NewStringUTF(""); - } - return env->NewStringUTF(model->name()->c_str()); -} - -} // namespace libtextclassifier2 - -using libtextclassifier2::ClassificationResultsToJObjectArray; -using libtextclassifier2::ConvertIndicesBMPToUTF8; -using libtextclassifier2::ConvertIndicesUTF8ToBMP; -using libtextclassifier2::FromJavaAnnotationOptions; -using libtextclassifier2::FromJavaClassificationOptions; -using libtextclassifier2::FromJavaSelectionOptions; -using libtextclassifier2::ToStlString; - -JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) -(JNIEnv* env, jobject thiz, jint fd) { -#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU - return reinterpret_cast<jlong>( - TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env)); -#else - return reinterpret_cast<jlong>( - TextClassifier::FromFileDescriptor(fd).release()); -#endif -} - -JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) -(JNIEnv* env, jobject thiz, jstring path) { - const std::string path_str = ToStlString(env, path); -#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU - return reinterpret_cast<jlong>( - TextClassifier::FromPath(path_str, new UniLib(env)).release()); -#else - return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release()); -#endif -} - -JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { - const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); -#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU - return reinterpret_cast<jlong>( - TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env)) - .release()); -#else - return reinterpret_cast<jlong>( - TextClassifier::FromFileDescriptor(fd, offset, size).release()); -#endif -} - -JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end, jobject options) { - if (!ptr) { - return nullptr; - } - - TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); - - const std::string context_utf8 = ToStlString(env, context); - CodepointSpan input_indices = - ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); - CodepointSpan selection = model->SuggestSelection( - context_utf8, input_indices, FromJavaSelectionOptions(env, options)); - selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); - - jintArray result = env->NewIntArray(2); - env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); - env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); - return result; -} - -JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end, jobject options) { - if (!ptr) { - return nullptr; - } - TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr); - - const std::string context_utf8 = ToStlString(env, context); - const CodepointSpan input_indices = - ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); - const std::vector<ClassificationResult> classification_result = - ff_model->ClassifyText(context_utf8, input_indices, - FromJavaClassificationOptions(env, options)); - - return ClassificationResultsToJObjectArray(env, classification_result); -} - -JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { - if (!ptr) { - return nullptr; - } - TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); - std::string context_utf8 = ToStlString(env, context); - std::vector<AnnotatedSpan> annotations = - model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); - - jclass result_class = - env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"); - if (!result_class) { - TC_LOG(ERROR) << "Couldn't find result class: " - << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"; - return nullptr; - } - - jmethodID result_class_constructor = env->GetMethodID( - result_class, "<init>", - "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V"); - - jobjectArray results = - env->NewObjectArray(annotations.size(), result_class, nullptr); - - for (int i = 0; i < annotations.size(); ++i) { - CodepointSpan span_bmp = - ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); - jobject result = env->NewObject( - result_class, result_class_constructor, - static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second), - ClassificationResultsToJObjectArray(env, - - annotations[i].classification)); - env->SetObjectArrayElement(results, i, result); - env->DeleteLocalRef(result); - } - env->DeleteLocalRef(result_class); - return results; -} - -JNI_METHOD(void, TC_CLASS_NAME, nativeClose) -(JNIEnv* env, jobject thiz, jlong ptr) { - TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); - delete model; -} - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) -(JNIEnv* env, jobject clazz, jint fd) { - TC_LOG(WARNING) << "Using deprecated getLanguage()."; - return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd); -} - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) -(JNIEnv* env, jobject clazz, jint fd) { - const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( - new libtextclassifier2::ScopedMmap(fd)); - return GetLocalesFromMmap(env, mmap.get()); -} - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { - const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); - const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( - new libtextclassifier2::ScopedMmap(fd, offset, size)); - return GetLocalesFromMmap(env, mmap.get()); -} - -JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) -(JNIEnv* env, jobject clazz, jint fd) { - const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( - new libtextclassifier2::ScopedMmap(fd)); - return GetVersionFromMmap(env, mmap.get()); -} - -JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { - const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); - const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( - new libtextclassifier2::ScopedMmap(fd, offset, size)); - return GetVersionFromMmap(env, mmap.get()); -} - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) -(JNIEnv* env, jobject clazz, jint fd) { - const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( - new libtextclassifier2::ScopedMmap(fd)); - return GetNameFromMmap(env, mmap.get()); -} - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { - const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); - const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( - new libtextclassifier2::ScopedMmap(fd, offset, size)); - return GetNameFromMmap(env, mmap.get()); -} diff --git a/textclassifier_jni.h b/textclassifier_jni.h deleted file mode 100644 index d6e742e..0000000 --- a/textclassifier_jni.h +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_ -#define LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_ - -#include <jni.h> -#include <string> - -#include "types.h" - -// When we use a macro as an argument for a macro, an additional level of -// indirection is needed, if the macro argument is used with # or ##. -#define ADD_QUOTES_HELPER(TOKEN) #TOKEN -#define ADD_QUOTES(TOKEN) ADD_QUOTES_HELPER(TOKEN) - -#ifndef TC_PACKAGE_NAME -#define TC_PACKAGE_NAME android_view_textclassifier -#endif - -#ifndef TC_CLASS_NAME -#define TC_CLASS_NAME TextClassifierImplNative -#endif -#define TC_CLASS_NAME_STR ADD_QUOTES(TC_CLASS_NAME) - -#ifndef TC_PACKAGE_PATH -#define TC_PACKAGE_PATH "android/view/textclassifier/" -#endif - -#define JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) \ - Java_##package_name##_##class_name##_##method_name - -#define JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \ - method_name) \ - JNIEXPORT return_type JNICALL JNI_METHOD_NAME_INTERNAL( \ - package_name, class_name, method_name) - -// The indirection is needed to correctly expand the TC_PACKAGE_NAME macro. -// See the explanation near ADD_QUOTES macro. -#define JNI_METHOD2(return_type, package_name, class_name, method_name) \ - JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name) - -#define JNI_METHOD(return_type, class_name, method_name) \ - JNI_METHOD2(return_type, TC_PACKAGE_NAME, class_name, method_name) - -#define JNI_METHOD_NAME2(package_name, class_name, method_name) \ - JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) - -#define JNI_METHOD_NAME(class_name, method_name) \ - JNI_METHOD_NAME2(TC_PACKAGE_NAME, class_name, method_name) - -#ifdef __cplusplus -extern "C" { -#endif - -// SmartSelection. -JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) -(JNIEnv* env, jobject thiz, jint fd); - -JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) -(JNIEnv* env, jobject thiz, jstring path); - -JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); - -JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end, jobject options); - -JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, - jint selection_end, jobject options); - -JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) -(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options); - -JNI_METHOD(void, TC_CLASS_NAME, nativeClose) -(JNIEnv* env, jobject thiz, jlong ptr); - -// DEPRECATED. Use nativeGetLocales instead. -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) -(JNIEnv* env, jobject clazz, jint fd); - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) -(JNIEnv* env, jobject clazz, jint fd); - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); - -JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) -(JNIEnv* env, jobject clazz, jint fd); - -JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) -(JNIEnv* env, jobject clazz, jint fd); - -JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) -(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size); - -#ifdef __cplusplus -} -#endif - -namespace libtextclassifier2 { - -// Given a utf8 string and a span expressed in Java BMP (basic multilingual -// plane) codepoints, converts it to a span expressed in utf8 codepoints. -libtextclassifier2::CodepointSpan ConvertIndicesBMPToUTF8( - const std::string& utf8_str, libtextclassifier2::CodepointSpan bmp_indices); - -// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a -// span expressed in Java BMP (basic multilingual plane) codepoints. -libtextclassifier2::CodepointSpan ConvertIndicesUTF8ToBMP( - const std::string& utf8_str, - libtextclassifier2::CodepointSpan utf8_indices); - -} // namespace libtextclassifier2 - -#endif // LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_ diff --git a/util/calendar/calendar-icu.cc b/util/calendar/calendar-icu.cc deleted file mode 100644 index 34ea22d..0000000 --- a/util/calendar/calendar-icu.cc +++ /dev/null @@ -1,436 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "util/calendar/calendar-icu.h" - -#include <memory> - -#include "util/base/macros.h" -#include "unicode/gregocal.h" -#include "unicode/timezone.h" -#include "unicode/ucal.h" - -namespace libtextclassifier2 { -namespace { -int MapToDayOfWeekOrDefault(int relation_type, int default_value) { - switch (relation_type) { - case DateParseData::MONDAY: - return UCalendarDaysOfWeek::UCAL_MONDAY; - case DateParseData::TUESDAY: - return UCalendarDaysOfWeek::UCAL_TUESDAY; - case DateParseData::WEDNESDAY: - return UCalendarDaysOfWeek::UCAL_WEDNESDAY; - case DateParseData::THURSDAY: - return UCalendarDaysOfWeek::UCAL_THURSDAY; - case DateParseData::FRIDAY: - return UCalendarDaysOfWeek::UCAL_FRIDAY; - case DateParseData::SATURDAY: - return UCalendarDaysOfWeek::UCAL_SATURDAY; - case DateParseData::SUNDAY: - return UCalendarDaysOfWeek::UCAL_SUNDAY; - default: - return default_value; - } -} - -bool DispatchToRecedeOrToLastDayOfWeek(icu::Calendar* date, int relation_type, - int distance) { - UErrorCode status = U_ZERO_ERROR; - switch (relation_type) { - case DateParseData::MONDAY: - case DateParseData::TUESDAY: - case DateParseData::WEDNESDAY: - case DateParseData::THURSDAY: - case DateParseData::FRIDAY: - case DateParseData::SATURDAY: - case DateParseData::SUNDAY: - for (int i = 0; i < distance; i++) { - do { - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error day of week"; - return false; - } - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - } while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) != - MapToDayOfWeekOrDefault(relation_type, 1)); - } - return true; - case DateParseData::DAY: - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -1 * distance, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - - return true; - case DateParseData::WEEK: - date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1); - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -7 * (distance - 1), - status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a week"; - return false; - } - - return true; - case DateParseData::MONTH: - date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); - date->add(UCalendarDateFields::UCAL_MONTH, -1 * (distance - 1), status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a month"; - return false; - } - return true; - case DateParseData::YEAR: - date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1); - date->add(UCalendarDateFields::UCAL_YEAR, -1 * (distance - 1), status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a year"; - - return true; - default: - return false; - } - return false; - } -} - -bool DispatchToAdvancerOrToNextOrSameDayOfWeek(icu::Calendar* date, - int relation_type) { - UErrorCode status = U_ZERO_ERROR; - switch (relation_type) { - case DateParseData::MONDAY: - case DateParseData::TUESDAY: - case DateParseData::WEDNESDAY: - case DateParseData::THURSDAY: - case DateParseData::FRIDAY: - case DateParseData::SATURDAY: - case DateParseData::SUNDAY: - while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) != - MapToDayOfWeekOrDefault(relation_type, 1)) { - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error day of week"; - return false; - } - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - } - return true; - case DateParseData::DAY: - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - - return true; - case DateParseData::WEEK: - date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1); - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 7, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a week"; - return false; - } - - return true; - case DateParseData::MONTH: - date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); - date->add(UCalendarDateFields::UCAL_MONTH, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a month"; - return false; - } - return true; - case DateParseData::YEAR: - date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1); - date->add(UCalendarDateFields::UCAL_YEAR, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a year"; - - return true; - default: - return false; - } - return false; - } -} - -bool DispatchToAdvancerOrToNextDayOfWeek(icu::Calendar* date, int relation_type, - int distance) { - UErrorCode status = U_ZERO_ERROR; - switch (relation_type) { - case DateParseData::MONDAY: - case DateParseData::TUESDAY: - case DateParseData::WEDNESDAY: - case DateParseData::THURSDAY: - case DateParseData::FRIDAY: - case DateParseData::SATURDAY: - case DateParseData::SUNDAY: - for (int i = 0; i < distance; i++) { - do { - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error day of week"; - return false; - } - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - } while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) != - MapToDayOfWeekOrDefault(relation_type, 1)); - } - return true; - case DateParseData::DAY: - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, distance, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - - return true; - case DateParseData::WEEK: - date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1); - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 7 * distance, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a week"; - return false; - } - - return true; - case DateParseData::MONTH: - date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); - date->add(UCalendarDateFields::UCAL_MONTH, 1 * distance, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a month"; - return false; - } - return true; - case DateParseData::YEAR: - date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1); - date->add(UCalendarDateFields::UCAL_YEAR, 1 * distance, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a year"; - - return true; - default: - return false; - } - return false; - } -} - -bool RoundToGranularity(DatetimeGranularity granularity, - icu::Calendar* calendar) { - // Force recomputation before doing the rounding. - UErrorCode status = U_ZERO_ERROR; - calendar->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "Can't interpret date."; - return false; - } - - switch (granularity) { - case GRANULARITY_YEAR: - calendar->set(UCalendarDateFields::UCAL_MONTH, 0); - TC_FALLTHROUGH_INTENDED; - case GRANULARITY_MONTH: - calendar->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1); - TC_FALLTHROUGH_INTENDED; - case GRANULARITY_DAY: - calendar->set(UCalendarDateFields::UCAL_HOUR, 0); - TC_FALLTHROUGH_INTENDED; - case GRANULARITY_HOUR: - calendar->set(UCalendarDateFields::UCAL_MINUTE, 0); - TC_FALLTHROUGH_INTENDED; - case GRANULARITY_MINUTE: - calendar->set(UCalendarDateFields::UCAL_SECOND, 0); - break; - - case GRANULARITY_WEEK: - calendar->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, - calendar->getFirstDayOfWeek()); - calendar->set(UCalendarDateFields::UCAL_HOUR, 0); - calendar->set(UCalendarDateFields::UCAL_MINUTE, 0); - calendar->set(UCalendarDateFields::UCAL_SECOND, 0); - break; - - case GRANULARITY_UNKNOWN: - case GRANULARITY_SECOND: - break; - } - - return true; -} - -} // namespace - -bool CalendarLib::InterpretParseData(const DateParseData& parse_data, - int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& reference_locale, - DatetimeGranularity granularity, - int64* interpreted_time_ms_utc) const { - UErrorCode status = U_ZERO_ERROR; - - std::unique_ptr<icu::Calendar> date(icu::Calendar::createInstance( - icu::Locale::createFromName(reference_locale.c_str()), status)); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error getting calendar instance"; - return false; - } - - date->adoptTimeZone(icu::TimeZone::createTimeZone( - icu::UnicodeString::fromUTF8(reference_timezone))); - date->setTime(reference_time_ms_utc, status); - - // By default, the parsed time is interpreted to be on the reference day. But - // a parsed date, should have time 0:00:00 unless specified. - date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, 0); - date->set(UCalendarDateFields::UCAL_MINUTE, 0); - date->set(UCalendarDateFields::UCAL_SECOND, 0); - date->set(UCalendarDateFields::UCAL_MILLISECOND, 0); - - static const int64 kMillisInHour = 1000 * 60 * 60; - if (parse_data.field_set_mask & DateParseData::Fields::ZONE_OFFSET_FIELD) { - date->set(UCalendarDateFields::UCAL_ZONE_OFFSET, - parse_data.zone_offset * kMillisInHour); - } - if (parse_data.field_set_mask & DateParseData::Fields::DST_OFFSET_FIELD) { - // convert from hours to milliseconds - date->set(UCalendarDateFields::UCAL_DST_OFFSET, - parse_data.dst_offset * kMillisInHour); - } - - if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) { - switch (parse_data.relation) { - case DateParseData::Relation::NEXT: - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_TYPE_FIELD) { - if (!DispatchToAdvancerOrToNextDayOfWeek( - date.get(), parse_data.relation_type, 1)) { - return false; - } - } - break; - case DateParseData::Relation::NEXT_OR_SAME: - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_TYPE_FIELD) { - if (!DispatchToAdvancerOrToNextOrSameDayOfWeek( - date.get(), parse_data.relation_type)) { - return false; - } - } - break; - case DateParseData::Relation::LAST: - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_TYPE_FIELD) { - if (!DispatchToRecedeOrToLastDayOfWeek(date.get(), - parse_data.relation_type, 1)) { - return false; - } - } - break; - case DateParseData::Relation::NOW: - // NOOP - break; - case DateParseData::Relation::TOMORROW: - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error adding a day"; - return false; - } - break; - case DateParseData::Relation::YESTERDAY: - date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -1, status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error subtracting a day"; - return false; - } - break; - case DateParseData::Relation::PAST: - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_TYPE_FIELD) { - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_DISTANCE_FIELD) { - if (!DispatchToRecedeOrToLastDayOfWeek( - date.get(), parse_data.relation_type, - parse_data.relation_distance)) { - return false; - } - } - } - break; - case DateParseData::Relation::FUTURE: - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_TYPE_FIELD) { - if (parse_data.field_set_mask & - DateParseData::Fields::RELATION_DISTANCE_FIELD) { - if (!DispatchToAdvancerOrToNextDayOfWeek( - date.get(), parse_data.relation_type, - parse_data.relation_distance)) { - return false; - } - } - } - break; - } - } - if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) { - date->set(UCalendarDateFields::UCAL_YEAR, parse_data.year); - } - if (parse_data.field_set_mask & DateParseData::Fields::MONTH_FIELD) { - // NOTE: Java and ICU disagree on month formats - date->set(UCalendarDateFields::UCAL_MONTH, parse_data.month - 1); - } - if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) { - date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, parse_data.day_of_month); - } - if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) { - if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD && - parse_data.ampm == 1 && parse_data.hour < 12) { - date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, parse_data.hour + 12); - } else { - date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, parse_data.hour); - } - } - if (parse_data.field_set_mask & DateParseData::Fields::MINUTE_FIELD) { - date->set(UCalendarDateFields::UCAL_MINUTE, parse_data.minute); - } - if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) { - date->set(UCalendarDateFields::UCAL_SECOND, parse_data.second); - } - - if (!RoundToGranularity(granularity, date.get())) { - return false; - } - - *interpreted_time_ms_utc = date->getTime(status); - if (U_FAILURE(status)) { - TC_LOG(ERROR) << "error getting time from instance"; - return false; - } - - return true; -} -} // namespace libtextclassifier2 diff --git a/util/calendar/calendar-icu.h b/util/calendar/calendar-icu.h deleted file mode 100644 index 8aae7ab..0000000 --- a/util/calendar/calendar-icu.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_ -#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_ - -#include <string> - -#include "types.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" - -namespace libtextclassifier2 { - -class CalendarLib { - public: - // Interprets parse_data as milliseconds since_epoch. Relative times are - // resolved against the current time (reference_time_ms_utc). Returns true if - // the interpratation was successful, false otherwise. - bool InterpretParseData(const DateParseData& parse_data, - int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& reference_locale, - DatetimeGranularity granularity, - int64* interpreted_time_ms_utc) const; -}; -} // namespace libtextclassifier2 -#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_ diff --git a/util/calendar/calendar_test.cc b/util/calendar/calendar_test.cc deleted file mode 100644 index 1f29106..0000000 --- a/util/calendar/calendar_test.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// This test serves the purpose of making sure all the different implementations -// of the unspoken CalendarLib interface support the same methods. - -#include "util/calendar/calendar.h" -#include "util/base/logging.h" - -#include "gtest/gtest.h" - -namespace libtextclassifier2 { -namespace { - -TEST(CalendarTest, Interface) { - CalendarLib calendar; - int64 time; - std::string timezone; - bool result = calendar.InterpretParseData( - DateParseData{0l, 0, 0, 0, 0, 0, 0, 0, 0, 0, - static_cast<DateParseData::Relation>(0), - static_cast<DateParseData::RelationType>(0), 0}, - 0L, "Zurich", "en-CH", GRANULARITY_UNKNOWN, &time); - TC_LOG(INFO) << result; -} - -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST(CalendarTest, RoundingToGranularity) { - CalendarLib calendar; - int64 time; - std::string timezone; - DateParseData data; - data.year = 2018; - data.month = 4; - data.day_of_month = 25; - data.hour = 9; - data.minute = 33; - data.second = 59; - data.field_set_mask = DateParseData::YEAR_FIELD | DateParseData::MONTH_FIELD | - DateParseData::DAY_FIELD | DateParseData::HOUR_FIELD | - DateParseData::MINUTE_FIELD | - DateParseData::SECOND_FIELD; - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_YEAR, &time)); - EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_MONTH, &time)); - EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_WEEK, &time)); - EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"*-CH", - /*granularity=*/GRANULARITY_WEEK, &time)); - EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-US", - /*granularity=*/GRANULARITY_WEEK, &time)); - EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"*-US", - /*granularity=*/GRANULARITY_WEEK, &time)); - EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_DAY, &time)); - EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_HOUR, &time)); - EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_MINUTE, &time)); - EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */); - - ASSERT_TRUE(calendar.InterpretParseData( - data, - /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", - /*reference_locale=*/"en-CH", - /*granularity=*/GRANULARITY_SECOND, &time)); - EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */); -} -#endif // LIBTEXTCLASSIFIER_UNILIB_DUMMY - -} // namespace -} // namespace libtextclassifier2 diff --git a/util/gtl/map_util.h b/util/gtl/map_util.h deleted file mode 100644 index bd020f8..0000000 --- a/util/gtl/map_util.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_ -#define LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_ - -namespace libtextclassifier2 { - -// Returns a const reference to the value associated with the given key if it -// exists, otherwise returns a const reference to the provided default value. -// -// WARNING: If a temporary object is passed as the default "value," -// this function will return a reference to that temporary object, -// which will be destroyed at the end of the statement. A common -// example: if you have a map with string values, and you pass a char* -// as the default "value," either use the returned value immediately -// or store it in a string (not string&). -template <class Collection> -const typename Collection::value_type::second_type& FindWithDefault( - const Collection& collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& value) { - typename Collection::const_iterator it = collection.find(key); - if (it == collection.end()) { - return value; - } - return it->second; -} - -// Inserts the given key and value into the given collection if and only if the -// given key did NOT already exist in the collection. If the key previously -// existed in the collection, the value is not changed. Returns true if the -// key-value pair was inserted; returns false if the key was already present. -template <class Collection> -bool InsertIfNotPresent(Collection* const collection, - const typename Collection::value_type& vt) { - return collection->insert(vt).second; -} - -// Same as above except the key and value are passed separately. -template <class Collection> -bool InsertIfNotPresent( - Collection* const collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& value) { - return InsertIfNotPresent(collection, - typename Collection::value_type(key, value)); -} - -} // namespace libtextclassifier2 - -#endif // LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_ diff --git a/util/gtl/stl_util.h b/util/gtl/stl_util.h deleted file mode 100644 index 7b88e05..0000000 --- a/util/gtl/stl_util.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_ -#define LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_ - -namespace libtextclassifier2 { - -// Deletes all the elements in an STL container and clears the container. This -// function is suitable for use with a vector, set, hash_set, or any other STL -// container which defines sensible begin(), end(), and clear() methods. -// If container is NULL, this function is a no-op. -template <typename T> -void STLDeleteElements(T *container) { - if (!container) return; - auto it = container->begin(); - while (it != container->end()) { - auto temp = it; - ++it; - delete *temp; - } - container->clear(); -} - -// Given an STL container consisting of (key, value) pairs, STLDeleteValues -// deletes all the "value" components and clears the container. Does nothing in -// the case it's given a nullptr. -template <typename T> -void STLDeleteValues(T *container) { - if (!container) return; - auto it = container->begin(); - while (it != container->end()) { - auto temp = it; - ++it; - delete temp->second; - } - container->clear(); -} - -} // namespace libtextclassifier2 - -#endif // LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_ diff --git a/util/hash/hash.cc b/util/hash/hash.cc index 9722ddc..eaa85ae 100644 --- a/util/hash/hash.cc +++ b/util/hash/hash.cc @@ -16,7 +16,7 @@ #include "util/hash/hash.h" -#include "util/base/macros.h" +#include "utils/base/macros.h" namespace libtextclassifier2 { @@ -59,10 +59,10 @@ uint32 Hash32(const char *data, size_t n, uint32 seed) { switch (n) { case 3: h ^= ByteAs32(data[2]) << 16; - TC_FALLTHROUGH_INTENDED; + TC3_FALLTHROUGH_INTENDED; case 2: h ^= ByteAs32(data[1]) << 8; - TC_FALLTHROUGH_INTENDED; + TC3_FALLTHROUGH_INTENDED; case 1: h ^= ByteAs32(data[0]); h *= m; diff --git a/util/hash/hash.h b/util/hash/hash.h index b7a3b53..9353e5f 100644 --- a/util/hash/hash.h +++ b/util/hash/hash.h @@ -19,10 +19,12 @@ #include <string> -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" namespace libtextclassifier2 { +using namespace libtextclassifier3; + uint32 Hash32(const char *data, size_t n, uint32 seed); static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) { diff --git a/util/java/string_utils.h b/util/java/string_utils.h deleted file mode 100644 index 6a85856..0000000 --- a/util/java/string_utils.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_ -#define LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_ - -#include <jni.h> -#include <string> - -namespace libtextclassifier2 { - -bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result); - -} // namespace libtextclassifier2 - -#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_ diff --git a/util/strings/utf8.cc b/util/strings/utf8.cc deleted file mode 100644 index 39dcb4e..0000000 --- a/util/strings/utf8.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "util/strings/utf8.h" - -namespace libtextclassifier2 { -bool IsValidUTF8(const char *src, int size) { - for (int i = 0; i < size;) { - // Unexpected trail byte. - if (IsTrailByte(src[i])) { - return false; - } - - const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[i]); - if (num_codepoint_bytes <= 0 || i + num_codepoint_bytes > size) { - return false; - } - - // Check that remaining bytes in the codepoint are trailing bytes. - i++; - for (int k = 1; k < num_codepoint_bytes; k++, i++) { - if (!IsTrailByte(src[i])) { - return false; - } - } - } - return true; -} -} // namespace libtextclassifier2 diff --git a/util/utf8/unilib-icu.cc b/util/utf8/unilib-icu.cc deleted file mode 100644 index 9e9ce19..0000000 --- a/util/utf8/unilib-icu.cc +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "util/utf8/unilib-icu.h" - -#include <utility> - -namespace libtextclassifier2 { - -bool UniLib::ParseInt32(const UnicodeText& text, int* result) const { - UErrorCode status = U_ZERO_ERROR; - UNumberFormat* format_alias = - unum_open(UNUM_DECIMAL, nullptr, 0, "en_US_POSIX", nullptr, &status); - if (U_FAILURE(status)) { - return false; - } - icu::UnicodeString utf8_string = icu::UnicodeString::fromUTF8( - icu::StringPiece(text.data(), text.size_bytes())); - int parse_index = 0; - const int32 integer = unum_parse(format_alias, utf8_string.getBuffer(), - utf8_string.length(), &parse_index, &status); - *result = integer; - unum_close(format_alias); - if (U_FAILURE(status) || parse_index != utf8_string.length()) { - return false; - } - return true; -} - -bool UniLib::IsOpeningBracket(char32 codepoint) const { - return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) == - U_BPT_OPEN; -} - -bool UniLib::IsClosingBracket(char32 codepoint) const { - return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) == - U_BPT_CLOSE; -} - -bool UniLib::IsWhitespace(char32 codepoint) const { - return u_isWhitespace(codepoint); -} - -bool UniLib::IsDigit(char32 codepoint) const { return u_isdigit(codepoint); } - -bool UniLib::IsUpper(char32 codepoint) const { return u_isupper(codepoint); } - -char32 UniLib::ToLower(char32 codepoint) const { return u_tolower(codepoint); } - -char32 UniLib::GetPairedBracket(char32 codepoint) const { - return u_getBidiPairedBracket(codepoint); -} - -UniLib::RegexMatcher::RegexMatcher(icu::RegexPattern* pattern, - icu::UnicodeString text) - : text_(std::move(text)), - last_find_offset_(0), - last_find_offset_codepoints_(0), - last_find_offset_dirty_(true) { - UErrorCode status = U_ZERO_ERROR; - matcher_.reset(pattern->matcher(text_, status)); - if (U_FAILURE(status)) { - matcher_.reset(nullptr); - } -} - -std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher( - const UnicodeText& input) const { - return std::unique_ptr<UniLib::RegexMatcher>(new UniLib::RegexMatcher( - pattern_.get(), icu::UnicodeString::fromUTF8( - icu::StringPiece(input.data(), input.size_bytes())))); -} - -constexpr int UniLib::RegexMatcher::kError; -constexpr int UniLib::RegexMatcher::kNoError; - -bool UniLib::RegexMatcher::Matches(int* status) const { - if (!matcher_) { - *status = kError; - return false; - } - - UErrorCode icu_status = U_ZERO_ERROR; - const bool result = matcher_->matches(/*startIndex=*/0, icu_status); - if (U_FAILURE(icu_status)) { - *status = kError; - return false; - } - *status = kNoError; - return result; -} - -bool UniLib::RegexMatcher::ApproximatelyMatches(int* status) { - if (!matcher_) { - *status = kError; - return false; - } - - matcher_->reset(); - *status = kNoError; - if (!Find(status) || *status != kNoError) { - return false; - } - const int found_start = Start(status); - if (*status != kNoError) { - return false; - } - const int found_end = End(status); - if (*status != kNoError) { - return false; - } - if (found_start != 0 || found_end != text_.countChar32()) { - return false; - } - return true; -} - -bool UniLib::RegexMatcher::UpdateLastFindOffset() const { - if (!last_find_offset_dirty_) { - return true; - } - - // Update the position of the match. - UErrorCode icu_status = U_ZERO_ERROR; - const int find_offset = matcher_->start(0, icu_status); - if (U_FAILURE(icu_status)) { - return false; - } - last_find_offset_codepoints_ += - text_.countChar32(last_find_offset_, find_offset - last_find_offset_); - last_find_offset_ = find_offset; - last_find_offset_dirty_ = false; - - return true; -} - -bool UniLib::RegexMatcher::Find(int* status) { - if (!matcher_) { - *status = kError; - return false; - } - UErrorCode icu_status = U_ZERO_ERROR; - const bool result = matcher_->find(icu_status); - if (U_FAILURE(icu_status)) { - *status = kError; - return false; - } - - last_find_offset_dirty_ = true; - *status = kNoError; - return result; -} - -int UniLib::RegexMatcher::Start(int* status) const { - return Start(/*group_idx=*/0, status); -} - -int UniLib::RegexMatcher::Start(int group_idx, int* status) const { - if (!matcher_ || !UpdateLastFindOffset()) { - *status = kError; - return kError; - } - - UErrorCode icu_status = U_ZERO_ERROR; - const int result = matcher_->start(group_idx, icu_status); - if (U_FAILURE(icu_status)) { - *status = kError; - return kError; - } - *status = kNoError; - - // If the group didn't participate in the match the result is -1 and is - // incompatible with the caching logic bellow. - if (result == -1) { - return -1; - } - - return last_find_offset_codepoints_ + - text_.countChar32(/*start=*/last_find_offset_, - /*length=*/result - last_find_offset_); -} - -int UniLib::RegexMatcher::End(int* status) const { - return End(/*group_idx=*/0, status); -} - -int UniLib::RegexMatcher::End(int group_idx, int* status) const { - if (!matcher_ || !UpdateLastFindOffset()) { - *status = kError; - return kError; - } - UErrorCode icu_status = U_ZERO_ERROR; - const int result = matcher_->end(group_idx, icu_status); - if (U_FAILURE(icu_status)) { - *status = kError; - return kError; - } - *status = kNoError; - - // If the group didn't participate in the match the result is -1 and is - // incompatible with the caching logic bellow. - if (result == -1) { - return -1; - } - - return last_find_offset_codepoints_ + - text_.countChar32(/*start=*/last_find_offset_, - /*length=*/result - last_find_offset_); -} - -UnicodeText UniLib::RegexMatcher::Group(int* status) const { - return Group(/*group_idx=*/0, status); -} - -UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const { - if (!matcher_) { - *status = kError; - return UTF8ToUnicodeText("", /*do_copy=*/false); - } - std::string result = ""; - UErrorCode icu_status = U_ZERO_ERROR; - const icu::UnicodeString result_icu = matcher_->group(group_idx, icu_status); - if (U_FAILURE(icu_status)) { - *status = kError; - return UTF8ToUnicodeText("", /*do_copy=*/false); - } - result_icu.toUTF8String(result); - *status = kNoError; - return UTF8ToUnicodeText(result, /*do_copy=*/true); -} - -constexpr int UniLib::BreakIterator::kDone; - -UniLib::BreakIterator::BreakIterator(const UnicodeText& text) - : text_(icu::UnicodeString::fromUTF8( - icu::StringPiece(text.data(), text.size_bytes()))), - last_break_index_(0), - last_unicode_index_(0) { - icu::ErrorCode status; - break_iterator_.reset( - icu::BreakIterator::createWordInstance(icu::Locale("en"), status)); - if (!status.isSuccess()) { - break_iterator_.reset(); - return; - } - break_iterator_->setText(text_); -} - -int UniLib::BreakIterator::Next() { - const int break_index = break_iterator_->next(); - if (break_index == icu::BreakIterator::DONE) { - return BreakIterator::kDone; - } - last_unicode_index_ += - text_.countChar32(last_break_index_, break_index - last_break_index_); - last_break_index_ = break_index; - return last_unicode_index_; -} - -std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern( - const UnicodeText& regex) const { - UErrorCode status = U_ZERO_ERROR; - std::unique_ptr<icu::RegexPattern> pattern( - icu::RegexPattern::compile(icu::UnicodeString::fromUTF8(icu::StringPiece( - regex.data(), regex.size_bytes())), - /*flags=*/UREGEX_MULTILINE, status)); - if (U_FAILURE(status) || !pattern) { - return nullptr; - } - return std::unique_ptr<UniLib::RegexPattern>( - new UniLib::RegexPattern(std::move(pattern))); -} - -std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator( - const UnicodeText& text) const { - return std::unique_ptr<UniLib::BreakIterator>( - new UniLib::BreakIterator(text)); -} - -} // namespace libtextclassifier2 diff --git a/util/base/casts.h b/utils/base/casts.h index a1d2056..175f56b 100644 --- a/util/base/casts.h +++ b/utils/base/casts.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_CASTS_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_CASTS_H_ #include <string.h> // for memcpy -namespace libtextclassifier2 { +namespace libtextclassifier3 { // bit_cast<Dest, Source> is a template function that implements the equivalent // of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level @@ -87,6 +87,6 @@ inline Dest bit_cast(const Source &source) { return dest; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_CASTS_H_ diff --git a/util/base/config.h b/utils/base/config.h index 8844b14..c476f13 100644 --- a/util/base/config.h +++ b/utils/base/config.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ // Define macros to indicate C++ standard / platform / etc we use. -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_CONFIG_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_CONFIG_H_ -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Define LANG_CXX11 to 1 if current compiler supports C++11. // @@ -38,6 +38,6 @@ namespace libtextclassifier2 { #define LANG_CXX11 1 #endif -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_CONFIG_H_ diff --git a/util/base/endian.h b/utils/base/endian.h index 2dfbfd6..9312704 100644 --- a/util/base/endian.h +++ b/utils/base/endian.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_ENDIAN_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_ENDIAN_H_ -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { #if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \ defined(__ANDROID__) @@ -133,6 +133,6 @@ class LittleEndian { #endif /* ENDIAN */ }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_ENDIAN_H_ diff --git a/util/base/integral_types.h b/utils/base/integral_types.h index f82c9cd..e3253de 100644 --- a/util/base/integral_types.h +++ b/utils/base/integral_types.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,12 @@ // Basic integer type definitions. -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_ -#include "util/base/config.h" +#include "utils/base/config.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { typedef unsigned int uint32; typedef unsigned long long uint64; @@ -56,6 +56,6 @@ static_assert(sizeof(char32) == 4, "wrong size"); static_assert(sizeof(int64) == 8, "wrong size"); #endif // LANG_CXX11 -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_ diff --git a/util/base/logging.cc b/utils/base/logging.cc index 919bb36..d7ddeb8 100644 --- a/util/base/logging.cc +++ b/utils/base/logging.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "util/base/logging.h" +#include "utils/base/logging.h" #include <stdlib.h> - +#include <exception> #include <iostream> -#include "util/base/logging_raw.h" +#include "utils/base/logging_raw.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace logging { namespace { @@ -59,9 +59,9 @@ LogMessage::LogMessage(LogSeverity severity, const char *file_name, LogMessage::~LogMessage() { LowLevelLogging(severity_, /* tag = */ "txtClsf", stream_.message); if (severity_ == FATAL) { - exit(1); + std::terminate(); // Will print a stacktrace (stdout or logcat). } } } // namespace logging -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/base/logging.h b/utils/base/logging.h index 4391d46..1267f5e 100644 --- a/util/base/logging.h +++ b/utils/base/logging.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,24 +14,24 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_ #include <cassert> #include <string> -#include "util/base/logging_levels.h" -#include "util/base/port.h" +#include "utils/base/logging_levels.h" +#include "utils/base/port.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace logging { // A tiny code footprint string stream for assembling log messages. struct LoggingStringStream { LoggingStringStream() {} LoggingStringStream &stream() { return *this; } - // Needed for invocation in TC_CHECK macro. + // Needed for invocation in TC3_CHECK macro. explicit operator bool() const { return true; } std::string message; @@ -64,8 +64,8 @@ inline LoggingStringStream &operator<<(LoggingStringStream &stream, return stream; } -// The class that does all the work behind our TC_LOG(severity) macros. Each -// TC_LOG(severity) << obj1 << obj2 << ...; logging statement creates a +// The class that does all the work behind our TC3_LOG(severity) macros. Each +// TC3_LOG(severity) << obj1 << obj2 << ...; logging statement creates a // LogMessage temporary object containing a stringstream. Each operator<< adds // info to that stringstream and the LogMessage destructor performs the actual // logging. The reason this works is that in C++, "all temporary objects are @@ -76,9 +76,9 @@ inline LoggingStringStream &operator<<(LoggingStringStream &stream, class LogMessage { public: LogMessage(LogSeverity severity, const char *file_name, - int line_number) TC_ATTRIBUTE_NOINLINE; + int line_number) TC3_ATTRIBUTE_NOINLINE; - ~LogMessage() TC_ATTRIBUTE_NOINLINE; + ~LogMessage() TC3_ATTRIBUTE_NOINLINE; // Returns the stream associated with the logger object. LoggingStringStream &stream() { return stream_; } @@ -104,64 +104,64 @@ inline NullStream &operator<<(NullStream &str, const T &) { } } // namespace logging -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#define TC_LOG(severity) \ - ::libtextclassifier2::logging::LogMessage( \ - ::libtextclassifier2::logging::severity, __FILE__, __LINE__) \ +#define TC3_LOG(severity) \ + ::libtextclassifier3::logging::LogMessage( \ + ::libtextclassifier3::logging::severity, __FILE__, __LINE__) \ .stream() // If condition x is true, does nothing. Otherwise, crashes the program (liek // LOG(FATAL)) with an informative message. Can be continued with extra // messages, via <<, like any logging macro, e.g., // -// TC_CHECK(my_cond) << "I think we hit a problem"; -#define TC_CHECK(x) \ - (x) || TC_LOG(FATAL) << __FILE__ << ":" << __LINE__ << ": check failed: \"" \ - << #x +// TC3_CHECK(my_cond) << "I think we hit a problem"; +#define TC3_CHECK(x) \ + (x) || TC3_LOG(FATAL) << __FILE__ << ":" << __LINE__ << ": check failed: \"" \ + << #x << "\" " -#define TC_CHECK_EQ(x, y) TC_CHECK((x) == (y)) -#define TC_CHECK_LT(x, y) TC_CHECK((x) < (y)) -#define TC_CHECK_GT(x, y) TC_CHECK((x) > (y)) -#define TC_CHECK_LE(x, y) TC_CHECK((x) <= (y)) -#define TC_CHECK_GE(x, y) TC_CHECK((x) >= (y)) -#define TC_CHECK_NE(x, y) TC_CHECK((x) != (y)) +#define TC3_CHECK_EQ(x, y) TC3_CHECK((x) == (y)) +#define TC3_CHECK_LT(x, y) TC3_CHECK((x) < (y)) +#define TC3_CHECK_GT(x, y) TC3_CHECK((x) > (y)) +#define TC3_CHECK_LE(x, y) TC3_CHECK((x) <= (y)) +#define TC3_CHECK_GE(x, y) TC3_CHECK((x) >= (y)) +#define TC3_CHECK_NE(x, y) TC3_CHECK((x) != (y)) -#define TC_NULLSTREAM ::libtextclassifier2::logging::NullStream().stream() +#define TC3_NULLSTREAM ::libtextclassifier3::logging::NullStream().stream() -// Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix> +// Debug checks: a TC3_DCHECK<suffix> macro should behave like TC3_CHECK<suffix> // in debug mode an don't check / don't print anything in non-debug mode. #ifdef NDEBUG -#define TC_DCHECK(x) TC_NULLSTREAM -#define TC_DCHECK_EQ(x, y) TC_NULLSTREAM -#define TC_DCHECK_LT(x, y) TC_NULLSTREAM -#define TC_DCHECK_GT(x, y) TC_NULLSTREAM -#define TC_DCHECK_LE(x, y) TC_NULLSTREAM -#define TC_DCHECK_GE(x, y) TC_NULLSTREAM -#define TC_DCHECK_NE(x, y) TC_NULLSTREAM +#define TC3_DCHECK(x) TC3_NULLSTREAM +#define TC3_DCHECK_EQ(x, y) TC3_NULLSTREAM +#define TC3_DCHECK_LT(x, y) TC3_NULLSTREAM +#define TC3_DCHECK_GT(x, y) TC3_NULLSTREAM +#define TC3_DCHECK_LE(x, y) TC3_NULLSTREAM +#define TC3_DCHECK_GE(x, y) TC3_NULLSTREAM +#define TC3_DCHECK_NE(x, y) TC3_NULLSTREAM #else // NDEBUG -// In debug mode, each TC_DCHECK<suffix> is equivalent to TC_CHECK<suffix>, +// In debug mode, each TC3_DCHECK<suffix> is equivalent to TC3_CHECK<suffix>, // i.e., a real check that crashes when the condition is not true. -#define TC_DCHECK(x) TC_CHECK(x) -#define TC_DCHECK_EQ(x, y) TC_CHECK_EQ(x, y) -#define TC_DCHECK_LT(x, y) TC_CHECK_LT(x, y) -#define TC_DCHECK_GT(x, y) TC_CHECK_GT(x, y) -#define TC_DCHECK_LE(x, y) TC_CHECK_LE(x, y) -#define TC_DCHECK_GE(x, y) TC_CHECK_GE(x, y) -#define TC_DCHECK_NE(x, y) TC_CHECK_NE(x, y) +#define TC3_DCHECK(x) TC3_CHECK(x) +#define TC3_DCHECK_EQ(x, y) TC3_CHECK_EQ(x, y) +#define TC3_DCHECK_LT(x, y) TC3_CHECK_LT(x, y) +#define TC3_DCHECK_GT(x, y) TC3_CHECK_GT(x, y) +#define TC3_DCHECK_LE(x, y) TC3_CHECK_LE(x, y) +#define TC3_DCHECK_GE(x, y) TC3_CHECK_GE(x, y) +#define TC3_DCHECK_NE(x, y) TC3_CHECK_NE(x, y) #endif // NDEBUG -#ifdef LIBTEXTCLASSIFIER_VLOG -#define TC_VLOG(severity) \ - ::libtextclassifier2::logging::LogMessage( \ - ::libtextclassifier2::logging::INFO, __FILE__, __LINE__) \ +#ifdef TC3_VLOG +#define TC3_VLOG(severity) \ + ::libtextclassifier3::logging::LogMessage( \ + ::libtextclassifier3::logging::INFO, __FILE__, __LINE__) \ .stream() #else -#define TC_VLOG(severity) TC_NULLSTREAM +#define TC3_VLOG(severity) TC3_NULLSTREAM #endif -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_ diff --git a/util/base/logging_levels.h b/utils/base/logging_levels.h index 17c882f..dfcb267 100644 --- a/util/base/logging_levels.h +++ b/utils/base/logging_levels.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_LEVELS_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_LEVELS_H_ -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace logging { enum LogSeverity { @@ -28,6 +28,6 @@ enum LogSeverity { }; } // namespace logging -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_LEVELS_H_ diff --git a/util/base/logging_raw.cc b/utils/base/logging_raw.cc index 6d97852..ccaef22 100644 --- a/util/base/logging_raw.cc +++ b/utils/base/logging_raw.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "util/base/logging_raw.h" +#include "utils/base/logging_raw.h" #include <stdio.h> #include <string> @@ -26,7 +26,7 @@ // Compiled as part of Android. #include <android/log.h> -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace logging { namespace { @@ -50,7 +50,7 @@ int GetAndroidLogLevel(LogSeverity severity) { void LowLevelLogging(LogSeverity severity, const std::string& tag, const std::string& message) { const int android_log_level = GetAndroidLogLevel(severity); -#if !defined(TC_DEBUG_LOGGING) +#if !defined(TC3_DEBUG_LOGGING) if (android_log_level != ANDROID_LOG_ERROR && android_log_level != ANDROID_LOG_FATAL) { return; @@ -60,12 +60,12 @@ void LowLevelLogging(LogSeverity severity, const std::string& tag, } } // namespace logging -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 #else // if defined(__ANDROID__) // Not on Android: implement LowLevelLogging to print to stderr (see below). -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace logging { namespace { @@ -94,6 +94,6 @@ void LowLevelLogging(LogSeverity severity, const std::string &tag, } } // namespace logging -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 #endif // if defined(__ANDROID__) diff --git a/util/base/logging_raw.h b/utils/base/logging_raw.h index e6265c7..be285ad 100644 --- a/util/base/logging_raw.h +++ b/utils/base/logging_raw.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_RAW_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_RAW_H_ #include <string> -#include "util/base/logging_levels.h" +#include "utils/base/logging_levels.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace logging { // Low-level logging primitive. Logs a message, with the indicated log @@ -31,6 +31,6 @@ void LowLevelLogging(LogSeverity severity, const std::string &tag, const std::string &message); } // namespace logging -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_RAW_H_ diff --git a/util/base/macros.h b/utils/base/macros.h index a021ab9..6739c0b 100644 --- a/util/base/macros.h +++ b/utils/base/macros.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,28 +14,28 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_ -#include "util/base/config.h" +#include "utils/base/config.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { #if LANG_CXX11 -#define TC_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName &) = delete; \ +#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName &) = delete; \ TypeName &operator=(const TypeName &) = delete #else // C++98 case follows // Note that these C++98 implementations cannot completely disallow copying, // as members and friends can still accidentally make elided copies without // triggering a linker error. -#define TC_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName &); \ +#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName &); \ TypeName &operator=(const TypeName &) #endif // LANG_CXX11 -// The TC_FALLTHROUGH_INTENDED macro can be used to annotate implicit +// The TC3_FALLTHROUGH_INTENDED macro can be used to annotate implicit // fall-through between switch labels: // // switch (x) { @@ -43,7 +43,7 @@ namespace libtextclassifier2 { // case 41: // if (truth_is_out_there) { // ++x; -// TC_FALLTHROUGH_INTENDED; // Use instead of/along with annotations in +// TC3_FALLTHROUGH_INTENDED; // Use instead of/along with annotations in // // comments. // } else { // return x; @@ -51,35 +51,37 @@ namespace libtextclassifier2 { // case 42: // ... // -// As shown in the example above, the TC_FALLTHROUGH_INTENDED macro should be +// As shown in the example above, the TC3_FALLTHROUGH_INTENDED macro should be // followed by a semicolon. It is designed to mimic control-flow statements // like 'break;', so it can be placed in most places where 'break;' can, but // only if there are no statements on the execution path between it and the // next switch label. // -// When compiled with clang in C++11 mode, the TC_FALLTHROUGH_INTENDED macro is -// expanded to [[clang::fallthrough]] attribute, which is analysed when +// When compiled with clang in C++11 mode, the TC3_FALLTHROUGH_INTENDED macro +// is expanded to [[clang::fallthrough]] attribute, which is analysed when // performing switch labels fall-through diagnostic ('-Wimplicit-fallthrough'). // See clang documentation on language extensions for details: // http://clang.llvm.org/docs/AttributeReference.html#fallthrough-clang-fallthrough // -// When used with unsupported compilers, the TC_FALLTHROUGH_INTENDED macro has +// When used with unsupported compilers, the TC3_FALLTHROUGH_INTENDED macro has // no effect on diagnostics. // // In either case this macro has no effect on runtime behavior and performance // of code. #if defined(__clang__) && defined(__has_warning) #if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") -#define TC_FALLTHROUGH_INTENDED [[clang::fallthrough]] +#define TC3_FALLTHROUGH_INTENDED [[clang::fallthrough]] #endif #elif defined(__GNUC__) && __GNUC__ >= 7 -#define TC_FALLTHROUGH_INTENDED [[gnu::fallthrough]] +#define TC3_FALLTHROUGH_INTENDED [[gnu::fallthrough]] #endif -#ifndef TC_FALLTHROUGH_INTENDED -#define TC_FALLTHROUGH_INTENDED do { } while (0) +#ifndef TC3_FALLTHROUGH_INTENDED +#define TC3_FALLTHROUGH_INTENDED \ + do { \ + } while (0) #endif -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_ diff --git a/util/base/port.h b/utils/base/port.h index 90a2bce..24344a0 100644 --- a/util/base/port.h +++ b/utils/base/port.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,30 +16,30 @@ // Various portability macros, type definitions, and inline functions. -#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_ -#define LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_PORT_H_ +#define LIBTEXTCLASSIFIER_UTILS_BASE_PORT_H_ -namespace libtextclassifier2 { +namespace libtextclassifier3 { #if defined(__GNUC__) && \ (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) // For functions we want to force inline. // Introduced in gcc 3.1. -#define TC_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#define TC3_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) // For functions we don't want to inline, e.g., to keep code size small. -#define TC_ATTRIBUTE_NOINLINE __attribute__((noinline)) +#define TC3_ATTRIBUTE_NOINLINE __attribute__((noinline)) #elif defined(_MSC_VER) -#define TC_ATTRIBUTE_ALWAYS_INLINE __forceinline +#define TC3_ATTRIBUTE_ALWAYS_INLINE __forceinline #else // Other compilers will have to figure it out for themselves. -#define TC_ATTRIBUTE_ALWAYS_INLINE -#define TC_ATTRIBUTE_NOINLINE +#define TC3_ATTRIBUTE_ALWAYS_INLINE +#define TC3_ATTRIBUTE_NOINLINE #endif -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_BASE_PORT_H_ diff --git a/utils/calendar/calendar-common.h b/utils/calendar/calendar-common.h new file mode 100644 index 0000000..7e606de --- /dev/null +++ b/utils/calendar/calendar-common.h @@ -0,0 +1,278 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_ +#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_ + +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/base/logging.h" +#include "utils/base/macros.h" + +namespace libtextclassifier3 { +namespace calendar { + +// Macro to reduce the amount of boilerplate needed for propagating errors. +#define TC3_CALENDAR_CHECK(EXPR) \ + if (!(EXPR)) { \ + return false; \ + } + +// An implementation of CalendarLib that is independent of the particular +// calendar implementation used (implementation type is passed as template +// argument). +template <class TCalendar> +class CalendarLibTempl { + public: + bool InterpretParseData(const DateParseData& parse_data, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + DatetimeGranularity granularity, + TCalendar* calendar) const; + + private: + // Adjusts the calendar's time instant according to a relative date reference + // in the parsed data. + bool ApplyRelationField(const DateParseData& parse_data, + TCalendar* calendar) const; + + // Round the time instant's precision down to the given granularity. + bool RoundToGranularity(DatetimeGranularity granularity, + TCalendar* calendar) const; + + // Adjusts time in steps of relation_type, by distance steps. + // For example: + // - Adjusting by -2 MONTHS will return the beginning of the 1st + // two weeks ago. + // - Adjusting by +4 Wednesdays will return the beginning of the next + // Wednesday at least 4 weeks from now. + // If allow_today is true, the same day of the week may be kept + // if it already matches the relation type. + bool AdjustByRelation(DateParseData::RelationType relation_type, int distance, + bool allow_today, TCalendar* calendar) const; +}; + +template <class TCalendar> +bool CalendarLibTempl<TCalendar>::InterpretParseData( + const DateParseData& parse_data, int64 reference_time_ms_utc, + const std::string& reference_timezone, const std::string& reference_locale, + DatetimeGranularity granularity, TCalendar* calendar) const { + TC3_CALENDAR_CHECK(calendar->Initialize(reference_timezone, reference_locale, + reference_time_ms_utc)) + + // By default, the parsed time is interpreted to be on the reference day. + // But a parsed date should have time 0:00:00 unless specified. + TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0)) + TC3_CALENDAR_CHECK(calendar->SetMinute(0)) + TC3_CALENDAR_CHECK(calendar->SetSecond(0)) + TC3_CALENDAR_CHECK(calendar->SetMillisecond(0)) + + // Apply each of the parsed fields in order of increasing granularity. + static const int64 kMillisInHour = 1000 * 60 * 60; + if (parse_data.field_set_mask & DateParseData::Fields::ZONE_OFFSET_FIELD) { + TC3_CALENDAR_CHECK( + calendar->SetZoneOffset(parse_data.zone_offset * kMillisInHour)) + } + if (parse_data.field_set_mask & DateParseData::Fields::DST_OFFSET_FIELD) { + TC3_CALENDAR_CHECK( + calendar->SetDstOffset(parse_data.dst_offset * kMillisInHour)) + } + if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) { + TC3_CALENDAR_CHECK(ApplyRelationField(parse_data, calendar)); + } + if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) { + TC3_CALENDAR_CHECK(calendar->SetYear(parse_data.year)) + } + if (parse_data.field_set_mask & DateParseData::Fields::MONTH_FIELD) { + // ICU has months starting at 0, Java and Datetime parser at 1, so we + // need to subtract 1. + TC3_CALENDAR_CHECK(calendar->SetMonth(parse_data.month - 1)) + } + if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) { + TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(parse_data.day_of_month)) + } + if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) { + if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD && + parse_data.ampm == DateParseData::AMPM::PM && parse_data.hour < 12) { + TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour + 12)) + } else { + TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour)) + } + } + if (parse_data.field_set_mask & DateParseData::Fields::MINUTE_FIELD) { + TC3_CALENDAR_CHECK(calendar->SetMinute(parse_data.minute)) + } + if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) { + TC3_CALENDAR_CHECK(calendar->SetSecond(parse_data.second)) + } + + TC3_CALENDAR_CHECK(RoundToGranularity(granularity, calendar)) + return true; +} + +template <class TCalendar> +bool CalendarLibTempl<TCalendar>::ApplyRelationField( + const DateParseData& parse_data, TCalendar* calendar) const { + constexpr int relation_type_mask = DateParseData::Fields::RELATION_TYPE_FIELD; + constexpr int relation_distance_mask = + DateParseData::Fields::RELATION_DISTANCE_FIELD; + switch (parse_data.relation) { + case DateParseData::Relation::NEXT: + if (parse_data.field_set_mask & relation_type_mask) { + TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type, + /*distance=*/1, + /*allow_today=*/false, calendar)); + } + return true; + case DateParseData::Relation::NEXT_OR_SAME: + if (parse_data.field_set_mask & relation_type_mask) { + TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type, + /*distance=*/1, + /*allow_today=*/true, calendar)) + } + return true; + case DateParseData::Relation::LAST: + if (parse_data.field_set_mask & relation_type_mask) { + TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type, + /*distance=*/-1, + /*allow_today=*/false, calendar)) + } + return true; + case DateParseData::Relation::NOW: + return true; // NOOP + case DateParseData::Relation::TOMORROW: + TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(1)); + return true; + case DateParseData::Relation::YESTERDAY: + TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(-1)); + return true; + case DateParseData::Relation::PAST: + if ((parse_data.field_set_mask & relation_type_mask) && + (parse_data.field_set_mask & relation_distance_mask)) { + TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type, + -parse_data.relation_distance, + /*allow_today=*/false, calendar)) + } + return true; + case DateParseData::Relation::FUTURE: + if ((parse_data.field_set_mask & relation_type_mask) && + (parse_data.field_set_mask & relation_distance_mask)) { + TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type, + parse_data.relation_distance, + /*allow_today=*/false, calendar)) + } + return true; + } + return false; +} + +template <class TCalendar> +bool CalendarLibTempl<TCalendar>::RoundToGranularity( + DatetimeGranularity granularity, TCalendar* calendar) const { + // Force recomputation before doing the rounding. + int unused; + TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&unused)); + + switch (granularity) { + case GRANULARITY_YEAR: + TC3_CALENDAR_CHECK(calendar->SetMonth(0)); + TC3_FALLTHROUGH_INTENDED; + case GRANULARITY_MONTH: + TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1)); + TC3_FALLTHROUGH_INTENDED; + case GRANULARITY_DAY: + TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0)); + TC3_FALLTHROUGH_INTENDED; + case GRANULARITY_HOUR: + TC3_CALENDAR_CHECK(calendar->SetMinute(0)); + TC3_FALLTHROUGH_INTENDED; + case GRANULARITY_MINUTE: + TC3_CALENDAR_CHECK(calendar->SetSecond(0)); + break; + + case GRANULARITY_WEEK: + int first_day_of_week; + TC3_CALENDAR_CHECK(calendar->GetFirstDayOfWeek(&first_day_of_week)); + TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(first_day_of_week)); + TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0)); + TC3_CALENDAR_CHECK(calendar->SetMinute(0)); + TC3_CALENDAR_CHECK(calendar->SetSecond(0)); + break; + + case GRANULARITY_UNKNOWN: + case GRANULARITY_SECOND: + break; + } + return true; +} + +template <class TCalendar> +bool CalendarLibTempl<TCalendar>::AdjustByRelation( + DateParseData::RelationType relation_type, int distance, bool allow_today, + TCalendar* calendar) const { + const int distance_sign = distance < 0 ? -1 : 1; + switch (relation_type) { + case DateParseData::MONDAY: + case DateParseData::TUESDAY: + case DateParseData::WEDNESDAY: + case DateParseData::THURSDAY: + case DateParseData::FRIDAY: + case DateParseData::SATURDAY: + case DateParseData::SUNDAY: + if (!allow_today) { + // If we're not including the same day as the reference, skip it. + TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign)) + } + // Keep walking back until we hit the desired day of the week. + while (distance != 0) { + int day_of_week; + TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&day_of_week)) + if (day_of_week == relation_type) { + distance += -distance_sign; + if (distance == 0) break; + } + TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign)) + } + return true; + case DateParseData::DAY: + TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance)); + return true; + case DateParseData::WEEK: + TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(7 * distance)) + TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(1)) + return true; + case DateParseData::MONTH: + TC3_CALENDAR_CHECK(calendar->AddMonth(distance)) + TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1)) + return true; + case DateParseData::YEAR: + TC3_CALENDAR_CHECK(calendar->AddYear(distance)) + TC3_CALENDAR_CHECK(calendar->SetDayOfYear(1)) + return true; + default: + return false; + } + return false; +} + +}; // namespace calendar + +#undef TC3_CALENDAR_CHECK + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_ diff --git a/utils/calendar/calendar-javaicu.cc b/utils/calendar/calendar-javaicu.cc new file mode 100644 index 0000000..7b7f2fa --- /dev/null +++ b/utils/calendar/calendar-javaicu.cc @@ -0,0 +1,190 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/calendar/calendar-javaicu.h" + +#include "annotator/types.h" +#include "utils/java/scoped_local_ref.h" + +namespace libtextclassifier3 { +namespace { + +// Generic version of icu::Calendar::add with error checking. +bool CalendarAdd(JniCache* jni_cache, JNIEnv* jenv, jobject calendar, + jint field, jint value) { + jenv->CallVoidMethod(calendar, jni_cache->calendar_add, field, value); + return !jni_cache->ExceptionCheckAndClear(); +} + +// Generic version of icu::Calendar::get with error checking. +bool CalendarGet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar, + jint field, jint* value) { + *value = jenv->CallIntMethod(calendar, jni_cache->calendar_get, field); + return !jni_cache->ExceptionCheckAndClear(); +} + +// Generic version of icu::Calendar::set with error checking. +bool CalendarSet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar, + jint field, jint value) { + jenv->CallVoidMethod(calendar, jni_cache->calendar_set, field, value); + return !jni_cache->ExceptionCheckAndClear(); +} + +// Extracts the first tag from a BCP47 tag (e.g. "en" for "en-US"). +std::string GetFirstBcp47Tag(const std::string& tag) { + for (size_t i = 0; i < tag.size(); ++i) { + if (tag[i] == '_' || tag[i] == '-') { + return std::string(tag, 0, i); + } + } + return tag; +} + +} // anonymous namespace + +Calendar::Calendar(JniCache* jni_cache) + : jni_cache_(jni_cache), + jenv_(jni_cache_ ? jni_cache->GetEnv() : nullptr) {} + +bool Calendar::Initialize(const std::string& time_zone, + const std::string& locale, int64 time_ms_utc) { + if (!jni_cache_ || !jenv_) { + TC3_LOG(ERROR) << "Initialize without env"; + return false; + } + + // We'll assume the day indices match later on, so verify it here. + if (jni_cache_->calendar_sunday != DateParseData::SUNDAY || + jni_cache_->calendar_monday != DateParseData::MONDAY || + jni_cache_->calendar_tuesday != DateParseData::TUESDAY || + jni_cache_->calendar_wednesday != DateParseData::WEDNESDAY || + jni_cache_->calendar_thursday != DateParseData::THURSDAY || + jni_cache_->calendar_friday != DateParseData::FRIDAY || + jni_cache_->calendar_saturday != DateParseData::SATURDAY) { + TC3_LOG(ERROR) << "day of the week indices mismatch"; + return false; + } + + // Get the time zone. + ScopedLocalRef<jstring> java_time_zone_str( + jenv_->NewStringUTF(time_zone.c_str())); + ScopedLocalRef<jobject> java_time_zone(jenv_->CallStaticObjectMethod( + jni_cache_->timezone_class.get(), jni_cache_->timezone_get_timezone, + java_time_zone_str.get())); + if (jni_cache_->ExceptionCheckAndClear() || !java_time_zone) { + TC3_LOG(ERROR) << "failed to get timezone"; + return false; + } + + // Get the locale. + ScopedLocalRef<jobject> java_locale; + if (jni_cache_->locale_for_language_tag) { + // API level 21+, we can actually parse language tags. + ScopedLocalRef<jstring> java_locale_str( + jenv_->NewStringUTF(locale.c_str())); + java_locale.reset(jenv_->CallStaticObjectMethod( + jni_cache_->locale_class.get(), jni_cache_->locale_for_language_tag, + java_locale_str.get())); + } else { + // API level <21. We can't parse tags, so we just use the language. + ScopedLocalRef<jstring> java_language_str( + jenv_->NewStringUTF(GetFirstBcp47Tag(locale).c_str())); + java_locale.reset(jenv_->NewObject(jni_cache_->locale_class.get(), + jni_cache_->locale_init_string, + java_language_str.get())); + } + if (jni_cache_->ExceptionCheckAndClear() || !java_locale) { + TC3_LOG(ERROR) << "failed to get locale"; + return false; + } + + // Get the calendar. + calendar_.reset(jenv_->CallStaticObjectMethod( + jni_cache_->calendar_class.get(), jni_cache_->calendar_get_instance, + java_time_zone.get(), java_locale.get())); + if (jni_cache_->ExceptionCheckAndClear() || !calendar_) { + TC3_LOG(ERROR) << "failed to get calendar"; + return false; + } + + // Set the time. + jenv_->CallVoidMethod(calendar_.get(), + jni_cache_->calendar_set_time_in_millis, time_ms_utc); + if (jni_cache_->ExceptionCheckAndClear()) { + TC3_LOG(ERROR) << "failed to set time"; + return false; + } + return true; +} + +bool Calendar::GetFirstDayOfWeek(int* value) const { + if (!jni_cache_ || !jenv_ || !calendar_) return false; + *value = jenv_->CallIntMethod(calendar_.get(), + jni_cache_->calendar_get_first_day_of_week); + return !jni_cache_->ExceptionCheckAndClear(); +} + +bool Calendar::GetTimeInMillis(int64* value) const { + if (!jni_cache_ || !jenv_ || !calendar_) return false; + *value = jenv_->CallLongMethod(calendar_.get(), + jni_cache_->calendar_get_time_in_millis); + return !jni_cache_->ExceptionCheckAndClear(); +} + +CalendarLib::CalendarLib() { + TC3_LOG(FATAL) << "Java ICU CalendarLib must be initialized with a JniCache."; +} + +CalendarLib::CalendarLib(const std::shared_ptr<JniCache>& jni_cache) + : jni_cache_(jni_cache) {} + +// Below is the boilerplate code for implementing the specialisations of +// get/set/add for the various field types. +#define TC3_DEFINE_FIELD_ACCESSOR(NAME, FIELD, KIND, TYPE) \ + bool Calendar::KIND##NAME(TYPE value) const { \ + if (!jni_cache_ || !jenv_ || !calendar_) return false; \ + return Calendar##KIND(jni_cache_, jenv_, calendar_.get(), \ + jni_cache_->calendar_##FIELD, value); \ + } +#define TC3_DEFINE_ADD(NAME, CONST) \ + TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Add, int) +#define TC3_DEFINE_SET(NAME, CONST) \ + TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Set, int) +#define TC3_DEFINE_GET(NAME, CONST) \ + TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Get, int*) + +TC3_DEFINE_ADD(DayOfMonth, day_of_month) +TC3_DEFINE_ADD(Year, year) +TC3_DEFINE_ADD(Month, month) +TC3_DEFINE_GET(DayOfWeek, day_of_week) +TC3_DEFINE_SET(ZoneOffset, zone_offset) +TC3_DEFINE_SET(DstOffset, dst_offset) +TC3_DEFINE_SET(Year, year) +TC3_DEFINE_SET(Month, month) +TC3_DEFINE_SET(DayOfYear, day_of_year) +TC3_DEFINE_SET(DayOfMonth, day_of_month) +TC3_DEFINE_SET(DayOfWeek, day_of_week) +TC3_DEFINE_SET(HourOfDay, hour_of_day) +TC3_DEFINE_SET(Minute, minute) +TC3_DEFINE_SET(Second, second) +TC3_DEFINE_SET(Millisecond, millisecond) + +#undef TC3_DEFINE_FIELD_ACCESSOR +#undef TC3_DEFINE_ADD +#undef TC3_DEFINE_SET +#undef TC3_DEFINE_GET + +} // namespace libtextclassifier3 diff --git a/utils/calendar/calendar-javaicu.h b/utils/calendar/calendar-javaicu.h new file mode 100644 index 0000000..88e696a --- /dev/null +++ b/utils/calendar/calendar-javaicu.h @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_ +#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_ + +#include <jni.h> +#include <memory> +#include <string> + +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/calendar/calendar-common.h" +#include "utils/java/jni-cache.h" +#include "utils/java/scoped_local_ref.h" + +namespace libtextclassifier3 { + +class Calendar { + public: + explicit Calendar(JniCache* jni_cache); + bool Initialize(const std::string& time_zone, const std::string& locale, + int64 time_ms_utc); + bool AddDayOfMonth(int value) const; + bool AddYear(int value) const; + bool AddMonth(int value) const; + bool GetDayOfWeek(int* value) const; + bool GetFirstDayOfWeek(int* value) const; + bool GetTimeInMillis(int64* value) const; + bool SetZoneOffset(int value) const; + bool SetDstOffset(int value) const; + bool SetYear(int value) const; + bool SetMonth(int value) const; + bool SetDayOfYear(int value) const; + bool SetDayOfMonth(int value) const; + bool SetDayOfWeek(int value) const; + bool SetHourOfDay(int value) const; + bool SetMinute(int value) const; + bool SetSecond(int value) const; + bool SetMillisecond(int value) const; + + private: + JniCache* jni_cache_; + JNIEnv* jenv_; + ScopedLocalRef<jobject> calendar_; +}; + +class CalendarLib { + public: + CalendarLib(); + explicit CalendarLib(const std::shared_ptr<JniCache>& jni_cache); + + // Returns false (dummy version). + bool InterpretParseData(const DateParseData& parse_data, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + DatetimeGranularity granularity, + int64* interpreted_time_ms_utc) const { + Calendar calendar(jni_cache_.get()); + calendar::CalendarLibTempl<Calendar> impl; + if (!impl.InterpretParseData(parse_data, reference_time_ms_utc, + reference_timezone, reference_locale, + granularity, &calendar)) { + return false; + } + return calendar.GetTimeInMillis(interpreted_time_ms_utc); + } + + private: + std::shared_ptr<JniCache> jni_cache_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_ diff --git a/utils/calendar/calendar.h b/utils/calendar/calendar.h new file mode 100644 index 0000000..99b137f --- /dev/null +++ b/utils/calendar/calendar.h @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_ +#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_ + +#include "utils/calendar/calendar-javaicu.h" +#define INIT_CALENDARLIB_FOR_TESTING(VAR) VAR(nullptr) + +#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_ diff --git a/utils/calendar/calendar_test.cc b/utils/calendar/calendar_test.cc new file mode 100644 index 0000000..a8c3af8 --- /dev/null +++ b/utils/calendar/calendar_test.cc @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This test serves the purpose of making sure all the different implementations +// of the unspoken CalendarLib interface support the same methods. + +#include "utils/calendar/calendar.h" +#include "utils/base/logging.h" + +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +class CalendarTest : public ::testing::Test { + protected: + CalendarTest() : INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {} + CalendarLib calendarlib_; +}; + +TEST_F(CalendarTest, Interface) { + int64 time; + std::string timezone; + bool result = calendarlib_.InterpretParseData( + DateParseData{/*field_set_mask=*/0, /*year=*/0, /*month=*/0, + /*day_of_month=*/0, /*hour=*/0, /*minute=*/0, /*second=*/0, + /*ampm=*/0, /*zone_offset=*/0, /*dst_offset=*/0, + static_cast<DateParseData::Relation>(0), + static_cast<DateParseData::RelationType>(0), + /*relation_distance=*/0}, + 0L, "Zurich", "en-CH", GRANULARITY_UNKNOWN, &time); + TC3_LOG(INFO) << result; +} + +#ifdef TC3_CALENDAR_ICU +TEST_F(CalendarTest, RoundingToGranularity) { + int64 time; + DateParseData data; + data.year = 2018; + data.month = 4; + data.day_of_month = 25; + data.hour = 9; + data.minute = 33; + data.second = 59; + data.field_set_mask = DateParseData::YEAR_FIELD | DateParseData::MONTH_FIELD | + DateParseData::DAY_FIELD | DateParseData::HOUR_FIELD | + DateParseData::MINUTE_FIELD | + DateParseData::SECOND_FIELD; + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_YEAR, &time)); + EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_MONTH, &time)); + EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"*-CH", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"*-US", + /*granularity=*/GRANULARITY_WEEK, &time)); + EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_HOUR, &time)); + EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_MINUTE, &time)); + EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */); + + ASSERT_TRUE(calendarlib_.InterpretParseData( + data, + /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-CH", + /*granularity=*/GRANULARITY_SECOND, &time)); + EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */); +} + +TEST_F(CalendarTest, RelativeTimeWeekday) { + const int field_mask = DateParseData::RELATION_FIELD | + DateParseData::RELATION_TYPE_FIELD | + DateParseData::RELATION_DISTANCE_FIELD; + const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */ + int64 time; + + // Two Weds from now. + const DateParseData future_wed_parse = { + field_mask, + /*year=*/0, + /*month=*/0, + /*day_of_month=*/0, + /*hour=*/0, + /*minute=*/0, + /*second=*/0, + /*ampm=*/0, + /*zone_offset=*/0, + /*dst_offset=*/0, + DateParseData::Relation::FUTURE, + DateParseData::RelationType::WEDNESDAY, + /*relation_distance=*/2}; + ASSERT_TRUE(calendarlib_.InterpretParseData( + future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1525816800000L /* 9 May 2018 00:00:00 */); + + // Next Wed. + const DateParseData next_wed_parse = {field_mask, + /*year=*/0, + /*month=*/0, + /*day_of_month=*/0, + /*hour=*/0, + /*minute=*/0, + /*second=*/0, + /*ampm=*/0, + /*zone_offset=*/0, + /*dst_offset=*/0, + DateParseData::Relation::NEXT, + DateParseData::RelationType::WEDNESDAY, + /*relation_distance=*/0}; + ASSERT_TRUE(calendarlib_.InterpretParseData( + next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1525212000000L /* 1 May 2018 00:00:00 */); + + // Same Wed. + const DateParseData same_wed_parse = {field_mask, + /*year=*/0, + /*month=*/0, + /*day_of_month=*/0, + /*hour=*/0, + /*minute=*/0, + /*second=*/0, + /*ampm=*/0, + /*zone_offset=*/0, + /*dst_offset=*/0, + DateParseData::Relation::NEXT_OR_SAME, + DateParseData::RelationType::WEDNESDAY, + /*relation_distance=*/0}; + ASSERT_TRUE(calendarlib_.InterpretParseData( + same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1524607200000L /* 25 April 2018 00:00:00 */); + + // Previous Wed. + const DateParseData last_wed_parse = {field_mask, + /*year=*/0, + /*month=*/0, + /*day_of_month=*/0, + /*hour=*/0, + /*minute=*/0, + /*second=*/0, + /*ampm=*/0, + /*zone_offset=*/0, + /*dst_offset=*/0, + DateParseData::Relation::LAST, + DateParseData::RelationType::WEDNESDAY, + /*relation_distance=*/0}; + ASSERT_TRUE(calendarlib_.InterpretParseData( + last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1524002400000L /* 18 April 2018 00:00:00 */); + + // Two Weds ago. + const DateParseData past_wed_parse = {field_mask, + /*year=*/0, + /*month=*/0, + /*day_of_month=*/0, + /*hour=*/0, + /*minute=*/0, + /*second=*/0, + /*ampm=*/0, + /*zone_offset=*/0, + /*dst_offset=*/0, + DateParseData::Relation::PAST, + DateParseData::RelationType::WEDNESDAY, + /*relation_distance=*/2}; + ASSERT_TRUE(calendarlib_.InterpretParseData( + past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich", + /*reference_locale=*/"en-US", + /*granularity=*/GRANULARITY_DAY, &time)); + EXPECT_EQ(time, 1523397600000L /* 11 April 2018 00:00:00 */); +} +#endif // TC3_UNILIB_DUMMY + +} // namespace +} // namespace libtextclassifier3 diff --git a/utils/checksum.cc b/utils/checksum.cc new file mode 100644 index 0000000..87b2d37 --- /dev/null +++ b/utils/checksum.cc @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/checksum.h" +#include "utils/strings/numbers.h" + +namespace libtextclassifier3 { + +bool VerifyLuhnChecksum(const std::string& input, bool ignore_whitespace) { + int sum = 0; + int num_digits = 0; + bool is_odd = true; + + // http://en.wikipedia.org/wiki/Luhn_algorithm + static const int kPrecomputedSumsOfDoubledDigits[] = {0, 2, 4, 6, 8, + 1, 3, 5, 7, 9}; + for (int i = input.size() - 1; i >= 0; i--) { + const char c = input[i]; + if (ignore_whitespace && c == ' ') { + continue; + } + if (!isdigit(c)) { + return false; + } + ++num_digits; + const int digit = c - '0'; + if (is_odd) { + sum += digit; + } else { + sum += kPrecomputedSumsOfDoubledDigits[digit]; + } + is_odd = !is_odd; + } + return (num_digits > 1 && sum % 10 == 0); +} + +} // namespace libtextclassifier3 diff --git a/utils/checksum.h b/utils/checksum.h new file mode 100644 index 0000000..2f94219 --- /dev/null +++ b/utils/checksum.h @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Utility functions for calculating and verifying checksums. + +#ifndef LIBTEXTCLASSIFIER_UTILS_CHECKSUM_H_ +#define LIBTEXTCLASSIFIER_UTILS_CHECKSUM_H_ + +#include <string> + +namespace libtextclassifier3 { + +// Computes and verifies that the last digit of `input` matches the Luhn +// checksum. Returns false if presented with non-digits, or on whitespace +// characters if `ignore_whitespace` is false. +bool VerifyLuhnChecksum(const std::string& input, + bool ignore_whitespace = true); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_CHECKSUM_H_ diff --git a/utils/checksum_test.cc b/utils/checksum_test.cc new file mode 100644 index 0000000..dd04956 --- /dev/null +++ b/utils/checksum_test.cc @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/checksum.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +TEST(LuhnTest, CorrectlyHandlesSimpleCases) { + EXPECT_TRUE(VerifyLuhnChecksum("3782 8224 6310 005")); + EXPECT_FALSE(VerifyLuhnChecksum("0")); + EXPECT_FALSE(VerifyLuhnChecksum("1")); + EXPECT_FALSE(VerifyLuhnChecksum("0A")); +} + +TEST(LuhnTest, CorrectlyVerifiesPaymentCardNumbers) { + // Fake test numbers. + EXPECT_TRUE(VerifyLuhnChecksum("3782 8224 6310 005")); + EXPECT_TRUE(VerifyLuhnChecksum("371449635398431")); + EXPECT_TRUE(VerifyLuhnChecksum("5610591081018250")); + EXPECT_TRUE(VerifyLuhnChecksum("38520000023237")); + EXPECT_TRUE(VerifyLuhnChecksum("6011000990139424")); + EXPECT_TRUE(VerifyLuhnChecksum("3566002020360505")); + EXPECT_TRUE(VerifyLuhnChecksum("5105105105105100")); + EXPECT_TRUE(VerifyLuhnChecksum("4012 8888 8888 1881")); +} + +TEST(LuhnTest, HandlesWhitespace) { + EXPECT_TRUE( + VerifyLuhnChecksum("3782 8224 6310 005 ", /*ignore_whitespace=*/true)); + EXPECT_FALSE( + VerifyLuhnChecksum("3782 8224 6310 005 ", /*ignore_whitespace=*/false)); +} + +TEST(LuhnTest, HandlesEdgeCases) { + EXPECT_FALSE(VerifyLuhnChecksum(" ", /*ignore_whitespace=*/true)); + EXPECT_FALSE(VerifyLuhnChecksum(" ", /*ignore_whitespace=*/false)); + EXPECT_FALSE(VerifyLuhnChecksum("", /*ignore_whitespace=*/true)); + EXPECT_FALSE(VerifyLuhnChecksum("", /*ignore_whitespace=*/false)); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/util/flatbuffers.cc b/utils/flatbuffers.cc index 6c0108e..c1c2625 100644 --- a/util/flatbuffers.cc +++ b/utils/flatbuffers.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "util/flatbuffers.h" +#include "utils/flatbuffers.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { template <> const char* FlatbufferFileIdentifier<Model>() { return ModelIdentifier(); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/flatbuffers.h b/utils/flatbuffers.h index 93d73b6..4031f89 100644 --- a/util/flatbuffers.h +++ b/utils/flatbuffers.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,16 @@ // Utility functions for working with FlatBuffers. -#ifndef LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_ -#define LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_ +#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_ #include <memory> #include <string> -#include "model_generated.h" +#include "annotator/model_generated.h" #include "flatbuffers/flatbuffers.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Loads and interprets the buffer as 'FlatbufferMessage' and verifies its // integrity. @@ -93,6 +93,6 @@ std::string PackFlatbuffer( builder.GetSize()); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_ diff --git a/util/hash/farmhash.cc b/utils/hash/farmhash.cc index 673f45f..884a561 100644 --- a/util/hash/farmhash.cc +++ b/utils/hash/farmhash.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "util/hash/farmhash.h" +#include "utils/hash/farmhash.h" // FARMHASH ASSUMPTIONS: Modify as needed, or use -DFARMHASH_ASSUME_SSE42 etc. // Note that if you use -DFARMHASH_ASSUME_SSE42 you likely need -msse42 @@ -136,8 +136,8 @@ #elif defined(__NetBSD__) -#include <sys/types.h> #include <machine/bswap.h> +#include <sys/types.h> #if defined(__BSWAP_RENAME) && !defined(__bswap_32) #undef bswap_32 #undef bswap_64 @@ -1509,9 +1509,9 @@ uint128_t Fingerprint128(const char* s, size_t len) { #ifndef FARMHASH_SELF_TEST_GUARD #define FARMHASH_SELF_TEST_GUARD +#include <string.h> #include <cstdio> #include <iostream> -#include <string.h> using std::cout; using std::cerr; @@ -3129,9 +3129,9 @@ int main(int argc, char** argv) { #endif #ifndef FARMHASH_SELF_TEST_GUARD #define FARMHASH_SELF_TEST_GUARD +#include <string.h> #include <cstdio> #include <iostream> -#include <string.h> using std::cout; using std::cerr; @@ -4021,9 +4021,9 @@ int main(int argc, char** argv) { #endif #ifndef FARMHASH_SELF_TEST_GUARD #define FARMHASH_SELF_TEST_GUARD +#include <string.h> #include <cstdio> #include <iostream> -#include <string.h> using std::cout; using std::cerr; @@ -5277,9 +5277,9 @@ int main(int argc, char** argv) { #endif #ifndef FARMHASH_SELF_TEST_GUARD #define FARMHASH_SELF_TEST_GUARD +#include <string.h> #include <cstdio> #include <iostream> -#include <string.h> using std::cout; using std::cerr; @@ -6169,9 +6169,9 @@ int main(int argc, char** argv) { #endif #ifndef FARMHASH_SELF_TEST_GUARD #define FARMHASH_SELF_TEST_GUARD +#include <string.h> #include <cstdio> #include <iostream> -#include <string.h> using std::cout; using std::cerr; @@ -7061,9 +7061,9 @@ int main(int argc, char** argv) { #endif #ifndef FARMHASH_SELF_TEST_GUARD #define FARMHASH_SELF_TEST_GUARD +#include <string.h> #include <cstdio> #include <iostream> -#include <string.h> using std::cout; using std::cerr; diff --git a/util/hash/farmhash.h b/utils/hash/farmhash.h index 477b7a8..f374c0b 100644 --- a/util/hash/farmhash.h +++ b/utils/hash/farmhash.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_ -#define LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_HASH_FARMHASH_H_ +#define LIBTEXTCLASSIFIER_UTILS_HASH_FARMHASH_H_ #include <assert.h> #include <stdint.h> @@ -24,7 +24,7 @@ #include <utility> #ifndef NAMESPACE_FOR_HASH_FUNCTIONS -#define NAMESPACE_FOR_HASH_FUNCTIONS tc2farmhash +#define NAMESPACE_FOR_HASH_FUNCTIONS tc3farmhash #endif namespace NAMESPACE_FOR_HASH_FUNCTIONS { @@ -261,4 +261,4 @@ inline uint128_t Fingerprint128(const Str& s) { } // namespace NAMESPACE_FOR_HASH_FUNCTIONS -#endif // LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_HASH_FARMHASH_H_ diff --git a/util/i18n/locale.cc b/utils/i18n/locale.cc index c587d2d..acd0379 100644 --- a/util/i18n/locale.cc +++ b/utils/i18n/locale.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "util/i18n/locale.h" +#include "utils/i18n/locale.h" -#include "util/strings/split.h" +#include "utils/strings/split.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { @@ -107,4 +107,4 @@ Locale Locale::FromBCP47(const std::string& locale_tag) { return Locale(language.ToString(), script.ToString(), region.ToString()); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/i18n/locale.h b/utils/i18n/locale.h index 16f10dc..4cfcc22 100644 --- a/util/i18n/locale.h +++ b/utils/i18n/locale.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_ -#define LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_ +#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_ #include <string> -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { class Locale { public: @@ -58,6 +58,6 @@ class Locale { bool is_valid_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_ diff --git a/util/i18n/locale_test.cc b/utils/i18n/locale_test.cc index 72ece98..3722727 100644 --- a/util/i18n/locale_test.cc +++ b/utils/i18n/locale_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "util/i18n/locale.h" +#include "utils/i18n/locale.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { TEST(LocaleTest, ParseUnknown) { @@ -67,4 +67,4 @@ TEST(LocaleTest, ParseCineseTraditional) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/utils/intents/intent-config.fbs b/utils/intents/intent-config.fbs new file mode 100755 index 0000000..93a6fc9 --- /dev/null +++ b/utils/intents/intent-config.fbs @@ -0,0 +1,192 @@ +// +// Copyright (C) 2018 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// The type of variable to fetch. +namespace libtextclassifier3; +enum AndroidSimpleIntentGeneratorVariableType : int { + INVALID_VARIABLE = 0, + + // The raw text that was classified. + RAW_TEXT = 1, + + // Text as a URL with explicit protocol. If no protocol was specified, http + // is prepended. + URL_TEXT = 2, + + // The raw text, but URL encoded. + URL_ENCODED_TEXT = 3, + + // For dates/times: the instant of the event in UTC millis. + EVENT_TIME_MS_UTC = 4, + + // For dates/times: the start of the event in UTC millis. + EVENT_START_MS_UTC = 5, + + // For dates/times: the end of the event in UTC millis. + EVENT_END_MS_UTC = 6, + + // Name of the package that's running the classifier. + PACKAGE_NAME = 7, +} + +// Enumerates the possible extra types for the simple intent generator. +namespace libtextclassifier3; +enum AndroidSimpleIntentGeneratorExtraType : int { + INVALID_EXTRA_TYPE = 0, + STRING = 1, + BOOL = 2, + VARIABLE_AS_LONG = 3, +} + +// Enumerates the possible condition types for the simple intent generator. +namespace libtextclassifier3; +enum AndroidSimpleIntentGeneratorConditionType : int { + INVALID_CONDITION_TYPE = 0, + + // Queries the UserManager for the given boolean restriction. The condition + // passes if the result is of getBoolean is false. The name of the + // restriction to check is in the string_ field. + USER_RESTRICTION_NOT_SET = 1, + + // Checks that the parsed event start time is at least a give number of + // milliseconds in the future. (Only valid if there is a parsed event + // time) The offset is stored in the int64_ field. + EVENT_START_IN_FUTURE_MS = 2, +} + +// Describes how intents for the various entity types should be generated on +// Android. This is distributed through the model, but not used by +// libtextclassifier yet - rather, it's passed to the calling Java code, which +// implements the Intent generation logic. +namespace libtextclassifier3; +table AndroidIntentFactoryOptions { + entity:[libtextclassifier3.AndroidIntentFactoryEntityOptions]; +} + +// Describes how intents should be generated for a particular entity type. +namespace libtextclassifier3; +table AndroidIntentFactoryEntityOptions { + // The entity type as defined by one of the TextClassifier ENTITY_TYPE + // constants. (e.g. "address", "phone", etc.) + entity_type:string; + + // List of generators for all the different types of intents that should + // be made available for the entity type. + generator:[libtextclassifier3.AndroidIntentGeneratorOptions]; +} + +// Configures a single Android Intent generator. +namespace libtextclassifier3; +table AndroidIntentGeneratorOptions { + // Strings for UI elements. + strings:[libtextclassifier3.AndroidIntentGeneratorStrings]; + + // Generator specific configuration. + simple:libtextclassifier3.AndroidSimpleIntentGeneratorOptions; +} + +// Language dependent configuration for an Android Intent generator. +namespace libtextclassifier3; +table AndroidIntentGeneratorStrings { + // BCP 47 tag for the supported locale. Note that because of API level + // restrictions, this must /not/ use wildcards. To e.g. match all English + // locales, use only "en" and not "en_*". Reference the java.util.Locale + // constructor for details. + language_tag:string; + + // Title shown for the action (see RemoteAction.getTitle). + title:string; + + // Description shown for the action (see + // RemoteAction.getContentDescription). + description:string; +} + +// An extra to set on a simple intent generator Intent. +namespace libtextclassifier3; +table AndroidSimpleIntentGeneratorExtra { + // The name of the extra to set. + name:string; + + // The type of the extra to set. + type:libtextclassifier3.AndroidSimpleIntentGeneratorExtraType; + + string_:string; + + bool_:bool; + int32_:int; +} + +// A condition that needs to be fulfilled for an Intent to get generated. +namespace libtextclassifier3; +table AndroidSimpleIntentGeneratorCondition { + type:libtextclassifier3.AndroidSimpleIntentGeneratorConditionType; + + string_:string; + + int32_:int; + int64_:long; +} + +// Configures an intent generator where the logic is simple to be expressed with +// basic rules - which covers the vast majority of use cases and is analogous +// to Android Actions. +// Most strings (action, data, type, ...) may contain variable references. To +// use them, the generator must first declare all the variables it wishes to use +// in the variables field. The values then become available as numbered +// arguments (using the normal java.util.Formatter syntax) in the order they +// were specified. +namespace libtextclassifier3; +table AndroidSimpleIntentGeneratorOptions { + // The action to set on the Intent (see Intent.setAction). Supports variables. + action:string; + + // The data to set on the Intent (see Intent.setData). Supports variables. + data:string; + + // The type to set on the Intent (see Intent.setType). Supports variables. + type:string; + + // The list of all the extras to add to the Intent. + extra:[libtextclassifier3.AndroidSimpleIntentGeneratorExtra]; + + // The list of all the variables that become available for substitution in + // the action, data, type and extra strings. To e.g. set a field to the value + // of the first variable, use "%0$s". + variable:[libtextclassifier3.AndroidSimpleIntentGeneratorVariableType]; + + // The list of all conditions that need to be fulfilled for Intent generation. + condition:[libtextclassifier3.AndroidSimpleIntentGeneratorCondition]; +} + +// Describes how intents should be generated for a particular entity type. +namespace libtextclassifier3.IntentFactoryModel_; +table IntentGenerator { + // The entity type as defined by on the TextClassifier ENTITY_TYPE constants + // e.g. "address", "phone", etc. + entity_type:string; + + // The template generator lua code, either as text source or precompiled + // bytecode. + lua_template_generator:[ubyte]; +} + +// Describes how intents for the various entity types should be generated. +namespace libtextclassifier3; +table IntentFactoryModel { + entities:[libtextclassifier3.IntentFactoryModel_.IntentGenerator]; +} + diff --git a/utils/java/jni-base.cc b/utils/java/jni-base.cc new file mode 100644 index 0000000..330732c --- /dev/null +++ b/utils/java/jni-base.cc @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/java/jni-base.h" + +#include <jni.h> +#include <type_traits> +#include <vector> + +#include "utils/base/integral_types.h" +#include "utils/java/scoped_local_ref.h" +#include "utils/java/string_utils.h" +#include "utils/memory/mmap.h" + +using libtextclassifier3::JStringToUtf8String; +using libtextclassifier3::ScopedLocalRef; + +namespace libtextclassifier3 { + +std::string ToStlString(JNIEnv* env, const jstring& str) { + std::string result; + JStringToUtf8String(env, str, &result); + return result; +} + +jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd) { + ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"), + env); + if (fd_class == nullptr) { + TC3_LOG(ERROR) << "Couldn't find FileDescriptor."; + return reinterpret_cast<jlong>(nullptr); + } + jfieldID fd_class_descriptor = + env->GetFieldID(fd_class.get(), "descriptor", "I"); + if (fd_class_descriptor == nullptr) { + TC3_LOG(ERROR) << "Couldn't find descriptor."; + return reinterpret_cast<jlong>(nullptr); + } + return env->GetIntField(fd, fd_class_descriptor); +} + +jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) { + ScopedLocalRef<jclass> afd_class( + env->FindClass("android/content/res/AssetFileDescriptor"), env); + if (afd_class == nullptr) { + TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor."; + return reinterpret_cast<jlong>(nullptr); + } + jmethodID afd_class_getFileDescriptor = env->GetMethodID( + afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;"); + if (afd_class_getFileDescriptor == nullptr) { + TC3_LOG(ERROR) << "Couldn't find getFileDescriptor."; + return reinterpret_cast<jlong>(nullptr); + } + jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor); + return GetFdFromFileDescriptor(env, bundle_jfd); +} + +} // namespace libtextclassifier3 diff --git a/utils/java/jni-base.h b/utils/java/jni-base.h new file mode 100644 index 0000000..23658a3 --- /dev/null +++ b/utils/java/jni-base.h @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_ +#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_ + +#include <jni.h> +#include <string> + +// When we use a macro as an argument for a macro, an additional level of +// indirection is needed, if the macro argument is used with # or ##. +#define TC3_ADD_QUOTES_HELPER(TOKEN) #TOKEN +#define TC3_ADD_QUOTES(TOKEN) TC3_ADD_QUOTES_HELPER(TOKEN) + +#ifndef TC3_PACKAGE_NAME +#define TC3_PACKAGE_NAME com_google_android_textclassifier +#endif + +#ifndef TC3_PACKAGE_PATH +#define TC3_PACKAGE_PATH \ + "com/google/android/textclassifier/" +#endif + +#define TC3_JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) \ + Java_##package_name##_##class_name##_##method_name + +#define TC3_JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \ + method_name) \ + JNIEXPORT return_type JNICALL TC3_JNI_METHOD_NAME_INTERNAL( \ + package_name, class_name, method_name) + +// The indirection is needed to correctly expand the TC3_PACKAGE_NAME macro. +// See the explanation near TC3_ADD_QUOTES macro. +#define TC3_JNI_METHOD2(return_type, package_name, class_name, method_name) \ + TC3_JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name) + +#define TC3_JNI_METHOD(return_type, class_name, method_name) \ + TC3_JNI_METHOD2(return_type, TC3_PACKAGE_NAME, class_name, method_name) + +#define TC3_JNI_METHOD_NAME2(package_name, class_name, method_name) \ + TC3_JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) + +#define TC3_JNI_METHOD_NAME(class_name, method_name) \ + TC3_JNI_METHOD_NAME2(TC3_PACKAGE_NAME, class_name, method_name) + +namespace libtextclassifier3 { + +template <typename T, typename F> +std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object, + jclass class_object, F function, + const std::string& method_name, + const std::string& return_java_type) { + const jmethodID method = env->GetMethodID(class_object, method_name.c_str(), + ("()" + return_java_type).c_str()); + if (!method) { + return std::make_pair(false, T()); + } + return std::make_pair(true, (env->*function)(object, method)); +} + +std::string ToStlString(JNIEnv* env, const jstring& str); + +// Get system-level file descriptor from AssetFileDescriptor. +jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd); + +// Get system-level file descriptor from FileDescriptor. +jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd); +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_ diff --git a/utils/java/jni-cache.cc b/utils/java/jni-cache.cc new file mode 100644 index 0000000..e2a6676 --- /dev/null +++ b/utils/java/jni-cache.cc @@ -0,0 +1,284 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/java/jni-cache.h" + +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +JniCache::JniCache(JavaVM* jvm) + : jvm(jvm), + string_class(nullptr, jvm), + string_utf8(nullptr, jvm), + pattern_class(nullptr, jvm), + matcher_class(nullptr, jvm), + locale_class(nullptr, jvm), + locale_us(nullptr, jvm), + breakiterator_class(nullptr, jvm), + integer_class(nullptr, jvm), + calendar_class(nullptr, jvm), + timezone_class(nullptr, jvm), + urlencoder_class(nullptr, jvm) +#ifdef __ANDROID__ + , + context_class(nullptr, jvm), + uri_class(nullptr, jvm), + usermanager_class(nullptr, jvm), + bundle_class(nullptr, jvm) +#endif +{ +} + +// The macros below are intended to reduce the boilerplate in Create and avoid +// easily introduced copy/paste errors. +#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr) +#define TC3_CHECK_JNI_RESULT(RESULT) TC3_CHECK(RESULT) + +#define TC3_GET_CLASS(FIELD, NAME) \ + result->FIELD##_class = MakeGlobalRef(env->FindClass(NAME), env, jvm); \ + TC3_CHECK_JNI_PTR(result->FIELD##_class) << "Error finding class: " << NAME; + +#define TC3_GET_OPTIONAL_CLASS(FIELD, NAME) \ + { \ + jclass clazz = env->FindClass(NAME); \ + if (clazz != nullptr) { \ + result->FIELD##_class = MakeGlobalRef(clazz, env, jvm); \ + } \ + env->ExceptionClear(); \ + } + +#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \ + result->CLASS##_##FIELD = \ + env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \ + TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \ + << "Error finding method: " << NAME; + +#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \ + if (result->CLASS##_class != nullptr) { \ + result->CLASS##_##FIELD = \ + env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \ + env->ExceptionClear(); \ + } + +#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \ + if (result->CLASS##_class != nullptr) { \ + result->CLASS##_##FIELD = \ + env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \ + env->ExceptionClear(); \ + } + +#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \ + result->CLASS##_##FIELD = \ + env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \ + TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \ + << "Error finding method: " << NAME; + +#define TC3_GET_STATIC_OBJECT_FIELD(CLASS, FIELD, NAME, SIGNATURE) \ + const jfieldID CLASS##_##FIELD##_field = \ + env->GetStaticFieldID(result->CLASS##_class.get(), NAME, SIGNATURE); \ + TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \ + << "Error finding field id: " << NAME; \ + result->CLASS##_##FIELD = \ + MakeGlobalRef(env->GetStaticObjectField(result->CLASS##_class.get(), \ + CLASS##_##FIELD##_field), \ + env, jvm); \ + TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \ + << "Error finding field: " << NAME; + +#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \ + const jfieldID CLASS##_##FIELD##_field = \ + env->GetStaticFieldID(result->CLASS##_class.get(), NAME, "I"); \ + TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \ + << "Error finding field id: " << NAME; \ + result->CLASS##_##FIELD = env->GetStaticIntField( \ + result->CLASS##_class.get(), CLASS##_##FIELD##_field); \ + TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \ + << "Error finding field: " << NAME; + +std::unique_ptr<JniCache> JniCache::Create(JNIEnv* env) { + if (env == nullptr) { + return nullptr; + } + JavaVM* jvm = nullptr; + if (JNI_OK != env->GetJavaVM(&jvm) || jvm == nullptr) { + return nullptr; + } + std::unique_ptr<JniCache> result(new JniCache(jvm)); + + // String + TC3_GET_CLASS(string, "java/lang/String"); + TC3_GET_METHOD(string, init_bytes_charset, "<init>", + "([BLjava/lang/String;)V"); + TC3_GET_METHOD(string, code_point_count, "codePointCount", "(II)I"); + TC3_GET_METHOD(string, length, "length", "()I"); + result->string_utf8 = MakeGlobalRef(env->NewStringUTF("UTF-8"), env, jvm); + TC3_CHECK_JNI_PTR(result->string_utf8); + + // Pattern + TC3_GET_CLASS(pattern, "java/util/regex/Pattern"); + TC3_GET_STATIC_METHOD(pattern, compile, "compile", + "(Ljava/lang/String;)Ljava/util/regex/Pattern;"); + TC3_GET_METHOD(pattern, matcher, "matcher", + "(Ljava/lang/CharSequence;)Ljava/util/regex/Matcher;"); + + // Matcher + TC3_GET_CLASS(matcher, "java/util/regex/Matcher"); + TC3_GET_METHOD(matcher, matches, "matches", "()Z"); + TC3_GET_METHOD(matcher, find, "find", "()Z"); + TC3_GET_METHOD(matcher, reset, "reset", "()Ljava/util/regex/Matcher;"); + TC3_GET_METHOD(matcher, start_idx, "start", "(I)I"); + TC3_GET_METHOD(matcher, end_idx, "end", "(I)I"); + TC3_GET_METHOD(matcher, group, "group", "()Ljava/lang/String;"); + TC3_GET_METHOD(matcher, group_idx, "group", "(I)Ljava/lang/String;"); + + // Locale + TC3_GET_CLASS(locale, "java/util/Locale"); + TC3_GET_STATIC_OBJECT_FIELD(locale, us, "US", "Ljava/util/Locale;"); + TC3_GET_METHOD(locale, init_string, "<init>", "(Ljava/lang/String;)V"); + TC3_GET_OPTIONAL_STATIC_METHOD(locale, for_language_tag, "forLanguageTag", + "(Ljava/lang/String;)Ljava/util/Locale;"); + + // BreakIterator + TC3_GET_CLASS(breakiterator, "java/text/BreakIterator"); + TC3_GET_STATIC_METHOD(breakiterator, getwordinstance, "getWordInstance", + "(Ljava/util/Locale;)Ljava/text/BreakIterator;"); + TC3_GET_METHOD(breakiterator, settext, "setText", "(Ljava/lang/String;)V"); + TC3_GET_METHOD(breakiterator, next, "next", "()I"); + + // Integer + TC3_GET_CLASS(integer, "java/lang/Integer"); + TC3_GET_STATIC_METHOD(integer, parse_int, "parseInt", + "(Ljava/lang/String;)I"); + + // Calendar. + TC3_GET_CLASS(calendar, "java/util/Calendar"); + TC3_GET_STATIC_METHOD( + calendar, get_instance, "getInstance", + "(Ljava/util/TimeZone;Ljava/util/Locale;)Ljava/util/Calendar;"); + TC3_GET_METHOD(calendar, get_first_day_of_week, "getFirstDayOfWeek", "()I"); + TC3_GET_METHOD(calendar, get_time_in_millis, "getTimeInMillis", "()J"); + TC3_GET_METHOD(calendar, set_time_in_millis, "setTimeInMillis", "(J)V"); + TC3_GET_METHOD(calendar, add, "add", "(II)V"); + TC3_GET_METHOD(calendar, get, "get", "(I)I"); + TC3_GET_METHOD(calendar, set, "set", "(II)V"); + TC3_GET_STATIC_INT_FIELD(calendar, zone_offset, "ZONE_OFFSET"); + TC3_GET_STATIC_INT_FIELD(calendar, dst_offset, "DST_OFFSET"); + TC3_GET_STATIC_INT_FIELD(calendar, year, "YEAR"); + TC3_GET_STATIC_INT_FIELD(calendar, month, "MONTH"); + TC3_GET_STATIC_INT_FIELD(calendar, day_of_year, "DAY_OF_YEAR"); + TC3_GET_STATIC_INT_FIELD(calendar, day_of_month, "DAY_OF_MONTH"); + TC3_GET_STATIC_INT_FIELD(calendar, day_of_week, "DAY_OF_WEEK"); + TC3_GET_STATIC_INT_FIELD(calendar, hour_of_day, "HOUR_OF_DAY"); + TC3_GET_STATIC_INT_FIELD(calendar, minute, "MINUTE"); + TC3_GET_STATIC_INT_FIELD(calendar, second, "SECOND"); + TC3_GET_STATIC_INT_FIELD(calendar, millisecond, "MILLISECOND"); + TC3_GET_STATIC_INT_FIELD(calendar, sunday, "SUNDAY"); + TC3_GET_STATIC_INT_FIELD(calendar, monday, "MONDAY"); + TC3_GET_STATIC_INT_FIELD(calendar, tuesday, "TUESDAY"); + TC3_GET_STATIC_INT_FIELD(calendar, wednesday, "WEDNESDAY"); + TC3_GET_STATIC_INT_FIELD(calendar, thursday, "THURSDAY"); + TC3_GET_STATIC_INT_FIELD(calendar, friday, "FRIDAY"); + TC3_GET_STATIC_INT_FIELD(calendar, saturday, "SATURDAY"); + + // TimeZone. + TC3_GET_CLASS(timezone, "java/util/TimeZone"); + TC3_GET_STATIC_METHOD(timezone, get_timezone, "getTimeZone", + "(Ljava/lang/String;)Ljava/util/TimeZone;"); + + // URLEncoder. + TC3_GET_CLASS(urlencoder, "java/net/URLEncoder"); + TC3_GET_STATIC_METHOD( + urlencoder, encode, "encode", + "(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;"); + +#ifdef __ANDROID__ + // Context. + TC3_GET_CLASS(context, "android/content/Context"); + TC3_GET_METHOD(context, get_package_name, "getPackageName", + "()Ljava/lang/String;"); + TC3_GET_METHOD(context, get_system_service, "getSystemService", + "(Ljava/lang/String;)Ljava/lang/Object;"); + + // Uri. + TC3_GET_CLASS(uri, "android/net/Uri"); + TC3_GET_STATIC_METHOD(uri, parse, "parse", + "(Ljava/lang/String;)Landroid/net/Uri;"); + TC3_GET_METHOD(uri, get_scheme, "getScheme", "()Ljava/lang/String;"); + + // UserManager. + TC3_GET_OPTIONAL_CLASS(usermanager, "android/os/UserManager"); + TC3_GET_OPTIONAL_METHOD(usermanager, get_user_restrictions, + "getUserRestrictions", "()Landroid/os/Bundle;"); + + // Bundle. + TC3_GET_CLASS(bundle, "android/os/Bundle"); + TC3_GET_METHOD(bundle, get_boolean, "getBoolean", "(Ljava/lang/String;)Z"); +#endif + + return result; +} + +#undef TC3_GET_STATIC_INT_FIELD +#undef TC3_GET_STATIC_OBJECT_FIELD +#undef TC3_GET_STATIC_METHOD +#undef TC3_GET_METHOD +#undef TC3_GET_CLASS +#undef TC3_CHECK_JNI_PTR + +JNIEnv* JniCache::GetEnv() const { + void* env; + if (JNI_OK == jvm->GetEnv(&env, JNI_VERSION_1_4)) { + return reinterpret_cast<JNIEnv*>(env); + } else { + TC3_LOG(ERROR) << "JavaICU UniLib used on unattached thread"; + return nullptr; + } +} + +bool JniCache::ExceptionCheckAndClear() const { + JNIEnv* env = GetEnv(); + TC3_CHECK(env != nullptr); + const bool result = env->ExceptionCheck(); + if (result) { + env->ExceptionDescribe(); + env->ExceptionClear(); + } + return result; +} + +ScopedLocalRef<jstring> JniCache::ConvertToJavaString( + const UnicodeText& text) const { + // Create java byte array. + JNIEnv* jenv = GetEnv(); + const ScopedLocalRef<jbyteArray> text_java_utf8( + jenv->NewByteArray(text.size_bytes()), jenv); + if (!text_java_utf8) { + return nullptr; + } + + jenv->SetByteArrayRegion(text_java_utf8.get(), 0, text.size_bytes(), + reinterpret_cast<const jbyte*>(text.data())); + + // Create the string with a UTF-8 charset. + return ScopedLocalRef<jstring>( + reinterpret_cast<jstring>( + jenv->NewObject(string_class.get(), string_init_bytes_charset, + text_java_utf8.get(), string_utf8.get())), + jenv); +} + +} // namespace libtextclassifier3 diff --git a/utils/java/jni-cache.h b/utils/java/jni-cache.h new file mode 100644 index 0000000..18675fc --- /dev/null +++ b/utils/java/jni-cache.h @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_ +#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_ + +#include <jni.h> +#include "utils/java/scoped_global_ref.h" +#include "utils/java/scoped_local_ref.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { + +// A helper class to cache class and method pointers for calls from JNI to Java. +// (for implementations such as Java ICU that need to make calls from C++ to +// Java) +struct JniCache { + static std::unique_ptr<JniCache> Create(JNIEnv* env); + + JNIEnv* GetEnv() const; + bool ExceptionCheckAndClear() const; + + JavaVM* jvm = nullptr; + + // java.lang.String + ScopedGlobalRef<jclass> string_class; + jmethodID string_init_bytes_charset = nullptr; + jmethodID string_code_point_count = nullptr; + jmethodID string_length = nullptr; + ScopedGlobalRef<jstring> string_utf8; + + // java.util.regex.Pattern + ScopedGlobalRef<jclass> pattern_class; + jmethodID pattern_compile = nullptr; + jmethodID pattern_matcher = nullptr; + + // java.util.regex.Matcher + ScopedGlobalRef<jclass> matcher_class; + jmethodID matcher_matches = nullptr; + jmethodID matcher_find = nullptr; + jmethodID matcher_reset = nullptr; + jmethodID matcher_start_idx = nullptr; + jmethodID matcher_end_idx = nullptr; + jmethodID matcher_group = nullptr; + jmethodID matcher_group_idx = nullptr; + + // java.util.Locale + ScopedGlobalRef<jclass> locale_class; + ScopedGlobalRef<jobject> locale_us; + jmethodID locale_init_string = nullptr; + jmethodID locale_for_language_tag = nullptr; + + // java.text.BreakIterator + ScopedGlobalRef<jclass> breakiterator_class; + jmethodID breakiterator_getwordinstance = nullptr; + jmethodID breakiterator_settext = nullptr; + jmethodID breakiterator_next = nullptr; + + // java.lang.Integer + ScopedGlobalRef<jclass> integer_class; + jmethodID integer_parse_int = nullptr; + + // java.util.Calendar + ScopedGlobalRef<jclass> calendar_class; + jmethodID calendar_get_instance = nullptr; + jmethodID calendar_get_first_day_of_week = nullptr; + jmethodID calendar_get_time_in_millis = nullptr; + jmethodID calendar_set_time_in_millis = nullptr; + jmethodID calendar_add = nullptr; + jmethodID calendar_get = nullptr; + jmethodID calendar_set = nullptr; + jint calendar_zone_offset; + jint calendar_dst_offset; + jint calendar_year; + jint calendar_month; + jint calendar_day_of_year; + jint calendar_day_of_month; + jint calendar_day_of_week; + jint calendar_hour_of_day; + jint calendar_minute; + jint calendar_second; + jint calendar_millisecond; + jint calendar_sunday; + jint calendar_monday; + jint calendar_tuesday; + jint calendar_wednesday; + jint calendar_thursday; + jint calendar_friday; + jint calendar_saturday; + + // java.util.TimeZone + ScopedGlobalRef<jclass> timezone_class; + jmethodID timezone_get_timezone = nullptr; + + // java.net.URLEncoder + ScopedGlobalRef<jclass> urlencoder_class; + jmethodID urlencoder_encode = nullptr; + +#ifdef __ANDROID__ + // android.content.Context + ScopedGlobalRef<jclass> context_class; + jmethodID context_get_package_name = nullptr; + jmethodID context_get_system_service = nullptr; + + // android.net.Uri + ScopedGlobalRef<jclass> uri_class; + jmethodID uri_parse = nullptr; + jmethodID uri_get_scheme = nullptr; + + // android.os.UserManager + ScopedGlobalRef<jclass> usermanager_class; + jmethodID usermanager_get_user_restrictions = nullptr; + + // android.os.Bundle + ScopedGlobalRef<jclass> bundle_class; + jmethodID bundle_get_boolean = nullptr; +#endif + + // Helper to convert lib3 UnicodeText to Java strings. + ScopedLocalRef<jstring> ConvertToJavaString(const UnicodeText& text) const; + + private: + explicit JniCache(JavaVM* jvm); +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_ diff --git a/util/java/scoped_global_ref.h b/utils/java/scoped_global_ref.h index 3f8754d..de0608e 100644 --- a/util/java/scoped_global_ref.h +++ b/utils/java/scoped_global_ref.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,16 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_GLOBAL_REF_H_ -#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_GLOBAL_REF_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_ +#define LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_ #include <jni.h> #include <memory> #include <type_traits> -#include "util/base/logging.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // A deleter to be used with std::unique_ptr to delete JNI global references. class GlobalRefDeleter { @@ -38,7 +38,7 @@ class GlobalRefDeleter { // Copy assignment to allow move semantics in ScopedGlobalRef. GlobalRefDeleter& operator=(const GlobalRefDeleter& rhs) { - TC_CHECK_EQ(jvm_, rhs.jvm_); + TC3_CHECK_EQ(jvm_, rhs.jvm_); return *this; } @@ -64,13 +64,15 @@ template <typename T> using ScopedGlobalRef = std::unique_ptr<typename std::remove_pointer<T>::type, GlobalRefDeleter>; -// A helper to create global references. +// A helper to create global references. Assumes the object has a local +// reference, which it deletes. template <typename T> ScopedGlobalRef<T> MakeGlobalRef(T object, JNIEnv* env, JavaVM* jvm) { - const jobject globalObject = env->NewGlobalRef(object); - return ScopedGlobalRef<T>(reinterpret_cast<T>(globalObject), jvm); + const jobject global_object = env->NewGlobalRef(object); + env->DeleteLocalRef(object); + return ScopedGlobalRef<T>(reinterpret_cast<T>(global_object), jvm); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_GLOBAL_REF_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_ diff --git a/util/java/scoped_local_ref.h b/utils/java/scoped_local_ref.h index 8476767..f439c45 100644 --- a/util/java/scoped_local_ref.h +++ b/utils/java/scoped_local_ref.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,16 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_ -#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_ +#define LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_ #include <jni.h> #include <memory> #include <type_traits> -#include "util/base/logging.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // A deleter to be used with std::unique_ptr to delete JNI local references. class LocalRefDeleter { @@ -40,7 +40,7 @@ class LocalRefDeleter { LocalRefDeleter& operator=(const LocalRefDeleter& rhs) { // As the deleter and its state are thread-local, ensure the envs // are consistent but do nothing. - TC_CHECK_EQ(env_, rhs.env_); + TC3_CHECK_EQ(env_, rhs.env_); return *this; } @@ -66,6 +66,6 @@ template <typename T> using ScopedLocalRef = std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_ diff --git a/util/java/string_utils.cc b/utils/java/string_utils.cc index ffd5b11..457a667 100644 --- a/util/java/string_utils.cc +++ b/utils/java/string_utils.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,26 @@ * limitations under the License. */ -#include "util/java/string_utils.h" +#include "utils/java/string_utils.h" -#include "util/base/logging.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { + +bool JByteArrayToString(JNIEnv* env, const jbyteArray& array, + std::string* result) { + jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE); + if (array_bytes == nullptr) { + return false; + } + + const int array_length = env->GetArrayLength(array); + *result = std::string(reinterpret_cast<char*>(array_bytes), array_length); + + env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT); + + return true; +} bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result) { @@ -29,7 +44,7 @@ bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, jclass string_class = env->FindClass("java/lang/String"); if (!string_class) { - TC_LOG(ERROR) << "Can't find String class"; + TC3_LOG(ERROR) << "Can't find String class"; return false; } @@ -37,16 +52,13 @@ bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); jstring encoding = env->NewStringUTF("UTF-8"); + jbyteArray array = reinterpret_cast<jbyteArray>( env->CallObjectMethod(jstr, get_bytes_id, encoding)); - jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE); - int length = env->GetArrayLength(array); - - *result = std::string(reinterpret_cast<char*>(array_bytes), length); + JByteArrayToString(env, array, result); // Release the array. - env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT); env->DeleteLocalRef(array); env->DeleteLocalRef(string_class); env->DeleteLocalRef(encoding); @@ -54,4 +66,10 @@ bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, return true; } -} // namespace libtextclassifier2 +ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string, + jboolean* is_copy) { + return ScopedStringChars(env->GetStringUTFChars(string, is_copy), + StringCharsReleaser(env, string)); +} + +} // namespace libtextclassifier3 diff --git a/utils/java/string_utils.h b/utils/java/string_utils.h new file mode 100644 index 0000000..172a938 --- /dev/null +++ b/utils/java/string_utils.h @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_ +#define LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_ + +#include <jni.h> +#include <memory> +#include <string> + +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +bool JByteArrayToString(JNIEnv* env, const jbyteArray& array, + std::string* result); +bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result); + +// A deleter to be used with std::unique_ptr to release Java string chars. +class StringCharsReleaser { + public: + StringCharsReleaser() : env_(nullptr) {} + + StringCharsReleaser(JNIEnv* env, jstring jstr) : env_(env), jstr_(jstr) {} + + StringCharsReleaser(const StringCharsReleaser& orig) = default; + + // Copy assignment to allow move semantics in StringCharsReleaser. + StringCharsReleaser& operator=(const StringCharsReleaser& rhs) { + // As the releaser and its state are thread-local, it's enough to only + // ensure the envs are consistent but do nothing. + TC3_CHECK_EQ(env_, rhs.env_); + return *this; + } + + // The delete operator. + void operator()(const char* chars) const { + if (env_ != nullptr) { + env_->ReleaseStringUTFChars(jstr_, chars); + } + } + + private: + // The env_ stashed to use for deletion. Thread-local, don't share! + JNIEnv* const env_; + + // The referenced jstring. + jstring jstr_; +}; + +// A smart pointer that releases string chars when it goes out of scope. +// of scope. +// Note that this class is not thread-safe since it caches JNIEnv in +// the deleter. Do not use the same jobject across different threads. +using ScopedStringChars = std::unique_ptr<const char, StringCharsReleaser>; + +// Returns a scoped pointer to the array of Unicode characters of a string. +ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string, + jboolean* is_copy = nullptr); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_ diff --git a/util/math/fastexp.cc b/utils/math/fastexp.cc index 4bf8592..b319eae 100644 --- a/util/math/fastexp.cc +++ b/utils/math/fastexp.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "util/math/fastexp.h" +#include "utils/math/fastexp.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { const int FastMathClass::kBits; const int FastMathClass::kMask1; @@ -45,4 +45,4 @@ const FastMathClass::Table FastMathClass::cache_ = { 7940441, 8029106, 8118253, 8207884, 8298001} }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/math/fastexp.h b/utils/math/fastexp.h index af7a08c..63e5d5d 100644 --- a/util/math/fastexp.h +++ b/utils/math/fastexp.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,18 +16,18 @@ // Fast approximation for exp. -#ifndef LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_ -#define LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_ +#define LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_ #include <cassert> #include <cmath> #include <limits> -#include "util/base/casts.h" -#include "util/base/integral_types.h" -#include "util/base/logging.h" +#include "utils/base/casts.h" +#include "utils/base/integral_types.h" +#include "utils/base/logging.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { class FastMathClass { private: @@ -42,7 +42,7 @@ class FastMathClass { public: float VeryFastExp2(float f) const { - TC_DCHECK_LE(fabs(f), 126); + TC3_DCHECK_LE(fabs(f), 126); const float g = f + (127 + (1 << (23 - kBits))); const int32 x = bit_cast<int32>(g); int32 ret = ((x & kMask2) << (23 - kBits)) @@ -63,6 +63,6 @@ extern FastMathClass FastMathInstance; inline float VeryFastExp2(float f) { return FastMathInstance.VeryFastExp2(f); } inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_ diff --git a/util/math/softmax.cc b/utils/math/softmax.cc index 986787f..c278625 100644 --- a/util/math/softmax.cc +++ b/utils/math/softmax.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,19 @@ * limitations under the License. */ -#include "util/math/softmax.h" +#include "utils/math/softmax.h" #include <limits> -#include "util/base/logging.h" -#include "util/math/fastexp.h" +#include "utils/base/logging.h" +#include "utils/math/fastexp.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) { if ((label < 0) || (label >= scores.size())) { - TC_LOG(ERROR) << "label " << label << " outside range " - << "[0, " << scores.size() << ")"; + TC3_LOG(ERROR) << "label " << label << " outside range " + << "[0, " << scores.size() << ")"; return 0.0f; } @@ -101,4 +101,4 @@ std::vector<float> ComputeSoftmax(const float *scores, int scores_size) { return softmax; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/math/softmax.h b/utils/math/softmax.h index f70a9ab..8ac198b 100644 --- a/util/math/softmax.h +++ b/utils/math/softmax.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_ -#define LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_MATH_SOFTMAX_H_ +#define LIBTEXTCLASSIFIER_UTILS_MATH_SOFTMAX_H_ #include <vector> -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Computes probability of a softmax label. Parameter "scores" is the vector of // softmax logits. Returns 0.0f if "label" is outside the range [0, @@ -33,6 +33,6 @@ std::vector<float> ComputeSoftmax(const std::vector<float> &scores); // Same as above but operates on an array of floats. std::vector<float> ComputeSoftmax(const float *scores, int scores_size); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_MATH_SOFTMAX_H_ diff --git a/util/memory/mmap.cc b/utils/memory/mmap.cc index 6b0bdf2..a251024 100644 --- a/util/memory/mmap.cc +++ b/utils/memory/mmap.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "util/memory/mmap.h" +#include "utils/memory/mmap.h" #include <errno.h> #include <fcntl.h> @@ -24,10 +24,10 @@ #include <sys/stat.h> #include <unistd.h> -#include "util/base/logging.h" -#include "util/base/macros.h" +#include "utils/base/logging.h" +#include "utils/base/macros.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { inline std::string GetLastSystemError() { return std::string(strerror(errno)); } @@ -41,14 +41,14 @@ class FileCloser { int result = close(fd_); if (result != 0) { const std::string last_error = GetLastSystemError(); - TC_LOG(ERROR) << "Error closing file descriptor: " << last_error; + TC3_LOG(ERROR) << "Error closing file descriptor: " << last_error; } } private: const int fd_; - TC_DISALLOW_COPY_AND_ASSIGN(FileCloser); + TC3_DISALLOW_COPY_AND_ASSIGN(FileCloser); }; } // namespace @@ -58,7 +58,7 @@ MmapHandle MmapFile(const std::string &filename) { if (fd < 0) { const std::string last_error = GetLastSystemError(); - TC_LOG(ERROR) << "Error opening " << filename << ": " << last_error; + TC3_LOG(ERROR) << "Error opening " << filename << ": " << last_error; return GetErrorMmapHandle(); } @@ -75,7 +75,7 @@ MmapHandle MmapFile(int fd) { struct stat sb; if (fstat(fd, &sb) != 0) { const std::string last_error = GetLastSystemError(); - TC_LOG(ERROR) << "Unable to stat fd: " << last_error; + TC3_LOG(ERROR) << "Unable to stat fd: " << last_error; return GetErrorMmapHandle(); } @@ -111,7 +111,7 @@ MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size) { aligned_offset); if (mmap_addr == MAP_FAILED) { const std::string last_error = GetLastSystemError(); - TC_LOG(ERROR) << "Error while mmapping: " << last_error; + TC3_LOG(ERROR) << "Error while mmapping: " << last_error; return GetErrorMmapHandle(); } @@ -126,10 +126,10 @@ bool Unmap(MmapHandle mmap_handle) { } if (munmap(mmap_handle.unmap_addr(), mmap_handle.num_bytes()) != 0) { const std::string last_error = GetLastSystemError(); - TC_LOG(ERROR) << "Error during Unmap / munmap: " << last_error; + TC3_LOG(ERROR) << "Error during Unmap / munmap: " << last_error; return false; } return true; } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/memory/mmap.h b/utils/memory/mmap.h index 7d28b64..acce7db 100644 --- a/util/memory/mmap.h +++ b/utils/memory/mmap.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,17 +14,17 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_ -#define LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_ +#define LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_ #include <stddef.h> #include <string> -#include "util/base/integral_types.h" -#include "util/strings/stringpiece.h" +#include "utils/base/integral_types.h" +#include "utils/strings/stringpiece.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Handle for a memory area where a file has been mmapped. // @@ -86,7 +86,7 @@ class MmapHandle { // Sample usage: // // MmapHandle mmap_handle = MmapFile(filename); -// TC_DCHECK(mmap_handle.ok()) << "Unable to mmap " << filename; +// TC3_DCHECK(mmap_handle.ok()) << "Unable to mmap " << filename; // // ... use data from addresses // ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes) @@ -136,6 +136,6 @@ class ScopedMmap { MmapHandle handle_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_ diff --git a/utils/optional.h b/utils/optional.h new file mode 100644 index 0000000..15d2619 --- /dev/null +++ b/utils/optional.h @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_OPTIONAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_OPTIONAL_H_ + +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +// Holds an optional value. +template <class T> +class Optional { + public: + Optional() : init_(false) {} + + Optional(const Optional& other) { + init_ = other.init_; + if (other.init_) { + value_ = other.value_; + } + } + + explicit Optional(T value) : init_(true), value_(value) {} + + Optional& operator=(Optional&& other) { + init_ = other.init_; + if (other.init_) { + value_ = std::move(other); + } + return *this; + } + + Optional& operator=(T&& other) { + init_ = true; + value_ = std::move(other); + return *this; + } + + constexpr bool has_value() const { return init_; } + + T const* operator->() const { + TC3_CHECK(init_) << "Bad optional access."; + return value_; + } + + T const& value() const& { + TC3_CHECK(init_) << "Bad optional access."; + return value_; + } + + T const& value_or(T&& default_value) { + return (init_ ? value_ : default_value); + } + + void set(const T& value) { + init_ = true; + value_ = value; + } + + private: + bool init_; + T value_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_OPTIONAL_H_ diff --git a/util/strings/numbers.cc b/utils/strings/numbers.cc index a89c0ef..3028c69 100644 --- a/util/strings/numbers.cc +++ b/utils/strings/numbers.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "util/strings/numbers.h" +#include "utils/strings/numbers.h" #ifdef COMPILER_MSVC #include <sstream> @@ -22,7 +22,7 @@ #include <stdlib.h> -namespace libtextclassifier2 { +namespace libtextclassifier3 { bool ParseInt32(const char *c_str, int32 *value) { char *temp; @@ -72,4 +72,4 @@ std::string IntToString(int64 input) { } #endif // COMPILER_MSVC -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/strings/numbers.h b/utils/strings/numbers.h index a2c8c6e..ae48068 100644 --- a/util/strings/numbers.h +++ b/utils/strings/numbers.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_ -#define LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_NUMBERS_H_ +#define LIBTEXTCLASSIFIER_UTILS_STRINGS_NUMBERS_H_ #include <string> -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Parses an int32 from a C-style string. // @@ -47,6 +47,6 @@ bool ParseDouble(const char *c_str, double *value); // int types. std::string IntToString(int64 input); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_NUMBERS_H_ diff --git a/util/strings/numbers_test.cc b/utils/strings/numbers_test.cc index 1fdd78a..57e812f 100644 --- a/util/strings/numbers_test.cc +++ b/utils/strings/numbers_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "util/strings/numbers.h" +#include "utils/strings/numbers.h" -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { void TestParseInt32(const char *c_str, bool expected_parsing_success, @@ -100,4 +100,4 @@ TEST(ParseDoubleTest, ErrorCases) { TestParseDouble("23.5a", false); } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/strings/split.cc b/utils/strings/split.cc index 2c610ba..584760a 100644 --- a/util/strings/split.cc +++ b/utils/strings/split.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "util/strings/split.h" +#include "utils/strings/split.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace strings { std::vector<StringPiece> Split(const StringPiece &text, char delim) { @@ -35,4 +35,4 @@ std::vector<StringPiece> Split(const StringPiece &text, char delim) { } } // namespace strings -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/strings/split.h b/utils/strings/split.h index 96f73fe..b565258 100644 --- a/util/strings/split.h +++ b/utils/strings/split.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,20 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_ -#define LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_ +#define LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_ #include <string> #include <vector> -#include "util/strings/stringpiece.h" +#include "utils/strings/stringpiece.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace strings { std::vector<StringPiece> Split(const StringPiece &text, char delim); } // namespace strings -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_ diff --git a/util/strings/stringpiece.h b/utils/strings/stringpiece.h index cd07848..3ec414f 100644 --- a/util/strings/stringpiece.h +++ b/utils/strings/stringpiece.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,15 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_ -#define LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_ +#define LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_ #include <stddef.h> - #include <string> -namespace libtextclassifier2 { +#include "utils/base/logging.h" + +namespace libtextclassifier3 { // Read-only "view" of a piece of data. Does not own the underlying data. class StringPiece { @@ -31,8 +32,7 @@ class StringPiece { StringPiece(const char *str) // NOLINT(runtime/explicit) : start_(str), size_(strlen(str)) {} - StringPiece(const char *start, size_t size) - : start_(start), size_(size) {} + StringPiece(const char *start, size_t size) : start_(start), size_(size) {} // Intentionally no "explicit" keyword: in function calls, we want strings to // be converted to StringPiece implicitly. @@ -54,8 +54,28 @@ class StringPiece { bool empty() const { return size_ == 0; } // Returns a std::string containing a copy of the underlying data. - std::string ToString() const { - return std::string(data(), size()); + std::string ToString() const { return std::string(data(), size()); } + + // Returns whether string ends with a given suffix. + bool EndsWith(StringPiece suffix) const { + return suffix.empty() || (size_ >= suffix.size() && + memcmp(start_ + (size_ - suffix.size()), + suffix.data(), suffix.size()) == 0); + } + + // Returns whether the string begins with a given prefix. + bool StartsWith(StringPiece prefix) const { + return prefix.empty() || + (size_ >= prefix.size() && + memcmp(start_, prefix.data(), prefix.size()) == 0); + } + + // Removes the first `n` characters from the string piece. Note that the + // underlying string is not changed, only the view. + void RemovePrefix(int n) { + TC3_CHECK_LE(n, size_); + start_ += n; + size_ -= n; } private: @@ -63,6 +83,22 @@ class StringPiece { size_t size_; }; -} // namespace libtextclassifier2 +inline bool EndsWith(StringPiece text, StringPiece suffix) { + return text.EndsWith(suffix); +} + +inline bool StartsWith(StringPiece text, StringPiece prefix) { + return text.StartsWith(prefix); +} + +inline bool ConsumePrefix(StringPiece *text, StringPiece prefix) { + if (!text->StartsWith(prefix)) { + return false; + } + text->RemovePrefix(prefix.size()); + return true; +} + +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_ diff --git a/utils/strings/stringpiece_test.cc b/utils/strings/stringpiece_test.cc new file mode 100644 index 0000000..713a7f9 --- /dev/null +++ b/utils/strings/stringpiece_test.cc @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "utils/strings/stringpiece.h" + +namespace libtextclassifier3 { +namespace { + +TEST(StringPieceTest, EndsWith) { + EXPECT_TRUE(EndsWith("hello there!", "there!")); + EXPECT_TRUE(EndsWith("hello there!", "!")); + EXPECT_FALSE(EndsWith("hello there!", "there")); + EXPECT_FALSE(EndsWith("hello there!", " hello there!")); + EXPECT_TRUE(EndsWith("hello there!", "")); + EXPECT_FALSE(EndsWith("", "hello there!")); +} + +TEST(StringPieceTest, StartsWith) { + EXPECT_TRUE(StartsWith("hello there!", "hello")); + EXPECT_TRUE(StartsWith("hello there!", "hello ")); + EXPECT_FALSE(StartsWith("hello there!", "there!")); + EXPECT_FALSE(StartsWith("hello there!", " hello there! ")); + EXPECT_TRUE(StartsWith("hello there!", "")); + EXPECT_FALSE(StartsWith("", "hello there!")); +} + +TEST(StringPieceTest, ConsumePrefix) { + StringPiece str("hello there!"); + EXPECT_TRUE(ConsumePrefix(&str, "hello ")); + EXPECT_EQ(str.ToString(), "there!"); + EXPECT_TRUE(ConsumePrefix(&str, "there")); + EXPECT_EQ(str.ToString(), "!"); + EXPECT_FALSE(ConsumePrefix(&str, "!!")); + EXPECT_TRUE(ConsumePrefix(&str, "")); + EXPECT_TRUE(ConsumePrefix(&str, "!")); + EXPECT_EQ(str.ToString(), ""); + EXPECT_TRUE(ConsumePrefix(&str, "")); + EXPECT_FALSE(ConsumePrefix(&str, "!")); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/utils/strings/utf8.cc b/utils/strings/utf8.cc new file mode 100644 index 0000000..faaf854 --- /dev/null +++ b/utils/strings/utf8.cc @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/strings/utf8.h" + +namespace libtextclassifier3 { +bool IsValidUTF8(const char *src, int size) { + for (int i = 0; i < size;) { + const int char_length = ValidUTF8CharLength(src + i, size - i); + if (char_length <= 0) { + return false; + } + i += char_length; + } + return true; +} + +int ValidUTF8CharLength(const char *src, int size) { + // Unexpected trail byte. + if (IsTrailByte(src[0])) { + return -1; + } + + const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[0]); + if (num_codepoint_bytes <= 0 || num_codepoint_bytes > size) { + return -1; + } + + // Check that remaining bytes in the codepoint are trailing bytes. + for (int k = 1; k < num_codepoint_bytes; k++) { + if (!IsTrailByte(src[k])) { + return -1; + } + } + + return num_codepoint_bytes; +} + +} // namespace libtextclassifier3 diff --git a/util/strings/utf8.h b/utils/strings/utf8.h index 1e75da2..6c4c8a0 100644 --- a/util/strings/utf8.h +++ b/utils/strings/utf8.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_ -#define LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_ +#define LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_ -namespace libtextclassifier2 { +namespace libtextclassifier3 { // Returns the length (number of bytes) of the Unicode code point starting at // src, based on inspecting just that one byte. Preconditions: src != NULL, @@ -47,6 +47,10 @@ static inline bool IsTrailByte(char x) { // Returns true iff src points to a well-formed UTF-8 string. bool IsValidUTF8(const char *src, int size); -} // namespace libtextclassifier2 +// Returns byte length of the first valid codepoint in the string, otherwise -1 +// if pointing to an ill-formed UTF-8 character. +int ValidUTF8CharLength(const char *src, int size); -#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_ +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_ diff --git a/utils/strings/utf8_test.cc b/utils/strings/utf8_test.cc new file mode 100644 index 0000000..a71d4f2 --- /dev/null +++ b/utils/strings/utf8_test.cc @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "utils/strings/utf8.h" + +namespace libtextclassifier3 { +namespace { + +TEST(Utf8Test, GetNumBytesForUTF8Char) { + EXPECT_EQ(GetNumBytesForUTF8Char("\x00"), 0); + EXPECT_EQ(GetNumBytesForUTF8Char("h"), 1); + EXPECT_EQ(GetNumBytesForUTF8Char("😋"), 4); + EXPECT_EQ(GetNumBytesForUTF8Char("㍿"), 3); +} + +TEST(Utf8Test, IsValidUTF8) { + EXPECT_TRUE(IsValidUTF8("1234😋hello", 13)); + EXPECT_TRUE(IsValidUTF8("\u304A\u00B0\u106B", 8)); + EXPECT_TRUE(IsValidUTF8("this is a test😋😋😋", 26)); + EXPECT_TRUE(IsValidUTF8("\xf0\x9f\x98\x8b", 4)); + // Too short (string is too short). + EXPECT_FALSE(IsValidUTF8("\xf0\x9f", 2)); + // Too long (too many trailing bytes). + EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x8b\x8b", 5)); + // Too short (too few trailing bytes). + EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x61\x61", 5)); +} + +TEST(Utf8Test, ValidUTF8CharLength) { + EXPECT_EQ(ValidUTF8CharLength("1234😋hello", 13), 1); + EXPECT_EQ(ValidUTF8CharLength("\u304A\u00B0\u106B", 8), 3); + EXPECT_EQ(ValidUTF8CharLength("this is a test😋😋😋", 26), 1); + EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b", 4), 4); + // Too short (string is too short). + EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f", 2), -1); + // Too long (too many trailing bytes). First character is valid. + EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b\x8b", 5), 4); + // Too short (too few trailing bytes). + EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x61\x61", 5), -1); +} + +} // namespace +} // namespace libtextclassifier3 diff --git a/tensor-view.cc b/utils/tensor-view.cc index 4acadc5..0ca0b7f 100644 --- a/tensor-view.cc +++ b/utils/tensor-view.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "tensor-view.h" +#include "utils/tensor-view.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace internal { int NumberOfElements(const std::vector<int>& shape) { @@ -28,4 +28,4 @@ int NumberOfElements(const std::vector<int>& shape) { } } // namespace internal -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/tensor-view.h b/utils/tensor-view.h index 00ab08c..a46ebd1 100644 --- a/tensor-view.h +++ b/utils/tensor-view.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,13 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ -#define LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_TENSOR_VIEW_H_ +#define LIBTEXTCLASSIFIER_UTILS_TENSOR_VIEW_H_ #include <algorithm> #include <vector> -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace internal { // Computes the number of elements in a tensor of given shape. int NumberOfElements(const std::vector<int>& shape); @@ -67,6 +67,6 @@ class TensorView { const int size_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_TENSOR_VIEW_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_TENSOR_VIEW_H_ diff --git a/tensor-view_test.cc b/utils/tensor-view_test.cc index d50fac7..9467264 100644 --- a/tensor-view_test.cc +++ b/utils/tensor-view_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "tensor-view.h" +#include "utils/tensor-view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { TEST(TensorViewTest, TestSize) { @@ -49,4 +49,4 @@ TEST(TensorViewTest, TestSize) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/utils/testing/logging_event_listener.h b/utils/testing/logging_event_listener.h new file mode 100644 index 0000000..2663a9c --- /dev/null +++ b/utils/testing/logging_event_listener.h @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_ +#define LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_ + +#include "gtest/gtest.h" + +namespace libtextclassifier3 { + +// TestEventListener that writes test results to the log so that they will be +// visible in the logcat output in Sponge. +// The formatting of the output is patterend after the output produced by the +// standard PrettyUnitTestResultPrinter. +class LoggingEventListener : public ::testing::TestEventListener { + public: + void OnTestProgramStart(const testing::UnitTest& unit_test) override; + + void OnTestIterationStart(const testing::UnitTest& unit_test, + int iteration) override; + + void OnEnvironmentsSetUpStart(const testing::UnitTest& unit_test) override; + + void OnEnvironmentsSetUpEnd(const testing::UnitTest& unit_test) override; + + void OnTestCaseStart(const testing::TestCase& test_case) override; + + void OnTestStart(const testing::TestInfo& test_info) override; + + void OnTestPartResult( + const testing::TestPartResult& test_part_result) override; + + void OnTestEnd(const testing::TestInfo& test_info) override; + + void OnTestCaseEnd(const testing::TestCase& test_case) override; + + void OnEnvironmentsTearDownStart(const testing::UnitTest& unit_test) override; + + void OnEnvironmentsTearDownEnd(const testing::UnitTest& unit_test) override; + + void OnTestIterationEnd(const testing::UnitTest& unit_test, + int iteration) override; + + void OnTestProgramEnd(const testing::UnitTest& unit_test) override; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_ diff --git a/utils/tflite-model-executor.cc b/utils/tflite-model-executor.cc new file mode 100644 index 0000000..1c850ac --- /dev/null +++ b/utils/tflite-model-executor.cc @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tflite-model-executor.h" + +#include "utils/base/logging.h" +#include "tensorflow/contrib/lite/kernels/register.h" + +// Forward declaration of custom TensorFlow Lite ops for registration. +namespace tflite { +namespace ops { +namespace builtin { +TfLiteRegistration* Register_DIV(); +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Register_SOFTMAX(); // TODO(smillius): remove. +} // namespace builtin +} // namespace ops +} // namespace tflite + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) { + resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, + ::tflite::ops::builtin::Register_FULLY_CONNECTED()); +} + +namespace libtextclassifier3 { + +inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() { +#ifdef TC3_USE_SELECTIVE_REGISTRATION + std::unique_ptr<tflite::MutableOpResolver> resolver( + new tflite::MutableOpResolver); + resolver->AddBuiltin(tflite::BuiltinOperator_DIV, + tflite::ops::builtin::Register_DIV()); + resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, + tflite::ops::builtin::Register_FULLY_CONNECTED()); + resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX, + tflite::ops::builtin::Register_SOFTMAX()); + RegisterSelectedOps(resolver.get()); +#else + std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver( + new tflite::ops::builtin::BuiltinOpResolver); +#endif + return std::unique_ptr<tflite::OpResolver>(std::move(resolver)); +} + +std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( + const tflite::Model* model_spec) { + std::unique_ptr<const tflite::FlatBufferModel> model( + tflite::FlatBufferModel::BuildFromModel(model_spec)); + if (!model || !model->initialized()) { + TC3_LOG(ERROR) << "Could not build TFLite model from a model spec."; + return nullptr; + } + return model; +} + +std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer) { + const tflite::Model* model = + flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); + flatbuffers::Verifier verifier(model_spec_buffer->data(), + model_spec_buffer->Length()); + if (!model->Verify(verifier)) { + return nullptr; + } + return TfLiteModelFromModelSpec(model); +} + +TfLiteModelExecutor::TfLiteModelExecutor( + std::unique_ptr<const tflite::FlatBufferModel> model) + : model_(std::move(model)), resolver_(BuildOpResolver()) {} + +std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter() + const { + std::unique_ptr<tflite::Interpreter> interpreter; + tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter); + return interpreter; +} + +template <> +void TfLiteModelExecutor::SetInput(const int input_index, + const std::vector<std::string>& input_data, + tflite::Interpreter* interpreter) const { + tflite::DynamicBuffer buf; + for (const std::string& s : input_data) { + buf.AddString(s.data(), s.length()); + } + // TODO(b/120230709): Use WriteToTensorAsVector() instead, once available in + // AOSP. + buf.WriteToTensor(interpreter->tensor(interpreter->inputs()[input_index])); +} + +template <> +std::vector<tflite::StringRef> TfLiteModelExecutor::Output( + const int output_index, tflite::Interpreter* interpreter) const { + const TfLiteTensor* output_tensor = + interpreter->tensor(interpreter->outputs()[output_index]); + const int num_strings = tflite::GetStringCount(output_tensor); + std::vector<tflite::StringRef> output(num_strings); + for (int i = 0; i < num_strings; i++) { + output[i] = tflite::GetString(output_tensor, i); + } + return output; +} + +template <> +std::vector<std::string> TfLiteModelExecutor::Output( + const int output_index, tflite::Interpreter* interpreter) const { + std::vector<std::string> output; + for (const tflite::StringRef& s : + Output<tflite::StringRef>(output_index, interpreter)) { + output.push_back(std::string(s.str, s.len)); + } + return output; +} + +} // namespace libtextclassifier3 diff --git a/utils/tflite-model-executor.h b/utils/tflite-model-executor.h new file mode 100644 index 0000000..fd00924 --- /dev/null +++ b/utils/tflite-model-executor.h @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Contains classes that can execute different models/parts of a model. + +#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ +#define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ + +#include <memory> + +#include "utils/base/logging.h" +#include "utils/tensor-view.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace libtextclassifier3 { + +std::unique_ptr<tflite::OpResolver> BuildOpResolver(); +std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( + const tflite::Model*); +std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( + const flatbuffers::Vector<uint8_t>*); + +// Executor for the text selection prediction and classification models. +class TfLiteModelExecutor { + public: + static std::unique_ptr<TfLiteModelExecutor> FromModelSpec( + const tflite::Model* model_spec) { + auto model = TfLiteModelFromModelSpec(model_spec); + if (!model) { + return nullptr; + } + return std::unique_ptr<TfLiteModelExecutor>( + new TfLiteModelExecutor(std::move(model))); + } + + static std::unique_ptr<TfLiteModelExecutor> FromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer) { + auto model = TfLiteModelFromBuffer(model_spec_buffer); + if (!model) { + return nullptr; + } + return std::unique_ptr<TfLiteModelExecutor>( + new TfLiteModelExecutor(std::move(model))); + } + + // Creates an Interpreter for the model that serves as a scratch-pad for the + // inference. The Interpreter is NOT thread-safe. + std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; + + template <typename T> + void SetInput(const int input_index, const TensorView<T>& input_data, + tflite::Interpreter* interpreter) const { + input_data.copy_to(interpreter->typed_input_tensor<T>(input_index), + input_data.size()); + } + + template <typename T> + void SetInput(const int input_index, const std::vector<T>& input_data, + tflite::Interpreter* interpreter) const { + std::copy(input_data.begin(), input_data.end(), + interpreter->typed_input_tensor<T>(input_index)); + } + + template <typename T> + TensorView<T> OutputView(const int output_index, + tflite::Interpreter* interpreter) const { + TfLiteTensor* output_tensor = + interpreter->tensor(interpreter->outputs()[output_index]); + return TensorView<T>(interpreter->typed_output_tensor<T>(output_index), + std::vector<int>(output_tensor->dims->data, + output_tensor->dims->data + + output_tensor->dims->size)); + } + + template <typename T> + std::vector<T> Output(const int output_index, + tflite::Interpreter* interpreter) const { + TensorView<T> output_view = OutputView<T>(output_index, interpreter); + return std::vector<T>(output_view.data(), + output_view.data() + output_view.size()); + } + + protected: + explicit TfLiteModelExecutor( + std::unique_ptr<const tflite::FlatBufferModel> model); + + std::unique_ptr<const tflite::FlatBufferModel> model_; + std::unique_ptr<tflite::OpResolver> resolver_; +}; + +template <> +void TfLiteModelExecutor::SetInput(const int input_index, + const std::vector<std::string>& input_data, + tflite::Interpreter* interpreter) const; + +template <> +std::vector<tflite::StringRef> TfLiteModelExecutor::Output( + const int output_index, tflite::Interpreter* interpreter) const; + +template <> +std::vector<std::string> TfLiteModelExecutor::Output( + const int output_index, tflite::Interpreter* interpreter) const; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ diff --git a/util/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc index 2ef79e9..81492d8 100644 --- a/util/utf8/unicodetext.cc +++ b/utils/utf8/unicodetext.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "util/utf8/unicodetext.h" +#include "utils/utf8/unicodetext.h" #include <string.h> #include <algorithm> -#include "util/strings/utf8.h" +#include "utils/strings/utf8.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // *************** Data representation ********** // Note: the copy constructor is undefined. @@ -176,7 +176,7 @@ int runetochar(const char32 rune, char* dest) { } // namespace -UnicodeText& UnicodeText::AppendCodepoint(char32 ch) { +UnicodeText& UnicodeText::push_back(char32 ch) { char str[4]; int char_len = runetochar(ch, str); repr_.append(str, char_len); @@ -296,4 +296,4 @@ UnicodeText UTF8ToUnicodeText(const std::string& str) { return UTF8ToUnicodeText(str, /*do_copy=*/true); } -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/util/utf8/unicodetext.h b/utils/utf8/unicodetext.h index ec08f53..eb206b8 100644 --- a/util/utf8/unicodetext.h +++ b/utils/utf8/unicodetext.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,16 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_ -#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_ +#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_ #include <iterator> #include <string> #include <utility> -#include "util/base/integral_types.h" +#include "utils/base/integral_types.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { // ***************************** UnicodeText ************************** // @@ -168,7 +168,7 @@ class UnicodeText { // Calling this may invalidate pointers to underlying data. UnicodeText& AppendUTF8(const char* utf8, int len); - UnicodeText& AppendCodepoint(char32 ch); + UnicodeText& push_back(char32 ch); void clear(); std::string ToUTF8String() const; @@ -219,6 +219,6 @@ UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy); UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy); UnicodeText UTF8ToUnicodeText(const std::string& str); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_ diff --git a/util/utf8/unicodetext_test.cc b/utils/utf8/unicodetext_test.cc index 9ec7621..7ebb415 100644 --- a/util/utf8/unicodetext_test.cc +++ b/utils/utf8/unicodetext_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,21 @@ * limitations under the License. */ -#include "util/utf8/unicodetext.h" +#include "utils/utf8/unicodetext.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { class UnicodeTextTest : public testing::Test { protected: UnicodeTextTest() : empty_text_() { - text_.AppendCodepoint(0x1C0); - text_.AppendCodepoint(0x4E8C); - text_.AppendCodepoint(0xD7DB); - text_.AppendCodepoint(0x34); - text_.AppendCodepoint(0x1D11E); + text_.push_back(0x1C0); + text_.push_back(0x4E8C); + text_.push_back(0xD7DB); + text_.push_back(0x34); + text_.push_back(0x1D11E); } UnicodeText empty_text_; @@ -186,4 +186,4 @@ TEST_F(OperatorTest, Empty) { } } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/utils/utf8/unilib-javaicu.cc b/utils/utf8/unilib-javaicu.cc new file mode 100644 index 0000000..dab4f70 --- /dev/null +++ b/utils/utf8/unilib-javaicu.cc @@ -0,0 +1,694 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/utf8/unilib-javaicu.h" + +#include <algorithm> +#include <cassert> +#include <cctype> +#include <map> + +#include "utils/java/string_utils.h" + +namespace libtextclassifier3 { +namespace { + +// ----------------------------------------------------------------------------- +// Native implementations. +// ----------------------------------------------------------------------------- + +#define ARRAYSIZE(a) sizeof(a) / sizeof(*a) + +// Derived from http://www.unicode.org/Public/UNIDATA/UnicodeData.txt +// grep -E "Ps" UnicodeData.txt | \ +// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p" +// IMPORTANT: entries with the same offsets in kOpeningBrackets and +// kClosingBrackets must be counterparts. +constexpr char32 kOpeningBrackets[] = { + 0x0028, 0x005B, 0x007B, 0x0F3C, 0x2045, 0x207D, 0x208D, 0x2329, 0x2768, + 0x276A, 0x276C, 0x2770, 0x2772, 0x2774, 0x27E6, 0x27E8, 0x27EA, 0x27EC, + 0x27EE, 0x2983, 0x2985, 0x2987, 0x2989, 0x298B, 0x298D, 0x298F, 0x2991, + 0x2993, 0x2995, 0x2997, 0x29FC, 0x2E22, 0x2E24, 0x2E26, 0x2E28, 0x3008, + 0x300A, 0x300C, 0x300E, 0x3010, 0x3014, 0x3016, 0x3018, 0x301A, 0xFD3F, + 0xFE17, 0xFE35, 0xFE37, 0xFE39, 0xFE3B, 0xFE3D, 0xFE3F, 0xFE41, 0xFE43, + 0xFE47, 0xFE59, 0xFE5B, 0xFE5D, 0xFF08, 0xFF3B, 0xFF5B, 0xFF5F, 0xFF62}; +constexpr int kNumOpeningBrackets = ARRAYSIZE(kOpeningBrackets); + +// grep -E "Pe" UnicodeData.txt | \ +// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p" +constexpr char32 kClosingBrackets[] = { + 0x0029, 0x005D, 0x007D, 0x0F3D, 0x2046, 0x207E, 0x208E, 0x232A, 0x2769, + 0x276B, 0x276D, 0x2771, 0x2773, 0x2775, 0x27E7, 0x27E9, 0x27EB, 0x27ED, + 0x27EF, 0x2984, 0x2986, 0x2988, 0x298A, 0x298C, 0x298E, 0x2990, 0x2992, + 0x2994, 0x2996, 0x2998, 0x29FD, 0x2E23, 0x2E25, 0x2E27, 0x2E29, 0x3009, + 0x300B, 0x300D, 0x300F, 0x3011, 0x3015, 0x3017, 0x3019, 0x301B, 0xFD3E, + 0xFE18, 0xFE36, 0xFE38, 0xFE3A, 0xFE3C, 0xFE3E, 0xFE40, 0xFE42, 0xFE44, + 0xFE48, 0xFE5A, 0xFE5C, 0xFE5E, 0xFF09, 0xFF3D, 0xFF5D, 0xFF60, 0xFF63}; +constexpr int kNumClosingBrackets = ARRAYSIZE(kClosingBrackets); + +// grep -E "WS" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /" +constexpr char32 kWhitespaces[] = { + 0x000C, 0x0020, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004, + 0x2005, 0x2006, 0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x205F, + 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84, 0x2B85, + 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347, 0x1DA0A, + 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0, 0x1F500, + 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE}; +constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces); + +// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /" +// As the name suggests, these ranges are always 10 codepoints long, so we just +// store the end of the range. +constexpr char32 kDecimalDigitRangesEnd[] = { + 0x0039, 0x0669, 0x06f9, 0x07c9, 0x096f, 0x09ef, 0x0a6f, 0x0aef, + 0x0b6f, 0x0bef, 0x0c6f, 0x0cef, 0x0d6f, 0x0def, 0x0e59, 0x0ed9, + 0x0f29, 0x1049, 0x1099, 0x17e9, 0x1819, 0x194f, 0x19d9, 0x1a89, + 0x1a99, 0x1b59, 0x1bb9, 0x1c49, 0x1c59, 0xa629, 0xa8d9, 0xa909, + 0xa9d9, 0xa9f9, 0xaa59, 0xabf9, 0xff19, 0x104a9, 0x1106f, 0x110f9, + 0x1113f, 0x111d9, 0x112f9, 0x11459, 0x114d9, 0x11659, 0x116c9, 0x11739, + 0x118e9, 0x11c59, 0x11d59, 0x16a69, 0x16b59, 0x1d7ff}; +constexpr int kNumDecimalDigitRangesEnd = ARRAYSIZE(kDecimalDigitRangesEnd); + +// grep -E "Lu" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /" +// There are three common ways in which upper/lower case codepoint ranges +// were introduced: one offs, dense ranges, and ranges that alternate between +// lower and upper case. For the sake of keeping out binary size down, we +// treat each independently. +constexpr char32 kUpperSingles[] = { + 0x01b8, 0x01bc, 0x01c4, 0x01c7, 0x01ca, 0x01f1, 0x0376, 0x037f, + 0x03cf, 0x03f4, 0x03fa, 0x10c7, 0x10cd, 0x2102, 0x2107, 0x2115, + 0x2145, 0x2183, 0x2c72, 0x2c75, 0x2cf2, 0xa7b6}; +constexpr int kNumUpperSingles = ARRAYSIZE(kUpperSingles); +constexpr char32 kUpperRanges1Start[] = { + 0x0041, 0x00c0, 0x00d8, 0x0181, 0x018a, 0x018e, 0x0193, 0x0196, + 0x019c, 0x019f, 0x01b2, 0x01f7, 0x023a, 0x023d, 0x0244, 0x0389, + 0x0392, 0x03a3, 0x03d2, 0x03fd, 0x0531, 0x10a0, 0x13a0, 0x1f08, + 0x1f18, 0x1f28, 0x1f38, 0x1f48, 0x1f68, 0x1fb8, 0x1fc8, 0x1fd8, + 0x1fe8, 0x1ff8, 0x210b, 0x2110, 0x2119, 0x212b, 0x2130, 0x213e, + 0x2c00, 0x2c63, 0x2c6e, 0x2c7e, 0xa7ab, 0xa7b0}; +constexpr int kNumUpperRanges1Start = ARRAYSIZE(kUpperRanges1Start); +constexpr char32 kUpperRanges1End[] = { + 0x005a, 0x00d6, 0x00de, 0x0182, 0x018b, 0x0191, 0x0194, 0x0198, + 0x019d, 0x01a0, 0x01b3, 0x01f8, 0x023b, 0x023e, 0x0246, 0x038a, + 0x03a1, 0x03ab, 0x03d4, 0x042f, 0x0556, 0x10c5, 0x13f5, 0x1f0f, + 0x1f1d, 0x1f2f, 0x1f3f, 0x1f4d, 0x1f6f, 0x1fbb, 0x1fcb, 0x1fdb, + 0x1fec, 0x1ffb, 0x210d, 0x2112, 0x211d, 0x212d, 0x2133, 0x213f, + 0x2c2e, 0x2c64, 0x2c70, 0x2c80, 0xa7ae, 0xa7b4}; +constexpr int kNumUpperRanges1End = ARRAYSIZE(kUpperRanges1End); +constexpr char32 kUpperRanges2Start[] = { + 0x0100, 0x0139, 0x014a, 0x0179, 0x0184, 0x0187, 0x01a2, 0x01a7, 0x01ac, + 0x01af, 0x01b5, 0x01cd, 0x01de, 0x01f4, 0x01fa, 0x0241, 0x0248, 0x0370, + 0x0386, 0x038c, 0x038f, 0x03d8, 0x03f7, 0x0460, 0x048a, 0x04c1, 0x04d0, + 0x1e00, 0x1e9e, 0x1f59, 0x2124, 0x2c60, 0x2c67, 0x2c82, 0x2ceb, 0xa640, + 0xa680, 0xa722, 0xa732, 0xa779, 0xa77e, 0xa78b, 0xa790, 0xa796}; +constexpr int kNumUpperRanges2Start = ARRAYSIZE(kUpperRanges2Start); +constexpr char32 kUpperRanges2End[] = { + 0x0136, 0x0147, 0x0178, 0x017d, 0x0186, 0x0189, 0x01a6, 0x01a9, 0x01ae, + 0x01b1, 0x01b7, 0x01db, 0x01ee, 0x01f6, 0x0232, 0x0243, 0x024e, 0x0372, + 0x0388, 0x038e, 0x0391, 0x03ee, 0x03f9, 0x0480, 0x04c0, 0x04cd, 0x052e, + 0x1e94, 0x1efe, 0x1f5f, 0x212a, 0x2c62, 0x2c6d, 0x2ce2, 0x2ced, 0xa66c, + 0xa69a, 0xa72e, 0xa76e, 0xa77d, 0xa786, 0xa78d, 0xa792, 0xa7aa}; +constexpr int kNumUpperRanges2End = ARRAYSIZE(kUpperRanges2End); + +// grep -E "Lu" UnicodeData.txt | \ +// sed -rne "s/^([0-9A-Z]+);.*;([0-9A-Z]+);$/(0x\1, 0x\2), /p" +// We have two strategies for mapping from upper to lower case. We have single +// character lookups that do not follow a pattern, and ranges for which there +// is a constant codepoint shift. +// Note that these ranges ignore anything that's not an upper case character, +// so when applied to a non-uppercase character the result is incorrect. +constexpr int kToLowerSingles[] = { + 0x0130, 0x0178, 0x0181, 0x0186, 0x018b, 0x018e, 0x018f, 0x0190, 0x0191, + 0x0194, 0x0196, 0x0197, 0x0198, 0x019c, 0x019d, 0x019f, 0x01a6, 0x01a9, + 0x01ae, 0x01b7, 0x01f6, 0x01f7, 0x0220, 0x023a, 0x023d, 0x023e, 0x0243, + 0x0244, 0x0245, 0x037f, 0x0386, 0x038c, 0x03cf, 0x03f4, 0x03f9, 0x04c0, + 0x1e9e, 0x1fec, 0x2126, 0x212a, 0x212b, 0x2132, 0x2183, 0x2c60, 0x2c62, + 0x2c63, 0x2c64, 0x2c6d, 0x2c6e, 0x2c6f, 0x2c70, 0xa77d, 0xa78d, 0xa7aa, + 0xa7ab, 0xa7ac, 0xa7ad, 0xa7ae, 0xa7b0, 0xa7b1, 0xa7b2, 0xa7b3}; +constexpr int kNumToLowerSingles = ARRAYSIZE(kToLowerSingles); +constexpr int kToLowerSinglesOffsets[] = { + -199, -121, 210, 206, 1, 79, 202, 203, 1, + 207, 211, 209, 1, 211, 213, 214, 218, 218, + 218, 219, -97, -56, -130, 10795, -163, 10792, -195, + 69, 71, 116, 38, 64, 8, -60, -7, 15, + -7615, -7, -7517, -8383, -8262, 28, 1, 1, -10743, + -3814, -10727, -10780, -10749, -10783, -10782, -35332, -42280, -42308, + -42319, -42315, -42305, -42308, -42258, -42282, -42261, 928}; +constexpr int kNumToLowerSinglesOffsets = ARRAYSIZE(kToLowerSinglesOffsets); +constexpr int kToLowerRangesStart[] = { + 0x0041, 0x0100, 0x0189, 0x01a0, 0x01b1, 0x01b3, 0x0388, 0x038e, 0x0391, + 0x03d8, 0x03fd, 0x0400, 0x0410, 0x0460, 0x0531, 0x10a0, 0x13a0, 0x13f0, + 0x1e00, 0x1f08, 0x1fba, 0x1fc8, 0x1fd8, 0x1fda, 0x1fe8, 0x1fea, 0x1ff8, + 0x1ffa, 0x2c00, 0x2c67, 0x2c7e, 0x2c80, 0xff21, 0x10400, 0x10c80, 0x118a0}; +constexpr int kNumToLowerRangesStart = ARRAYSIZE(kToLowerRangesStart); +constexpr int kToLowerRangesEnd[] = { + 0x00de, 0x0187, 0x019f, 0x01af, 0x01b2, 0x0386, 0x038c, 0x038f, 0x03cf, + 0x03fa, 0x03ff, 0x040f, 0x042f, 0x052e, 0x0556, 0x10cd, 0x13ef, 0x13f5, + 0x1efe, 0x1fb9, 0x1fbb, 0x1fcb, 0x1fd9, 0x1fdb, 0x1fe9, 0x1fec, 0x1ff9, + 0x2183, 0x2c64, 0x2c75, 0x2c7f, 0xa7b6, 0xff3a, 0x104d3, 0x10cb2, 0x118bf}; +constexpr int kNumToLowerRangesEnd = ARRAYSIZE(kToLowerRangesEnd); +constexpr int kToLowerRangesOffsets[] = { + 32, 1, 205, 1, 217, 1, 37, 63, 32, 1, -130, 80, + 32, 1, 48, 7264, 38864, 8, 1, -8, -74, -86, -8, -100, + -8, -112, -128, -126, 48, 1, -10815, 1, 32, 40, 64, 32}; +constexpr int kNumToLowerRangesOffsets = ARRAYSIZE(kToLowerRangesOffsets); + +#undef ARRAYSIZE + +static_assert(kNumOpeningBrackets == kNumClosingBrackets, + "mismatching number of opening and closing brackets"); +static_assert(kNumUpperRanges1Start == kNumUpperRanges1End, + "number of uppercase stride 1 range starts/ends doesn't match"); +static_assert(kNumUpperRanges2Start == kNumUpperRanges2End, + "number of uppercase stride 2 range starts/ends doesn't match"); +static_assert(kNumToLowerSingles == kNumToLowerSinglesOffsets, + "number of to lower singles and offsets doesn't match"); +static_assert(kNumToLowerRangesStart == kNumToLowerRangesEnd, + "mismatching number of range starts/ends for to lower ranges"); +static_assert(kNumToLowerRangesStart == kNumToLowerRangesOffsets, + "number of to lower ranges and offsets doesn't match"); + +constexpr int kNoMatch = -1; + +// Returns the index of the element in the array that matched the given +// codepoint, or kNoMatch if the element didn't exist. +// The input array must be in sorted order. +int GetMatchIndex(const char32* array, int array_length, char32 c) { + const char32* end = array + array_length; + const auto find_it = std::lower_bound(array, end, c); + if (find_it != end && *find_it == c) { + return find_it - array; + } else { + return kNoMatch; + } +} + +// Returns the index of the range in the array that overlapped the given +// codepoint, or kNoMatch if no such range existed. +// The input array must be in sorted order. +int GetOverlappingRangeIndex(const char32* arr, int arr_length, + int range_length, char32 c) { + const char32* end = arr + arr_length; + const auto find_it = std::lower_bound(arr, end, c); + if (find_it == end) { + return kNoMatch; + } + // The end is inclusive, we so subtract one less than the range length. + const char32 range_end = *find_it; + const char32 range_start = range_end - (range_length - 1); + if (c < range_start || range_end < c) { + return kNoMatch; + } else { + return find_it - arr; + } +} + +// As above, but with explicit codepoint start and end indices for the range. +// The input array must be in sorted order. +int GetOverlappingRangeIndex(const char32* start_arr, const char32* end_arr, + int arr_length, int stride, char32 c) { + const char32* end_arr_end = end_arr + arr_length; + const auto find_it = std::lower_bound(end_arr, end_arr_end, c); + if (find_it == end_arr_end) { + return kNoMatch; + } + // Find the corresponding start. + const int range_index = find_it - end_arr; + const char32 range_start = start_arr[range_index]; + const char32 range_end = *find_it; + if (c < range_start || range_end < c) { + return kNoMatch; + } + if ((c - range_start) % stride == 0) { + return range_index; + } else { + return kNoMatch; + } +} + +} // anonymous namespace + +UniLib::UniLib() { + TC3_LOG(FATAL) << "Java ICU UniLib must be initialized with a JniCache."; +} + +UniLib::UniLib(const std::shared_ptr<JniCache>& jni_cache) + : jni_cache_(jni_cache) {} + +bool UniLib::IsOpeningBracket(char32 codepoint) const { + return GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint) >= 0; +} + +bool UniLib::IsClosingBracket(char32 codepoint) const { + return GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint) >= 0; +} + +bool UniLib::IsWhitespace(char32 codepoint) const { + return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0; +} + +bool UniLib::IsDigit(char32 codepoint) const { + return GetOverlappingRangeIndex(kDecimalDigitRangesEnd, + kNumDecimalDigitRangesEnd, + /*range_length=*/10, codepoint) >= 0; +} + +bool UniLib::IsUpper(char32 codepoint) const { + if (GetMatchIndex(kUpperSingles, kNumUpperSingles, codepoint) >= 0) { + return true; + } else if (GetOverlappingRangeIndex(kUpperRanges1Start, kUpperRanges1End, + kNumUpperRanges1Start, /*stride=*/1, + codepoint) >= 0) { + return true; + } else if (GetOverlappingRangeIndex(kUpperRanges2Start, kUpperRanges2End, + kNumUpperRanges2Start, /*stride=*/2, + codepoint) >= 0) { + return true; + } else { + return false; + } +} + +char32 UniLib::ToLower(char32 codepoint) const { + // Make sure we still produce output even if the method is called for a + // codepoint that's not an uppercase character. + if (!IsUpper(codepoint)) { + return codepoint; + } + const int singles_idx = + GetMatchIndex(kToLowerSingles, kNumToLowerSingles, codepoint); + if (singles_idx >= 0) { + return codepoint + kToLowerSinglesOffsets[singles_idx]; + } + const int ranges_idx = + GetOverlappingRangeIndex(kToLowerRangesStart, kToLowerRangesEnd, + kNumToLowerRangesStart, /*stride=*/1, codepoint); + if (ranges_idx >= 0) { + return codepoint + kToLowerRangesOffsets[ranges_idx]; + } + return codepoint; +} + +char32 UniLib::GetPairedBracket(char32 codepoint) const { + const int open_offset = + GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint); + if (open_offset >= 0) { + return kClosingBrackets[open_offset]; + } + const int close_offset = + GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint); + if (close_offset >= 0) { + return kOpeningBrackets[close_offset]; + } + return codepoint; +} + +// ----------------------------------------------------------------------------- +// Implementations that call out to JVM. Behold the beauty. +// ----------------------------------------------------------------------------- + +bool UniLib::ParseInt32(const UnicodeText& text, int* result) const { + if (jni_cache_) { + JNIEnv* env = jni_cache_->GetEnv(); + const ScopedLocalRef<jstring> text_java = + jni_cache_->ConvertToJavaString(text); + jint res = env->CallStaticIntMethod(jni_cache_->integer_class.get(), + jni_cache_->integer_parse_int, + text_java.get()); + if (jni_cache_->ExceptionCheckAndClear()) { + return false; + } + *result = res; + return true; + } + return false; +} + +std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern( + const UnicodeText& regex) const { + return std::unique_ptr<UniLib::RegexPattern>( + new UniLib::RegexPattern(jni_cache_.get(), regex)); +} + +UniLib::RegexPattern::RegexPattern(const JniCache* jni_cache, + const UnicodeText& regex) + : jni_cache_(jni_cache), + pattern_(nullptr, jni_cache ? jni_cache->jvm : nullptr) { + if (jni_cache_) { + JNIEnv* jenv = jni_cache_->GetEnv(); + const ScopedLocalRef<jstring> regex_java = + jni_cache->ConvertToJavaString(regex); + pattern_ = MakeGlobalRef(jenv->CallStaticObjectMethod( + jni_cache_->pattern_class.get(), + jni_cache_->pattern_compile, regex_java.get()), + jenv, jni_cache_->jvm); + } +} + +constexpr int UniLib::RegexMatcher::kError; +constexpr int UniLib::RegexMatcher::kNoError; + +std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher( + const UnicodeText& context) const { + if (jni_cache_) { + JNIEnv* env = jni_cache_->GetEnv(); + const jstring context_java = + jni_cache_->ConvertToJavaString(context).release(); + if (!context_java) { + return nullptr; + } + const jobject matcher = env->CallObjectMethod( + pattern_.get(), jni_cache_->pattern_matcher, context_java); + if (jni_cache_->ExceptionCheckAndClear() || !matcher) { + return nullptr; + } + return std::unique_ptr<UniLib::RegexMatcher>(new RegexMatcher( + jni_cache_, MakeGlobalRef(matcher, env, jni_cache_->jvm), + MakeGlobalRef(context_java, env, jni_cache_->jvm))); + } else { + // NOTE: A valid object needs to be created here to pass the interface + // tests. + return std::unique_ptr<UniLib::RegexMatcher>( + new RegexMatcher(jni_cache_, nullptr, nullptr)); + } +} + +UniLib::RegexMatcher::RegexMatcher(const JniCache* jni_cache, + ScopedGlobalRef<jobject> matcher, + ScopedGlobalRef<jstring> text) + : jni_cache_(jni_cache), + matcher_(std::move(matcher)), + text_(std::move(text)) {} + +bool UniLib::RegexMatcher::Matches(int* status) const { + if (jni_cache_) { + *status = kNoError; + const bool result = jni_cache_->GetEnv()->CallBooleanMethod( + matcher_.get(), jni_cache_->matcher_matches); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return false; + } + return result; + } else { + *status = kError; + return false; + } +} + +bool UniLib::RegexMatcher::ApproximatelyMatches(int* status) { + *status = kNoError; + + jni_cache_->GetEnv()->CallObjectMethod(matcher_.get(), + jni_cache_->matcher_reset); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + if (!Find(status) || *status != kNoError) { + return false; + } + + const int found_start = jni_cache_->GetEnv()->CallIntMethod( + matcher_.get(), jni_cache_->matcher_start_idx, 0); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + const int found_end = jni_cache_->GetEnv()->CallIntMethod( + matcher_.get(), jni_cache_->matcher_end_idx, 0); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + int context_length_bmp = jni_cache_->GetEnv()->CallIntMethod( + text_.get(), jni_cache_->string_length); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return false; + } + + if (found_start != 0 || found_end != context_length_bmp) { + return false; + } + + return true; +} + +bool UniLib::RegexMatcher::UpdateLastFindOffset() const { + if (!last_find_offset_dirty_) { + return true; + } + + const int find_offset = jni_cache_->GetEnv()->CallIntMethod( + matcher_.get(), jni_cache_->matcher_start_idx, 0); + if (jni_cache_->ExceptionCheckAndClear()) { + return false; + } + + const int codepoint_count = jni_cache_->GetEnv()->CallIntMethod( + text_.get(), jni_cache_->string_code_point_count, last_find_offset_, + find_offset); + if (jni_cache_->ExceptionCheckAndClear()) { + return false; + } + + last_find_offset_codepoints_ += codepoint_count; + last_find_offset_ = find_offset; + last_find_offset_dirty_ = false; + + return true; +} + +bool UniLib::RegexMatcher::Find(int* status) { + if (jni_cache_) { + const bool result = jni_cache_->GetEnv()->CallBooleanMethod( + matcher_.get(), jni_cache_->matcher_find); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return false; + } + + last_find_offset_dirty_ = true; + *status = kNoError; + return result; + } else { + *status = kError; + return false; + } +} + +int UniLib::RegexMatcher::Start(int* status) const { + return Start(/*group_idx=*/0, status); +} + +int UniLib::RegexMatcher::Start(int group_idx, int* status) const { + if (jni_cache_) { + *status = kNoError; + + if (!UpdateLastFindOffset()) { + *status = kError; + return kError; + } + + const int java_index = jni_cache_->GetEnv()->CallIntMethod( + matcher_.get(), jni_cache_->matcher_start_idx, group_idx); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + // If the group didn't participate in the match the index is -1. + if (java_index == -1) { + return -1; + } + + const int unicode_index = jni_cache_->GetEnv()->CallIntMethod( + text_.get(), jni_cache_->string_code_point_count, last_find_offset_, + java_index); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + return unicode_index + last_find_offset_codepoints_; + } else { + *status = kError; + return kError; + } +} + +int UniLib::RegexMatcher::End(int* status) const { + return End(/*group_idx=*/0, status); +} + +int UniLib::RegexMatcher::End(int group_idx, int* status) const { + if (jni_cache_) { + *status = kNoError; + + if (!UpdateLastFindOffset()) { + *status = kError; + return kError; + } + + const int java_index = jni_cache_->GetEnv()->CallIntMethod( + matcher_.get(), jni_cache_->matcher_end_idx, group_idx); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + // If the group didn't participate in the match the index is -1. + if (java_index == -1) { + return -1; + } + + const int unicode_index = jni_cache_->GetEnv()->CallIntMethod( + text_.get(), jni_cache_->string_code_point_count, last_find_offset_, + java_index); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + return kError; + } + + return unicode_index + last_find_offset_codepoints_; + } else { + *status = kError; + return kError; + } +} + +UnicodeText UniLib::RegexMatcher::Group(int* status) const { + if (jni_cache_) { + JNIEnv* jenv = jni_cache_->GetEnv(); + const ScopedLocalRef<jstring> java_result( + reinterpret_cast<jstring>( + jenv->CallObjectMethod(matcher_.get(), jni_cache_->matcher_group)), + jenv); + if (jni_cache_->ExceptionCheckAndClear() || !java_result) { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + + std::string result; + if (!JStringToUtf8String(jenv, java_result.get(), &result)) { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + *status = kNoError; + return UTF8ToUnicodeText(result, /*do_copy=*/true); + } else { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } +} + +UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const { + if (jni_cache_) { + JNIEnv* jenv = jni_cache_->GetEnv(); + const ScopedLocalRef<jstring> java_result( + reinterpret_cast<jstring>(jenv->CallObjectMethod( + matcher_.get(), jni_cache_->matcher_group_idx, group_idx)), + jenv); + if (jni_cache_->ExceptionCheckAndClear()) { + *status = kError; + TC3_LOG(ERROR) << "Exception occurred"; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + + // java_result is nullptr when the group did not participate in the match. + // For these cases other UniLib implementations return empty string, and + // the participation can be checked by checking if Start() == -1. + if (!java_result) { + *status = kNoError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + + std::string result; + if (!JStringToUtf8String(jenv, java_result.get(), &result)) { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } + *status = kNoError; + return UTF8ToUnicodeText(result, /*do_copy=*/true); + } else { + *status = kError; + return UTF8ToUnicodeText("", /*do_copy=*/false); + } +} + +constexpr int UniLib::BreakIterator::kDone; + +UniLib::BreakIterator::BreakIterator(const JniCache* jni_cache, + const UnicodeText& text) + : jni_cache_(jni_cache), + text_(nullptr, jni_cache ? jni_cache->jvm : nullptr), + iterator_(nullptr, jni_cache ? jni_cache->jvm : nullptr), + last_break_index_(0), + last_unicode_index_(0) { + if (jni_cache_) { + JNIEnv* jenv = jni_cache_->GetEnv(); + text_ = MakeGlobalRef(jni_cache_->ConvertToJavaString(text).release(), jenv, + jni_cache->jvm); + if (!text_) { + return; + } + + iterator_ = MakeGlobalRef( + jenv->CallStaticObjectMethod(jni_cache->breakiterator_class.get(), + jni_cache->breakiterator_getwordinstance, + jni_cache->locale_us.get()), + jenv, jni_cache->jvm); + if (!iterator_) { + return; + } + jenv->CallVoidMethod(iterator_.get(), jni_cache->breakiterator_settext, + text_.get()); + } +} + +int UniLib::BreakIterator::Next() { + if (jni_cache_) { + const int break_index = jni_cache_->GetEnv()->CallIntMethod( + iterator_.get(), jni_cache_->breakiterator_next); + if (jni_cache_->ExceptionCheckAndClear() || + break_index == BreakIterator::kDone) { + return BreakIterator::kDone; + } + + const int token_unicode_length = jni_cache_->GetEnv()->CallIntMethod( + text_.get(), jni_cache_->string_code_point_count, last_break_index_, + break_index); + if (jni_cache_->ExceptionCheckAndClear()) { + return BreakIterator::kDone; + } + + last_break_index_ = break_index; + return last_unicode_index_ += token_unicode_length; + } + return BreakIterator::kDone; +} + +std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator( + const UnicodeText& text) const { + return std::unique_ptr<UniLib::BreakIterator>( + new UniLib::BreakIterator(jni_cache_.get(), text)); +} + +} // namespace libtextclassifier3 diff --git a/util/utf8/unilib-icu.h b/utils/utf8/unilib-javaicu.h index 8983756..a5ea54f 100644 --- a/util/utf8/unilib-icu.h +++ b/utils/utf8/unilib-javaicu.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,26 +14,30 @@ * limitations under the License. */ -// UniLib implementation with the help of ICU. UniLib is basically a wrapper -// around the ICU functionality. +// An implementation of Unilib that uses Android Java interfaces via JNI. The +// performance critical ops have been re-implemented in C++. +// Specifically, this class must be compatible with API level 14 (ICS). -#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_ -#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_ +#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_ +#include <jni.h> #include <memory> +#include <string> -#include "util/base/integral_types.h" -#include "util/utf8/unicodetext.h" -#include "unicode/brkiter.h" -#include "unicode/errorcode.h" -#include "unicode/regex.h" -#include "unicode/uchar.h" -#include "unicode/unum.h" +#include "utils/base/integral_types.h" +#include "utils/java/jni-cache.h" +#include "utils/java/scoped_global_ref.h" +#include "utils/java/scoped_local_ref.h" +#include "utils/utf8/unicodetext.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { class UniLib { public: + UniLib(); + explicit UniLib(const std::shared_ptr<JniCache>& jni_cache); + bool ParseInt32(const UnicodeText& text, int* result) const; bool IsOpeningBracket(char32 codepoint) const; bool IsClosingBracket(char32 codepoint) const; @@ -102,29 +106,31 @@ class UniLib { protected: friend class RegexPattern; - explicit RegexMatcher(icu::RegexPattern* pattern, icu::UnicodeString text); + RegexMatcher(const JniCache* jni_cache, ScopedGlobalRef<jobject> matcher, + ScopedGlobalRef<jstring> text); private: bool UpdateLastFindOffset() const; - std::unique_ptr<icu::RegexMatcher> matcher_; - icu::UnicodeString text_; - mutable int last_find_offset_; - mutable int last_find_offset_codepoints_; - mutable bool last_find_offset_dirty_; + const JniCache* jni_cache_; + ScopedGlobalRef<jobject> matcher_; + ScopedGlobalRef<jstring> text_; + mutable int last_find_offset_ = 0; + mutable int last_find_offset_codepoints_ = 0; + mutable bool last_find_offset_dirty_ = true; }; class RegexPattern { public: - std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& input) const; + std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& context) const; protected: friend class UniLib; - explicit RegexPattern(std::unique_ptr<icu::RegexPattern> pattern) - : pattern_(std::move(pattern)) {} + RegexPattern(const JniCache* jni_cache, const UnicodeText& regex); private: - std::unique_ptr<icu::RegexPattern> pattern_; + const JniCache* jni_cache_; + ScopedGlobalRef<jobject> pattern_; }; class BreakIterator { @@ -135,11 +141,12 @@ class UniLib { protected: friend class UniLib; - explicit BreakIterator(const UnicodeText& text); + BreakIterator(const JniCache* jni_cache, const UnicodeText& text); private: - std::unique_ptr<icu::BreakIterator> break_iterator_; - icu::UnicodeString text_; + const JniCache* jni_cache_; + ScopedGlobalRef<jstring> text_; + ScopedGlobalRef<jobject> iterator_; int last_break_index_; int last_unicode_index_; }; @@ -148,8 +155,11 @@ class UniLib { const UnicodeText& regex) const; std::unique_ptr<BreakIterator> CreateBreakIterator( const UnicodeText& text) const; + + private: + std::shared_ptr<JniCache> jni_cache_; }; -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_ diff --git a/util/utf8/unilib.h b/utils/utf8/unilib.h index 29b4575..ec1f329 100644 --- a/util/utf8/unilib.h +++ b/utils/utf8/unilib.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_ -#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_ +#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_ -#include "util/utf8/unilib-icu.h" -#define CREATE_UNILIB_FOR_TESTING const UniLib unilib; +#include "utils/utf8/unilib-javaicu.h" +#define INIT_UNILIB_FOR_TESTING(VAR) VAR(nullptr) -#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_ diff --git a/util/utf8/unilib_test.cc b/utils/utf8/unilib_test.cc index 13b1347..96b2c2d 100644 --- a/util/utf8/unilib_test.cc +++ b/utils/utf8/unilib_test.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,89 +14,90 @@ * limitations under the License. */ -#include "util/utf8/unilib.h" +#include "utils/utf8/unilib.h" -#include "util/base/logging.h" -#include "util/utf8/unicodetext.h" +#include "utils/base/logging.h" +#include "utils/utf8/unicodetext.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace libtextclassifier2 { +namespace libtextclassifier3 { namespace { using ::testing::ElementsAre; -TEST(UniLibTest, CharacterClassesAscii) { - CREATE_UNILIB_FOR_TESTING; - EXPECT_TRUE(unilib.IsOpeningBracket('(')); - EXPECT_TRUE(unilib.IsClosingBracket(')')); - EXPECT_FALSE(unilib.IsWhitespace(')')); - EXPECT_TRUE(unilib.IsWhitespace(' ')); - EXPECT_FALSE(unilib.IsDigit(')')); - EXPECT_TRUE(unilib.IsDigit('0')); - EXPECT_TRUE(unilib.IsDigit('9')); - EXPECT_FALSE(unilib.IsUpper(')')); - EXPECT_TRUE(unilib.IsUpper('A')); - EXPECT_TRUE(unilib.IsUpper('Z')); - EXPECT_EQ(unilib.ToLower('A'), 'a'); - EXPECT_EQ(unilib.ToLower('Z'), 'z'); - EXPECT_EQ(unilib.ToLower(')'), ')'); - EXPECT_EQ(unilib.GetPairedBracket(')'), '('); - EXPECT_EQ(unilib.GetPairedBracket('}'), '{'); +class UniLibTest : public ::testing::Test { + protected: + UniLibTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} + UniLib unilib_; +}; + +TEST_F(UniLibTest, CharacterClassesAscii) { + EXPECT_TRUE(unilib_.IsOpeningBracket('(')); + EXPECT_TRUE(unilib_.IsClosingBracket(')')); + EXPECT_FALSE(unilib_.IsWhitespace(')')); + EXPECT_TRUE(unilib_.IsWhitespace(' ')); + EXPECT_FALSE(unilib_.IsDigit(')')); + EXPECT_TRUE(unilib_.IsDigit('0')); + EXPECT_TRUE(unilib_.IsDigit('9')); + EXPECT_FALSE(unilib_.IsUpper(')')); + EXPECT_TRUE(unilib_.IsUpper('A')); + EXPECT_TRUE(unilib_.IsUpper('Z')); + EXPECT_EQ(unilib_.ToLower('A'), 'a'); + EXPECT_EQ(unilib_.ToLower('Z'), 'z'); + EXPECT_EQ(unilib_.ToLower(')'), ')'); + EXPECT_EQ(unilib_.GetPairedBracket(')'), '('); + EXPECT_EQ(unilib_.GetPairedBracket('}'), '{'); } -#ifndef LIBTEXTCLASSIFIER_UNILIB_DUMMY -TEST(UniLibTest, CharacterClassesUnicode) { - CREATE_UNILIB_FOR_TESTING; - EXPECT_TRUE(unilib.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON - EXPECT_TRUE(unilib.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS - EXPECT_FALSE(unilib.IsWhitespace(0x23F0)); // ALARM CLOCK - EXPECT_TRUE(unilib.IsWhitespace(0x2003)); // EM SPACE - EXPECT_FALSE(unilib.IsDigit(0xA619)); // VAI SYMBOL JONG - EXPECT_TRUE(unilib.IsDigit(0xA620)); // VAI DIGIT ZERO - EXPECT_TRUE(unilib.IsDigit(0xA629)); // VAI DIGIT NINE - EXPECT_FALSE(unilib.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA - EXPECT_FALSE(unilib.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE - EXPECT_TRUE(unilib.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE - EXPECT_TRUE(unilib.IsUpper(0x0391)); // GREEK CAPITAL ALPHA - EXPECT_TRUE(unilib.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL - EXPECT_FALSE(unilib.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS - EXPECT_EQ(unilib.ToLower(0x0391), 0x03B1); // GREEK ALPHA - EXPECT_EQ(unilib.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA - EXPECT_EQ(unilib.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI - - EXPECT_EQ(unilib.GetPairedBracket(0x0F3C), 0x0F3D); - EXPECT_EQ(unilib.GetPairedBracket(0x0F3D), 0x0F3C); +#ifndef TC3_UNILIB_DUMMY +TEST_F(UniLibTest, CharacterClassesUnicode) { + EXPECT_TRUE(unilib_.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON + EXPECT_TRUE(unilib_.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS + EXPECT_FALSE(unilib_.IsWhitespace(0x23F0)); // ALARM CLOCK + EXPECT_TRUE(unilib_.IsWhitespace(0x2003)); // EM SPACE + EXPECT_FALSE(unilib_.IsDigit(0xA619)); // VAI SYMBOL JONG + EXPECT_TRUE(unilib_.IsDigit(0xA620)); // VAI DIGIT ZERO + EXPECT_TRUE(unilib_.IsDigit(0xA629)); // VAI DIGIT NINE + EXPECT_FALSE(unilib_.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA + EXPECT_FALSE(unilib_.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE + EXPECT_TRUE(unilib_.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE + EXPECT_TRUE(unilib_.IsUpper(0x0391)); // GREEK CAPITAL ALPHA + EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL + EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS + EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA + EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA + EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI + + EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D); + EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C); } -#endif // ndef LIBTEXTCLASSIFIER_UNILIB_DUMMY +#endif // ndef TC3_UNILIB_DUMMY -TEST(UniLibTest, RegexInterface) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(UniLibTest, RegexInterface) { const UnicodeText regex_pattern = UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true); std::unique_ptr<UniLib::RegexPattern> pattern = - unilib.CreateRegexPattern(regex_pattern); + unilib_.CreateRegexPattern(regex_pattern); const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false); int status; std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input); - TC_LOG(INFO) << matcher->Matches(&status); - TC_LOG(INFO) << matcher->Find(&status); - TC_LOG(INFO) << matcher->Start(0, &status); - TC_LOG(INFO) << matcher->End(0, &status); - TC_LOG(INFO) << matcher->Group(0, &status).size_codepoints(); + TC3_LOG(INFO) << matcher->Matches(&status); + TC3_LOG(INFO) << matcher->Find(&status); + TC3_LOG(INFO) << matcher->Start(0, &status); + TC3_LOG(INFO) << matcher->End(0, &status); + TC3_LOG(INFO) << matcher->Group(0, &status).size_codepoints(); } -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST(UniLibTest, Regex) { - CREATE_UNILIB_FOR_TESTING; - +#ifdef TC3_UNILIB_ICU +TEST_F(UniLibTest, Regex) { // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to // test the regex functionality with it to verify we are handling the indices // correctly. const UnicodeText regex_pattern = UTF8ToUnicodeText("[0-9]+😋", /*do_copy=*/false); std::unique_ptr<UniLib::RegexPattern> pattern = - unilib.CreateRegexPattern(regex_pattern); + unilib_.CreateRegexPattern(regex_pattern); int status; std::unique_ptr<UniLib::RegexMatcher> matcher; @@ -125,19 +126,17 @@ TEST(UniLibTest, Regex) { EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋"); EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU - -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST(UniLibTest, RegexGroups) { - CREATE_UNILIB_FOR_TESTING; +#endif // TC3_UNILIB_ICU +#ifdef TC3_UNILIB_ICU +TEST_F(UniLibTest, RegexGroups) { // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to // test the regex functionality with it to verify we are handling the indices // correctly. const UnicodeText regex_pattern = UTF8ToUnicodeText( "(?<group1>[0-9])(?<group2>[0-9]+)😋", /*do_copy=*/false); std::unique_ptr<UniLib::RegexPattern> pattern = - unilib.CreateRegexPattern(regex_pattern); + unilib_.CreateRegexPattern(regex_pattern); int status; std::unique_ptr<UniLib::RegexMatcher> matcher; @@ -164,15 +163,14 @@ TEST(UniLibTest, RegexGroups) { EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123"); EXPECT_EQ(status, UniLib::RegexMatcher::kNoError); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU +#ifdef TC3_UNILIB_ICU -TEST(UniLibTest, BreakIterator) { - CREATE_UNILIB_FOR_TESTING; +TEST_F(UniLibTest, BreakIterator) { const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false); std::unique_ptr<UniLib::BreakIterator> iterator = - unilib.CreateBreakIterator(text); + unilib_.CreateBreakIterator(text); std::vector<int> break_indices; int break_index = 0; while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) { @@ -180,14 +178,13 @@ TEST(UniLibTest, BreakIterator) { } EXPECT_THAT(break_indices, ElementsAre(4, 5, 9)); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST(UniLibTest, BreakIterator4ByteUTF8) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_F(UniLibTest, BreakIterator4ByteUTF8) { const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false); std::unique_ptr<UniLib::BreakIterator> iterator = - unilib.CreateBreakIterator(text); + unilib_.CreateBreakIterator(text); std::vector<int> break_indices; int break_index = 0; while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) { @@ -195,38 +192,35 @@ TEST(UniLibTest, BreakIterator4ByteUTF8) { } EXPECT_THAT(break_indices, ElementsAre(1, 2, 3)); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU -TEST(UniLibTest, IntegerParse) { - CREATE_UNILIB_FOR_TESTING; +#ifndef TC3_UNILIB_JAVAICU +TEST_F(UniLibTest, IntegerParse) { int result; EXPECT_TRUE( - unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result)); + unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result)); EXPECT_EQ(result, 123); } -#endif // ndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU +#endif // ndef TC3_UNILIB_JAVAICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST(UniLibTest, IntegerParseFullWidth) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_F(UniLibTest, IntegerParseFullWidth) { int result; // The input string here is full width - EXPECT_TRUE(unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), - &result)); + EXPECT_TRUE(unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), + &result)); EXPECT_EQ(result, 123); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU -#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU -TEST(UniLibTest, IntegerParseFullWidthWithAlpha) { - CREATE_UNILIB_FOR_TESTING; +#ifdef TC3_UNILIB_ICU +TEST_F(UniLibTest, IntegerParseFullWidthWithAlpha) { int result; // The input string here is full width - EXPECT_FALSE(unilib.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false), - &result)); + EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false), + &result)); } -#endif // LIBTEXTCLASSIFIER_UNILIB_ICU +#endif // TC3_UNILIB_ICU } // namespace -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 diff --git a/utils/variant.h b/utils/variant.h new file mode 100644 index 0000000..ddb0d60 --- /dev/null +++ b/utils/variant.h @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_VARIANT_H_ +#define LIBTEXTCLASSIFIER_UTILS_VARIANT_H_ + +#include <string> + +#include "utils/base/integral_types.h" +#include "utils/strings/stringpiece.h" + +namespace libtextclassifier3 { + +// Represents a type-tagged union of different basic types. +struct Variant { + Variant() : type(TYPE_INVALID) {} + explicit Variant(int value) : type(TYPE_INT_VALUE), int_value(value) {} + explicit Variant(int64 value) : type(TYPE_LONG_VALUE), long_value(value) {} + explicit Variant(float value) : type(TYPE_FLOAT_VALUE), float_value(value) {} + explicit Variant(double value) + : type(TYPE_DOUBLE_VALUE), double_value(value) {} + explicit Variant(StringPiece value) + : type(TYPE_STRING_VALUE), string_value(value.ToString()) {} + explicit Variant(std::string value) + : type(TYPE_STRING_VALUE), string_value(value) {} + explicit Variant(const char* value) + : type(TYPE_STRING_VALUE), string_value(value) {} + explicit Variant(bool value) : type(TYPE_BOOL_VALUE), bool_value(value) {} + enum Type { + TYPE_INVALID = 0, + TYPE_INT_VALUE = 1, + TYPE_LONG_VALUE = 2, + TYPE_FLOAT_VALUE = 3, + TYPE_DOUBLE_VALUE = 4, + TYPE_BOOL_VALUE = 5, + TYPE_STRING_VALUE = 6, + }; + Type type; + union { + int int_value; + int64 long_value; + float float_value; + double double_value; + bool bool_value; + }; + std::string string_value; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_VARIANT_H_ diff --git a/utils/zlib/buffer.fbs b/utils/zlib/buffer.fbs new file mode 100755 index 0000000..60da23e --- /dev/null +++ b/utils/zlib/buffer.fbs @@ -0,0 +1,22 @@ +// +// Copyright (C) 2018 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +namespace libtextclassifier3; +table CompressedBuffer { + buffer:[ubyte]; + uncompressed_size:int; +} + diff --git a/utils/zlib/zlib.cc b/utils/zlib/zlib.cc new file mode 100644 index 0000000..e9991e0 --- /dev/null +++ b/utils/zlib/zlib.cc @@ -0,0 +1,174 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/zlib/zlib.h" + +#include <memory> + +#include "utils/base/logging.h" +#include "utils/flatbuffers.h" + +namespace libtextclassifier3 { + +std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() { + std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor()); + if (!result->initialized_) { + result.reset(); + } + return result; +} + +ZlibDecompressor::ZlibDecompressor() { + memset(&stream_, 0, sizeof(stream_)); + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + initialized_ = (inflateInit(&stream_) == Z_OK); +} + +ZlibDecompressor::~ZlibDecompressor() { + if (initialized_) { + inflateEnd(&stream_); + } +} + +bool ZlibDecompressor::Decompress(const uint8* buffer, const int buffer_size, + const int uncompressed_size, + std::string* out) { + if (out == nullptr) { + return false; + } + out->resize(uncompressed_size); + stream_.next_in = reinterpret_cast<const Bytef*>(buffer); + stream_.avail_in = buffer_size; + stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str())); + stream_.avail_out = uncompressed_size; + return (inflate(&stream_, Z_SYNC_FLUSH) == Z_OK); +} + +bool ZlibDecompressor::MaybeDecompress( + const CompressedBuffer* compressed_buffer, std::string* out) { + if (!compressed_buffer) { + return true; + } + return Decompress(compressed_buffer->buffer()->Data(), + compressed_buffer->buffer()->size(), + compressed_buffer->uncompressed_size(), out); +} + +bool ZlibDecompressor::MaybeDecompress( + const CompressedBufferT* compressed_buffer, std::string* out) { + if (!compressed_buffer) { + return true; + } + return Decompress(compressed_buffer->buffer.data(), + compressed_buffer->buffer.size(), + compressed_buffer->uncompressed_size, out); +} + +std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() { + std::unique_ptr<ZlibCompressor> result(new ZlibCompressor()); + if (!result->initialized_) { + result.reset(); + } + return result; +} + +ZlibCompressor::ZlibCompressor(int level, int tmp_buffer_size) { + memset(&stream_, 0, sizeof(stream_)); + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + buffer_size_ = tmp_buffer_size; + buffer_.reset(new Bytef[buffer_size_]); + initialized_ = (deflateInit(&stream_, level) == Z_OK); +} + +ZlibCompressor::~ZlibCompressor() { deflateEnd(&stream_); } + +void ZlibCompressor::Compress(const std::string& uncompressed_content, + CompressedBufferT* out) { + out->uncompressed_size = uncompressed_content.size(); + out->buffer.clear(); + stream_.next_in = + reinterpret_cast<const Bytef*>(uncompressed_content.c_str()); + stream_.avail_in = uncompressed_content.size(); + stream_.next_out = buffer_.get(); + stream_.avail_out = buffer_size_; + unsigned char* buffer_deflate_start_position = + reinterpret_cast<unsigned char*>(buffer_.get()); + int status; + do { + // Deflate chunk-wise. + // Z_SYNC_FLUSH causes all pending output to be flushed, but doesn't + // reset the compression state. + // As we do not know how big the compressed buffer will be, we compress + // chunk wise and append the flushed content to the output string buffer. + // As we store the uncompressed size, we do not have to do this during + // decompression. + status = deflate(&stream_, Z_SYNC_FLUSH); + unsigned char* buffer_deflate_end_position = + reinterpret_cast<unsigned char*>(stream_.next_out); + if (buffer_deflate_end_position != buffer_deflate_start_position) { + out->buffer.insert(out->buffer.end(), buffer_deflate_start_position, + buffer_deflate_end_position); + stream_.next_out = buffer_deflate_start_position; + stream_.avail_out = buffer_size_; + } else { + break; + } + } while (status == Z_OK); +} + +std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern( + const UniLib& unilib, const flatbuffers::String* uncompressed_pattern, + const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor, + std::string* result_pattern_text) { + UnicodeText unicode_regex_pattern; + std::string decompressed_pattern; + if (compressed_pattern != nullptr && + compressed_pattern->buffer() != nullptr) { + if (decompressor == nullptr || + !decompressor->MaybeDecompress(compressed_pattern, + &decompressed_pattern)) { + TC3_LOG(ERROR) << "Cannot decompress pattern."; + return nullptr; + } + unicode_regex_pattern = + UTF8ToUnicodeText(decompressed_pattern.data(), + decompressed_pattern.size(), /*do_copy=*/false); + } else { + if (uncompressed_pattern == nullptr) { + TC3_LOG(ERROR) << "Cannot load uncompressed pattern."; + return nullptr; + } + unicode_regex_pattern = + UTF8ToUnicodeText(uncompressed_pattern->c_str(), + uncompressed_pattern->Length(), /*do_copy=*/false); + } + + if (result_pattern_text != nullptr) { + *result_pattern_text = unicode_regex_pattern.ToUTF8String(); + } + + std::unique_ptr<UniLib::RegexPattern> regex_pattern = + unilib.CreateRegexPattern(unicode_regex_pattern); + if (!regex_pattern) { + TC3_LOG(ERROR) << "Could not create pattern: " + << unicode_regex_pattern.ToUTF8String(); + } + return regex_pattern; +} + +} // namespace libtextclassifier3 diff --git a/zlib-utils.h b/utils/zlib/zlib.h index 136f4d2..d93527e 100644 --- a/zlib-utils.h +++ b/utils/zlib/zlib.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017 The Android Open Source Project + * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,23 +16,28 @@ // Functions to compress and decompress low entropy entries in the model. -#ifndef LIBTEXTCLASSIFIER_ZLIB_UTILS_H_ -#define LIBTEXTCLASSIFIER_ZLIB_UTILS_H_ +#ifndef LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_H_ +#define LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_H_ #include <memory> -#include "model_generated.h" -#include "util/utf8/unilib.h" -#include "zlib.h" +#include "utils/utf8/unilib.h" +#include "utils/zlib/buffer_generated.h" +#include <zlib.h> -namespace libtextclassifier2 { +namespace libtextclassifier3 { class ZlibDecompressor { public: static std::unique_ptr<ZlibDecompressor> Instance(); ~ZlibDecompressor(); - bool Decompress(const CompressedBuffer* compressed_buffer, std::string* out); + bool Decompress(const uint8* buffer, const int buffer_size, + const int uncompressed_size, std::string* out); + bool MaybeDecompress(const CompressedBuffer* compressed_buffer, + std::string* out); + bool MaybeDecompress(const CompressedBufferT* compressed_buffer, + std::string* out); private: ZlibDecompressor(); @@ -59,21 +64,12 @@ class ZlibCompressor { bool initialized_; }; -// Compresses regex and datetime rules in the model in place. -bool CompressModel(ModelT* model); - -// Decompresses regex and datetime rules in the model in place. -bool DecompressModel(ModelT* model); - -// Compresses regex and datetime rules in the model. -std::string CompressSerializedModel(const std::string& model); - // Create and compile a regex pattern from optionally compressed pattern. std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern( const UniLib& unilib, const flatbuffers::String* uncompressed_pattern, const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor, std::string* result_pattern_text = nullptr); -} // namespace libtextclassifier2 +} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_ZLIB_UTILS_H_ +#endif // LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_H_ diff --git a/zlib-utils.cc b/zlib-utils.cc deleted file mode 100644 index 7e6646f..0000000 --- a/zlib-utils.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Copyright (C) 2017 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "zlib-utils.h" - -#include <memory> - -#include "util/base/logging.h" -#include "util/flatbuffers.h" - -namespace libtextclassifier2 { - -std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() { - std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor()); - if (!result->initialized_) { - result.reset(); - } - return result; -} - -ZlibDecompressor::ZlibDecompressor() { - memset(&stream_, 0, sizeof(stream_)); - stream_.zalloc = Z_NULL; - stream_.zfree = Z_NULL; - initialized_ = (inflateInit(&stream_) == Z_OK); -} - -ZlibDecompressor::~ZlibDecompressor() { - if (initialized_) { - inflateEnd(&stream_); - } -} - -bool ZlibDecompressor::Decompress(const CompressedBuffer* compressed_buffer, - std::string* out) { - out->resize(compressed_buffer->uncompressed_size()); - stream_.next_in = - reinterpret_cast<const Bytef*>(compressed_buffer->buffer()->Data()); - stream_.avail_in = compressed_buffer->buffer()->Length(); - stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str())); - stream_.avail_out = compressed_buffer->uncompressed_size(); - return (inflate(&stream_, Z_SYNC_FLUSH) == Z_OK); -} - -std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() { - std::unique_ptr<ZlibCompressor> result(new ZlibCompressor()); - if (!result->initialized_) { - result.reset(); - } - return result; -} - -ZlibCompressor::ZlibCompressor(int level, int tmp_buffer_size) { - memset(&stream_, 0, sizeof(stream_)); - stream_.zalloc = Z_NULL; - stream_.zfree = Z_NULL; - buffer_size_ = tmp_buffer_size; - buffer_.reset(new Bytef[buffer_size_]); - initialized_ = (deflateInit(&stream_, level) == Z_OK); -} - -ZlibCompressor::~ZlibCompressor() { deflateEnd(&stream_); } - -void ZlibCompressor::Compress(const std::string& uncompressed_content, - CompressedBufferT* out) { - out->uncompressed_size = uncompressed_content.size(); - out->buffer.clear(); - stream_.next_in = - reinterpret_cast<const Bytef*>(uncompressed_content.c_str()); - stream_.avail_in = uncompressed_content.size(); - stream_.next_out = buffer_.get(); - stream_.avail_out = buffer_size_; - unsigned char* buffer_deflate_start_position = - reinterpret_cast<unsigned char*>(buffer_.get()); - int status; - do { - // Deflate chunk-wise. - // Z_SYNC_FLUSH causes all pending output to be flushed, but doesn't - // reset the compression state. - // As we do not know how big the compressed buffer will be, we compress - // chunk wise and append the flushed content to the output string buffer. - // As we store the uncompressed size, we do not have to do this during - // decompression. - status = deflate(&stream_, Z_SYNC_FLUSH); - unsigned char* buffer_deflate_end_position = - reinterpret_cast<unsigned char*>(stream_.next_out); - if (buffer_deflate_end_position != buffer_deflate_start_position) { - out->buffer.insert(out->buffer.end(), buffer_deflate_start_position, - buffer_deflate_end_position); - stream_.next_out = buffer_deflate_start_position; - stream_.avail_out = buffer_size_; - } else { - break; - } - } while (status == Z_OK); -} - -// Compress rule fields in the model. -bool CompressModel(ModelT* model) { - std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance(); - if (!zlib_compressor) { - TC_LOG(ERROR) << "Cannot compress model."; - return false; - } - - // Compress regex rules. - if (model->regex_model != nullptr) { - for (int i = 0; i < model->regex_model->patterns.size(); i++) { - RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); - pattern->compressed_pattern.reset(new CompressedBufferT); - zlib_compressor->Compress(pattern->pattern, - pattern->compressed_pattern.get()); - pattern->pattern.clear(); - } - } - - // Compress date-time rules. - if (model->datetime_model != nullptr) { - for (int i = 0; i < model->datetime_model->patterns.size(); i++) { - DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); - for (int j = 0; j < pattern->regexes.size(); j++) { - DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); - regex->compressed_pattern.reset(new CompressedBufferT); - zlib_compressor->Compress(regex->pattern, - regex->compressed_pattern.get()); - regex->pattern.clear(); - } - } - for (int i = 0; i < model->datetime_model->extractors.size(); i++) { - DatetimeModelExtractorT* extractor = - model->datetime_model->extractors[i].get(); - extractor->compressed_pattern.reset(new CompressedBufferT); - zlib_compressor->Compress(extractor->pattern, - extractor->compressed_pattern.get()); - extractor->pattern.clear(); - } - } - return true; -} - -namespace { - -bool DecompressBuffer(const CompressedBufferT* compressed_pattern, - ZlibDecompressor* zlib_decompressor, - std::string* uncompressed_pattern) { - std::string packed_pattern = - PackFlatbuffer<CompressedBuffer>(compressed_pattern); - if (!zlib_decompressor->Decompress( - LoadAndVerifyFlatbuffer<CompressedBuffer>(packed_pattern), - uncompressed_pattern)) { - return false; - } - return true; -} - -} // namespace - -bool DecompressModel(ModelT* model) { - std::unique_ptr<ZlibDecompressor> zlib_decompressor = - ZlibDecompressor::Instance(); - if (!zlib_decompressor) { - TC_LOG(ERROR) << "Cannot initialize decompressor."; - return false; - } - - // Decompress regex rules. - if (model->regex_model != nullptr) { - for (int i = 0; i < model->regex_model->patterns.size(); i++) { - RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); - if (!DecompressBuffer(pattern->compressed_pattern.get(), - zlib_decompressor.get(), &pattern->pattern)) { - TC_LOG(ERROR) << "Cannot decompress pattern: " << i; - return false; - } - pattern->compressed_pattern.reset(nullptr); - } - } - - // Decompress date-time rules. - if (model->datetime_model != nullptr) { - for (int i = 0; i < model->datetime_model->patterns.size(); i++) { - DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); - for (int j = 0; j < pattern->regexes.size(); j++) { - DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); - if (!DecompressBuffer(regex->compressed_pattern.get(), - zlib_decompressor.get(), ®ex->pattern)) { - TC_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j; - return false; - } - regex->compressed_pattern.reset(nullptr); - } - } - for (int i = 0; i < model->datetime_model->extractors.size(); i++) { - DatetimeModelExtractorT* extractor = - model->datetime_model->extractors[i].get(); - if (!DecompressBuffer(extractor->compressed_pattern.get(), - zlib_decompressor.get(), &extractor->pattern)) { - TC_LOG(ERROR) << "Cannot decompress pattern: " << i; - return false; - } - extractor->compressed_pattern.reset(nullptr); - } - } - return true; -} - -std::string CompressSerializedModel(const std::string& model) { - std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str()); - TC_CHECK(unpacked_model != nullptr); - TC_CHECK(CompressModel(unpacked_model.get())); - flatbuffers::FlatBufferBuilder builder; - FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); - return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize()); -} - -std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern( - const UniLib& unilib, const flatbuffers::String* uncompressed_pattern, - const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor, - std::string* result_pattern_text) { - UnicodeText unicode_regex_pattern; - std::string decompressed_pattern; - if (compressed_pattern != nullptr && - compressed_pattern->buffer() != nullptr) { - if (decompressor == nullptr || - !decompressor->Decompress(compressed_pattern, &decompressed_pattern)) { - TC_LOG(ERROR) << "Cannot decompress pattern."; - return nullptr; - } - unicode_regex_pattern = - UTF8ToUnicodeText(decompressed_pattern.data(), - decompressed_pattern.size(), /*do_copy=*/false); - } else { - if (uncompressed_pattern == nullptr) { - TC_LOG(ERROR) << "Cannot load uncompressed pattern."; - return nullptr; - } - unicode_regex_pattern = - UTF8ToUnicodeText(uncompressed_pattern->c_str(), - uncompressed_pattern->Length(), /*do_copy=*/false); - } - - if (result_pattern_text != nullptr) { - *result_pattern_text = unicode_regex_pattern.ToUTF8String(); - } - - std::unique_ptr<UniLib::RegexPattern> regex_pattern = - unilib.CreateRegexPattern(unicode_regex_pattern); - if (!regex_pattern) { - TC_LOG(ERROR) << "Could not create pattern: " - << unicode_regex_pattern.ToUTF8String(); - } - return regex_pattern; -} - -} // namespace libtextclassifier2 |