summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandroid-build-team Robot <android-build-team-robot@google.com>2021-05-01 03:04:43 +0000
committerandroid-build-team Robot <android-build-team-robot@google.com>2021-05-01 03:04:43 +0000
commitc9d882c18aa0cc12bb74ed0f13e0d6a45f931911 (patch)
tree3f94817a4e3b75f6708bbbd1a4bf10794b481823
parent9334ab60de3b14c64e10733550bc26d6910b83e1 (diff)
parent97265d0e193cec31f858b14c8c1d8d2e0f0fda8a (diff)
downloadlibtextclassifier-c9d882c18aa0cc12bb74ed0f13e0d6a45f931911.tar.gz
Snap for 7328689 from 97265d0e193cec31f858b14c8c1d8d2e0f0fda8a to sc-v2-release
Change-Id: I308902c887ded131677ee1ca3dfcccc01eec94e2
-rw-r--r--java/src/com/android/textclassifier/DefaultTextClassifierService.java64
-rw-r--r--java/src/com/android/textclassifier/TextClassifierImpl.java36
-rw-r--r--java/src/com/android/textclassifier/common/TextClassifierSettings.java11
-rw-r--r--java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java48
-rw-r--r--java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java55
-rw-r--r--java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java57
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java59
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java4
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java38
-rw-r--r--jni/com/google/android/textclassifier/ActionsSuggestionsModel.java11
-rw-r--r--native/FlatBufferHeaders.bp170
-rw-r--r--native/JavaTests.bp24
-rw-r--r--[-rwxr-xr-x]native/actions/actions-entity-data.fbs0
-rw-r--r--native/actions/actions-suggestions.cc30
-rw-r--r--native/actions/actions-suggestions.h12
-rw-r--r--native/actions/actions-suggestions_test.cc16
-rw-r--r--native/actions/actions_jni.cc8
-rw-r--r--native/actions/actions_jni.h3
-rw-r--r--[-rwxr-xr-x]native/actions/actions_model.fbs0
-rw-r--r--native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h (renamed from native/actions/deep_clu/deep-clu-dummy.h)12
-rw-r--r--native/actions/conversation_intent_detection/conversation-intent-detection.h (renamed from native/actions/deep_clu/deep-clu.h)8
-rw-r--r--native/actions/test_data/actions_suggestions_grammar_test.modelbin145160 -> 145144 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.modelbin3387328 -> 3387328 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_9heads.modelbin3874656 -> 3874656 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.modelbin3808816 -> 3808528 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.modelbin3853520 -> 3853536 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.modelbin4671808 -> 4671824 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.modelbin5045120 -> 5045424 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.sensitive_tflite.modelbin0 -> 7111552 bytes
-rw-r--r--native/actions/test_data/en_sensitive_topic_2019117.tflitebin0 -> 439816 bytes
-rw-r--r--native/annotator/annotator.cc17
-rw-r--r--native/annotator/datetime/datetime-grounder.cc3
-rw-r--r--native/annotator/datetime/datetime-grounder_test.cc6
-rw-r--r--[-rwxr-xr-x]native/annotator/datetime/datetime.fbs0
-rw-r--r--[-rwxr-xr-x]native/annotator/entity-data.fbs0
-rw-r--r--[-rwxr-xr-x]native/annotator/experimental/experimental.fbs0
-rw-r--r--[-rwxr-xr-x]native/annotator/model.fbs0
-rw-r--r--[-rwxr-xr-x]native/annotator/person_name/person_name_model.fbs0
-rw-r--r--native/annotator/pod_ner/pod-ner-impl.cc5
-rw-r--r--native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc4
-rw-r--r--native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h4
-rw-r--r--native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h4
-rw-r--r--[-rwxr-xr-x]native/utils/codepoint-range.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/container/bit-vector.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/flatbuffers/flatbuffers.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/grammar/rules.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/grammar/semantics/expression.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/grammar/testing/value.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/i18n/language-tag.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/intents/intent-config.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/normalization.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/resources.fbs0
-rw-r--r--native/utils/tflite-model-executor.cc7
-rw-r--r--native/utils/tflite/blacklist_base.cc8
-rw-r--r--native/utils/tflite/string_projection_base.cc3
-rw-r--r--[-rwxr-xr-x]native/utils/tokenizer.fbs0
-rw-r--r--[-rwxr-xr-x]native/utils/zlib/buffer.fbs0
-rw-r--r--notification/Android.bp4
-rw-r--r--notification/AndroidManifest.xml4
-rw-r--r--notification/LibNoManifest_AndroidManifest.xml2
-rw-r--r--notification/lint-baseline.xml37
-rw-r--r--notification/tests/AndroidManifest.xml4
62 files changed, 457 insertions, 321 deletions
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 378d842..4ca058d 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -25,11 +25,14 @@ import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifierEvent;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.annotation.NonNull;
+import androidx.collection.LruCache;
import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.TextClassifierServiceExecutors;
import com.android.textclassifier.common.TextClassifierSettings;
@@ -45,8 +48,10 @@ import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.FileDescriptor;
import java.io.PrintWriter;
+import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
+import javax.annotation.Nullable;
/** An implementation of a TextClassifierService. */
public final class DefaultTextClassifierService extends TextClassifierService {
@@ -60,6 +65,7 @@ public final class DefaultTextClassifierService extends TextClassifierService {
private TextClassifierSettings settings;
private ModelFileManager modelFileManager;
private BroadcastReceiver localeChangedReceiver;
+ private LruCache<TextClassificationSessionId, TextClassificationContext> sessionIdToContext;
public DefaultTextClassifierService() {
this.injector = new InjectorImpl(this);
@@ -82,6 +88,7 @@ public final class DefaultTextClassifierService extends TextClassifierService {
lowPriorityExecutor = injector.createLowPriorityExecutor();
textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
localeChangedReceiver = new LocaleChangedReceiver(modelFileManager);
+ sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize());
textClassifierApiUsageLogger =
injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
@@ -97,13 +104,26 @@ public final class DefaultTextClassifierService extends TextClassifierService {
}
@Override
+ public void onCreateTextClassificationSession(
+ @NonNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId) {
+ sessionIdToContext.put(sessionId, context);
+ }
+
+ @Override
+ public void onDestroyTextClassificationSession(@NonNull TextClassificationSessionId sessionId) {
+ sessionIdToContext.remove(sessionId);
+ }
+
+ @Override
public void onSuggestSelection(
TextClassificationSessionId sessionId,
TextSelection.Request request,
CancellationSignal cancellationSignal,
Callback<TextSelection> callback) {
handleRequestAsync(
- () -> textClassifier.suggestSelection(request),
+ () ->
+ textClassifier.suggestSelection(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
callback,
textClassifierApiUsageLogger.createSession(
TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId),
@@ -117,7 +137,9 @@ public final class DefaultTextClassifierService extends TextClassifierService {
CancellationSignal cancellationSignal,
Callback<TextClassification> callback) {
handleRequestAsync(
- () -> textClassifier.classifyText(request),
+ () ->
+ textClassifier.classifyText(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
callback,
textClassifierApiUsageLogger.createSession(
TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId),
@@ -131,7 +153,9 @@ public final class DefaultTextClassifierService extends TextClassifierService {
CancellationSignal cancellationSignal,
Callback<TextLinks> callback) {
handleRequestAsync(
- () -> textClassifier.generateLinks(request),
+ () ->
+ textClassifier.generateLinks(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
callback,
textClassifierApiUsageLogger.createSession(
TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId),
@@ -145,7 +169,9 @@ public final class DefaultTextClassifierService extends TextClassifierService {
CancellationSignal cancellationSignal,
Callback<ConversationActions> callback) {
handleRequestAsync(
- () -> textClassifier.suggestConversationActions(request),
+ () ->
+ textClassifier.suggestConversationActions(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
callback,
textClassifierApiUsageLogger.createSession(
TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId),
@@ -159,7 +185,9 @@ public final class DefaultTextClassifierService extends TextClassifierService {
CancellationSignal cancellationSignal,
Callback<TextLanguage> callback) {
handleRequestAsync(
- () -> textClassifier.detectLanguage(request),
+ () ->
+ textClassifier.detectLanguage(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
callback,
textClassifierApiUsageLogger.createSession(
TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId),
@@ -168,7 +196,7 @@ public final class DefaultTextClassifierService extends TextClassifierService {
@Override
public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
- handleEvent(() -> textClassifier.onSelectionEvent(event));
+ handleEvent(() -> textClassifier.onSelectionEvent(sessionId, event));
}
@Override
@@ -182,9 +210,24 @@ public final class DefaultTextClassifierService extends TextClassifierService {
IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
// TODO(licha): Also dump ModelDownloadManager for debugging
textClassifier.dump(indentingPrintWriter);
+ dumpImpl(indentingPrintWriter);
indentingPrintWriter.flush();
}
+ private void dumpImpl(IndentingPrintWriter printWriter) {
+ printWriter.println("DefaultTextClassifierService:");
+ printWriter.increaseIndent();
+ printWriter.println("sessionIdToContext:");
+ printWriter.increaseIndent();
+ for (Map.Entry<TextClassificationSessionId, TextClassificationContext> entry :
+ sessionIdToContext.snapshot().entrySet()) {
+ printWriter.printPair(entry.getKey().getValue(), entry.getValue());
+ }
+ printWriter.decreaseIndent();
+ printWriter.decreaseIndent();
+ printWriter.println();
+ }
+
private <T> void handleRequestAsync(
Callable<T> callable,
Callback<T> callback,
@@ -232,6 +275,15 @@ public final class DefaultTextClassifierService extends TextClassifierService {
MoreExecutors.directExecutor());
}
+ @Nullable
+ private TextClassificationContext sessionIdToTextClassificationContext(
+ @Nullable TextClassificationSessionId sessionId) {
+ if (sessionId == null) {
+ return null;
+ }
+ return sessionIdToContext.get(sessionId);
+ }
+
/**
* Receiver listening to locale change event. Ask ModelFileManager to do clean-up upon receiving.
*/
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index f5d7a47..7383bc1 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -33,6 +33,8 @@ import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassification.Request;
+import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
@@ -127,7 +129,11 @@ final class TextClassifierImpl {
}
@WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) throws IOException {
+ TextSelection suggestSelection(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ TextSelection.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
final int rangeLength = request.getEndIndex() - request.getStartIndex();
@@ -186,7 +192,11 @@ final class TextClassifierImpl {
}
@WorkerThread
- TextClassification classifyText(TextClassification.Request request) throws IOException {
+ TextClassification classifyText(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
LangIdModel langId = getLangIdImpl();
@@ -226,7 +236,11 @@ final class TextClassifierImpl {
}
@WorkerThread
- TextLinks generateLinks(TextLinks.Request request) throws IOException {
+ TextLinks generateLinks(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ TextLinks.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
Preconditions.checkArgument(
request.getText().length() <= getMaxGenerateLinksTextLength(),
@@ -293,6 +307,8 @@ final class TextClassifierImpl {
langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
}
generateLinksLogger.logGenerateLinks(
+ sessionId,
+ textClassificationContext,
request.getText(),
links,
callingPackageName,
@@ -321,7 +337,7 @@ final class TextClassifierImpl {
}
}
- void onSelectionEvent(SelectionEvent event) {
+ void onSelectionEvent(@Nullable TextClassificationSessionId sessionId, SelectionEvent event) {
TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event);
if (textClassifierEvent == null) {
return;
@@ -336,7 +352,11 @@ final class TextClassifierImpl {
TextClassifierEventConverter.fromPlatform(event));
}
- TextLanguage detectLanguage(TextLanguage.Request request) throws IOException {
+ TextLanguage detectLanguage(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ TextLanguage.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
final TextLanguage.Builder builder = new TextLanguage.Builder();
@@ -349,7 +369,10 @@ final class TextClassifierImpl {
return builder.build();
}
- ConversationActions suggestConversationActions(ConversationActions.Request request)
+ ConversationActions suggestConversationActions(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ ConversationActions.Request request)
throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
@@ -650,6 +673,7 @@ final class TextClassifierImpl {
printWriter.println();
settings.dump(printWriter);
+ printWriter.println();
}
}
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
index 0f3322e..fdf259e 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -127,6 +127,9 @@ public final class TextClassifierSettings {
/** Sampling rate for TextClassifier API logging. */
static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate";
+ /** The size of the cache of the mapping of session id to text classification context. */
+ private static final String SESSION_ID_TO_CONTEXT_CACHE_SIZE = "session_id_to_context_cache_size";
+
/**
* A colon(:) separated string that specifies the configuration to use when including surrounding
* context text in language detection queries.
@@ -202,6 +205,8 @@ public final class TextClassifierSettings {
*/
private static final int TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT = 10;
+ private static final int SESSION_ID_TO_CONTEXT_CACHE_SIZE_DEFAULT = 10;
+
// TODO(licha): Consider removing this. We can use real device config for testing.
/** DeviceConfig interface to facilitate testing. */
@VisibleForTesting
@@ -427,6 +432,11 @@ public final class TextClassifierSettings {
NAMESPACE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT);
}
+ public int getSessionIdToContextCacheSize() {
+ return deviceConfig.getInt(
+ NAMESPACE, SESSION_ID_TO_CONTEXT_CACHE_SIZE, SESSION_ID_TO_CONTEXT_CACHE_SIZE_DEFAULT);
+ }
+
public void dump(IndentingPrintWriter pw) {
pw.println("TextClassifierSettings:");
pw.increaseIndent();
@@ -451,6 +461,7 @@ public final class TextClassifierSettings {
pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts());
pw.decreaseIndent();
pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate());
+ pw.printPair(SESSION_ID_TO_CONTEXT_CACHE_SIZE, getSessionIdToContextCacheSize());
pw.decreaseIndent();
}
diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
index 822eb77..df63f2f 100644
--- a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
@@ -16,21 +16,20 @@
package com.android.textclassifier.common.statsd;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
import androidx.collection.ArrayMap;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
import com.android.textclassifier.common.logging.TextClassifierEvent;
-import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
-import java.util.UUID;
-import java.util.function.Supplier;
import javax.annotation.Nullable;
/** A helper for logging calls to generateLinks. */
@@ -40,7 +39,6 @@ public final class GenerateLinksLogger {
private final Random random;
private final int sampleRate;
- private final Supplier<String> randomUuidSupplier;
/**
* @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
@@ -48,24 +46,14 @@ public final class GenerateLinksLogger {
* events, pass 1.
*/
public GenerateLinksLogger(int sampleRate) {
- this(sampleRate, () -> UUID.randomUUID().toString());
- }
-
- /**
- * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
- * chance that a call to logGenerateLinks results in an event being written). To write all
- * events, pass 1.
- * @param randomUuidSupplier supplies random UUIDs.
- */
- @VisibleForTesting
- GenerateLinksLogger(int sampleRate, Supplier<String> randomUuidSupplier) {
this.sampleRate = sampleRate;
random = new Random();
- this.randomUuidSupplier = Preconditions.checkNotNull(randomUuidSupplier);
}
/** Logs statistics about a call to generateLinks. */
public void logGenerateLinks(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
CharSequence text,
TextLinks links,
String callingPackageName,
@@ -95,20 +83,33 @@ public final class GenerateLinksLogger {
totalStats.countLink(link);
perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
}
+ int widgetType = TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ if (textClassificationContext != null) {
+ widgetType = WidgetTypeConverter.toLoggingValue(textClassificationContext.getWidgetType());
+ }
- final String callId = randomUuidSupplier.get();
+ final String sessionIdStr = sessionId == null ? null : sessionId.getValue();
writeStats(
- callId, callingPackageName, null, totalStats, text, latencyMs, annotatorModel, langIdModel);
+ sessionIdStr,
+ callingPackageName,
+ null,
+ totalStats,
+ text,
+ widgetType,
+ latencyMs,
+ annotatorModel,
+ langIdModel);
// Sort the entity types to ensure the logging order is deterministic.
ImmutableList<String> sortedEntityTypes =
ImmutableList.sortedCopyOf(perEntityTypeStats.keySet());
for (String entityType : sortedEntityTypes) {
writeStats(
- callId,
+ sessionIdStr,
callingPackageName,
entityType,
perEntityTypeStats.get(entityType),
text,
+ widgetType,
latencyMs,
annotatorModel,
langIdModel);
@@ -130,11 +131,12 @@ public final class GenerateLinksLogger {
/** Writes a log event for the given stats. */
private static void writeStats(
- String callId,
+ @Nullable String sessionId,
String callingPackageName,
@Nullable String entityType,
LinkifyStats stats,
CharSequence text,
+ int widgetType,
long latencyMs,
Optional<ModelInfo> annotatorModel,
Optional<ModelInfo> langIdModel) {
@@ -142,10 +144,10 @@ public final class GenerateLinksLogger {
String langIdModelName = langIdModel.transform(ModelInfo::toModelName).or("");
TextClassifierStatsLog.write(
TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- callId,
+ sessionId,
TextClassifierEvent.TYPE_LINKS_GENERATED,
annotatorModelName,
- TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
+ widgetType,
/* eventIndex */ 0,
entityType,
stats.numLinks,
@@ -161,7 +163,7 @@ public final class GenerateLinksLogger {
String.format(
Locale.US,
"%s:%s %d links (%d/%d chars) %dms %s annotator=%s langid=%s",
- callId,
+ sessionId,
entityType,
stats.numLinks,
stats.numLinksTextLength,
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
index 6678142..06ad44f 100644
--- a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
@@ -19,7 +19,6 @@ package com.android.textclassifier.common.statsd;
import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Strings.nullToEmpty;
-import android.view.textclassifier.TextClassifier;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils;
import com.android.textclassifier.common.logging.TextClassificationContext;
@@ -195,60 +194,20 @@ public final class TextClassifierEventLogger {
return ResultIdUtils.getModelNames(event.getResultId());
}
- @Nullable
- private static String getPackageName(TextClassifierEvent event) {
+ private static int getWidgetType(TextClassifierEvent event) {
TextClassificationContext eventContext = event.getEventContext();
if (eventContext == null) {
- return null;
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
}
- return eventContext.getPackageName();
+ return WidgetTypeConverter.toLoggingValue(eventContext.getWidgetType());
}
- private static int getWidgetType(TextClassifierEvent event) {
+ @Nullable
+ private static String getPackageName(TextClassifierEvent event) {
TextClassificationContext eventContext = event.getEventContext();
if (eventContext == null) {
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- }
- switch (eventContext.getWidgetType()) {
- case TextClassifier.WIDGET_TYPE_UNKNOWN:
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- case TextClassifier.WIDGET_TYPE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_EDITTEXT:
- return WidgetType.WIDGET_TYPE_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_WEBVIEW:
- return WidgetType.WIDGET_TYPE_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
- return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
- return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_NOTIFICATION:
- return WidgetType.WIDGET_TYPE_NOTIFICATION;
- default: // fall out
+ return null;
}
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- }
-
- /** Widget type constants for logging. */
- public static final class WidgetType {
- // Sync these constants with textclassifier_enums.proto.
- public static final int WIDGET_TYPE_UNKNOWN = 0;
- public static final int WIDGET_TYPE_TEXTVIEW = 1;
- public static final int WIDGET_TYPE_EDITTEXT = 2;
- public static final int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
- public static final int WIDGET_TYPE_WEBVIEW = 4;
- public static final int WIDGET_TYPE_EDIT_WEBVIEW = 5;
- public static final int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
- public static final int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
- public static final int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
- public static final int WIDGET_TYPE_NOTIFICATION = 9;
-
- private WidgetType() {}
+ return eventContext.getPackageName();
}
}
diff --git a/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java b/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java
new file mode 100644
index 0000000..13c04d1
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import android.view.textclassifier.TextClassifier;
+
+/** Converts TextClassifier's WidgetTypes to enum values that are logged to server. */
+final class WidgetTypeConverter {
+ public static int toLoggingValue(String widgetType) {
+ switch (widgetType) {
+ case TextClassifier.WIDGET_TYPE_UNKNOWN:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ case TextClassifier.WIDGET_TYPE_TEXTVIEW:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_EDITTEXT:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_WEBVIEW:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_EDIT_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_NOTIFICATION:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_NOTIFICATION;
+ case "clipboard": // TODO(tonymak) Replace it with WIDGET_TYPE_CLIPBOARD once S SDK is dropped
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CLIPBOARD;
+ default: // fall out
+ }
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ }
+
+ private WidgetTypeConverter() {}
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 07e03ab..81aa832 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -94,7 +94,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextSelection selection = classifier.suggestSelection(request);
+ TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
}
@@ -113,7 +113,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextSelection selection = classifier.suggestSelection(request);
+ TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
}
@@ -128,7 +128,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextSelection selection = classifier.suggestSelection(request);
+ TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
}
@@ -143,7 +143,8 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification =
+ classifier.classifyText(/* sessionId= */ null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
}
@@ -158,7 +159,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
}
@@ -171,7 +172,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
}
@@ -186,7 +187,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
}
@@ -202,7 +203,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
Bundle extras = classification.getExtras();
List<Bundle> entities = ExtrasUtils.getEntities(extras);
@@ -223,7 +224,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
}
@@ -237,7 +238,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
RemoteAction translateAction = classification.getActions().get(0);
assertEquals(1, classification.getActions().size());
assertEquals("Translate", translateAction.getTitle().toString());
@@ -261,7 +262,7 @@ public class TextClassifierImplTest {
String text = "The number is +12122537077. See you tonight!";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
}
@@ -277,7 +278,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
}
@@ -291,7 +292,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
isTextLinksContaining(
text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
}
@@ -308,7 +309,7 @@ public class TextClassifierImplTest {
.setDefaultLocales(LOCALES)
.build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
}
@@ -317,7 +318,7 @@ public class TextClassifierImplTest {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- TextLinks links = classifier.generateLinks(request);
+ TextLinks links = classifier.generateLinks(null, null, request);
assertTrue(links.getLinks().isEmpty());
}
@@ -327,7 +328,7 @@ public class TextClassifierImplTest {
TextLinks.Request request = new TextLinks.Request.Builder(url).build();
assertEquals(
TextLinks.STATUS_UNSUPPORTED_CHARACTER,
- classifier.generateLinks(request).apply(url, 0, null));
+ classifier.generateLinks(null, null, request).apply(url, 0, null));
}
@Test
@@ -335,7 +336,8 @@ public class TextClassifierImplTest {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- expectThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request));
+ expectThrows(
+ IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request));
}
@Test
@@ -345,7 +347,7 @@ public class TextClassifierImplTest {
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
- TextLinks textLinks = classifier.generateLinks(request);
+ TextLinks textLinks = classifier.generateLinks(null, null, request);
assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
@@ -360,7 +362,7 @@ public class TextClassifierImplTest {
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
- TextLinks textLinks = classifier.generateLinks(request);
+ TextLinks textLinks = classifier.generateLinks(null, null, request);
assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
@@ -372,7 +374,7 @@ public class TextClassifierImplTest {
public void testDetectLanguage() throws IOException {
String text = "This is English text";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
- TextLanguage textLanguage = classifier.detectLanguage(request);
+ TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
assertThat(textLanguage, isTextLanguage("en"));
}
@@ -380,7 +382,7 @@ public class TextClassifierImplTest {
public void testDetectLanguage_japanese() throws IOException {
String text = "これは日本語のテキストです";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
- TextLanguage textLanguage = classifier.detectLanguage(request);
+ TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
assertThat(textLanguage, isTextLanguage("ja"));
}
@@ -401,7 +403,8 @@ public class TextClassifierImplTest {
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
@@ -424,7 +427,8 @@ public class TextClassifierImplTest {
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertTrue(conversationActions.getConversationActions().size() > 1);
for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
@@ -448,7 +452,8 @@ public class TextClassifierImplTest {
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
@@ -475,7 +480,8 @@ public class TextClassifierImplTest {
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
@@ -497,7 +503,8 @@ public class TextClassifierImplTest {
.setMaxSuggestions(3)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).isEmpty();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
index 03041e1..4d5ca4a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
@@ -501,10 +501,10 @@ public final class ModelFileManagerTest {
ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
assertThat(listedModels).hasSize(2);
- assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile1.getAbsolutePath());
assertThat(listedModels.get(0).isAsset).isFalse();
- assertThat(listedModels.get(1).absolutePath).isEqualTo(modelFile2.getAbsolutePath());
assertThat(listedModels.get(1).isAsset).isFalse();
+ assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
+ .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
}
@Test
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
index 6c66dd5..e215b15 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
@@ -18,8 +18,12 @@ package com.android.textclassifier.common.statsd;
import static com.google.common.truth.Truth.assertThat;
+import android.os.Binder;
+import android.os.Parcel;
import android.stats.textclassifier.EventType;
import android.stats.textclassifier.WidgetType;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
import androidx.test.core.app.ApplicationProvider;
@@ -55,6 +59,11 @@ public class GenerateLinksLoggerTest {
new ModelInfo(1, ImmutableList.of(Locale.ENGLISH));
private static final ModelInfo LANGID_MODEL =
new ModelInfo(2, ImmutableList.of(Locale.forLanguageTag("*")));
+ private static final String SESSION_ID = "123456";
+ private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_WEBVIEW;
+ private static final WidgetType WIDGET_TYPE_ENUM = WidgetType.WIDGET_TYPE_WEBVIEW;
+ private final TextClassificationContext textClassificationContext =
+ new TextClassificationContext.Builder(PACKAGE_NAME, WIDGET_TYPE).build();
@Before
public void setup() throws Exception {
@@ -83,11 +92,11 @@ public class GenerateLinksLoggerTest {
new TextLinks.Builder(testText)
.addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores)
.build();
- String uuid = "uuid";
- GenerateLinksLogger generateLinksLogger =
- new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid);
+ GenerateLinksLogger generateLinksLogger = new GenerateLinksLogger(/* sampleRate= */ 1);
generateLinksLogger.logGenerateLinks(
+ createTextClassificationSessionId(),
+ textClassificationContext,
testText,
links,
PACKAGE_NAME,
@@ -103,10 +112,10 @@ public class GenerateLinksLoggerTest {
assertThat(loggedEvents).hasSize(2);
TextLinkifyEvent summaryEvent =
AtomsProto.TextLinkifyEvent.newBuilder()
- .setSessionId(uuid)
+ .setSessionId(SESSION_ID)
.setEventIndex(0)
.setModelName("en_v1")
- .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN)
+ .setWidgetType(WIDGET_TYPE_ENUM)
.setEventType(EventType.LINKS_GENERATED)
.setPackageName(PACKAGE_NAME)
.setEntityType("")
@@ -118,10 +127,10 @@ public class GenerateLinksLoggerTest {
.build();
TextLinkifyEvent phoneEvent =
AtomsProto.TextLinkifyEvent.newBuilder()
- .setSessionId(uuid)
+ .setSessionId(SESSION_ID)
.setEventIndex(0)
.setModelName("en_v1")
- .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN)
+ .setWidgetType(WIDGET_TYPE_ENUM)
.setEventType(EventType.LINKS_GENERATED)
.setPackageName(PACKAGE_NAME)
.setEntityType(TextClassifier.TYPE_PHONE)
@@ -148,11 +157,11 @@ public class GenerateLinksLoggerTest {
.addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores)
.addLink(addressOffset, addressOffset + addressText.length(), addressEntityScores)
.build();
- String uuid = "uuid";
- GenerateLinksLogger generateLinksLogger =
- new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid);
+ GenerateLinksLogger generateLinksLogger = new GenerateLinksLogger(/* sampleRate= */ 1);
generateLinksLogger.logGenerateLinks(
+ createTextClassificationSessionId(),
+ textClassificationContext,
testText,
links,
PACKAGE_NAME,
@@ -182,4 +191,13 @@ public class GenerateLinksLoggerTest {
assertThat(phoneEvent.getNumLinks()).isEqualTo(1);
assertThat(phoneEvent.getLinkedTextLength()).isEqualTo(phoneText.length());
}
+
+ private static TextClassificationSessionId createTextClassificationSessionId() {
+ // A hack to create TextClassificationSessionId because its constructor is @hide.
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(SESSION_ID);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
}
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index c5f2112..ad3992d 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -172,10 +172,10 @@ public final class ActionsSuggestionsModel implements AutoCloseable {
assetFileDescriptor.getLength());
}
- /** Initializes DeepCLU, passing the given serialized config to it. */
- public void initializeDeepClu(byte[] serializedConfig) {
- if (!nativeInitializeDeepClu(actionsModelPtr, serializedConfig)) {
- throw new IllegalArgumentException("Couldn't initialize DeepCLU");
+ /** Initializes conversation intent detection, passing the given serialized config to it. */
+ public void initializeConversationIntentDetection(byte[] serializedConfig) {
+ if (!nativeInitializeConversationIntentDetection(actionsModelPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize conversation intent detection");
}
}
@@ -345,7 +345,8 @@ public final class ActionsSuggestionsModel implements AutoCloseable {
private static native long nativeNewActionsModelWithOffset(
int fd, long offset, long size, byte[] preconditionsOverwrite);
- private native boolean nativeInitializeDeepClu(long actionsModelPtr, byte[] serializedConfig);
+ private native boolean nativeInitializeConversationIntentDetection(
+ long actionsModelPtr, byte[] serializedConfig);
private static native String nativeGetLocales(int fd);
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
index bf4ff44..950eee6 100644
--- a/native/FlatBufferHeaders.bp
+++ b/native/FlatBufferHeaders.bp
@@ -15,20 +15,6 @@
//
genrule {
- name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- srcs: ["lang_id/common/flatbuffers/model.fbs"],
- out: ["lang_id/common/flatbuffers/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
- srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
- out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
name: "libtextclassifier_fbgen_actions_actions_model",
srcs: ["actions/actions_model.fbs"],
out: ["actions/actions_model_generated.h"],
@@ -43,23 +29,23 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_annotator_model",
- srcs: ["annotator/model.fbs"],
- out: ["annotator/model_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
+ out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
- srcs: ["annotator/person_name/person_name_model.fbs"],
- out: ["annotator/person_name/person_name_model_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_annotator_experimental_experimental",
- srcs: ["annotator/experimental/experimental.fbs"],
- out: ["annotator/experimental/experimental_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ srcs: ["annotator/person_name/person_name_model.fbs"],
+ out: ["annotator/person_name/person_name_model_generated.h"],
defaults: ["fbgen"],
}
@@ -71,37 +57,37 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_annotator_entity-data",
- srcs: ["annotator/entity-data.fbs"],
- out: ["annotator/entity-data_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_experimental_experimental",
+ srcs: ["annotator/experimental/experimental.fbs"],
+ out: ["annotator/experimental/experimental_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_testing_value",
- srcs: ["utils/grammar/testing/value.fbs"],
- out: ["utils/grammar/testing/value_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_entity-data",
+ srcs: ["annotator/entity-data.fbs"],
+ out: ["annotator/entity-data_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- srcs: ["utils/grammar/semantics/expression.fbs"],
- out: ["utils/grammar/semantics/expression_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_model",
+ srcs: ["annotator/model.fbs"],
+ out: ["annotator/model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_rules",
- srcs: ["utils/grammar/rules.fbs"],
- out: ["utils/grammar/rules_generated.h"],
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ srcs: ["utils/flatbuffers/flatbuffers.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_normalization",
- srcs: ["utils/normalization.fbs"],
- out: ["utils/normalization_generated.h"],
+ name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ srcs: ["utils/tflite/text_encoder_config.fbs"],
+ out: ["utils/tflite/text_encoder_config_generated.h"],
defaults: ["fbgen"],
}
@@ -113,37 +99,51 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_utils_i18n_language-tag",
- srcs: ["utils/i18n/language-tag.fbs"],
- out: ["utils/i18n/language-tag_generated.h"],
+ name: "libtextclassifier_fbgen_utils_zlib_buffer",
+ srcs: ["utils/zlib/buffer.fbs"],
+ out: ["utils/zlib/buffer_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
- srcs: ["utils/tflite/text_encoder_config.fbs"],
- out: ["utils/tflite/text_encoder_config_generated.h"],
+ name: "libtextclassifier_fbgen_utils_container_bit-vector",
+ srcs: ["utils/container/bit-vector.fbs"],
+ out: ["utils/container/bit-vector_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- srcs: ["utils/flatbuffers/flatbuffers.fbs"],
- out: ["utils/flatbuffers/flatbuffers_generated.h"],
+ name: "libtextclassifier_fbgen_utils_intents_intent-config",
+ srcs: ["utils/intents/intent-config.fbs"],
+ out: ["utils/intents/intent-config_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_container_bit-vector",
- srcs: ["utils/container/bit-vector.fbs"],
- out: ["utils/container/bit-vector_generated.h"],
+ name: "libtextclassifier_fbgen_utils_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_tokenizer",
- srcs: ["utils/tokenizer.fbs"],
- out: ["utils/tokenizer_generated.h"],
+ name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ srcs: ["utils/grammar/semantics/expression.fbs"],
+ out: ["utils/grammar/semantics/expression_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_rules",
+ srcs: ["utils/grammar/rules.fbs"],
+ out: ["utils/grammar/rules_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_testing_value",
+ srcs: ["utils/grammar/testing/value.fbs"],
+ out: ["utils/grammar/testing/value_generated.h"],
defaults: ["fbgen"],
}
@@ -155,16 +155,16 @@ genrule {
}
genrule {
- name: "libtextclassifier_fbgen_utils_zlib_buffer",
- srcs: ["utils/zlib/buffer.fbs"],
- out: ["utils/zlib/buffer_generated.h"],
+ name: "libtextclassifier_fbgen_utils_tokenizer",
+ srcs: ["utils/tokenizer.fbs"],
+ out: ["utils/tokenizer_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_intents_intent-config",
- srcs: ["utils/intents/intent-config.fbs"],
- out: ["utils/intents/intent-config_generated.h"],
+ name: "libtextclassifier_fbgen_utils_i18n_language-tag",
+ srcs: ["utils/i18n/language-tag.fbs"],
+ out: ["utils/i18n/language-tag_generated.h"],
defaults: ["fbgen"],
}
@@ -178,50 +178,50 @@ cc_library_headers {
"com.android.extservices",
],
generated_headers: [
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_actions_actions_model",
"libtextclassifier_fbgen_actions_actions-entity-data",
- "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
- "libtextclassifier_fbgen_annotator_experimental_experimental",
"libtextclassifier_fbgen_annotator_datetime_datetime",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
"libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- "libtextclassifier_fbgen_utils_grammar_rules",
- "libtextclassifier_fbgen_utils_normalization",
- "libtextclassifier_fbgen_utils_resources",
- "libtextclassifier_fbgen_utils_i18n_language-tag",
- "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- "libtextclassifier_fbgen_utils_container_bit-vector",
- "libtextclassifier_fbgen_utils_tokenizer",
- "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_utils_resources",
"libtextclassifier_fbgen_utils_zlib_buffer",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
"libtextclassifier_fbgen_utils_intents_intent-config",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
],
export_generated_headers: [
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_actions_actions_model",
"libtextclassifier_fbgen_actions_actions-entity-data",
- "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
- "libtextclassifier_fbgen_annotator_experimental_experimental",
"libtextclassifier_fbgen_annotator_datetime_datetime",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
"libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- "libtextclassifier_fbgen_utils_grammar_rules",
- "libtextclassifier_fbgen_utils_normalization",
- "libtextclassifier_fbgen_utils_resources",
- "libtextclassifier_fbgen_utils_i18n_language-tag",
- "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- "libtextclassifier_fbgen_utils_container_bit-vector",
- "libtextclassifier_fbgen_utils_tokenizer",
- "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_utils_resources",
"libtextclassifier_fbgen_utils_zlib_buffer",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
"libtextclassifier_fbgen_utils_intents_intent-config",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
],
}
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
index 4c22ba3..78d5748 100644
--- a/native/JavaTests.bp
+++ b/native/JavaTests.bp
@@ -17,30 +17,30 @@
filegroup {
name: "libtextclassifier_java_test_sources",
srcs: [
- "actions/actions-suggestions_test.cc",
"actions/grammar-actions_test.cc",
+ "actions/actions-suggestions_test.cc",
"annotator/pod_ner/pod-ner-impl_test.cc",
"annotator/datetime/regex-parser_test.cc",
- "annotator/datetime/datetime-grounder_test.cc",
"annotator/datetime/grammar-parser_test.cc",
- "utils/grammar/parsing/lexer_test.cc",
- "utils/regex-match_test.cc",
- "utils/calendar/calendar_test.cc",
+ "annotator/datetime/datetime-grounder_test.cc",
"utils/intents/intent-generator-test-lib.cc",
- "annotator/grammar/grammar-annotator_test.cc",
- "annotator/grammar/test-utils.cc",
+ "utils/calendar/calendar_test.cc",
+ "utils/regex-match_test.cc",
+ "utils/grammar/parsing/lexer_test.cc",
"annotator/number/number_test-include.cc",
"annotator/annotator_test-include.cc",
+ "annotator/grammar/grammar-annotator_test.cc",
+ "annotator/grammar/test-utils.cc",
"utils/utf8/unilib_test-include.cc",
- "utils/grammar/parsing/parser_test.cc",
"utils/grammar/analyzer_test.cc",
"utils/grammar/semantics/composer_test.cc",
- "utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
- "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
- "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
"utils/grammar/semantics/evaluators/arithmetic-eval_test.cc",
- "utils/grammar/semantics/evaluators/span-eval_test.cc",
+ "utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
"utils/grammar/semantics/evaluators/const-eval_test.cc",
"utils/grammar/semantics/evaluators/compose-eval_test.cc",
+ "utils/grammar/semantics/evaluators/span-eval_test.cc",
+ "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
+ "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
+ "utils/grammar/parsing/parser_test.cc",
],
}
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
index 21584b6..21584b6 100755..100644
--- a/native/actions/actions-entity-data.fbs
+++ b/native/actions/actions-entity-data.fbs
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index f437c8d..830976e 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -1040,13 +1040,13 @@ bool ActionsSuggestions::SuggestActionsFromModel(
return ReadModelOutput(interpreter->get(), options, response);
}
-Status ActionsSuggestions::SuggestActionsFromDeepClu(
+Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
const Conversation& conversation, const ActionSuggestionOptions& options,
std::vector<ActionSuggestion>* actions) const {
- std::vector<ActionSuggestion> deep_clu_actions;
- TC3_ASSIGN_OR_RETURN(deep_clu_actions,
- deep_clu_->SuggestActions(conversation, options));
- for (const auto& action : deep_clu_actions) {
+ TC3_ASSIGN_OR_RETURN(
+ std::vector<ActionSuggestion> new_actions,
+ conversation_intent_detection_->SuggestActions(conversation, options));
+ for (auto& action : new_actions) {
actions->push_back(std::move(action));
}
return Status::OK;
@@ -1381,12 +1381,13 @@ bool ActionsSuggestions::GatherActionsSuggestions(
return true;
}
- if (deep_clu_) {
+ if (conversation_intent_detection_) {
// TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
- auto actions = SuggestActionsFromDeepClu(annotated_conversation, options,
- &response->actions);
+ auto actions = SuggestActionsFromConversationIntentDetection(
+ annotated_conversation, options, &response->actions);
if (!actions.ok()) {
- TC3_LOG(ERROR) << "Could not run DeepCLU: " << actions.error_message();
+ TC3_LOG(ERROR) << "Could not run conversation intent detection: "
+ << actions.error_message();
return false;
}
}
@@ -1475,14 +1476,15 @@ const ActionsModel* ViewActionsModel(const void* buffer, int size) {
return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
}
-bool ActionsSuggestions::InitializeDeepClu(
+bool ActionsSuggestions::InitializeConversationIntentDetection(
const std::string& serialized_config) {
- auto deep_clu = std::make_unique<DeepClu>();
- if (!deep_clu->Initialize(serialized_config).ok()) {
- TC3_LOG(ERROR) << "Failed to initialize DeepCLU.";
+ auto conversation_intent_detection =
+ std::make_unique<ConversationIntentDetection>();
+ if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
+ TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
return false;
}
- deep_clu_ = std::move(deep_clu);
+ conversation_intent_detection_ = std::move(conversation_intent_detection);
return true;
}
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 08a3f65..32edc78 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -25,7 +25,7 @@
#include <vector>
#include "actions/actions_model_generated.h"
-#include "actions/deep_clu/deep-clu.h"
+#include "actions/conversation_intent_detection/conversation-intent-detection.h"
#include "actions/feature-processor.h"
#include "actions/grammar-actions.h"
#include "actions/ranker.h"
@@ -105,7 +105,8 @@ class ActionsSuggestions {
const Conversation& conversation, const Annotator* annotator,
const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
- bool InitializeDeepClu(const std::string& serialized_config);
+ bool InitializeConversationIntentDetection(
+ const std::string& serialized_config);
const ActionsModel* model() const;
const reflection::Schema* entity_data_schema() const;
@@ -193,7 +194,7 @@ class ActionsSuggestions {
ActionsSuggestionsResponse* response,
std::unique_ptr<tflite::Interpreter>* interpreter) const;
- Status SuggestActionsFromDeepClu(
+ Status SuggestActionsFromConversationIntentDetection(
const Conversation& conversation, const ActionSuggestionOptions& options,
std::vector<ActionSuggestion>* actions) const;
@@ -267,8 +268,9 @@ class ActionsSuggestions {
// Low confidence input ngram classifier.
std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
- // DeepCLU model for additional actions.
- std::unique_ptr<const DeepClu> deep_clu_;
+ // Conversation intent detection model for additional actions.
+ std::unique_ptr<const ConversationIntentDetection>
+ conversation_intent_detection_;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 3b0ca0f..6e66b32 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -61,6 +61,8 @@ constexpr char kMultiTaskSrP13nModelFileName[] =
"actions_suggestions_test.multi_task_sr_p13n.model";
constexpr char kMultiTaskSrEmojiModelFileName[] =
"actions_suggestions_test.multi_task_sr_emoji.model";
+constexpr char kSensitiveTFliteModelFileName[] =
+ "actions_suggestions_test.sensitive_tflite.model";
std::string ReadFile(const std::string& file_name) {
std::ifstream file_stream(file_name);
@@ -1808,5 +1810,19 @@ TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
}
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kSensitiveTFliteModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I want to kill myself",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 0);
+ EXPECT_TRUE(response.output_filtered_low_confidence);
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 9c71ab6..5981a17 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -567,7 +567,8 @@ TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeGetNativeModelPtr)
reinterpret_cast<ActionsSuggestionsJniContext*>(ptr)->model());
}
-TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, nativeInitializeDeepClu)
+TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME,
+ nativeInitializeConversationIntentDetection)
(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config) {
if (!ptr) {
return false;
@@ -579,6 +580,7 @@ TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, nativeInitializeDeepClu)
std::string serialized_config;
TC3_ASSIGN_OR_RETURN_0(
serialized_config, JByteArrayToString(env, jserialized_config),
- TC3_LOG(ERROR) << "Could not convert serialized DeepCLU config.");
- return model->InitializeDeepClu(serialized_config);
+ TC3_LOG(ERROR) << "Could not convert serialized conversation intent "
+ "detection config.");
+ return model->InitializeConversationIntentDetection(serialized_config);
}
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
index f693101..5265a9c 100644
--- a/native/actions/actions_jni.h
+++ b/native/actions/actions_jni.h
@@ -41,7 +41,8 @@ TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size,
jbyteArray serialized_preconditions);
-TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME, nativeInitializeDeepClu)
+TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME,
+ nativeInitializeConversationIntentDetection)
(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config);
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 8c03eeb..8c03eeb 100755..100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
diff --git a/native/actions/deep_clu/deep-clu-dummy.h b/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h
index e547c70..66255c5 100644
--- a/native/actions/deep_clu/deep-clu-dummy.h
+++ b/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_DUMMY_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_DUMMY_H_
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_
#include <string>
#include <vector>
@@ -26,10 +26,10 @@
namespace libtextclassifier3 {
-// A dummy implementation of DeepCLU.
-class DeepClu {
+// A dummy implementation of conversation intent detection.
+class ConversationIntentDetection {
public:
- DeepClu() {}
+ ConversationIntentDetection() {}
Status Initialize(const std::string& serialized_config) { return Status::OK; }
@@ -42,4 +42,4 @@ class DeepClu {
} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_DUMMY_H_
+#endif // LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_
diff --git a/native/actions/deep_clu/deep-clu.h b/native/actions/conversation_intent_detection/conversation-intent-detection.h
index a10641b..949ceaf 100644
--- a/native/actions/deep_clu/deep-clu.h
+++ b/native/actions/conversation_intent_detection/conversation-intent-detection.h
@@ -14,9 +14,9 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_H_
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_
-#include "actions/deep_clu/deep-clu-dummy.h"
+#include "actions/conversation_intent_detection/conversation-intent-detection-dummy.h"
-#endif // LIBTEXTCLASSIFIER_ACTIONS_DEEP_CLU_DEEP_CLU_H_
+#endif // LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index 8c46dec..d900928 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index 5e43fd0..aa62c0a 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index b1912f5..50918e5 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index 772c1bf..b43e6d7 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index eb75802..6a71da3 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index 889b7e0..72f4d9d 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index 3da6e29..a6c8118 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
new file mode 100644
index 0000000..4a120b2
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/actions/test_data/en_sensitive_topic_2019117.tflite b/native/actions/test_data/en_sensitive_topic_2019117.tflite
new file mode 100644
index 0000000..48edfbd
--- /dev/null
+++ b/native/actions/test_data/en_sensitive_topic_2019117.tflite
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 78f513e..e296a64 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -2739,14 +2739,19 @@ bool Annotator::ParseAndFillInMoneyAmount(
std::string quantity;
GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
&quantity, &quantity_exponent);
- if ((quantity_exponent > 0 && quantity_exponent < 9) ||
- (quantity_exponent == 9 && data->money->amount_whole_part <= 2)) {
- data->money->amount_whole_part =
+ if (quantity_exponent > 0 && quantity_exponent <= 9) {
+ const double amount_whole_part =
data->money->amount_whole_part * pow(10, quantity_exponent) +
data->money->nanos / pow(10, 9 - quantity_exponent);
- data->money->nanos = data->money->nanos %
- static_cast<int>(pow(10, 9 - quantity_exponent)) *
- pow(10, quantity_exponent);
+ // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
+ // (and `std::numeric_limits<int>::max()` to
+ // `std::numeric_limits<int64>::max()`).
+ if (amount_whole_part < std::numeric_limits<int>::max()) {
+ data->money->amount_whole_part = amount_whole_part;
+ data->money->nanos = data->money->nanos %
+ static_cast<int>(pow(10, 9 - quantity_exponent)) *
+ pow(10, quantity_exponent);
+ }
}
if (quantity_exponent > 0) {
data->money->unnormalized_amount = strings::JoinStrings(
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
index aacb917..7d5f440 100644
--- a/native/annotator/datetime/datetime-grounder.cc
+++ b/native/annotator/datetime/datetime-grounder.cc
@@ -58,7 +58,8 @@ const std::unordered_map<int, int> kMonthDefaultLastDayMap(
bool IsValidDatetime(const AbsoluteDateTime* absolute_datetime) {
// Sanity Checks.
if (absolute_datetime->minute() > 59 || absolute_datetime->second() > 59 ||
- absolute_datetime->hour() > 23 || absolute_datetime->month() > 12) {
+ absolute_datetime->hour() > 23 || absolute_datetime->month() > 12 ||
+ absolute_datetime->month() == 0) {
return false;
}
if (absolute_datetime->day() >= 0) {
diff --git a/native/annotator/datetime/datetime-grounder_test.cc b/native/annotator/datetime/datetime-grounder_test.cc
index e55bbfc..121aae8 100644
--- a/native/annotator/datetime/datetime-grounder_test.cc
+++ b/native/annotator/datetime/datetime-grounder_test.cc
@@ -261,6 +261,12 @@ TEST_F(DatetimeGrounderTest, InValidUngroundedDatetime) {
/*hour=*/11, /*minute=*/59, /*second=*/99,
grammar::datetime::Meridiem_AM)
.get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/00, /*day=*/28,
+ /*hour=*/11, /*minute=*/59, /*second=*/99,
+ grammar::datetime::Meridiem_AM)
+ .get());
}
TEST_F(DatetimeGrounderTest, ValidUngroundedDatetime) {
diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs
index 9a96bae..9a96bae 100755..100644
--- a/native/annotator/datetime/datetime.fbs
+++ b/native/annotator/datetime/datetime.fbs
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
index f82eb44..f82eb44 100755..100644
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
index 6e15d04..6e15d04 100755..100644
--- a/native/annotator/experimental/experimental.fbs
+++ b/native/annotator/experimental/experimental.fbs
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 57187f5..57187f5 100755..100644
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs
index b15543f..b15543f 100755..100644
--- a/native/annotator/person_name/person_name_model.fbs
+++ b/native/annotator/person_name/person_name_model.fbs
diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc
index 2d7f0a2..666b7c7 100644
--- a/native/annotator/pod_ner/pod-ner-impl.cc
+++ b/native/annotator/pod_ner/pod-ner-impl.cc
@@ -126,7 +126,7 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
::tflite::BuiltinOperator_EXPAND_DIMS,
::tflite::ops::builtin::Register_EXPAND_DIMS());
mutable_resolver->AddCustom(
- "LayerNorm", ::tflite::ops::custom::Register_LAYER_NORM());
+ "LayerNorm", ::seq_flow_lite::ops::custom::Register_LAYER_NORM());
});
std::unique_ptr<tflite::Interpreter> tflite_interpreter;
@@ -253,7 +253,8 @@ std::vector<LabelT> PodNerAnnotator::ReadResultsFromInterpreter(
float max_prob = 0.0f;
int max_index = 0;
for (int cindex = 0; cindex < output->dims->data[2]; ++cindex) {
- const float probability = ::tflite::PodDequantize(*output, index++);
+ const float probability =
+ ::seq_flow_lite::PodDequantize(*output, index++);
if (probability > max_prob) {
max_prob = probability;
max_index = cindex;
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc
index df2b55f..e28b04d 100644
--- a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc
@@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
-namespace tflite {
+namespace seq_flow_lite {
namespace ops {
namespace custom {
@@ -344,4 +344,4 @@ TfLiteRegistration* Register_LAYER_NORM() {
} // namespace custom
} // namespace ops
-} // namespace tflite
+} // namespace seq_flow_lite
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
index 017d21a..6d84ca4 100644
--- a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
@@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
-namespace tflite {
+namespace seq_flow_lite {
namespace ops {
namespace custom {
@@ -41,6 +41,6 @@ TfLiteRegistration* Register_LAYER_NORM();
} // namespace custom
} // namespace ops
-} // namespace tflite
+} // namespace seq_flow_lite
#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
index f8e3836..7f2db41 100644
--- a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
@@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/lite/context.h"
-namespace tflite {
+namespace seq_flow_lite {
// Returns the original (dequantized) value of 8bit value.
inline float PodDequantizeValue(const TfLiteTensor& tensor, uint8_t value) {
@@ -64,6 +64,6 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
return static_cast<uint8_t>(std::max(std::min(255, integer_value), 0));
}
-} // namespace tflite
+} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
diff --git a/native/utils/codepoint-range.fbs b/native/utils/codepoint-range.fbs
index 135ce30..135ce30 100755..100644
--- a/native/utils/codepoint-range.fbs
+++ b/native/utils/codepoint-range.fbs
diff --git a/native/utils/container/bit-vector.fbs b/native/utils/container/bit-vector.fbs
index d117ee5..d117ee5 100755..100644
--- a/native/utils/container/bit-vector.fbs
+++ b/native/utils/container/bit-vector.fbs
diff --git a/native/utils/flatbuffers/flatbuffers.fbs b/native/utils/flatbuffers/flatbuffers.fbs
index 155e8f8..155e8f8 100755..100644
--- a/native/utils/flatbuffers/flatbuffers.fbs
+++ b/native/utils/flatbuffers/flatbuffers.fbs
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
index bc0136c..bc0136c 100755..100644
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
diff --git a/native/utils/grammar/semantics/expression.fbs b/native/utils/grammar/semantics/expression.fbs
index 5397407..5397407 100755..100644
--- a/native/utils/grammar/semantics/expression.fbs
+++ b/native/utils/grammar/semantics/expression.fbs
diff --git a/native/utils/grammar/testing/value.fbs b/native/utils/grammar/testing/value.fbs
index 0429491..0429491 100755..100644
--- a/native/utils/grammar/testing/value.fbs
+++ b/native/utils/grammar/testing/value.fbs
diff --git a/native/utils/i18n/language-tag.fbs b/native/utils/i18n/language-tag.fbs
index a2e1077..a2e1077 100755..100644
--- a/native/utils/i18n/language-tag.fbs
+++ b/native/utils/i18n/language-tag.fbs
diff --git a/native/utils/intents/intent-config.fbs b/native/utils/intents/intent-config.fbs
index 672eb9d..672eb9d 100755..100644
--- a/native/utils/intents/intent-config.fbs
+++ b/native/utils/intents/intent-config.fbs
diff --git a/native/utils/normalization.fbs b/native/utils/normalization.fbs
index 4d43f10..4d43f10 100755..100644
--- a/native/utils/normalization.fbs
+++ b/native/utils/normalization.fbs
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
index b4d9b83..b4d9b83 100755..100644
--- a/native/utils/resources.fbs
+++ b/native/utils/resources.fbs
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 104dedc..36db3e9 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -39,6 +39,7 @@ TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SOFTMAX();
TfLiteRegistration* Register_GATHER();
TfLiteRegistration* Register_GATHER_ND();
+TfLiteRegistration* Register_IF();
TfLiteRegistration* Register_ROUND();
TfLiteRegistration* Register_ZEROS_LIKE();
TfLiteRegistration* Register_TRANSPOSE();
@@ -145,8 +146,10 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND,
::tflite::ops::builtin::Register_GATHER_ND(),
/*version=*/2);
- resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
- ::tflite::ops::builtin::Register_ROUND());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_IF,
+ ::tflite::ops::builtin::Register_IF()),
+ resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
+ ::tflite::ops::builtin::Register_ROUND());
resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE,
::tflite::ops::builtin::Register_ZEROS_LIKE());
resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
diff --git a/native/utils/tflite/blacklist_base.cc b/native/utils/tflite/blacklist_base.cc
index 8dcfacb..214283b 100644
--- a/native/utils/tflite/blacklist_base.cc
+++ b/native/utils/tflite/blacklist_base.cc
@@ -77,9 +77,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
}
} else if (output_categories->type == kTfLiteUInt8) {
- const uint8_t one = PodQuantize(1.0, output_categories->params.zero_point,
- 1.0 / output_categories->params.scale);
- const uint8_t zero = PodQuantize(0.0, output_categories->params.zero_point,
+ const uint8_t one =
+ ::seq_flow_lite::PodQuantize(1.0, output_categories->params.zero_point,
+ 1.0 / output_categories->params.scale);
+ const uint8_t zero =
+ ::seq_flow_lite::PodQuantize(0.0, output_categories->params.zero_point,
1.0 / output_categories->params.scale);
for (int i = 0; i < input_size; i++) {
absl::flat_hash_set<int> categories = op->GetCategories(i);
diff --git a/native/utils/tflite/string_projection_base.cc b/native/utils/tflite/string_projection_base.cc
index ac0d6eb..d185f52 100644
--- a/native/utils/tflite/string_projection_base.cc
+++ b/native/utils/tflite/string_projection_base.cc
@@ -125,7 +125,8 @@ void StringProjectionOpBase::DenseLshProjection(
float seed = hash_function_[hash_bit].AsFloat();
float bit = running_sign_bit(input, weight, seed, key.get());
output->data.uint8[batch * num_hash_ * num_bits_ + hash_bit] =
- PodQuantize(bit, output->params.zero_point, inverse_scale);
+ seq_flow_lite::PodQuantize(bit, output->params.zero_point,
+ inverse_scale);
}
}
}
diff --git a/native/utils/tokenizer.fbs b/native/utils/tokenizer.fbs
index c0a3919..c0a3919 100755..100644
--- a/native/utils/tokenizer.fbs
+++ b/native/utils/tokenizer.fbs
diff --git a/native/utils/zlib/buffer.fbs b/native/utils/zlib/buffer.fbs
index 60da23e..60da23e 100755..100644
--- a/native/utils/zlib/buffer.fbs
+++ b/native/utils/zlib/buffer.fbs
diff --git a/notification/Android.bp b/notification/Android.bp
index 277985b..782d5cb 100644
--- a/notification/Android.bp
+++ b/notification/Android.bp
@@ -28,7 +28,7 @@ android_library {
name: "TextClassifierNotificationLib",
static_libs: ["TextClassifierNotificationLibNoManifest"],
sdk_version: "system_current",
- min_sdk_version: "29",
+ min_sdk_version: "30",
manifest: "AndroidManifest.xml",
}
@@ -41,6 +41,6 @@ android_library {
"guava",
],
sdk_version: "system_current",
- min_sdk_version: "29",
+ min_sdk_version: "30",
manifest: "LibNoManifest_AndroidManifest.xml",
}
diff --git a/notification/AndroidManifest.xml b/notification/AndroidManifest.xml
index 3153d1d..5a98ea3 100644
--- a/notification/AndroidManifest.xml
+++ b/notification/AndroidManifest.xml
@@ -1,7 +1,7 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier.notification">
- <uses-sdk android:minSdkVersion="29" />
+ <uses-sdk android:minSdkVersion="30" />
<application>
<activity
@@ -10,4 +10,4 @@
android:theme="@android:style/Theme.NoDisplay" />
</application>
-</manifest> \ No newline at end of file
+</manifest>
diff --git a/notification/LibNoManifest_AndroidManifest.xml b/notification/LibNoManifest_AndroidManifest.xml
index b9ebf7d..06e8da4 100644
--- a/notification/LibNoManifest_AndroidManifest.xml
+++ b/notification/LibNoManifest_AndroidManifest.xml
@@ -25,6 +25,6 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier.notification">
- <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
</manifest>
diff --git a/notification/lint-baseline.xml b/notification/lint-baseline.xml
deleted file mode 100644
index f2530d7..0000000
--- a/notification/lint-baseline.xml
+++ /dev/null
@@ -1,37 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<issues format="5" by="lint 4.1.0" client="cli" variant="all" version="4.1.0">
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.app.Notification#getContextualActions`"
- errorLine1=" boolean hasAppGeneratedContextualActions = !notification.getContextualActions().isEmpty();"
- errorLine2=" ~~~~~~~~~~~~~~~~~~~~">
- <location
- file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java"
- line="246"
- column="62"/>
- </issue>
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.app.Notification#findRemoteInputActionPair`"
- errorLine1=" notification.findRemoteInputActionPair(/* requiresFreeform */ true);"
- errorLine2=" ~~~~~~~~~~~~~~~~~~~~~~~~~">
- <location
- file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java"
- line="249"
- column="22"/>
- </issue>
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.app.Notification.MessagingStyle.Message#getMessagesFromBundleArray`"
- errorLine1=" Message.getMessagesFromBundleArray("
- errorLine2=" ~~~~~~~~~~~~~~~~~~~~~~~~~~">
- <location
- file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java"
- line="440"
- column="17"/>
- </issue>
-
-</issues>
diff --git a/notification/tests/AndroidManifest.xml b/notification/tests/AndroidManifest.xml
index 81308e3..d3da067 100644
--- a/notification/tests/AndroidManifest.xml
+++ b/notification/tests/AndroidManifest.xml
@@ -2,8 +2,8 @@
package="com.android.textclassifier.notification">
<uses-sdk
- android:minSdkVersion="29"
- android:targetSdkVersion="29" />
+ android:minSdkVersion="30"
+ android:targetSdkVersion="30" />
<application>
<uses-library android:name="android.test.runner"/>