summaryrefslogtreecommitdiff
path: root/native/annotator
diff options
context:
space:
mode:
authorTony Mak <tonymak@google.com>2019-10-15 15:29:22 +0100
committerTony Mak <tonymak@google.com>2019-10-15 18:33:02 +0100
commit8cd7ba6be23c557c608653330d931e6700f19688 (patch)
treeedaf9e6b6a52c247985af97ac94333b05cfb85bf /native/annotator
parentf9143f3090d0e29353d0841f3e892babc947b4d2 (diff)
downloadlibtextclassifier-8cd7ba6be23c557c608653330d931e6700f19688.tar.gz
Import libtextclassifier
Test: atest TextClassifierServiceTest Change-Id: Ief715193072d0af3aea230c3c9475ef18e8ac84c
Diffstat (limited to 'native/annotator')
-rw-r--r--native/annotator/annotator.cc22
-rw-r--r--native/annotator/annotator_jni.cc387
-rw-r--r--native/annotator/annotator_jni_common.cc220
-rw-r--r--native/annotator/annotator_jni_common.h11
-rw-r--r--native/annotator/datetime/parser.cc7
-rw-r--r--native/annotator/datetime/parser.h7
-rw-r--r--native/annotator/datetime/parser_test.cc48
-rw-r--r--native/annotator/duration/duration.cc50
-rw-r--r--native/annotator/duration/duration.h18
-rw-r--r--native/annotator/duration/duration_test.cc79
-rwxr-xr-xnative/annotator/entity-data.fbs43
-rw-r--r--native/annotator/flatbuffer-utils.cc65
-rw-r--r--native/annotator/flatbuffer-utils.h38
-rw-r--r--native/annotator/knowledge/knowledge-engine-dummy.h2
-rwxr-xr-xnative/annotator/model.fbs14
-rw-r--r--native/annotator/number/number.cc145
-rw-r--r--native/annotator/number/number.h11
-rw-r--r--native/annotator/test_data/test_model.fbbin654952 -> 656008 bytes
-rw-r--r--native/annotator/test_data/wrong_embeddings.fbbin393692 -> 394256 bytes
-rw-r--r--native/annotator/types.cc69
-rw-r--r--native/annotator/types.h2
-rw-r--r--native/annotator/zlib-utils.cc9
-rw-r--r--native/annotator/zlib-utils_test.cc53
23 files changed, 916 insertions, 384 deletions
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 867eea0..d910c8d 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -29,6 +29,7 @@
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
+#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
@@ -416,7 +417,7 @@ void Annotator::ValidateAndInitialize() {
model_->duration_annotator_options()->enabled()) {
duration_annotator_.reset(
new DurationAnnotator(model_->duration_annotator_options(),
- selection_feature_processor_.get()));
+ selection_feature_processor_.get(), unilib_));
}
if (model_->entity_data_schema()) {
@@ -505,6 +506,10 @@ bool Annotator::InitializeKnowledgeEngine(
TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
return false;
}
+ if (model_->triggering_options() != nullptr) {
+ knowledge_engine->SetPriorityScore(
+ model_->triggering_options()->knowledge_priority_score());
+ }
knowledge_engine_ = std::move(knowledge_engine);
return true;
}
@@ -2075,8 +2080,19 @@ bool Annotator::SerializedEntityDataFromRegexMatch(
// Set entity field from capturing group text.
if (group->entity_field_path() != nullptr) {
- if (!entity_data->ParseAndSet(group->entity_field_path(),
- group_match_text.value())) {
+ UnicodeText normalized_group_match_text =
+ UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
+
+ // Apply normalization if specified.
+ if (group->normalization_options() != nullptr) {
+ normalized_group_match_text =
+ NormalizeText(unilib_, group->normalization_options(),
+ normalized_group_match_text);
+ }
+
+ if (!entity_data->ParseAndSet(
+ group->entity_field_path(),
+ normalized_group_match_text.ToUTF8String())) {
TC3_LOG(ERROR)
<< "Could not set entity data from rule capturing group.";
return false;
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index e5b7833..28be366 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -19,6 +19,7 @@
#include "annotator/annotator_jni.h"
#include <jni.h>
+
#include <type_traits>
#include <vector>
@@ -26,11 +27,12 @@
#include "annotator/annotator_jni_common.h"
#include "annotator/types.h"
#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
#include "utils/calendar/calendar.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/jni.h"
#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-helper.h"
#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
#include "utils/strings/stringpiece.h"
@@ -48,8 +50,10 @@ using libtextclassifier3::AnnotatedSpan;
using libtextclassifier3::Annotator;
using libtextclassifier3::ClassificationResult;
using libtextclassifier3::CodepointSpan;
+using libtextclassifier3::JniHelper;
using libtextclassifier3::Model;
using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
// When using the Java's ICU, CalendarLib and UniLib need to be instantiated
// with a JavaVM pointer from JNI. When using a standard ICU the pointer is
// not needed and the objects are instantiated implicitly.
@@ -71,6 +75,7 @@ class AnnotatorJniContext {
if (jni_cache == nullptr || model == nullptr) {
return nullptr;
}
+ // Intent generator will be null if the options are not specified.
std::unique_ptr<IntentGenerator> intent_generator =
IntentGenerator::Create(model->model()->intent_options(),
model->model()->resources(), jni_cache);
@@ -79,6 +84,7 @@ class AnnotatorJniContext {
if (template_handler == nullptr) {
return nullptr;
}
+
return new AnnotatorJniContext(jni_cache, std::move(model),
std::move(intent_generator),
std::move(template_handler));
@@ -90,6 +96,8 @@ class AnnotatorJniContext {
Annotator* model() const { return model_.get(); }
+ // NOTE: Intent generator will be null if the options are not specified in
+ // the model.
IntentGenerator* intent_generator() const { return intent_generator_.get(); }
RemoteActionTemplatesHandler* template_handler() const {
@@ -113,184 +121,217 @@ class AnnotatorJniContext {
std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
};
-jobject ClassificationResultWithIntentsToJObject(
+StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject(
JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
jclass result_class, jmethodID result_class_constructor,
jclass datetime_parse_class, jmethodID datetime_parse_class_constructor,
const jstring device_locales, const ClassificationOptions* options,
const std::string& context, const CodepointSpan& selection_indices,
const ClassificationResult& classification_result, bool generate_intents) {
- jstring row_string =
- env->NewStringUTF(classification_result.collection.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> row_string,
+ JniHelper::NewStringUTF(env, classification_result.collection.c_str()));
- jobject row_datetime_parse = nullptr;
+ ScopedLocalRef<jobject> row_datetime_parse;
if (classification_result.datetime_parse_result.IsSet()) {
- row_datetime_parse =
- env->NewObject(datetime_parse_class, datetime_parse_class_constructor,
- classification_result.datetime_parse_result.time_ms_utc,
- classification_result.datetime_parse_result.granularity);
+ TC3_ASSIGN_OR_RETURN(
+ row_datetime_parse,
+ JniHelper::NewObject(
+ env, datetime_parse_class, datetime_parse_class_constructor,
+ classification_result.datetime_parse_result.time_ms_utc,
+ classification_result.datetime_parse_result.granularity));
}
- jbyteArray serialized_knowledge_result = nullptr;
+ ScopedLocalRef<jbyteArray> serialized_knowledge_result;
const std::string& serialized_knowledge_result_string =
classification_result.serialized_knowledge_result;
if (!serialized_knowledge_result_string.empty()) {
- serialized_knowledge_result =
- env->NewByteArray(serialized_knowledge_result_string.size());
- env->SetByteArrayRegion(serialized_knowledge_result, 0,
+ TC3_ASSIGN_OR_RETURN(serialized_knowledge_result,
+ JniHelper::NewByteArray(
+ env, serialized_knowledge_result_string.size()));
+ env->SetByteArrayRegion(serialized_knowledge_result.get(), 0,
serialized_knowledge_result_string.size(),
reinterpret_cast<const jbyte*>(
serialized_knowledge_result_string.data()));
}
- jstring contact_name = nullptr;
+ ScopedLocalRef<jstring> contact_name;
if (!classification_result.contact_name.empty()) {
- contact_name =
- env->NewStringUTF(classification_result.contact_name.c_str());
+ TC3_ASSIGN_OR_RETURN(contact_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_name.c_str()));
}
- jstring contact_given_name = nullptr;
+ ScopedLocalRef<jstring> contact_given_name;
if (!classification_result.contact_given_name.empty()) {
- contact_given_name =
- env->NewStringUTF(classification_result.contact_given_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_given_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_given_name.c_str()));
}
- jstring contact_family_name = nullptr;
+ ScopedLocalRef<jstring> contact_family_name;
if (!classification_result.contact_family_name.empty()) {
- contact_family_name =
- env->NewStringUTF(classification_result.contact_family_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_family_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_family_name.c_str()));
}
- jstring contact_nickname = nullptr;
+ ScopedLocalRef<jstring> contact_nickname;
if (!classification_result.contact_nickname.empty()) {
- contact_nickname =
- env->NewStringUTF(classification_result.contact_nickname.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_nickname,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_nickname.c_str()));
}
- jstring contact_email_address = nullptr;
+ ScopedLocalRef<jstring> contact_email_address;
if (!classification_result.contact_email_address.empty()) {
- contact_email_address =
- env->NewStringUTF(classification_result.contact_email_address.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_email_address,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_email_address.c_str()));
}
- jstring contact_phone_number = nullptr;
+ ScopedLocalRef<jstring> contact_phone_number;
if (!classification_result.contact_phone_number.empty()) {
- contact_phone_number =
- env->NewStringUTF(classification_result.contact_phone_number.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_phone_number,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_phone_number.c_str()));
}
- jstring contact_id = nullptr;
+ ScopedLocalRef<jstring> contact_id;
if (!classification_result.contact_id.empty()) {
- contact_id = env->NewStringUTF(classification_result.contact_id.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_id,
+ JniHelper::NewStringUTF(env, classification_result.contact_id.c_str()));
}
- jstring app_name = nullptr;
+ ScopedLocalRef<jstring> app_name;
if (!classification_result.app_name.empty()) {
- app_name = env->NewStringUTF(classification_result.app_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ app_name,
+ JniHelper::NewStringUTF(env, classification_result.app_name.c_str()));
}
- jstring app_package_name = nullptr;
+ ScopedLocalRef<jstring> app_package_name;
if (!classification_result.app_package_name.empty()) {
- app_package_name =
- env->NewStringUTF(classification_result.app_package_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ app_package_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.app_package_name.c_str()));
}
- jobject extras = nullptr;
+ ScopedLocalRef<jobjectArray> extras;
if (model_context->model()->entity_data_schema() != nullptr &&
!classification_result.serialized_entity_data.empty()) {
- extras = model_context->template_handler()->EntityDataAsNamedVariantArray(
- model_context->model()->entity_data_schema(),
- classification_result.serialized_entity_data);
+ TC3_ASSIGN_OR_RETURN(
+ extras,
+ model_context->template_handler()->EntityDataAsNamedVariantArray(
+ model_context->model()->entity_data_schema(),
+ classification_result.serialized_entity_data));
}
- jbyteArray serialized_entity_data = nullptr;
+ ScopedLocalRef<jbyteArray> serialized_entity_data;
if (!classification_result.serialized_entity_data.empty()) {
- serialized_entity_data =
- env->NewByteArray(classification_result.serialized_entity_data.size());
+ TC3_ASSIGN_OR_RETURN(
+ serialized_entity_data,
+ JniHelper::NewByteArray(
+ env, classification_result.serialized_entity_data.size()));
env->SetByteArrayRegion(
- serialized_entity_data, 0,
+ serialized_entity_data.get(), 0,
classification_result.serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
classification_result.serialized_entity_data.data()));
}
- jobject remote_action_templates_result = nullptr;
+ ScopedLocalRef<jobjectArray> remote_action_templates_result;
// Only generate RemoteActionTemplate for the top classification result
// as classifyText does not need RemoteAction from other results anyway.
if (generate_intents && model_context->intent_generator() != nullptr) {
std::vector<RemoteActionTemplate> remote_action_templates;
- if (model_context->intent_generator()->GenerateIntents(
+ if (!model_context->intent_generator()->GenerateIntents(
device_locales, classification_result,
options->reference_time_ms_utc, context, selection_indices,
app_context, model_context->model()->entity_data_schema(),
&remote_action_templates)) {
- remote_action_templates_result =
- model_context->template_handler()
- ->RemoteActionTemplatesToJObjectArray(remote_action_templates);
+ return {Status::UNKNOWN};
}
- }
- return env->NewObject(
- result_class, result_class_constructor, row_string,
- static_cast<jfloat>(classification_result.score), row_datetime_parse,
- serialized_knowledge_result, contact_name, contact_given_name,
- contact_family_name, contact_nickname, contact_email_address,
- contact_phone_number, contact_id, app_name, app_package_name, extras,
- serialized_entity_data, remote_action_templates_result,
- classification_result.duration_ms, classification_result.numeric_value,
+ TC3_ASSIGN_OR_RETURN(
+ remote_action_templates_result,
+ model_context->template_handler()->RemoteActionTemplatesToJObjectArray(
+ remote_action_templates));
+ }
+
+ return JniHelper::NewObject(
+ env, result_class, result_class_constructor, row_string.get(),
+ static_cast<jfloat>(classification_result.score),
+ row_datetime_parse.get(), serialized_knowledge_result.get(),
+ contact_name.get(), contact_given_name.get(), contact_family_name.get(),
+ contact_nickname.get(), contact_email_address.get(),
+ contact_phone_number.get(), contact_id.get(), app_name.get(),
+ app_package_name.get(), extras.get(), serialized_entity_data.get(),
+ remote_action_templates_result.get(), classification_result.duration_ms,
+ classification_result.numeric_value,
classification_result.numeric_double_value);
}
-jobjectArray ClassificationResultsWithIntentsToJObjectArray(
+StatusOr<ScopedLocalRef<jobjectArray>>
+ClassificationResultsWithIntentsToJObjectArray(
JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
const jstring device_locales, const ClassificationOptions* options,
const std::string& context, const CodepointSpan& selection_indices,
const std::vector<ClassificationResult>& classification_result,
bool generate_intents) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$ClassificationResult"),
- env);
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find ClassificationResult class.";
- return nullptr;
- }
- const ScopedLocalRef<jclass> datetime_parse_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult"),
- env);
- if (!datetime_parse_class) {
- TC3_LOG(ERROR) << "Couldn't find DatetimeResult class.";
- return nullptr;
- }
-
- const jmethodID result_class_constructor = env->GetMethodID(
- result_class.get(), "<init>",
- "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
- "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
- "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V");
- const jmethodID datetime_parse_class_constructor =
- env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> datetime_parse_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult"));
+
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(
+ env, result_class.get(), "<init>",
+ "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;"
+ "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;"
+ "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
+ "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"));
+ TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor,
+ JniHelper::GetMethodID(env, datetime_parse_class.get(),
+ "<init>", "(JI)V"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, classification_result.size(),
+ result_class.get()));
- const jobjectArray results = env->NewObjectArray(classification_result.size(),
- result_class.get(), nullptr);
for (int i = 0; i < classification_result.size(); i++) {
- jobject result = ClassificationResultWithIntentsToJObject(
- env, model_context, app_context, result_class.get(),
- result_class_constructor, datetime_parse_class.get(),
- datetime_parse_class_constructor, device_locales, options, context,
- selection_indices, classification_result[i],
- generate_intents && (i == 0));
- env->SetObjectArrayElement(results, i, result);
- env->DeleteLocalRef(result);
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ ClassificationResultWithIntentsToJObject(
+ env, model_context, app_context, result_class.get(),
+ result_class_constructor, datetime_parse_class.get(),
+ datetime_parse_class_constructor, device_locales, options, context,
+ selection_indices, classification_result[i],
+ generate_intents && (i == 0)));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
return results;
}
-jobjectArray ClassificationResultsToJObjectArray(
+StatusOr<ScopedLocalRef<jobjectArray>> ClassificationResultsToJObjectArray(
JNIEnv* env, const AnnotatorJniContext* model_context,
const std::vector<ClassificationResult>& classification_result) {
return ClassificationResultsWithIntentsToJObjectArray(
@@ -361,16 +402,18 @@ CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
}
-jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
const Model* model = libtextclassifier3::ViewModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->locales()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
- return env->NewStringUTF(model->locales()->c_str());
+
+ return JniHelper::NewStringUTF(env, model->locales()->c_str());
}
jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
@@ -385,16 +428,17 @@ jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
return model->version();
}
-jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+StatusOr<ScopedLocalRef<jstring>> GetNameFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
const Model* model = libtextclassifier3::ViewModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->name()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
- return env->NewStringUTF(model->name()->c_str());
+ return JniHelper::NewStringUTF(env, model->name()->c_str());
}
} // namespace libtextclassifier3
@@ -427,7 +471,7 @@ TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
- const std::string path_str = ToStlString(env, path);
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -531,17 +575,22 @@ TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
return nullptr;
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- const std::string context_utf8 = ToStlString(env, context);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- CodepointSpan selection = model->SuggestSelection(
- context_utf8, input_indices, FromJavaSelectionOptions(env, options));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::SelectionOptions selection_options,
+ FromJavaSelectionOptions(env, options));
+ CodepointSpan selection =
+ model->SuggestSelection(context_utf8, input_indices, selection_options);
selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
- jintArray result = env->NewIntArray(2);
- env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
- env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
- return result;
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jintArray> result,
+ JniHelper::NewIntArray(env, 2));
+ env->SetIntArrayRegion(result.get(), 0, 1, &(std::get<0>(selection)));
+ env->SetIntArrayRegion(result.get(), 1, 1, &(std::get<1>(selection)));
+ return result.release();
}
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
@@ -554,23 +603,33 @@ TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
const AnnotatorJniContext* model_context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
- const std::string context_utf8 = ToStlString(env, context);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- const libtextclassifier3::ClassificationOptions classification_options =
- FromJavaClassificationOptions(env, options);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ const libtextclassifier3::ClassificationOptions classification_options,
+ FromJavaClassificationOptions(env, options));
const std::vector<ClassificationResult> classification_result =
model_context->model()->ClassifyText(context_utf8, input_indices,
classification_options);
+
+ ScopedLocalRef<jobjectArray> result;
if (app_context != nullptr) {
- return ClassificationResultsWithIntentsToJObjectArray(
- env, model_context, app_context, device_locales,
- &classification_options, context_utf8, input_indices,
- classification_result,
- /*generate_intents=*/true);
- }
- return ClassificationResultsToJObjectArray(env, model_context,
- classification_result);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ result, ClassificationResultsWithIntentsToJObjectArray(
+ env, model_context, app_context, device_locales,
+ &classification_options, context_utf8, input_indices,
+ classification_result,
+ /*generate_intents=*/true));
+
+ } else {
+ TC3_ASSIGN_OR_RETURN_NULL(
+ result, ClassificationResultsToJObjectArray(env, model_context,
+ classification_result));
+ }
+
+ return result.release();
}
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
@@ -580,41 +639,46 @@ TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
}
const AnnotatorJniContext* model_context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
- const std::string context_utf8 = ToStlString(env, context);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::AnnotationOptions annotation_options,
+ FromJavaAnnotationOptions(env, options));
const std::vector<AnnotatedSpan> annotations =
- model_context->model()->Annotate(context_utf8,
- FromJavaAnnotationOptions(env, options));
-
- jclass result_class = env->FindClass(
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan");
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find result class: "
- << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$AnnotatedSpan";
- return nullptr;
- }
+ model_context->model()->Annotate(context_utf8, annotation_options);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"));
jmethodID result_class_constructor =
- env->GetMethodID(result_class, "<init>",
+ env->GetMethodID(result_class.get(), "<init>",
"(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$ClassificationResult;)V");
- jobjectArray results =
- env->NewObjectArray(annotations.size(), result_class, nullptr);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, annotations.size(), result_class.get()));
for (int i = 0; i < annotations.size(); ++i) {
CodepointSpan span_bmp =
ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
- jobject result = env->NewObject(
- result_class, result_class_constructor,
- static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> classification_results,
ClassificationResultsToJObjectArray(env, model_context,
annotations[i].classification));
- env->SetObjectArrayElement(results, i, result);
- env->DeleteLocalRef(result);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ static_cast<jint>(span_bmp.first),
+ static_cast<jint>(span_bmp.second),
+ classification_results.get()));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
- env->DeleteLocalRef(result_class);
- return results;
+ return results.release();
}
TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
@@ -624,16 +688,19 @@ TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
return nullptr;
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- const std::string id_utf8 = ToStlString(env, id);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8, ToStlString(env, id));
std::string serialized_knowledge_result;
if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
return nullptr;
}
- jbyteArray result = env->NewByteArray(serialized_knowledge_result.size());
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jbyteArray> result,
+ JniHelper::NewByteArray(env, serialized_knowledge_result.size()));
env->SetByteArrayRegion(
- result, 0, serialized_knowledge_result.size(),
+ result.get(), 0, serialized_knowledge_result.size(),
reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
- return result;
+ return result.release();
}
TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
@@ -654,14 +721,18 @@ TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
- return GetLocalesFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetLocalesFromMmap(env, mmap.get()));
+ return value.release();
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetLocalesFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetLocalesFromMmap(env, mmap.get()));
+ return value.release();
}
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
@@ -682,12 +753,16 @@ TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
- return GetNameFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetNameFromMmap(env, mmap.get()));
+ return value.release();
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetNameFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetNameFromMmap(env, mmap.get()));
+ return value.release();
}
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index 55f14e6..575e71b 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -17,138 +17,176 @@
#include "annotator/annotator_jni_common.h"
#include "utils/java/jni-base.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-helper.h"
namespace libtextclassifier3 {
namespace {
-std::unordered_set<std::string> EntityTypesFromJObject(JNIEnv* env,
- const jobject& jobject) {
+StatusOr<std::unordered_set<std::string>> EntityTypesFromJObject(
+ JNIEnv* env, const jobject& jobject) {
std::unordered_set<std::string> entity_types;
jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
const int size = env->GetArrayLength(jentity_types);
for (int i = 0; i < size; ++i) {
- jstring jentity_type =
- reinterpret_cast<jstring>(env->GetObjectArrayElement(jentity_types, i));
- entity_types.insert(ToStlString(env, jentity_type));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> jentity_type,
+ JniHelper::GetObjectArrayElement<jstring>(env, jentity_types, i));
+ TC3_ASSIGN_OR_RETURN(std::string entity_type,
+ ToStlString(env, jentity_type.get()));
+ entity_types.insert(entity_type);
}
return entity_types;
}
template <typename T>
-T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
- const std::string& class_name) {
+StatusOr<T> FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
+ const std::string& class_name) {
if (!joptions) {
- return {};
+ return {Status::UNKNOWN};
}
- const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
- env);
- if (!options_class) {
- return {};
- }
-
- const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getLocale", "Ljava/lang/String;");
- const std::pair<bool, jobject> status_or_reference_timezone =
- CallJniMethod0<jobject>(env, joptions, options_class.get(),
- &JNIEnv::CallObjectMethod, "getReferenceTimezone",
- "Ljava/lang/String;");
- const std::pair<bool, int64> status_or_reference_time_ms_utc =
- CallJniMethod0<int64>(env, joptions, options_class.get(),
- &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
- "J");
- const std::pair<bool, jobject> status_or_detected_text_language_tags =
- CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getDetectedTextLanguageTags", "Ljava/lang/String;");
- const std::pair<bool, int> status_or_annotation_usecase =
- CallJniMethod0<int>(env, joptions, options_class.get(),
- &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
-
- if (!status_or_locales.first || !status_or_reference_timezone.first ||
- !status_or_reference_time_ms_utc.first ||
- !status_or_detected_text_language_tags.first ||
- !status_or_annotation_usecase.first) {
- return {};
- }
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, class_name.c_str()));
+
+ // .getLocale()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_locale,
+ JniHelper::GetMethodID(env, options_class.get(), "getLocale",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> locales,
+ JniHelper::CallObjectMethod<jstring>(env, joptions, get_locale));
+
+ // .getReferenceTimeMsUtc()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_reference_time_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getReferenceTimeMsUtc", "()J"));
+ TC3_ASSIGN_OR_RETURN(
+ int64 reference_time,
+ JniHelper::CallLongMethod(env, joptions, get_reference_time_method));
+
+ // .getReferenceTimezone()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_timezone_method,
+ JniHelper::GetMethodID(env, options_class.get(), "getReferenceTimezone",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reference_timezone,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_reference_timezone_method));
+
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_detected_text_language_tags_method));
+
+ // .getAnnotationUsecase()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotationUsecase", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotation_usecase,
+ JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
T options;
- options.locales =
- ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
- options.reference_timezone = ToStlString(
- env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
- options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
- options.detected_text_language_tags = ToStlString(
- env,
- reinterpret_cast<jstring>(status_or_detected_text_language_tags.second));
+ TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.reference_timezone,
+ ToStlString(env, reference_timezone.get()));
+ options.reference_time_ms_utc = reference_time;
+ TC3_ASSIGN_OR_RETURN(options.detected_text_language_tags,
+ ToStlString(env, detected_text_language_tags.get()));
options.annotation_usecase =
- static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
+ static_cast<AnnotationUsecase>(annotation_usecase);
return options;
}
} // namespace
-SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
+StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env,
+ jobject joptions) {
if (!joptions) {
- return {};
+ return {Status::UNKNOWN};
}
- const ScopedLocalRef<jclass> options_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$SelectionOptions"),
- env);
- const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getLocales", "Ljava/lang/String;");
- const std::pair<bool, int> status_or_annotation_usecase =
- CallJniMethod0<int>(env, joptions, options_class.get(),
- &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
- if (!status_or_locales.first || !status_or_annotation_usecase.first) {
- return {};
- }
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$SelectionOptions"));
+
+ // .getLocale()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_locales,
+ JniHelper::GetMethodID(env, options_class.get(), "getLocales",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> locales,
+ JniHelper::CallObjectMethod<jstring>(env, joptions, get_locales));
+
+ // .getAnnotationUsecase()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotationUsecase", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotation_usecase,
+ JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
SelectionOptions options;
- options.locales =
- ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+ TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
options.annotation_usecase =
- static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
+ static_cast<AnnotationUsecase>(annotation_usecase);
return options;
}
-ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
- jobject joptions) {
+StatusOr<ClassificationOptions> FromJavaClassificationOptions(
+ JNIEnv* env, jobject joptions) {
return FromJavaOptionsInternal<ClassificationOptions>(
env, joptions,
TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions");
}
-AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
- if (!joptions) return {};
- const ScopedLocalRef<jclass> options_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$AnnotationOptions"),
- env);
- if (!options_class) return {};
- const std::pair<bool, jobject> status_or_entity_types =
- CallJniMethod0<jobject>(env, joptions, options_class.get(),
- &JNIEnv::CallObjectMethod, "getEntityTypes",
- "[Ljava/lang/String;");
- if (!status_or_entity_types.first) return {};
- const std::pair<bool, bool> status_or_enable_serialized_entity_data =
- CallJniMethod0<bool>(env, joptions, options_class.get(),
- &JNIEnv::CallBooleanMethod,
- "isSerializedEntityDataEnabled", "Z");
- if (!status_or_enable_serialized_entity_data.first) return {};
- AnnotationOptions annotation_options =
+StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env,
+ jobject joptions) {
+ if (!joptions) {
+ return {Status::UNKNOWN};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotationOptions"));
+
+ // .getEntityTypes()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_entity_types,
+ JniHelper::GetMethodID(env, options_class.get(), "getEntityTypes",
+ "()[Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> entity_types,
+ JniHelper::CallObjectMethod<jobject>(env, joptions, get_entity_types));
+
+ // .isSerializedEntityDataEnabled()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID is_serialized_entity_data_enabled_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "isSerializedEntityDataEnabled", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ bool is_serialized_entity_data_enabled,
+ JniHelper::CallBooleanMethod(env, joptions,
+ is_serialized_entity_data_enabled_method));
+
+ TC3_ASSIGN_OR_RETURN(
+ AnnotationOptions annotation_options,
FromJavaOptionsInternal<AnnotationOptions>(
env, joptions,
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
- annotation_options.entity_types =
- EntityTypesFromJObject(env, status_or_entity_types.second);
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions"));
+ TC3_ASSIGN_OR_RETURN(annotation_options.entity_types,
+ EntityTypesFromJObject(env, entity_types.get()));
annotation_options.is_serialized_entity_data_enabled =
- status_or_enable_serialized_entity_data.second;
+ is_serialized_entity_data_enabled;
return annotation_options;
}
diff --git a/native/annotator/annotator_jni_common.h b/native/annotator/annotator_jni_common.h
index b62bb21..f1f1d88 100644
--- a/native/annotator/annotator_jni_common.h
+++ b/native/annotator/annotator_jni_common.h
@@ -20,6 +20,7 @@
#include <jni.h>
#include "annotator/annotator.h"
+#include "utils/base/statusor.h"
#ifndef TC3_ANNOTATOR_CLASS_NAME
#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel
@@ -29,12 +30,14 @@
namespace libtextclassifier3 {
-SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions);
-
-ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
+StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env,
jobject joptions);
-AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions);
+StatusOr<ClassificationOptions> FromJavaClassificationOptions(JNIEnv* env,
+ jobject joptions);
+
+StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env,
+ jobject joptions);
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/parser.cc
index 0f222bd..6c759e7 100644
--- a/native/annotator/datetime/parser.cc
+++ b/native/annotator/datetime/parser.cc
@@ -92,7 +92,7 @@ DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
}
if (model->locales() != nullptr) {
- for (int i = 0; i < model->locales()->Length(); ++i) {
+ for (int i = 0; i < model->locales()->size(); ++i) {
locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
}
@@ -106,6 +106,8 @@ DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
use_extractors_for_locating_ = model->use_extractors_for_locating();
generate_alternative_interpretations_when_ambiguous_ =
model->generate_alternative_interpretations_when_ambiguous();
+ prefer_future_for_unspecified_date_ =
+ model->prefer_future_for_unspecified_date();
initialized_ = true;
}
@@ -433,7 +435,8 @@ bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
// response. For Details see b/130355975
if (!calendarlib_.InterpretParseData(
interpretation, reference_time_ms_utc, reference_timezone,
- reference_locale, &(result.time_ms_utc), &(result.granularity))) {
+ reference_locale, prefer_future_for_unspecified_date_,
+ &(result.time_ms_utc), &(result.granularity))) {
return false;
}
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
index 4e995bd..a5192d3 100644
--- a/native/annotator/datetime/parser.h
+++ b/native/annotator/datetime/parser.h
@@ -59,12 +59,6 @@ class DatetimeParser {
bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
-#ifdef TC3_TEST_ONLY
- void TestOnlySetGenerateAlternativeInterpretationsWhenAmbiguous(bool value) {
- generate_alternative_interpretations_when_ambiguous_ = value;
- }
-#endif // TC3_TEST_ONLY
-
protected:
DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
const CalendarLib& calendarlib,
@@ -126,6 +120,7 @@ class DatetimeParser {
std::vector<int> default_locale_ids_;
bool use_extractors_for_locating_;
bool generate_alternative_interpretations_when_ambiguous_;
+ bool prefer_future_for_unspecified_date_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/parser_test.cc
index 35c725f..1ddcf50 100644
--- a/native/annotator/datetime/parser_test.cc
+++ b/native/annotator/datetime/parser_test.cc
@@ -14,20 +14,21 @@
* limitations under the License.
*/
+#include "annotator/datetime/parser.h"
+
#include <time.h>
+
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
#include "annotator/annotator.h"
-#include "annotator/datetime/parser.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
#include "utils/testing/annotator.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
using std::vector;
using testing::ElementsAreArray;
@@ -152,7 +153,7 @@ class ParserTest : public testing::Test {
const int expected_start_index =
std::distance(marked_text_unicode.begin(), brace_open_it);
- // The -1 bellow is to account for the opening bracket character.
+ // The -1 below is to account for the opening bracket character.
const int expected_end_index =
std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
@@ -746,6 +747,43 @@ TEST_F(ParserTest, ParseWithRawUsecase) {
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
}
+TEST_F(ParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past, and so the parser should move this to the next day ->
+ // "Fri Jan 02 1970 00:30:00" Zurich time (b/139112907).
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", 84600000L /* 23.5 hours from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
+TEST_F(ParserTest, DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past. The parameter prefer_future_when_unspecified_day is
+ // disabled, so the parser should annotate this to the same day: "Thu Jan 01
+ // 1970 00:30:00" Zurich time.
+ LoadModel([](ModelT* model) {
+ // In the test model, the prefer_future_for_unspecified_date is true; make
+ // it false only for this test.
+ model->datetime_model->prefer_future_for_unspecified_date = false;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", -1800000L /* -30 minutes from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index 3529691..907a1a4 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -23,6 +23,7 @@
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/strings/numbers.h"
+#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -31,46 +32,55 @@ using DurationUnit = internal::DurationUnit;
namespace internal {
namespace {
+std::string ToLowerString(const std::string& str, const UniLib* unilib) {
+ return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
+ .ToUTF8String();
+}
+
void FillDurationUnitMap(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
expressions,
DurationUnit duration_unit,
- std::unordered_map<std::string, DurationUnit>* target_map) {
+ std::unordered_map<std::string, DurationUnit>* target_map,
+ const UniLib* unilib) {
if (expressions == nullptr) {
return;
}
for (const flatbuffers::String* expression_string : *expressions) {
- (*target_map)[expression_string->c_str()] = duration_unit;
+ (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
+ duration_unit;
}
}
} // namespace
std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
- const DurationAnnotatorOptions* options) {
+ const DurationAnnotatorOptions* options, const UniLib* unilib) {
std::unordered_map<std::string, DurationUnit> mapping;
- FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK,
- &mapping);
- FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping);
- FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR,
- &mapping);
+ FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
+ unilib);
+ FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
+ unilib);
+ FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
+ unilib);
FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
- &mapping);
+ &mapping, unilib);
FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
- &mapping);
+ &mapping, unilib);
return mapping;
}
std::unordered_set<std::string> BuildStringSet(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
- strings) {
+ strings,
+ const UniLib* unilib) {
std::unordered_set<std::string> result;
if (strings == nullptr) {
return result;
}
for (const flatbuffers::String* string_value : *strings) {
- result.insert(string_value->c_str());
+ result.insert(ToLowerString(string_value->c_str(), unilib));
}
return result;
@@ -260,14 +270,17 @@ bool DurationAnnotator::ParseQuantityToken(const Token& token,
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
- if (half_expressions_.find(token_value) != half_expressions_.end()) {
+ if (half_expressions_.find(lowercase_token_value) !=
+ half_expressions_.end()) {
value->plus_half = true;
return true;
}
int32 parsed_value;
- if (ParseInt32(token_value.c_str(), &parsed_value)) {
+ if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) {
value->value = parsed_value;
return true;
}
@@ -280,8 +293,10 @@ bool DurationAnnotator::ParseDurationUnitToken(
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
- const auto it = token_value_to_duration_unit_.find(token_value);
+ const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
if (it == token_value_to_duration_unit_.end()) {
return false;
}
@@ -319,8 +334,11 @@ bool DurationAnnotator::ParseFillerToken(const Token& token) const {
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
- if (filler_expressions_.find(token_value) == filler_expressions_.end()) {
+ if (filler_expressions_.find(lowercase_token_value) ==
+ filler_expressions_.end()) {
return false;
}
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index 2242259..db4bdae 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -26,6 +26,7 @@
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
@@ -46,12 +47,14 @@ enum class DurationUnit {
// Prepares the mapping between token values and duration unit types.
std::unordered_map<std::string, internal::DurationUnit>
-BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options);
+BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options,
+ const UniLib* unilib);
// Creates a set of strings from a flatbuffer string vector.
std::unordered_set<std::string> BuildStringSet(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
- strings);
+ strings,
+ const UniLib* unilib);
// Creates a set of ints from a flatbuffer int vector.
std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints);
@@ -62,15 +65,17 @@ std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints);
class DurationAnnotator {
public:
explicit DurationAnnotator(const DurationAnnotatorOptions* options,
- const FeatureProcessor* feature_processor)
+ const FeatureProcessor* feature_processor,
+ const UniLib* unilib)
: options_(options),
feature_processor_(feature_processor),
+ unilib_(unilib),
token_value_to_duration_unit_(
- internal::BuildTokenToDurationUnitMapping(options)),
+ internal::BuildTokenToDurationUnitMapping(options, unilib)),
filler_expressions_(
- internal::BuildStringSet(options->filler_expressions())),
+ internal::BuildStringSet(options->filler_expressions(), unilib)),
half_expressions_(
- internal::BuildStringSet(options->half_expressions())),
+ internal::BuildStringSet(options->half_expressions(), unilib)),
sub_token_separator_codepoints_(internal::BuildInt32Set(
options->sub_token_separator_codepoints())) {}
@@ -125,6 +130,7 @@ class DurationAnnotator {
const DurationAnnotatorOptions* options_;
const FeatureProcessor* feature_processor_;
+ const UniLib* unilib_;
const std::unordered_map<std::string, internal::DurationUnit>
token_value_to_duration_unit_;
const std::unordered_set<std::string> filler_expressions_;
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index 3fc25e6..d1dc67a 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -106,7 +106,7 @@ class DurationAnnotatorTest : public ::testing::Test {
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
duration_annotator_(TestingDurationAnnotatorOptions(),
- feature_processor_.get()) {}
+ feature_processor_.get(), &unilib_) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
return feature_processor_->Tokenize(text);
@@ -195,6 +195,26 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
3 * 60 * 60 * 1000 + 5 * 1000)))))));
}
+TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
+ const UnicodeText text = UTF8ToUnicodeText(
+ "See you in a week and a day and an hour and a minute and a second");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
+ 60 * 60 * 1000 + 60 * 1000 + 1000)))))));
+}
+
TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
std::vector<Token> tokens = Tokenize(text);
@@ -350,5 +370,62 @@ TEST_F(DurationAnnotatorTest,
1400L * 60L * 60L * 1000L)));
}
+TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 15 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
index 6da3dd5..fa2dc0b 100755
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
@@ -125,6 +125,46 @@ table Flight {
flight_number:string (shared);
}
+// Details about an ISBN number.
+namespace libtextclassifier3.EntityData_;
+table Isbn {
+ // The (normalized) number.
+ number:string (shared);
+}
+
+// Details about an IBAN number.
+namespace libtextclassifier3.EntityData_;
+table Iban {
+ // The (normalized) number.
+ number:string (shared);
+
+ // The country code.
+ country_code:string (shared);
+}
+
+namespace libtextclassifier3.EntityData_.ParcelTracking_;
+enum Carrier : int {
+ UNKNOWN_CARRIER = 0,
+ FEDEX = 1,
+ UPS = 2,
+ DHL = 3,
+ USPS = 4,
+ ONTRAC = 5,
+ LASERSHIP = 6,
+ ISRAEL_POST = 7,
+ SWISS_POST = 8,
+ MSC = 9,
+ AMAZON = 10,
+ I_PARCEL = 11,
+}
+
+// Details about a tracking number.
+namespace libtextclassifier3.EntityData_;
+table ParcelTracking {
+ carrier:ParcelTracking_.Carrier;
+ tracking_number:string (shared);
+}
+
// Represents an entity annotated in text.
namespace libtextclassifier3;
table EntityData {
@@ -143,6 +183,9 @@ table EntityData {
app:EntityData_.App;
payment_card:EntityData_.PaymentCard;
flight:EntityData_.Flight;
+ isbn:EntityData_.Isbn;
+ iban:EntityData_.Iban;
+ parcel:EntityData_.ParcelTracking;
}
root_type libtextclassifier3.EntityData;
diff --git a/native/annotator/flatbuffer-utils.cc b/native/annotator/flatbuffer-utils.cc
new file mode 100644
index 0000000..14b5901
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.cc
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/flatbuffer-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+bool SwapFieldNamesForOffsetsInPath(ModelT* model) {
+ if (model->regex_model == nullptr || model->entity_data_schema.empty()) {
+ // Nothing to do.
+ return true;
+ }
+ const reflection::Schema* schema =
+ LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model->entity_data_schema.data(), model->entity_data_schema.size());
+
+ for (std::unique_ptr<RegexModel_::PatternT>& pattern :
+ model->regex_model->patterns) {
+ for (std::unique_ptr<RegexModel_::Pattern_::CapturingGroupT>& group :
+ pattern->capturing_group) {
+ if (group->entity_field_path == nullptr) {
+ continue;
+ }
+
+ if (!SwapFieldNamesForOffsetsInPath(schema,
+ group->entity_field_path.get())) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(SwapFieldNamesForOffsetsInPath(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/flatbuffer-utils.h b/native/annotator/flatbuffer-utils.h
new file mode 100644
index 0000000..a7e5d64
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers in the annotator model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+
+namespace libtextclassifier3 {
+
+// Resolves field lookups by name to the concrete field offsets in the regex
+// rules of the model.
+bool SwapFieldNamesForOffsetsInPath(ModelT* model);
+
+// Same as above but for a serialized model.
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 1787353..865bf85 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -29,6 +29,8 @@ class KnowledgeEngine {
public:
bool Initialize(const std::string& serialized_config) { return true; }
+ void SetPriorityScore(float priority_score) {}
+
bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 181a8aa..5bf1472 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -17,6 +17,7 @@
include "utils/codepoint-range.fbs";
include "utils/flatbuffers.fbs";
include "utils/intents/intent-config.fbs";
+include "utils/normalization.fbs";
include "utils/resources.fbs";
include "utils/tokenizer.fbs";
include "utils/zlib/buffer.fbs";
@@ -209,6 +210,9 @@ table CapturingGroup {
// If set, the serialized entity data will be merged with the
// classification result entity data.
serialized_entity_data:string (shared);
+
+ // If set, normalization to apply before text is used in entity data.
+ normalization_options:NormalizationOptions;
}
// List of regular expression matchers to check.
@@ -329,6 +333,9 @@ table DatetimeModel {
// If true, will compile the regexes only on first use.
lazy_regex_compilation:bool = true;
+
+ // If true, will give only future dates (when the day is not specified).
+ prefer_future_for_unspecified_date:bool = false;
}
namespace libtextclassifier3.DatetimeModelLibrary_;
@@ -363,6 +370,9 @@ table ModelTriggeringOptions {
// Priority score assigned to the "other" class from ML model.
other_collection_priority_score:float = -1000;
+
+ // Priority score assigned to knowledge engine annotations.
+ knowledge_priority_score:float = 0;
}
// Options controlling the output of the classifier.
@@ -675,6 +685,10 @@ table NumberAnnotatorOptions {
// Priority score for the percentage annotation.
percentage_priority_score:float = 1;
+
+ // Float number priority score used for conflict resolution with the other
+ // models.
+ float_number_priority_score:float = 0;
}
// DurationAnnotator is so far tailored for English only.
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
index 7af63fa..671e1af 100644
--- a/native/annotator/number/number.cc
+++ b/native/annotator/number/number.cc
@@ -20,7 +20,9 @@
#include <cstdlib>
#include "annotator/collections.h"
+#include "annotator/types.h"
#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -28,68 +30,38 @@ bool NumberAnnotator::ClassifyText(
const UnicodeText& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
- if (!options_->enabled() || ((1 << annotation_usecase) &
- options_->enabled_annotation_usecases()) == 0) {
- return false;
- }
+ TC3_CHECK(classification_result != nullptr);
- int64 parsed_int_value;
- double parsed_double_value;
- int num_prefix_codepoints;
- int num_suffix_codepoints;
const UnicodeText substring_selected = UnicodeText::Substring(
context, selection_indices.first, selection_indices.second);
- if (ParseNumber(substring_selected, &parsed_int_value, &parsed_double_value,
- &num_prefix_codepoints, &num_suffix_codepoints)) {
- TC3_CHECK(classification_result != nullptr);
- classification_result->score = options_->score();
- classification_result->priority_score = options_->priority_score();
- classification_result->numeric_value = parsed_int_value;
- classification_result->numeric_double_value = parsed_double_value;
-
- if (num_suffix_codepoints == 0) {
- classification_result->collection = Collections::Number();
- return true;
+
+ std::vector<AnnotatedSpan> results;
+ if (!FindAll(substring_selected, annotation_usecase, &results)) {
+ return false;
+ }
+
+ const CodepointSpan stripped_selection_indices =
+ feature_processor_->StripBoundaryCodepoints(
+ context, selection_indices, ignored_prefix_span_boundary_codepoints_,
+ ignored_suffix_span_boundary_codepoints_);
+
+ for (const AnnotatedSpan& result : results) {
+ if (result.classification.empty()) {
+ continue;
}
- // If the selection ends in %, the parseNumber returns true with
- // num_suffix_codepoints = 1 => percent
- if (options_->enable_percentage() &&
- GetPercentSuffixLength(
- context, context.size_codepoints(),
- selection_indices.second - num_suffix_codepoints) ==
- num_suffix_codepoints) {
- classification_result->collection = Collections::Percentage();
- classification_result->priority_score =
- options_->percentage_priority_score();
+ // We make sure that the result span is equal to the stripped selection span
+ // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
+ // anyway only find valid numbers and percentages and a given selection with
+ // more than two tokens won't pass this check.
+ if (result.span.first + selection_indices.first ==
+ stripped_selection_indices.first &&
+ result.span.second + selection_indices.first ==
+ stripped_selection_indices.second) {
+ *classification_result = result.classification[0];
return true;
}
- } else if (options_->enable_percentage()) {
- // If the substring selected is a percent matching the form: 5 percent,
- // 5 pct or 5 pc => percent.
- std::vector<AnnotatedSpan> results;
- FindAll(substring_selected, annotation_usecase, &results);
- for (auto& result : results) {
- if (result.classification.empty() ||
- result.classification[0].collection != Collections::Percentage()) {
- continue;
- }
- if (result.span.first == 0 &&
- result.span.second == substring_selected.size_codepoints()) {
- TC3_CHECK(classification_result != nullptr);
- classification_result->collection = Collections::Percentage();
- classification_result->score = options_->score();
- classification_result->priority_score =
- options_->percentage_priority_score();
- classification_result->numeric_value =
- result.classification[0].numeric_value;
- classification_result->numeric_double_value =
- result.classification[0].numeric_double_value;
- return true;
- }
- }
}
-
return false;
}
@@ -107,21 +79,24 @@ bool NumberAnnotator::FindAll(const UnicodeText& context,
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
int64 parsed_int_value;
double parsed_double_value;
+ bool has_decimal;
int num_prefix_codepoints;
int num_suffix_codepoints;
if (ParseNumber(token_text, &parsed_int_value, &parsed_double_value,
- &num_prefix_codepoints, &num_suffix_codepoints)) {
+ &has_decimal, &num_prefix_codepoints,
+ &num_suffix_codepoints)) {
ClassificationResult classification{Collections::Number(),
options_->score()};
classification.numeric_value = parsed_int_value;
classification.numeric_double_value = parsed_double_value;
- classification.priority_score = options_->priority_score();
+ classification.priority_score =
+ has_decimal ? options_->float_number_priority_score()
+ : options_->priority_score();
AnnotatedSpan annotated_span;
annotated_span.span = {token.start + num_prefix_codepoints,
token.end - num_suffix_codepoints};
annotated_span.classification.push_back(classification);
-
result->push_back(annotated_span);
}
}
@@ -151,7 +126,7 @@ std::vector<uint32> NumberAnnotator::FlatbuffersIntVectorToStdVector(
namespace {
bool ParseNextNumericCodepoint(int32 codepoint, int64* current_value) {
- if (*current_value > INT64_MAX / 10) {
+ if (*current_value > INT64_MAX / 10 - 10) {
return false;
}
@@ -163,20 +138,20 @@ bool ParseNextNumericCodepoint(int32 codepoint, int64* current_value) {
UnicodeText::const_iterator ConsumeAndParseNumber(
const UnicodeText::const_iterator& it_begin,
const UnicodeText::const_iterator& it_end, int64* int_result,
- double* double_result) {
+ double* double_result, bool* has_decimal) {
*int_result = 0;
+ *has_decimal = false;
// See if there's a sign in the beginning of the number.
int sign = 1;
auto it = it_begin;
- if (it != it_end) {
+ while (it != it_end && (*it == '-' || *it == '+')) {
if (*it == '-') {
- ++it;
sign = -1;
- } else if (*it == '+') {
- ++it;
+ } else {
sign = 1;
}
+ ++it;
}
enum class State {
@@ -203,6 +178,7 @@ UnicodeText::const_iterator ConsumeAndParseNumber(
break;
case State::PARSING_FLOATING_PART:
if (*it >= '0' && *it <= '9') {
+ *has_decimal = true;
if (!ParseNextNumericCodepoint(*it, &decimal_result)) {
state = State::PARSING_DONE;
break;
@@ -236,7 +212,7 @@ UnicodeText::const_iterator ConsumeAndParseNumber(
} // namespace
bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result,
- double* double_result,
+ double* double_result, bool* has_decimal,
int* num_prefix_codepoints,
int* num_suffix_codepoints) const {
TC3_CHECK(int_result != nullptr && double_result != nullptr &&
@@ -258,13 +234,6 @@ bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result,
// Consume prefix codepoints.
*num_prefix_codepoints = stripped_span.first;
- bool valid_prefix = true;
- // Makes valid_prefix=false for cases like: "#10" where it points to '1'. In
- // this case the adjacent prefix is not an allowed one.
- if (it != text.begin() && allowed_prefix_codepoints_.find(*std::prev(it)) ==
- allowed_prefix_codepoints_.end()) {
- valid_prefix = false;
- }
while (it != it_end) {
if (allowed_prefix_codepoints_.find(*it) ==
allowed_prefix_codepoints_.end()) {
@@ -276,7 +245,8 @@ bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result,
}
auto it_start = it;
- it = ConsumeAndParseNumber(it, it_end, int_result, double_result);
+ it =
+ ConsumeAndParseNumber(it, it_end, int_result, double_result, has_decimal);
if (it == it_start) {
return false;
}
@@ -284,32 +254,35 @@ bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result,
// Consume suffix codepoints.
bool valid_suffix = true;
*num_suffix_codepoints = 0;
+ int ignored_suffix_codepoints = 0;
while (it != it_end) {
- if (allowed_suffix_codepoints_.find(*it) ==
+ if (allowed_suffix_codepoints_.find(*it) !=
allowed_suffix_codepoints_.end()) {
+ // Keep track of allowed suffix codepoints.
+ ++(*num_suffix_codepoints);
+ } else if (ignored_suffix_span_boundary_codepoints_.find(*it) ==
+ ignored_suffix_span_boundary_codepoints_.end()) {
+ // There is a suffix codepoint but it's not part of the ignored list of
+ // codepoints, fail the number parsing.
+ // Note: We want to support cases like "13.", "34#", "123!" etc.
valid_suffix = false;
break;
+ } else {
+ ++ignored_suffix_codepoints;
}
++it;
- ++(*num_suffix_codepoints);
}
*num_suffix_codepoints += num_stripped_end;
- // Makes valid_suffix=false for cases like: "10@", when it == it_end and
- // points to '@'. This adjacent character is not an allowed suffix.
- if (it == it_end && it != text.end() &&
- allowed_suffix_codepoints_.find(*it) ==
- allowed_suffix_codepoints_.end()) {
- valid_suffix = false;
- }
-
- return valid_suffix && valid_prefix;
+ return valid_suffix;
}
int NumberAnnotator::GetPercentSuffixLength(const UnicodeText& context,
- int context_size_codepoints,
int index_codepoints) const {
+ if (index_codepoints >= context.size_codepoints()) {
+ return -1;
+ }
auto context_it = context.begin();
std::advance(context_it, index_codepoints);
const StringPiece suffix_context(
@@ -329,15 +302,13 @@ int NumberAnnotator::GetPercentSuffixLength(const UnicodeText& context,
void NumberAnnotator::FindPercentages(
const UnicodeText& context, std::vector<AnnotatedSpan>* result) const {
- int context_size_codepoints = context.size_codepoints();
for (auto& res : *result) {
if (res.classification.empty() ||
res.classification[0].collection != Collections::Number()) {
continue;
}
- const int match_length = GetPercentSuffixLength(
- context, context_size_codepoints, res.span.second);
+ const int match_length = GetPercentSuffixLength(context, res.span.second);
if (match_length > 0) {
res.classification[0].collection = Collections::Percentage();
res.classification[0].priority_score =
diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h
index 3debd09..3e9e2c3 100644
--- a/native/annotator/number/number.h
+++ b/native/annotator/number/number.h
@@ -81,16 +81,17 @@ class NumberAnnotator {
static std::vector<uint32> FlatbuffersIntVectorToStdVector(
const flatbuffers::Vector<int32_t>* ints);
- // Parses the text to an int64 value and returns true if succeeded, otherwise
- // false. Also returns the number of prefix/suffix codepoints that were
- // stripped from the number.
+ // Parses the text to an int64 value and a double value and returns true if
+ // succeeded, otherwise false. Also returns whether the number contains a
+ // decimal and the number of prefix/suffix codepoints that were stripped from
+ // the number.
bool ParseNumber(const UnicodeText& text, int64* int_result,
- double* double_result, int* num_prefix_codepoints,
+ double* double_result, bool* has_decimal,
+ int* num_prefix_codepoints,
int* num_suffix_codepoints) const;
// Get the length of the percent suffix at the specified index in the context.
int GetPercentSuffixLength(const UnicodeText& context,
- int context_size_codepoints,
int index_codepoints) const;
// Checks if the annotated numbers from the context represent percentages.
diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb
index ce5f72f..bbf730e 100644
--- a/native/annotator/test_data/test_model.fb
+++ b/native/annotator/test_data/test_model.fb
Binary files differ
diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb
index efefa3c..135dec0 100644
--- a/native/annotator/test_data/wrong_embeddings.fb
+++ b/native/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/native/annotator/types.cc b/native/annotator/types.cc
index c31097d..1ec3790 100644
--- a/native/annotator/types.cc
+++ b/native/annotator/types.cc
@@ -56,6 +56,64 @@ std::string FormatMillis(int64 time_ms_utc) {
}
} // namespace
+std::string ComponentTypeToString(
+ const DatetimeComponent::ComponentType& component_type) {
+ switch (component_type) {
+ case DatetimeComponent::ComponentType::UNSPECIFIED:
+ return "UNSPECIFIED";
+ case DatetimeComponent::ComponentType::YEAR:
+ return "YEAR";
+ case DatetimeComponent::ComponentType::MONTH:
+ return "MONTH";
+ case DatetimeComponent::ComponentType::WEEK:
+ return "WEEK";
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ return "DAY_OF_WEEK";
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ return "DAY_OF_MONTH";
+ case DatetimeComponent::ComponentType::HOUR:
+ return "HOUR";
+ case DatetimeComponent::ComponentType::MINUTE:
+ return "MINUTE";
+ case DatetimeComponent::ComponentType::SECOND:
+ return "SECOND";
+ case DatetimeComponent::ComponentType::MERIDIEM:
+ return "MERIDIEM";
+ case DatetimeComponent::ComponentType::ZONE_OFFSET:
+ return "ZONE_OFFSET";
+ case DatetimeComponent::ComponentType::DST_OFFSET:
+ return "DST_OFFSET";
+ default:
+ return "";
+ }
+}
+
+std::string RelativeQualifierToString(
+ const DatetimeComponent::RelativeQualifier& relative_qualifier) {
+ switch (relative_qualifier) {
+ case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
+ return "UNSPECIFIED";
+ case DatetimeComponent::RelativeQualifier::NEXT:
+ return "NEXT";
+ case DatetimeComponent::RelativeQualifier::THIS:
+ return "THIS";
+ case DatetimeComponent::RelativeQualifier::LAST:
+ return "LAST";
+ case DatetimeComponent::RelativeQualifier::NOW:
+ return "NOW";
+ case DatetimeComponent::RelativeQualifier::TOMORROW:
+ return "TOMORROW";
+ case DatetimeComponent::RelativeQualifier::YESTERDAY:
+ return "YESTERDAY";
+ case DatetimeComponent::RelativeQualifier::PAST:
+ return "PAST";
+ case DatetimeComponent::RelativeQualifier::FUTURE:
+ return "FUTURE";
+ default:
+ return "";
+ }
+}
+
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const DatetimeParseResultSpan& value) {
stream << "DatetimeParseResultSpan({" << value.span.first << ", "
@@ -63,7 +121,16 @@ logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
for (const DatetimeParseResult& data : value.data) {
stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
<< FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
- << data.granularity << "}, ";
+ << data.granularity << ", /*datetime_components=*/ ";
+ for (const DatetimeComponent& datetime_comp : data.datetime_components) {
+ stream << "{/*component_type=*/ "
+ << ComponentTypeToString(datetime_comp.component_type)
+ << " /*relative_qualifier=*/ "
+ << RelativeQualifierToString(datetime_comp.relative_qualifier)
+ << " /*value=*/ " << datetime_comp.value << " /*relative_count=*/ "
+ << datetime_comp.relative_count << "}, ";
+ }
+ stream << "}, ";
}
stream << "})";
return stream;
diff --git a/native/annotator/types.h b/native/annotator/types.h
index ac24e24..9b94c10 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -352,7 +352,7 @@ struct ClassificationResult {
// Entity data information.
std::string serialized_entity_data;
- const EntityData* entity_data() {
+ const EntityData* entity_data() const {
return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
serialized_entity_data.size());
}
diff --git a/native/annotator/zlib-utils.cc b/native/annotator/zlib-utils.cc
index ec2392b..c3c2cf1 100644
--- a/native/annotator/zlib-utils.cc
+++ b/native/annotator/zlib-utils.cc
@@ -125,6 +125,15 @@ bool DecompressModel(ModelT* model) {
extractor->compressed_pattern.reset(nullptr);
}
}
+
+ if (model->resources != nullptr) {
+ DecompressResources(model->resources.get());
+ }
+
+ if (model->intent_options != nullptr) {
+ DecompressIntentModel(model->intent_options.get());
+ }
+
return true;
}
diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc
index 7a8d775..363c155 100644
--- a/native/annotator/zlib-utils_test.cc
+++ b/native/annotator/zlib-utils_test.cc
@@ -43,6 +43,37 @@ TEST(ZlibUtilsTest, CompressModel) {
model.datetime_model->extractors.back()->pattern =
"an example datetime extractor";
+ model.intent_options.reset(new IntentFactoryModelT);
+ model.intent_options->generator.emplace_back(
+ new IntentFactoryModel_::IntentGeneratorT);
+ const std::string intent_generator1 = "lua generator 1";
+ model.intent_options->generator.back()->lua_template_generator =
+ std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end());
+ model.intent_options->generator.emplace_back(
+ new IntentFactoryModel_::IntentGeneratorT);
+ const std::string intent_generator2 = "lua generator 2";
+ model.intent_options->generator.back()->lua_template_generator =
+ std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end());
+
+ // NOTE: The resource strings contain some repetition, so that the compressed
+ // version is smaller than the uncompressed one. Because the compression code
+ // looks at that as well.
+ model.resources.reset(new ResourcePoolT);
+ model.resources->resource_entry.emplace_back(new ResourceEntryT);
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr1.1";
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr1.2";
+ model.resources->resource_entry.emplace_back(new ResourceEntryT);
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr2.1";
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr2.2";
+
// Compress the model.
EXPECT_TRUE(CompressModel(&model));
@@ -51,6 +82,14 @@ TEST(ZlibUtilsTest, CompressModel) {
EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
+ EXPECT_TRUE(
+ model.intent_options->generator[0]->lua_template_generator.empty());
+ EXPECT_TRUE(
+ model.intent_options->generator[1]->lua_template_generator.empty());
+ EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty());
// Pack and load the model.
flatbuffers::FlatBufferBuilder builder;
@@ -94,6 +133,20 @@ TEST(ZlibUtilsTest, CompressModel) {
"an example datetime pattern");
EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
"an example datetime extractor");
+ EXPECT_EQ(
+ model.intent_options->generator[0]->lua_template_generator,
+ std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end()));
+ EXPECT_EQ(
+ model.intent_options->generator[1]->lua_template_generator,
+ std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()));
+ EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content,
+ "rrrrrrrrrrrrr1.1");
+ EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content,
+ "rrrrrrrrrrrrr1.2");
+ EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content,
+ "rrrrrrrrrrrrr2.1");
+ EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content,
+ "rrrrrrrrrrrrr2.2");
}
} // namespace libtextclassifier3