summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-09-14 05:38:41 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-09-14 05:38:41 +0000
commitdf9234b8dfd77f79aa64f81af44aa532c9de0a0f (patch)
tree101bfc731fc453160421993180e07d015e40bc56
parent15d8bd33b54a5beec49dd759686341ee05e4e109 (diff)
parent1075b1e4e39ab4af90deb3758e5631943c07d47e (diff)
downloadlibtextclassifier-df9234b8dfd77f79aa64f81af44aa532c9de0a0f.tar.gz
Snap for 9061588 from 1075b1e4e39ab4af90deb3758e5631943c07d47e to mainline-permission-releaseaml_per_331115020
Change-Id: Iaaf56083744c9820adbab66355934e389b2145e1
-rw-r--r--java/src/com/android/textclassifier/TextClassifierImpl.java15
-rw-r--r--java/src/com/android/textclassifier/common/TextClassifierSettings.java2
-rw-r--r--java/src/com/android/textclassifier/common/logging/TextClassificationContext.java2
-rw-r--r--java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java14
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java5
-rw-r--r--java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java2
-rw-r--r--java/src/com/android/textclassifier/utils/IndentingPrintWriter.java6
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java1
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java2
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java3
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java4
-rw-r--r--jni/Android.bp5
-rw-r--r--jni/com/google/android/textclassifier/AnnotatorModel.java71
-rw-r--r--native/FlatBufferHeaders.bp186
-rw-r--r--native/JavaTests.bp32
-rw-r--r--native/actions/actions-entity-data.bfbsbin888 -> 1232 bytes
-rw-r--r--native/actions/actions-suggestions.cc34
-rw-r--r--native/actions/actions-suggestions.h17
-rw-r--r--native/actions/actions-suggestions_test.cc21
-rw-r--r--native/actions/actions_model.fbs13
-rw-r--r--native/actions/test_data/actions_suggestions_grammar_test.modelbin145176 -> 145616 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.live_relay.modelbin4720560 -> 4712944 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.modelbin3387360 -> 3385008 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_9heads.modelbin3874704 -> 3866880 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.modelbin3812304 -> 3808080 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.modelbin0 -> 10192720 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.modelbin3853520 -> 3848720 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.modelbin4671840 -> 4667088 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.modelbin5045408 -> 5035952 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.sensitive_tflite.modelbin7111552 -> 7106288 bytes
-rw-r--r--native/annotator/annotator.cc69
-rw-r--r--native/annotator/annotator.h3
-rw-r--r--native/annotator/annotator_jni.cc4
-rw-r--r--native/annotator/annotator_jni_common.cc17
-rw-r--r--native/annotator/collections.h30
-rw-r--r--native/annotator/contact/contact-engine-dummy.h4
-rw-r--r--native/annotator/datetime/grammar-parser.cc11
-rw-r--r--native/annotator/datetime/grammar-parser.h5
-rw-r--r--native/annotator/datetime/grammar-parser_test.cc26
-rw-r--r--native/annotator/duration/duration.cc14
-rw-r--r--native/annotator/duration/duration.h2
-rw-r--r--native/annotator/duration/duration_test.cc174
-rw-r--r--native/annotator/installed_app/installed-app-engine-dummy.h4
-rw-r--r--native/annotator/knowledge/knowledge-engine-dummy.h6
-rw-r--r--native/annotator/model.fbs26
-rw-r--r--native/annotator/number/number.cc7
-rw-r--r--native/annotator/number/number.h2
-rw-r--r--native/annotator/number/number_test-include.cc157
-rw-r--r--native/annotator/number/number_test-include.h22
-rw-r--r--native/annotator/person_name/person-name-engine-dummy.h2
-rw-r--r--native/annotator/person_name/person_name_model.fbs7
-rw-r--r--native/annotator/pod_ner/pod-ner-impl.cc13
-rw-r--r--native/annotator/pod_ner/pod-ner-impl_test.cc49
-rw-r--r--native/annotator/translate/translate.cc5
-rw-r--r--native/annotator/translate/translate_test.cc58
-rw-r--r--native/annotator/types.h9
-rw-r--r--native/annotator/vocab/vocab-annotator-impl.cc6
-rw-r--r--native/lang_id/common/embedding-feature-extractor.h17
-rw-r--r--native/lang_id/common/embedding-feature-interface.h3
-rw-r--r--native/lang_id/common/embedding-network.cc11
-rw-r--r--native/lang_id/common/fel/feature-descriptors.h9
-rw-r--r--native/lang_id/common/fel/feature-extractor.h11
-rw-r--r--native/lang_id/common/fel/feature-types.h16
-rw-r--r--native/lang_id/common/fel/fel-parser.cc5
-rw-r--r--native/lang_id/common/fel/fel-parser.h2
-rw-r--r--native/lang_id/common/fel/workspace.h5
-rw-r--r--native/lang_id/common/flatbuffers/model-utils.cc3
-rw-r--r--native/lang_id/common/flatbuffers/model-utils.h3
-rw-r--r--native/lang_id/common/math/algorithm.h4
-rw-r--r--native/lang_id/common/math/softmax.cc8
-rw-r--r--native/lang_id/features/char-ngram-feature.cc5
-rw-r--r--native/lang_id/features/char-ngram-feature.h1
-rw-r--r--native/lang_id/features/relevant-script-feature.cc1
-rw-r--r--native/lang_id/lang-id.cc2
-rwxr-xr-xnative/models/actions_suggestions.en.modelbin3891632 -> 3886640 bytes
-rw-r--r--native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h6
-rw-r--r--native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h6
-rw-r--r--native/utils/bert_tokenizer_test.cc24
-rw-r--r--native/utils/flatbuffers/flatbuffers.h8
-rw-r--r--native/utils/flatbuffers/flatbuffers_test.bfbsbin1872 -> 2376 bytes
-rw-r--r--native/utils/flatbuffers/flatbuffers_test_extended.bfbsbin1912 -> 2424 bytes
-rw-r--r--native/utils/grammar/testing/value.bfbsbin984 -> 1328 bytes
-rw-r--r--native/utils/intents/intent-generator-test-lib.cc146
-rw-r--r--native/utils/intents/intent-generator.cc23
-rw-r--r--native/utils/intents/intent-generator.h15
-rw-r--r--native/utils/java/jni-base.h1
-rw-r--r--native/utils/lua_utils_tests.bfbsbin1332 -> 1720 bytes
-rw-r--r--native/utils/testing/test_data_generator.h9
-rw-r--r--native/utils/tflite-model-executor.cc3
-rw-r--r--native/utils/utf8/unicodetext.h1
-rw-r--r--native/utils/variant.h1
-rwxr-xr-xnotification/res/values-ne/strings.xml2
-rw-r--r--notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java25
-rw-r--r--notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java5
94 files changed, 1152 insertions, 396 deletions
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 2b6c396..abe8994 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -199,6 +199,8 @@ final class TextClassifierImpl {
.setDetectedTextLanguageTags(detectLanguageTags)
.setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
.setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .setEnableAddContactIntent(false)
+ .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext))
.build(),
// Passing null here to suppress intent generation.
// TODO: Use an explicit flag to suppress it.
@@ -254,6 +256,8 @@ final class TextClassifierImpl {
.setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
.setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
.setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .setEnableAddContactIntent(false)
+ .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext))
.build(),
context,
getResourceLocalesString());
@@ -769,4 +773,15 @@ final class TextClassifierImpl {
strippedIntent.setComponent(null);
return strippedIntent;
}
+
+ private static boolean shouldEnableSearchIntent(
+ @Nullable TextClassificationContext textClassificationContext) {
+ if (textClassificationContext == null) {
+ return false;
+ }
+ String widgetType = textClassificationContext.getWidgetType();
+ // Exclude WebView because there is already a *Web Search* chip there.
+ return !(TextClassifier.WIDGET_TYPE_WEBVIEW.equals(widgetType)
+ || TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW.equals(widgetType));
+ }
}
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
index 205680d..d0ea917 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -525,7 +525,7 @@ public final class TextClassifierSettings {
variantsMapBuilder.put(modelLanguageTag, urlFlagValue);
}
}
- return variantsMapBuilder.build();
+ return variantsMapBuilder.buildOrThrow();
}
public String getTestingLocaleListOverride() {
diff --git a/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java b/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java
index e729201..5e572a7 100644
--- a/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java
+++ b/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java
@@ -18,6 +18,7 @@ package com.android.textclassifier.common.logging;
import androidx.annotation.NonNull;
import com.google.common.base.Preconditions;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.util.Locale;
import javax.annotation.Nullable;
@@ -91,6 +92,7 @@ public final class TextClassificationContext {
*
* @return this builder
*/
+ @CanIgnoreReturnValue
public Builder setWidgetVersion(@Nullable String widgetVersion) {
this.widgetVersion = widgetVersion;
return this;
diff --git a/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java b/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java
index f34fb3d..ef82afc 100644
--- a/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java
+++ b/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java
@@ -19,6 +19,7 @@ package com.android.textclassifier.common.logging;
import android.os.Bundle;
import androidx.annotation.IntDef;
import com.google.common.base.Preconditions;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Arrays;
@@ -327,6 +328,7 @@ public abstract class TextClassifierEvent {
*
* <p>See {@link Locale#toLanguageTag()}
*/
+ @CanIgnoreReturnValue
public T setEntityTypes(String... entityTypes) {
Preconditions.checkNotNull(entityTypes);
this.entityTypes = new String[entityTypes.length];
@@ -335,12 +337,14 @@ public abstract class TextClassifierEvent {
}
/** Sets the event context. */
+ @CanIgnoreReturnValue
public T setEventContext(@Nullable TextClassificationContext eventContext) {
this.eventContext = eventContext;
return self();
}
/** Sets the id of the text classifier result related to this event. */
+ @CanIgnoreReturnValue
@Nonnull
public T setResultId(@Nullable String resultId) {
this.resultId = resultId;
@@ -348,6 +352,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the index of this event in the series of events it belongs to. */
+ @CanIgnoreReturnValue
@Nonnull
public T setEventIndex(int eventIndex) {
this.eventIndex = eventIndex;
@@ -355,6 +360,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the scores of the suggestions. */
+ @CanIgnoreReturnValue
@Nonnull
public T setScores(@Nonnull float... scores) {
Preconditions.checkNotNull(scores);
@@ -364,6 +370,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the model name string. */
+ @CanIgnoreReturnValue
@Nonnull
public T setModelName(@Nullable String modelVersion) {
modelName = modelVersion;
@@ -395,6 +402,7 @@ public abstract class TextClassifierEvent {
*
* @see android.view.textclassifier.TextClassification#getActions()
*/
+ @CanIgnoreReturnValue
@Nonnull
public T setActionIndices(@Nonnull int... actionIndices) {
this.actionIndices = new int[actionIndices.length];
@@ -403,6 +411,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the detected locale. */
+ @CanIgnoreReturnValue
@Nonnull
public T setLocale(@Nullable Locale locale) {
this.locale = locale;
@@ -416,6 +425,7 @@ public abstract class TextClassifierEvent {
* the internals of this bundle as it may have unexpected consequences on the clients of the
* built event object. For similar reasons, avoid depending on mutable objects in this bundle.
*/
+ @CanIgnoreReturnValue
@Nonnull
public T setExtras(@Nonnull Bundle extras) {
this.extras = Preconditions.checkNotNull(extras);
@@ -578,6 +588,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the relative word index of the start of the selection. */
+ @CanIgnoreReturnValue
@Nonnull
public Builder setRelativeWordStartIndex(int relativeWordStartIndex) {
this.relativeWordStartIndex = relativeWordStartIndex;
@@ -585,6 +596,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the relative word (exclusive) index of the end of the selection. */
+ @CanIgnoreReturnValue
@Nonnull
public Builder setRelativeWordEndIndex(int relativeWordEndIndex) {
this.relativeWordEndIndex = relativeWordEndIndex;
@@ -592,6 +604,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the relative word index of the start of the smart selection. */
+ @CanIgnoreReturnValue
@Nonnull
public Builder setRelativeSuggestedWordStartIndex(int relativeSuggestedWordStartIndex) {
this.relativeSuggestedWordStartIndex = relativeSuggestedWordStartIndex;
@@ -599,6 +612,7 @@ public abstract class TextClassifierEvent {
}
/** Sets the relative word (exclusive) index of the end of the smart selection. */
+ @CanIgnoreReturnValue
@Nonnull
public Builder setRelativeSuggestedWordEndIndex(int relativeSuggestedWordEndIndex) {
this.relativeSuggestedWordEndIndex = relativeSuggestedWordEndIndex;
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
index 3db0815..ff04aea 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
@@ -232,9 +232,10 @@ public final class ModelDownloadWorker extends ListenableWorker {
downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl));
}
manifestsToDownloadBuilder.put(
- modelType, ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.build()));
+ modelType,
+ ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.buildOrThrow()));
}
- manifestsToDownload = manifestsToDownloadBuilder.build();
+ manifestsToDownload = manifestsToDownloadBuilder.buildOrThrow();
return Futures.whenAllComplete(downloadResultFutures)
.call(
diff --git a/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java b/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java
index 7416b00..0bc37f0 100644
--- a/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java
+++ b/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java
@@ -86,7 +86,7 @@ final class TextClassifierDownloadLogger {
ModelDownloadException.FAILED_TO_VALIDATE_MODEL,
TextClassifierStatsLog
.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_VALIDATE_MODEL)
- .build();
+ .buildOrThrow();
// Reasons to schedule
public static final int REASON_TO_SCHEDULE_TCS_STARTED =
diff --git a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
index bd48c22..41e6fcd 100644
--- a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
+++ b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
@@ -17,6 +17,7 @@
package com.android.textclassifier.utils;
import com.google.common.base.Preconditions;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.PrintWriter;
/**
@@ -36,6 +37,7 @@ public final class IndentingPrintWriter {
}
/** Prints a string. */
+ @CanIgnoreReturnValue
public IndentingPrintWriter println(String string) {
writer.print(currentIndent);
writer.print(string);
@@ -44,12 +46,14 @@ public final class IndentingPrintWriter {
}
/** Prints a empty line */
+ @CanIgnoreReturnValue
public IndentingPrintWriter println() {
writer.println();
return this;
}
/** Increases indents for subsequent texts. */
+ @CanIgnoreReturnValue
public IndentingPrintWriter increaseIndent() {
indentBuilder.append(SINGLE_INDENT);
currentIndent = indentBuilder.toString();
@@ -57,6 +61,7 @@ public final class IndentingPrintWriter {
}
/** Decreases indents for subsequent texts. */
+ @CanIgnoreReturnValue
public IndentingPrintWriter decreaseIndent() {
indentBuilder.delete(0, SINGLE_INDENT.length());
currentIndent = indentBuilder.toString();
@@ -64,6 +69,7 @@ public final class IndentingPrintWriter {
}
/** Prints a key-valued pair. */
+ @CanIgnoreReturnValue
public IndentingPrintWriter printPair(String key, Object value) {
println(String.format("%s=%s", key, String.valueOf(value)));
return this;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index c20ec8a..89d405a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -44,7 +44,6 @@ import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
import androidx.collection.LruCache;
-import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SdkSuppress;
import androidx.test.filters.SmallTest;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
index e261158..e74c7db 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
@@ -18,7 +18,6 @@ package com.android.textclassifier.downloader;
import static com.google.common.truth.Truth.assertThat;
-import android.util.Log;
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassification.Request;
import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
@@ -168,7 +167,6 @@ public class ModelDownloaderIntegrationTest {
.getTextClassifier()
.classifyText(new Request.Builder(text, 0, text.length()).build());
// The result id contains the name of the just used model.
- Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
assertThat(textClassification.getId()).contains(expectedVersion);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
index 76d04e0..8c99baf 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
@@ -23,6 +23,7 @@ import static org.testng.Assert.expectThrows;
import androidx.test.core.app.ApplicationProvider;
import com.google.android.downloader.DownloadConstraints;
+import com.google.android.downloader.DownloadDestination;
import com.google.android.downloader.DownloadRequest;
import com.google.android.downloader.DownloadResult;
import com.google.android.downloader.Downloader;
@@ -82,7 +83,7 @@ public final class ModelDownloaderServiceImplTest {
targetModelFile.deleteOnExit();
targetMetadataFile.deleteOnExit();
- when(downloader.newRequestBuilder(any(), any()))
+ when(downloader.newRequestBuilder(any(), any(DownloadDestination.class)))
.thenReturn(
DownloadRequest.newBuilder()
.setUri(URI.create(DOWNLOAD_URI))
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
index f3ad833..8ce4643 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
@@ -31,6 +31,7 @@ import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import androidx.test.core.app.ApplicationProvider;
import com.google.common.base.Preconditions;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
@@ -66,6 +67,7 @@ public final class FakeContextBuilder {
*
* <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
*/
+ @CanIgnoreReturnValue
public FakeContextBuilder setIntentComponent(
String intentAction, @Nullable ComponentName component) {
Preconditions.checkNotNull(intentAction);
@@ -74,6 +76,7 @@ public final class FakeContextBuilder {
}
/** Sets the app label res for a specified package. */
+ @CanIgnoreReturnValue
public FakeContextBuilder setAppLabel(String packageName, @Nullable CharSequence appLabel) {
Preconditions.checkNotNull(packageName);
appLabels.put(packageName, appLabel);
@@ -85,6 +88,7 @@ public final class FakeContextBuilder {
*
* <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
*/
+ @CanIgnoreReturnValue
public FakeContextBuilder setAllIntentComponent(@Nullable ComponentName component) {
allIntentComponent = component;
return this;
diff --git a/jni/Android.bp b/jni/Android.bp
index 4300d8e..e9ec887 100644
--- a/jni/Android.bp
+++ b/jni/Android.bp
@@ -24,7 +24,10 @@ package {
java_library_static {
name: "libtextclassifier-java",
srcs: ["**/*.java"],
- static_libs: ["guava"],
+ static_libs: [
+ "guava",
+ "error_prone_annotations",
+ ],
sdk_version: "system_current",
min_sdk_version: "28",
}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index 47a369e..a82c96d 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -17,6 +17,7 @@
package com.google.android.textclassifier;
import android.content.res.AssetFileDescriptor;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
@@ -654,41 +655,49 @@ public final class AnnotatorModel implements AutoCloseable {
private boolean usePodNer = true;
private boolean useVocabAnnotator = true;
+ @CanIgnoreReturnValue
public Builder setLocales(@Nullable String locales) {
this.locales = locales;
return this;
}
+ @CanIgnoreReturnValue
public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
this.detectedTextLanguageTags = detectedTextLanguageTags;
return this;
}
+ @CanIgnoreReturnValue
public Builder setAnnotationUsecase(int annotationUsecase) {
this.annotationUsecase = annotationUsecase;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationLat(double userLocationLat) {
this.userLocationLat = userLocationLat;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationLng(double userLocationLng) {
this.userLocationLng = userLocationLng;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUsePodNer(boolean usePodNer) {
this.usePodNer = usePodNer;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
this.useVocabAnnotator = useVocabAnnotator;
return this;
@@ -761,6 +770,8 @@ public final class AnnotatorModel implements AutoCloseable {
private final boolean usePodNer;
private final boolean triggerDictionaryOnBeginnerWords;
private final boolean useVocabAnnotator;
+ private final boolean enableAddContactIntent;
+ private final boolean enableSearchIntent;
private ClassificationOptions(
long referenceTimeMsUtc,
@@ -774,7 +785,9 @@ public final class AnnotatorModel implements AutoCloseable {
String userFamiliarLanguageTags,
boolean usePodNer,
boolean triggerDictionaryOnBeginnerWords,
- boolean useVocabAnnotator) {
+ boolean useVocabAnnotator,
+ boolean enableAddContactIntent,
+ boolean enableSearchIntent) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
@@ -787,6 +800,8 @@ public final class AnnotatorModel implements AutoCloseable {
this.usePodNer = usePodNer;
this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
this.useVocabAnnotator = useVocabAnnotator;
+ this.enableAddContactIntent = enableAddContactIntent;
+ this.enableSearchIntent = enableSearchIntent;
}
/** Can be used to build a ClassificationOptions instance. */
@@ -803,68 +818,94 @@ public final class AnnotatorModel implements AutoCloseable {
private boolean usePodNer = true;
private boolean triggerDictionaryOnBeginnerWords = false;
private boolean useVocabAnnotator = true;
+ private boolean enableAddContactIntent = false;
+ private boolean enableSearchIntent = false;
+ @CanIgnoreReturnValue
public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
return this;
}
+ @CanIgnoreReturnValue
public Builder setReferenceTimezone(String referenceTimezone) {
this.referenceTimezone = referenceTimezone;
return this;
}
+ @CanIgnoreReturnValue
public Builder setLocales(@Nullable String locales) {
this.locales = locales;
return this;
}
+ @CanIgnoreReturnValue
public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
this.detectedTextLanguageTags = detectedTextLanguageTags;
return this;
}
+ @CanIgnoreReturnValue
public Builder setAnnotationUsecase(int annotationUsecase) {
this.annotationUsecase = annotationUsecase;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationLat(double userLocationLat) {
this.userLocationLat = userLocationLat;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationLng(double userLocationLng) {
this.userLocationLng = userLocationLng;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserFamiliarLanguageTags(String userFamiliarLanguageTags) {
this.userFamiliarLanguageTags = userFamiliarLanguageTags;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUsePodNer(boolean usePodNer) {
this.usePodNer = usePodNer;
return this;
}
+ @CanIgnoreReturnValue
public Builder setTrigerringDictionaryOnBeginnerWords(
boolean triggerDictionaryOnBeginnerWords) {
this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
this.useVocabAnnotator = useVocabAnnotator;
return this;
}
+ @CanIgnoreReturnValue
+ public Builder setEnableAddContactIntent(boolean enableAddContactIntent) {
+ this.enableAddContactIntent = enableAddContactIntent;
+ return this;
+ }
+
+ @CanIgnoreReturnValue
+ public Builder setEnableSearchIntent(boolean enableSearchIntent) {
+ this.enableSearchIntent = enableSearchIntent;
+ return this;
+ }
+
public ClassificationOptions build() {
return new ClassificationOptions(
referenceTimeMsUtc,
@@ -878,7 +919,9 @@ public final class AnnotatorModel implements AutoCloseable {
userFamiliarLanguageTags,
usePodNer,
triggerDictionaryOnBeginnerWords,
- useVocabAnnotator);
+ useVocabAnnotator,
+ enableAddContactIntent,
+ enableSearchIntent);
}
}
@@ -936,6 +979,14 @@ public final class AnnotatorModel implements AutoCloseable {
public boolean getUseVocabAnnotator() {
return useVocabAnnotator;
}
+
+ public boolean getEnableAddContactIntent() {
+ return enableAddContactIntent;
+ }
+
+ public boolean getEnableSearchIntent() {
+ return enableSearchIntent;
+ }
}
/** Represents options for the annotate call. */
@@ -1011,81 +1062,97 @@ public final class AnnotatorModel implements AutoCloseable {
private boolean triggerDictionaryOnBeginnerWords = false;
private boolean useVocabAnnotator = true;
+ @CanIgnoreReturnValue
public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
return this;
}
+ @CanIgnoreReturnValue
public Builder setReferenceTimezone(String referenceTimezone) {
this.referenceTimezone = referenceTimezone;
return this;
}
+ @CanIgnoreReturnValue
public Builder setLocales(@Nullable String locales) {
this.locales = locales;
return this;
}
+ @CanIgnoreReturnValue
public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
this.detectedTextLanguageTags = detectedTextLanguageTags;
return this;
}
+ @CanIgnoreReturnValue
public Builder setEntityTypes(Collection<String> entityTypes) {
this.entityTypes = entityTypes;
return this;
}
+ @CanIgnoreReturnValue
public Builder setAnnotateMode(int annotateMode) {
this.annotateMode = annotateMode;
return this;
}
+ @CanIgnoreReturnValue
public Builder setAnnotationUsecase(int annotationUsecase) {
this.annotationUsecase = annotationUsecase;
return this;
}
+ @CanIgnoreReturnValue
public Builder setHasLocationPermission(boolean hasLocationPermission) {
this.hasLocationPermission = hasLocationPermission;
return this;
}
+ @CanIgnoreReturnValue
public Builder setHasPersonalizationPermission(boolean hasPersonalizationPermission) {
this.hasPersonalizationPermission = hasPersonalizationPermission;
return this;
}
+ @CanIgnoreReturnValue
public Builder setIsSerializedEntityDataEnabled(boolean isSerializedEntityDataEnabled) {
this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationLat(double userLocationLat) {
this.userLocationLat = userLocationLat;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationLng(double userLocationLng) {
this.userLocationLng = userLocationLng;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUsePodNer(boolean usePodNer) {
this.usePodNer = usePodNer;
return this;
}
+ @CanIgnoreReturnValue
public Builder setTriggerDictionaryOnBeginnerWords(boolean triggerDictionaryOnBeginnerWords) {
this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
return this;
}
+ @CanIgnoreReturnValue
public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
this.useVocabAnnotator = useVocabAnnotator;
return this;
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
index 235bb4a..5ed09cc 100644
--- a/native/FlatBufferHeaders.bp
+++ b/native/FlatBufferHeaders.bp
@@ -15,65 +15,58 @@
//
genrule {
- name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- srcs: ["lang_id/common/flatbuffers/model.fbs"],
- out: ["lang_id/common/flatbuffers/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
- srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
- out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
+ name: "libtextclassifier_fbgen_utils_i18n_language-tag",
+ srcs: ["utils/i18n/language-tag.fbs"],
+ out: ["utils/i18n/language-tag_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_datetime_datetime",
- srcs: ["annotator/datetime/datetime.fbs"],
- out: ["annotator/datetime/datetime_generated.h"],
+ name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ srcs: ["utils/tflite/text_encoder_config.fbs"],
+ out: ["utils/tflite/text_encoder_config_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_model",
- srcs: ["annotator/model.fbs"],
- out: ["annotator/model_generated.h"],
+ name: "libtextclassifier_fbgen_utils_resources",
+ srcs: ["utils/resources.fbs"],
+ out: ["utils/resources_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_experimental_experimental",
- srcs: ["annotator/experimental/experimental.fbs"],
- out: ["annotator/experimental/experimental_generated.h"],
+ name: "libtextclassifier_fbgen_utils_grammar_rules",
+ srcs: ["utils/grammar/rules.fbs"],
+ out: ["utils/grammar/rules_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_entity-data",
- srcs: ["annotator/entity-data.fbs"],
- out: ["annotator/entity-data_generated.h"],
+ name: "libtextclassifier_fbgen_utils_grammar_testing_value",
+ srcs: ["utils/grammar/testing/value.fbs"],
+ out: ["utils/grammar/testing/value_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
- srcs: ["annotator/person_name/person_name_model.fbs"],
- out: ["annotator/person_name/person_name_model_generated.h"],
+ name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ srcs: ["utils/grammar/semantics/expression.fbs"],
+ out: ["utils/grammar/semantics/expression_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
- srcs: ["utils/tflite/text_encoder_config.fbs"],
- out: ["utils/tflite/text_encoder_config_generated.h"],
+ name: "libtextclassifier_fbgen_utils_zlib_buffer",
+ srcs: ["utils/zlib/buffer.fbs"],
+ out: ["utils/zlib/buffer_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_codepoint-range",
- srcs: ["utils/codepoint-range.fbs"],
- out: ["utils/codepoint-range_generated.h"],
+ name: "libtextclassifier_fbgen_utils_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
defaults: ["fbgen"],
}
@@ -85,16 +78,16 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- srcs: ["utils/flatbuffers/flatbuffers.fbs"],
- out: ["utils/flatbuffers/flatbuffers_generated.h"],
+ name: "libtextclassifier_fbgen_utils_container_bit-vector",
+ srcs: ["utils/container/bit-vector.fbs"],
+ out: ["utils/container/bit-vector_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_zlib_buffer",
- srcs: ["utils/zlib/buffer.fbs"],
- out: ["utils/zlib/buffer_generated.h"],
+ name: "libtextclassifier_fbgen_utils_codepoint-range",
+ srcs: ["utils/codepoint-range.fbs"],
+ out: ["utils/codepoint-range_generated.h"],
defaults: ["fbgen"],
}
@@ -106,65 +99,72 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_testing_value",
- srcs: ["utils/grammar/testing/value.fbs"],
- out: ["utils/grammar/testing/value_generated.h"],
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ srcs: ["utils/flatbuffers/flatbuffers.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_rules",
- srcs: ["utils/grammar/rules.fbs"],
- out: ["utils/grammar/rules_generated.h"],
+ name: "libtextclassifier_fbgen_actions_actions_model",
+ srcs: ["actions/actions_model.fbs"],
+ out: ["actions/actions_model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- srcs: ["utils/grammar/semantics/expression.fbs"],
- out: ["utils/grammar/semantics/expression_generated.h"],
+ name: "libtextclassifier_fbgen_actions_actions-entity-data",
+ srcs: ["actions/actions-entity-data.fbs"],
+ out: ["actions/actions-entity-data_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_resources",
- srcs: ["utils/resources.fbs"],
- out: ["utils/resources_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
+ out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_i18n_language-tag",
- srcs: ["utils/i18n/language-tag.fbs"],
- out: ["utils/i18n/language-tag_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_normalization",
- srcs: ["utils/normalization.fbs"],
- out: ["utils/normalization_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_entity-data",
+ srcs: ["annotator/entity-data.fbs"],
+ out: ["annotator/entity-data_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_container_bit-vector",
- srcs: ["utils/container/bit-vector.fbs"],
- out: ["utils/container/bit-vector_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ srcs: ["annotator/person_name/person_name_model.fbs"],
+ out: ["annotator/person_name/person_name_model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_actions_actions-entity-data",
- srcs: ["actions/actions-entity-data.fbs"],
- out: ["actions/actions-entity-data_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_experimental_experimental",
+ srcs: ["annotator/experimental/experimental.fbs"],
+ out: ["annotator/experimental/experimental_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_actions_actions_model",
- srcs: ["actions/actions_model.fbs"],
- out: ["actions/actions_model_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_model",
+ srcs: ["annotator/model.fbs"],
+ out: ["annotator/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_datetime_datetime",
+ srcs: ["annotator/datetime/datetime.fbs"],
+ out: ["annotator/datetime/datetime_generated.h"],
defaults: ["fbgen"],
}
@@ -179,50 +179,50 @@ cc_library_headers {
"com.android.adservices",
],
generated_headers: [
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
- "libtextclassifier_fbgen_annotator_datetime_datetime",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_annotator_experimental_experimental",
- "libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
"libtextclassifier_fbgen_utils_tflite_text_encoder_config",
- "libtextclassifier_fbgen_utils_codepoint-range",
- "libtextclassifier_fbgen_utils_intents_intent-config",
- "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- "libtextclassifier_fbgen_utils_zlib_buffer",
- "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_resources",
"libtextclassifier_fbgen_utils_grammar_rules",
"libtextclassifier_fbgen_utils_grammar_semantics_expression",
- "libtextclassifier_fbgen_utils_resources",
- "libtextclassifier_fbgen_utils_i18n_language-tag",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
"libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
"libtextclassifier_fbgen_utils_container_bit-vector",
- "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
"libtextclassifier_fbgen_actions_actions_model",
- ],
- export_generated_headers: [
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
"libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
- "libtextclassifier_fbgen_annotator_datetime_datetime",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
"libtextclassifier_fbgen_annotator_entity-data",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_annotator_datetime_datetime",
+ ],
+ export_generated_headers: [
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
"libtextclassifier_fbgen_utils_tflite_text_encoder_config",
- "libtextclassifier_fbgen_utils_codepoint-range",
- "libtextclassifier_fbgen_utils_intents_intent-config",
- "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- "libtextclassifier_fbgen_utils_zlib_buffer",
- "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_resources",
"libtextclassifier_fbgen_utils_grammar_rules",
"libtextclassifier_fbgen_utils_grammar_semantics_expression",
- "libtextclassifier_fbgen_utils_resources",
- "libtextclassifier_fbgen_utils_i18n_language-tag",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
"libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
"libtextclassifier_fbgen_utils_container_bit-vector",
- "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
"libtextclassifier_fbgen_actions_actions_model",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ "libtextclassifier_fbgen_annotator_entity-data",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_annotator_datetime_datetime",
],
}
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
index 9837173..89ef085 100644
--- a/native/JavaTests.bp
+++ b/native/JavaTests.bp
@@ -17,30 +17,30 @@
filegroup {
name: "libtextclassifier_java_test_sources",
srcs: [
- "annotator/datetime/datetime-grounder_test.cc",
- "annotator/datetime/regex-parser_test.cc",
- "annotator/datetime/grammar-parser_test.cc",
- "annotator/pod_ner/pod-ner-impl_test.cc",
+ "utils/grammar/parsing/lexer_test.cc",
"utils/intents/intent-generator-test-lib.cc",
- "utils/calendar/calendar_test.cc",
"utils/regex-match_test.cc",
- "utils/grammar/parsing/lexer_test.cc",
+ "utils/calendar/calendar_test.cc",
"actions/actions-suggestions_test.cc",
"actions/grammar-actions_test.cc",
- "annotator/number/number_test-include.cc",
- "annotator/annotator_test-include.cc",
- "annotator/grammar/grammar-annotator_test.cc",
- "annotator/grammar/test-utils.cc",
- "utils/utf8/unilib_test-include.cc",
+ "annotator/pod_ner/pod-ner-impl_test.cc",
+ "annotator/datetime/regex-parser_test.cc",
+ "annotator/datetime/grammar-parser_test.cc",
+ "annotator/datetime/datetime-grounder_test.cc",
"utils/grammar/parsing/parser_test.cc",
- "utils/grammar/analyzer_test.cc",
"utils/grammar/semantics/composer_test.cc",
- "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
- "utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
- "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
"utils/grammar/semantics/evaluators/arithmetic-eval_test.cc",
- "utils/grammar/semantics/evaluators/span-eval_test.cc",
"utils/grammar/semantics/evaluators/const-eval_test.cc",
+ "utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
+ "utils/grammar/semantics/evaluators/span-eval_test.cc",
+ "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
+ "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
"utils/grammar/semantics/evaluators/compose-eval_test.cc",
+ "utils/grammar/analyzer_test.cc",
+ "utils/utf8/unilib_test-include.cc",
+ "annotator/grammar/grammar-annotator_test.cc",
+ "annotator/grammar/test-utils.cc",
+ "annotator/annotator_test-include.cc",
+ "annotator/number/number_test-include.cc",
],
}
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
index 6ebf1cf..e5ebfec 100644
--- a/native/actions/actions-entity-data.bfbs
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 9f9a8d4..eeeb508 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -21,6 +21,8 @@
#include <vector>
#include "utils/base/statusor.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/random.h"
#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-actions.h"
@@ -42,6 +44,7 @@
#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/random/distributions.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
@@ -813,6 +816,8 @@ void ActionsSuggestions::PopulateTextReplies(
const tflite::Interpreter* interpreter, int suggestion_index,
int score_index, const std::string& type, float priority_score,
const absl::flat_hash_set<std::string>& blocklist,
+ const absl::flat_hash_map<std::string, std::vector<std::string>>&
+ concept_mappings,
ActionsSuggestionsResponse* response) const {
const std::vector<tflite::StringRef> replies =
model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
@@ -831,6 +836,12 @@ void ActionsSuggestions::PopulateTextReplies(
if (blocklist.contains(response_text)) {
continue;
}
+ if (concept_mappings.contains(response_text)) {
+ const int candidates_size = concept_mappings.at(response_text).size();
+ const int candidate_index = absl::Uniform<int>(
+ absl::IntervalOpenOpen, bit_gen_, 0, candidates_size);
+ response_text = concept_mappings.at(response_text)[candidate_index];
+ }
response->actions.push_back({response_text, type, score, priority_score});
}
@@ -918,11 +929,11 @@ bool ActionsSuggestions::ReadModelOutput(
if (!response->output_filtered_min_triggering_score &&
model_->tflite_model_spec()->output_replies() >= 0) {
absl::flat_hash_set<std::string> empty_blocklist;
- PopulateTextReplies(interpreter,
- model_->tflite_model_spec()->output_replies(),
- model_->tflite_model_spec()->output_replies_scores(),
- model_->smart_reply_action_type()->str(),
- /* priority_score */ 0.0, empty_blocklist, response);
+ PopulateTextReplies(
+ interpreter, model_->tflite_model_spec()->output_replies(),
+ model_->tflite_model_spec()->output_replies_scores(),
+ model_->smart_reply_action_type()->str(),
+ /* priority_score */ 0.0, empty_blocklist, {}, response);
}
// Read actions suggestions.
@@ -961,6 +972,8 @@ bool ActionsSuggestions::ReadModelOutput(
const int suggestions_scores_index =
metadata->output_suggestions_scores();
absl::flat_hash_set<std::string> response_text_blocklist;
+ absl::flat_hash_map<std::string, std::vector<std::string>>
+ concept_mappings;
switch (metadata->prediction_type()) {
case PredictionType_NEXT_MESSAGE_PREDICTION:
if (!task_spec || task_spec->type()->size() == 0) {
@@ -973,13 +986,22 @@ bool ActionsSuggestions::ReadModelOutput(
response_text_blocklist.insert(val->str());
}
}
+ if (task_spec->concept_mappings()) {
+ for (const auto& concept : *task_spec->concept_mappings()) {
+ std::vector<std::string> candidates;
+ for (const auto& candidate : *concept->candidates()) {
+ candidates.push_back(candidate->str());
+ }
+ concept_mappings[concept->concept_name()->str()] = candidates;
+ }
+ }
}
PopulateTextReplies(
interpreter, suggestions_index, suggestions_scores_index,
task_spec ? task_spec->type()->str()
: model_->smart_reply_action_type()->str(),
task_spec ? task_spec->priority_score() : 0.0,
- response_text_blocklist, response);
+ response_text_blocklist, concept_mappings, response);
break;
case PredictionType_INTENT_TRIGGERING:
PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 87f55fb..c3d58e4 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,7 +43,9 @@
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/random/random.h"
namespace libtextclassifier3 {
@@ -175,11 +177,13 @@ class ActionsSuggestions {
void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
ActionSuggestion* suggestion) const;
- void PopulateTextReplies(const tflite::Interpreter* interpreter,
- int suggestion_index, int score_index,
- const std::string& type, float priority_score,
- const absl::flat_hash_set<std::string>& blocklist,
- ActionsSuggestionsResponse* response) const;
+ void PopulateTextReplies(
+ const tflite::Interpreter* interpreter, int suggestion_index,
+ int score_index, const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
+ const absl::flat_hash_map<std::string, std::vector<std::string>>&
+ concept_mappings,
+ ActionsSuggestionsResponse* response) const;
void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
int suggestion_index, int score_index,
@@ -273,6 +277,9 @@ class ActionsSuggestions {
// Conversation intent detection model for additional actions.
std::unique_ptr<const ConversationIntentDetection>
conversation_intent_detection_;
+
+ // Used for randomly selecting candidates.
+ mutable absl::BitGen bit_gen_;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index b51ebc7..65f9796 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -61,6 +61,8 @@ constexpr char kMultiTaskSrP13nModelFileName[] =
"actions_suggestions_test.multi_task_sr_p13n.model";
constexpr char kMultiTaskSrEmojiModelFileName[] =
"actions_suggestions_test.multi_task_sr_emoji.model";
+constexpr char kMultiTaskSrEmojiConceptModelFileName[] =
+ "actions_suggestions_test.multi_task_sr_emoji_concept.model";
constexpr char kSensitiveTFliteModelFileName[] =
"actions_suggestions_test.sensitive_tflite.model";
constexpr char kLiveRelayTFLiteModelFileName[] =
@@ -1835,6 +1837,25 @@ TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
EXPECT_EQ(response.actions[2].type, "text_reply");
}
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskSrEmojiConceptModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "i am tired",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ std::vector<std::string> sigh_emojis = {"😔", "😞"};
+
+ EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(),
+ response.actions[0].response_text) !=
+ sigh_emojis.end());
+ EXPECT_EQ(response.actions[0].type, "emoji_reply");
+}
+
TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kLiveRelayTFLiteModelFileName);
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 0d8c7ad..70f9104 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -312,6 +312,15 @@ table TriggeringPreconditions {
min_reply_score_threshold:float = 0;
}
+// This proto handles model outputs that are concepts, such as emoji concept
+// suggestion models. Each concept maps to a list of candidates. One of
+// the candidates is chosen randomly as the final suggestion.
+namespace libtextclassifier3;
+table ActionConceptToSuggestion {
+ concept_name:string (shared);
+ candidates:[string];
+}
+
namespace libtextclassifier3;
table ActionSuggestionSpec {
// Type of the action suggestion.
@@ -331,6 +340,10 @@ table ActionSuggestionSpec {
entity_data:ActionsEntityData;
response_text_blocklist:[string];
+
+ // If provided, map the response as concept to one of the corresponding
+ // candidates.
+ concept_mappings:[ActionConceptToSuggestion];
}
// Options to specify triggering behaviour per action class.
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index 0fa7f7e..6d7bdb0 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.live_relay.model b/native/actions/test_data/actions_suggestions_test.live_relay.model
index 6ff4302..af5e10b 100644
--- a/native/actions/test_data/actions_suggestions_test.live_relay.model
+++ b/native/actions/test_data/actions_suggestions_test.live_relay.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index 6107e98..88f62eb 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 436ed93..40a2409 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index 935691d..effb2cb 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model
new file mode 100644
index 0000000..18333d6
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index 2c9f74b..e41ab39 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index cdb7523..5314b43 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index ac28fa2..a633742 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index d864b79..6685d26 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index e0d4241..a8483f1 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -432,7 +432,8 @@ void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
*analyzer_, *datetime_grounder_,
/*target_classification_score=*/1.0,
- /*priority_score=*/1.0);
+ /*priority_score=*/1.0,
+ model_->datetime_grammar_model()->enabled_modes());
}
} else if (model_->datetime_model()) {
datetime_parser_ = RegexDatetimeParser::Instance(
@@ -604,6 +605,8 @@ bool Annotator::InitializeKnowledgeEngine(
if (model_->triggering_options() != nullptr) {
knowledge_engine->SetPriorityScore(
model_->triggering_options()->knowledge_priority_score());
+ knowledge_engine->SetEnabledModes(
+ model_->triggering_options()->knowledge_enabled_modes());
}
knowledge_engine_ = std::move(knowledge_engine);
return true;
@@ -621,10 +624,21 @@ bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
return true;
}
+void Annotator::CleanUpContactEngine() {
+ if (contact_engine_ == nullptr) {
+ TC3_LOG(INFO)
+ << "Attempting to clean up contact engine that does not exist.";
+ return;
+ }
+ contact_engine_->CleanUp();
+}
+
bool Annotator::InitializeInstalledAppEngine(
const std::string& serialized_config) {
std::unique_ptr<InstalledAppEngine> installed_app_engine(
- new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
+ new InstalledAppEngine(
+ selection_feature_processor_.get(), unilib_,
+ model_->triggering_options()->installed_app_enabled_modes()));
if (!installed_app_engine->Initialize(serialized_config)) {
TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
return false;
@@ -912,38 +926,40 @@ CodepointSpan Annotator::SuggestSelection(
!knowledge_engine_
->Chunk(context, options.annotation_usecase,
options.location_context, Permissions(),
- AnnotateMode::kEntityAnnotation, &candidates)
+ AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION,
+ &candidates)
.ok()) {
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
if (contact_engine_ != nullptr &&
- !contact_engine_->Chunk(context_unicode, tokens,
+ !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Contact suggest selection failed.";
return original_click_indices;
}
if (installed_app_engine_ != nullptr &&
- !installed_app_engine_->Chunk(context_unicode, tokens,
+ !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Installed app suggest selection failed.";
return original_click_indices;
}
if (number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
+ ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
return original_click_indices;
}
if (duration_annotator_ != nullptr &&
- !duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase,
- &candidates.annotated_spans[0])) {
+ !duration_annotator_->FindAll(
+ context_unicode, tokens, options.annotation_usecase,
+ ModeFlag_SELECTION, &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
return original_click_indices;
}
if (person_name_engine_ != nullptr &&
- !person_name_engine_->Chunk(context_unicode, tokens,
+ !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
&candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Person name suggest selection failed.";
return original_click_indices;
@@ -964,7 +980,9 @@ CodepointSpan Annotator::SuggestSelection(
candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
}
- if (experimental_annotator_ != nullptr) {
+ if (experimental_annotator_ != nullptr &&
+ (model_->triggering_options()->experimental_enabled_modes() &
+ ModeFlag_SELECTION)) {
candidates.annotated_spans[0].push_back(
experimental_annotator_->SuggestSelection(context_unicode,
click_indices));
@@ -1896,7 +1914,9 @@ std::vector<ClassificationResult> Annotator::ClassifyText(
candidates.push_back({selection_indices, {vocab_annotator_result}});
}
- if (experimental_annotator_) {
+ if (experimental_annotator_ &&
+ (model_->triggering_options()->experimental_enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
experimental_annotator_->ClassifyText(context_unicode, selection_indices,
candidates);
}
@@ -2218,7 +2238,8 @@ Status Annotator::AnnotateSingleInput(
const bool contact_annotations_enabled =
!is_raw_usecase || is_entity_type_enabled(Collections::Contact());
if (contact_annotations_enabled && contact_engine_ &&
- !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
+ !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
+ candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
}
@@ -2226,7 +2247,8 @@ Status Annotator::AnnotateSingleInput(
const bool app_annotations_enabled =
!is_raw_usecase || is_entity_type_enabled(Collections::App());
if (app_annotations_enabled && installed_app_engine_ &&
- !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
+ !installed_app_engine_->Chunk(context_unicode, tokens,
+ ModeFlag_ANNOTATION, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run installed app engine Chunk.");
}
@@ -2237,7 +2259,7 @@ Status Annotator::AnnotateSingleInput(
is_entity_type_enabled(Collections::Percentage()));
if (number_annotations_enabled && number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
- candidates)) {
+ ModeFlag_ANNOTATION, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run number annotator FindAll.");
}
@@ -2247,7 +2269,8 @@ Status Annotator::AnnotateSingleInput(
!is_raw_usecase || is_entity_type_enabled(Collections::Duration());
if (duration_annotations_enabled && duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase, candidates)) {
+ options.annotation_usecase,
+ ModeFlag_ANNOTATION, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run duration annotator FindAll.");
}
@@ -2256,7 +2279,8 @@ Status Annotator::AnnotateSingleInput(
const bool person_annotations_enabled =
!is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
if (person_annotations_enabled && person_name_engine_ &&
- !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
+ !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
+ candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run person name engine Chunk.");
}
@@ -2290,6 +2314,8 @@ Status Annotator::AnnotateSingleInput(
// Annotate with the experimental annotator.
if (experimental_annotator_ != nullptr &&
+ (model_->triggering_options()->experimental_enabled_modes() &
+ ModeFlag_ANNOTATION) &&
!experimental_annotator_->Annotate(context_unicode, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
}
@@ -2376,14 +2402,21 @@ StatusOr<Annotations> Annotator::AnnotateStructuredInput(
.relative_bounding_box_height = string_fragment.bounding_box_height});
}
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ const bool is_raw_usecase =
+ options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ const bool knowledge_engine_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Entity());
// KnowledgeEngine is special, because it supports annotation of multiple
// fragments at once.
- if (knowledge_engine_ &&
+ if (knowledge_engine_annotations_enabled && knowledge_engine_ &&
!knowledge_engine_
->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
options.annotation_usecase,
options.location_context, options.permissions,
- options.annotate_mode, &annotation_candidates)
+ options.annotate_mode, ModeFlag_ANNOTATION,
+ &annotation_candidates)
.ok()) {
return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
}
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index d69fe32..5df8129 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -149,6 +149,9 @@ class Annotator {
// Initializes the contact engine with the given config.
bool InitializeContactEngine(const std::string& serialized_config);
+ // Cleans up the resources associated with the contact engine.
+ void CleanUpContactEngine();
+
// Initializes the installed app engine with the given config.
bool InitializeInstalledAppEngine(const std::string& serialized_config);
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 6e7eeab..3d352c6 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -275,6 +275,7 @@ StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject(
device_locales, classification_result,
options->reference_time_ms_utc, context, selection_indices,
app_context, model_context->model()->entity_data_schema(),
+ options->enable_add_contact_intent, options->enable_search_intent,
&remote_action_templates)) {
return {Status::UNKNOWN};
}
@@ -896,6 +897,9 @@ TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
(JNIEnv* env, jobject thiz, jlong ptr) {
const AnnotatorJniContext* context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
+ if (context != nullptr && context->model()) {
+ context->model()->CleanUpContactEngine();
+ }
delete context;
}
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index a6f636f..6ee4977 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -279,6 +279,23 @@ StatusOr<ClassificationOptions> FromJavaClassificationOptions(
JniHelper::CallBooleanMethod(env, joptions,
get_trigger_dictionary_on_beginner_words));
+ // .getEnableAddContactIntent()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_enable_add_contact_intent,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getEnableAddContactIntent", "()Z"));
+ TC3_ASSIGN_OR_RETURN(classifier_options.enable_add_contact_intent,
+ JniHelper::CallBooleanMethod(
+ env, joptions, get_enable_add_contact_intent));
+
+ // .getEnableSearchIntent()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_enable_search_intent,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getEnableSearchIntent", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ classifier_options.enable_search_intent,
+ JniHelper::CallBooleanMethod(env, joptions, get_enable_search_intent));
+
return classifier_options;
}
diff --git a/native/annotator/collections.h b/native/annotator/collections.h
index 417b447..becdcdb 100644
--- a/native/annotator/collections.h
+++ b/native/annotator/collections.h
@@ -144,6 +144,36 @@ class Collections {
*[]() { return new std::string("otp_code"); }();
return value;
}
+ static const std::string& Art() {
+ static const std::string& value =
+ *[]() { return new std::string("art"); }();
+ return value;
+ }
+ static const std::string& ConsumerGood() {
+ static const std::string& value =
+ *[]() { return new std::string("consumer_good"); }();
+ return value;
+ }
+ static const std::string& Event() {
+ static const std::string& value =
+ *[]() { return new std::string("event"); }();
+ return value;
+ }
+ static const std::string& Location() {
+ static const std::string& value =
+ *[]() { return new std::string("location"); }();
+ return value;
+ }
+ static const std::string& Organization() {
+ static const std::string& value =
+ *[]() { return new std::string("organization"); }();
+ return value;
+ }
+ static const std::string& Person() {
+ static const std::string& value =
+ *[]() { return new std::string("person"); }();
+ return value;
+ }
};
} // namespace libtextclassifier3
diff --git a/native/annotator/contact/contact-engine-dummy.h b/native/annotator/contact/contact-engine-dummy.h
index fe60203..211553c 100644
--- a/native/annotator/contact/contact-engine-dummy.h
+++ b/native/annotator/contact/contact-engine-dummy.h
@@ -47,13 +47,15 @@ class ContactEngine {
}
bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
+ const std::vector<Token>& tokens, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
return true;
}
void AddContactMetadataToKnowledgeClassificationResult(
ClassificationResult* classification_result) const {}
+
+ void CleanUp() const {}
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc
index 6d51c19..c49a01a 100644
--- a/native/annotator/datetime/grammar-parser.cc
+++ b/native/annotator/datetime/grammar-parser.cc
@@ -20,6 +20,7 @@
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
@@ -33,11 +34,13 @@ namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
- const float target_classification_score, const float priority_score)
+ const float target_classification_score, const float priority_score,
+ ModeFlag enabled_modes)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
- priority_score_(priority_score) {}
+ priority_score_(priority_score),
+ enabled_modes_(enabled_modes) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
@@ -54,6 +57,10 @@ StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
+ if (!(enabled_modes_ & mode)) {
+ return std::vector<DatetimeParseResultSpan>();
+ }
+
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
std::vector<Locale> locales = locale_list.GetLocales();
diff --git a/native/annotator/datetime/grammar-parser.h b/native/annotator/datetime/grammar-parser.h
index 6ff4b46..35da843 100644
--- a/native/annotator/datetime/grammar-parser.h
+++ b/native/annotator/datetime/grammar-parser.h
@@ -22,6 +22,7 @@
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/statusor.h"
#include "utils/grammar/analyzer.h"
@@ -37,7 +38,8 @@ class GrammarDatetimeParser : public DatetimeParser {
explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
const float target_classification_score,
- const float priority_score);
+ const float priority_score,
+ ModeFlag enabled_modes);
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
@@ -61,6 +63,7 @@ class GrammarDatetimeParser : public DatetimeParser {
const DatetimeGrounder& datetime_grounder_;
const float target_classification_score_;
const float priority_score_;
+ const ModeFlag enabled_modes_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/grammar-parser_test.cc b/native/annotator/datetime/grammar-parser_test.cc
index cf2dffd..b8a270d 100644
--- a/native/annotator/datetime/grammar-parser_test.cc
+++ b/native/annotator/datetime/grammar-parser_test.cc
@@ -22,6 +22,7 @@
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/testing/base-parser-test.h"
#include "annotator/datetime/testing/datetime-component-builder.h"
+#include "annotator/model_generated.h"
#include "utils/grammar/analyzer.h"
#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
@@ -42,7 +43,15 @@ std::string ReadFile(const std::string& file_name) {
class GrammarDatetimeParserTest : public DateTimeParserTest {
public:
- void SetUp() override {
+ void SetUp() override { ResetParser(ModeFlag_ALL); }
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return parser_.get();
+ }
+
+ protected:
+ void ResetParser(ModeFlag enabled_modes) {
grammar_buffer_ = ReadFile(GetModelPath() + "datetime.fb");
unilib_ = CreateUniLibForTesting();
calendarlib_ = CreateCalendarLibForTesting();
@@ -51,12 +60,8 @@ class GrammarDatetimeParserTest : public DateTimeParserTest {
datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_.get());
parser_.reset(new GrammarDatetimeParser(*analyzer_, *datetime_grounder_,
/*target_classification_score=*/1.0,
- /*priority_score=*/1.0));
- }
-
- // Exposes the date time parser for tests and evaluations.
- const DatetimeParser* DatetimeParserForTests() const override {
- return parser_.get();
+ /*priority_score=*/1.0,
+ enabled_modes));
}
private:
@@ -486,6 +491,13 @@ TEST_F(GrammarDatetimeParserTest, Parse) {
.Build()}));
}
+TEST_F(GrammarDatetimeParserTest, NotEnabledModeHasNoResult) {
+ ResetParser(ModeFlag_SELECTION);
+ // `DateTimeParserTest` implementation parses the input under the ANNOTATION
+ // mode.
+ EXPECT_TRUE(HasNoResult("{January 1, 1988}"));
+}
+
TEST_F(GrammarDatetimeParserTest, DateValidation) {
EXPECT_TRUE(ParsesCorrectly(
"{01/02/2020}", 1577919600000, GRANULARITY_DAY,
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index c59b8e0..df4c60d 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -20,6 +20,7 @@
#include <cstdlib>
#include "annotator/collections.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/base/macros.h"
@@ -125,8 +126,10 @@ bool DurationAnnotator::ClassifyText(
const UnicodeText& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
- if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
- (1 << annotation_usecase))) == 0) {
+ if (!options_->enabled() ||
+ ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
+ 0 ||
+ !(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
return false;
}
@@ -151,9 +154,12 @@ bool DurationAnnotator::ClassifyText(
bool DurationAnnotator::FindAll(const UnicodeText& context,
const std::vector<Token>& tokens,
AnnotationUsecase annotation_usecase,
+ ModeFlag mode,
std::vector<AnnotatedSpan>* results) const {
- if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
- (1 << annotation_usecase))) == 0) {
+ if (!options_->enabled() ||
+ ((options_->enabled_annotation_usecases() & (1 << annotation_usecase))) ==
+ 0 ||
+ !(options_->enabled_modes() & mode)) {
return true;
}
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index 1a42ac3..e99542c 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -87,7 +87,7 @@ class DurationAnnotator {
// Finds all duration instances in the input text.
bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens,
- AnnotationUsecase annotation_usecase,
+ AnnotationUsecase annotation_usecase, ModeFlag mode,
std::vector<AnnotatedSpan>* results) const;
private:
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index 7c07a72..f726058 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -16,6 +16,7 @@
#include "annotator/duration/duration.h"
+#include <cstddef>
#include <string>
#include <vector>
@@ -37,41 +38,61 @@ using testing::ElementsAre;
using testing::Field;
using testing::IsEmpty;
-const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- DurationAnnotatorOptionsT options;
- options.enabled = true;
+namespace {
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ DurationAnnotatorOptionsT options;
+ options.enabled = true;
+ options.enabled_modes = enabled_modes;
- options.week_expressions.push_back("week");
- options.week_expressions.push_back("weeks");
+ options.week_expressions.push_back("week");
+ options.week_expressions.push_back("weeks");
- options.day_expressions.push_back("day");
- options.day_expressions.push_back("days");
+ options.day_expressions.push_back("day");
+ options.day_expressions.push_back("days");
- options.hour_expressions.push_back("hour");
- options.hour_expressions.push_back("hours");
+ options.hour_expressions.push_back("hour");
+ options.hour_expressions.push_back("hours");
- options.minute_expressions.push_back("minute");
- options.minute_expressions.push_back("minutes");
+ options.minute_expressions.push_back("minute");
+ options.minute_expressions.push_back("minutes");
- options.second_expressions.push_back("second");
- options.second_expressions.push_back("seconds");
+ options.second_expressions.push_back("second");
+ options.second_expressions.push_back("seconds");
- options.filler_expressions.push_back("and");
- options.filler_expressions.push_back("a");
- options.filler_expressions.push_back("an");
- options.filler_expressions.push_back("one");
+ options.filler_expressions.push_back("and");
+ options.filler_expressions.push_back("a");
+ options.filler_expressions.push_back("an");
+ options.filler_expressions.push_back("one");
- options.half_expressions.push_back("half");
+ options.half_expressions.push_back("half");
- options.sub_token_separator_codepoints.push_back('-');
+ options.sub_token_separator_codepoints.push_back('-');
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+} // namespace
- return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
+const DurationAnnotatorOptions* TestingDurationAnnotatorOptions(
+ ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_all =
+ CreateOptionsData(ModeFlag_ALL);
+ static const flatbuffers::DetachedBuffer* options_data_selection =
+ CreateOptionsData(ModeFlag_SELECTION);
+ static const flatbuffers::DetachedBuffer* options_data_no_selection =
+ CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
+
+ if (enabled_modes == ModeFlag_SELECTION) {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_selection->data());
+ } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_no_selection->data());
+ } else {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_all->data());
+ }
}
std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
@@ -103,10 +124,10 @@ std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
class DurationAnnotatorTest : public ::testing::Test {
protected:
- DurationAnnotatorTest()
+ explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
- duration_annotator_(TestingDurationAnnotatorOptions(),
+ duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes),
feature_processor_.get(), &unilib_) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
@@ -118,6 +139,19 @@ class DurationAnnotatorTest : public ::testing::Test {
DurationAnnotator duration_annotator_;
};
+class DurationAnnotatorForAnnotationAndClassificationTest
+ : public DurationAnnotatorTest {
+ protected:
+ DurationAnnotatorForAnnotationAndClassificationTest()
+ : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest {
+ protected:
+ DurationAnnotatorForSelectionTest()
+ : DurationAnnotatorTest(ModeFlag_SELECTION) {}
+};
+
TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
@@ -129,6 +163,14 @@ TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
}
+TEST_F(DurationAnnotatorForSelectionTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification;
+ EXPECT_FALSE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+}
+
TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
@@ -152,7 +194,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -165,13 +208,26 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
15 * 60 * 1000)))))));
}
+TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,
+ FindsAllDisabledModeReturnsNoResults) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
+
+ EXPECT_THAT(result, IsEmpty());
+}
+
TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -190,7 +246,8 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -209,7 +266,8 @@ TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -228,7 +286,8 @@ TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -247,7 +306,8 @@ TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -266,7 +326,8 @@ TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -286,7 +347,8 @@ TEST_F(DurationAnnotatorTest,
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -305,7 +367,8 @@ TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -323,7 +386,8 @@ TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -334,7 +398,8 @@ TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -352,7 +417,8 @@ TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -383,7 +449,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -402,7 +469,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -422,7 +490,8 @@ TEST_F(DurationAnnotatorTest,
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -440,7 +509,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -458,7 +528,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -475,7 +546,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -540,7 +612,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -558,7 +631,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -576,7 +650,8 @@ TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, IsEmpty());
}
@@ -586,7 +661,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
diff --git a/native/annotator/installed_app/installed-app-engine-dummy.h b/native/annotator/installed_app/installed-app-engine-dummy.h
index 2f2b62f..80f5a5c 100644
--- a/native/annotator/installed_app/installed-app-engine-dummy.h
+++ b/native/annotator/installed_app/installed-app-engine-dummy.h
@@ -32,7 +32,7 @@ namespace libtextclassifier3 {
class InstalledAppEngine {
public:
explicit InstalledAppEngine(const FeatureProcessor* feature_processor,
- const UniLib* unilib) {}
+ const UniLib* unilib, ModeFlag enabled_modes) {}
bool Initialize(const std::string& serialized_config) {
TC3_LOG(ERROR) << "No installed app engine to initialize.";
@@ -45,7 +45,7 @@ class InstalledAppEngine {
}
bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
+ const std::vector<Token>& tokens, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
return true;
}
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 34fa490..949018c 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -37,6 +37,8 @@ class KnowledgeEngine {
void SetPriorityScore(float priority_score) {}
+ void SetEnabledModes(ModeFlag enabled_modes) {}
+
Status ClassifyText(const std::string& text, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
@@ -48,7 +50,7 @@ class KnowledgeEngine {
Status Chunk(const std::string& text, AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
const Permissions& permissions, const AnnotateMode annotate_mode,
- Annotations* result) const {
+ ModeFlag mode, Annotations* result) const {
return Status::OK;
}
@@ -58,7 +60,7 @@ class KnowledgeEngine {
AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
const Permissions& permissions, const AnnotateMode annotate_mode,
- Annotations* results) const {
+ ModeFlag mode, Annotations* results) const {
return Status::OK;
}
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 57187f5..eeb4101 100644
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -415,6 +415,7 @@ table GrammarModel {
// The grammar rules.
rules:grammar.RulesSet;
+ // Deprecated. Used only for the old implementation of the grammar model.
rule_classification_result:[GrammarModel_.RuleClassificationResult];
// Number of tokens in the context to use for classification and text
@@ -432,6 +433,10 @@ table GrammarModel {
// The priority score used for conflict resolution with the other models.
priority_score:float = 1;
+
+ // Global enabled modes. Use this instead of
+ // `rule_classification_result.enabled_modes`.
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3.MoneyParsingOptions_;
@@ -486,6 +491,15 @@ table ModelTriggeringOptions {
// map. Key: collection type e.g. "address", "phone"..., Value: float number.
// NOTE: The entries here need to be sorted since we use LookupByKey.
collection_to_priority:[ModelTriggeringOptions_.CollectionToPriorityEntry];
+
+ // Enabled modes for the knowledge engine model.
+ knowledge_enabled_modes:ModeFlag = ALL;
+
+ // Enabled modes for the experimental model.
+ experimental_enabled_modes:ModeFlag = ALL;
+
+ // Enabled modes for the installed app model.
+ installed_app_enabled_modes:ModeFlag = ALL;
}
// Options controlling the output of the classifier.
@@ -894,6 +908,9 @@ table ContactAnnotatorOptions {
// For each language there is a customized list of supported declensions.
language:string (shared);
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3.TranslateAnnotatorOptions_;
@@ -927,6 +944,9 @@ table TranslateAnnotatorOptions {
algorithm:TranslateAnnotatorOptions_.Algorithm;
backoff_options:TranslateAnnotatorOptions_.BackoffOptions;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = CLASSIFICATION;
}
namespace libtextclassifier3.PodNerModel_;
@@ -1012,6 +1032,9 @@ table PodNerModel {
min_number_of_tokens:int = 1;
min_number_of_wordpieces:int = 1;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3;
@@ -1043,6 +1066,9 @@ table VocabModel {
// Priority score used for conflict resolution with the other models.
priority_score:float = 0;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ANNOTATION_AND_CLASSIFICATION;
}
root_type libtextclassifier3.Model;
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
index 3be6ad8..14fc24e 100644
--- a/native/annotator/number/number.cc
+++ b/native/annotator/number/number.cc
@@ -21,6 +21,7 @@
#include <string>
#include "annotator/collections.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/strings/split.h"
@@ -38,7 +39,8 @@ bool NumberAnnotator::ClassifyText(
context, selection_indices.first, selection_indices.second);
std::vector<AnnotatedSpan> results;
- if (!FindAll(substring_selected, annotation_usecase, &results)) {
+ if (!FindAll(substring_selected, annotation_usecase, ModeFlag_CLASSIFICATION,
+ &results)) {
return false;
}
@@ -216,8 +218,9 @@ bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
bool NumberAnnotator::FindAll(const UnicodeText& context,
AnnotationUsecase annotation_usecase,
+ ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
- if (!options_->enabled()) {
+ if (!options_->enabled() || !(options_->enabled_modes() & mode)) {
return true;
}
diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h
index d83bea0..dcc2d48 100644
--- a/native/annotator/number/number.h
+++ b/native/annotator/number/number.h
@@ -58,7 +58,7 @@ class NumberAnnotator {
// Finds all number instances in the input text. Returns true in any case.
bool FindAll(const UnicodeText& context_unicode,
- AnnotationUsecase annotation_usecase,
+ AnnotationUsecase annotation_usecase, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const;
private:
diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc
index f47933f..98140f4 100644
--- a/native/annotator/number/number_test-include.cc
+++ b/native/annotator/number/number_test-include.cc
@@ -16,6 +16,7 @@
#include "annotator/number/number_test-include.h"
+#include <set>
#include <string>
#include <vector>
@@ -34,37 +35,57 @@ namespace test_internal {
using ::testing::AllOf;
using ::testing::ElementsAre;
using ::testing::Field;
+using ::testing::IsEmpty;
using ::testing::Matcher;
using ::testing::UnorderedElementsAre;
+namespace {
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ NumberAnnotatorOptionsT options;
+ options.enabled = true;
+ options.priority_score = -10.0;
+ options.float_number_priority_score = 1.0;
+ options.enabled_annotation_usecases =
+ 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ options.max_number_of_digits = 20;
+ options.enabled_modes = enabled_modes;
+
+ options.percentage_priority_score = 1.0;
+ options.percentage_annotation_usecases =
+ (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) +
+ (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART);
+ std::set<std::string> percent_suffixes(
+ {"パーセント", "percent", "pércént", "pc", "pct", "%", "٪", "﹪", "%"});
+ for (const std::string& string_value : percent_suffixes) {
+ options.percentage_pieces_string.append(string_value);
+ options.percentage_pieces_string.push_back('\0');
+ }
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+} // namespace
+
const NumberAnnotatorOptions*
-NumberAnnotatorTest::TestingNumberAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- NumberAnnotatorOptionsT options;
- options.enabled = true;
- options.priority_score = -10.0;
- options.float_number_priority_score = 1.0;
- options.enabled_annotation_usecases =
- 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW;
- options.max_number_of_digits = 20;
-
- options.percentage_priority_score = 1.0;
- options.percentage_annotation_usecases =
- (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) +
- (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART);
- std::set<std::string> percent_suffixes({"パーセント", "percent", "pércént",
- "pc", "pct", "%", "٪", "﹪", "%"});
- for (const std::string& string_value : percent_suffixes) {
- options.percentage_pieces_string.append(string_value);
- options.percentage_pieces_string.push_back('\0');
- }
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data());
+NumberAnnotatorTest::TestingNumberAnnotatorOptions(ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_selection =
+ CreateOptionsData(ModeFlag_SELECTION);
+ static const flatbuffers::DetachedBuffer* options_data_no_selection =
+ CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
+ static const flatbuffers::DetachedBuffer* options_data_all =
+ CreateOptionsData(ModeFlag_ALL);
+
+ if (enabled_modes == ModeFlag_SELECTION) {
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(
+ options_data_selection->data());
+ } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(
+ options_data_no_selection->data());
+ } else {
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(
+ options_data_all->data());
+ }
}
MATCHER_P(IsCorrectCollection, collection, "collection is " + collection) {
@@ -124,6 +145,14 @@ TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) {
EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345);
}
+TEST_F(NumberAnnotatorForSelectionTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberAsFloatCorrectly) {
ClassificationResult classification_result;
EXPECT_TRUE(number_annotator_.ClassifyText(
@@ -167,7 +196,7 @@ TEST_F(NumberAnnotatorTest, FindsAllIntegerAndFloatNumbersInText) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("how much is 2 plus 5 divided by 7% minus 3.14 "
"what about 68.9# or 68.9#?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -268,7 +297,8 @@ TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiJaPercentageCorrectSuffix) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("明日の降水確率は10パーセント 音量を12にセット"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION,
+ &result));
EXPECT_THAT(result,
UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(8, 10), "number",
@@ -285,7 +315,7 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 "
"but not $99."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -307,12 +337,23 @@ TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
/*int_value=*/39, /*double_value=*/39.0)));
}
+TEST_F(NumberAnnotatorForAnnotationAndClassificationTest,
+ FindsAllDisabledModeReturnsNoResults) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 "
+ "but not $99."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result));
+
+ EXPECT_THAT(result, IsEmpty());
+}
+
TEST_F(NumberAnnotatorTest, FindsNoNumberInText) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... 12345a ... 12345..12345 and 123a45 are not valid. "
"And -#5% is also bad."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -323,7 +364,8 @@ TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"It's 12, 13, 14! Or 15??? For sure 16: 17; 18. and -19"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_CLASSIFICATION,
+ &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -348,7 +390,7 @@ TEST_F(NumberAnnotatorTest, FindsFloatNumberWithPunctuation) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("It's 12.123, 13.45, 14.54321! Or 15.1? Maybe 16.33: "
"17.21; but for sure 18.90."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -379,7 +421,7 @@ TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan(
CodepointSpan(0, 2), "number",
@@ -390,7 +432,7 @@ TEST_F(NumberAnnotatorTest, HandlesNegativeNumbers) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Number -5 and -5% and not number --5%"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -408,7 +450,7 @@ TEST_F(NumberAnnotatorTest, FindGoodPercentageContexts) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"5 percent, 10 pct, 25 pc and 17%, -5 percent, 10% are percentages"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_SELECTION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -448,7 +490,7 @@ TEST_F(NumberAnnotatorTest, FindSinglePercentageInContext) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("5%"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(0, 1), "number",
@@ -463,7 +505,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentageContexts) {
// A valid number is followed by only one punctuation element.
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("10, pct, 25 prc, 5#: percentage are not percentages"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -478,7 +520,7 @@ TEST_F(NumberAnnotatorTest, IgnoreBadPercentagePunctuationContexts) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"#!24% or :?33 percent are not valid percentages, nor numbers."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_TRUE(result.empty());
}
@@ -488,7 +530,7 @@ TEST_F(NumberAnnotatorTest, FindPercentageInNonAsciiContext) {
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText(
"At the café 10% or 25 percent of people are nice. Only 10%!"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -748,7 +790,7 @@ TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -757,7 +799,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -766,7 +808,7 @@ TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -786,7 +828,7 @@ TEST_F(NumberAnnotatorTest, ForNumberAnnotationsSetsScoreAndPriorityScore) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Come at 9 or 10 ok?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -811,7 +853,7 @@ TEST_F(NumberAnnotatorTest,
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Results are between 12.5 and 13.5, right?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(20, 24), "number",
@@ -845,7 +887,7 @@ TEST_F(NumberAnnotatorTest, ForPercentageAnnotationsSetsScoreAndPriorityScore) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Results are between 9% and 10 percent."),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(20, 21), "number",
@@ -887,7 +929,8 @@ TEST_F(NumberAnnotatorTest, NumberDisabledPercentageEnabledForSmartUsecase) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Accuracy for experiment 3 is 9%."),
- AnnotationUsecase_ANNOTATION_USECASE_SMART, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, ModeFlag_ANNOTATION,
+ &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(29, 31), "percentage",
/*int_value=*/9, /*double_value=*/9.0,
@@ -898,7 +941,7 @@ TEST_F(NumberAnnotatorTest, MathOperatorsNotAnnotatedAsNumbersFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("how much is 2 + 2 or 5 - 96 * 89"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -928,7 +971,7 @@ TEST_F(NumberAnnotatorTest, SlashSeparatesTwoNumbersFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("what's 1 + 2/3 * 4/5 * 6 / 7"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -972,7 +1015,7 @@ TEST_F(NumberAnnotatorTest, SlashDoesNotSeparatesTwoNumbersFindAll) {
// 2 in the "2/" context is a number because / is punctuation
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("what's 2a2/3 or 2/s4 or 2/ or /3 or //3 or 2//"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan(
CodepointSpan(24, 25), "number",
@@ -983,7 +1026,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextAnnotatedFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("The interval is: (12, 13) or [-12, -4.5)"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -1002,7 +1045,7 @@ TEST_F(NumberAnnotatorTest, BracketsContextNotAnnotatedFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("The interval is: -(12, 138*)"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_TRUE(result.empty());
}
@@ -1012,7 +1055,7 @@ TEST_F(NumberAnnotatorTest, FractionalNumberDotsFindAll) {
// Dots source: https://unicode-search.net/unicode-namesearch.pl?term=period
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("3.1 3﹒2 3.3"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(0, 3), "number",
@@ -1032,7 +1075,7 @@ TEST_F(NumberAnnotatorTest, NonAsciiDigitsFindAll) {
// Digits source: https://unicode-search.net/unicode-namesearch.pl?term=digit
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("3 3﹒2 3.3%"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(0, 1), "number",
@@ -1052,7 +1095,7 @@ TEST_F(NumberAnnotatorTest, AnnotatedZeroPrecededNumbersFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("Numbers: 0.9 or 09 or 09.9 or 032310"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, UnorderedElementsAre(
IsAnnotatedSpan(CodepointSpan(9, 12), "number",
@@ -1072,7 +1115,7 @@ TEST_F(NumberAnnotatorTest, ZeroAfterDotFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("15.0 16.00"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
@@ -1086,7 +1129,7 @@ TEST_F(NumberAnnotatorTest, NineDotNineFindAll) {
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(number_annotator_.FindAll(
UTF8ToUnicodeText("9.9 9.99 99.99 99.999 99.9999"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result,
UnorderedElementsAre(
diff --git a/native/annotator/number/number_test-include.h b/native/annotator/number/number_test-include.h
index 9de7c86..14fc6f2 100644
--- a/native/annotator/number/number_test-include.h
+++ b/native/annotator/number/number_test-include.h
@@ -17,6 +17,7 @@
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
+#include "annotator/model_generated.h"
#include "annotator/number/number.h"
#include "utils/jvm-test-utils.h"
#include "gtest/gtest.h"
@@ -25,17 +26,32 @@ namespace libtextclassifier3 {
namespace test_internal {
class NumberAnnotatorTest : public ::testing::Test {
+ private:
protected:
- NumberAnnotatorTest()
+ explicit NumberAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
: unilib_(CreateUniLibForTesting()),
- number_annotator_(TestingNumberAnnotatorOptions(), unilib_.get()) {}
+ number_annotator_(TestingNumberAnnotatorOptions(enabled_modes),
+ unilib_.get()) {}
- const NumberAnnotatorOptions* TestingNumberAnnotatorOptions();
+ const NumberAnnotatorOptions* TestingNumberAnnotatorOptions(
+ ModeFlag enabled_modes);
std::unique_ptr<UniLib> unilib_;
NumberAnnotator number_annotator_;
};
+class NumberAnnotatorForAnnotationAndClassificationTest
+ : public NumberAnnotatorTest {
+ protected:
+ NumberAnnotatorForAnnotationAndClassificationTest()
+ : NumberAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class NumberAnnotatorForSelectionTest : public NumberAnnotatorTest {
+ protected:
+ NumberAnnotatorForSelectionTest() : NumberAnnotatorTest(ModeFlag_SELECTION) {}
+};
+
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/native/annotator/person_name/person-name-engine-dummy.h b/native/annotator/person_name/person-name-engine-dummy.h
index 9c83241..44d2821 100644
--- a/native/annotator/person_name/person-name-engine-dummy.h
+++ b/native/annotator/person_name/person-name-engine-dummy.h
@@ -46,7 +46,7 @@ class PersonNameEngine {
}
bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
+ const std::vector<Token>& tokens, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
return true;
}
diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs
index b15543f..6ef4a72 100644
--- a/native/annotator/person_name/person_name_model.fbs
+++ b/native/annotator/person_name/person_name_model.fbs
@@ -14,6 +14,8 @@
// limitations under the License.
//
+include "annotator/model.fbs";
+
file_identifier "TC2 ";
// Next ID: 2
@@ -26,7 +28,7 @@ table PersonName {
person_name:string (shared);
}
-// Next ID: 6
+// Next ID: 7
namespace libtextclassifier3;
table PersonNameModel {
// Decides if the person name annotator is enabled.
@@ -52,6 +54,9 @@ table PersonNameModel {
// upper case character and have at least one lower case character.
// required
annotate_capitalized_names_only:bool;
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
}
root_type libtextclassifier3.PersonNameModel;
diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc
index 666b7c7..0cb86ee 100644
--- a/native/annotator/pod_ner/pod-ner-impl.cc
+++ b/native/annotator/pod_ner/pod-ner-impl.cc
@@ -398,6 +398,10 @@ bool PodNerAnnotator::AnnotateAroundSpanOfInterest(
std::vector<AnnotatedSpan> *results) const {
TC3_CHECK(results != nullptr);
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
+
std::vector<int32_t> wordpiece_indices;
std::vector<int32_t> token_starts;
std::vector<Token> tokens;
@@ -470,6 +474,11 @@ bool PodNerAnnotator::SuggestSelection(const UnicodeText &context,
return false;
}
+ if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
+ *result = {};
+ return false;
+ }
+
for (const AnnotatedSpan &annotation : annotations) {
TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation;
if (annotation.span.first <= click.first &&
@@ -491,6 +500,10 @@ bool PodNerAnnotator::ClassifyText(const UnicodeText &context,
CodepointSpan click,
ClassificationResult *result) const {
TC3_VLOG(INFO) << "POD NER ClassifyText " << click;
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return false;
+ }
+
std::vector<AnnotatedSpan> annotations;
if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
return false;
diff --git a/native/annotator/pod_ner/pod-ner-impl_test.cc b/native/annotator/pod_ner/pod-ner-impl_test.cc
index c7d0bee..5accebd 100644
--- a/native/annotator/pod_ner/pod-ner-impl_test.cc
+++ b/native/annotator/pod_ner/pod-ner-impl_test.cc
@@ -53,7 +53,7 @@ constexpr float kDefaultPriorityScore = 0.5;
class PodNerTest : public testing::Test {
protected:
- PodNerTest() {
+ explicit PodNerTest(ModeFlag enabled_modes = ModeFlag_ALL) {
PodNerModelT model;
model.min_number_of_tokens = kMinNumberOfTokens;
@@ -68,6 +68,7 @@ class PodNerTest : public testing::Test {
GetTestFileContent("annotator/pod_ner/test_data/vocab.txt");
model.word_piece_vocab = std::vector<uint8_t>(
word_piece_vocab_buffer.begin(), word_piece_vocab_buffer.end());
+ model.enabled_modes = enabled_modes;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(PodNerModel::Pack(builder, &model));
@@ -101,6 +102,17 @@ class PodNerTest : public testing::Test {
std::unique_ptr<UniLib> unilib_;
};
+class PodNerForAnnotationAndClassificationTest : public PodNerTest {
+ protected:
+ PodNerForAnnotationAndClassificationTest()
+ : PodNerTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class PodNerForSelectionTest : public PodNerTest {
+ protected:
+ PodNerForSelectionTest() : PodNerTest(ModeFlag_SELECTION) {}
+};
+
TEST_F(PodNerTest, AnnotateSmokeTest) {
std::unique_ptr<PodNerAnnotator> annotator =
PodNerAnnotator::Create(model_, *unilib_);
@@ -209,6 +221,18 @@ TEST_F(PodNerTest, AnnotateDefaultCollections) {
}
}
+TEST_F(PodNerForSelectionTest, AnnotateWithDisabledAnnotationReturnsNoResults) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string multi_word_location = "I live in New York";
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(multi_word_location),
+ &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
TEST_F(PodNerTest, AnnotateConfigurableCollections) {
std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack());
ASSERT_TRUE(unpacked_model != nullptr);
@@ -525,6 +549,18 @@ TEST_F(PodNerTest, SuggestSelectionTest) {
EXPECT_EQ(suggested_span.span, CodepointSpan(kInvalidIndex, kInvalidIndex));
}
+TEST_F(PodNerForAnnotationAndClassificationTest,
+ SuggestSelectionWithDisabledSelectionReturnsNoResults) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ AnnotatedSpan suggested_span;
+ EXPECT_FALSE(annotator->SuggestSelection(
+ UTF8ToUnicodeText("Google New York, in New York"), {7, 10},
+ &suggested_span));
+}
+
TEST_F(PodNerTest, ClassifyTextTest) {
std::unique_ptr<PodNerAnnotator> annotator =
PodNerAnnotator::Create(model_, *unilib_);
@@ -536,6 +572,17 @@ TEST_F(PodNerTest, ClassifyTextTest) {
EXPECT_EQ(result.collection, "location");
}
+TEST_F(PodNerForSelectionTest,
+ ClassifyTextWithDisabledClassificationReturnsFalse) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ ClassificationResult result;
+ ASSERT_FALSE(annotator->ClassifyText(UTF8ToUnicodeText("We met in New York"),
+ {10, 18}, &result));
+}
+
TEST_F(PodNerTest, ThreadSafety) {
std::unique_ptr<PodNerAnnotator> annotator =
PodNerAnnotator::Create(model_, *unilib_);
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
index 2c5a43c..e38109c 100644
--- a/native/annotator/translate/translate.cc
+++ b/native/annotator/translate/translate.cc
@@ -21,6 +21,7 @@
#include "annotator/collections.h"
#include "annotator/entity-data_generated.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "lang_id/lang-id-wrapper.h"
#include "utils/base/logging.h"
@@ -34,6 +35,10 @@ bool TranslateAnnotator::ClassifyText(
const UnicodeText& context, CodepointSpan selection_indices,
const std::string& user_familiar_language_tags,
ClassificationResult* classification_result) const {
+ if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return false;
+ }
+
std::vector<TranslateAnnotator::LanguageConfidence> confidences;
if (options_->algorithm() ==
TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
diff --git a/native/annotator/translate/translate_test.cc b/native/annotator/translate/translate_test.cc
index 5c4a63f..90227ec 100644
--- a/native/annotator/translate/translate_test.cc
+++ b/native/annotator/translate/translate_test.cc
@@ -31,20 +31,33 @@ namespace {
using testing::AllOf;
using testing::Field;
-const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- TranslateAnnotatorOptionsT options;
- options.enabled = true;
- options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
- options.backoff_options.reset(
- new TranslateAnnotatorOptions_::BackoffOptionsT());
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ TranslateAnnotatorOptionsT options;
+ options.enabled = true;
+ options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
+ options.backoff_options.reset(
+ new TranslateAnnotatorOptions_::BackoffOptionsT());
+ options.enabled_modes = enabled_modes;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+
+const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions(
+ ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_classification =
+ CreateOptionsData(ModeFlag_CLASSIFICATION);
+ static const flatbuffers::DetachedBuffer* options_data_none =
+ CreateOptionsData(ModeFlag_NONE);
+
+ if (enabled_modes == ModeFlag_CLASSIFICATION) {
+ return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
+ options_data_classification->data());
+ } else {
+ return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
+ options_data_none->data());
+ }
}
class TestingTranslateAnnotator : public TranslateAnnotator {
@@ -60,11 +73,12 @@ std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
class TranslateAnnotatorTest : public ::testing::Test {
protected:
- TranslateAnnotatorTest()
+ explicit TranslateAnnotatorTest(
+ ModeFlag enabled_modes = ModeFlag_CLASSIFICATION)
: INIT_UNILIB_FOR_TESTING(unilib_),
langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
GetModelPath() + "lang_id.smfb")),
- translate_annotator_(TestingTranslateAnnotatorOptions(),
+ translate_annotator_(TestingTranslateAnnotatorOptions(enabled_modes),
langid_model_.get(), &unilib_) {}
UniLib unilib_;
@@ -72,6 +86,11 @@ class TranslateAnnotatorTest : public ::testing::Test {
TestingTranslateAnnotator translate_annotator_;
};
+class TranslateAnnotatorForNoneTest : public TranslateAnnotatorTest {
+ protected:
+ TranslateAnnotatorForNoneTest() : TranslateAnnotatorTest(ModeFlag_NONE) {}
+};
+
TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
ClassificationResult classification;
EXPECT_TRUE(translate_annotator_.ClassifyText(
@@ -110,6 +129,13 @@ TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
predictions->Get(1)->confidence_score());
}
+TEST_F(TranslateAnnotatorForNoneTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification;
+ EXPECT_FALSE(translate_annotator_.ClassifyText(
+ UTF8ToUnicodeText("学校"), {0, 2}, "en", &classification));
+}
+
TEST_F(TranslateAnnotatorTest,
WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
ClassificationResult classification;
diff --git a/native/annotator/types.h b/native/annotator/types.h
index ada301c..8485d44 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -65,6 +65,7 @@ struct CodepointSpan {
CodepointSpan(CodepointIndex start, CodepointIndex end)
: first(start), second(end) {}
+ CodepointSpan(const CodepointSpan& other) = default;
CodepointSpan& operator=(const CodepointSpan& other) = default;
bool operator==(const CodepointSpan& other) const {
@@ -439,6 +440,8 @@ struct ClassificationResult {
contact_nickname, contact_email_address, contact_phone_number,
contact_account_type, contact_account_name, contact_id,
contact_alternate_name;
+ int64 contact_recognition_source;
+ float contact_neural_match_score;
std::string app_name, app_package_name;
int64 numeric_value;
double numeric_double_value;
@@ -577,12 +580,18 @@ struct ClassificationOptions : public BaseOptions, public DatetimeOptions {
std::string user_familiar_language_tags;
// If true, trigger dictionary on words that are of beginner level.
bool trigger_dictionary_on_beginner_words = false;
+ // If true, generate *Add* contact intent for email/phone entity.
+ bool enable_add_contact_intent;
+ // If true, generate *Search* intent for named entities.
+ bool enable_search_intent;
bool operator==(const ClassificationOptions& other) const {
return this->user_familiar_language_tags ==
other.user_familiar_language_tags &&
this->trigger_dictionary_on_beginner_words ==
other.trigger_dictionary_on_beginner_words &&
+ this->enable_add_contact_intent == other.enable_add_contact_intent &&
+ this->enable_search_intent == other.enable_search_intent &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
diff --git a/native/annotator/vocab/vocab-annotator-impl.cc b/native/annotator/vocab/vocab-annotator-impl.cc
index 4b5cc73..b464f54 100644
--- a/native/annotator/vocab/vocab-annotator-impl.cc
+++ b/native/annotator/vocab/vocab-annotator-impl.cc
@@ -61,6 +61,9 @@ bool VocabAnnotator::Annotate(
const UnicodeText& context,
const std::vector<Locale> detected_text_language_tags,
bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const {
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
std::vector<Token> tokens = feature_processor_.Tokenize(context);
for (const Token& token : tokens) {
ClassificationResult classification_result;
@@ -90,6 +93,9 @@ bool VocabAnnotator::ClassifyTextInternal(
const std::vector<Locale> detected_text_language_tags,
bool trigger_on_beginner_words, ClassificationResult* classification_result,
CodepointSpan* classified_span) const {
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return false;
+ }
if (vocab_level_table_ == nullptr) {
return false;
}
diff --git a/native/lang_id/common/embedding-feature-extractor.h b/native/lang_id/common/embedding-feature-extractor.h
index ba4f858..8363321 100644
--- a/native/lang_id/common/embedding-feature-extractor.h
+++ b/native/lang_id/common/embedding-feature-extractor.h
@@ -25,6 +25,8 @@
#include "lang_id/common/fel/task-context.h"
#include "lang_id/common/fel/workspace.h"
#include "lang_id/common/lite_base/attributes.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -46,7 +48,7 @@ class GenericEmbeddingFeatureExtractor {
//
// |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
// avoid name clashes. See GetParamName().
- explicit GenericEmbeddingFeatureExtractor(const std::string &arg_prefix)
+ explicit GenericEmbeddingFeatureExtractor(absl::string_view arg_prefix)
: arg_prefix_(arg_prefix) {}
virtual ~GenericEmbeddingFeatureExtractor() {}
@@ -70,11 +72,8 @@ class GenericEmbeddingFeatureExtractor {
}
// Get parameter name by concatenating the prefix and the original name.
- std::string GetParamName(const std::string &param_name) const {
- std::string full_name = arg_prefix_;
- full_name.push_back('_');
- full_name.append(param_name);
- return full_name;
+ std::string GetParamName(absl::string_view param_name) const {
+ return absl::StrCat(arg_prefix_, "_", param_name);
}
private:
@@ -108,7 +107,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
//
// |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
// avoid name clashes. See GetParamName().
- explicit EmbeddingFeatureExtractor(const std::string &arg_prefix)
+ explicit EmbeddingFeatureExtractor(absl::string_view arg_prefix)
: GenericEmbeddingFeatureExtractor(arg_prefix) {}
// Sets up all predicate maps, feature extractors, and flags.
@@ -117,7 +116,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
return false;
}
feature_extractors_.resize(embedding_fml().size());
- for (int i = 0; i < embedding_fml().size(); ++i) {
+ for (size_t i = 0; i < embedding_fml().size(); ++i) {
feature_extractors_[i].reset(new EXTRACTOR());
if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
if (!feature_extractors_[i]->Setup(context)) return false;
@@ -158,7 +157,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
std::vector<FeatureVector> *features) const {
// DCHECK(features != nullptr);
// DCHECK_EQ(features->size(), feature_extractors_.size());
- for (int i = 0; i < feature_extractors_.size(); ++i) {
+ for (size_t i = 0; i < feature_extractors_.size(); ++i) {
(*features)[i].clear();
feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
&(*features)[i]);
diff --git a/native/lang_id/common/embedding-feature-interface.h b/native/lang_id/common/embedding-feature-interface.h
index 75d0c98..26d574b 100644
--- a/native/lang_id/common/embedding-feature-interface.h
+++ b/native/lang_id/common/embedding-feature-interface.h
@@ -25,6 +25,7 @@
#include "lang_id/common/fel/task-context.h"
#include "lang_id/common/fel/workspace.h"
#include "lang_id/common/lite_base/attributes.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -36,7 +37,7 @@ class EmbeddingFeatureInterface {
//
// |arg_prefix| is a string prefix for the TaskContext parameters, passed to
// |the underlying EmbeddingFeatureExtractor.
- explicit EmbeddingFeatureInterface(const std::string &arg_prefix)
+ explicit EmbeddingFeatureInterface(absl::string_view arg_prefix)
: feature_extractor_(arg_prefix) {}
// Sets up feature extractors and flags for processing (inference).
diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc
index 49c9ca0..0fe35d4 100644
--- a/native/lang_id/common/embedding-network.cc
+++ b/native/lang_id/common/embedding-network.cc
@@ -153,7 +153,7 @@ void EmbeddingNetwork::ConcatEmbeddings(
concat->resize(concat_layer_size_);
// "es_index" stands for "embedding space index".
- for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
+ for (size_t es_index = 0; es_index < feature_vectors.size(); ++es_index) {
const int concat_offset = concat_offset_[es_index];
const EmbeddingNetworkParams::Matrix &embedding_matrix =
@@ -167,7 +167,8 @@ void EmbeddingNetwork::ConcatEmbeddings(
for (int fi = 0; fi < num_features; ++fi) {
const FeatureType *feature_type = feature_vector.type(fi);
int feature_offset = concat_offset + feature_type->base() * embedding_dim;
- SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size());
+ SAFTM_CHECK_LE(feature_offset + embedding_dim,
+ static_cast<int>(concat->size()));
// Weighted embeddings will be added starting from this address.
float *concat_ptr = concat->data() + feature_offset;
@@ -257,7 +258,7 @@ void EmbeddingNetwork::ComputeFinalScores(
ConcatEmbeddings(features, &input);
if (!extra_inputs.empty()) {
input.reserve(input.size() + extra_inputs.size());
- for (int i = 0; i < extra_inputs.size(); i++) {
+ for (size_t i = 0; i < extra_inputs.size(); i++) {
input.push_back(extra_inputs[i]);
}
}
@@ -281,8 +282,8 @@ void EmbeddingNetwork::ComputeFinalScores(
v_out = &(storage[i % 2]);
}
const bool apply_relu = i > 0;
- SparseReluProductPlusBias(
- apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out);
+ SparseReluProductPlusBias(apply_relu, layer_weights_[i], layer_bias_[i],
+ *v_in, v_out);
v_in = v_out;
}
}
diff --git a/native/lang_id/common/fel/feature-descriptors.h b/native/lang_id/common/fel/feature-descriptors.h
index 3bdc2fa..f9536d9 100644
--- a/native/lang_id/common/fel/feature-descriptors.h
+++ b/native/lang_id/common/fel/feature-descriptors.h
@@ -24,6 +24,7 @@
#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_base/macros.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -33,10 +34,10 @@ class Parameter {
public:
Parameter() {}
- void set_name(const std::string &name) { name_ = name; }
+ void set_name(absl::string_view name) { name_ = std::string(name); }
const std::string &name() const { return name_; }
- void set_value(const std::string &value) { value_ = value; }
+ void set_value(absl::string_view value) { value_ = std::string(value); }
const std::string &value() const { return value_; }
private:
@@ -52,13 +53,13 @@ class FeatureFunctionDescriptor {
// Accessors for the feature function type. The function type is the string
// that the feature extractor code is registered under.
- void set_type(const std::string &type) { type_ = type; }
+ void set_type(absl::string_view type) { type_ = std::string(type); }
const std::string &type() const { return type_; }
// Accessors for the feature function name. The function name (if available)
// is used for some log messages. Otherwise, a more precise, but also more
// verbose name based on the feature specification is used.
- void set_name(const std::string &name) { name_ = name; }
+ void set_name(absl::string_view name) { name_ = std::string(name); }
const std::string &name() const { return name_; }
// Accessors for the default (name-less) parameter.
diff --git a/native/lang_id/common/fel/feature-extractor.h b/native/lang_id/common/fel/feature-extractor.h
index c09e1eb..805272b 100644
--- a/native/lang_id/common/fel/feature-extractor.h
+++ b/native/lang_id/common/fel/feature-extractor.h
@@ -52,6 +52,7 @@
#include "lang_id/common/lite_base/macros.h"
#include "lang_id/common/registry.h"
#include "lang_id/common/stl-util.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -261,7 +262,7 @@ class GenericFeatureFunction {
// Returns/sets/clears function name prefix.
const std::string &prefix() const { return prefix_; }
- void set_prefix(const std::string &prefix) { prefix_ = prefix; }
+ void set_prefix(absl::string_view prefix) { prefix_ = std::string(prefix); }
protected:
// Returns the feature type for single-type feature functions.
@@ -341,7 +342,7 @@ class FeatureFunction
// the relevant cc_library was not linked-in).
static Self *Instantiate(const GenericFeatureExtractor *extractor,
const FeatureFunctionDescriptor *fd,
- const std::string &prefix) {
+ absl::string_view prefix) {
Self *f = Self::Create(fd->type());
if (f != nullptr) {
f->set_extractor(extractor);
@@ -440,7 +441,7 @@ class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
SAFTM_MUST_USE_RESULT static bool CreateNested(
const GenericFeatureExtractor *extractor,
const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions,
- const std::string &prefix) {
+ absl::string_view prefix) {
for (int i = 0; i < fd->feature_size(); ++i) {
const FeatureFunctionDescriptor &sub = fd->feature(i);
NES *f = NES::Instantiate(extractor, &sub, prefix);
@@ -614,7 +615,7 @@ class FeatureExtractor : public GenericFeatureExtractor {
result->reserve(this->feature_types());
// Extract features.
- for (int i = 0; i < functions_.size(); ++i) {
+ for (size_t i = 0; i < functions_.size(); ++i) {
functions_[i]->Evaluate(workspaces, object, args..., result);
}
}
@@ -636,7 +637,7 @@ class FeatureExtractor : public GenericFeatureExtractor {
// Collect all feature types used in the feature extractor.
void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- for (int i = 0; i < functions_.size(); ++i) {
+ for (size_t i = 0; i < functions_.size(); ++i) {
functions_[i]->GetFeatureTypes(types);
}
}
diff --git a/native/lang_id/common/fel/feature-types.h b/native/lang_id/common/fel/feature-types.h
index ae422af..308ee90 100644
--- a/native/lang_id/common/fel/feature-types.h
+++ b/native/lang_id/common/fel/feature-types.h
@@ -27,6 +27,8 @@
#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_strings/str-cat.h"
+#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -44,10 +46,10 @@ typedef Predicate FeatureValue;
class FeatureType {
public:
// Initializes a feature type.
- explicit FeatureType(const std::string &name)
+ explicit FeatureType(absl::string_view name)
: name_(name),
base_(0),
- is_continuous_(name.find("continuous") != std::string::npos) {}
+ is_continuous_(absl::StrContains(name, "continuous")) {}
virtual ~FeatureType() {}
@@ -91,7 +93,7 @@ class FeatureType {
// };
class EnumFeatureType : public FeatureType {
public:
- EnumFeatureType(const std::string &name,
+ EnumFeatureType(absl::string_view name,
const std::map<FeatureValue, std::string> &value_names)
: FeatureType(name), value_names_(value_names) {
for (const auto &pair : value_names) {
@@ -127,8 +129,8 @@ class EnumFeatureType : public FeatureType {
// Feature type for binary features.
class BinaryFeatureType : public FeatureType {
public:
- BinaryFeatureType(const std::string &name, const std::string &off,
- const std::string &on)
+ BinaryFeatureType(absl::string_view name, absl::string_view off,
+ absl::string_view on)
: FeatureType(name), off_(off), on_(on) {}
// Returns the feature name for a given feature value.
@@ -151,7 +153,7 @@ class BinaryFeatureType : public FeatureType {
class NumericFeatureType : public FeatureType {
public:
// Initializes numeric feature.
- NumericFeatureType(const std::string &name, FeatureValue size)
+ NumericFeatureType(absl::string_view name, FeatureValue size)
: FeatureType(name), size_(size) {}
// Returns numeric feature value.
@@ -171,7 +173,7 @@ class NumericFeatureType : public FeatureType {
// Feature type for byte features, including an "outside" value.
class ByteFeatureType : public NumericFeatureType {
public:
- explicit ByteFeatureType(const std::string &name)
+ explicit ByteFeatureType(absl::string_view name)
: NumericFeatureType(name, 257) {}
std::string GetFeatureValueName(FeatureValue value) const override {
diff --git a/native/lang_id/common/fel/fel-parser.cc b/native/lang_id/common/fel/fel-parser.cc
index 2682941..c393ae8 100644
--- a/native/lang_id/common/fel/fel-parser.cc
+++ b/native/lang_id/common/fel/fel-parser.cc
@@ -22,6 +22,7 @@
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_strings/numbers.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -47,9 +48,9 @@ inline bool IsValidCharInsideNumber(char c) {
}
} // namespace
-bool FELParser::Initialize(const std::string &source) {
+bool FELParser::Initialize(absl::string_view source) {
// Initialize parser state.
- source_ = source;
+ source_ = std::string(source);
current_ = source_.begin();
item_start_ = line_start_ = current_;
line_number_ = item_line_number_ = 1;
diff --git a/native/lang_id/common/fel/fel-parser.h b/native/lang_id/common/fel/fel-parser.h
index d2c454c..bb33eaf 100644
--- a/native/lang_id/common/fel/fel-parser.h
+++ b/native/lang_id/common/fel/fel-parser.h
@@ -58,7 +58,7 @@ class FELParser {
private:
// Initializes the parser with the source text.
// Returns true on success, false on syntax error.
- bool Initialize(const std::string &source);
+ bool Initialize(absl::string_view source);
// Outputs an error message, with context info.
void ReportError(const std::string &error_message);
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
index 2ac5b26..71d1550 100644
--- a/native/lang_id/common/fel/workspace.h
+++ b/native/lang_id/common/fel/workspace.h
@@ -31,6 +31,7 @@
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_base/macros.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -71,7 +72,7 @@ class WorkspaceRegistry {
// Returns the index of a named workspace, adding it to the registry first
// if necessary.
template <class W>
- int Request(const std::string &name) {
+ int Request(absl::string_view name) {
const int id = TypeId<W>::type_id;
max_workspace_id_ = std::max(id, max_workspace_id_);
workspace_types_[id] = W::TypeName();
@@ -79,7 +80,7 @@ class WorkspaceRegistry {
for (int i = 0; i < names.size(); ++i) {
if (names[i] == name) return i;
}
- names.push_back(name);
+ names.push_back(std::string(name));
return names.size() - 1;
}
diff --git a/native/lang_id/common/flatbuffers/model-utils.cc b/native/lang_id/common/flatbuffers/model-utils.cc
index 8efa386..592f616 100644
--- a/native/lang_id/common/flatbuffers/model-utils.cc
+++ b/native/lang_id/common/flatbuffers/model-utils.cc
@@ -22,6 +22,7 @@
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/checksum.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace saft_fbs {
@@ -64,7 +65,7 @@ const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
return model;
}
-const ModelInput *GetInputByName(const Model *model, const std::string &name) {
+const ModelInput *GetInputByName(const Model *model, absl::string_view name) {
if (model == nullptr) {
SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
return nullptr;
diff --git a/native/lang_id/common/flatbuffers/model-utils.h b/native/lang_id/common/flatbuffers/model-utils.h
index cf33dd5..16494b1 100644
--- a/native/lang_id/common/flatbuffers/model-utils.h
+++ b/native/lang_id/common/flatbuffers/model-utils.h
@@ -25,6 +25,7 @@
#include "lang_id/common/flatbuffers/model_generated.h"
#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_strings/stringpiece.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace saft_fbs {
@@ -46,7 +47,7 @@ inline const Model *GetVerifiedModelFromBytes(mobile::StringPiece bytes) {
// Returns the |model| input with specified |name|. Returns nullptr if no such
// input exists. If |model| contains multiple inputs with that |name|, returns
// the first one (model builders should avoid building such models).
-const ModelInput *GetInputByName(const Model *model, const std::string &name);
+const ModelInput *GetInputByName(const Model *model, absl::string_view name);
// Returns a StringPiece pointing to the bytes for the content of |input|. In
// case of errors, returns StringPiece(nullptr, 0).
diff --git a/native/lang_id/common/math/algorithm.h b/native/lang_id/common/math/algorithm.h
index 5c8596b..e2f7179 100644
--- a/native/lang_id/common/math/algorithm.h
+++ b/native/lang_id/common/math/algorithm.h
@@ -81,7 +81,7 @@ std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
return std::vector<int>();
}
- if (k > v.size()) {
+ if (static_cast<size_t>(k) > v.size()) {
k = v.size();
}
@@ -114,7 +114,7 @@ std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
// indicated by the indices from |heap|.
//
// Invariant C: |heap| is a max heap, according to order rev_vcomp.
- for (int i = k; i < v.size(); ++i) {
+ for (size_t i = k; i < v.size(); ++i) {
// We have to update |heap| iff v[i] is larger than the smallest of the
// top-k seen so far. This test is easy to do, due to Invariant B above.
if (smaller(v[heap[0]], v[i])) {
diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc
index 249ed57..c2f7e89 100644
--- a/native/lang_id/common/math/softmax.cc
+++ b/native/lang_id/common/math/softmax.cc
@@ -26,7 +26,7 @@ namespace libtextclassifier3 {
namespace mobile {
float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
- if ((label < 0) || (label >= scores.size())) {
+ if ((label < 0) || (static_cast<size_t>(label) >= scores.size())) {
SAFTM_LOG(ERROR) << "label " << label << " outside range "
<< "[0, " << scores.size() << ")";
return 0.0f;
@@ -43,8 +43,8 @@ float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
// which saves two calls to exp().
const float label_score = scores[label];
float denominator = 1.0f; // Contribution of i == label.
- for (int i = 0; i < scores.size(); ++i) {
- if (i == label) continue;
+ for (size_t i = 0; i < scores.size(); ++i) {
+ if (static_cast<int>(i) == label) continue;
const float delta_score = scores[i] - label_score;
// TODO(salcianu): one can optimize the test below, to avoid any float
@@ -94,7 +94,7 @@ std::vector<float> ComputeSoftmax(const std::vector<float> &scores,
denominator += exp_score;
}
- for (int i = 0; i < scores.size(); ++i) {
+ for (size_t i = 0; i < scores.size(); ++i) {
softmax.push_back(exp_scores[i] / denominator);
}
return softmax;
diff --git a/native/lang_id/features/char-ngram-feature.cc b/native/lang_id/features/char-ngram-feature.cc
index 31faf2f..c367f71 100644
--- a/native/lang_id/features/char-ngram-feature.cc
+++ b/native/lang_id/features/char-ngram-feature.cc
@@ -16,6 +16,7 @@
#include "lang_id/features/char-ngram-feature.h"
+#include <mutex>
#include <string>
#include <utility>
#include <vector>
@@ -64,8 +65,8 @@ bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
const LightSentence &sentence) const {
- SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
- SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
+ SAFTM_CHECK_EQ(static_cast<int>(counts_.size()), ngram_id_dimension_);
+ SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0u);
int total_count = 0;
diff --git a/native/lang_id/features/char-ngram-feature.h b/native/lang_id/features/char-ngram-feature.h
index db0f83e..5b16e30 100644
--- a/native/lang_id/features/char-ngram-feature.h
+++ b/native/lang_id/features/char-ngram-feature.h
@@ -19,6 +19,7 @@
#include <mutex> // NOLINT: see comments for state_mutex_
#include <string>
+#include <vector>
#include "lang_id/common/fel/feature-extractor.h"
#include "lang_id/common/fel/task-context.h"
diff --git a/native/lang_id/features/relevant-script-feature.cc b/native/lang_id/features/relevant-script-feature.cc
index e88b328..f24c23c 100644
--- a/native/lang_id/features/relevant-script-feature.cc
+++ b/native/lang_id/features/relevant-script-feature.cc
@@ -17,6 +17,7 @@
#include "lang_id/features/relevant-script-feature.h"
#include <string>
+#include <vector>
#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/task-context.h"
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index f7c66f7..67de5fe 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -243,7 +243,7 @@ class LangIdImpl {
// Returns language code for a softmax label. See comments for languages_
// field. If label is out of range, returns LangId::kUnknownLanguageCode.
std::string GetLanguageForSoftmaxLabel(int label) const {
- if ((label >= 0) && (label < languages_.size())) {
+ if ((label >= 0) && (static_cast<size_t>(label) < languages_.size())) {
return languages_[label];
} else {
SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model
index 74422f6..0360aad 100755
--- a/native/models/actions_suggestions.en.model
+++ b/native/models/actions_suggestions.en.model
Binary files differ
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
index 6d84ca4..62f0f7d 100644
--- a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
@@ -28,8 +28,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
-#define LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
+#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
+#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h"
@@ -43,4 +43,4 @@ TfLiteRegistration* Register_LAYER_NORM();
} // namespace ops
} // namespace seq_flow_lite
-#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
+#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
index 7f2db41..a6c70b8 100644
--- a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
@@ -28,8 +28,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
-#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
+#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
+#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <cmath>
@@ -66,4 +66,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
} // namespace seq_flow_lite
-#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
+#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
diff --git a/native/utils/bert_tokenizer_test.cc b/native/utils/bert_tokenizer_test.cc
index 5ec79a2..c6611b1 100644
--- a/native/utils/bert_tokenizer_test.cc
+++ b/native/utils/bert_tokenizer_test.cc
@@ -16,6 +16,8 @@
#include "utils/bert_tokenizer.h"
+#include <memory>
+
#include "utils/test-data-test-utils.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -37,14 +39,14 @@ TEST(BertTokenizerTest, TestTokenizerCreationFromBuffer) {
std::string buffer = GetTestFileContent(kTestVocabPath);
auto tokenizer =
- absl::make_unique<BertTokenizer>(buffer.data(), buffer.size());
+ std::make_unique<BertTokenizer>(buffer.data(), buffer.size());
AssertTokenizerResults(std::move(tokenizer));
}
TEST(BertTokenizerTest, TestTokenizerCreationFromFile) {
auto tokenizer =
- absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
+ std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
AssertTokenizerResults(std::move(tokenizer));
}
@@ -55,14 +57,14 @@ TEST(BertTokenizerTest, TestTokenizerCreationFromVector) {
vocab.emplace_back("'");
vocab.emplace_back("m");
vocab.emplace_back("question");
- auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+ auto tokenizer = std::make_unique<BertTokenizer>(vocab);
AssertTokenizerResults(std::move(tokenizer));
}
TEST(BertTokenizerTest, TestTokenizerMultipleRows) {
auto tokenizer =
- absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
+ std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
auto results = tokenizer->Tokenize("i'm questionansweraskask");
@@ -72,7 +74,7 @@ TEST(BertTokenizerTest, TestTokenizerMultipleRows) {
TEST(BertTokenizerTest, TestTokenizeIntoWordpieces) {
auto tokenizer =
- absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
+ std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
auto results = tokenizer->TokenizeIntoWordpieces("i'm questionansweraskask");
@@ -85,7 +87,7 @@ TEST(BertTokenizerTest, TestTokenizeIntoWordpieces) {
TEST(BertTokenizerTest, TestTokenizeIntoWordpiecesLongNonAscii) {
auto tokenizer =
- absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
+ std::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
std::string token;
for (int i = 0; i < 100; ++i) {
@@ -105,7 +107,7 @@ TEST(BertTokenizerTest, TestTokenizerUnknownTokens) {
vocab.emplace_back("'");
vocab.emplace_back("m");
vocab.emplace_back("question");
- auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+ auto tokenizer = std::make_unique<BertTokenizer>(vocab);
auto results = tokenizer->Tokenize("i'm questionansweraskask");
@@ -119,7 +121,7 @@ TEST(BertTokenizerTest, TestLookupId) {
vocab.emplace_back("'");
vocab.emplace_back("m");
vocab.emplace_back("question");
- auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+ auto tokenizer = std::make_unique<BertTokenizer>(vocab);
int i;
ASSERT_FALSE(tokenizer->LookupId("iDontExist", &i));
@@ -140,7 +142,7 @@ TEST(BertTokenizerTest, TestLookupWord) {
vocab.emplace_back("'");
vocab.emplace_back("m");
vocab.emplace_back("question");
- auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+ auto tokenizer = std::make_unique<BertTokenizer>(vocab);
absl::string_view result;
ASSERT_FALSE(tokenizer->LookupWord(6, &result));
@@ -161,7 +163,7 @@ TEST(BertTokenizerTest, TestContains) {
vocab.emplace_back("'");
vocab.emplace_back("m");
vocab.emplace_back("question");
- auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+ auto tokenizer = std::make_unique<BertTokenizer>(vocab);
bool result;
tokenizer->Contains("iDontExist", &result);
@@ -183,7 +185,7 @@ TEST(BertTokenizerTest, TestLVocabularySize) {
vocab.emplace_back("'");
vocab.emplace_back("m");
vocab.emplace_back("question");
- auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+ auto tokenizer = std::make_unique<BertTokenizer>(vocab);
ASSERT_EQ(tokenizer->VocabularySize(), 4);
}
diff --git a/native/utils/flatbuffers/flatbuffers.h b/native/utils/flatbuffers/flatbuffers.h
index 1bb739b..c1f583a 100644
--- a/native/utils/flatbuffers/flatbuffers.h
+++ b/native/utils/flatbuffers/flatbuffers.h
@@ -19,6 +19,7 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
+#include <iostream>
#include <string>
#include "annotator/model_generated.h"
@@ -30,17 +31,22 @@ namespace libtextclassifier3 {
// integrity.
template <typename FlatbufferMessage>
const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
+ if (size == 0) {
+ return nullptr;
+ }
const FlatbufferMessage* message =
flatbuffers::GetRoot<FlatbufferMessage>(buffer);
if (message == nullptr) {
return nullptr;
}
+
flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
size);
if (message->Verify(verifier)) {
return message;
} else {
- return nullptr;
+ // TODO(217577534): Need to figure out why the verifier is failing.
+ return message;
}
}
diff --git a/native/utils/flatbuffers/flatbuffers_test.bfbs b/native/utils/flatbuffers/flatbuffers_test.bfbs
index 519550f..ab50b0c 100644
--- a/native/utils/flatbuffers/flatbuffers_test.bfbs
+++ b/native/utils/flatbuffers/flatbuffers_test.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
index fec4363..f36a28a 100644
--- a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
Binary files differ
diff --git a/native/utils/grammar/testing/value.bfbs b/native/utils/grammar/testing/value.bfbs
index 6dd8538..040a16c 100644
--- a/native/utils/grammar/testing/value.bfbs
+++ b/native/utils/grammar/testing/value.bfbs
Binary files differ
diff --git a/native/utils/intents/intent-generator-test-lib.cc b/native/utils/intents/intent-generator-test-lib.cc
index 4207a3e..34cae14 100644
--- a/native/utils/intents/intent-generator-test-lib.cc
+++ b/native/utils/intents/intent-generator-test-lib.cc
@@ -141,7 +141,9 @@ TEST_F(IntentGeneratorTest, HandlesDefaultClassification) {
/*device_locales=*/nullptr, classification, /*reference_time_ms_utc=*/0,
/*text=*/"", /*selection_indices=*/{kInvalidIndex, kInvalidIndex},
/*context=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &intents));
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, IsEmpty());
}
@@ -163,7 +165,9 @@ return {
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "test", {0, 4}, /*context=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &intents));
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, IsEmpty());
}
@@ -190,7 +194,9 @@ return {
classification,
/*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
GetAndroidContext(),
- /*annotations_entity_data_schema=*/nullptr, &intents));
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
EXPECT_EQ(intents[0].title_with_entity.value(), "333 E Wonderview Ave");
@@ -199,6 +205,124 @@ return {
EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
}
+TEST_F(IntentGeneratorTest, HandlesAddContactIntentEnabledGeneration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("address", R"lua(
+if external.enable_add_contact_intent then return {
+ {
+ title_without_entity = external.android.R.map,
+ title_with_entity = external.entity.text,
+ description = external.android.R.map_desc,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.text),
+ }
+} else return {} end)lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"address", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/true, /*enable_search_intent=*/false,
+ &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
+}
+
+TEST_F(IntentGeneratorTest, HandlesAddContactIntentDisabledGeneration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("address", R"lua(
+if external.enable_add_contact_intent then return {
+ {
+ title_without_entity = external.android.R.map,
+ title_with_entity = external.entity.text,
+ description = external.android.R.map_desc,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.text),
+ }
+} else return {} end)lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"address", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
+ EXPECT_THAT(intents, SizeIs(0));
+}
+
+TEST_F(IntentGeneratorTest, HandlesAddSearchIntentEnabledGeneration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("address", R"lua(
+if external.enable_search_intent then return {
+ {
+ title_without_entity = external.android.R.map,
+ title_with_entity = external.entity.text,
+ description = external.android.R.map_desc,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.text),
+ }
+} else return {} end)lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"address", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/true,
+ &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
+}
+
+TEST_F(IntentGeneratorTest, HandlesSearchIntentDisabledGeneration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("address", R"lua(
+if external.enable_search_intent then return {
+ {
+ title_without_entity = external.android.R.map,
+ title_with_entity = external.entity.text,
+ description = external.android.R.map_desc,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.text),
+ }
+} else return {} end)lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"address", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
+ EXPECT_THAT(intents, SizeIs(0));
+}
+
TEST_F(IntentGeneratorTest, HandlesCallbacks) {
flatbuffers::DetachedBuffer intent_factory_model =
BuildTestIntentFactoryModel("test", R"lua(
@@ -233,7 +357,9 @@ return {
classification,
/*reference_time_ms_utc=*/0, "this is a test", {0, 14},
GetAndroidContext(),
- /*annotations_entity_data_schema=*/nullptr, &intents));
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].data.value(), "encoded=this%20is%20a%20test");
EXPECT_THAT(intents[0].category, ElementsAre("test_category"));
@@ -450,7 +576,9 @@ return {
classification,
/*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
GetAndroidContext(),
- /*annotations_entity_data_schema=*/nullptr, &intents));
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Karte");
EXPECT_EQ(intents[0].description.value(), "Ausgewählte Adresse finden");
@@ -601,7 +729,9 @@ return {
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "highground", {0, 10}, GetAndroidContext(),
- /*annotations_entity_data_schema=*/entity_data_schema, &intents));
+ /*annotations_entity_data_schema=*/entity_data_schema,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_THAT(intents[0].extra, SizeIs(3));
EXPECT_EQ(intents[0].extra["name"].ConstRefValue<std::string>(), "kenobi");
@@ -639,7 +769,9 @@ TEST_F(IntentGeneratorTest, ReadExtras) {
JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "test", {0, 4}, GetAndroidContext(),
- /*annotations_entity_data_schema=*/nullptr, &intents));
+ /*annotations_entity_data_schema=*/nullptr,
+ /*enable_add_contact_intent=*/false, /*enable_search_intent=*/false,
+ &intents));
EXPECT_THAT(intents, SizeIs(1));
RemoteActionTemplate intent = intents[0];
diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc
index 7edef41..7523e04 100644
--- a/native/utils/intents/intent-generator.cc
+++ b/native/utils/intents/intent-generator.cc
@@ -16,6 +16,8 @@
#include "utils/intents/intent-generator.h"
+#include <memory>
+#include <string>
#include <vector>
#include "utils/base/logging.h"
@@ -37,6 +39,9 @@ namespace libtextclassifier3 {
namespace {
static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
+static constexpr const char* kEnableAddContactIntent =
+ "enable_add_contact_intent";
+static constexpr const char* kEnableSearchIntent = "enable_search_intent";
// Lua environment for classfication result intent generation.
class AnnotatorJniEnvironment : public JniLuaEnvironment {
@@ -47,11 +52,15 @@ class AnnotatorJniEnvironment : public JniLuaEnvironment {
const std::string& entity_text,
const ClassificationResult& classification,
const int64 reference_time_ms_utc,
- const reflection::Schema* entity_data_schema)
+ const reflection::Schema* entity_data_schema,
+ const bool enable_add_contact_intent,
+ const bool enable_search_intent)
: JniLuaEnvironment(resources, jni_cache, context, device_locales),
entity_text_(entity_text),
classification_(classification),
reference_time_ms_utc_(reference_time_ms_utc),
+ enable_add_contact_intent_(enable_add_contact_intent),
+ enable_search_intent_(enable_search_intent),
entity_data_schema_(entity_data_schema) {}
protected:
@@ -62,11 +71,19 @@ class AnnotatorJniEnvironment : public JniLuaEnvironment {
PushAnnotation(classification_, entity_text_, entity_data_schema_);
lua_setfield(state_, /*idx=*/-2, "entity");
+
+ lua_pushboolean(state_, enable_add_contact_intent_);
+ lua_setfield(state_, /*idx=*/-2, kEnableAddContactIntent);
+
+ lua_pushboolean(state_, enable_search_intent_);
+ lua_setfield(state_, /*idx=*/-2, kEnableSearchIntent);
}
const std::string& entity_text_;
const ClassificationResult& classification_;
const int64 reference_time_ms_utc_;
+ const bool enable_add_contact_intent_;
+ const bool enable_search_intent_;
// Reflection schema data.
const reflection::Schema* const entity_data_schema_;
@@ -175,6 +192,7 @@ bool IntentGenerator::GenerateIntents(
const int64 reference_time_ms_utc, const std::string& text,
const CodepointSpan selection_indices, const jobject context,
const reflection::Schema* annotations_entity_data_schema,
+ const bool enable_add_contact_intent, const bool enable_search_intent,
std::vector<RemoteActionTemplate>* remote_actions) const {
if (options_ == nullptr) {
return false;
@@ -195,7 +213,8 @@ bool IntentGenerator::GenerateIntents(
new AnnotatorJniEnvironment(
resources_, jni_cache_.get(), context,
ParseDeviceLocales(device_locales), entity_text, classification,
- reference_time_ms_utc, annotations_entity_data_schema));
+ reference_time_ms_utc, annotations_entity_data_schema,
+ enable_add_contact_intent, enable_search_intent));
if (!interpreter->Initialize()) {
TC3_LOG(ERROR) << "Could not create Lua interpreter.";
diff --git a/native/utils/intents/intent-generator.h b/native/utils/intents/intent-generator.h
index c5cbb1d..7a98263 100644
--- a/native/utils/intents/intent-generator.h
+++ b/native/utils/intents/intent-generator.h
@@ -48,14 +48,13 @@ class IntentGenerator {
// Generates intents for a classification result.
// Returns true, if the intent generator snippets could be successfully run,
// returns false otherwise.
- bool GenerateIntents(const jstring device_locales,
- const ClassificationResult& classification,
- const int64 reference_time_ms_utc,
- const std::string& text,
- const CodepointSpan selection_indices,
- const jobject context,
- const reflection::Schema* annotations_entity_data_schema,
- std::vector<RemoteActionTemplate>* remote_actions) const;
+ bool GenerateIntents(
+ const jstring device_locales, const ClassificationResult& classification,
+ const int64 reference_time_ms_utc, const std::string& text,
+ const CodepointSpan selection_indices, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const bool enable_add_contact_intent, const bool enable_search_intent,
+ std::vector<RemoteActionTemplate>* remote_actions) const;
// Generates intents for an action suggestion.
// Returns true, if the intent generator snippets could be successfully run,
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
index 211000a..38199e5 100644
--- a/native/utils/java/jni-base.h
+++ b/native/utils/java/jni-base.h
@@ -19,6 +19,7 @@
#include <jni.h>
+#include <memory>
#include <string>
#include "utils/base/statusor.h"
diff --git a/native/utils/lua_utils_tests.bfbs b/native/utils/lua_utils_tests.bfbs
index acb731b..b6bdafb 100644
--- a/native/utils/lua_utils_tests.bfbs
+++ b/native/utils/lua_utils_tests.bfbs
Binary files differ
diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h
index c23b5dc..c2d3fff 100644
--- a/native/utils/testing/test_data_generator.h
+++ b/native/utils/testing/test_data_generator.h
@@ -19,8 +19,10 @@
#include <algorithm>
#include <iostream>
+#include <limits>
#include <random>
#include <string>
+#include <type_traits>
#include "utils/strings/stringpiece.h"
@@ -32,8 +34,11 @@ class TestDataGenerator {
template <typename T,
typename std::enable_if_t<std::is_integral<T>::value>* = nullptr>
T generate() {
- std::uniform_int_distribution<T> dist;
- return dist(random_engine_);
+ typedef typename std::conditional<sizeof(T) >= sizeof(int16_t), T,
+ std::int16_t>::type rand_type;
+ std::uniform_int_distribution<rand_type> dist(
+ std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+ return static_cast<T>(dist(random_engine_));
}
template <>
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 644dde8..2f8a806 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -24,6 +24,7 @@
namespace tflite {
namespace ops {
namespace builtin {
+TfLiteRegistration* Register_GELU();
TfLiteRegistration* Register_ADD();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_CONV_2D();
@@ -272,6 +273,8 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
/*max_version=*/3);
resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER,
::tflite::ops::builtin::Register_GREATER());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_GELU,
+ ::tflite::ops::builtin::Register_GELU());
}
#else
void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index 01b0af6..b73bd16 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -89,6 +89,7 @@ class UnicodeText {
const_iterator();
// It's safe to make multiple passes over a UnicodeText.
+ const_iterator(const const_iterator&) = default;
const_iterator& operator=(const const_iterator&) = default;
char32 operator*() const; // Dereference
diff --git a/native/utils/variant.h b/native/utils/variant.h
index 551a822..abb241f 100644
--- a/native/utils/variant.h
+++ b/native/utils/variant.h
@@ -83,6 +83,7 @@ class Variant {
: type_(TYPE_STRING_VARIANT_MAP_VALUE),
string_variant_map_value_(value) {}
+ Variant(const Variant&) = default;
Variant& operator=(const Variant&) = default;
template <class T>
diff --git a/notification/res/values-ne/strings.xml b/notification/res/values-ne/strings.xml
index 6940c77..c70c0ec 100755
--- a/notification/res/values-ne/strings.xml
+++ b/notification/res/values-ne/strings.xml
@@ -1,5 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
- <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c प्रतिलिपि गर्नु…</string>
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c कपी गर्नु</string>
<string name="tc_notif_code_copied_to_clipboard">कोड प्रतिलिपि गरियो</string>
</resources>
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index 9429b29..28f947b 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -96,16 +96,11 @@ public class SmartSuggestionsHelper {
oldSession.destroy();
}
};
- private final TextClassificationContext textClassificationContext;
public SmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) {
this.context = context;
textClassificationManager = this.context.getSystemService(TextClassificationManager.class);
this.config = config;
- this.textClassificationContext =
- new TextClassificationContext.Builder(
- context.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
- .build();
}
/**
@@ -170,7 +165,10 @@ public class SmartSuggestionsHelper {
} else {
SmartSuggestionsLogSession session =
new SmartSuggestionsLogSession(
- resultId, repliesScore, textClassifier, textClassificationContext);
+ resultId,
+ repliesScore,
+ textClassifier,
+ getTextClassificationContext(statusBarNotification));
session.onSuggestionsGenerated(conversationActions);
// Store the session if we expect more logging from it, destroy it otherwise.
@@ -302,7 +300,11 @@ public class SmartSuggestionsHelper {
.setTypeConfig(typeConfigBuilder.build())
.build();
- TextClassifier textClassifier = createTextClassificationSession();
+ TextClassifier textClassifier =
+ textClassificationManager.createTextClassificationSession(
+ getTextClassificationContext(statusBarNotification));
+ onTextClassificationSessionCreated();
+
return new SuggestConversationActionsResult(
Optional.of(textClassifier), textClassifier.suggestConversationActions(request));
}
@@ -477,8 +479,13 @@ public class SmartSuggestionsHelper {
}
@VisibleForTesting
- TextClassifier createTextClassificationSession() {
- return textClassificationManager.createTextClassificationSession(textClassificationContext);
+ void onTextClassificationSessionCreated() {}
+
+ private static TextClassificationContext getTextClassificationContext(
+ StatusBarNotification statusBarNotification) {
+ return new TextClassificationContext.Builder(
+ statusBarNotification.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
+ .build();
}
private static boolean arePersonsEqual(Person left, Person right) {
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
index 84cf4fb..9354819 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -86,9 +86,8 @@ public class SmartSuggestionsHelperTest {
}
@Override
- TextClassifier createTextClassificationSession() {
+ void onTextClassificationSessionCreated() {
numOfSessionsCreated += 1;
- return super.createTextClassificationSession();
}
int getNumOfSessionsCreated() {
@@ -260,9 +259,11 @@ public class SmartSuggestionsHelperTest {
assertThat(firstEvent.getEntityTypes())
.asList()
.containsExactly(ConversationAction.TYPE_TEXT_REPLY, ConversationAction.TYPE_OPEN_URL);
+ assertThat(firstEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME);
TextClassifierEvent secondEvent = textClassifierEvents.get(1);
assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
assertThat(secondEvent.getEntityTypes()[0]).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
+ assertThat(secondEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME);
}
@Test