summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--TEST_MAPPING5
-rw-r--r--java/src/com/android/textclassifier/ModelFileManager.java8
-rw-r--r--java/src/com/android/textclassifier/common/TextClassifierSettings.java10
-rw-r--r--java/tests/instrumentation/Android.bp57
-rw-r--r--java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml14
-rw-r--r--java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml28
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java72
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java184
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java90
9 files changed, 390 insertions, 78 deletions
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 93ea6d6..72e022b 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -16,6 +16,9 @@
},
{
"name": "TextClassifierNotificationTests"
+ },
+ {
+ "name": "TCSModelDownloaderIntegrationTest"
}
],
"mainline-presubmit": [
@@ -32,4 +35,4 @@
"name": "libtextclassifier_java_tests[com.google.android.extservices.apex]"
}
]
-} \ No newline at end of file
+}
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index 63e7155..f61b917 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -78,7 +78,7 @@ public final class ModelFileManager {
new RegularFileFullMatchLister(
ModelType.ANNOTATOR,
new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
- /* isEnabled= */ () -> true),
+ /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
new AssetFilePatternMatchLister(
assetManager,
ModelType.ANNOTATOR,
@@ -89,7 +89,7 @@ public final class ModelFileManager {
new RegularFileFullMatchLister(
ModelType.ACTIONS_SUGGESTIONS,
new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
- /* isEnabled= */ () -> true),
+ /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
new AssetFilePatternMatchLister(
assetManager,
ModelType.ACTIONS_SUGGESTIONS,
@@ -100,7 +100,7 @@ public final class ModelFileManager {
new RegularFileFullMatchLister(
ModelType.LANG_ID,
new File(CONFIG_UPDATER_DIR, "lang_id.model"),
- /* isEnabled= */ () -> true),
+ /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
new AssetFilePatternMatchLister(
assetManager,
ModelType.LANG_ID,
@@ -321,7 +321,7 @@ public final class ModelFileManager {
try {
modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
} catch (IOException e) {
- TcLog.w(TAG, "Failed to call createFromAsset with: " + absolutePath);
+ TcLog.e(TAG, "Failed to call createFromAsset with: " + absolutePath, e);
}
}
ImmutableList<ModelFile> result = modelFilesBuilder.build();
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
index 5b419a2..d8c98fa 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -109,7 +109,8 @@ public final class TextClassifierSettings {
*/
private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
"detect_languages_from_text_enabled";
-
+ /** Whether to use models downloaded by config updater. */
+ private static final String CONFIG_UPDATER_MODEL_ENABLED = "config_updater_model_enabled";
/** Whether to enable model downloading with ModelDownloadManager */
@VisibleForTesting
public static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
@@ -206,6 +207,7 @@ public final class TextClassifierSettings {
private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final boolean CONFIG_UPDATER_MODEL_ENABLED_DEFAULT = true;
private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "UNMETERED";
private static final int MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT = 5;
@@ -384,6 +386,11 @@ public final class TextClassifierSettings {
return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
}
+ public boolean isConfigUpdaterModelEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, CONFIG_UPDATER_MODEL_ENABLED, CONFIG_UPDATER_MODEL_ENABLED_DEFAULT);
+ }
+
public boolean isModelDownloadManagerEnabled() {
return deviceConfig.getBoolean(
NAMESPACE, MODEL_DOWNLOAD_MANAGER_ENABLED, MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT);
@@ -506,6 +513,7 @@ public final class TextClassifierSettings {
pw.printPair(USER_LANGUAGE_PROFILE_ENABLED, isUserLanguageProfileEnabled());
pw.printPair(TEMPLATE_INTENT_FACTORY_ENABLED, isTemplateIntentFactoryEnabled());
pw.printPair(TRANSLATE_IN_CLASSIFICATION_ENABLED, isTranslateInClassificationEnabled());
+ pw.printPair(CONFIG_UPDATER_MODEL_ENABLED, isConfigUpdaterModelEnabled());
pw.printPair(MODEL_DOWNLOAD_MANAGER_ENABLED, isModelDownloadManagerEnabled());
pw.printPair(MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, getModelDownloadWorkerMaxAttempts());
pw.printPair(MANIFEST_DOWNLOAD_MAX_ATTEMPTS, getManifestDownloadMaxAttempts());
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index 62390e2..775f9f9 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -23,6 +23,22 @@ package {
default_applicable_licenses: ["external_libtextclassifier_license"],
}
+java_library {
+ name: "TextClassifierServiceTestingLib",
+
+ srcs: [
+ "src/com/android/textclassifier/testing/*.java",
+ ],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ "TextClassifierServiceLib",
+ "androidx.test.espresso.core",
+ "mockito-target-minus-junit4",
+ ],
+}
+
android_test {
name: "TextClassifierServiceTest",
@@ -32,9 +48,14 @@ android_test {
"src/**/*.java",
],
+ exclude_srcs: [
+ "src/**/ModelDownloaderIntegrationTest.java",
+ "src/com/android/textclassifier/testing/*.java",
+ ],
+
+
static_libs: [
"androidx.test.ext.junit",
- "androidx.test.rules",
"androidx.test.espresso.core",
"androidx.test.ext.truth",
"mockito-target-minus-junit4",
@@ -47,6 +68,7 @@ android_test {
"textclassifierprotoslite",
"TextClassifierCoverageLib",
"androidx.work_work-testing",
+ "TextClassifierServiceTestingLib",
],
jni_libs: [
@@ -67,4 +89,37 @@ android_test {
instrumentation_for: "TextClassifierService",
data: ["testdata/*"],
+
+ test_config: "AndroidTest.xml",
+}
+
+android_test {
+ name: "TCSModelDownloaderIntegrationTest",
+
+ manifest: "AndroidManifest_TCSModelDownloaderIntegrationTest.xml",
+
+ srcs: [
+ "src/**/ModelDownloaderIntegrationTest.java",
+ ],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.espresso.core",
+ "androidx.test.ext.truth",
+ "ub-uiautomator",
+ "TextClassifierServiceTestingLib",
+ ],
+
+ jni_libs: [
+ "libtextclassifier",
+ ],
+
+ test_suites: [
+ "general-tests"
+ ],
+
+ min_sdk_version: "30",
+ sdk_version: "system_current",
+
+ test_config: "AndroidTest_TCSModelDownloaderIntegrationTest.xml",
}
diff --git a/java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml b/java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml
new file mode 100644
index 0000000..ff6ab85
--- /dev/null
+++ b/java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.downloader.tests">
+
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
+
+ <application>
+ <uses-library android:name="android.test.runner"/>
+ </application>
+
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.android.textclassifier.downloader.tests"/>
+</manifest>
diff --git a/java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml b/java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml
new file mode 100644
index 0000000..424b0f5
--- /dev/null
+++ b/java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml
@@ -0,0 +1,28 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<configuration description="Runs TCSModelDownloaderIntegrationTest.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="TCSModelDownloaderIntegrationTest.apk" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.textclassifier.downloader.tests" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ </test>
+</configuration>
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
index 27ea7f0..42177e6 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -18,31 +18,24 @@ package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
-import android.app.UiAutomation;
-import android.content.pm.PackageManager;
-import android.content.pm.PackageManager.NameNotFoundException;
import android.icu.util.ULocale;
-import android.provider.DeviceConfig;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.TextClassification;
-import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextLinks.TextLink;
import android.view.textclassifier.TextSelection;
-import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
-import androidx.test.platform.app.InstrumentationRegistry;
+import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-import org.junit.rules.ExternalResource;
import org.junit.runner.RunWith;
/**
@@ -146,67 +139,4 @@ public class TextClassifierApiTest {
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
assertThat(conversationAction.getAction()).isNotNull();
}
-
- /** A rule that manages a text classifier that is backed by the ExtServices. */
- private static class ExtServicesTextClassifierRule extends ExternalResource {
- private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
- "textclassifier_service_package_override";
- private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
- private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
-
- private String textClassifierServiceOverrideFlagOldValue;
-
- @Override
- protected void before() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- try {
- uiAutomation.adoptShellPermissionIdentity();
- textClassifierServiceOverrideFlagOldValue =
- DeviceConfig.getString(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- null);
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- getExtServicesPackageName(),
- /* makeDefault= */ false);
- } finally {
- uiAutomation.dropShellPermissionIdentity();
- }
- }
-
- @Override
- protected void after() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- try {
- uiAutomation.adoptShellPermissionIdentity();
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- textClassifierServiceOverrideFlagOldValue,
- /* makeDefault= */ false);
- } finally {
- uiAutomation.dropShellPermissionIdentity();
- }
- }
-
- private static String getExtServicesPackageName() {
- PackageManager packageManager =
- ApplicationProvider.getApplicationContext().getPackageManager();
- try {
- packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
- return PKG_NAME_GOOGLE_EXTSERVICES;
- } catch (NameNotFoundException e) {
- return PKG_NAME_AOSP_EXTSERVICES;
- }
- }
-
- public TextClassifier getTextClassifier() {
- TextClassificationManager textClassificationManager =
- ApplicationProvider.getApplicationContext()
- .getSystemService(TextClassificationManager.class);
- return textClassificationManager.getTextClassifier();
- }
- }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
new file mode 100644
index 0000000..e17dbef
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
@@ -0,0 +1,184 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Instrumentation;
+import android.app.UiAutomation;
+import android.util.Log;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassification.Request;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
+import java.io.IOException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ModelDownloaderIntegrationTest {
+ private static final String TAG = "ModelDownloaderTest";
+ private static final String EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/r/experimental/v999999999/en.fb.manifest";
+ private static final String EXPERIMENTAL_EN_TAG = "en_v999999999";
+ private static final String V804_EN_ANNOTATOR_MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/r/v804/en.fb.manifest";
+ private static final String V804_EN_TAG = "en_v804";
+ private static final String FACTORY_MODEL_TAG = "*";
+
+ @Rule
+ public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
+ new ExtServicesTextClassifierRule();
+
+ private TextClassifier textClassifier;
+
+ @Before
+ public void setup() throws Exception {
+ // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky
+ runShellCommand("device_config put textclassifier config_updater_model_enabled false");
+ runShellCommand("device_config put textclassifier model_download_manager_enabled true");
+ runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 1");
+
+ textClassifier = extServicesTextClassifierRule.getTextClassifier();
+ startExtservicesProcess();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
+ runShellCommand("device_config put textclassifier config_updater_model_enabled true");
+ runShellCommand(
+ "device_config put textclassifier model_download_backoff_delay_in_millis 3600000");
+ }
+
+ @Test
+ public void smokeTest() throws IOException, InterruptedException {
+ runShellCommand(
+ "device_config put textclassifier manifest_url_annotator_en "
+ + V804_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(V804_EN_TAG));
+ }
+
+ @Test
+ public void downgradeModel() throws IOException, InterruptedException {
+ // Download an experimental model.
+ {
+ runShellCommand(
+ "device_config put textclassifier manifest_url_annotator_en "
+ + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(EXPERIMENTAL_EN_TAG));
+ }
+
+ // Downgrade to an older model.
+ {
+ runShellCommand(
+ "device_config put textclassifier manifest_url_annotator_en "
+ + V804_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(V804_EN_TAG));
+ }
+ }
+
+ @Test
+ public void upgradeModel() throws IOException, InterruptedException {
+ // Download a model.
+ {
+ runShellCommand(
+ "device_config put textclassifier manifest_url_annotator_en "
+ + V804_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(V804_EN_TAG));
+ }
+
+ // Upgrade to an experimental model.
+ {
+ runShellCommand(
+ "device_config put textclassifier manifest_url_annotator_en "
+ + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(EXPERIMENTAL_EN_TAG));
+ }
+ }
+
+ @Test
+ public void clearFlag() throws IOException, InterruptedException {
+ // Download a new model.
+ {
+ runShellCommand(
+ "device_config put textclassifier manifest_url_annotator_en "
+ + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(EXPERIMENTAL_EN_TAG));
+ }
+
+ // Revert the flag.
+ {
+ runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
+ // Fallback to use the universal model.
+ assertWithRetries(
+ /* maxAttempts= */ 10, /* sleepMs= */ 500, () -> verifyActiveModel(FACTORY_MODEL_TAG));
+ }
+ }
+
+ private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable)
+ throws InterruptedException {
+ for (int i = 0; i < maxAttempts; i++) {
+ try {
+ assertRunnable.run();
+ break; // success. Bail out.
+ } catch (AssertionError ex) {
+ if (i == maxAttempts - 1) { // last attempt, give up.
+ throw ex;
+ } else {
+ Thread.sleep(sleepMs);
+ }
+ }
+ }
+ }
+
+ private void verifyActiveModel(String expectedVersion) {
+ TextClassification textClassification =
+ textClassifier.classifyText(new Request.Builder("abc", 0, 3).build());
+ // The result id contains the name of the just used model.
+ assertThat(textClassification.getId()).contains(expectedVersion);
+ }
+
+ private void startExtservicesProcess() {
+ // Start the process of ExtServices by sending it a text classifier request.
+ textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build());
+ }
+
+ private static void runShellCommand(String cmd) {
+ Log.v(TAG, "run shell command: " + cmd);
+ Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
+ UiAutomation uiAutomation = instrumentation.getUiAutomation();
+ uiAutomation.executeShellCommand(cmd);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
new file mode 100644
index 0000000..3ceb47b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.app.UiAutomation;
+import android.content.pm.PackageManager;
+import android.content.pm.PackageManager.NameNotFoundException;
+import android.provider.DeviceConfig;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.platform.app.InstrumentationRegistry;
+import org.junit.rules.ExternalResource;
+
+/** A rule that manages a text classifier that is backed by the ExtServices. */
+public final class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
+ "textclassifier_service_package_override";
+ private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
+ private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
+
+ private String textClassifierServiceOverrideFlagOldValue;
+
+ @Override
+ protected void before() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ textClassifierServiceOverrideFlagOldValue =
+ DeviceConfig.getString(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ null);
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ getExtServicesPackageName(),
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ @Override
+ protected void after() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ textClassifierServiceOverrideFlagOldValue,
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ private static String getExtServicesPackageName() {
+ PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager();
+ try {
+ packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
+ return PKG_NAME_GOOGLE_EXTSERVICES;
+ } catch (NameNotFoundException e) {
+ return PKG_NAME_AOSP_EXTSERVICES;
+ }
+ }
+
+ public TextClassifier getTextClassifier() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ textClassificationManager.setTextClassifier(null); // Reset TC overrides
+ return textClassificationManager.getTextClassifier();
+ }
+}