From 8ebbedca8443b38941a7ddadc8245fcc83c6f866 Mon Sep 17 00:00:00 2001 From: Chang Li Date: Mon, 14 Mar 2022 10:25:28 +0000 Subject: Export java souces to fix several broken tests. BUG: 222229739 Change-Id: Ic1e9f7c1e734f135271effcd7d49fd8c4e0bd268 --- .../com/android/textclassifier/ExtrasUtils.java | 4 +- .../textclassifier/ModelFileManagerImpl.java | 13 +- .../downloader/DownloadedModelManagerImpl.java | 4 +- .../downloader/ModelDownloadManager.java | 8 +- .../downloader/ModelDownloadWorker.java | 8 +- .../downloader/ModelDownloaderImpl.java | 10 +- .../downloader/ModelDownloaderService.java | 2 +- .../downloader/ModelDownloaderServiceImpl.java | 2 +- .../DefaultTextClassifierServiceTest.java | 36 ++-- .../com/android/textclassifier/TestDataUtils.java | 29 ++-- .../textclassifier/TextClassifierApiTest.java | 10 +- .../textclassifier/TextClassifierImplTest.java | 171 ++++++++----------- .../downloader/ModelDownloadManagerTest.java | 6 + .../downloader/ModelDownloaderIntegrationTest.java | 185 ++++++++------------- .../testing/ExtServicesTextClassifierRule.java | 108 ++++++++---- .../testing/TestingLocaleListOverrideRule.java | 58 ------- 16 files changed, 292 insertions(+), 362 deletions(-) delete mode 100644 java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java index fd64581..bde3898 100644 --- a/java/src/com/android/textclassifier/ExtrasUtils.java +++ b/java/src/com/android/textclassifier/ExtrasUtils.java @@ -87,7 +87,9 @@ public final class ExtrasUtils { return classification.getExtras().getBundle(FOREIGN_LANGUAGE); } - /** @see #getTopLanguage(Intent) */ + /** + * @see #getTopLanguage(Intent) + */ static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) { final int maxSize = Math.min(3, languageScores.getEntities().size()); final String[] languages = diff --git a/java/src/com/android/textclassifier/ModelFileManagerImpl.java b/java/src/com/android/textclassifier/ModelFileManagerImpl.java index 45426d0..e3b646f 100644 --- a/java/src/com/android/textclassifier/ModelFileManagerImpl.java +++ b/java/src/com/android/textclassifier/ModelFileManagerImpl.java @@ -390,7 +390,18 @@ final class ModelFileManagerImpl implements ModelFileManager { localePreferences.get(0), targetLocale)); } - return findBestModelFile(modelType, targetLocale); + ModelFile modelFile = findBestModelFile(modelType, targetLocale); + TcLog.d( + TAG, + String.format( + Locale.US, + "findBestModelFile: best model: %s; localePreferences: %s; detectedLocales: %s;" + + " targetLocale: %s", + modelFile, + localePreferences, + detectedLocales, + targetLocale)); + return modelFile; } /** diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java index 1ae79ce..9bdfb5e 100644 --- a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java +++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java @@ -195,7 +195,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager @Override public void onDownloadCompleted( ImmutableMap manifestsToDownload) { - TcLog.v(TAG, "Start to clean up models and update model lookup cache..."); + TcLog.d(TAG, "Start to clean up models and update model lookup cache..."); // Step 1: Clean up ManifestEnrollment table List allManifestEnrollments = db.dao().queryAllManifestEnrollments(); List manifestEnrollmentsToDelete = new ArrayList<>(); @@ -286,7 +286,7 @@ public final class DownloadedModelManagerImpl implements DownloadedModelManager // Clear the cache table and rebuild the cache based on ModelView table private void updateCache() { synchronized (cacheLock) { - TcLog.v(TAG, "Updating model lookup cache..."); + TcLog.d(TAG, "Updating model lookup cache..."); for (String modelType : ModelType.values()) { modelLookupCache.get(modelType).clear(); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java index 3614d35..af33e21 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java @@ -147,7 +147,7 @@ public final class ModelDownloadManager { return; } maybeOverrideLocaleListForTesting(); - TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started."); + TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started."); scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED); } catch (Throwable t) { TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t); @@ -161,7 +161,7 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - TcLog.v(TAG, "Try to schedule model download work because of system locale changes."); + TcLog.d(TAG, "Try to schedule model download work because of system locale changes."); try { scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED); } catch (Throwable t) { @@ -176,7 +176,7 @@ public final class ModelDownloadManager { if (!settings.isModelDownloadManagerEnabled()) { return; } - TcLog.v(TAG, "Try to schedule model download work because of device config changes."); + TcLog.d(TAG, "Try to schedule model download work because of device config changes."); try { maybeOverrideLocaleListForTesting(); scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED); @@ -261,7 +261,7 @@ public final class ModelDownloadManager { new FutureCallback() { @Override public void onSuccess(Operation.State.SUCCESS unused) { - TcLog.v(TAG, "Download work scheduled."); + TcLog.d(TAG, "Download work scheduled."); TextClassifierDownloadLogger.downloadWorkScheduled( workId, reasonToSchedule, /* failedToSchedule= */ false); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java index 6e04e16..3db0815 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java @@ -113,6 +113,7 @@ public final class ModelDownloadWorker extends ListenableWorker { @Override public final ListenableFuture startWork() { + TcLog.d(TAG, "Start download work..."); workStartedTimeMillis = getCurrentTimeMillis(); // Notice: startWork() is invoked on the main thread if (!settings.isModelDownloadManagerEnabled()) { @@ -121,7 +122,6 @@ public final class ModelDownloadWorker extends ListenableWorker { TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED); return Futures.immediateFuture(ListenableWorker.Result.failure()); } - TcLog.v(TAG, "Start download work..."); if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) { TcLog.d(TAG, "Max attempt reached. Abort download work."); logDownloadWorkCompleted( @@ -134,7 +134,7 @@ public final class ModelDownloadWorker extends ListenableWorker { downloadResult -> { Preconditions.checkNotNull(manifestsToDownload); downloadedModelManager.onDownloadCompleted(manifestsToDownload); - TcLog.v(TAG, "Download work completed: " + downloadResult); + TcLog.d(TAG, "Download work completed: " + downloadResult); if (downloadResult.failureCount() == 0) { logDownloadWorkCompleted( downloadResult.successCount() > 0 @@ -239,7 +239,7 @@ public final class ModelDownloadWorker extends ListenableWorker { return Futures.whenAllComplete(downloadResultFutures) .call( () -> { - TcLog.v(TAG, "All Download Tasks Completed"); + TcLog.d(TAG, "All Download Tasks Completed"); int successCount = 0; int failureCount = 0; for (ListenableFuture downloadResultFuture : downloadResultFutures) { @@ -333,7 +333,7 @@ public final class ModelDownloadWorker extends ListenableWorker { Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl); if (downloadedManifest != null && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) { - TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl); + TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl); return Futures.immediateVoidFuture(); } if (pendingDownloads.containsKey(manifestUrl)) { diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java index 2244e9a..0b76f22 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java @@ -99,7 +99,7 @@ final class ModelDownloaderImpl implements ModelDownloader { new FutureCallback() { @Override public void onSuccess(File pendingModelFile) { - TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); + TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); } @Override @@ -170,11 +170,11 @@ final class ModelDownloaderImpl implements ModelDownloader { } catch (IOException e) { throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e); } - TcLog.v(TAG, "Pending model file passed validation."); + TcLog.d(TAG, "Pending model file passed validation."); } private ListenableFuture connect(DownloaderServiceConnection conn) { - TcLog.v(TAG, "Starting a new connection to ModelDownloaderService"); + TcLog.d(TAG, "Starting a new connection to ModelDownloaderService"); return CallbackToFutureAdapter.getFuture( completer -> { conn.attachCompleter(completer); @@ -197,7 +197,7 @@ final class ModelDownloaderImpl implements ModelDownloader { // restult future will hang there until time out. (WorkManager forces a 10-min running time.) private static ListenableFuture scheduleDownload( IModelDownloaderService service, URI uri, File targetFile) { - TcLog.v(TAG, "Scheduling a new download task with ModelDownloaderService"); + TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService"); return CallbackToFutureAdapter.getFuture( completer -> { service.download( @@ -236,7 +236,7 @@ final class ModelDownloaderImpl implements ModelDownloader { @Override public void onServiceConnected(ComponentName componentName, IBinder iBinder) { - TcLog.v(TAG, "DownloaderService connected"); + TcLog.d(TAG, "DownloaderService connected"); completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder))); } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java index e4ebbfa..6d7e47e 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java @@ -39,7 +39,7 @@ public final class ModelDownloaderService extends Service { @Override public IBinder onBind(Intent intent) { - TcLog.v(TAG, "Binding to ModelDownloadService"); + TcLog.d(TAG, "Binding to ModelDownloadService"); return iBinder; } } diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java index 439588b..47e6f19 100644 --- a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java +++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java @@ -91,7 +91,7 @@ final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub { @Override public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) { - TcLog.v(TAG, "Download request received: " + uri); + TcLog.d(TAG, "Download request received: " + uri); try { File targetFile = new File(targetFilePath); File tempMetadataFile = getMetadataFile(targetFile); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java index 71f9a4f..ddab8bd 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java @@ -17,7 +17,10 @@ package com.android.textclassifier; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import android.content.Context; import android.os.CancellationSignal; @@ -39,6 +42,7 @@ import com.android.os.AtomsProto.Atom; import com.android.os.AtomsProto.TextClassifierApiUsageReported; import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType; import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType; +import com.android.textclassifier.common.ModelType; import com.android.textclassifier.common.TextClassifierSettings; import com.android.textclassifier.common.statsd.StatsdTestUtils; import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger; @@ -47,6 +51,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import java.io.IOException; import java.util.List; import java.util.concurrent.Executor; import java.util.stream.Collectors; @@ -81,13 +86,21 @@ public class DefaultTextClassifierServiceTest { @Mock private TextClassifierService.Callback textLinksCallback; @Mock private TextClassifierService.Callback conversationActionsCallback; @Mock private TextClassifierService.Callback textLanguageCallback; + @Mock private ModelFileManager testModelFileManager; @Before - public void setup() { - - testInjector = new TestInjector(ApplicationProvider.getApplicationContext()); + public void setup() throws IOException { + testInjector = + new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager); defaultTextClassifierService = new DefaultTextClassifierService(testInjector); defaultTextClassifierService.onCreate(); + + when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); + when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) + .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); + when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) + .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Before @@ -211,11 +224,8 @@ public class DefaultTextClassifierServiceTest { @Test public void missingModelFile_onFailureShouldBeCalled() throws Exception { - testInjector.setModelFileManager( - new ModelFileManagerImpl( - ApplicationProvider.getApplicationContext(), - ImmutableList.of(), - testInjector.createTextClassifierSettings())); + when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(null); defaultTextClassifierService.onCreate(); TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build(); @@ -251,12 +261,9 @@ public class DefaultTextClassifierServiceTest { private final Context context; private ModelFileManager modelFileManager; - private TestInjector(Context context) { + private TestInjector(Context context, ModelFileManager modelFileManager) { this.context = Preconditions.checkNotNull(context); - } - - private void setModelFileManager(ModelFileManager modelFileManager) { - this.modelFileManager = modelFileManager; + this.modelFileManager = Preconditions.checkNotNull(modelFileManager); } @Override @@ -267,9 +274,6 @@ public class DefaultTextClassifierServiceTest { @Override public ModelFileManager createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) { - if (modelFileManager == null) { - return TestDataUtils.createModelFileManagerForTesting(context); - } return modelFileManager; } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java index bac4fa1..a19e3ff 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java @@ -16,12 +16,10 @@ package com.android.textclassifier; -import android.content.Context; -import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister; +import com.android.textclassifier.common.ModelFile; import com.android.textclassifier.common.ModelType; -import com.android.textclassifier.common.TextClassifierSettings; -import com.google.common.collect.ImmutableList; import java.io.File; +import java.io.IOException; /** Utils to access test data files. */ public final class TestDataUtils { @@ -30,7 +28,7 @@ public final class TestDataUtils { private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model"; /** Returns the root folder that contains the test data. */ - public static File getTestDataFolder() { + private static File getTestDataFolder() { return new File("/data/local/tmp/TextClassifierServiceTest/"); } @@ -38,24 +36,25 @@ public final class TestDataUtils { return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH); } + public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR); + } + public static File getTestActionsModelFile() { return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH); } + public static ModelFile getTestActionsModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile( + getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS); + } + public static File getLangIdModelFile() { return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH); } - public static ModelFileManager createModelFileManagerForTesting(Context context) { - return new ModelFileManagerImpl( - context, - ImmutableList.of( - new RegularFileFullMatchLister( - ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true), - new RegularFileFullMatchLister( - ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true), - new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)), - new TextClassifierSettings()); + public static ModelFile getLangIdModelFileWrapped() throws IOException { + return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID); } private TestDataUtils() {} diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java index 42177e6..e7bf90c 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java @@ -56,6 +56,10 @@ public class TextClassifierApiTest { @Before public void setup() { + extServicesTextClassifierRule.enableVerboseLogging(); + // Verbose logging only takes effect after restarting ExtServices + extServicesTextClassifierRule.forceStopExtServices(); + textClassifier = extServicesTextClassifierRule.getTextClassifier(); } @@ -81,8 +85,8 @@ public class TextClassifierApiTest { @Test public void classifyText() { - String text = "Contact me at droid@android.com"; - String classifiedText = "droid@android.com"; + String text = "Contact me at http://www.android.com"; + String classifiedText = "http://www.android.com"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = @@ -90,7 +94,7 @@ public class TextClassifierApiTest { TextClassification classification = textClassifier.classifyText(request); assertThat(classification.getEntityCount()).isGreaterThan(0); - assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL); + assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL); assertThat(classification.getText()).isEqualTo(classifiedText); assertThat(classification.getActions()).isNotEmpty(); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java index fb1aea8..c20ec8a 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java @@ -22,6 +22,9 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.testng.Assert.expectThrows; @@ -74,30 +77,34 @@ public class TextClassifierImplTest { private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); private static final String NO_TYPE = null; - @Mock private ModelFileManagerImpl.ModelFileLister mockModelFileLister; + @Mock private ModelFileManager modelFileManager; - private TextClassifierSettings settings; private Context context; private TestingDeviceConfig deviceConfig; - private TextClassifierImpl classifier; - - private final ModelFileManager modelFileManager = - TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext()); + private TextClassifierSettings settings; private LruCache annotatorModelCache; + private TextClassifierImpl classifier; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.initMocks(this); - deviceConfig = new TestingDeviceConfig(); - Context context = + this.context = new FakeContextBuilder() .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") .build(); - this.context = context; - settings = new TextClassifierSettings(deviceConfig); - // TODO(veronikanikina): consider using a testing constructor here. - classifier = new TextClassifierImpl(context, settings, modelFileManager); + this.deviceConfig = new TestingDeviceConfig(); + this.settings = new TextClassifierSettings(deviceConfig); + this.annotatorModelCache = new LruCache<>(2); + this.classifier = + new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); + + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); + when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) + .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); + when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) + .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Test @@ -110,15 +117,31 @@ public class TextClassifierImplTest { int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat( selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL)); } + @Test + public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { + String text = "Contact me at droid@android.com"; + String selected = "droid"; + String suggested = "droid@android.com"; + int startIndex = text.indexOf(selected); + int endIndex = startIndex + selected.length(); + int smartStartIndex = text.indexOf(suggested); + int smartEndIndex = smartStartIndex + suggested.length(); + TextSelection.Request request = + new TextSelection.Request.Builder(text, startIndex, endIndex) + .setDefaultLocales(LOCALES) + .build(); + + classifier.suggestSelection(null, null, request); + verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); + } + @Test public void testSuggestSelection_url() throws IOException { String text = "Visit http://www.android.com for more information"; @@ -129,9 +152,7 @@ public class TextClassifierImplTest { int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); @@ -144,9 +165,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); TextSelection.Request request = - new TextSelection.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); @@ -160,7 +179,6 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(suggested); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1) - .setDefaultLocales(LOCALES) .setIncludeTextClassification(true) .build(); @@ -178,7 +196,6 @@ public class TextClassifierImplTest { String text = "Visit http://www.android.com for more information"; TextSelection.Request request = new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4) - .setDefaultLocales(LOCALES) .setIncludeTextClassification(false) .build(); @@ -194,9 +211,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(/* sessionId= */ null, null, request); @@ -210,9 +225,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); @@ -223,9 +236,7 @@ public class TextClassifierImplTest { public void testClassifyText_address() throws IOException { String text = "Brandschenkestrasse 110, Zürich, Switzerland"; TextClassification.Request request = - new TextClassification.Request.Builder(text, 0, text.length()) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, 0, text.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); @@ -238,9 +249,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); @@ -254,9 +263,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); @@ -275,9 +282,7 @@ public class TextClassifierImplTest { int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = - new TextClassification.Request.Builder(text, startIndex, endIndex) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); @@ -289,14 +294,12 @@ public class TextClassifierImplTest { LocaleList.setDefault(LocaleList.forLanguageTags("en")); String japaneseText = "これは日本語のテキストです"; TextClassification.Request request = - new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()) - .setDefaultLocales(LOCALES) - .build(); + new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); RemoteAction translateAction = classification.getActions().get(0); assertEquals(1, classification.getActions().size()); - assertEquals("Translate", translateAction.getTitle().toString()); + assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); @@ -323,18 +326,17 @@ public class TextClassifierImplTest { @Test public void testGenerateLinks_exclude() throws IOException { - String text = "You want apple@banana.com. See you tonight!"; + String text = "The number is +12122537077. See you tonight!"; List hints = ImmutableList.of(); List included = ImmutableList.of(); - List excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); + List excluded = Arrays.asList(TextClassifier.TYPE_PHONE); TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), - not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); + not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); } @Test @@ -344,7 +346,6 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), @@ -361,7 +362,6 @@ public class TextClassifierImplTest { TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) - .setDefaultLocales(LOCALES) .build(); assertThat( classifier.generateLinks(null, null, request), @@ -573,29 +573,16 @@ public class TextClassifierImplTest { new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); - String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath(); - ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false); - - annotatorModelCache = new LruCache<>(2); - ModelFileManager modelFileManagerCached = - new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings); - TextClassifierImpl textClassifierImpl = - new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache); - LocaleList.setDefault(LocaleList.forLanguageTags("en")); String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = - new TextClassification.Request.Builder(englishText, 0, englishText.length()) - .setDefaultLocales(LOCALES) - .build(); - - when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel)); + new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationA = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationA = classifier.classifyText(null, null, request); assertThat(classificationA.getId()).contains("v701"); assertThat(classificationA.getText()).contains(classifiedText); @@ -609,9 +596,9 @@ public class TextClassifierImplTest { }); // Check modelFileB v801 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelB)); - TextClassification classificationB = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelB); + TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); @@ -625,9 +612,9 @@ public class TextClassifierImplTest { }); // Reload modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); @@ -651,28 +638,16 @@ public class TextClassifierImplTest { new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); - String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath(); - ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false); - - annotatorModelCache = new LruCache<>(settings.getMultiAnnotatorCacheSize()); - ModelFileManager modelFileManagerCached = - new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings); - TextClassifierImpl textClassifierImpl = - new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache); - LocaleList.setDefault(LocaleList.forLanguageTags("en")); + String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = - new TextClassification.Request.Builder(englishText, 0, englishText.length()) - .setDefaultLocales(LOCALES) - .build(); - - when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel)); + new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classification = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification.getId()).contains("v701"); assertThat(classification.getText()).contains(classifiedText); @@ -686,9 +661,9 @@ public class TextClassifierImplTest { }); // Check modelFileB v801 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelB)); - TextClassification classificationB = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelB); + TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); @@ -702,9 +677,9 @@ public class TextClassifierImplTest { }); // Reload modelFileA v701 - when(mockModelFileLister.list(ModelType.ANNOTATOR)) - .thenReturn(ImmutableList.of(annotatorModelA)); - TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request); + when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) + .thenReturn(annotatorModelA); + TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java index 5a67f93..9e11c09 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java @@ -109,8 +109,10 @@ public final class ModelDownloadManagerTest { @Test public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); downloadManagerWithBadWorkManager.onTextClassifierServiceCreated(); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED); @@ -119,6 +121,7 @@ public final class ModelDownloadManagerTest { @Test public void onTextClassifierServiceCreated_requestEnqueued() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); downloadManager.onTextClassifierServiceCreated(); WorkInfo workInfo = @@ -126,17 +129,20 @@ public final class ModelDownloadManagerTest { DownloaderTestUtils.queryWorkInfos( workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test public void onTextClassifierServiceCreated_localeListOverridden() throws Exception { + assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); downloadManager.onTextClassifierServiceCreated(); assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); + // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java index 9f555fc..e261158 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java @@ -18,16 +18,10 @@ package com.android.textclassifier.downloader; import static com.google.common.truth.Truth.assertThat; -import android.app.Instrumentation; -import android.app.UiAutomation; import android.util.Log; import android.view.textclassifier.TextClassification; import android.view.textclassifier.TextClassification.Request; -import android.view.textclassifier.TextClassifier; -import androidx.test.platform.app.InstrumentationRegistry; import com.android.textclassifier.testing.ExtServicesTextClassifierRule; -import com.android.textclassifier.testing.TestingLocaleListOverrideRule; -import java.io.IOException; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -48,170 +42,133 @@ public class ModelDownloaderIntegrationTest { private static final String V804_EN_TAG = "en_v804"; private static final String V804_RU_TAG = "ru_v804"; private static final String FACTORY_MODEL_TAG = "*"; - - @Rule - public final TestingLocaleListOverrideRule testingLocaleListOverrideRule = - new TestingLocaleListOverrideRule(); + private static final int ASSERT_MAX_ATTEMPTS = 20; + private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000; @Rule public final ExtServicesTextClassifierRule extServicesTextClassifierRule = new ExtServicesTextClassifierRule(); - private TextClassifier textClassifier; - @Before public void setup() throws Exception { - // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky - runShellCommand("device_config put textclassifier config_updater_model_enabled false"); - runShellCommand("device_config put textclassifier model_download_manager_enabled true"); - runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 5"); - - textClassifier = extServicesTextClassifierRule.getTextClassifier(); - startExtservicesProcess(); + extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false"); + extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true"); + extServicesTextClassifierRule.addDeviceConfigOverride( + "model_download_backoff_delay_in_millis", "5"); + extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US"); + extServicesTextClassifierRule.overrideDeviceConfig(); + + extServicesTextClassifierRule.enableVerboseLogging(); + // Verbose logging only takes effect after restarting ExtServices + extServicesTextClassifierRule.forceStopExtServices(); } @After public void tearDown() throws Exception { - runShellCommand("device_config delete textclassifier manifest_url_annotator_en"); - runShellCommand("device_config delete textclassifier manifest_url_annotator_ru"); - runShellCommand("device_config put textclassifier config_updater_model_enabled true"); - runShellCommand("device_config delete textclassifier multi_language_support_enabled"); - runShellCommand( - "device_config put textclassifier model_download_backoff_delay_in_millis 3600000"); + // This is to reset logging/locale_override for ExtServices. + extServicesTextClassifierRule.forceStopExtServices(); } @Test - public void smokeTest() throws IOException, InterruptedException { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + public void smokeTest() throws Exception { + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG)); + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); } @Test - public void downgradeModel() throws IOException, InterruptedException { + public void downgradeModel() throws Exception { // Download an experimental model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); // Downgrade to an older model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG)); - } + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); } @Test - public void upgradeModel() throws IOException, InterruptedException { + public void upgradeModel() throws Exception { // Download a model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + V804_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveEnglishModel(V804_EN_TAG)); - } + assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG)); // Upgrade to an experimental model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); } @Test - public void clearFlag() throws IOException, InterruptedException { + public void clearFlag() throws Exception { // Download a new model. - { - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); - - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); // Revert the flag. - { - runShellCommand("device_config delete textclassifier manifest_url_annotator_en"); - // Fallback to use the universal model. - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG)); - } + extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", ""); + // Fallback to use the universal model. + assertWithRetries( + () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG)); } @Test - public void modelsForMultipleLanguagesDownloaded() throws IOException, InterruptedException { - runShellCommand("device_config put textclassifier multi_language_support_enabled true"); - testingLocaleListOverrideRule.set("en-US", "ru-RU"); + public void modelsForMultipleLanguagesDownloaded() throws Exception { + extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true"); + extServicesTextClassifierRule.addDeviceConfigOverride( + "testing_locale_list_override", "en-US,ru-RU"); // download en model - runShellCommand( - "device_config put textclassifier manifest_url_annotator_en " - + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL); // download ru model - runShellCommand( - "device_config put textclassifier manifest_url_annotator_ru " - + V804_RU_ANNOTATOR_MANIFEST_URL); - assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, - () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); + extServicesTextClassifierRule.addDeviceConfigOverride( + "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL); + assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG)); - assertWithRetries(/* maxAttempts= */ 10, /* sleepMs= */ 500, this::verifyActiveRussianModel); + assertWithRetries(this::verifyActiveRussianModel); assertWithRetries( - /* maxAttempts= */ 10, - /* sleepMs= */ 500, () -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG)); } - private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable) - throws InterruptedException { - for (int i = 0; i < maxAttempts; i++) { + private void assertWithRetries(Runnable assertRunnable) throws Exception { + for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) { try { + extServicesTextClassifierRule.overrideDeviceConfig(); assertRunnable.run(); break; // success. Bail out. } catch (AssertionError ex) { - if (i == maxAttempts - 1) { // last attempt, give up. + if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up. + extServicesTextClassifierRule.dumpDefaultTextClassifierService(); throw ex; } else { - Thread.sleep(sleepMs); + Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS); } + } catch (Exception unknownException) { + throw unknownException; } } } private void verifyActiveModel(String text, String expectedVersion) { TextClassification textClassification = - textClassifier.classifyText(new Request.Builder(text, 0, text.length()).build()); + extServicesTextClassifierRule + .getTextClassifier() + .classifyText(new Request.Builder(text, 0, text.length()).build()); // The result id contains the name of the just used model. + Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId()); assertThat(textClassification.getId()).contains(expectedVersion); } @@ -222,16 +179,4 @@ public class ModelDownloaderIntegrationTest { private void verifyActiveRussianModel() { verifyActiveModel("привет", V804_RU_TAG); } - - private void startExtservicesProcess() { - // Start the process of ExtServices by sending it a text classifier request. - textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build()); - } - - private static void runShellCommand(String cmd) { - Log.v(TAG, "run shell command: " + cmd); - Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation(); - UiAutomation uiAutomation = instrumentation.getUiAutomation(); - uiAutomation.executeShellCommand(cmd); - } } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java index 3ceb47b..5f8247d 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java @@ -20,64 +20,72 @@ import android.app.UiAutomation; import android.content.pm.PackageManager; import android.content.pm.PackageManager.NameNotFoundException; import android.provider.DeviceConfig; +import android.util.Log; import android.view.textclassifier.TextClassificationManager; import android.view.textclassifier.TextClassifier; import androidx.test.core.app.ApplicationProvider; import androidx.test.platform.app.InstrumentationRegistry; +import com.google.common.io.ByteStreams; +import java.io.FileInputStream; +import java.io.IOException; import org.junit.rules.ExternalResource; /** A rule that manages a text classifier that is backed by the ExtServices. */ public final class ExtServicesTextClassifierRule extends ExternalResource { + private static final String TAG = "androidtc"; private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE = "textclassifier_service_package_override"; private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services"; private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services"; - private String textClassifierServiceOverrideFlagOldValue; + private UiAutomation uiAutomation; + private DeviceConfig.Properties originalProperties; + private DeviceConfig.Properties.Builder newPropertiesBuilder; @Override - protected void before() { - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); - try { - uiAutomation.adoptShellPermissionIdentity(); - textClassifierServiceOverrideFlagOldValue = - DeviceConfig.getString( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - null); - DeviceConfig.setProperty( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - getExtServicesPackageName(), - /* makeDefault= */ false); - } finally { - uiAutomation.dropShellPermissionIdentity(); - } + protected void before() throws Exception { + uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); + uiAutomation.adoptShellPermissionIdentity(); + originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER); + newPropertiesBuilder = + new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER) + .setString( + CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName()); + overrideDeviceConfig(); } @Override protected void after() { - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); try { - uiAutomation.adoptShellPermissionIdentity(); - DeviceConfig.setProperty( - DeviceConfig.NAMESPACE_TEXTCLASSIFIER, - CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, - textClassifierServiceOverrideFlagOldValue, - /* makeDefault= */ false); + DeviceConfig.setProperties(originalProperties); + } catch (Throwable t) { + Log.e(TAG, "Failed to reset DeviceConfig", t); } finally { uiAutomation.dropShellPermissionIdentity(); } } - private static String getExtServicesPackageName() { - PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager(); - try { - packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0); - return PKG_NAME_GOOGLE_EXTSERVICES; - } catch (NameNotFoundException e) { - return PKG_NAME_AOSP_EXTSERVICES; - } + public void addDeviceConfigOverride(String name, String value) { + newPropertiesBuilder.setString(name, value); + } + + /** + * Overrides the TextClassifier DeviceConfig manually. + * + *

This will clean up all device configs not in newPropertiesBuilder. + * + *

We will need to call this everytime before testing, because DeviceConfig can be synced in + * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a + * hidden API. + */ + public void overrideDeviceConfig() throws Exception { + DeviceConfig.setProperties(newPropertiesBuilder.build()); + } + + /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */ + public void forceStopExtServices() { + runShellCommand("am force-stop com.google.android.ext.services"); + runShellCommand("am force-stop android.ext.services"); } public TextClassifier getTextClassifier() { @@ -87,4 +95,38 @@ public final class ExtServicesTextClassifierRule extends ExternalResource { textClassificationManager.setTextClassifier(null); // Reset TC overrides return textClassificationManager.getTextClassifier(); } + + public void dumpDefaultTextClassifierService() { + runShellCommand( + "dumpsys activity service com.google.android.ext.services/" + + "com.android.textclassifier.DefaultTextClassifierService"); + runShellCommand("cmd device_config list textclassifier"); + } + + public void enableVerboseLogging() { + runShellCommand("setprop log.tag.androidtc VERBOSE"); + } + + private void runShellCommand(String cmd) { + Log.v(TAG, "run shell command: " + cmd); + try (FileInputStream output = + new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) { + String cmdOutput = new String(ByteStreams.toByteArray(output)); + if (!cmdOutput.isEmpty()) { + Log.d(TAG, "cmd output: " + cmdOutput); + } + } catch (IOException ioe) { + Log.w(TAG, "failed to get cmd output", ioe); + } + } + + private static String getExtServicesPackageName() { + PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager(); + try { + packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0); + return PKG_NAME_GOOGLE_EXTSERVICES; + } catch (NameNotFoundException e) { + return PKG_NAME_AOSP_EXTSERVICES; + } + } } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java deleted file mode 100644 index 7d46e97..0000000 --- a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2018 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.textclassifier.testing; - -import android.app.UiAutomation; -import android.os.LocaleList; -import android.util.Log; -import androidx.test.platform.app.InstrumentationRegistry; -import org.junit.rules.ExternalResource; - -/** class for overriding testing_locale_list_override from {@link TextClassifierSettings} */ -public final class TestingLocaleListOverrideRule extends ExternalResource { - private static final String TAG = "TestingLocaleListOverrideRule"; - - private LocaleList originalLocaleList; - - @Override - protected void before() { - originalLocaleList = LocaleList.getDefault(); - } - - public void set(String... localeTags) { - if (localeTags.length == 0) { - return; - } - runShellCommand( - "device_config put textclassifier testing_locale_list_override " - + String.join(",", localeTags)); - } - - @Override - protected void after() { - runShellCommand( - "device_config put textclassifier testing_locale_list_override " - + originalLocaleList.toLanguageTags()); - runShellCommand("device_config delete textclassifier testing_locale_list_override"); - } - - private static void runShellCommand(String cmd) { - Log.v(TAG, "run shell command: " + cmd); - UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation(); - uiAutomation.executeShellCommand(cmd); - } -} -- cgit v1.2.3 From 0c1f32bfa21cdcd618b55c4e8dd4f122e9a59af1 Mon Sep 17 00:00:00 2001 From: Chang Li Date: Fri, 25 Mar 2022 12:10:15 +0000 Subject: Rename Encoder test names. Bug: 226182044 Change-Id: I7bfcf3948f7a4fcb603ff15cac318a67a891e269 --- native/utils/tokenfree/byte_encoder_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc index d4d119e..964e316 100644 --- a/native/utils/tokenfree/byte_encoder_test.cc +++ b/native/utils/tokenfree/byte_encoder_test.cc @@ -29,7 +29,7 @@ namespace { using testing::ElementsAre; -TEST(EncoderTest, SimpleTokenization) { +TEST(ByteEncoderTest, SimpleTokenization) { const ByteEncoder encoder; { std::vector encoded_text; @@ -39,7 +39,7 @@ TEST(EncoderTest, SimpleTokenization) { } } -TEST(EncoderTest, SimpleTokenization2) { +TEST(ByteEncoderTest, SimpleTokenization2) { const ByteEncoder encoder; { std::vector encoded_text; -- cgit v1.2.3 From 84b3aa147eabb6fbb916aa20e1e7b5613b267304 Mon Sep 17 00:00:00 2001 From: Chang Li Date: Mon, 4 Apr 2022 15:11:10 +0000 Subject: Remove expensive logging statement. This logging statement is expensive (maybe because of stringification) and it is on the TextClassifier API critical path. Remove it. In the future, we can log model lookup info only when we need to reload the model (i.e. the lookup result changes). BUG: 225081632 Change-Id: I14f950e4101027d1a6596fd1c459c4c4440b8379 --- .../com/android/textclassifier/ModelFileManagerImpl.java | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/java/src/com/android/textclassifier/ModelFileManagerImpl.java b/java/src/com/android/textclassifier/ModelFileManagerImpl.java index e3b646f..45426d0 100644 --- a/java/src/com/android/textclassifier/ModelFileManagerImpl.java +++ b/java/src/com/android/textclassifier/ModelFileManagerImpl.java @@ -390,18 +390,7 @@ final class ModelFileManagerImpl implements ModelFileManager { localePreferences.get(0), targetLocale)); } - ModelFile modelFile = findBestModelFile(modelType, targetLocale); - TcLog.d( - TAG, - String.format( - Locale.US, - "findBestModelFile: best model: %s; localePreferences: %s; detectedLocales: %s;" - + " targetLocale: %s", - modelFile, - localePreferences, - detectedLocales, - targetLocale)); - return modelFile; + return findBestModelFile(modelType, targetLocale); } /** -- cgit v1.2.3