diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-10-10 23:06:01 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-10-10 23:06:01 +0000 |
commit | fe949b752c0c03a77e01aed75b114bdaed6835a8 (patch) | |
tree | 46e8f64830ee8f8d4b8e9cdc67cb0f13c4773238 | |
parent | 56e59eeef263a3155f036433ca64fe4fe3346e3c (diff) | |
parent | 0d8a53648a30e140e3faf29f7809c5e90963409b (diff) | |
download | libtextclassifier-fe949b752c0c03a77e01aed75b114bdaed6835a8.tar.gz |
Snap for 10929834 from 0d8a53648a30e140e3faf29f7809c5e90963409b to sdk-release
Change-Id: Ic3c2a1fb81976a2234b1db309c02d2c62a6d1ada
100 files changed, 1251 insertions, 401 deletions
diff --git a/TEST_MAPPING b/TEST_MAPPING index 370acd6..17a31d4 100644 --- a/TEST_MAPPING +++ b/TEST_MAPPING @@ -18,7 +18,12 @@ "name": "TextClassifierNotificationTests" }, { - "name": "TCSModelDownloaderIntegrationTest" + "name": "TCSModelDownloaderIntegrationTest", + "options": [ + { + "exclude-annotation": "androidx.test.filters.FlakyTest" + } + ] } ], "hwasan-postsubmit": [ diff --git a/abseil-cpp/Android.bp b/abseil-cpp/Android.bp index a3635f3..0ab3734 100644 --- a/abseil-cpp/Android.bp +++ b/abseil-cpp/Android.bp @@ -35,7 +35,7 @@ cc_library_static { export_include_dirs: ["."], visibility: [ "//external/libtextclassifier:__subpackages__", - "//external/tflite-support:__subpackages__" + "//external/tflite-support:__subpackages__", ], srcs: [ "absl/**/*.cc", @@ -43,6 +43,7 @@ cc_library_static { apex_available: [ "//apex_available:platform", "com.android.extservices", + "com.android.adservices", ], sdk_version: "current", min_sdk_version: "30", diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java index 2b6c396..abe8994 100644 --- a/java/src/com/android/textclassifier/TextClassifierImpl.java +++ b/java/src/com/android/textclassifier/TextClassifierImpl.java @@ -199,6 +199,8 @@ final class TextClassifierImpl { .setDetectedTextLanguageTags(detectLanguageTags) .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue()) .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags()) + .setEnableAddContactIntent(false) + .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext)) .build(), // Passing null here to suppress intent generation. // TODO: Use an explicit flag to suppress it. @@ -254,6 +256,8 @@ final class TextClassifierImpl { .setDetectedTextLanguageTags(String.join(",", detectLanguageTags)) .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue()) .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags()) + .setEnableAddContactIntent(false) + .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext)) .build(), context, getResourceLocalesString()); @@ -769,4 +773,15 @@ final class TextClassifierImpl { strippedIntent.setComponent(null); return strippedIntent; } + + private static boolean shouldEnableSearchIntent( + @Nullable TextClassificationContext textClassificationContext) { + if (textClassificationContext == null) { + return false; + } + String widgetType = textClassificationContext.getWidgetType(); + // Exclude WebView because there is already a *Web Search* chip there. + return !(TextClassifier.WIDGET_TYPE_WEBVIEW.equals(widgetType) + || TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW.equals(widgetType)); + } } diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java index 205680d..d0ea917 100644 --- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java +++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java @@ -525,7 +525,7 @@ public final class TextClassifierSettings { variantsMapBuilder.put(modelLanguageTag, urlFlagValue); } } - return variantsMapBuilder.build(); + return variantsMapBuilder.buildOrThrow(); } public String getTestingLocaleListOverride() { diff --git a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java index 5c420ad..abc879d 100644 --- a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java +++ b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java @@ -92,7 +92,7 @@ public final class LabeledIntent { @Nullable public Result resolve(Context context, @Nullable TitleChooser titleChooser) { final PackageManager pm = context.getPackageManager(); - final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0); + final ResolveInfo resolveInfo = pm.resolveActivity(intent, PackageManager.MATCH_DEFAULT_ONLY); if (resolveInfo == null || resolveInfo.activityInfo == null) { // Failed to resolve the intent. It could be because there are no apps to handle diff --git a/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java b/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java index e729201..5e572a7 100644 --- a/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java +++ b/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java @@ -18,6 +18,7 @@ package com.android.textclassifier.common.logging; import androidx.annotation.NonNull; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.Locale; import javax.annotation.Nullable; @@ -91,6 +92,7 @@ public final class TextClassificationContext { * * @return this builder */ + @CanIgnoreReturnValue public Builder setWidgetVersion(@Nullable String widgetVersion) { this.widgetVersion = widgetVersion; return this; diff --git a/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java b/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java index f34fb3d..ef82afc 100644 --- a/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java +++ b/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java @@ -19,6 +19,7 @@ package com.android.textclassifier.common.logging; import android.os.Bundle; import androidx.annotation.IntDef; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.util.Arrays; @@ -327,6 +328,7 @@ public abstract class TextClassifierEvent { * * <p>See {@link Locale#toLanguageTag()} */ + @CanIgnoreReturnValue public T setEntityTypes(String... entityTypes) { Preconditions.checkNotNull(entityTypes); this.entityTypes = new String[entityTypes.length]; @@ -335,12 +337,14 @@ public abstract class TextClassifierEvent { } /** Sets the event context. */ + @CanIgnoreReturnValue public T setEventContext(@Nullable TextClassificationContext eventContext) { this.eventContext = eventContext; return self(); } /** Sets the id of the text classifier result related to this event. */ + @CanIgnoreReturnValue @Nonnull public T setResultId(@Nullable String resultId) { this.resultId = resultId; @@ -348,6 +352,7 @@ public abstract class TextClassifierEvent { } /** Sets the index of this event in the series of events it belongs to. */ + @CanIgnoreReturnValue @Nonnull public T setEventIndex(int eventIndex) { this.eventIndex = eventIndex; @@ -355,6 +360,7 @@ public abstract class TextClassifierEvent { } /** Sets the scores of the suggestions. */ + @CanIgnoreReturnValue @Nonnull public T setScores(@Nonnull float... scores) { Preconditions.checkNotNull(scores); @@ -364,6 +370,7 @@ public abstract class TextClassifierEvent { } /** Sets the model name string. */ + @CanIgnoreReturnValue @Nonnull public T setModelName(@Nullable String modelVersion) { modelName = modelVersion; @@ -395,6 +402,7 @@ public abstract class TextClassifierEvent { * * @see android.view.textclassifier.TextClassification#getActions() */ + @CanIgnoreReturnValue @Nonnull public T setActionIndices(@Nonnull int... actionIndices) { this.actionIndices = new int[actionIndices.length]; @@ -403,6 +411,7 @@ public abstract class TextClassifierEvent { } /** Sets the detected locale. */ + @CanIgnoreReturnValue @Nonnull public T setLocale(@Nullable Locale locale) { this.locale = locale; @@ -416,6 +425,7 @@ public abstract class TextClassifierEvent { * the internals of this bundle as it may have unexpected consequences on the clients of the * built event object. For similar reasons, avoid depending on mutable objects in this bundle. */ + @CanIgnoreReturnValue @Nonnull public T setExtras(@Nonnull Bundle extras) { this.extras = Preconditions.checkNotNull(extras); @@ -578,6 +588,7 @@ public abstract class TextClassifierEvent { } /** Sets the relative word index of the start of the selection. */ + @CanIgnoreReturnValue @Nonnull public Builder setRelativeWordStartIndex(int relativeWordStartIndex) { this.relativeWordStartIndex = relativeWordStartIndex; @@ -585,6 +596,7 @@ public abstract class TextClassifierEvent { } /** Sets the relative word (exclusive) index of the end of the selection. */ + @CanIgnoreReturnValue @Nonnull public Builder setRelativeWordEndIndex(int relativeWordEndIndex) { this.relativeWordEndIndex = relativeWordEndIndex; @@ -592,6 +604,7 @@ public abstract class TextClassifierEvent { } /** Sets the relative word index of the start of the smart selection. */ + @CanIgnoreReturnValue @Nonnull public Builder setRelativeSuggestedWordStartIndex(int relativeSuggestedWordStartIndex) { this.relativeSuggestedWordStartIndex = relativeSuggestedWordStartIndex; @@ -599,6 +612,7 @@ public abstract class TextClassifierEvent { } /** Sets the relative word (exclusive) index of the end of the smart selection. */ + @CanIgnoreReturnValue @Nonnull public Builder setRelativeSuggestedWordEndIndex(int relativeSuggestedWordEndIndex) { this.relativeSuggestedWordEndIndex = relativeSuggestedWordEndIndex; diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java index 3db0815..ff04aea 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java @@ -232,9 +232,10 @@ public final class ModelDownloadWorker extends ListenableWorker { downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl)); } manifestsToDownloadBuilder.put( - modelType, ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.build())); + modelType, + ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.buildOrThrow())); } - manifestsToDownload = manifestsToDownloadBuilder.build(); + manifestsToDownload = manifestsToDownloadBuilder.buildOrThrow(); return Futures.whenAllComplete(downloadResultFutures) .call( diff --git a/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java b/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java index 7416b00..0bc37f0 100644 --- a/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java +++ b/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java @@ -86,7 +86,7 @@ final class TextClassifierDownloadLogger { ModelDownloadException.FAILED_TO_VALIDATE_MODEL, TextClassifierStatsLog .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_VALIDATE_MODEL) - .build(); + .buildOrThrow(); // Reasons to schedule public static final int REASON_TO_SCHEDULE_TCS_STARTED = diff --git a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java index bd48c22..41e6fcd 100644 --- a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java +++ b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java @@ -17,6 +17,7 @@ package com.android.textclassifier.utils; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.PrintWriter; /** @@ -36,6 +37,7 @@ public final class IndentingPrintWriter { } /** Prints a string. */ + @CanIgnoreReturnValue public IndentingPrintWriter println(String string) { writer.print(currentIndent); writer.print(string); @@ -44,12 +46,14 @@ public final class IndentingPrintWriter { } /** Prints a empty line */ + @CanIgnoreReturnValue public IndentingPrintWriter println() { writer.println(); return this; } /** Increases indents for subsequent texts. */ + @CanIgnoreReturnValue public IndentingPrintWriter increaseIndent() { indentBuilder.append(SINGLE_INDENT); currentIndent = indentBuilder.toString(); @@ -57,6 +61,7 @@ public final class IndentingPrintWriter { } /** Decreases indents for subsequent texts. */ + @CanIgnoreReturnValue public IndentingPrintWriter decreaseIndent() { indentBuilder.delete(0, SINGLE_INDENT.length()); currentIndent = indentBuilder.toString(); @@ -64,6 +69,7 @@ public final class IndentingPrintWriter { } /** Prints a key-valued pair. */ + @CanIgnoreReturnValue public IndentingPrintWriter printPair(String key, Object value) { println(String.format("%s=%s", key, String.valueOf(value))); return this; diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java index c20ec8a..89d405a 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java @@ -44,7 +44,6 @@ import android.view.textclassifier.TextLanguage; import android.view.textclassifier.TextLinks; import android.view.textclassifier.TextSelection; import androidx.collection.LruCache; -import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.filters.SdkSuppress; import androidx.test.filters.SmallTest; diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java index e261158..2d06afd 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java @@ -18,9 +18,9 @@ package com.android.textclassifier.downloader; import static com.google.common.truth.Truth.assertThat; -import android.util.Log; import android.view.textclassifier.TextClassification; import android.view.textclassifier.TextClassification.Request; +import androidx.test.filters.FlakyTest; import com.android.textclassifier.testing.ExtServicesTextClassifierRule; import org.junit.After; import org.junit.Before; @@ -70,6 +70,7 @@ public class ModelDownloaderIntegrationTest { } @Test + @FlakyTest(bugId = 284901878) public void smokeTest() throws Exception { extServicesTextClassifierRule.addDeviceConfigOverride( "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); @@ -78,6 +79,7 @@ public class ModelDownloaderIntegrationTest { } @Test + @FlakyTest(bugId = 284901878) public void downgradeModel() throws Exception { // Download an experimental model. extServicesTextClassifierRule.addDeviceConfigOverride( @@ -93,6 +95,7 @@ public class ModelDownloaderIntegrationTest { } @Test + @FlakyTest(bugId = 284901878) public void upgradeModel() throws Exception { // Download a model. extServicesTextClassifierRule.addDeviceConfigOverride( @@ -108,6 +111,7 @@ public class ModelDownloaderIntegrationTest { } @Test + @FlakyTest(bugId = 284901878) public void clearFlag() throws Exception { // Download a new model. extServicesTextClassifierRule.addDeviceConfigOverride( @@ -123,6 +127,7 @@ public class ModelDownloaderIntegrationTest { } @Test + @FlakyTest(bugId = 267344737) public void modelsForMultipleLanguagesDownloaded() throws Exception { extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true"); extServicesTextClassifierRule.addDeviceConfigOverride( @@ -168,7 +173,6 @@ public class ModelDownloaderIntegrationTest { .getTextClassifier() .classifyText(new Request.Builder(text, 0, text.length()).build()); // The result id contains the name of the just used model. - Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId()); assertThat(textClassification.getId()).contains(expectedVersion); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java index 76d04e0..8c99baf 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java @@ -23,6 +23,7 @@ import static org.testng.Assert.expectThrows; import androidx.test.core.app.ApplicationProvider; import com.google.android.downloader.DownloadConstraints; +import com.google.android.downloader.DownloadDestination; import com.google.android.downloader.DownloadRequest; import com.google.android.downloader.DownloadResult; import com.google.android.downloader.Downloader; @@ -82,7 +83,7 @@ public final class ModelDownloaderServiceImplTest { targetModelFile.deleteOnExit(); targetMetadataFile.deleteOnExit(); - when(downloader.newRequestBuilder(any(), any())) + when(downloader.newRequestBuilder(any(), any(DownloadDestination.class))) .thenReturn( DownloadRequest.newBuilder() .setUri(URI.create(DOWNLOAD_URI)) diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java index f3ad833..8ce4643 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java @@ -31,6 +31,7 @@ import android.content.pm.PackageManager; import android.content.pm.ResolveInfo; import androidx.test.core.app.ApplicationProvider; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashMap; import java.util.Map; import java.util.UUID; @@ -66,6 +67,7 @@ public final class FakeContextBuilder { * * <p><strong>NOTE: </strong>By default, no component is set to handle any intent. */ + @CanIgnoreReturnValue public FakeContextBuilder setIntentComponent( String intentAction, @Nullable ComponentName component) { Preconditions.checkNotNull(intentAction); @@ -74,6 +76,7 @@ public final class FakeContextBuilder { } /** Sets the app label res for a specified package. */ + @CanIgnoreReturnValue public FakeContextBuilder setAppLabel(String packageName, @Nullable CharSequence appLabel) { Preconditions.checkNotNull(packageName); appLabels.put(packageName, appLabel); @@ -85,6 +88,7 @@ public final class FakeContextBuilder { * * <p><strong>NOTE: </strong>By default, no component is set to handle any intent. */ + @CanIgnoreReturnValue public FakeContextBuilder setAllIntentComponent(@Nullable ComponentName component) { allIntentComponent = component; return this; diff --git a/jni/Android.bp b/jni/Android.bp index 35ee441..17d31fe 100644 --- a/jni/Android.bp +++ b/jni/Android.bp @@ -24,7 +24,10 @@ package { java_library_static { name: "libtextclassifier-java", srcs: ["**/*.java"], - static_libs: ["guava"], + static_libs: [ + "guava", + "error_prone_annotations", + ], sdk_version: "system_current", min_sdk_version: "28", apex_available: [ diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java index 47a369e..a82c96d 100644 --- a/jni/com/google/android/textclassifier/AnnotatorModel.java +++ b/jni/com/google/android/textclassifier/AnnotatorModel.java @@ -17,6 +17,7 @@ package com.google.android.textclassifier; import android.content.res.AssetFileDescriptor; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.Collection; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; @@ -654,41 +655,49 @@ public final class AnnotatorModel implements AutoCloseable { private boolean usePodNer = true; private boolean useVocabAnnotator = true; + @CanIgnoreReturnValue public Builder setLocales(@Nullable String locales) { this.locales = locales; return this; } + @CanIgnoreReturnValue public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) { this.detectedTextLanguageTags = detectedTextLanguageTags; return this; } + @CanIgnoreReturnValue public Builder setAnnotationUsecase(int annotationUsecase) { this.annotationUsecase = annotationUsecase; return this; } + @CanIgnoreReturnValue public Builder setUserLocationLat(double userLocationLat) { this.userLocationLat = userLocationLat; return this; } + @CanIgnoreReturnValue public Builder setUserLocationLng(double userLocationLng) { this.userLocationLng = userLocationLng; return this; } + @CanIgnoreReturnValue public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) { this.userLocationAccuracyMeters = userLocationAccuracyMeters; return this; } + @CanIgnoreReturnValue public Builder setUsePodNer(boolean usePodNer) { this.usePodNer = usePodNer; return this; } + @CanIgnoreReturnValue public Builder setUseVocabAnnotator(boolean useVocabAnnotator) { this.useVocabAnnotator = useVocabAnnotator; return this; @@ -761,6 +770,8 @@ public final class AnnotatorModel implements AutoCloseable { private final boolean usePodNer; private final boolean triggerDictionaryOnBeginnerWords; private final boolean useVocabAnnotator; + private final boolean enableAddContactIntent; + private final boolean enableSearchIntent; private ClassificationOptions( long referenceTimeMsUtc, @@ -774,7 +785,9 @@ public final class AnnotatorModel implements AutoCloseable { String userFamiliarLanguageTags, boolean usePodNer, boolean triggerDictionaryOnBeginnerWords, - boolean useVocabAnnotator) { + boolean useVocabAnnotator, + boolean enableAddContactIntent, + boolean enableSearchIntent) { this.referenceTimeMsUtc = referenceTimeMsUtc; this.referenceTimezone = referenceTimezone; this.locales = locales; @@ -787,6 +800,8 @@ public final class AnnotatorModel implements AutoCloseable { this.usePodNer = usePodNer; this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords; this.useVocabAnnotator = useVocabAnnotator; + this.enableAddContactIntent = enableAddContactIntent; + this.enableSearchIntent = enableSearchIntent; } /** Can be used to build a ClassificationOptions instance. */ @@ -803,68 +818,94 @@ public final class AnnotatorModel implements AutoCloseable { private boolean usePodNer = true; private boolean triggerDictionaryOnBeginnerWords = false; private boolean useVocabAnnotator = true; + private boolean enableAddContactIntent = false; + private boolean enableSearchIntent = false; + @CanIgnoreReturnValue public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) { this.referenceTimeMsUtc = referenceTimeMsUtc; return this; } + @CanIgnoreReturnValue public Builder setReferenceTimezone(String referenceTimezone) { this.referenceTimezone = referenceTimezone; return this; } + @CanIgnoreReturnValue public Builder setLocales(@Nullable String locales) { this.locales = locales; return this; } + @CanIgnoreReturnValue public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) { this.detectedTextLanguageTags = detectedTextLanguageTags; return this; } + @CanIgnoreReturnValue public Builder setAnnotationUsecase(int annotationUsecase) { this.annotationUsecase = annotationUsecase; return this; } + @CanIgnoreReturnValue public Builder setUserLocationLat(double userLocationLat) { this.userLocationLat = userLocationLat; return this; } + @CanIgnoreReturnValue public Builder setUserLocationLng(double userLocationLng) { this.userLocationLng = userLocationLng; return this; } + @CanIgnoreReturnValue public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) { this.userLocationAccuracyMeters = userLocationAccuracyMeters; return this; } + @CanIgnoreReturnValue public Builder setUserFamiliarLanguageTags(String userFamiliarLanguageTags) { this.userFamiliarLanguageTags = userFamiliarLanguageTags; return this; } + @CanIgnoreReturnValue public Builder setUsePodNer(boolean usePodNer) { this.usePodNer = usePodNer; return this; } + @CanIgnoreReturnValue public Builder setTrigerringDictionaryOnBeginnerWords( boolean triggerDictionaryOnBeginnerWords) { this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords; return this; } + @CanIgnoreReturnValue public Builder setUseVocabAnnotator(boolean useVocabAnnotator) { this.useVocabAnnotator = useVocabAnnotator; return this; } + @CanIgnoreReturnValue + public Builder setEnableAddContactIntent(boolean enableAddContactIntent) { + this.enableAddContactIntent = enableAddContactIntent; + return this; + } + + @CanIgnoreReturnValue + public Builder setEnableSearchIntent(boolean enableSearchIntent) { + this.enableSearchIntent = enableSearchIntent; + return this; + } + public ClassificationOptions build() { return new ClassificationOptions( referenceTimeMsUtc, @@ -878,7 +919,9 @@ public final class AnnotatorModel implements AutoCloseable { userFamiliarLanguageTags, usePodNer, triggerDictionaryOnBeginnerWords, - useVocabAnnotator); + useVocabAnnotator, + enableAddContactIntent, + enableSearchIntent); } } @@ -936,6 +979,14 @@ public final class AnnotatorModel implements AutoCloseable { public boolean getUseVocabAnnotator() { return useVocabAnnotator; } + + public boolean getEnableAddContactIntent() { + return enableAddContactIntent; + } + + public boolean getEnableSearchIntent() { + return enableSearchIntent; + } } /** Represents options for the annotate call. */ @@ -1011,81 +1062,97 @@ public final class AnnotatorModel implements AutoCloseable { private boolean triggerDictionaryOnBeginnerWords = false; private boolean useVocabAnnotator = true; + @CanIgnoreReturnValue public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) { this.referenceTimeMsUtc = referenceTimeMsUtc; return this; } + @CanIgnoreReturnValue public Builder setReferenceTimezone(String referenceTimezone) { this.referenceTimezone = referenceTimezone; return this; } + @CanIgnoreReturnValue public Builder setLocales(@Nullable String locales) { this.locales = locales; return this; } + @CanIgnoreReturnValue public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) { this.detectedTextLanguageTags = detectedTextLanguageTags; return this; } + @CanIgnoreReturnValue public Builder setEntityTypes(Collection<String> entityTypes) { this.entityTypes = entityTypes; return this; } + @CanIgnoreReturnValue public Builder setAnnotateMode(int annotateMode) { this.annotateMode = annotateMode; return this; } + @CanIgnoreReturnValue public Builder setAnnotationUsecase(int annotationUsecase) { this.annotationUsecase = annotationUsecase; return this; } + @CanIgnoreReturnValue public Builder setHasLocationPermission(boolean hasLocationPermission) { this.hasLocationPermission = hasLocationPermission; return this; } + @CanIgnoreReturnValue public Builder setHasPersonalizationPermission(boolean hasPersonalizationPermission) { this.hasPersonalizationPermission = hasPersonalizationPermission; return this; } + @CanIgnoreReturnValue public Builder setIsSerializedEntityDataEnabled(boolean isSerializedEntityDataEnabled) { this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled; return this; } + @CanIgnoreReturnValue public Builder setUserLocationLat(double userLocationLat) { this.userLocationLat = userLocationLat; return this; } + @CanIgnoreReturnValue public Builder setUserLocationLng(double userLocationLng) { this.userLocationLng = userLocationLng; return this; } + @CanIgnoreReturnValue public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) { this.userLocationAccuracyMeters = userLocationAccuracyMeters; return this; } + @CanIgnoreReturnValue public Builder setUsePodNer(boolean usePodNer) { this.usePodNer = usePodNer; return this; } + @CanIgnoreReturnValue public Builder setTriggerDictionaryOnBeginnerWords(boolean triggerDictionaryOnBeginnerWords) { this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords; return this; } + @CanIgnoreReturnValue public Builder setUseVocabAnnotator(boolean useVocabAnnotator) { this.useVocabAnnotator = useVocabAnnotator; return this; diff --git a/native/Android.bp b/native/Android.bp index 9aaa951..5af5c50 100644 --- a/native/Android.bp +++ b/native/Android.bp @@ -68,6 +68,7 @@ cc_library_static { "com.android.neuralnetworks", "test_com.android.neuralnetworks", "com.android.extservices", + "com.android.adservices", ], } @@ -136,6 +137,61 @@ cc_defaults { ], } +cc_library_static { + name: "libtextclassifier_bert_tokenizer", + export_include_dirs: ["."], + visibility: ["//external/tflite-support:__subpackages__"], + srcs: [ + "utils/base/logging.cc", + "utils/base/logging_raw.cc", + "utils/bert_tokenizer.cc", + "utils/strings/utf8.cc", + "utils/tokenizer-utils.cc", + "utils/utf8/unilib-common.cc", + "utils/utf8/unicodetext.cc", + "utils/wordpiece_tokenizer.cc", + ], + apex_available: [ + "//apex_available:platform", + "com.android.extservices", + "com.android.adservices", + ], + cflags: [ + "-Wno-ignored-qualifiers", + "-Wno-missing-field-initializers", + "-Wno-unused-parameter", + + "-DLIBTEXTCLASSIFIER_UNILIB_ICU", + "-DZLIB_CONST", + "-DSAFTM_COMPACT_LOGGING", + "-DTC3_WITH_ACTIONS_OPS", + "-DTC3_UNILIB_JAVAICU", + "-DTC3_CALENDAR_JAVAICU", + "-DTC3_AOSP", + "-DTC3_VOCAB_ANNOTATOR_IMPL", + "-DTC3_POD_NER_ANNOTATOR_IMPL", + ], + product_variables: { + debuggable: { + // Only enable debug logging in userdebug/eng builds. + cflags: ["-DTC3_DEBUG_LOGGING=1"], + }, + }, + header_libs: [ + "jni_headers", + "tensorflow_headers", + "flatbuffer_headers", + "libtextclassifier_flatbuffer_headers", + ], + static_libs: [ + "libtextclassifier_abseil", + "tflite_support", + ], + sdk_version: "current", + min_sdk_version: "30", + stl: "libc++_static", +} + // ----------------- // Generate headers with FlatBuffer schema compiler. // ----------------- diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp index 813ec6a..5ed09cc 100644 --- a/native/FlatBufferHeaders.bp +++ b/native/FlatBufferHeaders.bp @@ -15,65 +15,58 @@ // genrule { - name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", - srcs: ["lang_id/common/flatbuffers/model.fbs"], - out: ["lang_id/common/flatbuffers/model_generated.h"], - defaults: ["fbgen"], -} - -genrule { - name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", - srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"], - out: ["lang_id/common/flatbuffers/embedding-network_generated.h"], + name: "libtextclassifier_fbgen_utils_i18n_language-tag", + srcs: ["utils/i18n/language-tag.fbs"], + out: ["utils/i18n/language-tag_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_datetime_datetime", - srcs: ["annotator/datetime/datetime.fbs"], - out: ["annotator/datetime/datetime_generated.h"], + name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config", + srcs: ["utils/tflite/text_encoder_config.fbs"], + out: ["utils/tflite/text_encoder_config_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_model", - srcs: ["annotator/model.fbs"], - out: ["annotator/model_generated.h"], + name: "libtextclassifier_fbgen_utils_resources", + srcs: ["utils/resources.fbs"], + out: ["utils/resources_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_experimental_experimental", - srcs: ["annotator/experimental/experimental.fbs"], - out: ["annotator/experimental/experimental_generated.h"], + name: "libtextclassifier_fbgen_utils_grammar_rules", + srcs: ["utils/grammar/rules.fbs"], + out: ["utils/grammar/rules_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_entity-data", - srcs: ["annotator/entity-data.fbs"], - out: ["annotator/entity-data_generated.h"], + name: "libtextclassifier_fbgen_utils_grammar_testing_value", + srcs: ["utils/grammar/testing/value.fbs"], + out: ["utils/grammar/testing/value_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_person_name_person_name_model", - srcs: ["annotator/person_name/person_name_model.fbs"], - out: ["annotator/person_name/person_name_model_generated.h"], + name: "libtextclassifier_fbgen_utils_grammar_semantics_expression", + srcs: ["utils/grammar/semantics/expression.fbs"], + out: ["utils/grammar/semantics/expression_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config", - srcs: ["utils/tflite/text_encoder_config.fbs"], - out: ["utils/tflite/text_encoder_config_generated.h"], + name: "libtextclassifier_fbgen_utils_zlib_buffer", + srcs: ["utils/zlib/buffer.fbs"], + out: ["utils/zlib/buffer_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_codepoint-range", - srcs: ["utils/codepoint-range.fbs"], - out: ["utils/codepoint-range_generated.h"], + name: "libtextclassifier_fbgen_utils_normalization", + srcs: ["utils/normalization.fbs"], + out: ["utils/normalization_generated.h"], defaults: ["fbgen"], } @@ -85,16 +78,16 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", - srcs: ["utils/flatbuffers/flatbuffers.fbs"], - out: ["utils/flatbuffers/flatbuffers_generated.h"], + name: "libtextclassifier_fbgen_utils_container_bit-vector", + srcs: ["utils/container/bit-vector.fbs"], + out: ["utils/container/bit-vector_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_zlib_buffer", - srcs: ["utils/zlib/buffer.fbs"], - out: ["utils/zlib/buffer_generated.h"], + name: "libtextclassifier_fbgen_utils_codepoint-range", + srcs: ["utils/codepoint-range.fbs"], + out: ["utils/codepoint-range_generated.h"], defaults: ["fbgen"], } @@ -106,65 +99,72 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_utils_grammar_testing_value", - srcs: ["utils/grammar/testing/value.fbs"], - out: ["utils/grammar/testing/value_generated.h"], + name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", + srcs: ["utils/flatbuffers/flatbuffers.fbs"], + out: ["utils/flatbuffers/flatbuffers_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_grammar_rules", - srcs: ["utils/grammar/rules.fbs"], - out: ["utils/grammar/rules_generated.h"], + name: "libtextclassifier_fbgen_actions_actions_model", + srcs: ["actions/actions_model.fbs"], + out: ["actions/actions_model_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_grammar_semantics_expression", - srcs: ["utils/grammar/semantics/expression.fbs"], - out: ["utils/grammar/semantics/expression_generated.h"], + name: "libtextclassifier_fbgen_actions_actions-entity-data", + srcs: ["actions/actions-entity-data.fbs"], + out: ["actions/actions-entity-data_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_resources", - srcs: ["utils/resources.fbs"], - out: ["utils/resources_generated.h"], + name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", + srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"], + out: ["lang_id/common/flatbuffers/embedding-network_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_i18n_language-tag", - srcs: ["utils/i18n/language-tag.fbs"], - out: ["utils/i18n/language-tag_generated.h"], + name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", + srcs: ["lang_id/common/flatbuffers/model.fbs"], + out: ["lang_id/common/flatbuffers/model_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_normalization", - srcs: ["utils/normalization.fbs"], - out: ["utils/normalization_generated.h"], + name: "libtextclassifier_fbgen_annotator_entity-data", + srcs: ["annotator/entity-data.fbs"], + out: ["annotator/entity-data_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_container_bit-vector", - srcs: ["utils/container/bit-vector.fbs"], - out: ["utils/container/bit-vector_generated.h"], + name: "libtextclassifier_fbgen_annotator_person_name_person_name_model", + srcs: ["annotator/person_name/person_name_model.fbs"], + out: ["annotator/person_name/person_name_model_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_actions_actions-entity-data", - srcs: ["actions/actions-entity-data.fbs"], - out: ["actions/actions-entity-data_generated.h"], + name: "libtextclassifier_fbgen_annotator_experimental_experimental", + srcs: ["annotator/experimental/experimental.fbs"], + out: ["annotator/experimental/experimental_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_actions_actions_model", - srcs: ["actions/actions_model.fbs"], - out: ["actions/actions_model_generated.h"], + name: "libtextclassifier_fbgen_annotator_model", + srcs: ["annotator/model.fbs"], + out: ["annotator/model_generated.h"], + defaults: ["fbgen"], +} + +genrule { + name: "libtextclassifier_fbgen_annotator_datetime_datetime", + srcs: ["annotator/datetime/datetime.fbs"], + out: ["annotator/datetime/datetime_generated.h"], defaults: ["fbgen"], } @@ -176,52 +176,53 @@ cc_library_headers { apex_available: [ "//apex_available:platform", "com.android.extservices", + "com.android.adservices", ], generated_headers: [ - "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", - "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", - "libtextclassifier_fbgen_annotator_datetime_datetime", - "libtextclassifier_fbgen_annotator_model", - "libtextclassifier_fbgen_annotator_experimental_experimental", - "libtextclassifier_fbgen_annotator_entity-data", - "libtextclassifier_fbgen_annotator_person_name_person_name_model", + "libtextclassifier_fbgen_utils_i18n_language-tag", "libtextclassifier_fbgen_utils_tflite_text_encoder_config", - "libtextclassifier_fbgen_utils_codepoint-range", - "libtextclassifier_fbgen_utils_intents_intent-config", - "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", - "libtextclassifier_fbgen_utils_zlib_buffer", - "libtextclassifier_fbgen_utils_tokenizer", + "libtextclassifier_fbgen_utils_resources", "libtextclassifier_fbgen_utils_grammar_rules", "libtextclassifier_fbgen_utils_grammar_semantics_expression", - "libtextclassifier_fbgen_utils_resources", - "libtextclassifier_fbgen_utils_i18n_language-tag", + "libtextclassifier_fbgen_utils_zlib_buffer", "libtextclassifier_fbgen_utils_normalization", + "libtextclassifier_fbgen_utils_intents_intent-config", "libtextclassifier_fbgen_utils_container_bit-vector", - "libtextclassifier_fbgen_actions_actions-entity-data", + "libtextclassifier_fbgen_utils_codepoint-range", + "libtextclassifier_fbgen_utils_tokenizer", + "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", "libtextclassifier_fbgen_actions_actions_model", - ], - export_generated_headers: [ - "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", + "libtextclassifier_fbgen_actions_actions-entity-data", "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", - "libtextclassifier_fbgen_annotator_datetime_datetime", - "libtextclassifier_fbgen_annotator_model", - "libtextclassifier_fbgen_annotator_experimental_experimental", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", "libtextclassifier_fbgen_annotator_entity-data", "libtextclassifier_fbgen_annotator_person_name_person_name_model", + "libtextclassifier_fbgen_annotator_experimental_experimental", + "libtextclassifier_fbgen_annotator_model", + "libtextclassifier_fbgen_annotator_datetime_datetime", + ], + export_generated_headers: [ + "libtextclassifier_fbgen_utils_i18n_language-tag", "libtextclassifier_fbgen_utils_tflite_text_encoder_config", - "libtextclassifier_fbgen_utils_codepoint-range", - "libtextclassifier_fbgen_utils_intents_intent-config", - "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", - "libtextclassifier_fbgen_utils_zlib_buffer", - "libtextclassifier_fbgen_utils_tokenizer", + "libtextclassifier_fbgen_utils_resources", "libtextclassifier_fbgen_utils_grammar_rules", "libtextclassifier_fbgen_utils_grammar_semantics_expression", - "libtextclassifier_fbgen_utils_resources", - "libtextclassifier_fbgen_utils_i18n_language-tag", + "libtextclassifier_fbgen_utils_zlib_buffer", "libtextclassifier_fbgen_utils_normalization", + "libtextclassifier_fbgen_utils_intents_intent-config", "libtextclassifier_fbgen_utils_container_bit-vector", - "libtextclassifier_fbgen_actions_actions-entity-data", + "libtextclassifier_fbgen_utils_codepoint-range", + "libtextclassifier_fbgen_utils_tokenizer", + "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", "libtextclassifier_fbgen_actions_actions_model", + "libtextclassifier_fbgen_actions_actions-entity-data", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", + "libtextclassifier_fbgen_annotator_entity-data", + "libtextclassifier_fbgen_annotator_person_name_person_name_model", + "libtextclassifier_fbgen_annotator_experimental_experimental", + "libtextclassifier_fbgen_annotator_model", + "libtextclassifier_fbgen_annotator_datetime_datetime", ], } diff --git a/native/JavaTests.bp b/native/JavaTests.bp index 9837173..89ef085 100644 --- a/native/JavaTests.bp +++ b/native/JavaTests.bp @@ -17,30 +17,30 @@ filegroup { name: "libtextclassifier_java_test_sources", srcs: [ - "annotator/datetime/datetime-grounder_test.cc", - "annotator/datetime/regex-parser_test.cc", - "annotator/datetime/grammar-parser_test.cc", - "annotator/pod_ner/pod-ner-impl_test.cc", + "utils/grammar/parsing/lexer_test.cc", "utils/intents/intent-generator-test-lib.cc", - "utils/calendar/calendar_test.cc", "utils/regex-match_test.cc", - "utils/grammar/parsing/lexer_test.cc", + "utils/calendar/calendar_test.cc", "actions/actions-suggestions_test.cc", "actions/grammar-actions_test.cc", - "annotator/number/number_test-include.cc", - "annotator/annotator_test-include.cc", - "annotator/grammar/grammar-annotator_test.cc", - "annotator/grammar/test-utils.cc", - "utils/utf8/unilib_test-include.cc", + "annotator/pod_ner/pod-ner-impl_test.cc", + "annotator/datetime/regex-parser_test.cc", + "annotator/datetime/grammar-parser_test.cc", + "annotator/datetime/datetime-grounder_test.cc", "utils/grammar/parsing/parser_test.cc", - "utils/grammar/analyzer_test.cc", "utils/grammar/semantics/composer_test.cc", - "utils/grammar/semantics/evaluators/constituent-eval_test.cc", - "utils/grammar/semantics/evaluators/merge-values-eval_test.cc", - "utils/grammar/semantics/evaluators/parse-number-eval_test.cc", "utils/grammar/semantics/evaluators/arithmetic-eval_test.cc", - "utils/grammar/semantics/evaluators/span-eval_test.cc", "utils/grammar/semantics/evaluators/const-eval_test.cc", + "utils/grammar/semantics/evaluators/merge-values-eval_test.cc", + "utils/grammar/semantics/evaluators/span-eval_test.cc", + "utils/grammar/semantics/evaluators/constituent-eval_test.cc", + "utils/grammar/semantics/evaluators/parse-number-eval_test.cc", "utils/grammar/semantics/evaluators/compose-eval_test.cc", + "utils/grammar/analyzer_test.cc", + "utils/utf8/unilib_test-include.cc", + "annotator/grammar/grammar-annotator_test.cc", + "annotator/grammar/test-utils.cc", + "annotator/annotator_test-include.cc", + "annotator/number/number_test-include.cc", ], } diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs Binary files differindex 6ebf1cf..e5ebfec 100644 --- a/native/actions/actions-entity-data.bfbs +++ b/native/actions/actions-entity-data.bfbs diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index 9f9a8d4..eeeb508 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -21,6 +21,8 @@ #include <vector> #include "utils/base/statusor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/random.h" #if !defined(TC3_DISABLE_LUA) #include "actions/lua-actions.h" @@ -42,6 +44,7 @@ #include "utils/strings/utf8.h" #include "utils/utf8/unicodetext.h" #include "absl/container/flat_hash_set.h" +#include "absl/random/distributions.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { @@ -813,6 +816,8 @@ void ActionsSuggestions::PopulateTextReplies( const tflite::Interpreter* interpreter, int suggestion_index, int score_index, const std::string& type, float priority_score, const absl::flat_hash_set<std::string>& blocklist, + const absl::flat_hash_map<std::string, std::vector<std::string>>& + concept_mappings, ActionsSuggestionsResponse* response) const { const std::vector<tflite::StringRef> replies = model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter); @@ -831,6 +836,12 @@ void ActionsSuggestions::PopulateTextReplies( if (blocklist.contains(response_text)) { continue; } + if (concept_mappings.contains(response_text)) { + const int candidates_size = concept_mappings.at(response_text).size(); + const int candidate_index = absl::Uniform<int>( + absl::IntervalOpenOpen, bit_gen_, 0, candidates_size); + response_text = concept_mappings.at(response_text)[candidate_index]; + } response->actions.push_back({response_text, type, score, priority_score}); } @@ -918,11 +929,11 @@ bool ActionsSuggestions::ReadModelOutput( if (!response->output_filtered_min_triggering_score && model_->tflite_model_spec()->output_replies() >= 0) { absl::flat_hash_set<std::string> empty_blocklist; - PopulateTextReplies(interpreter, - model_->tflite_model_spec()->output_replies(), - model_->tflite_model_spec()->output_replies_scores(), - model_->smart_reply_action_type()->str(), - /* priority_score */ 0.0, empty_blocklist, response); + PopulateTextReplies( + interpreter, model_->tflite_model_spec()->output_replies(), + model_->tflite_model_spec()->output_replies_scores(), + model_->smart_reply_action_type()->str(), + /* priority_score */ 0.0, empty_blocklist, {}, response); } // Read actions suggestions. @@ -961,6 +972,8 @@ bool ActionsSuggestions::ReadModelOutput( const int suggestions_scores_index = metadata->output_suggestions_scores(); absl::flat_hash_set<std::string> response_text_blocklist; + absl::flat_hash_map<std::string, std::vector<std::string>> + concept_mappings; switch (metadata->prediction_type()) { case PredictionType_NEXT_MESSAGE_PREDICTION: if (!task_spec || task_spec->type()->size() == 0) { @@ -973,13 +986,22 @@ bool ActionsSuggestions::ReadModelOutput( response_text_blocklist.insert(val->str()); } } + if (task_spec->concept_mappings()) { + for (const auto& concept : *task_spec->concept_mappings()) { + std::vector<std::string> candidates; + for (const auto& candidate : *concept->candidates()) { + candidates.push_back(candidate->str()); + } + concept_mappings[concept->concept_name()->str()] = candidates; + } + } } PopulateTextReplies( interpreter, suggestions_index, suggestions_scores_index, task_spec ? task_spec->type()->str() : model_->smart_reply_action_type()->str(), task_spec ? task_spec->priority_score() : 0.0, - response_text_blocklist, response); + response_text_blocklist, concept_mappings, response); break; case PredictionType_INTENT_TRIGGERING: PopulateIntentTriggering(interpreter, suggestions_index, diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h index 87f55fb..c3d58e4 100644 --- a/native/actions/actions-suggestions.h +++ b/native/actions/actions-suggestions.h @@ -43,7 +43,9 @@ #include "utils/utf8/unilib.h" #include "utils/variant.h" #include "utils/zlib/zlib.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/random/random.h" namespace libtextclassifier3 { @@ -175,11 +177,13 @@ class ActionsSuggestions { void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const; - void PopulateTextReplies(const tflite::Interpreter* interpreter, - int suggestion_index, int score_index, - const std::string& type, float priority_score, - const absl::flat_hash_set<std::string>& blocklist, - ActionsSuggestionsResponse* response) const; + void PopulateTextReplies( + const tflite::Interpreter* interpreter, int suggestion_index, + int score_index, const std::string& type, float priority_score, + const absl::flat_hash_set<std::string>& blocklist, + const absl::flat_hash_map<std::string, std::vector<std::string>>& + concept_mappings, + ActionsSuggestionsResponse* response) const; void PopulateIntentTriggering(const tflite::Interpreter* interpreter, int suggestion_index, int score_index, @@ -273,6 +277,9 @@ class ActionsSuggestions { // Conversation intent detection model for additional actions. std::unique_ptr<const ConversationIntentDetection> conversation_intent_detection_; + + // Used for randomly selecting candidates. + mutable absl::BitGen bit_gen_; }; // Interprets the buffer as a Model flatbuffer and returns it for reading. diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc index b51ebc7..65f9796 100644 --- a/native/actions/actions-suggestions_test.cc +++ b/native/actions/actions-suggestions_test.cc @@ -61,6 +61,8 @@ constexpr char kMultiTaskSrP13nModelFileName[] = "actions_suggestions_test.multi_task_sr_p13n.model"; constexpr char kMultiTaskSrEmojiModelFileName[] = "actions_suggestions_test.multi_task_sr_emoji.model"; +constexpr char kMultiTaskSrEmojiConceptModelFileName[] = + "actions_suggestions_test.multi_task_sr_emoji_concept.model"; constexpr char kSensitiveTFliteModelFileName[] = "actions_suggestions_test.sensitive_tflite.model"; constexpr char kLiveRelayTFLiteModelFileName[] = @@ -1835,6 +1837,25 @@ TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) { EXPECT_EQ(response.actions[2].type, "text_reply"); } +TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) { + std::unique_ptr<ActionsSuggestions> actions_suggestions = + LoadTestModel(kMultiTaskSrEmojiConceptModelFileName); + + const ActionsSuggestionsResponse response = + actions_suggestions->SuggestActions( + {{{/*user_id=*/1, "i am tired", + /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"Europe/Zurich", + /*annotations=*/{}, + /*locales=*/"en"}}}); + std::vector<std::string> sigh_emojis = {"😔", "😞"}; + + EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(), + response.actions[0].response_text) != + sigh_emojis.end()); + EXPECT_EQ(response.actions[0].type, "emoji_reply"); +} + TEST_F(ActionsSuggestionsTest, LiveRelayModel) { std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(kLiveRelayTFLiteModelFileName); diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs index 0d8c7ad..70f9104 100644 --- a/native/actions/actions_model.fbs +++ b/native/actions/actions_model.fbs @@ -312,6 +312,15 @@ table TriggeringPreconditions { min_reply_score_threshold:float = 0; } +// This proto handles model outputs that are concepts, such as emoji concept +// suggestion models. Each concept maps to a list of candidates. One of +// the candidates is chosen randomly as the final suggestion. +namespace libtextclassifier3; +table ActionConceptToSuggestion { + concept_name:string (shared); + candidates:[string]; +} + namespace libtextclassifier3; table ActionSuggestionSpec { // Type of the action suggestion. @@ -331,6 +340,10 @@ table ActionSuggestionSpec { entity_data:ActionsEntityData; response_text_blocklist:[string]; + + // If provided, map the response as concept to one of the corresponding + // candidates. + concept_mappings:[ActionConceptToSuggestion]; } // Options to specify triggering behaviour per action class. diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model Binary files differindex 0fa7f7e..6d7bdb0 100644 --- a/native/actions/test_data/actions_suggestions_grammar_test.model +++ b/native/actions/test_data/actions_suggestions_grammar_test.model diff --git a/native/actions/test_data/actions_suggestions_test.live_relay.model b/native/actions/test_data/actions_suggestions_test.live_relay.model Binary files differindex 6ff4302..af5e10b 100644 --- a/native/actions/test_data/actions_suggestions_test.live_relay.model +++ b/native/actions/test_data/actions_suggestions_test.live_relay.model diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model Binary files differindex 6107e98..88f62eb 100644 --- a/native/actions/test_data/actions_suggestions_test.model +++ b/native/actions/test_data/actions_suggestions_test.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model Binary files differindex 436ed93..40a2409 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model Binary files differindex 935691d..effb2cb 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model Binary files differnew file mode 100644 index 0000000..18333d6 --- /dev/null +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model Binary files differindex 2c9f74b..e41ab39 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model Binary files differindex cdb7523..5314b43 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model Binary files differindex ac28fa2..a633742 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model Binary files differindex d864b79..6685d26 100644 --- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model +++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index e0d4241..a8483f1 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -432,7 +432,8 @@ void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib, datetime_parser_ = std::make_unique<GrammarDatetimeParser>( *analyzer_, *datetime_grounder_, /*target_classification_score=*/1.0, - /*priority_score=*/1.0); + /*priority_score=*/1.0, + model_->datetime_grammar_model()->enabled_modes()); } } else if (model_->datetime_model()) { datetime_parser_ = RegexDatetimeParser::Instance( @@ -604,6 +605,8 @@ bool Annotator::InitializeKnowledgeEngine( if (model_->triggering_options() != nullptr) { knowledge_engine->SetPriorityScore( model_->triggering_options()->knowledge_priority_score()); + knowledge_engine->SetEnabledModes( + model_->triggering_options()->knowledge_enabled_modes()); } knowledge_engine_ = std::move(knowledge_engine); return true; @@ -621,10 +624,21 @@ bool Annotator::InitializeContactEngine(const std::string& serialized_config) { return true; } +void Annotator::CleanUpContactEngine() { + if (contact_engine_ == nullptr) { + TC3_LOG(INFO) + << "Attempting to clean up contact engine that does not exist."; + return; + } + contact_engine_->CleanUp(); +} + bool Annotator::InitializeInstalledAppEngine( const std::string& serialized_config) { std::unique_ptr<InstalledAppEngine> installed_app_engine( - new InstalledAppEngine(selection_feature_processor_.get(), unilib_)); + new InstalledAppEngine( + selection_feature_processor_.get(), unilib_, + model_->triggering_options()->installed_app_enabled_modes())); if (!installed_app_engine->Initialize(serialized_config)) { TC3_LOG(ERROR) << "Failed to initialize the installed app engine."; return false; @@ -912,38 +926,40 @@ CodepointSpan Annotator::SuggestSelection( !knowledge_engine_ ->Chunk(context, options.annotation_usecase, options.location_context, Permissions(), - AnnotateMode::kEntityAnnotation, &candidates) + AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION, + &candidates) .ok()) { TC3_LOG(ERROR) << "Knowledge suggest selection failed."; return original_click_indices; } if (contact_engine_ != nullptr && - !contact_engine_->Chunk(context_unicode, tokens, + !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Contact suggest selection failed."; return original_click_indices; } if (installed_app_engine_ != nullptr && - !installed_app_engine_->Chunk(context_unicode, tokens, + !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Installed app suggest selection failed."; return original_click_indices; } if (number_annotator_ != nullptr && !number_annotator_->FindAll(context_unicode, options.annotation_usecase, + ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Number annotator failed in suggest selection."; return original_click_indices; } if (duration_annotator_ != nullptr && - !duration_annotator_->FindAll(context_unicode, tokens, - options.annotation_usecase, - &candidates.annotated_spans[0])) { + !duration_annotator_->FindAll( + context_unicode, tokens, options.annotation_usecase, + ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Duration annotator failed in suggest selection."; return original_click_indices; } if (person_name_engine_ != nullptr && - !person_name_engine_->Chunk(context_unicode, tokens, + !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION, &candidates.annotated_spans[0])) { TC3_LOG(ERROR) << "Person name suggest selection failed."; return original_click_indices; @@ -964,7 +980,9 @@ CodepointSpan Annotator::SuggestSelection( candidates.annotated_spans[0].push_back(pod_ner_suggested_span); } - if (experimental_annotator_ != nullptr) { + if (experimental_annotator_ != nullptr && + (model_->triggering_options()->experimental_enabled_modes() & + ModeFlag_SELECTION)) { candidates.annotated_spans[0].push_back( experimental_annotator_->SuggestSelection(context_unicode, click_indices)); @@ -1896,7 +1914,9 @@ std::vector<ClassificationResult> Annotator::ClassifyText( candidates.push_back({selection_indices, {vocab_annotator_result}}); } - if (experimental_annotator_) { + if (experimental_annotator_ && + (model_->triggering_options()->experimental_enabled_modes() & + ModeFlag_CLASSIFICATION)) { experimental_annotator_->ClassifyText(context_unicode, selection_indices, candidates); } @@ -2218,7 +2238,8 @@ Status Annotator::AnnotateSingleInput( const bool contact_annotations_enabled = !is_raw_usecase || is_entity_type_enabled(Collections::Contact()); if (contact_annotations_enabled && contact_engine_ && - !contact_engine_->Chunk(context_unicode, tokens, candidates)) { + !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION, + candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk."); } @@ -2226,7 +2247,8 @@ Status Annotator::AnnotateSingleInput( const bool app_annotations_enabled = !is_raw_usecase || is_entity_type_enabled(Collections::App()); if (app_annotations_enabled && installed_app_engine_ && - !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) { + !installed_app_engine_->Chunk(context_unicode, tokens, + ModeFlag_ANNOTATION, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run installed app engine Chunk."); } @@ -2237,7 +2259,7 @@ Status Annotator::AnnotateSingleInput( is_entity_type_enabled(Collections::Percentage())); if (number_annotations_enabled && number_annotator_ != nullptr && !number_annotator_->FindAll(context_unicode, options.annotation_usecase, - candidates)) { + ModeFlag_ANNOTATION, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run number annotator FindAll."); } @@ -2247,7 +2269,8 @@ Status Annotator::AnnotateSingleInput( !is_raw_usecase || is_entity_type_enabled(Collections::Duration()); if (duration_annotations_enabled && duration_annotator_ != nullptr && !duration_annotator_->FindAll(context_unicode, tokens, - options.annotation_usecase, candidates)) { + options.annotation_usecase, + ModeFlag_ANNOTATION, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run duration annotator FindAll."); } @@ -2256,7 +2279,8 @@ Status Annotator::AnnotateSingleInput( const bool person_annotations_enabled = !is_raw_usecase || is_entity_type_enabled(Collections::PersonName()); if (person_annotations_enabled && person_name_engine_ && - !person_name_engine_->Chunk(context_unicode, tokens, candidates)) { + !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION, + candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run person name engine Chunk."); } @@ -2290,6 +2314,8 @@ Status Annotator::AnnotateSingleInput( // Annotate with the experimental annotator. if (experimental_annotator_ != nullptr && + (model_->triggering_options()->experimental_enabled_modes() & + ModeFlag_ANNOTATION) && !experimental_annotator_->Annotate(context_unicode, candidates)) { return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator."); } @@ -2376,14 +2402,21 @@ StatusOr<Annotations> Annotator::AnnotateStructuredInput( .relative_bounding_box_height = string_fragment.bounding_box_height}); } + const EnabledEntityTypes is_entity_type_enabled(options.entity_types); + const bool is_raw_usecase = + options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW; + + const bool knowledge_engine_annotations_enabled = + !is_raw_usecase || is_entity_type_enabled(Collections::Entity()); // KnowledgeEngine is special, because it supports annotation of multiple // fragments at once. - if (knowledge_engine_ && + if (knowledge_engine_annotations_enabled && knowledge_engine_ && !knowledge_engine_ ->ChunkMultipleSpans(text_to_annotate, fragment_metadata, options.annotation_usecase, options.location_context, options.permissions, - options.annotate_mode, &annotation_candidates) + options.annotate_mode, ModeFlag_ANNOTATION, + &annotation_candidates) .ok()) { return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk."); } diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h index d69fe32..5df8129 100644 --- a/native/annotator/annotator.h +++ b/native/annotator/annotator.h @@ -149,6 +149,9 @@ class Annotator { // Initializes the contact engine with the given config. bool InitializeContactEngine(const std::string& serialized_config); + // Cleans up the resources associated with the contact engine. + void CleanUpContactEngine(); + // Initializes the installed app engine with the given config. bool InitializeInstalledAppEngine(const std::string& serialized_config); diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc index 6e7eeab..3d352c6 100644 --- a/native/annotator/annotator_jni.cc +++ b/native/annotator/annotator_jni.cc @@ -275,6 +275,7 @@ StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject( device_locales, classification_result, options->reference_time_ms_utc, context, selection_indices, app_context, model_context->model()->entity_data_schema(), + options->enable_add_contact_intent, options->enable_search_intent, &remote_action_templates)) { return {Status::UNKNOWN}; } @@ -896,6 +897,9 @@ TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator) (JNIEnv* env, jobject thiz, jlong ptr) { const AnnotatorJniContext* context = reinterpret_cast<AnnotatorJniContext*>(ptr); + if (context != nullptr && context->model()) { + context->model()->CleanUpContactEngine(); + } delete context; } diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc index a6f636f..6ee4977 100644 --- a/native/annotator/annotator_jni_common.cc +++ b/native/annotator/annotator_jni_common.cc @@ -279,6 +279,23 @@ StatusOr<ClassificationOptions> FromJavaClassificationOptions( JniHelper::CallBooleanMethod(env, joptions, get_trigger_dictionary_on_beginner_words)); + // .getEnableAddContactIntent() + TC3_ASSIGN_OR_RETURN( + jmethodID get_enable_add_contact_intent, + JniHelper::GetMethodID(env, options_class.get(), + "getEnableAddContactIntent", "()Z")); + TC3_ASSIGN_OR_RETURN(classifier_options.enable_add_contact_intent, + JniHelper::CallBooleanMethod( + env, joptions, get_enable_add_contact_intent)); + + // .getEnableSearchIntent() + TC3_ASSIGN_OR_RETURN(jmethodID get_enable_search_intent, + JniHelper::GetMethodID(env, options_class.get(), + "getEnableSearchIntent", "()Z")); + TC3_ASSIGN_OR_RETURN( + classifier_options.enable_search_intent, + JniHelper::CallBooleanMethod(env, joptions, get_enable_search_intent)); + return classifier_options; } diff --git a/native/annotator/collections.h b/native/annotator/collections.h index 417b447..becdcdb 100644 --- a/native/annotator/collections.h +++ b/native/annotator/collections.h @@ -144,6 +144,36 @@ class Collections { *[]() { return new std::string("otp_code"); }(); return value; } + static const std::string& Art() { + static const std::string& value = + *[]() { return new std::string("art"); }(); + return value; + } + static const std::string& ConsumerGood() { + static const std::string& value = + *[]() { return new std::string("consumer_good"); }(); + return value; + } + static const std::string& Event() { + static const std::string& value = + *[]() { return new std::string("event"); }(); + return value; + } + static const std::string& Location() { + static const std::string& value = + *[]() { return new std::string("location"); }(); + return value; + } + static const std::string& Organization() { + static const std::string& value = + *[]() { return new std::string("organization"); }(); + return value; + } + static const std::string& Person() { + static const std::string& value = + *[]() { return new std::string("person"); }(); + return value; + } }; } // namespace libtextclassifier3 diff --git a/native/annotator/contact/contact-engine-dummy.h b/native/annotator/contact/contact-engine-dummy.h index fe60203..211553c 100644 --- a/native/annotator/contact/contact-engine-dummy.h +++ b/native/annotator/contact/contact-engine-dummy.h @@ -47,13 +47,15 @@ class ContactEngine { } bool Chunk(const UnicodeText& context_unicode, - const std::vector<Token>& tokens, + const std::vector<Token>& tokens, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { return true; } void AddContactMetadataToKnowledgeClassificationResult( ClassificationResult* classification_result) const {} + + void CleanUp() const {} }; } // namespace libtextclassifier3 diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc index 6d51c19..c49a01a 100644 --- a/native/annotator/datetime/grammar-parser.cc +++ b/native/annotator/datetime/grammar-parser.cc @@ -20,6 +20,7 @@ #include <unordered_set> #include "annotator/datetime/datetime-grounder.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/grammar/analyzer.h" #include "utils/grammar/evaluated-derivation.h" @@ -33,11 +34,13 @@ namespace libtextclassifier3 { GrammarDatetimeParser::GrammarDatetimeParser( const grammar::Analyzer& analyzer, const DatetimeGrounder& datetime_grounder, - const float target_classification_score, const float priority_score) + const float target_classification_score, const float priority_score, + ModeFlag enabled_modes) : analyzer_(analyzer), datetime_grounder_(datetime_grounder), target_classification_score_(target_classification_score), - priority_score_(priority_score) {} + priority_score_(priority_score), + enabled_modes_(enabled_modes) {} StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse( const std::string& input, const int64 reference_time_ms_utc, @@ -54,6 +57,10 @@ StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse( const std::string& reference_timezone, const LocaleList& locale_list, ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end) const { + if (!(enabled_modes_ & mode)) { + return std::vector<DatetimeParseResultSpan>(); + } + std::vector<DatetimeParseResultSpan> results; UnsafeArena arena(/*block_size=*/16 << 10); std::vector<Locale> locales = locale_list.GetLocales(); diff --git a/native/annotator/datetime/grammar-parser.h b/native/annotator/datetime/grammar-parser.h index 6ff4b46..35da843 100644 --- a/native/annotator/datetime/grammar-parser.h +++ b/native/annotator/datetime/grammar-parser.h @@ -22,6 +22,7 @@ #include "annotator/datetime/datetime-grounder.h" #include "annotator/datetime/parser.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/statusor.h" #include "utils/grammar/analyzer.h" @@ -37,7 +38,8 @@ class GrammarDatetimeParser : public DatetimeParser { explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer, const DatetimeGrounder& datetime_grounder, const float target_classification_score, - const float priority_score); + const float priority_score, + ModeFlag enabled_modes); // Parses the dates in 'input' and fills result. Makes sure that the results // do not overlap. @@ -61,6 +63,7 @@ class GrammarDatetimeParser : public DatetimeParser { const DatetimeGrounder& datetime_grounder_; const float target_classification_score_; const float priority_score_; + const ModeFlag enabled_modes_; }; } // namespace libtextclassifier3 diff --git a/native/annotator/datetime/grammar-parser_test.cc b/native/annotator/datetime/grammar-parser_test.cc index cf2dffd..b8a270d 100644 --- a/native/annotator/datetime/grammar-parser_test.cc +++ b/native/annotator/datetime/grammar-parser_test.cc @@ -22,6 +22,7 @@ #include "annotator/datetime/datetime-grounder.h" #include "annotator/datetime/testing/base-parser-test.h" #include "annotator/datetime/testing/datetime-component-builder.h" +#include "annotator/model_generated.h" #include "utils/grammar/analyzer.h" #include "utils/jvm-test-utils.h" #include "utils/test-data-test-utils.h" @@ -42,7 +43,15 @@ std::string ReadFile(const std::string& file_name) { class GrammarDatetimeParserTest : public DateTimeParserTest { public: - void SetUp() override { + void SetUp() override { ResetParser(ModeFlag_ALL); } + + // Exposes the date time parser for tests and evaluations. + const DatetimeParser* DatetimeParserForTests() const override { + return parser_.get(); + } + + protected: + void ResetParser(ModeFlag enabled_modes) { grammar_buffer_ = ReadFile(GetModelPath() + "datetime.fb"); unilib_ = CreateUniLibForTesting(); calendarlib_ = CreateCalendarLibForTesting(); @@ -51,12 +60,8 @@ class GrammarDatetimeParserTest : public DateTimeParserTest { datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_.get()); parser_.reset(new GrammarDatetimeParser(*analyzer_, *datetime_grounder_, /*target_classification_score=*/1.0, - /*priority_score=*/1.0)); - } - - // Exposes the date time parser for tests and evaluations. - const DatetimeParser* DatetimeParserForTests() const override { - return parser_.get(); + /*priority_score=*/1.0, + enabled_modes)); } private: @@ -486,6 +491,13 @@ TEST_F(GrammarDatetimeParserTest, Parse) { .Build()})); } +TEST_F(GrammarDatetimeParserTest, NotEnabledModeHasNoResult) { + ResetParser(ModeFlag_SELECTION); + // `DateTimeParserTest` implementation parses the input under the ANNOTATION + // mode. + EXPECT_TRUE(HasNoResult("{January 1, 1988}")); +} + TEST_F(GrammarDatetimeParserTest, DateValidation) { EXPECT_TRUE(ParsesCorrectly( "{01/02/2020}", 1577919600000, GRANULARITY_DAY, diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc index c59b8e0..df4c60d 100644 --- a/native/annotator/duration/duration.cc +++ b/native/annotator/duration/duration.cc @@ -20,6 +20,7 @@ #include <cstdlib> #include "annotator/collections.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/base/macros.h" @@ -125,8 +126,10 @@ bool DurationAnnotator::ClassifyText( const UnicodeText& context, CodepointSpan selection_indices, AnnotationUsecase annotation_usecase, ClassificationResult* classification_result) const { - if (!options_->enabled() || ((options_->enabled_annotation_usecases() & - (1 << annotation_usecase))) == 0) { + if (!options_->enabled() || + ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) == + 0 || + !(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) { return false; } @@ -151,9 +154,12 @@ bool DurationAnnotator::ClassifyText( bool DurationAnnotator::FindAll(const UnicodeText& context, const std::vector<Token>& tokens, AnnotationUsecase annotation_usecase, + ModeFlag mode, std::vector<AnnotatedSpan>* results) const { - if (!options_->enabled() || ((options_->enabled_annotation_usecases() & - (1 << annotation_usecase))) == 0) { + if (!options_->enabled() || + ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) == + 0 || + !(options_->enabled_modes() & mode)) { return true; } diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h index 1a42ac3..e99542c 100644 --- a/native/annotator/duration/duration.h +++ b/native/annotator/duration/duration.h @@ -87,7 +87,7 @@ class DurationAnnotator { // Finds all duration instances in the input text. bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens, - AnnotationUsecase annotation_usecase, + AnnotationUsecase annotation_usecase, ModeFlag mode, std::vector<AnnotatedSpan>* results) const; private: diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc index 7c07a72..f726058 100644 --- a/native/annotator/duration/duration_test.cc +++ b/native/annotator/duration/duration_test.cc @@ -16,6 +16,7 @@ #include "annotator/duration/duration.h" +#include <cstddef> #include <string> #include <vector> @@ -37,41 +38,61 @@ using testing::ElementsAre; using testing::Field; using testing::IsEmpty; -const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - DurationAnnotatorOptionsT options; - options.enabled = true; +namespace { +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + DurationAnnotatorOptionsT options; + options.enabled = true; + options.enabled_modes = enabled_modes; - options.week_expressions.push_back("week"); - options.week_expressions.push_back("weeks"); + options.week_expressions.push_back("week"); + options.week_expressions.push_back("weeks"); - options.day_expressions.push_back("day"); - options.day_expressions.push_back("days"); + options.day_expressions.push_back("day"); + options.day_expressions.push_back("days"); - options.hour_expressions.push_back("hour"); - options.hour_expressions.push_back("hours"); + options.hour_expressions.push_back("hour"); + options.hour_expressions.push_back("hours"); - options.minute_expressions.push_back("minute"); - options.minute_expressions.push_back("minutes"); + options.minute_expressions.push_back("minute"); + options.minute_expressions.push_back("minutes"); - options.second_expressions.push_back("second"); - options.second_expressions.push_back("seconds"); + options.second_expressions.push_back("second"); + options.second_expressions.push_back("seconds"); - options.filler_expressions.push_back("and"); - options.filler_expressions.push_back("a"); - options.filler_expressions.push_back("an"); - options.filler_expressions.push_back("one"); + options.filler_expressions.push_back("and"); + options.filler_expressions.push_back("a"); + options.filler_expressions.push_back("an"); + options.filler_expressions.push_back("one"); - options.half_expressions.push_back("half"); + options.half_expressions.push_back("half"); - options.sub_token_separator_codepoints.push_back('-'); + options.sub_token_separator_codepoints.push_back('-'); - flatbuffers::FlatBufferBuilder builder; - builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); + flatbuffers::FlatBufferBuilder builder; + builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} +} // namespace - return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data()); +const DurationAnnotatorOptions* TestingDurationAnnotatorOptions( + ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_all = + CreateOptionsData(ModeFlag_ALL); + static const flatbuffers::DetachedBuffer* options_data_selection = + CreateOptionsData(ModeFlag_SELECTION); + static const flatbuffers::DetachedBuffer* options_data_no_selection = + CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION); + + if (enabled_modes == ModeFlag_SELECTION) { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_selection->data()); + } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_no_selection->data()); + } else { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_all->data()); + } } std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) { @@ -103,10 +124,10 @@ std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) { class DurationAnnotatorTest : public ::testing::Test { protected: - DurationAnnotatorTest() + explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL) : INIT_UNILIB_FOR_TESTING(unilib_), feature_processor_(BuildFeatureProcessor(&unilib_)), - duration_annotator_(TestingDurationAnnotatorOptions(), + duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes), feature_processor_.get(), &unilib_) {} std::vector<Token> Tokenize(const UnicodeText& text) { @@ -118,6 +139,19 @@ class DurationAnnotatorTest : public ::testing::Test { DurationAnnotator duration_annotator_; }; +class DurationAnnotatorForAnnotationAndClassificationTest + : public DurationAnnotatorTest { + protected: + DurationAnnotatorForAnnotationAndClassificationTest() + : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest { + protected: + DurationAnnotatorForSelectionTest() + : DurationAnnotatorTest(ModeFlag_SELECTION) {} +}; + TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( @@ -129,6 +163,14 @@ TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { Field(&ClassificationResult::duration_ms, 15 * 60 * 1000))); } +TEST_F(DurationAnnotatorForSelectionTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification; + EXPECT_FALSE(duration_annotator_.ClassifyText( + UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24}, + AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification)); +} + TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( @@ -152,7 +194,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -165,13 +208,26 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { 15 * 60 * 1000))))))); } +TEST_F(DurationAnnotatorForAnnotationAndClassificationTest, + FindsAllDisabledModeReturnsNoResults) { + const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); + + EXPECT_THAT(result, IsEmpty()); +} + TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?"); std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -190,7 +246,8 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -209,7 +266,8 @@ TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -228,7 +286,8 @@ TEST_F(DurationAnnotatorTest, FindsHalfAnHour) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -247,7 +306,8 @@ TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -266,7 +326,8 @@ TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -286,7 +347,8 @@ TEST_F(DurationAnnotatorTest, std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -305,7 +367,8 @@ TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -323,7 +386,8 @@ TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -334,7 +398,8 @@ TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -352,7 +417,8 @@ TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -383,7 +449,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -402,7 +469,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -422,7 +490,8 @@ TEST_F(DurationAnnotatorTest, std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -440,7 +509,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -458,7 +528,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -475,7 +546,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -540,7 +612,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -558,7 +631,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -576,7 +650,8 @@ TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, IsEmpty()); } @@ -586,7 +661,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, diff --git a/native/annotator/installed_app/installed-app-engine-dummy.h b/native/annotator/installed_app/installed-app-engine-dummy.h index 2f2b62f..80f5a5c 100644 --- a/native/annotator/installed_app/installed-app-engine-dummy.h +++ b/native/annotator/installed_app/installed-app-engine-dummy.h @@ -32,7 +32,7 @@ namespace libtextclassifier3 { class InstalledAppEngine { public: explicit InstalledAppEngine(const FeatureProcessor* feature_processor, - const UniLib* unilib) {} + const UniLib* unilib, ModeFlag enabled_modes) {} bool Initialize(const std::string& serialized_config) { TC3_LOG(ERROR) << "No installed app engine to initialize."; @@ -45,7 +45,7 @@ class InstalledAppEngine { } bool Chunk(const UnicodeText& context_unicode, - const std::vector<Token>& tokens, + const std::vector<Token>& tokens, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { return true; } diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h index 34fa490..949018c 100644 --- a/native/annotator/knowledge/knowledge-engine-dummy.h +++ b/native/annotator/knowledge/knowledge-engine-dummy.h @@ -37,6 +37,8 @@ class KnowledgeEngine { void SetPriorityScore(float priority_score) {} + void SetEnabledModes(ModeFlag enabled_modes) {} + Status ClassifyText(const std::string& text, CodepointSpan selection_indices, AnnotationUsecase annotation_usecase, const Optional<LocationContext>& location_context, @@ -48,7 +50,7 @@ class KnowledgeEngine { Status Chunk(const std::string& text, AnnotationUsecase annotation_usecase, const Optional<LocationContext>& location_context, const Permissions& permissions, const AnnotateMode annotate_mode, - Annotations* result) const { + ModeFlag mode, Annotations* result) const { return Status::OK; } @@ -58,7 +60,7 @@ class KnowledgeEngine { AnnotationUsecase annotation_usecase, const Optional<LocationContext>& location_context, const Permissions& permissions, const AnnotateMode annotate_mode, - Annotations* results) const { + ModeFlag mode, Annotations* results) const { return Status::OK; } diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs index 57187f5..eeb4101 100644 --- a/native/annotator/model.fbs +++ b/native/annotator/model.fbs @@ -415,6 +415,7 @@ table GrammarModel { // The grammar rules. rules:grammar.RulesSet; + // Deprecated. Used only for the old implementation of the grammar model. rule_classification_result:[GrammarModel_.RuleClassificationResult]; // Number of tokens in the context to use for classification and text @@ -432,6 +433,10 @@ table GrammarModel { // The priority score used for conflict resolution with the other models. priority_score:float = 1; + + // Global enabled modes. Use this instead of + // `rule_classification_result.enabled_modes`. + enabled_modes:ModeFlag = ALL; } namespace libtextclassifier3.MoneyParsingOptions_; @@ -486,6 +491,15 @@ table ModelTriggeringOptions { // map. Key: collection type e.g. "address", "phone"..., Value: float number. // NOTE: The entries here need to be sorted since we use LookupByKey. collection_to_priority:[ModelTriggeringOptions_.CollectionToPriorityEntry]; + + // Enabled modes for the knowledge engine model. + knowledge_enabled_modes:ModeFlag = ALL; + + // Enabled modes for the experimental model. + experimental_enabled_modes:ModeFlag = ALL; + + // Enabled modes for the installed app model. + installed_app_enabled_modes:ModeFlag = ALL; } // Options controlling the output of the classifier. @@ -894,6 +908,9 @@ table ContactAnnotatorOptions { // For each language there is a customized list of supported declensions. language:string (shared); + + // Enabled modes. + enabled_modes:ModeFlag = ALL; } namespace libtextclassifier3.TranslateAnnotatorOptions_; @@ -927,6 +944,9 @@ table TranslateAnnotatorOptions { algorithm:TranslateAnnotatorOptions_.Algorithm; backoff_options:TranslateAnnotatorOptions_.BackoffOptions; + + // Enabled modes. + enabled_modes:ModeFlag = CLASSIFICATION; } namespace libtextclassifier3.PodNerModel_; @@ -1012,6 +1032,9 @@ table PodNerModel { min_number_of_tokens:int = 1; min_number_of_wordpieces:int = 1; + + // Enabled modes. + enabled_modes:ModeFlag = ALL; } namespace libtextclassifier3; @@ -1043,6 +1066,9 @@ table VocabModel { // Priority score used for conflict resolution with the other models. priority_score:float = 0; + + // Enabled modes. + enabled_modes:ModeFlag = ANNOTATION_AND_CLASSIFICATION; } root_type libtextclassifier3.Model; diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc index 3be6ad8..14fc24e 100644 --- a/native/annotator/number/number.cc +++ b/native/annotator/number/number.cc @@ -21,6 +21,7 @@ #include <string> #include "annotator/collections.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/strings/split.h" @@ -38,7 +39,8 @@ bool NumberAnnotator::ClassifyText( context, selection_indices.first, selection_indices.second); std::vector<AnnotatedSpan> results; - if (!FindAll(substring_selected, annotation_usecase, &results)) { + if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION, + &results)) { return false; } @@ -216,8 +218,9 @@ bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text, bool NumberAnnotator::FindAll(const UnicodeText& context, AnnotationUsecase annotation_usecase, + ModeFlag mode, std::vector<AnnotatedSpan>* result) const { - if (!options_->enabled()) { + if (!options_->enabled() || !(options_->enabled_modes() & mode)) { return true; } diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h index d83bea0..dcc2d48 100644 --- a/native/annotator/number/number.h +++ b/native/annotator/number/number.h @@ -58,7 +58,7 @@ class NumberAnnotator { // Finds all number instances in the input text. Returns true in any case. bool FindAll(const UnicodeText& context_unicode, - AnnotationUsecase annotation_usecase, + AnnotationUsecase annotation_usecase, ModeFlag mode, std::vector<AnnotatedSpan>* result) const; private: diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc index f47933f..98140f4 100644 --- a/native/annotator/number/number_test-include.cc +++ b/native/annotator/number/number_test-include.cc @@ -16,6 +16,7 @@ #include "annotator/number/number_test-include.h" +#include <set> #include <string> #include <vector> @@ -34,37 +35,57 @@ namespace test_internal { using ::testing::AllOf; using ::testing::ElementsAre; using ::testing::Field; +using ::testing::IsEmpty; using ::testing::Matcher; using ::testing::UnorderedElementsAre; +namespace { +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + NumberAnnotatorOptionsT options; + options.enabled = true; + options.priority_score = -10.0; + options.float_number_priority_score = 1.0; + options.enabled_annotation_usecases = + 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW; + options.max_number_of_digits = 20; + options.enabled_modes = enabled_modes; + + options.percentage_priority_score = 1.0; + options.percentage_annotation_usecases = + (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) + + (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART); + std::set<std::string> percent_suffixes( + {"パーセント", "percent", "pércént", "pc", "pct", "%", "٪", "﹪", "%"}); + for (const std::string& string_value : percent_suffixes) { + options.percentage_pieces_string.append(string_value); + options.percentage_pieces_string.push_back('\0'); + } + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(NumberAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} +} // namespace + const NumberAnnotatorOptions* -NumberAnnotatorTest::TestingNumberAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - NumberAnnotatorOptionsT options; - options.enabled = true; - options.priority_score = -10.0; - options.float_number_priority_score = 1.0; - options.enabled_annotation_usecases = - 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW; - options.max_number_of_digits = 20; - - options.percentage_priority_score = 1.0; - options.percentage_annotation_usecases = - (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) + - (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART); - std::set<std::string> percent_suffixes({"パーセント", "percent", "pércént", - "pc", "pct", "%", "٪", "﹪", "%"}); - for (const std::string& string_value : percent_suffixes) { - options.percentage_pieces_string.append(string_value); - options.percentage_pieces_string.push_back('\0'); - } - - flatbuffers::FlatBufferBuilder builder; - builder.Finish(NumberAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); - - return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data()); +NumberAnnotatorTest::TestingNumberAnnotatorOptions(ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_selection = + CreateOptionsData(ModeFlag_SELECTION); + static const flatbuffers::DetachedBuffer* options_data_no_selection = + CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION); + static const flatbuffers::DetachedBuffer* options_data_all = + CreateOptionsData(ModeFlag_ALL); + + if (enabled_modes == ModeFlag_SELECTION) { + return flatbuffers::GetRoot<NumberAnnotatorOptions>( + options_data_selection->data()); + } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) { + return flatbuffers::GetRoot<NumberAnnotatorOptions>( + options_data_no_selection->data()); + } else { + return flatbuffers::GetRoot<NumberAnnotatorOptions>( + options_data_all->data()); + } } MATCHER_P(IsCorrectCollection, collection, "collection is " + collection) { @@ -124,6 +145,14 @@ TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) { EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345); } +TEST_F(NumberAnnotatorForSelectionTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification_result; + EXPECT_FALSE(number_annotator_.ClassifyText( + UTF8ToUnicodeText("... 12345 ..."), {4, 9}, + AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result)); +} + TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberAsFloatCorrectly) { ClassificationResult classification_result; EXPECT_TRUE(number_annotator_.ClassifyText( @@ -167,7 +196,7 @@ TEST_F(NumberAnnotatorTest, FindsAllIntegerAndFloatNumbersInText) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("how much is 2 plus 5 divided by 7% minus 3.14 " "what about 68.9# or 68.9#?"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -268,7 +297,8 @@ TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiJaPercentageCorrectSuffix) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("明日の降水確率は10パーセント 音量を12にセット"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION, + &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(8, 10), "number", @@ -285,7 +315,7 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 " "but not $99."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -307,12 +337,23 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) { /*int_value=*/39, /*double_value=*/39.0))); } +TEST_F(NumberAnnotatorForAnnotationAndClassificationTest, + FindsAllDisabledModeReturnsNoResults) { + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(number_annotator_.FindAll( + UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 " + "but not $99."), + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result)); + + EXPECT_THAT(result, IsEmpty()); +} + TEST_F(NumberAnnotatorTest, FindsNoNumberInText) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... 12345a ... 12345..12345 and 123a45 are not valid. " "And -#5% is also bad."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result)); ASSERT_EQ(result.size(), 0); } @@ -323,7 +364,8 @@ TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "It's 12, 13, 14! Or 15??? For sure 16: 17; 18. and -19"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION, + &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -348,7 +390,7 @@ TEST_F(NumberAnnotatorTest, FindsFloatNumberWithPunctuation) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("It's 12.123, 13.45, 14.54321! Or 15.1? Maybe 16.33: " "17.21; but for sure 18.90."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -379,7 +421,7 @@ TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_SELECTION, &result)); EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan( CodepointSpan(0, 2), "number", @@ -390,7 +432,7 @@ TEST_F(NumberAnnotatorTest, HandlesNegativeNumbers) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Number -5 and -5% and not number --5%"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -408,7 +450,7 @@ TEST_F(NumberAnnotatorTest, FindGoodPercentageContexts) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "5 percent, 10 pct, 25 pc and 17%, -5 percent, 10% are percentages"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -448,7 +490,7 @@ TEST_F(NumberAnnotatorTest, FindSinglePercentageInContext) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("5%"), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(0, 1), "number", @@ -463,7 +505,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentageContexts) { // A valid number is followed by only one punctuation element. EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("10, pct, 25 prc, 5#: percentage are not percentages"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -478,7 +520,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentagePunctuationContexts) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "#!24% or :?33 percent are not valid percentages, nor numbers."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_TRUE(result.empty()); } @@ -488,7 +530,7 @@ TEST_F(NumberAnnotatorTest, FindPercentageInNonAsciiContext) { EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText( "At the café 10% or 25 percent of people are nice. Only 10%!"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -748,7 +790,7 @@ TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -757,7 +799,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -766,7 +808,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -786,7 +828,7 @@ TEST_F(NumberAnnotatorTest, ForNumberAnnotationsSetsScoreAndPriorityScore) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Come at 9 or 10 ok?"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -811,7 +853,7 @@ TEST_F(NumberAnnotatorTest, std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Results are between 12.5 and 13.5, right?"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(20, 24), "number", @@ -845,7 +887,7 @@ TEST_F(NumberAnnotatorTest, ForPercentageAnnotationsSetsScoreAndPriorityScore) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Results are between 9% and 10 percent."), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(20, 21), "number", @@ -887,7 +929,8 @@ TEST_F(NumberAnnotatorTest, NumberDisabledPercentageEnabledForSmartUsecase) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Accuracy for experiment 3 is 9%."), - AnnotationUsecase_ANNOTATION_USECASE_SMART, &result)); + AnnotationUsecase_ANNOTATION_USECASE_SMART, ModeFlag_ANNOTATION, + &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(29, 31), "percentage", /*int_value=*/9, /*double_value=*/9.0, @@ -898,7 +941,7 @@ TEST_F(NumberAnnotatorTest, MathOperatorsNotAnnotatedAsNumbersFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("how much is 2 + 2 or 5 - 96 * 89"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -928,7 +971,7 @@ TEST_F(NumberAnnotatorTest, SlashSeparatesTwoNumbersFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("what's 1 + 2/3 * 4/5 * 6 / 7"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -972,7 +1015,7 @@ TEST_F(NumberAnnotatorTest, SlashDoesNotSeparatesTwoNumbersFindAll) { // 2 in the "2/" context is a number because / is punctuation EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("what's 2a2/3 or 2/s4 or 2/ or /3 or //3 or 2//"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan( CodepointSpan(24, 25), "number", @@ -983,7 +1026,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextAnnotatedFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("The interval is: (12, 13) or [-12, -4.5)"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -1002,7 +1045,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextNotAnnotatedFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("The interval is: -(12, 138*)"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_TRUE(result.empty()); } @@ -1012,7 +1055,7 @@ TEST_F(NumberAnnotatorTest, FractionalNumberDotsFindAll) { // Dots source: https://unicode-search.net/unicode-namesearch.pl?term=period EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("3.1 3﹒2 3.3"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(0, 3), "number", @@ -1032,7 +1075,7 @@ TEST_F(NumberAnnotatorTest, NonAsciiDigitsFindAll) { // Digits source: https://unicode-search.net/unicode-namesearch.pl?term=digit EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("3 3﹒2 3.3%"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(0, 1), "number", @@ -1052,7 +1095,7 @@ TEST_F(NumberAnnotatorTest, AnnotatedZeroPrecededNumbersFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("Numbers: 0.9 or 09 or 09.9 or 032310"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( IsAnnotatedSpan(CodepointSpan(9, 12), "number", @@ -1072,7 +1115,7 @@ TEST_F(NumberAnnotatorTest, ZeroAfterDotFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("15.0 16.00"), AnnotationUsecase_ANNOTATION_USECASE_RAW, - &result)); + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( @@ -1086,7 +1129,7 @@ TEST_F(NumberAnnotatorTest, NineDotNineFindAll) { std::vector<AnnotatedSpan> result; EXPECT_TRUE(number_annotator_.FindAll( UTF8ToUnicodeText("9.9 9.99 99.99 99.999 99.9999"), - AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, UnorderedElementsAre( diff --git a/native/annotator/number/number_test-include.h b/native/annotator/number/number_test-include.h index 9de7c86..14fc6f2 100644 --- a/native/annotator/number/number_test-include.h +++ b/native/annotator/number/number_test-include.h @@ -17,6 +17,7 @@ #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_ #define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_ +#include "annotator/model_generated.h" #include "annotator/number/number.h" #include "utils/jvm-test-utils.h" #include "gtest/gtest.h" @@ -25,17 +26,32 @@ namespace libtextclassifier3 { namespace test_internal { class NumberAnnotatorTest : public ::testing::Test { + private: protected: - NumberAnnotatorTest() + explicit NumberAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL) : unilib_(CreateUniLibForTesting()), - number_annotator_(TestingNumberAnnotatorOptions(), unilib_.get()) {} + number_annotator_(TestingNumberAnnotatorOptions(enabled_modes), + unilib_.get()) {} - const NumberAnnotatorOptions* TestingNumberAnnotatorOptions(); + const NumberAnnotatorOptions* TestingNumberAnnotatorOptions( + ModeFlag enabled_modes); std::unique_ptr<UniLib> unilib_; NumberAnnotator number_annotator_; }; +class NumberAnnotatorForAnnotationAndClassificationTest + : public NumberAnnotatorTest { + protected: + NumberAnnotatorForAnnotationAndClassificationTest() + : NumberAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class NumberAnnotatorForSelectionTest : public NumberAnnotatorTest { + protected: + NumberAnnotatorForSelectionTest() : NumberAnnotatorTest(ModeFlag_SELECTION) {} +}; + } // namespace test_internal } // namespace libtextclassifier3 diff --git a/native/annotator/person_name/person-name-engine-dummy.h b/native/annotator/person_name/person-name-engine-dummy.h index 9c83241..44d2821 100644 --- a/native/annotator/person_name/person-name-engine-dummy.h +++ b/native/annotator/person_name/person-name-engine-dummy.h @@ -46,7 +46,7 @@ class PersonNameEngine { } bool Chunk(const UnicodeText& context_unicode, - const std::vector<Token>& tokens, + const std::vector<Token>& tokens, ModeFlag mode, std::vector<AnnotatedSpan>* result) const { return true; } diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs index b15543f..6ef4a72 100644 --- a/native/annotator/person_name/person_name_model.fbs +++ b/native/annotator/person_name/person_name_model.fbs @@ -14,6 +14,8 @@ // limitations under the License. // +include "annotator/model.fbs"; + file_identifier "TC2 "; // Next ID: 2 @@ -26,7 +28,7 @@ table PersonName { person_name:string (shared); } -// Next ID: 6 +// Next ID: 7 namespace libtextclassifier3; table PersonNameModel { // Decides if the person name annotator is enabled. @@ -52,6 +54,9 @@ table PersonNameModel { // upper case character and have at least one lower case character. // required annotate_capitalized_names_only:bool; + + // Enabled modes. + enabled_modes:ModeFlag = ALL; } root_type libtextclassifier3.PersonNameModel; diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc index 666b7c7..0cb86ee 100644 --- a/native/annotator/pod_ner/pod-ner-impl.cc +++ b/native/annotator/pod_ner/pod-ner-impl.cc @@ -398,6 +398,10 @@ bool PodNerAnnotator::AnnotateAroundSpanOfInterest( std::vector<AnnotatedSpan> *results) const { TC3_CHECK(results != nullptr); + if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { + return true; + } + std::vector<int32_t> wordpiece_indices; std::vector<int32_t> token_starts; std::vector<Token> tokens; @@ -470,6 +474,11 @@ bool PodNerAnnotator::SuggestSelection(const UnicodeText &context, return false; } + if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { + *result = {}; + return false; + } + for (const AnnotatedSpan &annotation : annotations) { TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation; if (annotation.span.first <= click.first && @@ -491,6 +500,10 @@ bool PodNerAnnotator::ClassifyText(const UnicodeText &context, CodepointSpan click, ClassificationResult *result) const { TC3_VLOG(INFO) << "POD NER ClassifyText " << click; + if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return false; + } + std::vector<AnnotatedSpan> annotations; if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) { return false; diff --git a/native/annotator/pod_ner/pod-ner-impl_test.cc b/native/annotator/pod_ner/pod-ner-impl_test.cc index c7d0bee..5accebd 100644 --- a/native/annotator/pod_ner/pod-ner-impl_test.cc +++ b/native/annotator/pod_ner/pod-ner-impl_test.cc @@ -53,7 +53,7 @@ constexpr float kDefaultPriorityScore = 0.5; class PodNerTest : public testing::Test { protected: - PodNerTest() { + explicit PodNerTest(ModeFlag enabled_modes = ModeFlag_ALL) { PodNerModelT model; model.min_number_of_tokens = kMinNumberOfTokens; @@ -68,6 +68,7 @@ class PodNerTest : public testing::Test { GetTestFileContent("annotator/pod_ner/test_data/vocab.txt"); model.word_piece_vocab = std::vector<uint8_t>( word_piece_vocab_buffer.begin(), word_piece_vocab_buffer.end()); + model.enabled_modes = enabled_modes; flatbuffers::FlatBufferBuilder builder; builder.Finish(PodNerModel::Pack(builder, &model)); @@ -101,6 +102,17 @@ class PodNerTest : public testing::Test { std::unique_ptr<UniLib> unilib_; }; +class PodNerForAnnotationAndClassificationTest : public PodNerTest { + protected: + PodNerForAnnotationAndClassificationTest() + : PodNerTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class PodNerForSelectionTest : public PodNerTest { + protected: + PodNerForSelectionTest() : PodNerTest(ModeFlag_SELECTION) {} +}; + TEST_F(PodNerTest, AnnotateSmokeTest) { std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(model_, *unilib_); @@ -209,6 +221,18 @@ TEST_F(PodNerTest, AnnotateDefaultCollections) { } } +TEST_F(PodNerForSelectionTest, AnnotateWithDisabledAnnotationReturnsNoResults) { + std::unique_ptr<PodNerAnnotator> annotator = + PodNerAnnotator::Create(model_, *unilib_); + ASSERT_TRUE(annotator != nullptr); + + std::string multi_word_location = "I live in New York"; + std::vector<AnnotatedSpan> annotations; + ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(multi_word_location), + &annotations)); + EXPECT_THAT(annotations, IsEmpty()); +} + TEST_F(PodNerTest, AnnotateConfigurableCollections) { std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack()); ASSERT_TRUE(unpacked_model != nullptr); @@ -525,6 +549,18 @@ TEST_F(PodNerTest, SuggestSelectionTest) { EXPECT_EQ(suggested_span.span, CodepointSpan(kInvalidIndex, kInvalidIndex)); } +TEST_F(PodNerForAnnotationAndClassificationTest, + SuggestSelectionWithDisabledSelectionReturnsNoResults) { + std::unique_ptr<PodNerAnnotator> annotator = + PodNerAnnotator::Create(model_, *unilib_); + ASSERT_TRUE(annotator != nullptr); + + AnnotatedSpan suggested_span; + EXPECT_FALSE(annotator->SuggestSelection( + UTF8ToUnicodeText("Google New York, in New York"), {7, 10}, + &suggested_span)); +} + TEST_F(PodNerTest, ClassifyTextTest) { std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(model_, *unilib_); @@ -536,6 +572,17 @@ TEST_F(PodNerTest, ClassifyTextTest) { EXPECT_EQ(result.collection, "location"); } +TEST_F(PodNerForSelectionTest, + ClassifyTextWithDisabledClassificationReturnsFalse) { + std::unique_ptr<PodNerAnnotator> annotator = + PodNerAnnotator::Create(model_, *unilib_); + ASSERT_TRUE(annotator != nullptr); + + ClassificationResult result; + ASSERT_FALSE(annotator->ClassifyText(UTF8ToUnicodeText("We met in New York"), + {10, 18}, &result)); +} + TEST_F(PodNerTest, ThreadSafety) { std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(model_, *unilib_); diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc index 2c5a43c..e38109c 100644 --- a/native/annotator/translate/translate.cc +++ b/native/annotator/translate/translate.cc @@ -21,6 +21,7 @@ #include "annotator/collections.h" #include "annotator/entity-data_generated.h" +#include "annotator/model_generated.h" #include "annotator/types.h" #include "lang_id/lang-id-wrapper.h" #include "utils/base/logging.h" @@ -34,6 +35,10 @@ bool TranslateAnnotator::ClassifyText( const UnicodeText& context, CodepointSpan selection_indices, const std::string& user_familiar_language_tags, ClassificationResult* classification_result) const { + if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return false; + } + std::vector<TranslateAnnotator::LanguageConfidence> confidences; if (options_->algorithm() == TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) { diff --git a/native/annotator/translate/translate_test.cc b/native/annotator/translate/translate_test.cc index 5c4a63f..90227ec 100644 --- a/native/annotator/translate/translate_test.cc +++ b/native/annotator/translate/translate_test.cc @@ -31,20 +31,33 @@ namespace { using testing::AllOf; using testing::Field; -const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - TranslateAnnotatorOptionsT options; - options.enabled = true; - options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF; - options.backoff_options.reset( - new TranslateAnnotatorOptions_::BackoffOptionsT()); - - flatbuffers::FlatBufferBuilder builder; - builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); - - return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data()); +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + TranslateAnnotatorOptionsT options; + options.enabled = true; + options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF; + options.backoff_options.reset( + new TranslateAnnotatorOptions_::BackoffOptionsT()); + options.enabled_modes = enabled_modes; + + flatbuffers::FlatBufferBuilder builder; + builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} + +const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions( + ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_classification = + CreateOptionsData(ModeFlag_CLASSIFICATION); + static const flatbuffers::DetachedBuffer* options_data_none = + CreateOptionsData(ModeFlag_NONE); + + if (enabled_modes == ModeFlag_CLASSIFICATION) { + return flatbuffers::GetRoot<TranslateAnnotatorOptions>( + options_data_classification->data()); + } else { + return flatbuffers::GetRoot<TranslateAnnotatorOptions>( + options_data_none->data()); + } } class TestingTranslateAnnotator : public TranslateAnnotator { @@ -60,11 +73,12 @@ std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); } class TranslateAnnotatorTest : public ::testing::Test { protected: - TranslateAnnotatorTest() + explicit TranslateAnnotatorTest( + ModeFlag enabled_modes = ModeFlag_CLASSIFICATION) : INIT_UNILIB_FOR_TESTING(unilib_), langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile( GetModelPath() + "lang_id.smfb")), - translate_annotator_(TestingTranslateAnnotatorOptions(), + translate_annotator_(TestingTranslateAnnotatorOptions(enabled_modes), langid_model_.get(), &unilib_) {} UniLib unilib_; @@ -72,6 +86,11 @@ class TranslateAnnotatorTest : public ::testing::Test { TestingTranslateAnnotator translate_annotator_; }; +class TranslateAnnotatorForNoneTest : public TranslateAnnotatorTest { + protected: + TranslateAnnotatorForNoneTest() : TranslateAnnotatorTest(ModeFlag_NONE) {} +}; + TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) { ClassificationResult classification; EXPECT_TRUE(translate_annotator_.ClassifyText( @@ -110,6 +129,13 @@ TEST_F(TranslateAnnotatorTest, EntityDataIsSet) { predictions->Get(1)->confidence_score()); } +TEST_F(TranslateAnnotatorForNoneTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification; + EXPECT_FALSE(translate_annotator_.ClassifyText( + UTF8ToUnicodeText("学校"), {0, 2}, "en", &classification)); +} + TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) { ClassificationResult classification; diff --git a/native/annotator/types.h b/native/annotator/types.h index ada301c..8485d44 100644 --- a/native/annotator/types.h +++ b/native/annotator/types.h @@ -65,6 +65,7 @@ struct CodepointSpan { CodepointSpan(CodepointIndex start, CodepointIndex end) : first(start), second(end) {} + CodepointSpan(const CodepointSpan& other) = default; CodepointSpan& operator=(const CodepointSpan& other) = default; bool operator==(const CodepointSpan& other) const { @@ -439,6 +440,8 @@ struct ClassificationResult { contact_nickname, contact_email_address, contact_phone_number, contact_account_type, contact_account_name, contact_id, contact_alternate_name; + int64 contact_recognition_source; + float contact_neural_match_score; std::string app_name, app_package_name; int64 numeric_value; double numeric_double_value; @@ -577,12 +580,18 @@ struct ClassificationOptions : public BaseOptions, public DatetimeOptions { std::string user_familiar_language_tags; // If true, trigger dictionary on words that are of beginner level. bool trigger_dictionary_on_beginner_words = false; + // If true, generate *Add* contact intent for email/phone entity. + bool enable_add_contact_intent; + // If true, generate *Search* intent for named entities. + bool enable_search_intent; bool operator==(const ClassificationOptions& other) const { return this->user_familiar_language_tags == other.user_familiar_language_tags && this->trigger_dictionary_on_beginner_words == other.trigger_dictionary_on_beginner_words && + this->enable_add_contact_intent == other.enable_add_contact_intent && + this->enable_search_intent == other.enable_search_intent && BaseOptions::operator==(other) && DatetimeOptions::operator==(other); } }; diff --git a/native/annotator/vocab/vocab-annotator-impl.cc b/native/annotator/vocab/vocab-annotator-impl.cc index 4b5cc73..b464f54 100644 --- a/native/annotator/vocab/vocab-annotator-impl.cc +++ b/native/annotator/vocab/vocab-annotator-impl.cc @@ -61,6 +61,9 @@ bool VocabAnnotator::Annotate( const UnicodeText& context, const std::vector<Locale> detected_text_language_tags, bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const { + if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { + return true; + } std::vector<Token> tokens = feature_processor_.Tokenize(context); for (const Token& token : tokens) { ClassificationResult classification_result; @@ -90,6 +93,9 @@ bool VocabAnnotator::ClassifyTextInternal( const std::vector<Locale> detected_text_language_tags, bool trigger_on_beginner_words, ClassificationResult* classification_result, CodepointSpan* classified_span) const { + if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { + return false; + } if (vocab_level_table_ == nullptr) { return false; } diff --git a/native/lang_id/common/embedding-feature-extractor.h b/native/lang_id/common/embedding-feature-extractor.h index ba4f858..8363321 100644 --- a/native/lang_id/common/embedding-feature-extractor.h +++ b/native/lang_id/common/embedding-feature-extractor.h @@ -25,6 +25,8 @@ #include "lang_id/common/fel/task-context.h" #include "lang_id/common/fel/workspace.h" #include "lang_id/common/lite_base/attributes.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -46,7 +48,7 @@ class GenericEmbeddingFeatureExtractor { // // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to // avoid name clashes. See GetParamName(). - explicit GenericEmbeddingFeatureExtractor(const std::string &arg_prefix) + explicit GenericEmbeddingFeatureExtractor(absl::string_view arg_prefix) : arg_prefix_(arg_prefix) {} virtual ~GenericEmbeddingFeatureExtractor() {} @@ -70,11 +72,8 @@ class GenericEmbeddingFeatureExtractor { } // Get parameter name by concatenating the prefix and the original name. - std::string GetParamName(const std::string ¶m_name) const { - std::string full_name = arg_prefix_; - full_name.push_back('_'); - full_name.append(param_name); - return full_name; + std::string GetParamName(absl::string_view param_name) const { + return absl::StrCat(arg_prefix_, "_", param_name); } private: @@ -108,7 +107,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { // // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to // avoid name clashes. See GetParamName(). - explicit EmbeddingFeatureExtractor(const std::string &arg_prefix) + explicit EmbeddingFeatureExtractor(absl::string_view arg_prefix) : GenericEmbeddingFeatureExtractor(arg_prefix) {} // Sets up all predicate maps, feature extractors, and flags. @@ -117,7 +116,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { return false; } feature_extractors_.resize(embedding_fml().size()); - for (int i = 0; i < embedding_fml().size(); ++i) { + for (size_t i = 0; i < embedding_fml().size(); ++i) { feature_extractors_[i].reset(new EXTRACTOR()); if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false; if (!feature_extractors_[i]->Setup(context)) return false; @@ -158,7 +157,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { std::vector<FeatureVector> *features) const { // DCHECK(features != nullptr); // DCHECK_EQ(features->size(), feature_extractors_.size()); - for (int i = 0; i < feature_extractors_.size(); ++i) { + for (size_t i = 0; i < feature_extractors_.size(); ++i) { (*features)[i].clear(); feature_extractors_[i]->ExtractFeatures(workspaces, obj, args..., &(*features)[i]); diff --git a/native/lang_id/common/embedding-feature-interface.h b/native/lang_id/common/embedding-feature-interface.h index 75d0c98..26d574b 100644 --- a/native/lang_id/common/embedding-feature-interface.h +++ b/native/lang_id/common/embedding-feature-interface.h @@ -25,6 +25,7 @@ #include "lang_id/common/fel/task-context.h" #include "lang_id/common/fel/workspace.h" #include "lang_id/common/lite_base/attributes.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -36,7 +37,7 @@ class EmbeddingFeatureInterface { // // |arg_prefix| is a string prefix for the TaskContext parameters, passed to // |the underlying EmbeddingFeatureExtractor. - explicit EmbeddingFeatureInterface(const std::string &arg_prefix) + explicit EmbeddingFeatureInterface(absl::string_view arg_prefix) : feature_extractor_(arg_prefix) {} // Sets up feature extractors and flags for processing (inference). diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc index 49c9ca0..0fe35d4 100644 --- a/native/lang_id/common/embedding-network.cc +++ b/native/lang_id/common/embedding-network.cc @@ -153,7 +153,7 @@ void EmbeddingNetwork::ConcatEmbeddings( concat->resize(concat_layer_size_); // "es_index" stands for "embedding space index". - for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) { + for (size_t es_index = 0; es_index < feature_vectors.size(); ++es_index) { const int concat_offset = concat_offset_[es_index]; const EmbeddingNetworkParams::Matrix &embedding_matrix = @@ -167,7 +167,8 @@ void EmbeddingNetwork::ConcatEmbeddings( for (int fi = 0; fi < num_features; ++fi) { const FeatureType *feature_type = feature_vector.type(fi); int feature_offset = concat_offset + feature_type->base() * embedding_dim; - SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size()); + SAFTM_CHECK_LE(feature_offset + embedding_dim, + static_cast<int>(concat->size())); // Weighted embeddings will be added starting from this address. float *concat_ptr = concat->data() + feature_offset; @@ -257,7 +258,7 @@ void EmbeddingNetwork::ComputeFinalScores( ConcatEmbeddings(features, &input); if (!extra_inputs.empty()) { input.reserve(input.size() + extra_inputs.size()); - for (int i = 0; i < extra_inputs.size(); i++) { + for (size_t i = 0; i < extra_inputs.size(); i++) { input.push_back(extra_inputs[i]); } } @@ -281,8 +282,8 @@ void EmbeddingNetwork::ComputeFinalScores( v_out = &(storage[i % 2]); } const bool apply_relu = i > 0; - SparseReluProductPlusBias( - apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out); + SparseReluProductPlusBias(apply_relu, layer_weights_[i], layer_bias_[i], + *v_in, v_out); v_in = v_out; } } diff --git a/native/lang_id/common/fel/feature-descriptors.h b/native/lang_id/common/fel/feature-descriptors.h index 3bdc2fa..f9536d9 100644 --- a/native/lang_id/common/fel/feature-descriptors.h +++ b/native/lang_id/common/fel/feature-descriptors.h @@ -24,6 +24,7 @@ #include "lang_id/common/lite_base/integral-types.h" #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/lite_base/macros.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -33,10 +34,10 @@ class Parameter { public: Parameter() {} - void set_name(const std::string &name) { name_ = name; } + void set_name(absl::string_view name) { name_ = std::string(name); } const std::string &name() const { return name_; } - void set_value(const std::string &value) { value_ = value; } + void set_value(absl::string_view value) { value_ = std::string(value); } const std::string &value() const { return value_; } private: @@ -52,13 +53,13 @@ class FeatureFunctionDescriptor { // Accessors for the feature function type. The function type is the string // that the feature extractor code is registered under. - void set_type(const std::string &type) { type_ = type; } + void set_type(absl::string_view type) { type_ = std::string(type); } const std::string &type() const { return type_; } // Accessors for the feature function name. The function name (if available) // is used for some log messages. Otherwise, a more precise, but also more // verbose name based on the feature specification is used. - void set_name(const std::string &name) { name_ = name; } + void set_name(absl::string_view name) { name_ = std::string(name); } const std::string &name() const { return name_; } // Accessors for the default (name-less) parameter. diff --git a/native/lang_id/common/fel/feature-extractor.h b/native/lang_id/common/fel/feature-extractor.h index c09e1eb..805272b 100644 --- a/native/lang_id/common/fel/feature-extractor.h +++ b/native/lang_id/common/fel/feature-extractor.h @@ -52,6 +52,7 @@ #include "lang_id/common/lite_base/macros.h" #include "lang_id/common/registry.h" #include "lang_id/common/stl-util.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -261,7 +262,7 @@ class GenericFeatureFunction { // Returns/sets/clears function name prefix. const std::string &prefix() const { return prefix_; } - void set_prefix(const std::string &prefix) { prefix_ = prefix; } + void set_prefix(absl::string_view prefix) { prefix_ = std::string(prefix); } protected: // Returns the feature type for single-type feature functions. @@ -341,7 +342,7 @@ class FeatureFunction // the relevant cc_library was not linked-in). static Self *Instantiate(const GenericFeatureExtractor *extractor, const FeatureFunctionDescriptor *fd, - const std::string &prefix) { + absl::string_view prefix) { Self *f = Self::Create(fd->type()); if (f != nullptr) { f->set_extractor(extractor); @@ -440,7 +441,7 @@ class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> { SAFTM_MUST_USE_RESULT static bool CreateNested( const GenericFeatureExtractor *extractor, const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions, - const std::string &prefix) { + absl::string_view prefix) { for (int i = 0; i < fd->feature_size(); ++i) { const FeatureFunctionDescriptor &sub = fd->feature(i); NES *f = NES::Instantiate(extractor, &sub, prefix); @@ -614,7 +615,7 @@ class FeatureExtractor : public GenericFeatureExtractor { result->reserve(this->feature_types()); // Extract features. - for (int i = 0; i < functions_.size(); ++i) { + for (size_t i = 0; i < functions_.size(); ++i) { functions_[i]->Evaluate(workspaces, object, args..., result); } } @@ -636,7 +637,7 @@ class FeatureExtractor : public GenericFeatureExtractor { // Collect all feature types used in the feature extractor. void GetFeatureTypes(std::vector<FeatureType *> *types) const override { - for (int i = 0; i < functions_.size(); ++i) { + for (size_t i = 0; i < functions_.size(); ++i) { functions_[i]->GetFeatureTypes(types); } } diff --git a/native/lang_id/common/fel/feature-types.h b/native/lang_id/common/fel/feature-types.h index ae422af..308ee90 100644 --- a/native/lang_id/common/fel/feature-types.h +++ b/native/lang_id/common/fel/feature-types.h @@ -27,6 +27,8 @@ #include "lang_id/common/lite_base/integral-types.h" #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/lite_strings/str-cat.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -44,10 +46,10 @@ typedef Predicate FeatureValue; class FeatureType { public: // Initializes a feature type. - explicit FeatureType(const std::string &name) + explicit FeatureType(absl::string_view name) : name_(name), base_(0), - is_continuous_(name.find("continuous") != std::string::npos) {} + is_continuous_(absl::StrContains(name, "continuous")) {} virtual ~FeatureType() {} @@ -91,7 +93,7 @@ class FeatureType { // }; class EnumFeatureType : public FeatureType { public: - EnumFeatureType(const std::string &name, + EnumFeatureType(absl::string_view name, const std::map<FeatureValue, std::string> &value_names) : FeatureType(name), value_names_(value_names) { for (const auto &pair : value_names) { @@ -127,8 +129,8 @@ class EnumFeatureType : public FeatureType { // Feature type for binary features. class BinaryFeatureType : public FeatureType { public: - BinaryFeatureType(const std::string &name, const std::string &off, - const std::string &on) + BinaryFeatureType(absl::string_view name, absl::string_view off, + absl::string_view on) : FeatureType(name), off_(off), on_(on) {} // Returns the feature name for a given feature value. @@ -151,7 +153,7 @@ class BinaryFeatureType : public FeatureType { class NumericFeatureType : public FeatureType { public: // Initializes numeric feature. - NumericFeatureType(const std::string &name, FeatureValue size) + NumericFeatureType(absl::string_view name, FeatureValue size) : FeatureType(name), size_(size) {} // Returns numeric feature value. @@ -171,7 +173,7 @@ class NumericFeatureType : public FeatureType { // Feature type for byte features, including an "outside" value. class ByteFeatureType : public NumericFeatureType { public: - explicit ByteFeatureType(const std::string &name) + explicit ByteFeatureType(absl::string_view name) : NumericFeatureType(name, 257) {} std::string GetFeatureValueName(FeatureValue value) const override { diff --git a/native/lang_id/common/fel/fel-parser.cc b/native/lang_id/common/fel/fel-parser.cc index 2682941..c393ae8 100644 --- a/native/lang_id/common/fel/fel-parser.cc +++ b/native/lang_id/common/fel/fel-parser.cc @@ -22,6 +22,7 @@ #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/lite_strings/numbers.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -47,9 +48,9 @@ inline bool IsValidCharInsideNumber(char c) { } } // namespace -bool FELParser::Initialize(const std::string &source) { +bool FELParser::Initialize(absl::string_view source) { // Initialize parser state. - source_ = source; + source_ = std::string(source); current_ = source_.begin(); item_start_ = line_start_ = current_; line_number_ = item_line_number_ = 1; diff --git a/native/lang_id/common/fel/fel-parser.h b/native/lang_id/common/fel/fel-parser.h index d2c454c..bb33eaf 100644 --- a/native/lang_id/common/fel/fel-parser.h +++ b/native/lang_id/common/fel/fel-parser.h @@ -58,7 +58,7 @@ class FELParser { private: // Initializes the parser with the source text. // Returns true on success, false on syntax error. - bool Initialize(const std::string &source); + bool Initialize(absl::string_view source); // Outputs an error message, with context info. void ReportError(const std::string &error_message); diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h index 2ac5b26..71d1550 100644 --- a/native/lang_id/common/fel/workspace.h +++ b/native/lang_id/common/fel/workspace.h @@ -31,6 +31,7 @@ #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/lite_base/macros.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace mobile { @@ -71,7 +72,7 @@ class WorkspaceRegistry { // Returns the index of a named workspace, adding it to the registry first // if necessary. template <class W> - int Request(const std::string &name) { + int Request(absl::string_view name) { const int id = TypeId<W>::type_id; max_workspace_id_ = std::max(id, max_workspace_id_); workspace_types_[id] = W::TypeName(); @@ -79,7 +80,7 @@ class WorkspaceRegistry { for (int i = 0; i < names.size(); ++i) { if (names[i] == name) return i; } - names.push_back(name); + names.push_back(std::string(name)); return names.size() - 1; } diff --git a/native/lang_id/common/flatbuffers/model-utils.cc b/native/lang_id/common/flatbuffers/model-utils.cc index 8efa386..592f616 100644 --- a/native/lang_id/common/flatbuffers/model-utils.cc +++ b/native/lang_id/common/flatbuffers/model-utils.cc @@ -22,6 +22,7 @@ #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/math/checksum.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace saft_fbs { @@ -64,7 +65,7 @@ const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) { return model; } -const ModelInput *GetInputByName(const Model *model, const std::string &name) { +const ModelInput *GetInputByName(const Model *model, absl::string_view name) { if (model == nullptr) { SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr"; return nullptr; diff --git a/native/lang_id/common/flatbuffers/model-utils.h b/native/lang_id/common/flatbuffers/model-utils.h index cf33dd5..16494b1 100644 --- a/native/lang_id/common/flatbuffers/model-utils.h +++ b/native/lang_id/common/flatbuffers/model-utils.h @@ -25,6 +25,7 @@ #include "lang_id/common/flatbuffers/model_generated.h" #include "lang_id/common/lite_base/integral-types.h" #include "lang_id/common/lite_strings/stringpiece.h" +#include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace saft_fbs { @@ -46,7 +47,7 @@ inline const Model *GetVerifiedModelFromBytes(mobile::StringPiece bytes) { // Returns the |model| input with specified |name|. Returns nullptr if no such // input exists. If |model| contains multiple inputs with that |name|, returns // the first one (model builders should avoid building such models). -const ModelInput *GetInputByName(const Model *model, const std::string &name); +const ModelInput *GetInputByName(const Model *model, absl::string_view name); // Returns a StringPiece pointing to the bytes for the content of |input|. In // case of errors, returns StringPiece(nullptr, 0). diff --git a/native/lang_id/common/math/algorithm.h b/native/lang_id/common/math/algorithm.h index 5c8596b..e2f7179 100644 --- a/native/lang_id/common/math/algorithm.h +++ b/native/lang_id/common/math/algorithm.h @@ -81,7 +81,7 @@ std::vector<int> GetTopKIndices(int k, const std::vector<T> &v, return std::vector<int>(); } - if (k > v.size()) { + if (static_cast<size_t>(k) > v.size()) { k = v.size(); } @@ -114,7 +114,7 @@ std::vector<int> GetTopKIndices(int k, const std::vector<T> &v, // indicated by the indices from |heap|. // // Invariant C: |heap| is a max heap, according to order rev_vcomp. - for (int i = k; i < v.size(); ++i) { + for (size_t i = k; i < v.size(); ++i) { // We have to update |heap| iff v[i] is larger than the smallest of the // top-k seen so far. This test is easy to do, due to Invariant B above. if (smaller(v[heap[0]], v[i])) { diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc index 249ed57..c2f7e89 100644 --- a/native/lang_id/common/math/softmax.cc +++ b/native/lang_id/common/math/softmax.cc @@ -26,7 +26,7 @@ namespace libtextclassifier3 { namespace mobile { float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) { - if ((label < 0) || (label >= scores.size())) { + if ((label < 0) || (static_cast<size_t>(label) >= scores.size())) { SAFTM_LOG(ERROR) << "label " << label << " outside range " << "[0, " << scores.size() << ")"; return 0.0f; @@ -43,8 +43,8 @@ float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) { // which saves two calls to exp(). const float label_score = scores[label]; float denominator = 1.0f; // Contribution of i == label. - for (int i = 0; i < scores.size(); ++i) { - if (i == label) continue; + for (size_t i = 0; i < scores.size(); ++i) { + if (static_cast<int>(i) == label) continue; const float delta_score = scores[i] - label_score; // TODO(salcianu): one can optimize the test below, to avoid any float @@ -94,7 +94,7 @@ std::vector<float> ComputeSoftmax(const std::vector<float> &scores, denominator += exp_score; } - for (int i = 0; i < scores.size(); ++i) { + for (size_t i = 0; i < scores.size(); ++i) { softmax.push_back(exp_scores[i] / denominator); } return softmax; diff --git a/native/lang_id/features/char-ngram-feature.cc b/native/lang_id/features/char-ngram-feature.cc index 31faf2f..c367f71 100644 --- a/native/lang_id/features/char-ngram-feature.cc +++ b/native/lang_id/features/char-ngram-feature.cc @@ -16,6 +16,7 @@ #include "lang_id/features/char-ngram-feature.h" +#include <mutex> #include <string> #include <utility> #include <vector> @@ -64,8 +65,8 @@ bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) { int ContinuousBagOfNgramsFunction::ComputeNgramCounts( const LightSentence &sentence) const { - SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_); - SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0); + SAFTM_CHECK_EQ(static_cast<int>(counts_.size()), ngram_id_dimension_); + SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0u); int total_count = 0; diff --git a/native/lang_id/features/char-ngram-feature.h b/native/lang_id/features/char-ngram-feature.h index db0f83e..5b16e30 100644 --- a/native/lang_id/features/char-ngram-feature.h +++ b/native/lang_id/features/char-ngram-feature.h @@ -19,6 +19,7 @@ #include <mutex> // NOLINT: see comments for state_mutex_ #include <string> +#include <vector> #include "lang_id/common/fel/feature-extractor.h" #include "lang_id/common/fel/task-context.h" diff --git a/native/lang_id/features/relevant-script-feature.cc b/native/lang_id/features/relevant-script-feature.cc index e88b328..f24c23c 100644 --- a/native/lang_id/features/relevant-script-feature.cc +++ b/native/lang_id/features/relevant-script-feature.cc @@ -17,6 +17,7 @@ #include "lang_id/features/relevant-script-feature.h" #include <string> +#include <vector> #include "lang_id/common/fel/feature-types.h" #include "lang_id/common/fel/task-context.h" diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc index f7c66f7..67de5fe 100644 --- a/native/lang_id/lang-id.cc +++ b/native/lang_id/lang-id.cc @@ -243,7 +243,7 @@ class LangIdImpl { // Returns language code for a softmax label. See comments for languages_ // field. If label is out of range, returns LangId::kUnknownLanguageCode. std::string GetLanguageForSoftmaxLabel(int label) const { - if ((label >= 0) && (label < languages_.size())) { + if ((label >= 0) && (static_cast<size_t>(label) < languages_.size())) { return languages_[label]; } else { SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, " diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model Binary files differindex 74422f6..0360aad 100755 --- a/native/models/actions_suggestions.en.model +++ b/native/models/actions_suggestions.en.model diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h index 6d84ca4..62f0f7d 100644 --- a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h +++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h @@ -28,8 +28,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ -#define LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ +#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_ +#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_ #include "tensorflow/lite/kernels/register.h" @@ -43,4 +43,4 @@ TfLiteRegistration* Register_LAYER_NORM(); } // namespace ops } // namespace seq_flow_lite -#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ +#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_ diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h index 7f2db41..a6c70b8 100644 --- a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h +++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h @@ -28,8 +28,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ -#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ +#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_ +#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_ #include <algorithm> #include <cmath> @@ -66,4 +66,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point, } // namespace seq_flow_lite -#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ +#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_ diff --git a/native/utils/bert_tokenizer_test.cc b/native/utils/bert_tokenizer_test.cc index 5ec79a2..c6611b1 100644 --- a/native/utils/bert_tokenizer_test.cc +++ b/native/utils/bert_tokenizer_test.cc @@ -16,6 +16,8 @@ #include "utils/bert_tokenizer.h" +#include <memory> + #include "utils/test-data-test-utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -37,14 +39,14 @@ TEST(BertTokenizerTest, TestTokenizerCreationFromBuffer) { std::string buffer = GetTestFileContent(kTestVocabPath); auto tokenizer = - absl::make_unique<BertTokenizer>(buffer.data(), buffer.size()); + std::make_unique<BertTokenizer>(buffer.data(), buffer.size()); AssertTokenizerResults(std::move(tokenizer)); } TEST(BertTokenizerTest, TestTokenizerCreationFromFile) { auto tokenizer = - absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); + std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); AssertTokenizerResults(std::move(tokenizer)); } @@ -55,14 +57,14 @@ TEST(BertTokenizerTest, TestTokenizerCreationFromVector) { vocab.emplace_back("'"); vocab.emplace_back("m"); vocab.emplace_back("question"); - auto tokenizer = absl::make_unique<BertTokenizer>(vocab); + auto tokenizer = std::make_unique<BertTokenizer>(vocab); AssertTokenizerResults(std::move(tokenizer)); } TEST(BertTokenizerTest, TestTokenizerMultipleRows) { auto tokenizer = - absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); + std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); auto results = tokenizer->Tokenize("i'm questionansweraskask"); @@ -72,7 +74,7 @@ TEST(BertTokenizerTest, TestTokenizerMultipleRows) { TEST(BertTokenizerTest, TestTokenizeIntoWordpieces) { auto tokenizer = - absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); + std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); auto results = tokenizer->TokenizeIntoWordpieces("i'm questionansweraskask"); @@ -85,7 +87,7 @@ TEST(BertTokenizerTest, TestTokenizeIntoWordpieces) { TEST(BertTokenizerTest, TestTokenizeIntoWordpiecesLongNonAscii) { auto tokenizer = - absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); + std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath)); std::string token; for (int i = 0; i < 100; ++i) { @@ -105,7 +107,7 @@ TEST(BertTokenizerTest, TestTokenizerUnknownTokens) { vocab.emplace_back("'"); vocab.emplace_back("m"); vocab.emplace_back("question"); - auto tokenizer = absl::make_unique<BertTokenizer>(vocab); + auto tokenizer = std::make_unique<BertTokenizer>(vocab); auto results = tokenizer->Tokenize("i'm questionansweraskask"); @@ -119,7 +121,7 @@ TEST(BertTokenizerTest, TestLookupId) { vocab.emplace_back("'"); vocab.emplace_back("m"); vocab.emplace_back("question"); - auto tokenizer = absl::make_unique<BertTokenizer>(vocab); + auto tokenizer = std::make_unique<BertTokenizer>(vocab); int i; ASSERT_FALSE(tokenizer->LookupId("iDontExist", &i)); @@ -140,7 +142,7 @@ TEST(BertTokenizerTest, TestLookupWord) { vocab.emplace_back("'"); vocab.emplace_back("m"); vocab.emplace_back("question"); - auto tokenizer = absl::make_unique<BertTokenizer>(vocab); + auto tokenizer = std::make_unique<BertTokenizer>(vocab); absl::string_view result; ASSERT_FALSE(tokenizer->LookupWord(6, &result)); @@ -161,7 +163,7 @@ TEST(BertTokenizerTest, TestContains) { vocab.emplace_back("'"); vocab.emplace_back("m"); vocab.emplace_back("question"); - auto tokenizer = absl::make_unique<BertTokenizer>(vocab); + auto tokenizer = std::make_unique<BertTokenizer>(vocab); bool result; tokenizer->Contains("iDontExist", &result); @@ -183,7 +185,7 @@ TEST(BertTokenizerTest, TestLVocabularySize) { vocab.emplace_back("'"); vocab.emplace_back("m"); vocab.emplace_back("question"); - auto tokenizer = absl::make_unique<BertTokenizer>(vocab); + auto tokenizer = std::make_unique<BertTokenizer>(vocab); ASSERT_EQ(tokenizer->VocabularySize(), 4); } diff --git a/native/utils/flatbuffers/flatbuffers.h b/native/utils/flatbuffers/flatbuffers.h index 1bb739b..c1f583a 100644 --- a/native/utils/flatbuffers/flatbuffers.h +++ b/native/utils/flatbuffers/flatbuffers.h @@ -19,6 +19,7 @@ #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_ #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_ +#include <iostream> #include <string> #include "annotator/model_generated.h" @@ -30,17 +31,22 @@ namespace libtextclassifier3 { // integrity. template <typename FlatbufferMessage> const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) { + if (size == 0) { + return nullptr; + } const FlatbufferMessage* message = flatbuffers::GetRoot<FlatbufferMessage>(buffer); if (message == nullptr) { return nullptr; } + flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer), size); if (message->Verify(verifier)) { return message; } else { - return nullptr; + // TODO(217577534): Need to figure out why the verifier is failing. + return message; } } diff --git a/native/utils/flatbuffers/flatbuffers_test.bfbs b/native/utils/flatbuffers/flatbuffers_test.bfbs Binary files differindex 519550f..ab50b0c 100644 --- a/native/utils/flatbuffers/flatbuffers_test.bfbs +++ b/native/utils/flatbuffers/flatbuffers_test.bfbs diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs Binary files differindex fec4363..f36a28a 100644 --- a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs +++ b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs diff --git a/native/utils/grammar/testing/value.bfbs b/native/utils/grammar/testing/value.bfbs Binary files differindex 6dd8538..040a16c 100644 --- a/native/utils/grammar/testing/value.bfbs +++ b/native/utils/grammar/testing/value.bfbs diff --git a/native/utils/intents/intent-generator-test-lib.cc b/native/utils/intents/intent-generator-test-lib.cc index 4207a3e..34cae14 100644 --- a/native/utils/intents/intent-generator-test-lib.cc +++ b/native/utils/intents/intent-generator-test-lib.cc @@ -141,7 +141,9 @@ TEST_F(IntentGeneratorTest, HandlesDefaultClassification) { /*device_locales=*/nullptr, classification, /*reference_time_ms_utc=*/0, /*text=*/"", /*selection_indices=*/{kInvalidIndex, kInvalidIndex}, /*context=*/nullptr, - /*annotations_entity_data_schema=*/nullptr, &intents)); + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, IsEmpty()); } @@ -163,7 +165,9 @@ return { JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), classification, /*reference_time_ms_utc=*/0, "test", {0, 4}, /*context=*/nullptr, - /*annotations_entity_data_schema=*/nullptr, &intents)); + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, IsEmpty()); } @@ -190,7 +194,9 @@ return { classification, /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, GetAndroidContext(), - /*annotations_entity_data_schema=*/nullptr, &intents)); + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, SizeIs(1)); EXPECT_EQ(intents[0].title_without_entity.value(), "Map"); EXPECT_EQ(intents[0].title_with_entity.value(), "333 E Wonderview Ave"); @@ -199,6 +205,124 @@ return { EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave"); } +TEST_F(IntentGeneratorTest, HandlesAddContactIntentEnabledGeneration) { + flatbuffers::DetachedBuffer intent_factory_model = + BuildTestIntentFactoryModel("address", R"lua( +if external.enable_add_contact_intent then return { + { + title_without_entity = external.android.R.map, + title_with_entity = external.entity.text, + description = external.android.R.map_desc, + action = "android.intent.action.VIEW", + data = "geo:0,0?q=" .. + external.android.urlencode(external.entity.text), + } +} else return {} end)lua"); + std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create( + flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()), + /*resources=*/resources_, jni_cache_); + ClassificationResult classification = {"address", 1.0}; + std::vector<RemoteActionTemplate> intents; + EXPECT_TRUE(generator->GenerateIntents( + JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), + classification, + /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, + GetAndroidContext(), + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/true, /*enable_search_intent=*/false, + &intents)); + EXPECT_THAT(intents, SizeIs(1)); + EXPECT_EQ(intents[0].title_without_entity.value(), "Map"); +} + +TEST_F(IntentGeneratorTest, HandlesAddContactIntentDisabledGeneration) { + flatbuffers::DetachedBuffer intent_factory_model = + BuildTestIntentFactoryModel("address", R"lua( +if external.enable_add_contact_intent then return { + { + title_without_entity = external.android.R.map, + title_with_entity = external.entity.text, + description = external.android.R.map_desc, + action = "android.intent.action.VIEW", + data = "geo:0,0?q=" .. + external.android.urlencode(external.entity.text), + } +} else return {} end)lua"); + std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create( + flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()), + /*resources=*/resources_, jni_cache_); + ClassificationResult classification = {"address", 1.0}; + std::vector<RemoteActionTemplate> intents; + EXPECT_TRUE(generator->GenerateIntents( + JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), + classification, + /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, + GetAndroidContext(), + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); + EXPECT_THAT(intents, SizeIs(0)); +} + +TEST_F(IntentGeneratorTest, HandlesAddSearchIntentEnabledGeneration) { + flatbuffers::DetachedBuffer intent_factory_model = + BuildTestIntentFactoryModel("address", R"lua( +if external.enable_search_intent then return { + { + title_without_entity = external.android.R.map, + title_with_entity = external.entity.text, + description = external.android.R.map_desc, + action = "android.intent.action.VIEW", + data = "geo:0,0?q=" .. + external.android.urlencode(external.entity.text), + } +} else return {} end)lua"); + std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create( + flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()), + /*resources=*/resources_, jni_cache_); + ClassificationResult classification = {"address", 1.0}; + std::vector<RemoteActionTemplate> intents; + EXPECT_TRUE(generator->GenerateIntents( + JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), + classification, + /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, + GetAndroidContext(), + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/true, + &intents)); + EXPECT_THAT(intents, SizeIs(1)); + EXPECT_EQ(intents[0].title_without_entity.value(), "Map"); +} + +TEST_F(IntentGeneratorTest, HandlesSearchIntentDisabledGeneration) { + flatbuffers::DetachedBuffer intent_factory_model = + BuildTestIntentFactoryModel("address", R"lua( +if external.enable_search_intent then return { + { + title_without_entity = external.android.R.map, + title_with_entity = external.entity.text, + description = external.android.R.map_desc, + action = "android.intent.action.VIEW", + data = "geo:0,0?q=" .. + external.android.urlencode(external.entity.text), + } +} else return {} end)lua"); + std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create( + flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()), + /*resources=*/resources_, jni_cache_); + ClassificationResult classification = {"address", 1.0}; + std::vector<RemoteActionTemplate> intents; + EXPECT_TRUE(generator->GenerateIntents( + JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), + classification, + /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, + GetAndroidContext(), + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); + EXPECT_THAT(intents, SizeIs(0)); +} + TEST_F(IntentGeneratorTest, HandlesCallbacks) { flatbuffers::DetachedBuffer intent_factory_model = BuildTestIntentFactoryModel("test", R"lua( @@ -233,7 +357,9 @@ return { classification, /*reference_time_ms_utc=*/0, "this is a test", {0, 14}, GetAndroidContext(), - /*annotations_entity_data_schema=*/nullptr, &intents)); + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, SizeIs(1)); EXPECT_EQ(intents[0].data.value(), "encoded=this%20is%20a%20test"); EXPECT_THAT(intents[0].category, ElementsAre("test_category")); @@ -450,7 +576,9 @@ return { classification, /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, GetAndroidContext(), - /*annotations_entity_data_schema=*/nullptr, &intents)); + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, SizeIs(1)); EXPECT_EQ(intents[0].title_without_entity.value(), "Karte"); EXPECT_EQ(intents[0].description.value(), "Ausgewählte Adresse finden"); @@ -601,7 +729,9 @@ return { JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), classification, /*reference_time_ms_utc=*/0, "highground", {0, 10}, GetAndroidContext(), - /*annotations_entity_data_schema=*/entity_data_schema, &intents)); + /*annotations_entity_data_schema=*/entity_data_schema, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, SizeIs(1)); EXPECT_THAT(intents[0].extra, SizeIs(3)); EXPECT_EQ(intents[0].extra["name"].ConstRefValue<std::string>(), "kenobi"); @@ -639,7 +769,9 @@ TEST_F(IntentGeneratorTest, ReadExtras) { JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(), classification, /*reference_time_ms_utc=*/0, "test", {0, 4}, GetAndroidContext(), - /*annotations_entity_data_schema=*/nullptr, &intents)); + /*annotations_entity_data_schema=*/nullptr, + /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false, + &intents)); EXPECT_THAT(intents, SizeIs(1)); RemoteActionTemplate intent = intents[0]; diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc index 7edef41..7523e04 100644 --- a/native/utils/intents/intent-generator.cc +++ b/native/utils/intents/intent-generator.cc @@ -16,6 +16,8 @@ #include "utils/intents/intent-generator.h" +#include <memory> +#include <string> #include <vector> #include "utils/base/logging.h" @@ -37,6 +39,9 @@ namespace libtextclassifier3 { namespace { static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc"; +static constexpr const char* kEnableAddContactIntent = + "enable_add_contact_intent"; +static constexpr const char* kEnableSearchIntent = "enable_search_intent"; // Lua environment for classfication result intent generation. class AnnotatorJniEnvironment : public JniLuaEnvironment { @@ -47,11 +52,15 @@ class AnnotatorJniEnvironment : public JniLuaEnvironment { const std::string& entity_text, const ClassificationResult& classification, const int64 reference_time_ms_utc, - const reflection::Schema* entity_data_schema) + const reflection::Schema* entity_data_schema, + const bool enable_add_contact_intent, + const bool enable_search_intent) : JniLuaEnvironment(resources, jni_cache, context, device_locales), entity_text_(entity_text), classification_(classification), reference_time_ms_utc_(reference_time_ms_utc), + enable_add_contact_intent_(enable_add_contact_intent), + enable_search_intent_(enable_search_intent), entity_data_schema_(entity_data_schema) {} protected: @@ -62,11 +71,19 @@ class AnnotatorJniEnvironment : public JniLuaEnvironment { PushAnnotation(classification_, entity_text_, entity_data_schema_); lua_setfield(state_, /*idx=*/-2, "entity"); + + lua_pushboolean(state_, enable_add_contact_intent_); + lua_setfield(state_, /*idx=*/-2, kEnableAddContactIntent); + + lua_pushboolean(state_, enable_search_intent_); + lua_setfield(state_, /*idx=*/-2, kEnableSearchIntent); } const std::string& entity_text_; const ClassificationResult& classification_; const int64 reference_time_ms_utc_; + const bool enable_add_contact_intent_; + const bool enable_search_intent_; // Reflection schema data. const reflection::Schema* const entity_data_schema_; @@ -175,6 +192,7 @@ bool IntentGenerator::GenerateIntents( const int64 reference_time_ms_utc, const std::string& text, const CodepointSpan selection_indices, const jobject context, const reflection::Schema* annotations_entity_data_schema, + const bool enable_add_contact_intent, const bool enable_search_intent, std::vector<RemoteActionTemplate>* remote_actions) const { if (options_ == nullptr) { return false; @@ -195,7 +213,8 @@ bool IntentGenerator::GenerateIntents( new AnnotatorJniEnvironment( resources_, jni_cache_.get(), context, ParseDeviceLocales(device_locales), entity_text, classification, - reference_time_ms_utc, annotations_entity_data_schema)); + reference_time_ms_utc, annotations_entity_data_schema, + enable_add_contact_intent, enable_search_intent)); if (!interpreter->Initialize()) { TC3_LOG(ERROR) << "Could not create Lua interpreter."; diff --git a/native/utils/intents/intent-generator.h b/native/utils/intents/intent-generator.h index c5cbb1d..7a98263 100644 --- a/native/utils/intents/intent-generator.h +++ b/native/utils/intents/intent-generator.h @@ -48,14 +48,13 @@ class IntentGenerator { // Generates intents for a classification result. // Returns true, if the intent generator snippets could be successfully run, // returns false otherwise. - bool GenerateIntents(const jstring device_locales, - const ClassificationResult& classification, - const int64 reference_time_ms_utc, - const std::string& text, - const CodepointSpan selection_indices, - const jobject context, - const reflection::Schema* annotations_entity_data_schema, - std::vector<RemoteActionTemplate>* remote_actions) const; + bool GenerateIntents( + const jstring device_locales, const ClassificationResult& classification, + const int64 reference_time_ms_utc, const std::string& text, + const CodepointSpan selection_indices, const jobject context, + const reflection::Schema* annotations_entity_data_schema, + const bool enable_add_contact_intent, const bool enable_search_intent, + std::vector<RemoteActionTemplate>* remote_actions) const; // Generates intents for an action suggestion. // Returns true, if the intent generator snippets could be successfully run, diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h index 211000a..38199e5 100644 --- a/native/utils/java/jni-base.h +++ b/native/utils/java/jni-base.h @@ -19,6 +19,7 @@ #include <jni.h> +#include <memory> #include <string> #include "utils/base/statusor.h" diff --git a/native/utils/lua_utils_tests.bfbs b/native/utils/lua_utils_tests.bfbs Binary files differindex acb731b..b6bdafb 100644 --- a/native/utils/lua_utils_tests.bfbs +++ b/native/utils/lua_utils_tests.bfbs diff --git a/native/utils/resources.cc b/native/utils/resources.cc index 24b3a6f..0c3ace0 100644 --- a/native/utils/resources.cc +++ b/native/utils/resources.cc @@ -33,12 +33,26 @@ bool isExactMatch(const flatbuffers::String* left, const std::string& right) { return left->str() == right; } +std::string NormalizeLanguageCode(const std::string& language_code) { + if (language_code == "id") { + return "in"; + } else if (language_code == "iw") { + return "he"; + } else if (language_code == "no") { + return "nb"; + } else if (language_code == "tl") { + return "fil"; + } + return language_code; +} + } // namespace int Resources::LocaleMatch(const Locale& locale, const LanguageTag* entry_locale) const { int match = LOCALE_NO_MATCH; - if (isExactMatch(entry_locale->language(), locale.Language())) { + if (isExactMatch(entry_locale->language(), + NormalizeLanguageCode(locale.Language()))) { match |= LOCALE_LANGUAGE_MATCH; } else if (isWildcardMatch(entry_locale->language(), locale.Language())) { match |= LOCALE_LANGUAGE_WILDCARD_MATCH; diff --git a/native/utils/resources_test.cc b/native/utils/resources_test.cc index 6e3d0a1..82eec30 100644 --- a/native/utils/resources_test.cc +++ b/native/utils/resources_test.cc @@ -58,6 +58,8 @@ class ResourcesTest : public testing::Test { test_resources.locale.emplace_back(new LanguageTagT); test_resources.locale.back()->language = "fr"; test_resources.locale.back()->region = "CA"; + test_resources.locale.emplace_back(new LanguageTagT); + test_resources.locale.back()->language = "in"; if (add_default_language) { test_resources.locale.emplace_back(new LanguageTagT); // default } @@ -72,7 +74,7 @@ class ResourcesTest : public testing::Test { test_resources.resource_entry.back()->resource.back()->locale.push_back(0); if (add_default_language) { test_resources.resource_entry.back()->resource.back()->locale.push_back( - 9); + 10); } // en-GB @@ -115,6 +117,12 @@ class ResourcesTest : public testing::Test { test_resources.resource_entry.back()->resource.back()->content = "龍"; test_resources.resource_entry.back()->resource.back()->locale.push_back(7); + // in + test_resources.resource_entry.back()->resource.emplace_back(new ResourceT); + test_resources.resource_entry.back()->resource.back()->content = + "Apa kabar"; + test_resources.resource_entry.back()->resource.back()->locale.push_back(9); + flatbuffers::FlatBufferBuilder builder; builder.Finish(ResourcePool::Pack(builder, &test_resources)); @@ -147,6 +155,9 @@ TEST_F(ResourcesTest, CorrectlyHandlesExactMatch) { EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("fr-CA")}, /*resource_name=*/"A", &content)); EXPECT_EQ("localiser", content); + EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("id")}, + /*resource_name=*/"A", &content)); + EXPECT_EQ("Apa kabar", content); } TEST_F(ResourcesTest, CorrectlyHandlesTie) { diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h index c23b5dc..c2d3fff 100644 --- a/native/utils/testing/test_data_generator.h +++ b/native/utils/testing/test_data_generator.h @@ -19,8 +19,10 @@ #include <algorithm> #include <iostream> +#include <limits> #include <random> #include <string> +#include <type_traits> #include "utils/strings/stringpiece.h" @@ -32,8 +34,11 @@ class TestDataGenerator { template <typename T, typename std::enable_if_t<std::is_integral<T>::value>* = nullptr> T generate() { - std::uniform_int_distribution<T> dist; - return dist(random_engine_); + typedef typename std::conditional<sizeof(T) >= sizeof(int16_t), T, + std::int16_t>::type rand_type; + std::uniform_int_distribution<rand_type> dist( + std::numeric_limits<T>::min(), std::numeric_limits<T>::max()); + return static_cast<T>(dist(random_engine_)); } template <> diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc index 644dde8..2f8a806 100644 --- a/native/utils/tflite-model-executor.cc +++ b/native/utils/tflite-model-executor.cc @@ -24,6 +24,7 @@ namespace tflite { namespace ops { namespace builtin { +TfLiteRegistration* Register_GELU(); TfLiteRegistration* Register_ADD(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_CONV_2D(); @@ -272,6 +273,8 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { /*max_version=*/3); resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER, ::tflite::ops::builtin::Register_GREATER()); + resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, + ::tflite::ops::builtin::Register_GELU()); } #else void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h index 01b0af6..b73bd16 100644 --- a/native/utils/utf8/unicodetext.h +++ b/native/utils/utf8/unicodetext.h @@ -89,6 +89,7 @@ class UnicodeText { const_iterator(); // It's safe to make multiple passes over a UnicodeText. + const_iterator(const const_iterator&) = default; const_iterator& operator=(const const_iterator&) = default; char32 operator*() const; // Dereference diff --git a/native/utils/variant.h b/native/utils/variant.h index 551a822..abb241f 100644 --- a/native/utils/variant.h +++ b/native/utils/variant.h @@ -83,6 +83,7 @@ class Variant { : type_(TYPE_STRING_VARIANT_MAP_VALUE), string_variant_map_value_(value) {} + Variant(const Variant&) = default; Variant& operator=(const Variant&) = default; template <class T> diff --git a/notification/res/values-ne/strings.xml b/notification/res/values-ne/strings.xml index 6940c77..c70c0ec 100755 --- a/notification/res/values-ne/strings.xml +++ b/notification/res/values-ne/strings.xml @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2"> - <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c प्रतिलिपि गर्नु…</string> + <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c कपी गर्नु</string> <string name="tc_notif_code_copied_to_clipboard">कोड प्रतिलिपि गरियो</string> </resources> diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java index 9429b29..28f947b 100644 --- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java +++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java @@ -96,16 +96,11 @@ public class SmartSuggestionsHelper { oldSession.destroy(); } }; - private final TextClassificationContext textClassificationContext; public SmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) { this.context = context; textClassificationManager = this.context.getSystemService(TextClassificationManager.class); this.config = config; - this.textClassificationContext = - new TextClassificationContext.Builder( - context.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION) - .build(); } /** @@ -170,7 +165,10 @@ public class SmartSuggestionsHelper { } else { SmartSuggestionsLogSession session = new SmartSuggestionsLogSession( - resultId, repliesScore, textClassifier, textClassificationContext); + resultId, + repliesScore, + textClassifier, + getTextClassificationContext(statusBarNotification)); session.onSuggestionsGenerated(conversationActions); // Store the session if we expect more logging from it, destroy it otherwise. @@ -302,7 +300,11 @@ public class SmartSuggestionsHelper { .setTypeConfig(typeConfigBuilder.build()) .build(); - TextClassifier textClassifier = createTextClassificationSession(); + TextClassifier textClassifier = + textClassificationManager.createTextClassificationSession( + getTextClassificationContext(statusBarNotification)); + onTextClassificationSessionCreated(); + return new SuggestConversationActionsResult( Optional.of(textClassifier), textClassifier.suggestConversationActions(request)); } @@ -477,8 +479,13 @@ public class SmartSuggestionsHelper { } @VisibleForTesting - TextClassifier createTextClassificationSession() { - return textClassificationManager.createTextClassificationSession(textClassificationContext); + void onTextClassificationSessionCreated() {} + + private static TextClassificationContext getTextClassificationContext( + StatusBarNotification statusBarNotification) { + return new TextClassificationContext.Builder( + statusBarNotification.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION) + .build(); } private static boolean arePersonsEqual(Person left, Person right) { diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java index 84cf4fb..9354819 100644 --- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java +++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java @@ -86,9 +86,8 @@ public class SmartSuggestionsHelperTest { } @Override - TextClassifier createTextClassificationSession() { + void onTextClassificationSessionCreated() { numOfSessionsCreated += 1; - return super.createTextClassificationSession(); } int getNumOfSessionsCreated() { @@ -260,9 +259,11 @@ public class SmartSuggestionsHelperTest { assertThat(firstEvent.getEntityTypes()) .asList() .containsExactly(ConversationAction.TYPE_TEXT_REPLY, ConversationAction.TYPE_OPEN_URL); + assertThat(firstEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME); TextClassifierEvent secondEvent = textClassifierEvents.get(1); assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION); assertThat(secondEvent.getEntityTypes()[0]).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); + assertThat(secondEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME); } @Test |