diff options
author | Tony Mak <tonymak@google.com> | 2019-10-15 15:29:22 +0100 |
---|---|---|
committer | Tony Mak <tonymak@google.com> | 2019-10-15 18:33:02 +0100 |
commit | 8cd7ba6be23c557c608653330d931e6700f19688 (patch) | |
tree | edaf9e6b6a52c247985af97ac94333b05cfb85bf /native/annotator | |
parent | f9143f3090d0e29353d0841f3e892babc947b4d2 (diff) | |
download | libtextclassifier-8cd7ba6be23c557c608653330d931e6700f19688.tar.gz |
Import libtextclassifier
Test: atest TextClassifierServiceTest
Change-Id: Ief715193072d0af3aea230c3c9475ef18e8ac84c
Diffstat (limited to 'native/annotator')
23 files changed, 916 insertions, 384 deletions
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index 867eea0..d910c8d 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -29,6 +29,7 @@ #include "utils/base/logging.h" #include "utils/checksum.h" #include "utils/math/softmax.h" +#include "utils/normalization.h" #include "utils/optional.h" #include "utils/regex-match.h" #include "utils/utf8/unicodetext.h" @@ -416,7 +417,7 @@ void Annotator::ValidateAndInitialize() { model_->duration_annotator_options()->enabled()) { duration_annotator_.reset( new DurationAnnotator(model_->duration_annotator_options(), - selection_feature_processor_.get())); + selection_feature_processor_.get(), unilib_)); } if (model_->entity_data_schema()) { @@ -505,6 +506,10 @@ bool Annotator::InitializeKnowledgeEngine( TC3_LOG(ERROR) << "Failed to initialize the knowledge engine."; return false; } + if (model_->triggering_options() != nullptr) { + knowledge_engine->SetPriorityScore( + model_->triggering_options()->knowledge_priority_score()); + } knowledge_engine_ = std::move(knowledge_engine); return true; } @@ -2075,8 +2080,19 @@ bool Annotator::SerializedEntityDataFromRegexMatch( // Set entity field from capturing group text. if (group->entity_field_path() != nullptr) { - if (!entity_data->ParseAndSet(group->entity_field_path(), - group_match_text.value())) { + UnicodeText normalized_group_match_text = + UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false); + + // Apply normalization if specified. + if (group->normalization_options() != nullptr) { + normalized_group_match_text = + NormalizeText(unilib_, group->normalization_options(), + normalized_group_match_text); + } + + if (!entity_data->ParseAndSet( + group->entity_field_path(), + normalized_group_match_text.ToUTF8String())) { TC3_LOG(ERROR) << "Could not set entity data from rule capturing group."; return false; diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc index e5b7833..28be366 100644 --- a/native/annotator/annotator_jni.cc +++ b/native/annotator/annotator_jni.cc @@ -19,6 +19,7 @@ #include "annotator/annotator_jni.h" #include <jni.h> + #include <type_traits> #include <vector> @@ -26,11 +27,12 @@ #include "annotator/annotator_jni_common.h" #include "annotator/types.h" #include "utils/base/integral_types.h" +#include "utils/base/statusor.h" #include "utils/calendar/calendar.h" #include "utils/intents/intent-generator.h" #include "utils/intents/jni.h" #include "utils/java/jni-cache.h" -#include "utils/java/scoped_local_ref.h" +#include "utils/java/jni-helper.h" #include "utils/java/string_utils.h" #include "utils/memory/mmap.h" #include "utils/strings/stringpiece.h" @@ -48,8 +50,10 @@ using libtextclassifier3::AnnotatedSpan; using libtextclassifier3::Annotator; using libtextclassifier3::ClassificationResult; using libtextclassifier3::CodepointSpan; +using libtextclassifier3::JniHelper; using libtextclassifier3::Model; using libtextclassifier3::ScopedLocalRef; +using libtextclassifier3::StatusOr; // When using the Java's ICU, CalendarLib and UniLib need to be instantiated // with a JavaVM pointer from JNI. When using a standard ICU the pointer is // not needed and the objects are instantiated implicitly. @@ -71,6 +75,7 @@ class AnnotatorJniContext { if (jni_cache == nullptr || model == nullptr) { return nullptr; } + // Intent generator will be null if the options are not specified. std::unique_ptr<IntentGenerator> intent_generator = IntentGenerator::Create(model->model()->intent_options(), model->model()->resources(), jni_cache); @@ -79,6 +84,7 @@ class AnnotatorJniContext { if (template_handler == nullptr) { return nullptr; } + return new AnnotatorJniContext(jni_cache, std::move(model), std::move(intent_generator), std::move(template_handler)); @@ -90,6 +96,8 @@ class AnnotatorJniContext { Annotator* model() const { return model_.get(); } + // NOTE: Intent generator will be null if the options are not specified in + // the model. IntentGenerator* intent_generator() const { return intent_generator_.get(); } RemoteActionTemplatesHandler* template_handler() const { @@ -113,184 +121,217 @@ class AnnotatorJniContext { std::unique_ptr<RemoteActionTemplatesHandler> template_handler_; }; -jobject ClassificationResultWithIntentsToJObject( +StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject( JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context, jclass result_class, jmethodID result_class_constructor, jclass datetime_parse_class, jmethodID datetime_parse_class_constructor, const jstring device_locales, const ClassificationOptions* options, const std::string& context, const CodepointSpan& selection_indices, const ClassificationResult& classification_result, bool generate_intents) { - jstring row_string = - env->NewStringUTF(classification_result.collection.c_str()); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jstring> row_string, + JniHelper::NewStringUTF(env, classification_result.collection.c_str())); - jobject row_datetime_parse = nullptr; + ScopedLocalRef<jobject> row_datetime_parse; if (classification_result.datetime_parse_result.IsSet()) { - row_datetime_parse = - env->NewObject(datetime_parse_class, datetime_parse_class_constructor, - classification_result.datetime_parse_result.time_ms_utc, - classification_result.datetime_parse_result.granularity); + TC3_ASSIGN_OR_RETURN( + row_datetime_parse, + JniHelper::NewObject( + env, datetime_parse_class, datetime_parse_class_constructor, + classification_result.datetime_parse_result.time_ms_utc, + classification_result.datetime_parse_result.granularity)); } - jbyteArray serialized_knowledge_result = nullptr; + ScopedLocalRef<jbyteArray> serialized_knowledge_result; const std::string& serialized_knowledge_result_string = classification_result.serialized_knowledge_result; if (!serialized_knowledge_result_string.empty()) { - serialized_knowledge_result = - env->NewByteArray(serialized_knowledge_result_string.size()); - env->SetByteArrayRegion(serialized_knowledge_result, 0, + TC3_ASSIGN_OR_RETURN(serialized_knowledge_result, + JniHelper::NewByteArray( + env, serialized_knowledge_result_string.size())); + env->SetByteArrayRegion(serialized_knowledge_result.get(), 0, serialized_knowledge_result_string.size(), reinterpret_cast<const jbyte*>( serialized_knowledge_result_string.data())); } - jstring contact_name = nullptr; + ScopedLocalRef<jstring> contact_name; if (!classification_result.contact_name.empty()) { - contact_name = - env->NewStringUTF(classification_result.contact_name.c_str()); + TC3_ASSIGN_OR_RETURN(contact_name, + JniHelper::NewStringUTF( + env, classification_result.contact_name.c_str())); } - jstring contact_given_name = nullptr; + ScopedLocalRef<jstring> contact_given_name; if (!classification_result.contact_given_name.empty()) { - contact_given_name = - env->NewStringUTF(classification_result.contact_given_name.c_str()); + TC3_ASSIGN_OR_RETURN( + contact_given_name, + JniHelper::NewStringUTF( + env, classification_result.contact_given_name.c_str())); } - jstring contact_family_name = nullptr; + ScopedLocalRef<jstring> contact_family_name; if (!classification_result.contact_family_name.empty()) { - contact_family_name = - env->NewStringUTF(classification_result.contact_family_name.c_str()); + TC3_ASSIGN_OR_RETURN( + contact_family_name, + JniHelper::NewStringUTF( + env, classification_result.contact_family_name.c_str())); } - jstring contact_nickname = nullptr; + ScopedLocalRef<jstring> contact_nickname; if (!classification_result.contact_nickname.empty()) { - contact_nickname = - env->NewStringUTF(classification_result.contact_nickname.c_str()); + TC3_ASSIGN_OR_RETURN( + contact_nickname, + JniHelper::NewStringUTF( + env, classification_result.contact_nickname.c_str())); } - jstring contact_email_address = nullptr; + ScopedLocalRef<jstring> contact_email_address; if (!classification_result.contact_email_address.empty()) { - contact_email_address = - env->NewStringUTF(classification_result.contact_email_address.c_str()); + TC3_ASSIGN_OR_RETURN( + contact_email_address, + JniHelper::NewStringUTF( + env, classification_result.contact_email_address.c_str())); } - jstring contact_phone_number = nullptr; + ScopedLocalRef<jstring> contact_phone_number; if (!classification_result.contact_phone_number.empty()) { - contact_phone_number = - env->NewStringUTF(classification_result.contact_phone_number.c_str()); + TC3_ASSIGN_OR_RETURN( + contact_phone_number, + JniHelper::NewStringUTF( + env, classification_result.contact_phone_number.c_str())); } - jstring contact_id = nullptr; + ScopedLocalRef<jstring> contact_id; if (!classification_result.contact_id.empty()) { - contact_id = env->NewStringUTF(classification_result.contact_id.c_str()); + TC3_ASSIGN_OR_RETURN( + contact_id, + JniHelper::NewStringUTF(env, classification_result.contact_id.c_str())); } - jstring app_name = nullptr; + ScopedLocalRef<jstring> app_name; if (!classification_result.app_name.empty()) { - app_name = env->NewStringUTF(classification_result.app_name.c_str()); + TC3_ASSIGN_OR_RETURN( + app_name, + JniHelper::NewStringUTF(env, classification_result.app_name.c_str())); } - jstring app_package_name = nullptr; + ScopedLocalRef<jstring> app_package_name; if (!classification_result.app_package_name.empty()) { - app_package_name = - env->NewStringUTF(classification_result.app_package_name.c_str()); + TC3_ASSIGN_OR_RETURN( + app_package_name, + JniHelper::NewStringUTF( + env, classification_result.app_package_name.c_str())); } - jobject extras = nullptr; + ScopedLocalRef<jobjectArray> extras; if (model_context->model()->entity_data_schema() != nullptr && !classification_result.serialized_entity_data.empty()) { - extras = model_context->template_handler()->EntityDataAsNamedVariantArray( - model_context->model()->entity_data_schema(), - classification_result.serialized_entity_data); + TC3_ASSIGN_OR_RETURN( + extras, + model_context->template_handler()->EntityDataAsNamedVariantArray( + model_context->model()->entity_data_schema(), + classification_result.serialized_entity_data)); } - jbyteArray serialized_entity_data = nullptr; + ScopedLocalRef<jbyteArray> serialized_entity_data; if (!classification_result.serialized_entity_data.empty()) { - serialized_entity_data = - env->NewByteArray(classification_result.serialized_entity_data.size()); + TC3_ASSIGN_OR_RETURN( + serialized_entity_data, + JniHelper::NewByteArray( + env, classification_result.serialized_entity_data.size())); env->SetByteArrayRegion( - serialized_entity_data, 0, + serialized_entity_data.get(), 0, classification_result.serialized_entity_data.size(), reinterpret_cast<const jbyte*>( classification_result.serialized_entity_data.data())); } - jobject remote_action_templates_result = nullptr; + ScopedLocalRef<jobjectArray> remote_action_templates_result; // Only generate RemoteActionTemplate for the top classification result // as classifyText does not need RemoteAction from other results anyway. if (generate_intents && model_context->intent_generator() != nullptr) { std::vector<RemoteActionTemplate> remote_action_templates; - if (model_context->intent_generator()->GenerateIntents( + if (!model_context->intent_generator()->GenerateIntents( device_locales, classification_result, options->reference_time_ms_utc, context, selection_indices, app_context, model_context->model()->entity_data_schema(), &remote_action_templates)) { - remote_action_templates_result = - model_context->template_handler() - ->RemoteActionTemplatesToJObjectArray(remote_action_templates); + return {Status::UNKNOWN}; } - } - return env->NewObject( - result_class, result_class_constructor, row_string, - static_cast<jfloat>(classification_result.score), row_datetime_parse, - serialized_knowledge_result, contact_name, contact_given_name, - contact_family_name, contact_nickname, contact_email_address, - contact_phone_number, contact_id, app_name, app_package_name, extras, - serialized_entity_data, remote_action_templates_result, - classification_result.duration_ms, classification_result.numeric_value, + TC3_ASSIGN_OR_RETURN( + remote_action_templates_result, + model_context->template_handler()->RemoteActionTemplatesToJObjectArray( + remote_action_templates)); + } + + return JniHelper::NewObject( + env, result_class, result_class_constructor, row_string.get(), + static_cast<jfloat>(classification_result.score), + row_datetime_parse.get(), serialized_knowledge_result.get(), + contact_name.get(), contact_given_name.get(), contact_family_name.get(), + contact_nickname.get(), contact_email_address.get(), + contact_phone_number.get(), contact_id.get(), app_name.get(), + app_package_name.get(), extras.get(), serialized_entity_data.get(), + remote_action_templates_result.get(), classification_result.duration_ms, + classification_result.numeric_value, classification_result.numeric_double_value); } -jobjectArray ClassificationResultsWithIntentsToJObjectArray( +StatusOr<ScopedLocalRef<jobjectArray>> +ClassificationResultsWithIntentsToJObjectArray( JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context, const jstring device_locales, const ClassificationOptions* options, const std::string& context, const CodepointSpan& selection_indices, const std::vector<ClassificationResult>& classification_result, bool generate_intents) { - const ScopedLocalRef<jclass> result_class( - env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR - "$ClassificationResult"), - env); - if (!result_class) { - TC3_LOG(ERROR) << "Couldn't find ClassificationResult class."; - return nullptr; - } - const ScopedLocalRef<jclass> datetime_parse_class( - env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR - "$DatetimeResult"), - env); - if (!datetime_parse_class) { - TC3_LOG(ERROR) << "Couldn't find DatetimeResult class."; - return nullptr; - } - - const jmethodID result_class_constructor = env->GetMethodID( - result_class.get(), "<init>", - "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR - "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;" - "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;" - "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH - "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH - "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"); - const jmethodID datetime_parse_class_constructor = - env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jclass> result_class, + JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$ClassificationResult")); + + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jclass> datetime_parse_class, + JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$DatetimeResult")); + + TC3_ASSIGN_OR_RETURN( + const jmethodID result_class_constructor, + JniHelper::GetMethodID( + env, result_class.get(), "<init>", + "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/" + "String;" + "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/" + "String;" + "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH + "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH + "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V")); + TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor, + JniHelper::GetMethodID(env, datetime_parse_class.get(), + "<init>", "(JI)V")); + + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jobjectArray> results, + JniHelper::NewObjectArray(env, classification_result.size(), + result_class.get())); - const jobjectArray results = env->NewObjectArray(classification_result.size(), - result_class.get(), nullptr); for (int i = 0; i < classification_result.size(); i++) { - jobject result = ClassificationResultWithIntentsToJObject( - env, model_context, app_context, result_class.get(), - result_class_constructor, datetime_parse_class.get(), - datetime_parse_class_constructor, device_locales, options, context, - selection_indices, classification_result[i], - generate_intents && (i == 0)); - env->SetObjectArrayElement(results, i, result); - env->DeleteLocalRef(result); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jobject> result, + ClassificationResultWithIntentsToJObject( + env, model_context, app_context, result_class.get(), + result_class_constructor, datetime_parse_class.get(), + datetime_parse_class_constructor, device_locales, options, context, + selection_indices, classification_result[i], + generate_intents && (i == 0))); + env->SetObjectArrayElement(results.get(), i, result.get()); } return results; } -jobjectArray ClassificationResultsToJObjectArray( +StatusOr<ScopedLocalRef<jobjectArray>> ClassificationResultsToJObjectArray( JNIEnv* env, const AnnotatorJniContext* model_context, const std::vector<ClassificationResult>& classification_result) { return ClassificationResultsWithIntentsToJObjectArray( @@ -361,16 +402,18 @@ CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); } -jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { +StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap( + JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { if (!mmap->handle().ok()) { - return env->NewStringUTF(""); + return JniHelper::NewStringUTF(env, ""); } const Model* model = libtextclassifier3::ViewModel( mmap->handle().start(), mmap->handle().num_bytes()); if (!model || !model->locales()) { - return env->NewStringUTF(""); + return JniHelper::NewStringUTF(env, ""); } - return env->NewStringUTF(model->locales()->c_str()); + + return JniHelper::NewStringUTF(env, model->locales()->c_str()); } jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { @@ -385,16 +428,17 @@ jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { return model->version(); } -jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { +StatusOr<ScopedLocalRef<jstring>> GetNameFromMmap( + JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { if (!mmap->handle().ok()) { - return env->NewStringUTF(""); + return JniHelper::NewStringUTF(env, ""); } const Model* model = libtextclassifier3::ViewModel( mmap->handle().start(), mmap->handle().num_bytes()); if (!model || !model->name()) { - return env->NewStringUTF(""); + return JniHelper::NewStringUTF(env, ""); } - return env->NewStringUTF(model->name()->c_str()); + return JniHelper::NewStringUTF(env, model->name()->c_str()); } } // namespace libtextclassifier3 @@ -427,7 +471,7 @@ TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator) TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath) (JNIEnv* env, jobject thiz, jstring path) { - const std::string path_str = ToStlString(env, path); + TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path)); std::shared_ptr<libtextclassifier3::JniCache> jni_cache( libtextclassifier3::JniCache::Create(env)); #ifdef TC3_USE_JAVAICU @@ -531,17 +575,22 @@ TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection) return nullptr; } const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model(); - const std::string context_utf8 = ToStlString(env, context); + TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8, + ToStlString(env, context)); CodepointSpan input_indices = ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); - CodepointSpan selection = model->SuggestSelection( - context_utf8, input_indices, FromJavaSelectionOptions(env, options)); + TC3_ASSIGN_OR_RETURN_NULL( + libtextclassifier3::SelectionOptions selection_options, + FromJavaSelectionOptions(env, options)); + CodepointSpan selection = + model->SuggestSelection(context_utf8, input_indices, selection_options); selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); - jintArray result = env->NewIntArray(2); - env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); - env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); - return result; + TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jintArray> result, + JniHelper::NewIntArray(env, 2)); + env->SetIntArrayRegion(result.get(), 0, 1, &(std::get<0>(selection))); + env->SetIntArrayRegion(result.get(), 1, 1, &(std::get<1>(selection))); + return result.release(); } TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText) @@ -554,23 +603,33 @@ TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText) const AnnotatorJniContext* model_context = reinterpret_cast<AnnotatorJniContext*>(ptr); - const std::string context_utf8 = ToStlString(env, context); + TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8, + ToStlString(env, context)); const CodepointSpan input_indices = ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); - const libtextclassifier3::ClassificationOptions classification_options = - FromJavaClassificationOptions(env, options); + TC3_ASSIGN_OR_RETURN_NULL( + const libtextclassifier3::ClassificationOptions classification_options, + FromJavaClassificationOptions(env, options)); const std::vector<ClassificationResult> classification_result = model_context->model()->ClassifyText(context_utf8, input_indices, classification_options); + + ScopedLocalRef<jobjectArray> result; if (app_context != nullptr) { - return ClassificationResultsWithIntentsToJObjectArray( - env, model_context, app_context, device_locales, - &classification_options, context_utf8, input_indices, - classification_result, - /*generate_intents=*/true); - } - return ClassificationResultsToJObjectArray(env, model_context, - classification_result); + TC3_ASSIGN_OR_RETURN_NULL( + result, ClassificationResultsWithIntentsToJObjectArray( + env, model_context, app_context, device_locales, + &classification_options, context_utf8, input_indices, + classification_result, + /*generate_intents=*/true)); + + } else { + TC3_ASSIGN_OR_RETURN_NULL( + result, ClassificationResultsToJObjectArray(env, model_context, + classification_result)); + } + + return result.release(); } TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate) @@ -580,41 +639,46 @@ TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate) } const AnnotatorJniContext* model_context = reinterpret_cast<AnnotatorJniContext*>(ptr); - const std::string context_utf8 = ToStlString(env, context); + TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8, + ToStlString(env, context)); + TC3_ASSIGN_OR_RETURN_NULL( + libtextclassifier3::AnnotationOptions annotation_options, + FromJavaAnnotationOptions(env, options)); const std::vector<AnnotatedSpan> annotations = - model_context->model()->Annotate(context_utf8, - FromJavaAnnotationOptions(env, options)); - - jclass result_class = env->FindClass( - TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"); - if (!result_class) { - TC3_LOG(ERROR) << "Couldn't find result class: " - << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR - "$AnnotatedSpan"; - return nullptr; - } + model_context->model()->Annotate(context_utf8, annotation_options); + + TC3_ASSIGN_OR_RETURN_NULL( + ScopedLocalRef<jclass> result_class, + JniHelper::FindClass( + env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan")); jmethodID result_class_constructor = - env->GetMethodID(result_class, "<init>", + env->GetMethodID(result_class.get(), "<init>", "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationResult;)V"); - jobjectArray results = - env->NewObjectArray(annotations.size(), result_class, nullptr); + TC3_ASSIGN_OR_RETURN_NULL( + ScopedLocalRef<jobjectArray> results, + JniHelper::NewObjectArray(env, annotations.size(), result_class.get())); for (int i = 0; i < annotations.size(); ++i) { CodepointSpan span_bmp = ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); - jobject result = env->NewObject( - result_class, result_class_constructor, - static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second), + + TC3_ASSIGN_OR_RETURN_NULL( + ScopedLocalRef<jobjectArray> classification_results, ClassificationResultsToJObjectArray(env, model_context, annotations[i].classification)); - env->SetObjectArrayElement(results, i, result); - env->DeleteLocalRef(result); + + TC3_ASSIGN_OR_RETURN_NULL( + ScopedLocalRef<jobject> result, + JniHelper::NewObject(env, result_class.get(), result_class_constructor, + static_cast<jint>(span_bmp.first), + static_cast<jint>(span_bmp.second), + classification_results.get())); + env->SetObjectArrayElement(results.get(), i, result.get()); } - env->DeleteLocalRef(result_class); - return results; + return results.release(); } TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME, @@ -624,16 +688,19 @@ TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME, return nullptr; } const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model(); - const std::string id_utf8 = ToStlString(env, id); + TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8, ToStlString(env, id)); std::string serialized_knowledge_result; if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) { return nullptr; } - jbyteArray result = env->NewByteArray(serialized_knowledge_result.size()); + + TC3_ASSIGN_OR_RETURN_NULL( + ScopedLocalRef<jbyteArray> result, + JniHelper::NewByteArray(env, serialized_knowledge_result.size())); env->SetByteArrayRegion( - result, 0, serialized_knowledge_result.size(), + result.get(), 0, serialized_knowledge_result.size(), reinterpret_cast<const jbyte*>(serialized_knowledge_result.data())); - return result; + return result.release(); } TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) @@ -654,14 +721,18 @@ TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales) (JNIEnv* env, jobject clazz, jint fd) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd)); - return GetLocalesFromMmap(env, mmap.get()); + TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value, + GetLocalesFromMmap(env, mmap.get())); + return value.release(); } TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset) (JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd, offset, size)); - return GetLocalesFromMmap(env, mmap.get()); + TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value, + GetLocalesFromMmap(env, mmap.get())); + return value.release(); } TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion) @@ -682,12 +753,16 @@ TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName) (JNIEnv* env, jobject clazz, jint fd) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd)); - return GetNameFromMmap(env, mmap.get()); + TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value, + GetNameFromMmap(env, mmap.get())); + return value.release(); } TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset) (JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) { const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( new libtextclassifier3::ScopedMmap(fd, offset, size)); - return GetNameFromMmap(env, mmap.get()); + TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value, + GetNameFromMmap(env, mmap.get())); + return value.release(); } diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc index 55f14e6..575e71b 100644 --- a/native/annotator/annotator_jni_common.cc +++ b/native/annotator/annotator_jni_common.cc @@ -17,138 +17,176 @@ #include "annotator/annotator_jni_common.h" #include "utils/java/jni-base.h" -#include "utils/java/scoped_local_ref.h" +#include "utils/java/jni-helper.h" namespace libtextclassifier3 { namespace { -std::unordered_set<std::string> EntityTypesFromJObject(JNIEnv* env, - const jobject& jobject) { +StatusOr<std::unordered_set<std::string>> EntityTypesFromJObject( + JNIEnv* env, const jobject& jobject) { std::unordered_set<std::string> entity_types; jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject); const int size = env->GetArrayLength(jentity_types); for (int i = 0; i < size; ++i) { - jstring jentity_type = - reinterpret_cast<jstring>(env->GetObjectArrayElement(jentity_types, i)); - entity_types.insert(ToStlString(env, jentity_type)); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jstring> jentity_type, + JniHelper::GetObjectArrayElement<jstring>(env, jentity_types, i)); + TC3_ASSIGN_OR_RETURN(std::string entity_type, + ToStlString(env, jentity_type.get())); + entity_types.insert(entity_type); } return entity_types; } template <typename T> -T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, - const std::string& class_name) { +StatusOr<T> FromJavaOptionsInternal(JNIEnv* env, jobject joptions, + const std::string& class_name) { if (!joptions) { - return {}; + return {Status::UNKNOWN}; } - const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), - env); - if (!options_class) { - return {}; - } - - const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( - env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, - "getLocale", "Ljava/lang/String;"); - const std::pair<bool, jobject> status_or_reference_timezone = - CallJniMethod0<jobject>(env, joptions, options_class.get(), - &JNIEnv::CallObjectMethod, "getReferenceTimezone", - "Ljava/lang/String;"); - const std::pair<bool, int64> status_or_reference_time_ms_utc = - CallJniMethod0<int64>(env, joptions, options_class.get(), - &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", - "J"); - const std::pair<bool, jobject> status_or_detected_text_language_tags = - CallJniMethod0<jobject>( - env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, - "getDetectedTextLanguageTags", "Ljava/lang/String;"); - const std::pair<bool, int> status_or_annotation_usecase = - CallJniMethod0<int>(env, joptions, options_class.get(), - &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I"); - - if (!status_or_locales.first || !status_or_reference_timezone.first || - !status_or_reference_time_ms_utc.first || - !status_or_detected_text_language_tags.first || - !status_or_annotation_usecase.first) { - return {}; - } + TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> options_class, + JniHelper::FindClass(env, class_name.c_str())); + + // .getLocale() + TC3_ASSIGN_OR_RETURN( + jmethodID get_locale, + JniHelper::GetMethodID(env, options_class.get(), "getLocale", + "()Ljava/lang/String;")); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jstring> locales, + JniHelper::CallObjectMethod<jstring>(env, joptions, get_locale)); + + // .getReferenceTimeMsUtc() + TC3_ASSIGN_OR_RETURN(jmethodID get_reference_time_method, + JniHelper::GetMethodID(env, options_class.get(), + "getReferenceTimeMsUtc", "()J")); + TC3_ASSIGN_OR_RETURN( + int64 reference_time, + JniHelper::CallLongMethod(env, joptions, get_reference_time_method)); + + // .getReferenceTimezone() + TC3_ASSIGN_OR_RETURN( + jmethodID get_reference_timezone_method, + JniHelper::GetMethodID(env, options_class.get(), "getReferenceTimezone", + "()Ljava/lang/String;")); + TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reference_timezone, + JniHelper::CallObjectMethod<jstring>( + env, joptions, get_reference_timezone_method)); + + // .getDetectedTextLanguageTags() + TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method, + JniHelper::GetMethodID(env, options_class.get(), + "getDetectedTextLanguageTags", + "()Ljava/lang/String;")); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jstring> detected_text_language_tags, + JniHelper::CallObjectMethod<jstring>( + env, joptions, get_detected_text_language_tags_method)); + + // .getAnnotationUsecase() + TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase, + JniHelper::GetMethodID(env, options_class.get(), + "getAnnotationUsecase", "()I")); + TC3_ASSIGN_OR_RETURN( + int32 annotation_usecase, + JniHelper::CallIntMethod(env, joptions, get_annotation_usecase)); T options; - options.locales = - ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); - options.reference_timezone = ToStlString( - env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); - options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; - options.detected_text_language_tags = ToStlString( - env, - reinterpret_cast<jstring>(status_or_detected_text_language_tags.second)); + TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get())); + TC3_ASSIGN_OR_RETURN(options.reference_timezone, + ToStlString(env, reference_timezone.get())); + options.reference_time_ms_utc = reference_time; + TC3_ASSIGN_OR_RETURN(options.detected_text_language_tags, + ToStlString(env, detected_text_language_tags.get())); options.annotation_usecase = - static_cast<AnnotationUsecase>(status_or_annotation_usecase.second); + static_cast<AnnotationUsecase>(annotation_usecase); return options; } } // namespace -SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { +StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env, + jobject joptions) { if (!joptions) { - return {}; + return {Status::UNKNOWN}; } - const ScopedLocalRef<jclass> options_class( - env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR - "$SelectionOptions"), - env); - const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( - env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, - "getLocales", "Ljava/lang/String;"); - const std::pair<bool, int> status_or_annotation_usecase = - CallJniMethod0<int>(env, joptions, options_class.get(), - &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I"); - if (!status_or_locales.first || !status_or_annotation_usecase.first) { - return {}; - } + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jclass> options_class, + JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$SelectionOptions")); + + // .getLocale() + TC3_ASSIGN_OR_RETURN( + jmethodID get_locales, + JniHelper::GetMethodID(env, options_class.get(), "getLocales", + "()Ljava/lang/String;")); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jstring> locales, + JniHelper::CallObjectMethod<jstring>(env, joptions, get_locales)); + + // .getAnnotationUsecase() + TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase, + JniHelper::GetMethodID(env, options_class.get(), + "getAnnotationUsecase", "()I")); + TC3_ASSIGN_OR_RETURN( + int32 annotation_usecase, + JniHelper::CallIntMethod(env, joptions, get_annotation_usecase)); SelectionOptions options; - options.locales = - ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); + TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get())); options.annotation_usecase = - static_cast<AnnotationUsecase>(status_or_annotation_usecase.second); + static_cast<AnnotationUsecase>(annotation_usecase); return options; } -ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, - jobject joptions) { +StatusOr<ClassificationOptions> FromJavaClassificationOptions( + JNIEnv* env, jobject joptions) { return FromJavaOptionsInternal<ClassificationOptions>( env, joptions, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions"); } -AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { - if (!joptions) return {}; - const ScopedLocalRef<jclass> options_class( - env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR - "$AnnotationOptions"), - env); - if (!options_class) return {}; - const std::pair<bool, jobject> status_or_entity_types = - CallJniMethod0<jobject>(env, joptions, options_class.get(), - &JNIEnv::CallObjectMethod, "getEntityTypes", - "[Ljava/lang/String;"); - if (!status_or_entity_types.first) return {}; - const std::pair<bool, bool> status_or_enable_serialized_entity_data = - CallJniMethod0<bool>(env, joptions, options_class.get(), - &JNIEnv::CallBooleanMethod, - "isSerializedEntityDataEnabled", "Z"); - if (!status_or_enable_serialized_entity_data.first) return {}; - AnnotationOptions annotation_options = +StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env, + jobject joptions) { + if (!joptions) { + return {Status::UNKNOWN}; + } + + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jclass> options_class, + JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR + "$AnnotationOptions")); + + // .getEntityTypes() + TC3_ASSIGN_OR_RETURN( + jmethodID get_entity_types, + JniHelper::GetMethodID(env, options_class.get(), "getEntityTypes", + "()[Ljava/lang/String;")); + TC3_ASSIGN_OR_RETURN( + ScopedLocalRef<jobject> entity_types, + JniHelper::CallObjectMethod<jobject>(env, joptions, get_entity_types)); + + // .isSerializedEntityDataEnabled() + TC3_ASSIGN_OR_RETURN( + jmethodID is_serialized_entity_data_enabled_method, + JniHelper::GetMethodID(env, options_class.get(), + "isSerializedEntityDataEnabled", "()Z")); + TC3_ASSIGN_OR_RETURN( + bool is_serialized_entity_data_enabled, + JniHelper::CallBooleanMethod(env, joptions, + is_serialized_entity_data_enabled_method)); + + TC3_ASSIGN_OR_RETURN( + AnnotationOptions annotation_options, FromJavaOptionsInternal<AnnotationOptions>( env, joptions, - TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions"); - annotation_options.entity_types = - EntityTypesFromJObject(env, status_or_entity_types.second); + TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions")); + TC3_ASSIGN_OR_RETURN(annotation_options.entity_types, + EntityTypesFromJObject(env, entity_types.get())); annotation_options.is_serialized_entity_data_enabled = - status_or_enable_serialized_entity_data.second; + is_serialized_entity_data_enabled; return annotation_options; } diff --git a/native/annotator/annotator_jni_common.h b/native/annotator/annotator_jni_common.h index b62bb21..f1f1d88 100644 --- a/native/annotator/annotator_jni_common.h +++ b/native/annotator/annotator_jni_common.h @@ -20,6 +20,7 @@ #include <jni.h> #include "annotator/annotator.h" +#include "utils/base/statusor.h" #ifndef TC3_ANNOTATOR_CLASS_NAME #define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel @@ -29,12 +30,14 @@ namespace libtextclassifier3 { -SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions); - -ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, +StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env, jobject joptions); -AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions); +StatusOr<ClassificationOptions> FromJavaClassificationOptions(JNIEnv* env, + jobject joptions); + +StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env, + jobject joptions); } // namespace libtextclassifier3 diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/parser.cc index 0f222bd..6c759e7 100644 --- a/native/annotator/datetime/parser.cc +++ b/native/annotator/datetime/parser.cc @@ -92,7 +92,7 @@ DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, } if (model->locales() != nullptr) { - for (int i = 0; i < model->locales()->Length(); ++i) { + for (int i = 0; i < model->locales()->size(); ++i) { locale_string_to_id_[model->locales()->Get(i)->str()] = i; } } @@ -106,6 +106,8 @@ DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, use_extractors_for_locating_ = model->use_extractors_for_locating(); generate_alternative_interpretations_when_ambiguous_ = model->generate_alternative_interpretations_when_ambiguous(); + prefer_future_for_unspecified_date_ = + model->prefer_future_for_unspecified_date(); initialized_ = true; } @@ -433,7 +435,8 @@ bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, // response. For Details see b/130355975 if (!calendarlib_.InterpretParseData( interpretation, reference_time_ms_utc, reference_timezone, - reference_locale, &(result.time_ms_utc), &(result.granularity))) { + reference_locale, prefer_future_for_unspecified_date_, + &(result.time_ms_utc), &(result.granularity))) { return false; } diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h index 4e995bd..a5192d3 100644 --- a/native/annotator/datetime/parser.h +++ b/native/annotator/datetime/parser.h @@ -59,12 +59,6 @@ class DatetimeParser { bool anchor_start_end, std::vector<DatetimeParseResultSpan>* results) const; -#ifdef TC3_TEST_ONLY - void TestOnlySetGenerateAlternativeInterpretationsWhenAmbiguous(bool value) { - generate_alternative_interpretations_when_ambiguous_ = value; - } -#endif // TC3_TEST_ONLY - protected: DatetimeParser(const DatetimeModel* model, const UniLib& unilib, const CalendarLib& calendarlib, @@ -126,6 +120,7 @@ class DatetimeParser { std::vector<int> default_locale_ids_; bool use_extractors_for_locating_; bool generate_alternative_interpretations_when_ambiguous_; + bool prefer_future_for_unspecified_date_; }; } // namespace libtextclassifier3 diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/parser_test.cc index 35c725f..1ddcf50 100644 --- a/native/annotator/datetime/parser_test.cc +++ b/native/annotator/datetime/parser_test.cc @@ -14,20 +14,21 @@ * limitations under the License. */ +#include "annotator/datetime/parser.h" + #include <time.h> + #include <fstream> #include <iostream> #include <memory> #include <string> -#include "gmock/gmock.h" -#include "gtest/gtest.h" - #include "annotator/annotator.h" -#include "annotator/datetime/parser.h" #include "annotator/model_generated.h" #include "annotator/types-test-util.h" #include "utils/testing/annotator.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" using std::vector; using testing::ElementsAreArray; @@ -152,7 +153,7 @@ class ParserTest : public testing::Test { const int expected_start_index = std::distance(marked_text_unicode.begin(), brace_open_it); - // The -1 bellow is to account for the opening bracket character. + // The -1 below is to account for the opening bracket character. const int expected_end_index = std::distance(marked_text_unicode.begin(), brace_end_it) - 1; @@ -746,6 +747,43 @@ TEST_F(ParserTest, ParseWithRawUsecase) { /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART)); } +TEST_F(ParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) { + // ParsesCorrectly uses 0 as the reference time, which corresponds to: + // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means + // it is in the past, and so the parser should move this to the next day -> + // "Fri Jan 02 1970 00:30:00" Zurich time (b/139112907). + EXPECT_TRUE(ParsesCorrectly( + "{0:30am}", 84600000L /* 23.5 hours from reference time */, + GRANULARITY_MINUTE, + {DatetimeComponentsBuilder() + .Add(DatetimeComponent::ComponentType::MERIDIEM, 0) + .Add(DatetimeComponent::ComponentType::MINUTE, 30) + .Add(DatetimeComponent::ComponentType::HOUR, 0) + .Build()})); +} + +TEST_F(ParserTest, DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) { + // ParsesCorrectly uses 0 as the reference time, which corresponds to: + // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means + // it is in the past. The parameter prefer_future_when_unspecified_day is + // disabled, so the parser should annotate this to the same day: "Thu Jan 01 + // 1970 00:30:00" Zurich time. + LoadModel([](ModelT* model) { + // In the test model, the prefer_future_for_unspecified_date is true; make + // it false only for this test. + model->datetime_model->prefer_future_for_unspecified_date = false; + }); + + EXPECT_TRUE(ParsesCorrectly( + "{0:30am}", -1800000L /* -30 minutes from reference time */, + GRANULARITY_MINUTE, + {DatetimeComponentsBuilder() + .Add(DatetimeComponent::ComponentType::MERIDIEM, 0) + .Add(DatetimeComponent::ComponentType::MINUTE, 30) + .Add(DatetimeComponent::ComponentType::HOUR, 0) + .Build()})); +} + TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) { EXPECT_TRUE(ParsesCorrectly( "{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE, diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc index 3529691..907a1a4 100644 --- a/native/annotator/duration/duration.cc +++ b/native/annotator/duration/duration.cc @@ -23,6 +23,7 @@ #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/strings/numbers.h" +#include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { @@ -31,46 +32,55 @@ using DurationUnit = internal::DurationUnit; namespace internal { namespace { +std::string ToLowerString(const std::string& str, const UniLib* unilib) { + return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false)) + .ToUTF8String(); +} + void FillDurationUnitMap( const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>* expressions, DurationUnit duration_unit, - std::unordered_map<std::string, DurationUnit>* target_map) { + std::unordered_map<std::string, DurationUnit>* target_map, + const UniLib* unilib) { if (expressions == nullptr) { return; } for (const flatbuffers::String* expression_string : *expressions) { - (*target_map)[expression_string->c_str()] = duration_unit; + (*target_map)[ToLowerString(expression_string->c_str(), unilib)] = + duration_unit; } } } // namespace std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping( - const DurationAnnotatorOptions* options) { + const DurationAnnotatorOptions* options, const UniLib* unilib) { std::unordered_map<std::string, DurationUnit> mapping; - FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, - &mapping); - FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping); - FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, - &mapping); + FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping, + unilib); + FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping, + unilib); + FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping, + unilib); FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE, - &mapping); + &mapping, unilib); FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND, - &mapping); + &mapping, unilib); return mapping; } std::unordered_set<std::string> BuildStringSet( const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>* - strings) { + strings, + const UniLib* unilib) { std::unordered_set<std::string> result; if (strings == nullptr) { return result; } for (const flatbuffers::String* string_value : *strings) { - result.insert(string_value->c_str()); + result.insert(ToLowerString(string_value->c_str(), unilib)); } return result; @@ -260,14 +270,17 @@ bool DurationAnnotator::ParseQuantityToken(const Token& token, std::string token_value_buffer; const std::string& token_value = feature_processor_->StripBoundaryCodepoints( token.value, &token_value_buffer); + const std::string& lowercase_token_value = + internal::ToLowerString(token_value, unilib_); - if (half_expressions_.find(token_value) != half_expressions_.end()) { + if (half_expressions_.find(lowercase_token_value) != + half_expressions_.end()) { value->plus_half = true; return true; } int32 parsed_value; - if (ParseInt32(token_value.c_str(), &parsed_value)) { + if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) { value->value = parsed_value; return true; } @@ -280,8 +293,10 @@ bool DurationAnnotator::ParseDurationUnitToken( std::string token_value_buffer; const std::string& token_value = feature_processor_->StripBoundaryCodepoints( token.value, &token_value_buffer); + const std::string& lowercase_token_value = + internal::ToLowerString(token_value, unilib_); - const auto it = token_value_to_duration_unit_.find(token_value); + const auto it = token_value_to_duration_unit_.find(lowercase_token_value); if (it == token_value_to_duration_unit_.end()) { return false; } @@ -319,8 +334,11 @@ bool DurationAnnotator::ParseFillerToken(const Token& token) const { std::string token_value_buffer; const std::string& token_value = feature_processor_->StripBoundaryCodepoints( token.value, &token_value_buffer); + const std::string& lowercase_token_value = + internal::ToLowerString(token_value, unilib_); - if (filler_expressions_.find(token_value) == filler_expressions_.end()) { + if (filler_expressions_.find(lowercase_token_value) == + filler_expressions_.end()) { return false; } diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h index 2242259..db4bdae 100644 --- a/native/annotator/duration/duration.h +++ b/native/annotator/duration/duration.h @@ -26,6 +26,7 @@ #include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" namespace libtextclassifier3 { @@ -46,12 +47,14 @@ enum class DurationUnit { // Prepares the mapping between token values and duration unit types. std::unordered_map<std::string, internal::DurationUnit> -BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options); +BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options, + const UniLib* unilib); // Creates a set of strings from a flatbuffer string vector. std::unordered_set<std::string> BuildStringSet( const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>* - strings); + strings, + const UniLib* unilib); // Creates a set of ints from a flatbuffer int vector. std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints); @@ -62,15 +65,17 @@ std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints); class DurationAnnotator { public: explicit DurationAnnotator(const DurationAnnotatorOptions* options, - const FeatureProcessor* feature_processor) + const FeatureProcessor* feature_processor, + const UniLib* unilib) : options_(options), feature_processor_(feature_processor), + unilib_(unilib), token_value_to_duration_unit_( - internal::BuildTokenToDurationUnitMapping(options)), + internal::BuildTokenToDurationUnitMapping(options, unilib)), filler_expressions_( - internal::BuildStringSet(options->filler_expressions())), + internal::BuildStringSet(options->filler_expressions(), unilib)), half_expressions_( - internal::BuildStringSet(options->half_expressions())), + internal::BuildStringSet(options->half_expressions(), unilib)), sub_token_separator_codepoints_(internal::BuildInt32Set( options->sub_token_separator_codepoints())) {} @@ -125,6 +130,7 @@ class DurationAnnotator { const DurationAnnotatorOptions* options_; const FeatureProcessor* feature_processor_; + const UniLib* unilib_; const std::unordered_map<std::string, internal::DurationUnit> token_value_to_duration_unit_; const std::unordered_set<std::string> filler_expressions_; diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc index 3fc25e6..d1dc67a 100644 --- a/native/annotator/duration/duration_test.cc +++ b/native/annotator/duration/duration_test.cc @@ -106,7 +106,7 @@ class DurationAnnotatorTest : public ::testing::Test { : INIT_UNILIB_FOR_TESTING(unilib_), feature_processor_(BuildFeatureProcessor(&unilib_)), duration_annotator_(TestingDurationAnnotatorOptions(), - feature_processor_.get()) {} + feature_processor_.get(), &unilib_) {} std::vector<Token> Tokenize(const UnicodeText& text) { return feature_processor_->Tokenize(text); @@ -195,6 +195,26 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) { 3 * 60 * 60 * 1000 + 5 * 1000))))))); } +TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) { + const UnicodeText text = UTF8ToUnicodeText( + "See you in a week and a day and an hour and a minute and a second"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + + EXPECT_THAT( + result, + ElementsAre( + AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)), + Field(&AnnotatedSpan::classification, + ElementsAre(AllOf( + Field(&ClassificationResult::collection, "duration"), + Field(&ClassificationResult::duration_ms, + 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 + + 60 * 60 * 1000 + 60 * 1000 + 1000))))))); +} + TEST_F(DurationAnnotatorTest, FindsHalfAnHour) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour"); std::vector<Token> tokens = Tokenize(text); @@ -350,5 +370,62 @@ TEST_F(DurationAnnotatorTest, 1400L * 60L * 60L * 1000L))); } +TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) { + const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + + EXPECT_THAT( + result, + ElementsAre( + AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)), + Field(&AnnotatedSpan::classification, + ElementsAre(AllOf( + Field(&ClassificationResult::collection, "duration"), + Field(&ClassificationResult::duration_ms, + 15 * 60 * 1000))))))); +} + +TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) { + const UnicodeText text = + UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + + EXPECT_THAT( + result, + ElementsAre( + AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)), + Field(&AnnotatedSpan::classification, + ElementsAre(AllOf( + Field(&ClassificationResult::collection, "duration"), + Field(&ClassificationResult::duration_ms, + 3.5 * 60 * 1000))))))); +} + +TEST_F(DurationAnnotatorTest, + FindsDurationWithHalfExpressionIgnoringFillerWordCase) { + const UnicodeText text = + UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + + EXPECT_THAT( + result, + ElementsAre( + AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)), + Field(&AnnotatedSpan::classification, + ElementsAre(AllOf( + Field(&ClassificationResult::collection, "duration"), + Field(&ClassificationResult::duration_ms, + 3.5 * 60 * 1000))))))); +} + } // namespace } // namespace libtextclassifier3 diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs index 6da3dd5..fa2dc0b 100755 --- a/native/annotator/entity-data.fbs +++ b/native/annotator/entity-data.fbs @@ -125,6 +125,46 @@ table Flight { flight_number:string (shared); } +// Details about an ISBN number. +namespace libtextclassifier3.EntityData_; +table Isbn { + // The (normalized) number. + number:string (shared); +} + +// Details about an IBAN number. +namespace libtextclassifier3.EntityData_; +table Iban { + // The (normalized) number. + number:string (shared); + + // The country code. + country_code:string (shared); +} + +namespace libtextclassifier3.EntityData_.ParcelTracking_; +enum Carrier : int { + UNKNOWN_CARRIER = 0, + FEDEX = 1, + UPS = 2, + DHL = 3, + USPS = 4, + ONTRAC = 5, + LASERSHIP = 6, + ISRAEL_POST = 7, + SWISS_POST = 8, + MSC = 9, + AMAZON = 10, + I_PARCEL = 11, +} + +// Details about a tracking number. +namespace libtextclassifier3.EntityData_; +table ParcelTracking { + carrier:ParcelTracking_.Carrier; + tracking_number:string (shared); +} + // Represents an entity annotated in text. namespace libtextclassifier3; table EntityData { @@ -143,6 +183,9 @@ table EntityData { app:EntityData_.App; payment_card:EntityData_.PaymentCard; flight:EntityData_.Flight; + isbn:EntityData_.Isbn; + iban:EntityData_.Iban; + parcel:EntityData_.ParcelTracking; } root_type libtextclassifier3.EntityData; diff --git a/native/annotator/flatbuffer-utils.cc b/native/annotator/flatbuffer-utils.cc new file mode 100644 index 0000000..14b5901 --- /dev/null +++ b/native/annotator/flatbuffer-utils.cc @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/flatbuffer-utils.h" + +#include <memory> + +#include "utils/base/logging.h" +#include "utils/flatbuffers.h" +#include "flatbuffers/reflection.h" + +namespace libtextclassifier3 { + +bool SwapFieldNamesForOffsetsInPath(ModelT* model) { + if (model->regex_model == nullptr || model->entity_data_schema.empty()) { + // Nothing to do. + return true; + } + const reflection::Schema* schema = + LoadAndVerifyFlatbuffer<reflection::Schema>( + model->entity_data_schema.data(), model->entity_data_schema.size()); + + for (std::unique_ptr<RegexModel_::PatternT>& pattern : + model->regex_model->patterns) { + for (std::unique_ptr<RegexModel_::Pattern_::CapturingGroupT>& group : + pattern->capturing_group) { + if (group->entity_field_path == nullptr) { + continue; + } + + if (!SwapFieldNamesForOffsetsInPath(schema, + group->entity_field_path.get())) { + return false; + } + } + } + + return true; +} + +std::string SwapFieldNamesForOffsetsInPathInSerializedModel( + const std::string& model) { + std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str()); + TC3_CHECK(unpacked_model != nullptr); + TC3_CHECK(SwapFieldNamesForOffsetsInPath(unpacked_model.get())); + flatbuffers::FlatBufferBuilder builder; + FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +} // namespace libtextclassifier3 diff --git a/native/annotator/flatbuffer-utils.h b/native/annotator/flatbuffer-utils.h new file mode 100644 index 0000000..a7e5d64 --- /dev/null +++ b/native/annotator/flatbuffer-utils.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Utility functions for working with FlatBuffers in the annotator model. + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_ + +#include <string> + +#include "annotator/model_generated.h" + +namespace libtextclassifier3 { + +// Resolves field lookups by name to the concrete field offsets in the regex +// rules of the model. +bool SwapFieldNamesForOffsetsInPath(ModelT* model); + +// Same as above but for a serialized model. +std::string SwapFieldNamesForOffsetsInPathInSerializedModel( + const std::string& model); + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_ diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h index 1787353..865bf85 100644 --- a/native/annotator/knowledge/knowledge-engine-dummy.h +++ b/native/annotator/knowledge/knowledge-engine-dummy.h @@ -29,6 +29,8 @@ class KnowledgeEngine { public: bool Initialize(const std::string& serialized_config) { return true; } + void SetPriorityScore(float priority_score) {} + bool ClassifyText(const std::string& context, CodepointSpan selection_indices, AnnotationUsecase annotation_usecase, ClassificationResult* classification_result) const { diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs index 181a8aa..5bf1472 100755 --- a/native/annotator/model.fbs +++ b/native/annotator/model.fbs @@ -17,6 +17,7 @@ include "utils/codepoint-range.fbs"; include "utils/flatbuffers.fbs"; include "utils/intents/intent-config.fbs"; +include "utils/normalization.fbs"; include "utils/resources.fbs"; include "utils/tokenizer.fbs"; include "utils/zlib/buffer.fbs"; @@ -209,6 +210,9 @@ table CapturingGroup { // If set, the serialized entity data will be merged with the // classification result entity data. serialized_entity_data:string (shared); + + // If set, normalization to apply before text is used in entity data. + normalization_options:NormalizationOptions; } // List of regular expression matchers to check. @@ -329,6 +333,9 @@ table DatetimeModel { // If true, will compile the regexes only on first use. lazy_regex_compilation:bool = true; + + // If true, will give only future dates (when the day is not specified). + prefer_future_for_unspecified_date:bool = false; } namespace libtextclassifier3.DatetimeModelLibrary_; @@ -363,6 +370,9 @@ table ModelTriggeringOptions { // Priority score assigned to the "other" class from ML model. other_collection_priority_score:float = -1000; + + // Priority score assigned to knowledge engine annotations. + knowledge_priority_score:float = 0; } // Options controlling the output of the classifier. @@ -675,6 +685,10 @@ table NumberAnnotatorOptions { // Priority score for the percentage annotation. percentage_priority_score:float = 1; + + // Float number priority score used for conflict resolution with the other + // models. + float_number_priority_score:float = 0; } // DurationAnnotator is so far tailored for English only. diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc index 7af63fa..671e1af 100644 --- a/native/annotator/number/number.cc +++ b/native/annotator/number/number.cc @@ -20,7 +20,9 @@ #include <cstdlib> #include "annotator/collections.h" +#include "annotator/types.h" #include "utils/base/logging.h" +#include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { @@ -28,68 +30,38 @@ bool NumberAnnotator::ClassifyText( const UnicodeText& context, CodepointSpan selection_indices, AnnotationUsecase annotation_usecase, ClassificationResult* classification_result) const { - if (!options_->enabled() || ((1 << annotation_usecase) & - options_->enabled_annotation_usecases()) == 0) { - return false; - } + TC3_CHECK(classification_result != nullptr); - int64 parsed_int_value; - double parsed_double_value; - int num_prefix_codepoints; - int num_suffix_codepoints; const UnicodeText substring_selected = UnicodeText::Substring( context, selection_indices.first, selection_indices.second); - if (ParseNumber(substring_selected, &parsed_int_value, &parsed_double_value, - &num_prefix_codepoints, &num_suffix_codepoints)) { - TC3_CHECK(classification_result != nullptr); - classification_result->score = options_->score(); - classification_result->priority_score = options_->priority_score(); - classification_result->numeric_value = parsed_int_value; - classification_result->numeric_double_value = parsed_double_value; - - if (num_suffix_codepoints == 0) { - classification_result->collection = Collections::Number(); - return true; + + std::vector<AnnotatedSpan> results; + if (!FindAll(substring_selected, annotation_usecase, &results)) { + return false; + } + + const CodepointSpan stripped_selection_indices = + feature_processor_->StripBoundaryCodepoints( + context, selection_indices, ignored_prefix_span_boundary_codepoints_, + ignored_suffix_span_boundary_codepoints_); + + for (const AnnotatedSpan& result : results) { + if (result.classification.empty()) { + continue; } - // If the selection ends in %, the parseNumber returns true with - // num_suffix_codepoints = 1 => percent - if (options_->enable_percentage() && - GetPercentSuffixLength( - context, context.size_codepoints(), - selection_indices.second - num_suffix_codepoints) == - num_suffix_codepoints) { - classification_result->collection = Collections::Percentage(); - classification_result->priority_score = - options_->percentage_priority_score(); + // We make sure that the result span is equal to the stripped selection span + // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will + // anyway only find valid numbers and percentages and a given selection with + // more than two tokens won't pass this check. + if (result.span.first + selection_indices.first == + stripped_selection_indices.first && + result.span.second + selection_indices.first == + stripped_selection_indices.second) { + *classification_result = result.classification[0]; return true; } - } else if (options_->enable_percentage()) { - // If the substring selected is a percent matching the form: 5 percent, - // 5 pct or 5 pc => percent. - std::vector<AnnotatedSpan> results; - FindAll(substring_selected, annotation_usecase, &results); - for (auto& result : results) { - if (result.classification.empty() || - result.classification[0].collection != Collections::Percentage()) { - continue; - } - if (result.span.first == 0 && - result.span.second == substring_selected.size_codepoints()) { - TC3_CHECK(classification_result != nullptr); - classification_result->collection = Collections::Percentage(); - classification_result->score = options_->score(); - classification_result->priority_score = - options_->percentage_priority_score(); - classification_result->numeric_value = - result.classification[0].numeric_value; - classification_result->numeric_double_value = - result.classification[0].numeric_double_value; - return true; - } - } } - return false; } @@ -107,21 +79,24 @@ bool NumberAnnotator::FindAll(const UnicodeText& context, UTF8ToUnicodeText(token.value, /*do_copy=*/false); int64 parsed_int_value; double parsed_double_value; + bool has_decimal; int num_prefix_codepoints; int num_suffix_codepoints; if (ParseNumber(token_text, &parsed_int_value, &parsed_double_value, - &num_prefix_codepoints, &num_suffix_codepoints)) { + &has_decimal, &num_prefix_codepoints, + &num_suffix_codepoints)) { ClassificationResult classification{Collections::Number(), options_->score()}; classification.numeric_value = parsed_int_value; classification.numeric_double_value = parsed_double_value; - classification.priority_score = options_->priority_score(); + classification.priority_score = + has_decimal ? options_->float_number_priority_score() + : options_->priority_score(); AnnotatedSpan annotated_span; annotated_span.span = {token.start + num_prefix_codepoints, token.end - num_suffix_codepoints}; annotated_span.classification.push_back(classification); - result->push_back(annotated_span); } } @@ -151,7 +126,7 @@ std::vector<uint32> NumberAnnotator::FlatbuffersIntVectorToStdVector( namespace { bool ParseNextNumericCodepoint(int32 codepoint, int64* current_value) { - if (*current_value > INT64_MAX / 10) { + if (*current_value > INT64_MAX / 10 - 10) { return false; } @@ -163,20 +138,20 @@ bool ParseNextNumericCodepoint(int32 codepoint, int64* current_value) { UnicodeText::const_iterator ConsumeAndParseNumber( const UnicodeText::const_iterator& it_begin, const UnicodeText::const_iterator& it_end, int64* int_result, - double* double_result) { + double* double_result, bool* has_decimal) { *int_result = 0; + *has_decimal = false; // See if there's a sign in the beginning of the number. int sign = 1; auto it = it_begin; - if (it != it_end) { + while (it != it_end && (*it == '-' || *it == '+')) { if (*it == '-') { - ++it; sign = -1; - } else if (*it == '+') { - ++it; + } else { sign = 1; } + ++it; } enum class State { @@ -203,6 +178,7 @@ UnicodeText::const_iterator ConsumeAndParseNumber( break; case State::PARSING_FLOATING_PART: if (*it >= '0' && *it <= '9') { + *has_decimal = true; if (!ParseNextNumericCodepoint(*it, &decimal_result)) { state = State::PARSING_DONE; break; @@ -236,7 +212,7 @@ UnicodeText::const_iterator ConsumeAndParseNumber( } // namespace bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result, - double* double_result, + double* double_result, bool* has_decimal, int* num_prefix_codepoints, int* num_suffix_codepoints) const { TC3_CHECK(int_result != nullptr && double_result != nullptr && @@ -258,13 +234,6 @@ bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result, // Consume prefix codepoints. *num_prefix_codepoints = stripped_span.first; - bool valid_prefix = true; - // Makes valid_prefix=false for cases like: "#10" where it points to '1'. In - // this case the adjacent prefix is not an allowed one. - if (it != text.begin() && allowed_prefix_codepoints_.find(*std::prev(it)) == - allowed_prefix_codepoints_.end()) { - valid_prefix = false; - } while (it != it_end) { if (allowed_prefix_codepoints_.find(*it) == allowed_prefix_codepoints_.end()) { @@ -276,7 +245,8 @@ bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result, } auto it_start = it; - it = ConsumeAndParseNumber(it, it_end, int_result, double_result); + it = + ConsumeAndParseNumber(it, it_end, int_result, double_result, has_decimal); if (it == it_start) { return false; } @@ -284,32 +254,35 @@ bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result, // Consume suffix codepoints. bool valid_suffix = true; *num_suffix_codepoints = 0; + int ignored_suffix_codepoints = 0; while (it != it_end) { - if (allowed_suffix_codepoints_.find(*it) == + if (allowed_suffix_codepoints_.find(*it) != allowed_suffix_codepoints_.end()) { + // Keep track of allowed suffix codepoints. + ++(*num_suffix_codepoints); + } else if (ignored_suffix_span_boundary_codepoints_.find(*it) == + ignored_suffix_span_boundary_codepoints_.end()) { + // There is a suffix codepoint but it's not part of the ignored list of + // codepoints, fail the number parsing. + // Note: We want to support cases like "13.", "34#", "123!" etc. valid_suffix = false; break; + } else { + ++ignored_suffix_codepoints; } ++it; - ++(*num_suffix_codepoints); } *num_suffix_codepoints += num_stripped_end; - // Makes valid_suffix=false for cases like: "10@", when it == it_end and - // points to '@'. This adjacent character is not an allowed suffix. - if (it == it_end && it != text.end() && - allowed_suffix_codepoints_.find(*it) == - allowed_suffix_codepoints_.end()) { - valid_suffix = false; - } - - return valid_suffix && valid_prefix; + return valid_suffix; } int NumberAnnotator::GetPercentSuffixLength(const UnicodeText& context, - int context_size_codepoints, int index_codepoints) const { + if (index_codepoints >= context.size_codepoints()) { + return -1; + } auto context_it = context.begin(); std::advance(context_it, index_codepoints); const StringPiece suffix_context( @@ -329,15 +302,13 @@ int NumberAnnotator::GetPercentSuffixLength(const UnicodeText& context, void NumberAnnotator::FindPercentages( const UnicodeText& context, std::vector<AnnotatedSpan>* result) const { - int context_size_codepoints = context.size_codepoints(); for (auto& res : *result) { if (res.classification.empty() || res.classification[0].collection != Collections::Number()) { continue; } - const int match_length = GetPercentSuffixLength( - context, context_size_codepoints, res.span.second); + const int match_length = GetPercentSuffixLength(context, res.span.second); if (match_length > 0) { res.classification[0].collection = Collections::Percentage(); res.classification[0].priority_score = diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h index 3debd09..3e9e2c3 100644 --- a/native/annotator/number/number.h +++ b/native/annotator/number/number.h @@ -81,16 +81,17 @@ class NumberAnnotator { static std::vector<uint32> FlatbuffersIntVectorToStdVector( const flatbuffers::Vector<int32_t>* ints); - // Parses the text to an int64 value and returns true if succeeded, otherwise - // false. Also returns the number of prefix/suffix codepoints that were - // stripped from the number. + // Parses the text to an int64 value and a double value and returns true if + // succeeded, otherwise false. Also returns whether the number contains a + // decimal and the number of prefix/suffix codepoints that were stripped from + // the number. bool ParseNumber(const UnicodeText& text, int64* int_result, - double* double_result, int* num_prefix_codepoints, + double* double_result, bool* has_decimal, + int* num_prefix_codepoints, int* num_suffix_codepoints) const; // Get the length of the percent suffix at the specified index in the context. int GetPercentSuffixLength(const UnicodeText& context, - int context_size_codepoints, int index_codepoints) const; // Checks if the annotated numbers from the context represent percentages. diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb Binary files differindex ce5f72f..bbf730e 100644 --- a/native/annotator/test_data/test_model.fb +++ b/native/annotator/test_data/test_model.fb diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb Binary files differindex efefa3c..135dec0 100644 --- a/native/annotator/test_data/wrong_embeddings.fb +++ b/native/annotator/test_data/wrong_embeddings.fb diff --git a/native/annotator/types.cc b/native/annotator/types.cc index c31097d..1ec3790 100644 --- a/native/annotator/types.cc +++ b/native/annotator/types.cc @@ -56,6 +56,64 @@ std::string FormatMillis(int64 time_ms_utc) { } } // namespace +std::string ComponentTypeToString( + const DatetimeComponent::ComponentType& component_type) { + switch (component_type) { + case DatetimeComponent::ComponentType::UNSPECIFIED: + return "UNSPECIFIED"; + case DatetimeComponent::ComponentType::YEAR: + return "YEAR"; + case DatetimeComponent::ComponentType::MONTH: + return "MONTH"; + case DatetimeComponent::ComponentType::WEEK: + return "WEEK"; + case DatetimeComponent::ComponentType::DAY_OF_WEEK: + return "DAY_OF_WEEK"; + case DatetimeComponent::ComponentType::DAY_OF_MONTH: + return "DAY_OF_MONTH"; + case DatetimeComponent::ComponentType::HOUR: + return "HOUR"; + case DatetimeComponent::ComponentType::MINUTE: + return "MINUTE"; + case DatetimeComponent::ComponentType::SECOND: + return "SECOND"; + case DatetimeComponent::ComponentType::MERIDIEM: + return "MERIDIEM"; + case DatetimeComponent::ComponentType::ZONE_OFFSET: + return "ZONE_OFFSET"; + case DatetimeComponent::ComponentType::DST_OFFSET: + return "DST_OFFSET"; + default: + return ""; + } +} + +std::string RelativeQualifierToString( + const DatetimeComponent::RelativeQualifier& relative_qualifier) { + switch (relative_qualifier) { + case DatetimeComponent::RelativeQualifier::UNSPECIFIED: + return "UNSPECIFIED"; + case DatetimeComponent::RelativeQualifier::NEXT: + return "NEXT"; + case DatetimeComponent::RelativeQualifier::THIS: + return "THIS"; + case DatetimeComponent::RelativeQualifier::LAST: + return "LAST"; + case DatetimeComponent::RelativeQualifier::NOW: + return "NOW"; + case DatetimeComponent::RelativeQualifier::TOMORROW: + return "TOMORROW"; + case DatetimeComponent::RelativeQualifier::YESTERDAY: + return "YESTERDAY"; + case DatetimeComponent::RelativeQualifier::PAST: + return "PAST"; + case DatetimeComponent::RelativeQualifier::FUTURE: + return "FUTURE"; + default: + return ""; + } +} + logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, const DatetimeParseResultSpan& value) { stream << "DatetimeParseResultSpan({" << value.span.first << ", " @@ -63,7 +121,16 @@ logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, for (const DatetimeParseResult& data : value.data) { stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* " << FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ " - << data.granularity << "}, "; + << data.granularity << ", /*datetime_components=*/ "; + for (const DatetimeComponent& datetime_comp : data.datetime_components) { + stream << "{/*component_type=*/ " + << ComponentTypeToString(datetime_comp.component_type) + << " /*relative_qualifier=*/ " + << RelativeQualifierToString(datetime_comp.relative_qualifier) + << " /*value=*/ " << datetime_comp.value << " /*relative_count=*/ " + << datetime_comp.relative_count << "}, "; + } + stream << "}, "; } stream << "})"; return stream; diff --git a/native/annotator/types.h b/native/annotator/types.h index ac24e24..9b94c10 100644 --- a/native/annotator/types.h +++ b/native/annotator/types.h @@ -352,7 +352,7 @@ struct ClassificationResult { // Entity data information. std::string serialized_entity_data; - const EntityData* entity_data() { + const EntityData* entity_data() const { return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(), serialized_entity_data.size()); } diff --git a/native/annotator/zlib-utils.cc b/native/annotator/zlib-utils.cc index ec2392b..c3c2cf1 100644 --- a/native/annotator/zlib-utils.cc +++ b/native/annotator/zlib-utils.cc @@ -125,6 +125,15 @@ bool DecompressModel(ModelT* model) { extractor->compressed_pattern.reset(nullptr); } } + + if (model->resources != nullptr) { + DecompressResources(model->resources.get()); + } + + if (model->intent_options != nullptr) { + DecompressIntentModel(model->intent_options.get()); + } + return true; } diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc index 7a8d775..363c155 100644 --- a/native/annotator/zlib-utils_test.cc +++ b/native/annotator/zlib-utils_test.cc @@ -43,6 +43,37 @@ TEST(ZlibUtilsTest, CompressModel) { model.datetime_model->extractors.back()->pattern = "an example datetime extractor"; + model.intent_options.reset(new IntentFactoryModelT); + model.intent_options->generator.emplace_back( + new IntentFactoryModel_::IntentGeneratorT); + const std::string intent_generator1 = "lua generator 1"; + model.intent_options->generator.back()->lua_template_generator = + std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end()); + model.intent_options->generator.emplace_back( + new IntentFactoryModel_::IntentGeneratorT); + const std::string intent_generator2 = "lua generator 2"; + model.intent_options->generator.back()->lua_template_generator = + std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()); + + // NOTE: The resource strings contain some repetition, so that the compressed + // version is smaller than the uncompressed one. Because the compression code + // looks at that as well. + model.resources.reset(new ResourcePoolT); + model.resources->resource_entry.emplace_back(new ResourceEntryT); + model.resources->resource_entry.back()->resource.emplace_back(new ResourceT); + model.resources->resource_entry.back()->resource.back()->content = + "rrrrrrrrrrrrr1.1"; + model.resources->resource_entry.back()->resource.emplace_back(new ResourceT); + model.resources->resource_entry.back()->resource.back()->content = + "rrrrrrrrrrrrr1.2"; + model.resources->resource_entry.emplace_back(new ResourceEntryT); + model.resources->resource_entry.back()->resource.emplace_back(new ResourceT); + model.resources->resource_entry.back()->resource.back()->content = + "rrrrrrrrrrrrr2.1"; + model.resources->resource_entry.back()->resource.emplace_back(new ResourceT); + model.resources->resource_entry.back()->resource.back()->content = + "rrrrrrrrrrrrr2.2"; + // Compress the model. EXPECT_TRUE(CompressModel(&model)); @@ -51,6 +82,14 @@ TEST(ZlibUtilsTest, CompressModel) { EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty()); EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty()); EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty()); + EXPECT_TRUE( + model.intent_options->generator[0]->lua_template_generator.empty()); + EXPECT_TRUE( + model.intent_options->generator[1]->lua_template_generator.empty()); + EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty()); + EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty()); + EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty()); + EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty()); // Pack and load the model. flatbuffers::FlatBufferBuilder builder; @@ -94,6 +133,20 @@ TEST(ZlibUtilsTest, CompressModel) { "an example datetime pattern"); EXPECT_EQ(model.datetime_model->extractors[0]->pattern, "an example datetime extractor"); + EXPECT_EQ( + model.intent_options->generator[0]->lua_template_generator, + std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end())); + EXPECT_EQ( + model.intent_options->generator[1]->lua_template_generator, + std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end())); + EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content, + "rrrrrrrrrrrrr1.1"); + EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content, + "rrrrrrrrrrrrr1.2"); + EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content, + "rrrrrrrrrrrrr2.1"); + EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content, + "rrrrrrrrrrrrr2.2"); } } // namespace libtextclassifier3 |