summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-05-09 06:00:07 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-05-09 06:00:07 +0000
commit91373f542bb90c4ef1377b9503078d8add07cc34 (patch)
tree4b3ec7a534166d5663d826de85b941d82e74ea9b
parent44144f837942b745b5d1eff5403989b4dfa1e0b6 (diff)
parent775e966e07fb11a55afff5ab93b79128c29a84ac (diff)
downloadlibtextclassifier-android13-frc-extservices-release.tar.gz
Snap for 8558685 from 775e966e07fb11a55afff5ab93b79128c29a84ac to tm-frc-extservices-releaset_frc_ext_330443000android13-frc-extservices-release
Change-Id: Ia6e0cf122e64c64e08e96faa687b1ba9a4332345
-rw-r--r--OWNERS6
-rw-r--r--TEST_MAPPING19
-rw-r--r--java/src/com/android/textclassifier/ExtrasUtils.java4
-rw-r--r--java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java4
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloadManager.java179
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java8
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java10
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloaderService.java2
-rw-r--r--java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java2
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java36
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java2
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java29
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java10
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java171
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java55
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java186
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java108
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java58
-rw-r--r--native/utils/tokenfree/byte_encoder_test.cc4
19 files changed, 460 insertions, 433 deletions
diff --git a/OWNERS b/OWNERS
index 81cfdb8..46bd5b1 100644
--- a/OWNERS
+++ b/OWNERS
@@ -2,6 +2,6 @@
# Please update this list if you find better candidates.
tonymak@google.com
toki@google.com
-zilka@google.com
-mns@google.com
-jalt@google.com
+licha@google.com
+joannechung@google.com
+lpeter@google.com \ No newline at end of file
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 72e022b..370acd6 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -21,6 +21,25 @@
"name": "TCSModelDownloaderIntegrationTest"
}
],
+ "hwasan-postsubmit": [
+ {
+ "name": "TextClassifierServiceTest",
+ "options": [
+ {
+ "exclude-annotation": "androidx.test.filters.FlakyTest"
+ }
+ ]
+ },
+ {
+ "name": "libtextclassifier_tests"
+ },
+ {
+ "name": "libtextclassifier_java_tests"
+ },
+ {
+ "name": "TextClassifierNotificationTests"
+ }
+ ],
"mainline-presubmit": [
{
"name": "TextClassifierNotificationTests[com.google.android.extservices.apex]"
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
index fd64581..bde3898 100644
--- a/java/src/com/android/textclassifier/ExtrasUtils.java
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -87,7 +87,9 @@ public final class ExtrasUtils {
return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
}
- /** @see #getTopLanguage(Intent) */
+ /**
+ * @see #getTopLanguage(Intent)
+ */
static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
final int maxSize = Math.min(3, languageScores.getEntities().size());
final String[] languages =
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
index 1ae79ce..9bdfb5e 100644
--- a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
@@ -195,7 +195,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager
@Override
public void onDownloadCompleted(
ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) {
- TcLog.v(TAG, "Start to clean up models and update model lookup cache...");
+ TcLog.d(TAG, "Start to clean up models and update model lookup cache...");
// Step 1: Clean up ManifestEnrollment table
List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
@@ -286,7 +286,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager
// Clear the cache table and rebuild the cache based on ModelView table
private void updateCache() {
synchronized (cacheLock) {
- TcLog.v(TAG, "Updating model lookup cache...");
+ TcLog.d(TAG, "Updating model lookup cache...");
for (String modelType : ModelType.values()) {
modelLookupCache.get(modelType).clear();
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
index b125f13..af33e21 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
@@ -44,6 +44,7 @@ import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Enums;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
import com.google.common.hash.Hashing;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
@@ -54,6 +55,7 @@ import java.time.Instant;
import java.util.List;
import java.util.Locale;
import java.util.UUID;
+import java.util.concurrent.Callable;
import javax.annotation.Nullable;
/** Manager to listen to config update and download latest models. */
@@ -64,6 +66,7 @@ public final class ModelDownloadManager {
private final Context appContext;
private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
+ private final Callable<WorkManager> workManagerSupplier;
private final DownloadedModelManager downloadedModelManager;
private final TextClassifierSettings settings;
private final ListeningExecutorService executorService;
@@ -84,6 +87,7 @@ public final class ModelDownloadManager {
this(
appContext,
ModelDownloadWorker.class,
+ () -> WorkManager.getInstance(appContext),
DownloadedModelManagerImpl.getInstance(appContext),
settings,
executorService);
@@ -93,11 +97,13 @@ public final class ModelDownloadManager {
public ModelDownloadManager(
Context appContext,
Class<? extends ListenableWorker> modelDownloadWorkerClass,
+ Callable<WorkManager> workManagerSupplier,
DownloadedModelManager downloadedModelManager,
TextClassifierSettings settings,
ListeningExecutorService executorService) {
this.appContext = Preconditions.checkNotNull(appContext);
this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
+ this.workManagerSupplier = Preconditions.checkNotNull(workManagerSupplier);
this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager);
this.settings = Preconditions.checkNotNull(settings);
this.executorService = Preconditions.checkNotNull(executorService);
@@ -121,22 +127,31 @@ public final class ModelDownloadManager {
/** Returns the downlaoded models for the given modelType. */
@Nullable
public List<File> listDownloadedModels(@ModelTypeDef String modelType) {
- return downloadedModelManager.listModels(modelType);
+ try {
+ return downloadedModelManager.listModels(modelType);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to list downloaded models", t);
+ return ImmutableList.of();
+ }
}
/** Notifies the model downlaoder that the text classifier service is created. */
public void onTextClassifierServiceCreated() {
- DeviceConfig.addOnPropertiesChangedListener(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
- appContext.registerReceiver(
- localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
- TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
- if (!settings.isModelDownloadManagerEnabled()) {
- return;
+ try {
+ DeviceConfig.addOnPropertiesChangedListener(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
+ appContext.registerReceiver(
+ localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
+ TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
+ if (!settings.isModelDownloadManagerEnabled()) {
+ return;
+ }
+ maybeOverrideLocaleListForTesting();
+ TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started.");
+ scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t);
}
- maybeOverrideLocaleListForTesting();
- TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started.");
- scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
}
// TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -146,8 +161,12 @@ public final class ModelDownloadManager {
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- TcLog.v(TAG, "Try to schedule model download work because of system locale changes.");
- scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+ TcLog.d(TAG, "Try to schedule model download work because of system locale changes.");
+ try {
+ scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onLocaleChanged", t);
+ }
}
// TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -157,16 +176,24 @@ public final class ModelDownloadManager {
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- maybeOverrideLocaleListForTesting();
- TcLog.v(TAG, "Try to schedule model download work because of device config changes.");
- scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+ TcLog.d(TAG, "Try to schedule model download work because of device config changes.");
+ try {
+ maybeOverrideLocaleListForTesting();
+ scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onTextClassifierDeviceConfigChanged", t);
+ }
}
/** Clean up internal states on destroying. */
public void destroy() {
- DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
- appContext.unregisterReceiver(localeChangedReceiver);
- TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+ try {
+ DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
+ appContext.unregisterReceiver(localeChangedReceiver);
+ TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to destroy ModelDownloadManager", t);
+ }
}
/**
@@ -178,10 +205,14 @@ public final class ModelDownloadManager {
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- printWriter.println("ModelDownloadManager:");
- printWriter.increaseIndent();
- downloadedModelManager.dump(printWriter);
- printWriter.decreaseIndent();
+ try {
+ printWriter.println("ModelDownloadManager:");
+ printWriter.increaseIndent();
+ downloadedModelManager.dump(printWriter);
+ printWriter.decreaseIndent();
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to dump ModelDownloadManager", t);
+ }
}
/**
@@ -193,54 +224,62 @@ public final class ModelDownloadManager {
private void scheduleDownloadWork(int reasonToSchedule) {
long workId =
Hashing.farmHashFingerprint64().hashUnencodedChars(UUID.randomUUID().toString()).asLong();
- NetworkType networkType =
- Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
- .or(NetworkType.UNMETERED);
- OneTimeWorkRequest downloadRequest =
- new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
- .setConstraints(
- new Constraints.Builder()
- .setRequiredNetworkType(networkType)
- .setRequiresBatteryNotLow(true)
- .setRequiresStorageNotLow(true)
- .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
- .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
- .build())
- .setBackoffCriteria(
- BackoffPolicy.EXPONENTIAL,
- settings.getModelDownloadBackoffDelayInMillis(),
- MILLISECONDS)
- .setInputData(
- new Data.Builder()
- .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
- .putLong(
- ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
- Instant.now().toEpochMilli())
- .build())
- .build();
- ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
- WorkManager.getInstance(appContext)
- .enqueueUniqueWork(
- UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
- .getResult();
- Futures.addCallback(
- enqueueResultFuture,
- new FutureCallback<Operation.State.SUCCESS>() {
- @Override
- public void onSuccess(Operation.State.SUCCESS unused) {
- TcLog.v(TAG, "Download work scheduled.");
- TextClassifierDownloadLogger.downloadWorkScheduled(
- workId, reasonToSchedule, /* failedToSchedule= */ false);
- }
+ try {
+ NetworkType networkType =
+ Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
+ .or(NetworkType.UNMETERED);
+ OneTimeWorkRequest downloadRequest =
+ new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+ .setConstraints(
+ new Constraints.Builder()
+ .setRequiredNetworkType(networkType)
+ .setRequiresBatteryNotLow(true)
+ .setRequiresStorageNotLow(true)
+ .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
+ .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
+ .build())
+ .setBackoffCriteria(
+ BackoffPolicy.EXPONENTIAL,
+ settings.getModelDownloadBackoffDelayInMillis(),
+ MILLISECONDS)
+ .setInputData(
+ new Data.Builder()
+ .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
+ .putLong(
+ ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
+ Instant.now().toEpochMilli())
+ .build())
+ .build();
+ ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
+ workManagerSupplier
+ .call()
+ .enqueueUniqueWork(
+ UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
+ .getResult();
+ Futures.addCallback(
+ enqueueResultFuture,
+ new FutureCallback<Operation.State.SUCCESS>() {
+ @Override
+ public void onSuccess(Operation.State.SUCCESS unused) {
+ TcLog.d(TAG, "Download work scheduled.");
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ false);
+ }
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "Failed to schedule download work: ", t);
- TextClassifierDownloadLogger.downloadWorkScheduled(
- workId, reasonToSchedule, /* failedToSchedule= */ true);
- }
- },
- executorService);
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ true);
+ }
+ },
+ executorService);
+ } catch (Throwable t) {
+ // TODO(licha): this is just for temporary fix. Refactor the try-catch in the future.
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ true);
+ }
}
private void maybeOverrideLocaleListForTesting() {
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
index 6e04e16..3db0815 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
@@ -113,6 +113,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
@Override
public final ListenableFuture<ListenableWorker.Result> startWork() {
+ TcLog.d(TAG, "Start download work...");
workStartedTimeMillis = getCurrentTimeMillis();
// Notice: startWork() is invoked on the main thread
if (!settings.isModelDownloadManagerEnabled()) {
@@ -121,7 +122,6 @@ public final class ModelDownloadWorker extends ListenableWorker {
TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED);
return Futures.immediateFuture(ListenableWorker.Result.failure());
}
- TcLog.v(TAG, "Start download work...");
if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
TcLog.d(TAG, "Max attempt reached. Abort download work.");
logDownloadWorkCompleted(
@@ -134,7 +134,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
downloadResult -> {
Preconditions.checkNotNull(manifestsToDownload);
downloadedModelManager.onDownloadCompleted(manifestsToDownload);
- TcLog.v(TAG, "Download work completed: " + downloadResult);
+ TcLog.d(TAG, "Download work completed: " + downloadResult);
if (downloadResult.failureCount() == 0) {
logDownloadWorkCompleted(
downloadResult.successCount() > 0
@@ -239,7 +239,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
return Futures.whenAllComplete(downloadResultFutures)
.call(
() -> {
- TcLog.v(TAG, "All Download Tasks Completed");
+ TcLog.d(TAG, "All Download Tasks Completed");
int successCount = 0;
int failureCount = 0;
for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
@@ -333,7 +333,7 @@ public final class ModelDownloadWorker extends ListenableWorker {
Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
if (downloadedManifest != null
&& downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
- TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl);
+ TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl);
return Futures.immediateVoidFuture();
}
if (pendingDownloads.containsKey(manifestUrl)) {
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
index 2244e9a..0b76f22 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
@@ -99,7 +99,7 @@ final class ModelDownloaderImpl implements ModelDownloader {
new FutureCallback<File>() {
@Override
public void onSuccess(File pendingModelFile) {
- TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
+ TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
}
@Override
@@ -170,11 +170,11 @@ final class ModelDownloaderImpl implements ModelDownloader {
} catch (IOException e) {
throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e);
}
- TcLog.v(TAG, "Pending model file passed validation.");
+ TcLog.d(TAG, "Pending model file passed validation.");
}
private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) {
- TcLog.v(TAG, "Starting a new connection to ModelDownloaderService");
+ TcLog.d(TAG, "Starting a new connection to ModelDownloaderService");
return CallbackToFutureAdapter.getFuture(
completer -> {
conn.attachCompleter(completer);
@@ -197,7 +197,7 @@ final class ModelDownloaderImpl implements ModelDownloader {
// restult future will hang there until time out. (WorkManager forces a 10-min running time.)
private static ListenableFuture<Long> scheduleDownload(
IModelDownloaderService service, URI uri, File targetFile) {
- TcLog.v(TAG, "Scheduling a new download task with ModelDownloaderService");
+ TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService");
return CallbackToFutureAdapter.getFuture(
completer -> {
service.download(
@@ -236,7 +236,7 @@ final class ModelDownloaderImpl implements ModelDownloader {
@Override
public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
- TcLog.v(TAG, "DownloaderService connected");
+ TcLog.d(TAG, "DownloaderService connected");
completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder)));
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
index e4ebbfa..6d7e47e 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
@@ -39,7 +39,7 @@ public final class ModelDownloaderService extends Service {
@Override
public IBinder onBind(Intent intent) {
- TcLog.v(TAG, "Binding to ModelDownloadService");
+ TcLog.d(TAG, "Binding to ModelDownloadService");
return iBinder;
}
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
index 439588b..47e6f19 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
@@ -91,7 +91,7 @@ final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
@Override
public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) {
- TcLog.v(TAG, "Download request received: " + uri);
+ TcLog.d(TAG, "Download request received: " + uri);
try {
File targetFile = new File(targetFilePath);
File tempMetadataFile = getMetadataFile(targetFile);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 71f9a4f..ddab8bd 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -17,7 +17,10 @@
package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import android.content.Context;
import android.os.CancellationSignal;
@@ -39,6 +42,7 @@ import com.android.os.AtomsProto.Atom;
import com.android.os.AtomsProto.TextClassifierApiUsageReported;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.StatsdTestUtils;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
@@ -47,6 +51,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
+import java.io.IOException;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
@@ -81,13 +86,21 @@ public class DefaultTextClassifierServiceTest {
@Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
@Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
@Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
+ @Mock private ModelFileManager testModelFileManager;
@Before
- public void setup() {
-
- testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
+ public void setup() throws IOException {
+ testInjector =
+ new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager);
defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
defaultTextClassifierService.onCreate();
+
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+ when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+ .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+ .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Before
@@ -211,11 +224,8 @@ public class DefaultTextClassifierServiceTest {
@Test
public void missingModelFile_onFailureShouldBeCalled() throws Exception {
- testInjector.setModelFileManager(
- new ModelFileManagerImpl(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(),
- testInjector.createTextClassifierSettings()));
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(null);
defaultTextClassifierService.onCreate();
TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
@@ -251,12 +261,9 @@ public class DefaultTextClassifierServiceTest {
private final Context context;
private ModelFileManager modelFileManager;
- private TestInjector(Context context) {
+ private TestInjector(Context context, ModelFileManager modelFileManager) {
this.context = Preconditions.checkNotNull(context);
- }
-
- private void setModelFileManager(ModelFileManager modelFileManager) {
- this.modelFileManager = modelFileManager;
+ this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
}
@Override
@@ -267,9 +274,6 @@ public class DefaultTextClassifierServiceTest {
@Override
public ModelFileManager createModelFileManager(
TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
- if (modelFileManager == null) {
- return TestDataUtils.createModelFileManagerForTesting(context);
- }
return modelFileManager;
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
index 20ae592..0e40515 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
@@ -25,6 +25,7 @@ import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import androidx.work.WorkManager;
import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister;
import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister;
@@ -87,6 +88,7 @@ public final class ModelFileManagerImplTest {
new ModelDownloadManager(
context,
ModelDownloadWorker.class,
+ () -> WorkManager.getInstance(context),
downloadedModelManager,
settings,
MoreExecutors.newDirectExecutorService());
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index bac4fa1..a19e3ff 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,12 +16,10 @@
package com.android.textclassifier;
-import android.content.Context;
-import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
+import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.ModelType;
-import com.android.textclassifier.common.TextClassifierSettings;
-import com.google.common.collect.ImmutableList;
import java.io.File;
+import java.io.IOException;
/** Utils to access test data files. */
public final class TestDataUtils {
@@ -30,7 +28,7 @@ public final class TestDataUtils {
private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
/** Returns the root folder that contains the test data. */
- public static File getTestDataFolder() {
+ private static File getTestDataFolder() {
return new File("/data/local/tmp/TextClassifierServiceTest/");
}
@@ -38,24 +36,25 @@ public final class TestDataUtils {
return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
}
+ public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR);
+ }
+
public static File getTestActionsModelFile() {
return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
}
+ public static ModelFile getTestActionsModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(
+ getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS);
+ }
+
public static File getLangIdModelFile() {
return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
}
- public static ModelFileManager createModelFileManagerForTesting(Context context) {
- return new ModelFileManagerImpl(
- context,
- ImmutableList.of(
- new RegularFileFullMatchLister(
- ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
- new RegularFileFullMatchLister(
- ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
- new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)),
- new TextClassifierSettings());
+ public static ModelFile getLangIdModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID);
}
private TestDataUtils() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
index 42177e6..e7bf90c 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -56,6 +56,10 @@ public class TextClassifierApiTest {
@Before
public void setup() {
+ extServicesTextClassifierRule.enableVerboseLogging();
+ // Verbose logging only takes effect after restarting ExtServices
+ extServicesTextClassifierRule.forceStopExtServices();
+
textClassifier = extServicesTextClassifierRule.getTextClassifier();
}
@@ -81,8 +85,8 @@ public class TextClassifierApiTest {
@Test
public void classifyText() {
- String text = "Contact me at droid@android.com";
- String classifiedText = "droid@android.com";
+ String text = "Contact me at http://www.android.com";
+ String classifiedText = "http://www.android.com";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
@@ -90,7 +94,7 @@ public class TextClassifierApiTest {
TextClassification classification = textClassifier.classifyText(request);
assertThat(classification.getEntityCount()).isGreaterThan(0);
- assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
assertThat(classification.getText()).isEqualTo(classifiedText);
assertThat(classification.getActions()).isNotEmpty();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index fb1aea8..c20ec8a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -22,6 +22,9 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.expectThrows;
@@ -74,30 +77,34 @@ public class TextClassifierImplTest {
private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
private static final String NO_TYPE = null;
- @Mock private ModelFileManagerImpl.ModelFileLister mockModelFileLister;
+ @Mock private ModelFileManager modelFileManager;
- private TextClassifierSettings settings;
private Context context;
private TestingDeviceConfig deviceConfig;
- private TextClassifierImpl classifier;
-
- private final ModelFileManager modelFileManager =
- TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
+ private TextClassifierSettings settings;
private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
+ private TextClassifierImpl classifier;
@Before
- public void setup() {
+ public void setup() throws IOException {
MockitoAnnotations.initMocks(this);
- deviceConfig = new TestingDeviceConfig();
- Context context =
+ this.context =
new FakeContextBuilder()
.setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
.setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
.build();
- this.context = context;
- settings = new TextClassifierSettings(deviceConfig);
- // TODO(veronikanikina): consider using a testing constructor here.
- classifier = new TextClassifierImpl(context, settings, modelFileManager);
+ this.deviceConfig = new TestingDeviceConfig();
+ this.settings = new TextClassifierSettings(deviceConfig);
+ this.annotatorModelCache = new LruCache<>(2);
+ this.classifier =
+ new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
+
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+ when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+ .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+ when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+ .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Test
@@ -110,9 +117,7 @@ public class TextClassifierImplTest {
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
@@ -120,6 +125,24 @@ public class TextClassifierImplTest {
}
@Test
+ public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
+ String text = "Contact me at droid@android.com";
+ String selected = "droid";
+ String suggested = "droid@android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ classifier.suggestSelection(null, null, request);
+ verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
+ }
+
+ @Test
public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
@@ -129,9 +152,7 @@ public class TextClassifierImplTest {
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
@@ -144,9 +165,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
@@ -160,7 +179,6 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(suggested);
TextSelection.Request request =
new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1)
- .setDefaultLocales(LOCALES)
.setIncludeTextClassification(true)
.build();
@@ -178,7 +196,6 @@ public class TextClassifierImplTest {
String text = "Visit http://www.android.com for more information";
TextSelection.Request request =
new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4)
- .setDefaultLocales(LOCALES)
.setIncludeTextClassification(false)
.build();
@@ -194,9 +211,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification =
classifier.classifyText(/* sessionId= */ null, null, request);
@@ -210,9 +225,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -223,9 +236,7 @@ public class TextClassifierImplTest {
public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
- new TextClassification.Request.Builder(text, 0, text.length())
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, 0, text.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
@@ -238,9 +249,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -254,9 +263,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
@@ -275,9 +282,7 @@ public class TextClassifierImplTest {
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
@@ -289,14 +294,12 @@ public class TextClassifierImplTest {
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
TextClassification.Request request =
- new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
RemoteAction translateAction = classification.getActions().get(0);
assertEquals(1, classification.getActions().size());
- assertEquals("Translate", translateAction.getTitle().toString());
+ assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
@@ -323,18 +326,17 @@ public class TextClassifierImplTest {
@Test
public void testGenerateLinks_exclude() throws IOException {
- String text = "You want apple@banana.com. See you tonight!";
+ String text = "The number is +12122537077. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
- List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
- not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
}
@Test
@@ -344,7 +346,6 @@ public class TextClassifierImplTest {
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
@@ -361,7 +362,6 @@ public class TextClassifierImplTest {
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
@@ -573,29 +573,16 @@ public class TextClassifierImplTest {
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
ModelFile annotatorModelB =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
- String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath();
- ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false);
-
- annotatorModelCache = new LruCache<>(2);
- ModelFileManager modelFileManagerCached =
- new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings);
- TextClassifierImpl textClassifierImpl =
- new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache);
- LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String englishText = "You can reach me on +12122537077.";
String classifiedText = "+12122537077";
TextClassification.Request request =
- new TextClassification.Request.Builder(englishText, 0, englishText.length())
- .setDefaultLocales(LOCALES)
- .build();
-
- when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel));
+ new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
// Check modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classificationA = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationA = classifier.classifyText(null, null, request);
assertThat(classificationA.getId()).contains("v701");
assertThat(classificationA.getText()).contains(classifiedText);
@@ -609,9 +596,9 @@ public class TextClassifierImplTest {
});
// Check modelFileB v801
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelB));
- TextClassification classificationB = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelB);
+ TextClassification classificationB = classifier.classifyText(null, null, request);
assertThat(classificationB.getId()).contains("v801");
assertThat(classificationB.getText()).contains(classifiedText);
@@ -625,9 +612,9 @@ public class TextClassifierImplTest {
});
// Reload modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationAcached = classifier.classifyText(null, null, request);
assertThat(classificationAcached.getId()).contains("v701");
assertThat(classificationAcached.getText()).contains(classifiedText);
@@ -651,28 +638,16 @@ public class TextClassifierImplTest {
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
ModelFile annotatorModelB =
new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
- String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath();
- ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false);
-
- annotatorModelCache = new LruCache<>(settings.getMultiAnnotatorCacheSize());
- ModelFileManager modelFileManagerCached =
- new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings);
- TextClassifierImpl textClassifierImpl =
- new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache);
- LocaleList.setDefault(LocaleList.forLanguageTags("en"));
+
String englishText = "You can reach me on +12122537077.";
String classifiedText = "+12122537077";
TextClassification.Request request =
- new TextClassification.Request.Builder(englishText, 0, englishText.length())
- .setDefaultLocales(LOCALES)
- .build();
-
- when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel));
+ new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
// Check modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classification = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification.getId()).contains("v701");
assertThat(classification.getText()).contains(classifiedText);
@@ -686,9 +661,9 @@ public class TextClassifierImplTest {
});
// Check modelFileB v801
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelB));
- TextClassification classificationB = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelB);
+ TextClassification classificationB = classifier.classifyText(null, null, request);
assertThat(classificationB.getId()).contains("v801");
assertThat(classificationB.getText()).contains(classifiedText);
@@ -702,9 +677,9 @@ public class TextClassifierImplTest {
});
// Reload modelFileA v701
- when(mockModelFileLister.list(ModelType.ANNOTATOR))
- .thenReturn(ImmutableList.of(annotatorModelA));
- TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request);
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationAcached = classifier.classifyText(null, null, request);
assertThat(classificationAcached.getId()).contains("v701");
assertThat(classificationAcached.getText()).contains(classifiedText);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index 394b7ad..9e11c09 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -67,6 +67,7 @@ public final class ModelDownloadManagerTest {
private TestingDeviceConfig deviceConfig;
private WorkManager workManager;
private ModelDownloadManager downloadManager;
+ private ModelDownloadManager downloadManagerWithBadWorkManager;
@Mock DownloadedModelManager downloadedModelManager;
@Before
@@ -80,6 +81,17 @@ public final class ModelDownloadManagerTest {
new ModelDownloadManager(
context,
ModelDownloadWorker.class,
+ () -> workManager,
+ downloadedModelManager,
+ new TextClassifierSettings(deviceConfig),
+ MoreExecutors.newDirectExecutorService());
+ this.downloadManagerWithBadWorkManager =
+ new ModelDownloadManager(
+ context,
+ ModelDownloadWorker.class,
+ () -> {
+ throw new IllegalStateException("WorkManager may fail!");
+ },
downloadedModelManager,
new TextClassifierSettings(deviceConfig),
MoreExecutors.newDirectExecutorService());
@@ -96,7 +108,20 @@ public final class ModelDownloadManagerTest {
}
@Test
+ public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+ downloadManagerWithBadWorkManager.onTextClassifierServiceCreated();
+
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
downloadManager.onTextClassifierServiceCreated();
WorkInfo workInfo =
@@ -104,21 +129,34 @@ public final class ModelDownloadManagerTest {
DownloaderTestUtils.queryWorkInfos(
workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
}
@Test
public void onTextClassifierServiceCreated_localeListOverridden() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
downloadManager.onTextClassifierServiceCreated();
assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
}
@Test
+ public void onLocaleChanged_workManagerCrashed() throws Exception {
+ downloadManagerWithBadWorkManager.onLocaleChanged();
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
public void onLocaleChanged_requestEnqueued() throws Exception {
downloadManager.onLocaleChanged();
@@ -131,6 +169,16 @@ public final class ModelDownloadManagerTest {
}
@Test
+ public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception {
+ downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged();
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
downloadManager.onTextClassifierDeviceConfigChanged();
@@ -188,6 +236,13 @@ public final class ModelDownloadManagerTest {
assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
}
+ @Test
+ public void listDownloadedModels_doNotCrashOnError() throws Exception {
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException());
+
+ assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty();
+ }
+
private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception {
TextClassifierDownloadWorkScheduled atom =
Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
index e4360c6..e261158 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
@@ -18,16 +18,10 @@ package com.android.textclassifier.downloader;
import static com.google.common.truth.Truth.assertThat;
-import android.app.Instrumentation;
-import android.app.UiAutomation;
import android.util.Log;
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassification.Request;
-import android.view.textclassifier.TextClassifier;
-import androidx.test.platform.app.InstrumentationRegistry;
import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
-import com.android.textclassifier.testing.TestingLocaleListOverrideRule;
-import java.io.IOException;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@@ -48,171 +42,133 @@ public class ModelDownloaderIntegrationTest {
private static final String V804_EN_TAG = "en_v804";
private static final String V804_RU_TAG = "ru_v804";
private static final String FACTORY_MODEL_TAG = "*";
-
- @Rule
- public final TestingLocaleListOverrideRule testingLocaleListOverrideRule =
- new TestingLocaleListOverrideRule();
+ private static final int ASSERT_MAX_ATTEMPTS = 20;
+ private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000;
@Rule
public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
new ExtServicesTextClassifierRule();
- private TextClassifier textClassifier;
-
@Before
public void setup() throws Exception {
- // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky
- runShellCommand("device_config put textclassifier config_updater_model_enabled false");
- runShellCommand("device_config put textclassifier model_download_manager_enabled true");
- runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 5");
-
- textClassifier = extServicesTextClassifierRule.getTextClassifier();
- startExtservicesProcess();
+ extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false");
+ extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true");
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "model_download_backoff_delay_in_millis", "5");
+ extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US");
+ extServicesTextClassifierRule.overrideDeviceConfig();
+
+ extServicesTextClassifierRule.enableVerboseLogging();
+ // Verbose logging only takes effect after restarting ExtServices
+ extServicesTextClassifierRule.forceStopExtServices();
}
@After
public void tearDown() throws Exception {
- runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
- runShellCommand("device_config delete textclassifier manifest_url_annotator_ru");
- runShellCommand("device_config put textclassifier config_updater_model_enabled true");
- runShellCommand("device_config delete textclassifier multi_language_support_enabled");
- runShellCommand(
- "device_config put textclassifier model_download_backoff_delay_in_millis 3600000");
+ // This is to reset logging/locale_override for ExtServices.
+ extServicesTextClassifierRule.forceStopExtServices();
}
@Test
- public void smokeTest() throws IOException, InterruptedException {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + V804_EN_ANNOTATOR_MANIFEST_URL);
+ public void smokeTest() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG));
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
}
@Test
- public void downgradeModel() throws IOException, InterruptedException {
+ public void downgradeModel() throws Exception {
// Download an experimental model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
-
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 1000,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
// Downgrade to an older model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + V804_EN_ANNOTATOR_MANIFEST_URL);
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG));
- }
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
}
@Test
- public void upgradeModel() throws IOException, InterruptedException {
+ public void upgradeModel() throws Exception {
// Download a model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + V804_EN_ANNOTATOR_MANIFEST_URL);
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG));
- }
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
// Upgrade to an experimental model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
-
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 1000,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
}
@Test
- public void clearFlag() throws IOException, InterruptedException {
+ public void clearFlag() throws Exception {
// Download a new model.
- {
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
-
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 1000,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
// Revert the flag.
- {
- runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
- // Fallback to use the universal model.
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 1000,
- () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
- }
+ extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", "");
+ // Fallback to use the universal model.
+ assertWithRetries(
+ () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
}
@Test
- public void modelsForMultipleLanguagesDownloaded() throws IOException, InterruptedException {
- runShellCommand("device_config put textclassifier multi_language_support_enabled true");
- testingLocaleListOverrideRule.set("en-US", "ru-RU");
+ public void modelsForMultipleLanguagesDownloaded() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true");
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "testing_locale_list_override", "en-US,ru-RU");
// download en model
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_en "
- + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
// download ru model
- runShellCommand(
- "device_config put textclassifier manifest_url_annotator_ru "
- + V804_RU_ANNOTATOR_MANIFEST_URL);
- assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 1000,
- () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL);
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
- assertWithRetries(/* maxAttempts= */ 10, /* sleepMs= */ 1000, this::verifyActiveRussianModel);
+ assertWithRetries(this::verifyActiveRussianModel);
assertWithRetries(
- /* maxAttempts= */ 10,
- /* sleepMs= */ 1000,
() -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG));
}
- private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable)
- throws InterruptedException {
- for (int i = 0; i < maxAttempts; i++) {
+ private void assertWithRetries(Runnable assertRunnable) throws Exception {
+ for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) {
try {
+ extServicesTextClassifierRule.overrideDeviceConfig();
assertRunnable.run();
break; // success. Bail out.
} catch (AssertionError ex) {
- if (i == maxAttempts - 1) { // last attempt, give up.
+ if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up.
+ extServicesTextClassifierRule.dumpDefaultTextClassifierService();
throw ex;
} else {
- Thread.sleep(sleepMs);
+ Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS);
}
+ } catch (Exception unknownException) {
+ throw unknownException;
}
}
}
private void verifyActiveModel(String text, String expectedVersion) {
TextClassification textClassification =
- textClassifier.classifyText(new Request.Builder(text, 0, text.length()).build());
- Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
+ extServicesTextClassifierRule
+ .getTextClassifier()
+ .classifyText(new Request.Builder(text, 0, text.length()).build());
// The result id contains the name of the just used model.
+ Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
assertThat(textClassification.getId()).contains(expectedVersion);
}
@@ -223,16 +179,4 @@ public class ModelDownloaderIntegrationTest {
private void verifyActiveRussianModel() {
verifyActiveModel("привет", V804_RU_TAG);
}
-
- private void startExtservicesProcess() {
- // Start the process of ExtServices by sending it a text classifier request.
- textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build());
- }
-
- private static void runShellCommand(String cmd) {
- Log.v(TAG, "run shell command: " + cmd);
- Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
- UiAutomation uiAutomation = instrumentation.getUiAutomation();
- uiAutomation.executeShellCommand(cmd);
- }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
index 3ceb47b..5f8247d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
@@ -20,64 +20,72 @@ import android.app.UiAutomation;
import android.content.pm.PackageManager;
import android.content.pm.PackageManager.NameNotFoundException;
import android.provider.DeviceConfig;
+import android.util.Log;
import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.io.ByteStreams;
+import java.io.FileInputStream;
+import java.io.IOException;
import org.junit.rules.ExternalResource;
/** A rule that manages a text classifier that is backed by the ExtServices. */
public final class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String TAG = "androidtc";
private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
"textclassifier_service_package_override";
private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
- private String textClassifierServiceOverrideFlagOldValue;
+ private UiAutomation uiAutomation;
+ private DeviceConfig.Properties originalProperties;
+ private DeviceConfig.Properties.Builder newPropertiesBuilder;
@Override
- protected void before() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- try {
- uiAutomation.adoptShellPermissionIdentity();
- textClassifierServiceOverrideFlagOldValue =
- DeviceConfig.getString(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- null);
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- getExtServicesPackageName(),
- /* makeDefault= */ false);
- } finally {
- uiAutomation.dropShellPermissionIdentity();
- }
+ protected void before() throws Exception {
+ uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ uiAutomation.adoptShellPermissionIdentity();
+ originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER);
+ newPropertiesBuilder =
+ new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER)
+ .setString(
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName());
+ overrideDeviceConfig();
}
@Override
protected void after() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
try {
- uiAutomation.adoptShellPermissionIdentity();
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- textClassifierServiceOverrideFlagOldValue,
- /* makeDefault= */ false);
+ DeviceConfig.setProperties(originalProperties);
+ } catch (Throwable t) {
+ Log.e(TAG, "Failed to reset DeviceConfig", t);
} finally {
uiAutomation.dropShellPermissionIdentity();
}
}
- private static String getExtServicesPackageName() {
- PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager();
- try {
- packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
- return PKG_NAME_GOOGLE_EXTSERVICES;
- } catch (NameNotFoundException e) {
- return PKG_NAME_AOSP_EXTSERVICES;
- }
+ public void addDeviceConfigOverride(String name, String value) {
+ newPropertiesBuilder.setString(name, value);
+ }
+
+ /**
+ * Overrides the TextClassifier DeviceConfig manually.
+ *
+ * <p>This will clean up all device configs not in newPropertiesBuilder.
+ *
+ * <p>We will need to call this everytime before testing, because DeviceConfig can be synced in
+ * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a
+ * hidden API.
+ */
+ public void overrideDeviceConfig() throws Exception {
+ DeviceConfig.setProperties(newPropertiesBuilder.build());
+ }
+
+ /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */
+ public void forceStopExtServices() {
+ runShellCommand("am force-stop com.google.android.ext.services");
+ runShellCommand("am force-stop android.ext.services");
}
public TextClassifier getTextClassifier() {
@@ -87,4 +95,38 @@ public final class ExtServicesTextClassifierRule extends ExternalResource {
textClassificationManager.setTextClassifier(null); // Reset TC overrides
return textClassificationManager.getTextClassifier();
}
+
+ public void dumpDefaultTextClassifierService() {
+ runShellCommand(
+ "dumpsys activity service com.google.android.ext.services/"
+ + "com.android.textclassifier.DefaultTextClassifierService");
+ runShellCommand("cmd device_config list textclassifier");
+ }
+
+ public void enableVerboseLogging() {
+ runShellCommand("setprop log.tag.androidtc VERBOSE");
+ }
+
+ private void runShellCommand(String cmd) {
+ Log.v(TAG, "run shell command: " + cmd);
+ try (FileInputStream output =
+ new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) {
+ String cmdOutput = new String(ByteStreams.toByteArray(output));
+ if (!cmdOutput.isEmpty()) {
+ Log.d(TAG, "cmd output: " + cmdOutput);
+ }
+ } catch (IOException ioe) {
+ Log.w(TAG, "failed to get cmd output", ioe);
+ }
+ }
+
+ private static String getExtServicesPackageName() {
+ PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager();
+ try {
+ packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
+ return PKG_NAME_GOOGLE_EXTSERVICES;
+ } catch (NameNotFoundException e) {
+ return PKG_NAME_AOSP_EXTSERVICES;
+ }
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java
deleted file mode 100644
index 7d46e97..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.testing;
-
-import android.app.UiAutomation;
-import android.os.LocaleList;
-import android.util.Log;
-import androidx.test.platform.app.InstrumentationRegistry;
-import org.junit.rules.ExternalResource;
-
-/** class for overriding testing_locale_list_override from {@link TextClassifierSettings} */
-public final class TestingLocaleListOverrideRule extends ExternalResource {
- private static final String TAG = "TestingLocaleListOverrideRule";
-
- private LocaleList originalLocaleList;
-
- @Override
- protected void before() {
- originalLocaleList = LocaleList.getDefault();
- }
-
- public void set(String... localeTags) {
- if (localeTags.length == 0) {
- return;
- }
- runShellCommand(
- "device_config put textclassifier testing_locale_list_override "
- + String.join(",", localeTags));
- }
-
- @Override
- protected void after() {
- runShellCommand(
- "device_config put textclassifier testing_locale_list_override "
- + originalLocaleList.toLanguageTags());
- runShellCommand("device_config delete textclassifier testing_locale_list_override");
- }
-
- private static void runShellCommand(String cmd) {
- Log.v(TAG, "run shell command: " + cmd);
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- uiAutomation.executeShellCommand(cmd);
- }
-}
diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc
index d4d119e..964e316 100644
--- a/native/utils/tokenfree/byte_encoder_test.cc
+++ b/native/utils/tokenfree/byte_encoder_test.cc
@@ -29,7 +29,7 @@ namespace {
using testing::ElementsAre;
-TEST(EncoderTest, SimpleTokenization) {
+TEST(ByteEncoderTest, SimpleTokenization) {
const ByteEncoder encoder;
{
std::vector<int64_t> encoded_text;
@@ -39,7 +39,7 @@ TEST(EncoderTest, SimpleTokenization) {
}
}
-TEST(EncoderTest, SimpleTokenization2) {
+TEST(ByteEncoderTest, SimpleTokenization2) {
const ByteEncoder encoder;
{
std::vector<int64_t> encoded_text;