diff options
author | android-build-team Robot <android-build-team-robot@google.com> | 2021-05-01 03:04:43 +0000 |
---|---|---|
committer | android-build-team Robot <android-build-team-robot@google.com> | 2021-05-01 03:04:43 +0000 |
commit | c9d882c18aa0cc12bb74ed0f13e0d6a45f931911 (patch) | |
tree | 3f94817a4e3b75f6708bbbd1a4bf10794b481823 | |
parent | 9334ab60de3b14c64e10733550bc26d6910b83e1 (diff) | |
parent | 97265d0e193cec31f858b14c8c1d8d2e0f0fda8a (diff) | |
download | libtextclassifier-c9d882c18aa0cc12bb74ed0f13e0d6a45f931911.tar.gz |
Snap for 7328689 from 97265d0e193cec31f858b14c8c1d8d2e0f0fda8a to sc-v2-release
Change-Id: I308902c887ded131677ee1ca3dfcccc01eec94e2
62 files changed, 457 insertions, 321 deletions
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java index 378d842..4ca058d 100644 --- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java +++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java @@ -25,11 +25,14 @@ import android.service.textclassifier.TextClassifierService; import android.view.textclassifier.ConversationActions; import android.view.textclassifier.SelectionEvent; import android.view.textclassifier.TextClassification; +import android.view.textclassifier.TextClassificationContext; import android.view.textclassifier.TextClassificationSessionId; import android.view.textclassifier.TextClassifierEvent; import android.view.textclassifier.TextLanguage; import android.view.textclassifier.TextLinks; import android.view.textclassifier.TextSelection; +import androidx.annotation.NonNull; +import androidx.collection.LruCache; import com.android.textclassifier.common.ModelFileManager; import com.android.textclassifier.common.TextClassifierServiceExecutors; import com.android.textclassifier.common.TextClassifierSettings; @@ -45,8 +48,10 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import java.io.FileDescriptor; import java.io.PrintWriter; +import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.Executor; +import javax.annotation.Nullable; /** An implementation of a TextClassifierService. */ public final class DefaultTextClassifierService extends TextClassifierService { @@ -60,6 +65,7 @@ public final class DefaultTextClassifierService extends TextClassifierService { private TextClassifierSettings settings; private ModelFileManager modelFileManager; private BroadcastReceiver localeChangedReceiver; + private LruCache<TextClassificationSessionId, TextClassificationContext> sessionIdToContext; public DefaultTextClassifierService() { this.injector = new InjectorImpl(this); @@ -82,6 +88,7 @@ public final class DefaultTextClassifierService extends TextClassifierService { lowPriorityExecutor = injector.createLowPriorityExecutor(); textClassifier = injector.createTextClassifierImpl(settings, modelFileManager); localeChangedReceiver = new LocaleChangedReceiver(modelFileManager); + sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize()); textClassifierApiUsageLogger = injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor); @@ -97,13 +104,26 @@ public final class DefaultTextClassifierService extends TextClassifierService { } @Override + public void onCreateTextClassificationSession( + @NonNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId) { + sessionIdToContext.put(sessionId, context); + } + + @Override + public void onDestroyTextClassificationSession(@NonNull TextClassificationSessionId sessionId) { + sessionIdToContext.remove(sessionId); + } + + @Override public void onSuggestSelection( TextClassificationSessionId sessionId, TextSelection.Request request, CancellationSignal cancellationSignal, Callback<TextSelection> callback) { handleRequestAsync( - () -> textClassifier.suggestSelection(request), + () -> + textClassifier.suggestSelection( + sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId), @@ -117,7 +137,9 @@ public final class DefaultTextClassifierService extends TextClassifierService { CancellationSignal cancellationSignal, Callback<TextClassification> callback) { handleRequestAsync( - () -> textClassifier.classifyText(request), + () -> + textClassifier.classifyText( + sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId), @@ -131,7 +153,9 @@ public final class DefaultTextClassifierService extends TextClassifierService { CancellationSignal cancellationSignal, Callback<TextLinks> callback) { handleRequestAsync( - () -> textClassifier.generateLinks(request), + () -> + textClassifier.generateLinks( + sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId), @@ -145,7 +169,9 @@ public final class DefaultTextClassifierService extends TextClassifierService { CancellationSignal cancellationSignal, Callback<ConversationActions> callback) { handleRequestAsync( - () -> textClassifier.suggestConversationActions(request), + () -> + textClassifier.suggestConversationActions( + sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId), @@ -159,7 +185,9 @@ public final class DefaultTextClassifierService extends TextClassifierService { CancellationSignal cancellationSignal, Callback<TextLanguage> callback) { handleRequestAsync( - () -> textClassifier.detectLanguage(request), + () -> + textClassifier.detectLanguage( + sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId), @@ -168,7 +196,7 @@ public final class DefaultTextClassifierService extends TextClassifierService { @Override public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) { - handleEvent(() -> textClassifier.onSelectionEvent(event)); + handleEvent(() -> textClassifier.onSelectionEvent(sessionId, event)); } @Override @@ -182,9 +210,24 @@ public final class DefaultTextClassifierService extends TextClassifierService { IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer); // TODO(licha): Also dump ModelDownloadManager for debugging textClassifier.dump(indentingPrintWriter); + dumpImpl(indentingPrintWriter); indentingPrintWriter.flush(); } + private void dumpImpl(IndentingPrintWriter printWriter) { + printWriter.println("DefaultTextClassifierService:"); + printWriter.increaseIndent(); + printWriter.println("sessionIdToContext:"); + printWriter.increaseIndent(); + for (Map.Entry<TextClassificationSessionId, TextClassificationContext> entry : + sessionIdToContext.snapshot().entrySet()) { + printWriter.printPair(entry.getKey().getValue(), entry.getValue()); + } + printWriter.decreaseIndent(); + printWriter.decreaseIndent(); + printWriter.println(); + } + private <T> void handleRequestAsync( Callable<T> callable, Callback<T> callback, @@ -232,6 +275,15 @@ public final class DefaultTextClassifierService extends TextClassifierService { MoreExecutors.directExecutor()); } + @Nullable + private TextClassificationContext sessionIdToTextClassificationContext( + @Nullable TextClassificationSessionId sessionId) { + if (sessionId == null) { + return null; + } + return sessionIdToContext.get(sessionId); + } + /** * Receiver listening to locale change event. Ask ModelFileManager to do clean-up upon receiving. */ diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java index f5d7a47..7383bc1 100644 --- a/java/src/com/android/textclassifier/TextClassifierImpl.java +++ b/java/src/com/android/textclassifier/TextClassifierImpl.java @@ -33,6 +33,8 @@ import android.view.textclassifier.ConversationAction; import android.view.textclassifier.ConversationActions; import android.view.textclassifier.SelectionEvent; import android.view.textclassifier.TextClassification; +import android.view.textclassifier.TextClassification.Request; +import android.view.textclassifier.TextClassificationContext; import android.view.textclassifier.TextClassificationSessionId; import android.view.textclassifier.TextClassifier; import android.view.textclassifier.TextClassifierEvent; @@ -127,7 +129,11 @@ final class TextClassifierImpl { } @WorkerThread - TextSelection suggestSelection(TextSelection.Request request) throws IOException { + TextSelection suggestSelection( + @Nullable TextClassificationSessionId sessionId, + @Nullable TextClassificationContext textClassificationContext, + TextSelection.Request request) + throws IOException { Preconditions.checkNotNull(request); checkMainThread(); final int rangeLength = request.getEndIndex() - request.getStartIndex(); @@ -186,7 +192,11 @@ final class TextClassifierImpl { } @WorkerThread - TextClassification classifyText(TextClassification.Request request) throws IOException { + TextClassification classifyText( + @Nullable TextClassificationSessionId sessionId, + @Nullable TextClassificationContext textClassificationContext, + Request request) + throws IOException { Preconditions.checkNotNull(request); checkMainThread(); LangIdModel langId = getLangIdImpl(); @@ -226,7 +236,11 @@ final class TextClassifierImpl { } @WorkerThread - TextLinks generateLinks(TextLinks.Request request) throws IOException { + TextLinks generateLinks( + @Nullable TextClassificationSessionId sessionId, + @Nullable TextClassificationContext textClassificationContext, + TextLinks.Request request) + throws IOException { Preconditions.checkNotNull(request); Preconditions.checkArgument( request.getText().length() <= getMaxGenerateLinksTextLength(), @@ -293,6 +307,8 @@ final class TextClassifierImpl { langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo); } generateLinksLogger.logGenerateLinks( + sessionId, + textClassificationContext, request.getText(), links, callingPackageName, @@ -321,7 +337,7 @@ final class TextClassifierImpl { } } - void onSelectionEvent(SelectionEvent event) { + void onSelectionEvent(@Nullable TextClassificationSessionId sessionId, SelectionEvent event) { TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event); if (textClassifierEvent == null) { return; @@ -336,7 +352,11 @@ final class TextClassifierImpl { TextClassifierEventConverter.fromPlatform(event)); } - TextLanguage detectLanguage(TextLanguage.Request request) throws IOException { + TextLanguage detectLanguage( + @Nullable TextClassificationSessionId sessionId, + @Nullable TextClassificationContext textClassificationContext, + TextLanguage.Request request) + throws IOException { Preconditions.checkNotNull(request); checkMainThread(); final TextLanguage.Builder builder = new TextLanguage.Builder(); @@ -349,7 +369,10 @@ final class TextClassifierImpl { return builder.build(); } - ConversationActions suggestConversationActions(ConversationActions.Request request) + ConversationActions suggestConversationActions( + @Nullable TextClassificationSessionId sessionId, + @Nullable TextClassificationContext textClassificationContext, + ConversationActions.Request request) throws IOException { Preconditions.checkNotNull(request); checkMainThread(); @@ -650,6 +673,7 @@ final class TextClassifierImpl { printWriter.println(); settings.dump(printWriter); + printWriter.println(); } } diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java index 0f3322e..fdf259e 100644 --- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java +++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java @@ -127,6 +127,9 @@ public final class TextClassifierSettings { /** Sampling rate for TextClassifier API logging. */ static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate"; + /** The size of the cache of the mapping of session id to text classification context. */ + private static final String SESSION_ID_TO_CONTEXT_CACHE_SIZE = "session_id_to_context_cache_size"; + /** * A colon(:) separated string that specifies the configuration to use when including surrounding * context text in language detection queries. @@ -202,6 +205,8 @@ public final class TextClassifierSettings { */ private static final int TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT = 10; + private static final int SESSION_ID_TO_CONTEXT_CACHE_SIZE_DEFAULT = 10; + // TODO(licha): Consider removing this. We can use real device config for testing. /** DeviceConfig interface to facilitate testing. */ @VisibleForTesting @@ -427,6 +432,11 @@ public final class TextClassifierSettings { NAMESPACE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT); } + public int getSessionIdToContextCacheSize() { + return deviceConfig.getInt( + NAMESPACE, SESSION_ID_TO_CONTEXT_CACHE_SIZE, SESSION_ID_TO_CONTEXT_CACHE_SIZE_DEFAULT); + } + public void dump(IndentingPrintWriter pw) { pw.println("TextClassifierSettings:"); pw.increaseIndent(); @@ -451,6 +461,7 @@ public final class TextClassifierSettings { pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts()); pw.decreaseIndent(); pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate()); + pw.printPair(SESSION_ID_TO_CONTEXT_CACHE_SIZE, getSessionIdToContextCacheSize()); pw.decreaseIndent(); } diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java index 822eb77..df63f2f 100644 --- a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java +++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java @@ -16,21 +16,20 @@ package com.android.textclassifier.common.statsd; +import android.view.textclassifier.TextClassificationContext; +import android.view.textclassifier.TextClassificationSessionId; import android.view.textclassifier.TextClassifier; import android.view.textclassifier.TextLinks; import androidx.collection.ArrayMap; import com.android.textclassifier.common.base.TcLog; import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo; import com.android.textclassifier.common.logging.TextClassifierEvent; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.Locale; import java.util.Map; import java.util.Random; -import java.util.UUID; -import java.util.function.Supplier; import javax.annotation.Nullable; /** A helper for logging calls to generateLinks. */ @@ -40,7 +39,6 @@ public final class GenerateLinksLogger { private final Random random; private final int sampleRate; - private final Supplier<String> randomUuidSupplier; /** * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01 @@ -48,24 +46,14 @@ public final class GenerateLinksLogger { * events, pass 1. */ public GenerateLinksLogger(int sampleRate) { - this(sampleRate, () -> UUID.randomUUID().toString()); - } - - /** - * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01 - * chance that a call to logGenerateLinks results in an event being written). To write all - * events, pass 1. - * @param randomUuidSupplier supplies random UUIDs. - */ - @VisibleForTesting - GenerateLinksLogger(int sampleRate, Supplier<String> randomUuidSupplier) { this.sampleRate = sampleRate; random = new Random(); - this.randomUuidSupplier = Preconditions.checkNotNull(randomUuidSupplier); } /** Logs statistics about a call to generateLinks. */ public void logGenerateLinks( + @Nullable TextClassificationSessionId sessionId, + @Nullable TextClassificationContext textClassificationContext, CharSequence text, TextLinks links, String callingPackageName, @@ -95,20 +83,33 @@ public final class GenerateLinksLogger { totalStats.countLink(link); perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link); } + int widgetType = TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN; + if (textClassificationContext != null) { + widgetType = WidgetTypeConverter.toLoggingValue(textClassificationContext.getWidgetType()); + } - final String callId = randomUuidSupplier.get(); + final String sessionIdStr = sessionId == null ? null : sessionId.getValue(); writeStats( - callId, callingPackageName, null, totalStats, text, latencyMs, annotatorModel, langIdModel); + sessionIdStr, + callingPackageName, + null, + totalStats, + text, + widgetType, + latencyMs, + annotatorModel, + langIdModel); // Sort the entity types to ensure the logging order is deterministic. ImmutableList<String> sortedEntityTypes = ImmutableList.sortedCopyOf(perEntityTypeStats.keySet()); for (String entityType : sortedEntityTypes) { writeStats( - callId, + sessionIdStr, callingPackageName, entityType, perEntityTypeStats.get(entityType), text, + widgetType, latencyMs, annotatorModel, langIdModel); @@ -130,11 +131,12 @@ public final class GenerateLinksLogger { /** Writes a log event for the given stats. */ private static void writeStats( - String callId, + @Nullable String sessionId, String callingPackageName, @Nullable String entityType, LinkifyStats stats, CharSequence text, + int widgetType, long latencyMs, Optional<ModelInfo> annotatorModel, Optional<ModelInfo> langIdModel) { @@ -142,10 +144,10 @@ public final class GenerateLinksLogger { String langIdModelName = langIdModel.transform(ModelInfo::toModelName).or(""); TextClassifierStatsLog.write( TextClassifierStatsLog.TEXT_LINKIFY_EVENT, - callId, + sessionId, TextClassifierEvent.TYPE_LINKS_GENERATED, annotatorModelName, - TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN, + widgetType, /* eventIndex */ 0, entityType, stats.numLinks, @@ -161,7 +163,7 @@ public final class GenerateLinksLogger { String.format( Locale.US, "%s:%s %d links (%d/%d chars) %dms %s annotator=%s langid=%s", - callId, + sessionId, entityType, stats.numLinks, stats.numLinksTextLength, diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java index 6678142..06ad44f 100644 --- a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java +++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java @@ -19,7 +19,6 @@ package com.android.textclassifier.common.statsd; import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Strings.nullToEmpty; -import android.view.textclassifier.TextClassifier; import com.android.textclassifier.common.base.TcLog; import com.android.textclassifier.common.logging.ResultIdUtils; import com.android.textclassifier.common.logging.TextClassificationContext; @@ -195,60 +194,20 @@ public final class TextClassifierEventLogger { return ResultIdUtils.getModelNames(event.getResultId()); } - @Nullable - private static String getPackageName(TextClassifierEvent event) { + private static int getWidgetType(TextClassifierEvent event) { TextClassificationContext eventContext = event.getEventContext(); if (eventContext == null) { - return null; + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN; } - return eventContext.getPackageName(); + return WidgetTypeConverter.toLoggingValue(eventContext.getWidgetType()); } - private static int getWidgetType(TextClassifierEvent event) { + @Nullable + private static String getPackageName(TextClassifierEvent event) { TextClassificationContext eventContext = event.getEventContext(); if (eventContext == null) { - return WidgetType.WIDGET_TYPE_UNKNOWN; - } - switch (eventContext.getWidgetType()) { - case TextClassifier.WIDGET_TYPE_UNKNOWN: - return WidgetType.WIDGET_TYPE_UNKNOWN; - case TextClassifier.WIDGET_TYPE_TEXTVIEW: - return WidgetType.WIDGET_TYPE_TEXTVIEW; - case TextClassifier.WIDGET_TYPE_EDITTEXT: - return WidgetType.WIDGET_TYPE_EDITTEXT; - case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW: - return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW; - case TextClassifier.WIDGET_TYPE_WEBVIEW: - return WidgetType.WIDGET_TYPE_WEBVIEW; - case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW: - return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW; - case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW: - return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW; - case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT: - return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT; - case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW: - return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW; - case TextClassifier.WIDGET_TYPE_NOTIFICATION: - return WidgetType.WIDGET_TYPE_NOTIFICATION; - default: // fall out + return null; } - return WidgetType.WIDGET_TYPE_UNKNOWN; - } - - /** Widget type constants for logging. */ - public static final class WidgetType { - // Sync these constants with textclassifier_enums.proto. - public static final int WIDGET_TYPE_UNKNOWN = 0; - public static final int WIDGET_TYPE_TEXTVIEW = 1; - public static final int WIDGET_TYPE_EDITTEXT = 2; - public static final int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3; - public static final int WIDGET_TYPE_WEBVIEW = 4; - public static final int WIDGET_TYPE_EDIT_WEBVIEW = 5; - public static final int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6; - public static final int WIDGET_TYPE_CUSTOM_EDITTEXT = 7; - public static final int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8; - public static final int WIDGET_TYPE_NOTIFICATION = 9; - - private WidgetType() {} + return eventContext.getPackageName(); } } diff --git a/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java b/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java new file mode 100644 index 0000000..13c04d1 --- /dev/null +++ b/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.textclassifier.common.statsd; + +import android.view.textclassifier.TextClassifier; + +/** Converts TextClassifier's WidgetTypes to enum values that are logged to server. */ +final class WidgetTypeConverter { + public static int toLoggingValue(String widgetType) { + switch (widgetType) { + case TextClassifier.WIDGET_TYPE_UNKNOWN: + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN; + case TextClassifier.WIDGET_TYPE_TEXTVIEW: + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_TEXTVIEW; + case TextClassifier.WIDGET_TYPE_EDITTEXT: + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_EDITTEXT; + case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW: + return TextClassifierStatsLog + .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNSELECTABLE_TEXTVIEW; + case TextClassifier.WIDGET_TYPE_WEBVIEW: + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_WEBVIEW; + case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW: + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_EDIT_WEBVIEW; + case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW: + return TextClassifierStatsLog + .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_TEXTVIEW; + case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT: + return TextClassifierStatsLog + .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_EDITTEXT; + case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW: + return TextClassifierStatsLog + .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW; + case TextClassifier.WIDGET_TYPE_NOTIFICATION: + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_NOTIFICATION; + case "clipboard": // TODO(tonymak) Replace it with WIDGET_TYPE_CLIPBOARD once S SDK is dropped + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CLIPBOARD; + default: // fall out + } + return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN; + } + + private WidgetTypeConverter() {} +} diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java index 07e03ab..81aa832 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java @@ -94,7 +94,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextSelection selection = classifier.suggestSelection(request); + TextSelection selection = classifier.suggestSelection(null, null, request); assertThat( selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL)); } @@ -113,7 +113,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextSelection selection = classifier.suggestSelection(request); + TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); } @@ -128,7 +128,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextSelection selection = classifier.suggestSelection(request); + TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); } @@ -143,7 +143,8 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = + classifier.classifyText(/* sessionId= */ null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL)); } @@ -158,7 +159,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); } @@ -171,7 +172,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); } @@ -186,7 +187,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); } @@ -202,7 +203,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); Bundle extras = classification.getExtras(); List<Bundle> entities = ExtrasUtils.getEntities(extras); @@ -223,7 +224,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); } @@ -237,7 +238,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); - TextClassification classification = classifier.classifyText(request); + TextClassification classification = classifier.classifyText(null, null, request); RemoteAction translateAction = classification.getActions().get(0); assertEquals(1, classification.getActions().size()); assertEquals("Translate", translateAction.getTitle().toString()); @@ -261,7 +262,7 @@ public class TextClassifierImplTest { String text = "The number is +12122537077. See you tonight!"; TextLinks.Request request = new TextLinks.Request.Builder(text).build(); assertThat( - classifier.generateLinks(request), + classifier.generateLinks(null, null, request), isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)); } @@ -277,7 +278,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); assertThat( - classifier.generateLinks(request), + classifier.generateLinks(null, null, request), not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); } @@ -291,7 +292,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); assertThat( - classifier.generateLinks(request), + classifier.generateLinks(null, null, request), isTextLinksContaining( text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS)); } @@ -308,7 +309,7 @@ public class TextClassifierImplTest { .setDefaultLocales(LOCALES) .build(); assertThat( - classifier.generateLinks(request), + classifier.generateLinks(null, null, request), not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); } @@ -317,7 +318,7 @@ public class TextClassifierImplTest { char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()]; Arrays.fill(manySpaces, ' '); TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); - TextLinks links = classifier.generateLinks(request); + TextLinks links = classifier.generateLinks(null, null, request); assertTrue(links.getLinks().isEmpty()); } @@ -327,7 +328,7 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(url).build(); assertEquals( TextLinks.STATUS_UNSUPPORTED_CHARACTER, - classifier.generateLinks(request).apply(url, 0, null)); + classifier.generateLinks(null, null, request).apply(url, 0, null)); } @Test @@ -335,7 +336,8 @@ public class TextClassifierImplTest { char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1]; Arrays.fill(manySpaces, ' '); TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); - expectThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request)); + expectThrows( + IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request)); } @Test @@ -345,7 +347,7 @@ public class TextClassifierImplTest { ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true); TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build(); - TextLinks textLinks = classifier.generateLinks(request); + TextLinks textLinks = classifier.generateLinks(null, null, request); assertThat(textLinks.getLinks()).hasSize(1); TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); @@ -360,7 +362,7 @@ public class TextClassifierImplTest { String text = "The number is +12122537077."; TextLinks.Request request = new TextLinks.Request.Builder(text).build(); - TextLinks textLinks = classifier.generateLinks(request); + TextLinks textLinks = classifier.generateLinks(null, null, request); assertThat(textLinks.getLinks()).hasSize(1); TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); @@ -372,7 +374,7 @@ public class TextClassifierImplTest { public void testDetectLanguage() throws IOException { String text = "This is English text"; TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); - TextLanguage textLanguage = classifier.detectLanguage(request); + TextLanguage textLanguage = classifier.detectLanguage(null, null, request); assertThat(textLanguage, isTextLanguage("en")); } @@ -380,7 +382,7 @@ public class TextClassifierImplTest { public void testDetectLanguage_japanese() throws IOException { String text = "これは日本語のテキストです"; TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); - TextLanguage textLanguage = classifier.detectLanguage(request); + TextLanguage textLanguage = classifier.detectLanguage(null, null, request); assertThat(textLanguage, isTextLanguage("ja")); } @@ -401,7 +403,8 @@ public class TextClassifierImplTest { .setTypeConfig(typeConfig) .build(); - ConversationActions conversationActions = classifier.suggestConversationActions(request); + ConversationActions conversationActions = + classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).hasSize(1); ConversationAction conversationAction = conversationActions.getConversationActions().get(0); assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); @@ -424,7 +427,8 @@ public class TextClassifierImplTest { .setTypeConfig(typeConfig) .build(); - ConversationActions conversationActions = classifier.suggestConversationActions(request); + ConversationActions conversationActions = + classifier.suggestConversationActions(null, null, request); assertTrue(conversationActions.getConversationActions().size() > 1); for (ConversationAction conversationAction : conversationActions.getConversationActions()) { assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY)); @@ -448,7 +452,8 @@ public class TextClassifierImplTest { .setTypeConfig(typeConfig) .build(); - ConversationActions conversationActions = classifier.suggestConversationActions(request); + ConversationActions conversationActions = + classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).hasSize(1); ConversationAction conversationAction = conversationActions.getConversationActions().get(0); assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL); @@ -475,7 +480,8 @@ public class TextClassifierImplTest { .setTypeConfig(typeConfig) .build(); - ConversationActions conversationActions = classifier.suggestConversationActions(request); + ConversationActions conversationActions = + classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).hasSize(1); ConversationAction conversationAction = conversationActions.getConversationActions().get(0); assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY); @@ -497,7 +503,8 @@ public class TextClassifierImplTest { .setMaxSuggestions(3) .build(); - ConversationActions conversationActions = classifier.suggestConversationActions(request); + ConversationActions conversationActions = + classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).isEmpty(); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java index 03041e1..4d5ca4a 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java @@ -501,10 +501,10 @@ public final class ModelFileManagerTest { ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE); assertThat(listedModels).hasSize(2); - assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile1.getAbsolutePath()); assertThat(listedModels.get(0).isAsset).isFalse(); - assertThat(listedModels.get(1).absolutePath).isEqualTo(modelFile2.getAbsolutePath()); assertThat(listedModels.get(1).isAsset).isFalse(); + assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath)) + .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath()); } @Test diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java index 6c66dd5..e215b15 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java @@ -18,8 +18,12 @@ package com.android.textclassifier.common.statsd; import static com.google.common.truth.Truth.assertThat; +import android.os.Binder; +import android.os.Parcel; import android.stats.textclassifier.EventType; import android.stats.textclassifier.WidgetType; +import android.view.textclassifier.TextClassificationContext; +import android.view.textclassifier.TextClassificationSessionId; import android.view.textclassifier.TextClassifier; import android.view.textclassifier.TextLinks; import androidx.test.core.app.ApplicationProvider; @@ -55,6 +59,11 @@ public class GenerateLinksLoggerTest { new ModelInfo(1, ImmutableList.of(Locale.ENGLISH)); private static final ModelInfo LANGID_MODEL = new ModelInfo(2, ImmutableList.of(Locale.forLanguageTag("*"))); + private static final String SESSION_ID = "123456"; + private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_WEBVIEW; + private static final WidgetType WIDGET_TYPE_ENUM = WidgetType.WIDGET_TYPE_WEBVIEW; + private final TextClassificationContext textClassificationContext = + new TextClassificationContext.Builder(PACKAGE_NAME, WIDGET_TYPE).build(); @Before public void setup() throws Exception { @@ -83,11 +92,11 @@ public class GenerateLinksLoggerTest { new TextLinks.Builder(testText) .addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores) .build(); - String uuid = "uuid"; - GenerateLinksLogger generateLinksLogger = - new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid); + GenerateLinksLogger generateLinksLogger = new GenerateLinksLogger(/* sampleRate= */ 1); generateLinksLogger.logGenerateLinks( + createTextClassificationSessionId(), + textClassificationContext, testText, links, PACKAGE_NAME, @@ -103,10 +112,10 @@ public class GenerateLinksLoggerTest { assertThat(loggedEvents).hasSize(2); TextLinkifyEvent summaryEvent = AtomsProto.TextLinkifyEvent.newBuilder() - .setSessionId(uuid) + .setSessionId(SESSION_ID) .setEventIndex(0) .setModelName("en_v1") - .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN) + .setWidgetType(WIDGET_TYPE_ENUM) .setEventType(EventType.LINKS_GENERATED) .setPackageName(PACKAGE_NAME) .setEntityType("") @@ -118,10 +127,10 @@ public class GenerateLinksLoggerTest { .build(); TextLinkifyEvent phoneEvent = AtomsProto.TextLinkifyEvent.newBuilder() - .setSessionId(uuid) + .setSessionId(SESSION_ID) .setEventIndex(0) .setModelName("en_v1") - .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN) + .setWidgetType(WIDGET_TYPE_ENUM) .setEventType(EventType.LINKS_GENERATED) .setPackageName(PACKAGE_NAME) .setEntityType(TextClassifier.TYPE_PHONE) @@ -148,11 +157,11 @@ public class GenerateLinksLoggerTest { .addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores) .addLink(addressOffset, addressOffset + addressText.length(), addressEntityScores) .build(); - String uuid = "uuid"; - GenerateLinksLogger generateLinksLogger = - new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid); + GenerateLinksLogger generateLinksLogger = new GenerateLinksLogger(/* sampleRate= */ 1); generateLinksLogger.logGenerateLinks( + createTextClassificationSessionId(), + textClassificationContext, testText, links, PACKAGE_NAME, @@ -182,4 +191,13 @@ public class GenerateLinksLoggerTest { assertThat(phoneEvent.getNumLinks()).isEqualTo(1); assertThat(phoneEvent.getLinkedTextLength()).isEqualTo(phoneText.length()); } + + private static TextClassificationSessionId createTextClassificationSessionId() { + // A hack to create TextClassificationSessionId because its constructor is @hide. + Parcel parcel = Parcel.obtain(); + parcel.writeString(SESSION_ID); + parcel.writeStrongBinder(new Binder()); + parcel.setDataPosition(0); + return TextClassificationSessionId.CREATOR.createFromParcel(parcel); + } } diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java index c5f2112..ad3992d 100644 --- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java +++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java @@ -172,10 +172,10 @@ public final class ActionsSuggestionsModel implements AutoCloseable { assetFileDescriptor.getLength()); } - /** Initializes DeepCLU, passing the given serialized config to it. */ - public void initializeDeepClu(byte[] serializedConfig) { - if (!nativeInitializeDeepClu(actionsModelPtr, serializedConfig)) { - throw new IllegalArgumentException("Couldn't initialize DeepCLU"); + /** Initializes conversation intent detection, passing the given serialized config to it. */ + public void initializeConversationIntentDetection(byte[] serializedConfig) { + if (!nativeInitializeConversationIntentDetection(actionsModelPtr, serializedConfig)) { + throw new IllegalArgumentException("Couldn't initialize conversation intent detection"); } } @@ -345,7 +345,8 @@ public final class ActionsSuggestionsModel implements AutoCloseable { private static native long nativeNewActionsModelWithOffset( int fd, long offset, long size, byte[] preconditionsOverwrite); - private native boolean nativeInitializeDeepClu(long actionsModelPtr, byte[] serializedConfig); + private native boolean nativeInitializeConversationIntentDetection( + long actionsModelPtr, byte[] serializedConfig); private static native String nativeGetLocales(int fd); diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp index bf4ff44..950eee6 100644 --- a/native/FlatBufferHeaders.bp +++ b/native/FlatBufferHeaders.bp @@ -15,20 +15,6 @@ // genrule { - name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", - srcs: ["lang_id/common/flatbuffers/model.fbs"], - out: ["lang_id/common/flatbuffers/model_generated.h"], - defaults: ["fbgen"], -} - -genrule { - name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", - srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"], - out: ["lang_id/common/flatbuffers/embedding-network_generated.h"], - defaults: ["fbgen"], -} - -genrule { name: "libtextclassifier_fbgen_actions_actions_model", srcs: ["actions/actions_model.fbs"], out: ["actions/actions_model_generated.h"], @@ -43,23 +29,23 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_annotator_model", - srcs: ["annotator/model.fbs"], - out: ["annotator/model_generated.h"], + name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", + srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"], + out: ["lang_id/common/flatbuffers/embedding-network_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_person_name_person_name_model", - srcs: ["annotator/person_name/person_name_model.fbs"], - out: ["annotator/person_name/person_name_model_generated.h"], + name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", + srcs: ["lang_id/common/flatbuffers/model.fbs"], + out: ["lang_id/common/flatbuffers/model_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_annotator_experimental_experimental", - srcs: ["annotator/experimental/experimental.fbs"], - out: ["annotator/experimental/experimental_generated.h"], + name: "libtextclassifier_fbgen_annotator_person_name_person_name_model", + srcs: ["annotator/person_name/person_name_model.fbs"], + out: ["annotator/person_name/person_name_model_generated.h"], defaults: ["fbgen"], } @@ -71,37 +57,37 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_annotator_entity-data", - srcs: ["annotator/entity-data.fbs"], - out: ["annotator/entity-data_generated.h"], + name: "libtextclassifier_fbgen_annotator_experimental_experimental", + srcs: ["annotator/experimental/experimental.fbs"], + out: ["annotator/experimental/experimental_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_grammar_testing_value", - srcs: ["utils/grammar/testing/value.fbs"], - out: ["utils/grammar/testing/value_generated.h"], + name: "libtextclassifier_fbgen_annotator_entity-data", + srcs: ["annotator/entity-data.fbs"], + out: ["annotator/entity-data_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_grammar_semantics_expression", - srcs: ["utils/grammar/semantics/expression.fbs"], - out: ["utils/grammar/semantics/expression_generated.h"], + name: "libtextclassifier_fbgen_annotator_model", + srcs: ["annotator/model.fbs"], + out: ["annotator/model_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_grammar_rules", - srcs: ["utils/grammar/rules.fbs"], - out: ["utils/grammar/rules_generated.h"], + name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", + srcs: ["utils/flatbuffers/flatbuffers.fbs"], + out: ["utils/flatbuffers/flatbuffers_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_normalization", - srcs: ["utils/normalization.fbs"], - out: ["utils/normalization_generated.h"], + name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config", + srcs: ["utils/tflite/text_encoder_config.fbs"], + out: ["utils/tflite/text_encoder_config_generated.h"], defaults: ["fbgen"], } @@ -113,37 +99,51 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_utils_i18n_language-tag", - srcs: ["utils/i18n/language-tag.fbs"], - out: ["utils/i18n/language-tag_generated.h"], + name: "libtextclassifier_fbgen_utils_zlib_buffer", + srcs: ["utils/zlib/buffer.fbs"], + out: ["utils/zlib/buffer_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config", - srcs: ["utils/tflite/text_encoder_config.fbs"], - out: ["utils/tflite/text_encoder_config_generated.h"], + name: "libtextclassifier_fbgen_utils_container_bit-vector", + srcs: ["utils/container/bit-vector.fbs"], + out: ["utils/container/bit-vector_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", - srcs: ["utils/flatbuffers/flatbuffers.fbs"], - out: ["utils/flatbuffers/flatbuffers_generated.h"], + name: "libtextclassifier_fbgen_utils_intents_intent-config", + srcs: ["utils/intents/intent-config.fbs"], + out: ["utils/intents/intent-config_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_container_bit-vector", - srcs: ["utils/container/bit-vector.fbs"], - out: ["utils/container/bit-vector_generated.h"], + name: "libtextclassifier_fbgen_utils_normalization", + srcs: ["utils/normalization.fbs"], + out: ["utils/normalization_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_tokenizer", - srcs: ["utils/tokenizer.fbs"], - out: ["utils/tokenizer_generated.h"], + name: "libtextclassifier_fbgen_utils_grammar_semantics_expression", + srcs: ["utils/grammar/semantics/expression.fbs"], + out: ["utils/grammar/semantics/expression_generated.h"], + defaults: ["fbgen"], +} + +genrule { + name: "libtextclassifier_fbgen_utils_grammar_rules", + srcs: ["utils/grammar/rules.fbs"], + out: ["utils/grammar/rules_generated.h"], + defaults: ["fbgen"], +} + +genrule { + name: "libtextclassifier_fbgen_utils_grammar_testing_value", + srcs: ["utils/grammar/testing/value.fbs"], + out: ["utils/grammar/testing/value_generated.h"], defaults: ["fbgen"], } @@ -155,16 +155,16 @@ genrule { } genrule { - name: "libtextclassifier_fbgen_utils_zlib_buffer", - srcs: ["utils/zlib/buffer.fbs"], - out: ["utils/zlib/buffer_generated.h"], + name: "libtextclassifier_fbgen_utils_tokenizer", + srcs: ["utils/tokenizer.fbs"], + out: ["utils/tokenizer_generated.h"], defaults: ["fbgen"], } genrule { - name: "libtextclassifier_fbgen_utils_intents_intent-config", - srcs: ["utils/intents/intent-config.fbs"], - out: ["utils/intents/intent-config_generated.h"], + name: "libtextclassifier_fbgen_utils_i18n_language-tag", + srcs: ["utils/i18n/language-tag.fbs"], + out: ["utils/i18n/language-tag_generated.h"], defaults: ["fbgen"], } @@ -178,50 +178,50 @@ cc_library_headers { "com.android.extservices", ], generated_headers: [ - "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", - "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", "libtextclassifier_fbgen_actions_actions_model", "libtextclassifier_fbgen_actions_actions-entity-data", - "libtextclassifier_fbgen_annotator_model", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", "libtextclassifier_fbgen_annotator_person_name_person_name_model", - "libtextclassifier_fbgen_annotator_experimental_experimental", "libtextclassifier_fbgen_annotator_datetime_datetime", + "libtextclassifier_fbgen_annotator_experimental_experimental", "libtextclassifier_fbgen_annotator_entity-data", - "libtextclassifier_fbgen_utils_grammar_semantics_expression", - "libtextclassifier_fbgen_utils_grammar_rules", - "libtextclassifier_fbgen_utils_normalization", - "libtextclassifier_fbgen_utils_resources", - "libtextclassifier_fbgen_utils_i18n_language-tag", - "libtextclassifier_fbgen_utils_tflite_text_encoder_config", + "libtextclassifier_fbgen_annotator_model", "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", - "libtextclassifier_fbgen_utils_container_bit-vector", - "libtextclassifier_fbgen_utils_tokenizer", - "libtextclassifier_fbgen_utils_codepoint-range", + "libtextclassifier_fbgen_utils_tflite_text_encoder_config", + "libtextclassifier_fbgen_utils_resources", "libtextclassifier_fbgen_utils_zlib_buffer", + "libtextclassifier_fbgen_utils_container_bit-vector", "libtextclassifier_fbgen_utils_intents_intent-config", + "libtextclassifier_fbgen_utils_normalization", + "libtextclassifier_fbgen_utils_grammar_semantics_expression", + "libtextclassifier_fbgen_utils_grammar_rules", + "libtextclassifier_fbgen_utils_codepoint-range", + "libtextclassifier_fbgen_utils_tokenizer", + "libtextclassifier_fbgen_utils_i18n_language-tag", ], export_generated_headers: [ - "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", - "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", "libtextclassifier_fbgen_actions_actions_model", "libtextclassifier_fbgen_actions_actions-entity-data", - "libtextclassifier_fbgen_annotator_model", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network", + "libtextclassifier_fbgen_lang_id_common_flatbuffers_model", "libtextclassifier_fbgen_annotator_person_name_person_name_model", - "libtextclassifier_fbgen_annotator_experimental_experimental", "libtextclassifier_fbgen_annotator_datetime_datetime", + "libtextclassifier_fbgen_annotator_experimental_experimental", "libtextclassifier_fbgen_annotator_entity-data", - "libtextclassifier_fbgen_utils_grammar_semantics_expression", - "libtextclassifier_fbgen_utils_grammar_rules", - "libtextclassifier_fbgen_utils_normalization", - "libtextclassifier_fbgen_utils_resources", - "libtextclassifier_fbgen_utils_i18n_language-tag", - "libtextclassifier_fbgen_utils_tflite_text_encoder_config", + "libtextclassifier_fbgen_annotator_model", "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers", - "libtextclassifier_fbgen_utils_container_bit-vector", - "libtextclassifier_fbgen_utils_tokenizer", - "libtextclassifier_fbgen_utils_codepoint-range", + "libtextclassifier_fbgen_utils_tflite_text_encoder_config", + "libtextclassifier_fbgen_utils_resources", "libtextclassifier_fbgen_utils_zlib_buffer", + "libtextclassifier_fbgen_utils_container_bit-vector", "libtextclassifier_fbgen_utils_intents_intent-config", + "libtextclassifier_fbgen_utils_normalization", + "libtextclassifier_fbgen_utils_grammar_semantics_expression", + "libtextclassifier_fbgen_utils_grammar_rules", + "libtextclassifier_fbgen_utils_codepoint-range", + "libtextclassifier_fbgen_utils_tokenizer", + "libtextclassifier_fbgen_utils_i18n_language-tag", ], } diff --git a/native/JavaTests.bp b/native/JavaTests.bp index 4c22ba3..78d5748 100644 --- a/native/JavaTests.bp +++ b/native/JavaTests.bp @@ -17,30 +17,30 @@ filegroup { name: "libtextclassifier_java_test_sources", srcs: [ - "actions/actions-suggestions_test.cc", "actions/grammar-actions_test.cc", + "actions/actions-suggestions_test.cc", "annotator/pod_ner/pod-ner-impl_test.cc", "annotator/datetime/regex-parser_test.cc", - "annotator/datetime/datetime-grounder_test.cc", "annotator/datetime/grammar-parser_test.cc", - "utils/grammar/parsing/lexer_test.cc", - "utils/regex-match_test.cc", - "utils/calendar/calendar_test.cc", + "annotator/datetime/datetime-grounder_test.cc", "utils/intents/intent-generator-test-lib.cc", - "annotator/grammar/grammar-annotator_test.cc", - "annotator/grammar/test-utils.cc", + "utils/calendar/calendar_test.cc", + "utils/regex-match_test.cc", + "utils/grammar/parsing/lexer_test.cc", "annotator/number/number_test-include.cc", "annotator/annotator_test-include.cc", + "annotator/grammar/grammar-annotator_test.cc", + "annotator/grammar/test-utils.cc", "utils/utf8/unilib_test-include.cc", - "utils/grammar/parsing/parser_test.cc", "utils/grammar/analyzer_test.cc", "utils/grammar/semantics/composer_test.cc", - "utils/grammar/semantics/evaluators/merge-values-eval_test.cc", - "utils/grammar/semantics/evaluators/constituent-eval_test.cc", - "utils/grammar/semantics/evaluators/parse-number-eval_test.cc", "utils/grammar/semantics/evaluators/arithmetic-eval_test.cc", - "utils/grammar/semantics/evaluators/span-eval_test.cc", + "utils/grammar/semantics/evaluators/merge-values-eval_test.cc", "utils/grammar/semantics/evaluators/const-eval_test.cc", "utils/grammar/semantics/evaluators/compose-eval_test.cc", + "utils/grammar/semantics/evaluators/span-eval_test.cc", + "utils/grammar/semantics/evaluators/parse-number-eval_test.cc", + "utils/grammar/semantics/evaluators/constituent-eval_test.cc", + "utils/grammar/parsing/parser_test.cc", ], } diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs index 21584b6..21584b6 100755..100644 --- a/native/actions/actions-entity-data.fbs +++ b/native/actions/actions-entity-data.fbs diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index f437c8d..830976e 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -1040,13 +1040,13 @@ bool ActionsSuggestions::SuggestActionsFromModel( return ReadModelOutput(interpreter->get(), options, response); } -Status ActionsSuggestions::SuggestActionsFromDeepClu( +Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection( const Conversation& conversation, const ActionSuggestionOptions& options, std::vector<ActionSuggestion>* actions) const { - std::vector<ActionSuggestion> deep_clu_actions; - TC3_ASSIGN_OR_RETURN(deep_clu_actions, - deep_clu_->SuggestActions(conversation, options)); - for (const auto& action : deep_clu_actions) { + TC3_ASSIGN_OR_RETURN( + std::vector<ActionSuggestion> new_actions, + conversation_intent_detection_->SuggestActions(conversation, options)); + for (auto& action : new_actions) { actions->push_back(std::move(action)); } return Status::OK; @@ -1381,12 +1381,13 @@ bool ActionsSuggestions::GatherActionsSuggestions( return true; } - if (deep_clu_) { + if (conversation_intent_detection_) { // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works. - auto actions = SuggestActionsFromDeepClu(annotated_conversation, options, - &response->actions); + auto actions = SuggestActionsFromConversationIntentDetection( + annotated_conversation, options, &response->actions); if (!actions.ok()) { - TC3_LOG(ERROR) << "Could not run DeepCLU: " << actions.error_message(); + TC3_LOG(ERROR) << "Could not run conversation intent detection: " + << actions.error_message(); return false; } } @@ -1475,14 +1476,15 @@ const ActionsModel* ViewActionsModel(const void* buffer, int size) { return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size); } -bool ActionsSuggestions::InitializeDeepClu( +bool ActionsSuggestions::InitializeConversationIntentDetection( const std::string& serialized_config) { - auto deep_clu = std::make_unique<DeepClu>(); - if (!deep_clu->Initialize(serialized_config).ok()) { - TC3_LOG(ERROR) << "Failed to initialize DeepCLU."; + auto conversation_intent_detection = + std::make_unique<ConversationIntentDetection>(); + if (!conversation_intent_detection->Initialize(serialized_config).ok()) { + TC3_LOG(ERROR) << "Failed to initialize conversation intent detection."; return false; } - deep_clu_ = std::move(deep_clu); + conversation_intent_detection_ = std::move(conversation_intent_detection); return true; } diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h index 08a3f65..32edc78 100644 --- a/native/actions/actions-suggestions.h +++ b/native/actions/actions-suggestions.h @@ -25,7 +25,7 @@ #include <vector> #include "actions/actions_model_generated.h" -#include "actions/deep_clu/deep-clu.h" +#include "actions/conversation_intent_detection/conversation-intent-detection.h" #include "actions/feature-processor.h" #include "actions/grammar-actions.h" #include "actions/ranker.h" @@ -105,7 +105,8 @@ class ActionsSuggestions { const Conversation& conversation, const Annotator* annotator, const ActionSuggestionOptions& options = ActionSuggestionOptions()) const; - bool InitializeDeepClu(const std::string& serialized_config); + bool InitializeConversationIntentDetection( + const std::string& serialized_config); const ActionsModel* model() const; const reflection::Schema* entity_data_schema() const; @@ -193,7 +194,7 @@ class ActionsSuggestions { ActionsSuggestionsResponse* response, std::unique_ptr<tflite::Interpreter>* interpreter) const; - Status SuggestActionsFromDeepClu( + Status SuggestActionsFromConversationIntentDetection( const Conversation& conversation, const ActionSuggestionOptions& options, std::vector<ActionSuggestion>* actions) const; @@ -267,8 +268,9 @@ class ActionsSuggestions { // Low confidence input ngram classifier. std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_; - // DeepCLU model for additional actions. - std::unique_ptr<const DeepClu> deep_clu_; + // Conversation intent detection model for additional actions. + std::unique_ptr<const ConversationIntentDetection> + conversation_intent_detection_; }; // Interprets the buffer as a Model flatbuffer and returns it for reading. diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc index 3b0ca0f..6e66b32 100644 --- a/native/actions/actions-suggestions_test.cc +++ b/native/actions/actions-suggestions_test.cc @@ -61,6 +61,8 @@ constexpr char kMultiTaskSrP13nModelFileName[] = "actions_suggestions_test.multi_task_sr_p13n.model"; constexpr char kMultiTaskSrEmojiModelFileName[] = "actions_suggestions_test.multi_task_sr_emoji.model"; +constexpr char kSensitiveTFliteModelFileName[] = + "actions_suggestions_test.sensitive_tflite.model"; std::string ReadFile(const std::string& file_name) { std::ifstream file_stream(file_name); @@ -1808,5 +1810,19 @@ TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) { EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION"); } +TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) { + std::unique_ptr<ActionsSuggestions> actions_suggestions = + LoadTestModel(kSensitiveTFliteModelFileName); + const ActionsSuggestionsResponse response = + actions_suggestions->SuggestActions( + {{{/*user_id=*/1, "I want to kill myself", + /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"Europe/Zurich", + /*annotations=*/{}, + /*locales=*/"en"}}}); + EXPECT_EQ(response.actions.size(), 0); + EXPECT_TRUE(response.output_filtered_low_confidence); +} + } // namespace } // namespace libtextclassifier3 diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc index 9c71ab6..5981a17 100644 --- a/native/actions/actions_jni.cc +++ b/native/actions/actions_jni.cc @@ -567,7 +567,8 @@ TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeGetNativeModelPtr) reinterpret_cast<ActionsSuggestionsJniContext*>(ptr)->model()); } -TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, nativeInitializeDeepClu) +TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, + nativeInitializeConversationIntentDetection) (JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config) { if (!ptr) { return false; @@ -579,6 +580,7 @@ TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, nativeInitializeDeepClu) std::string serialized_config; TC3_ASSIGN_OR_RETURN_0( serialized_config, JByteArrayToString(env, jserialized_config), - TC3_LOG(ERROR) << "Could not convert serialized DeepCLU config."); - return model->InitializeDeepClu(serialized_config); + TC3_LOG(ERROR) << "Could not convert serialized conversation intent " + "detection config."); + return model->InitializeConversationIntentDetection(serialized_config); } diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h index f693101..5265a9c 100644 --- a/native/actions/actions_jni.h +++ b/native/actions/actions_jni.h @@ -41,7 +41,8 @@ TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset) (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size, jbyteArray serialized_preconditions); -TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, nativeInitializeDeepClu) +TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, + nativeInitializeConversationIntentDetection) (JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config); TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions) diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs index 8c03eeb..8c03eeb 100755..100644 --- a/native/actions/actions_model.fbs +++ b/native/actions/actions_model.fbs diff --git a/native/actions/deep_clu/deep-clu-dummy.h b/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h index e547c70..66255c5 100644 --- a/native/actions/deep_clu/deep-clu-dummy.h +++ b/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_DUMMY_H_ -#define LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_DUMMY_H_ +#ifndef LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_ +#define LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_ #include <string> #include <vector> @@ -26,10 +26,10 @@ namespace libtextclassifier3 { -// A dummy implementation of DeepCLU. -class DeepClu { +// A dummy implementation of conversation intent detection. +class ConversationIntentDetection { public: - DeepClu() {} + ConversationIntentDetection() {} Status Initialize(const std::string& serialized_config) { return Status::OK; } @@ -42,4 +42,4 @@ class DeepClu { } // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_DUMMY_H_ +#endif // LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_ diff --git a/native/actions/deep_clu/deep-clu.h b/native/actions/conversation_intent_detection/conversation-intent-detection.h index a10641b..949ceaf 100644 --- a/native/actions/deep_clu/deep-clu.h +++ b/native/actions/conversation_intent_detection/conversation-intent-detection.h @@ -14,9 +14,9 @@ * limitations under the License. */ -#ifndef LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_H_ -#define LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_H_ +#ifndef LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_ +#define LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_ -#include "actions/deep_clu/deep-clu-dummy.h" +#include "actions/conversation_intent_detection/conversation-intent-detection-dummy.h" -#endif // LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_H_ +#endif // LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_ diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model Binary files differindex 8c46dec..d900928 100644 --- a/native/actions/test_data/actions_suggestions_grammar_test.model +++ b/native/actions/test_data/actions_suggestions_grammar_test.model diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model Binary files differindex 5e43fd0..aa62c0a 100644 --- a/native/actions/test_data/actions_suggestions_test.model +++ b/native/actions/test_data/actions_suggestions_test.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model Binary files differindex b1912f5..50918e5 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model Binary files differindex 772c1bf..b43e6d7 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model Binary files differindex eb75802..6a71da3 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model Binary files differindex 889b7e0..72f4d9d 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model Binary files differindex 3da6e29..a6c8118 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model Binary files differnew file mode 100644 index 0000000..4a120b2 --- /dev/null +++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model diff --git a/native/actions/test_data/en_sensitive_topic_2019117.tflite b/native/actions/test_data/en_sensitive_topic_2019117.tflite Binary files differnew file mode 100644 index 0000000..48edfbd --- /dev/null +++ b/native/actions/test_data/en_sensitive_topic_2019117.tflite diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index 78f513e..e296a64 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -2739,14 +2739,19 @@ bool Annotator::ParseAndFillInMoneyAmount( std::string quantity; GetMoneyQuantityFromCapturingGroup(match, config, context_unicode, &quantity, &quantity_exponent); - if ((quantity_exponent > 0 && quantity_exponent < 9) || - (quantity_exponent == 9 && data->money->amount_whole_part <= 2)) { - data->money->amount_whole_part = + if (quantity_exponent > 0 && quantity_exponent <= 9) { + const double amount_whole_part = data->money->amount_whole_part * pow(10, quantity_exponent) + data->money->nanos / pow(10, 9 - quantity_exponent); - data->money->nanos = data->money->nanos % - static_cast<int>(pow(10, 9 - quantity_exponent)) * - pow(10, quantity_exponent); + // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64 + // (and `std::numeric_limits<int>::max()` to + // `std::numeric_limits<int64>::max()`). + if (amount_whole_part < std::numeric_limits<int>::max()) { + data->money->amount_whole_part = amount_whole_part; + data->money->nanos = data->money->nanos % + static_cast<int>(pow(10, 9 - quantity_exponent)) * + pow(10, quantity_exponent); + } } if (quantity_exponent > 0) { data->money->unnormalized_amount = strings::JoinStrings( diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc index aacb917..7d5f440 100644 --- a/native/annotator/datetime/datetime-grounder.cc +++ b/native/annotator/datetime/datetime-grounder.cc @@ -58,7 +58,8 @@ const std::unordered_map<int, int> kMonthDefaultLastDayMap( bool IsValidDatetime(const AbsoluteDateTime* absolute_datetime) { // Sanity Checks. if (absolute_datetime->minute() > 59 || absolute_datetime->second() > 59 || - absolute_datetime->hour() > 23 || absolute_datetime->month() > 12) { + absolute_datetime->hour() > 23 || absolute_datetime->month() > 12 || + absolute_datetime->month() == 0) { return false; } if (absolute_datetime->day() >= 0) { diff --git a/native/annotator/datetime/datetime-grounder_test.cc b/native/annotator/datetime/datetime-grounder_test.cc index e55bbfc..121aae8 100644 --- a/native/annotator/datetime/datetime-grounder_test.cc +++ b/native/annotator/datetime/datetime-grounder_test.cc @@ -261,6 +261,12 @@ TEST_F(DatetimeGrounderTest, InValidUngroundedDatetime) { /*hour=*/11, /*minute=*/59, /*second=*/99, grammar::datetime::Meridiem_AM) .get()); + + VerifyInValidUngroundedDatetime( + BuildAbsoluteDatetime(/*year=*/2000, /*month=*/00, /*day=*/28, + /*hour=*/11, /*minute=*/59, /*second=*/99, + grammar::datetime::Meridiem_AM) + .get()); } TEST_F(DatetimeGrounderTest, ValidUngroundedDatetime) { diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs index 9a96bae..9a96bae 100755..100644 --- a/native/annotator/datetime/datetime.fbs +++ b/native/annotator/datetime/datetime.fbs diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs index f82eb44..f82eb44 100755..100644 --- a/native/annotator/entity-data.fbs +++ b/native/annotator/entity-data.fbs diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs index 6e15d04..6e15d04 100755..100644 --- a/native/annotator/experimental/experimental.fbs +++ b/native/annotator/experimental/experimental.fbs diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs index 57187f5..57187f5 100755..100644 --- a/native/annotator/model.fbs +++ b/native/annotator/model.fbs diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs index b15543f..b15543f 100755..100644 --- a/native/annotator/person_name/person_name_model.fbs +++ b/native/annotator/person_name/person_name_model.fbs diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc index 2d7f0a2..666b7c7 100644 --- a/native/annotator/pod_ner/pod-ner-impl.cc +++ b/native/annotator/pod_ner/pod-ner-impl.cc @@ -126,7 +126,7 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( ::tflite::BuiltinOperator_EXPAND_DIMS, ::tflite::ops::builtin::Register_EXPAND_DIMS()); mutable_resolver->AddCustom( - "LayerNorm", ::tflite::ops::custom::Register_LAYER_NORM()); + "LayerNorm", ::seq_flow_lite::ops::custom::Register_LAYER_NORM()); }); std::unique_ptr<tflite::Interpreter> tflite_interpreter; @@ -253,7 +253,8 @@ std::vector<LabelT> PodNerAnnotator::ReadResultsFromInterpreter( float max_prob = 0.0f; int max_index = 0; for (int cindex = 0; cindex < output->dims->data[2]; ++cindex) { - const float probability = ::tflite::PodDequantize(*output, index++); + const float probability = + ::seq_flow_lite::PodDequantize(*output, index++); if (probability > max_prob) { max_prob = probability; max_index = cindex; diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc index df2b55f..e28b04d 100644 --- a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc +++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h" #include "tensorflow/lite/kernels/kernel_util.h" -namespace tflite { +namespace seq_flow_lite { namespace ops { namespace custom { @@ -344,4 +344,4 @@ TfLiteRegistration* Register_LAYER_NORM() { } // namespace custom } // namespace ops -} // namespace tflite +} // namespace seq_flow_lite diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h index 017d21a..6d84ca4 100644 --- a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h +++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" -namespace tflite { +namespace seq_flow_lite { namespace ops { namespace custom { @@ -41,6 +41,6 @@ TfLiteRegistration* Register_LAYER_NORM(); } // namespace custom } // namespace ops -} // namespace tflite +} // namespace seq_flow_lite #endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h index f8e3836..7f2db41 100644 --- a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h +++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/lite/context.h" -namespace tflite { +namespace seq_flow_lite { // Returns the original (dequantized) value of 8bit value. inline float PodDequantizeValue(const TfLiteTensor& tensor, uint8_t value) { @@ -64,6 +64,6 @@ inline uint8_t PodQuantize(float value, int32_t zero_point, return static_cast<uint8_t>(std::max(std::min(255, integer_value), 0)); } -} // namespace tflite +} // namespace seq_flow_lite #endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ diff --git a/native/utils/codepoint-range.fbs b/native/utils/codepoint-range.fbs index 135ce30..135ce30 100755..100644 --- a/native/utils/codepoint-range.fbs +++ b/native/utils/codepoint-range.fbs diff --git a/native/utils/container/bit-vector.fbs b/native/utils/container/bit-vector.fbs index d117ee5..d117ee5 100755..100644 --- a/native/utils/container/bit-vector.fbs +++ b/native/utils/container/bit-vector.fbs diff --git a/native/utils/flatbuffers/flatbuffers.fbs b/native/utils/flatbuffers/flatbuffers.fbs index 155e8f8..155e8f8 100755..100644 --- a/native/utils/flatbuffers/flatbuffers.fbs +++ b/native/utils/flatbuffers/flatbuffers.fbs diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs index bc0136c..bc0136c 100755..100644 --- a/native/utils/grammar/rules.fbs +++ b/native/utils/grammar/rules.fbs diff --git a/native/utils/grammar/semantics/expression.fbs b/native/utils/grammar/semantics/expression.fbs index 5397407..5397407 100755..100644 --- a/native/utils/grammar/semantics/expression.fbs +++ b/native/utils/grammar/semantics/expression.fbs diff --git a/native/utils/grammar/testing/value.fbs b/native/utils/grammar/testing/value.fbs index 0429491..0429491 100755..100644 --- a/native/utils/grammar/testing/value.fbs +++ b/native/utils/grammar/testing/value.fbs diff --git a/native/utils/i18n/language-tag.fbs b/native/utils/i18n/language-tag.fbs index a2e1077..a2e1077 100755..100644 --- a/native/utils/i18n/language-tag.fbs +++ b/native/utils/i18n/language-tag.fbs diff --git a/native/utils/intents/intent-config.fbs b/native/utils/intents/intent-config.fbs index 672eb9d..672eb9d 100755..100644 --- a/native/utils/intents/intent-config.fbs +++ b/native/utils/intents/intent-config.fbs diff --git a/native/utils/normalization.fbs b/native/utils/normalization.fbs index 4d43f10..4d43f10 100755..100644 --- a/native/utils/normalization.fbs +++ b/native/utils/normalization.fbs diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs index b4d9b83..b4d9b83 100755..100644 --- a/native/utils/resources.fbs +++ b/native/utils/resources.fbs diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc index 104dedc..36db3e9 100644 --- a/native/utils/tflite-model-executor.cc +++ b/native/utils/tflite-model-executor.cc @@ -39,6 +39,7 @@ TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_SOFTMAX(); TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_GATHER_ND(); +TfLiteRegistration* Register_IF(); TfLiteRegistration* Register_ROUND(); TfLiteRegistration* Register_ZEROS_LIKE(); TfLiteRegistration* Register_TRANSPOSE(); @@ -145,8 +146,10 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND, ::tflite::ops::builtin::Register_GATHER_ND(), /*version=*/2); - resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND, - ::tflite::ops::builtin::Register_ROUND()); + resolver->AddBuiltin(::tflite::BuiltinOperator_IF, + ::tflite::ops::builtin::Register_IF()), + resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND, + ::tflite::ops::builtin::Register_ROUND()); resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE, ::tflite::ops::builtin::Register_ZEROS_LIKE()); resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE, diff --git a/native/utils/tflite/blacklist_base.cc b/native/utils/tflite/blacklist_base.cc index 8dcfacb..214283b 100644 --- a/native/utils/tflite/blacklist_base.cc +++ b/native/utils/tflite/blacklist_base.cc @@ -77,9 +77,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } } else if (output_categories->type == kTfLiteUInt8) { - const uint8_t one = PodQuantize(1.0, output_categories->params.zero_point, - 1.0 / output_categories->params.scale); - const uint8_t zero = PodQuantize(0.0, output_categories->params.zero_point, + const uint8_t one = + ::seq_flow_lite::PodQuantize(1.0, output_categories->params.zero_point, + 1.0 / output_categories->params.scale); + const uint8_t zero = + ::seq_flow_lite::PodQuantize(0.0, output_categories->params.zero_point, 1.0 / output_categories->params.scale); for (int i = 0; i < input_size; i++) { absl::flat_hash_set<int> categories = op->GetCategories(i); diff --git a/native/utils/tflite/string_projection_base.cc b/native/utils/tflite/string_projection_base.cc index ac0d6eb..d185f52 100644 --- a/native/utils/tflite/string_projection_base.cc +++ b/native/utils/tflite/string_projection_base.cc @@ -125,7 +125,8 @@ void StringProjectionOpBase::DenseLshProjection( float seed = hash_function_[hash_bit].AsFloat(); float bit = running_sign_bit(input, weight, seed, key.get()); output->data.uint8[batch * num_hash_ * num_bits_ + hash_bit] = - PodQuantize(bit, output->params.zero_point, inverse_scale); + seq_flow_lite::PodQuantize(bit, output->params.zero_point, + inverse_scale); } } } diff --git a/native/utils/tokenizer.fbs b/native/utils/tokenizer.fbs index c0a3919..c0a3919 100755..100644 --- a/native/utils/tokenizer.fbs +++ b/native/utils/tokenizer.fbs diff --git a/native/utils/zlib/buffer.fbs b/native/utils/zlib/buffer.fbs index 60da23e..60da23e 100755..100644 --- a/native/utils/zlib/buffer.fbs +++ b/native/utils/zlib/buffer.fbs diff --git a/notification/Android.bp b/notification/Android.bp index 277985b..782d5cb 100644 --- a/notification/Android.bp +++ b/notification/Android.bp @@ -28,7 +28,7 @@ android_library { name: "TextClassifierNotificationLib", static_libs: ["TextClassifierNotificationLibNoManifest"], sdk_version: "system_current", - min_sdk_version: "29", + min_sdk_version: "30", manifest: "AndroidManifest.xml", } @@ -41,6 +41,6 @@ android_library { "guava", ], sdk_version: "system_current", - min_sdk_version: "29", + min_sdk_version: "30", manifest: "LibNoManifest_AndroidManifest.xml", } diff --git a/notification/AndroidManifest.xml b/notification/AndroidManifest.xml index 3153d1d..5a98ea3 100644 --- a/notification/AndroidManifest.xml +++ b/notification/AndroidManifest.xml @@ -1,7 +1,7 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.android.textclassifier.notification"> - <uses-sdk android:minSdkVersion="29" /> + <uses-sdk android:minSdkVersion="30" /> <application> <activity @@ -10,4 +10,4 @@ android:theme="@android:style/Theme.NoDisplay" /> </application> -</manifest>
\ No newline at end of file +</manifest> diff --git a/notification/LibNoManifest_AndroidManifest.xml b/notification/LibNoManifest_AndroidManifest.xml index b9ebf7d..06e8da4 100644 --- a/notification/LibNoManifest_AndroidManifest.xml +++ b/notification/LibNoManifest_AndroidManifest.xml @@ -25,6 +25,6 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.android.textclassifier.notification"> - <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/> + <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/> </manifest> diff --git a/notification/lint-baseline.xml b/notification/lint-baseline.xml deleted file mode 100644 index f2530d7..0000000 --- a/notification/lint-baseline.xml +++ /dev/null @@ -1,37 +0,0 @@ -<?xml version="1.0" encoding="UTF-8"?> -<issues format="5" by="lint 4.1.0" client="cli" variant="all" version="4.1.0"> - - <issue - id="NewApi" - message="Call requires API level R (current min is 29): `android.app.Notification#getContextualActions`" - errorLine1=" boolean hasAppGeneratedContextualActions = !notification.getContextualActions().isEmpty();" - errorLine2=" ~~~~~~~~~~~~~~~~~~~~"> - <location - file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java" - line="246" - column="62"/> - </issue> - - <issue - id="NewApi" - message="Call requires API level R (current min is 29): `android.app.Notification#findRemoteInputActionPair`" - errorLine1=" notification.findRemoteInputActionPair(/* requiresFreeform */ true);" - errorLine2=" ~~~~~~~~~~~~~~~~~~~~~~~~~"> - <location - file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java" - line="249" - column="22"/> - </issue> - - <issue - id="NewApi" - message="Call requires API level R (current min is 29): `android.app.Notification.MessagingStyle.Message#getMessagesFromBundleArray`" - errorLine1=" Message.getMessagesFromBundleArray(" - errorLine2=" ~~~~~~~~~~~~~~~~~~~~~~~~~~"> - <location - file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java" - line="440" - column="17"/> - </issue> - -</issues> diff --git a/notification/tests/AndroidManifest.xml b/notification/tests/AndroidManifest.xml index 81308e3..d3da067 100644 --- a/notification/tests/AndroidManifest.xml +++ b/notification/tests/AndroidManifest.xml @@ -2,8 +2,8 @@ package="com.android.textclassifier.notification"> <uses-sdk - android:minSdkVersion="29" - android:targetSdkVersion="29" /> + android:minSdkVersion="30" + android:targetSdkVersion="30" /> <application> <uses-library android:name="android.test.runner"/> |