diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-11-02 22:45:31 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-11-02 22:45:31 +0000 |
commit | 33fc97465f0c65b2dbd383b13454e50577d20158 (patch) | |
tree | 1ba5b6be224081464652eea44bfb309e02b288f5 | |
parent | 8ff093ddd3cd8e85cd728aa7dcd664cc95f00244 (diff) | |
parent | dd9ce894754412a5bb795c1ed38e78b34d77490c (diff) | |
download | OnDevicePersonalization-android14-mainline-uwb-release.tar.gz |
Snap for 11041982 from dd9ce894754412a5bb795c1ed38e78b34d77490c to mainline-uwb-releaseaml_uwb_341310300aml_uwb_341310030android14-mainline-uwb-release
Change-Id: Ia5299067f822b8e0485ef9223c83b05c143b0ca2
348 files changed, 22617 insertions, 6953 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..5e6ab4b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +# IntelliJ project files +**/.idea +**/*.iml +**/*.ipr
\ No newline at end of file @@ -113,6 +113,7 @@ android_app { "kotlinx-coroutines-android", "ondevicepersonalization-protos", "mobile_data_downloader_lib", + "modules-utils-build", "ondevicepersonalization-plugin-lib", "flatbuffers-java", "apache-velocity-engine-core", @@ -122,6 +123,7 @@ android_app { min_sdk_version: "33", updatable: true, certificate: "platform", + privileged: true, apex_available: ["com.android.ondevicepersonalization"], defaults: [ "ondevicepersonalization-java-defaults", diff --git a/AndroidManifest.xml b/AndroidManifest.xml index adf7bd6b..55bfdab9 100644 --- a/AndroidManifest.xml +++ b/AndroidManifest.xml @@ -23,12 +23,10 @@ android:versionName="T-initial"> <uses-sdk android:minSdkVersion="33"/> - <uses-permission android:name="android.permission.START_ACTIVITIES_FROM_BACKGROUND"/> <uses-permission android:name="android.permission.INTERNET"/> <uses-permission android:name="android.permission.ACCESS_NETWORK_STATE"/> <uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION"/> <uses-permission android:name="android.permission.ACCESS_FINE_LOCATION"/> - <uses-permission android:name="android.permission.PACKAGE_USAGE_STATS"/> <uses-permission android:name="android.permission.READ_PHONE_STATE"/> <!-- Required for the app to find all packages onboarded to ODP --> @@ -40,6 +38,10 @@ <!-- Required for reading device configs --> <uses-permission android:name="android.permission.READ_DEVICE_CONFIG"/> + <!-- Required for modifying personalization state --> + <permission android:name="android.permission.ondevicepersonalization.MODIFY_ONDEVICEPERSONALIZATION_STATE" + android:protectionLevel="signature|privileged"/> + <application android:name=".OnDevicePersonalizationApplication" android:forceQueryable="true"> <service android:name=".OnDevicePersonalizationManagingServiceImpl" android:exported="true" > @@ -47,9 +49,14 @@ <action android:name="android.OnDevicePersonalizationService" /> </intent-filter> </service> - <service android:name=".OnDevicePersonalizationPrivacyStatusServiceImpl" android:exported="true" > + <service android:name=".OnDevicePersonalizationConfigServiceImpl" android:exported="true" > + <intent-filter> + <action android:name="android.OnDevicePersonalizationConfigService" /> + </intent-filter> + </service> + <service android:name=".OnDevicePersonalizationDebugServiceImpl" android:exported="true" > <intent-filter> - <action android:name="android.OnDevicePersonalizationPrivacyStatusService" /> + <action android:name="android.OnDevicePersonalizationDebugService" /> </intent-filter> </service> <service android:name="com.android.ondevicepersonalization.services.download.OnDevicePersonalizationDownloadProcessingJobService" @@ -68,25 +75,22 @@ android:exported="false" android:permission="android.permission.BIND_JOB_SERVICE"> </service> - <service android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpFederatedComputeJobService" - android:exported="false" - android:permission="android.permission.BIND_JOB_SERVICE"> - </service> <service android:name="com.android.ondevicepersonalization.libraries.plugin.internal.PluginExecutorService" android:isolatedProcess="true" android:exported="false" > </service> - <service android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpExampleStoreService" - android:enabled="true" android:exported="true" > + <service + android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpExampleStoreService" + android:enabled="true" + android:exported="true" + android:permission="android.permission.BIND_EXAMPLE_STORE_SERVICE"> <intent-filter> <action android:name="android.federatedcompute.EXAMPLE_STORE" /> - <data android:scheme="app" /> </intent-filter> </service> <service android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpResultHandlingService" android:enabled="true" android:exported="true" > <intent-filter> <action android:name="android.federatedcompute.COMPUTATION_RESULT" /> - <data android:scheme="app" /> </intent-filter> </service> diff --git a/TEST_MAPPING b/TEST_MAPPING index 50f5e0e7..8fe285ec 100644 --- a/TEST_MAPPING +++ b/TEST_MAPPING @@ -23,6 +23,10 @@ { // Install com.google.android.ondevicepersonalization.apex and run OnDevicePersonalizationSystemServiceImplTests. "name": "OnDevicePersonalizationSystemServiceImplTests[com.google.android.ondevicepersonalization.apex]" + }, + { + // Install com.google.android.ondevicepersonalization.apex and run OnDevicePersonalizationSystemServiceImplTests. + "name": "CtsOnDevicePersonalizationE2ETests[com.google.android.ondevicepersonalization.apex]" } ], "presubmit": [ @@ -43,6 +47,9 @@ }, { "name": "OnDevicePersonalizationSystemServiceImplTests" + }, + { + "name": "CtsOnDevicePersonalizationE2ETests" } ] } diff --git a/apex/Android.bp b/apex/Android.bp index a6256964..db58dd39 100644 --- a/apex/Android.bp +++ b/apex/Android.bp @@ -71,6 +71,7 @@ bootclasspath_fragment { "android.ondevicepersonalization", "com.android.ondevicepersonalization.internal", "android.federatedcompute", + "com.android.federatedcompute.internal", ], }, } @@ -93,4 +94,5 @@ apex { systemserverclasspath_fragments: ["com.android.ondevicepersonalization-systemserverclasspath-fragment"], defaults: ["t-launched-apex-module"], prebuilts: ["current_sdkinfo"], + jni_libs: ["libfcp_cpp_dep_jni"], } diff --git a/federatedcompute/Android.bp b/federatedcompute/Android.bp index cffcae7f..a09bf9af 100644 --- a/federatedcompute/Android.bp +++ b/federatedcompute/Android.bp @@ -53,3 +53,36 @@ filegroup { "//packages/modules/OnDevicePersonalization:__subpackages__" ], } + +cc_library_shared { + name: "libfcp_cpp_dep_jni", + srcs: [ + "jni/cpp/*.cc", + ], + version_script: "jni/jni.lds", + whole_static_libs: [ + "libfederatedcompute", + ], + static_libs:[ + "federated-compute-cc-proto-lite", + "libprotobuf-cpp-lite-ndk", + ], + header_libs: [ + "libnativehelper_header_only", + "libeigen", + ], + shared_libs: [ + "liblog", + "libcrypto", + ], + stl: "libc++_static", + apex_available: ["com.android.ondevicepersonalization"], + sdk_version: "current", + min_sdk_version: "Tiramisu", + visibility: [ + "//packages/modules/OnDevicePersonalization:__subpackages__", + ], + cflags: [ + "-Wno-unused-parameter", + ], +} diff --git a/federatedcompute/apk/Android.bp b/federatedcompute/apk/Android.bp index 228d0060..a8c73204 100644 --- a/federatedcompute/apk/Android.bp +++ b/federatedcompute/apk/Android.bp @@ -21,9 +21,17 @@ genrule { name: "statslog-federatedcompute-java-gen", tools: ["stats-log-api-gen"], cmd: "$(location stats-log-api-gen) --java $(out) --module ondevicepersonalization" + - " --javaPackage com.android.ondevicepersonalization" + + " --javaPackage com.android.federatedcompute.services.stats" + " --javaClass FederatedComputeStatsLog", - out: ["com/android/federatedcompute/services/FederatedComputeStatsLog.java"], + out: ["com/android/federatedcompute/services/stats/FederatedComputeStatsLog.java"], + visibility: [ + "//packages/modules/OnDevicePersonalization:__subpackages__" + ], +} + +filegroup { + name: "fcp-apk-jarjar", + srcs: ["jarjar_rules.txt"] } android_app_certificate { @@ -37,10 +45,10 @@ android_app { ":federatedcompute-sources", ":federatedcompute-fbs", ":statslog-federatedcompute-java-gen", + ":fcp_native_wrapper", ], libs: [ "auto_value_annotations", - "flatbuffers-java", "framework-ondevicepersonalization.impl", "framework-annotations-lib", "framework-configinfrastructure", // For PH flags @@ -48,17 +56,23 @@ android_app { ], plugins: ["auto_value_plugin"], static_libs: [ + "flatbuffers-java", + "androidx.concurrent_concurrent-futures", "federated-compute-java-proto-lite", "guava", - "modules-utils-preconditions", "libprotobuf-java-lite", + "modules-utils-preconditions", ], sdk_version: "module_current", min_sdk_version: "Tiramisu", + jarjar_rules: ":fcp-apk-jarjar", updatable: true, certificate: ":com.android.federatedcompute.certificate", apex_available: ["com.android.ondevicepersonalization"], defaults: [ "federatedcompute-java-defaults", ], + optimize: { + proguard_flags_files: ["proguard.flags"], + }, }
\ No newline at end of file diff --git a/federatedcompute/apk/AndroidManifest.xml b/federatedcompute/apk/AndroidManifest.xml index c6c7aa21..dec46ab3 100644 --- a/federatedcompute/apk/AndroidManifest.xml +++ b/federatedcompute/apk/AndroidManifest.xml @@ -20,6 +20,11 @@ android:versionCode="1" android:versionName="T-initial"> + <!-- Define the permission to call federated compute clients implemented ExampleStoreService --> + <permission + android:name="android.permission.BIND_EXAMPLE_STORE_SERVICE" + android:protectionLevel="signature" /> + <!-- Required for persisting scheduled jobs --> <uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" /> <!-- Used for scheduling connectivity jobs --> @@ -27,11 +32,14 @@ <uses-permission android:name="android.permission.INTERNET" /> <!-- Required for reading device configs --> <uses-permission android:name="android.permission.READ_DEVICE_CONFIG"/> + <!-- Permission to call federated compute clients implemented ExampleStoreService --> + <uses-permission + android:name="android.permission.BIND_EXAMPLE_STORE_SERVICE" /> - <application> - <service android:name=".FederatedComputeServiceImpl" android:exported="true" > + <application android:forceQueryable="true"> + <service android:name=".FederatedComputeManagingServiceImpl" android:exported="true" > <intent-filter> - <action android:name="com.android.federatedcompute.FederatedComputeService" /> + <action android:name="android.federatedcompute.FederatedComputeService" /> </intent-filter> </service> <!-- The JobService runs in main process, so when JobScheduler wakes up, it allows us to diff --git a/federatedcompute/apk/jarjar_rules.txt b/federatedcompute/apk/jarjar_rules.txt new file mode 100644 index 00000000..711d1d2c --- /dev/null +++ b/federatedcompute/apk/jarjar_rules.txt @@ -0,0 +1 @@ +rule com.android.internal.util.** com.android.ondevicepersonalization.internal.util.masked.@1
\ No newline at end of file diff --git a/federatedcompute/apk/proguard.flags b/federatedcompute/apk/proguard.flags new file mode 100644 index 00000000..db6acaed --- /dev/null +++ b/federatedcompute/apk/proguard.flags @@ -0,0 +1,3 @@ +# Keep all impl classes referenced via JNI. +-keep class com.android.federatedcompute.services.training.jni.** { *; } +-keep class com.android.federatedcompute.services.training.util.** { *; }
\ No newline at end of file diff --git a/federatedcompute/jni/cpp/example_iterator_wrapper_impl.cc b/federatedcompute/jni/cpp/example_iterator_wrapper_impl.cc new file mode 100644 index 00000000..eb9f77d6 --- /dev/null +++ b/federatedcompute/jni/cpp/example_iterator_wrapper_impl.cc @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "example_iterator_wrapper_impl.h" + +#include <jni.h> + +#include <string> + +#include "fcp/jni/jni_util.h" +#include "fcp/protos/federatedcompute/common.pb.h" +#include "more_jni_util.h" +#include "nativehelper/scoped_local_ref.h" + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +using ::fcp::jni::JavaMethodSig; +using ::fcp::jni::ParseProtoFromJByteArray; +using ::fcp::jni::ScopedJniEnv; +using ::fcp::jni::SerializeProtoToJByteArray; + +struct JavaExampleIteratorClassDesc { + static constexpr JavaMethodSig kNext = {"next", "()[B"}; + static constexpr JavaMethodSig kClose = {"close", "()V"}; +}; + +ExampleIteratorWrapperImpl::ExampleIteratorWrapperImpl(JavaVM *jvm, + jobject example_iterator) + : jvm_(jvm) { + { + std::lock_guard<std::mutex> lock(close_mu_); + closed_ = false; + } + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + jthis_ = env->NewGlobalRef(example_iterator); + FCP_CHECK(jthis_ != nullptr); + + ScopedLocalRef<jclass> example_iterator_class( + env, env->GetObjectClass(example_iterator)); + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); + FCP_CHECK(example_iterator_class.get() != nullptr); + + next_id_ = MoreJniUtil::GetMethodIdOrAbort( + env, example_iterator_class.get(), JavaExampleIteratorClassDesc::kNext); + close_id_ = MoreJniUtil::GetMethodIdOrAbort( + env, example_iterator_class.get(), JavaExampleIteratorClassDesc::kClose); +} + +ExampleIteratorWrapperImpl::~ExampleIteratorWrapperImpl() { + // This ensures that the java iterator instance is released when the + // TensorFlow session is freed. + Close(); + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + env->DeleteGlobalRef(jthis_); +} + +absl::StatusOr<std::string> ExampleIteratorWrapperImpl::Next() { + { + std::lock_guard<std::mutex> lock(close_mu_); + if (closed_) { + return absl::InternalError("Next() called on closed iterator."); + } + } + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + + ScopedLocalRef<jbyteArray> example( + env, (jbyteArray)env->CallObjectMethod(jthis_, next_id_)); + FCP_RETURN_IF_ERROR( + MoreJniUtil::GetExceptionStatus(env, "call JavaExampleIterator.Next()")); + FCP_CHECK(example.get() != nullptr); + + int result_size = env->GetArrayLength(example.get()); + if (result_size == 0) { + return absl::OutOfRangeError("end of iterator reached"); + } + std::string example_string = + MoreJniUtil::JByteArrayToString(env, example.get()); + return example_string; +} + +// Close the iterator to release associated resources. +void ExampleIteratorWrapperImpl::Close() { + std::lock_guard<std::mutex> lock(close_mu_); + if (closed_) { + return; + } + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + + env->CallVoidMethod(jthis_, close_id_); + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); + closed_ = true; +} + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp diff --git a/federatedcompute/jni/cpp/example_iterator_wrapper_impl.h b/federatedcompute/jni/cpp/example_iterator_wrapper_impl.h new file mode 100644 index 00000000..a4df061c --- /dev/null +++ b/federatedcompute/jni/cpp/example_iterator_wrapper_impl.h @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <jni.h> + +#include <optional> +#include <string> + +#include "absl/status/statusor.h" +#include "fcp/client/flags.h" +#include "fcp/client/log_manager.h" +#include "fcp/client/simple_task_environment.h" + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +// A wrapper around the Java example iterator class. Object is only valid for +// the time of a JNI call and must not be kept around for any longer. Can be +// created on an arbitrary thread. +class ExampleIteratorWrapperImpl : public ExampleIterator { + public: + ExampleIteratorWrapperImpl(JavaVM *jvm, jobject example_iterator); + + ~ExampleIteratorWrapperImpl() override; + + absl::StatusOr<std::string> Next() override; + + void Close() override final; + + private: + jmethodID next_id_; + jmethodID close_id_; + std::mutex close_mu_; + bool closed_ GUARDED_BY(close_mu_); + JavaVM *const jvm_; + jobject jthis_; +}; + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp diff --git a/federatedcompute/jni/cpp/fl_runner_jni.cc b/federatedcompute/jni/cpp/fl_runner_jni.cc new file mode 100644 index 00000000..10a373da --- /dev/null +++ b/federatedcompute/jni/cpp/fl_runner_jni.cc @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <jni.h> + +#include "example_iterator_wrapper_impl.h" +#include "fcp/client/fcp_runner.h" +#include "fcp/client/fl_runner.pb.h" +#include "fcp/client/interruptible_runner.h" +#include "fcp/jni/jni_util.h" +#include "flags_impl.h" +#include "log_manager_wrapper_impl.h" +#include "more_jni_util.h" +#include "simple_task_environment_wrapper_impl.h" + +#define JFUN(METHOD_NAME) \ + Java_com_android_federatedcompute_services_training_jni_FlRunnerWrapper_##METHOD_NAME // NOLINT + +using fcp::jni::ParseProtoFromJByteArray; + +extern "C" JNIEXPORT jbyteArray JNICALL JFUN(runNativeFederatedComputation)( + JNIEnv *env, jclass, jobject java_simple_task_env, + jstring population_name_jstring, jstring session_name_jstring, + jstring task_name_jstring, jobject java_native_log_manager, + jbyteArray client_only_plan_bytes, + jstring checkpoint_input_filename_jstring, + jstring checkpoint_output_filename_jstring) { + google::internal::federated::plan::ClientOnlyPlan client_only_plan = + ParseProtoFromJByteArray< + google::internal::federated::plan::ClientOnlyPlan>( + env, client_only_plan_bytes); + + const fcp::client::engine::jni::FlagsImpl flags; + JavaVM *jvm = MoreJniUtil::getJavaVm(env); + fcp::client::engine::jni::SimpleTaskEnvironmentWrapperImpl + simple_task_env_impl(jvm, java_simple_task_env); + + std::string population_name = + MoreJniUtil::JStringToString(env, population_name_jstring); + std::string session_name = + MoreJniUtil::JStringToString(env, session_name_jstring); + std::string task_name = MoreJniUtil::JStringToString(env, task_name_jstring); + // TODO: add real implementation of log manager. + fcp::client::engine::jni::LogManagerWrapperImpl log_manager_impl( + jvm, java_native_log_manager); + std::string checkpoint_input_filename = + MoreJniUtil::JStringToString(env, checkpoint_input_filename_jstring); + std::string checkpoint_output_filename = + MoreJniUtil::JStringToString(env, checkpoint_output_filename_jstring); + fcp::client::InterruptibleRunner::TimingConfig timing_config = { + .polling_period = absl::Milliseconds(1000), + }; + + absl::StatusOr<fcp::client::FLRunnerResult> fl_runner_result = + fcp::client::RunFederatedComputation( + &simple_task_env_impl, &log_manager_impl, &flags, client_only_plan, + checkpoint_input_filename, checkpoint_output_filename, session_name, + population_name, task_name, timing_config); + FCP_CHECK(fl_runner_result.ok()); + // Serialize + return the FLRunnerResult. + jbyteArray fl_runner_result_serialized = + fcp::jni::SerializeProtoToJByteArray(env, fl_runner_result.value()); + return fl_runner_result_serialized; +} diff --git a/federatedcompute/jni/cpp/flags_impl.h b/federatedcompute/jni/cpp/flags_impl.h new file mode 100644 index 00000000..7d0a5546 --- /dev/null +++ b/federatedcompute/jni/cpp/flags_impl.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <jni.h> + +#include <optional> +#include <string> + +#include "fcp/client/flags.h" + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +class FlagsImpl : public fcp::client::Flags { + public: + bool use_tflite_training() const override { return true; } + bool use_http_federated_compute_protocol() const override { return true; } + bool enable_opstats() const override { return false; } + + int64_t condition_polling_period_millis() const override { return 1000; } + int64_t tf_execution_teardown_grace_period_millis() const override { + return 1000; + } + int64_t tf_execution_teardown_extended_period_millis() const override { + return 2000; + } + int64_t grpc_channel_deadline_seconds() const override { return 0; } + bool log_tensorflow_error_messages() const override { return true; } + bool enable_example_query_plan_engine() const override { return true; } +}; + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp diff --git a/federatedcompute/jni/cpp/log_manager_wrapper_impl.cc b/federatedcompute/jni/cpp/log_manager_wrapper_impl.cc new file mode 100644 index 00000000..9dafa2d5 --- /dev/null +++ b/federatedcompute/jni/cpp/log_manager_wrapper_impl.cc @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "log_manager_wrapper_impl.h" + +#include <jni.h> + +#include <string> + +#include "fcp/jni/jni_util.h" +#include "fcp/protos/federatedcompute/common.pb.h" +#include "more_jni_util.h" +#include "nativehelper/scoped_local_ref.h" + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +using ::fcp::jni::JavaMethodSig; +using ::fcp::jni::ParseProtoFromJByteArray; +using ::fcp::jni::ScopedJniEnv; +using ::fcp::jni::SerializeProtoToJByteArray; + +struct LogManagerClassDesc { + static constexpr JavaMethodSig kLogProdDiag = {"logProdDiag", "(I)V"}; + static constexpr JavaMethodSig kLogDebugDiag = {"logDebugDiag", "(I)V"}; + static constexpr JavaMethodSig kLogToLongHistogram = {"logToLongHistogram", + "(IIIIJ)V"}; + static constexpr JavaMethodSig kLogToLongHistogramWithModelIdentifier = { + "logToLongHistogram", "(IIIILjava/lang/String;J)V"}; +}; + +LogManagerWrapperImpl::LogManagerWrapperImpl(JavaVM* jvm, + jobject java_log_manager) + : jvm_(jvm) { + ScopedJniEnv scoped_env(jvm_); + JNIEnv* env = scoped_env.env(); + jthis_ = env->NewGlobalRef(java_log_manager); + FCP_CHECK(jthis_ != nullptr); + + ScopedLocalRef<jclass> java_log_manager_class( + env, env->GetObjectClass(java_log_manager)); + FCP_CHECK(java_log_manager_class.get() != nullptr); + + log_prod_diag_id_ = MoreJniUtil::GetMethodIdOrAbort( + env, java_log_manager_class.get(), LogManagerClassDesc::kLogProdDiag); + + log_debug_diag_id_ = MoreJniUtil::GetMethodIdOrAbort( + env, java_log_manager_class.get(), LogManagerClassDesc::kLogDebugDiag); + + log_to_long_histogram_id_ = + MoreJniUtil::GetMethodIdOrAbort(env, java_log_manager_class.get(), + LogManagerClassDesc::kLogToLongHistogram); + log_to_long_histogram_with_model_identifier_id_ = + MoreJniUtil::GetMethodIdOrAbort( + env, java_log_manager_class.get(), + LogManagerClassDesc::kLogToLongHistogramWithModelIdentifier); +} + +LogManagerWrapperImpl::~LogManagerWrapperImpl() { + ScopedJniEnv scoped_env(jvm_); + JNIEnv* env = scoped_env.env(); + env->DeleteGlobalRef(jthis_); +} + +void LogManagerWrapperImpl::LogDiag(ProdDiagCode diag_code) { + ScopedJniEnv scoped_env(jvm_); + JNIEnv* env = scoped_env.env(); + env->CallVoidMethod(jthis_, log_prod_diag_id_, static_cast<jint>(diag_code)); + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); +} + +void LogManagerWrapperImpl::LogDiag(DebugDiagCode diag_code) { + ScopedJniEnv scoped_env(jvm_); + JNIEnv* env = scoped_env.env(); + env->CallVoidMethod(jthis_, log_debug_diag_id_, static_cast<jint>(diag_code)); + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); +} + +void LogManagerWrapperImpl::LogToLongHistogram( + HistogramCounters histogram_counter, int execution_index, int epoch_index, + DataSourceType data_source_type, int64_t value) { + ScopedJniEnv scoped_env(jvm_); + JNIEnv* env = scoped_env.env(); + if (model_identifier_.has_value()) { + ScopedLocalRef<jstring> model_identifier_jstring( + env, env->NewStringUTF(model_identifier_.value().c_str())); + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); + + env->CallVoidMethod(jthis_, log_to_long_histogram_with_model_identifier_id_, + static_cast<jint>(histogram_counter), execution_index, + epoch_index, static_cast<jint>(data_source_type), + model_identifier_jstring.get(), + static_cast<jlong>(value)); + } else { + env->CallVoidMethod(jthis_, log_to_long_histogram_id_, + static_cast<jint>(histogram_counter), execution_index, + epoch_index, static_cast<jint>(data_source_type), + static_cast<jlong>(value)); + } + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); +} + +void LogManagerWrapperImpl::SetModelIdentifier( + const std::string& model_identifier) { + model_identifier_ = model_identifier; +} + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp diff --git a/federatedcompute/jni/cpp/log_manager_wrapper_impl.h b/federatedcompute/jni/cpp/log_manager_wrapper_impl.h new file mode 100644 index 00000000..44f72171 --- /dev/null +++ b/federatedcompute/jni/cpp/log_manager_wrapper_impl.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <jni.h> + +#include <optional> +#include <string> + +#include "example_iterator_wrapper_impl.h" +#include "fcp/client/flags.h" +#include "fcp/client/log_manager.h" + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +// A wrapper around the LogManager Java class. Can be created on an +// arbitrary thread. +class LogManagerWrapperImpl : public LogManager { + public: + LogManagerWrapperImpl(const LogManagerWrapperImpl &) = delete; + void *operator new(std::size_t) = delete; + LogManagerWrapperImpl &operator=(const LogManagerWrapperImpl &) = delete; + + LogManagerWrapperImpl(JavaVM *jvm, jobject java_log_manager); + + ~LogManagerWrapperImpl() override; + + void LogDiag(ProdDiagCode diag_code) override; + + void LogDiag(DebugDiagCode diag_code) override; + + void LogToLongHistogram(HistogramCounters histogram_counter, + int execution_index, int epoch_index, + DataSourceType data_source_type, + int64_t value) override; + + void SetModelIdentifier(const std::string &model_identifier) override; + + private: + JavaVM *const jvm_; + jobject jthis_; + jmethodID log_prod_diag_id_; + jmethodID log_debug_diag_id_; + jmethodID log_to_long_histogram_id_; + jmethodID log_to_long_histogram_with_model_identifier_id_; + std::optional<std::string> model_identifier_; +}; + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp diff --git a/federatedcompute/jni/cpp/more_jni_util.h b/federatedcompute/jni/cpp/more_jni_util.h new file mode 100644 index 00000000..c7789570 --- /dev/null +++ b/federatedcompute/jni/cpp/more_jni_util.h @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <jni.h> + +#include "fcp/base/monitoring.h" +#include "fcp/jni/jni_util.h" +#include "fcp/protos/federatedcompute/common.pb.h" + +class MoreJniUtil { + public: + static jmethodID GetMethodIdOrAbort(JNIEnv *env, jclass clazz, + fcp::jni::JavaMethodSig method) { + jmethodID id = env->GetMethodID(clazz, method.name, method.signature); + FCP_CHECK(!CheckForJniException(env)); + FCP_CHECK(id != nullptr); + return id; + } + + static bool CheckForJniException(JNIEnv *env) { + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return true; + } + return false; + } + + static absl::Status GetExceptionStatus(JNIEnv *env, const char *msg) { + if (env->ExceptionCheck()) { + FCP_LOG(WARNING) << "GetExceptionStatus [" << msg + << "] Java exception follows"; + env->ExceptionDescribe(); + env->ExceptionClear(); + // We gracefully return AbortedError to c++ code instead of crash app. The + // underlying c++ code only checks if status is ok, if not it will return + // error status code to java. + return absl::AbortedError(absl::StrCat("Got Exception when ", msg)); + } else { + return absl::OkStatus(); + } + } + + static std::string JStringToString(JNIEnv *env, jstring jstr) { + if (jstr == nullptr) { + return std::string(); + } + const char *cstring = env->GetStringUTFChars(jstr, nullptr); + FCP_CHECK(!env->ExceptionCheck()); + std::string result(cstring); + env->ReleaseStringUTFChars(jstr, cstring); + return result; + } + + static std::string JByteArrayToString(JNIEnv *env, jbyteArray array) { + FCP_CHECK(array != nullptr); + std::string result; + int array_length = env->GetArrayLength(array); + FCP_CHECK(!env->ExceptionCheck()); + + result.resize(array_length); + env->GetByteArrayRegion( + array, 0, array_length, + reinterpret_cast<jbyte *>(const_cast<char *>(result.data()))); + FCP_CHECK(!env->ExceptionCheck()); + + return result; + } + + static JavaVM *getJavaVm(JNIEnv *env) { + JavaVM *jvm = nullptr; + env->GetJavaVM(&jvm); + FCP_CHECK(jvm != nullptr); + return jvm; + } +}; diff --git a/federatedcompute/jni/cpp/simple_task_environment_wrapper_impl.cc b/federatedcompute/jni/cpp/simple_task_environment_wrapper_impl.cc new file mode 100644 index 00000000..127bd819 --- /dev/null +++ b/federatedcompute/jni/cpp/simple_task_environment_wrapper_impl.cc @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "simple_task_environment_wrapper_impl.h" + +#include "example_iterator_wrapper_impl.h" +#include "fcp/protos/federatedcompute/common.pb.h" +#include "more_jni_util.h" +#include "nativehelper/scoped_local_ref.h" + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +using ::fcp::jni::ParseProtoFromJByteArray; +using ::fcp::jni::ScopedJniEnv; +using ::fcp::jni::SerializeProtoToJByteArray; +using ::google::internal::federated::plan::ExampleSelector; + +SimpleTaskEnvironmentWrapperImpl::SimpleTaskEnvironmentWrapperImpl( + JavaVM *jvm, jobject simple_task_env) + : jvm_(jvm) { + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + jthis_ = env->NewGlobalRef(simple_task_env); + FCP_CHECK(jthis_ != nullptr); + + ScopedLocalRef<jclass> simple_task_env_class( + env, env->GetObjectClass(simple_task_env)); + FCP_CHECK(simple_task_env_class.get() != nullptr); + + training_conditions_satisfied_id_ = MoreJniUtil::GetMethodIdOrAbort( + env, simple_task_env_class.get(), + SimpleTaskEnvironmentImplClassDesc::kTrainingConditionsSatisfied); + create_example_iterator_id_ = MoreJniUtil::GetMethodIdOrAbort( + env, simple_task_env_class.get(), + SimpleTaskEnvironmentImplClassDesc::kCreateExampleIterator); +} + +SimpleTaskEnvironmentWrapperImpl::~SimpleTaskEnvironmentWrapperImpl() { + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + env->DeleteGlobalRef(jthis_); +} + +absl::StatusOr<std::unique_ptr<ExampleIterator>> +SimpleTaskEnvironmentWrapperImpl::CreateExampleIterator( + const ExampleSelector &example_selector, + const SelectorContext &selector_context) { + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + ScopedLocalRef<jbyteArray> serialized_example_selector( + env, SerializeProtoToJByteArray(env, example_selector)); + ScopedLocalRef<jbyteArray> serialized_selector_context( + env, SerializeProtoToJByteArray(env, selector_context)); + + jobject java_example_iterator = env->CallObjectMethod( + jthis_, create_example_iterator_id_, serialized_example_selector.get(), + serialized_selector_context.get()); + FCP_RETURN_IF_ERROR( + MoreJniUtil::GetExceptionStatus(env, "call JavaCreateExampleIterator")); + FCP_CHECK(java_example_iterator != nullptr); + // TODO(b/301323421): check if abseil can be used in JNI layer. + return absl::StatusOr<std::unique_ptr<ExampleIterator>>( + std::make_unique<ExampleIteratorWrapperImpl>(jvm_, + java_example_iterator)); +} + +absl::StatusOr<std::unique_ptr<ExampleIterator>> +SimpleTaskEnvironmentWrapperImpl::CreateExampleIterator( + const ExampleSelector &example_selector) { + return CreateExampleIterator(example_selector, SelectorContext()); +} + +// Don't call to java implementation because training runs in isolated process +// which doesn't have file system access. +std::string SimpleTaskEnvironmentWrapperImpl::GetBaseDir() { return ""; } + +// Don't call to java implementation because training runs in isolated process +// which doesn't have file system access. +std::string SimpleTaskEnvironmentWrapperImpl::GetCacheDir() { return ""; } + +bool SimpleTaskEnvironmentWrapperImpl::TrainingConditionsSatisfied() { + ScopedJniEnv scoped_env(jvm_); + JNIEnv *env = scoped_env.env(); + bool training_conditions_satisfied = + env->CallBooleanMethod(jthis_, training_conditions_satisfied_id_); + FCP_CHECK(!MoreJniUtil::CheckForJniException(env)); + return training_conditions_satisfied; +} + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp
\ No newline at end of file diff --git a/federatedcompute/jni/cpp/simple_task_environment_wrapper_impl.h b/federatedcompute/jni/cpp/simple_task_environment_wrapper_impl.h new file mode 100644 index 00000000..5742f086 --- /dev/null +++ b/federatedcompute/jni/cpp/simple_task_environment_wrapper_impl.h @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <jni.h> + +#include <memory> + +#include "absl/status/statusor.h" +#include "fcp/client/simple_task_environment.h" +#include "fcp/jni/jni_util.h" +#include "fcp/protos/plan.pb.h" + +/** + * A wrapper around a Java class for callbacks into Java. This is required for + * functionality we do not provide in c++ code, such as Clearcut logging, + * checking for device conditions, or querying apps for examples. + */ + +namespace fcp { +namespace client { +namespace engine { +namespace jni { + +using ::google::internal::federated::plan::ExampleSelector; + +// Descriptions of Java classes we call into - full class path, and method +// names + JNI signatures. +struct SimpleTaskEnvironmentImplClassDesc { + static constexpr fcp::jni::JavaMethodSig kTrainingConditionsSatisfied = { + "trainingConditionsSatisfied", "()Z"}; + static constexpr fcp::jni::JavaMethodSig kCreateExampleIterator = { + "createExampleIteratorWithContext", + "([B[B)Lcom/android/federatedcompute/services/training/jni/" + "JavaExampleIterator;"}; +}; + +class SimpleTaskEnvironmentWrapperImpl : public SimpleTaskEnvironment { + public: + explicit SimpleTaskEnvironmentWrapperImpl(JavaVM *jvm, + jobject simple_task_env); + + ~SimpleTaskEnvironmentWrapperImpl() override; + + std::string GetBaseDir() override; + + std::string GetCacheDir() override; + + bool TrainingConditionsSatisfied() override; + + absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( + const google::internal::federated::plan::ExampleSelector + &example_selector) override; + + absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( + const ExampleSelector &example_selector, + const SelectorContext &selector_context) override; + + private: + JavaVM *jvm_; + jobject jthis_; + jmethodID training_conditions_satisfied_id_; + jmethodID create_example_iterator_id_; +}; + +} // namespace jni +} // namespace engine +} // namespace client +} // namespace fcp diff --git a/federatedcompute/jni/jni.lds b/federatedcompute/jni/jni.lds new file mode 100644 index 00000000..73ec879e --- /dev/null +++ b/federatedcompute/jni/jni.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export JNI symbols. + global: + Java_*; + + # Hide everything else. + local: + *; +};
\ No newline at end of file diff --git a/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegate.java b/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegate.java index 5e2540bc..8537ba4c 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegate.java +++ b/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegate.java @@ -16,16 +16,28 @@ package com.android.federatedcompute.services; +import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; +import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; + +import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL; +import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE; + import android.annotation.NonNull; import android.content.Context; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IFederatedComputeService; import android.federatedcompute.common.TrainingOptions; import android.os.Binder; +import android.os.RemoteException; +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.common.Clock; import com.android.federatedcompute.services.common.FederatedComputeExecutors; import com.android.federatedcompute.services.common.FlagsFactory; +import com.android.federatedcompute.services.common.MonotonicClock; import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager; +import com.android.federatedcompute.services.statsd.ApiCallStats; +import com.android.federatedcompute.services.statsd.FederatedComputeStatsdLogger; import com.google.common.annotations.VisibleForTesting; @@ -35,6 +47,8 @@ import java.util.Objects; public class FederatedComputeManagingServiceDelegate extends IFederatedComputeService.Stub { private static final String TAG = "FcpServiceDelegate"; @NonNull private final Context mContext; + private final FederatedComputeStatsdLogger mFcStatsdLogger; + private final Clock mClock; @VisibleForTesting static class Injector { @@ -45,19 +59,25 @@ public class FederatedComputeManagingServiceDelegate extends IFederatedComputeSe @NonNull private final Injector mInjector; - public FederatedComputeManagingServiceDelegate(@NonNull Context context) { - this(context, new Injector()); + public FederatedComputeManagingServiceDelegate( + @NonNull Context context, FederatedComputeStatsdLogger federatedComputeStatsdLogger) { + this(context, new Injector(), federatedComputeStatsdLogger, MonotonicClock.getInstance()); } @VisibleForTesting public FederatedComputeManagingServiceDelegate( - @NonNull Context context, @NonNull Injector injector) { + @NonNull Context context, + @NonNull Injector injector, + FederatedComputeStatsdLogger federatedComputeStatsdLogger, + Clock clock) { mContext = Objects.requireNonNull(context); mInjector = Objects.requireNonNull(injector); + mClock = clock; + this.mFcStatsdLogger = federatedComputeStatsdLogger; } @Override - public void scheduleFederatedCompute( + public void schedule( String callingPackageName, TrainingOptions trainingOptions, IFederatedComputeCallback callback) { @@ -65,19 +85,98 @@ public class FederatedComputeManagingServiceDelegate extends IFederatedComputeSe // READ_DEVICE_CONFIG permission. long origId = Binder.clearCallingIdentity(); if (FlagsFactory.getFlags().getGlobalKillSwitch()) { + throw new IllegalStateException( + "FederatedComputeService skipped as the global kill switch is on."); + } + Binder.restoreCallingIdentity(origId); + + Objects.requireNonNull(callingPackageName); + Objects.requireNonNull(callback); + + final long startServiceTime = mClock.elapsedRealtime(); + FederatedComputeJobManager jobManager = mInjector.getJobManager(mContext); + FederatedComputeExecutors.getBackgroundExecutor() + .execute( + () -> { + int resultCode = STATUS_SUCCESS; + try { + resultCode = + jobManager.onTrainerStartCalled( + callingPackageName, trainingOptions); + } catch (Exception e) { + resultCode = STATUS_INTERNAL_ERROR; + LogUtil.e(TAG, "Got exception for schedule()", e); + } finally { + sendResult(callback, resultCode); + int serviceLatency = + (int) (mClock.elapsedRealtime() - startServiceTime); + mFcStatsdLogger.logApiCallStats( + new ApiCallStats.Builder() + .setApiName( + FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE) + .setLatencyMillis(serviceLatency) + .setResponseCode(resultCode) + .build()); + } + }); + } + + @Override + public void cancel( + String callingPackageName, String populationName, IFederatedComputeCallback callback) { + // Use FederatedCompute instead of caller permission to read experiment flags. It requires + // READ_DEVICE_CONFIG permission. + long origId = Binder.clearCallingIdentity(); + if (FlagsFactory.getFlags().getGlobalKillSwitch()) { throw new IllegalStateException("Service skipped as the global kill switch is on."); } Binder.restoreCallingIdentity(origId); Objects.requireNonNull(callingPackageName); Objects.requireNonNull(callback); + Objects.requireNonNull(populationName); + final long startServiceTime = mClock.elapsedRealtime(); FederatedComputeJobManager jobManager = mInjector.getJobManager(mContext); FederatedComputeExecutors.getBackgroundExecutor() .execute( () -> { - jobManager.onTrainerStartCalled( - callingPackageName, trainingOptions, callback); + int resultCode = STATUS_SUCCESS; + try { + resultCode = + jobManager.onTrainerStopCalled( + callingPackageName, populationName); + } catch (Exception e) { + resultCode = STATUS_INTERNAL_ERROR; + LogUtil.e( + TAG, + e, + "Got exception when call Cancel %s", + populationName); + } finally { + sendResult(callback, resultCode); + int serviceLatency = + (int) (mClock.elapsedRealtime() - startServiceTime); + mFcStatsdLogger.logApiCallStats( + new ApiCallStats.Builder() + .setApiName( + FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL) + .setLatencyMillis(serviceLatency) + .setResponseCode(resultCode) + .build()); + } }); } + + private void sendResult(@NonNull IFederatedComputeCallback callback, int resultCode) { + try { + if (resultCode == STATUS_SUCCESS) { + callback.onSuccess(); + return; + } + callback.onFailure(resultCode); + } catch (RemoteException e) { + LogUtil.e(TAG, e, "Callback error"); + } + } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceImpl.java b/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceImpl.java index 2f0234bf..7d8db4d8 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceImpl.java +++ b/federatedcompute/src/com/android/federatedcompute/services/FederatedComputeManagingServiceImpl.java @@ -20,6 +20,8 @@ import android.app.Service; import android.content.Intent; import android.os.IBinder; +import com.android.federatedcompute.services.statsd.FederatedComputeStatsdLogger; + import java.util.Objects; /** Implementation of FederatedCompute Service */ @@ -30,7 +32,9 @@ public class FederatedComputeManagingServiceImpl extends Service { public void onCreate() { super.onCreate(); if (mFcpServiceDelegate == null) { - mFcpServiceDelegate = new FederatedComputeManagingServiceDelegate(this); + mFcpServiceDelegate = + new FederatedComputeManagingServiceDelegate( + this, FederatedComputeStatsdLogger.getInstance()); } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/BatteryInfo.java b/federatedcompute/src/com/android/federatedcompute/services/common/BatteryInfo.java index 7bfe1974..454ed1fb 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/common/BatteryInfo.java +++ b/federatedcompute/src/com/android/federatedcompute/services/common/BatteryInfo.java @@ -20,11 +20,12 @@ import android.content.Context; import android.content.Intent; import android.content.IntentFilter; import android.os.BatteryManager; -import android.util.Log; + +import com.android.federatedcompute.internal.util.LogUtil; /** Checks the battery status of the device. */ public class BatteryInfo { - private static final String TAG = "BatteryInfo"; + private static final String TAG = BatteryInfo.class.getSimpleName(); private final Context mContext; private final Flags mFlags; @@ -41,7 +42,7 @@ public class BatteryInfo { if (level != -1 && scale > 0) { return level / (float) scale; } else { - Log.e(TAG, "Bad Battery Changed intent: batteryLevel=" + level + ", scale=" + scale); + LogUtil.e(TAG, "Bad Battery Changed intent: batteryLevel= %d, scale=%d", level, scale); return -1; } } @@ -61,11 +62,11 @@ public class BatteryInfo { // Check if battery level is sufficient. float minBatteryLevel = mFlags.getTrainingMinBatteryLevel() / 100.0f; if (requireBatteryNotLow && level < minBatteryLevel) { - Log.i( + LogUtil.i( TAG, - String.format( - "Battery level insufficient (%f < %f), skipping training.", - level, minBatteryLevel)); + "Battery level insufficient (%f < %f), skipping training.", + level, + minBatteryLevel); return false; } return true; diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java b/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java index 92ed6ec8..33b64603 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java +++ b/federatedcompute/src/com/android/federatedcompute/services/common/Constants.java @@ -18,25 +18,23 @@ package com.android.federatedcompute.services.common; /** Constants used internally in the FederatedCompute APK. */ public class Constants { - public static final String EXTRA_COLLECTION_NAME = "android.federatedcompute.collection_name"; - public static final String EXTRA_EXAMPLE_ITERATOR_CRITERIA = - "android.federatedcompute.example_iterator_criteria"; - public static final String EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN = - "android.federatedcompute.example_iterator_resumption_token"; public static final String EXTRA_EXAMPLE_STORE_ITERATOR_BINDER = "android.federatedcompute.example_store_iterator_binder"; - public static final String EXTRA_RESULT_HANDLING_SERVICE_BINDER = - "android.federatedcompute.result_handling_service_binder"; public static final String EXTRA_INPUT_CHECKPOINT_FD = "android.federatedcompute.input_checkpoint_fd"; public static final String EXTRA_OUTPUT_CHECKPOINT_FD = "android.federatedcompute.output_checkpoint_fd"; - public static final String EXTRA_POPULATION_NAME = "android.federatedcompute.population_name"; public static final String EXTRA_FL_RUNNER_RESULT = "android.federatedcompute.fl_runner_result"; public static final String EXTRA_JOB_ID = "android.federatedcompute.job_id"; public static final String EXTRA_EXAMPLE_SELECTOR = "android.federatedcompute.example_selector"; - public static final String EXTRA_CLIENT_ONLY_PLAN = "android.federatedcompute.client_only_plan"; + public static final String EXTRA_CLIENT_ONLY_PLAN_FD = + "android.federatedcompute.client_only_plan_fd"; + + public static final String CLIENT_ONLY_PLAN_FILE_NAME = "federated_client_only_plan"; + + public static final String ISOLATED_TRAINING_SERVICE_NAME = + "com.android.federatedcompute.services.training.IsolatedTrainingService"; private Constants() {} } diff --git a/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java b/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java new file mode 100644 index 00000000..75490297 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/common/FileUtils.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.common; + +import android.os.ParcelFileDescriptor; + +import com.android.federatedcompute.internal.util.LogUtil; + +import java.io.BufferedInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +/** Utils related to {@link File} and {@link ParcelFileDescriptor}. */ +public class FileUtils { + private static final String TAG = FileUtils.class.getSimpleName(); + + private static final int BUFFER_SIZE = 1024; + + /** Create {@link ParcelFileDescriptor} based on the input file. */ + public static ParcelFileDescriptor createTempFileDescriptor(String fileName, int mode) { + ParcelFileDescriptor fileDescriptor; + try { + fileDescriptor = ParcelFileDescriptor.open(new File(fileName), mode); + } catch (IOException e) { + LogUtil.e(TAG, e, "Failed to createTempFileDescriptor %s", fileName); + throw new RuntimeException(e); + } + return fileDescriptor; + } + + /** Create a temporary file based on provided name and extension. */ + public static String createTempFile(String name, String extension) { + String fileName; + try { + File tempFile = File.createTempFile(name, extension); + fileName = tempFile.getAbsolutePath(); + } catch (IOException e) { + throw new RuntimeException(e); + } + return fileName; + } + + /** Write the provided data to the file. */ + public static void writeToFile(String fileName, byte[] data) throws IOException { + FileOutputStream out = new FileOutputStream(fileName); + out.write(data); + out.close(); + } + + /** Read the input file content to a byte array. */ + public static byte[] readFileAsByteArray(String filePath) throws IOException { + File file = new File(filePath); + long fileLength = file.length(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) fileLength); + try (BufferedInputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + byte[] buffer = new byte[BUFFER_SIZE]; + for (int len = inputStream.read(buffer); len > 0; len = inputStream.read(buffer)) { + outputStream.write(buffer, 0, len); + } + } catch (IOException e) { + LogUtil.e(TAG, e, "Failed to read the content of binary file %s", filePath); + throw e; + } + return outputStream.toByteArray(); + } + + /** Read the content from a file descriptor to a byte array. */ + public static byte[] readFileDescriptorAsByteArray(ParcelFileDescriptor fd) { + InputStream inputStream = new ParcelFileDescriptor.AutoCloseInputStream(fd); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + byte[] buffer = new byte[1024]; + int bytesRead; + try { + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } + inputStream.close(); + } catch (IOException e) { + LogUtil.e(TAG, e, "Failed to read the content of binary file %d", fd.getFd()); + } + return outputStream.toByteArray(); + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDbHelper.java b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeDbHelper.java index a4a46766..92ff384c 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDbHelper.java +++ b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeDbHelper.java @@ -16,23 +16,25 @@ package com.android.federatedcompute.services.data; +import static com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.ENCRYPTION_KEY_TABLE; import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE; import android.content.Context; import android.database.sqlite.SQLiteDatabase; import android.database.sqlite.SQLiteOpenHelper; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.FederatedComputeEncryptionColumns; import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns; import com.android.internal.annotations.VisibleForTesting; /** Helper to manage FederatedTrainingTask database. */ -public class FederatedTrainingTaskDbHelper extends SQLiteOpenHelper { +public class FederatedComputeDbHelper extends SQLiteOpenHelper { - private static final String TAG = "FederatedTrainingTaskDbHelper"; + private static final String TAG = FederatedComputeDbHelper.class.getSimpleName(); private static final int DATABASE_VERSION = 1; - private static final String DATABASE_NAME = "trainingtasks.db"; + private static final String DATABASE_NAME = "federatedcompute.db"; private static final String CREATE_TRAINING_TASK_TABLE = "CREATE TABLE " + FEDERATED_TRAINING_TASKS_TABLE @@ -66,32 +68,51 @@ public class FederatedTrainingTaskDbHelper extends SQLiteOpenHelper { + FederatedTrainingTaskColumns.SCHEDULING_REASON + " INTEGER )"; - private static FederatedTrainingTaskDbHelper sInstance = null; + private static final String CREATE_ENCRYPTION_KEY_TABLE = + "CREATE TABLE " + + ENCRYPTION_KEY_TABLE + + " ( " + + FederatedComputeEncryptionColumns.KEY_IDENTIFIER + + " TEXT PRIMARY KEY, " + + FederatedComputeEncryptionColumns.PUBLIC_KEY + + " TEXT NOT NULL, " + + FederatedComputeEncryptionColumns.KEY_TYPE + + " INTEGER, " + + FederatedComputeEncryptionColumns.CREATION_TIME + + " INTEGER NOT NULL, " + + FederatedComputeEncryptionColumns.EXPIRY_TIME + + " INTEGER NOT NULL)"; - private FederatedTrainingTaskDbHelper(Context context, String dbName) { + private static volatile FederatedComputeDbHelper sInstance = null; + + private FederatedComputeDbHelper(Context context, String dbName) { super(context, dbName, null, DATABASE_VERSION); } - /** Returns an instance of the FederatedTrainingTaskDbHelper given a context. */ - public static FederatedTrainingTaskDbHelper getInstance(Context context) { - synchronized (FederatedTrainingTaskDbHelper.class) { - if (sInstance == null) { - sInstance = new FederatedTrainingTaskDbHelper(context, DATABASE_NAME); + /** Returns an instance of the FederatedComputeDbHelper given a context. */ + public static FederatedComputeDbHelper getInstance(Context context) { + if (sInstance == null) { + synchronized (FederatedComputeDbHelper.class) { + if (sInstance == null) { + sInstance = + new FederatedComputeDbHelper( + context.getApplicationContext(), DATABASE_NAME); + } } - return sInstance; } + return sInstance; } /** - * Returns an instance of the FederatedTrainingTaskDbHelper given a context. This is used for - * testing only. + * Returns an instance of the FederatedComputeDbHelper given a context. This is used for testing + * only. */ @VisibleForTesting - public static FederatedTrainingTaskDbHelper getInstanceForTest(Context context) { - synchronized (FederatedTrainingTaskDbHelper.class) { + public static FederatedComputeDbHelper getInstanceForTest(Context context) { + synchronized (FederatedComputeDbHelper.class) { if (sInstance == null) { // Use null database name to make it in-memory - sInstance = new FederatedTrainingTaskDbHelper(context, null); + sInstance = new FederatedComputeDbHelper(context, null); } return sInstance; } @@ -100,12 +121,13 @@ public class FederatedTrainingTaskDbHelper extends SQLiteOpenHelper { @Override public void onCreate(SQLiteDatabase db) { db.execSQL(CREATE_TRAINING_TASK_TABLE); + db.execSQL(CREATE_ENCRYPTION_KEY_TABLE); } @Override public void onUpgrade(SQLiteDatabase db, int oldVersion, int newVersion) { // TODO: handle upgrade when the db schema is changed. - Log.d(TAG, "DB upgrade from " + oldVersion + " to " + newVersion); + LogUtil.d(TAG, "DB upgrade from %d to %d", oldVersion, newVersion); } @VisibleForTesting @@ -113,6 +135,7 @@ public class FederatedTrainingTaskDbHelper extends SQLiteOpenHelper { // Delete and recreate the database. // These tables must be dropped in order because of database constraints. db.execSQL("DROP TABLE IF EXISTS " + FEDERATED_TRAINING_TASKS_TABLE); + db.execSQL("DROP TABLE IF EXISTS " + ENCRYPTION_KEY_TABLE); onCreate(db); } @@ -124,7 +147,7 @@ public class FederatedTrainingTaskDbHelper extends SQLiteOpenHelper { /** It's only public to testing. */ @VisibleForTesting public static void resetInstance() { - synchronized (FederatedTrainingTaskDbHelper.class) { + synchronized (FederatedComputeDbHelper.class) { if (sInstance != null) { sInstance.close(); sInstance = null; diff --git a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKey.java b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKey.java new file mode 100644 index 00000000..6d0d7b7f --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKey.java @@ -0,0 +1,339 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.data; + +import android.annotation.NonNull; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +import java.io.Serializable; + +/** The details of a federated compute encryption key. */ +@DataClass(genHiddenBuilder = true, genEqualsHashCode = true) +public class FederatedComputeEncryptionKey implements Serializable { + + + /** Define the key type as enum. + * Currently keys are used to encrypt results only. Keys might be used to + * sign (and verify on server) in the future. + */ + public static final int KEY_TYPE_UNDEFINED = 0; + + public static final int KEY_TYPE_ENCRYPTION = 1; + + /** + * @return the key identifier. + */ + @NonNull private final String mKeyIdentifier; + + /** + * @return the public key. + */ + @NonNull private final String mPublicKey; + + /** + * @return the key type. + */ + @KeyType private final int mKeyType; + + /** + * @return the creation time in milliseconds. + */ + private final long mCreationTime; + + /** + * @return the expiry time in milliseconds. + */ + private final long mExpiryTime; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKey.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @android.annotation.IntDef(prefix = "KEY_TYPE_", value = { + KEY_TYPE_UNDEFINED, + KEY_TYPE_ENCRYPTION + }) + @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.SOURCE) + @DataClass.Generated.Member + public @interface KeyType {} + + @DataClass.Generated.Member + public static String keyTypeToString(@KeyType int value) { + switch (value) { + case KEY_TYPE_UNDEFINED: + return "KEY_TYPE_UNDEFINED"; + case KEY_TYPE_ENCRYPTION: + return "KEY_TYPE_ENCRYPTION"; + default: return Integer.toHexString(value); + } + } + + @DataClass.Generated.Member + /* package-private */ FederatedComputeEncryptionKey( + @NonNull String keyIdentifier, + @NonNull String publicKey, + @KeyType int keyType, + long creationTime, + long expiryTime) { + this.mKeyIdentifier = keyIdentifier; + AnnotationValidations.validate( + NonNull.class, null, mKeyIdentifier); + this.mPublicKey = publicKey; + AnnotationValidations.validate( + NonNull.class, null, mPublicKey); + this.mKeyType = keyType; + + if (!(mKeyType == KEY_TYPE_UNDEFINED) + && !(mKeyType == KEY_TYPE_ENCRYPTION)) { + throw new java.lang.IllegalArgumentException( + "keyType was " + mKeyType + " but must be one of: " + + "KEY_TYPE_UNDEFINED(" + KEY_TYPE_UNDEFINED + "), " + + "KEY_TYPE_ENCRYPTION(" + KEY_TYPE_ENCRYPTION + ")"); + } + + this.mCreationTime = creationTime; + this.mExpiryTime = expiryTime; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * @return the key identifier. + */ + @DataClass.Generated.Member + public @NonNull String getKeyIdentifier() { + return mKeyIdentifier; + } + + /** + * @return the public key. + */ + @DataClass.Generated.Member + public @NonNull String getPublicKey() { + return mPublicKey; + } + + /** + * @return the key type. + */ + @DataClass.Generated.Member + public @KeyType int getKeyType() { + return mKeyType; + } + + /** + * @return the creation time in milliseconds. + */ + @DataClass.Generated.Member + public long getCreationTime() { + return mCreationTime; + } + + /** + * @return the expiry time in milliseconds. + */ + @DataClass.Generated.Member + public long getExpiryTime() { + return mExpiryTime; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(FederatedComputeEncryptionKey other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + FederatedComputeEncryptionKey that = (FederatedComputeEncryptionKey) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mKeyIdentifier, that.mKeyIdentifier) + && java.util.Objects.equals(mPublicKey, that.mPublicKey) + && mKeyType == that.mKeyType + && mCreationTime == that.mCreationTime + && mExpiryTime == that.mExpiryTime; + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mKeyIdentifier); + _hash = 31 * _hash + java.util.Objects.hashCode(mPublicKey); + _hash = 31 * _hash + mKeyType; + _hash = 31 * _hash + Long.hashCode(mCreationTime); + _hash = 31 * _hash + Long.hashCode(mExpiryTime); + return _hash; + } + + /** + * A builder for {@link FederatedComputeEncryptionKey} + * @hide + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static class Builder { + + private @NonNull String mKeyIdentifier; + private @NonNull String mPublicKey; + private @KeyType int mKeyType; + private long mCreationTime; + private long mExpiryTime; + + private long mBuilderFieldsSet = 0L; + + public Builder() {} + + /** + * Creates a new Builder. + * + */ + public Builder( + @NonNull String keyIdentifier, + @NonNull String publicKey, + @KeyType int keyType, + long creationTime, + long expiryTime) { + mKeyIdentifier = keyIdentifier; + AnnotationValidations.validate( + NonNull.class, null, mKeyIdentifier); + mPublicKey = publicKey; + AnnotationValidations.validate( + NonNull.class, null, mPublicKey); + mKeyType = keyType; + + if (!(mKeyType == KEY_TYPE_UNDEFINED) + && !(mKeyType == KEY_TYPE_ENCRYPTION)) { + throw new java.lang.IllegalArgumentException( + "keyType was " + mKeyType + " but must be one of: " + + "KEY_TYPE_UNDEFINED(" + KEY_TYPE_UNDEFINED + "), " + + "KEY_TYPE_ENCRYPTION(" + KEY_TYPE_ENCRYPTION + ")"); + } + + mCreationTime = creationTime; + mExpiryTime = expiryTime; + } + + /** + * @return the key identifier. + */ + @DataClass.Generated.Member + public @NonNull Builder setKeyIdentifier(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mKeyIdentifier = value; + return this; + } + + /** + * @return the public key. + */ + @DataClass.Generated.Member + public @NonNull Builder setPublicKey(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mPublicKey = value; + return this; + } + + /** + * @return the key type. + */ + @DataClass.Generated.Member + public @NonNull Builder setKeyType(@KeyType int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mKeyType = value; + return this; + } + + /** + * @return the creation time in milliseconds. + */ + @DataClass.Generated.Member + public @NonNull Builder setCreationTime(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; + mCreationTime = value; + return this; + } + + /** + * @return the expiry time in milliseconds. + */ + @DataClass.Generated.Member + public @NonNull Builder setExpiryTime(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x10; + mExpiryTime = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull FederatedComputeEncryptionKey build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x20; // Mark builder used + + FederatedComputeEncryptionKey o = new FederatedComputeEncryptionKey( + mKeyIdentifier, + mPublicKey, + mKeyType, + mCreationTime, + mExpiryTime); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x20) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1698371312320L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKey.java", + inputSignatures = "public static final int KEY_TYPE_UNDEFINED\npublic static final int KEY_TYPE_ENCRYPTION\nprivate final @android.annotation.NonNull java.lang.String mKeyIdentifier\nprivate final @android.annotation.NonNull java.lang.String mPublicKey\nprivate final @com.android.federatedcompute.services.data.FederatedComputeEncryptionKey.KeyType int mKeyType\nprivate final long mCreationTime\nprivate final long mExpiryTime\nclass FederatedComputeEncryptionKey extends java.lang.Object implements [java.io.Serializable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genHiddenBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyContract.java b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyContract.java new file mode 100644 index 00000000..f5c6d079 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyContract.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.data; + +public final class FederatedComputeEncryptionKeyContract { + public static final String ENCRYPTION_KEY_TABLE = "encryption_keys"; + + private FederatedComputeEncryptionKeyContract() {} + + public static final class FederatedComputeEncryptionColumns { + private FederatedComputeEncryptionColumns() {} + + /** + * A unique identifier of the key, in thd form of UUID. FCP server uses key_identifier to + * get private key. + */ + public static final String KEY_IDENTIFIER = "key_identifier"; + + /** The public key base64 encoded. */ + public static final String PUBLIC_KEY = "public_key"; + + /** + * The type of the key in @link {com.android.federatedcompute.services.data.fbs.KeyType} + * Currently only encryption key is allowed. + */ + public static final String KEY_TYPE = "key_type"; + + /** Creation instant of the key in the database in milliseconds. */ + public static final String CREATION_TIME = "creation_time"; + + /** Expiry time of the key in milliseconds. */ + public static final String EXPIRY_TIME = "expiry_time"; + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyDao.java b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyDao.java new file mode 100644 index 00000000..b1cb18be --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyDao.java @@ -0,0 +1,235 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.data; + +import static com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.ENCRYPTION_KEY_TABLE; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.content.ContentValues; +import android.content.Context; +import android.database.Cursor; +import android.database.sqlite.SQLiteDatabase; +import android.database.sqlite.SQLiteException; +import android.database.sqlite.SQLiteOpenHelper; + +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.common.Clock; +import com.android.federatedcompute.services.common.MonotonicClock; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.FederatedComputeEncryptionColumns; + +import com.google.common.annotations.VisibleForTesting; + +import java.util.ArrayList; +import java.util.List; + +/** DAO for accessing encryption key table */ +public class FederatedComputeEncryptionKeyDao { + private static final String TAG = FederatedComputeEncryptionKeyDao.class.getSimpleName(); + + private final SQLiteOpenHelper mDbHelper; + + private final Clock mClock; + + private static volatile FederatedComputeEncryptionKeyDao sSingletonInstance; + + private FederatedComputeEncryptionKeyDao(SQLiteOpenHelper dbHelper, Clock clock) { + mDbHelper = dbHelper; + mClock = clock; + } + + /** + * @return an instance of FederatedComputeEncryptionKeyDao given a context + */ + @NonNull + public static FederatedComputeEncryptionKeyDao getInstance(Context context) { + if (sSingletonInstance == null) { + synchronized (FederatedComputeEncryptionKeyDao.class) { + if (sSingletonInstance == null) { + sSingletonInstance = + new FederatedComputeEncryptionKeyDao( + FederatedComputeDbHelper.getInstance(context), + MonotonicClock.getInstance()); + } + } + } + return sSingletonInstance; + } + + /** It is only public to unit test. */ + @VisibleForTesting + public static FederatedComputeEncryptionKeyDao getInstanceForTest(Context context) { + if (sSingletonInstance == null) { + synchronized (FederatedComputeEncryptionKeyDao.class) { + if (sSingletonInstance == null) { + FederatedComputeDbHelper dbHelper = + FederatedComputeDbHelper.getInstanceForTest(context); + Clock clk = MonotonicClock.getInstance(); + sSingletonInstance = new FederatedComputeEncryptionKeyDao(dbHelper, clk); + } + } + } + return sSingletonInstance; + } + + /** Insert a key to the encryption_key table. */ + public boolean insertEncryptionKey(FederatedComputeEncryptionKey key) { + + SQLiteDatabase db = getWritableDatabase(); + if (db == null) { + throw new SQLiteException(TAG + ": Failed to open database."); + } + + ContentValues values = new ContentValues(); + values.put(FederatedComputeEncryptionColumns.KEY_IDENTIFIER, key.getKeyIdentifier()); + values.put(FederatedComputeEncryptionColumns.PUBLIC_KEY, key.getPublicKey()); + values.put(FederatedComputeEncryptionColumns.KEY_TYPE, key.getKeyType()); + values.put(FederatedComputeEncryptionColumns.CREATION_TIME, key.getCreationTime()); + values.put(FederatedComputeEncryptionColumns.EXPIRY_TIME, key.getExpiryTime()); + + long jobId = + db.insertWithOnConflict( + ENCRYPTION_KEY_TABLE, "", values, SQLiteDatabase.CONFLICT_REPLACE); + return jobId != -1; + } + + /** + * Read from encryption key table given selection, order and limit conidtions. + * + * @return a list of {@link FederatedComputeEncryptionKey}. + */ + @VisibleForTesting + public List<FederatedComputeEncryptionKey> readFederatedComputeEncryptionKeysFromDatabase( + String selection, String[] selectionArgs, String orderBy, int count) { + List<FederatedComputeEncryptionKey> keyList = new ArrayList<>(); + SQLiteDatabase db = getReadableDatabase(); + if (db == null) { + throw new SQLiteException(TAG + ": Failed to open database."); + } + + String[] selectColumns = { + FederatedComputeEncryptionColumns.KEY_IDENTIFIER, + FederatedComputeEncryptionColumns.PUBLIC_KEY, + FederatedComputeEncryptionColumns.KEY_TYPE, + FederatedComputeEncryptionColumns.CREATION_TIME, + FederatedComputeEncryptionColumns.EXPIRY_TIME + }; + + Cursor cursor = null; + try { + cursor = + db.query( + ENCRYPTION_KEY_TABLE, + selectColumns, + selection, + selectionArgs, + null + /* groupBy= */ , + null + /* having= */ , + orderBy + /* orderBy= */ , + String.valueOf(count) + /* limit= */); + while (cursor.moveToNext()) { + FederatedComputeEncryptionKey.Builder encryptionKeyBuilder = + new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier( + cursor.getString( + cursor.getColumnIndexOrThrow( + FederatedComputeEncryptionColumns + .KEY_IDENTIFIER))) + .setPublicKey( + cursor.getString( + cursor.getColumnIndexOrThrow( + FederatedComputeEncryptionColumns + .PUBLIC_KEY))) + .setKeyType( + cursor.getInt( + cursor.getColumnIndexOrThrow( + FederatedComputeEncryptionColumns + .KEY_TYPE))) + .setCreationTime( + cursor.getLong( + cursor.getColumnIndexOrThrow( + FederatedComputeEncryptionColumns + .CREATION_TIME))) + .setExpiryTime( + cursor.getLong( + cursor.getColumnIndexOrThrow( + FederatedComputeEncryptionColumns + .EXPIRY_TIME))); + keyList.add(encryptionKeyBuilder.build()); + } + } finally { + if (cursor != null) { + cursor.close(); + } + } + return keyList; + } + + /** + * @return latest expired keys (order by expiry time). + */ + public List<FederatedComputeEncryptionKey> getLatestExpiryNKeys(int count) { + String selection = FederatedComputeEncryptionColumns.EXPIRY_TIME + " > ?"; + String[] selectionArgs = {String.valueOf(mClock.currentTimeMillis())}; + // reverse order of expiry time + String orderBy = FederatedComputeEncryptionColumns.EXPIRY_TIME + " DESC"; + return readFederatedComputeEncryptionKeysFromDatabase( + selection, selectionArgs, orderBy, count); + } + + /** + * Delete expired keys. + * + * @return number of keys deleted. + */ + public int deleteExpiredKeys() { + SQLiteDatabase db = getWritableDatabase(); + if (db == null) { + throw new SQLiteException(TAG + ": Failed to open database."); + } + String whereClause = FederatedComputeEncryptionColumns.EXPIRY_TIME + " < ?"; + String[] whereArgs = {String.valueOf(mClock.currentTimeMillis())}; + int deletedRows = db.delete(ENCRYPTION_KEY_TABLE, whereClause, whereArgs); + LogUtil.d(TAG, "Deleted %s expired keys from database", deletedRows); + return deletedRows; + } + + @Nullable + private SQLiteDatabase getReadableDatabase() { + try { + return mDbHelper.getReadableDatabase(); + } catch (SQLiteException e) { + LogUtil.e(TAG, e, "Failed to open the database."); + } + return null; + } + + /* @return a writable database object or null if error occurs. */ + @Nullable + private SQLiteDatabase getWritableDatabase() { + try { + return mDbHelper.getWritableDatabase(); + } catch (SQLiteException e) { + LogUtil.e(TAG, e, "Failed to open the database."); + } + return null; + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTask.java b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTask.java index c71340e6..a14dd319 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTask.java +++ b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTask.java @@ -103,12 +103,22 @@ public abstract class FederatedTrainingTask { @Nullable public abstract Long lastRunStartTime(); + @NonNull + public long getLastRunStartTime() { + return lastRunStartTime() == null ? 0 : lastRunStartTime(); + } + /** * @return the end time of the task's last run. */ @Nullable public abstract Long lastRunEndTime(); + @NonNull + public long getLastRunEndTime() { + return lastRunEndTime() == null ? 0 : lastRunEndTime(); + } + /** * @return the earliest time to run the task by. */ @@ -263,7 +273,7 @@ public abstract class FederatedTrainingTask { null /* having= */ , null - /* orderBy= */); + /* orderBy= */ ); while (cursor.moveToNext()) { FederatedTrainingTask.Builder trainingTaskBuilder = FederatedTrainingTask.builder() diff --git a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDao.java b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDao.java index 67bac172..409505f7 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDao.java +++ b/federatedcompute/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDao.java @@ -24,8 +24,8 @@ import android.content.Context; import android.database.sqlite.SQLiteDatabase; import android.database.sqlite.SQLiteException; import android.database.sqlite.SQLiteOpenHelper; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns; import com.google.common.annotations.VisibleForTesting; @@ -36,10 +36,10 @@ import java.util.List; /** DAO for accessing training task table. */ public class FederatedTrainingTaskDao { - private static final String TAG = "FederatedTrainingTaskDao"; + private static final String TAG = FederatedTrainingTaskDao.class.getSimpleName(); private final SQLiteOpenHelper mDbHelper; - private static FederatedTrainingTaskDao sSingletonInstance; + private static volatile FederatedTrainingTaskDao sSingletonInstance; private FederatedTrainingTaskDao(SQLiteOpenHelper dbHelper) { this.mDbHelper = dbHelper; @@ -48,14 +48,16 @@ public class FederatedTrainingTaskDao { /** Returns an instance of the FederatedTrainingTaskDao given a context. */ @NonNull public static FederatedTrainingTaskDao getInstance(Context context) { - synchronized (FederatedTrainingTaskDao.class) { - if (sSingletonInstance == null) { - sSingletonInstance = - new FederatedTrainingTaskDao( - FederatedTrainingTaskDbHelper.getInstance(context)); + if (sSingletonInstance == null) { + synchronized (FederatedTrainingTaskDao.class) { + if (sSingletonInstance == null) { + sSingletonInstance = + new FederatedTrainingTaskDao( + FederatedComputeDbHelper.getInstance(context)); + } } - return sSingletonInstance; } + return sSingletonInstance; } /** It's only public to unit test. */ @@ -63,8 +65,8 @@ public class FederatedTrainingTaskDao { public static FederatedTrainingTaskDao getInstanceForTest(Context context) { synchronized (FederatedTrainingTaskDao.class) { if (sSingletonInstance == null) { - FederatedTrainingTaskDbHelper dbHelper = - FederatedTrainingTaskDbHelper.getInstanceForTest(context); + FederatedComputeDbHelper dbHelper = + FederatedComputeDbHelper.getInstanceForTest(context); sSingletonInstance = new FederatedTrainingTaskDao(dbHelper); } return sSingletonInstance; @@ -167,7 +169,7 @@ public class FederatedTrainingTaskDao { try { return mDbHelper.getWritableDatabase(); } catch (SQLiteException e) { - Log.e(TAG, "Failed to open the database.", e); + LogUtil.e(TAG, e, "Failed to open the database."); } return null; } diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java index 22dcf9b1..fd7dc5ef 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java +++ b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorder.java @@ -39,7 +39,7 @@ public class ExampleConsumptionRecorder { /** Records information for a single query. */ @ThreadSafe public static class SingleQueryRecorder { - private final String mCollection; + private final String mTaskName; private final byte[] mCriteria; // The pair of example count and resumption token needs to be updated atomically, therefore @@ -52,8 +52,8 @@ public class ExampleConsumptionRecorder { @GuardedBy("SingleQueryRecorder.this") private byte[] mResumptionToken; - private SingleQueryRecorder(String collection, byte[] criteria) { - this.mCollection = collection; + private SingleQueryRecorder(String taskName, byte[] criteria) { + this.mTaskName = taskName; this.mCriteria = criteria; this.mExampleCount = 0; this.mResumptionToken = null; @@ -72,7 +72,7 @@ public class ExampleConsumptionRecorder { /** Returns a single recorded of {@link ExampleConsumption}. */ public synchronized ExampleConsumption finishRecordingAndGet() { return new ExampleConsumption.Builder() - .setCollectionName(mCollection) + .setTaskName(mTaskName) .setSelectionCriteria(mCriteria) .setExampleCount(mExampleCount) .setResumptionToken(mResumptionToken) @@ -80,17 +80,17 @@ public class ExampleConsumptionRecorder { } } - /** Create a {@link SingleQueryRecorder} for the current query. */ + /** Create a {@link SingleQueryRecorder} for the current task. */ public synchronized SingleQueryRecorder createRecorderForTracking( - String collection, byte[] criteria) { - SingleQueryRecorder recorder = new SingleQueryRecorder(collection, criteria); + String taskName, byte[] criteria) { + SingleQueryRecorder recorder = new SingleQueryRecorder(taskName, criteria); mSingleQueryRecorders.add(recorder); return recorder; } /** Returns all recorded {@link ExampleConsumption}. */ - public synchronized List<ExampleConsumption> finishRecordingAndGet() { - List<ExampleConsumption> exampleConsumptions = new ArrayList<>(); + public synchronized ArrayList<ExampleConsumption> finishRecordingAndGet() { + ArrayList<ExampleConsumption> exampleConsumptions = new ArrayList<>(); for (SingleQueryRecorder recorder : mSingleQueryRecorders) { exampleConsumptions.add(recorder.finishRecordingAndGet()); } diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProvider.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProvider.java deleted file mode 100644 index c928134c..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProvider.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.examplestore; - -import android.federatedcompute.aidl.IExampleStoreIterator; - -import com.android.federatedcompute.services.common.ErrorStatusException; - -import com.google.internal.federated.plan.ExampleSelector; - -/** Interface used to provide a reference to the IExampleStoreIterator. */ -public interface ExampleStoreIteratorProvider { - - /** Returns the selected ExampleStoreIterator. */ - IExampleStoreIterator getExampleStoreIterator( - String packageName, ExampleSelector exampleSelector) - throws InterruptedException, ErrorStatusException; -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java deleted file mode 100644 index 982bca53..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImpl.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.examplestore; - -import android.content.Intent; -import android.federatedcompute.aidl.IExampleStoreCallback; -import android.federatedcompute.aidl.IExampleStoreIterator; -import android.federatedcompute.aidl.IExampleStoreService; -import android.federatedcompute.common.ClientConstants; -import android.net.Uri; -import android.os.Bundle; -import android.os.RemoteException; -import android.util.Log; -import android.util.Pair; - -import com.android.federatedcompute.services.common.Constants; -import com.android.federatedcompute.services.common.ErrorStatusException; -import com.android.federatedcompute.services.common.Flags; - -import com.google.common.util.concurrent.SettableFuture; -import com.google.common.util.concurrent.UncheckedExecutionException; -import com.google.internal.federated.plan.ExampleSelector; - -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -/** Implementation of the ExampleStoreIterator interface. */ -public class ExampleStoreIteratorProviderImpl implements ExampleStoreIteratorProvider { - private static final String TAG = "ExampleStoreIteratorProvider"; - private final ExampleStoreServiceProvider mExampleStoreServiceProvider; - private final Flags mFlags; - - public ExampleStoreIteratorProviderImpl( - ExampleStoreServiceProvider exampleStoreServiceProvider, Flags flags) { - this.mExampleStoreServiceProvider = exampleStoreServiceProvider; - this.mFlags = flags; - } - - @Override - public IExampleStoreIterator getExampleStoreIterator( - String packageName, ExampleSelector exampleSelector) - throws InterruptedException, ErrorStatusException { - String collection = exampleSelector.getCollectionUri(); - byte[] criteria = exampleSelector.getCriteria().toByteArray(); - byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray(); - Intent intent = new Intent(); - intent.setAction(ClientConstants.EXAMPLE_STORE_ACTION).setPackage(packageName); - intent.setData( - new Uri.Builder().scheme("app").authority(packageName).path(collection).build()); - Log.d(TAG, "Attempting to bind to example store service: " + intent); - if (!mExampleStoreServiceProvider.bindService(intent)) { - Log.w(TAG, "bindService failed for example store service: " + intent); - mExampleStoreServiceProvider.unbindService(); - return null; - } - IExampleStoreService exampleStoreService = - mExampleStoreServiceProvider.getExampleStoreService(); - Bundle bundle = new Bundle(); - bundle.putString(Constants.EXTRA_COLLECTION_NAME, collection); - bundle.putByteArray(Constants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken); - bundle.putByteArray(Constants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria); - SettableFuture<Pair<IExampleStoreIterator, Integer>> iteratorOrFailureFuture = - SettableFuture.create(); - try { - try { - exampleStoreService.startQuery( - bundle, - new IExampleStoreCallback.Stub() { - @Override - public void onStartQuerySuccess(IExampleStoreIterator iterator) { - Log.d(TAG, "Acquire iterator"); - iteratorOrFailureFuture.set(Pair.create(iterator, null)); - } - - @Override - public void onStartQueryFailure(int errorCode) { - Log.e(TAG, "Could not acquire iterator: " + errorCode); - iteratorOrFailureFuture.set(Pair.create(null, errorCode)); - } - }); - } catch (RemoteException e) { - Log.e(TAG, "StartQuery failure: " + e.getMessage()); - throw new IllegalStateException(e); - } - Pair<IExampleStoreIterator, Integer> iteratorOrFailure; - try { - iteratorOrFailure = - iteratorOrFailureFuture.get( - mFlags.getAppHostedExampleStoreTimeoutSecs(), TimeUnit.SECONDS); - } catch (ExecutionException e) { - // Should not happen. - throw new UncheckedExecutionException(e); - } catch (TimeoutException e) { - Log.e(TAG, "startQuery timed out: ", e); - throw new IllegalStateException( - String.format( - "startQuery timed out (%ss): %s", - mFlags.getAppHostedExampleStoreTimeoutSecs(), collection), - e); - } - if (iteratorOrFailure.second != null) { - throw new IllegalStateException( - String.format( - "onStartQueryFailure collection %s error code %d", - collection, iteratorOrFailure.second)); - } - Log.d(TAG, "Wrapping IExampleStoreIterator"); - return iteratorOrFailure.first; - } catch (Exception e) { - // If any exception is thrown in try block, we first call unbindService to avoid service - // connection hanging. - Log.d(TAG, "Unbinding from service due to exception", e); - mExampleStoreServiceProvider.unbindService(); - throw e; - } - } -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProvider.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProvider.java deleted file mode 100644 index 7bdbf064..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProvider.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.examplestore; - -import android.annotation.Nullable; -import android.content.Intent; -import android.federatedcompute.aidl.IExampleStoreService; - -/** Interface used to provide a reference to the IExampleStore. */ -public interface ExampleStoreServiceProvider { - - /** Returns the connected ExampleStoreService, or otherwise {@code null}. */ - @Nullable - IExampleStoreService getExampleStoreService(); - - /** Bind to and establish a connection with client implemented ExampleStoreService. */ - boolean bindService(Intent intent); - - /** Unbind from the client implemented ExampleStoreService. */ - void unbindService(); -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProviderImpl.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProviderImpl.java deleted file mode 100644 index 4bd6bb80..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProviderImpl.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.examplestore; - -import android.annotation.Nullable; -import android.content.ComponentName; -import android.content.Context; -import android.content.Intent; -import android.content.ServiceConnection; -import android.federatedcompute.aidl.IExampleStoreService; -import android.os.IBinder; -import android.util.Log; - -import com.android.federatedcompute.services.common.Flags; - -import com.google.common.util.concurrent.SettableFuture; -import com.google.common.util.concurrent.UncheckedExecutionException; - -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -/** Implementation of the ExampleStoreServiceProvider interface. */ -public final class ExampleStoreServiceProviderImpl implements ExampleStoreServiceProvider { - private static final String TAG = "ExampleStoreServiceProviderImpl"; - private static final Executor SINGLE_THREAD_EXECUTOR = Executors.newSingleThreadExecutor(); - private final Context mContext; - private IExampleStoreService mExampleStoreService; - private SettableFuture<IExampleStoreService> mExampleStoreServiceFuture = - SettableFuture.create(); - private boolean mBound; - private Flags mFlags; - - public ExampleStoreServiceProviderImpl(Context context, Flags flags) { - this.mContext = context; - this.mFlags = flags; - } - - ServiceConnection mServiceConnection = - new ServiceConnection() { - @Override - public void onServiceConnected(ComponentName name, IBinder service) { - if (service == null) { - Log.e(TAG, "onServiceConnected() received null binder"); - return; - } - mExampleStoreServiceFuture.set(IExampleStoreService.Stub.asInterface(service)); - mBound = true; - } - - @Override - public void onServiceDisconnected(ComponentName name) { - reset(); - Log.d(TAG, "Connection unexpectedly disconnected"); - } - }; - - @Override - @Nullable - public IExampleStoreService getExampleStoreService() { - return mExampleStoreService; - } - - @Override - public boolean bindService(Intent intent) { - if (!mContext.bindService( - intent, Context.BIND_AUTO_CREATE, SINGLE_THREAD_EXECUTOR, mServiceConnection)) { - Log.e(TAG, "Unable to bind to ExampleStoreService intent: " + intent); - return false; - } - try { - mExampleStoreService = - mExampleStoreServiceFuture.get( - mFlags.getAppHostedExampleStoreTimeoutSecs(), TimeUnit.SECONDS); - } catch (TimeoutException e) { - throw new IllegalStateException( - String.format( - "Service connection time out (%ss) for app hosted examplestore.", - mFlags.getAppHostedExampleStoreTimeoutSecs()), - e); - } catch (ExecutionException e) { - throw new UncheckedExecutionException(e); - } catch (InterruptedException e) { - Log.e(TAG, "ExampleStoreService interrupted", e); - unbindService(); - return false; - } - return true; - } - - @Override - public void unbindService() { - if (mBound) { - mContext.unbindService(mServiceConnection); - reset(); - } - } - - private void reset() { - mExampleStoreServiceFuture = SettableFuture.create(); - mExampleStoreService = null; - mBound = false; - } -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/examplestore/FederatedExampleIterator.java b/federatedcompute/src/com/android/federatedcompute/services/examplestore/FederatedExampleIterator.java index 784558dc..567908c1 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/examplestore/FederatedExampleIterator.java +++ b/federatedcompute/src/com/android/federatedcompute/services/examplestore/FederatedExampleIterator.java @@ -26,11 +26,10 @@ import android.federatedcompute.aidl.IExampleStoreIteratorCallback; import android.os.Bundle; import android.os.Looper; import android.os.RemoteException; -import android.util.Log; import android.util.Pair; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.ErrorStatusException; -import com.android.federatedcompute.services.common.Flags; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder.SingleQueryRecorder; import com.android.internal.util.Preconditions; @@ -50,14 +49,12 @@ import javax.annotation.Nullable; * main thread. */ public final class FederatedExampleIterator implements ExampleIterator { - private static final String TAG = "FederatedExampleIterator"; + private static final String TAG = FederatedExampleIterator.class.getSimpleName(); // TODO: replace with PH flag. private static final long TIMEOUT_SECS = 2L; private boolean mClosed; - private final String mCollection; - private final byte[] mCriteria; - @Nullable private ProxyIteratorWrapper mIteratorWrapper; + @Nullable private final ProxyIteratorWrapper mIteratorWrapper; @Nullable private IteratorResult mCurrentResult; private byte[] mResumptionToken; @Nullable private final SingleQueryRecorder mRecorder; @@ -78,16 +75,11 @@ public final class FederatedExampleIterator implements ExampleIterator { } private NextResultState mNextResultState; - private Flags mFlags; public FederatedExampleIterator( IExampleStoreIterator exampleStoreIterator, - String collectionName, - byte[] criteria, byte[] resumptionToken, SingleQueryRecorder recorder) { - this.mCollection = collectionName; - this.mCriteria = criteria; this.mResumptionToken = resumptionToken; this.mNextResultState = NextResultState.UNKNOWN; this.mCurrentResult = null; @@ -143,7 +135,7 @@ public final class FederatedExampleIterator implements ExampleIterator { mCurrentResult = mIteratorWrapper.next(); if (mCurrentResult == null) { mNextResultState = NextResultState.END_OF_ITERATOR; - Log.d(TAG, "App example store returns null, end of iterator."); + LogUtil.d(TAG, "App example store returns null, end of iterator."); } else { mNextResultState = NextResultState.RESULT_AVAILABLE; } @@ -199,8 +191,7 @@ public final class FederatedExampleIterator implements ExampleIterator { close(); throw ErrorStatusException.create( Code.UNAVAILABLE_VALUE, - "OnIteratorNextFailure: %s %s", - mCollection, + "OnIteratorNextFailure: %s", resultOrFailure.second); } if (resultOrFailure.first == null) { @@ -219,7 +210,7 @@ public final class FederatedExampleIterator implements ExampleIterator { try { mExampleStoreIterator.close(); } catch (RemoteException e) { - Log.w(TAG, "Exception during call to IExampleStoreIterator.close", e); + LogUtil.w(TAG, e, "Exception during call to IExampleStoreIterator.close"); } } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java b/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java index 44dd4f1c..e82bfe61 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/CheckinResult.java @@ -16,23 +16,43 @@ package com.android.federatedcompute.services.http; +import android.annotation.Nullable; + +import com.android.internal.util.Preconditions; + +import com.google.internal.federated.plan.ClientOnlyPlan; +import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment; + /** * The result after client calls TaskAssignemnt API. It includes init checkpoint data and plan data. */ public class CheckinResult { - private byte[] mCheckpointData; - private byte[] mPlanData; + private String mInputCheckpoint = null; + private ClientOnlyPlan mPlanData = null; + private TaskAssignment mTaskAssignment = null; - public CheckinResult(byte[] checkpointData, byte[] planData) { - this.mCheckpointData = checkpointData; + public CheckinResult( + String inputCheckpoint, ClientOnlyPlan planData, TaskAssignment taskAssignment) { + this.mInputCheckpoint = inputCheckpoint; this.mPlanData = planData; + this.mTaskAssignment = taskAssignment; } - public byte[] getCheckpointData() { - return mCheckpointData; + @Nullable + public String getInputCheckpointFile() { + Preconditions.checkArgument( + mInputCheckpoint != null && !mInputCheckpoint.isEmpty(), + "Input checkpoint file should not be none or empty"); + return mInputCheckpoint; } - public byte[] getPlanData() { + @Nullable + public ClientOnlyPlan getPlanData() { return mPlanData; } + + @Nullable + public TaskAssignment getTaskAssignment() { + return mTaskAssignment; + } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/FederatedComputeHttpRequest.java b/federatedcompute/src/com/android/federatedcompute/services/http/FederatedComputeHttpRequest.java index 7a3b14f2..527811df 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/FederatedComputeHttpRequest.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/FederatedComputeHttpRequest.java @@ -61,7 +61,7 @@ public final class FederatedComputeHttpRequest { throw new IllegalArgumentException("Content-Length header should not be provided!"); } if (body.length > 0) { - if (httpMethod != HttpMethod.POST) { + if (httpMethod != HttpMethod.POST && httpMethod != HttpMethod.PUT) { throw new IllegalArgumentException( "Request method does not allow request mBody: " + httpMethod); } diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClient.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClient.java index 500f3d19..2caeccea 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClient.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClient.java @@ -16,13 +16,17 @@ package com.android.federatedcompute.services.http; +import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBlockingExecutor; import static com.android.federatedcompute.services.http.HttpClientUtil.HTTP_OK_STATUS; import android.annotation.NonNull; import android.annotation.Nullable; -import android.util.Log; + +import com.android.federatedcompute.internal.util.LogUtil; import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import java.io.BufferedOutputStream; import java.io.ByteArrayOutputStream; @@ -40,7 +44,7 @@ import java.util.concurrent.TimeUnit; * The HTTP client to be used by the FederatedCompute to communicate with remote federated servers. */ public class HttpClient { - private static final String TAG = "HttpClient"; + private static final String TAG = HttpClient.class.getSimpleName(); private static final int NETWORK_CONNECT_TIMEOUT_MS = (int) TimeUnit.SECONDS.toMillis(5); private static final int NETWORK_READ_TIMEOUT_MS = (int) TimeUnit.SECONDS.toMillis(30); @@ -56,12 +60,23 @@ public class HttpClient { return urlConnection; } + /** Perform HTTP requests based on given information asynchronously. */ + @NonNull + public ListenableFuture<FederatedComputeHttpResponse> performRequestAsync( + FederatedComputeHttpRequest request) { + try { + return getBlockingExecutor().submit(() -> performRequest(request)); + } catch (Exception e) { + return Futures.immediateFailedFuture(e); + } + } + /** Perform HTTP requests based on given information. */ @NonNull public FederatedComputeHttpResponse performRequest(FederatedComputeHttpRequest request) throws IOException { if (request.getUri() == null || request.getHttpMethod() == null) { - Log.e(TAG, "Endpoint or http method is empty"); + LogUtil.e(TAG, "Endpoint or http method is empty"); throw new IllegalArgumentException("Endpoint or http method is empty"); } @@ -69,7 +84,7 @@ public class HttpClient { try { url = new URL(request.getUri()); } catch (MalformedURLException e) { - Log.e(TAG, "Malformed registration target URL", e); + LogUtil.e(TAG, e, "Malformed registration target URL"); throw new IllegalArgumentException("Malformed registration target URL", e); } @@ -77,7 +92,7 @@ public class HttpClient { try { urlConnection = (HttpURLConnection) setup(url); } catch (IOException e) { - Log.e(TAG, "Failed to open target URL", e); + LogUtil.e(TAG, e, "Failed to open target URL"); throw new IllegalArgumentException("Failed to open target URL", e); } @@ -100,7 +115,7 @@ public class HttpClient { } int responseCode = urlConnection.getResponseCode(); - if (responseCode == HTTP_OK_STATUS) { + if (HTTP_OK_STATUS.contains(responseCode)) { return new FederatedComputeHttpResponse.Builder() .setPayload( getByteArray( @@ -120,7 +135,7 @@ public class HttpClient { .build(); } } catch (IOException e) { - Log.e(TAG, "Failed to get registration response", e); + LogUtil.e(TAG, e, "Failed to get registration response"); throw new IOException("Failed to get registration response", e); } finally { if (urlConnection != null) { @@ -133,8 +148,8 @@ public class HttpClient { if (contentLength == 0) { return HttpClientUtil.EMPTY_BODY; } - try { + // TODO(b/297952090): evaluate the large file download. byte[] buffer = new byte[HttpClientUtil.DEFAULT_BUFFER_SIZE]; ByteArrayOutputStream out = new ByteArrayOutputStream(); int bytesRead; diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java index a99420a7..1d6bf784 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java @@ -16,8 +16,9 @@ package com.android.federatedcompute.services.http; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; +import com.google.common.collect.ImmutableSet; import com.google.protobuf.ByteString; import java.io.IOException; @@ -25,7 +26,7 @@ import java.util.zip.GZIPOutputStream; /** Utility class containing http related variable e.g. headers, method. */ public final class HttpClientUtil { - private static final String TAG = "HttpClientUtil"; + private static final String TAG = HttpClientUtil.class.getSimpleName(); public static final String IDENTITY_ENCODING_HDR = "identity"; public static final String CONTENT_ENCODING_HDR = "Content-Encoding"; public static final String CONTENT_LENGTH_HDR = "Content-Length"; @@ -35,7 +36,7 @@ public final class HttpClientUtil { public static final String PROTOBUF_CONTENT_TYPE = "application/x-protobuf"; public static final String OCTET_STREAM = "application/octet-stream"; public static final String CLIENT_DECODE_GZIP_SUFFIX = "+gzip"; - public static final int HTTP_OK_STATUS = 200; + public static final ImmutableSet<Integer> HTTP_OK_STATUS = ImmutableSet.of(200, 201); public static final String FAKE_API_KEY = "FAKE_API_KEY"; public static final int DEFAULT_BUFFER_SIZE = 1024; public static final byte[] EMPTY_BODY = new byte[0]; @@ -43,7 +44,8 @@ public final class HttpClientUtil { /** The supported http methods. */ public enum HttpMethod { GET, - POST + POST, + PUT, } /** Compresses the input data using Gzip. */ @@ -54,7 +56,7 @@ public final class HttpClientUtil { gzipOutputStream.finish(); return outputStream.toByteString().toByteArray(); } catch (IOException e) { - Log.e(TAG, "Failed to compress using Gzip"); + LogUtil.e(TAG, "Failed to compress using Gzip"); throw new IllegalArgumentException("Failed to compress using Gzip", e); } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java index 8d63755f..a55fcf49 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java @@ -18,51 +18,45 @@ package com.android.federatedcompute.services.http; import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor; import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor; -import static com.android.federatedcompute.services.http.HttpClientUtil.CLIENT_DECODE_GZIP_SUFFIX; -import static com.android.federatedcompute.services.http.HttpClientUtil.FAKE_API_KEY; +import static com.android.federatedcompute.services.common.FileUtils.createTempFile; +import static com.android.federatedcompute.services.common.FileUtils.readFileAsByteArray; +import static com.android.federatedcompute.services.common.FileUtils.writeToFile; import static com.android.federatedcompute.services.http.HttpClientUtil.HTTP_OK_STATUS; -import static com.android.federatedcompute.services.http.HttpClientUtil.OCTET_STREAM; -import android.util.Log; - -import com.android.federatedcompute.services.common.FederatedComputeExecutors; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; +import com.android.federatedcompute.services.training.util.ComputationResult; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.util.concurrent.AsyncCallable; import com.google.common.util.concurrent.FluentFuture; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.internal.federated.plan.ClientOnlyPlan; import com.google.internal.federatedcompute.v1.ClientVersion; import com.google.internal.federatedcompute.v1.Resource; -import com.google.internal.federatedcompute.v1.ResourceCapabilities; -import com.google.internal.federatedcompute.v1.ResourceCompressionFormat; -import com.google.internal.federatedcompute.v1.StartAggregationDataUploadRequest; -import com.google.internal.federatedcompute.v1.StartAggregationDataUploadResponse; -import com.google.internal.federatedcompute.v1.StartTaskAssignmentRequest; -import com.google.internal.federatedcompute.v1.StartTaskAssignmentResponse; -import com.google.internal.federatedcompute.v1.SubmitAggregationResultRequest; -import com.google.internal.federatedcompute.v1.TaskAssignment; -import com.google.protobuf.ExtensionRegistryLite; +import com.google.ondevicepersonalization.federatedcompute.proto.CreateTaskAssignmentRequest; +import com.google.ondevicepersonalization.federatedcompute.proto.CreateTaskAssignmentResponse; +import com.google.ondevicepersonalization.federatedcompute.proto.ReportResultRequest; +import com.google.ondevicepersonalization.federatedcompute.proto.ReportResultRequest.Result; +import com.google.ondevicepersonalization.federatedcompute.proto.ReportResultResponse; +import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment; +import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction; import com.google.protobuf.InvalidProtocolBufferException; -import java.io.IOException; import java.util.HashMap; -import java.util.List; /** Implements a single session of HTTP-based federated compute protocol. */ public final class HttpFederatedProtocol { public static final String TAG = "HttpFederatedProtocol"; - private final String mClientVersion; private final String mPopulationName; private final HttpClient mHttpClient; - private String mAggregatedSessionId; - private String mAggregationAuthroizationToken; - private String mAggregationResourceName; + private String mTaskId; + private String mAggregationId; + private String mAssignmentId; private final ProtocolRequestCreator mTaskAssignmentRequestCreator; - private ProtocolRequestCreator mAggregationRequestCreator; - private ProtocolRequestCreator mDataUploadRequestCreator; @VisibleForTesting HttpFederatedProtocol( @@ -71,7 +65,7 @@ public final class HttpFederatedProtocol { this.mPopulationName = populationName; this.mHttpClient = httpClient; this.mTaskAssignmentRequestCreator = - new ProtocolRequestCreator(entryUri, FAKE_API_KEY, new HashMap<>(), false); + new ProtocolRequestCreator(entryUri, new HashMap<>(), false); } /** Creates a HttpFederatedProtocol object. */ @@ -83,38 +77,22 @@ public final class HttpFederatedProtocol { /** Helper function to perform check in and download federated task from remote servers. */ public ListenableFuture<CheckinResult> issueCheckin() { ListenableFuture<TaskAssignment> taskAssignmentFuture = - FluentFuture.from(getBackgroundExecutor().submit(() -> callStartTaskAssignment())) - .transformAsync( + FluentFuture.from(createTaskAssignment()) + .transform( getTaskAssignmentHttpResponse -> getTaskAssignment(getTaskAssignmentHttpResponse), - FederatedComputeExecutors.getLightweightExecutor()); - + getLightweightExecutor()); ListenableFuture<FederatedComputeHttpResponse> planDataResponseFuture = FluentFuture.from(taskAssignmentFuture) .transformAsync( - taskAssignment -> - getBackgroundExecutor() - .submit( - () -> - fetchTaskResource( - taskAssignment.getPlan())), + taskAssignment -> fetchTaskResource(taskAssignment.getPlan()), getBackgroundExecutor()); - ListenableFuture<FederatedComputeHttpResponse> checkpointDataResponseFuture = FluentFuture.from(taskAssignmentFuture) .transformAsync( taskAssignment -> - getBackgroundExecutor() - .submit( - () -> { - Resource checkpointResource = - taskAssignment - .getInitCheckpoint(); - return fetchTaskResource( - checkpointResource); - }), + fetchTaskResource(taskAssignment.getInitCheckpoint()), getBackgroundExecutor()); - return Futures.whenAllSucceed( taskAssignmentFuture, planDataResponseFuture, checkpointDataResponseFuture) .callAsync( @@ -130,73 +108,59 @@ public final class HttpFederatedProtocol { getBackgroundExecutor()); } - /** Helper functions to reporting result via simple aggregation. */ - public FluentFuture<Void> reportViaSimpleAggregation(byte[] computationResult) { - return FluentFuture.from(getBackgroundExecutor().submit(() -> performStartDataUpload())) - .transform( - startResp -> processStartDataUploadResponse(startResp), - FederatedComputeExecutors.getLightweightExecutor()) - .transformAsync( - voidIgnore -> - getBackgroundExecutor() - .submit( - () -> - uploadViaSimpleAggregation( - computationResult)), - getBackgroundExecutor()) - .transform( - resp -> processFederatedComputeHttpResponse("upload failed", resp), - FederatedComputeExecutors.getLightweightExecutor()) - .transformAsync( - voidIgnore -> - getBackgroundExecutor().submit(() -> subimitAggregationResult()), - getBackgroundExecutor()) - .transform( - resp -> - processFederatedComputeHttpResponse( - "submit aggregation result failed", resp), - getLightweightExecutor()); + /** Helper functions to reporting result and upload result. */ + public FluentFuture<Void> reportResult(ComputationResult computationResult) { + if (computationResult != null && computationResult.isResultSuccess()) { + return FluentFuture.from(performReportResult(computationResult)) + .transformAsync( + reportResp -> + processReportResultResponseAndUploadResult( + reportResp, computationResult), + getBackgroundExecutor()) + .transform( + resp -> { + validateHttpResponseStatus("Upload result", resp); + return null; + }, + getLightweightExecutor()); + } else { + return FluentFuture.from(performReportResult(computationResult)) + .transform( + resp -> { + validateHttpResponseStatus("Report failure result", resp); + return null; + }, + getLightweightExecutor()); + } } - private FederatedComputeHttpResponse callStartTaskAssignment() throws IOException { - StartTaskAssignmentRequest request = - StartTaskAssignmentRequest.newBuilder() + private ListenableFuture<FederatedComputeHttpResponse> createTaskAssignment() { + CreateTaskAssignmentRequest request = + CreateTaskAssignmentRequest.newBuilder() .setClientVersion(ClientVersion.newBuilder().setVersionCode(mClientVersion)) - .setPopulationName(mPopulationName) - .setResourceCapabilities( - ResourceCapabilities.newBuilder() - .addSupportedCompressionFormats( - ResourceCompressionFormat - .RESOURCE_COMPRESSION_FORMAT_GZIP) - .build()) .build(); String taskAssignmentUriSuffix = - String.format("/v1/populations/%1$s/taskassignments:start", mPopulationName); + String.format( + "/taskassignment/v1/population/%1$s:create-task-assignment", + mPopulationName); FederatedComputeHttpRequest httpRequest = mTaskAssignmentRequestCreator.createProtoRequest( taskAssignmentUriSuffix, HttpMethod.POST, request.toByteArray(), /* isProtobufEncoded= */ true); - return mHttpClient.performRequest(httpRequest); + return mHttpClient.performRequestAsync(httpRequest); } - private ListenableFuture<TaskAssignment> getTaskAssignment( - FederatedComputeHttpResponse httpResponse) { - if (httpResponse.getStatusCode() != HTTP_OK_STATUS) { - Log.e(TAG, "start task assignment failed: " + httpResponse.getStatusCode()); - throw new IllegalStateException( - "start task assignment failed: " + httpResponse.getStatusCode()); - } - StartTaskAssignmentResponse taskAssignmentResponse; + private TaskAssignment getTaskAssignment(FederatedComputeHttpResponse httpResponse) { + validateHttpResponseStatus("Start task assignment", httpResponse); + CreateTaskAssignmentResponse taskAssignmentResponse; try { taskAssignmentResponse = - StartTaskAssignmentResponse.parseFrom( - httpResponse.getPayload(), ExtensionRegistryLite.getEmptyRegistry()); + CreateTaskAssignmentResponse.parseFrom(httpResponse.getPayload()); } catch (InvalidProtocolBufferException e) { throw new IllegalStateException("Could not parse StartTaskAssignmentResponse proto", e); } - Log.i(TAG, "start task assignment response: " + taskAssignmentResponse); if (taskAssignmentResponse.hasRejectionInfo()) { throw new IllegalStateException("Device rejected by server."); } @@ -204,7 +168,30 @@ public final class HttpFederatedProtocol { throw new IllegalStateException( "Could not find both task assignment and rejection info."); } - return Futures.immediateFuture(taskAssignmentResponse.getTaskAssignment()); + validateTaskAssignment(taskAssignmentResponse.getTaskAssignment()); + TaskAssignment taskAssignment = taskAssignmentResponse.getTaskAssignment(); + LogUtil.d( + TAG, + "Receive CreateTaskAssignmentResponse: task name %s assignment id %s", + taskAssignment.getTaskName(), + taskAssignment.getAssignmentId()); + return taskAssignment; + } + + private void validateTaskAssignment(TaskAssignment taskAssignment) { + Preconditions.checkArgument( + taskAssignment.getPopulationName().equals(mPopulationName), + "Population name should match"); + // These fields are required to construct ReportResultRequest. + Preconditions.checkArgument( + !taskAssignment.getTaskId().isEmpty(), "Task id should not be empty"); + Preconditions.checkArgument( + !taskAssignment.getAggregationId().isEmpty(), "Aggregation id should not be empty"); + Preconditions.checkArgument( + !taskAssignment.getAssignmentId().isEmpty(), "Assignment id should not be empty"); + this.mTaskId = taskAssignment.getTaskId(); + this.mAggregationId = taskAssignment.getAggregationId(); + this.mAssignmentId = taskAssignment.getAssignmentId(); } private ListenableFuture<CheckinResult> getCheckinResult( @@ -216,167 +203,132 @@ public final class HttpFederatedProtocol { FederatedComputeHttpResponse checkpointDataResponse = Futures.getDone(checkpointDataResponseFuture); TaskAssignment taskAssignment = Futures.getDone(taskAssignmentFuture); - if (planDataResponse.getStatusCode() != HTTP_OK_STATUS) { - throw new IllegalStateException( - "plan fetch failed: " + planDataResponse.getStatusCode()); + validateHttpResponseStatus("Fetch plan", planDataResponse); + validateHttpResponseStatus("Fetch checkpoint", checkpointDataResponse); + ClientOnlyPlan clientOnlyPlan; + try { + clientOnlyPlan = ClientOnlyPlan.parseFrom(planDataResponse.getPayload()); + } catch (InvalidProtocolBufferException e) { + LogUtil.e(TAG, e, "Could not parse ClientOnlyPlan proto"); + return Futures.immediateFailedFuture( + new IllegalStateException("Could not parse ClientOnlyPlan proto", e)); } - if (checkpointDataResponse.getStatusCode() != HTTP_OK_STATUS) { - throw new IllegalStateException( - "checkpoint data fetch failed: " + checkpointDataResponse.getStatusCode()); - } - if (taskAssignment.getAggregationId().isEmpty() - || taskAssignment.getAuthorizationToken().isEmpty()) { - throw new IllegalStateException( - "Aggregation id and authorization token should not be empty: " - + taskAssignment.getAggregationId() - + " " - + taskAssignment.getAuthorizationToken()); - } - this.mAggregatedSessionId = taskAssignment.getAggregationId(); - this.mAggregationAuthroizationToken = taskAssignment.getAuthorizationToken(); - mAggregationRequestCreator = - ProtocolRequestCreator.create( - FAKE_API_KEY, - taskAssignment.getAggregationDataForwardingInfo(), - /* useCompression= */ false); - + String inputCheckpointFile = createTempFile("input", ".ckp"); + writeToFile(inputCheckpointFile, checkpointDataResponse.getPayload()); return Futures.immediateFuture( - new CheckinResult( - planDataResponse.getPayload(), checkpointDataResponse.getPayload())); + new CheckinResult(inputCheckpointFile, clientOnlyPlan, taskAssignment)); } catch (Exception e) { return Futures.immediateFailedFuture(e); } } - private FederatedComputeHttpResponse performStartDataUpload() throws IOException { - StartAggregationDataUploadRequest startDataUploadRequest = - StartAggregationDataUploadRequest.newBuilder() - .setAggregationId(mAggregatedSessionId) - .setAuthorizationToken(mAggregationAuthroizationToken) - .build(); + private ListenableFuture<FederatedComputeHttpResponse> performReportResult( + ComputationResult computationResult) { + Result result = + (computationResult != null && computationResult.isResultSuccess()) + ? Result.COMPLETED + : Result.FAILED; + ReportResultRequest startDataUploadRequest = + ReportResultRequest.newBuilder().setResult(result).build(); String startDataUploadUri = String.format( - "/v1/aggregations/%1$s/clients/%2$s:startdataupload", - mAggregatedSessionId, mAggregationAuthroizationToken); + "/taskassignment/v1/population/%1$s/task/%2$s/aggregation" + + "/%3$s/task-assignment/%4$s:report-result", + mPopulationName, mTaskId, mAggregationId, mAssignmentId); + LogUtil.d( + TAG, + "send ReportResultRequest: population name %s, task name %s," + + " assignment id %s, result %s", + mPopulationName, + mTaskId, + mAssignmentId, + result.toString()); FederatedComputeHttpRequest httpRequest = - mAggregationRequestCreator.createProtoRequest( + mTaskAssignmentRequestCreator.createProtoRequest( startDataUploadUri, - HttpMethod.POST, + HttpMethod.PUT, startDataUploadRequest.toByteArray(), /* isProtobufEncoded= */ true); - return mHttpClient.performRequest(httpRequest); + return mHttpClient.performRequestAsync(httpRequest); } - private Void processStartDataUploadResponse(FederatedComputeHttpResponse httpResponse) { - StartAggregationDataUploadResponse startDataUploadResponse; - if (httpResponse.getStatusCode() != HTTP_OK_STATUS) { - Log.e(TAG, "start data upload failed: " + httpResponse.getStatusCode()); - throw new IllegalStateException( - "start data upload failed: " + httpResponse.getStatusCode()); - } + private ListenableFuture<FederatedComputeHttpResponse> + processReportResultResponseAndUploadResult( + FederatedComputeHttpResponse httpResponse, + ComputationResult computationResult) { try { - startDataUploadResponse = - StartAggregationDataUploadResponse.parseFrom( - httpResponse.getPayload(), ExtensionRegistryLite.getEmptyRegistry()); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException( - "Could not parse StartAggregationDataUploadResponse proto", e); - } - if (startDataUploadResponse - .getAggregationProtocolForwardingInfo() - .getTargetUriPrefix() - .isEmpty()) { - throw new IllegalStateException( - "Missing ForwardingInfo.target_uri_prefix in" - + " StartAggregationDataUploadResponse"); + validateHttpResponseStatus("ReportResult", httpResponse); + ReportResultResponse reportResultResponse = + ReportResultResponse.parseFrom(httpResponse.getPayload()); + // TODO(b/297605806): better handle rejection info. + if (reportResultResponse.hasRejectionInfo()) { + return Futures.immediateFailedFuture( + new IllegalStateException( + "ReportResult got rejection: " + httpResponse.getStatusCode())); + } + Preconditions.checkArgument( + !computationResult.getOutputCheckpointFile().isEmpty(), + "Output checkpoint file should not be empty"); + byte[] outputBytes = readFileAsByteArray(computationResult.getOutputCheckpointFile()); + UploadInstruction uploadInstruction = reportResultResponse.getUploadInstruction(); + Preconditions.checkArgument( + !uploadInstruction.getUploadLocation().isEmpty(), + "UploadInstruction.upload_location must not be empty"); + HashMap<String, String> requestHeader = new HashMap<>(); + uploadInstruction + .getExtraRequestHeadersMap() + .forEach( + (key, value) -> { + requestHeader.put(key, value); + }); + LogUtil.d( + TAG, + "Start upload training result: population name %s, task name %s," + + " assignment id %s", + mPopulationName, + mTaskId, + mAssignmentId); + FederatedComputeHttpRequest httpUploadRequest = + FederatedComputeHttpRequest.create( + uploadInstruction.getUploadLocation(), + HttpMethod.PUT, + requestHeader, + outputBytes, + /* useCompression= */ false); + return mHttpClient.performRequestAsync(httpUploadRequest); + } catch (Exception e) { + return Futures.immediateFailedFuture(e); } - mAggregationRequestCreator = - ProtocolRequestCreator.create( - FAKE_API_KEY, - startDataUploadResponse.getAggregationProtocolForwardingInfo(), - /* useCompression= */ false); - mDataUploadRequestCreator = - ProtocolRequestCreator.create( - FAKE_API_KEY, - startDataUploadResponse.getResource().getDataUploadForwardingInfo(), - /* useCompression= */ false); - mAggregationAuthroizationToken = - startDataUploadResponse.getClientToken().isEmpty() - ? mAggregationAuthroizationToken - : startDataUploadResponse.getClientToken(); - mAggregationResourceName = startDataUploadResponse.getResource().getResourceName(); - - return null; } - private FederatedComputeHttpResponse subimitAggregationResult() throws IOException { - String submitAggregationResultUri = - String.format( - "/v1/aggregations/%1$s/clients/%2$s:submit", - mAggregatedSessionId, mAggregationAuthroizationToken); - SubmitAggregationResultRequest request = - SubmitAggregationResultRequest.newBuilder() - .setResourceName(mAggregationResourceName) - .build(); - FederatedComputeHttpRequest httpRequest = - mAggregationRequestCreator.createProtoRequest( - submitAggregationResultUri, - HttpMethod.POST, - request.toByteArray(), - /* isProtobufEncoded= */ true); - return mHttpClient.performRequest(httpRequest); - } - - private Void processFederatedComputeHttpResponse( + private void validateHttpResponseStatus( String stage, FederatedComputeHttpResponse httpResponse) { - if (httpResponse.getStatusCode() != HTTP_OK_STATUS) { - throw new IllegalStateException( - stage + ": " + httpResponse.getStatusCode() + " " + mAggregationResourceName); + if (!HTTP_OK_STATUS.contains(httpResponse.getStatusCode())) { + throw new IllegalStateException(stage + " failed: " + httpResponse.getStatusCode()); + } else { + LogUtil.d(TAG, stage + " success."); } - return null; } - private FederatedComputeHttpResponse uploadViaSimpleAggregation(byte[] computationResult) - throws IOException { - String uploadUri = String.format("/upload/v1/media/%1$s", mAggregationResourceName); - HashMap<String, String> params = new HashMap<>(); - params.put("upload_protocol", "raw"); - FederatedComputeHttpRequest request = - mDataUploadRequestCreator.createProtoRequest( - uploadUri, HttpMethod.POST, params, computationResult, false); - return mHttpClient.performRequest(request); - } - - private FederatedComputeHttpResponse fetchTaskResource(Resource resource) throws IOException { - if (resource.getResourceCase() == Resource.ResourceCase.URI) { - if (resource.getUri().isEmpty()) { - throw new IllegalArgumentException("Resource.uri must be non-empty when set"); - } - if (!resource.getClientCacheId().isEmpty()) { - throw new UnsupportedOperationException("Resource cache is not supported yet."); - } - - FederatedComputeHttpRequest httpRequest = - FederatedComputeHttpRequest.create( - resource.getUri(), - HttpMethod.GET, - new HashMap<String, String>(), - HttpClientUtil.EMPTY_BODY, - /* useCompression= */ false); - return mHttpClient.performRequest(httpRequest); - - } else if (resource.getResourceCase() == Resource.ResourceCase.INLINE_RESOURCE) { - String contentType = OCTET_STREAM; - if (resource.getInlineResource().getCompressionFormat() - == ResourceCompressionFormat.RESOURCE_COMPRESSION_FORMAT_GZIP) { - contentType = contentType.concat(CLIENT_DECODE_GZIP_SUFFIX); - } - return new FederatedComputeHttpResponse.Builder() - .setPayload(resource.getInlineResource().getData().toByteArray()) - .setStatusCode(HTTP_OK_STATUS) - .setHeaders(new HashMap<String, List<String>>()) - .build(); + private ListenableFuture<FederatedComputeHttpResponse> fetchTaskResource(Resource resource) { + switch (resource.getResourceCase()) { + case URI: + Preconditions.checkArgument( + !resource.getUri().isEmpty(), "Resource.uri must be non-empty when set"); + FederatedComputeHttpRequest httpRequest = + FederatedComputeHttpRequest.create( + resource.getUri(), + HttpMethod.GET, + new HashMap<String, String>(), + HttpClientUtil.EMPTY_BODY, + /* useCompression= */ false); + return mHttpClient.performRequestAsync(httpRequest); + case INLINE_RESOURCE: + return Futures.immediateFailedFuture( + new UnsupportedOperationException("Inline resource is not supported yet.")); + default: + return Futures.immediateFailedFuture( + new UnsupportedOperationException("Unknown Resource type")); } - throw new UnsupportedOperationException("Unknown Resource type"); } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/ProtocolRequestCreator.java b/federatedcompute/src/com/android/federatedcompute/services/http/ProtocolRequestCreator.java index 968365f8..e87447f4 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/ProtocolRequestCreator.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/ProtocolRequestCreator.java @@ -16,9 +16,7 @@ package com.android.federatedcompute.services.http; -import static com.android.federatedcompute.services.http.HttpClientUtil.API_KEY_HDR; import static com.android.federatedcompute.services.http.HttpClientUtil.CONTENT_TYPE_HDR; -import static com.android.federatedcompute.services.http.HttpClientUtil.FAKE_API_KEY; import static com.android.federatedcompute.services.http.HttpClientUtil.PROTOBUF_CONTENT_TYPE; import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; @@ -33,17 +31,12 @@ import java.util.HashMap; */ public final class ProtocolRequestCreator { private final String mRequestBaseUri; - private final String mApiKey; private final HashMap<String, String> mHeaderList; private boolean mUseCompression; public ProtocolRequestCreator( - String requestBaseUri, - String apiKey, - HashMap<String, String> headerList, - boolean useCompression) { + String requestBaseUri, HashMap<String, String> headerList, boolean useCompression) { this.mRequestBaseUri = requestBaseUri; - this.mApiKey = apiKey; this.mHeaderList = headerList; this.mUseCompression = useCompression; } @@ -53,14 +46,14 @@ public final class ProtocolRequestCreator { * base URI for the subsequent requests. */ public static ProtocolRequestCreator create( - String apiKey, ForwardingInfo forwardingInfo, boolean useCompression) { + ForwardingInfo forwardingInfo, boolean useCompression) { if (forwardingInfo.getTargetUriPrefix().isEmpty()) { throw new IllegalArgumentException("Missing `ForwardingInfo.target_uri_prefix`"); } HashMap<String, String> extraHeaders = new HashMap<>(); extraHeaders.putAll(forwardingInfo.getExtraRequestHeadersMap()); return new ProtocolRequestCreator( - forwardingInfo.getTargetUriPrefix(), apiKey, extraHeaders, useCompression); + forwardingInfo.getTargetUriPrefix(), extraHeaders, useCompression); } /** Creates a {@link FederatedComputeHttpRequest} with base uri and compression setting. */ @@ -80,14 +73,11 @@ public final class ProtocolRequestCreator { HashMap<String, String> params, byte[] requestBody, boolean isProtobufEncoded) { - HashMap<String, String> requestHeader = mHeaderList; - requestHeader.put(API_KEY_HDR, mApiKey.isEmpty() ? FAKE_API_KEY : mApiKey); + HashMap<String, String> requestHeader = new HashMap<>(); + requestHeader.putAll(mHeaderList); - if (isProtobufEncoded) { - if (requestBody.length > 0) { - requestHeader.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); - } - params.put("%24alt", "proto"); + if (isProtobufEncoded && requestBody.length > 0) { + requestHeader.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); } String requestUriSuffix = uri; diff --git a/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java b/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java index a9ee67f2..cfd38970 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java +++ b/federatedcompute/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManager.java @@ -17,6 +17,7 @@ package com.android.federatedcompute.services.scheduling; import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; +import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; import static com.android.federatedcompute.services.scheduling.SchedulingUtil.convertSchedulingMode; @@ -25,17 +26,14 @@ import static java.util.concurrent.TimeUnit.SECONDS; import android.annotation.NonNull; import android.annotation.Nullable; import android.content.Context; -import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.common.TrainingInterval; import android.federatedcompute.common.TrainingOptions; -import android.os.RemoteException; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.Clock; import com.android.federatedcompute.services.common.Flags; import com.android.federatedcompute.services.common.MonotonicClock; import com.android.federatedcompute.services.common.PhFlags; -import com.android.federatedcompute.services.common.TrainingResult; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.FederatedTrainingTaskDao; import com.android.federatedcompute.services.data.fbs.SchedulingMode; @@ -46,6 +44,7 @@ import com.android.internal.util.Preconditions; import com.google.common.annotations.VisibleForTesting; import com.google.flatbuffers.FlatBufferBuilder; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; import com.google.intelligence.fcp.client.engine.TaskRetry; import java.util.Arrays; @@ -54,14 +53,13 @@ import java.util.Set; /** Handles scheduling training tasks e.g. calling into JobScheduler, maintaining datastore. */ public class FederatedComputeJobManager { - private static final String TAG = "FederatedComputeJobManager"; - + private static final String TAG = FederatedComputeJobManager.class.getSimpleName(); + private static volatile FederatedComputeJobManager sSingletonInstance; @NonNull private final Context mContext; private final FederatedTrainingTaskDao mFederatedTrainingTaskDao; private final JobSchedulerHelper mJobSchedulerHelper; - private static FederatedComputeJobManager sSingletonInstance; private final FederatedJobIdGenerator mJobIdGenerator; - private Clock mClock; + private final Clock mClock; private final Flags mFlags; @VisibleForTesting @@ -72,7 +70,7 @@ public class FederatedComputeJobManager { JobSchedulerHelper jobSchedulerHelper, @NonNull Clock clock, Flags flag) { - this.mContext = context; + this.mContext = context.getApplicationContext(); this.mFederatedTrainingTaskDao = federatedTrainingTaskDao; this.mJobIdGenerator = jobIdGenerator; this.mJobSchedulerHelper = jobSchedulerHelper; @@ -83,45 +81,44 @@ public class FederatedComputeJobManager { /** Returns an instance of FederatedComputeJobManager given a context. */ @NonNull public static FederatedComputeJobManager getInstance(@NonNull Context mContext) { - synchronized (FederatedComputeJobManager.class) { - if (sSingletonInstance == null) { - Clock clock = MonotonicClock.getInstance(); - sSingletonInstance = - new FederatedComputeJobManager( - mContext, - FederatedTrainingTaskDao.getInstance(mContext), - FederatedJobIdGenerator.getInstance(), - new JobSchedulerHelper(clock), - clock, - PhFlags.getInstance()); + if (sSingletonInstance == null) { + synchronized (FederatedComputeJobManager.class) { + if (sSingletonInstance == null) { + Clock clock = MonotonicClock.getInstance(); + sSingletonInstance = + new FederatedComputeJobManager( + mContext.getApplicationContext(), + FederatedTrainingTaskDao.getInstance(mContext), + FederatedJobIdGenerator.getInstance(), + new JobSchedulerHelper(clock), + clock, + PhFlags.getInstance()); + } } - return sSingletonInstance; } + return sSingletonInstance; } /** * Called when a client indicates via the client API that a task with the given parameters * should be scheduled. */ - public synchronized void onTrainerStartCalled( - String callingPackageName, - TrainingOptions trainingOptions, - IFederatedComputeCallback callback) { + public synchronized int onTrainerStartCalled( + String callingPackageName, TrainingOptions trainingOptions) { FederatedTrainingTask existingTask = mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationName( trainingOptions.getPopulationName()); Set<FederatedTrainingTask> trainingTasksToCancel = new HashSet<>(); String populationName = trainingOptions.getPopulationName(); long nowMs = mClock.currentTimeMillis(); - boolean shouldSchedule = false; + boolean shouldSchedule; FederatedTrainingTask newTask; byte[] newTrainingConstraint = buildTrainingConstraints(); + // Federated server address is required to schedule the job. + Preconditions.checkStringNotEmpty(trainingOptions.getServerAddress()); if (existingTask == null) { int jobId = mJobIdGenerator.generateJobId(this.mContext, populationName); - // Federated server address is required to provide when first time schedule the - // job. - Preconditions.checkStringNotEmpty(trainingOptions.getServerAddress()); FederatedTrainingTask.Builder newTaskBuilder = FederatedTrainingTask.builder() .appPackageName(callingPackageName) @@ -130,15 +127,15 @@ public class FederatedComputeJobManager { .lastScheduledTime(nowMs) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) .constraints(newTrainingConstraint) + .intervalOptions( + buildTrainingIntervalOptions( + trainingOptions.getTrainingInterval())) .populationName(trainingOptions.getPopulationName()) + .contextData(trainingOptions.getContextData()) .serverAddress(trainingOptions.getServerAddress()) .earliestNextRunTime( SchedulingUtil.getEarliestRuntimeForInitialSchedule( nowMs, 0, trainingOptions, mFlags)); - if (trainingOptions.getTrainingInterval() != null) { - newTaskBuilder.intervalOptions( - buildTrainingIntervalOptions(trainingOptions.getTrainingInterval())); - } newTask = newTaskBuilder.build(); shouldSchedule = true; } else { @@ -152,18 +149,18 @@ public class FederatedComputeJobManager { FederatedTrainingTask.Builder newTaskBuilder = existingTask.toBuilder() .constraints(buildTrainingConstraints()) + .serverAddress(trainingOptions.getServerAddress()) + .contextData(trainingOptions.getContextData()) .lastScheduledTime(nowMs); - if (detectKeyParametersChanged(trainingOptions, existingTask, trainingTasksToCancel)) { + if (detectKeyParametersChanged(trainingOptions, existingTask)) { newTaskBuilder.intervalOptions(null).lastRunStartTime(null).lastRunEndTime(null); newTaskBuilder .populationName(trainingOptions.getPopulationName()) + .intervalOptions( + buildTrainingIntervalOptions(trainingOptions.getTrainingInterval())) .earliestNextRunTime( SchedulingUtil.getEarliestRuntimeForInitialSchedule( nowMs, nowMs, trainingOptions, mFlags)); - if (trainingOptions.getTrainingInterval() != null) { - newTaskBuilder.intervalOptions( - buildTrainingIntervalOptions(trainingOptions.getTrainingInterval())); - } shouldSchedule = true; } else { long earliestNextRunTime = @@ -172,10 +169,11 @@ public class FederatedComputeJobManager { long maxExpectedRuntimeSecs = mFlags.getTrainingServiceResultCallbackTimeoutSecs() + /*buffer*/ 30; boolean currentlyRunningHeuristic = - existingTask.lastRunStartTime() < nowMs - && nowMs - existingTask.lastRunStartTime() + existingTask.getLastRunStartTime() < nowMs + && nowMs - existingTask.getLastRunStartTime() < 1000 * maxExpectedRuntimeSecs - && existingTask.lastRunStartTime() > existingTask.lastRunEndTime(); + && existingTask.getLastRunStartTime() + > existingTask.getLastRunEndTime(); shouldSchedule = !currentlyRunningHeuristic && (!mJobSchedulerHelper.isTaskScheduled(mContext, existingTask) @@ -196,10 +194,6 @@ public class FederatedComputeJobManager { shouldSchedule ? SchedulingReason.SCHEDULING_REASON_NEW_TASK : existingTask.schedulingReason()); - if (trainingOptions.getServerAddress() != null - && !trainingOptions.getServerAddress().isEmpty()) { - newTaskBuilder.serverAddress(trainingOptions.getServerAddress()); - } newTask = newTaskBuilder.build(); } @@ -209,13 +203,12 @@ public class FederatedComputeJobManager { if (shouldSchedule) { boolean scheduleResult = mJobSchedulerHelper.scheduleTask(mContext, newTask); if (!scheduleResult) { - Log.w( + LogUtil.w( TAG, - "JobScheduler returned failure when starting training job " - + newTask.jobId()); + "JobScheduler returned failure when starting training job %d", + newTask.jobId()); // If scheduling failed then leave the task store as-is, and bail. - sendError(callback); - return; + return STATUS_INTERNAL_ERROR; } } @@ -223,19 +216,38 @@ public class FederatedComputeJobManager { boolean storeResult = mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask); if (!storeResult) { - Log.w( + LogUtil.w( TAG, - "JobScheduler returned failure when storing training job!" + newTask.jobId()); - sendError(callback); - return; + "JobScheduler returned failure when storing training job with id %d!", + newTask.jobId()); + return STATUS_INTERNAL_ERROR; } // Second, if the task previously had a different job ID or a if there was another // task with the same population name, then cancel the corresponding old tasks. for (FederatedTrainingTask trainingTaskToCancel : trainingTasksToCancel) { - Log.i(TAG, " JobScheduler cancel the task " + newTask.jobId()); + LogUtil.i(TAG, " JobScheduler cancel the task %d", newTask.jobId()); mJobSchedulerHelper.cancelTask(mContext, trainingTaskToCancel); } - sendSuccess(callback); + return STATUS_SUCCESS; + } + + /** + * Called when a client indicates via the client API that a task with the given parameters + * should be canceled. + */ + public synchronized int onTrainerStopCalled(String callingPackageName, String populationName) { + FederatedTrainingTask taskToCancel = + mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationName(populationName); + // If no matching task exists then there's nothing for us to do. This is not an error + // case though. + if (taskToCancel == null) { + LogUtil.i(TAG, "No matching task exists when cancel the job %s", populationName); + return STATUS_SUCCESS; + } + + LogUtil.i(TAG, " onTrainerStopCalled cancel the task %d", taskToCancel.jobId()); + mJobSchedulerHelper.cancelTask(mContext, taskToCancel); + return STATUS_SUCCESS; } /** Called when a training task identified by {@code jobId} starts running. */ @@ -250,7 +262,7 @@ public class FederatedComputeJobManager { long nowMs = mClock.currentTimeMillis(); if (ttlMs > 0 && nowMs - existingTask.lastScheduledTime() > ttlMs) { // If the TTL is expired, then delete the task. - Log.i(TAG, String.format("Training task %d TTLd", jobId)); + LogUtil.i(TAG, "Training task %d TTLd", jobId); return null; } FederatedTrainingTask newTask = existingTask.toBuilder().lastRunStartTime(nowMs).build(); @@ -264,12 +276,12 @@ public class FederatedComputeJobManager { String populationName, TrainingIntervalOptions trainingIntervalOptions, TaskRetry taskRetry, - @TrainingResult int trainingResult) { + ContributionResult trainingResult) { boolean result = rescheduleFederatedTaskAfterTraining( jobId, populationName, trainingIntervalOptions, taskRetry, trainingResult); if (!result) { - Log.e(TAG, "JobScheduler returned failure after successful run!"); + LogUtil.e(TAG, "JobScheduler returned failure after successful run!"); } } @@ -279,7 +291,7 @@ public class FederatedComputeJobManager { String populationName, TrainingIntervalOptions intervalOptions, TaskRetry taskRetry, - @TrainingResult int trainingResult) { + ContributionResult trainingResult) { FederatedTrainingTask existingTask = mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationAndJobId( populationName, jobId); @@ -289,12 +301,12 @@ public class FederatedComputeJobManager { if (existingTask == null) { return true; } - boolean hasContributed = trainingResult == TrainingResult.SUCCESS; + boolean hasContributed = trainingResult == ContributionResult.SUCCESS; if (intervalOptions != null && intervalOptions.schedulingMode() == SchedulingMode.ONE_TIME && hasContributed) { mJobSchedulerHelper.cancelTask(mContext, existingTask); - Log.i(TAG, "federated task remove because oneoff task succeeded: " + jobId); + LogUtil.i(TAG, "federated task remove because oneoff task succeeded: %d", jobId); return true; } // Update the task and add it back to the training task store. @@ -330,15 +342,21 @@ public class FederatedComputeJobManager { return builder.sizedByteArray(); } + private static byte[] buildDefaultTrainingInterval() { + FlatBufferBuilder builder = new FlatBufferBuilder(); + builder.finish( + TrainingIntervalOptions.createTrainingIntervalOptions( + builder, SchedulingMode.ONE_TIME, 0)); + return builder.sizedByteArray(); + } + private static byte[] buildTrainingIntervalOptions( @Nullable TrainingInterval trainingInterval) { - FlatBufferBuilder builder = new FlatBufferBuilder(); if (trainingInterval == null) { - builder.finish( - TrainingIntervalOptions.createTrainingIntervalOptions( - builder, SchedulingMode.ONE_TIME, 0)); - return builder.sizedByteArray(); + return buildDefaultTrainingInterval(); } + + FlatBufferBuilder builder = new FlatBufferBuilder(); builder.finish( TrainingIntervalOptions.createTrainingIntervalOptions( builder, @@ -349,28 +367,25 @@ public class FederatedComputeJobManager { } private boolean detectKeyParametersChanged( - TrainingOptions newTaskOptions, - FederatedTrainingTask existingTask, - Set<FederatedTrainingTask> trainingTasksToCancel) { + TrainingOptions newTaskOptions, FederatedTrainingTask existingTask) { // Check if the task previously had a different population name. boolean populationChanged = !existingTask.populationName().equals(newTaskOptions.getPopulationName()); if (populationChanged) { - Log.i( + LogUtil.i( TAG, - String.format( - "JobScheduler population name changed from %s to %s", - existingTask.populationName(), newTaskOptions.getPopulationName())); + "JobScheduler population name changed from %s to %s", + existingTask.populationName(), + newTaskOptions.getPopulationName()); } boolean trainingIntervalChanged = trainingIntervalChanged(newTaskOptions, existingTask); if (trainingIntervalChanged) { - Log.i( + LogUtil.i( TAG, - String.format( - "JobScheduler training interval changed from %s to %s", - existingTask.getTrainingIntervalOptions(), - newTaskOptions.getTrainingInterval())); + "JobScheduler training interval changed from %s to %s", + existingTask.getTrainingIntervalOptions(), + newTaskOptions.getTrainingInterval()); } return populationChanged || trainingIntervalChanged; } @@ -378,25 +393,7 @@ public class FederatedComputeJobManager { private static boolean trainingIntervalChanged( TrainingOptions newTaskOptions, FederatedTrainingTask existingTask) { byte[] incomingTrainingIntervalOptions = - newTaskOptions.getTrainingInterval() == null - ? null - : buildTrainingIntervalOptions(newTaskOptions.getTrainingInterval()); + buildTrainingIntervalOptions(newTaskOptions.getTrainingInterval()); return !Arrays.equals(incomingTrainingIntervalOptions, existingTask.intervalOptions()); } - - private void sendError(@NonNull IFederatedComputeCallback callback) { - try { - callback.onFailure(STATUS_INTERNAL_ERROR); - } catch (RemoteException e) { - Log.e(TAG, "IFederatedComputeCallback error", e); - } - } - - private void sendSuccess(@NonNull IFederatedComputeCallback callback) { - try { - callback.onSuccess(); - } catch (RemoteException e) { - Log.e(TAG, "IFederatedComputeCallback error", e); - } - } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/scheduling/JobSchedulerHelper.java b/federatedcompute/src/com/android/federatedcompute/services/scheduling/JobSchedulerHelper.java index ce6511f2..d009c910 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/scheduling/JobSchedulerHelper.java +++ b/federatedcompute/src/com/android/federatedcompute/services/scheduling/JobSchedulerHelper.java @@ -20,14 +20,14 @@ import android.app.job.JobInfo; import android.app.job.JobScheduler; import android.content.ComponentName; import android.content.Context; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.Clock; import com.android.federatedcompute.services.data.FederatedTrainingTask; /** The helper class of JobScheduler. */ public class JobSchedulerHelper { - private static final String TAG = "JobSchedulerHelper"; + private static final String TAG = JobSchedulerHelper.class.getSimpleName(); private static final String TRAINING_JOB_SERVICE = "com.android.federatedcompute.services.training.FederatedJobService"; private Clock mClock; @@ -39,6 +39,7 @@ public class JobSchedulerHelper { /** Schedules a task using JobScheduler. */ public boolean scheduleTask(Context context, FederatedTrainingTask newTask) { JobInfo jobInfo = convertToJobInfo(context, newTask); + LogUtil.i(TAG, "Scheduling job %s", jobInfo.getId()); return tryScheduleJob(context, jobInfo); } @@ -55,12 +56,11 @@ public class JobSchedulerHelper { private boolean tryScheduleJob(Context context, JobInfo jobInfo) { final JobScheduler jobScheduler = context.getSystemService(JobScheduler.class); if (checkCollidesWithNonFederatedComputationJob(jobScheduler, jobInfo)) { - Log.w( + LogUtil.w( TAG, - String.format( - "Collision with non-FederatedComputation job with same job ID (%s)" - + " detected, not scheduling!", - jobInfo.getId())); + "Collision with non-FederatedComputation job with same job ID (%s)" + + " detected, not scheduling!", + jobInfo.getId()); return false; } diff --git a/federatedcompute/src/com/android/federatedcompute/services/statsd/ApiCallStats.java b/federatedcompute/src/com/android/federatedcompute/services/statsd/ApiCallStats.java new file mode 100644 index 00000000..d9719ef0 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/statsd/ApiCallStats.java @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.statsd; + +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Class holds FederatedComputeApiCalled defined at + * frameworks/proto_logging/stats/atoms/federatedcompute/federatedcompute_extension_atoms.proto + */ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public class ApiCallStats { + private final int mApiClass; + private final int mApiName; + private final int mLatencyMillis; + private final int mResponseCode; + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen + // $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/federatedcompute/src/com/android/federatedcompute/services/statsd/ApiCallStats.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + // @formatter:off + + @DataClass.Generated.Member + /* package-private */ ApiCallStats( + int apiClass, int apiName, int latencyMillis, int responseCode) { + this.mApiClass = apiClass; + this.mApiName = apiName; + this.mLatencyMillis = latencyMillis; + this.mResponseCode = responseCode; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public int getApiClass() { + return mApiClass; + } + + @DataClass.Generated.Member + public int getApiName() { + return mApiName; + } + + @DataClass.Generated.Member + public int getLatencyMillis() { + return mLatencyMillis; + } + + @DataClass.Generated.Member + public int getResponseCode() { + return mResponseCode; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(ApiCallStats other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + ApiCallStats that = (ApiCallStats) o; + //noinspection PointlessBooleanExpression + return true + && mApiClass == that.mApiClass + && mApiName == that.mApiName + && mLatencyMillis == that.mLatencyMillis + && mResponseCode == that.mResponseCode; + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + mApiClass; + _hash = 31 * _hash + mApiName; + _hash = 31 * _hash + mLatencyMillis; + _hash = 31 * _hash + mResponseCode; + return _hash; + } + + /** A builder for {@link ApiCallStats} */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static class Builder { + + private int mApiClass; + private int mApiName; + private int mLatencyMillis; + private int mResponseCode; + + private long mBuilderFieldsSet = 0L; + + public Builder() {} + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setApiClass(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mApiClass = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setApiName(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mApiName = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setLatencyMillis(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mLatencyMillis = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setResponseCode(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; + mResponseCode = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @android.annotation.NonNull ApiCallStats build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x10; // Mark builder used + + ApiCallStats o = new ApiCallStats(mApiClass, mApiName, mLatencyMillis, mResponseCode); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x10) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1695427755295L, + codegenVersion = "1.0.23", + sourceFile = + "packages/modules/OnDevicePersonalization/federatedcompute/src/com/android/federatedcompute/services/statsd/ApiCallStats.java", + inputSignatures = + "private int mApiClass\nprivate int mApiName\nprivate int mLatencyMillis\nprivate int mResponseCode\nclass ApiCallStats extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + // @formatter:on + // End of generated code + +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/statsd/FederatedComputeStatsdLogger.java b/federatedcompute/src/com/android/federatedcompute/services/statsd/FederatedComputeStatsdLogger.java new file mode 100644 index 00000000..69be1594 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/statsd/FederatedComputeStatsdLogger.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.statsd; + +import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED; + +import com.android.federatedcompute.services.stats.FederatedComputeStatsLog; + +/** Log API stats and client error stats to StatsD. */ +public class FederatedComputeStatsdLogger { + private static volatile FederatedComputeStatsdLogger sFCStatsdLogger = null; + + /** Returns an instance of {@link FederatedComputeStatsdLogger}. */ + public static FederatedComputeStatsdLogger getInstance() { + if (sFCStatsdLogger == null) { + synchronized (FederatedComputeStatsdLogger.class) { + if (sFCStatsdLogger == null) { + sFCStatsdLogger = new FederatedComputeStatsdLogger(); + } + } + } + return sFCStatsdLogger; + } + + /** Log API call stats e.g. response code, API name etc. */ + public void logApiCallStats(ApiCallStats apiCallStats) { + FederatedComputeStatsLog.write( + FEDERATED_COMPUTE_API_CALLED, + apiCallStats.getApiClass(), + apiCallStats.getApiName(), + apiCallStats.getLatencyMillis(), + apiCallStats.getResponseCode()); + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/ComputationRunner.java b/federatedcompute/src/com/android/federatedcompute/services/training/ComputationRunner.java index a80a266d..da42298b 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/ComputationRunner.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/ComputationRunner.java @@ -16,15 +16,14 @@ package com.android.federatedcompute.services.training; -import android.content.Context; import android.federatedcompute.aidl.IExampleStoreIterator; -import android.federatedcompute.aidl.IResultHandlingService; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; +import com.android.federatedcompute.services.examplestore.FederatedExampleIterator; +import com.android.federatedcompute.services.training.jni.FlRunnerWrapper; import com.android.federatedcompute.services.training.util.ListenableSupplier; import com.google.intelligence.fcp.client.FLRunnerResult; -import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; import com.google.internal.federated.plan.ClientOnlyPlan; import com.google.internal.federated.plan.ExampleSelector; @@ -33,15 +32,12 @@ import com.google.internal.federated.plan.ExampleSelector; * start federated ananlytic and federated training jobs. */ public class ComputationRunner { - private final String mPackageName; - public ComputationRunner(Context context) { - this.mPackageName = context.getPackageName(); - } + public ComputationRunner() {} /** Run a single round of federated computation. */ public FLRunnerResult runTaskWithNativeRunner( - int jobId, + String taskName, String populationName, String inputCheckpointFd, String outputCheckpointFd, @@ -49,11 +45,25 @@ public class ComputationRunner { ExampleSelector exampleSelector, ExampleConsumptionRecorder recorder, IExampleStoreIterator exampleStoreIterator, - IResultHandlingService resultHandlingService, ListenableSupplier<Boolean> interruptState) { - // TODO(b/241799297): add native fl runner to call fcp client. - return FLRunnerResult.newBuilder() - .setContributionResult(ContributionResult.SUCCESS) - .build(); + byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray(); + FederatedExampleIterator federatedExampleIterator = + new FederatedExampleIterator( + exampleStoreIterator, + resumptionToken, + recorder.createRecorderForTracking(taskName, resumptionToken)); + + FlRunnerWrapper flRunnerWrapper = + new FlRunnerWrapper(interruptState, populationName, federatedExampleIterator); + + FLRunnerResult runResult = + flRunnerWrapper.run( + taskName, + populationName, + clientOnlyPlan, + inputCheckpointFd, + outputCheckpointFd); + + return runResult; } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java index 15b2ad33..c991b916 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java @@ -16,128 +16,681 @@ package com.android.federatedcompute.services.training; +import static com.android.federatedcompute.services.common.Constants.CLIENT_ONLY_PLAN_FILE_NAME; +import static com.android.federatedcompute.services.common.Constants.ISOLATED_TRAINING_SERVICE_NAME; +import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor; +import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getLightweightExecutor; +import static com.android.federatedcompute.services.common.FileUtils.createTempFile; +import static com.android.federatedcompute.services.common.FileUtils.createTempFileDescriptor; + import android.annotation.NonNull; import android.annotation.Nullable; import android.content.Context; -import android.util.Log; +import android.federatedcompute.aidl.IExampleStoreCallback; +import android.federatedcompute.aidl.IExampleStoreIterator; +import android.federatedcompute.aidl.IExampleStoreService; +import android.federatedcompute.common.ClientConstants; +import android.federatedcompute.common.ExampleConsumption; +import android.os.Bundle; +import android.os.ParcelFileDescriptor; + +import androidx.concurrent.futures.CallbackToFutureAdapter; -import com.android.federatedcompute.services.common.Flags; -import com.android.federatedcompute.services.common.PhFlags; -import com.android.federatedcompute.services.common.TrainingResult; +import com.android.federatedcompute.internal.util.AbstractServiceBinder; +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.common.Constants; +import com.android.federatedcompute.services.common.FileUtils; import com.android.federatedcompute.services.data.FederatedTrainingTask; +import com.android.federatedcompute.services.data.fbs.TrainingConstraints; +import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; +import com.android.federatedcompute.services.http.CheckinResult; +import com.android.federatedcompute.services.http.HttpFederatedProtocol; import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager; -import com.android.federatedcompute.services.scheduling.SchedulingUtil; +import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService; +import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback; +import com.android.federatedcompute.services.training.util.ComputationResult; +import com.android.federatedcompute.services.training.util.ListenableSupplier; +import com.android.federatedcompute.services.training.util.TrainingConditionsChecker; +import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition; +import com.android.internal.annotations.GuardedBy; +import com.android.internal.util.Preconditions; import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.FluentFuture; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; +import com.google.intelligence.fcp.client.RetryInfo; import com.google.intelligence.fcp.client.engine.TaskRetry; +import com.google.internal.federated.plan.ClientOnlyPlan; +import com.google.internal.federated.plan.ExampleSelector; +import com.google.protobuf.InvalidProtocolBufferException; -import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; /** The worker to execute federated computation jobs. */ public class FederatedComputeWorker { - private static final String TAG = "FederatedComputeWorker"; - - static final Object LOCK = new Object(); + private static final String TAG = FederatedComputeWorker.class.getSimpleName(); + private static volatile FederatedComputeWorker sWorker; + private final Object mLock = new Object(); + private final AtomicBoolean mInterruptFlag = new AtomicBoolean(false); + private final ListenableSupplier<Boolean> mInterruptSupplier = + new ListenableSupplier<>(mInterruptFlag::get); + private final Context mContext; + @Nullable private final FederatedComputeJobManager mJobManager; + @Nullable private final TrainingConditionsChecker mTrainingConditionsChecker; + private final ComputationRunner mComputationRunner; + private final ResultCallbackHelper mResultCallbackHelper; + @NonNull private final Injector mInjector; - @GuardedBy("LOCK") + @GuardedBy("mLock") @Nullable private TrainingRun mActiveRun = null; - @Nullable private final FederatedComputeJobManager mJobManager; - private static volatile FederatedComputeWorker sFederatedComputeWorker; - private final Flags mFlags; + private HttpFederatedProtocol mHttpFederatedProtocol; + private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder; + private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder; @VisibleForTesting - public FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags) { + public FederatedComputeWorker( + Context context, + FederatedComputeJobManager jobManager, + TrainingConditionsChecker trainingConditionsChecker, + ComputationRunner computationRunner, + ResultCallbackHelper resultCallbackHelper, + Injector injector) { + this.mContext = context.getApplicationContext(); this.mJobManager = jobManager; - this.mFlags = flags; + this.mTrainingConditionsChecker = trainingConditionsChecker; + this.mComputationRunner = computationRunner; + this.mInjector = injector; + this.mResultCallbackHelper = resultCallbackHelper; } /** Gets an instance of {@link FederatedComputeWorker}. */ @NonNull public static FederatedComputeWorker getInstance(Context context) { - synchronized (FederatedComputeWorker.class) { - if (sFederatedComputeWorker == null) { - sFederatedComputeWorker = - new FederatedComputeWorker( - FederatedComputeJobManager.getInstance(context), - PhFlags.getInstance()); + if (sWorker == null) { + synchronized (FederatedComputeWorker.class) { + if (sWorker == null) { + sWorker = + new FederatedComputeWorker( + context, + FederatedComputeJobManager.getInstance(context), + TrainingConditionsChecker.getInstance(context), + new ComputationRunner(), + new ResultCallbackHelper(context), + new Injector()); + } } - return sFederatedComputeWorker; } + return sWorker; } /** Starts a training run with the given job Id. */ - public boolean startTrainingRun(int jobId) { - Log.d(TAG, "startTrainingRun()"); + public ListenableFuture<FLRunnerResult> startTrainingRun(int jobId) { + LogUtil.d(TAG, "startTrainingRun() %d", jobId); + return FluentFuture.from( + mInjector + .getBgExecutor() + .submit( + () -> { + return getTrainableTask(jobId); + })) + .transformAsync( + task -> { + if (task == null) { + return Futures.immediateFuture(null); + } + return startTrainingRun(jobId, task); + }, + mInjector.getBgExecutor()); + } + + private ListenableFuture<FLRunnerResult> startTrainingRun( + int jobId, FederatedTrainingTask trainingTask) { + synchronized (mLock) { + // Only allows one concurrent job running. + TrainingRun run = new TrainingRun(jobId, trainingTask); + mActiveRun = run; + ListenableFuture<FLRunnerResult> runCompletedFuture = doTraining(run); + var unused = + Futures.whenAllComplete(runCompletedFuture) + .call( + () -> { + unBindServicesIfNecessary(run); + return null; + }, + mInjector.getBgExecutor()); + run.mFuture = runCompletedFuture; + return runCompletedFuture; + } + } + + @Nullable + private FederatedTrainingTask getTrainableTask(int jobId) { FederatedTrainingTask trainingTask = mJobManager.onTrainingStarted(jobId); if (trainingTask == null) { - Log.i(TAG, String.format("Could not find task to run for job ID %s", jobId)); - return false; + LogUtil.i(TAG, "Could not find task to run for job ID %s", jobId); + return null; } - - synchronized (LOCK) { - // Only allow one concurrent federated computation job. + if (!checkTrainingConditions(trainingTask.getTrainingConstraints())) { + mJobManager.onTrainingCompleted( + jobId, + trainingTask.populationName(), + trainingTask.getTrainingIntervalOptions(), + /* taskRetry= */ null, + ContributionResult.FAIL); + LogUtil.i(TAG, "Training conditions not satisfied (before bindService)!"); + return null; + } + synchronized (mLock) { + // Only allows one concurrent job running. if (mActiveRun != null) { - Log.i( + LogUtil.i( TAG, - String.format( - "Delaying %d/%s another run is already active!", - jobId, trainingTask.populationName())); + "Delaying %d/%s another run is already active!", + jobId, + trainingTask.populationName()); mJobManager.onTrainingCompleted( jobId, trainingTask.populationName(), trainingTask.getTrainingIntervalOptions(), /* taskRetry= */ null, - TrainingResult.FAIL); - return false; + ContributionResult.FAIL); + return null; } - TrainingRun run = new TrainingRun(jobId, trainingTask); - this.mActiveRun = run; - doTraining(run); - // TODO: get retry info from federated server. - TaskRetry taskRetry = SchedulingUtil.generateTransientErrorTaskRetry(mFlags); - finish(this.mActiveRun, taskRetry, TrainingResult.SUCCESS); + return trainingTask; } - return true; } - /** Cancels the running job if present. */ - public void cancelActiveRun() { - Log.d(TAG, "cancelActiveRun()"); - synchronized (LOCK) { - if (mActiveRun == null) { - return; + private ListenableFuture<FLRunnerResult> doTraining(TrainingRun run) { + try { + // 1. Communicate with remote federated compute server to start task assignment and + // download client plan and initial model checkpoint. Note: use bLocking executors for + // http requests. + mHttpFederatedProtocol = + getHttpFederatedProtocol(run.mTask.serverAddress(), run.mTask.populationName()); + ListenableFuture<CheckinResult> checkinResultFuture = + mHttpFederatedProtocol.issueCheckin(); + + // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector. + ListenableFuture<IExampleStoreIterator> iteratorFuture = + FluentFuture.from(checkinResultFuture) + .transform( + result -> { + // Set active run's task name. + String taskName = result.getTaskAssignment().getTaskName(); + Preconditions.checkArgument( + !taskName.isEmpty(), + "Task name should not be empty"); + synchronized (mLock) { + mActiveRun.mTaskName = taskName; + } + return getExampleSelector(result); + }, + getLightweightExecutor()) + .transformAsync( + selector -> + getExampleStoreIterator( + run, + run.mTask.appPackageName(), + run.mTaskName, + selector), + mInjector.getBgExecutor()); + + // 3. Run federated learning or federated analytic depends on task type. Federated + // learning job will start a new isolated process to run TFLite training. + ListenableFuture<ComputationResult> computationResultFuture = + Futures.whenAllSucceed(checkinResultFuture, iteratorFuture) + .callAsync( + () -> + runFederatedComputation( + Futures.getDone(checkinResultFuture), + run, + Futures.getDone(iteratorFuture)), + mInjector.getBgExecutor()); + + // 4. Report computation result to federated compute server. + ListenableFuture<Void> reportToServerFuture = + FluentFuture.from(computationResultFuture) + .transformAsync( + result -> mHttpFederatedProtocol.reportResult(result), + getLightweightExecutor()); + return FluentFuture.from( + Futures.whenAllSucceed(reportToServerFuture, computationResultFuture) + .call( + () -> { + ComputationResult result = + Futures.getDone(computationResultFuture); + var reportToServer = Futures.getDone(reportToServerFuture); + // 5. Publish computation result and consumed + // examples to client implemented + // ResultHandlingService. + var unused = + mResultCallbackHelper.callHandleResult( + run.mTaskName, run.mTask, result); + return result.getFlRunnerResult(); + }, + mInjector.getBgExecutor())); + } catch (Exception e) { + return Futures.immediateFailedFuture(e); + } + } + + /** + * Completes the running job , schedule recurrent job, and unbind from ExampleStoreService and + * ResultHandlingService etc. + */ + public void finish(FLRunnerResult flRunnerResult) { + TaskRetry taskRetry = null; + if (flRunnerResult != null) { + if (flRunnerResult.hasRetryInfo()) { + RetryInfo retryInfo = flRunnerResult.getRetryInfo(); + long delay = retryInfo.getMinimumDelay().getSeconds() * 1000L; + taskRetry = + TaskRetry.newBuilder() + .setRetryToken(retryInfo.getRetryToken()) + .setDelayMin(delay) + .setDelayMax(delay) + .build(); + LogUtil.i(TAG, "Finished with task retry= %s", taskRetry); } - finish(mActiveRun, /* taskRetry= */ null, TrainingResult.FAIL); } + finish(taskRetry, flRunnerResult.getContributionResult(), true); } - private void finish( - TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult) { - synchronized (LOCK) { - if (mActiveRun != runToFinish) { + /** + * Cancel the current running job, schedule recurrent job, unbind from ExampleStoreService and + * ResultHandlingService etc. + */ + public void finish( + TaskRetry taskRetry, ContributionResult contributionResult, boolean cancelFuture) { + TrainingRun runToFinish; + synchronized (mLock) { + if (mActiveRun == null) { return; } + + runToFinish = mActiveRun; mActiveRun = null; - mJobManager.onTrainingCompleted( - runToFinish.mJobId, - runToFinish.mTask.populationName(), - runToFinish.mTask.getTrainingIntervalOptions(), - taskRetry, - trainingResult); + if (cancelFuture) { + runToFinish.mFuture.cancel(true); + } } + + mJobManager.onTrainingCompleted( + runToFinish.mJobId, + runToFinish.mTask.populationName(), + runToFinish.mTask.getTrainingIntervalOptions(), + taskRetry, + contributionResult); } - private void doTraining(TrainingRun run) { - // TODO: add training logic. - Log.d(TAG, "Start run training job " + run.mJobId); + private void unBindServicesIfNecessary(TrainingRun runToFinish) { + if (runToFinish.mIsolatedTrainingService != null) { + LogUtil.i(TAG, "Unbinding from IsolatedTrainingService"); + unbindFromIsolatedTrainingService(); + runToFinish.mIsolatedTrainingService = null; + } + if (runToFinish.mExampleStoreService != null) { + LogUtil.i(TAG, "Unbinding from ExampleStoreService"); + unbindFromExampleStoreService(); + runToFinish.mExampleStoreService = null; + } + } + + @VisibleForTesting + HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) { + return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName); + } + + private ExampleSelector getExampleSelector(CheckinResult checkinResult) { + ClientOnlyPlan clientPlan = checkinResult.getPlanData(); + switch (clientPlan.getPhase().getSpecCase()) { + case EXAMPLE_QUERY_SPEC: + // Only support one FA query for now. + return clientPlan + .getPhase() + .getExampleQuerySpec() + .getExampleQueries(0) + .getExampleSelector(); + case TENSORFLOW_SPEC: + return clientPlan.getPhase().getTensorflowSpec().getExampleSelector(); + default: + throw new IllegalArgumentException( + String.format( + "Client plan spec is not supported %s", + clientPlan.getPhase().getSpecCase().toString())); + } + } + + private boolean checkTrainingConditions(TrainingConstraints constraints) { + Set<Condition> conditions = + mTrainingConditionsChecker.checkAllConditionsForFlTraining(constraints); + for (Condition condition : conditions) { + switch (condition) { + case THERMALS_NOT_OK: + LogUtil.e(TAG, "training job service interrupt thermals not ok"); + break; + case BATTERY_NOT_OK: + LogUtil.e(TAG, "training job service interrupt battery not ok"); + break; + } + } + return conditions.isEmpty(); + } + + @VisibleForTesting + ListenableFuture<ComputationResult> runFlComputation( + TrainingRun run, + CheckinResult checkinResult, + String outputCheckpointFile, + IExampleStoreIterator iterator) { + ParcelFileDescriptor outputCheckpointFd = + createTempFileDescriptor( + outputCheckpointFile, ParcelFileDescriptor.MODE_READ_WRITE); + ParcelFileDescriptor inputCheckpointFd = + createTempFileDescriptor( + checkinResult.getInputCheckpointFile(), + ParcelFileDescriptor.MODE_READ_ONLY); + ExampleSelector exampleSelector = getExampleSelector(checkinResult); + ClientOnlyPlan clientPlan = checkinResult.getPlanData(); + if (clientPlan.getTfliteGraph().isEmpty()) { + LogUtil.e( + TAG, + "ClientOnlyPlan input tflite graph is empty." + + " population name: %s, task name: %s", + run.mTask.populationName(), + run.mTaskName); + return Futures.immediateFailedFuture( + new IllegalStateException("Client plan input tflite graph is empty")); + } + + try { + // Write ClientOnlyPlan to file and pass ParcelFileDescriptor to isolated process to + // avoid TransactionTooLargeException through IPC. + String clientOnlyPlanFile = createTempFile(CLIENT_ONLY_PLAN_FILE_NAME, ".pb"); + FileUtils.writeToFile(clientOnlyPlanFile, clientPlan.toByteArray()); + ParcelFileDescriptor clientPlanFd = + createTempFileDescriptor( + clientOnlyPlanFile, ParcelFileDescriptor.MODE_READ_ONLY); + IIsolatedTrainingService trainingService = getIsolatedTrainingService(); + if (trainingService == null) { + LogUtil.w(TAG, "Could not bind to IsolatedTrainingService"); + throw new IllegalStateException("Could not bind to IsolatedTrainingService"); + } + run.mIsolatedTrainingService = trainingService; + + Bundle bundle = new Bundle(); + bundle.putByteArray(Constants.EXTRA_EXAMPLE_SELECTOR, exampleSelector.toByteArray()); + bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName()); + bundle.putString(ClientConstants.EXTRA_TASK_NAME, run.mTaskName); + bundle.putParcelable(Constants.EXTRA_CLIENT_ONLY_PLAN_FD, clientPlanFd); + bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, inputCheckpointFd); + bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, outputCheckpointFd); + bundle.putBinder(Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, iterator.asBinder()); + + return FluentFuture.from(runIsolatedTrainingProcess(run, bundle)) + .transform( + result -> { + ComputationResult computationResult = + processIsolatedTrainingResult(outputCheckpointFile, result); + // Close opened file descriptor. + try { + if (outputCheckpointFd != null) { + outputCheckpointFd.close(); + } + if (inputCheckpointFd != null) { + inputCheckpointFd.close(); + } + } catch (IOException e) { + LogUtil.e(TAG, "Failed to close file descriptor", e); + } finally { + // Unbind from IsolatedTrainingService. + LogUtil.i(TAG, "Unbinding from IsolatedTrainingService"); + unbindFromIsolatedTrainingService(); + run.mIsolatedTrainingService = null; + } + return computationResult; + }, + getLightweightExecutor()); + } catch (Exception e) { + // Close opened file descriptor. + try { + if (outputCheckpointFd != null) { + outputCheckpointFd.close(); + } + if (inputCheckpointFd != null) { + inputCheckpointFd.close(); + } + } catch (IOException t) { + LogUtil.e(TAG, t, "Failed to close file descriptor"); + } finally { + // Unbind from IsolatedTrainingService. + LogUtil.i(TAG, "Unbinding from IsolatedTrainingService"); + unbindFromIsolatedTrainingService(); + run.mIsolatedTrainingService = null; + } + return Futures.immediateFailedFuture(e); + } + } + + private ComputationResult processIsolatedTrainingResult( + String outputCheckpoint, Bundle result) { + byte[] resultBytes = + Objects.requireNonNull(result.getByteArray(Constants.EXTRA_FL_RUNNER_RESULT)); + FLRunnerResult flRunnerResult; + try { + flRunnerResult = FLRunnerResult.parseFrom(resultBytes); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException(e); + } + if (flRunnerResult.getContributionResult() == ContributionResult.FAIL) { + return new ComputationResult(outputCheckpoint, flRunnerResult, new ArrayList<>()); + } + ArrayList<ExampleConsumption> exampleList = + result.getParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, ExampleConsumption.class); + if (exampleList == null || exampleList.isEmpty()) { + throw new IllegalArgumentException("example consumption list should not be empty"); + } + + return new ComputationResult(outputCheckpoint, flRunnerResult, exampleList); + } + + private ListenableFuture<Bundle> runIsolatedTrainingProcess(TrainingRun run, Bundle input) { + return CallbackToFutureAdapter.getFuture( + completer -> { + try { + run.mIsolatedTrainingService.runFlTraining( + input, + new ITrainingResultCallback.Stub() { + @Override + public void onResult(Bundle result) { + completer.set(result); + } + }); + } catch (Exception e) { + LogUtil.e(TAG, e, "Got exception when runIsolatedTrainingProcess"); + completer.setException(e); + } + return "runIsolatedTrainingProcess"; + }); + } + + private ListenableFuture<ComputationResult> runFederatedComputation( + CheckinResult checkinResult, TrainingRun run, IExampleStoreIterator iterator) { + ClientOnlyPlan clientPlan = checkinResult.getPlanData(); + String outputCheckpointFile = createTempFile("output", ".ckp"); + + ListenableFuture<ComputationResult> computationResultFuture; + switch (clientPlan.getPhase().getSpecCase()) { + case EXAMPLE_QUERY_SPEC: + computationResultFuture = + runFAComputation(run, checkinResult, outputCheckpointFile, iterator); + break; + case TENSORFLOW_SPEC: + computationResultFuture = + runFlComputation(run, checkinResult, outputCheckpointFile, iterator); + break; + default: + return Futures.immediateFailedFuture( + new IllegalArgumentException( + String.format( + "Client plan spec is not supported %s", + clientPlan.getPhase().getSpecCase().toString()))); + } + return computationResultFuture; + } + + private ListenableFuture<ComputationResult> runFAComputation( + TrainingRun run, + CheckinResult checkinResult, + String outputCheckpointFile, + IExampleStoreIterator exampleStoreIterator) { + ExampleSelector exampleSelector = getExampleSelector(checkinResult); + ClientOnlyPlan clientPlan = checkinResult.getPlanData(); + // The federated analytic runs in main process which has permission to file system. + ExampleConsumptionRecorder recorder = mInjector.getExampleConsumptionRecorder(); + FLRunnerResult runResult = + mComputationRunner.runTaskWithNativeRunner( + run.mTaskName, + run.mTask.populationName(), + checkinResult.getInputCheckpointFile(), + outputCheckpointFile, + clientPlan, + exampleSelector, + recorder, + exampleStoreIterator, + mInterruptSupplier); + ArrayList<ExampleConsumption> exampleConsumptions = recorder.finishRecordingAndGet(); + return Futures.immediateFuture( + new ComputationResult(outputCheckpointFile, runResult, exampleConsumptions)); + } + + @VisibleForTesting + IExampleStoreService getExampleStoreService(String packageName) { + mExampleStoreServiceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + mContext, + ClientConstants.EXAMPLE_STORE_ACTION, + packageName, + IExampleStoreService.Stub::asInterface); + return mExampleStoreServiceBinder.getService(Runnable::run); + } + + @VisibleForTesting + void unbindFromExampleStoreService() { + mExampleStoreServiceBinder.unbindFromService(); + } + + private ListenableFuture<IExampleStoreIterator> runExampleStoreStartQuery( + TrainingRun run, Bundle input) { + return CallbackToFutureAdapter.getFuture( + completer -> { + try { + run.mExampleStoreService.startQuery( + input, + new IExampleStoreCallback.Stub() { + @Override + public void onStartQuerySuccess( + IExampleStoreIterator iterator) { + LogUtil.d(TAG, "Acquire iterator"); + completer.set(iterator); + } + + @Override + public void onStartQueryFailure(int errorCode) { + LogUtil.e(TAG, "Could not acquire iterator: " + errorCode); + completer.setException( + new IllegalStateException( + "StartQuery failed: " + errorCode)); + } + }); + } catch (Exception e) { + completer.setException(e); + } + return "runExampleStoreStartQuery"; + }); + } + + private ListenableFuture<IExampleStoreIterator> getExampleStoreIterator( + TrainingRun run, String packageName, String taskName, ExampleSelector exampleSelector) { + try { + run.mTaskName = taskName; + + IExampleStoreService exampleStoreService = getExampleStoreService(packageName); + if (exampleStoreService == null) { + return Futures.immediateFailedFuture( + new IllegalStateException( + "Could not bind to ExampleStoreService " + packageName)); + } + run.mExampleStoreService = exampleStoreService; + + byte[] criteria = exampleSelector.getCriteria().toByteArray(); + byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray(); + Bundle bundle = new Bundle(); + bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, run.mTask.populationName()); + bundle.putString(ClientConstants.EXTRA_TASK_NAME, taskName); + bundle.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, run.mTask.contextData()); + bundle.putByteArray( + ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken); + bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria); + + return runExampleStoreStartQuery(run, bundle); + } catch (Exception e) { + LogUtil.e(TAG, "StartQuery failure: " + e.getMessage()); + return Futures.immediateFailedFuture(e); + } + } + + @VisibleForTesting + @Nullable + IIsolatedTrainingService getIsolatedTrainingService() { + mIsolatedTrainingServiceBinder = + AbstractServiceBinder.getServiceBinderByServiceName( + mContext, + ISOLATED_TRAINING_SERVICE_NAME, + mContext.getPackageName(), + IIsolatedTrainingService.Stub::asInterface); + return mIsolatedTrainingServiceBinder.getService(Runnable::run); + } + + @VisibleForTesting + void unbindFromIsolatedTrainingService() { + mIsolatedTrainingServiceBinder.unbindFromService(); + } + + @VisibleForTesting + static class Injector { + ExampleConsumptionRecorder getExampleConsumptionRecorder() { + return new ExampleConsumptionRecorder(); + } + + ListeningExecutorService getBgExecutor() { + return getBackgroundExecutor(); + } } private static final class TrainingRun { private final int mJobId; + + private String mTaskName; private final FederatedTrainingTask mTask; + @Nullable private ListenableFuture<?> mFuture; + + @Nullable private IIsolatedTrainingService mIsolatedTrainingService = null; + + @Nullable private IExampleStoreService mExampleStoreService = null; + private TrainingRun(int jobId, FederatedTrainingTask task) { this.mJobId = jobId; this.mTask = task; diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java index 9f31127f..d15b37ea 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedJobService.java @@ -16,66 +16,64 @@ package com.android.federatedcompute.services.training; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static com.android.federatedcompute.services.common.FederatedComputeExecutors.getBackgroundExecutor; import android.app.job.JobParameters; import android.app.job.JobService; -import android.util.Log; -import com.android.federatedcompute.services.common.FederatedComputeExecutors; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.FlagsFactory; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; /** Main service for the scheduled federated computation jobs. */ public class FederatedJobService extends JobService { - private static final String TAG = "FederatedJobService"; - private ListenableFuture<Boolean> mRunCompleteFuture; + private static final String TAG = FederatedJobService.class.getSimpleName(); @Override public boolean onStartJob(JobParameters params) { - Log.d(TAG, "FederatedJobService.onStartJob"); + LogUtil.d(TAG, "FederatedJobService.onStartJob"); if (FlagsFactory.getFlags().getGlobalKillSwitch()) { - Log.d(TAG, "GlobalKillSwitch enabled, finishing job."); + LogUtil.d(TAG, "GlobalKillSwitch enabled, finishing job."); jobFinished(params, /* wantsReschedule= */ false); return true; } - mRunCompleteFuture = - Futures.submit( - () -> - FederatedComputeWorker.getInstance(this) - .startTrainingRun(params.getJobId()), - FederatedComputeExecutors.getBackgroundExecutor()); + FederatedComputeWorker worker = FederatedComputeWorker.getInstance(this); + ListenableFuture<FLRunnerResult> runCompleteFuture = + worker.startTrainingRun(params.getJobId()); Futures.addCallback( - mRunCompleteFuture, - new FutureCallback<Boolean>() { + runCompleteFuture, + new FutureCallback<FLRunnerResult>() { @Override - public void onSuccess(Boolean result) { - Log.d(TAG, "federated computation job is done!"); + public void onSuccess(FLRunnerResult flRunnerResult) { + LogUtil.d(TAG, "Federated computation job %d is done!", params.getJobId()); jobFinished(params, /* wantsReschedule= */ false); + if (flRunnerResult != null) { + worker.finish(flRunnerResult); + } } @Override public void onFailure(Throwable t) { - Log.e(TAG, "Failed to handle computation job: " + params.getJobId()); + LogUtil.e( + TAG, t, "Failed to handle computation job: %d", params.getJobId()); jobFinished(params, /* wantsReschedule= */ false); + worker.finish(null, ContributionResult.FAIL, false); } }, - directExecutor()); + getBackgroundExecutor()); return true; } @Override public boolean onStopJob(JobParameters params) { - if (mRunCompleteFuture != null) { - mRunCompleteFuture.cancel(true); - } - FederatedComputeWorker.getInstance(this).cancelActiveRun(); - // Reschedule the job since it's not done. TODO: we should implement specify reschedule - // logic instead. - return true; + LogUtil.d(TAG, "FederatedJobService.onStopJob %d", params.getJobId()); + FederatedComputeWorker.getInstance(this).finish(null, ContributionResult.FAIL, true); + return false; } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingService.java b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingService.java index 1d98500e..31363db4 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingService.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingService.java @@ -33,7 +33,7 @@ public class IsolatedTrainingService extends Service { @Override public void onCreate() { - mBinder = new IsolatedTrainingServiceImpl(this); + mBinder = new IsolatedTrainingServiceImpl(); } @Override diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java index 002b4b38..ebaf7ef0 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImpl.java @@ -16,16 +16,18 @@ package com.android.federatedcompute.services.training; -import android.content.Context; +import android.annotation.NonNull; import android.federatedcompute.aidl.IExampleStoreIterator; -import android.federatedcompute.aidl.IResultHandlingService; +import android.federatedcompute.common.ClientConstants; +import android.federatedcompute.common.ExampleConsumption; import android.os.Bundle; import android.os.ParcelFileDescriptor; import android.os.RemoteException; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.Constants; import com.android.federatedcompute.services.common.FederatedComputeExecutors; +import com.android.federatedcompute.services.common.FileUtils; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService; import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback; @@ -41,23 +43,22 @@ import com.google.internal.federated.plan.ClientOnlyPlan; import com.google.internal.federated.plan.ExampleSelector; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; -import javax.annotation.Nonnull; - /** The implementation of {@link IsolatedTrainingService}. */ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { - private static final String TAG = "IsolatedTrainingServiceImpl"; + private static final String TAG = IsolatedTrainingServiceImpl.class.getSimpleName(); private final AtomicBoolean mInterruptFlag = new AtomicBoolean(false); @VisibleForTesting ListenableSupplier<Boolean> mInterruptState = new ListenableSupplier<>(mInterruptFlag::get); - private ComputationRunner mComputationRunner; + private final ComputationRunner mComputationRunner; - public IsolatedTrainingServiceImpl(Context context) { - mComputationRunner = new ComputationRunner(context); + public IsolatedTrainingServiceImpl() { + mComputationRunner = new ComputationRunner(); } @VisibleForTesting @@ -66,20 +67,16 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { } @Override - public void runFlTraining(@Nonnull Bundle params, @Nonnull ITrainingResultCallback callback) { + public void runFlTraining(@NonNull Bundle params, @NonNull ITrainingResultCallback callback) { Objects.requireNonNull(params); Objects.requireNonNull(callback); + IExampleStoreIterator exampleStoreIteratorBinder = IExampleStoreIterator.Stub.asInterface( Objects.requireNonNull( params.getBinder(Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER))); Objects.requireNonNull(exampleStoreIteratorBinder); - IResultHandlingService resultHandlingServiceBinder = - IResultHandlingService.Stub.asInterface( - Objects.requireNonNull( - params.getBinder(Constants.EXTRA_RESULT_HANDLING_SERVICE_BINDER))); - Objects.requireNonNull(resultHandlingServiceBinder); - ExampleConsumptionRecorder recorder = new ExampleConsumptionRecorder(); + byte[] exampleSelectorBytes = Objects.requireNonNull(params.getByteArray(Constants.EXTRA_EXAMPLE_SELECTOR)); ExampleSelector exampleSelector; @@ -88,8 +85,11 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { } catch (InvalidProtocolBufferException e) { throw new IllegalArgumentException("ExampleSelector proto is invalid", e); } + ExampleConsumptionRecorder recorder = new ExampleConsumptionRecorder(); String populationName = - Objects.requireNonNull(params.getString(Constants.EXTRA_POPULATION_NAME)); + Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME)); + String taskName = Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_NAME)); + ParcelFileDescriptor inputCheckpointFd = Objects.requireNonNull( params.getParcelable( @@ -98,9 +98,12 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { Objects.requireNonNull( params.getParcelable( Constants.EXTRA_OUTPUT_CHECKPOINT_FD, ParcelFileDescriptor.class)); - int jobId = params.getInt(Constants.EXTRA_JOB_ID); - byte[] clientPlanBytes = - Objects.requireNonNull(params.getByteArray(Constants.EXTRA_CLIENT_ONLY_PLAN)); + ParcelFileDescriptor clientPlanFd = + Objects.requireNonNull( + params.getParcelable( + Constants.EXTRA_CLIENT_ONLY_PLAN_FD, ParcelFileDescriptor.class)); + + byte[] clientPlanBytes = FileUtils.readFileDescriptorAsByteArray(clientPlanFd); ClientOnlyPlan clientPlan; try { clientPlan = ClientOnlyPlan.parseFrom(clientPlanBytes); @@ -112,7 +115,7 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { Futures.submit( () -> mComputationRunner.runTaskWithNativeRunner( - jobId, + taskName, populationName, getFileDescriptorForTensorflow(inputCheckpointFd), getFileDescriptorForTensorflow(outputCheckpointFd), @@ -120,7 +123,6 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { exampleSelector, recorder, exampleStoreIteratorBinder, - resultHandlingServiceBinder, mInterruptState), FederatedComputeExecutors.getBackgroundExecutor()); @@ -129,17 +131,32 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { new FutureCallback<FLRunnerResult>() { @Override public void onSuccess(FLRunnerResult result) { - sendResult(result, callback); + Bundle bundle = new Bundle(); + bundle.putByteArray(Constants.EXTRA_FL_RUNNER_RESULT, result.toByteArray()); + ArrayList<ExampleConsumption> exampleConsumptionArrayList = + recorder.finishRecordingAndGet(); + LogUtil.i( + TAG, + "training task %s: result %s, used %d examples", + populationName, + result.toString(), + exampleConsumptionArrayList.size()); + bundle.putParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, + exampleConsumptionArrayList); + sendResult(bundle, callback); } @Override public void onFailure(Throwable t) { - Log.e(TAG, "Failed to runTaskWithNativeRunner", t); + LogUtil.e(TAG, t, "Failed to runTaskWithNativeRunner"); + Bundle bundle = new Bundle(); FLRunnerResult result = FLRunnerResult.newBuilder() .setContributionResult(ContributionResult.FAIL) .build(); - sendResult(result, callback); + bundle.putByteArray(Constants.EXTRA_FL_RUNNER_RESULT, result.toByteArray()); + sendResult(bundle, callback); } }, FederatedComputeExecutors.getLightweightExecutor()); @@ -151,13 +168,11 @@ public class IsolatedTrainingServiceImpl extends IIsolatedTrainingService.Stub { return "fd:///" + parcelFileDescriptor.getFd(); } - private void sendResult(FLRunnerResult result, ITrainingResultCallback callback) { - Bundle bundle = new Bundle(); - bundle.putByteArray(Constants.EXTRA_FL_RUNNER_RESULT, result.toByteArray()); + private void sendResult(Bundle result, ITrainingResultCallback callback) { try { - callback.onResult(bundle); + callback.onResult(result); } catch (RemoteException e) { - Log.w(TAG + ": Callback failed ", e); + LogUtil.w(TAG, e, ": Callback failed "); } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java b/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java index f1477135..30ae9d90 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/ResultCallbackHelper.java @@ -16,33 +16,35 @@ package com.android.federatedcompute.services.training; +import static android.federatedcompute.common.ClientConstants.RESULT_HANDLING_SERVICE_ACTION; import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; +import static android.federatedcompute.common.ClientConstants.STATUS_TRAINING_FAILED; +import android.content.Context; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IResultHandlingService; -import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingInterval; -import android.federatedcompute.common.TrainingOptions; -import android.os.RemoteException; -import android.util.Log; +import android.federatedcompute.common.ClientConstants; +import android.os.Bundle; -import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; +import com.android.federatedcompute.internal.util.AbstractServiceBinder; +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.data.FederatedTrainingTask; +import com.android.federatedcompute.services.training.util.ComputationResult; import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.SettableFuture; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; /** * A helper class for binding to client implemented ResultHandlingService and trigger handleResult. */ public class ResultCallbackHelper { - private static final String TAG = "ResultCallbackHelper"; - private static final long RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS = 60 * 9 + 45; + private static final String TAG = ResultCallbackHelper.class.getSimpleName(); + private static final long RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS = 10; /** The outcome of the result handling. */ public enum CallbackResult { @@ -54,105 +56,85 @@ public class ResultCallbackHelper { NEEDS_RESUME, } - private final List<ExampleConsumption> mExampleConsumptions; - private final IResultHandlingService mResultHandlingService; - private final long mResultHandlingServiceCallbackTimeoutSecs; - - public ResultCallbackHelper( - List<ExampleConsumption> exampleConsumptions, - IResultHandlingService resultHandlingService) { - this.mExampleConsumptions = exampleConsumptions; - this.mResultHandlingService = resultHandlingService; - this.mResultHandlingServiceCallbackTimeoutSecs = - RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS; - } + private final Context mContext; + private AbstractServiceBinder<IResultHandlingService> mResultHandlingServiceBinder; - @VisibleForTesting - ResultCallbackHelper( - List<ExampleConsumption> exampleConsumptions, - IResultHandlingService resultHandlingService, - long resultHandlingServiceCallbackTimeoutSecs) { - this.mExampleConsumptions = exampleConsumptions; - this.mResultHandlingService = resultHandlingService; - this.mResultHandlingServiceCallbackTimeoutSecs = resultHandlingServiceCallbackTimeoutSecs; + public ResultCallbackHelper(Context context) { + this.mContext = context.getApplicationContext(); } - public CallbackResult callHandleResult( - int jobId, String populationName, byte[] intervalOptions, boolean success) { - - SettableFuture<Integer> errorCodeFuture = SettableFuture.create(); - IFederatedComputeCallback callback = - new IFederatedComputeCallback.Stub() { - @Override - public void onSuccess() { - errorCodeFuture.set(STATUS_SUCCESS); - } - - @Override - public void onFailure(int errorCode) { - errorCodeFuture.set(errorCode); - } - }; + /** + * Publishes the training result and example list to client implemented ResultHandlingService. + */ + public ListenableFuture<CallbackResult> callHandleResult( + String taskName, FederatedTrainingTask task, ComputationResult result) { + Bundle input = new Bundle(); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName()); + input.putString(ClientConstants.EXTRA_TASK_NAME, taskName); + input.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData()); + input.putInt( + ClientConstants.EXTRA_COMPUTATION_RESULT, + result.isResultSuccess() ? STATUS_SUCCESS : STATUS_TRAINING_FAILED); + input.putParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, result.getExampleConsumptionList()); + try { - mResultHandlingService.handleResult( - buildTrainingOptions(jobId, populationName, intervalOptions), - success, - mExampleConsumptions, - callback); + IResultHandlingService resultHandlingService = + getResultHandlingService(task.appPackageName()); + if (resultHandlingService == null) { + LogUtil.e( + TAG, + "ResultHandlingService binding died. population name: " + + task.populationName()); + return Futures.immediateFuture(CallbackResult.FAIL); + } + + BlockingQueue<Integer> asyncResult = new ArrayBlockingQueue<>(1); + resultHandlingService.handleResult( + input, + new IFederatedComputeCallback.Stub() { + @Override + public void onSuccess() { + asyncResult.add(STATUS_SUCCESS); + } + + @Override + public void onFailure(int errorCode) { + asyncResult.add(errorCode); + } + }); int statusCode = - errorCodeFuture.get( - mResultHandlingServiceCallbackTimeoutSecs, TimeUnit.SECONDS); - return statusCode == STATUS_SUCCESS ? CallbackResult.SUCCESS : CallbackResult.FAIL; - } catch (RemoteException e) { - Log.e( - TAG, - String.format( - "ResultHandlingService binding died. population name: %s", - populationName), - e); - return CallbackResult.FAIL; - } catch (InterruptedException interruptedException) { - Log.e( - TAG, - String.format( - "ResultHandlingService callback interrupted. population name: %s", - populationName), - interruptedException); - return CallbackResult.FAIL; - } catch (ExecutionException e) { - Log.e( + asyncResult.poll( + RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS, TimeUnit.SECONDS); + CallbackResult callbackResult = + statusCode == STATUS_SUCCESS ? CallbackResult.SUCCESS : CallbackResult.FAIL; + return Futures.immediateFuture(callbackResult); + } catch (Exception e) { + LogUtil.e( TAG, - String.format( - "ResultHandlingService callback failed. population name: %s", - populationName), - e); - return CallbackResult.FAIL; - } catch (TimeoutException e) { - Log.e( - TAG, - String.format( - "ResultHandlingService callback timed out %d population name: %s", - mResultHandlingServiceCallbackTimeoutSecs, populationName), - e); + e, + "ResultHandlingService binding died. population name: %s", + task.populationName()); + // We publish result to client app with best effort and should not crash flow. + return Futures.immediateFuture(CallbackResult.FAIL); + } finally { + unbindFromResultHandlingService(); } - return CallbackResult.FAIL; } - private TrainingOptions buildTrainingOptions( - int jobId, String populationName, byte[] intervalBytes) { - TrainingIntervalOptions intervalOptions = - TrainingIntervalOptions.getRootAsTrainingIntervalOptions( - ByteBuffer.wrap(intervalBytes)); - TrainingOptions.Builder trainingOptionsBuilder = new TrainingOptions.Builder(); - trainingOptionsBuilder.setPopulationName(populationName); - if (intervalOptions != null) { - TrainingInterval interval = - new TrainingInterval.Builder() - .setSchedulingMode(intervalOptions.schedulingMode()) - .setMinimumIntervalMillis(intervalOptions.minIntervalMillis()) - .build(); - trainingOptionsBuilder.setTrainingInterval(interval); - } - return trainingOptionsBuilder.build(); + @VisibleForTesting + IResultHandlingService getResultHandlingService(String appPackageName) { + mResultHandlingServiceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + this.mContext, + RESULT_HANDLING_SERVICE_ACTION, + appPackageName, + IResultHandlingService.Stub::asInterface); + return mResultHandlingServiceBinder.getService(Runnable::run); + } + + @VisibleForTesting + void unbindFromResultHandlingService() { + mResultHandlingServiceBinder.unbindFromService(); } } diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/ResultHandlingServiceProvider.java b/federatedcompute/src/com/android/federatedcompute/services/training/ResultHandlingServiceProvider.java deleted file mode 100644 index 9a128fa7..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/training/ResultHandlingServiceProvider.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.training; - -import android.annotation.Nullable; -import android.content.Intent; -import android.federatedcompute.aidl.IResultHandlingService; - -/** Interface used to provide a reference to the IResultHandlingService. */ -public interface ResultHandlingServiceProvider { - - /** Returns the connected ResultHandlingService, or otherwise {@code null}. */ - @Nullable - IResultHandlingService getResultHandlingService(); - - /** Bind to and establish a connection with client implemented ResultHandlingService. */ - boolean bindService(Intent intent); - - /** Unbind from the client implemented ResultHandlingService. */ - void unbindService(); -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/ResultHandlingServiceProviderImpl.java b/federatedcompute/src/com/android/federatedcompute/services/training/ResultHandlingServiceProviderImpl.java deleted file mode 100644 index 392591b8..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/training/ResultHandlingServiceProviderImpl.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.training; - -import android.annotation.Nullable; -import android.content.ComponentName; -import android.content.Context; -import android.content.Intent; -import android.content.ServiceConnection; -import android.federatedcompute.aidl.IResultHandlingService; -import android.os.IBinder; -import android.util.Log; - -import com.android.federatedcompute.services.common.Flags; - -import com.google.common.util.concurrent.SettableFuture; -import com.google.common.util.concurrent.UncheckedExecutionException; - -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -/** Implementation of the ResultHandlingServiceProvider interface. */ -public final class ResultHandlingServiceProviderImpl implements ResultHandlingServiceProvider { - private static final String TAG = "ResultHandlingServiceProviderImpl"; - private static final Executor SINGLE_THREAD_EXECUTOR = Executors.newSingleThreadExecutor(); - private final Context mContext; - private IResultHandlingService mResultHandlingService; - private SettableFuture<IResultHandlingService> mResultHandlingServiceFuture = - SettableFuture.create(); - private boolean mBound; - private Flags mFlags; - - public ResultHandlingServiceProviderImpl(Context context, Flags flags) { - this.mContext = context; - this.mFlags = flags; - } - - ServiceConnection mServiceConnection = - new ServiceConnection() { - @Override - public void onServiceConnected(ComponentName name, IBinder service) { - if (service == null) { - Log.e(TAG, "onServiceConnected() received null binder"); - return; - } - mResultHandlingServiceFuture.set( - IResultHandlingService.Stub.asInterface(service)); - mBound = true; - } - - @Override - public void onServiceDisconnected(ComponentName name) { - reset(); - Log.d(TAG, "Connection unexpectedly disconnected"); - } - }; - - @Override - @Nullable - public IResultHandlingService getResultHandlingService() { - return mResultHandlingService; - } - - @Override - public boolean bindService(Intent intent) { - if (!mContext.bindService( - intent, Context.BIND_AUTO_CREATE, SINGLE_THREAD_EXECUTOR, mServiceConnection)) { - Log.e(TAG, "Unable to bind to ResultHandlingService intent: " + intent); - return false; - } - try { - mResultHandlingService = - mResultHandlingServiceFuture.get( - mFlags.getResultHandlingBindServiceTimeoutSecs(), TimeUnit.SECONDS); - } catch (TimeoutException e) { - throw new IllegalStateException( - String.format( - "Service connection time out (%ss) for app hosted" - + " ResultHandlingService.", - mFlags.getResultHandlingBindServiceTimeoutSecs()), - e); - } catch (ExecutionException e) { - throw new UncheckedExecutionException(e); - } catch (InterruptedException e) { - Log.e(TAG, "ResultHandlingService interrupted", e); - unbindService(); - return false; - } - return true; - } - - @Override - public void unbindService() { - if (mBound) { - mContext.unbindService(mServiceConnection); - reset(); - } - } - - private void reset() { - mResultHandlingServiceFuture = SettableFuture.create(); - mResultHandlingService = null; - mBound = false; - } -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/FlRunnerWrapper.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/FlRunnerWrapper.java new file mode 100644 index 00000000..33ec656f --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/FlRunnerWrapper.java @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.examplestore.ExampleIterator; +import com.android.federatedcompute.services.training.util.ListenableSupplier; + +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.internal.federated.plan.ClientOnlyPlan; +import com.google.protobuf.InvalidProtocolBufferException; + +import java.io.Closeable; + +import javax.annotation.Nullable; + +/** Runs a federated computation using C++ code. */ +public final class FlRunnerWrapper implements Closeable { + private static final String TAG = FlRunnerWrapper.class.getSimpleName(); + + static { + System.loadLibrary("fcp_cpp_dep_jni"); + } + + private final LogManager mLogManager; + + private final ListenableSupplier<Boolean> mInterruptionFlag; + private final ExampleIterator mExampleIterator; + + public FlRunnerWrapper( + ListenableSupplier<Boolean> interruptionFlag, + String populationName, + ExampleIterator exampleIterator) { + this.mLogManager = new LogManagerImpl(populationName); + this.mInterruptionFlag = interruptionFlag; + this.mExampleIterator = exampleIterator; + } + + /** Starts run a federated computation job. */ + public FLRunnerResult run( + String taskName, + String populationName, + ClientOnlyPlan clientOnlyPlan, + String checkpointInputFileName, + String checkpointOutputFileName) { + SimpleTaskEnvironmentImpl simpleTaskEnv = + new SimpleTaskEnvironmentImpl(mInterruptionFlag, mExampleIterator); + byte[] flRunnerResultSerialized = + runNativeFederatedComputation( + simpleTaskEnv, + populationName, + // Session name is optional and mainly used by legacy customers + "", + taskName, + mLogManager, + clientOnlyPlan.toByteArray(), + checkpointInputFileName, + checkpointOutputFileName); + try { + return FLRunnerResult.parseFrom(flRunnerResultSerialized); + } catch (InvalidProtocolBufferException e) { + // Promote to a RuntimeException, this should never happen and if it does, we shouldn't + // recover from it. + LogUtil.e(TAG, "Cannot parse FLRunnerResult", e); + throw new IllegalArgumentException(e); + } + } + + @Override + public void close() {} + + @Nullable + static native byte[] runNativeFederatedComputation( + SimpleTaskEnvironment simpleTaskEnvironment, + String populationName, + String sessionName, + String taskName, + LogManager logManager, + byte[] clientOnlyPlan, + String checkpointInputFileName, + String checkpointOutputFileName); +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/JavaExampleIterator.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/JavaExampleIterator.java new file mode 100644 index 00000000..b63f25ba --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/JavaExampleIterator.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import java.io.Closeable; + +/** This class offers a task environment to C++ code. */ +public interface JavaExampleIterator extends Closeable { + /** + * Returns the next training example which is typically a serialized tf.Example, but can really + * be any sort of serialized data that the TensorFlow graph can process in batches of DT_STRING + * tensors. + */ + byte[] next() throws Exception; + + /** + * Called by C++ when the iterator is no longer used and can be closed. May be called multiple + * times, by C++ and Java. + */ + @Override + void close(); +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/JavaExampleStore.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/JavaExampleStore.java new file mode 100644 index 00000000..200bded2 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/JavaExampleStore.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.common.ErrorStatusException; +import com.android.federatedcompute.services.examplestore.ExampleIterator; + +import com.google.intelligence.fcp.client.SelectorContext; +import com.google.internal.federated.plan.ExampleSelector; +import com.google.protobuf.InvalidProtocolBufferException; + +import java.io.Closeable; +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.concurrent.GuardedBy; + +/** + * Java class that forms the connective tissue between C++ code and Java example stores. C++ code + * (e.g. tflite plan engine, or Dataset op) uses this class to create example iterators and retrieve + * examples, and the class takes care of closing leftover iterators as well as storing unexpected + * exceptions for later re-throwing. + */ +public class JavaExampleStore implements Closeable { + private static final String TAG = JavaExampleStore.class.getSimpleName(); + + private final ExampleIterator mExampleIterator; + private final Object mIteratorLock = new Object(); + + @GuardedBy("mIteratorLock") + private final List<ExampleIterator> mServedExampleIterators; + + public JavaExampleStore(ExampleIterator exampleIterator) { + this.mExampleIterator = exampleIterator; + this.mServedExampleIterators = new ArrayList<>(); + } + + /** Creates an ExampleIterator based on provided contexts. */ + public JavaExampleIterator createExampleIteratorWithContext( + byte[] exampleSelector, byte[] selectorContext) { + ExampleSelector selector; + // 1. Deserialize the ExampleSelector. The ExampleIterator is already validated + // and created ahead. We only do validation here but not crash if it goes wrong. + try { + selector = ExampleSelector.parseFrom(exampleSelector); + SelectorContext.parseFrom(selectorContext); + } catch (InvalidProtocolBufferException e) { + LogUtil.e(TAG, "Invalid protobuf message", e); + } + + // 2. Stores ExampleIterator in a list so we can close it after training. + synchronized (mIteratorLock) { + mServedExampleIterators.add(this.mExampleIterator); + } + // 3. Wrap the ExampleIterator in a {@link JavaExampleIterator} that translates + // exceptions. + JavaExampleIterator javaExampleIterator = + new JavaExampleIterator() { + @GuardedBy("mIteratorLock") + final com.android.federatedcompute.services.examplestore.ExampleIterator + mIterator = mExampleIterator; + + @Override + public byte[] next() throws InterruptedException, ErrorStatusException { + synchronized (mIteratorLock) { + try { + boolean hasNext = mIterator.hasNext(); + if (!hasNext) { + return new byte[0]; + } + return mIterator.next(); + } catch (InterruptedException e) { + LogUtil.e(TAG, "ExampleStore.next()", e); + throw e; + } catch (ErrorStatusException e) { + LogUtil.e(TAG, "ExampleStore.next()", e); + throw e; + } + } + } + + @Override + public void close() { + + // Avoid closing an iterator twice. We do this by + // keeping open iterators in {@link + // mServedExampleIterators} and removing them when + // closing. If the list does not contain the + // iterator anymore, it has already been closed, and + // we avoid closing it again. + boolean iteratorOpen; + synchronized (mIteratorLock) { + iteratorOpen = mServedExampleIterators.remove(mExampleIterator); + if (iteratorOpen) { + mIterator.close(); + } + } + } + }; + return javaExampleIterator; + } + + @Override + public void close() { + // Close remaining open iterators, if any. This can happen when C++ code fails + // and returns via an error path that does not close the iterators. + synchronized (mIteratorLock) { + for (ExampleIterator exampleIterator : mServedExampleIterators) { + // TODO(b/283309324): add metrics to track iterator left open case. + LogUtil.e(TAG, "Close left open iterator"); + exampleIterator.close(); + } + } + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/LogManager.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/LogManager.java new file mode 100644 index 00000000..2fd1483b --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/LogManager.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import javax.annotation.Nullable; + +/** + * This class offers logging functionality to c++ code. It will be used by + * //external/federatedcompute/fcp/client/log_manager.h. + */ +public interface LogManager { + + /** + * Called to log a {@link com.google.android.libraries.micore.learning.proto.ProdDiagCode}. + * + * @param prodDiagCode a serialized ProdDiagCode. + */ + void logProdDiag(int prodDiagCode); + + /** + * Called to log a {@link com.google.android.libraries.micore.learning.proto.DebugDiagCode}. + * + * @param debugDiagCode a serialized DebugDiagCode. + */ + void logDebugDiag(int debugDiagCode); + + /** + * Like {@link #logToLongHistogram(int, int, int, int, String, long)}, but without attaching a + * model identifier. + */ + void logToLongHistogram( + int histogramCounterCode, + int executionIndex, + int epochIndex, + int dataSourceType, + long value); + + /** Called to log a long value, by adding it to an annotated histogram. */ + void logToLongHistogram( + int histogramCounterCode, + int executionIndex, + int epochIndex, + int dataSourceType, + @Nullable String modelIdentifier, + long value); +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/LogManagerImpl.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/LogManagerImpl.java new file mode 100644 index 00000000..3707694f --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/LogManagerImpl.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import com.android.federatedcompute.internal.util.LogUtil; + +import com.google.common.base.Preconditions; +import com.google.intelligence.fcp.client.DebugDiagCode; +import com.google.intelligence.fcp.client.HistogramCounters; +import com.google.intelligence.fcp.client.ProdDiagCode; + +import javax.annotation.Nullable; + +/** + * An implementation of the NativeLogManager interface based on {@link LogManager}, used by C++ + * code. + */ +public class LogManagerImpl implements LogManager { + private static final String TAG = LogManagerImpl.class.getSimpleName(); + private final String mClientPackageName; + + public LogManagerImpl(String clientPackageName) { + this.mClientPackageName = clientPackageName; + } + + @Override + public void logProdDiag(int prodDiagCode) { + // The diag code comes from a C++ engine which could theoretically supply + // invalid codes. + ProdDiagCode diagCode = ProdDiagCode.forNumber(prodDiagCode); + Preconditions.checkNotNull(diagCode); + LogUtil.i(TAG, "Send FL diagnosis log %s for package %s", diagCode, mClientPackageName); + } + + @Override + public void logDebugDiag(int debugDiagCode) { + DebugDiagCode diagCode = DebugDiagCode.forNumber(debugDiagCode); + Preconditions.checkNotNull(diagCode); + LogUtil.i(TAG, "Send FL diagnosis log %s for package %s", diagCode, mClientPackageName); + } + + @Override + public void logToLongHistogram( + int histogramCounterCode, + int executionIndex, + int epochIndex, + int dataSourceTypeCode, + long value) { + logToLongHistogram( + histogramCounterCode, executionIndex, epochIndex, dataSourceTypeCode, null, value); + } + + @Override + public void logToLongHistogram( + int histogramCounterCode, + int executionIndex, + int epochIndex, + int dataSourceTypeCode, + @Nullable String modelIdentifier, + long value) { + HistogramCounters histogramCounter = HistogramCounters.forNumber(histogramCounterCode); + Preconditions.checkNotNull(histogramCounter); + LogUtil.i( + TAG, + "Calling logToLongHistogram %d %d %d %d %d", + histogramCounterCode, + executionIndex, + epochIndex, + dataSourceTypeCode, + value); + // TODO: implement histogram counter logic in LogManager. + + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironment.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironment.java new file mode 100644 index 00000000..5abd84cc --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironment.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +/** + * This class provides callbacks for the C++ federated learning logic, "FL runner". + * + * <p>This interface is used by "fl_runner_jni.cc". + */ +public interface SimpleTaskEnvironment { + /** + * Returns the path of the directory that the C++ implementation should use to store persistent + * files. If files created by the C++ runtime in this directory are deleted, it may not function + * properly. + */ + String getBaseDir(); + + /** + * Returns the path of the directory that the C++ implementation should use to store temporary + * files. If files created by the C++ runtime in this directory are deleted, it will still + * function properly. + */ + String getCacheDir(); + + /** + * Checks whether the device conditions - e.g. Network, Battery, Idleness - allow for running a + * federated computation. + */ + boolean trainingConditionsSatisfied(); + + /** Returns an {@link JavaExampleIterator} object. */ + JavaExampleIterator createExampleIterator(byte[] exampleSelector); + + /** Returns an {@link JavaExampleIterator} object. */ + JavaExampleIterator createExampleIteratorWithContext( + byte[] exampleSelector, byte[] selectorContext); +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironmentImpl.java b/federatedcompute/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironmentImpl.java new file mode 100644 index 00000000..794c36b1 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironmentImpl.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.examplestore.ExampleIterator; +import com.android.federatedcompute.services.training.util.ListenableSupplier; + +import com.google.intelligence.fcp.client.SelectorContext; + +import java.io.Closeable; + +/** Default implementation for {@link SimpleTaskEnvironment} */ +public class SimpleTaskEnvironmentImpl implements SimpleTaskEnvironment, Closeable { + private static final String TAG = SimpleTaskEnvironmentImpl.class.getSimpleName(); + private final ListenableSupplier<Boolean> mInterruptionFlag; + private final JavaExampleStore mJavaExampleStore; + private final Object mLock = new Object(); + + public SimpleTaskEnvironmentImpl( + ListenableSupplier<Boolean> interruptionFlag, ExampleIterator exampleIterator) { + this.mInterruptionFlag = interruptionFlag; + this.mJavaExampleStore = new JavaExampleStore(exampleIterator); + } + + /** + * Isolated training process should not have file system access, so we throw exception here on + * purpose. It should be called because we directly return at JNI layer. + */ + @Override + public String getBaseDir() { + throw new UnsupportedOperationException("getBaseDir is not supported yet."); + } + + /** + * Isolated training process should not have file system access, so we throw exception here on + * purpose. It should be called because we directly return at JNI layer. + */ + @Override + public String getCacheDir() { + throw new UnsupportedOperationException("getCacheDir is not supported yet."); + } + + /** + * We don't check real training conditions here because the isolated training process can't + * access system service e.g. PowerManager, battery intent. We only check all the conditions + * before the training starts. + */ + @Override + public boolean trainingConditionsSatisfied() { + // Check for external termination. + if (Thread.interrupted()) { + return false; + } + + if (mInterruptionFlag.get()) { + LogUtil.i( + TAG, "Interrupting training due to custom interruption flag set to" + " true"); + return false; + } + return true; + } + + @Override + public JavaExampleIterator createExampleIterator(byte[] exampleSelector) { + return createExampleIteratorWithContext( + exampleSelector, SelectorContext.getDefaultInstance().toByteArray()); + } + + @Override + public JavaExampleIterator createExampleIteratorWithContext( + byte[] exampleSelector, byte[] selectorContext) { + synchronized (mLock) { + return mJavaExampleStore.createExampleIteratorWithContext( + exampleSelector, selectorContext); + } + } + + @Override + public void close() { + mJavaExampleStore.close(); + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java b/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java new file mode 100644 index 00000000..89d98472 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/training/util/ComputationResult.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.util; + +import android.federatedcompute.common.ExampleConsumption; + +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; + +import java.util.ArrayList; + +/** The result of federated computation. */ +public class ComputationResult { + private String mOutputCheckpointFile = ""; + private FLRunnerResult mFlRunnerResult = null; + private ArrayList<ExampleConsumption> mExampleConsumptionList = null; + + public ComputationResult( + String outputCheckpointFile, + FLRunnerResult flRunnerResult, + ArrayList<ExampleConsumption> exampleConsumptionList) { + this.mOutputCheckpointFile = outputCheckpointFile; + this.mFlRunnerResult = flRunnerResult; + this.mExampleConsumptionList = exampleConsumptionList; + } + + public ArrayList<ExampleConsumption> getExampleConsumptionList() { + return mExampleConsumptionList; + } + + public String getOutputCheckpointFile() { + return mOutputCheckpointFile; + } + + public FLRunnerResult getFlRunnerResult() { + return mFlRunnerResult; + } + + public boolean isResultSuccess() { + return mFlRunnerResult.getContributionResult() == ContributionResult.SUCCESS; + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/util/TrainingConditionsChecker.java b/federatedcompute/src/com/android/federatedcompute/services/training/util/TrainingConditionsChecker.java index 25845492..caf45ecb 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/util/TrainingConditionsChecker.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/util/TrainingConditionsChecker.java @@ -19,8 +19,8 @@ package com.android.federatedcompute.services.training.util; import android.annotation.NonNull; import android.content.Context; import android.os.PowerManager; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.BatteryInfo; import com.android.federatedcompute.services.common.Clock; import com.android.federatedcompute.services.common.Flags; @@ -37,7 +37,7 @@ import java.util.concurrent.atomic.AtomicLong; * Utilities for checking that the device is currently in an acceptable state for training models. */ public class TrainingConditionsChecker { - private static final String TAG = "TrainingCondChecker"; + private static final String TAG = TrainingConditionsChecker.class.getSimpleName(); private final BatteryInfo mBatteryInfo; private final PowerManager mPowerManager; private final Clock mClock; @@ -45,7 +45,7 @@ public class TrainingConditionsChecker { private final long mThrottlePeriodMillis; private final AtomicLong mLastConditionCheckTimeMillis; - private static TrainingConditionsChecker sSingletonInstance; + private static volatile TrainingConditionsChecker sSingletonInstance; /** * Result of the training condition check. We rely on JobScheduler for unmetered network and @@ -71,25 +71,27 @@ public class TrainingConditionsChecker { /** Gets an instance of {@link TrainingConditionsChecker}. */ public static TrainingConditionsChecker getInstance(Context context) { - synchronized (TrainingConditionsChecker.class) { - if (sSingletonInstance == null) { - Flags flags = PhFlags.getInstance(); - sSingletonInstance = - new TrainingConditionsChecker( - new BatteryInfo(context, flags), - context.getSystemService(PowerManager.class), - flags, - MonotonicClock.getInstance()); + if (sSingletonInstance == null) { + synchronized (TrainingConditionsChecker.class) { + if (sSingletonInstance == null) { + Flags flags = PhFlags.getInstance(); + sSingletonInstance = + new TrainingConditionsChecker( + new BatteryInfo(context.getApplicationContext(), flags), + context.getSystemService(PowerManager.class), + flags, + MonotonicClock.getInstance()); + } } - return sSingletonInstance; } + return sSingletonInstance; } private boolean deviceThermalsOkForTraining() { if (mPowerManager == null) { // If the device does not expose a PowerManager service, then we can't determine // idleness, and then we err on the side of caution and return false. - Log.w(TAG, "PowerManager is not available when do background training"); + LogUtil.w(TAG, "PowerManager is not available when do background training"); return false; } return mPowerManager.getCurrentThermalStatus() < mFlags.getThermalStatusToThrottle(); @@ -105,11 +107,11 @@ public class TrainingConditionsChecker { // Throttling is enabled. if (mThrottlePeriodMillis > 0) { if (nowMillis - mLastConditionCheckTimeMillis.get() < mThrottlePeriodMillis) { - Log.i( + LogUtil.i( TAG, - String.format( - "training condition check is throttled %d %d", - nowMillis, mLastConditionCheckTimeMillis.get())); + "training condition check is throttled %d %d", + nowMillis, + mLastConditionCheckTimeMillis.get()); return EnumSet.noneOf(Condition.class); } else { mLastConditionCheckTimeMillis.set(nowMillis); diff --git a/framework/Android.bp b/framework/Android.bp index 73d17773..c927a944 100644 --- a/framework/Android.bp +++ b/framework/Android.bp @@ -51,6 +51,7 @@ java_sdk_library { ], libs: [ "modules-utils-preconditions", + "framework-connectivity.stubs.module_lib", ], sdk_version: "module_current", defaults: [ @@ -62,6 +63,7 @@ java_sdk_library { "android.ondevicepersonalization", "com.android.ondevicepersonalization.internal", "android.federatedcompute", + "com.android.federatedcompute.internal", ], jarjar_rules: ":framework-ondevicepersonalization-jarjar", static_libs: [ diff --git a/framework/api/current.txt b/framework/api/current.txt index d802177e..e07681c6 100644 --- a/framework/api/current.txt +++ b/framework/api/current.txt @@ -1 +1,183 @@ // Signature format: 2.0 +package android.adservices.ondevicepersonalization { + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class AppInfo implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public boolean isInstalled(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.AppInfo> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class DownloadCompletedInput { + method @NonNull public java.util.Map<java.lang.String,byte[]> getData(); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class DownloadCompletedOutput implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public java.util.List<java.lang.String> getRetainedKeys(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.DownloadCompletedOutput> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final class DownloadCompletedOutput.Builder { + ctor public DownloadCompletedOutput.Builder(); + method @NonNull public android.adservices.ondevicepersonalization.DownloadCompletedOutput.Builder addRetainedKey(@NonNull String); + method @NonNull public android.adservices.ondevicepersonalization.DownloadCompletedOutput build(); + method @NonNull public android.adservices.ondevicepersonalization.DownloadCompletedOutput.Builder setRetainedKeys(@NonNull java.util.List<java.lang.String>); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class EventLogRecord implements android.os.Parcelable { + method public int describeContents(); + method @Nullable public android.content.ContentValues getData(); + method @IntRange(from=0) public int getRowIndex(); + method @IntRange(from=1, to=127) public int getType(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.EventLogRecord> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final class EventLogRecord.Builder { + ctor public EventLogRecord.Builder(); + method @NonNull public android.adservices.ondevicepersonalization.EventLogRecord build(); + method @NonNull public android.adservices.ondevicepersonalization.EventLogRecord.Builder setData(@NonNull android.content.ContentValues); + method @NonNull public android.adservices.ondevicepersonalization.EventLogRecord.Builder setRowIndex(@IntRange(from=0) int); + method @NonNull public android.adservices.ondevicepersonalization.EventLogRecord.Builder setType(@IntRange(from=1, to=127) int); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class EventUrlProvider { + method @NonNull @WorkerThread public android.net.Uri createEventTrackingUrlWithRedirect(@NonNull android.os.PersistableBundle, @Nullable android.net.Uri); + method @NonNull @WorkerThread public android.net.Uri createEventTrackingUrlWithResponse(@NonNull android.os.PersistableBundle, @Nullable byte[], @Nullable String); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class ExecuteInput implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public String getAppPackageName(); + method @NonNull public android.os.PersistableBundle getAppParams(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.ExecuteInput> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class ExecuteOutput implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public java.util.List<android.adservices.ondevicepersonalization.RenderingConfig> getRenderingConfigs(); + method @Nullable public android.adservices.ondevicepersonalization.RequestLogRecord getRequestLogRecord(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.ExecuteOutput> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final class ExecuteOutput.Builder { + ctor public ExecuteOutput.Builder(); + method @NonNull public android.adservices.ondevicepersonalization.ExecuteOutput.Builder addRenderingConfig(@NonNull android.adservices.ondevicepersonalization.RenderingConfig); + method @NonNull public android.adservices.ondevicepersonalization.ExecuteOutput build(); + method @NonNull public android.adservices.ondevicepersonalization.ExecuteOutput.Builder setRenderingConfigs(@NonNull java.util.List<android.adservices.ondevicepersonalization.RenderingConfig>); + method @NonNull public android.adservices.ondevicepersonalization.ExecuteOutput.Builder setRequestLogRecord(@NonNull android.adservices.ondevicepersonalization.RequestLogRecord); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public abstract class IsolatedService extends android.app.Service { + ctor public IsolatedService(); + method @NonNull public final android.adservices.ondevicepersonalization.EventUrlProvider getEventUrlProvider(@NonNull android.adservices.ondevicepersonalization.RequestToken); + method @NonNull public final android.adservices.ondevicepersonalization.MutableKeyValueStore getLocalData(@NonNull android.adservices.ondevicepersonalization.RequestToken); + method @NonNull public final android.adservices.ondevicepersonalization.KeyValueStore getRemoteData(@NonNull android.adservices.ondevicepersonalization.RequestToken); + method @Nullable public final android.adservices.ondevicepersonalization.UserData getUserData(@NonNull android.adservices.ondevicepersonalization.RequestToken); + method @Nullable public android.os.IBinder onBind(@NonNull android.content.Intent); + method @NonNull public abstract android.adservices.ondevicepersonalization.IsolatedWorker onRequest(@NonNull android.adservices.ondevicepersonalization.RequestToken); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public interface IsolatedWorker { + method public default void onDownloadCompleted(@NonNull android.adservices.ondevicepersonalization.DownloadCompletedInput, @NonNull java.util.function.Consumer<android.adservices.ondevicepersonalization.DownloadCompletedOutput>); + method public default void onExecute(@NonNull android.adservices.ondevicepersonalization.ExecuteInput, @NonNull java.util.function.Consumer<android.adservices.ondevicepersonalization.ExecuteOutput>); + method public default void onRender(@NonNull android.adservices.ondevicepersonalization.RenderInput, @NonNull java.util.function.Consumer<android.adservices.ondevicepersonalization.RenderOutput>); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public interface KeyValueStore { + method @Nullable @WorkerThread public byte[] get(@NonNull String); + method @NonNull @WorkerThread public java.util.Set<java.lang.String> keySet(); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public interface MutableKeyValueStore extends android.adservices.ondevicepersonalization.KeyValueStore { + method @Nullable @WorkerThread public byte[] put(@NonNull String, @NonNull byte[]); + method @Nullable @WorkerThread public byte[] remove(@NonNull String); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class OnDevicePersonalizationException extends java.lang.Exception { + method public int getErrorCode(); + field public static final int ERROR_ISOLATED_SERVICE_FAILED = 1; // 0x1 + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class OnDevicePersonalizationManager { + method public void requestSurfacePackage(@NonNull android.adservices.ondevicepersonalization.SurfacePackageToken, @NonNull android.os.IBinder, int, int, int, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<android.view.SurfaceControlViewHost.SurfacePackage,java.lang.Exception>); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class RenderInput implements android.os.Parcelable { + method public int describeContents(); + method public int getHeight(); + method @Nullable public android.adservices.ondevicepersonalization.RenderingConfig getRenderingConfig(); + method public int getRenderingConfigIndex(); + method public int getWidth(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.RenderInput> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class RenderOutput implements android.os.Parcelable { + method public int describeContents(); + method @Nullable public String getContent(); + method @Nullable public String getTemplateId(); + method @NonNull public android.os.PersistableBundle getTemplateParams(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.RenderOutput> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final class RenderOutput.Builder { + ctor public RenderOutput.Builder(); + method @NonNull public android.adservices.ondevicepersonalization.RenderOutput build(); + method @NonNull public android.adservices.ondevicepersonalization.RenderOutput.Builder setContent(@NonNull String); + method @NonNull public android.adservices.ondevicepersonalization.RenderOutput.Builder setTemplateId(@NonNull String); + method @NonNull public android.adservices.ondevicepersonalization.RenderOutput.Builder setTemplateParams(@NonNull android.os.PersistableBundle); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class RenderingConfig implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public java.util.List<java.lang.String> getKeys(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.RenderingConfig> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final class RenderingConfig.Builder { + ctor public RenderingConfig.Builder(); + method @NonNull public android.adservices.ondevicepersonalization.RenderingConfig.Builder addKey(@NonNull String); + method @NonNull public android.adservices.ondevicepersonalization.RenderingConfig build(); + method @NonNull public android.adservices.ondevicepersonalization.RenderingConfig.Builder setKeys(@NonNull java.util.List<java.lang.String>); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class RequestLogRecord implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public java.util.List<android.content.ContentValues> getRows(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.RequestLogRecord> CREATOR; + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final class RequestLogRecord.Builder { + ctor public RequestLogRecord.Builder(); + method @NonNull public android.adservices.ondevicepersonalization.RequestLogRecord.Builder addRow(@NonNull android.content.ContentValues); + method @NonNull public android.adservices.ondevicepersonalization.RequestLogRecord build(); + method @NonNull public android.adservices.ondevicepersonalization.RequestLogRecord.Builder setRows(@NonNull java.util.List<android.content.ContentValues>); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class RequestToken { + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class SurfacePackageToken { + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public final class UserData implements android.os.Parcelable { + method public int describeContents(); + method @NonNull public java.util.Map<java.lang.String,android.adservices.ondevicepersonalization.AppInfo> getAppInfos(); + method @IntRange(from=0) public long getAvailableStorageBytes(); + method @IntRange(from=0, to=100) public int getBatteryPercentage(); + method @NonNull public String getCarrier(); + method public int getOrientation(); + method public void writeToParcel(@NonNull android.os.Parcel, int); + field @NonNull public static final android.os.Parcelable.Creator<android.adservices.ondevicepersonalization.UserData> CREATOR; + } + +} + diff --git a/framework/api/system-current.txt b/framework/api/system-current.txt index d802177e..f125fe9c 100644 --- a/framework/api/system-current.txt +++ b/framework/api/system-current.txt @@ -1 +1,13 @@ // Signature format: 2.0 +package android.adservices.ondevicepersonalization { + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class OnDevicePersonalizationConfigManager { + method @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @RequiresPermission(android.adservices.ondevicepersonalization.OnDevicePersonalizationPermissions.MODIFY_ONDEVICEPERSONALIZATION_STATE) public void setPersonalizationEnabled(boolean, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<java.lang.Void,java.lang.Exception>); + } + + @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class OnDevicePersonalizationPermissions { + field @FlaggedApi(android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public static final String MODIFY_ONDEVICEPERSONALIZATION_STATE = "android.permission.ondevicepersonalization.MODIFY_ONDEVICEPERSONALIZATION_STATE"; + } + +} + diff --git a/framework/java/android/adservices/ondevicepersonalization/DownloadOutput.aidl b/framework/java/android/adservices/ondevicepersonalization/AppInfo.aidl index aec35c14..fa69edbe 100644 --- a/framework/java/android/adservices/ondevicepersonalization/DownloadOutput.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/AppInfo.aidl @@ -16,4 +16,4 @@ package android.adservices.ondevicepersonalization; -parcelable DownloadOutput; +parcelable AppInfo; diff --git a/framework/java/android/adservices/ondevicepersonalization/AppInstallInfo.java b/framework/java/android/adservices/ondevicepersonalization/AppInfo.java index c6d57893..4a628d8d 100644 --- a/framework/java/android/adservices/ondevicepersonalization/AppInstallInfo.java +++ b/framework/java/android/adservices/ondevicepersonalization/AppInfo.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.os.Parcelable; @@ -23,12 +26,12 @@ import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * Installation information for an app. + * Information about apps. * - * @hide */ -@DataClass(genBuilder = true, genEqualsHashCode = true) -public final class AppInstallInfo implements Parcelable { +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +@DataClass(genHiddenBuilder = true, genEqualsHashCode = true) +public final class AppInfo implements Parcelable { /** Whether the app is installed. */ @NonNull boolean mInstalled = false; @@ -40,7 +43,7 @@ public final class AppInstallInfo implements Parcelable { // CHECKSTYLE:OFF Generated code // // To regenerate run: - // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/AppInstallInfo.java + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/AppInfo.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control @@ -48,7 +51,7 @@ public final class AppInstallInfo implements Parcelable { @DataClass.Generated.Member - /* package-private */ AppInstallInfo( + /* package-private */ AppInfo( @NonNull boolean installed) { this.mInstalled = installed; AnnotationValidations.validate( @@ -69,13 +72,13 @@ public final class AppInstallInfo implements Parcelable { @DataClass.Generated.Member public boolean equals(@android.annotation.Nullable Object o) { // You can override field equality logic by defining either of the methods like: - // boolean fieldNameEquals(AppInstallInfo other) { ... } + // boolean fieldNameEquals(AppInfo other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; @SuppressWarnings("unchecked") - AppInstallInfo that = (AppInstallInfo) o; + AppInfo that = (AppInfo) o; //noinspection PointlessBooleanExpression return true && mInstalled == that.mInstalled; @@ -110,7 +113,7 @@ public final class AppInstallInfo implements Parcelable { /** @hide */ @SuppressWarnings({"unchecked", "RedundantCast"}) @DataClass.Generated.Member - /* package-private */ AppInstallInfo(@NonNull android.os.Parcel in) { + /* package-private */ AppInfo(@NonNull android.os.Parcel in) { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } @@ -125,21 +128,22 @@ public final class AppInstallInfo implements Parcelable { } @DataClass.Generated.Member - public static final @NonNull Parcelable.Creator<AppInstallInfo> CREATOR - = new Parcelable.Creator<AppInstallInfo>() { + public static final @NonNull Parcelable.Creator<AppInfo> CREATOR + = new Parcelable.Creator<AppInfo>() { @Override - public AppInstallInfo[] newArray(int size) { - return new AppInstallInfo[size]; + public AppInfo[] newArray(int size) { + return new AppInfo[size]; } @Override - public AppInstallInfo createFromParcel(@NonNull android.os.Parcel in) { - return new AppInstallInfo(in); + public AppInfo createFromParcel(@NonNull android.os.Parcel in) { + return new AppInfo(in); } }; /** - * A builder for {@link AppInstallInfo} + * A builder for {@link AppInfo} + * @hide */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member @@ -164,14 +168,14 @@ public final class AppInstallInfo implements Parcelable { } /** Builds the instance. This builder should not be touched after calling this! */ - public @NonNull AppInstallInfo build() { + public @NonNull AppInfo build() { checkNotUsed(); mBuilderFieldsSet |= 0x2; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mInstalled = false; } - AppInstallInfo o = new AppInstallInfo( + AppInfo o = new AppInfo( mInstalled); return o; } @@ -185,10 +189,10 @@ public final class AppInstallInfo implements Parcelable { } @DataClass.Generated( - time = 1693265003084L, + time = 1695492606666L, codegenVersion = "1.0.23", - sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/AppInstallInfo.java", - inputSignatures = " @android.annotation.NonNull boolean mInstalled\nclass AppInstallInfo extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/AppInfo.java", + inputSignatures = " @android.annotation.NonNull boolean mInstalled\nclass AppInfo extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genHiddenBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/AppInstallInfo.aidl b/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.aidl index 2a8aa983..3e936c9b 100644 --- a/framework/java/android/adservices/ondevicepersonalization/AppInstallInfo.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.aidl @@ -16,4 +16,4 @@ package android.adservices.ondevicepersonalization; -parcelable AppInstallInfo; +parcelable CalleeMetadata; diff --git a/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.java b/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.java new file mode 100644 index 00000000..f5ad7a0a --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Wrapper class for additional information returned with IPC results. +* +* @hide +*/ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public final class CalleeMetadata implements Parcelable { + /** Time elapsed in callee, as measured by callee. */ + private long mElapsedTimeMillis = 0; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ CalleeMetadata( + long elapsedTimeMillis) { + this.mElapsedTimeMillis = elapsedTimeMillis; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * Time elapsed in callee, as measured by callee. + */ + @DataClass.Generated.Member + public long getElapsedTimeMillis() { + return mElapsedTimeMillis; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(CalleeMetadata other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + CalleeMetadata that = (CalleeMetadata) o; + //noinspection PointlessBooleanExpression + return true + && mElapsedTimeMillis == that.mElapsedTimeMillis; + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + Long.hashCode(mElapsedTimeMillis); + return _hash; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + dest.writeLong(mElapsedTimeMillis); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ CalleeMetadata(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + long elapsedTimeMillis = in.readLong(); + + this.mElapsedTimeMillis = elapsedTimeMillis; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<CalleeMetadata> CREATOR + = new Parcelable.Creator<CalleeMetadata>() { + @Override + public CalleeMetadata[] newArray(int size) { + return new CalleeMetadata[size]; + } + + @Override + public CalleeMetadata createFromParcel(@NonNull android.os.Parcel in) { + return new CalleeMetadata(in); + } + }; + + /** + * A builder for {@link CalleeMetadata} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private long mElapsedTimeMillis; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * Time elapsed in callee, as measured by callee. + */ + @DataClass.Generated.Member + public @NonNull Builder setElapsedTimeMillis(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mElapsedTimeMillis = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull CalleeMetadata build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mElapsedTimeMillis = 0; + } + CalleeMetadata o = new CalleeMetadata( + mElapsedTimeMillis); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x2) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1696885546254L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/CalleeMetadata.java", + inputSignatures = "private long mElapsedTimeMillis\nclass CalleeMetadata extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.aidl b/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.aidl new file mode 100644 index 00000000..d313206e --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable CallerMetadata; diff --git a/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.java b/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.java new file mode 100644 index 00000000..68a9cd86 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Wrapper class for additional information passed to IPC requests. +* +* @hide +*/ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public final class CallerMetadata implements Parcelable { + /** Start time of the operation. */ + private long mStartTimeMillis = 0; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ CallerMetadata( + long startTimeMillis) { + this.mStartTimeMillis = startTimeMillis; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * Start time of the operation. + */ + @DataClass.Generated.Member + public long getStartTimeMillis() { + return mStartTimeMillis; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(CallerMetadata other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + CallerMetadata that = (CallerMetadata) o; + //noinspection PointlessBooleanExpression + return true + && mStartTimeMillis == that.mStartTimeMillis; + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + Long.hashCode(mStartTimeMillis); + return _hash; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + dest.writeLong(mStartTimeMillis); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ CallerMetadata(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + long startTimeMillis = in.readLong(); + + this.mStartTimeMillis = startTimeMillis; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<CallerMetadata> CREATOR + = new Parcelable.Creator<CallerMetadata>() { + @Override + public CallerMetadata[] newArray(int size) { + return new CallerMetadata[size]; + } + + @Override + public CallerMetadata createFromParcel(@NonNull android.os.Parcel in) { + return new CallerMetadata(in); + } + }; + + /** + * A builder for {@link CallerMetadata} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private long mStartTimeMillis; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * Start time of the operation. + */ + @DataClass.Generated.Member + public @NonNull Builder setStartTimeMillis(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mStartTimeMillis = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull CallerMetadata build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mStartTimeMillis = 0; + } + CallerMetadata o = new CallerMetadata( + mStartTimeMillis); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x2) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1696884555838L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/CallerMetadata.java", + inputSignatures = "private long mStartTimeMillis\nclass CallerMetadata extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/Constants.java b/framework/java/android/adservices/ondevicepersonalization/Constants.java index 354d2669..756a7c2e 100644 --- a/framework/java/android/adservices/ondevicepersonalization/Constants.java +++ b/framework/java/android/adservices/ondevicepersonalization/Constants.java @@ -17,42 +17,46 @@ package android.adservices.ondevicepersonalization; /** - * Constants used internally in the OnDevicePersonalization Module and not - * used in public APIs. + * Constants used internally in the OnDevicePersonalization Module and not used in public APIs. * * @hide */ public class Constants { + public static final int STATUS_SUCCESS = 0; public static final int STATUS_INTERNAL_ERROR = 100; - + public static final int STATUS_NAME_NOT_FOUND = 101; + public static final int STATUS_CLASS_NOT_FOUND = 102; + public static final int STATUS_SERVICE_FAILED = 103; + public static final int STATUS_PERSONALIZATION_DISABLED = 104; // Operations implemented by personalization services. public static final int OP_EXECUTE = 1; public static final int OP_DOWNLOAD = 2; public static final int OP_RENDER = 3; public static final int OP_WEB_VIEW_EVENT = 4; + public static final int OP_TRAINING_EXAMPLE = 5; // Keys for Bundle objects passed between processes. - public static final String - EXTRA_DATA_ACCESS_SERVICE_BINDER = - "android.ondevicepersonalization.extra.data_access_service_binder"; - public static final String - EXTRA_DESTINATION_URL = "android.ondevicepersonalization.extra.destination_url"; - public static final String - EXTRA_EVENT_PARAMS = "android.ondevicepersonalization.extra.event_params"; - public static final String - EXTRA_INPUT = "android.ondevicepersonalization.extra.input"; - public static final String - EXTRA_LOOKUP_KEYS = "android.ondevicepersonalization.extra.lookup_keys"; - public static final String - EXTRA_MIME_TYPE = "android.ondevicepersonalization.extra.mime_type"; - public static final String - EXTRA_RESPONSE_DATA = "android.ondevicepersonalization.extra.response_data"; - public static final String - EXTRA_USER_DATA = "android.ondevicepersonalization.extra.user_data"; - public static final String - EXTRA_VALUE = "android.ondevicepersonalization.extra.value"; - public static final String - EXTRA_RESULT = "android.ondevicepersonalization.extra.result"; + public static final String EXTRA_CALLEE_METADATA = + "android.ondevicepersonalization.extra.callee_metadata"; + public static final String EXTRA_DATA_ACCESS_SERVICE_BINDER = + "android.ondevicepersonalization.extra.data_access_service_binder"; + public static final String EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER = + "android.ondevicepersonalization.extra.federated_computation_service_binder"; + public static final String EXTRA_DESTINATION_URL = + "android.ondevicepersonalization.extra.destination_url"; + public static final String EXTRA_EVENT_PARAMS = + "android.ondevicepersonalization.extra.event_params"; + public static final String EXTRA_INPUT = "android.ondevicepersonalization.extra.input"; + public static final String EXTRA_LOOKUP_KEYS = + "android.ondevicepersonalization.extra.lookup_keys"; + public static final String EXTRA_MIME_TYPE = "android.ondevicepersonalization.extra.mime_type"; + public static final String EXTRA_RESPONSE_DATA = + "android.ondevicepersonalization.extra.response_data"; + public static final String EXTRA_USER_DATA = "android.ondevicepersonalization.extra.user_data"; + public static final String EXTRA_VALUE = "android.ondevicepersonalization.extra.value"; + public static final String EXTRA_RESULT = "android.ondevicepersonalization.extra.result"; + public static final String KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS = + "enable_ondevicepersonalization_apis"; // Data Access Service operations. public static final int DATA_ACCESS_OP_REMOTE_DATA_LOOKUP = 1; @@ -62,7 +66,8 @@ public class Constants { public static final int DATA_ACCESS_OP_LOCAL_DATA_KEYSET = 5; public static final int DATA_ACCESS_OP_LOCAL_DATA_PUT = 6; public static final int DATA_ACCESS_OP_LOCAL_DATA_REMOVE = 7; - + public static final int DATA_ACCESS_OP_GET_REQUESTS = 8; + public static final int DATA_ACCESS_OP_GET_JOINED_EVENTS = 9; private Constants() {} } diff --git a/framework/java/android/adservices/ondevicepersonalization/DownloadInput.java b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedInput.java index 72a5fd8f..14dfb320 100644 --- a/framework/java/android/adservices/ondevicepersonalization/DownloadInput.java +++ b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedInput.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; @@ -26,12 +29,13 @@ import java.util.Collections; import java.util.Map; /** - * The input data for {@link IsolatedWorker#onDownload()}. + * The input data for {@link + * IsolatedWorker#onDownloadCompleted(DownloadCompletedInput, java.util.function.Consumer)}. * - * @hide */ -@DataClass(genBuilder = true, genEqualsHashCode = true) -public final class DownloadInput { +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +@DataClass(genHiddenBuilder = true, genEqualsHashCode = true) +public final class DownloadCompletedInput { /** Map containing downloaded keys and values */ @NonNull Map<String, byte[]> mData = Collections.emptyMap(); @@ -43,7 +47,7 @@ public final class DownloadInput { // CHECKSTYLE:OFF Generated code // // To regenerate run: - // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadInput.java + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedInput.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control @@ -51,7 +55,7 @@ public final class DownloadInput { @DataClass.Generated.Member - /* package-private */ DownloadInput( + /* package-private */ DownloadCompletedInput( @NonNull Map<String,byte[]> data) { this.mData = data; AnnotationValidations.validate( @@ -72,13 +76,13 @@ public final class DownloadInput { @DataClass.Generated.Member public boolean equals(@Nullable Object o) { // You can override field equality logic by defining either of the methods like: - // boolean fieldNameEquals(DownloadInput other) { ... } + // boolean fieldNameEquals(DownloadCompletedInput other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; @SuppressWarnings("unchecked") - DownloadInput that = (DownloadInput) o; + DownloadCompletedInput that = (DownloadCompletedInput) o; //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mData, that.mData); @@ -96,7 +100,8 @@ public final class DownloadInput { } /** - * A builder for {@link DownloadInput} + * A builder for {@link DownloadCompletedInput} + * @hide */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member @@ -120,15 +125,26 @@ public final class DownloadInput { return this; } + /** @see #setData */ + @DataClass.Generated.Member + public @NonNull Builder addData(@NonNull String key, @NonNull byte[] value) { + // You can refine this method's name by providing item's singular name, e.g.: + // @DataClass.PluralOf("item")) mItems = ... + + if (mData == null) setData(new java.util.LinkedHashMap()); + mData.put(key, value); + return this; + } + /** Builds the instance. This builder should not be touched after calling this! */ - public @NonNull DownloadInput build() { + public @NonNull DownloadCompletedInput build() { checkNotUsed(); mBuilderFieldsSet |= 0x2; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mData = Collections.emptyMap(); } - DownloadInput o = new DownloadInput( + DownloadCompletedInput o = new DownloadCompletedInput( mData); return o; } @@ -142,10 +158,10 @@ public final class DownloadInput { } @DataClass.Generated( - time = 1692119978934L, + time = 1695492633750L, codegenVersion = "1.0.23", - sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadInput.java", - inputSignatures = " @android.annotation.NonNull java.util.Map<java.lang.String,byte[]> mData\nclass DownloadInput extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedInput.java", + inputSignatures = " @android.annotation.NonNull java.util.Map<java.lang.String,byte[]> mData\nclass DownloadCompletedInput extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genHiddenBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.aidl b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.aidl new file mode 100644 index 00000000..8ebe46ef --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable DownloadCompletedOutput; diff --git a/framework/java/android/adservices/ondevicepersonalization/DownloadOutput.java b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.java index 1b6b88f0..10d953d0 100644 --- a/framework/java/android/adservices/ondevicepersonalization/DownloadOutput.java +++ b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.os.Parcelable; @@ -26,12 +29,13 @@ import java.util.Collections; import java.util.List; /** - * The result returned by {@link IsolatedWorker#onDownload()}. + * The result returned by {@link + * IsolatedWorker#onDownloadCompleted(DownloadCompletedInput, java.util.function.Consumer)}. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @DataClass(genBuilder = true, genEqualsHashCode = true) -public final class DownloadOutput implements Parcelable { +public final class DownloadCompletedOutput implements Parcelable { /** * The keys to be retained in the REMOTE_DATA table. Any existing keys that are not * present in this list are removed from the table. @@ -47,7 +51,7 @@ public final class DownloadOutput implements Parcelable { // CHECKSTYLE:OFF Generated code // // To regenerate run: - // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadOutput.java + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control @@ -55,7 +59,7 @@ public final class DownloadOutput implements Parcelable { @DataClass.Generated.Member - /* package-private */ DownloadOutput( + /* package-private */ DownloadCompletedOutput( @NonNull List<String> retainedKeys) { this.mRetainedKeys = retainedKeys; AnnotationValidations.validate( @@ -77,13 +81,13 @@ public final class DownloadOutput implements Parcelable { @DataClass.Generated.Member public boolean equals(@android.annotation.Nullable Object o) { // You can override field equality logic by defining either of the methods like: - // boolean fieldNameEquals(DownloadOutput other) { ... } + // boolean fieldNameEquals(DownloadCompletedOutput other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; @SuppressWarnings("unchecked") - DownloadOutput that = (DownloadOutput) o; + DownloadCompletedOutput that = (DownloadCompletedOutput) o; //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mRetainedKeys, that.mRetainedKeys); @@ -116,7 +120,7 @@ public final class DownloadOutput implements Parcelable { /** @hide */ @SuppressWarnings({"unchecked", "RedundantCast"}) @DataClass.Generated.Member - /* package-private */ DownloadOutput(@NonNull android.os.Parcel in) { + /* package-private */ DownloadCompletedOutput(@NonNull android.os.Parcel in) { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } @@ -131,22 +135,23 @@ public final class DownloadOutput implements Parcelable { } @DataClass.Generated.Member - public static final @NonNull Parcelable.Creator<DownloadOutput> CREATOR - = new Parcelable.Creator<DownloadOutput>() { + public static final @NonNull Parcelable.Creator<DownloadCompletedOutput> CREATOR + = new Parcelable.Creator<DownloadCompletedOutput>() { @Override - public DownloadOutput[] newArray(int size) { - return new DownloadOutput[size]; + public DownloadCompletedOutput[] newArray(int size) { + return new DownloadCompletedOutput[size]; } @Override - public DownloadOutput createFromParcel(@NonNull android.os.Parcel in) { - return new DownloadOutput(in); + public DownloadCompletedOutput createFromParcel(@NonNull android.os.Parcel in) { + return new DownloadCompletedOutput(in); } }; /** - * A builder for {@link DownloadOutput} + * A builder for {@link DownloadCompletedOutput} */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { @@ -179,14 +184,14 @@ public final class DownloadOutput implements Parcelable { } /** Builds the instance. This builder should not be touched after calling this! */ - public @NonNull DownloadOutput build() { + public @NonNull DownloadCompletedOutput build() { checkNotUsed(); mBuilderFieldsSet |= 0x2; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mRetainedKeys = Collections.emptyList(); } - DownloadOutput o = new DownloadOutput( + DownloadCompletedOutput o = new DownloadCompletedOutput( mRetainedKeys); return o; } @@ -200,10 +205,10 @@ public final class DownloadOutput implements Parcelable { } @DataClass.Generated( - time = 1692118344685L, + time = 1696972554365L, codegenVersion = "1.0.23", - sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadOutput.java", - inputSignatures = "private @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"retainedKey\") @android.annotation.NonNull java.util.List<java.lang.String> mRetainedKeys\nclass DownloadOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutput.java", + inputSignatures = "private @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"retainedKey\") @android.annotation.NonNull java.util.List<java.lang.String> mRetainedKeys\nclass DownloadCompletedOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutputParcel.java b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutputParcel.java new file mode 100644 index 00000000..7a906f1d --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutputParcel.java @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +import java.util.Collections; +import java.util.List; + +/** + * Parcelable version of {@link DownloadCompletedOutput}. + * @hide + */ +@DataClass(genAidl = false, genBuilder = false) +public final class DownloadCompletedOutputParcel implements Parcelable { + /** + * The keys to be retained in the REMOTE_DATA table. Any existing keys that are not + * present in this list are removed from the table. + */ + @NonNull private List<String> mRetainedKeys = Collections.emptyList(); + + /** @hide */ + public DownloadCompletedOutputParcel(@NonNull DownloadCompletedOutput value) { + this(value.getRetainedKeys()); + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + /** + * Creates a new DownloadCompletedOutputParcel. + * + * @param retainedKeys + * The keys to be retained in the REMOTE_DATA table. Any existing keys that are not + * present in this list are removed from the table. + */ + @DataClass.Generated.Member + public DownloadCompletedOutputParcel( + @NonNull List<String> retainedKeys) { + this.mRetainedKeys = retainedKeys; + AnnotationValidations.validate( + NonNull.class, null, mRetainedKeys); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The keys to be retained in the REMOTE_DATA table. Any existing keys that are not + * present in this list are removed from the table. + */ + @DataClass.Generated.Member + public @NonNull List<String> getRetainedKeys() { + return mRetainedKeys; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + dest.writeStringList(mRetainedKeys); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ DownloadCompletedOutputParcel(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + List<String> retainedKeys = new java.util.ArrayList<>(); + in.readStringList(retainedKeys); + + this.mRetainedKeys = retainedKeys; + AnnotationValidations.validate( + NonNull.class, null, mRetainedKeys); + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<DownloadCompletedOutputParcel> CREATOR + = new Parcelable.Creator<DownloadCompletedOutputParcel>() { + @Override + public DownloadCompletedOutputParcel[] newArray(int size) { + return new DownloadCompletedOutputParcel[size]; + } + + @Override + public DownloadCompletedOutputParcel createFromParcel(@NonNull android.os.Parcel in) { + return new DownloadCompletedOutputParcel(in); + } + }; + + @DataClass.Generated( + time = 1698783477713L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/DownloadCompletedOutputParcel.java", + inputSignatures = "private @android.annotation.NonNull java.util.List<java.lang.String> mRetainedKeys\nclass DownloadCompletedOutputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genBuilder=false)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/WebViewEventInput.aidl b/framework/java/android/adservices/ondevicepersonalization/EventInput.aidl index 9b6caa6d..156f47e7 100644 --- a/framework/java/android/adservices/ondevicepersonalization/WebViewEventInput.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/EventInput.aidl @@ -16,4 +16,4 @@ package android.adservices.ondevicepersonalization; -parcelable WebViewEventInput; +parcelable EventInput; diff --git a/framework/java/android/adservices/ondevicepersonalization/EventInput.java b/framework/java/android/adservices/ondevicepersonalization/EventInput.java new file mode 100644 index 00000000..bb5de28b --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/EventInput.java @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.os.Parcelable; +import android.os.PersistableBundle; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * The input data for {@link + * IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}. + * @hide + */ +@DataClass(genBuilder = false, genHiddenConstructor = true, genEqualsHashCode = true) +public final class EventInput implements Parcelable { + /** + * The {@link RequestLogRecord} that was returned as a result of + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + */ + @Nullable private RequestLogRecord mRequestLogRecord = null; + + /** + * The Event URL parameters that the service passed to {@link + * EventUrlProvider#createEventTrackingUrlWithResponse(PersistableBundle, byte[], String)} + * or {@link EventUrlProvider#createEventTrackingUrlWithRedirect(PersistableBundle, Uri)}. + */ + @NonNull private PersistableBundle mParameters = PersistableBundle.EMPTY; + + /** @hide */ + public EventInput(@NonNull EventInputParcel parcel) { + this(parcel.getRequestLogRecord(), parcel.getParameters()); + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventInput.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + /** + * Creates a new EventInput. + * + * @param requestLogRecord + * The {@link RequestLogRecord} that was returned as a result of + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + * @param parameters + * The Event URL parameters that the service passed to {@link + * EventUrlProvider#createEventTrackingUrlWithResponse(PersistableBundle, byte[], String)} + * or {@link EventUrlProvider#createEventTrackingUrlWithRedirect(PersistableBundle, Uri)}. + * @hide + */ + @DataClass.Generated.Member + public EventInput( + @Nullable RequestLogRecord requestLogRecord, + @NonNull PersistableBundle parameters) { + this.mRequestLogRecord = requestLogRecord; + this.mParameters = parameters; + AnnotationValidations.validate( + NonNull.class, null, mParameters); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The {@link RequestLogRecord} that was returned as a result of + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + */ + @DataClass.Generated.Member + public @Nullable RequestLogRecord getRequestLogRecord() { + return mRequestLogRecord; + } + + /** + * The Event URL parameters that the service passed to {@link + * EventUrlProvider#createEventTrackingUrlWithResponse(PersistableBundle, byte[], String)} + * or {@link EventUrlProvider#createEventTrackingUrlWithRedirect(PersistableBundle, Uri)}. + */ + @DataClass.Generated.Member + public @NonNull PersistableBundle getParameters() { + return mParameters; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(EventInput other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + EventInput that = (EventInput) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mRequestLogRecord, that.mRequestLogRecord) + && java.util.Objects.equals(mParameters, that.mParameters); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mRequestLogRecord); + _hash = 31 * _hash + java.util.Objects.hashCode(mParameters); + return _hash; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mRequestLogRecord != null) flg |= 0x1; + dest.writeByte(flg); + if (mRequestLogRecord != null) dest.writeTypedObject(mRequestLogRecord, flags); + dest.writeTypedObject(mParameters, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ EventInput(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + RequestLogRecord requestLogRecord = (flg & 0x1) == 0 ? null : (RequestLogRecord) in.readTypedObject(RequestLogRecord.CREATOR); + PersistableBundle parameters = (PersistableBundle) in.readTypedObject(PersistableBundle.CREATOR); + + this.mRequestLogRecord = requestLogRecord; + this.mParameters = parameters; + AnnotationValidations.validate( + NonNull.class, null, mParameters); + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<EventInput> CREATOR + = new Parcelable.Creator<EventInput>() { + @Override + public EventInput[] newArray(int size) { + return new EventInput[size]; + } + + @Override + public EventInput createFromParcel(@NonNull android.os.Parcel in) { + return new EventInput(in); + } + }; + + @DataClass.Generated( + time = 1698875332901L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventInput.java", + inputSignatures = "private @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nprivate @android.annotation.NonNull android.os.PersistableBundle mParameters\nclass EventInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=false, genHiddenConstructor=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/WebViewEventInput.java b/framework/java/android/adservices/ondevicepersonalization/EventInputParcel.java index 0487c783..7434064d 100644 --- a/framework/java/android/adservices/ondevicepersonalization/WebViewEventInput.java +++ b/framework/java/android/adservices/ondevicepersonalization/EventInputParcel.java @@ -25,21 +25,21 @@ import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * The input data for {@link IsolatedWorker#onWebViewEvent()}. - * + * Parcelable version of {@link EventInput}. * @hide */ -@DataClass(genBuilder = true, genEqualsHashCode = true) -public final class WebViewEventInput implements Parcelable { +@DataClass(genAidl = false, genHiddenBuilder = true) +public final class EventInputParcel implements Parcelable { /** * The {@link RequestLogRecord} that was returned as a result of - * {@link IsolatedWorker#onExecute()}. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. */ @Nullable private RequestLogRecord mRequestLogRecord = null; /** - * The Event URL parameters that the service passed to - * {@link EventUrlProvider#getEventTrackingUrl()}. + * The Event URL parameters that the service passed to {@link + * EventUrlProvider#createEventTrackingUrlWithResponse(PersistableBundle, byte[], String)} + * or {@link EventUrlProvider#createEventTrackingUrlWithRedirect(PersistableBundle, Uri)}. */ @NonNull private PersistableBundle mParameters = PersistableBundle.EMPTY; @@ -51,7 +51,7 @@ public final class WebViewEventInput implements Parcelable { // CHECKSTYLE:OFF Generated code // // To regenerate run: - // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/WebViewEventInput.java + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventInputParcel.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control @@ -59,7 +59,7 @@ public final class WebViewEventInput implements Parcelable { @DataClass.Generated.Member - /* package-private */ WebViewEventInput( + /* package-private */ EventInputParcel( @Nullable RequestLogRecord requestLogRecord, @NonNull PersistableBundle parameters) { this.mRequestLogRecord = requestLogRecord; @@ -72,7 +72,7 @@ public final class WebViewEventInput implements Parcelable { /** * The {@link RequestLogRecord} that was returned as a result of - * {@link IsolatedWorker#onExecute()}. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. */ @DataClass.Generated.Member public @Nullable RequestLogRecord getRequestLogRecord() { @@ -80,8 +80,9 @@ public final class WebViewEventInput implements Parcelable { } /** - * The Event URL parameters that the service passed to - * {@link EventUrlProvider#getEventTrackingUrl()}. + * The Event URL parameters that the service passed to {@link + * EventUrlProvider#createEventTrackingUrlWithResponse(PersistableBundle, byte[], String)} + * or {@link EventUrlProvider#createEventTrackingUrlWithRedirect(PersistableBundle, Uri)}. */ @DataClass.Generated.Member public @NonNull PersistableBundle getParameters() { @@ -90,35 +91,6 @@ public final class WebViewEventInput implements Parcelable { @Override @DataClass.Generated.Member - public boolean equals(@Nullable Object o) { - // You can override field equality logic by defining either of the methods like: - // boolean fieldNameEquals(WebViewEventInput other) { ... } - // boolean fieldNameEquals(FieldType otherValue) { ... } - - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - @SuppressWarnings("unchecked") - WebViewEventInput that = (WebViewEventInput) o; - //noinspection PointlessBooleanExpression - return true - && java.util.Objects.equals(mRequestLogRecord, that.mRequestLogRecord) - && java.util.Objects.equals(mParameters, that.mParameters); - } - - @Override - @DataClass.Generated.Member - public int hashCode() { - // You can override field hashCode logic by defining methods like: - // int fieldNameHashCode() { ... } - - int _hash = 1; - _hash = 31 * _hash + java.util.Objects.hashCode(mRequestLogRecord); - _hash = 31 * _hash + java.util.Objects.hashCode(mParameters); - return _hash; - } - - @Override - @DataClass.Generated.Member public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { // You can override field parcelling by defining methods like: // void parcelFieldName(Parcel dest, int flags) { ... } @@ -137,7 +109,7 @@ public final class WebViewEventInput implements Parcelable { /** @hide */ @SuppressWarnings({"unchecked", "RedundantCast"}) @DataClass.Generated.Member - /* package-private */ WebViewEventInput(@NonNull android.os.Parcel in) { + /* package-private */ EventInputParcel(@NonNull android.os.Parcel in) { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } @@ -154,21 +126,22 @@ public final class WebViewEventInput implements Parcelable { } @DataClass.Generated.Member - public static final @NonNull Parcelable.Creator<WebViewEventInput> CREATOR - = new Parcelable.Creator<WebViewEventInput>() { + public static final @NonNull Parcelable.Creator<EventInputParcel> CREATOR + = new Parcelable.Creator<EventInputParcel>() { @Override - public WebViewEventInput[] newArray(int size) { - return new WebViewEventInput[size]; + public EventInputParcel[] newArray(int size) { + return new EventInputParcel[size]; } @Override - public WebViewEventInput createFromParcel(@NonNull android.os.Parcel in) { - return new WebViewEventInput(in); + public EventInputParcel createFromParcel(@NonNull android.os.Parcel in) { + return new EventInputParcel(in); } }; /** - * A builder for {@link WebViewEventInput} + * A builder for {@link EventInputParcel} + * @hide */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member @@ -184,7 +157,7 @@ public final class WebViewEventInput implements Parcelable { /** * The {@link RequestLogRecord} that was returned as a result of - * {@link IsolatedWorker#onExecute()}. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. */ @DataClass.Generated.Member public @NonNull Builder setRequestLogRecord(@NonNull RequestLogRecord value) { @@ -195,8 +168,9 @@ public final class WebViewEventInput implements Parcelable { } /** - * The Event URL parameters that the service passed to - * {@link EventUrlProvider#getEventTrackingUrl()}. + * The Event URL parameters that the service passed to {@link + * EventUrlProvider#createEventTrackingUrlWithResponse(PersistableBundle, byte[], String)} + * or {@link EventUrlProvider#createEventTrackingUrlWithRedirect(PersistableBundle, Uri)}. */ @DataClass.Generated.Member public @NonNull Builder setParameters(@NonNull PersistableBundle value) { @@ -207,7 +181,7 @@ public final class WebViewEventInput implements Parcelable { } /** Builds the instance. This builder should not be touched after calling this! */ - public @NonNull WebViewEventInput build() { + public @NonNull EventInputParcel build() { checkNotUsed(); mBuilderFieldsSet |= 0x4; // Mark builder used @@ -217,7 +191,7 @@ public final class WebViewEventInput implements Parcelable { if ((mBuilderFieldsSet & 0x2) == 0) { mParameters = PersistableBundle.EMPTY; } - WebViewEventInput o = new WebViewEventInput( + EventInputParcel o = new EventInputParcel( mRequestLogRecord, mParameters); return o; @@ -232,10 +206,10 @@ public final class WebViewEventInput implements Parcelable { } @DataClass.Generated( - time = 1692118434746L, + time = 1698875208124L, codegenVersion = "1.0.23", - sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/WebViewEventInput.java", - inputSignatures = "private @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nprivate @android.annotation.NonNull android.os.PersistableBundle mParameters\nclass WebViewEventInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventInputParcel.java", + inputSignatures = "private @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nprivate @android.annotation.NonNull android.os.PersistableBundle mParameters\nclass EventInputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genHiddenBuilder=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/EventLogRecord.java b/framework/java/android/adservices/ondevicepersonalization/EventLogRecord.java index 1944140c..a65b2aa4 100644 --- a/framework/java/android/adservices/ondevicepersonalization/EventLogRecord.java +++ b/framework/java/android/adservices/ondevicepersonalization/EventLogRecord.java @@ -16,35 +16,82 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; +import android.annotation.IntRange; +import android.annotation.NonNull; import android.annotation.Nullable; import android.content.ContentValues; import android.os.Parcelable; +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; +// TODO(b/289102463): Add a link to the public doc for the EVENTS table when available. /** - * Data to be logged in the EVENTS table that is associated with a pre-existing - * {@link RequestLogRecord} that has been written to the REQUESTS table. + * Data to be logged in the EVENTS table. * - * @hide + * Each record in the EVENTS table is associated with one row from an existing + * {@link RequestLogRecord} in the requests table {@link RequestLogRecord#getRows()}. + * The purpose of the EVENTS table is to add supplemental information to logged data + * from a prior request, e.g., logging an event when a link in a rendered WebView is + * clicked {@code IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}. + * The contents of the EVENTS table can be + * consumed by Federated Learning facilitated model training, or Federated Analytics facilitated + * cross-device statistical analysis. */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @DataClass(genBuilder = true, genEqualsHashCode = true) public final class EventLogRecord implements Parcelable { /** * The index of the row in an existing {@link RequestLogRecord} that this payload should be * associated with. **/ - private int mRowIndex = 0; + private @IntRange(from = 0) int mRowIndex = 0; + + /** + * The service-assigned identifier that identifies this payload. Each row in + * {@link RequestLogRecord} can be associated with up to one event of a specified type. + * The platform drops events if another event with the same type already exists for a row + * in {@link RequestLogRecord}. Must be >0 and <128. This allows up to 127 events to be + * written for each row in {@link RequestLogRecord}. If unspecified, the default is 1. + */ + private @IntRange(from = 1, to = 127) int mType = 1; /** - * The service-assigned type that identifies this payload. Unique for each row. Duplicates are - * discarded. Must be >0 and <128. + * Time of the event in milliseconds. + * @hide */ - private int mType = 0; + private long mTimeMillis = 0; - /** Additional data to be logged. */ + /** + * Additional data to be logged. Can be null if no additional data needs to be written as part + * of the event, and only the occurrence of the event needs to be logged. + */ @Nullable ContentValues mData = null; + /** + * The existing {@link RequestLogRecord} that this payload should be associated with. In an + * implementation of + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}, this should be + * set to a value returned by {@link LogReader#getRequests(long, long)}. In an implementation + * of {@link IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}, this should be + * set to {@code null} because the payload will be automatically associated with the current + * {@link RequestLogRecord}. + * + * @hide + */ + @Nullable RequestLogRecord mRequestLogRecord = null; + + + abstract static class BaseBuilder { + /** + * @hide + */ + public abstract Builder setTimeMillis(long value); + } + // Code below generated by codegen v1.0.23. @@ -62,12 +109,23 @@ public final class EventLogRecord implements Parcelable { @DataClass.Generated.Member /* package-private */ EventLogRecord( - int rowIndex, - int type, - @Nullable ContentValues data) { + @IntRange(from = 0) int rowIndex, + @IntRange(from = 1, to = 127) int type, + long timeMillis, + @Nullable ContentValues data, + @Nullable RequestLogRecord requestLogRecord) { this.mRowIndex = rowIndex; + AnnotationValidations.validate( + IntRange.class, null, mRowIndex, + "from", 0); this.mType = type; + AnnotationValidations.validate( + IntRange.class, null, mType, + "from", 1, + "to", 127); + this.mTimeMillis = timeMillis; this.mData = data; + this.mRequestLogRecord = requestLogRecord; // onConstructed(); // You can define this method to get a callback } @@ -77,27 +135,57 @@ public final class EventLogRecord implements Parcelable { * associated with. */ @DataClass.Generated.Member - public int getRowIndex() { + public @IntRange(from = 0) int getRowIndex() { return mRowIndex; } /** - * The service-assigned type that identifies this payload. Unique for each row. Duplicates are - * discarded. Must be >0 and <128. + * The service-assigned identifier that identifies this payload. Each row in + * {@link RequestLogRecord} can be associated with up to one event of a specified type. + * The platform drops events if another event with the same type already exists for a row + * in {@link RequestLogRecord}. Must be >0 and <128. This allows up to 127 events to be + * written for each row in {@link RequestLogRecord}. If unspecified, the default is 1. */ @DataClass.Generated.Member - public int getType() { + public @IntRange(from = 1, to = 127) int getType() { return mType; } /** - * Additional data to be logged. + * Time of the event in milliseconds. + * + * @hide + */ + @DataClass.Generated.Member + public long getTimeMillis() { + return mTimeMillis; + } + + /** + * Additional data to be logged. Can be null if no additional data needs to be written as part + * of the event, and only the occurrence of the event needs to be logged. */ @DataClass.Generated.Member public @Nullable ContentValues getData() { return mData; } + /** + * The existing {@link RequestLogRecord} that this payload should be associated with. In an + * implementation of + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}, this should be + * set to a value returned by {@link LogReader#getRequests(long, long)}. In an implementation + * of {@link IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}, this should be + * set to {@code null} because the payload will be automatically associated with the current + * {@link RequestLogRecord}. + * + * @hide + */ + @DataClass.Generated.Member + public @Nullable RequestLogRecord getRequestLogRecord() { + return mRequestLogRecord; + } + @Override @DataClass.Generated.Member public boolean equals(@Nullable Object o) { @@ -113,7 +201,9 @@ public final class EventLogRecord implements Parcelable { return true && mRowIndex == that.mRowIndex && mType == that.mType - && java.util.Objects.equals(mData, that.mData); + && mTimeMillis == that.mTimeMillis + && java.util.Objects.equals(mData, that.mData) + && java.util.Objects.equals(mRequestLogRecord, that.mRequestLogRecord); } @Override @@ -125,22 +215,27 @@ public final class EventLogRecord implements Parcelable { int _hash = 1; _hash = 31 * _hash + mRowIndex; _hash = 31 * _hash + mType; + _hash = 31 * _hash + Long.hashCode(mTimeMillis); _hash = 31 * _hash + java.util.Objects.hashCode(mData); + _hash = 31 * _hash + java.util.Objects.hashCode(mRequestLogRecord); return _hash; } @Override @DataClass.Generated.Member - public void writeToParcel(@android.annotation.NonNull android.os.Parcel dest, int flags) { + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { // You can override field parcelling by defining methods like: // void parcelFieldName(Parcel dest, int flags) { ... } byte flg = 0; - if (mData != null) flg |= 0x4; + if (mData != null) flg |= 0x8; + if (mRequestLogRecord != null) flg |= 0x10; dest.writeByte(flg); dest.writeInt(mRowIndex); dest.writeInt(mType); + dest.writeLong(mTimeMillis); if (mData != null) dest.writeTypedObject(mData, flags); + if (mRequestLogRecord != null) dest.writeTypedObject(mRequestLogRecord, flags); } @Override @@ -150,24 +245,35 @@ public final class EventLogRecord implements Parcelable { /** @hide */ @SuppressWarnings({"unchecked", "RedundantCast"}) @DataClass.Generated.Member - /* package-private */ EventLogRecord(@android.annotation.NonNull android.os.Parcel in) { + /* package-private */ EventLogRecord(@NonNull android.os.Parcel in) { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } byte flg = in.readByte(); int rowIndex = in.readInt(); int type = in.readInt(); - ContentValues data = (flg & 0x4) == 0 ? null : (ContentValues) in.readTypedObject(ContentValues.CREATOR); + long timeMillis = in.readLong(); + ContentValues data = (flg & 0x8) == 0 ? null : (ContentValues) in.readTypedObject(ContentValues.CREATOR); + RequestLogRecord requestLogRecord = (flg & 0x10) == 0 ? null : (RequestLogRecord) in.readTypedObject(RequestLogRecord.CREATOR); this.mRowIndex = rowIndex; + AnnotationValidations.validate( + IntRange.class, null, mRowIndex, + "from", 0); this.mType = type; + AnnotationValidations.validate( + IntRange.class, null, mType, + "from", 1, + "to", 127); + this.mTimeMillis = timeMillis; this.mData = data; + this.mRequestLogRecord = requestLogRecord; // onConstructed(); // You can define this method to get a callback } @DataClass.Generated.Member - public static final @android.annotation.NonNull Parcelable.Creator<EventLogRecord> CREATOR + public static final @NonNull Parcelable.Creator<EventLogRecord> CREATOR = new Parcelable.Creator<EventLogRecord>() { @Override public EventLogRecord[] newArray(int size) { @@ -175,7 +281,7 @@ public final class EventLogRecord implements Parcelable { } @Override - public EventLogRecord createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + public EventLogRecord createFromParcel(@NonNull android.os.Parcel in) { return new EventLogRecord(in); } }; @@ -183,13 +289,16 @@ public final class EventLogRecord implements Parcelable { /** * A builder for {@link EventLogRecord} */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member - public static final class Builder { + public static final class Builder extends BaseBuilder { - private int mRowIndex; - private int mType; + private @IntRange(from = 0) int mRowIndex; + private @IntRange(from = 1, to = 127) int mType; + private long mTimeMillis; private @Nullable ContentValues mData; + private @Nullable RequestLogRecord mRequestLogRecord; private long mBuilderFieldsSet = 0L; @@ -201,7 +310,7 @@ public final class EventLogRecord implements Parcelable { * associated with. */ @DataClass.Generated.Member - public @android.annotation.NonNull Builder setRowIndex(int value) { + public @NonNull Builder setRowIndex(@IntRange(from = 0) int value) { checkNotUsed(); mBuilderFieldsSet |= 0x1; mRowIndex = value; @@ -209,11 +318,14 @@ public final class EventLogRecord implements Parcelable { } /** - * The service-assigned type that identifies this payload. Unique for each row. Duplicates are - * discarded. Must be >0 and <128. + * The service-assigned identifier that identifies this payload. Each row in + * {@link RequestLogRecord} can be associated with up to one event of a specified type. + * The platform drops events if another event with the same type already exists for a row + * in {@link RequestLogRecord}. Must be >0 and <128. This allows up to 127 events to be + * written for each row in {@link RequestLogRecord}. If unspecified, the default is 1. */ @DataClass.Generated.Member - public @android.annotation.NonNull Builder setType(int value) { + public @NonNull Builder setType(@IntRange(from = 1, to = 127) int value) { checkNotUsed(); mBuilderFieldsSet |= 0x2; mType = value; @@ -221,39 +333,81 @@ public final class EventLogRecord implements Parcelable { } /** - * Additional data to be logged. + * Time of the event in milliseconds. + * + * @hide */ @DataClass.Generated.Member - public @android.annotation.NonNull Builder setData(@android.annotation.NonNull ContentValues value) { + @Override + public @NonNull Builder setTimeMillis(long value) { checkNotUsed(); mBuilderFieldsSet |= 0x4; + mTimeMillis = value; + return this; + } + + /** + * Additional data to be logged. Can be null if no additional data needs to be written as part + * of the event, and only the occurrence of the event needs to be logged. + */ + @DataClass.Generated.Member + public @NonNull Builder setData(@NonNull ContentValues value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; mData = value; return this; } + /** + * The existing {@link RequestLogRecord} that this payload should be associated with. In an + * implementation of + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}, this should be + * set to a value returned by {@link LogReader#getRequests(long, long)}. In an implementation + * of {@link IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}, this should be + * set to {@code null} because the payload will be automatically associated with the current + * {@link RequestLogRecord}. + * + * @hide + */ + @DataClass.Generated.Member + public @NonNull Builder setRequestLogRecord(@NonNull RequestLogRecord value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x10; + mRequestLogRecord = value; + return this; + } + /** Builds the instance. This builder should not be touched after calling this! */ - public @android.annotation.NonNull EventLogRecord build() { + public @NonNull EventLogRecord build() { checkNotUsed(); - mBuilderFieldsSet |= 0x8; // Mark builder used + mBuilderFieldsSet |= 0x20; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mRowIndex = 0; } if ((mBuilderFieldsSet & 0x2) == 0) { - mType = 0; + mType = 1; } if ((mBuilderFieldsSet & 0x4) == 0) { + mTimeMillis = 0; + } + if ((mBuilderFieldsSet & 0x8) == 0) { mData = null; } + if ((mBuilderFieldsSet & 0x10) == 0) { + mRequestLogRecord = null; + } EventLogRecord o = new EventLogRecord( mRowIndex, mType, - mData); + mTimeMillis, + mData, + mRequestLogRecord); return o; } private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x8) != 0) { + if ((mBuilderFieldsSet & 0x20) != 0) { throw new IllegalStateException( "This Builder should not be reused. Use a new Builder instance instead"); } @@ -261,10 +415,10 @@ public final class EventLogRecord implements Parcelable { } @DataClass.Generated( - time = 1692118350572L, + time = 1697576750150L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventLogRecord.java", - inputSignatures = "private int mRowIndex\nprivate int mType\n @android.annotation.Nullable android.content.ContentValues mData\nclass EventLogRecord extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private @android.annotation.IntRange int mRowIndex\nprivate @android.annotation.IntRange int mType\nprivate long mTimeMillis\n @android.annotation.Nullable android.content.ContentValues mData\n @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nclass EventLogRecord extends java.lang.Object implements [android.os.Parcelable]\npublic abstract android.adservices.ondevicepersonalization.EventLogRecord.Builder setTimeMillis(long)\nclass BaseBuilder extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)\npublic abstract android.adservices.ondevicepersonalization.EventLogRecord.Builder setTimeMillis(long)\nclass BaseBuilder extends java.lang.Object implements []") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/WebViewEventOutput.aidl b/framework/java/android/adservices/ondevicepersonalization/EventOutput.aidl index 7cbe3696..e7e0f952 100644 --- a/framework/java/android/adservices/ondevicepersonalization/WebViewEventOutput.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/EventOutput.aidl @@ -16,4 +16,4 @@ package android.adservices.ondevicepersonalization; -parcelable WebViewEventOutput; +parcelable EventOutput; diff --git a/framework/java/android/adservices/ondevicepersonalization/WebViewEventOutput.java b/framework/java/android/adservices/ondevicepersonalization/EventOutput.java index b69b36fc..82b65750 100644 --- a/framework/java/android/adservices/ondevicepersonalization/WebViewEventOutput.java +++ b/framework/java/android/adservices/ondevicepersonalization/EventOutput.java @@ -22,12 +22,11 @@ import android.os.Parcelable; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * The result returned by {@link IsolatedWorker#onWebViewEvent()} - * + * The result returned by {@link IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}. * @hide */ @DataClass(genBuilder = true, genEqualsHashCode = true) -public final class WebViewEventOutput implements Parcelable { +public final class EventOutput implements Parcelable { /** * An {@link EventLogRecord} to be written to the EVENTS table, if not null. Each * {@link EventLogRecord} is associated with a row in an existing {@link RequestLogRecord} that @@ -43,7 +42,7 @@ public final class WebViewEventOutput implements Parcelable { // CHECKSTYLE:OFF Generated code // // To regenerate run: - // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/WebViewEventOutput.java + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventOutput.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control @@ -51,7 +50,7 @@ public final class WebViewEventOutput implements Parcelable { @DataClass.Generated.Member - /* package-private */ WebViewEventOutput( + /* package-private */ EventOutput( @Nullable EventLogRecord eventLogRecord) { this.mEventLogRecord = eventLogRecord; @@ -72,13 +71,13 @@ public final class WebViewEventOutput implements Parcelable { @DataClass.Generated.Member public boolean equals(@Nullable Object o) { // You can override field equality logic by defining either of the methods like: - // boolean fieldNameEquals(WebViewEventOutput other) { ... } + // boolean fieldNameEquals(EventOutput other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; @SuppressWarnings("unchecked") - WebViewEventOutput that = (WebViewEventOutput) o; + EventOutput that = (EventOutput) o; //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mEventLogRecord, that.mEventLogRecord); @@ -114,7 +113,7 @@ public final class WebViewEventOutput implements Parcelable { /** @hide */ @SuppressWarnings({"unchecked", "RedundantCast"}) @DataClass.Generated.Member - /* package-private */ WebViewEventOutput(@android.annotation.NonNull android.os.Parcel in) { + /* package-private */ EventOutput(@android.annotation.NonNull android.os.Parcel in) { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } @@ -127,21 +126,21 @@ public final class WebViewEventOutput implements Parcelable { } @DataClass.Generated.Member - public static final @android.annotation.NonNull Parcelable.Creator<WebViewEventOutput> CREATOR - = new Parcelable.Creator<WebViewEventOutput>() { + public static final @android.annotation.NonNull Parcelable.Creator<EventOutput> CREATOR + = new Parcelable.Creator<EventOutput>() { @Override - public WebViewEventOutput[] newArray(int size) { - return new WebViewEventOutput[size]; + public EventOutput[] newArray(int size) { + return new EventOutput[size]; } @Override - public WebViewEventOutput createFromParcel(@android.annotation.NonNull android.os.Parcel in) { - return new WebViewEventOutput(in); + public EventOutput createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + return new EventOutput(in); } }; /** - * A builder for {@link WebViewEventOutput} + * A builder for {@link EventOutput} */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member @@ -168,14 +167,14 @@ public final class WebViewEventOutput implements Parcelable { } /** Builds the instance. This builder should not be touched after calling this! */ - public @android.annotation.NonNull WebViewEventOutput build() { + public @android.annotation.NonNull EventOutput build() { checkNotUsed(); mBuilderFieldsSet |= 0x2; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mEventLogRecord = null; } - WebViewEventOutput o = new WebViewEventOutput( + EventOutput o = new EventOutput( mEventLogRecord); return o; } @@ -189,10 +188,10 @@ public final class WebViewEventOutput implements Parcelable { } @DataClass.Generated( - time = 1692118441223L, + time = 1696369232183L, codegenVersion = "1.0.23", - sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/WebViewEventOutput.java", - inputSignatures = " @android.annotation.Nullable android.adservices.ondevicepersonalization.EventLogRecord mEventLogRecord\nclass WebViewEventOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventOutput.java", + inputSignatures = " @android.annotation.Nullable android.adservices.ondevicepersonalization.EventLogRecord mEventLogRecord\nclass EventOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/EventOutputParcel.java b/framework/java/android/adservices/ondevicepersonalization/EventOutputParcel.java new file mode 100644 index 00000000..5432a6c4 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/EventOutputParcel.java @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Parcelable version of {@link EventOutput}. + * @hide + */ +@DataClass(genAidl = false, genBuilder = false) +public final class EventOutputParcel implements Parcelable { + /** + * An {@link EventLogRecord} to be written to the EVENTS table, if not null. Each + * {@link EventLogRecord} is associated with a row in an existing {@link RequestLogRecord} that + * has been written to the REQUESTS table. + */ + @Nullable EventLogRecord mEventLogRecord = null; + + /** @hide */ + public EventOutputParcel(@NonNull EventOutput value) { + this(value.getEventLogRecord()); + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventOutputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + /** + * Creates a new EventOutputParcel. + * + * @param eventLogRecord + * An {@link EventLogRecord} to be written to the EVENTS table, if not null. Each + * {@link EventLogRecord} is associated with a row in an existing {@link RequestLogRecord} that + * has been written to the REQUESTS table. + */ + @DataClass.Generated.Member + public EventOutputParcel( + @Nullable EventLogRecord eventLogRecord) { + this.mEventLogRecord = eventLogRecord; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * An {@link EventLogRecord} to be written to the EVENTS table, if not null. Each + * {@link EventLogRecord} is associated with a row in an existing {@link RequestLogRecord} that + * has been written to the REQUESTS table. + */ + @DataClass.Generated.Member + public @Nullable EventLogRecord getEventLogRecord() { + return mEventLogRecord; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@android.annotation.NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mEventLogRecord != null) flg |= 0x1; + dest.writeByte(flg); + if (mEventLogRecord != null) dest.writeTypedObject(mEventLogRecord, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ EventOutputParcel(@android.annotation.NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + EventLogRecord eventLogRecord = (flg & 0x1) == 0 ? null : (EventLogRecord) in.readTypedObject(EventLogRecord.CREATOR); + + this.mEventLogRecord = eventLogRecord; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @android.annotation.NonNull Parcelable.Creator<EventOutputParcel> CREATOR + = new Parcelable.Creator<EventOutputParcel>() { + @Override + public EventOutputParcel[] newArray(int size) { + return new EventOutputParcel[size]; + } + + @Override + public EventOutputParcel createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + return new EventOutputParcel(in); + } + }; + + @DataClass.Generated( + time = 1698864082503L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/EventOutputParcel.java", + inputSignatures = " @android.annotation.Nullable android.adservices.ondevicepersonalization.EventLogRecord mEventLogRecord\nclass EventOutputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genBuilder=false)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/EventUrlProvider.java b/framework/java/android/adservices/ondevicepersonalization/EventUrlProvider.java index f55bfcd5..930c434f 100644 --- a/framework/java/android/adservices/ondevicepersonalization/EventUrlProvider.java +++ b/framework/java/android/adservices/ondevicepersonalization/EventUrlProvider.java @@ -16,10 +16,14 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + import android.adservices.ondevicepersonalization.aidl.IDataAccessService; import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; +import android.annotation.WorkerThread; import android.net.Uri; import android.os.Bundle; import android.os.PersistableBundle; @@ -28,16 +32,16 @@ import android.os.RemoteException; import java.util.Objects; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.TimeUnit; /** - * Generates event tracking URLs for a request. The {@link IsolatedService} can - * embed these URLs in the HTML output. When the HTML is rendered, ODP will intercept requests - * to these URLs, call {@link IsolatedWorker#onEvent}, and log the returned output - * in the EVENTS table. + * Generates event tracking URLs for a request. The service can embed these URLs within the + * HTML output as needed. When the HTML is rendered within an ODP WebView, ODP will intercept + * requests to these URLs, call + * {@code IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)}, and log the returned + * output in the EVENTS table. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class EventUrlProvider { private static final long ASYNC_TIMEOUT_MS = 1000; @@ -53,16 +57,18 @@ public class EventUrlProvider { * 200 (OK) if the response data is not empty. Returns HTTP Status 204 (No Content) if the * response data is empty. * - * @param eventParams The data to be passed to {@link IsolatedWorker#onEvent} + * @param eventParams The data to be passed to + * {@code IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)} * when the event occurs. * @param responseData The content to be returned to the WebView when the URL is fetched. * @param mimeType The Mime Type of the URL response. * @return An ODP event URL that can be inserted into a WebView. */ - @NonNull public Uri getEventTrackingUrl( + @WorkerThread + @NonNull public Uri createEventTrackingUrlWithResponse( @NonNull PersistableBundle eventParams, @Nullable byte[] responseData, - @Nullable String mimeType) throws OnDevicePersonalizationException { + @Nullable String mimeType) { Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_EVENT_PARAMS, eventParams); params.putByteArray(Constants.EXTRA_RESPONSE_DATA, responseData); @@ -74,22 +80,23 @@ public class EventUrlProvider { * Creates an event tracking URL that redirects to the provided destination URL when it is * clicked in an ODP webview. * - * @param eventParams The data to be passed to {@link IsolatedWorker#onEvent} + * @param eventParams The data to be passed to + * {@code IsolatedWorker#onEvent(EventInput, java.util.function.Consumer)} * when the event occurs * @param destinationUrl The URL to redirect to. * @return An ODP event URL that can be inserted into a WebView. */ - @NonNull public Uri getEventTrackingUrlWithRedirect( + @WorkerThread + @NonNull public Uri createEventTrackingUrlWithRedirect( @NonNull PersistableBundle eventParams, - @Nullable String destinationUrl) throws OnDevicePersonalizationException { + @Nullable Uri destinationUrl) { Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_EVENT_PARAMS, eventParams); - params.putString(Constants.EXTRA_DESTINATION_URL, destinationUrl); + params.putString(Constants.EXTRA_DESTINATION_URL, destinationUrl.toString()); return getUrl(params); } - @NonNull private Uri getUrl(@NonNull Bundle params) - throws OnDevicePersonalizationException { + @NonNull private Uri getUrl(@NonNull Bundle params) { try { BlockingQueue<CallbackResult> asyncResult = new ArrayBlockingQueue<>(1); @@ -106,19 +113,17 @@ public class EventUrlProvider { asyncResult.add(new CallbackResult(null, errorCode)); } }); - CallbackResult callbackResult = - asyncResult.poll(ASYNC_TIMEOUT_MS, TimeUnit.MILLISECONDS); + CallbackResult callbackResult = asyncResult.take(); Objects.requireNonNull(callbackResult); if (callbackResult.mErrorCode != 0) { - throw new OnDevicePersonalizationException(callbackResult.mErrorCode); + throw new IllegalStateException("Error: " + callbackResult.mErrorCode); } Bundle result = Objects.requireNonNull(callbackResult.mResult); Uri url = Objects.requireNonNull( result.getParcelable(Constants.EXTRA_RESULT, Uri.class)); return url; } catch (InterruptedException | RemoteException e) { - throw new OnDevicePersonalizationException( - Constants.STATUS_INTERNAL_ERROR, (Throwable) e); + throw new RuntimeException(e); } } diff --git a/framework/java/android/adservices/ondevicepersonalization/ExecuteInput.java b/framework/java/android/adservices/ondevicepersonalization/ExecuteInput.java index 5123315f..7ded2213 100644 --- a/framework/java/android/adservices/ondevicepersonalization/ExecuteInput.java +++ b/framework/java/android/adservices/ondevicepersonalization/ExecuteInput.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.os.Parcelable; import android.os.PersistableBundle; @@ -24,11 +27,11 @@ import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * The input data for {@link IsolatedWorker#onExecute()}. + * The input data for {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. * - * @hide */ -@DataClass(genBuilder = true, genEqualsHashCode = true) +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +@DataClass(genBuilder = false, genHiddenConstructor = true, genEqualsHashCode = true) public final class ExecuteInput implements Parcelable { /** * The package name of the calling app. @@ -41,6 +44,11 @@ public final class ExecuteInput implements Parcelable { */ @NonNull PersistableBundle mAppParams = PersistableBundle.EMPTY; + /** @hide */ + public ExecuteInput(@NonNull ExecuteInputParcel parcel) { + this(parcel.getAppPackageName(), parcel.getAppParams()); + } + // Code below generated by codegen v1.0.23. @@ -56,8 +64,18 @@ public final class ExecuteInput implements Parcelable { //@formatter:off + /** + * Creates a new ExecuteInput. + * + * @param appPackageName + * The package name of the calling app. + * @param appParams + * The parameters provided by the app to the {@link IsolatedService}. The service + * defines the expected keys in this {@link PersistableBundle}. + * @hide + */ @DataClass.Generated.Member - /* package-private */ ExecuteInput( + public ExecuteInput( @NonNull String appPackageName, @NonNull PersistableBundle appParams) { this.mAppPackageName = appPackageName; @@ -164,74 +182,11 @@ public final class ExecuteInput implements Parcelable { } }; - /** - * A builder for {@link ExecuteInput} - */ - @SuppressWarnings("WeakerAccess") - @DataClass.Generated.Member - public static final class Builder { - - private @NonNull String mAppPackageName; - private @NonNull PersistableBundle mAppParams; - - private long mBuilderFieldsSet = 0L; - - public Builder() { - } - - /** - * The package name of the calling app. - */ - @DataClass.Generated.Member - public @NonNull Builder setAppPackageName(@NonNull String value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x1; - mAppPackageName = value; - return this; - } - - /** - * The parameters provided by the app to the {@link IsolatedService}. The service - * defines the expected keys in this {@link PersistableBundle}. - */ - @DataClass.Generated.Member - public @NonNull Builder setAppParams(@NonNull PersistableBundle value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x2; - mAppParams = value; - return this; - } - - /** Builds the instance. This builder should not be touched after calling this! */ - public @NonNull ExecuteInput build() { - checkNotUsed(); - mBuilderFieldsSet |= 0x4; // Mark builder used - - if ((mBuilderFieldsSet & 0x1) == 0) { - mAppPackageName = ""; - } - if ((mBuilderFieldsSet & 0x2) == 0) { - mAppParams = PersistableBundle.EMPTY; - } - ExecuteInput o = new ExecuteInput( - mAppPackageName, - mAppParams); - return o; - } - - private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x4) != 0) { - throw new IllegalStateException( - "This Builder should not be reused. Use a new Builder instance instead"); - } - } - } - @DataClass.Generated( - time = 1692118363539L, + time = 1698872215353L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/ExecuteInput.java", - inputSignatures = " @android.annotation.NonNull java.lang.String mAppPackageName\n @android.annotation.NonNull android.os.PersistableBundle mAppParams\nclass ExecuteInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = " @android.annotation.NonNull java.lang.String mAppPackageName\n @android.annotation.NonNull android.os.PersistableBundle mAppParams\nclass ExecuteInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=false, genHiddenConstructor=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/ExecuteInputParcel.java b/framework/java/android/adservices/ondevicepersonalization/ExecuteInputParcel.java new file mode 100644 index 00000000..44c60f30 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/ExecuteInputParcel.java @@ -0,0 +1,213 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.os.Parcelable; +import android.os.PersistableBundle; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Parcelable version of {@link ExecuteInput}. + * @hide + */ +@DataClass(genAidl = false, genHiddenBuilder = true) +public final class ExecuteInputParcel implements Parcelable { + /** + * The package name of the calling app. + */ + @NonNull String mAppPackageName = ""; + + /** + * The parameters provided by the app to the {@link IsolatedService}. The service + * defines the expected keys in this {@link PersistableBundle}. + */ + @NonNull PersistableBundle mAppParams = PersistableBundle.EMPTY; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/ExecuteInputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ ExecuteInputParcel( + @NonNull String appPackageName, + @NonNull PersistableBundle appParams) { + this.mAppPackageName = appPackageName; + AnnotationValidations.validate( + NonNull.class, null, mAppPackageName); + this.mAppParams = appParams; + AnnotationValidations.validate( + NonNull.class, null, mAppParams); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The package name of the calling app. + */ + @DataClass.Generated.Member + public @NonNull String getAppPackageName() { + return mAppPackageName; + } + + /** + * The parameters provided by the app to the {@link IsolatedService}. The service + * defines the expected keys in this {@link PersistableBundle}. + */ + @DataClass.Generated.Member + public @NonNull PersistableBundle getAppParams() { + return mAppParams; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + dest.writeString(mAppPackageName); + dest.writeTypedObject(mAppParams, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ ExecuteInputParcel(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + String appPackageName = in.readString(); + PersistableBundle appParams = (PersistableBundle) in.readTypedObject(PersistableBundle.CREATOR); + + this.mAppPackageName = appPackageName; + AnnotationValidations.validate( + NonNull.class, null, mAppPackageName); + this.mAppParams = appParams; + AnnotationValidations.validate( + NonNull.class, null, mAppParams); + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<ExecuteInputParcel> CREATOR + = new Parcelable.Creator<ExecuteInputParcel>() { + @Override + public ExecuteInputParcel[] newArray(int size) { + return new ExecuteInputParcel[size]; + } + + @Override + public ExecuteInputParcel createFromParcel(@NonNull android.os.Parcel in) { + return new ExecuteInputParcel(in); + } + }; + + /** + * A builder for {@link ExecuteInputParcel} + * @hide + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private @NonNull String mAppPackageName; + private @NonNull PersistableBundle mAppParams; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * The package name of the calling app. + */ + @DataClass.Generated.Member + public @NonNull Builder setAppPackageName(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mAppPackageName = value; + return this; + } + + /** + * The parameters provided by the app to the {@link IsolatedService}. The service + * defines the expected keys in this {@link PersistableBundle}. + */ + @DataClass.Generated.Member + public @NonNull Builder setAppParams(@NonNull PersistableBundle value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mAppParams = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull ExecuteInputParcel build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mAppPackageName = ""; + } + if ((mBuilderFieldsSet & 0x2) == 0) { + mAppParams = PersistableBundle.EMPTY; + } + ExecuteInputParcel o = new ExecuteInputParcel( + mAppPackageName, + mAppParams); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x4) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1698868678877L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/ExecuteInputParcel.java", + inputSignatures = " @android.annotation.NonNull java.lang.String mAppPackageName\n @android.annotation.NonNull android.os.PersistableBundle mAppParams\nclass ExecuteInputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genHiddenBuilder=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/ExecuteOutput.java b/framework/java/android/adservices/ondevicepersonalization/ExecuteOutput.java index f8b673b3..ad7c5d2e 100644 --- a/framework/java/android/adservices/ondevicepersonalization/ExecuteOutput.java +++ b/framework/java/android/adservices/ondevicepersonalization/ExecuteOutput.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; import android.os.Parcelable; @@ -27,17 +30,19 @@ import java.util.Collections; import java.util.List; /** - * The result returned by {@link IsolatedWorker#onExecute()} in response to a call to - * {@link OnDevicePersonalizationManager#execute()} from a client app. - * - * @hide + * The result returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} in response to a call to + * {@code OnDevicePersonalizationManager#execute(ComponentName, PersistableBundle, + * java.util.concurrent.Executor, OutcomeReceiver)} + * from a client app. */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @DataClass(genBuilder = true, genEqualsHashCode = true) public final class ExecuteOutput implements Parcelable { /** * Persistent data to be written to the REQUESTS table after - * {@link IsolatedWorker#onExecute()} completes. If null, no persistent data will - * be written. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * completes. If null, no persistent data will be written. */ @Nullable private RequestLogRecord mRequestLogRecord = null; @@ -48,6 +53,18 @@ public final class ExecuteOutput implements Parcelable { @DataClass.PluralOf("renderingConfig") @NonNull private List<RenderingConfig> mRenderingConfigs = Collections.emptyList(); + /** + * A list of {@link EventLogRecord}. Writes events to the EVENTS table and associates + * them with requests with the specified corresponding {@link RequestLogRecord} from + * {@link EventLogRecord#getRequestLogRecord()}. + * If the event does not contain a {@link RequestLogRecord} emitted by this package, the + * EventLogRecord is not written. + * + * @hide + */ + @DataClass.PluralOf("eventLogRecord") + @NonNull private List<EventLogRecord> mEventLogRecords = Collections.emptyList(); + // Code below generated by codegen v1.0.23. @@ -66,19 +83,23 @@ public final class ExecuteOutput implements Parcelable { @DataClass.Generated.Member /* package-private */ ExecuteOutput( @Nullable RequestLogRecord requestLogRecord, - @NonNull List<RenderingConfig> renderingConfigs) { + @NonNull List<RenderingConfig> renderingConfigs, + @NonNull List<EventLogRecord> eventLogRecords) { this.mRequestLogRecord = requestLogRecord; this.mRenderingConfigs = renderingConfigs; AnnotationValidations.validate( NonNull.class, null, mRenderingConfigs); + this.mEventLogRecords = eventLogRecords; + AnnotationValidations.validate( + NonNull.class, null, mEventLogRecords); // onConstructed(); // You can define this method to get a callback } /** * Persistent data to be written to the REQUESTS table after - * {@link IsolatedWorker#onExecute()} completes. If null, no persistent data will - * be written. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * completes. If null, no persistent data will be written. */ @DataClass.Generated.Member public @Nullable RequestLogRecord getRequestLogRecord() { @@ -94,6 +115,20 @@ public final class ExecuteOutput implements Parcelable { return mRenderingConfigs; } + /** + * A list of {@link EventLogRecord}. Writes events to the EVENTS table and associates + * them with requests with the specified corresponding {@link RequestLogRecord} from + * {@link EventLogRecord#getRequestLogRecord()}. + * If the event does not contain a {@link RequestLogRecord} emitted by this package, the + * EventLogRecord is not written. + * + * @hide + */ + @DataClass.Generated.Member + public @NonNull List<EventLogRecord> getEventLogRecords() { + return mEventLogRecords; + } + @Override @DataClass.Generated.Member public boolean equals(@Nullable Object o) { @@ -108,7 +143,8 @@ public final class ExecuteOutput implements Parcelable { //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mRequestLogRecord, that.mRequestLogRecord) - && java.util.Objects.equals(mRenderingConfigs, that.mRenderingConfigs); + && java.util.Objects.equals(mRenderingConfigs, that.mRenderingConfigs) + && java.util.Objects.equals(mEventLogRecords, that.mEventLogRecords); } @Override @@ -120,6 +156,7 @@ public final class ExecuteOutput implements Parcelable { int _hash = 1; _hash = 31 * _hash + java.util.Objects.hashCode(mRequestLogRecord); _hash = 31 * _hash + java.util.Objects.hashCode(mRenderingConfigs); + _hash = 31 * _hash + java.util.Objects.hashCode(mEventLogRecords); return _hash; } @@ -134,6 +171,7 @@ public final class ExecuteOutput implements Parcelable { dest.writeByte(flg); if (mRequestLogRecord != null) dest.writeTypedObject(mRequestLogRecord, flags); dest.writeParcelableList(mRenderingConfigs, flags); + dest.writeParcelableList(mEventLogRecords, flags); } @Override @@ -151,11 +189,16 @@ public final class ExecuteOutput implements Parcelable { RequestLogRecord requestLogRecord = (flg & 0x1) == 0 ? null : (RequestLogRecord) in.readTypedObject(RequestLogRecord.CREATOR); List<RenderingConfig> renderingConfigs = new java.util.ArrayList<>(); in.readParcelableList(renderingConfigs, RenderingConfig.class.getClassLoader()); + List<EventLogRecord> eventLogRecords = new java.util.ArrayList<>(); + in.readParcelableList(eventLogRecords, EventLogRecord.class.getClassLoader()); this.mRequestLogRecord = requestLogRecord; this.mRenderingConfigs = renderingConfigs; AnnotationValidations.validate( NonNull.class, null, mRenderingConfigs); + this.mEventLogRecords = eventLogRecords; + AnnotationValidations.validate( + NonNull.class, null, mEventLogRecords); // onConstructed(); // You can define this method to get a callback } @@ -177,12 +220,14 @@ public final class ExecuteOutput implements Parcelable { /** * A builder for {@link ExecuteOutput} */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { private @Nullable RequestLogRecord mRequestLogRecord; private @NonNull List<RenderingConfig> mRenderingConfigs; + private @NonNull List<EventLogRecord> mEventLogRecords; private long mBuilderFieldsSet = 0L; @@ -191,8 +236,8 @@ public final class ExecuteOutput implements Parcelable { /** * Persistent data to be written to the REQUESTS table after - * {@link IsolatedWorker#onExecute()} completes. If null, no persistent data will - * be written. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * completes. If null, no persistent data will be written. */ @DataClass.Generated.Member public @NonNull Builder setRequestLogRecord(@NonNull RequestLogRecord value) { @@ -222,10 +267,38 @@ public final class ExecuteOutput implements Parcelable { return this; } + /** + * A list of {@link EventLogRecord}. Writes events to the EVENTS table and associates + * them with requests with the specified corresponding {@link RequestLogRecord} from + * {@link EventLogRecord#getRequestLogRecord()}. + * If the event does not contain a {@link RequestLogRecord} emitted by this package, the + * EventLogRecord is not written. + * + * @hide + */ + @DataClass.Generated.Member + public @NonNull Builder setEventLogRecords(@NonNull List<EventLogRecord> value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mEventLogRecords = value; + return this; + } + + /** + * @see #setEventLogRecords + * @hide + */ + @DataClass.Generated.Member + public @NonNull Builder addEventLogRecord(@NonNull EventLogRecord value) { + if (mEventLogRecords == null) setEventLogRecords(new java.util.ArrayList<>()); + mEventLogRecords.add(value); + return this; + } + /** Builds the instance. This builder should not be touched after calling this! */ public @NonNull ExecuteOutput build() { checkNotUsed(); - mBuilderFieldsSet |= 0x4; // Mark builder used + mBuilderFieldsSet |= 0x8; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mRequestLogRecord = null; @@ -233,14 +306,18 @@ public final class ExecuteOutput implements Parcelable { if ((mBuilderFieldsSet & 0x2) == 0) { mRenderingConfigs = Collections.emptyList(); } + if ((mBuilderFieldsSet & 0x4) == 0) { + mEventLogRecords = Collections.emptyList(); + } ExecuteOutput o = new ExecuteOutput( mRequestLogRecord, - mRenderingConfigs); + mRenderingConfigs, + mEventLogRecords); return o; } private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x4) != 0) { + if ((mBuilderFieldsSet & 0x8) != 0) { throw new IllegalStateException( "This Builder should not be reused. Use a new Builder instance instead"); } @@ -248,10 +325,10 @@ public final class ExecuteOutput implements Parcelable { } @DataClass.Generated( - time = 1692118370720L, + time = 1697132452641L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/ExecuteOutput.java", - inputSignatures = "private @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nprivate @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"renderingConfig\") @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.RenderingConfig> mRenderingConfigs\nclass ExecuteOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nprivate @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"renderingConfig\") @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.RenderingConfig> mRenderingConfigs\nprivate @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"eventLogRecord\") @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.EventLogRecord> mEventLogRecords\nclass ExecuteOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/ExecuteOutputParcel.java b/framework/java/android/adservices/ondevicepersonalization/ExecuteOutputParcel.java new file mode 100644 index 00000000..df23d7cf --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/ExecuteOutputParcel.java @@ -0,0 +1,216 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +import java.util.Collections; +import java.util.List; + +/** + * Parcelable version of {@link ExecuteOutput}. + * @hide + */ +@DataClass(genAidl = false, genBuilder = false) +public final class ExecuteOutputParcel implements Parcelable { + /** + * Persistent data to be written to the REQUESTS table after + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * completes. If null, no persistent data will be written. + */ + @Nullable private RequestLogRecord mRequestLogRecord = null; + + /** + * A list of {@link RenderingConfig} objects, one per slot specified in the request from the + * calling app. The calling app and the service must agree on the expected size of this list. + */ + @DataClass.PluralOf("renderingConfig") + @NonNull private List<RenderingConfig> mRenderingConfigs = Collections.emptyList(); + + /** + * A list of {@link EventLogRecord}. Writes events to the EVENTS table and associates + * them with requests with the specified corresponding {@link RequestLogRecord} from + * {@link EventLogRecord#getRequestLogRecord()}. + * If the event does not contain a {@link RequestLogRecord} emitted by this package, the + * EventLogRecord is not written. + * + * @hide + */ + @DataClass.PluralOf("eventLogRecord") + @NonNull private List<EventLogRecord> mEventLogRecords = Collections.emptyList(); + + /** @hide */ + public ExecuteOutputParcel(@NonNull ExecuteOutput value) { + this(value.getRequestLogRecord(), value.getRenderingConfigs(), value.getEventLogRecords()); + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/ExecuteOutputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + /** + * Creates a new ExecuteOutputParcel. + * + * @param requestLogRecord + * Persistent data to be written to the REQUESTS table after + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * completes. If null, no persistent data will be written. + * @param renderingConfigs + * A list of {@link RenderingConfig} objects, one per slot specified in the request from the + * calling app. The calling app and the service must agree on the expected size of this list. + * @param eventLogRecords + * A list of {@link EventLogRecord}. Writes events to the EVENTS table and associates + * them with requests with the specified corresponding {@link RequestLogRecord} from + * {@link EventLogRecord#getRequestLogRecord()}. + * If the event does not contain a {@link RequestLogRecord} emitted by this package, the + * EventLogRecord is not written. + */ + @DataClass.Generated.Member + public ExecuteOutputParcel( + @Nullable RequestLogRecord requestLogRecord, + @NonNull List<RenderingConfig> renderingConfigs, + @NonNull List<EventLogRecord> eventLogRecords) { + this.mRequestLogRecord = requestLogRecord; + this.mRenderingConfigs = renderingConfigs; + AnnotationValidations.validate( + NonNull.class, null, mRenderingConfigs); + this.mEventLogRecords = eventLogRecords; + AnnotationValidations.validate( + NonNull.class, null, mEventLogRecords); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * Persistent data to be written to the REQUESTS table after + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * completes. If null, no persistent data will be written. + */ + @DataClass.Generated.Member + public @Nullable RequestLogRecord getRequestLogRecord() { + return mRequestLogRecord; + } + + /** + * A list of {@link RenderingConfig} objects, one per slot specified in the request from the + * calling app. The calling app and the service must agree on the expected size of this list. + */ + @DataClass.Generated.Member + public @NonNull List<RenderingConfig> getRenderingConfigs() { + return mRenderingConfigs; + } + + /** + * A list of {@link EventLogRecord}. Writes events to the EVENTS table and associates + * them with requests with the specified corresponding {@link RequestLogRecord} from + * {@link EventLogRecord#getRequestLogRecord()}. + * If the event does not contain a {@link RequestLogRecord} emitted by this package, the + * EventLogRecord is not written. + * + * @hide + */ + @DataClass.Generated.Member + public @NonNull List<EventLogRecord> getEventLogRecords() { + return mEventLogRecords; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mRequestLogRecord != null) flg |= 0x1; + dest.writeByte(flg); + if (mRequestLogRecord != null) dest.writeTypedObject(mRequestLogRecord, flags); + dest.writeParcelableList(mRenderingConfigs, flags); + dest.writeParcelableList(mEventLogRecords, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ ExecuteOutputParcel(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + RequestLogRecord requestLogRecord = (flg & 0x1) == 0 ? null : (RequestLogRecord) in.readTypedObject(RequestLogRecord.CREATOR); + List<RenderingConfig> renderingConfigs = new java.util.ArrayList<>(); + in.readParcelableList(renderingConfigs, RenderingConfig.class.getClassLoader()); + List<EventLogRecord> eventLogRecords = new java.util.ArrayList<>(); + in.readParcelableList(eventLogRecords, EventLogRecord.class.getClassLoader()); + + this.mRequestLogRecord = requestLogRecord; + this.mRenderingConfigs = renderingConfigs; + AnnotationValidations.validate( + NonNull.class, null, mRenderingConfigs); + this.mEventLogRecords = eventLogRecords; + AnnotationValidations.validate( + NonNull.class, null, mEventLogRecords); + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<ExecuteOutputParcel> CREATOR + = new Parcelable.Creator<ExecuteOutputParcel>() { + @Override + public ExecuteOutputParcel[] newArray(int size) { + return new ExecuteOutputParcel[size]; + } + + @Override + public ExecuteOutputParcel createFromParcel(@NonNull android.os.Parcel in) { + return new ExecuteOutputParcel(in); + } + }; + + @DataClass.Generated( + time = 1698864579986L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/ExecuteOutputParcel.java", + inputSignatures = "private @android.annotation.Nullable android.adservices.ondevicepersonalization.RequestLogRecord mRequestLogRecord\nprivate @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"renderingConfig\") @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.RenderingConfig> mRenderingConfigs\nprivate @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"eventLogRecord\") @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.EventLogRecord> mEventLogRecords\nclass ExecuteOutputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genBuilder=false)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.aidl b/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.aidl new file mode 100644 index 00000000..738f25dd --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable FederatedComputeInput; diff --git a/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.java b/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.java new file mode 100644 index 00000000..e5dee4c5 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.java @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * The input data for + * {@link FederatedComputeScheduler#schedule(FederatedComputeScheduler.Params, FederatedComputeInput)} + * + * @hide + */ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public final class FederatedComputeInput { + // TODO(b/300461799): add federated compute server document. + /** + * Population refers to a collection of devices that specific task groups can run on. It should + * match task plan configured at remote federated computation server. + */ + @NonNull private String mPopulationName = ""; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ FederatedComputeInput( + @NonNull String populationName) { + this.mPopulationName = populationName; + AnnotationValidations.validate( + NonNull.class, null, mPopulationName); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * Population refers to a collection of devices that specific task groups can run on. It should + * match task plan configured at remote federated computation server. + */ + @DataClass.Generated.Member + public @NonNull String getPopulationName() { + return mPopulationName; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(FederatedComputeInput other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + FederatedComputeInput that = (FederatedComputeInput) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mPopulationName, that.mPopulationName); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mPopulationName); + return _hash; + } + + /** + * A builder for {@link FederatedComputeInput} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private @NonNull String mPopulationName; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * Population refers to a collection of devices that specific task groups can run on. It should + * match task plan configured at remote federated computation server. + */ + @DataClass.Generated.Member + public @NonNull Builder setPopulationName(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mPopulationName = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull FederatedComputeInput build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mPopulationName = ""; + } + FederatedComputeInput o = new FederatedComputeInput( + mPopulationName); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x2) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1697578140247L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/FederatedComputeInput.java", + inputSignatures = "private @android.annotation.NonNull java.lang.String mPopulationName\nclass FederatedComputeInput extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/FederatedComputeScheduler.java b/framework/java/android/adservices/ondevicepersonalization/FederatedComputeScheduler.java new file mode 100644 index 00000000..4077e7e8 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/FederatedComputeScheduler.java @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; +import android.annotation.NonNull; +import android.annotation.WorkerThread; +import android.federatedcompute.common.TrainingOptions; +import android.os.RemoteException; + +import com.android.ondevicepersonalization.internal.util.LoggerFactory; + +import java.util.concurrent.CountDownLatch; + +/** + * Handles scheduling Federated Learning and Federated Analytics jobs. + * + * @hide + */ +public class FederatedComputeScheduler { + private static final String TAG = FederatedComputeScheduler.class.getSimpleName(); + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + + private final IFederatedComputeService mFcService; + + /** @hide */ + public FederatedComputeScheduler(IFederatedComputeService binder) { + mFcService = binder; + } + + // TODO(b/300461799): add federated compute server document. + /** + * Schedule a federated computation job. + * + * @param params parameters related to job scheduling. + * @param input the configuration of the federated computation. It should be consistent with + * the federated computation server setup. + * @throws IllegalArgumentException caused by caller supplied invalid input argument. + * @throws IllegalStateException caused by an internal failure of FederatedComputeScheduler. + */ + @WorkerThread + public void schedule(@NonNull Params params, @NonNull FederatedComputeInput input) { + if (mFcService == null) { + throw new IllegalStateException( + "FederatedComputeScheduler not available for this instance."); + } + android.federatedcompute.common.TrainingInterval trainingInterval = + convertTrainingInterval(params.getTrainingInterval()); + TrainingOptions trainingOptions = + new TrainingOptions.Builder() + .setPopulationName(input.getPopulationName()) + .setTrainingInterval(trainingInterval) + .build(); + CountDownLatch latch = new CountDownLatch(1); + final int[] err = {0}; + try { + mFcService.schedule( + trainingOptions, + new IFederatedComputeCallback.Stub() { + @Override + public void onSuccess() { + latch.countDown(); + } + + @Override + public void onFailure(int i) { + err[0] = i; + latch.countDown(); + } + }); + latch.await(); + if (err[0] != 0) { + throw new IllegalStateException("Internal failure occurred while scheduling job"); + } + } catch (RemoteException | InterruptedException e) { + sLogger.e(TAG + ": Failed to schedule federated compute job", e); + throw new IllegalStateException(e); + } + } + + /** + * Cancel a federated computation job with input training params. + * + * @param populationName population name of the job that caller wants to cancel + * @throws IllegalStateException caused by an internal failure of FederatedComputeScheduler. + */ + @WorkerThread + public void cancel(@NonNull String populationName) { + if (mFcService == null) { + throw new IllegalStateException( + "FederatedComputeScheduler not available for this instance."); + } + CountDownLatch latch = new CountDownLatch(1); + final int[] err = {0}; + try { + mFcService.cancel( + populationName, + new IFederatedComputeCallback.Stub() { + @Override + public void onSuccess() { + latch.countDown(); + } + + @Override + public void onFailure(int i) { + err[0] = i; + latch.countDown(); + } + }); + latch.await(); + if (err[0] != 0) { + throw new IllegalStateException("Internal failure occurred while cancelling job"); + } + } catch (RemoteException | InterruptedException e) { + sLogger.e(TAG + ": Failed to cancel federated compute job", e); + throw new IllegalStateException(e); + } + } + + private android.federatedcompute.common.TrainingInterval convertTrainingInterval( + TrainingInterval interval) { + return new android.federatedcompute.common.TrainingInterval.Builder() + .setMinimumIntervalMillis(interval.getMinimumInterval().toMillis()) + .setSchedulingMode(interval.getSchedulingMode()) + .build(); + } + + /** The parameters related to job scheduling. */ + public static class Params { + /** + * If training interval is scheduled for recurrent tasks, the earliest time this task could + * start is after the minimum training interval expires. E.g. If the task is set to run + * maximum once per day, the first run of this task will be one day after this task is + * scheduled. When a one time job is scheduled, the earliest next runtime is calculated + * based on federated compute default interval. + */ + @NonNull private final TrainingInterval mTrainingInterval; + + public Params(@NonNull TrainingInterval trainingInterval) { + mTrainingInterval = trainingInterval; + } + + @NonNull + public TrainingInterval getTrainingInterval() { + return mTrainingInterval; + } + } +} diff --git a/framework/java/android/adservices/ondevicepersonalization/IsolatedService.java b/framework/java/android/adservices/ondevicepersonalization/IsolatedService.java index bc1a1a3d..3f7329fe 100644 --- a/framework/java/android/adservices/ondevicepersonalization/IsolatedService.java +++ b/framework/java/android/adservices/ondevicepersonalization/IsolatedService.java @@ -16,9 +16,13 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + import android.adservices.ondevicepersonalization.aidl.IDataAccessService; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; import android.adservices.ondevicepersonalization.aidl.IIsolatedService; import android.adservices.ondevicepersonalization.aidl.IIsolatedServiceCallback; +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; import android.app.Service; @@ -27,103 +31,177 @@ import android.os.Bundle; import android.os.IBinder; import android.os.Parcelable; import android.os.RemoteException; +import android.os.SystemClock; +import com.android.ondevicepersonalization.internal.util.ByteArrayParceledListSlice; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import java.util.HashMap; import java.util.List; import java.util.Objects; import java.util.function.Consumer; +import java.util.function.Function; // TODO(b/289102463): Add a link to the public ODP developer documentation. /** - * Base class for services that run in an - * <a href="https://developer.android.com/guide/topics/manifest/service-element#isolated">isolated - * process</a> and can produce content to be displayed in a {@link SurfaceView} in a calling app - * and write persistent results to on-device storage, which can be consumed by Federated Analytics - * for cross-device statistical analysis or by Federated Learning for model training. Client apps - * use {@link OnDevicePersonalizationManager} to interact with an - * {@link IsolatedService}. - * - * @hide + * Base class for services that are started by ODP on a call to + * {@code OnDevicePersonalizationManager#execute(ComponentName, PersistableBundle, + * java.util.concurrent.Executor, OutcomeReceiver)} + * and run in an <a + * href="https://developer.android.com/guide/topics/manifest/service-element#isolated">isolated + * process</a>. The service can produce content to be displayed in a + * {@link android.view.SurfaceView} in a calling app and write persistent results to on-device + * storage, which can be consumed by Federated Analytics for cross-device statistical analysis or + * by Federated Learning for model training. + * Client apps use {@link OnDevicePersonalizationManager} to interact with an {@link + * IsolatedService}. */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public abstract class IsolatedService extends Service { private static final String TAG = "IsolatedService"; private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private IBinder mBinder; - @Override public void onCreate() { + /** Creates a binder for an {@link IsolatedService}. */ + @Override + public void onCreate() { mBinder = new ServiceBinder(); } - @Override @Nullable public IBinder onBind(@NonNull Intent intent) { + /** + * Handles binding to the {@link IsolatedService}. + * + * @param intent The Intent that was used to bind to this service, as given to {@link + * android.content.Context#bindService Context.bindService}. Note that any extras that were + * included with the Intent at that point will <em>not</em> be seen here. + */ + @Override + @Nullable + public IBinder onBind(@NonNull Intent intent) { return mBinder; } /** - * Return an instance of {@link IsolatedWorker} that handles client requests. + * Return an instance of an {@link IsolatedWorker} that handles client requests. + * + * @param requestToken an opaque token that identifies the current request to the service that + * must be passed to service methods that depend on per-request state. */ - @NonNull public abstract IsolatedWorker onRequest( - @NonNull RequestToken requestToken); + @NonNull + public abstract IsolatedWorker onRequest(@NonNull RequestToken requestToken); /** - * Returns a DAO for the REMOTE_DATA table. The REMOTE_DATA table is a read-only key-value - * store that contains data that is periodically downloaded from an endpoint declared in the - * package manifest of the service. + * Returns a Data Access Object for the REMOTE_DATA table. The REMOTE_DATA table is a read-only + * key-value store that contains data that is periodically downloaded from an endpoint declared + * in the <download> tag in the ODP manifest of the service, as shown in the following example. + * + * <pre>{@code + * <!-- Contents of res/xml/OdpSettings.xml --> + * <on-device-personalization> + * <!-- Name of the service subclass --> + * <service "com.example.odpsample.SampleService"> + * <!-- If this tag is present, ODP will periodically poll this URL and + * download content to populate REMOTE_DATA. Adopters that do not need to + * download content from their servers can skip this tag. --> + * <download-settings url="https://example.com/get" /> + * </service> + * </on-device-personalization> + * }</pre> * * @param requestToken an opaque token that identifies the current request to the service. - * @return A {@link KeyValueStore} object that provides access to the REMOTE_DATA table. + * @return A {@link KeyValueStore} object that provides access to the REMOTE_DATA table. The + * methods in the returned {@link KeyValueStore} are blocking operations and should be + * called from a worker thread and not the main thread or a binder thread. + * @see #onRequest(RequestToken) */ - @NonNull public final KeyValueStore getRemoteData( - @NonNull RequestToken requestToken) { + @NonNull + public final KeyValueStore getRemoteData(@NonNull RequestToken requestToken) { return new RemoteDataImpl(requestToken.getDataAccessService()); } /** - * Returns a DAO for the LOCAL_DATA table. The LOCAL_DATA table is a persistent key-value - * store that the service can use to store any data. The contents of this table are visible - * only to the service running in an isolated process and cannot be sent outside the device. + * Returns a Data Access Object for the LOCAL_DATA table. The LOCAL_DATA table is a persistent + * key-value store that the service can use to store any data. The contents of this table are + * visible only to the service running in an isolated process and cannot be sent outside the + * device. * * @param requestToken an opaque token that identifies the current request to the service. * @return A {@link MutableKeyValueStore} object that provides access to the LOCAL_DATA table. + * The methods in the returned {@link MutableKeyValueStore} are blocking operations and + * should be called from a worker thread and not the main thread or a binder thread. + * @see #onRequest(RequestToken) */ - @NonNull public final MutableKeyValueStore getLocalData( - @NonNull RequestToken requestToken) { + @NonNull + public final MutableKeyValueStore getLocalData(@NonNull RequestToken requestToken) { return new LocalDataImpl(requestToken.getDataAccessService()); } /** + * Returns a DAO for the REQUESTS and EVENTS tables that provides + * access to the rows that are readable by the IsolatedService. + * + * @param requestToken an opaque token that identifies the current request to the service. + * @return A {@link LogReader} object that provides access to the REQUESTS and EVENTS table. + * The methods in the returned {@link LogReader} are blocking operations and + * should be called from a worker thread and not the main thread or a binder thread. + * @see #onRequest(RequestToken) + * + * @hide + */ + @NonNull + public final LogReader getLogReader(@NonNull RequestToken requestToken) { + return new LogReader(requestToken.getDataAccessService()); + } + + /** * Returns an {@link EventUrlProvider} for the current request. The {@link EventUrlProvider} - * provides URLs that can be embedded in HTML. When the HTML is rendered in a {@link WebView}, - * the platform intercepts requests to these URLs and invokes - * {@link IsolatedCmputationCallback#onWebViewEvent()}. + * provides URLs that can be embedded in HTML. When the HTML is rendered in an + * {@link android.webkit.WebView}, the platform intercepts requests to these URLs and invokes + * {@code IsolatedWorker#onEvent(EventInput, Consumer)}. * * @param requestToken an opaque token that identifies the current request to the service. * @return An {@link EventUrlProvider} that returns event tracking URLs. + * @see #onRequest(RequestToken) */ - @NonNull public final EventUrlProvider getEventUrlProvider( - @NonNull RequestToken requestToken) { + @NonNull + public final EventUrlProvider getEventUrlProvider(@NonNull RequestToken requestToken) { return new EventUrlProvider(requestToken.getDataAccessService()); } /** - * Returns the {@link UserData} for the current request. The {@link UserData} object contains - * location and app usage history collected by the platform. + * Returns the platform-provided {@link UserData} for the current request. * * @param requestToken an opaque token that identifies the current request to the service. * @return A {@link UserData} object. + * @see #onRequest(RequestToken) + */ + @Nullable + public final UserData getUserData(@NonNull RequestToken requestToken) { + return requestToken.getUserData(); + } + + /** + * Returns an {@link FederatedComputeScheduler} for the current request. The {@link + * FederatedComputeScheduler} can be used to schedule and cancel federated computation jobs. + * The federated computation includes federated learning and federated analytic jobs. * + * @param requestToken an opaque token that identifies the current request to the service. + * @return An {@link FederatedComputeScheduler} that returns a federated computation job + * scheduler. + * @see #onRequest(RequestToken) * @hide */ - @Nullable public final UserData getUserData( + @NonNull + public final FederatedComputeScheduler getFederatedComputeScheduler( @NonNull RequestToken requestToken) { - return requestToken.getUserData(); + return new FederatedComputeScheduler(requestToken.getFederatedComputeService()); } // TODO(b/228200518): Add onBidRequest()/onBidResponse() methods. class ServiceBinder extends IIsolatedService.Stub { - @Override public void onRequest( + @Override + public void onRequest( int operationCode, @NonNull Bundle params, @NonNull IIsolatedServiceCallback resultCallback) { @@ -133,97 +211,190 @@ public abstract class IsolatedService extends Service { if (operationCode == Constants.OP_EXECUTE) { - ExecuteInput input = Objects.requireNonNull( - params.getParcelable(Constants.EXTRA_INPUT, ExecuteInput.class)); + ExecuteInputParcel inputParcel = Objects.requireNonNull( + params.getParcelable(Constants.EXTRA_INPUT, ExecuteInputParcel.class)); + ExecuteInput input = new ExecuteInput(inputParcel); Objects.requireNonNull(input.getAppPackageName()); IDataAccessService binder = - IDataAccessService.Stub.asInterface(Objects.requireNonNull( - params.getBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); + IDataAccessService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); Objects.requireNonNull(binder); - UserData userData = params.getParcelable( - Constants.EXTRA_USER_DATA, UserData.class); - RequestToken requestToken = new RequestToken(binder, userData); - IsolatedWorker implCallback = - IsolatedService.this.onRequest(requestToken); + IFederatedComputeService fcBinder = + IFederatedComputeService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER))); + Objects.requireNonNull(fcBinder); + UserData userData = params.getParcelable(Constants.EXTRA_USER_DATA, UserData.class); + RequestToken requestToken = new RequestToken(binder, fcBinder, userData); + IsolatedWorker implCallback = IsolatedService.this.onRequest(requestToken); implCallback.onExecute( - input, new WrappedCallback<ExecuteOutput>(resultCallback)); + input, + new WrappedCallback<ExecuteOutput, ExecuteOutputParcel>( + resultCallback, requestToken, v -> new ExecuteOutputParcel(v))); } else if (operationCode == Constants.OP_DOWNLOAD) { - DownloadInputParcel input = Objects.requireNonNull( - params.getParcelable(Constants.EXTRA_INPUT, DownloadInputParcel.class)); + DownloadInputParcel inputParcel = + Objects.requireNonNull( + params.getParcelable( + Constants.EXTRA_INPUT, DownloadInputParcel.class)); - List<String> keys = Objects.requireNonNull(input.getDownloadedKeys()).getList(); - List<byte[]> values = Objects.requireNonNull(input.getDownloadedValues()).getList(); + List<String> keys = + Objects.requireNonNull(inputParcel.getDownloadedKeys()).getList(); + List<byte[]> values = + Objects.requireNonNull(inputParcel.getDownloadedValues()).getList(); if (keys.size() != values.size()) { throw new IllegalArgumentException( "Mismatching key and value list sizes of " - + keys.size() + " and " + values.size()); + + keys.size() + + " and " + + values.size()); } HashMap<String, byte[]> downloadData = new HashMap<>(); for (int i = 0; i < keys.size(); i++) { downloadData.put(keys.get(i), values.get(i)); } - DownloadInput downloadInput = new DownloadInput.Builder() - .setData(downloadData) - .build(); + DownloadCompletedInput input = + new DownloadCompletedInput.Builder().setData(downloadData).build(); IDataAccessService binder = - IDataAccessService.Stub.asInterface(Objects.requireNonNull( - params.getBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); + IDataAccessService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); Objects.requireNonNull(binder); - UserData userData = params.getParcelable( - Constants.EXTRA_USER_DATA, UserData.class); - RequestToken requestToken = new RequestToken(binder, userData); - IsolatedWorker implCallback = - IsolatedService.this.onRequest(requestToken); - implCallback.onDownload( - downloadInput, new WrappedCallback<DownloadOutput>(resultCallback)); + IFederatedComputeService fcBinder = + IFederatedComputeService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER))); + Objects.requireNonNull(fcBinder); + UserData userData = params.getParcelable(Constants.EXTRA_USER_DATA, UserData.class); + RequestToken requestToken = new RequestToken(binder, fcBinder, userData); + IsolatedWorker implCallback = IsolatedService.this.onRequest(requestToken); + implCallback.onDownloadCompleted( + input, + new WrappedCallback<DownloadCompletedOutput, + DownloadCompletedOutputParcel>( + resultCallback, + requestToken, + v -> new DownloadCompletedOutputParcel(v))); } else if (operationCode == Constants.OP_RENDER) { - RenderInput input = Objects.requireNonNull( - params.getParcelable(Constants.EXTRA_INPUT, RenderInput.class)); + RenderInputParcel inputParcel = + Objects.requireNonNull(params.getParcelable( + Constants.EXTRA_INPUT, RenderInputParcel.class)); + RenderInput input = new RenderInput(inputParcel); Objects.requireNonNull(input.getRenderingConfig()); IDataAccessService binder = - IDataAccessService.Stub.asInterface(Objects.requireNonNull( - params.getBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); + IDataAccessService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); Objects.requireNonNull(binder); - RequestToken requestToken = new RequestToken(binder, null); - IsolatedWorker implCallback = - IsolatedService.this.onRequest(requestToken); - implCallback.onRender( - input, new WrappedCallback<RenderOutput>(resultCallback)); + RequestToken requestToken = new RequestToken(binder, null, null); + IsolatedWorker implCallback = IsolatedService.this.onRequest(requestToken); + implCallback.onRender(input, new WrappedCallback<RenderOutput, RenderOutputParcel>( + resultCallback, requestToken, v -> new RenderOutputParcel(v))); } else if (operationCode == Constants.OP_WEB_VIEW_EVENT) { - WebViewEventInput input = Objects.requireNonNull( - params.getParcelable(Constants.EXTRA_INPUT, WebViewEventInput.class)); + EventInputParcel inputParcel = + Objects.requireNonNull( + params.getParcelable( + Constants.EXTRA_INPUT, EventInputParcel.class)); + EventInput input = new EventInput(inputParcel); IDataAccessService binder = - IDataAccessService.Stub.asInterface(Objects.requireNonNull( - params.getBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); - UserData userData = params.getParcelable( - Constants.EXTRA_USER_DATA, UserData.class); - RequestToken requestToken = new RequestToken(binder, userData); - IsolatedWorker implCallback = - IsolatedService.this.onRequest(requestToken); - implCallback.onWebViewEvent( - input, new WrappedCallback<WebViewEventOutput>(resultCallback)); + IDataAccessService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); + UserData userData = params.getParcelable(Constants.EXTRA_USER_DATA, UserData.class); + RequestToken requestToken = new RequestToken(binder, null, userData); + IsolatedWorker implCallback = IsolatedService.this.onRequest(requestToken); + implCallback.onEvent( + input, new WrappedCallback<EventOutput, EventOutputParcel>( + resultCallback, requestToken, v -> new EventOutputParcel(v))); + } else if (operationCode == Constants.OP_TRAINING_EXAMPLE) { + TrainingExampleInput input = + Objects.requireNonNull( + params.getParcelable( + Constants.EXTRA_INPUT, TrainingExampleInput.class)); + IDataAccessService binder = + IDataAccessService.Stub.asInterface( + Objects.requireNonNull( + params.getBinder( + Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); + Objects.requireNonNull(binder); + UserData userData = params.getParcelable(Constants.EXTRA_USER_DATA, UserData.class); + RequestToken requestToken = new RequestToken(binder, null, userData); + IsolatedWorker implCallback = IsolatedService.this.onRequest(requestToken); + implCallback.onTrainingExample( + input, new Consumer<TrainingExampleOutput>() { + @Override + public void accept(TrainingExampleOutput result) { + long elapsedTimeMillis = SystemClock.elapsedRealtime() + - requestToken.getStartTimeMillis(); + if (result == null) { + try { + resultCallback.onError(Constants.STATUS_INTERNAL_ERROR); + } catch (RemoteException e) { + sLogger.w(TAG + ": Callback failed.", e); + } + } else { + TrainingExampleOutputParcel parcelResult = + new TrainingExampleOutputParcel.Builder() + .setTrainingExamples( + new ByteArrayParceledListSlice( + result.getTrainingExamples())) + .setResumptionTokens( + new ByteArrayParceledListSlice( + result.getResumptionTokens())) + .build(); + Bundle bundle = new Bundle(); + bundle.putParcelable(Constants.EXTRA_RESULT, parcelResult); + bundle.putParcelable(Constants.EXTRA_CALLEE_METADATA, + new CalleeMetadata.Builder() + .setElapsedTimeMillis(elapsedTimeMillis) + .build()); + try { + resultCallback.onSuccess(bundle); + } catch (RemoteException e) { + sLogger.w(TAG + ": Callback failed.", e); + } + } + } + }); } else { throw new IllegalArgumentException("Invalid op code: " + operationCode); } } } - private static class WrappedCallback<T extends Parcelable> implements Consumer<T> { + private static class WrappedCallback<T, U extends Parcelable> implements Consumer<T> { @NonNull private final IIsolatedServiceCallback mCallback; - WrappedCallback(IIsolatedServiceCallback callback) { + @NonNull private final RequestToken mRequestToken; + @NonNull private final Function<T, U> mConverter; + + WrappedCallback( + IIsolatedServiceCallback callback, + RequestToken requestToken, + Function<T, U> converter) { mCallback = Objects.requireNonNull(callback); + mRequestToken = Objects.requireNonNull(requestToken); + mConverter = Objects.requireNonNull(converter); } - @Override public void accept(T result) { + @Override + public void accept(T result) { + long elapsedTimeMillis = + SystemClock.elapsedRealtime() - mRequestToken.getStartTimeMillis(); if (result == null) { try { mCallback.onError(Constants.STATUS_INTERNAL_ERROR); @@ -232,7 +403,12 @@ public abstract class IsolatedService extends Service { } } else { Bundle bundle = new Bundle(); - bundle.putParcelable(Constants.EXTRA_RESULT, result); + U wrappedResult = mConverter.apply(result); + bundle.putParcelable(Constants.EXTRA_RESULT, wrappedResult); + bundle.putParcelable(Constants.EXTRA_CALLEE_METADATA, + new CalleeMetadata.Builder() + .setElapsedTimeMillis(elapsedTimeMillis) + .build()); try { mCallback.onSuccess(bundle); } catch (RemoteException e) { diff --git a/framework/java/android/adservices/ondevicepersonalization/IsolatedWorker.java b/framework/java/android/adservices/ondevicepersonalization/IsolatedWorker.java index 06c20bf4..39adfc36 100644 --- a/framework/java/android/adservices/ondevicepersonalization/IsolatedWorker.java +++ b/framework/java/android/adservices/ondevicepersonalization/IsolatedWorker.java @@ -16,71 +16,116 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import java.util.function.Consumer; /** - * Interface with methods that need to be implemented to handle requests to an - * {@link IsolatedService}. - * @hide + * Interface with methods that need to be implemented to handle requests from the OS to an {@link + * IsolatedService}. The {@link IsolatedService} creates an instance of {@link IsolatedWorker} on + * each request and calls one of the methods below, depending the type of the request. The {@link + * IsolatedService} calls the method on a Binder thread and the {@link IsolatedWorker} should + * offload long running operations to a worker thread. The consumer parameter of each method is used + * to return results. */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public interface IsolatedWorker { /** - * Handle a request from an app. A {@link IsolatedService} that - * processes requests from apps must override this method. + * Handles a request from an app. This method is called when an app calls {@code + * OnDevicePersonalizationManager#execute(ComponentName, PersistableBundle, + * java.util.concurrent.Executor, OutcomeReceiver)} that refers to a named + * {@link IsolatedService}. * - * @param input App Request Parameters. - * @param consumer Callback to be invoked on completion. + * @param input Request Parameters from the calling app. + * @param consumer Callback that receives the result {@link ExecuteOutput}. Should be called + * with <code>null</code> on an error. The error is propagated to the calling app as an + * {@link OnDevicePersonalizationException} with error code {@link + * OnDevicePersonalizationException#ERROR_ISOLATED_SERVICE_FAILED}. To avoid leaking private + * data to the calling app, more detailed error reporting is not available. If the {@link + * IsolatedService} needs to report error stats to its backend, it should populate {@link + * ExecuteOutput} with error data for logging, and rely on Federated Analytics to aggregate + * the error reports. + * <p>If this method throws a {@link RuntimeException}, that is also reported to + * calling apps as an {@link OnDevicePersonalizationException} with error code {@link + * OnDevicePersonalizationException#ERROR_ISOLATED_SERVICE_FAILED}. */ - default void onExecute( - @NonNull ExecuteInput input, - @NonNull Consumer<ExecuteOutput> consumer - ) { + default void onExecute(@NonNull ExecuteInput input, @NonNull Consumer<ExecuteOutput> consumer) { consumer.accept(null); } /** - * Handle a completed download. The platform downloads content using the - * parameters defined in the package manifest of the {@link IsolatedService} - * and calls this function after the download is complete. + * Handles a completed download. The platform downloads content using the parameters defined in + * the package manifest of the {@link IsolatedService}, calls this function after the download + * is complete, and updates the REMOTE_DATA table from + * {@link IsolatedService#getRemoteData(RequestToken)} with the result of this method. * * @param input Download handler parameters. - * @param consumer Callback to be invoked on completion. + * @param consumer Callback that receives the result. Should be called with <code>null</code> on + * an error. If called with <code>null</code>, no updates are made to the REMOTE_DATA table. + * <p>If this method throws a {@link RuntimeException}, no updates are made to the + * REMOTE_DATA table. */ - default void onDownload( - @NonNull DownloadInput input, - @NonNull Consumer<DownloadOutput> consumer - ) { + default void onDownloadCompleted( + @NonNull DownloadCompletedInput input, + @NonNull Consumer<DownloadCompletedOutput> consumer) { consumer.accept(null); } /** - * Generate HTML for the results that were returned as a result of {@link #onExecute()}. - * The platform will render this HTML in a WebView inside a fenced frame. + * Generates HTML for the results that were returned as a result of + * {@link #onExecute(ExecuteInput, Consumer)}. Called when a client app calls + * {@link OnDevicePersonalizationManager#requestSurfacePackage(SurfacePackageToken, IBinder, int, int, int, java.util.concurrent.Executor, OutcomeReceiver)}. + * The platform will render this HTML in an {@link android.webkit.WebView} inside a fenced + * frame. * - * @param input Parameters for the renderContent request. - * @param consumer Callback to be invoked on completion. + * @param input Parameters for the render request. + * @param consumer Callback that receives the result. Should be called with <code>null</code> on + * an error. The error is propagated to the calling app as an {@link + * OnDevicePersonalizationException} with error code {@link + * OnDevicePersonalizationException#ERROR_ISOLATED_SERVICE_FAILED}. + * <p>If this method throws a {@link RuntimeException}, that is also reported to calling + * apps as an {@link OnDevicePersonalizationException} with error code {@link + * OnDevicePersonalizationException#ERROR_ISOLATED_SERVICE_FAILED}. */ - default void onRender( - @NonNull RenderInput input, - @NonNull Consumer<RenderOutput> consumer - ) { + default void onRender(@NonNull RenderInput input, @NonNull Consumer<RenderOutput> consumer) { consumer.accept(null); } /** - * Handle an event triggered by a request to a URL generated by {@link EventUrlProvider} - * and embedded in the HTML output returned by {@link #onRender()}. + * Handles an event triggered by a request to a platform-provided tracking URL {@link + * EventUrlProvider} that was embedded in the HTML output returned by + * {@link #onRender(RenderInput, Consumer)}. The platform updates the EVENTS table with + * {@link EventOutput#getEventLogRecord()}. * * @param input The parameters needed to compute event data. - * @param consumer Callback to be invoked on completion. + * @param consumer Callback that receives the result. Should be called with <code>null</code> on + * an error. If called with <code>null</code>, no data is written to the EVENTS table. + * <p>If this method throws a {@link RuntimeException}, no data is written to the EVENTS + * table. + * @hide + */ + default void onEvent( + @NonNull EventInput input, @NonNull Consumer<EventOutput> consumer) { + consumer.accept(null); + } + + /** + * Generate a single training example used for federated computation job. + * + * @param input The parameters needed to generate the training example. + * @param consumer Callback that receives the result. Should be called with <code>null</code> on + * an error. If called with <code>null</code>, no training examples is produced for this + * training session. <p>If this method throws a {@link RuntimeException}, no training + * examples are produced for this training session. + * @hide */ - default void onWebViewEvent( - @NonNull WebViewEventInput input, - @NonNull Consumer<WebViewEventOutput> consumer - ) { + default void onTrainingExample( + @NonNull TrainingExampleInput input, + @NonNull Consumer<TrainingExampleOutput> consumer) { consumer.accept(null); } } diff --git a/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.aidl b/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.aidl new file mode 100644 index 00000000..2139ba46 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable JoinedLogRecord; diff --git a/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.java b/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.java new file mode 100644 index 00000000..451c35a4 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.java @@ -0,0 +1,365 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.IntRange; +import android.annotation.Nullable; +import android.content.ContentValues; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Input data to create example from. Represents a single joined log record. + * + * @hide + */ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public class JoinedLogRecord implements Parcelable { + /** Time of the request in milliseconds */ + private final long mRequestTimeMillis; + + /** Time of the event in milliseconds */ + private final long mEventTimeMillis; + + /** + * The service-assigned type that identifies the event data. Must be >0 and <128. If type is 0, + * it is an Request-only row with no associated event. + */ + @IntRange(from = 0, to = 127) + private final int mType; + + /** Request data logged in a {@link RequestLogRecord} */ + @Nullable private ContentValues mRequestData = null; + + /** Event data logged in an {@link EventLogRecord} */ + @Nullable private ContentValues mEventData = null; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ JoinedLogRecord( + long requestTimeMillis, + long eventTimeMillis, + @IntRange(from = 0, to = 127) int type, + @Nullable ContentValues requestData, + @Nullable ContentValues eventData) { + this.mRequestTimeMillis = requestTimeMillis; + this.mEventTimeMillis = eventTimeMillis; + this.mType = type; + AnnotationValidations.validate( + IntRange.class, null, mType, + "from", 0, + "to", 127); + this.mRequestData = requestData; + this.mEventData = eventData; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * Time of the request in milliseconds + */ + @DataClass.Generated.Member + public long getRequestTimeMillis() { + return mRequestTimeMillis; + } + + /** + * Time of the event in milliseconds + */ + @DataClass.Generated.Member + public long getEventTimeMillis() { + return mEventTimeMillis; + } + + /** + * The service-assigned type that identifies the event data. Must be >0 and <128. If type is 0, + * it is an Request-only row with no associated event. + */ + @DataClass.Generated.Member + public @IntRange(from = 0, to = 127) int getType() { + return mType; + } + + /** + * Request data logged in a {@link RequestLogRecord} + */ + @DataClass.Generated.Member + public @Nullable ContentValues getRequestData() { + return mRequestData; + } + + /** + * Event data logged in an {@link EventLogRecord} + */ + @DataClass.Generated.Member + public @Nullable ContentValues getEventData() { + return mEventData; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(JoinedLogRecord other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + JoinedLogRecord that = (JoinedLogRecord) o; + //noinspection PointlessBooleanExpression + return true + && mRequestTimeMillis == that.mRequestTimeMillis + && mEventTimeMillis == that.mEventTimeMillis + && mType == that.mType + && java.util.Objects.equals(mRequestData, that.mRequestData) + && java.util.Objects.equals(mEventData, that.mEventData); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + Long.hashCode(mRequestTimeMillis); + _hash = 31 * _hash + Long.hashCode(mEventTimeMillis); + _hash = 31 * _hash + mType; + _hash = 31 * _hash + java.util.Objects.hashCode(mRequestData); + _hash = 31 * _hash + java.util.Objects.hashCode(mEventData); + return _hash; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@android.annotation.NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mRequestData != null) flg |= 0x8; + if (mEventData != null) flg |= 0x10; + dest.writeByte(flg); + dest.writeLong(mRequestTimeMillis); + dest.writeLong(mEventTimeMillis); + dest.writeInt(mType); + if (mRequestData != null) dest.writeTypedObject(mRequestData, flags); + if (mEventData != null) dest.writeTypedObject(mEventData, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + protected JoinedLogRecord(@android.annotation.NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + long requestTimeMillis = in.readLong(); + long eventTimeMillis = in.readLong(); + int type = in.readInt(); + ContentValues requestData = (flg & 0x8) == 0 ? null : (ContentValues) in.readTypedObject(ContentValues.CREATOR); + ContentValues eventData = (flg & 0x10) == 0 ? null : (ContentValues) in.readTypedObject(ContentValues.CREATOR); + + this.mRequestTimeMillis = requestTimeMillis; + this.mEventTimeMillis = eventTimeMillis; + this.mType = type; + AnnotationValidations.validate( + IntRange.class, null, mType, + "from", 0, + "to", 127); + this.mRequestData = requestData; + this.mEventData = eventData; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @android.annotation.NonNull Parcelable.Creator<JoinedLogRecord> CREATOR + = new Parcelable.Creator<JoinedLogRecord>() { + @Override + public JoinedLogRecord[] newArray(int size) { + return new JoinedLogRecord[size]; + } + + @Override + public JoinedLogRecord createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + return new JoinedLogRecord(in); + } + }; + + /** + * A builder for {@link JoinedLogRecord} + * @hide + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static class Builder { + + private long mRequestTimeMillis; + private long mEventTimeMillis; + private @IntRange(from = 0, to = 127) int mType; + private @Nullable ContentValues mRequestData; + private @Nullable ContentValues mEventData; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * Creates a new Builder. + * + * @param requestTimeMillis + * Time of the request in milliseconds + * @param eventTimeMillis + * Time of the event in milliseconds + * @param type + * The service-assigned type that identifies the event data. Must be >0 and <128. If type is 0, + * it is an Request-only row with no associated event. + */ + public Builder( + long requestTimeMillis, + long eventTimeMillis, + @IntRange(from = 0, to = 127) int type) { + mRequestTimeMillis = requestTimeMillis; + mEventTimeMillis = eventTimeMillis; + mType = type; + AnnotationValidations.validate( + IntRange.class, null, mType, + "from", 0, + "to", 127); + } + + /** + * Time of the request in milliseconds + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setRequestTimeMillis(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mRequestTimeMillis = value; + return this; + } + + /** + * Time of the event in milliseconds + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setEventTimeMillis(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mEventTimeMillis = value; + return this; + } + + /** + * The service-assigned type that identifies the event data. Must be >0 and <128. If type is 0, + * it is an Request-only row with no associated event. + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setType(@IntRange(from = 0, to = 127) int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mType = value; + return this; + } + + /** + * Request data logged in a {@link RequestLogRecord} + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setRequestData(@android.annotation.NonNull ContentValues value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; + mRequestData = value; + return this; + } + + /** + * Event data logged in an {@link EventLogRecord} + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setEventData(@android.annotation.NonNull ContentValues value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x10; + mEventData = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @android.annotation.NonNull JoinedLogRecord build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x20; // Mark builder used + + if ((mBuilderFieldsSet & 0x8) == 0) { + mRequestData = null; + } + if ((mBuilderFieldsSet & 0x10) == 0) { + mEventData = null; + } + JoinedLogRecord o = new JoinedLogRecord( + mRequestTimeMillis, + mEventTimeMillis, + mType, + mRequestData, + mEventData); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x20) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1695413878624L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/JoinedLogRecord.java", + inputSignatures = "private final long mRequestTimeMillis\nprivate final long mEventTimeMillis\nprivate final @android.annotation.IntRange int mType\nprivate @android.annotation.Nullable android.content.ContentValues mRequestData\nprivate @android.annotation.Nullable android.content.ContentValues mEventData\nclass JoinedLogRecord extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/KeyValueStore.java b/framework/java/android/adservices/ondevicepersonalization/KeyValueStore.java index e412aa9e..b50f70e2 100644 --- a/framework/java/android/adservices/ondevicepersonalization/KeyValueStore.java +++ b/framework/java/android/adservices/ondevicepersonalization/KeyValueStore.java @@ -16,32 +16,39 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; +import android.annotation.WorkerThread; import java.util.Set; /** - * Data Access Object for the REMOTE_DATA table. The REMOTE_DATA table is a immutable - * data store that contains data that has been downloaded by the ODP platform from - * the vendor endpoint that is declared in the package manifest. + * An interface to a read-only key-value store. + * + * Used as a Data Access Object for the REMOTE_DATA table. + * + * @see IsolatedService#getRemoteData(RequestToken) * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public interface KeyValueStore { /** - * Looks up a key in the REMOTE_DATA table. + * Looks up a key in a read-only store. * * @param key The key to look up. * @return the value to which the specified key is mapped, * or null if there contains no mapping for the key. + * */ - @Nullable byte[] get(@NonNull String key) throws OnDevicePersonalizationException; + @WorkerThread + @Nullable byte[] get(@NonNull String key); /** * Returns a Set view of the keys contained in the REMOTE_DATA table. - * - * @return a Set view of the keys contained in the REMOTE_DATA table. */ - @NonNull Set<String> keySet() throws OnDevicePersonalizationException; + @WorkerThread + @NonNull Set<String> keySet(); } diff --git a/framework/java/android/adservices/ondevicepersonalization/LocalDataImpl.java b/framework/java/android/adservices/ondevicepersonalization/LocalDataImpl.java index dbfc89ae..e87f4ac7 100644 --- a/framework/java/android/adservices/ondevicepersonalization/LocalDataImpl.java +++ b/framework/java/android/adservices/ondevicepersonalization/LocalDataImpl.java @@ -32,7 +32,6 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.TimeUnit; /** @hide */ public class LocalDataImpl implements MutableKeyValueStore { @@ -41,15 +40,13 @@ public class LocalDataImpl implements MutableKeyValueStore { @NonNull IDataAccessService mDataAccessService; - private static final long ASYNC_TIMEOUT_MS = 1000; - /** @hide */ public LocalDataImpl(@NonNull IDataAccessService binder) { mDataAccessService = Objects.requireNonNull(binder); } @Override @Nullable - public byte[] get(@NonNull String key) throws OnDevicePersonalizationException { + public byte[] get(@NonNull String key) { Objects.requireNonNull(key); Bundle params = new Bundle(); params.putStringArray(Constants.EXTRA_LOOKUP_KEYS, new String[]{key}); @@ -57,7 +54,7 @@ public class LocalDataImpl implements MutableKeyValueStore { } @Override @Nullable - public byte[] put(@NonNull String key, byte[] value) throws OnDevicePersonalizationException { + public byte[] put(@NonNull String key, byte[] value) { Objects.requireNonNull(key); Bundle params = new Bundle(); params.putStringArray(Constants.EXTRA_LOOKUP_KEYS, new String[]{key}); @@ -66,48 +63,38 @@ public class LocalDataImpl implements MutableKeyValueStore { } @Override @Nullable - public byte[] remove(@NonNull String key) throws OnDevicePersonalizationException { + public byte[] remove(@NonNull String key) { Objects.requireNonNull(key); Bundle params = new Bundle(); params.putStringArray(Constants.EXTRA_LOOKUP_KEYS, new String[]{key}); return handleLookupRequest(Constants.DATA_ACCESS_OP_LOCAL_DATA_REMOVE, key, params); } - private byte[] handleLookupRequest(int op, String key, Bundle params) - throws OnDevicePersonalizationException { + private byte[] handleLookupRequest(int op, String key, Bundle params) { Bundle result = handleAsyncRequest(op, params); - if (null == result) { - sLogger.e(TAG + ": Timed out waiting for result of lookup for op: " + op); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); - } HashMap<String, byte[]> data = result.getSerializable( Constants.EXTRA_RESULT, HashMap.class); if (null == data) { sLogger.e(TAG + ": No EXTRA_RESULT was present in bundle"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException("Bundle missing EXTRA_RESULT."); } return data.get(key); } @Override @NonNull - public Set<String> keySet() throws OnDevicePersonalizationException { + public Set<String> keySet() { Bundle result = handleAsyncRequest(Constants.DATA_ACCESS_OP_LOCAL_DATA_KEYSET, Bundle.EMPTY); - if (null == result) { - sLogger.e(TAG + ": Timed out waiting for result of keySet"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); - } HashSet<String> resultSet = result.getSerializable(Constants.EXTRA_RESULT, HashSet.class); if (null == resultSet) { sLogger.e(TAG + ": No EXTRA_RESULT was present in bundle"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException("Bundle missing EXTRA_RESULT."); } return resultSet; } - private Bundle handleAsyncRequest(int op, Bundle params) - throws OnDevicePersonalizationException { + private Bundle handleAsyncRequest(int op, Bundle params) { try { BlockingQueue<Bundle> asyncResult = new ArrayBlockingQueue<>(1); mDataAccessService.onRequest( @@ -116,7 +103,11 @@ public class LocalDataImpl implements MutableKeyValueStore { new IDataAccessServiceCallback.Stub() { @Override public void onSuccess(@NonNull Bundle result) { - asyncResult.add(result); + if (result != null) { + asyncResult.add(result); + } else { + asyncResult.add(Bundle.EMPTY); + } } @Override @@ -124,10 +115,10 @@ public class LocalDataImpl implements MutableKeyValueStore { asyncResult.add(Bundle.EMPTY); } }); - return asyncResult.poll(ASYNC_TIMEOUT_MS, TimeUnit.MILLISECONDS); + return asyncResult.take(); } catch (InterruptedException | RemoteException e) { sLogger.e(TAG + ": Failed to retrieve result from localData", e); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException(e); } } } diff --git a/framework/java/android/adservices/ondevicepersonalization/LogReader.java b/framework/java/android/adservices/ondevicepersonalization/LogReader.java new file mode 100644 index 00000000..7d54dd55 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/LogReader.java @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.adservices.ondevicepersonalization.aidl.IDataAccessService; +import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; +import android.annotation.NonNull; +import android.annotation.WorkerThread; +import android.os.Bundle; +import android.os.Parcelable; +import android.os.RemoteException; + +import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * An interface to a read logs from REQUESTS and EVENTS + * + * Used as a Data Access Object for the REQUESTS and EVENTS table. + * + * @see IsolatedService#getLogReader(RequestToken) + * + * @hide + */ +public class LogReader { + private static final String TAG = "LogReader"; + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + + @NonNull + private final IDataAccessService mDataAccessService; + + /** @hide */ + public LogReader(@NonNull IDataAccessService binder) { + mDataAccessService = Objects.requireNonNull(binder); + } + + + /** + * Retrieves a List of RequestLogRecords written by this IsolatedService within + * the specified time range. + */ + @WorkerThread + @NonNull + public List<RequestLogRecord> getRequests(long startTimeMillis, long endTimeMillis) { + if (endTimeMillis <= startTimeMillis) { + throw new IllegalArgumentException( + "endTimeMillis must be greater than startTimeMillis"); + } + if (startTimeMillis < 0) { + throw new IllegalArgumentException("startTimeMillis must be greater than 0"); + } + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, + new long[]{startTimeMillis, endTimeMillis}); + OdpParceledListSlice<RequestLogRecord> result = + handleListLookupRequest(Constants.DATA_ACCESS_OP_GET_REQUESTS, params); + return result.getList(); + } + + /** + * Retrieves a List of EventLogRecord with its corresponding RequestLogRecord written by this + * IsolatedService within the specified time range. + */ + @WorkerThread + @NonNull + public List<EventLogRecord> getJoinedEvents(long startTimeMillis, long endTimeMillis) { + if (endTimeMillis <= startTimeMillis) { + throw new IllegalArgumentException( + "endTimeMillis must be greater than startTimeMillis"); + } + if (startTimeMillis < 0) { + throw new IllegalArgumentException("startTimeMillis must be greater than 0"); + } + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, + new long[]{startTimeMillis, endTimeMillis}); + OdpParceledListSlice<EventLogRecord> result = + handleListLookupRequest(Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS, params); + return result.getList(); + } + + private Bundle handleAsyncRequest(int op, Bundle params) { + try { + BlockingQueue<Bundle> asyncResult = new ArrayBlockingQueue<>(1); + mDataAccessService.onRequest( + op, + params, + new IDataAccessServiceCallback.Stub() { + @Override + public void onSuccess(@NonNull Bundle result) { + if (result != null) { + asyncResult.add(result); + } else { + asyncResult.add(Bundle.EMPTY); + } + } + + @Override + public void onError(int errorCode) { + asyncResult.add(Bundle.EMPTY); + } + }); + return asyncResult.take(); + } catch (InterruptedException | RemoteException e) { + sLogger.e(TAG + ": Failed to retrieve result", e); + throw new IllegalStateException(e); + } + } + + private <T extends Parcelable> OdpParceledListSlice<T> handleListLookupRequest(int op, + Bundle params) { + Bundle result = handleAsyncRequest(op, params); + try { + OdpParceledListSlice<T> data = result.getParcelable( + Constants.EXTRA_RESULT, OdpParceledListSlice.class); + if (null == data) { + sLogger.e(TAG + ": No EXTRA_RESULT was present in bundle"); + throw new IllegalStateException("Bundle missing EXTRA_RESULT."); + } + return data; + } catch (ClassCastException e) { + throw new IllegalStateException("Failed to retrieve parceled list"); + } + } +} diff --git a/framework/java/android/adservices/ondevicepersonalization/MutableKeyValueStore.java b/framework/java/android/adservices/ondevicepersonalization/MutableKeyValueStore.java index 5e350d97..e63bcc6f 100644 --- a/framework/java/android/adservices/ondevicepersonalization/MutableKeyValueStore.java +++ b/framework/java/android/adservices/ondevicepersonalization/MutableKeyValueStore.java @@ -16,34 +16,42 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; +import android.annotation.WorkerThread; /** - * Data Access Object for the LOCAL_DATA table. The LOCAL_DATA table is a mutable - * data store that contains data that has been stored locally by the vendor. + * An interface to a read-write key-value store. + * + * Used as a Data Access Object for the LOCAL_DATA table. + * + * @see IsolatedService#getLocalData(RequestToken) * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public interface MutableKeyValueStore extends KeyValueStore { /** - * Associates the specified value with the specified key in LOCAL_DATA. - * If LOCAL_DATA previously contained a mapping for the key, the old value is replaced. + * Associates the specified value with the specified key. + * If a value already exists for that key, the old value is replaced. * * @param key key with which the specified value is to be associated * @param value value to be associated with the specified key * * @return the previous value associated with key, or null if there was no mapping for key. */ - @Nullable byte[] put(@NonNull String key, @NonNull byte[] value) - throws OnDevicePersonalizationException; + @WorkerThread + @Nullable byte[] put(@NonNull String key, @NonNull byte[] value); /** - * Removes the mapping for the specified key from LOCAL_DATA if present. + * Removes the mapping for the specified key. * - * @param key key whose mapping is to be removed from the LOCAL_DATA + * @param key key whose mapping is to be removed * * @return the previous value associated with key, or null if there was no mapping for key. */ - @Nullable byte[] remove(@NonNull String key) throws OnDevicePersonalizationException; + @WorkerThread + @Nullable byte[] remove(@NonNull String key); } diff --git a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationConfigManager.java b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationConfigManager.java new file mode 100644 index 00000000..f6442fec --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationConfigManager.java @@ -0,0 +1,197 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; +import static android.adservices.ondevicepersonalization.OnDevicePersonalizationPermissions.MODIFY_ONDEVICEPERSONALIZATION_STATE; + +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationConfigService; +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationConfigServiceCallback; +import android.annotation.CallbackExecutor; +import android.annotation.FlaggedApi; +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.annotation.RequiresPermission; +import android.annotation.SystemApi; +import android.content.ComponentName; +import android.content.Context; +import android.content.Intent; +import android.content.ServiceConnection; +import android.content.pm.ResolveInfo; +import android.content.pm.ServiceInfo; +import android.os.Binder; +import android.os.IBinder; +import android.os.OutcomeReceiver; +import android.os.RemoteException; + +import com.android.ondevicepersonalization.internal.util.LoggerFactory; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; + +/** + * OnDevicePersonalizationConfigManager provides system APIs + * for privileged APKs to control OnDevicePersonalization's enablement status. + * + * @hide + */ +@SystemApi +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +public class OnDevicePersonalizationConfigManager { + /** @hide */ + public static final String ON_DEVICE_PERSONALIZATION_CONFIG_SERVICE = + "on_device_personalization_config_service"; + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + private static final String TAG = "OnDevicePersonalizationConfigManager"; + private static final String ODP_CONFIG_SERVICE_INTENT = + "android.OnDevicePersonalizationConfigService"; + private static final int BIND_SERVICE_TIMEOUT_SEC = 5; + private final Context mContext; + private final CountDownLatch mConnectionLatch = new CountDownLatch(1); + private boolean mBound = false; + private IOnDevicePersonalizationConfigService mService = null; + private final ServiceConnection mConnection = new ServiceConnection() { + @Override + public void onServiceConnected(ComponentName name, IBinder binder) { + mService = IOnDevicePersonalizationConfigService.Stub.asInterface(binder); + mBound = true; + mConnectionLatch.countDown(); + } + + @Override + public void onNullBinding(ComponentName name) { + mBound = false; + mConnectionLatch.countDown(); + } + + @Override + public void onServiceDisconnected(ComponentName name) { + mService = null; + mBound = false; + } + }; + + /** @hide */ + public OnDevicePersonalizationConfigManager(@NonNull Context context) { + mContext = context; + } + + /** + * API users are expected to call this to modify personalization status for + * On Device Personalization. The status is persisted both in memory and to the disk. + * When reboot, the in-memory status will be restored from the disk. + * Personalization is disabled by default. + * + * @param enabled boolean whether On Device Personalization should be enabled. + * @param executor The {@link Executor} on which to invoke the callback. + * @param receiver This either returns null on success or {@link Exception} on failure. + * + * In case of an error, the receiver returns one of the following exceptions: + * Returns an {@link IllegalStateException} if the callback is unable to send back results. + * Returns a {@link SecurityException} if the caller is unauthorized to modify + * personalization status. + */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) + @RequiresPermission(MODIFY_ONDEVICEPERSONALIZATION_STATE) + public void setPersonalizationEnabled(boolean enabled, + @NonNull @CallbackExecutor Executor executor, + @NonNull OutcomeReceiver<Void, Exception> receiver) { + + try { + bindService(executor); + + mService.setPersonalizationStatus(enabled, + new IOnDevicePersonalizationConfigServiceCallback.Stub() { + @Override + public void onSuccess() { + executor.execute(() -> { + Binder.clearCallingIdentity(); + receiver.onResult(null); + }); + } + + @Override + public void onFailure(int errorCode) { + executor.execute(() -> { + sLogger.w(TAG + ": Unexpected failure from ODP" + + "config service with error code: " + errorCode); + Binder.clearCallingIdentity(); + receiver.onError(new IllegalStateException("Unexpected failure.")); + }); + } + }); + } catch (IllegalStateException | InterruptedException | RemoteException e) { + executor.execute(() -> { + receiver.onError(new IllegalStateException(e)); + }); + } catch (SecurityException e) { + executor.execute(() -> { + sLogger.w(TAG + ": Unauthorized call to ODP config service."); + receiver.onError(e); + }); + } + } + + private void bindService(@NonNull Executor executor) throws InterruptedException { + if (!mBound) { + Intent intent = new Intent(ODP_CONFIG_SERVICE_INTENT); + ComponentName serviceComponent = resolveService(intent); + if (serviceComponent == null) { + sLogger.e(TAG + ": Invalid component for ODP config service"); + return; + } + + intent.setComponent(serviceComponent); + boolean r = mContext.bindService( + intent, Context.BIND_AUTO_CREATE, executor, mConnection); + if (!r) { + return; + } + mConnectionLatch.await(BIND_SERVICE_TIMEOUT_SEC, TimeUnit.SECONDS); + } + } + + /** + * Find the ComponentName of the service, given its intent. + * + * @return ComponentName of the service. Null if the service is not found. + */ + @Nullable + private ComponentName resolveService(@NonNull Intent intent) { + List<ResolveInfo> services = mContext.getPackageManager().queryIntentServices(intent, 0); + if (services == null || services.isEmpty()) { + sLogger.e(TAG + ": Failed to find OnDevicePersonalizationConfigService"); + return null; + } + + for (int i = 0; i < services.size(); i++) { + ServiceInfo serviceInfo = services.get(i).serviceInfo; + if (serviceInfo == null) { + sLogger.e(TAG + ": Failed to find serviceInfo " + + "for OnDevicePersonalizationConfigService."); + return null; + } + // There should only be one matching service inside the given package. + // If there's more than one, return the first one found. + return new ComponentName(serviceInfo.packageName, serviceInfo.name); + } + sLogger.e(TAG + ": Didn't find any matching OnDevicePersonalizationConfigService."); + return null; + } +} diff --git a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationDebugManager.java b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationDebugManager.java new file mode 100644 index 00000000..d0558385 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationDebugManager.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationDebugService; +import android.content.Context; +import android.os.RemoteException; + +import com.android.federatedcompute.internal.util.AbstractServiceBinder; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.Executors; + +/** + * Provides APIs to support testing. + * @hide + */ +public class OnDevicePersonalizationDebugManager { + + private static final String INTENT_FILTER_ACTION = + "android.OnDevicePersonalizationDebugService"; + private static final String ODP_MANAGING_SERVICE_PACKAGE_SUFFIX = + "com.android.ondevicepersonalization.services"; + + private static final String ALT_ODP_MANAGING_SERVICE_PACKAGE_SUFFIX = + "com.google.android.ondevicepersonalization.services"; + + private final AbstractServiceBinder<IOnDevicePersonalizationDebugService> mServiceBinder; + + public OnDevicePersonalizationDebugManager(Context context) { + mServiceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + context, + INTENT_FILTER_ACTION, + List.of( + ODP_MANAGING_SERVICE_PACKAGE_SUFFIX, + ALT_ODP_MANAGING_SERVICE_PACKAGE_SUFFIX), + IOnDevicePersonalizationDebugService.Stub::asInterface); + } + + /** Returns whether the service is enabled. */ + public Boolean isEnabled() { + try { + IOnDevicePersonalizationDebugService service = Objects.requireNonNull( + mServiceBinder.getService(Executors.newSingleThreadExecutor())); + boolean result = service.isEnabled(); + mServiceBinder.unbindFromService(); + return result; + } catch (RemoteException e) { + throw e.rethrowAsRuntimeException(); + } + } +} diff --git a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationException.java b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationException.java index 54a251af..e54f76bf 100644 --- a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationException.java +++ b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationException.java @@ -16,36 +16,48 @@ package android.adservices.ondevicepersonalization; -import android.annotation.NonNull; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; +import android.annotation.IntDef; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; /** * Exception thrown by OnDevicePersonalization APIs. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class OnDevicePersonalizationException extends Exception { - private final int mErrorCode; - - public OnDevicePersonalizationException(int errorCode) { - this(errorCode, ""); - } - - public OnDevicePersonalizationException(int errorCode, @NonNull String errorMessage) { - super("Error code: " + errorCode + " message: " + errorMessage); - mErrorCode = errorCode; - } - - public OnDevicePersonalizationException(int errorCode, @NonNull Throwable cause) { - this(errorCode, "", cause); - } - - public OnDevicePersonalizationException( - int errorCode, @NonNull String errorMessage, @NonNull Throwable cause) { - super("Error code: " + errorCode + " message: " + errorMessage, cause); + /** + * The {@link IsolatedService} that was invoked failed to run. + */ + public static final int ERROR_ISOLATED_SERVICE_FAILED = 1; + + /** + * Personalization is disabled. + * @hide + */ + public static final int ERROR_PERSONALIZATION_DISABLED = 2; + + /** @hide */ + @IntDef(prefix = "ERROR_", value = { + ERROR_ISOLATED_SERVICE_FAILED, + ERROR_PERSONALIZATION_DISABLED + }) + @Retention(RetentionPolicy.SOURCE) + public @interface ErrorCode {} + + private final @ErrorCode int mErrorCode; + + /** @hide */ + public OnDevicePersonalizationException(@ErrorCode int errorCode) { mErrorCode = errorCode; } - public int getErrorCode() { + /** Returns the error code for this exception. */ + public @ErrorCode int getErrorCode() { return mErrorCode; } } diff --git a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationManager.java b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationManager.java index beb44965..5c0e9b5c 100644 --- a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationManager.java +++ b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationManager.java @@ -16,32 +16,31 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + import android.adservices.ondevicepersonalization.aidl.IExecuteCallback; import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationManagingService; import android.adservices.ondevicepersonalization.aidl.IRequestSurfacePackageCallback; import android.annotation.CallbackExecutor; +import android.annotation.FlaggedApi; import android.annotation.NonNull; -import android.annotation.Nullable; import android.content.ComponentName; import android.content.Context; -import android.content.Intent; -import android.content.ServiceConnection; import android.content.pm.PackageManager; -import android.content.pm.ResolveInfo; import android.os.IBinder; import android.os.OutcomeReceiver; import android.os.PersistableBundle; import android.os.RemoteException; +import android.os.SystemClock; import android.view.SurfaceControlViewHost; +import com.android.federatedcompute.internal.util.AbstractServiceBinder; import com.android.modules.utils.build.SdkLevel; -import com.android.ondevicepersonalization.internal.util.LoggerFactory; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CountDownLatch; +import java.util.Objects; import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; // TODO(b/289102463): Add a link to the public ODP developer documentation. /** @@ -49,74 +48,48 @@ import java.util.concurrent.TimeUnit; * {@link IsolatedService} in an isolated process and interact with it. * * An app can request an {@link IsolatedService} to generate content for display - * within a {@link SurfaceView} within the app's view hierarchy, and also write persistent results - * to on-device storage which can be consumed by Federated Analytics for cross-device statistical - * analysis or by Federated Learning for model training. The displayed content and the persistent - * output are both not directly accessible by the calling app. - * - * @hide + * within an {@link android.view.SurfaceView} within the app's view hierarchy, and also write + * persistent results to on-device storage which can be consumed by Federated Analytics for + * cross-device statistical analysis or by Federated Learning for model training. The displayed + * content and the persistent output are both not directly accessible by the calling app. */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class OnDevicePersonalizationManager { /** @hide */ public static final String ON_DEVICE_PERSONALIZATION_SERVICE = "on_device_personalization_service"; + private static final String INTENT_FILTER_ACTION = "android.OnDevicePersonalizationService"; + private static final String ODP_MANAGING_SERVICE_PACKAGE_SUFFIX = + "com.android.ondevicepersonalization.services"; - private boolean mBound = false; - private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); - private static final String TAG = "OdpManager"; + private static final String ALT_ODP_MANAGING_SERVICE_PACKAGE_SUFFIX = + "com.google.android.ondevicepersonalization.services"; - private IOnDevicePersonalizationManagingService mService; + private final AbstractServiceBinder<IOnDevicePersonalizationManagingService> mServiceBinder; private final Context mContext; /** @hide */ public OnDevicePersonalizationManager(Context context) { mContext = context; - } - - private final CountDownLatch mConnectionLatch = new CountDownLatch(1); - - private final ServiceConnection mConnection = - new ServiceConnection() { - @Override - public void onServiceConnected(ComponentName name, IBinder service) { - mService = IOnDevicePersonalizationManagingService.Stub.asInterface(service); - mBound = true; - mConnectionLatch.countDown(); - } - - @Override - public void onNullBinding(ComponentName name) { - mBound = false; - mConnectionLatch.countDown(); - } - - @Override - public void onServiceDisconnected(ComponentName name) { - mService = null; - mBound = false; - } - }; - - private static final int BIND_SERVICE_TIMEOUT_SEC = 5; - private static final String VERSION = "1.0"; - - /** - * Gets OnDevicePersonalization version. - * This function is a temporary place holder. It will be removed when new APIs are added. - * - * @hide - */ - public String getVersion() { - return VERSION; + this.mServiceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + context, + INTENT_FILTER_ACTION, + List.of( + ODP_MANAGING_SERVICE_PACKAGE_SUFFIX, + ALT_ODP_MANAGING_SERVICE_PACKAGE_SUFFIX), + SdkLevel.isAtLeastU() ? Context.BIND_ALLOW_ACTIVITY_STARTS : 0, + IOnDevicePersonalizationManagingService.Stub::asInterface); } /** - * Executes a {@link IsolatedService} in the OnDevicePersonalization sandbox. The + * Executes an {@link IsolatedService} in the OnDevicePersonalization sandbox. The * platform binds to the specified {@link IsolatedService} in an isolated process - * and calls {@link IsolatedService#onExecute()} with the caller-provided - * parameters. When the {@link IsolatedService} finishes execution, the platform - * returns tokens that refer to the results from the service to the caller. These tokens can - * be subsequently used to display results in a {@link SurfaceView} within the calling app. + * and calls {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * with the caller-provided parameters. When the {@link IsolatedService} finishes execution, + * the platform returns tokens that refer to the results from the service to the caller. + * These tokens can be subsequently used to display results in a + * {@link android.view.SurfaceView} within the calling app. * * @param handler The {@link ComponentName} of the {@link IsolatedService}. * @param params a {@link PersistableBundle} that is passed from the calling app to the @@ -127,10 +100,18 @@ public class OnDevicePersonalizationManager { * an opaque reference to a {@link RenderingConfig} returned by an * {@link IsolatedService}, or an {@link Exception} on failure. The returned * {@link SurfacePackageToken} objects can be used in a subsequent - * {@link requestSurfacePackage} call to display the result in a view. The calling app and + * {@link #requestSurfacePackage(SurfacePackageToken, IBinder, int, int, int, Executor, + * OutcomeReceiver)} call to display the result in a view. The calling app and * the {@link IsolatedService} must agree on the expected size of this list. * An entry in the returned list of {@link SurfacePackageToken} objects may be null to * indicate that the service has no output for that specific surface. + * + * In case of an error, the receiver returns one of the following exceptions: + * Returns a {@link android.content.pm.PackageManager.NameNotFoundException} if the handler + * package is not installed or does not have a valid ODP manifest. + * Returns {@link ClassNotFoundException} if the handler class is not found. + * Returns an {@link OnDevicePersonalizationException} if execution of the handler fails. + * @hide */ public void execute( @NonNull ComponentName handler, @@ -138,8 +119,15 @@ public class OnDevicePersonalizationManager { @NonNull @CallbackExecutor Executor executor, @NonNull OutcomeReceiver<List<SurfacePackageToken>, Exception> receiver ) { + Objects.requireNonNull(handler); + Objects.requireNonNull(params); + Objects.requireNonNull(executor); + Objects.requireNonNull(receiver); + long startTimeMillis = SystemClock.elapsedRealtime(); + try { - bindService(executor); + final IOnDevicePersonalizationManagingService service = + mServiceBinder.getService(executor); IExecuteCallback callbackWrapper = new IExecuteCallback.Stub() { @Override @@ -165,49 +153,65 @@ public class OnDevicePersonalizationManager { @Override public void onError(int errorCode) { - executor.execute(() -> receiver.onError( - new OnDevicePersonalizationException(errorCode))); + executor.execute(() -> receiver.onError(createException(errorCode))); } }; - mService.execute( - mContext.getPackageName(), handler, params, callbackWrapper); + service.execute( + mContext.getPackageName(), + handler, + params, + new CallerMetadata.Builder().setStartTimeMillis(startTimeMillis).build(), + callbackWrapper); - } catch (Exception e) { - receiver.onError(e); + } catch (RemoteException e) { + receiver.onError(new IllegalStateException(e)); } } /** - * Requests a {@link SurfacePackage} to be inserted into a {@link SurfaceView} inside the - * calling app. The surface package will contain a {@link View} with the content from a result - * of a prior call to {@link #execute()} running in the OnDevicePersonalization sandbox. + * Requests a {@link android.view.SurfaceControlViewHost.SurfacePackage} to be inserted into a + * {@link android.view.SurfaceView} inside the calling app. The surface package will contain an + * {@link android.view.View} with the content from a result of a prior call to + * {@code #execute(ComponentName, PersistableBundle, Executor, OutcomeReceiver)} running in + * the OnDevicePersonalization sandbox. * * @param surfacePackageToken a reference to a {@link SurfacePackageToken} returned by a prior - * call to {@link execute}. - * @param hostToken the hostToken of the {@link SurfaceView}, which is returned by - * {@link SurfaceView#getHostToken()} after the {@link SurfaceView} has been added to the - * view hierarchy. + * call to {@code #execute(ComponentName, PersistableBundle, Executor, OutcomeReceiver)}. + * @param surfaceViewHostToken the hostToken of the {@link android.view.SurfaceView}, which is + * returned by {@link android.view.SurfaceView#getHostToken()} after the + * {@link android.view.SurfaceView} has been added to the view hierarchy. * @param displayId the integer ID of the logical display on which to display the - * {@link SurfacePackage}, returned by {@code Context.getDisplay().getDisplayId()}. - * @param width the width of the {@link SurfacePackage} in pixels. - * @param height the height of the {@link SurfacePackage} in pixels. + * {@link android.view.SurfaceControlViewHost.SurfacePackage}, returned by + * {@code Context.getDisplay().getDisplayId()}. + * @param width the width of the {@link android.view.SurfaceControlViewHost.SurfacePackage} + * in pixels. + * @param height the height of the {@link android.view.SurfaceControlViewHost.SurfacePackage} + * in pixels. * @param executor the {@link Executor} on which to invoke the callback - * @param receiver This either returns a {@link SurfacePackage} on success, or {@link - * Exception} on failure. - * + * @param receiver This either returns a + * {@link android.view.SurfaceControlViewHost.SurfacePackage} on success, or + * {@link Exception} on failure. The exception type is + * {@link OnDevicePersonalizationException} if execution of the handler fails. */ public void requestSurfacePackage( @NonNull SurfacePackageToken surfacePackageToken, - @NonNull IBinder hostToken, + @NonNull IBinder surfaceViewHostToken, int displayId, int width, int height, @NonNull @CallbackExecutor Executor executor, @NonNull OutcomeReceiver<SurfaceControlViewHost.SurfacePackage, Exception> receiver ) { + Objects.requireNonNull(surfacePackageToken); + Objects.requireNonNull(surfaceViewHostToken); + Objects.requireNonNull(executor); + Objects.requireNonNull(receiver); + long startTimeMillis = SystemClock.elapsedRealtime(); + try { - bindService(executor); + final IOnDevicePersonalizationManagingService service = + mServiceBinder.getService(executor); IRequestSurfacePackageCallback callbackWrapper = new IRequestSurfacePackageCallback.Stub() { @@ -221,70 +225,37 @@ public class OnDevicePersonalizationManager { @Override public void onError(int errorCode) { - executor.execute(() -> receiver.onError( - new OnDevicePersonalizationException(errorCode))); + executor.execute(() -> receiver.onError(createException(errorCode))); } }; - mService.requestSurfacePackage( - surfacePackageToken.getTokenString(), hostToken, displayId, - width, height, callbackWrapper); + service.requestSurfacePackage( + surfacePackageToken.getTokenString(), + surfaceViewHostToken, + displayId, + width, + height, + new CallerMetadata.Builder().setStartTimeMillis(startTimeMillis).build(), + callbackWrapper); - } catch (InterruptedException - | NullPointerException - | RemoteException e) { - receiver.onError(e); + } catch (RemoteException e) { + receiver.onError(new IllegalStateException(e)); } } - /** Bind to the service, if not already bound. */ - private void bindService(@NonNull Executor executor) throws InterruptedException { - if (!mBound) { - Intent intent = new Intent("android.OnDevicePersonalizationService"); - ComponentName serviceComponent = - resolveService(intent, mContext.getPackageManager()); - if (serviceComponent == null) { - sLogger.e(TAG + ": Invalid component for ondevicepersonalization service"); - return; - } - - intent.setComponent(serviceComponent); - int bindFlags = Context.BIND_AUTO_CREATE; - if (SdkLevel.isAtLeastU()) { - bindFlags |= Context.BIND_ALLOW_ACTIVITY_STARTS; - } - boolean r = mContext.bindService( - intent, bindFlags, executor, mConnection); - if (!r) { - return; - } - mConnectionLatch.await(BIND_SERVICE_TIMEOUT_SEC, TimeUnit.SECONDS); - } - } - - /** - * Find the ComponentName of the service, given its intent and package manager. - * - * @return ComponentName of the service. Null if the service is not found. - */ - private @Nullable ComponentName resolveService( - @NonNull Intent intent, @NonNull PackageManager pm) { - List<ResolveInfo> services = - pm.queryIntentServices(intent, PackageManager.ResolveInfoFlags.of(0)); - if (services == null || services.isEmpty()) { - sLogger.e(TAG + ": Failed to find ondevicepersonalization service"); - return null; - } - - for (int i = 0; i < services.size(); i++) { - ResolveInfo ri = services.get(i); - ComponentName resolved = - new ComponentName(ri.serviceInfo.packageName, ri.serviceInfo.name); - // There should only be one matching service inside the given package. - // If there's more than one, return the first one found. - return resolved; + private Exception createException(int errorCode) { + if (errorCode == Constants.STATUS_NAME_NOT_FOUND) { + return new PackageManager.NameNotFoundException(); + } else if (errorCode == Constants.STATUS_CLASS_NOT_FOUND) { + return new ClassNotFoundException(); + } else if (errorCode == Constants.STATUS_SERVICE_FAILED) { + return new OnDevicePersonalizationException( + OnDevicePersonalizationException.ERROR_ISOLATED_SERVICE_FAILED); + } else if (errorCode == Constants.STATUS_PERSONALIZATION_DISABLED) { + return new OnDevicePersonalizationException( + OnDevicePersonalizationException.ERROR_PERSONALIZATION_DISABLED); + } else { + return new IllegalStateException("Error: " + errorCode); } - sLogger.e(TAG + ": Didn't find any matching ondevicepersonalization service."); - return null; } } diff --git a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationPermissions.java b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationPermissions.java new file mode 100644 index 00000000..694e336c --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationPermissions.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; +import android.annotation.NonNull; +import android.annotation.SystemApi; +import android.content.Context; +import android.content.pm.PackageManager; + +/** + * OnDevicePersonalization permission settings. + * + * @hide +*/ +@SystemApi +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +public class OnDevicePersonalizationPermissions { + private OnDevicePersonalizationPermissions() {} + + /** + * The permission that lets it modify ODP's enablement state. + */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) + public static final String MODIFY_ONDEVICEPERSONALIZATION_STATE = + "android.permission.ondevicepersonalization.MODIFY_ONDEVICEPERSONALIZATION_STATE"; + + /** + * verify that caller has the permission to modify ODP's enablement state. + * @throws SecurityException otherwise. + * + * @hide + */ + public static void enforceCallingPermission(@NonNull Context context, + @NonNull String permission) { + if (context.checkCallingOrSelfPermission(permission) != PackageManager.PERMISSION_GRANTED) { + throw new SecurityException("Unauthorized call to ODP."); + } + } +} diff --git a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationPrivacyStatusManager.java b/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationPrivacyStatusManager.java deleted file mode 100644 index 7cd43c3a..00000000 --- a/framework/java/android/adservices/ondevicepersonalization/OnDevicePersonalizationPrivacyStatusManager.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package android.adservices.ondevicepersonalization; - -import android.adservices.ondevicepersonalization.aidl.IPrivacyStatusService; -import android.adservices.ondevicepersonalization.aidl.IPrivacyStatusServiceCallback; -import android.annotation.CallbackExecutor; -import android.annotation.NonNull; -import android.annotation.Nullable; -import android.content.ComponentName; -import android.content.Context; -import android.content.Intent; -import android.content.ServiceConnection; -import android.content.pm.ResolveInfo; -import android.content.pm.ServiceInfo; -import android.os.IBinder; -import android.os.OutcomeReceiver; -import android.os.RemoteException; - -import com.android.ondevicepersonalization.internal.util.LoggerFactory; - -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; - -/** - * OnDevicePersonalizationPrivacyStatusManager provides system APIs - * for GMSCore to control privacy statuses of ODP users. - * @hide - */ -public class OnDevicePersonalizationPrivacyStatusManager { - public static final String ON_DEVICE_PERSONALIZATION_PRIVACY_STATUS_SERVICE = - "on_device_personalization_privacy_status_service"; - private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); - private static final String TAG = "OdpPrivacyStatusManager"; - private static final String ODP_PRIVACY_STATUS_SERVICE_INTENT = - "android.OnDevicePersonalizationPrivacyStatusService"; - private boolean mBound = false; - private IPrivacyStatusService mService = null; - private final Context mContext; - private final CountDownLatch mConnectionLatch = new CountDownLatch(1); - - public OnDevicePersonalizationPrivacyStatusManager(@NonNull Context context) { - mContext = context; - } - - private final ServiceConnection mConnection = new ServiceConnection() { - @Override - public void onServiceConnected(ComponentName name, IBinder binder) { - mService = IPrivacyStatusService.Stub.asInterface(binder); - mBound = true; - mConnectionLatch.countDown(); - } - - @Override - public void onNullBinding(ComponentName name) { - mBound = false; - mConnectionLatch.countDown(); - } - - @Override - public void onServiceDisconnected(ComponentName name) { - mService = null; - mBound = false; - } - }; - - private static final int BIND_SERVICE_TIMEOUT_SEC = 5; - - /** - * Modify the user's kid status in ODP from GMSCore. - * - * @param isKidStatusEnabled user's kid status available at GMSCore. - * @param executor the {@link Executor} on which to invoke the callback - * @param receiver This either returns true on success or {@link Exception} on failure. - * @hide - */ - public void setKidStatus(boolean isKidStatusEnabled, - @NonNull @CallbackExecutor Executor executor, - @NonNull OutcomeReceiver<Boolean, Exception> receiver) { - try { - bindService(executor); - - mService.setKidStatus(isKidStatusEnabled, new IPrivacyStatusServiceCallback.Stub() { - @Override - public void onSuccess() { - executor.execute(() -> receiver.onResult(true)); - } - - @Override - public void onFailure(int errorCode) { - executor.execute(() -> receiver.onError( - new OnDevicePersonalizationException(errorCode))); - } - }); - } catch (InterruptedException | RemoteException e) { - receiver.onError(e); - } - } - - private void bindService(@NonNull Executor executor) throws InterruptedException { - if (!mBound) { - Intent intent = new Intent(ODP_PRIVACY_STATUS_SERVICE_INTENT); - ComponentName serviceComponent = resolveService(intent); - if (serviceComponent == null) { - sLogger.e(TAG + ": Invalid component for ODP privacy status service"); - return; - } - - intent.setComponent(serviceComponent); - boolean r = mContext.bindService( - intent, Context.BIND_AUTO_CREATE, executor, mConnection); - if (!r) { - return; - } - mConnectionLatch.await(BIND_SERVICE_TIMEOUT_SEC, TimeUnit.SECONDS); - } - } - - /** - * Find the ComponentName of the service, given its intent. - * - * @return ComponentName of the service. Null if the service is not found. - */ - @Nullable - private ComponentName resolveService(@NonNull Intent intent) { - List<ResolveInfo> services = mContext.getPackageManager().queryIntentServices(intent, 0); - if (services == null || services.isEmpty()) { - sLogger.e(TAG + ": Failed to find OdpPrivacyStatus service"); - return null; - } - - for (int i = 0; i < services.size(); i++) { - ServiceInfo serviceInfo = services.get(i).serviceInfo; - if (serviceInfo == null) { - sLogger.e(TAG + ": Failed to find serviceInfo for OdpPrivacyStatus service."); - return null; - } - // There should only be one matching service inside the given package. - // If there's more than one, return the first one found. - return new ComponentName(serviceInfo.packageName, serviceInfo.name); - } - sLogger.e(TAG + ": Didn't find any matching OdpPrivacyStatus service."); - return null; - } -} diff --git a/framework/java/android/adservices/ondevicepersonalization/RemoteDataImpl.java b/framework/java/android/adservices/ondevicepersonalization/RemoteDataImpl.java index ce0d3b43..03752bac 100644 --- a/framework/java/android/adservices/ondevicepersonalization/RemoteDataImpl.java +++ b/framework/java/android/adservices/ondevicepersonalization/RemoteDataImpl.java @@ -32,7 +32,6 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.TimeUnit; /** @hide */ public class RemoteDataImpl implements KeyValueStore { @@ -41,15 +40,13 @@ public class RemoteDataImpl implements KeyValueStore { @NonNull IDataAccessService mDataAccessService; - private static final long ASYNC_TIMEOUT_MS = 1000; - /** @hide */ public RemoteDataImpl(@NonNull IDataAccessService binder) { mDataAccessService = Objects.requireNonNull(binder); } @Override @Nullable - public byte[] get(@NonNull String key) throws OnDevicePersonalizationException { + public byte[] get(@NonNull String key) { Objects.requireNonNull(key); try { BlockingQueue<Bundle> asyncResult = new ArrayBlockingQueue<>(1); @@ -61,7 +58,11 @@ public class RemoteDataImpl implements KeyValueStore { new IDataAccessServiceCallback.Stub() { @Override public void onSuccess(@NonNull Bundle result) { - asyncResult.add(result); + if (result != null) { + asyncResult.add(result); + } else { + asyncResult.add(Bundle.EMPTY); + } } @Override @@ -69,26 +70,22 @@ public class RemoteDataImpl implements KeyValueStore { asyncResult.add(Bundle.EMPTY); } }); - Bundle result = asyncResult.poll(ASYNC_TIMEOUT_MS, TimeUnit.MILLISECONDS); - if (null == result) { - sLogger.e(TAG + ": Timed out waiting for result of remoteData lookup"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); - } + Bundle result = asyncResult.take(); HashMap<String, byte[]> data = result.getSerializable( Constants.EXTRA_RESULT, HashMap.class); if (null == data) { sLogger.e(TAG + ": No EXTRA_RESULT was present in bundle"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException("Bundle missing EXTRA_RESULT."); } return data.get(key); } catch (InterruptedException | RemoteException e) { sLogger.e(TAG + ": Failed to retrieve key from remoteData", e); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException(e); } } @Override @NonNull - public Set<String> keySet() throws OnDevicePersonalizationException { + public Set<String> keySet() { try { BlockingQueue<Bundle> asyncResult = new ArrayBlockingQueue<>(1); mDataAccessService.onRequest( @@ -97,7 +94,11 @@ public class RemoteDataImpl implements KeyValueStore { new IDataAccessServiceCallback.Stub() { @Override public void onSuccess(@NonNull Bundle result) { - asyncResult.add(result); + if (result != null) { + asyncResult.add(result); + } else { + asyncResult.add(Bundle.EMPTY); + } } @Override @@ -105,21 +106,17 @@ public class RemoteDataImpl implements KeyValueStore { asyncResult.add(Bundle.EMPTY); } }); - Bundle result = asyncResult.poll(ASYNC_TIMEOUT_MS, TimeUnit.MILLISECONDS); - if (null == result) { - sLogger.e(TAG + ": Timed out waiting for result of remoteData keySet"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); - } + Bundle result = asyncResult.take(); HashSet<String> resultSet = result.getSerializable(Constants.EXTRA_RESULT, HashSet.class); if (null == resultSet) { sLogger.e(TAG + ": No EXTRA_RESULT was present in bundle"); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException("Bundle missing EXTRA_RESULT."); } return resultSet; } catch (InterruptedException | RemoteException e) { sLogger.e(TAG + ": Failed to retrieve keySet from remoteData", e); - throw new OnDevicePersonalizationException(Constants.STATUS_INTERNAL_ERROR); + throw new IllegalStateException(e); } } } diff --git a/framework/java/android/adservices/ondevicepersonalization/RenderInput.java b/framework/java/android/adservices/ondevicepersonalization/RenderInput.java index 8061a52d..a19f4ddb 100644 --- a/framework/java/android/adservices/ondevicepersonalization/RenderInput.java +++ b/framework/java/android/adservices/ondevicepersonalization/RenderInput.java @@ -16,17 +16,22 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; +import android.annotation.NonNull; import android.annotation.Nullable; import android.os.Parcelable; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * The input data for {@link IsolatedWorker#onRender()}. + * The input data for + * {@link IsolatedWorker#onRender(RenderInput, java.util.function.Consumer)}. * - * @hide */ -@DataClass(genBuilder = true, genEqualsHashCode = true) +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +@DataClass(genBuilder = false, genHiddenConstructor = true, genEqualsHashCode = true) public final class RenderInput implements Parcelable { /** The width of the slot. */ private int mWidth = 0; @@ -40,9 +45,18 @@ public final class RenderInput implements Parcelable { */ private int mRenderingConfigIndex = 0; - /** A {@link RenderingConfig} returned by {@link onExecute}. */ + /** + * A {@link RenderingConfig} within an {@link ExecuteOutput} that was returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + */ @Nullable RenderingConfig mRenderingConfig = null; + /** @hide */ + public RenderInput(@NonNull RenderInputParcel parcel) { + this(parcel.getWidth(), parcel.getHeight(), parcel.getRenderingConfigIndex(), + parcel.getRenderingConfig()); + } + // Code below generated by codegen v1.0.23. @@ -58,8 +72,23 @@ public final class RenderInput implements Parcelable { //@formatter:off + /** + * Creates a new RenderInput. + * + * @param width + * The width of the slot. + * @param height + * The height of the slot. + * @param renderingConfigIndex + * The index of the {@link RenderingConfig} in {@link ExecuteOutput} that this render + * request is for. + * @param renderingConfig + * A {@link RenderingConfig} within an {@link ExecuteOutput} that was returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + * @hide + */ @DataClass.Generated.Member - /* package-private */ RenderInput( + public RenderInput( int width, int height, int renderingConfigIndex, @@ -98,7 +127,8 @@ public final class RenderInput implements Parcelable { } /** - * A {@link RenderingConfig} returned by {@link onExecute}. + * A {@link RenderingConfig} within an {@link ExecuteOutput} that was returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. */ @DataClass.Generated.Member public @Nullable RenderingConfig getRenderingConfig() { @@ -140,7 +170,7 @@ public final class RenderInput implements Parcelable { @Override @DataClass.Generated.Member - public void writeToParcel(@android.annotation.NonNull android.os.Parcel dest, int flags) { + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { // You can override field parcelling by defining methods like: // void parcelFieldName(Parcel dest, int flags) { ... } @@ -160,7 +190,7 @@ public final class RenderInput implements Parcelable { /** @hide */ @SuppressWarnings({"unchecked", "RedundantCast"}) @DataClass.Generated.Member - /* package-private */ RenderInput(@android.annotation.NonNull android.os.Parcel in) { + /* package-private */ RenderInput(@NonNull android.os.Parcel in) { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } @@ -179,7 +209,7 @@ public final class RenderInput implements Parcelable { } @DataClass.Generated.Member - public static final @android.annotation.NonNull Parcelable.Creator<RenderInput> CREATOR + public static final @NonNull Parcelable.Creator<RenderInput> CREATOR = new Parcelable.Creator<RenderInput>() { @Override public RenderInput[] newArray(int size) { @@ -187,111 +217,16 @@ public final class RenderInput implements Parcelable { } @Override - public RenderInput createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + public RenderInput createFromParcel(@NonNull android.os.Parcel in) { return new RenderInput(in); } }; - /** - * A builder for {@link RenderInput} - */ - @SuppressWarnings("WeakerAccess") - @DataClass.Generated.Member - public static final class Builder { - - private int mWidth; - private int mHeight; - private int mRenderingConfigIndex; - private @Nullable RenderingConfig mRenderingConfig; - - private long mBuilderFieldsSet = 0L; - - public Builder() { - } - - /** - * The width of the slot. - */ - @DataClass.Generated.Member - public @android.annotation.NonNull Builder setWidth(int value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x1; - mWidth = value; - return this; - } - - /** - * The height of the slot. - */ - @DataClass.Generated.Member - public @android.annotation.NonNull Builder setHeight(int value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x2; - mHeight = value; - return this; - } - - /** - * The index of the {@link RenderingConfig} in {@link ExecuteOutput} that this render - * request is for. - */ - @DataClass.Generated.Member - public @android.annotation.NonNull Builder setRenderingConfigIndex(int value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x4; - mRenderingConfigIndex = value; - return this; - } - - /** - * A {@link RenderingConfig} returned by {@link onExecute}. - */ - @DataClass.Generated.Member - public @android.annotation.NonNull Builder setRenderingConfig(@android.annotation.NonNull RenderingConfig value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x8; - mRenderingConfig = value; - return this; - } - - /** Builds the instance. This builder should not be touched after calling this! */ - public @android.annotation.NonNull RenderInput build() { - checkNotUsed(); - mBuilderFieldsSet |= 0x10; // Mark builder used - - if ((mBuilderFieldsSet & 0x1) == 0) { - mWidth = 0; - } - if ((mBuilderFieldsSet & 0x2) == 0) { - mHeight = 0; - } - if ((mBuilderFieldsSet & 0x4) == 0) { - mRenderingConfigIndex = 0; - } - if ((mBuilderFieldsSet & 0x8) == 0) { - mRenderingConfig = null; - } - RenderInput o = new RenderInput( - mWidth, - mHeight, - mRenderingConfigIndex, - mRenderingConfig); - return o; - } - - private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x10) != 0) { - throw new IllegalStateException( - "This Builder should not be reused. Use a new Builder instance instead"); - } - } - } - @DataClass.Generated( - time = 1692118409407L, + time = 1698873113096L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderInput.java", - inputSignatures = "private int mWidth\nprivate int mHeight\nprivate int mRenderingConfigIndex\n @android.annotation.Nullable android.adservices.ondevicepersonalization.RenderingConfig mRenderingConfig\nclass RenderInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private int mWidth\nprivate int mHeight\nprivate int mRenderingConfigIndex\n @android.annotation.Nullable android.adservices.ondevicepersonalization.RenderingConfig mRenderingConfig\nclass RenderInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=false, genHiddenConstructor=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/RenderInputParcel.java b/framework/java/android/adservices/ondevicepersonalization/RenderInputParcel.java new file mode 100644 index 00000000..bf2af756 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/RenderInputParcel.java @@ -0,0 +1,274 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.Nullable; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Parcelable version of {@link RenderInput}. + * @hide + */ +@DataClass(genAidl = false, genHiddenBuilder = true) +public final class RenderInputParcel implements Parcelable { + /** The width of the slot. */ + private int mWidth = 0; + + /** The height of the slot. */ + private int mHeight = 0; + + /** + * The index of the {@link RenderingConfig} in {@link ExecuteOutput} that this render + * request is for. + */ + private int mRenderingConfigIndex = 0; + + /** + * A {@link RenderingConfig} within an {@link ExecuteOutput} that was returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + */ + @Nullable RenderingConfig mRenderingConfig = null; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderInputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ RenderInputParcel( + int width, + int height, + int renderingConfigIndex, + @Nullable RenderingConfig renderingConfig) { + this.mWidth = width; + this.mHeight = height; + this.mRenderingConfigIndex = renderingConfigIndex; + this.mRenderingConfig = renderingConfig; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The width of the slot. + */ + @DataClass.Generated.Member + public int getWidth() { + return mWidth; + } + + /** + * The height of the slot. + */ + @DataClass.Generated.Member + public int getHeight() { + return mHeight; + } + + /** + * The index of the {@link RenderingConfig} in {@link ExecuteOutput} that this render + * request is for. + */ + @DataClass.Generated.Member + public int getRenderingConfigIndex() { + return mRenderingConfigIndex; + } + + /** + * A {@link RenderingConfig} within an {@link ExecuteOutput} that was returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + */ + @DataClass.Generated.Member + public @Nullable RenderingConfig getRenderingConfig() { + return mRenderingConfig; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@android.annotation.NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mRenderingConfig != null) flg |= 0x8; + dest.writeByte(flg); + dest.writeInt(mWidth); + dest.writeInt(mHeight); + dest.writeInt(mRenderingConfigIndex); + if (mRenderingConfig != null) dest.writeTypedObject(mRenderingConfig, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ RenderInputParcel(@android.annotation.NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + int width = in.readInt(); + int height = in.readInt(); + int renderingConfigIndex = in.readInt(); + RenderingConfig renderingConfig = (flg & 0x8) == 0 ? null : (RenderingConfig) in.readTypedObject(RenderingConfig.CREATOR); + + this.mWidth = width; + this.mHeight = height; + this.mRenderingConfigIndex = renderingConfigIndex; + this.mRenderingConfig = renderingConfig; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @android.annotation.NonNull Parcelable.Creator<RenderInputParcel> CREATOR + = new Parcelable.Creator<RenderInputParcel>() { + @Override + public RenderInputParcel[] newArray(int size) { + return new RenderInputParcel[size]; + } + + @Override + public RenderInputParcel createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + return new RenderInputParcel(in); + } + }; + + /** + * A builder for {@link RenderInputParcel} + * @hide + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private int mWidth; + private int mHeight; + private int mRenderingConfigIndex; + private @Nullable RenderingConfig mRenderingConfig; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * The width of the slot. + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setWidth(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mWidth = value; + return this; + } + + /** + * The height of the slot. + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setHeight(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mHeight = value; + return this; + } + + /** + * The index of the {@link RenderingConfig} in {@link ExecuteOutput} that this render + * request is for. + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setRenderingConfigIndex(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mRenderingConfigIndex = value; + return this; + } + + /** + * A {@link RenderingConfig} within an {@link ExecuteOutput} that was returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setRenderingConfig(@android.annotation.NonNull RenderingConfig value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; + mRenderingConfig = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @android.annotation.NonNull RenderInputParcel build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x10; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mWidth = 0; + } + if ((mBuilderFieldsSet & 0x2) == 0) { + mHeight = 0; + } + if ((mBuilderFieldsSet & 0x4) == 0) { + mRenderingConfigIndex = 0; + } + if ((mBuilderFieldsSet & 0x8) == 0) { + mRenderingConfig = null; + } + RenderInputParcel o = new RenderInputParcel( + mWidth, + mHeight, + mRenderingConfigIndex, + mRenderingConfig); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x10) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1698872925083L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderInputParcel.java", + inputSignatures = "private int mWidth\nprivate int mHeight\nprivate int mRenderingConfigIndex\n @android.annotation.Nullable android.adservices.ondevicepersonalization.RenderingConfig mRenderingConfig\nclass RenderInputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genHiddenBuilder=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/RenderOutput.java b/framework/java/android/adservices/ondevicepersonalization/RenderOutput.java index 015e2876..3b7c8dc1 100644 --- a/framework/java/android/adservices/ondevicepersonalization/RenderOutput.java +++ b/framework/java/android/adservices/ondevicepersonalization/RenderOutput.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; import android.os.Parcelable; @@ -25,25 +28,34 @@ import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * The result returned by {@link IsolatedWorker#onExecute()} in response to a - * {@link OnDevicePersonalizationManager#requestSurfacePackage()} request from a calling app. + * The result returned by + * {@link IsolatedWorker#onRender(RenderInput, java.util.function.Consumer)}. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @DataClass(genBuilder = true, genEqualsHashCode = true) public final class RenderOutput implements Parcelable { - /** The content to be rendered. */ + /** + * The HTML content to be rendered in a webview. If this is null, the ODP service + * generates HTML from the data in {@link #getTemplateId()} and {@link #getTemplateParams()} + * as described below. + */ @Nullable private String mContent = null; /** - * Parameters for template rendering + * A key in the REMOTE_DATA {@link IsolatedService#getRemoteData(RequestToken)} table that + * points to an <a href="velocity.apache.org">Apache Velocity</a> template. This is ignored if + * {@link #getContent()} is not null. */ - @NonNull private PersistableBundle mTemplateParams = PersistableBundle.EMPTY; + @Nullable private String mTemplateId = null; /** - * Template ID to retrieve from REMOTE_DATA for rendering + * The parameters to be populated in the template from {@link #getTemplateId()}. This is + * ignored if {@link #getContent()} is not null. */ - @Nullable private String mTemplateId = null; + @NonNull private PersistableBundle mTemplateParams = PersistableBundle.EMPTY; + + @@ -63,19 +75,21 @@ public final class RenderOutput implements Parcelable { @DataClass.Generated.Member /* package-private */ RenderOutput( @Nullable String content, - @NonNull PersistableBundle templateParams, - @Nullable String templateId) { + @Nullable String templateId, + @NonNull PersistableBundle templateParams) { this.mContent = content; + this.mTemplateId = templateId; this.mTemplateParams = templateParams; AnnotationValidations.validate( NonNull.class, null, mTemplateParams); - this.mTemplateId = templateId; // onConstructed(); // You can define this method to get a callback } /** - * The content to be rendered. + * The HTML content to be rendered in a webview. If this is null, the ODP service + * generates HTML from the data in {@link #getTemplateId()} and {@link #getTemplateParams()} + * as described below. */ @DataClass.Generated.Member public @Nullable String getContent() { @@ -83,19 +97,22 @@ public final class RenderOutput implements Parcelable { } /** - * Parameters for template rendering + * A key in the REMOTE_DATA {@link IsolatedService#getRemoteData(RequestToken)} table that + * points to an <a href="velocity.apache.org">Apache Velocity</a> template. This is ignored if + * {@link #getContent()} is not null. */ @DataClass.Generated.Member - public @NonNull PersistableBundle getTemplateParams() { - return mTemplateParams; + public @Nullable String getTemplateId() { + return mTemplateId; } /** - * Template ID to retrieve from REMOTE_DATA for rendering + * The parameters to be populated in the template from {@link #getTemplateId()}. This is + * ignored if {@link #getContent()} is not null. */ @DataClass.Generated.Member - public @Nullable String getTemplateId() { - return mTemplateId; + public @NonNull PersistableBundle getTemplateParams() { + return mTemplateParams; } @Override @@ -112,8 +129,8 @@ public final class RenderOutput implements Parcelable { //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mContent, that.mContent) - && java.util.Objects.equals(mTemplateParams, that.mTemplateParams) - && java.util.Objects.equals(mTemplateId, that.mTemplateId); + && java.util.Objects.equals(mTemplateId, that.mTemplateId) + && java.util.Objects.equals(mTemplateParams, that.mTemplateParams); } @Override @@ -124,8 +141,8 @@ public final class RenderOutput implements Parcelable { int _hash = 1; _hash = 31 * _hash + java.util.Objects.hashCode(mContent); - _hash = 31 * _hash + java.util.Objects.hashCode(mTemplateParams); _hash = 31 * _hash + java.util.Objects.hashCode(mTemplateId); + _hash = 31 * _hash + java.util.Objects.hashCode(mTemplateParams); return _hash; } @@ -137,11 +154,11 @@ public final class RenderOutput implements Parcelable { byte flg = 0; if (mContent != null) flg |= 0x1; - if (mTemplateId != null) flg |= 0x4; + if (mTemplateId != null) flg |= 0x2; dest.writeByte(flg); if (mContent != null) dest.writeString(mContent); - dest.writeTypedObject(mTemplateParams, flags); if (mTemplateId != null) dest.writeString(mTemplateId); + dest.writeTypedObject(mTemplateParams, flags); } @Override @@ -157,14 +174,14 @@ public final class RenderOutput implements Parcelable { byte flg = in.readByte(); String content = (flg & 0x1) == 0 ? null : in.readString(); + String templateId = (flg & 0x2) == 0 ? null : in.readString(); PersistableBundle templateParams = (PersistableBundle) in.readTypedObject(PersistableBundle.CREATOR); - String templateId = (flg & 0x4) == 0 ? null : in.readString(); this.mContent = content; + this.mTemplateId = templateId; this.mTemplateParams = templateParams; AnnotationValidations.validate( NonNull.class, null, mTemplateParams); - this.mTemplateId = templateId; // onConstructed(); // You can define this method to get a callback } @@ -186,13 +203,14 @@ public final class RenderOutput implements Parcelable { /** * A builder for {@link RenderOutput} */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { private @Nullable String mContent; - private @NonNull PersistableBundle mTemplateParams; private @Nullable String mTemplateId; + private @NonNull PersistableBundle mTemplateParams; private long mBuilderFieldsSet = 0L; @@ -200,7 +218,9 @@ public final class RenderOutput implements Parcelable { } /** - * The content to be rendered. + * The HTML content to be rendered in a webview. If this is null, the ODP service + * generates HTML from the data in {@link #getTemplateId()} and {@link #getTemplateParams()} + * as described below. */ @DataClass.Generated.Member public @NonNull Builder setContent(@NonNull String value) { @@ -211,24 +231,27 @@ public final class RenderOutput implements Parcelable { } /** - * Parameters for template rendering + * A key in the REMOTE_DATA {@link IsolatedService#getRemoteData(RequestToken)} table that + * points to an <a href="velocity.apache.org">Apache Velocity</a> template. This is ignored if + * {@link #getContent()} is not null. */ @DataClass.Generated.Member - public @NonNull Builder setTemplateParams(@NonNull PersistableBundle value) { + public @NonNull Builder setTemplateId(@NonNull String value) { checkNotUsed(); mBuilderFieldsSet |= 0x2; - mTemplateParams = value; + mTemplateId = value; return this; } /** - * Template ID to retrieve from REMOTE_DATA for rendering + * The parameters to be populated in the template from {@link #getTemplateId()}. This is + * ignored if {@link #getContent()} is not null. */ @DataClass.Generated.Member - public @NonNull Builder setTemplateId(@NonNull String value) { + public @NonNull Builder setTemplateParams(@NonNull PersistableBundle value) { checkNotUsed(); mBuilderFieldsSet |= 0x4; - mTemplateId = value; + mTemplateParams = value; return this; } @@ -241,15 +264,15 @@ public final class RenderOutput implements Parcelable { mContent = null; } if ((mBuilderFieldsSet & 0x2) == 0) { - mTemplateParams = PersistableBundle.EMPTY; + mTemplateId = null; } if ((mBuilderFieldsSet & 0x4) == 0) { - mTemplateId = null; + mTemplateParams = PersistableBundle.EMPTY; } RenderOutput o = new RenderOutput( mContent, - mTemplateParams, - mTemplateId); + mTemplateId, + mTemplateParams); return o; } @@ -262,10 +285,10 @@ public final class RenderOutput implements Parcelable { } @DataClass.Generated( - time = 1692118415895L, + time = 1697132582732L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderOutput.java", - inputSignatures = "private @android.annotation.Nullable java.lang.String mContent\nprivate @android.annotation.NonNull android.os.PersistableBundle mTemplateParams\nprivate @android.annotation.Nullable java.lang.String mTemplateId\nclass RenderOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private @android.annotation.Nullable java.lang.String mContent\nprivate @android.annotation.Nullable java.lang.String mTemplateId\nprivate @android.annotation.NonNull android.os.PersistableBundle mTemplateParams\nclass RenderOutput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/RenderOutputParcel.java b/framework/java/android/adservices/ondevicepersonalization/RenderOutputParcel.java new file mode 100644 index 00000000..afd4cfaf --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/RenderOutputParcel.java @@ -0,0 +1,197 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.os.Parcelable; +import android.os.PersistableBundle; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Parcelable version of {@link RenderOutput}. + * @hide + */ +@DataClass(genAidl = false, genBuilder = false) +public final class RenderOutputParcel implements Parcelable { + /** + * The HTML content to be rendered in a webview. If this is null, the ODP service + * generates HTML from the data in {@link #getTemplateId()} and {@link #getTemplateParams()} + * as described below. + */ + @Nullable private String mContent = null; + + /** + * A key in the REMOTE_DATA {@link IsolatedService#getRemoteData(RequestToken)} table that + * points to an <a href="velocity.apache.org">Apache Velocity</a> template. This is ignored if + * {@link #getContent()} is not null. + */ + @Nullable private String mTemplateId = null; + + /** + * The parameters to be populated in the template from {@link #getTemplateId()}. This is + * ignored if {@link #getContent()} is not null. + */ + @NonNull private PersistableBundle mTemplateParams = PersistableBundle.EMPTY; + + /** @hide */ + public RenderOutputParcel(@NonNull RenderOutput value) { + this(value.getContent(), value.getTemplateId(), value.getTemplateParams()); + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderOutputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + /** + * Creates a new RenderOutputParcel. + * + * @param content + * The HTML content to be rendered in a webview. If this is null, the ODP service + * generates HTML from the data in {@link #getTemplateId()} and {@link #getTemplateParams()} + * as described below. + * @param templateId + * A key in the REMOTE_DATA {@link IsolatedService#getRemoteData(RequestToken)} table that + * points to an <a href="velocity.apache.org">Apache Velocity</a> template. This is ignored if + * {@link #getContent()} is not null. + * @param templateParams + * The parameters to be populated in the template from {@link #getTemplateId()}. This is + * ignored if {@link #getContent()} is not null. + */ + @DataClass.Generated.Member + public RenderOutputParcel( + @Nullable String content, + @Nullable String templateId, + @NonNull PersistableBundle templateParams) { + this.mContent = content; + this.mTemplateId = templateId; + this.mTemplateParams = templateParams; + AnnotationValidations.validate( + NonNull.class, null, mTemplateParams); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The HTML content to be rendered in a webview. If this is null, the ODP service + * generates HTML from the data in {@link #getTemplateId()} and {@link #getTemplateParams()} + * as described below. + */ + @DataClass.Generated.Member + public @Nullable String getContent() { + return mContent; + } + + /** + * A key in the REMOTE_DATA {@link IsolatedService#getRemoteData(RequestToken)} table that + * points to an <a href="velocity.apache.org">Apache Velocity</a> template. This is ignored if + * {@link #getContent()} is not null. + */ + @DataClass.Generated.Member + public @Nullable String getTemplateId() { + return mTemplateId; + } + + /** + * The parameters to be populated in the template from {@link #getTemplateId()}. This is + * ignored if {@link #getContent()} is not null. + */ + @DataClass.Generated.Member + public @NonNull PersistableBundle getTemplateParams() { + return mTemplateParams; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mContent != null) flg |= 0x1; + if (mTemplateId != null) flg |= 0x2; + dest.writeByte(flg); + if (mContent != null) dest.writeString(mContent); + if (mTemplateId != null) dest.writeString(mTemplateId); + dest.writeTypedObject(mTemplateParams, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ RenderOutputParcel(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + String content = (flg & 0x1) == 0 ? null : in.readString(); + String templateId = (flg & 0x2) == 0 ? null : in.readString(); + PersistableBundle templateParams = (PersistableBundle) in.readTypedObject(PersistableBundle.CREATOR); + + this.mContent = content; + this.mTemplateId = templateId; + this.mTemplateParams = templateParams; + AnnotationValidations.validate( + NonNull.class, null, mTemplateParams); + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<RenderOutputParcel> CREATOR + = new Parcelable.Creator<RenderOutputParcel>() { + @Override + public RenderOutputParcel[] newArray(int size) { + return new RenderOutputParcel[size]; + } + + @Override + public RenderOutputParcel createFromParcel(@NonNull android.os.Parcel in) { + return new RenderOutputParcel(in); + } + }; + + @DataClass.Generated( + time = 1698864341247L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderOutputParcel.java", + inputSignatures = "private @android.annotation.Nullable java.lang.String mContent\nprivate @android.annotation.Nullable java.lang.String mTemplateId\nprivate @android.annotation.NonNull android.os.PersistableBundle mTemplateParams\nclass RenderOutputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genAidl=false, genBuilder=false)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/RenderingConfig.java b/framework/java/android/adservices/ondevicepersonalization/RenderingConfig.java index fec58f73..ae54c77a 100644 --- a/framework/java/android/adservices/ondevicepersonalization/RenderingConfig.java +++ b/framework/java/android/adservices/ondevicepersonalization/RenderingConfig.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.os.Parcelable; @@ -26,16 +29,20 @@ import java.util.Collections; import java.util.List; /** - * Information returned by {@link IsolatedWorker#onExecute()} that is used - * in a subesequent call to {@link IsolatedWorker#onRender()} to identify the - * content to be displayed in a single {@link View}. + * Information returned by + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)} + * that is used in a subesequent call to + * {@link IsolatedWorker#onRender(RenderInput, java.util.function.Consumer)} to identify the + * content to be displayed in a single {@link android.view.View}. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @DataClass(genBuilder = true, genEqualsHashCode = true) public final class RenderingConfig implements Parcelable { /** - * A List of keys in the REMOTE_DATA table that identify the content to be rendered. + * A List of keys in the REMOTE_DATA + * {@link IsolatedService#getRemoteData(RequestToken)} + * table that identify the content to be rendered. **/ @DataClass.PluralOf("key") @NonNull private List<String> mKeys = Collections.emptyList(); @@ -66,7 +73,9 @@ public final class RenderingConfig implements Parcelable { } /** - * A List of keys in the REMOTE_DATA table that identify the content to be rendered. + * A List of keys in the REMOTE_DATA + * {@link IsolatedService#getRemoteData(RequestToken)} + * table that identify the content to be rendered. */ @DataClass.Generated.Member public @NonNull List<String> getKeys() { @@ -147,6 +156,7 @@ public final class RenderingConfig implements Parcelable { /** * A builder for {@link RenderingConfig} */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { @@ -159,7 +169,9 @@ public final class RenderingConfig implements Parcelable { } /** - * A List of keys in the REMOTE_DATA table that identify the content to be rendered. + * A List of keys in the REMOTE_DATA + * {@link IsolatedService#getRemoteData(RequestToken)} + * table that identify the content to be rendered. */ @DataClass.Generated.Member public @NonNull Builder setKeys(@NonNull List<String> value) { @@ -199,7 +211,7 @@ public final class RenderingConfig implements Parcelable { } @DataClass.Generated( - time = 1692118399909L, + time = 1697132616124L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RenderingConfig.java", inputSignatures = "private @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"key\") @android.annotation.NonNull java.util.List<java.lang.String> mKeys\nclass RenderingConfig extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") diff --git a/framework/java/android/adservices/ondevicepersonalization/RequestLogRecord.java b/framework/java/android/adservices/ondevicepersonalization/RequestLogRecord.java index 1d6d0ddc..8cd16764 100644 --- a/framework/java/android/adservices/ondevicepersonalization/RequestLogRecord.java +++ b/framework/java/android/adservices/ondevicepersonalization/RequestLogRecord.java @@ -16,6 +16,9 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.content.ContentValues; import android.os.Parcelable; @@ -26,12 +29,17 @@ import com.android.ondevicepersonalization.internal.util.DataClass; import java.util.Collections; import java.util.List; +// TODO(b/289102463): Add a link to the public doc for the REQUESTS table when available. /** * Contains data that will be written to the REQUESTS table at the end of a call to - * {@link IsolatedWorker#onExecute()}. + * {@link IsolatedWorker#onExecute(ExecuteInput, java.util.function.Consumer)}. + * A single {@link RequestLogRecord} is appended to the + * REQUESTS table if it is provided as a part of {@link ExecuteOutput}. The contents of + * the REQUESTS table can be consumed by Federated Learning facilitated model training, + * or Federated Analytics facilitated cross-device statistical analysis. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @DataClass(genBuilder = true, genEqualsHashCode = true) public final class RequestLogRecord implements Parcelable { /** @@ -40,6 +48,25 @@ public final class RequestLogRecord implements Parcelable { @DataClass.PluralOf("row") @NonNull List<ContentValues> mRows = Collections.emptyList(); + /** + * Internal id for the RequestLogRecord. + * @hide + */ + private long mRequestId = 0; + + /** + * Time of the request in milliseconds + * @hide + */ + private long mTimeMillis = 0; + + abstract static class BaseBuilder { + /** + * @hide + */ + public abstract Builder setTimeMillis(long value); + } + // Code below generated by codegen v1.0.23. @@ -57,10 +84,14 @@ public final class RequestLogRecord implements Parcelable { @DataClass.Generated.Member /* package-private */ RequestLogRecord( - @NonNull List<ContentValues> rows) { + @NonNull List<ContentValues> rows, + long requestId, + long timeMillis) { this.mRows = rows; AnnotationValidations.validate( NonNull.class, null, mRows); + this.mRequestId = requestId; + this.mTimeMillis = timeMillis; // onConstructed(); // You can define this method to get a callback } @@ -73,6 +104,26 @@ public final class RequestLogRecord implements Parcelable { return mRows; } + /** + * Internal id for the RequestLogRecord. + * + * @hide + */ + @DataClass.Generated.Member + public long getRequestId() { + return mRequestId; + } + + /** + * Time of the request in milliseconds + * + * @hide + */ + @DataClass.Generated.Member + public long getTimeMillis() { + return mTimeMillis; + } + @Override @DataClass.Generated.Member public boolean equals(@android.annotation.Nullable Object o) { @@ -86,7 +137,9 @@ public final class RequestLogRecord implements Parcelable { RequestLogRecord that = (RequestLogRecord) o; //noinspection PointlessBooleanExpression return true - && java.util.Objects.equals(mRows, that.mRows); + && java.util.Objects.equals(mRows, that.mRows) + && mRequestId == that.mRequestId + && mTimeMillis == that.mTimeMillis; } @Override @@ -97,6 +150,8 @@ public final class RequestLogRecord implements Parcelable { int _hash = 1; _hash = 31 * _hash + java.util.Objects.hashCode(mRows); + _hash = 31 * _hash + Long.hashCode(mRequestId); + _hash = 31 * _hash + Long.hashCode(mTimeMillis); return _hash; } @@ -107,6 +162,8 @@ public final class RequestLogRecord implements Parcelable { // void parcelFieldName(Parcel dest, int flags) { ... } dest.writeParcelableList(mRows, flags); + dest.writeLong(mRequestId); + dest.writeLong(mTimeMillis); } @Override @@ -122,10 +179,14 @@ public final class RequestLogRecord implements Parcelable { List<ContentValues> rows = new java.util.ArrayList<>(); in.readParcelableList(rows, ContentValues.class.getClassLoader()); + long requestId = in.readLong(); + long timeMillis = in.readLong(); this.mRows = rows; AnnotationValidations.validate( NonNull.class, null, mRows); + this.mRequestId = requestId; + this.mTimeMillis = timeMillis; // onConstructed(); // You can define this method to get a callback } @@ -147,11 +208,14 @@ public final class RequestLogRecord implements Parcelable { /** * A builder for {@link RequestLogRecord} */ + @FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member - public static final class Builder { + public static final class Builder extends BaseBuilder { private @NonNull List<ContentValues> mRows; + private long mRequestId; + private long mTimeMillis; private long mBuilderFieldsSet = 0L; @@ -177,21 +241,56 @@ public final class RequestLogRecord implements Parcelable { return this; } + /** + * Internal id for the RequestLogRecord. + * + * @hide + */ + @DataClass.Generated.Member + public @NonNull Builder setRequestId(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mRequestId = value; + return this; + } + + /** + * Time of the request in milliseconds + * + * @hide + */ + @DataClass.Generated.Member + @Override + public @NonNull Builder setTimeMillis(long value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mTimeMillis = value; + return this; + } + /** Builds the instance. This builder should not be touched after calling this! */ public @NonNull RequestLogRecord build() { checkNotUsed(); - mBuilderFieldsSet |= 0x2; // Mark builder used + mBuilderFieldsSet |= 0x8; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mRows = Collections.emptyList(); } + if ((mBuilderFieldsSet & 0x2) == 0) { + mRequestId = 0; + } + if ((mBuilderFieldsSet & 0x4) == 0) { + mTimeMillis = 0; + } RequestLogRecord o = new RequestLogRecord( - mRows); + mRows, + mRequestId, + mTimeMillis); return o; } private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x2) != 0) { + if ((mBuilderFieldsSet & 0x8) != 0) { throw new IllegalStateException( "This Builder should not be reused. Use a new Builder instance instead"); } @@ -199,10 +298,10 @@ public final class RequestLogRecord implements Parcelable { } @DataClass.Generated( - time = 1692118422731L, + time = 1696978492795L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/RequestLogRecord.java", - inputSignatures = " @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"row\") @android.annotation.NonNull java.util.List<android.content.ContentValues> mRows\nclass RequestLogRecord extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = " @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"row\") @android.annotation.NonNull java.util.List<android.content.ContentValues> mRows\nprivate long mRequestId\nprivate long mTimeMillis\nclass RequestLogRecord extends java.lang.Object implements [android.os.Parcelable]\npublic abstract android.adservices.ondevicepersonalization.RequestLogRecord.Builder setTimeMillis(long)\nclass BaseBuilder extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)\npublic abstract android.adservices.ondevicepersonalization.RequestLogRecord.Builder setTimeMillis(long)\nclass BaseBuilder extends java.lang.Object implements []") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/RequestToken.java b/framework/java/android/adservices/ondevicepersonalization/RequestToken.java index 3967c9ef..d3e23310 100644 --- a/framework/java/android/adservices/ondevicepersonalization/RequestToken.java +++ b/framework/java/android/adservices/ondevicepersonalization/RequestToken.java @@ -16,38 +16,65 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + import android.adservices.ondevicepersonalization.aidl.IDataAccessService; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; +import android.annotation.FlaggedApi; import android.annotation.NonNull; import android.annotation.Nullable; +import android.os.SystemClock; import java.util.Objects; /** - * An opaque token that identifies the current request to an - * {@link IsolatedService}. This token must be passed as a parameter to all service - * methods that depend on per-request state. - * - * @hide + * An opaque token that identifies the current request to an {@link IsolatedService}. This token + * must be passed as a parameter to all service methods that depend on per-request state. */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class RequestToken { - @NonNull private IDataAccessService mDataAccessService; - @Nullable private UserData mUserData; + @NonNull + private final IDataAccessService mDataAccessService; + + @Nullable + private final IFederatedComputeService mFcService; + + @Nullable + private final UserData mUserData; + + private final long mStartTimeMillis; /** @hide */ RequestToken( @NonNull IDataAccessService binder, + @Nullable IFederatedComputeService fcServiceBinder, @Nullable UserData userData) { mDataAccessService = Objects.requireNonNull(binder); + mFcService = fcServiceBinder; mUserData = userData; + mStartTimeMillis = SystemClock.elapsedRealtime(); } /** @hide */ - @NonNull IDataAccessService getDataAccessService() { + @NonNull + IDataAccessService getDataAccessService() { return mDataAccessService; } /** @hide */ - @Nullable UserData getUserData() { + @Nullable + IFederatedComputeService getFederatedComputeService() { + return mFcService; + } + + /** @hide */ + @Nullable + UserData getUserData() { return mUserData; } + + /** @hide */ + long getStartTimeMillis() { + return mStartTimeMillis; + } } diff --git a/framework/java/android/adservices/ondevicepersonalization/SurfacePackageToken.java b/framework/java/android/adservices/ondevicepersonalization/SurfacePackageToken.java index 216063a2..48406c3d 100644 --- a/framework/java/android/adservices/ondevicepersonalization/SurfacePackageToken.java +++ b/framework/java/android/adservices/ondevicepersonalization/SurfacePackageToken.java @@ -16,14 +16,17 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; + +import android.annotation.FlaggedApi; import android.annotation.NonNull; /** - * An opaque reference to content that can be displayed in a {@link SurfaceView}. This maps - * to a {@link RenderingConfig} returned by an {@link IsolatedService}. + * An opaque reference to content that can be displayed in a {@link android.view.SurfaceView}. This + * maps to a {@link RenderingConfig} returned by an {@link IsolatedService}. * - * @hide */ +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) public class SurfacePackageToken { @NonNull private final String mTokenString; diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.aidl b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.aidl new file mode 100644 index 00000000..7d841a50 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable TrainingExampleInput; diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.java b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.java new file mode 100644 index 00000000..412f99aa --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +import java.util.function.Consumer; + +/** + * The input data for {@link IsolatedWorker#onTrainingExample(TrainingExampleInput, Consumer)} + * + * @hide + */ +@DataClass(genHiddenBuilder = true, genEqualsHashCode = true) +public final class TrainingExampleInput implements Parcelable { + /** The name of the federated compute population. */ + @NonNull private String mPopulationName = ""; + + /** + * The name of the task within the population. One population may have multiple tasks. + * The task name can be used to uniquely identify the job. + */ + @NonNull private String mTaskName = ""; + + /** Token used to support the resumption of training. */ + @Nullable private byte[] mResumptionToken = null; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ TrainingExampleInput( + @NonNull String populationName, + @NonNull String taskName, + @Nullable byte[] resumptionToken) { + this.mPopulationName = populationName; + AnnotationValidations.validate( + NonNull.class, null, mPopulationName); + this.mTaskName = taskName; + AnnotationValidations.validate( + NonNull.class, null, mTaskName); + this.mResumptionToken = resumptionToken; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The name of the federated compute population. + */ + @DataClass.Generated.Member + public @NonNull String getPopulationName() { + return mPopulationName; + } + + /** + * The name of the task within the population. One population may have multiple tasks. + * The task name can be used to uniquely identify the job. + */ + @DataClass.Generated.Member + public @NonNull String getTaskName() { + return mTaskName; + } + + /** + * Token used to support the resumption of training. + */ + @DataClass.Generated.Member + public @Nullable byte[] getResumptionToken() { + return mResumptionToken; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(TrainingExampleInput other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + TrainingExampleInput that = (TrainingExampleInput) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mPopulationName, that.mPopulationName) + && java.util.Objects.equals(mTaskName, that.mTaskName) + && java.util.Arrays.equals(mResumptionToken, that.mResumptionToken); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mPopulationName); + _hash = 31 * _hash + java.util.Objects.hashCode(mTaskName); + _hash = 31 * _hash + java.util.Arrays.hashCode(mResumptionToken); + return _hash; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + dest.writeString(mPopulationName); + dest.writeString(mTaskName); + dest.writeByteArray(mResumptionToken); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + /* package-private */ TrainingExampleInput(@NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + String populationName = in.readString(); + String taskName = in.readString(); + byte[] resumptionToken = in.createByteArray(); + + this.mPopulationName = populationName; + AnnotationValidations.validate( + NonNull.class, null, mPopulationName); + this.mTaskName = taskName; + AnnotationValidations.validate( + NonNull.class, null, mTaskName); + this.mResumptionToken = resumptionToken; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @NonNull Parcelable.Creator<TrainingExampleInput> CREATOR + = new Parcelable.Creator<TrainingExampleInput>() { + @Override + public TrainingExampleInput[] newArray(int size) { + return new TrainingExampleInput[size]; + } + + @Override + public TrainingExampleInput createFromParcel(@NonNull android.os.Parcel in) { + return new TrainingExampleInput(in); + } + }; + + /** + * A builder for {@link TrainingExampleInput} + * @hide + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private @NonNull String mPopulationName; + private @NonNull String mTaskName; + private @Nullable byte[] mResumptionToken; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * The name of the federated compute population. + */ + @DataClass.Generated.Member + public @NonNull Builder setPopulationName(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mPopulationName = value; + return this; + } + + /** + * The name of the task within the population. One population may have multiple tasks. + * The task name can be used to uniquely identify the job. + */ + @DataClass.Generated.Member + public @NonNull Builder setTaskName(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mTaskName = value; + return this; + } + + /** + * Token used to support the resumption of training. + */ + @DataClass.Generated.Member + public @NonNull Builder setResumptionToken(@NonNull byte... value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mResumptionToken = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull TrainingExampleInput build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mPopulationName = ""; + } + if ((mBuilderFieldsSet & 0x2) == 0) { + mTaskName = ""; + } + if ((mBuilderFieldsSet & 0x4) == 0) { + mResumptionToken = null; + } + TrainingExampleInput o = new TrainingExampleInput( + mPopulationName, + mTaskName, + mResumptionToken); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x8) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1697577073626L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingExampleInput.java", + inputSignatures = "private @android.annotation.NonNull java.lang.String mPopulationName\nprivate @android.annotation.NonNull java.lang.String mTaskName\nprivate @android.annotation.Nullable byte[] mResumptionToken\nclass TrainingExampleInput extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genHiddenBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutput.java b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutput.java new file mode 100644 index 00000000..1a698686 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutput.java @@ -0,0 +1,234 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; + +import com.android.internal.util.Preconditions; +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +/** + * The output data of {@link IsolatedWorker#onTrainingExample(TrainingExampleInput, Consumer)} + * + * @hide + */ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public final class TrainingExampleOutput { + /** + * A list of training example byte arrays. The format is a binary serialized + * <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto"> + * tensorflow.Example</a> proto. The maximum allowed example size is 50KB. + */ + @NonNull + @DataClass.PluralOf("trainingExample") + private List<byte[]> mTrainingExamples = Collections.emptyList(); + + /** + * A list of resumption token byte arrays corresponding to training examples. The last + * processed example's corresponding resumption token will be passed to + * {@link IsolatedWorker#onTrainingExample(TrainingExampleInput, Consumer)} to support + * resumption. The length of this list must match the length {@link #getTrainingExamples()}. + */ + @NonNull + @DataClass.PluralOf("resumptionToken") + private List<byte[]> mResumptionTokens = Collections.emptyList(); + + + private void onConstructed() { + Preconditions.checkArgument(mTrainingExamples.size() == mResumptionTokens.size()); + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutput.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ TrainingExampleOutput( + @NonNull List<byte[]> trainingExamples, + @NonNull List<byte[]> resumptionTokens) { + this.mTrainingExamples = trainingExamples; + AnnotationValidations.validate( + NonNull.class, null, mTrainingExamples); + this.mResumptionTokens = resumptionTokens; + AnnotationValidations.validate( + NonNull.class, null, mResumptionTokens); + + onConstructed(); + } + + /** + * A list of training example byte arrays. The format is a binary serialized + * <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto"> + * tensorflow.Example</a> proto. The maximum allowed example size is 50KB. + */ + @DataClass.Generated.Member + public @NonNull List<byte[]> getTrainingExamples() { + return mTrainingExamples; + } + + /** + * A list of resumption token byte arrays corresponding to training examples. The last + * processed example's corresponding resumption token will be passed to + * {@link IsolatedWorker#onTrainingExample(TrainingExampleInput, Consumer)} to support + * resumption. The length of this list must match the length {@link #getTrainingExamples()}. + */ + @DataClass.Generated.Member + public @NonNull List<byte[]> getResumptionTokens() { + return mResumptionTokens; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(TrainingExampleOutput other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + TrainingExampleOutput that = (TrainingExampleOutput) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mTrainingExamples, that.mTrainingExamples) + && java.util.Objects.equals(mResumptionTokens, that.mResumptionTokens); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mTrainingExamples); + _hash = 31 * _hash + java.util.Objects.hashCode(mResumptionTokens); + return _hash; + } + + /** + * A builder for {@link TrainingExampleOutput} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private @NonNull List<byte[]> mTrainingExamples; + private @NonNull List<byte[]> mResumptionTokens; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * A list of training example byte arrays. The format is a binary serialized + * <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto"> + * tensorflow.Example</a> proto. The maximum allowed example size is 50KB. + */ + @DataClass.Generated.Member + public @NonNull Builder setTrainingExamples(@NonNull List<byte[]> value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mTrainingExamples = value; + return this; + } + + /** @see #setTrainingExamples */ + @DataClass.Generated.Member + public @NonNull Builder addTrainingExample(@NonNull byte[] value) { + if (mTrainingExamples == null) setTrainingExamples(new java.util.ArrayList<>()); + mTrainingExamples.add(value); + return this; + } + + /** + * A list of resumption token byte arrays corresponding to training examples. The last + * processed example's corresponding resumption token will be passed to + * {@link IsolatedWorker#onTrainingExample(TrainingExampleInput, Consumer)} to support + * resumption. The length of this list must match the length {@link #getTrainingExamples()}. + */ + @DataClass.Generated.Member + public @NonNull Builder setResumptionTokens(@NonNull List<byte[]> value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mResumptionTokens = value; + return this; + } + + /** @see #setResumptionTokens */ + @DataClass.Generated.Member + public @NonNull Builder addResumptionToken(@NonNull byte[] value) { + if (mResumptionTokens == null) setResumptionTokens(new java.util.ArrayList<>()); + mResumptionTokens.add(value); + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull TrainingExampleOutput build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mTrainingExamples = Collections.emptyList(); + } + if ((mBuilderFieldsSet & 0x2) == 0) { + mResumptionTokens = Collections.emptyList(); + } + TrainingExampleOutput o = new TrainingExampleOutput( + mTrainingExamples, + mResumptionTokens); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x4) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1697575854959L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutput.java", + inputSignatures = "private @android.annotation.NonNull @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"trainingExample\") java.util.List<byte[]> mTrainingExamples\nprivate @android.annotation.NonNull @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"resumptionToken\") java.util.List<byte[]> mResumptionTokens\nprivate void onConstructed()\nclass TrainingExampleOutput extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.aidl b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.aidl new file mode 100644 index 00000000..699e9c59 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable TrainingExampleOutputParcel; diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.java b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.java new file mode 100644 index 00000000..8bbd518b --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.java @@ -0,0 +1,234 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.Nullable; +import android.os.Parcelable; + +import com.android.ondevicepersonalization.internal.util.ByteArrayParceledListSlice; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Parcelable version of {@link TrainingExampleOutput} + * + * @hide + */ +@DataClass(genHiddenBuilder = true, genEqualsHashCode = true) +public class TrainingExampleOutputParcel implements Parcelable { + /** List of training examples */ + @Nullable + ByteArrayParceledListSlice mTrainingExamples = null; + + /** List of resumption tokens */ + @Nullable + ByteArrayParceledListSlice mResumptionTokens = null; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @DataClass.Generated.Member + /* package-private */ TrainingExampleOutputParcel( + @Nullable ByteArrayParceledListSlice trainingExamples, + @Nullable ByteArrayParceledListSlice resumptionTokens) { + this.mTrainingExamples = trainingExamples; + this.mResumptionTokens = resumptionTokens; + + // onConstructed(); // You can define this method to get a callback + } + + /** + * List of training examples + */ + @DataClass.Generated.Member + public @Nullable ByteArrayParceledListSlice getTrainingExamples() { + return mTrainingExamples; + } + + /** + * List of resumption tokens + */ + @DataClass.Generated.Member + public @Nullable ByteArrayParceledListSlice getResumptionTokens() { + return mResumptionTokens; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(TrainingExampleOutputParcel other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + TrainingExampleOutputParcel that = (TrainingExampleOutputParcel) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mTrainingExamples, that.mTrainingExamples) + && java.util.Objects.equals(mResumptionTokens, that.mResumptionTokens); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mTrainingExamples); + _hash = 31 * _hash + java.util.Objects.hashCode(mResumptionTokens); + return _hash; + } + + @Override + @DataClass.Generated.Member + public void writeToParcel(@android.annotation.NonNull android.os.Parcel dest, int flags) { + // You can override field parcelling by defining methods like: + // void parcelFieldName(Parcel dest, int flags) { ... } + + byte flg = 0; + if (mTrainingExamples != null) flg |= 0x1; + if (mResumptionTokens != null) flg |= 0x2; + dest.writeByte(flg); + if (mTrainingExamples != null) dest.writeTypedObject(mTrainingExamples, flags); + if (mResumptionTokens != null) dest.writeTypedObject(mResumptionTokens, flags); + } + + @Override + @DataClass.Generated.Member + public int describeContents() { return 0; } + + /** @hide */ + @SuppressWarnings({"unchecked", "RedundantCast"}) + @DataClass.Generated.Member + protected TrainingExampleOutputParcel(@android.annotation.NonNull android.os.Parcel in) { + // You can override field unparcelling by defining methods like: + // static FieldType unparcelFieldName(Parcel in) { ... } + + byte flg = in.readByte(); + ByteArrayParceledListSlice trainingExamples = (flg & 0x1) == 0 ? null : (ByteArrayParceledListSlice) in.readTypedObject(ByteArrayParceledListSlice.CREATOR); + ByteArrayParceledListSlice resumptionTokens = (flg & 0x2) == 0 ? null : (ByteArrayParceledListSlice) in.readTypedObject(ByteArrayParceledListSlice.CREATOR); + + this.mTrainingExamples = trainingExamples; + this.mResumptionTokens = resumptionTokens; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public static final @android.annotation.NonNull Parcelable.Creator<TrainingExampleOutputParcel> CREATOR + = new Parcelable.Creator<TrainingExampleOutputParcel>() { + @Override + public TrainingExampleOutputParcel[] newArray(int size) { + return new TrainingExampleOutputParcel[size]; + } + + @Override + public TrainingExampleOutputParcel createFromParcel(@android.annotation.NonNull android.os.Parcel in) { + return new TrainingExampleOutputParcel(in); + } + }; + + /** + * A builder for {@link TrainingExampleOutputParcel} + * @hide + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static class Builder { + + private @Nullable ByteArrayParceledListSlice mTrainingExamples; + private @Nullable ByteArrayParceledListSlice mResumptionTokens; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * List of training examples + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setTrainingExamples(@android.annotation.NonNull ByteArrayParceledListSlice value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mTrainingExamples = value; + return this; + } + + /** + * List of resumption tokens + */ + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setResumptionTokens(@android.annotation.NonNull ByteArrayParceledListSlice value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mResumptionTokens = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @android.annotation.NonNull TrainingExampleOutputParcel build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mTrainingExamples = null; + } + if ((mBuilderFieldsSet & 0x2) == 0) { + mResumptionTokens = null; + } + TrainingExampleOutputParcel o = new TrainingExampleOutputParcel( + mTrainingExamples, + mResumptionTokens); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x4) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1695743444776L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingExampleOutputParcel.java", + inputSignatures = " @android.annotation.Nullable com.android.ondevicepersonalization.internal.util.ByteArrayParceledListSlice mTrainingExamples\n @android.annotation.Nullable com.android.ondevicepersonalization.internal.util.ByteArrayParceledListSlice mResumptionTokens\nclass TrainingExampleOutputParcel extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genHiddenBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.aidl b/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.aidl new file mode 100644 index 00000000..610e5808 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.aidl @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +parcelable TrainingInterval;
\ No newline at end of file diff --git a/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.java b/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.java new file mode 100644 index 00000000..eba12de1 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.java @@ -0,0 +1,263 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import android.annotation.NonNull; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +import java.time.Duration; + +/** + * Training interval settings required for federated computation jobs. + * + * @hide + */ +@DataClass(genBuilder = true, genHiddenConstDefs = true, genEqualsHashCode = true) +public final class TrainingInterval { + /** + * The scheduling mode for a one-off task. + */ + public static final int SCHEDULING_MODE_ONE_TIME = 1; + + /** + * The scheduling mode for a task that will be rescheduled after each run. + */ + public static final int SCHEDULING_MODE_RECURRENT = 2; + + + /** + * The scheduling mode for this task, either {@link #SCHEDULING_MODE_ONE_TIME} or + * {@link #SCHEDULING_MODE_RECURRENT}. The default scheduling mode is + * {@link #SCHEDULING_MODE_ONE_TIME} if unspecified. + */ + @SchedulingMode private int mSchedulingMode = SCHEDULING_MODE_ONE_TIME; + + /** + * Sets the minimum time interval between two training runs. + * + * <p>This field will only be used when the scheduling mode is + * {@link #SCHEDULING_MODE_RECURRENT}. Only positive values are accepted, zero or negative + * values will result in IllegalArgumentException. + * + * <p>Please also note this value is advisory, which does not guarantee the job will be run + * immediately after the interval expired. Federated compute will still enforce a minimum + * required interval and training constraints to ensure system health. The current training + * constraints are device on unmetered network, idle and battery not low. + */ + @NonNull private Duration mMinimumInterval = Duration.ZERO; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + /** @hide */ + @android.annotation.IntDef(prefix = "SCHEDULING_MODE_", value = { + SCHEDULING_MODE_ONE_TIME, + SCHEDULING_MODE_RECURRENT + }) + @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.SOURCE) + @DataClass.Generated.Member + public @interface SchedulingMode {} + + /** @hide */ + @DataClass.Generated.Member + public static String schedulingModeToString(@SchedulingMode int value) { + switch (value) { + case SCHEDULING_MODE_ONE_TIME: + return "SCHEDULING_MODE_ONE_TIME"; + case SCHEDULING_MODE_RECURRENT: + return "SCHEDULING_MODE_RECURRENT"; + default: return Integer.toHexString(value); + } + } + + @DataClass.Generated.Member + /* package-private */ TrainingInterval( + @SchedulingMode int schedulingMode, + @NonNull Duration minimumInterval) { + this.mSchedulingMode = schedulingMode; + + if (!(mSchedulingMode == SCHEDULING_MODE_ONE_TIME) + && !(mSchedulingMode == SCHEDULING_MODE_RECURRENT)) { + throw new java.lang.IllegalArgumentException( + "schedulingMode was " + mSchedulingMode + " but must be one of: " + + "SCHEDULING_MODE_ONE_TIME(" + SCHEDULING_MODE_ONE_TIME + "), " + + "SCHEDULING_MODE_RECURRENT(" + SCHEDULING_MODE_RECURRENT + ")"); + } + + this.mMinimumInterval = minimumInterval; + AnnotationValidations.validate( + NonNull.class, null, mMinimumInterval); + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The scheduling mode for this task, either {@link #SCHEDULING_MODE_ONE_TIME} or + * {@link #SCHEDULING_MODE_RECURRENT}. The default scheduling mode is + * {@link #SCHEDULING_MODE_ONE_TIME} if unspecified. + */ + @DataClass.Generated.Member + public @SchedulingMode int getSchedulingMode() { + return mSchedulingMode; + } + + /** + * Sets the minimum time interval between two training runs. + * + * <p>This field will only be used when the scheduling mode is + * {@link #SCHEDULING_MODE_RECURRENT}. Only positive values are accepted, zero or negative + * values will result in IllegalArgumentException. + * + * <p>Please also note this value is advisory, which does not guarantee the job will be run + * immediately after the interval expired. Federated compute will still enforce a minimum + * required interval and training constraints to ensure system health. The current training + * constraints are device on unmetered network, idle and battery not low. + */ + @DataClass.Generated.Member + public @NonNull Duration getMinimumInterval() { + return mMinimumInterval; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(TrainingInterval other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + TrainingInterval that = (TrainingInterval) o; + //noinspection PointlessBooleanExpression + return true + && mSchedulingMode == that.mSchedulingMode + && java.util.Objects.equals(mMinimumInterval, that.mMinimumInterval); + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + mSchedulingMode; + _hash = 31 * _hash + java.util.Objects.hashCode(mMinimumInterval); + return _hash; + } + + /** + * A builder for {@link TrainingInterval} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static final class Builder { + + private @SchedulingMode int mSchedulingMode; + private @NonNull Duration mMinimumInterval; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * The scheduling mode for this task, either {@link #SCHEDULING_MODE_ONE_TIME} or + * {@link #SCHEDULING_MODE_RECURRENT}. The default scheduling mode is + * {@link #SCHEDULING_MODE_ONE_TIME} if unspecified. + */ + @DataClass.Generated.Member + public @NonNull Builder setSchedulingMode(@SchedulingMode int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mSchedulingMode = value; + return this; + } + + /** + * Sets the minimum time interval between two training runs. + * + * <p>This field will only be used when the scheduling mode is + * {@link #SCHEDULING_MODE_RECURRENT}. Only positive values are accepted, zero or negative + * values will result in IllegalArgumentException. + * + * <p>Please also note this value is advisory, which does not guarantee the job will be run + * immediately after the interval expired. Federated compute will still enforce a minimum + * required interval and training constraints to ensure system health. The current training + * constraints are device on unmetered network, idle and battery not low. + */ + @DataClass.Generated.Member + public @NonNull Builder setMinimumInterval(@NonNull Duration value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mMinimumInterval = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull TrainingInterval build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mSchedulingMode = SCHEDULING_MODE_ONE_TIME; + } + if ((mBuilderFieldsSet & 0x2) == 0) { + mMinimumInterval = Duration.ZERO; + } + TrainingInterval o = new TrainingInterval( + mSchedulingMode, + mMinimumInterval); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x4) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1697653739724L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/TrainingInterval.java", + inputSignatures = "public static final int SCHEDULING_MODE_ONE_TIME\npublic static final int SCHEDULING_MODE_RECURRENT\nprivate @android.adservices.ondevicepersonalization.TrainingInterval.SchedulingMode int mSchedulingMode\nprivate @android.annotation.NonNull java.time.Duration mMinimumInterval\nclass TrainingInterval extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genHiddenConstDefs=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/framework/java/android/adservices/ondevicepersonalization/UserData.java b/framework/java/android/adservices/ondevicepersonalization/UserData.java index eeb8cd50..644ea74a 100644 --- a/framework/java/android/adservices/ondevicepersonalization/UserData.java +++ b/framework/java/android/adservices/ondevicepersonalization/UserData.java @@ -16,16 +16,27 @@ package android.adservices.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.Constants.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; +import static android.content.res.Configuration.ORIENTATION_LANDSCAPE; +import static android.content.res.Configuration.ORIENTATION_PORTRAIT; +import static android.content.res.Configuration.ORIENTATION_SQUARE; +import static android.content.res.Configuration.ORIENTATION_UNDEFINED; + +import android.annotation.FlaggedApi; import android.annotation.IntDef; import android.annotation.IntRange; import android.annotation.NonNull; +import android.annotation.Nullable; +import android.net.NetworkCapabilities; import android.os.Parcelable; +import android.telephony.TelephonyManager; import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; +import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Map; @@ -33,21 +44,36 @@ import java.util.Map; /** * User data provided by the platform to an {@link IsolatedService}. * - * @hide */ // This class should be updated with the Kotlin mirror // {@link com.android.ondevicepersonalization.services.policyengine.data.UserData}. -@DataClass(genBuilder = true, genEqualsHashCode = true) +@FlaggedApi(KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS) +@DataClass(genHiddenBuilder = true, genEqualsHashCode = true, genConstDefs = false) public final class UserData implements Parcelable { - /** The device timezone +/- minutes offset from UTC. */ + /** + * The device timezone +/- offset from UTC. + * + * @hide + */ int mTimezoneUtcOffsetMins = 0; + /** @hide **/ + @IntDef(prefix = {"ORIENTATION_"}, value = { + ORIENTATION_UNDEFINED, + ORIENTATION_PORTRAIT, + ORIENTATION_LANDSCAPE, + ORIENTATION_SQUARE + }) + @Retention(RetentionPolicy.SOURCE) + public @interface Orientation { + } + /** * The device orientation. The value can be one of the constants ORIENTATION_UNDEFINED, * ORIENTATION_PORTRAIT or ORIENTATION_LANDSCAPE defined in * {@link android.content.res.Configuration}. */ - int mOrientation = 0; + @Orientation int mOrientation = 0; /** The available space on device in bytes. */ @IntRange(from = 0) long mAvailableStorageBytes = 0; @@ -55,35 +81,59 @@ public final class UserData implements Parcelable { /** Battery percentage. */ @IntRange(from = 0, to = 100) int mBatteryPercentage = 0; - /** The name of the carrier. */ + /** The Service Provider Name (SPN) returned by {@link TelephonyManager#getSimOperatorName()} */ @NonNull String mCarrier = ""; - /** Connection type unknown. @hide */ - public static final int CONNECTION_TYPE_UNKNOWN = 0; - /** Connection type ethernet. @hide */ - public static final int CONNECTION_TYPE_ETHERNET = 1; - /** Connection type wifi. @hide */ - public static final int CONNECTION_TYPE_WIFI = 2; - /** Connection type cellular 2G. @hide */ - public static final int CONNECTION_TYPE_CELLULAR_2G = 3; - /** Connection type cellular 3G. @hide */ - public static final int CONNECTION_TYPE_CELLULAR_3G = 4; - /** Connection type cellular 4G. @hide */ - public static final int CONNECTION_TYPE_CELLULAR_4G = 5; - /** Connection type cellular 5G. @hide */ - public static final int CONNECTION_TYPE_CELLULAR_5G = 6; - - /** Connection types. @hide */ - @ConnectionType int mConnectionType = 0; - - /** Network connection speed in kbps. 0 if no network connection is present. @hide */ - @IntRange(from = 0) long mNetworkConnectionSpeedKbps = 0; - - /** Whether the network is metered. False - not metered. True - metered. @hide */ - boolean mNetworkMetered = false; - - /** The history of installed/uninstalled packages. */ - @NonNull Map<String, AppInstallInfo> mAppInstallInfo = Collections.emptyMap(); + /** @hide **/ + @IntDef({ + TelephonyManager.NETWORK_TYPE_UNKNOWN, + TelephonyManager.NETWORK_TYPE_GPRS, + TelephonyManager.NETWORK_TYPE_EDGE, + TelephonyManager.NETWORK_TYPE_UMTS, + TelephonyManager.NETWORK_TYPE_CDMA, + TelephonyManager.NETWORK_TYPE_EVDO_0, + TelephonyManager.NETWORK_TYPE_EVDO_A, + TelephonyManager.NETWORK_TYPE_1xRTT, + TelephonyManager.NETWORK_TYPE_HSDPA, + TelephonyManager.NETWORK_TYPE_HSUPA, + TelephonyManager.NETWORK_TYPE_HSPA, + TelephonyManager.NETWORK_TYPE_EVDO_B, + TelephonyManager.NETWORK_TYPE_LTE, + TelephonyManager.NETWORK_TYPE_EHRPD, + TelephonyManager.NETWORK_TYPE_HSPAP, + TelephonyManager.NETWORK_TYPE_GSM, + TelephonyManager.NETWORK_TYPE_TD_SCDMA, + TelephonyManager.NETWORK_TYPE_IWLAN, + + //TODO: In order for @SystemApi methods to use this class, there cannot be any + // public hidden members. This network type is marked as hidden because it is not a + // true network type and we are looking to remove it completely from the available list + // of network types. + //TelephonyManager.NETWORK_TYPE_LTE_CA, + + TelephonyManager.NETWORK_TYPE_NR, + }) + @Retention(RetentionPolicy.SOURCE) + public @interface NetworkType { + } + + /** + * Network capabilities of the device. This is the value of + * {@link android.net.ConnectivityManager#getNetworkCapabilities(android.net.Network)}. + * @hide + */ + @Nullable NetworkCapabilities mNetworkCapabilities = null; + + /** + * Data network type. This is the value of + * {@link android.telephony.TelephonyManager#getDataNetworkType()}. + * @hide + */ + @NetworkType int mDataNetworkType = 0; + + /** A map from package name to app information for installed and uninstalled apps. */ + @DataClass.PluralOf("appInfo") + @NonNull Map<String, AppInfo> mAppInfos = Collections.emptyMap(); /** The app usage history in the last 30 days, sorted by total time spent. @hide */ @NonNull List<AppUsageStatus> mAppUsageHistory = Collections.emptyList(); @@ -94,6 +144,11 @@ public final class UserData implements Parcelable { /** The location history in last 30 days, sorted by the stay duration. @hide */ @NonNull List<LocationStatus> mLocationHistory = Collections.emptyList(); + /** The device timezone +/- offset from UTC in {@link Duration}. @hide */ + @NonNull public Duration getTimezoneUtcOffset() { + return Duration.ofMinutes(mTimezoneUtcOffsetMins); + } + // Code below generated by codegen v1.0.23. @@ -109,58 +164,23 @@ public final class UserData implements Parcelable { //@formatter:off - /** @hide */ - @IntDef(prefix = "CONNECTION_TYPE_", value = { - CONNECTION_TYPE_UNKNOWN, - CONNECTION_TYPE_ETHERNET, - CONNECTION_TYPE_WIFI, - CONNECTION_TYPE_CELLULAR_2G, - CONNECTION_TYPE_CELLULAR_3G, - CONNECTION_TYPE_CELLULAR_4G, - CONNECTION_TYPE_CELLULAR_5G - }) - @Retention(RetentionPolicy.SOURCE) - @DataClass.Generated.Member - public @interface ConnectionType {} - - /** @hide */ - @DataClass.Generated.Member - @NonNull public static String connectionTypeToString(@ConnectionType int value) { - switch (value) { - case CONNECTION_TYPE_UNKNOWN: - return "CONNECTION_TYPE_UNKNOWN"; - case CONNECTION_TYPE_ETHERNET: - return "CONNECTION_TYPE_ETHERNET"; - case CONNECTION_TYPE_WIFI: - return "CONNECTION_TYPE_WIFI"; - case CONNECTION_TYPE_CELLULAR_2G: - return "CONNECTION_TYPE_CELLULAR_2G"; - case CONNECTION_TYPE_CELLULAR_3G: - return "CONNECTION_TYPE_CELLULAR_3G"; - case CONNECTION_TYPE_CELLULAR_4G: - return "CONNECTION_TYPE_CELLULAR_4G"; - case CONNECTION_TYPE_CELLULAR_5G: - return "CONNECTION_TYPE_CELLULAR_5G"; - default: return Integer.toHexString(value); - } - } - @DataClass.Generated.Member /* package-private */ UserData( int timezoneUtcOffsetMins, - int orientation, + @Orientation int orientation, @IntRange(from = 0) long availableStorageBytes, @IntRange(from = 0, to = 100) int batteryPercentage, @NonNull String carrier, - @ConnectionType int connectionType, - @IntRange(from = 0) long networkConnectionSpeedKbps, - boolean networkMetered, - @NonNull Map<String,AppInstallInfo> appInstallInfo, + @Nullable NetworkCapabilities networkCapabilities, + @NetworkType int dataNetworkType, + @NonNull Map<String,AppInfo> appInfos, @NonNull List<AppUsageStatus> appUsageHistory, @NonNull Location currentLocation, @NonNull List<LocationStatus> locationHistory) { this.mTimezoneUtcOffsetMins = timezoneUtcOffsetMins; this.mOrientation = orientation; + AnnotationValidations.validate( + Orientation.class, null, mOrientation); this.mAvailableStorageBytes = availableStorageBytes; AnnotationValidations.validate( IntRange.class, null, mAvailableStorageBytes, @@ -173,34 +193,13 @@ public final class UserData implements Parcelable { this.mCarrier = carrier; AnnotationValidations.validate( NonNull.class, null, mCarrier); - this.mConnectionType = connectionType; - - if (!(mConnectionType == CONNECTION_TYPE_UNKNOWN) - && !(mConnectionType == CONNECTION_TYPE_ETHERNET) - && !(mConnectionType == CONNECTION_TYPE_WIFI) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_2G) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_3G) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_4G) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_5G)) { - throw new java.lang.IllegalArgumentException( - "connectionType was " + mConnectionType + " but must be one of: " - + "CONNECTION_TYPE_UNKNOWN(" + CONNECTION_TYPE_UNKNOWN + "), " - + "CONNECTION_TYPE_ETHERNET(" + CONNECTION_TYPE_ETHERNET + "), " - + "CONNECTION_TYPE_WIFI(" + CONNECTION_TYPE_WIFI + "), " - + "CONNECTION_TYPE_CELLULAR_2G(" + CONNECTION_TYPE_CELLULAR_2G + "), " - + "CONNECTION_TYPE_CELLULAR_3G(" + CONNECTION_TYPE_CELLULAR_3G + "), " - + "CONNECTION_TYPE_CELLULAR_4G(" + CONNECTION_TYPE_CELLULAR_4G + "), " - + "CONNECTION_TYPE_CELLULAR_5G(" + CONNECTION_TYPE_CELLULAR_5G + ")"); - } - - this.mNetworkConnectionSpeedKbps = networkConnectionSpeedKbps; + this.mNetworkCapabilities = networkCapabilities; + this.mDataNetworkType = dataNetworkType; AnnotationValidations.validate( - IntRange.class, null, mNetworkConnectionSpeedKbps, - "from", 0); - this.mNetworkMetered = networkMetered; - this.mAppInstallInfo = appInstallInfo; + NetworkType.class, null, mDataNetworkType); + this.mAppInfos = appInfos; AnnotationValidations.validate( - NonNull.class, null, mAppInstallInfo); + NonNull.class, null, mAppInfos); this.mAppUsageHistory = appUsageHistory; AnnotationValidations.validate( NonNull.class, null, mAppUsageHistory); @@ -215,7 +214,9 @@ public final class UserData implements Parcelable { } /** - * The device timezone +/- minutes offset from UTC. + * The device timezone +/- offset from UTC. + * + * @hide */ @DataClass.Generated.Member public int getTimezoneUtcOffsetMins() { @@ -228,7 +229,7 @@ public final class UserData implements Parcelable { * {@link android.content.res.Configuration}. */ @DataClass.Generated.Member - public int getOrientation() { + public @Orientation int getOrientation() { return mOrientation; } @@ -249,7 +250,7 @@ public final class UserData implements Parcelable { } /** - * The name of the carrier. + * The Service Provider Name (SPN) returned by {@link TelephonyManager#getSimOperatorName()} */ @DataClass.Generated.Member public @NonNull String getCarrier() { @@ -257,35 +258,33 @@ public final class UserData implements Parcelable { } /** - * Connection types. @hide + * Network capabilities of the device. This is the value of + * {@link android.net.ConnectivityManager#getNetworkCapabilities(android.net.Network)}. + * + * @hide */ @DataClass.Generated.Member - public @ConnectionType int getConnectionType() { - return mConnectionType; + public @Nullable NetworkCapabilities getNetworkCapabilities() { + return mNetworkCapabilities; } /** - * Network connection speed in kbps. 0 if no network connection is present. @hide + * Data network type. This is the value of + * {@link android.telephony.TelephonyManager#getDataNetworkType()}. + * + * @hide */ @DataClass.Generated.Member - public @IntRange(from = 0) long getNetworkConnectionSpeedKbps() { - return mNetworkConnectionSpeedKbps; + public @NetworkType int getDataNetworkType() { + return mDataNetworkType; } /** - * Whether the network is metered. False - not metered. True - metered. @hide + * A map from package name to app information for installed and uninstalled apps. */ @DataClass.Generated.Member - public boolean isNetworkMetered() { - return mNetworkMetered; - } - - /** - * The history of installed/uninstalled packages. - */ - @DataClass.Generated.Member - public @NonNull Map<String,AppInstallInfo> getAppInstallInfo() { - return mAppInstallInfo; + public @NonNull Map<String,AppInfo> getAppInfos() { + return mAppInfos; } /** @@ -314,7 +313,7 @@ public final class UserData implements Parcelable { @Override @DataClass.Generated.Member - public boolean equals(@android.annotation.Nullable Object o) { + public boolean equals(@Nullable Object o) { // You can override field equality logic by defining either of the methods like: // boolean fieldNameEquals(UserData other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } @@ -330,10 +329,9 @@ public final class UserData implements Parcelable { && mAvailableStorageBytes == that.mAvailableStorageBytes && mBatteryPercentage == that.mBatteryPercentage && java.util.Objects.equals(mCarrier, that.mCarrier) - && mConnectionType == that.mConnectionType - && mNetworkConnectionSpeedKbps == that.mNetworkConnectionSpeedKbps - && mNetworkMetered == that.mNetworkMetered - && java.util.Objects.equals(mAppInstallInfo, that.mAppInstallInfo) + && java.util.Objects.equals(mNetworkCapabilities, that.mNetworkCapabilities) + && mDataNetworkType == that.mDataNetworkType + && java.util.Objects.equals(mAppInfos, that.mAppInfos) && java.util.Objects.equals(mAppUsageHistory, that.mAppUsageHistory) && java.util.Objects.equals(mCurrentLocation, that.mCurrentLocation) && java.util.Objects.equals(mLocationHistory, that.mLocationHistory); @@ -351,10 +349,9 @@ public final class UserData implements Parcelable { _hash = 31 * _hash + Long.hashCode(mAvailableStorageBytes); _hash = 31 * _hash + mBatteryPercentage; _hash = 31 * _hash + java.util.Objects.hashCode(mCarrier); - _hash = 31 * _hash + mConnectionType; - _hash = 31 * _hash + Long.hashCode(mNetworkConnectionSpeedKbps); - _hash = 31 * _hash + Boolean.hashCode(mNetworkMetered); - _hash = 31 * _hash + java.util.Objects.hashCode(mAppInstallInfo); + _hash = 31 * _hash + java.util.Objects.hashCode(mNetworkCapabilities); + _hash = 31 * _hash + mDataNetworkType; + _hash = 31 * _hash + java.util.Objects.hashCode(mAppInfos); _hash = 31 * _hash + java.util.Objects.hashCode(mAppUsageHistory); _hash = 31 * _hash + java.util.Objects.hashCode(mCurrentLocation); _hash = 31 * _hash + java.util.Objects.hashCode(mLocationHistory); @@ -368,16 +365,16 @@ public final class UserData implements Parcelable { // void parcelFieldName(Parcel dest, int flags) { ... } int flg = 0; - if (mNetworkMetered) flg |= 0x80; + if (mNetworkCapabilities != null) flg |= 0x20; dest.writeInt(flg); dest.writeInt(mTimezoneUtcOffsetMins); dest.writeInt(mOrientation); dest.writeLong(mAvailableStorageBytes); dest.writeInt(mBatteryPercentage); dest.writeString(mCarrier); - dest.writeInt(mConnectionType); - dest.writeLong(mNetworkConnectionSpeedKbps); - dest.writeMap(mAppInstallInfo); + if (mNetworkCapabilities != null) dest.writeTypedObject(mNetworkCapabilities, flags); + dest.writeInt(mDataNetworkType); + dest.writeMap(mAppInfos); dest.writeParcelableList(mAppUsageHistory, flags); dest.writeTypedObject(mCurrentLocation, flags); dest.writeParcelableList(mLocationHistory, flags); @@ -395,16 +392,15 @@ public final class UserData implements Parcelable { // static FieldType unparcelFieldName(Parcel in) { ... } int flg = in.readInt(); - boolean networkMetered = (flg & 0x80) != 0; int timezoneUtcOffsetMins = in.readInt(); int orientation = in.readInt(); long availableStorageBytes = in.readLong(); int batteryPercentage = in.readInt(); String carrier = in.readString(); - int connectionType = in.readInt(); - long networkConnectionSpeedKbps = in.readLong(); - Map<String,AppInstallInfo> appInstallInfo = new java.util.LinkedHashMap<>(); - in.readMap(appInstallInfo, AppInstallInfo.class.getClassLoader()); + NetworkCapabilities networkCapabilities = (flg & 0x20) == 0 ? null : (NetworkCapabilities) in.readTypedObject(NetworkCapabilities.CREATOR); + int dataNetworkType = in.readInt(); + Map<String,AppInfo> appInfos = new java.util.LinkedHashMap<>(); + in.readMap(appInfos, AppInfo.class.getClassLoader()); List<AppUsageStatus> appUsageHistory = new java.util.ArrayList<>(); in.readParcelableList(appUsageHistory, AppUsageStatus.class.getClassLoader()); Location currentLocation = (Location) in.readTypedObject(Location.CREATOR); @@ -413,6 +409,8 @@ public final class UserData implements Parcelable { this.mTimezoneUtcOffsetMins = timezoneUtcOffsetMins; this.mOrientation = orientation; + AnnotationValidations.validate( + Orientation.class, null, mOrientation); this.mAvailableStorageBytes = availableStorageBytes; AnnotationValidations.validate( IntRange.class, null, mAvailableStorageBytes, @@ -425,34 +423,13 @@ public final class UserData implements Parcelable { this.mCarrier = carrier; AnnotationValidations.validate( NonNull.class, null, mCarrier); - this.mConnectionType = connectionType; - - if (!(mConnectionType == CONNECTION_TYPE_UNKNOWN) - && !(mConnectionType == CONNECTION_TYPE_ETHERNET) - && !(mConnectionType == CONNECTION_TYPE_WIFI) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_2G) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_3G) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_4G) - && !(mConnectionType == CONNECTION_TYPE_CELLULAR_5G)) { - throw new java.lang.IllegalArgumentException( - "connectionType was " + mConnectionType + " but must be one of: " - + "CONNECTION_TYPE_UNKNOWN(" + CONNECTION_TYPE_UNKNOWN + "), " - + "CONNECTION_TYPE_ETHERNET(" + CONNECTION_TYPE_ETHERNET + "), " - + "CONNECTION_TYPE_WIFI(" + CONNECTION_TYPE_WIFI + "), " - + "CONNECTION_TYPE_CELLULAR_2G(" + CONNECTION_TYPE_CELLULAR_2G + "), " - + "CONNECTION_TYPE_CELLULAR_3G(" + CONNECTION_TYPE_CELLULAR_3G + "), " - + "CONNECTION_TYPE_CELLULAR_4G(" + CONNECTION_TYPE_CELLULAR_4G + "), " - + "CONNECTION_TYPE_CELLULAR_5G(" + CONNECTION_TYPE_CELLULAR_5G + ")"); - } - - this.mNetworkConnectionSpeedKbps = networkConnectionSpeedKbps; + this.mNetworkCapabilities = networkCapabilities; + this.mDataNetworkType = dataNetworkType; AnnotationValidations.validate( - IntRange.class, null, mNetworkConnectionSpeedKbps, - "from", 0); - this.mNetworkMetered = networkMetered; - this.mAppInstallInfo = appInstallInfo; + NetworkType.class, null, mDataNetworkType); + this.mAppInfos = appInfos; AnnotationValidations.validate( - NonNull.class, null, mAppInstallInfo); + NonNull.class, null, mAppInfos); this.mAppUsageHistory = appUsageHistory; AnnotationValidations.validate( NonNull.class, null, mAppUsageHistory); @@ -482,20 +459,20 @@ public final class UserData implements Parcelable { /** * A builder for {@link UserData} + * @hide */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { private int mTimezoneUtcOffsetMins; - private int mOrientation; + private @Orientation int mOrientation; private @IntRange(from = 0) long mAvailableStorageBytes; private @IntRange(from = 0, to = 100) int mBatteryPercentage; private @NonNull String mCarrier; - private @ConnectionType int mConnectionType; - private @IntRange(from = 0) long mNetworkConnectionSpeedKbps; - private boolean mNetworkMetered; - private @NonNull Map<String,AppInstallInfo> mAppInstallInfo; + private @Nullable NetworkCapabilities mNetworkCapabilities; + private @NetworkType int mDataNetworkType; + private @NonNull Map<String,AppInfo> mAppInfos; private @NonNull List<AppUsageStatus> mAppUsageHistory; private @NonNull Location mCurrentLocation; private @NonNull List<LocationStatus> mLocationHistory; @@ -506,7 +483,9 @@ public final class UserData implements Parcelable { } /** - * The device timezone +/- minutes offset from UTC. + * The device timezone +/- offset from UTC. + * + * @hide */ @DataClass.Generated.Member public @NonNull Builder setTimezoneUtcOffsetMins(int value) { @@ -522,7 +501,7 @@ public final class UserData implements Parcelable { * {@link android.content.res.Configuration}. */ @DataClass.Generated.Member - public @NonNull Builder setOrientation(int value) { + public @NonNull Builder setOrientation(@Orientation int value) { checkNotUsed(); mBuilderFieldsSet |= 0x2; mOrientation = value; @@ -552,7 +531,7 @@ public final class UserData implements Parcelable { } /** - * The name of the carrier. + * The Service Provider Name (SPN) returned by {@link TelephonyManager#getSimOperatorName()} */ @DataClass.Generated.Member public @NonNull Builder setCarrier(@NonNull String value) { @@ -563,46 +542,49 @@ public final class UserData implements Parcelable { } /** - * Connection types. @hide + * Network capabilities of the device. This is the value of + * {@link android.net.ConnectivityManager#getNetworkCapabilities(android.net.Network)}. + * + * @hide */ @DataClass.Generated.Member - public @NonNull Builder setConnectionType(@ConnectionType int value) { + public @NonNull Builder setNetworkCapabilities(@NonNull NetworkCapabilities value) { checkNotUsed(); mBuilderFieldsSet |= 0x20; - mConnectionType = value; + mNetworkCapabilities = value; return this; } /** - * Network connection speed in kbps. 0 if no network connection is present. @hide + * Data network type. This is the value of + * {@link android.telephony.TelephonyManager#getDataNetworkType()}. + * + * @hide */ @DataClass.Generated.Member - public @NonNull Builder setNetworkConnectionSpeedKbps(@IntRange(from = 0) long value) { + public @NonNull Builder setDataNetworkType(@NetworkType int value) { checkNotUsed(); mBuilderFieldsSet |= 0x40; - mNetworkConnectionSpeedKbps = value; + mDataNetworkType = value; return this; } /** - * Whether the network is metered. False - not metered. True - metered. @hide + * A map from package name to app information for installed and uninstalled apps. */ @DataClass.Generated.Member - public @NonNull Builder setNetworkMetered(boolean value) { + public @NonNull Builder setAppInfos(@NonNull Map<String,AppInfo> value) { checkNotUsed(); mBuilderFieldsSet |= 0x80; - mNetworkMetered = value; + mAppInfos = value; return this; } - /** - * The history of installed/uninstalled packages. - */ + /** @see #setAppInfos */ @DataClass.Generated.Member - public @NonNull Builder setAppInstallInfo(@NonNull Map<String,AppInstallInfo> value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x100; - mAppInstallInfo = value; + public @NonNull Builder addAppInfo(@NonNull String key, @NonNull AppInfo value) { + if (mAppInfos == null) setAppInfos(new java.util.LinkedHashMap()); + mAppInfos.put(key, value); return this; } @@ -612,18 +594,29 @@ public final class UserData implements Parcelable { @DataClass.Generated.Member public @NonNull Builder setAppUsageHistory(@NonNull List<AppUsageStatus> value) { checkNotUsed(); - mBuilderFieldsSet |= 0x200; + mBuilderFieldsSet |= 0x100; mAppUsageHistory = value; return this; } + /** @see #setAppUsageHistory */ + @DataClass.Generated.Member + public @NonNull Builder addAppUsageHistory(@NonNull AppUsageStatus value) { + // You can refine this method's name by providing item's singular name, e.g.: + // @DataClass.PluralOf("item")) mItems = ... + + if (mAppUsageHistory == null) setAppUsageHistory(new java.util.ArrayList<>()); + mAppUsageHistory.add(value); + return this; + } + /** * The most recently known location. @hide */ @DataClass.Generated.Member public @NonNull Builder setCurrentLocation(@NonNull Location value) { checkNotUsed(); - mBuilderFieldsSet |= 0x400; + mBuilderFieldsSet |= 0x200; mCurrentLocation = value; return this; } @@ -634,15 +627,26 @@ public final class UserData implements Parcelable { @DataClass.Generated.Member public @NonNull Builder setLocationHistory(@NonNull List<LocationStatus> value) { checkNotUsed(); - mBuilderFieldsSet |= 0x800; + mBuilderFieldsSet |= 0x400; mLocationHistory = value; return this; } + /** @see #setLocationHistory */ + @DataClass.Generated.Member + public @NonNull Builder addLocationHistory(@NonNull LocationStatus value) { + // You can refine this method's name by providing item's singular name, e.g.: + // @DataClass.PluralOf("item")) mItems = ... + + if (mLocationHistory == null) setLocationHistory(new java.util.ArrayList<>()); + mLocationHistory.add(value); + return this; + } + /** Builds the instance. This builder should not be touched after calling this! */ public @NonNull UserData build() { checkNotUsed(); - mBuilderFieldsSet |= 0x1000; // Mark builder used + mBuilderFieldsSet |= 0x800; // Mark builder used if ((mBuilderFieldsSet & 0x1) == 0) { mTimezoneUtcOffsetMins = 0; @@ -660,24 +664,21 @@ public final class UserData implements Parcelable { mCarrier = ""; } if ((mBuilderFieldsSet & 0x20) == 0) { - mConnectionType = 0; + mNetworkCapabilities = null; } if ((mBuilderFieldsSet & 0x40) == 0) { - mNetworkConnectionSpeedKbps = 0; + mDataNetworkType = 0; } if ((mBuilderFieldsSet & 0x80) == 0) { - mNetworkMetered = false; + mAppInfos = Collections.emptyMap(); } if ((mBuilderFieldsSet & 0x100) == 0) { - mAppInstallInfo = Collections.emptyMap(); - } - if ((mBuilderFieldsSet & 0x200) == 0) { mAppUsageHistory = Collections.emptyList(); } - if ((mBuilderFieldsSet & 0x400) == 0) { + if ((mBuilderFieldsSet & 0x200) == 0) { mCurrentLocation = Location.EMPTY; } - if ((mBuilderFieldsSet & 0x800) == 0) { + if ((mBuilderFieldsSet & 0x400) == 0) { mLocationHistory = Collections.emptyList(); } UserData o = new UserData( @@ -686,10 +687,9 @@ public final class UserData implements Parcelable { mAvailableStorageBytes, mBatteryPercentage, mCarrier, - mConnectionType, - mNetworkConnectionSpeedKbps, - mNetworkMetered, - mAppInstallInfo, + mNetworkCapabilities, + mDataNetworkType, + mAppInfos, mAppUsageHistory, mCurrentLocation, mLocationHistory); @@ -697,7 +697,7 @@ public final class UserData implements Parcelable { } private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x1000) != 0) { + if ((mBuilderFieldsSet & 0x800) != 0) { throw new IllegalStateException( "This Builder should not be reused. Use a new Builder instance instead"); } @@ -705,10 +705,10 @@ public final class UserData implements Parcelable { } @DataClass.Generated( - time = 1693528589621L, + time = 1697063796387L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/UserData.java", - inputSignatures = " int mTimezoneUtcOffsetMins\n int mOrientation\n @android.annotation.IntRange long mAvailableStorageBytes\n @android.annotation.IntRange int mBatteryPercentage\n @android.annotation.NonNull java.lang.String mCarrier\npublic static final int CONNECTION_TYPE_UNKNOWN\npublic static final int CONNECTION_TYPE_ETHERNET\npublic static final int CONNECTION_TYPE_WIFI\npublic static final int CONNECTION_TYPE_CELLULAR_2G\npublic static final int CONNECTION_TYPE_CELLULAR_3G\npublic static final int CONNECTION_TYPE_CELLULAR_4G\npublic static final int CONNECTION_TYPE_CELLULAR_5G\n @android.adservices.ondevicepersonalization.UserData.ConnectionType int mConnectionType\n @android.annotation.IntRange long mNetworkConnectionSpeedKbps\n boolean mNetworkMetered\n @android.annotation.NonNull java.util.Map<java.lang.String,android.adservices.ondevicepersonalization.AppInstallInfo> mAppInstallInfo\n @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.AppUsageStatus> mAppUsageHistory\n @android.annotation.NonNull android.adservices.ondevicepersonalization.Location mCurrentLocation\n @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.LocationStatus> mLocationHistory\nclass UserData extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = " int mTimezoneUtcOffsetMins\n @android.adservices.ondevicepersonalization.UserData.Orientation int mOrientation\n @android.annotation.IntRange long mAvailableStorageBytes\n @android.annotation.IntRange int mBatteryPercentage\n @android.annotation.NonNull java.lang.String mCarrier\n @android.annotation.Nullable android.net.NetworkCapabilities mNetworkCapabilities\n @android.adservices.ondevicepersonalization.UserData.NetworkType int mDataNetworkType\n @com.android.ondevicepersonalization.internal.util.DataClass.PluralOf(\"appInfo\") @android.annotation.NonNull java.util.Map<java.lang.String,android.adservices.ondevicepersonalization.AppInfo> mAppInfos\n @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.AppUsageStatus> mAppUsageHistory\n @android.annotation.NonNull android.adservices.ondevicepersonalization.Location mCurrentLocation\n @android.annotation.NonNull java.util.List<android.adservices.ondevicepersonalization.LocationStatus> mLocationHistory\npublic @android.annotation.NonNull java.time.Duration getTimezoneUtcOffset()\nclass UserData extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genHiddenBuilder=true, genEqualsHashCode=true, genConstDefs=false)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/adservices/ondevicepersonalization/aidl/IFederatedComputeCallback.aidl b/framework/java/android/adservices/ondevicepersonalization/aidl/IFederatedComputeCallback.aidl new file mode 100644 index 00000000..b224b5ee --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/aidl/IFederatedComputeCallback.aidl @@ -0,0 +1,13 @@ +package android.adservices.ondevicepersonalization.aidl; + + +/** + * Callback from a schedule/cancel federated computation request. + * @hide + */ +oneway interface IFederatedComputeCallback { + /** Sends back a void indicating success. */ + void onSuccess(); + /** Sends back a status code indicating failure. */ + void onFailure(int errorCode); +}
\ No newline at end of file diff --git a/framework/java/android/adservices/ondevicepersonalization/aidl/IFederatedComputeService.aidl b/framework/java/android/adservices/ondevicepersonalization/aidl/IFederatedComputeService.aidl new file mode 100644 index 00000000..b81b7621 --- /dev/null +++ b/framework/java/android/adservices/ondevicepersonalization/aidl/IFederatedComputeService.aidl @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization.aidl; +import android.federatedcompute.common.TrainingOptions; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; + +/** @hide */ +interface IFederatedComputeService { + void schedule( + in TrainingOptions trainingOptions, + in IFederatedComputeCallback callback); + + void cancel( + in String populationName, + in IFederatedComputeCallback callback); +}
\ No newline at end of file diff --git a/framework/java/android/adservices/ondevicepersonalization/aidl/IPrivacyStatusService.aidl b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationConfigService.aidl index 5e85f98c..f2d04e9a 100644 --- a/framework/java/android/adservices/ondevicepersonalization/aidl/IPrivacyStatusService.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationConfigService.aidl @@ -16,14 +16,15 @@ package android.adservices.ondevicepersonalization.aidl; -import android.adservices.ondevicepersonalization.aidl.IPrivacyStatusServiceCallback; +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationConfigServiceCallback; /** * OnDevicePersonalization service that modifies - * user's privacy status by GMS Core only. + * ODP's enablement status by GMS Core only. * @hide */ -interface IPrivacyStatusService { +interface IOnDevicePersonalizationConfigService { - void setKidStatus(in boolean kidStatusEnabled, in IPrivacyStatusServiceCallback callback); + void setPersonalizationStatus(in boolean enabled, + in IOnDevicePersonalizationConfigServiceCallback callback); }
\ No newline at end of file diff --git a/framework/java/android/adservices/ondevicepersonalization/aidl/IPrivacyStatusServiceCallback.aidl b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationConfigServiceCallback.aidl index e97d9c50..74092b41 100644 --- a/framework/java/android/adservices/ondevicepersonalization/aidl/IPrivacyStatusServiceCallback.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationConfigServiceCallback.aidl @@ -19,10 +19,10 @@ package android.adservices.ondevicepersonalization.aidl; import android.os.Bundle; /** - * Callback from a OdpPrivacyStatusService. + * Callback from a OnDevicePersonalizationConfigService. * @hide */ -oneway interface IPrivacyStatusServiceCallback { +oneway interface IOnDevicePersonalizationConfigServiceCallback { void onSuccess(); diff --git a/src/com/android/ondevicepersonalization/services/data/user/OSVersion.java b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationDebugService.aidl index e9355051..c4e1de6a 100644 --- a/src/com/android/ondevicepersonalization/services/data/user/OSVersion.java +++ b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationDebugService.aidl @@ -14,11 +14,9 @@ * limitations under the License. */ -package com.android.ondevicepersonalization.services.data.user; +package android.adservices.ondevicepersonalization.aidl; -/** Values for OS versions. */ -public class OSVersion { - public int major = 0; - public int minor = 0; - public int micro = 0; +/** @hide */ +interface IOnDevicePersonalizationDebugService { + boolean isEnabled(); } diff --git a/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationManagingService.aidl b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationManagingService.aidl index 66cde536..a749fde6 100644 --- a/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationManagingService.aidl +++ b/framework/java/android/adservices/ondevicepersonalization/aidl/IOnDevicePersonalizationManagingService.aidl @@ -17,6 +17,7 @@ package android.adservices.ondevicepersonalization.aidl; import android.content.ComponentName; +import android.adservices.ondevicepersonalization.CallerMetadata; import android.adservices.ondevicepersonalization.aidl.IExecuteCallback; import android.adservices.ondevicepersonalization.aidl.IRequestSurfacePackageCallback; import android.os.Bundle; @@ -28,6 +29,7 @@ interface IOnDevicePersonalizationManagingService { in String callingPackageName, in ComponentName handler, in PersistableBundle params, + in CallerMetadata metadata, in IExecuteCallback callback); void requestSurfacePackage( @@ -36,5 +38,6 @@ interface IOnDevicePersonalizationManagingService { int displayId, int width, int height, + in CallerMetadata metadata, in IRequestSurfacePackageCallback callback); } diff --git a/framework/java/android/federatedcompute/ExampleStoreQueryCallbackImpl.java b/framework/java/android/federatedcompute/ExampleStoreQueryCallbackImpl.java index 2d56e00b..7b491885 100644 --- a/framework/java/android/federatedcompute/ExampleStoreQueryCallbackImpl.java +++ b/framework/java/android/federatedcompute/ExampleStoreQueryCallbackImpl.java @@ -23,8 +23,8 @@ import android.federatedcompute.aidl.IExampleStoreIterator; import android.federatedcompute.aidl.IExampleStoreIteratorCallback; import android.os.Bundle; import android.os.RemoteException; -import android.util.Log; +import com.android.federatedcompute.internal.util.LogUtil; import com.android.internal.util.Preconditions; /** @@ -47,7 +47,7 @@ public class ExampleStoreQueryCallbackImpl implements QueryCallback { try { mExampleStoreQueryCallback.onStartQuerySuccess(iteratorAdapter); } catch (RemoteException e) { - Log.w(TAG, "onIteratorNextSuccess AIDL call failed, closing iterator", e); + LogUtil.w(TAG, e, "onIteratorNextSuccess AIDL call failed, closing iterator"); iteratorAdapter.close(); } } @@ -57,7 +57,7 @@ public class ExampleStoreQueryCallbackImpl implements QueryCallback { try { mExampleStoreQueryCallback.onStartQueryFailure(errorCode); } catch (RemoteException e) { - Log.w(TAG, "onIteratorNextFailure AIDL call failed, closing iterator", e); + LogUtil.w(TAG, e, "onIteratorNextFailure AIDL call failed, closing iterator"); } } /** @@ -79,7 +79,7 @@ public class ExampleStoreQueryCallbackImpl implements QueryCallback { Preconditions.checkNotNull(callback, "callback must not be null"); synchronized (mLock) { if (mClosed) { - Log.w(TAG, "IExampleStoreIterator.next called after close"); + LogUtil.w(TAG, "IExampleStoreIterator.next called after close"); return; } IteratorCallbackAdapter callbackAdapter = @@ -92,7 +92,7 @@ public class ExampleStoreQueryCallbackImpl implements QueryCallback { public void close() { synchronized (mLock) { if (mClosed) { - Log.w(TAG, "IExampleStoreIterator.close called more than once"); + LogUtil.w(TAG, "IExampleStoreIterator.close called more than once"); return; } mClosed = true; @@ -124,7 +124,7 @@ public class ExampleStoreQueryCallbackImpl implements QueryCallback { mExampleStoreIteratorCallback.onIteratorNextSuccess(result); return true; } catch (RemoteException e) { - Log.w(TAG, "onIteratorNextSuccess AIDL call failed, closing iterator", e); + LogUtil.w(TAG, e, "onIteratorNextSuccess AIDL call failed, closing iterator"); mIteratorAdapter.close(); } return false; @@ -135,7 +135,7 @@ public class ExampleStoreQueryCallbackImpl implements QueryCallback { try { mExampleStoreIteratorCallback.onIteratorNextFailure(errorCode); } catch (RemoteException e) { - Log.w(TAG, "onIteratorNextFailure AIDL call failed, closing iterator", e); + LogUtil.w(TAG, e, "onIteratorNextFailure AIDL call failed, closing iterator"); mIteratorAdapter.close(); } } diff --git a/framework/java/android/federatedcompute/ExampleStoreService.java b/framework/java/android/federatedcompute/ExampleStoreService.java index f35af2c2..7daadd3c 100644 --- a/framework/java/android/federatedcompute/ExampleStoreService.java +++ b/framework/java/android/federatedcompute/ExampleStoreService.java @@ -19,6 +19,7 @@ package android.federatedcompute; import android.annotation.NonNull; import android.app.Service; import android.content.Intent; +import android.content.pm.PackageManager; import android.federatedcompute.aidl.IExampleStoreCallback; import android.federatedcompute.aidl.IExampleStoreService; import android.os.Bundle; @@ -37,7 +38,6 @@ import android.os.IBinder; * <service android:enabled="true" android:exported="true" android:name=".YourServiceClass"> * <intent-filter> * <action android:name="com.android.federatedcompute.EXAMPLE_STORE"/> - * <data android:scheme="app"/> * </intent-filter> * </service> * </application> @@ -46,7 +46,10 @@ import android.os.IBinder; * @hide */ public abstract class ExampleStoreService extends Service { - private static final String TAG = "ExampleStoreService"; + private static final String TAG = ExampleStoreService.class.getSimpleName(); + + private static final String BIND_EXAMPLE_STORE_SERVICE = + "android.permission.BIND_EXAMPLE_STORE_SERVICE"; private IBinder mIBinder; @Override @@ -62,15 +65,32 @@ public abstract class ExampleStoreService extends Service { class ServiceBinder extends IExampleStoreService.Stub { @Override public void startQuery(Bundle params, IExampleStoreCallback callback) { + if (!ExampleStoreService.this.checkCallerPermission()) { + throw new SecurityException( + "Unauthorized startQuery call to ExampleStore."); + } ExampleStoreService.this.startQuery( params, new ExampleStoreQueryCallbackImpl(callback)); } } + + /** + * To be overridden by implementation to provide checks if caller has specific permission to be + * used in ServiceBinder call. + * + * @return true if permission granted + */ + protected boolean checkCallerPermission() { + return checkCallingOrSelfPermission(BIND_EXAMPLE_STORE_SERVICE) + == PackageManager.PERMISSION_GRANTED; + } + /** * The abstract method that client apps should implement to start a new example store query * using the given selection criteria. */ public abstract void startQuery(@NonNull Bundle params, @NonNull QueryCallback callback); + /** * The client apps use this callback to return their ExampleStoreIterator implementation to the * federated training service. @@ -78,7 +98,8 @@ public abstract class ExampleStoreService extends Service { public interface QueryCallback { /** Called when the iterator is ready for use. */ void onStartQuerySuccess(@NonNull ExampleStoreIterator iterator); - /** Called when an error occurred and the iterator cannot not be created. */ + + /** Called when an error occurred and the iterator cannot be created. */ void onStartQueryFailure(int errorCode); } } diff --git a/framework/java/android/federatedcompute/FederatedComputeException.java b/framework/java/android/federatedcompute/FederatedComputeException.java new file mode 100644 index 00000000..8fdfa2e6 --- /dev/null +++ b/framework/java/android/federatedcompute/FederatedComputeException.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.federatedcompute; + +import android.annotation.IntDef; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * Exception thrown by Federated Compute APIs. + * @hide + */ +public class FederatedComputeException extends Exception { + /** + * Internal error. + */ + public static final int INTERNAL_ERROR = 1; + + /** @hide */ + @IntDef(prefix = "ERROR_", value = { + INTERNAL_ERROR + }) + @Retention(RetentionPolicy.SOURCE) + public @interface ErrorCode {} + + private final @ErrorCode int mErrorCode; + + /** @hide */ + public FederatedComputeException(@ErrorCode int errorCode) { + mErrorCode = errorCode; + } + + /** Returns the error code for this exception. */ + public @ErrorCode int getErrorCode() { + return mErrorCode; + } +} diff --git a/framework/java/android/federatedcompute/FederatedComputeManager.java b/framework/java/android/federatedcompute/FederatedComputeManager.java index be971bd4..2447d837 100644 --- a/framework/java/android/federatedcompute/FederatedComputeManager.java +++ b/framework/java/android/federatedcompute/FederatedComputeManager.java @@ -16,31 +16,20 @@ package android.federatedcompute; -import static java.util.concurrent.TimeUnit.MILLISECONDS; - -import android.adservices.ondevicepersonalization.OnDevicePersonalizationException; import android.annotation.CallbackExecutor; import android.annotation.NonNull; -import android.annotation.Nullable; -import android.content.ComponentName; import android.content.Context; -import android.content.Intent; -import android.content.ServiceConnection; -import android.content.pm.ResolveInfo; -import android.content.pm.ServiceInfo; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IFederatedComputeService; import android.federatedcompute.common.ScheduleFederatedComputeRequest; -import android.os.IBinder; import android.os.OutcomeReceiver; import android.os.RemoteException; -import android.util.Log; -import com.android.internal.annotations.GuardedBy; +import com.android.federatedcompute.internal.util.AbstractServiceBinder; +import com.android.federatedcompute.internal.util.LogUtil; import java.util.List; import java.util.Objects; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; /** @@ -49,27 +38,37 @@ import java.util.concurrent.Executor; * @hide */ public final class FederatedComputeManager { - private static final String TAG = "FederatedComputeManager"; + /** + * Constant that represents the service name for {@link FederatedComputeManager} to be used in + * {@link android.ondevicepersonalization.OnDevicePersonalizationFrameworkInitializer + * #registerServiceWrappers} + * + * @hide + */ + public static final String FEDERATED_COMPUTE_SERVICE = "federated_compute_service"; + + private static final String TAG = FederatedComputeManager.class.getSimpleName(); private static final String FEDERATED_COMPUTATION_SERVICE_INTENT_FILTER_NAME = "android.federatedcompute.FederatedComputeService"; - private static final int BINDER_CONNECTION_TIMEOUT_MS = 5000; - - // A CountDownloadLatch which will be opened when the connection is established or any error - // occurs. - private CountDownLatch mConnectionCountDownLatch; - // Concurrency mLock. - private final Object mLock = new Object(); - - @GuardedBy("mLock") - private IFederatedComputeService mFcpService; - - @GuardedBy("mLock") - private ServiceConnection mServiceConnection; + private static final String FEDERATED_COMPUTATION_SERVICE_PACKAGE = + "com.android.federatedcompute.services"; + private static final String ALT_FEDERATED_COMPUTATION_SERVICE_PACKAGE = + "com.google.android.federatedcompute"; private final Context mContext; - FederatedComputeManager(Context context) { + private final AbstractServiceBinder<IFederatedComputeService> mServiceBinder; + + public FederatedComputeManager(Context context) { this.mContext = context; + this.mServiceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + context, + FEDERATED_COMPUTATION_SERVICE_INTENT_FILTER_NAME, + List.of( + FEDERATED_COMPUTATION_SERVICE_PACKAGE, + ALT_FEDERATED_COMPUTATION_SERVICE_PACKAGE), + IFederatedComputeService.Stub::asInterface); } /** @@ -77,149 +76,89 @@ public final class FederatedComputeManager { * * @hide */ - public void scheduleFederatedCompute( + public void schedule( @NonNull ScheduleFederatedComputeRequest request, @NonNull @CallbackExecutor Executor executor, @NonNull OutcomeReceiver<Object, Exception> callback) { Objects.requireNonNull(request); - final IFederatedComputeService service = getService(executor); + final IFederatedComputeService service = mServiceBinder.getService(executor); try { IFederatedComputeCallback federatedComputeCallback = new IFederatedComputeCallback.Stub() { @Override public void onSuccess() { + LogUtil.d(TAG, ": schedule onSuccess() called"); executor.execute(() -> callback.onResult(null)); + unbindFromService(); } @Override public void onFailure(int errorCode) { + LogUtil.d( + TAG, + ": schedule onFailure() called with errorCode %d", + errorCode); executor.execute( () -> callback.onError( - new OnDevicePersonalizationException( - errorCode))); + new FederatedComputeException(errorCode))); + unbindFromService(); } }; - service.scheduleFederatedCompute( + service.schedule( mContext.getPackageName(), request.getTrainingOptions(), federatedComputeCallback); } catch (RemoteException e) { - Log.e(TAG, "Remote Exception", e); + LogUtil.e(TAG, e, "Remote Exception"); executor.execute(() -> callback.onError(e)); - } - } - - private IFederatedComputeService getService(@NonNull Executor executor) { - synchronized (mLock) { - if (mFcpService != null) { - return mFcpService; - } - if (mServiceConnection == null) { - Intent intent = new Intent(FEDERATED_COMPUTATION_SERVICE_INTENT_FILTER_NAME); - ComponentName serviceComponent = resolveService(intent); - if (serviceComponent == null) { - Log.e(TAG, "Invalid component for federatedcompute service"); - throw new IllegalStateException( - "Invalid component for federatedcompute service"); - } - intent.setComponent(serviceComponent); - // This latch will open when the connection is established or any error occurs. - mConnectionCountDownLatch = new CountDownLatch(1); - mServiceConnection = new FederatedComputeServiceConnection(); - boolean result = - mContext.bindService( - intent, Context.BIND_AUTO_CREATE, executor, mServiceConnection); - if (!result) { - mServiceConnection = null; - throw new IllegalStateException("Unable to bind to the service"); - } else { - Log.i(TAG, "bindService() succeeded..."); - } - } else { - Log.i(TAG, "bindService() already pending..."); - } - try { - mConnectionCountDownLatch.await(BINDER_CONNECTION_TIMEOUT_MS, MILLISECONDS); - } catch (InterruptedException e) { - throw new IllegalStateException("Thread interrupted"); // TODO Handle it better. - } - synchronized (mLock) { - if (mFcpService == null) { - throw new IllegalStateException("Failed to connect to the service"); - } - return mFcpService; - } + unbindFromService(); } } /** - * Find the ComponentName of the service, given its intent and package manager. + * Cancel FederatedCompute task. * - * @return ComponentName of the service. Null if the service is not found. + * @hide */ - @Nullable - private ComponentName resolveService(@NonNull Intent intent) { - List<ResolveInfo> services = mContext.getPackageManager().queryIntentServices(intent, 0); - if (services == null || services.isEmpty()) { - Log.e(TAG, "Failed to find federatedcompute service"); - return null; - } + public void cancel( + @NonNull String populationName, + @NonNull @CallbackExecutor Executor executor, + @NonNull OutcomeReceiver<Object, Exception> callback) { + Objects.requireNonNull(populationName); + final IFederatedComputeService service = mServiceBinder.getService(executor); + try { + IFederatedComputeCallback federatedComputeCallback = + new IFederatedComputeCallback.Stub() { + @Override + public void onSuccess() { + LogUtil.d(TAG, ": cancel onSuccess() called"); + executor.execute(() -> callback.onResult(null)); + unbindFromService(); + } - for (int i = 0; i < services.size(); i++) { - ServiceInfo serviceInfo = services.get(i).serviceInfo; - if (serviceInfo == null) { - Log.e(TAG, "Failed to find serviceInfo for federatedcompute service."); - return null; - } - // There should only be one matching service inside the given package. - // If there's more than one, return the first one found. - return new ComponentName(serviceInfo.packageName, serviceInfo.name); + @Override + public void onFailure(int errorCode) { + LogUtil.d( + TAG, + ": cancel onFailure() called with errorCode %d", + errorCode); + executor.execute( + () -> + callback.onError( + new FederatedComputeException(errorCode))); + unbindFromService(); + } + }; + service.cancel(mContext.getPackageName(), populationName, federatedComputeCallback); + } catch (RemoteException e) { + LogUtil.e(TAG, e, "Remote Exception"); + executor.execute(() -> callback.onError(e)); + unbindFromService(); } - Log.e(TAG, "Didn't find any matching federatedcompute service."); - return null; } public void unbindFromService() { - synchronized (mLock) { - if (mServiceConnection != null) { - Log.i(TAG, "unbinding..."); - mContext.unbindService(mServiceConnection); - } - mServiceConnection = null; - mFcpService = null; - } - } - - private class FederatedComputeServiceConnection implements ServiceConnection { - @Override - public void onServiceConnected(ComponentName name, IBinder service) { - Log.d(TAG, "onServiceConnected"); - synchronized (mLock) { - mFcpService = IFederatedComputeService.Stub.asInterface(service); - } - mConnectionCountDownLatch.countDown(); - } - - @Override - public void onServiceDisconnected(ComponentName name) { - Log.d(TAG, "onServiceDisconnected"); - unbindFromService(); - mConnectionCountDownLatch.countDown(); - } - - @Override - public void onBindingDied(ComponentName name) { - Log.e(TAG, "onBindingDied"); - unbindFromService(); - mConnectionCountDownLatch.countDown(); - } - - @Override - public void onNullBinding(ComponentName name) { - Log.e(TAG, "onNullBinding shouldn't happen."); - unbindFromService(); - mConnectionCountDownLatch.countDown(); - } + mServiceBinder.unbindFromService(); } } diff --git a/framework/java/android/federatedcompute/ResultHandlingService.java b/framework/java/android/federatedcompute/ResultHandlingService.java index 62f51e6d..8b79e2c3 100644 --- a/framework/java/android/federatedcompute/ResultHandlingService.java +++ b/framework/java/android/federatedcompute/ResultHandlingService.java @@ -18,17 +18,17 @@ package android.federatedcompute; import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; +import android.annotation.NonNull; import android.app.Service; import android.content.Intent; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IResultHandlingService; -import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingOptions; +import android.os.Bundle; import android.os.IBinder; import android.os.RemoteException; -import android.util.Log; -import java.util.List; +import com.android.federatedcompute.internal.util.LogUtil; + import java.util.function.Consumer; /** @@ -66,16 +66,8 @@ public abstract class ResultHandlingService extends Service { private class ServiceBinder extends IResultHandlingService.Stub { @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) { - ResultHandlingService.this.handleResult( - trainingOptions, - success, - exampleConsumptionList, - new ResultHandlingCallback(callback)); + public void handleResult(Bundle params, IFederatedComputeCallback callback) { + ResultHandlingService.this.handleResult(params, new ResultHandlingCallback(callback)); } } @@ -101,7 +93,8 @@ public abstract class ResultHandlingService extends Service { } mInternalCallback.onFailure(status); } catch (RemoteException e) { - Log.w(TAG, "An error occurred when trying to communicate with FederatedCompute."); + LogUtil.w( + TAG, "An error occurred when trying to communicate with FederatedCompute."); } } } @@ -110,9 +103,5 @@ public abstract class ResultHandlingService extends Service { * The client app needs to implement this method to handle results. After handling the results, * the client app should signal FederatedCompute via the ResultHandlingCallback. */ - public abstract void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - Consumer<Integer> callback); + public abstract void handleResult(@NonNull Bundle params, Consumer<Integer> callback); } diff --git a/framework/java/android/federatedcompute/aidl/IFederatedComputeService.aidl b/framework/java/android/federatedcompute/aidl/IFederatedComputeService.aidl index 212c6155..1c247d3d 100644 --- a/framework/java/android/federatedcompute/aidl/IFederatedComputeService.aidl +++ b/framework/java/android/federatedcompute/aidl/IFederatedComputeService.aidl @@ -21,8 +21,13 @@ import android.federatedcompute.aidl.IFederatedComputeCallback; /** @hide */ interface IFederatedComputeService { - void scheduleFederatedCompute( + void schedule( in String callingPackageName, in TrainingOptions trainingOptions, in IFederatedComputeCallback callback); + + void cancel( + in String callingPackageName, + in String populationName, + in IFederatedComputeCallback callback); }
\ No newline at end of file diff --git a/framework/java/android/federatedcompute/aidl/IResultHandlingService.aidl b/framework/java/android/federatedcompute/aidl/IResultHandlingService.aidl index 8dbefcb5..b07271a8 100644 --- a/framework/java/android/federatedcompute/aidl/IResultHandlingService.aidl +++ b/framework/java/android/federatedcompute/aidl/IResultHandlingService.aidl @@ -29,7 +29,5 @@ import android.federatedcompute.aidl.IFederatedComputeCallback; interface IResultHandlingService { /** The app will implement this method to handle results. */ - oneway void handleResult(in TrainingOptions trainingOptions, - boolean success, in List<ExampleConsumption> exampleConsumptionList, - in IFederatedComputeCallback callback); + oneway void handleResult(in Bundle params, in IFederatedComputeCallback callback); }
\ No newline at end of file diff --git a/framework/java/android/federatedcompute/common/ClientConstants.java b/framework/java/android/federatedcompute/common/ClientConstants.java index 87242e9e..29342eb4 100644 --- a/framework/java/android/federatedcompute/common/ClientConstants.java +++ b/framework/java/android/federatedcompute/common/ClientConstants.java @@ -25,8 +25,18 @@ public final class ClientConstants { // Status code constants. public static final int STATUS_SUCCESS = 0; public static final int STATUS_INTERNAL_ERROR = 1; + public static final int STATUS_TRAINING_FAILED = 2; + public static final String EXTRA_POPULATION_NAME = "android.federatedcompute.population_name"; - public static final String EXTRA_COLLECTION_NAME = "android.federatedcompute.collection_name"; + public static final String EXTRA_TASK_NAME = "android.federatedcompute.task_name"; + + public static final String EXTRA_CONTEXT_DATA = "android.federatedcompute.context_data"; + + public static final String EXTRA_COMPUTATION_RESULT = + "android.federatedcompute.computation_result"; + + public static final String EXTRA_EXAMPLE_CONSUMPTION_LIST = + "android.federatedcompute.example_consumption_list"; // ExampleStoreService related constants. public static final String EXAMPLE_STORE_ACTION = "android.federatedcompute.EXAMPLE_STORE"; diff --git a/framework/java/android/federatedcompute/common/ExampleConsumption.java b/framework/java/android/federatedcompute/common/ExampleConsumption.java index 3fa1ec0f..57ba1a6d 100644 --- a/framework/java/android/federatedcompute/common/ExampleConsumption.java +++ b/framework/java/android/federatedcompute/common/ExampleConsumption.java @@ -25,16 +25,16 @@ import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; /** - * A container for information regarding an example store access, including the collection name, the + * A container for information regarding an example store access, including the task name, the * selection criteria and the number of examples which has been used. * * @hide */ @DataClass(genBuilder = true, genEqualsHashCode = true) public class ExampleConsumption implements Parcelable { - @NonNull private final String mCollectionName; + @NonNull private final String mTaskName; - @NonNull private final byte[] mSelectionCriteria; + @Nullable private final byte[] mSelectionCriteria; private final int mExampleCount; @@ -55,14 +55,13 @@ public class ExampleConsumption implements Parcelable { @DataClass.Generated.Member /* package-private */ ExampleConsumption( - @NonNull String collectionName, - @NonNull byte[] selectionCriteria, + @NonNull String taskName, + @Nullable byte[] selectionCriteria, int exampleCount, @Nullable byte[] resumptionToken) { - this.mCollectionName = collectionName; - AnnotationValidations.validate(NonNull.class, null, mCollectionName); + this.mTaskName = taskName; + AnnotationValidations.validate(NonNull.class, null, mTaskName); this.mSelectionCriteria = selectionCriteria; - AnnotationValidations.validate(NonNull.class, null, mSelectionCriteria); this.mExampleCount = exampleCount; this.mResumptionToken = resumptionToken; @@ -70,12 +69,12 @@ public class ExampleConsumption implements Parcelable { } @DataClass.Generated.Member - public @NonNull String getCollectionName() { - return mCollectionName; + public @NonNull String getTaskName() { + return mTaskName; } @DataClass.Generated.Member - public @NonNull byte[] getSelectionCriteria() { + public @Nullable byte[] getSelectionCriteria() { return mSelectionCriteria; } @@ -102,7 +101,7 @@ public class ExampleConsumption implements Parcelable { ExampleConsumption that = (ExampleConsumption) o; //noinspection PointlessBooleanExpression return true - && java.util.Objects.equals(mCollectionName, that.mCollectionName) + && java.util.Objects.equals(mTaskName, that.mTaskName) && java.util.Arrays.equals(mSelectionCriteria, that.mSelectionCriteria) && mExampleCount == that.mExampleCount && java.util.Arrays.equals(mResumptionToken, that.mResumptionToken); @@ -115,7 +114,7 @@ public class ExampleConsumption implements Parcelable { // int fieldNameHashCode() { ... } int _hash = 1; - _hash = 31 * _hash + java.util.Objects.hashCode(mCollectionName); + _hash = 31 * _hash + java.util.Objects.hashCode(mTaskName); _hash = 31 * _hash + java.util.Arrays.hashCode(mSelectionCriteria); _hash = 31 * _hash + mExampleCount; _hash = 31 * _hash + java.util.Arrays.hashCode(mResumptionToken); @@ -128,13 +127,10 @@ public class ExampleConsumption implements Parcelable { // You can override field parcelling by defining methods like: // void parcelFieldName(Parcel dest, int flags) { ... } - byte flg = 0; - if (mResumptionToken != null) flg |= 0x8; - dest.writeByte(flg); - dest.writeString(mCollectionName); + dest.writeString(mTaskName); dest.writeByteArray(mSelectionCriteria); dest.writeInt(mExampleCount); - if (mResumptionToken != null) dest.writeByteArray(mResumptionToken); + dest.writeByteArray(mResumptionToken); } @Override @@ -150,16 +146,14 @@ public class ExampleConsumption implements Parcelable { // You can override field unparcelling by defining methods like: // static FieldType unparcelFieldName(Parcel in) { ... } - byte flg = in.readByte(); - String collectionName = in.readString(); + String taskName = in.readString(); byte[] selectionCriteria = in.createByteArray(); int exampleCount = in.readInt(); - byte[] resumptionToken = (flg & 0x8) == 0 ? null : in.createByteArray(); + byte[] resumptionToken = in.createByteArray(); - this.mCollectionName = collectionName; - AnnotationValidations.validate(NonNull.class, null, mCollectionName); + this.mTaskName = taskName; + AnnotationValidations.validate(NonNull.class, null, mTaskName); this.mSelectionCriteria = selectionCriteria; - AnnotationValidations.validate(NonNull.class, null, mSelectionCriteria); this.mExampleCount = exampleCount; this.mResumptionToken = resumptionToken; @@ -185,8 +179,8 @@ public class ExampleConsumption implements Parcelable { @DataClass.Generated.Member public static class Builder { - private @NonNull String mCollectionName; - private @NonNull byte[] mSelectionCriteria; + private @NonNull String mTaskName; + private @Nullable byte[] mSelectionCriteria; private int mExampleCount; private @Nullable byte[] mResumptionToken; @@ -195,11 +189,11 @@ public class ExampleConsumption implements Parcelable { public Builder() {} @DataClass.Generated.Member - public @NonNull Builder setCollectionName(@NonNull String value) { + public @NonNull Builder setTaskName(@NonNull String value) { checkNotUsed(); Preconditions.checkStringNotEmpty(value); mBuilderFieldsSet |= 0x1; - mCollectionName = value; + mTaskName = value; return this; } @@ -234,7 +228,7 @@ public class ExampleConsumption implements Parcelable { ExampleConsumption o = new ExampleConsumption( - mCollectionName, mSelectionCriteria, mExampleCount, mResumptionToken); + mTaskName, mSelectionCriteria, mExampleCount, mResumptionToken); return o; } diff --git a/framework/java/android/federatedcompute/common/TrainingOptions.java b/framework/java/android/federatedcompute/common/TrainingOptions.java index 17d29eb6..3a6b8b24 100644 --- a/framework/java/android/federatedcompute/common/TrainingOptions.java +++ b/framework/java/android/federatedcompute/common/TrainingOptions.java @@ -31,14 +31,17 @@ import com.android.ondevicepersonalization.internal.util.DataClass; */ @DataClass(genBuilder = true, genEqualsHashCode = true) public final class TrainingOptions implements Parcelable { - /** The task name to be provided to the federated compute server during checkin. */ + /** + * The task name to be provided to the federated compute server during checkin. The field is + * required and should not be empty. + */ @NonNull private String mPopulationName = ""; /** * The remote federated compute server address that federated compute client need to checkin. - * It's required when you first time schedule the job. + * The field is required and should not be empty. */ - @Nullable private String mServerAddress = ""; + @NonNull private String mServerAddress = ""; @Nullable private TrainingInterval mTrainingInterval = null; @@ -46,7 +49,7 @@ public final class TrainingOptions implements Parcelable { * The context data that federatedcompute will pass back to client when bind to * ExampleStoreService and ResultHandlingService. */ - @Nullable private byte[] mContextData; + @Nullable private final byte[] mContextData; // Code below generated by codegen v1.0.23. // @@ -64,12 +67,13 @@ public final class TrainingOptions implements Parcelable { @DataClass.Generated.Member /* package-private */ TrainingOptions( @NonNull String populationName, - @Nullable String serverAddress, + @NonNull String serverAddress, @Nullable TrainingInterval trainingInterval, @Nullable byte[] contextData) { this.mPopulationName = populationName; AnnotationValidations.validate(NonNull.class, null, mPopulationName); this.mServerAddress = serverAddress; + AnnotationValidations.validate(NonNull.class, null, mServerAddress); this.mTrainingInterval = trainingInterval; this.mContextData = contextData; @@ -87,7 +91,7 @@ public final class TrainingOptions implements Parcelable { * It's required when you first time schedule the job. */ @DataClass.Generated.Member - public @Nullable String getServerAddress() { + public @NonNull String getServerAddress() { return mServerAddress; } @@ -145,11 +149,10 @@ public final class TrainingOptions implements Parcelable { // void parcelFieldName(Parcel dest, int flags) { ... } byte flg = 0; - if (mServerAddress != null) flg |= 0x2; if (mTrainingInterval != null) flg |= 0x4; dest.writeByte(flg); dest.writeString(mPopulationName); - if (mServerAddress != null) dest.writeString(mServerAddress); + dest.writeString(mServerAddress); if (mTrainingInterval != null) dest.writeTypedObject(mTrainingInterval, flags); dest.writeByteArray(mContextData); } @@ -169,7 +172,7 @@ public final class TrainingOptions implements Parcelable { byte flg = in.readByte(); String populationName = in.readString(); - String serverAddress = (flg & 0x2) == 0 ? null : in.readString(); + String serverAddress = in.readString(); TrainingInterval trainingInterval = (flg & 0x4) == 0 ? null @@ -179,6 +182,7 @@ public final class TrainingOptions implements Parcelable { this.mPopulationName = populationName; AnnotationValidations.validate(NonNull.class, null, mPopulationName); this.mServerAddress = serverAddress; + AnnotationValidations.validate(NonNull.class, null, mServerAddress); this.mTrainingInterval = trainingInterval; this.mContextData = contextData; @@ -205,7 +209,7 @@ public final class TrainingOptions implements Parcelable { public static final class Builder { private @NonNull String mPopulationName; - private @Nullable String mServerAddress; + private @NonNull String mServerAddress; private @Nullable TrainingInterval mTrainingInterval; private @Nullable byte[] mContextData; @@ -230,6 +234,7 @@ public final class TrainingOptions implements Parcelable { @DataClass.Generated.Member public @NonNull Builder setServerAddress(@NonNull String value) { checkNotUsed(); + Preconditions.checkStringNotEmpty(value); mBuilderFieldsSet |= 0x2; mServerAddress = value; return this; @@ -284,20 +289,12 @@ public final class TrainingOptions implements Parcelable { } @DataClass.Generated( - time = 1692048512258L, + time = 1696467816547L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/federatedcompute/common/TrainingOptions.java", inputSignatures = - "private @android.annotation.NonNull java.lang.String mPopulationName\n" - + "private @android.annotation.Nullable java.lang.String mServerAddress\n" - + "private @android.annotation.Nullable" - + " android.federatedcompute.common.TrainingInterval mTrainingInterval\n" - + "private @android.annotation.Nullable byte[] mContextData\n" - + "class TrainingOptions extends java.lang.Object implements" - + " [android.os.Parcelable]\n" - + "@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true," - + " genEqualsHashCode=true)") + "private final @android.annotation.NonNull java.lang.String mPopulationName\nprivate final @android.annotation.NonNull java.lang.String mServerAddress\nprivate final @android.annotation.Nullable android.federatedcompute.common.TrainingInterval mTrainingInterval\nprivate @android.annotation.Nullable byte[] mContextData\nclass TrainingOptions extends java.lang.Object implements [android.os.Parcelable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemService.aidl b/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemService.aidl index 7ab14ed3..c6915279 100644 --- a/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemService.aidl +++ b/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemService.aidl @@ -25,4 +25,11 @@ interface IOnDevicePersonalizationSystemService { in Bundle params, in IOnDevicePersonalizationSystemServiceCallback callback ); + + void setPersonalizationStatus( + in boolean enabled, + in IOnDevicePersonalizationSystemServiceCallback callback + ); + + void readPersonalizationStatus(in IOnDevicePersonalizationSystemServiceCallback callback); } diff --git a/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemServiceCallback.aidl b/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemServiceCallback.aidl index 4c2f960f..a2c90005 100644 --- a/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemServiceCallback.aidl +++ b/framework/java/android/ondevicepersonalization/IOnDevicePersonalizationSystemServiceCallback.aidl @@ -21,4 +21,6 @@ import android.os.Bundle; /** @hide */ oneway interface IOnDevicePersonalizationSystemServiceCallback { void onResult(in Bundle result); + + void onError(int errorCode); } diff --git a/framework/java/android/ondevicepersonalization/OnDevicePersonalizationFrameworkInitializer.java b/framework/java/android/ondevicepersonalization/OnDevicePersonalizationFrameworkInitializer.java index d386b602..bd591e79 100644 --- a/framework/java/android/ondevicepersonalization/OnDevicePersonalizationFrameworkInitializer.java +++ b/framework/java/android/ondevicepersonalization/OnDevicePersonalizationFrameworkInitializer.java @@ -16,15 +16,17 @@ package android.ondevicepersonalization; +import static android.adservices.ondevicepersonalization.OnDevicePersonalizationConfigManager.ON_DEVICE_PERSONALIZATION_CONFIG_SERVICE; import static android.adservices.ondevicepersonalization.OnDevicePersonalizationManager.ON_DEVICE_PERSONALIZATION_SERVICE; -import static android.adservices.ondevicepersonalization.OnDevicePersonalizationPrivacyStatusManager.ON_DEVICE_PERSONALIZATION_PRIVACY_STATUS_SERVICE; +import static android.federatedcompute.FederatedComputeManager.FEDERATED_COMPUTE_SERVICE; import static android.ondevicepersonalization.OnDevicePersonalizationSystemServiceManager.ON_DEVICE_PERSONALIZATION_SYSTEM_SERVICE; +import android.adservices.ondevicepersonalization.OnDevicePersonalizationConfigManager; import android.adservices.ondevicepersonalization.OnDevicePersonalizationManager; -import android.adservices.ondevicepersonalization.OnDevicePersonalizationPrivacyStatusManager; import android.annotation.SystemApi; import android.app.SystemServiceRegistry; import android.content.Context; +import android.federatedcompute.FederatedComputeManager; import com.android.modules.utils.build.SdkLevel; @@ -51,9 +53,12 @@ public class OnDevicePersonalizationFrameworkInitializer { ON_DEVICE_PERSONALIZATION_SERVICE, OnDevicePersonalizationManager.class, (c) -> new OnDevicePersonalizationManager(c)); SystemServiceRegistry.registerContextAwareService( - ON_DEVICE_PERSONALIZATION_PRIVACY_STATUS_SERVICE, - OnDevicePersonalizationPrivacyStatusManager.class, - (c) -> new OnDevicePersonalizationPrivacyStatusManager(c)); + ON_DEVICE_PERSONALIZATION_CONFIG_SERVICE, + OnDevicePersonalizationConfigManager.class, + (c) -> new OnDevicePersonalizationConfigManager(c)); + SystemServiceRegistry.registerContextAwareService( + FEDERATED_COMPUTE_SERVICE, FederatedComputeManager.class, + (c) -> new FederatedComputeManager(c)); if (SdkLevel.isAtLeastU()) { SystemServiceRegistry.registerStaticService( diff --git a/framework/java/com/android/federatedcompute/internal/util/AbstractServiceBinder.java b/framework/java/com/android/federatedcompute/internal/util/AbstractServiceBinder.java new file mode 100644 index 00000000..9d01a15e --- /dev/null +++ b/framework/java/com/android/federatedcompute/internal/util/AbstractServiceBinder.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.android.federatedcompute.internal.util; + +import android.content.Context; +import android.os.IBinder; + +import java.util.List; +import java.util.concurrent.Executor; +import java.util.function.Function; + +/** + * Abstracts how to find and connect the service binder object. + * + * @param <T> The type of Service Binder. + * @hide + */ +public abstract class AbstractServiceBinder<T> { + /** Get the {@link AbstractServiceBinder} suitable for the configuration. */ + public static <T2> AbstractServiceBinder<T2> getServiceBinderByIntent( + Context context, + String serviceIntentAction, + String servicePackage, + Function<IBinder, T2> converter) { + return new AndroidServiceBinder<>( + context, serviceIntentAction, servicePackage, converter); + } + + /** Get the {@link AbstractServiceBinder} suitable for the configuration. */ + public static <T2> AbstractServiceBinder<T2> getServiceBinderByServiceName( + Context context, + String serviceName, + String servicePackage, + Function<IBinder, T2> converter) { + return new AndroidServiceBinder<>( + context, serviceName, servicePackage, true, converter); + } + + /** Get the {@link AbstractServiceBinder} suitable for the configuration. */ + public static <T2> AbstractServiceBinder<T2> getServiceBinderByIntent( + Context context, + String serviceIntentAction, + List<String> servicePackages, + Function<IBinder, T2> converter) { + return new AndroidServiceBinder<>( + context, serviceIntentAction, servicePackages, converter); + } + + /** Get the {@link AbstractServiceBinder} suitable for the configuration. */ + public static <T2> AbstractServiceBinder<T2> getServiceBinderByIntent( + Context context, + String serviceIntentAction, + List<String> servicePackages, + int bindFlags, + Function<IBinder, T2> converter) { + return new AndroidServiceBinder<>( + context, serviceIntentAction, servicePackages, bindFlags, converter); + } + + /** Get the binder service. */ + public abstract T getService(Executor executor); + + /** + * The service is in an APK (as opposed to the system service), unbind it from the service to + * allow the APK process to die. + */ + public abstract void unbindFromService(); +} diff --git a/framework/java/com/android/federatedcompute/internal/util/AndroidServiceBinder.java b/framework/java/com/android/federatedcompute/internal/util/AndroidServiceBinder.java new file mode 100644 index 00000000..1340de82 --- /dev/null +++ b/framework/java/com/android/federatedcompute/internal/util/AndroidServiceBinder.java @@ -0,0 +1,262 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.internal.util; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.content.ComponentName; +import android.content.Context; +import android.content.Intent; +import android.content.ServiceConnection; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.os.IBinder; + +import com.android.internal.annotations.GuardedBy; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.function.Function; + +class AndroidServiceBinder<T> extends AbstractServiceBinder<T> { + private static final String TAG = AndroidServiceBinder.class.getSimpleName(); + + private static final int BINDER_CONNECTION_TIMEOUT_MS = 5000; + private final String mServiceIntentActionOrName; + private final List<String> mServicePackages; + private final Function<IBinder, T> mBinderConverter; + private final Context mContext; + private final boolean mEnableLookupByServiceName; + private final int mBindFlags; + // Concurrency mLock. + private final Object mLock = new Object(); + // A CountDownloadLatch which will be opened when the connection is established or any error + // occurs. + private CountDownLatch mConnectionCountDownLatch; + + @GuardedBy("mLock") + private T mService; + + @GuardedBy("mLock") + private ServiceConnection mServiceConnection; + + AndroidServiceBinder( + @NonNull Context context, + @NonNull String serviceIntentAction, + @NonNull String servicePackage, + @NonNull Function<IBinder, T> converter) { + this(context, serviceIntentAction, List.of(servicePackage), converter); + } + + AndroidServiceBinder( + @NonNull Context context, + @NonNull String serviceIntentAction, + @NonNull List<String> servicePackages, + @NonNull Function<IBinder, T> converter) { + this(context, serviceIntentAction, servicePackages, 0, converter); + } + + AndroidServiceBinder( + @NonNull Context context, + @NonNull String serviceIntentAction, + @NonNull List<String> servicePackages, + int bindFlags, + @NonNull Function<IBinder, T> converter) { + this.mServiceIntentActionOrName = serviceIntentAction; + this.mContext = context; + this.mBinderConverter = converter; + this.mServicePackages = servicePackages; + this.mEnableLookupByServiceName = false; + this.mBindFlags = bindFlags; + } + + AndroidServiceBinder( + @NonNull Context context, + @NonNull String serviceIntentActionOrName, + @NonNull String servicePackage, + boolean enableLookupByName, + @NonNull Function<IBinder, T> converter) { + this.mServiceIntentActionOrName = serviceIntentActionOrName; + this.mContext = context; + this.mBinderConverter = converter; + this.mServicePackages = List.of(servicePackage); + this.mEnableLookupByServiceName = enableLookupByName; + this.mBindFlags = 0; + } + + @Override + public T getService(@NonNull Executor executor) { + synchronized (mLock) { + if (mService != null) { + return mService; + } + if (mServiceConnection == null) { + Intent bindIntent = + mEnableLookupByServiceName + ? getIntentBasedOnServiceName() + : getIntentBasedOnAction(); + // This latch will open when the connection is established or any error occurs. + mConnectionCountDownLatch = new CountDownLatch(1); + mServiceConnection = new GenericServiceConnection(); + boolean result = + mContext.bindService( + bindIntent, + Context.BIND_AUTO_CREATE | mBindFlags, + executor, + mServiceConnection); + if (!result) { + mServiceConnection = null; + throw new IllegalStateException( + String.format( + "Unable to bind to the service %s", + mServiceIntentActionOrName)); + } else { + LogUtil.i(TAG, "bindService() %s succeeded...", mServiceIntentActionOrName); + } + } else { + LogUtil.i(TAG, "bindService() %s already pending...", mServiceIntentActionOrName); + } + } + // release the lock to let connection to set the mFcpService + try { + mConnectionCountDownLatch.await(BINDER_CONNECTION_TIMEOUT_MS, MILLISECONDS); + } catch (InterruptedException e) { + throw new IllegalStateException("Thread interrupted"); // TODO Handle it better. + } + synchronized (mLock) { + if (mService == null) { + throw new IllegalStateException( + String.format( + "Failed to connect to the service %s", mServiceIntentActionOrName)); + } + return mService; + } + } + + private Intent getIntentBasedOnServiceName() { + Intent intent = new Intent(); + ComponentName serviceComponent = + new ComponentName(mServicePackages.get(0), mServiceIntentActionOrName); + intent.setComponent(serviceComponent); + return intent; + } + + private Intent getIntentBasedOnAction() { + Intent intent = new Intent(mServiceIntentActionOrName); + ComponentName serviceComponent = resolveComponentName(intent); + if (serviceComponent == null) { + LogUtil.e(TAG, "Invalid component for %s intent", mServiceIntentActionOrName); + throw new IllegalStateException( + String.format("Invalid component for %s service", mServiceIntentActionOrName)); + } + intent.setComponent(serviceComponent); + return intent; + } + + /** + * Find the ComponentName of the service, given its intent and package manager. + * + * @return ComponentName of the service. Null if the service is not found. + */ + @Nullable + private ComponentName resolveComponentName(@NonNull Intent intent) { + List<ResolveInfo> services = + mContext.getPackageManager() + .queryIntentServices(intent, PackageManager.MATCH_SYSTEM_ONLY); + if (services == null || services.isEmpty()) { + LogUtil.e(TAG, "Failed to find service %s!", intent.getAction()); + return null; + } else if (services.size() != 1) { + LogUtil.i(TAG, "Found more than 1 (%d) service by intent %s!", services.size(), intent); + } + + for (ResolveInfo ri : services) { + // Check that found service has expected package. + if (ri != null + && ri.serviceInfo != null + && ri.serviceInfo.packageName != null + && mServicePackages.contains(ri.serviceInfo.packageName)) { + // There should only be one matching service inside the given package. + // If there's more than one, return the first one found. + LogUtil.d( + TAG, + "Resolved component with pkg %s, class %s", + ri.serviceInfo.packageName, + ri.serviceInfo.name); + return new ComponentName(ri.serviceInfo.packageName, ri.serviceInfo.name); + } else { + if (ri != null && ri.serviceInfo != null) { + LogUtil.d( + TAG, + "Resolved component with pkg %s, class %s", + ri.serviceInfo.packageName, + ri.serviceInfo.name); + } else { + LogUtil.d(TAG, "Resolved component is null or service info is null"); + } + } + } + LogUtil.e(TAG, "Didn't find any matching service %s.", intent.getAction()); + return null; + } + + public void unbindFromService() { + synchronized (mLock) { + if (mServiceConnection != null) { + LogUtil.d(TAG, "unbinding %s...", mServiceIntentActionOrName); + mContext.unbindService(mServiceConnection); + } + mServiceConnection = null; + mService = null; + } + } + + private class GenericServiceConnection implements ServiceConnection { + @Override + public void onServiceConnected(ComponentName name, IBinder service) { + LogUtil.d(TAG, "onServiceConnected " + mServiceIntentActionOrName); + synchronized (mLock) { + mService = mBinderConverter.apply(service); + } + mConnectionCountDownLatch.countDown(); + } + + @Override + public void onServiceDisconnected(ComponentName name) { + LogUtil.d(TAG, "onServiceDisconnected " + mServiceIntentActionOrName); + unbindFromService(); + mConnectionCountDownLatch.countDown(); + } + + @Override + public void onBindingDied(ComponentName name) { + LogUtil.e(TAG, "onBindingDied " + mServiceIntentActionOrName); + unbindFromService(); + mConnectionCountDownLatch.countDown(); + } + + @Override + public void onNullBinding(ComponentName name) { + LogUtil.e(TAG, "onNullBinding shouldn't happen. " + mServiceIntentActionOrName); + unbindFromService(); + mConnectionCountDownLatch.countDown(); + } + } +} diff --git a/framework/java/com/android/federatedcompute/internal/util/LogUtil.java b/framework/java/com/android/federatedcompute/internal/util/LogUtil.java new file mode 100644 index 00000000..4c1db654 --- /dev/null +++ b/framework/java/com/android/federatedcompute/internal/util/LogUtil.java @@ -0,0 +1,148 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.internal.util; + +import android.util.Log; + +import java.util.Locale; + +/** + * Logger for logging to logcat with the various logcat tags. + * + * @hide + */ +public final class LogUtil { + /* Unified TAG for logging across Federated Compute Program app. */ + public static final String TAG = "federatedcompute"; + + /* This is a static wrapper over standard Android logger. + * No instantiation assumed. */ + private LogUtil() {} + + /** Log the message as VERBOSE. Return The number of bytes written. */ + public static int v(String msg) { + if (Log.isLoggable(TAG, Log.VERBOSE)) { + return Log.v(TAG, msg); + } + return 0; + } + + /** Log the message as VERBOSE. Return The number of bytes written. */ + public static int v(String tag, String msg) { + if (Log.isLoggable(TAG, Log.VERBOSE)) { + return Log.v(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as DEBUG. Return The number of bytes written. */ + public static int d(String tag, String msg) { + if (Log.isLoggable(TAG, Log.DEBUG)) { + return Log.d(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as DEBUG. Return The number of bytes written. */ + public static int d(String tag, Throwable throwable, String msg) { + if (Log.isLoggable(TAG, Log.DEBUG)) { + return Log.d(TAG, tag + " - " + msg, throwable); + } + return 0; + } + + /** Log the message as DEBUG. Return The number of bytes written. */ + public static int d(String tag, String format, Object... params) { + if (Log.isLoggable(TAG, Log.DEBUG)) { + String msg = format(format, params); + return Log.d(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as INFO. Return The number of bytes written. */ + public static int i(String tag, String format, Object... params) { + if (Log.isLoggable(TAG, Log.INFO)) { + String msg = format(format, params); + return Log.i(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as WARN. Return The number of bytes written. */ + public static int w(String tag, String msg) { + if (Log.isLoggable(TAG, Log.WARN)) { + return Log.w(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as WARN. Return The number of bytes written. */ + public static int w(String tag, Throwable throwable, String msg) { + if (Log.isLoggable(TAG, Log.WARN)) { + return Log.w(TAG, tag + " - " + msg, throwable); + } + return 0; + } + + /** Log the message as WARN. Return The number of bytes written. */ + public static int w(String tag, String format, Object... params) { + if (Log.isLoggable(TAG, Log.WARN)) { + String msg = format(format, params); + return Log.w(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as ERROR. Return The number of bytes written. */ + public static int e(String tag, String msg) { + if (Log.isLoggable(TAG, Log.ERROR)) { + return Log.e(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as ERROR. Return The number of bytes written. */ + public static int e(String tag, Throwable throwable, String msg) { + if (Log.isLoggable(TAG, Log.ERROR)) { + return Log.e(TAG, tag + " - " + msg, throwable); + } + return 0; + } + + /** Log the message as ERROR. Return The number of bytes written. */ + public static int e(String tag, String format, Object... params) { + if (Log.isLoggable(TAG, Log.ERROR)) { + String msg = format(format, params); + return Log.e(TAG, tag + " - " + msg); + } + return 0; + } + + /** Log the message as ERROR. Return The number of bytes written. */ + public static int e(String tag, Throwable throwable, String format, Object... params) { + if (Log.isLoggable(TAG, Log.ERROR)) { + String msg = format(format, params); + return Log.e(TAG, tag + " - " + msg, throwable); + } + return 0; + } + + private static String format(String format, Object... args) { + return String.format(Locale.US, format, args); + } +} diff --git a/framework/java/com/android/ondevicepersonalization/internal/util/LoggerFactory.java b/framework/java/com/android/ondevicepersonalization/internal/util/LoggerFactory.java index 8dba2682..d900796a 100644 --- a/framework/java/com/android/ondevicepersonalization/internal/util/LoggerFactory.java +++ b/framework/java/com/android/ondevicepersonalization/internal/util/LoggerFactory.java @@ -35,7 +35,6 @@ public class LoggerFactory { return sLogger; } - /** * Logger for logging to logcat with the various logcat tags. * diff --git a/framework/java/com/android/ondevicepersonalization/internal/util/OdpParceledListSlice.java b/framework/java/com/android/ondevicepersonalization/internal/util/OdpParceledListSlice.java new file mode 100644 index 00000000..14a89e0d --- /dev/null +++ b/framework/java/com/android/ondevicepersonalization/internal/util/OdpParceledListSlice.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.internal.util; + +import android.annotation.NonNull; +import android.os.Parcel; +import android.os.Parcelable; + +import java.util.Collections; +import java.util.List; + +/** + * Transfer a large list of parcelable objects across an IPC. Splits into + * multiple transactions if needed. + * + * @param <T> Parcelable type for the List + * @hide + */ +public final class OdpParceledListSlice<T extends Parcelable> extends BaseOdpParceledListSlice<T> { + @NonNull + @SuppressWarnings("unchecked") + public static final Parcelable.ClassLoaderCreator<OdpParceledListSlice> CREATOR = + new Parcelable.ClassLoaderCreator<OdpParceledListSlice>() { + public OdpParceledListSlice createFromParcel(Parcel in) { + return new OdpParceledListSlice(in, null); + } + + @Override + public OdpParceledListSlice createFromParcel(Parcel in, ClassLoader loader) { + return new OdpParceledListSlice(in, loader); + } + + @Override + public OdpParceledListSlice[] newArray(int size) { + return new OdpParceledListSlice[size]; + } + }; + + public OdpParceledListSlice(@NonNull List<T> list) { + super(list); + } + + private OdpParceledListSlice(Parcel in, ClassLoader loader) { + super(in, loader); + } + + /** + * Returns an empty OdpParceledListSlice. + */ + @NonNull + public static <T extends Parcelable> OdpParceledListSlice<T> emptyList() { + return new OdpParceledListSlice<T>(Collections.<T>emptyList()); + } + + @Override + public int describeContents() { + int contents = 0; + final List<T> list = getList(); + for (int i = 0; i < list.size(); i++) { + contents |= list.get(i).describeContents(); + } + return contents; + } + + @Override + protected void writeElement(T parcelable, Parcel dest, int callFlags) { + parcelable.writeToParcel(dest, callFlags); + } + + @Override + protected void writeParcelableCreator(T parcelable, Parcel dest) { + dest.writeParcelableCreator((Parcelable) parcelable); + } + + @Override + protected Parcelable.Creator<?> readParcelableCreator(Parcel from, ClassLoader loader) { + return from.readParcelableCreator(loader); + } +} diff --git a/src/com/android/ondevicepersonalization/services/Flags.java b/src/com/android/ondevicepersonalization/services/Flags.java index dbd7c897..86615ac3 100644 --- a/src/com/android/ondevicepersonalization/services/Flags.java +++ b/src/com/android/ondevicepersonalization/services/Flags.java @@ -23,21 +23,55 @@ package com.android.ondevicepersonalization.services; * generated from the GCL. */ public interface Flags { + /** + * Global OnDevicePersonalization Kill Switch. This overrides all other killswitches. + * The default value is true which means OnDevicePersonalization is disabled. + * This flag is used for ramp-up and emergency turning off the whole module. + */ + boolean GLOBAL_KILL_SWITCH = true; - boolean ONDEVICEPERSONALIZATION_ENABLED = false; + /** + * P/H flag to enable all APIs under OnDevicePersonalization (ODP). + * The default value is false, which means all APIs are disabled. + * This flag is used for ramp-up and emergency turning off ODP API. + */ + boolean ENABLE_ONDEVICEPERSONALIZATION_APIS = false; - default boolean getOnDevicePersonalizationEnabled() { - return ONDEVICEPERSONALIZATION_ENABLED; - } + /** + * P/H flag to override the personalization status for end-to-end tests. + * The default value is false, which means UserPrivacyStatus#personalizationStatus is not + * override by PERSONALIZATION_STATUS_OVERRIDE_VALUE. If true, returns the personalization + * status in PERSONALIZATION_STATUS_OVERRIDE_VALUE. + */ + boolean ENABLE_PERSONALIZATION_STATUS_OVERRIDE = false; /** - * Global OnDevicePersonalization Kill Switch. This overrides all other killswitches. - * The default value is false which means OnDevicePersonalization is enabled. - * This flag is used for emergency turning off the whole module. + * Value of the personalization status, if ENABLE_PERSONALIZATION_STATUS_OVERRIDE is true. */ - boolean GLOBAL_KILL_SWITCH = true; + boolean PERSONALIZATION_STATUS_OVERRIDE_VALUE = false; + + /** + * Deadline for calls from ODP to isolated services. + */ + int ISOLATED_SERVICE_DEADLINE_SECONDS = 30; default boolean getGlobalKillSwitch() { return GLOBAL_KILL_SWITCH; } + + default boolean isOnDevicePersonalizationApisEnabled() { + return ENABLE_ONDEVICEPERSONALIZATION_APIS; + } + + default boolean isPersonalizationStatusOverrideEnabled() { + return ENABLE_PERSONALIZATION_STATUS_OVERRIDE; + } + + default boolean getPersonalizationStatusOverrideValue() { + return PERSONALIZATION_STATUS_OVERRIDE_VALUE; + } + + default int getIsolatedServiceDeadlineSeconds() { + return ISOLATED_SERVICE_DEADLINE_SECONDS; + } } diff --git a/src/com/android/ondevicepersonalization/services/OdpServiceException.java b/src/com/android/ondevicepersonalization/services/OdpServiceException.java new file mode 100644 index 00000000..d46a7867 --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/OdpServiceException.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services; + +import android.annotation.NonNull; + +/** + * Exception thrown inside the ODP service. + * + * @hide + */ +public class OdpServiceException extends Exception { + private final int mErrorCode; + + public OdpServiceException(int errorCode) { + this(errorCode, ""); + } + + public OdpServiceException(int errorCode, @NonNull String errorMessage) { + super("Error code: " + errorCode + " message: " + errorMessage); + mErrorCode = errorCode; + } + + public OdpServiceException(int errorCode, @NonNull Throwable cause) { + this(errorCode, "", cause); + } + + public OdpServiceException( + int errorCode, @NonNull String errorMessage, @NonNull Throwable cause) { + super("Error code: " + errorCode + " message: " + errorMessage, cause); + mErrorCode = errorCode; + } + + /** Returns the error code for this exception. */ + public int getErrorCode() { + return mErrorCode; + } +} diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationApplication.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationApplication.java index 4e9a6682..52b20ddd 100644 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationApplication.java +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationApplication.java @@ -16,8 +16,10 @@ package com.android.ondevicepersonalization.services; +import android.annotation.NonNull; import android.annotation.Nullable; import android.app.Application; +import android.content.Context; import com.android.ondevicepersonalization.libraries.plugin.PluginApplication; import com.android.ondevicepersonalization.libraries.plugin.PluginHost; @@ -26,7 +28,18 @@ import com.android.ondevicepersonalization.services.process.OnDevicePersonalizat /** The Application class for OnDevicePersonalization. */ public final class OnDevicePersonalizationApplication extends Application implements PluginApplication { + @NonNull private static Context sApplicationContext; + @Override public @Nullable PluginHost getPluginHost() { return new OnDevicePersonalizationPluginHost(this); } + + @Override public void onCreate() { + super.onCreate(); + sApplicationContext = getApplicationContext(); + } + + @NonNull public static Context getAppContext() { + return sApplicationContext; + } } diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiver.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiver.java index 7dd39dfd..128d6868 100644 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiver.java +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiver.java @@ -23,13 +23,18 @@ import android.content.ComponentName; import android.content.Context; import android.content.Intent; import android.content.pm.PackageManager; - +import android.ondevicepersonalization.IOnDevicePersonalizationSystemService; +import android.ondevicepersonalization.IOnDevicePersonalizationSystemServiceCallback; +import android.ondevicepersonalization.OnDevicePersonalizationSystemServiceManager; +import android.os.Bundle; +import android.os.RemoteException; import com.android.internal.annotations.VisibleForTesting; +import com.android.modules.utils.build.SdkLevel; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import com.android.ondevicepersonalization.services.data.user.UserDataCollectionJobService; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.download.mdd.MobileDataDownloadFactory; -import com.android.ondevicepersonalization.services.federatedcompute.OdpFederatedComputeJobService; import com.android.ondevicepersonalization.services.maintenance.OnDevicePersonalizationMaintenanceJobService; import com.android.ondevicepersonalization.services.policyengine.api.ChronicleManager; import com.android.ondevicepersonalization.services.policyengine.data.impl.UserDataConnectionProvider; @@ -42,12 +47,12 @@ import java.util.Arrays; import java.util.HashSet; import java.util.concurrent.Executor; -/** - * BroadcastReceiver used to schedule OnDevicePersonalization jobs/workers. - */ +/** BroadcastReceiver used to schedule OnDevicePersonalization jobs/workers. */ public class OnDevicePersonalizationBroadcastReceiver extends BroadcastReceiver { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "OnDevicePersonalizationBroadcastReceiver"; + private static final String PERSONALIZATION_STATUS_KEY = "PERSONALIZATION_STATUS"; + private static final int KEY_NOT_FOUND_ERROR = 404; private final Executor mExecutor; public OnDevicePersonalizationBroadcastReceiver() { @@ -64,8 +69,8 @@ public class OnDevicePersonalizationBroadcastReceiver extends BroadcastReceiver try { context.getPackageManager() .setComponentEnabledSetting( - new ComponentName(context, - OnDevicePersonalizationBroadcastReceiver.class), + new ComponentName( + context, OnDevicePersonalizationBroadcastReceiver.class), COMPONENT_ENABLED_STATE_ENABLED, PackageManager.DONT_KILL_APP); } catch (IllegalArgumentException e) { @@ -75,15 +80,53 @@ public class OnDevicePersonalizationBroadcastReceiver extends BroadcastReceiver return true; } - /** - * Called when the broadcast is received. OnDevicePersonalization jobs will be started here. - */ + /** Called when the broadcast is received. OnDevicePersonalization jobs will be started here. */ public void onReceive(Context context, Intent intent) { sLogger.d(TAG + ": onReceive() with intent + " + intent.getAction()); if (!Intent.ACTION_BOOT_COMPLETED.equals(intent.getAction())) { sLogger.d(TAG + ": Received unexpected intent " + intent.getAction()); return; } + // Restore personalization status from the system server on U+ devices. + if (SdkLevel.isAtLeastU()) { + OnDevicePersonalizationSystemServiceManager systemServiceManager = + context.getSystemService(OnDevicePersonalizationSystemServiceManager.class); + if (systemServiceManager != null) { + IOnDevicePersonalizationSystemService systemService = + systemServiceManager.getService(); + if (systemService != null) { + try { + systemService.readPersonalizationStatus( + new IOnDevicePersonalizationSystemServiceCallback.Stub() { + @Override + public void onResult(Bundle bundle) { + boolean personalizationStatus = + bundle.getBoolean(PERSONALIZATION_STATUS_KEY); + UserPrivacyStatus.getInstance() + .setPersonalizationStatusEnabled( + personalizationStatus); + } + + @Override + public void onError(int errorCode) { + if (errorCode == KEY_NOT_FOUND_ERROR) { + sLogger.d( + TAG + + ": Personalization status " + + "not found in the system server"); + } + } + }); + } catch (RemoteException e) { + sLogger.e(TAG + ": Callback error."); + } + } else { + sLogger.w(TAG + ": System service is not ready."); + } + } else { + sLogger.w(TAG + ": Cannot find system server on U+ devices."); + } + } // Initialize policy engine instance ChronicleManager.getInstance( @@ -93,9 +136,6 @@ public class OnDevicePersonalizationBroadcastReceiver extends BroadcastReceiver // Schedule maintenance task OnDevicePersonalizationMaintenanceJobService.schedule(context); - // Schedule federatedCompute task - OdpFederatedComputeJobService.schedule(context); - // Schedule user data collection task UserDataCollectionJobService.schedule(context); diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfig.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfig.java index 25d08192..f0fcb14e 100644 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfig.java +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfig.java @@ -18,51 +18,45 @@ package com.android.ondevicepersonalization.services; import com.android.ondevicepersonalization.services.data.user.UserDataCollectionJobService; import com.android.ondevicepersonalization.services.download.OnDevicePersonalizationDownloadProcessingJobService; -import com.android.ondevicepersonalization.services.federatedcompute.OdpFederatedComputeJobService; import com.android.ondevicepersonalization.services.maintenance.OnDevicePersonalizationMaintenanceJobService; -/** - * Hard-coded configs for OnDevicePersonalization - */ +/** Hard-coded configs for OnDevicePersonalization */ public class OnDevicePersonalizationConfig { - private OnDevicePersonalizationConfig() { - } + private OnDevicePersonalizationConfig() {} - /** Job ID for Mdd Maintenance Task - * ({@link com.android.ondevicepersonalization.services.download.mdd.MddJobService}) */ + /** + * Job ID for Mdd Maintenance Task ({@link + * com.android.ondevicepersonalization.services.download.mdd.MddJobService}) + */ public static final int MDD_MAINTENANCE_PERIODIC_TASK_JOB_ID = 1000; /** - * Job ID for Mdd Charging Periodic Task - * ({@link com.android.ondevicepersonalization.services.download.mdd.MddJobService}) + * Job ID for Mdd Charging Periodic Task ({@link + * com.android.ondevicepersonalization.services.download.mdd.MddJobService}) */ public static final int MDD_CHARGING_PERIODIC_TASK_JOB_ID = 1001; /** - * Job ID for Mdd Cellular Charging Task - * ({@link com.android.ondevicepersonalization.services.download.mdd.MddJobService}) + * Job ID for Mdd Cellular Charging Task ({@link + * com.android.ondevicepersonalization.services.download.mdd.MddJobService}) */ public static final int MDD_CELLULAR_CHARGING_PERIODIC_TASK_JOB_ID = 1002; - /** Job ID for Mdd Wifi Charging Task - * ({@link com.android.ondevicepersonalization.services.download.mdd.MddJobService}) */ + /** + * Job ID for Mdd Wifi Charging Task ({@link + * com.android.ondevicepersonalization.services.download.mdd.MddJobService}) + */ public static final int MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID = 1003; - /** Job ID for Download Processing Task - * ({@link OnDevicePersonalizationDownloadProcessingJobService}) */ + /** + * Job ID for Download Processing Task ({@link + * OnDevicePersonalizationDownloadProcessingJobService}) + */ public static final int DOWNLOAD_PROCESSING_TASK_JOB_ID = 1004; - /** Job ID for Maintenance Task - * ({@link OnDevicePersonalizationMaintenanceJobService}) */ + /** Job ID for Maintenance Task ({@link OnDevicePersonalizationMaintenanceJobService}) */ public static final int MAINTENANCE_TASK_JOB_ID = 1005; - /** Job ID for User Data Collection Task - * ({@link UserDataCollectionJobService}) */ + /** Job ID for User Data Collection Task ({@link UserDataCollectionJobService}) */ public static final int USER_DATA_COLLECTION_ID = 1006; - - /** Job ID for Maintenance Task - * ({@link OdpFederatedComputeJobService}) */ - public static final int FEDERATED_COMPUTE_TASK_JOB_ID = 1007; - - public static final String ODP_POPULATION_NAME = "odp"; } diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceDelegate.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceDelegate.java new file mode 100644 index 00000000..3f57e94c --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceDelegate.java @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services; + +import static android.adservices.ondevicepersonalization.OnDevicePersonalizationPermissions.MODIFY_ONDEVICEPERSONALIZATION_STATE; + +import android.adservices.ondevicepersonalization.OnDevicePersonalizationPermissions; +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationConfigService; +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationConfigServiceCallback; +import android.annotation.NonNull; +import android.annotation.RequiresPermission; +import android.content.Context; +import android.ondevicepersonalization.IOnDevicePersonalizationSystemService; +import android.ondevicepersonalization.IOnDevicePersonalizationSystemServiceCallback; +import android.ondevicepersonalization.OnDevicePersonalizationSystemServiceManager; +import android.os.Binder; +import android.os.Bundle; +import android.os.RemoteException; + +import com.android.modules.utils.build.SdkLevel; +import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.data.user.RawUserData; +import com.android.ondevicepersonalization.services.data.user.UserDataCollector; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; + +import java.util.Objects; +import java.util.concurrent.Executor; + +/** + * ODP service that modifies and persists ODP enablement status + */ +public class OnDevicePersonalizationConfigServiceDelegate + extends IOnDevicePersonalizationConfigService.Stub { + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + private static final String TAG = "OnDevicePersonalizationConfigServiceDelegate"; + private final Context mContext; + private static final Executor sBackgroundExecutor = + OnDevicePersonalizationExecutors.getBackgroundExecutor(); + private static final int SERVICE_NOT_IMPLEMENTED = 501; + + public OnDevicePersonalizationConfigServiceDelegate(Context context) { + mContext = context; + } + + @Override + @RequiresPermission(MODIFY_ONDEVICEPERSONALIZATION_STATE) + public void setPersonalizationStatus(boolean enabled, + @NonNull IOnDevicePersonalizationConfigServiceCallback + callback) { + if (!isOnDevicePersonalizationApisEnabled()) { + throw new IllegalStateException("Service skipped as the API flag is turned off."); + } + // Verify caller's permission + OnDevicePersonalizationPermissions.enforceCallingPermission(mContext, + MODIFY_ONDEVICEPERSONALIZATION_STATE); + Objects.requireNonNull(callback); + sBackgroundExecutor.execute( + () -> { + try { + UserPrivacyStatus userPrivacyStatus = UserPrivacyStatus.getInstance(); + + boolean oldStatus = userPrivacyStatus.isPersonalizationStatusEnabled(); + userPrivacyStatus.setPersonalizationStatusEnabled(enabled); + boolean newStatus = userPrivacyStatus.isPersonalizationStatusEnabled(); + + if (oldStatus == newStatus) { + callback.onSuccess(); + return; + } + + // Rollback all user data if personalization status changes + RawUserData userData = RawUserData.getInstance(); + UserDataCollector userDataCollector = + UserDataCollector.getInstance(mContext); + userDataCollector.clearUserData(userData); + userDataCollector.clearMetadata(); + userDataCollector.clearDatabase(); + + // TODO(b/302018665): replicate system server storage to T devices. + if (!SdkLevel.isAtLeastU()) { + userPrivacyStatus.setPersonalizationStatusEnabled(enabled); + callback.onSuccess(); + return; + } + // Persist in the system server for U+ devices + OnDevicePersonalizationSystemServiceManager systemServiceManager = + mContext.getSystemService( + OnDevicePersonalizationSystemServiceManager.class); + // Cannot find system server on U+. + if (systemServiceManager == null) { + callback.onFailure(SERVICE_NOT_IMPLEMENTED); + return; + } + IOnDevicePersonalizationSystemService systemService = + systemServiceManager.getService(); + // The system service is not ready. + if (systemService == null) { + callback.onFailure(SERVICE_NOT_IMPLEMENTED); + return; + } + systemService.setPersonalizationStatus(enabled, + new IOnDevicePersonalizationSystemServiceCallback.Stub() { + @Override + public void onResult(Bundle bundle) throws RemoteException { + userPrivacyStatus.setPersonalizationStatusEnabled(enabled); + callback.onSuccess(); + } + + @Override + public void onError(int errorCode) throws RemoteException { + callback.onFailure(errorCode); + } + }); + } catch (RemoteException re) { + sLogger.e(TAG + ": Unable to send result to the callback.", re); + } + } + ); + } + + private boolean isOnDevicePersonalizationApisEnabled() { + long origId = Binder.clearCallingIdentity(); + boolean isOnDevicePersonalizationApisEnabled = + FlagsFactory.getFlags().isOnDevicePersonalizationApisEnabled(); + Binder.restoreCallingIdentity(origId); + return isOnDevicePersonalizationApisEnabled; + } +} diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceImpl.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceImpl.java index 0340b393..53b50e59 100644 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceImpl.java +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceImpl.java @@ -23,14 +23,14 @@ import android.os.IBinder; /** * ODP service that modifies and persists user's privacy status. */ -public class OnDevicePersonalizationPrivacyStatusServiceImpl extends Service { +public class OnDevicePersonalizationConfigServiceImpl extends Service { /** Binder interface. */ - private OnDevicePersonalizationPrivacyStatusServiceDelegate mBinder; + private OnDevicePersonalizationConfigServiceDelegate mBinder; @Override public void onCreate() { - mBinder = new OnDevicePersonalizationPrivacyStatusServiceDelegate(this); + mBinder = new OnDevicePersonalizationConfigServiceDelegate(this); } @Override diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceDelegate.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceDelegate.java new file mode 100644 index 00000000..f3fb88cf --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceDelegate.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services; + + +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationDebugService; +import android.annotation.NonNull; +import android.content.Context; + +import com.android.ondevicepersonalization.services.util.DebugUtils; + +import java.util.Objects; + +/** + * Service that provides test and debug APIs. + */ +public class OnDevicePersonalizationDebugServiceDelegate + extends IOnDevicePersonalizationDebugService.Stub { + @NonNull private final Context mContext; + + public OnDevicePersonalizationDebugServiceDelegate(@NonNull Context context) { + mContext = Objects.requireNonNull(context); + } + + @Override + public boolean isEnabled() { + return DebugUtils.isDeveloperModeEnabled(mContext); + } +} diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceImpl.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceImpl.java new file mode 100644 index 00000000..b4a46259 --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceImpl.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services; + +import android.app.Service; +import android.content.Intent; +import android.os.IBinder; + +/** + * Service that provides test and debug APIs. + */ +public class OnDevicePersonalizationDebugServiceImpl extends Service { + + /** Binder interface. */ + private OnDevicePersonalizationDebugServiceDelegate mBinder; + + @Override + public void onCreate() { + mBinder = new OnDevicePersonalizationDebugServiceDelegate(this); + } + + @Override + public IBinder onBind(Intent intent) { + return mBinder; + } +} diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationExecutors.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationExecutors.java index 837e5eb3..3d03f785 100644 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationExecutors.java +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationExecutors.java @@ -24,6 +24,7 @@ import android.os.StrictMode; import android.os.StrictMode.ThreadPolicy; import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -52,6 +53,12 @@ public final class OnDevicePersonalizationExecutors { createThreadFactory("Blocking Thread", Process.THREAD_PRIORITY_BACKGROUND + Process.THREAD_PRIORITY_LESS_FAVORABLE, Optional.empty()))); + private static final ListeningScheduledExecutorService sScheduledExecutor = + MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool( + /* nThreads */ 4, + createThreadFactory("SCH Thread", Process.THREAD_PRIORITY_BACKGROUND, + Optional.of(getIoThreadPolicy())))); + private static final HandlerThread sHandlerThread = createHandlerThread(); private static final Handler sHandler = new Handler(sHandlerThread.getLooper()); @@ -86,6 +93,14 @@ public final class OnDevicePersonalizationExecutors { } /** + * Returns an executor that can start tasks after a delay. + */ + @NonNull + public static ListeningScheduledExecutorService getScheduledExecutor() { + return sScheduledExecutor; + } + + /** * Returns a Handler that can post messages to a HandlerThread. */ public static Handler getHandler() { diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceDelegate.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceDelegate.java index 57597a58..75b046ea 100644 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceDelegate.java +++ b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceDelegate.java @@ -16,6 +16,7 @@ package com.android.ondevicepersonalization.services; +import android.adservices.ondevicepersonalization.CallerMetadata; import android.adservices.ondevicepersonalization.aidl.IExecuteCallback; import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationManagingService; import android.adservices.ondevicepersonalization.aidl.IRequestSurfacePackageCallback; @@ -45,9 +46,10 @@ public class OnDevicePersonalizationManagingServiceDelegate ComponentName handler, PersistableBundle params, IExecuteCallback callback, - Context context) { + Context context, + long startTimeMillis) { return new AppRequestFlow( - callingPackageName, handler, params, callback, context); + callingPackageName, handler, params, callback, context, startTimeMillis); } RenderFlow getRenderFlow( @@ -57,9 +59,11 @@ public class OnDevicePersonalizationManagingServiceDelegate int width, int height, IRequestSurfacePackageCallback callback, - Context context) { + Context context, + long startTimeMillis) { return new RenderFlow( - slotResultToken, hostToken, displayId, width, height, callback, context); + slotResultToken, hostToken, displayId, width, height, callback, context, + startTimeMillis); } } @@ -87,19 +91,28 @@ public class OnDevicePersonalizationManagingServiceDelegate @NonNull String callingPackageName, @NonNull ComponentName handler, @NonNull PersistableBundle params, + @NonNull CallerMetadata metadata, @NonNull IExecuteCallback callback) { - long origId = Binder.clearCallingIdentity(); - if (FlagsFactory.getFlags().getGlobalKillSwitch()) { + if (getGlobalKillSwitch()) { throw new IllegalStateException("Service skipped as the global kill switch is on."); } - Binder.restoreCallingIdentity(origId); Objects.requireNonNull(callingPackageName); Objects.requireNonNull(handler); Objects.requireNonNull(handler.getPackageName()); Objects.requireNonNull(handler.getClassName()); Objects.requireNonNull(params); + Objects.requireNonNull(metadata); Objects.requireNonNull(callback); + if (callingPackageName.isEmpty()) { + throw new IllegalArgumentException("missing app package name"); + } + if (handler.getPackageName().isEmpty()) { + throw new IllegalArgumentException("missing service package name"); + } + if (handler.getClassName().isEmpty()) { + throw new IllegalArgumentException("missing service class name"); + } final int uid = Binder.getCallingUid(); enforceCallingPackageBelongsToUid(callingPackageName, uid); @@ -109,7 +122,8 @@ public class OnDevicePersonalizationManagingServiceDelegate handler, params, callback, - mContext); + mContext, + metadata.getStartTimeMillis()); flow.run(); } @@ -120,12 +134,11 @@ public class OnDevicePersonalizationManagingServiceDelegate int displayId, int width, int height, + @NonNull CallerMetadata metadata, @NonNull IRequestSurfacePackageCallback callback) { - long origId = Binder.clearCallingIdentity(); - if (FlagsFactory.getFlags().getGlobalKillSwitch()) { + if (getGlobalKillSwitch()) { throw new IllegalStateException("Service skipped as the global kill switch is on."); } - Binder.restoreCallingIdentity(origId); Objects.requireNonNull(slotResultToken); Objects.requireNonNull(hostToken); @@ -149,10 +162,17 @@ public class OnDevicePersonalizationManagingServiceDelegate width, height, callback, - mContext); + mContext, + metadata.getStartTimeMillis()); flow.run(); } + private boolean getGlobalKillSwitch() { + long origId = Binder.clearCallingIdentity(); + boolean globalKillSwitch = FlagsFactory.getFlags().getGlobalKillSwitch(); + Binder.restoreCallingIdentity(origId); + return globalKillSwitch; + } private void enforceCallingPackageBelongsToUid(@NonNull String packageName, int uid) { int packageUid; PackageManager pm = mContext.getPackageManager(); diff --git a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceDelegate.java b/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceDelegate.java deleted file mode 100644 index 85f91bde..00000000 --- a/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceDelegate.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services; - -import android.adservices.ondevicepersonalization.aidl.IPrivacyStatusService; -import android.adservices.ondevicepersonalization.aidl.IPrivacyStatusServiceCallback; -import android.annotation.NonNull; -import android.content.Context; -import android.os.RemoteException; - - -import com.android.ondevicepersonalization.internal.util.LoggerFactory; -import com.android.ondevicepersonalization.services.data.user.PrivacySignal; -import com.android.ondevicepersonalization.services.data.user.RawUserData; -import com.android.ondevicepersonalization.services.data.user.UserDataCollector; - -import java.util.Objects; -import java.util.concurrent.Executor; - -/** - * ODP service that modifies and persists user's privacy status. - */ -public class OnDevicePersonalizationPrivacyStatusServiceDelegate - extends IPrivacyStatusService.Stub { - private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); - private static final String TAG = "OnDevicePersonalizationPrivacyStatusServiceDelegate"; - private final Context mContext; - private static final Executor sBackgroundExecutor = - OnDevicePersonalizationExecutors.getBackgroundExecutor(); - - public OnDevicePersonalizationPrivacyStatusServiceDelegate(Context context) { - mContext = context; - } - - @Override - public void setKidStatus(boolean kidStatusEnabled, - @NonNull IPrivacyStatusServiceCallback callback) { - Objects.requireNonNull(callback); - // TODO(b/272823829): Verify caller's permission - // TODO(b/270468742): Call system server for U+ devices - sBackgroundExecutor.execute( - () -> { - try { - PrivacySignal privacySignal = PrivacySignal.getInstance(); - - if (kidStatusEnabled == privacySignal.isKidStatusEnabled()) { - callback.onSuccess(); - return; - } - - privacySignal.setKidStatusEnabled(kidStatusEnabled); - // Rollback all user data if kid status changes - RawUserData userData = RawUserData.getInstance(); - UserDataCollector userDataCollector = - UserDataCollector.getInstance(mContext); - userDataCollector.clearUserData(userData); - userDataCollector.clearMetadata(); - userDataCollector.clearDatabase(); - callback.onSuccess(); - } catch (RemoteException re) { - sLogger.e(TAG + ": Unable to send result to the callback.", re); - } - } - ); - } -} diff --git a/src/com/android/ondevicepersonalization/services/PhFlags.java b/src/com/android/ondevicepersonalization/services/PhFlags.java index b5dc0db0..b253e044 100644 --- a/src/com/android/ondevicepersonalization/services/PhFlags.java +++ b/src/com/android/ondevicepersonalization/services/PhFlags.java @@ -17,11 +17,8 @@ package com.android.ondevicepersonalization.services; import android.annotation.NonNull; -import android.os.SystemProperties; import android.provider.DeviceConfig; -import com.android.internal.annotations.VisibleForTesting; - /** Flags Implementation that delegates to DeviceConfig. */ // TODO(b/228037065): Add validation logics for Feature flags read from PH. public final class PhFlags implements Flags { @@ -31,8 +28,17 @@ public final class PhFlags implements Flags { // Killswitch keys static final String KEY_GLOBAL_KILL_SWITCH = "global_kill_switch"; - // SystemProperty prefix. SystemProperty is for overriding OnDevicePersonalization Configs. - private static final String SYSTEM_PROPERTY_PREFIX = "debug.ondevicepersonalization."; + static final String KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS = + "enable_ondevicepersonalization_apis"; + + static final String KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE = + "enable_personalization_status_override"; + + static final String KEY_PERSONALIZATION_STATUS_OVERRIDE_VALUE = + "personalization_status_override_value"; + + static final String KEY_ISOLATED_SERVICE_DEADLINE_SECONDS = + "isolated_service_deadline_seconds"; // OnDevicePersonalization Namespace String from DeviceConfig class static final String NAMESPACE_ON_DEVICE_PERSONALIZATION = "on_device_personalization"; @@ -47,18 +53,55 @@ public final class PhFlags implements Flags { // Group of All Killswitches @Override public boolean getGlobalKillSwitch() { - // The priority of applying the flag values: SystemProperties, PH (DeviceConfig), - // then hard-coded value. - return SystemProperties.getBoolean( - getSystemPropertyName(KEY_GLOBAL_KILL_SWITCH), - DeviceConfig.getBoolean( + // The priority of applying the flag values: PH (DeviceConfig), then hard-coded value. + return DeviceConfig.getBoolean( /* namespace= */ NAMESPACE_ON_DEVICE_PERSONALIZATION, /* name= */ KEY_GLOBAL_KILL_SWITCH, - /* defaultValue= */ GLOBAL_KILL_SWITCH)); + /* defaultValue= */ GLOBAL_KILL_SWITCH); + } + + @Override + public boolean isOnDevicePersonalizationApisEnabled() { + if (getGlobalKillSwitch()) { + return false; + } + // The priority of applying the flag values: PH (DeviceConfig), then user hard-coded value. + return DeviceConfig.getBoolean( + /* namespace= */ NAMESPACE_ON_DEVICE_PERSONALIZATION, + /* name= */ KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS, + /* defaultValue= */ ENABLE_ONDEVICEPERSONALIZATION_APIS); } - @VisibleForTesting - static String getSystemPropertyName(String key) { - return SYSTEM_PROPERTY_PREFIX + key; + @Override + public boolean isPersonalizationStatusOverrideEnabled() { + if (getGlobalKillSwitch()) { + return false; + } + // The priority of applying the flag values: PH (DeviceConfig), then user hard-coded value. + return DeviceConfig.getBoolean( + /* namespace= */ NAMESPACE_ON_DEVICE_PERSONALIZATION, + /* name= */ KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE, + /* defaultValue= */ ENABLE_PERSONALIZATION_STATUS_OVERRIDE); + } + + @Override + public boolean getPersonalizationStatusOverrideValue() { + if (getGlobalKillSwitch()) { + return false; + } + // The priority of applying the flag values: PH (DeviceConfig), then user hard-coded value. + return DeviceConfig.getBoolean( + /* namespace= */ NAMESPACE_ON_DEVICE_PERSONALIZATION, + /* name= */ KEY_PERSONALIZATION_STATUS_OVERRIDE_VALUE, + /* defaultValue= */ PERSONALIZATION_STATUS_OVERRIDE_VALUE); + } + + @Override + public int getIsolatedServiceDeadlineSeconds() { + // The priority of applying the flag values: PH (DeviceConfig), then user hard-coded value. + return DeviceConfig.getInt( + /* namespace= */ NAMESPACE_ON_DEVICE_PERSONALIZATION, + /* name= */ KEY_ISOLATED_SERVICE_DEADLINE_SECONDS, + /* defaultValue= */ ISOLATED_SERVICE_DEADLINE_SECONDS); } } diff --git a/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImpl.java b/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImpl.java index ab1f5440..dd612342 100644 --- a/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImpl.java +++ b/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImpl.java @@ -17,6 +17,8 @@ package com.android.ondevicepersonalization.services.data; import android.adservices.ondevicepersonalization.Constants; +import android.adservices.ondevicepersonalization.EventLogRecord; +import android.adservices.ondevicepersonalization.RequestLogRecord; import android.adservices.ondevicepersonalization.aidl.IDataAccessService; import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; import android.annotation.NonNull; @@ -28,21 +30,27 @@ import android.os.Bundle; import android.os.PersistableBundle; import android.os.RemoteException; - import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; import com.android.ondevicepersonalization.services.data.events.EventUrlHelper; import com.android.ondevicepersonalization.services.data.events.EventUrlPayload; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import com.android.ondevicepersonalization.services.data.events.JoinedEvent; +import com.android.ondevicepersonalization.services.data.events.Query; import com.android.ondevicepersonalization.services.data.vendor.LocalData; import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonalizationLocalDataDao; import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonalizationVendorDataDao; +import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; import com.android.ondevicepersonalization.services.util.PackageUtils; import com.google.common.util.concurrent.ListeningExecutorService; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Objects; /** @@ -52,44 +60,26 @@ import java.util.Objects; public class DataAccessServiceImpl extends IDataAccessService.Stub { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "DataAccessServiceImpl"; - - @VisibleForTesting - static class Injector { - long getTimeMillis() { - return System.currentTimeMillis(); - } - - ListeningExecutorService getExecutor() { - return OnDevicePersonalizationExecutors.getBackgroundExecutor(); - } - - OnDevicePersonalizationVendorDataDao getVendorDataDao( - Context context, String packageName, String certDigest - ) { - return OnDevicePersonalizationVendorDataDao.getInstance(context, - packageName, certDigest); - } - - OnDevicePersonalizationLocalDataDao getLocalDataDao( - Context context, String packageName, String certDigest - ) { - return OnDevicePersonalizationLocalDataDao.getInstance(context, - packageName, certDigest); - } - } - - @NonNull private final Context mApplicationContext; - @NonNull private final String mServicePackageName; + @NonNull + private final Context mApplicationContext; + @NonNull + private final String mServicePackageName; private final OnDevicePersonalizationVendorDataDao mVendorDataDao; - @Nullable private final OnDevicePersonalizationLocalDataDao mLocalDataDao; + @Nullable + private final OnDevicePersonalizationLocalDataDao mLocalDataDao; + @Nullable + private final EventsDao mEventsDao; private final boolean mIncludeLocalData; - @NonNull private final Injector mInjector; + private final boolean mIncludeEventData; + @NonNull + private final Injector mInjector; public DataAccessServiceImpl( @NonNull String servicePackageName, @NonNull Context applicationContext, - boolean includeLocalData) { - this(servicePackageName, applicationContext, includeLocalData, + boolean includeLocalData, + boolean includeEventData) { + this(servicePackageName, applicationContext, includeLocalData, includeEventData, new Injector()); } @@ -98,6 +88,7 @@ public class DataAccessServiceImpl extends IDataAccessService.Stub { @NonNull String servicePackageName, @NonNull Context applicationContext, boolean includeLocalData, + boolean includeEventData, @NonNull Injector injector) { mApplicationContext = Objects.requireNonNull(applicationContext, "applicationContext"); mServicePackageName = Objects.requireNonNull(servicePackageName, "servicePackageName"); @@ -118,6 +109,12 @@ public class DataAccessServiceImpl extends IDataAccessService.Stub { throw new IllegalArgumentException("Package: " + servicePackageName + " does not exist.", nnfe); } + mIncludeEventData = includeEventData; + if (includeEventData) { + mEventsDao = mInjector.getEventsDao(mApplicationContext); + } else { + mEventsDao = null; + } } /** Handle a request from the isolated process. */ @@ -196,6 +193,32 @@ public class DataAccessServiceImpl extends IDataAccessService.Stub { eventParams, responseData, mimeType, destinationUrl, callback) ); break; + case Constants.DATA_ACCESS_OP_GET_REQUESTS: + if (!mIncludeEventData) { + throw new IllegalStateException( + "request and event data are not included for this instance."); + } + long[] requestTimes = Objects.requireNonNull(params.getLongArray( + Constants.EXTRA_LOOKUP_KEYS)); + if (requestTimes.length != 2) { + throw new IllegalArgumentException("Invalid request timestamps provided."); + } + mInjector.getExecutor().execute( + () -> getRequests(requestTimes[0], requestTimes[1], callback)); + break; + case Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS: + if (!mIncludeEventData) { + throw new IllegalStateException( + "request and event data are not included for this instance."); + } + long[] eventTimes = Objects.requireNonNull(params.getLongArray( + Constants.EXTRA_LOOKUP_KEYS)); + if (eventTimes.length != 2) { + throw new IllegalArgumentException("Invalid event timestamps provided."); + } + mInjector.getExecutor().execute( + () -> getJoinedEvents(eventTimes[0], eventTimes[1], callback)); + break; default: sendError(callback); } @@ -281,7 +304,7 @@ public class DataAccessServiceImpl extends IDataAccessService.Stub { @NonNull IDataAccessServiceCallback callback) { try { sLogger.d(TAG, ": getEventUrl() started."); - EventUrlPayload payload = new EventUrlPayload(eventParams, responseData, mimeType); + EventUrlPayload payload = new EventUrlPayload(eventParams, responseData, mimeType); Uri eventUrl; if (destinationUrl == null || destinationUrl.isEmpty()) { eventUrl = EventUrlHelper.getEncryptedOdpEventUrl(payload); @@ -299,6 +322,63 @@ public class DataAccessServiceImpl extends IDataAccessService.Stub { } } + private void getRequests(long startTimeMillis, long endTimeMillis, + @NonNull IDataAccessServiceCallback callback) { + try { + List<Query> queries = mEventsDao.readAllQueries(startTimeMillis, endTimeMillis, + mServicePackageName); + List<RequestLogRecord> requestLogRecords = new ArrayList<>(); + for (Query query : queries) { + RequestLogRecord record = new RequestLogRecord.Builder() + .setRows(OnDevicePersonalizationFlatbufferUtils + .getContentValuesFromQueryData(query.getQueryData())) + .setRequestId(query.getQueryId()) + .setTimeMillis(query.getTimeMillis()) + .build(); + requestLogRecords.add(record); + } + Bundle result = new Bundle(); + result.putParcelable(Constants.EXTRA_RESULT, + new OdpParceledListSlice<>(requestLogRecords)); + sendResult(result, callback); + } catch (Exception e) { + sendError(callback); + } + } + + private void getJoinedEvents(long startTimeMillis, long endTimeMillis, + @NonNull IDataAccessServiceCallback callback) { + try { + List<JoinedEvent> joinedEvents = mEventsDao.readJoinedTableRows(startTimeMillis, + endTimeMillis, + mServicePackageName); + List<EventLogRecord> joinedLogRecords = new ArrayList<>(); + for (JoinedEvent joinedEvent : joinedEvents) { + RequestLogRecord requestLogRecord = new RequestLogRecord.Builder() + .setRequestId(joinedEvent.getQueryId()) + .setRows(OnDevicePersonalizationFlatbufferUtils + .getContentValuesFromQueryData(joinedEvent.getQueryData())) + .setTimeMillis(joinedEvent.getQueryTimeMillis()) + .build(); + EventLogRecord record = new EventLogRecord.Builder() + .setTimeMillis(joinedEvent.getEventTimeMillis()) + .setType(joinedEvent.getType()) + .setData( + OnDevicePersonalizationFlatbufferUtils + .getContentValuesFromEventData(joinedEvent.getEventData())) + .setRequestLogRecord(requestLogRecord) + .build(); + joinedLogRecords.add(record); + } + Bundle result = new Bundle(); + result.putParcelable(Constants.EXTRA_RESULT, + new OdpParceledListSlice<>(joinedLogRecords)); + sendResult(result, callback); + } catch (Exception e) { + sendError(callback); + } + } + private void sendResult( @NonNull Bundle result, @NonNull IDataAccessServiceCallback callback) { @@ -316,4 +396,35 @@ public class DataAccessServiceImpl extends IDataAccessService.Stub { sLogger.e(TAG + ": Callback error", e); } } + + @VisibleForTesting + static class Injector { + long getTimeMillis() { + return System.currentTimeMillis(); + } + + ListeningExecutorService getExecutor() { + return OnDevicePersonalizationExecutors.getBackgroundExecutor(); + } + + OnDevicePersonalizationVendorDataDao getVendorDataDao( + Context context, String packageName, String certDigest + ) { + return OnDevicePersonalizationVendorDataDao.getInstance(context, + packageName, certDigest); + } + + OnDevicePersonalizationLocalDataDao getLocalDataDao( + Context context, String packageName, String certDigest + ) { + return OnDevicePersonalizationLocalDataDao.getInstance(context, + packageName, certDigest); + } + + EventsDao getEventsDao( + Context context + ) { + return EventsDao.getInstance(context); + } + } } diff --git a/src/com/android/ondevicepersonalization/services/data/OnDevicePersonalizationDbHelper.java b/src/com/android/ondevicepersonalization/services/data/OnDevicePersonalizationDbHelper.java index 4e18cebf..c8981e3c 100644 --- a/src/com/android/ondevicepersonalization/services/data/OnDevicePersonalizationDbHelper.java +++ b/src/com/android/ondevicepersonalization/services/data/OnDevicePersonalizationDbHelper.java @@ -40,7 +40,7 @@ public class OnDevicePersonalizationDbHelper extends SQLiteOpenHelper { private static final int DATABASE_VERSION = 1; private static final String DATABASE_NAME = "ondevicepersonalization.db"; - private static OnDevicePersonalizationDbHelper sSingleton = null; + private static volatile OnDevicePersonalizationDbHelper sSingleton = null; private OnDevicePersonalizationDbHelper(Context context, String dbName) { super(context, dbName, null, DATABASE_VERSION); @@ -48,12 +48,15 @@ public class OnDevicePersonalizationDbHelper extends SQLiteOpenHelper { /** Returns an instance of the OnDevicePersonalizationDbHelper given a context. */ public static OnDevicePersonalizationDbHelper getInstance(Context context) { - synchronized (OnDevicePersonalizationDbHelper.class) { - if (sSingleton == null) { - sSingleton = new OnDevicePersonalizationDbHelper(context, DATABASE_NAME); + if (sSingleton == null) { + synchronized (OnDevicePersonalizationDbHelper.class) { + if (sSingleton == null) { + sSingleton = new OnDevicePersonalizationDbHelper( + context.getApplicationContext(), DATABASE_NAME); + } } - return sSingleton; } + return sSingleton; } /** diff --git a/src/com/android/ondevicepersonalization/services/data/events/ColumnSchema.java b/src/com/android/ondevicepersonalization/services/data/events/ColumnSchema.java new file mode 100644 index 00000000..cad06bd9 --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/data/events/ColumnSchema.java @@ -0,0 +1,274 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.data.events; + +import android.annotation.NonNull; + +import com.android.ondevicepersonalization.internal.util.AnnotationValidations; +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * ColumnSchema object representing a SQL column + */ +@DataClass( + genBuilder = true, + genEqualsHashCode = true +) +public class ColumnSchema { + public static final int SQL_DATA_TYPE_INTEGER = 1; + public static final int SQL_DATA_TYPE_REAL = 2; + public static final int SQL_DATA_TYPE_TEXT = 3; + public static final int SQL_DATA_TYPE_BLOB = 4; + + /** The name of the column. */ + @NonNull + private final String mName; + + /** The SQL type of the column. */ + @SqlDataType + private final int mType; + + @Override + public String toString() { + String result = mName + " "; + switch (mType) { + case ColumnSchema.SQL_DATA_TYPE_INTEGER: + result += "INTEGER"; + break; + case ColumnSchema.SQL_DATA_TYPE_REAL: + result += "REAL"; + break; + case ColumnSchema.SQL_DATA_TYPE_TEXT: + result += "TEXT"; + break; + case ColumnSchema.SQL_DATA_TYPE_BLOB: + default: + result += "BLOB"; + break; + } + return result; + } + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/data/events/ColumnSchema.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @android.annotation.IntDef(prefix = "SQL_DATA_TYPE_", value = { + SQL_DATA_TYPE_INTEGER, + SQL_DATA_TYPE_REAL, + SQL_DATA_TYPE_TEXT, + SQL_DATA_TYPE_BLOB + }) + @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.SOURCE) + @DataClass.Generated.Member + public @interface SqlDataType {} + + @DataClass.Generated.Member + public static String sqlDataTypeToString(@SqlDataType int value) { + switch (value) { + case SQL_DATA_TYPE_INTEGER: + return "SQL_DATA_TYPE_INTEGER"; + case SQL_DATA_TYPE_REAL: + return "SQL_DATA_TYPE_REAL"; + case SQL_DATA_TYPE_TEXT: + return "SQL_DATA_TYPE_TEXT"; + case SQL_DATA_TYPE_BLOB: + return "SQL_DATA_TYPE_BLOB"; + default: return Integer.toHexString(value); + } + } + + @DataClass.Generated.Member + /* package-private */ ColumnSchema( + @NonNull String name, + @SqlDataType int type) { + this.mName = name; + AnnotationValidations.validate( + NonNull.class, null, mName); + this.mType = type; + + if (!(mType == SQL_DATA_TYPE_INTEGER) + && !(mType == SQL_DATA_TYPE_REAL) + && !(mType == SQL_DATA_TYPE_TEXT) + && !(mType == SQL_DATA_TYPE_BLOB)) { + throw new java.lang.IllegalArgumentException( + "type was " + mType + " but must be one of: " + + "SQL_DATA_TYPE_INTEGER(" + SQL_DATA_TYPE_INTEGER + "), " + + "SQL_DATA_TYPE_REAL(" + SQL_DATA_TYPE_REAL + "), " + + "SQL_DATA_TYPE_TEXT(" + SQL_DATA_TYPE_TEXT + "), " + + "SQL_DATA_TYPE_BLOB(" + SQL_DATA_TYPE_BLOB + ")"); + } + + + // onConstructed(); // You can define this method to get a callback + } + + /** + * The name of the column. + */ + @DataClass.Generated.Member + public @NonNull String getName() { + return mName; + } + + /** + * The SQL type of the column. + */ + @DataClass.Generated.Member + public @SqlDataType int getType() { + return mType; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(ColumnSchema other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + ColumnSchema that = (ColumnSchema) o; + //noinspection PointlessBooleanExpression + return true + && java.util.Objects.equals(mName, that.mName) + && mType == that.mType; + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + java.util.Objects.hashCode(mName); + _hash = 31 * _hash + mType; + return _hash; + } + + /** + * A builder for {@link ColumnSchema} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static class Builder { + + private @NonNull String mName; + private @SqlDataType int mType; + + private long mBuilderFieldsSet = 0L; + + public Builder() { + } + + /** + * Creates a new Builder. + * + * @param name + * The name of the column. + * @param type + * The SQL type of the column. + */ + public Builder( + @NonNull String name, + @SqlDataType int type) { + mName = name; + AnnotationValidations.validate( + NonNull.class, null, mName); + mType = type; + + if (!(mType == SQL_DATA_TYPE_INTEGER) + && !(mType == SQL_DATA_TYPE_REAL) + && !(mType == SQL_DATA_TYPE_TEXT) + && !(mType == SQL_DATA_TYPE_BLOB)) { + throw new java.lang.IllegalArgumentException( + "type was " + mType + " but must be one of: " + + "SQL_DATA_TYPE_INTEGER(" + SQL_DATA_TYPE_INTEGER + "), " + + "SQL_DATA_TYPE_REAL(" + SQL_DATA_TYPE_REAL + "), " + + "SQL_DATA_TYPE_TEXT(" + SQL_DATA_TYPE_TEXT + "), " + + "SQL_DATA_TYPE_BLOB(" + SQL_DATA_TYPE_BLOB + ")"); + } + + } + + /** + * The name of the column. + */ + @DataClass.Generated.Member + public @NonNull Builder setName(@NonNull String value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mName = value; + return this; + } + + /** + * The SQL type of the column. + */ + @DataClass.Generated.Member + public @NonNull Builder setType(@SqlDataType int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mType = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @NonNull ColumnSchema build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; // Mark builder used + + ColumnSchema o = new ColumnSchema( + mName, + mType); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x4) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1693945130500L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/data/events/ColumnSchema.java", + inputSignatures = "public static final int SQL_DATA_TYPE_INTEGER\npublic static final int SQL_DATA_TYPE_REAL\npublic static final int SQL_DATA_TYPE_TEXT\npublic static final int SQL_DATA_TYPE_BLOB\nprivate final @android.annotation.NonNull java.lang.String mName\nprivate final @com.android.ondevicepersonalization.services.data.events.ColumnSchema.SqlDataType int mType\nclass ColumnSchema extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/src/com/android/ondevicepersonalization/services/data/events/Event.java b/src/com/android/ondevicepersonalization/services/data/events/Event.java index 3879d78d..1b867668 100644 --- a/src/com/android/ondevicepersonalization/services/data/events/Event.java +++ b/src/com/android/ondevicepersonalization/services/data/events/Event.java @@ -39,7 +39,7 @@ public class Event implements Serializable { private final long mQueryId; /** Index of the associated entry in the request log for this event. */ - private final long mRowIndex; + private final int mRowIndex; /** Name of the service package for this event */ @NonNull @@ -74,7 +74,7 @@ public class Event implements Serializable { /* package-private */ Event( long eventId, long queryId, - long rowIndex, + int rowIndex, @NonNull String servicePackageName, int type, long timeMillis, @@ -112,7 +112,7 @@ public class Event implements Serializable { * Index of the associated entry in the request log for this event. */ @DataClass.Generated.Member - public long getRowIndex() { + public int getRowIndex() { return mRowIndex; } @@ -125,7 +125,7 @@ public class Event implements Serializable { } /** - * {@link EventType} defining the type of event + * The service assigned type of the event. */ @DataClass.Generated.Member public int getType() { @@ -179,7 +179,7 @@ public class Event implements Serializable { int _hash = 1; _hash = 31 * _hash + Long.hashCode(mEventId); _hash = 31 * _hash + Long.hashCode(mQueryId); - _hash = 31 * _hash + Long.hashCode(mRowIndex); + _hash = 31 * _hash + mRowIndex; _hash = 31 * _hash + java.util.Objects.hashCode(mServicePackageName); _hash = 31 * _hash + mType; _hash = 31 * _hash + Long.hashCode(mTimeMillis); @@ -196,7 +196,7 @@ public class Event implements Serializable { private long mEventId; private long mQueryId; - private long mRowIndex; + private int mRowIndex; private @NonNull String mServicePackageName; private int mType; private long mTimeMillis; @@ -219,7 +219,7 @@ public class Event implements Serializable { * @param servicePackageName * Name of the service package for this event * @param type - * {@link EventType} defining the type of event + * The service assigned type of the event. * @param timeMillis * Time of the event in milliseconds. * @param eventData @@ -228,7 +228,7 @@ public class Event implements Serializable { public Builder( long eventId, long queryId, - long rowIndex, + int rowIndex, @NonNull String servicePackageName, int type, long timeMillis, @@ -270,7 +270,7 @@ public class Event implements Serializable { * Index of the associated entry in the request log for this event. */ @DataClass.Generated.Member - public @NonNull Builder setRowIndex(long value) { + public @NonNull Builder setRowIndex(int value) { checkNotUsed(); mBuilderFieldsSet |= 0x4; mRowIndex = value; @@ -289,7 +289,7 @@ public class Event implements Serializable { } /** - * {@link EventType} defining the type of event + * The service assigned type of the event. */ @DataClass.Generated.Member public @NonNull Builder setType(int value) { @@ -346,10 +346,10 @@ public class Event implements Serializable { } @DataClass.Generated( - time = 1686610284923L, + time = 1693520125987L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/data/events/Event.java", - inputSignatures = "private final long mEventId\nprivate final long mQueryId\nprivate final long mRowIndex\nprivate final @android.annotation.NonNull java.lang.String mServicePackageName\nprivate final int mType\nprivate final long mTimeMillis\nprivate final @android.annotation.Nullable byte[] mEventData\nclass Event extends java.lang.Object implements [java.io.Serializable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private final long mEventId\nprivate final long mQueryId\nprivate final int mRowIndex\nprivate final @android.annotation.NonNull java.lang.String mServicePackageName\nprivate final int mType\nprivate final long mTimeMillis\nprivate final @android.annotation.Nullable byte[] mEventData\nclass Event extends java.lang.Object implements [java.io.Serializable]\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/src/com/android/ondevicepersonalization/services/data/events/EventState.java b/src/com/android/ondevicepersonalization/services/data/events/EventState.java index 8c97801a..d82b1d60 100644 --- a/src/com/android/ondevicepersonalization/services/data/events/EventState.java +++ b/src/com/android/ondevicepersonalization/services/data/events/EventState.java @@ -29,11 +29,9 @@ import com.android.ondevicepersonalization.internal.util.DataClass; genEqualsHashCode = true ) public class EventState { - /** The id of the last confirmed event. */ - private final long mEventId; - - /** The id of the last confirmed query. */ - private final long mQueryId; + /** Token representing the event state. */ + @NonNull + private final byte[] mToken; /** Name of the service package for this event */ @NonNull @@ -60,12 +58,12 @@ public class EventState { @DataClass.Generated.Member /* package-private */ EventState( - long eventId, - long queryId, + @NonNull byte[] token, @NonNull String servicePackageName, @NonNull String taskIdentifier) { - this.mEventId = eventId; - this.mQueryId = queryId; + this.mToken = token; + AnnotationValidations.validate( + NonNull.class, null, mToken); this.mServicePackageName = servicePackageName; AnnotationValidations.validate( NonNull.class, null, mServicePackageName); @@ -77,19 +75,11 @@ public class EventState { } /** - * The id of the last confirmed event. + * Token representing the event state. */ @DataClass.Generated.Member - public long getEventId() { - return mEventId; - } - - /** - * The id of the last confirmed query. - */ - @DataClass.Generated.Member - public long getQueryId() { - return mQueryId; + public @NonNull byte[] getToken() { + return mToken; } /** @@ -121,8 +111,7 @@ public class EventState { EventState that = (EventState) o; //noinspection PointlessBooleanExpression return true - && mEventId == that.mEventId - && mQueryId == that.mQueryId + && java.util.Arrays.equals(mToken, that.mToken) && java.util.Objects.equals(mServicePackageName, that.mServicePackageName) && java.util.Objects.equals(mTaskIdentifier, that.mTaskIdentifier); } @@ -134,8 +123,7 @@ public class EventState { // int fieldNameHashCode() { ... } int _hash = 1; - _hash = 31 * _hash + Long.hashCode(mEventId); - _hash = 31 * _hash + Long.hashCode(mQueryId); + _hash = 31 * _hash + java.util.Arrays.hashCode(mToken); _hash = 31 * _hash + java.util.Objects.hashCode(mServicePackageName); _hash = 31 * _hash + java.util.Objects.hashCode(mTaskIdentifier); return _hash; @@ -148,35 +136,31 @@ public class EventState { @DataClass.Generated.Member public static class Builder { - private long mEventId; - private long mQueryId; + private @NonNull byte[] mToken; private @NonNull String mServicePackageName; private @NonNull String mTaskIdentifier; private long mBuilderFieldsSet = 0L; - public Builder() { - } + public Builder() {} /** * Creates a new Builder. * - * @param eventId - * The id of the last confirmed event. - * @param queryId - * The id of the last confirmed query. + * @param token + * Token representing the event state. * @param servicePackageName * Name of the service package for this event * @param taskIdentifier * Unique identifier of the task for processing this event */ public Builder( - long eventId, - long queryId, + @NonNull byte[] token, @NonNull String servicePackageName, @NonNull String taskIdentifier) { - mEventId = eventId; - mQueryId = queryId; + mToken = token; + AnnotationValidations.validate( + NonNull.class, null, mToken); mServicePackageName = servicePackageName; AnnotationValidations.validate( NonNull.class, null, mServicePackageName); @@ -186,24 +170,13 @@ public class EventState { } /** - * The id of the last confirmed event. + * Token representing the event state. */ @DataClass.Generated.Member - public @NonNull Builder setEventId(long value) { + public @NonNull Builder setToken(@NonNull byte... value) { checkNotUsed(); mBuilderFieldsSet |= 0x1; - mEventId = value; - return this; - } - - /** - * The id of the last confirmed query. - */ - @DataClass.Generated.Member - public @NonNull Builder setQueryId(long value) { - checkNotUsed(); - mBuilderFieldsSet |= 0x2; - mQueryId = value; + mToken = value; return this; } @@ -213,7 +186,7 @@ public class EventState { @DataClass.Generated.Member public @NonNull Builder setServicePackageName(@NonNull String value) { checkNotUsed(); - mBuilderFieldsSet |= 0x4; + mBuilderFieldsSet |= 0x2; mServicePackageName = value; return this; } @@ -224,7 +197,7 @@ public class EventState { @DataClass.Generated.Member public @NonNull Builder setTaskIdentifier(@NonNull String value) { checkNotUsed(); - mBuilderFieldsSet |= 0x8; + mBuilderFieldsSet |= 0x4; mTaskIdentifier = value; return this; } @@ -232,18 +205,17 @@ public class EventState { /** Builds the instance. This builder should not be touched after calling this! */ public @NonNull EventState build() { checkNotUsed(); - mBuilderFieldsSet |= 0x10; // Mark builder used + mBuilderFieldsSet |= 0x8; // Mark builder used EventState o = new EventState( - mEventId, - mQueryId, + mToken, mServicePackageName, mTaskIdentifier); return o; } private void checkNotUsed() { - if ((mBuilderFieldsSet & 0x10) != 0) { + if ((mBuilderFieldsSet & 0x8) != 0) { throw new IllegalStateException( "This Builder should not be reused. Use a new Builder instance instead"); } @@ -251,10 +223,10 @@ public class EventState { } @DataClass.Generated( - time = 1692638505171L, + time = 1695678195125L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/data/events/EventState.java", - inputSignatures = "private final long mEventId\nprivate final long mQueryId\nprivate final @android.annotation.NonNull java.lang.String mServicePackageName\nprivate final @android.annotation.NonNull java.lang.String mTaskIdentifier\nclass EventState extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private final @android.annotation.NonNull byte[] mToken\nprivate final @android.annotation.NonNull java.lang.String mServicePackageName\nprivate final @android.annotation.NonNull java.lang.String mTaskIdentifier\nclass EventState extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/src/com/android/ondevicepersonalization/services/data/events/EventStateContract.java b/src/com/android/ondevicepersonalization/services/data/events/EventStateContract.java index 7104abdf..b1764434 100644 --- a/src/com/android/ondevicepersonalization/services/data/events/EventStateContract.java +++ b/src/com/android/ondevicepersonalization/services/data/events/EventStateContract.java @@ -36,18 +36,15 @@ public class EventStateContract { /** Name of the service package for this event */ public static final String SERVICE_PACKAGE_NAME = "servicePackageName"; - /** The id of the last confirmed event. */ - public static final String EVENT_ID = "eventId"; + /** Token representing the event state. */ + public static final String TOKEN = "token"; - /** The id of the last confirmed query. */ - public static final String QUERY_ID = "queryId"; public static final String CREATE_TABLE_STATEMENT = "CREATE TABLE IF NOT EXISTS " + TABLE_NAME + " (" + TASK_IDENTIFIER + " TEXT NOT NULL," + SERVICE_PACKAGE_NAME + " TEXT NOT NULL," - + EVENT_ID + " INTEGER," - + QUERY_ID + " INTEGER," + + TOKEN + " BLOB NOT NULL," + "UNIQUE(" + TASK_IDENTIFIER + "," + SERVICE_PACKAGE_NAME + "))"; diff --git a/src/com/android/ondevicepersonalization/services/data/events/EventsDao.java b/src/com/android/ondevicepersonalization/services/data/events/EventsDao.java index 51fe3434..e9d987d8 100644 --- a/src/com/android/ondevicepersonalization/services/data/events/EventsDao.java +++ b/src/com/android/ondevicepersonalization/services/data/events/EventsDao.java @@ -36,8 +36,10 @@ import java.util.List; public class EventsDao { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "EventsDao"; + private static final String JOINED_EVENT_TIME_MILLIS = "eventTimeMillis"; + private static final String JOINED_QUERY_TIME_MILLIS = "queryTimeMillis"; - private static EventsDao sSingleton; + private static volatile EventsDao sSingleton; private final OnDevicePersonalizationDbHelper mDbHelper; @@ -47,14 +49,16 @@ public class EventsDao { /** Returns an instance of the EventsDao given a context. */ public static EventsDao getInstance(@NonNull Context context) { - synchronized (EventsDao.class) { - if (sSingleton == null) { - OnDevicePersonalizationDbHelper dbHelper = - OnDevicePersonalizationDbHelper.getInstance(context); - sSingleton = new EventsDao(dbHelper); + if (sSingleton == null) { + synchronized (EventsDao.class) { + if (sSingleton == null) { + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstance(context); + sSingleton = new EventsDao(dbHelper); + } } - return sSingleton; } + return sSingleton; } /** @@ -97,6 +101,31 @@ public class EventsDao { return -1; } + + /** + * Inserts the List of Events into the Events table. + * + * @return true if all inserts succeeded, false otherwise. + */ + public boolean insertEvents(@NonNull List<Event> events) { + SQLiteDatabase db = mDbHelper.getWritableDatabase(); + try { + db.beginTransactionNonExclusive(); + for (Event event : events) { + if (insertEvent(event) == -1) { + return false; + } + } + db.setTransactionSuccessful(); + } catch (Exception e) { + sLogger.e(TAG + ": Failed to insert events", e); + return false; + } finally { + db.endTransaction(); + } + return true; + } + /** * Inserts the Query into the Queries table. * @@ -127,8 +156,7 @@ public class EventsDao { try { SQLiteDatabase db = mDbHelper.getWritableDatabase(); ContentValues values = new ContentValues(); - values.put(EventStateContract.EventStateEntry.EVENT_ID, eventState.getEventId()); - values.put(EventStateContract.EventStateEntry.QUERY_ID, eventState.getQueryId()); + values.put(EventStateContract.EventStateEntry.TOKEN, eventState.getToken()); values.put(EventStateContract.EventStateEntry.SERVICE_PACKAGE_NAME, eventState.getServicePackageName()); values.put(EventStateContract.EventStateEntry.TASK_IDENTIFIER, @@ -142,6 +170,31 @@ public class EventsDao { } /** + * Updates/inserts a list of EventStates as a transaction + * + * @return true if the all the update/inserts succeeded, false otherwise + */ + public boolean updateOrInsertEventStatesTransaction(List<EventState> eventStates) { + SQLiteDatabase db = mDbHelper.getWritableDatabase(); + try { + db.beginTransactionNonExclusive(); + for (EventState eventState : eventStates) { + if (!updateOrInsertEventState(eventState)) { + return false; + } + } + + db.setTransactionSuccessful(); + } catch (Exception e) { + sLogger.e(TAG + ": Failed to insert/update eventstates", e); + return false; + } finally { + db.endTransaction(); + } + return true; + } + + /** * Gets the eventState for the given package and task * * @return eventState if found, null otherwise @@ -151,8 +204,7 @@ public class EventsDao { String selection = EventStateContract.EventStateEntry.TASK_IDENTIFIER + " = ? AND " + EventStateContract.EventStateEntry.SERVICE_PACKAGE_NAME + " = ?"; String[] selectionArgs = {taskIdentifier, packageName}; - String[] projection = {EventStateContract.EventStateEntry.EVENT_ID, - EventStateContract.EventStateEntry.QUERY_ID}; + String[] projection = {EventStateContract.EventStateEntry.TOKEN}; try (Cursor cursor = db.query( EventStateContract.EventStateEntry.TABLE_NAME, projection, @@ -163,14 +215,11 @@ public class EventsDao { /* orderBy= */ null )) { if (cursor.moveToFirst()) { - long eventId = cursor.getLong(cursor.getColumnIndexOrThrow( - EventStateContract.EventStateEntry.EVENT_ID)); - long queryId = cursor.getLong(cursor.getColumnIndexOrThrow( - EventStateContract.EventStateEntry.QUERY_ID)); + byte[] token = cursor.getBlob(cursor.getColumnIndexOrThrow( + EventStateContract.EventStateEntry.TOKEN)); return new EventState.Builder() - .setEventId(eventId) - .setQueryId(queryId) + .setToken(token) .setServicePackageName(packageName) .setTaskIdentifier(taskIdentifier) .build(); @@ -287,9 +336,6 @@ public class EventsDao { List<JoinedEvent> joinedEventList = new ArrayList<>(); SQLiteDatabase db = mDbHelper.getReadableDatabase(); - - String eventTimeMillisCol = "eventTimeMillis"; - String queryTimeMillisCol = "queryTimeMillis"; String select = "SELECT " + EventsContract.EventsEntry.EVENT_ID + "," + EventsContract.EventsEntry.ROW_INDEX + "," @@ -298,12 +344,12 @@ public class EventsDao { + EventsContract.EventsEntry.SERVICE_PACKAGE_NAME + "," + EventsContract.EventsEntry.EVENT_DATA + "," + EventsContract.EventsEntry.TABLE_NAME + "." - + EventsContract.EventsEntry.TIME_MILLIS + " AS " + eventTimeMillisCol + "," + + EventsContract.EventsEntry.TIME_MILLIS + " AS " + JOINED_EVENT_TIME_MILLIS + "," + EventsContract.EventsEntry.TABLE_NAME + "." + EventsContract.EventsEntry.QUERY_ID + "," + QueriesContract.QueriesEntry.QUERY_DATA + "," + QueriesContract.QueriesEntry.TABLE_NAME + "." - + QueriesContract.QueriesEntry.TIME_MILLIS + " AS " + queryTimeMillisCol; + + QueriesContract.QueriesEntry.TIME_MILLIS + " AS " + JOINED_QUERY_TIME_MILLIS; String from = " FROM " + EventsContract.EventsEntry.TABLE_NAME + " INNER JOIN " + QueriesContract.QueriesEntry.TABLE_NAME + " ON " @@ -317,7 +363,7 @@ public class EventsDao { while (cursor.moveToNext()) { long eventId = cursor.getLong( cursor.getColumnIndexOrThrow(EventsContract.EventsEntry.EVENT_ID)); - long rowIndex = cursor.getLong( + int rowIndex = cursor.getInt( cursor.getColumnIndexOrThrow(EventsContract.EventsEntry.ROW_INDEX)); int type = cursor.getInt( cursor.getColumnIndexOrThrow(EventsContract.EventsEntry.TYPE)); @@ -327,13 +373,13 @@ public class EventsDao { byte[] eventData = cursor.getBlob( cursor.getColumnIndexOrThrow(EventsContract.EventsEntry.EVENT_DATA)); long eventTimeMillis = cursor.getLong( - cursor.getColumnIndexOrThrow(eventTimeMillisCol)); + cursor.getColumnIndexOrThrow(JOINED_EVENT_TIME_MILLIS)); long queryId = cursor.getLong( cursor.getColumnIndexOrThrow(QueriesContract.QueriesEntry.QUERY_ID)); byte[] queryData = cursor.getBlob( cursor.getColumnIndexOrThrow(QueriesContract.QueriesEntry.QUERY_DATA)); long queryTimeMillis = cursor.getLong( - cursor.getColumnIndexOrThrow(queryTimeMillisCol)); + cursor.getColumnIndexOrThrow(JOINED_QUERY_TIME_MILLIS)); joinedEventList.add(new JoinedEvent.Builder() .setEventId(eventId) .setRowIndex(rowIndex) @@ -406,4 +452,179 @@ public class EventsDao { } return true; } + + /** + * Reads all queries in the query table between the given timestamps. + * + * @return List of Query in the query table. + */ + public List<Query> readAllQueries(long startTimeMillis, long endTimeMillis, + String packageName) { + String selection = QueriesContract.QueriesEntry.TIME_MILLIS + " > ?" + + " AND " + QueriesContract.QueriesEntry.TIME_MILLIS + " < ?" + + " AND " + QueriesContract.QueriesEntry.SERVICE_PACKAGE_NAME + " = ?"; + String[] selectionArgs = {String.valueOf(startTimeMillis), String.valueOf( + endTimeMillis), packageName}; + return readQueryRows(selection, selectionArgs); + } + + /** + * Reads all ids in the event table between the given timestamps. + * + * @return List of ids in the event table. + */ + public List<Long> readAllEventIds(long startTimeMillis, long endTimeMillis, + String packageName) { + List<Long> idList = new ArrayList<>(); + try { + SQLiteDatabase db = mDbHelper.getReadableDatabase(); + String[] projection = {EventsContract.EventsEntry.EVENT_ID}; + String selection = EventsContract.EventsEntry.TIME_MILLIS + " > ?" + + " AND " + EventsContract.EventsEntry.TIME_MILLIS + " < ?" + + " AND " + EventsContract.EventsEntry.SERVICE_PACKAGE_NAME + " = ?"; + String[] selectionArgs = {String.valueOf(startTimeMillis), String.valueOf( + endTimeMillis), packageName}; + String orderBy = EventsContract.EventsEntry.EVENT_ID; + try (Cursor cursor = db.query( + EventsContract.EventsEntry.TABLE_NAME, + projection, + selection, + selectionArgs, + /* groupBy= */ null, + /* having= */ null, + orderBy + )) { + while (cursor.moveToNext()) { + Long id = cursor.getLong( + cursor.getColumnIndexOrThrow(EventsContract.EventsEntry.EVENT_ID)); + idList.add(id); + } + cursor.close(); + return idList; + } + } catch (SQLiteException e) { + sLogger.e(TAG + ": Failed to read event ids", e); + } + return idList; + } + + + /** + * Reads all ids in the event table associated with the specified queryId + * + * @return List of ids in the event table. + */ + public List<Long> readAllEventIdsForQuery(long queryId, String packageName) { + List<Long> idList = new ArrayList<>(); + try { + SQLiteDatabase db = mDbHelper.getReadableDatabase(); + String[] projection = {EventsContract.EventsEntry.EVENT_ID}; + String selection = EventsContract.EventsEntry.QUERY_ID + " = ?" + + " AND " + EventsContract.EventsEntry.SERVICE_PACKAGE_NAME + " = ?"; + String[] selectionArgs = {String.valueOf(queryId), packageName}; + String orderBy = EventsContract.EventsEntry.EVENT_ID; + try (Cursor cursor = db.query( + EventsContract.EventsEntry.TABLE_NAME, + projection, + selection, + selectionArgs, + /* groupBy= */ null, + /* having= */ null, + orderBy + )) { + while (cursor.moveToNext()) { + Long id = cursor.getLong( + cursor.getColumnIndexOrThrow(EventsContract.EventsEntry.EVENT_ID)); + idList.add(id); + } + cursor.close(); + return idList; + } + } catch (SQLiteException e) { + sLogger.e(TAG + ": Failed to read event ids for specified queryid", e); + } + return idList; + } + + /** + * Reads single row in the query table + * + * @return Query object for the single row requested + */ + public Query readSingleQueryRow(long queryId, String packageName) { + try { + SQLiteDatabase db = mDbHelper.getReadableDatabase(); + String selection = QueriesContract.QueriesEntry.QUERY_ID + " = ?" + + " AND " + QueriesContract.QueriesEntry.SERVICE_PACKAGE_NAME + " = ?"; + String[] selectionArgs = {String.valueOf(queryId), packageName}; + try (Cursor cursor = db.query( + QueriesContract.QueriesEntry.TABLE_NAME, + /* projection= */ null, + selection, + selectionArgs, + /* groupBy= */ null, + /* having= */ null, + /* orderBy= */ null + )) { + if (cursor.getCount() < 1) { + sLogger.d(TAG + ": Failed to find requested id: " + queryId); + return null; + } + cursor.moveToNext(); + long id = cursor.getLong( + cursor.getColumnIndexOrThrow(QueriesContract.QueriesEntry.QUERY_ID)); + byte[] queryData = cursor.getBlob( + cursor.getColumnIndexOrThrow(QueriesContract.QueriesEntry.QUERY_DATA)); + long timeMillis = cursor.getLong( + cursor.getColumnIndexOrThrow(QueriesContract.QueriesEntry.TIME_MILLIS)); + String servicePackageName = cursor.getString( + cursor.getColumnIndexOrThrow( + QueriesContract.QueriesEntry.SERVICE_PACKAGE_NAME)); + return new Query.Builder() + .setQueryId(id) + .setQueryData(queryData) + .setTimeMillis(timeMillis) + .setServicePackageName(servicePackageName) + .build(); + } + } catch (SQLiteException e) { + sLogger.e(TAG + ": Failed to read query row", e); + } + return null; + } + + /** + * Reads single row in the event table joined with its corresponding query + * + * @return JoinedEvent representing the event joined with its query + */ + public JoinedEvent readSingleJoinedTableRow(long eventId, String packageName) { + String selection = EventsContract.EventsEntry.EVENT_ID + " = ?" + + " AND " + EventsContract.EventsEntry.TABLE_NAME + "." + + EventsContract.EventsEntry.SERVICE_PACKAGE_NAME + " = ?"; + String[] selectionArgs = {String.valueOf(eventId), packageName}; + List<JoinedEvent> joinedEventList = readJoinedTableRows(selection, selectionArgs); + if (joinedEventList.size() < 1) { + sLogger.d(TAG + ": Failed to find requested id: " + eventId); + return null; + } + return joinedEventList.get(0); + } + + /** + * Reads all row in the event table joined with its corresponding query within the given time + * range. + * + * @return List of JoinedEvents representing the event joined with its query + */ + public List<JoinedEvent> readJoinedTableRows(long startTimeMillis, long endTimeMillis, + String packageName) { + String selection = JOINED_EVENT_TIME_MILLIS + " > ?" + + " AND " + JOINED_EVENT_TIME_MILLIS + " < ?" + + " AND " + EventsContract.EventsEntry.TABLE_NAME + "." + + EventsContract.EventsEntry.SERVICE_PACKAGE_NAME + " = ?"; + String[] selectionArgs = {String.valueOf(startTimeMillis), String.valueOf( + endTimeMillis), packageName}; + return readJoinedTableRows(selection, selectionArgs); + } } diff --git a/src/com/android/ondevicepersonalization/services/data/events/JoinedEvent.java b/src/com/android/ondevicepersonalization/services/data/events/JoinedEvent.java index 8055e009..ce431c96 100644 --- a/src/com/android/ondevicepersonalization/services/data/events/JoinedEvent.java +++ b/src/com/android/ondevicepersonalization/services/data/events/JoinedEvent.java @@ -37,7 +37,7 @@ public class JoinedEvent { private final long mQueryId; /** Index of the associated entry in the request log for this event. */ - private final long mRowIndex; + private final int mRowIndex; /** Name of the service package for this event */ @NonNull @@ -79,7 +79,7 @@ public class JoinedEvent { /* package-private */ JoinedEvent( long eventId, long queryId, - long rowIndex, + int rowIndex, @NonNull String servicePackageName, int type, long eventTimeMillis, @@ -121,7 +121,7 @@ public class JoinedEvent { * Index of the associated entry in the request log for this event. */ @DataClass.Generated.Member - public long getRowIndex() { + public int getRowIndex() { return mRowIndex; } @@ -206,7 +206,7 @@ public class JoinedEvent { int _hash = 1; _hash = 31 * _hash + Long.hashCode(mEventId); _hash = 31 * _hash + Long.hashCode(mQueryId); - _hash = 31 * _hash + Long.hashCode(mRowIndex); + _hash = 31 * _hash + mRowIndex; _hash = 31 * _hash + java.util.Objects.hashCode(mServicePackageName); _hash = 31 * _hash + mType; _hash = 31 * _hash + Long.hashCode(mEventTimeMillis); @@ -225,7 +225,7 @@ public class JoinedEvent { private long mEventId; private long mQueryId; - private long mRowIndex; + private int mRowIndex; private @NonNull String mServicePackageName; private int mType; private long mEventTimeMillis; @@ -263,7 +263,7 @@ public class JoinedEvent { public Builder( long eventId, long queryId, - long rowIndex, + int rowIndex, @NonNull String servicePackageName, int type, long eventTimeMillis, @@ -309,7 +309,7 @@ public class JoinedEvent { * Index of the associated entry in the request log for this event. */ @DataClass.Generated.Member - public @NonNull Builder setRowIndex(long value) { + public @NonNull Builder setRowIndex(int value) { checkNotUsed(); mBuilderFieldsSet |= 0x4; mRowIndex = value; @@ -409,10 +409,10 @@ public class JoinedEvent { } @DataClass.Generated( - time = 1693354428802L, + time = 1693520269776L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/data/events/JoinedEvent.java", - inputSignatures = "private final long mEventId\nprivate final long mQueryId\nprivate final long mRowIndex\nprivate final @android.annotation.NonNull java.lang.String mServicePackageName\nprivate final int mType\nprivate final long mEventTimeMillis\nprivate final @android.annotation.Nullable byte[] mEventData\nprivate final long mQueryTimeMillis\nprivate final @android.annotation.Nullable byte[] mQueryData\nclass JoinedEvent extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + inputSignatures = "private final long mEventId\nprivate final long mQueryId\nprivate final int mRowIndex\nprivate final @android.annotation.NonNull java.lang.String mServicePackageName\nprivate final int mType\nprivate final long mEventTimeMillis\nprivate final @android.annotation.Nullable byte[] mEventData\nprivate final long mQueryTimeMillis\nprivate final @android.annotation.Nullable byte[] mQueryData\nclass JoinedEvent extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} diff --git a/src/com/android/ondevicepersonalization/services/data/events/JoinedTableDao.java b/src/com/android/ondevicepersonalization/services/data/events/JoinedTableDao.java new file mode 100644 index 00000000..ef9cb18a --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/data/events/JoinedTableDao.java @@ -0,0 +1,267 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.data.events; + +import android.content.ContentValues; +import android.content.Context; +import android.database.Cursor; +import android.database.SQLException; +import android.database.sqlite.SQLiteDatabase; +import android.database.sqlite.SQLiteOpenHelper; + +import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Dao used to manage and create in-memory table for joined Events and Queries tables + */ +public class JoinedTableDao { + /** Map of column name to {@link ColumnSchema} of columns provided by OnDevicePersonalization */ + public static final Map<String, ColumnSchema> ODP_PROVIDED_COLUMNS; + // TODO(298682670): Finalize provided column and table names. + public static final String SERVICE_PACKAGE_NAME_COL = "servicePackageName"; + public static final String TYPE_COL = "type"; + public static final String EVENT_TIME_MILLIS_COL = "eventTimeMillis"; + public static final String QUERY_TIME_MILLIS_COL = "queryTimeMillis"; + public static final String TABLE_NAME = "odp_joined_table"; + private static final String TAG = "JoinedTableDao"; + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + + static { + ODP_PROVIDED_COLUMNS = new HashMap<>(); + ODP_PROVIDED_COLUMNS.put(SERVICE_PACKAGE_NAME_COL, new ColumnSchema.Builder().setName( + SERVICE_PACKAGE_NAME_COL).setType(ColumnSchema.SQL_DATA_TYPE_TEXT).build()); + ODP_PROVIDED_COLUMNS.put(TYPE_COL, new ColumnSchema.Builder().setName(TYPE_COL).setType( + ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + ODP_PROVIDED_COLUMNS.put(EVENT_TIME_MILLIS_COL, new ColumnSchema.Builder().setName( + EVENT_TIME_MILLIS_COL).setType(ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + ODP_PROVIDED_COLUMNS.put(QUERY_TIME_MILLIS_COL, new ColumnSchema.Builder().setName( + QUERY_TIME_MILLIS_COL).setType(ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + } + + private final SQLiteOpenHelper mDbHelper; + private final Map<String, ColumnSchema> mColumns; + + public JoinedTableDao(List<ColumnSchema> columnSchemaList, long fromEventId, long fromQueryId, + Context context) { + if (!validateColumns(columnSchemaList)) { + throw new IllegalArgumentException("Provided columns are invalid."); + } + // Move the List to a HashMap <ColumnName, ColumnSchema> for easier access. + mColumns = columnSchemaList.stream().collect(Collectors.toMap( + ColumnSchema::getName, + Function.identity(), + (v1, v2) -> { + // Throw on duplicate keys. + throw new IllegalArgumentException("Duplicate key found in columnSchemaList"); + }, + HashMap::new)); + mDbHelper = createInMemoryTable(columnSchemaList, context); + populateTable(fromEventId, fromQueryId, context); + } + + private static SQLiteOpenHelper createInMemoryTable(List<ColumnSchema> columnSchemaList, + Context context) { + List<String> columns = columnSchemaList.stream().map(ColumnSchema::toString).collect( + Collectors.toList()); + String createTableStatement = "CREATE TABLE IF NOT EXISTS " + TABLE_NAME + " (" + + String.join(",", columns) + ")"; + SQLiteOpenHelper sqLiteOpenHelper = new SQLiteOpenHelper(context, null, null, 1) { + @Override + public void onCreate(SQLiteDatabase db) { + // Do nothing. + } + + @Override + public void onUpgrade(SQLiteDatabase db, int oldVersion, int newVersion) { + // Do nothing. Should never be called. + } + }; + + try { + sqLiteOpenHelper.getReadableDatabase().execSQL(createTableStatement); + } catch (SQLException e) { + sLogger.e(e, TAG + " : Failed to create JoinedTable database in memory."); + throw new IllegalStateException(e); + } + return sqLiteOpenHelper; + } + + private static boolean validateColumns(List<ColumnSchema> columnSchemaList) { + if (columnSchemaList.size() == 0) { + sLogger.d(TAG, ": Empty columnSchemaList provided"); + return false; + } + for (ColumnSchema columnSchema : columnSchemaList) { + // Validate any ODP_PROVIDED_COLUMNS are the correct type + if (ODP_PROVIDED_COLUMNS.containsKey(columnSchema.getName())) { + ColumnSchema expected = ODP_PROVIDED_COLUMNS.get(columnSchema.getName()); + if (expected.getType() != columnSchema.getType()) { + sLogger.d(TAG + + ": ODP column %s of type %s provided does not match " + + "expected type %s", + columnSchema.getName(), columnSchema.getType(), expected.getType()); + return false; + } + } + } + // TODO(298225729): Additional validation on column name formatting. + return true; + } + + /** + * Executes the given query on the in-memory db. + * + * @return Cursor holding result of the query. + */ + public Cursor rawQuery(String sql) { + SQLiteDatabase db = mDbHelper.getReadableDatabase(); + // TODO(298225729): Determine return format. + return db.rawQuery(sql, null); + } + + private void populateTable(long fromEventId, long fromQueryId, Context context) { + EventsDao eventsDao = EventsDao.getInstance(context); + List<JoinedEvent> joinedEventList = eventsDao.readAllNewRows(fromEventId, + fromQueryId); + + SQLiteDatabase db = mDbHelper.getWritableDatabase(); + try { + db.beginTransactionNonExclusive(); + for (JoinedEvent joinedEvent : joinedEventList) { + if (joinedEvent.getEventId() == 0) { + // Process Query-only rows + if (joinedEvent.getQueryData() != null) { + List<ContentValues> queryFieldRows = + OnDevicePersonalizationFlatbufferUtils + .getContentValuesFromQueryData( + joinedEvent.getQueryData()); + for (ContentValues queryRow : queryFieldRows) { + ContentValues insertValues = new ContentValues(); + insertValues.putAll(extractValidColumns(queryRow)); + insertValues.putAll(addProvidedColumns(joinedEvent)); + long insertResult = db.insert(TABLE_NAME, null, insertValues); + if (insertResult == -1) { + throw new IllegalStateException("Failed to insert row into SQL DB"); + } + } + } + } else { + ContentValues insertValues = new ContentValues(); + // Add eventData columns + if (joinedEvent.getEventData() != null) { + ContentValues eventData = + OnDevicePersonalizationFlatbufferUtils + .getContentValuesFromEventData( + joinedEvent.getEventData()); + insertValues.putAll(extractValidColumns(eventData)); + } + // Add queryData columns + if (joinedEvent.getQueryData() != null) { + ContentValues queryData = + OnDevicePersonalizationFlatbufferUtils + .getContentValuesRowFromQueryData( + joinedEvent.getQueryData(), + joinedEvent.getRowIndex()); + insertValues.putAll(extractValidColumns(queryData)); + } + // Add ODP provided columns + insertValues.putAll(addProvidedColumns(joinedEvent)); + long insertResult = db.insert(TABLE_NAME, null, insertValues); + if (insertResult == -1) { + throw new IllegalStateException("Failed to insert row into SQL DB"); + } + } + } + db.setTransactionSuccessful(); + } finally { + db.endTransaction(); + } + } + + private ContentValues addProvidedColumns(JoinedEvent joinedEvent) { + ContentValues result = new ContentValues(); + if (mColumns.containsKey(SERVICE_PACKAGE_NAME_COL)) { + result.put(SERVICE_PACKAGE_NAME_COL, + joinedEvent.getServicePackageName()); + } + if (mColumns.containsKey(TYPE_COL)) { + result.put(TYPE_COL, joinedEvent.getType()); + } + if (mColumns.containsKey(EVENT_TIME_MILLIS_COL)) { + result.put(EVENT_TIME_MILLIS_COL, joinedEvent.getEventTimeMillis()); + } + if (mColumns.containsKey(QUERY_TIME_MILLIS_COL)) { + result.put(QUERY_TIME_MILLIS_COL, joinedEvent.getQueryTimeMillis()); + } + return result; + } + + private ContentValues extractValidColumns(ContentValues data) { + ContentValues result = new ContentValues(); + for (String key : data.keySet()) { + if (mColumns.containsKey(key)) { + Object value = data.get(key); + int sqlType = mColumns.get(key).getType(); + if (value instanceof Byte) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) { + result.put(key, (Byte) value); + } + } else if (value instanceof Short) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) { + result.put(key, (Short) value); + } + } else if (value instanceof Integer) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) { + result.put(key, (Integer) value); + } + } else if (value instanceof Long) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) { + result.put(key, (Long) value); + } + } else if (value instanceof Float) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_REAL) { + result.put(key, (Float) value); + } + } else if (value instanceof Double) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_REAL) { + result.put(key, (Double) value); + } + } else if (value instanceof String) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_TEXT) { + result.put(key, (String) value); + } + } else if (value instanceof byte[]) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_BLOB) { + result.put(key, (byte[]) value); + } + } else if (value instanceof Boolean) { + if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) { + result.put(key, (Boolean) value); + } + } + } + } + return result; + } +} diff --git a/src/com/android/ondevicepersonalization/services/data/user/Country.java b/src/com/android/ondevicepersonalization/services/data/user/Country.java deleted file mode 100644 index 5beb8148..00000000 --- a/src/com/android/ondevicepersonalization/services/data/user/Country.java +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.data.user; - -/** - * A enum class for all countries in ISO-3166 alpha-3 format. - */ -public enum Country { - UNKNOWN, - ABW, // Aruba - AFG, // Afghanistan - AGO, // Angola - AIA, // Anguilla - ALA, // Åland Islands - ALB, // Albania - AND, // Andorra - ARE, // United Arab Emirates - ARG, // Argentina - ARM, // Armenia - ASM, // American Samoa - ATA, // Antarctica - ATF, // French Southern Territories - ATG, // Antigua and Barbuda - AUS, // Australia - AUT, // Austria - AZE, // Azerbaijan - BDI, // Burundi - BEL, // Belgium - BEN, // Benin - BES, // Bonaire, Sint Eustatius and Saba - BFA, // Burkina Faso - BGD, // Bangladesh - BGR, // Bulgaria - BHR, // Bahrain - BHS, // Bahamas - BIH, // Bosnia and Herzegovina - BLM, // Saint Barthélemy - BLR, // Belarus - BLZ, // Belize - BMU, // Bermuda - BOL, // Bolivia (Plurinational State of) - BRA, // Brazil - BRB, // Barbados - BRN, // Brunei Darussalam - BTN, // Bhutan - BVT, // Bouvet Island - BWA, // Botswana - CAF, // Central African Republic - CAN, // Canada - CCK, // Cocos (Keeling) Islands - CHE, // Switzerland - CHL, // Chile - CHN, // China - CIV, // Côte d'Ivoire - CMR, // Cameroon - COD, // Congo, Democratic Republic of the - COG, // Congo - COK, // Cook Islands - COL, // Colombia - COM, // Comoros - CPV, // Cabo Verde - CRI, // Costa Rica - CUB, // Cuba - CUW, // Curaçao - CXR, // Christmas Island - CYM, // Cayman Islands - CYP, // Cyprus - CZE, // Czechia - DEU, // Germany - DJI, // Djibouti - DMA, // Dominica - DNK, // Denmark - DOM, // Dominican Republic - DZA, // Algeria - ECU, // Ecuador - EGY, // Egypt - ERI, // Eritrea - ESH, // Western Sahara - ESP, // Spain - EST, // Estonia - ETH, // Ethiopia - FIN, // Finland - FJI, // Fiji - FLK, // Falkland Islands (Malvinas) - FRA, // France - FRO, // Faroe Islands - FSM, // Micronesia (Federated States of) - GAB, // Gabon - GBR, // United Kingdom of Great Britain and Northern Ireland - GEO, // Georgia - GGY, // Guernsey - GHA, // Ghana - GIB, // Gibraltar - GIN, // Guinea - GLP, // Guadeloupe - GMB, // Gambia - GNB, // Guinea-Bissau - GNQ, // Equatorial Guinea - GRC, // Greece - GRD, // Grenada - GRL, // Greenland - GTM, // Guatemala - GUF, // French Guiana - GUM, // Guam - GUY, // Guyana - HKG, // Hong Kong - HMD, // Heard Island and McDonald Islands - HND, // Honduras - HRV, // Croatia - HTI, // Haiti - HUN, // Hungary - IDN, // Indonesia - IMN, // Isle of Man - IND, // India - IOT, // British Indian Ocean Territory - IRL, // Ireland - IRN, // Iran (Islamic Republic of) - IRQ, // Iraq - ISL, // Iceland - ISR, // Israel - ITA, // Italy - JAM, // Jamaica - JEY, // Jersey - JOR, // Jordan - JPN, // Japan - KAZ, // Kazakhstan - KEN, // Kenya - KGZ, // Kyrgyzstan - KHM, // Cambodia - KIR, // Kiribati - KNA, // Saint Kitts and Nevis - KOR, // Korea, Republic of - KWT, // Kuwait - LAO, // Lao People's Democratic Republic - LBN, // Lebanon - LBR, // Liberia - LBY, // Libya - LCA, // Saint Lucia - LIE, // Liechtenstein - LKA, // Sri Lanka - LSO, // Lesotho - LTU, // Lithuania - LUX, // Luxembourg - LVA, // Latvia - MAC, // Macao - MAF, // Saint Martin (French part) - MAR, // Morocco - MCO, // Monaco - MDA, // Moldova, Republic of - MDG, // Madagascar - MDV, // Maldives - MEX, // Mexico - MHL, // Marshall Islands - MKD, // North Macedonia - MLI, // Mali - MLT, // Malta - MMR, // Myanmar - MNE, // Montenegro - MNG, // Mongolia - MNP, // Northern Mariana Islands - MOZ, // Mozambique - MRT, // Mauritania - MSR, // Montserrat - MTQ, // Martinique - MUS, // Mauritius - MWI, // Malawi - MYS, // Malaysia - MYT, // Mayotte - NAM, // Namibia - NCL, // New Caledonia - NER, // Niger - NFK, // Norfolk Island - NGA, // Nigeria - NIC, // Nicaragua - NIU, // Niue - NLD, // Netherlands - NOR, // Norway - NPL, // Nepal - NRU, // Nauru - NZL, // New Zealand - OMN, // Oman - PAK, // Pakistan - PAN, // Panama - PCN, // Pitcairn - PER, // Peru - PHL, // Philippines - PLW, // Palau - PNG, // Papua New Guinea - POL, // Poland - PRI, // Puerto Rico - PRK, // Korea (Democratic People's Republic of) - PRT, // Portugal - PRY, // Paraguay - PSE, // Palestine, State of - PYF, // French Polynesia - QAT, // Qatar - REU, // Réunion - ROU, // Romania - RUS, // Russian Federation - RWA, // Rwanda - SAU, // Saudi Arabia - SDN, // Sudan - SEN, // Senegal - SGP, // Singapore - SGS, // South Georgia and the South Sandwich Islands - SHN, // Saint Helena, Ascension and Tristan da Cunha - SJM, // Svalbard and Jan Mayen - SLB, // Solomon Islands - SLE, // Sierra Leone - SLV, // El Salvador - SMR, // San Marino - SOM, // Somalia - SPM, // Saint Pierre and Miquelon - SRB, // Serbia - SSD, // South Sudan - STP, // Sao Tome and Principe - SUR, // Suriname - SVK, // Slovakia - SVN, // Slovenia - SWE, // Sweden - SWZ, // Eswatini - SXM, // Sint Maarten (Dutch part) - SYC, // Seychelles - SYR, // Syrian Arab Republic - TCA, // Turks and Caicos Islands - TCD, // Chad - TGO, // Togo - THA, // Thailand - TJK, // Tajikistan - TKL, // Tokelau - TKM, // Turkmenistan - TLS, // Timor-Leste - TON, // Tonga - TTO, // Trinidad and Tobago - TUN, // Tunisia - TUR, // Türkiye - TUV, // Tuvalu - TWN, // Taiwan, Province of China - TZA, // Tanzania, United Republic of - UGA, // Uganda - UKR, // Ukraine - UMI, // United States Minor Outlying Islands - URY, // Uruguay - USA, // United States of America - UZB, // Uzbekistan - VAT, // Holy See - VCT, // Saint Vincent and the Grenadines - VEN, // Venezuela (Bolivarian Republic of) - VGB, // Virgin Islands (British) - VIR, // Virgin Islands (U.S.) - VNM, // Viet Nam - VUT, // Vanuatu - WLF, // Wallis and Futuna - WSM, // Samoa - YEM, // Yemen - ZAF, // South Africa - ZMB, // Zambia - ZWE // Zimbabwe -} diff --git a/src/com/android/ondevicepersonalization/services/data/user/DeviceMetrics.java b/src/com/android/ondevicepersonalization/services/data/user/DeviceMetrics.java deleted file mode 100644 index 6a40f5c1..00000000 --- a/src/com/android/ondevicepersonalization/services/data/user/DeviceMetrics.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.data.user; - -import android.content.res.Configuration; - -/** Constant device metrics values. */ -public class DeviceMetrics { - // Device manufacturer - public Make make = Make.UNKNOWN; - - // Device model - public Model model = Model.UNKNOWN; - - // Screen height of the device in dp units - public int screenHeight = Configuration.SCREEN_HEIGHT_DP_UNDEFINED; - - // Screen weight of the device in dp units - public int screenWidth = Configuration.SCREEN_WIDTH_DP_UNDEFINED; - - // Device x dpi; - public float xdpi = 0; - - // Device y dpi; - public float ydpi = 0; - - // Dveice pixel ratio. - public float pxRatio = 0; -} diff --git a/src/com/android/ondevicepersonalization/services/data/user/Language.java b/src/com/android/ondevicepersonalization/services/data/user/Language.java deleted file mode 100644 index a1a3cd24..00000000 --- a/src/com/android/ondevicepersonalization/services/data/user/Language.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.data.user; - -/** - * A enum class for all languages in ISO 639-1 format. - */ -public enum Language { - UNKNOWN, - AA, // Afar - AB, // Abkhazian - AE, // Avestan - AF, // Afrikaans - AK, // Akan - AM, // Amharic - AN, // Aragonese - AR, // Arabic - AS, // Assamese - AV, // Avaric - AY, // Aymara - AZ, // Azerbaijani - BA, // Bashkir - BE, // Belarusian - BG, // Bulgarian - BH, // Bihari languages - BI, // Bislama - BM, // Bambara - BN, // Bengali - BO, // Tibetan - BR, // Breton - BS, // Bosnian - CA, // Catalan; Valencian - CE, // Chechen - CH, // Chamorro - CO, // Corsican - CR, // Cree - CS, // Czech - CU, // Church Slavic; Old Slavonic; Church Slavonic; Old Bulgarian; Old Church Slavonic - CV, // Chuvash - CY, // Welsh - DA, // Danish - DE, // German - DV, // Divehi; Dhivehi; Maldivian - DZ, // Dzongkha - EE, // Ewe - EL, // "Greek - EN, // English - EO, // Esperanto - ES, // Spanish; Castilian - ET, // Estonian - EU, // Basque - FA, // Persian - FF, // Fulah - FI, // Finnish - FJ, // Fijian - FO, // Faroese - FR, // French - FY, // Western Frisian - GA, // Irish - GD, // Gaelic; Scottish Gaelic - GL, // Galician - GN, // Guarani - GU, // Gujarati - GV, // Manx - HA, // Hausa - HE, // Hebrew - HI, // Hindi - HO, // Hiri Motu - HR, // Croatian - HT, // Haitian; Haitian Creole - HU, // Hungarian - HY, // Armenian - HZ, // Herero - IA, // Interlingua (International Auxiliary Language Association) - ID, // Indonesian - IE, // Interlingue; Occidental - IG, // Igbo - II, // Sichuan Yi; Nuosu - IK, // Inupiaq - IO, // Ido - IS, // Icelandic - IT, // Italian - IU, // Inuktitut - JA, // Japanese - JV, // Javanese - KA, // Georgian - KG, // Kongo - KI, // Kikuyu; Gikuyu - KJ, // Kuanyama; Kwanyama - KK, // Kazakh - KL, // Kalaallisut; Greenlandic - KM, // Central Khmer - KN, // Kannada - KO, // Korean - KR, // Kanuri - KS, // Kashmiri - KU, // Kurdish - KV, // Komi - KW, // Cornish - KY, // Kirghiz; Kyrgyz - LA, // Latin - LB, // Luxembourgish; Letzeburgesch - LG, // Ganda - LI, // Limburgan; Limburger; Limburgish - LN, // Lingala - LO, // Lao - LT, // Lithuanian - LU, // Luba-Katanga - LV, // Latvian - MG, // Malagasy - MH, // Marshallese - MI, // Maori - MK, // Macedonian - ML, // Malayalam - MN, // Mongolian - MR, // Marathi - MS, // Malay - MT, // Maltese - MY, // Burmese - NA, // Nauru - NB, // "Bokmål - ND, // "Ndebele - NE, // Nepali - NG, // Ndonga - NL, // Dutch; Flemish - NN, // "Norwegian Nynorsk; Nynorsk - NO, // Norwegian - NR, // "Ndebele - NV, // Navajo; Navaho - NY, // Chichewa; Chewa; Nyanja - OC, // Occitan (post 1500) - OJ, // Ojibwa - OM, // Oromo - OR, // Oriya - OS, // Ossetian; Ossetic - PA, // Panjabi; Punjabi - PI, // Pali - PL, // Polish - PS, // Pushto; Pashto - PT, // Portuguese - QU, // Quechua - RM, // Romansh - RN, // Rundi - RO, // Romanian; Moldavian; Moldovan - RU, // Russian - RW, // Kinyarwanda - SA, // Sanskrit - SC, // Sardinian - SD, // Sindhi - SE, // Northern Sami - SG, // Sango - SI, // Sinhala; Sinhalese - SK, // Slovak - SL, // Slovenian - SM, // Samoan - SN, // Shona - SO, // Somali - SQ, // Albanian - SR, // Serbian - SS, // Swati - ST, // "Sotho - SU, // Sundanese - SV, // Swedish - SW, // Swahili - TA, // Tamil - TE, // Telugu - TG, // Tajik - TH, // Thai - TI, // Tigrinya - TK, // Turkmen - TL, // Tagalog - TN, // Tswana - TO, // Tonga (Tonga Islands) - TR, // Turkish - TS, // Tsonga - TT, // Tatar - TW, // Twi - TY, // Tahitian - UG, // Uighur; Uyghur - UK, // Ukrainian - UR, // Urdu - UZ, // Uzbek - VE, // Venda - VI, // Vietnamese - VO, // Volapük - WA, // Walloon - WO, // Wolof - XH, // Xhosa - YI, // Yiddish - YO, // Yoruba - ZA, // Zhuang; Chuang - ZH, // Chinese - ZU // Zulu -} diff --git a/src/com/android/ondevicepersonalization/services/data/user/Make.java b/src/com/android/ondevicepersonalization/services/data/user/Make.java deleted file mode 100644 index bff0cfd3..00000000 --- a/src/com/android/ondevicepersonalization/services/data/user/Make.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.data.user; - -/** - * A enum class for top 200 popular device manufacturers. - */ - -public enum Make { - UNKNOWN, - SAMSUNG, // samsung/Samsung - XIAOMI, // Xiaomi - MOTOROLA, // motorola - OPPO, // OPPO/oppo - VIVO, // vivo - REALME, // realme/Realme - HUAWEI, // HUAWEI - GOOGLE, // Google - SONY, // Sony - TECNO, // TECNO/TECNO MOBILE LIMITED - ZTE, // ZTE - TCL, // TCL - INFINIX, // INFINIX/Infinix/INFINIX MOBILITY LIMITED - LENOVO, // LENOVO/Lenovo - ONEPLUS, // OnePlus - SHARP, // SHARP - HMD_GLOBAL, // HMD Global - TECHNICOLOR, // Technicolor - LG_ELECTRONICS, // LGE/LG Electronics - SKYWORTH, // skyworth/SkyworthDigital - WNC, // WNC/wnc - ITEL_MOBILE, // itel/ITEL/ITEL MOBILE LIMITED - KAON_MEDIA, // KaonMedia/KAON - TPV, // TPV - ZEBRA_TECHNOLOGIES, // Zebra Technologies - SDMC, // SDMC - KYOCERA, // KYOCERA - INTEK, // INTEK - HISENSE, // Hisense - FUJITSU, // FUJITSU - BLU, // BLU - SAGEMCOM, // Sagemcom - HONOR, // HONOR - MARUSYS, // MARUSYS - FCNT, // FCNT - VOLVO_CARS, // VolvoCars - TINNO, // TINNO - INNOPIA, // Innopia/INNOPIA - WINGTECH, // Wingtech - WIKO, // WIKO - CHANGHONG, // Changhong - FOSSIL, // Fossil - INCAR, // incar - GM, // gm - SCHOK, // Schok - BLACKVIEW, // Blackview - CAT, // Cat - PANASONIC, // Panasonic - ULEFONE, // Ulefone - SUMITOMO_ELECTRIC_INDUSTRIES, // SumitomoElectricIndustries - ALONG, // along - PRIME, // Prime - COMPAL, // Compal - POSITIVO, // positivo - MULTILASER, // Multilaser - ASUS, // asus - WHEATEK, // wheatek - VSMART, // vsmart - HONEYWELL, // Honeywell - ARRIS, // ARRIS - MOBICEL, // Mobicel - MOBIWIRE, // MobiWire - DOOGEE, // DOOGEE - ASKEY, // askey - JIO, // Jio - HAIER, // Haier - RHINO, // RHINO - CROSSCALL, // Crosscall - SCBC, // SCBC - STAR_HUB, // StarHub - SZMTC, // SZMTC - FAIRPHONE, // Fairphone - OUKITEL, // OUKITEL - SUNMI, // SUNMI - ITREE, // iTree - ALLDOCUBE, // Alldocube - GENERAL_MOBILE, // General Mobile - KTC, // KTC - HXY, // HXY - FUNAI, // Funai - JOYAR, // JOYAR/Joyar - MAXWEST, // Maxwest - AMINO, // Amino - INSPUR, // Inspur - NEC, // NEC - KONNECT_ONE, // KonnectONE - CUBOT, // CUBOT - LAVA, // LAVA - FISE, // FISE - CASIO_COMPUTER, // CASIO COMPUTER CO., LTD. - DISH, // DISH - EXPRESSLUCK, // Expressluck - MOBVOI, // Mobvoi - ONN, // onn - ARCELIK, // ARCELIK - SEI_ROBOTICS, // SEI/SEI Robotics - IMUZ, // IMUZ - FORTUNE_SHIP, // Fortune Ship - SYMPHONY, // Symphony - EAST_AEON, // EastAeon - ACER, // Acer - SONIMTECH, // Sonimtech - SAFARICOM, // Safaricom - GIGASET, // Gigaset - EMDOOR, // emdoor - REEDER, // reeder - TRESWAVE, // Treswave - TECLAST, // Teclast/TECLAST - MTC, // MTC - BQRU, // BQru - HKC, // HKC - STYLO, // STYLO - HI_MEDIA, // HiMedia - HOT_PEPPER, // Hot Pepper Inc - MY, // My - UMX, // Umx - NOTHING, // Nothing - NUU, // NUU - SKY_DEVICES, // sky/SKY/Sky Devices/SkyDevices/Sky_Devices - KODAK, // KODAK - FOXX_DEVELOPMENT, // Foxx Development Inc. - CHINOE, // Chinoe - MICROMAX, // Micromax - SKYTHTEK, // skythtek - HTC, // HTC - LT, // LT - WALTON, // WALTON - DEXP, // DEXP - RELIANCE_COMMUNICATIONS, // Reliance Communications - BMOBILE, // Bmobile - EMPORIA, // emporia/Emporia Telecom GmbH & Co. KG - OMIX, // OMIX - ASCOM, // Ascom - PHILCO, // Philco - BENCO, // benco - PHILIPS, // Philips - COOSEA, // Coosea - DATALOGIC, // Datalogic - UNITECH_ELECTRONICS, // Unitech_Electronics - ZUUM, // ZUUM - CASPER, // Casper - IMG, // IMG - CIPHER_LAB, // CipherLab - GIONEE, // GIONEE - UMIDIGI, // UMIDIGI - DEEJOY, // Deejoy - YULONG, // Yulong - VONINO, // Vonino - HENA, // hena - ISAFE_MOBILE, // isafemobile - ARCADYAN, // Arcadyan - CLARO, // Claro/Claro Colombia - BLACKSHARK, // blackshark - NVIDIA, // NVIDIA - MY_PHONE, // myPhone - VESTEL, // Vestel - KONROW, // KONROW - TAG_HEUER, // TAG Heuer - UROVO, // Urovo - VS, // VS - INNOVATECH, // Innovatech - CHAINWAY, // CHAINWAY - KRIP, // KRIP - EACRUGGED, // EACRUGGED - M3MOBILE, // M3Mobile - SPECTRALINK, // Spectralink - XWIRELESS, // XwirelessLLC - FREEBOX, // Freebox - ADVAN, // ADVAN - JIUZHOU, // Jiuzhou - MPTECH, // mPTech - ALCO, // Alco - ALLVIEW, // ALLVIEW - NUBIA, // nubia - TABLET_PC, // Tablet_PC -} diff --git a/src/com/android/ondevicepersonalization/services/data/user/Model.java b/src/com/android/ondevicepersonalization/services/data/user/Model.java deleted file mode 100644 index 7fdf13a4..00000000 --- a/src/com/android/ondevicepersonalization/services/data/user/Model.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.data.user; - -/** - * A enum class for top 200 popular device model. - */ - -public enum Model { - UNKNOWN, - SM_A515F, // SM-A515F - SM_A127F, // SM-A127F - M2006C3LG, // M2006C3LG - SM_A125F, // SM-A125F - OCTOPUS, // octopus - SM_A217F, // SM-A217F - SM_A528B, // SM-A528B - UIE4027LGU, // uie4027lgu - M2003J15SC, // M2003J15SC - BRAVIA_4K_VH2, // BRAVIA 4K VH2 - SM_A325F, // SM-A325F - SM_R860, // SM-R860 - SM_A202F, // SM-A202F - M2004J19C, // M2004J19C - BFX_AT100, // BFX-AT100 - ATT_TV, // AT&T TV - SM_A715F, // SM-A715F - REDMI_NOTE_8_PRO, // Redmi Note 8 Pro - REDMI_NOTE_8, // Redmi Note 8 - REDMI_NOTE_9_PRO, // Redmi Note 9 Pro - SM_A135F, // SM-A135F - SM_A226B, // SM-A226B - SM_G991B, // SM-G991B - SM_A405FN, // SM-A405FN - SM_T500, // SM-T500 - XVIEW_PLUS, // Xview+ - SM_A525F, // SM-A525F - SM_A505FN, // SM-A505FN - SM_P610, // SM-P610 - SM_G973F, // SM-G973F - SM_A107F, // SM-A107F - SM_A326B, // SM-A326B - SM_R870, // SM-R870 - SM_T510, // SM-T510 - CHROMECAST, // Chromecast - B820C_A15, // B820C-A15 - SM_G960F, // SM-G960F - SM_G780G, // SM-G780G - SKYWORTH_HY4002, // SKYWORTH-HY4002 - M2101K6G, // M2101K6G - MAR_LX1A, // MAR-LX1A - RMX3231, // RMX3231 - REDMI_NOTE_7, // Redmi Note 7 - SM_A137F, // SM-A137F - LENOVO_TB_8505F, // Lenovo TB-8505F - SM_G781B, // SM-G781B - M2102J20SG, // M2102J20SG - SM_A536B, // SM-A536B - M2006C3MG, // M2006C3MG - SM_A032F, // SM-A032F - SM_A115F, // SM-A115F - SM_G975F, // SM-G975F - SM_R890, // SM-R890 - SM_A105FN, // SM-A105FN - SM_A107M, // SM-A107M - SM_A415F, // SM-A415F - SM_A315G, // SM-A315G - SM_A037F, // SM-A037F - M2006C3MNG, // M2006C3MNG - CPH2185, // CPH2185 - SM_X200, // SM-X200 - SM_G991U, // SM-G991U - SMART_TV, // Smart TV - M2006C3LI, // M2006C3LI - SM_A325M, // SM-A325M - MOTO_G_PURE, // moto g pure - CPH2269, // CPH2269 - SM_T290, // SM-T290 - SM_T295, // SM-T295 - SM_T220, // SM-T220 - SM_S901U, // SM-S901U - SM_A326U, // SM-A326U - SM_A105F, // SM-A105F - REDMI_NOTE_9S, // Redmi Note 9S - SM_A125U, // SM-A125U - SM_A207F, // SM-A207F - SM_A217M, // SM-A217M - REDMI_8, // Redmi 8 - SM_G965F, // SM-G965F - SM_G960U, // SM-G960U - SM_A307FN, // SM-A307FN - SM_A225F, // SM-A225F - SM_G970F, // SM-G970F - BID_AT200, // BID-AT200 - SM_G998B, // SM-G998B - SM_S908U, // SM-S908U - SM_A705FN, // SM-A705FN - SM_A136U, // SM-A136U - SM_G780F, // SM-G780F - SM_A035F, // SM-A035F - M2101K7BNY, // M2101K7BNY - REDMI_2201117TG, // 2201117TG - SM_G970U, // SM-G970U - POT_LX1, // POT-LX1 - SM_S901B, // SM-S901B - REDMI_2201117TY, // 2201117TY - SM_G525F, // SM-G525F - VOG_L29, // VOG-L29 - SM_M315F, // SM-M315F - M2103K19G, // M2103K19G - SM_A315F, // SM-A315F - B860H_V5, // B860H V5.0 - CPH2239, // CPH2239 - INFINIX_X657, // Infinix X657 - CPH2127, // CPH2127 - REDMI_NOTE_8T, // Redmi Note 8T - SM_A505F, // SM-A505F - MOTO_G20, // moto g(20) - SM_A037U, // SM-A037U - PHILIPS_2020_2021_UHD_ANDROID_TV, // 2020/2021 UHD Android TV - VIVO_1906, // vivo 1906 - SM_M127F, // SM-M127F - SM_G990B, // SM-G990B - SM_G715U1, // SM-G715U1 - SM_A022F, // SM-A022F - SM_A127M, // SM-A127M - BEYONDTV, // BeyondTV - MA4000, // MA4000 - PIXEL_6, // Pixel 6 - SM_A035M, // SM-A035M - M2007J20CG, // M2007J20CG - V2111, // V2111 - XIAOMI_2109119DG, // 2109119DG - SM_A032M, // SM-A032M - SM_G998U, // SM-G998U - SM_G973U, // SM-G973U - LENOVO_TB_X606F, // Lenovo TB-X606F - SM_T505, // SM-T505 - SM_A235F, // SM-A235F - REDMI_7A, // Redmi 7A - SM_T225, // SM-T225 - SM_A336B, // SM-A336B - SM_G398FN, // SM-G398FN - SM_S908B, // SM-S908B - SM_A600FN, // SM-A600FN - M2006C3MII, // M2006C3MII - SM_A015M, // SM-A015M - SM_G781U, // SM-G781U - SM_A536E, // SM-A536E - SM_A750FN, // SM-A750FN - LENOVO_TB_X306X, // Lenovo TB-X306X - SM_A022M, // SM-A022M - SM_M115F, // SM-M115F - SM_G991N, // SM-G991N - SM_A105M, // SM-A105M - M2101K7AG, // M2101K7AG - SM_A125M, // SM-A125M - SM_A025F, // SM-A025F - SNE_LX1, // SNE-LX1 - MOTO_G9_PLAY, // moto g(9) play - SM_G980F, // SM-G980F - SM_A115M, // SM-A115M - BRAVIA_4K_VH21, // BRAVIA 4K VH21 - MOTO_E7, // moto e(7) - SM_R865U, // SM-R865U - SM_G996B, // SM-G996B - M2010J19SY, // M2010J19SY - PHILIPS_2021_22_UHD_ANDROID_TV, // 2021/22 Philips UHD Android TV - INFINIX_X688B, // Infinix X688B - SM_A205F, // SM-A205F - SM_M215F, // SM-M215F - F_42A, // F-42A - UHD4K, // UHD4K - SM_N960F, // SM-N960F - SM_A135M, // SM-A135M - SM_F711N, // SM-F711N - SM_A305F, // SM-A305F - SM_T227U, // SM-T227U - AQUOS_TVJ19, // AQUOS-TVJ19 - VIVO_1904, // vivo 1904 - SMART_TV_PRO, // Smart TV Pro - RMX3263, // RMX3263 - MOTO_E20, // moto e20 - SM_T720, // SM-T720 - M2010J19SG, // M2010J19SG - PIXEL_4A, // Pixel 4a - SM_P615N, // SM-P615N - MOTO_G30, // moto g(30) - STK_L21, // STK-L21 - MOTO_G_POWER_2021, // moto g power (2021) - SM_G781V, // SM-G781V - MOTO_G7_POWER, // moto g(7) power - V2027, // V2027 - CPH2219, // CPH2219 - SM_T515, // SM-T515 - SM_A725F, // SM-A725F - RMX3085, // RMX3085 - ELE_L29, // ELE-L29 - SM_A013G, // SM-A013G - SM_A207M, // SM-A207M - OTT_XVIEW_PLUS_AV1, // OTT Xview+ AV1 - SM_G975U, // SM-G975U - SM_N975F, // SM-N975F - SM_A013M, // SM-A013M - SM_J600FN, // SM-J600FN - VOLVO, // Volvo - SM_S908E, // SM-S908E - INFINIX_X657B, // Infinix X657B - M2010J19CG, // M2010J19CG - K1100UA, // K1100UA -} diff --git a/src/com/android/ondevicepersonalization/services/data/user/PrivacySignal.java b/src/com/android/ondevicepersonalization/services/data/user/PrivacySignal.java deleted file mode 100644 index 21c03360..00000000 --- a/src/com/android/ondevicepersonalization/services/data/user/PrivacySignal.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.data.user; - -/** - * A singleton class that stores all privacy signals of the user in memory. - * TODO (b/272075982): make it as a part of ODP flags. - */ -public final class PrivacySignal { - public static PrivacySignal sPrivacySignal = null; - private boolean mKidStatusEnabled; - private boolean mLimitedAdsTrackingEnabled; - - private PrivacySignal() { - // Assume the more privacy-safe option until updated. - mKidStatusEnabled = true; - mLimitedAdsTrackingEnabled = true; - } - - /** Returns an instance of PrivacySignal. */ - public static PrivacySignal getInstance() { - synchronized (PrivacySignal.class) { - if (sPrivacySignal == null) { - sPrivacySignal = new PrivacySignal(); - } - return sPrivacySignal; - } - } - - public void setKidStatusEnabled(boolean kidStatusEnabled) { - mKidStatusEnabled = kidStatusEnabled; - } - - public boolean isKidStatusEnabled() { - return mKidStatusEnabled; - } - - public void setLimitedAdsTrackingEnabled(boolean limitedAdsTrackingEnabled) { - mLimitedAdsTrackingEnabled = limitedAdsTrackingEnabled; - } - - public boolean isLimitedAdsTrackingEnabled() { - return mLimitedAdsTrackingEnabled; - } -} diff --git a/src/com/android/ondevicepersonalization/services/data/user/RawUserData.java b/src/com/android/ondevicepersonalization/services/data/user/RawUserData.java index 7ecdf18b..9ab50d08 100644 --- a/src/com/android/ondevicepersonalization/services/data/user/RawUserData.java +++ b/src/com/android/ondevicepersonalization/services/data/user/RawUserData.java @@ -16,9 +16,9 @@ package com.android.ondevicepersonalization.services.data.user; +import android.adservices.ondevicepersonalization.UserData; import android.content.res.Configuration; - -import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import android.net.NetworkCapabilities; import java.util.ArrayList; import java.util.HashMap; @@ -30,13 +30,8 @@ import java.util.List; public final class RawUserData { private static RawUserData sUserData = null; - private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); - private static final String TAG = "UserData"; - - // The current system time in milliseconds. - public long timeMillis = 0; - // The device time zone +/- minutes offset from UTC. + // The device time zone +/- offset in minute from UTC. public int utcOffset = 0; // The device orientation. @@ -48,40 +43,12 @@ public final class RawUserData { // Battery percentage. public int batteryPercentage = 0; - // The 3-letter ISO-3166 country code - public Country country = Country.UNKNOWN; - - // The 2-letter ISO-639 language code - public Language language = Language.UNKNOWN; - // Mobile carrier. public Carrier carrier = Carrier.UNKNOWN; - // OS versions of the device. - public OSVersion osVersions = new OSVersion(); - - // Connection type values. - public enum ConnectionType { - UNKNOWN, - ETHERNET, - WIFI, - CELLULAR_2G, - CELLULAR_3G, - CELLULAR_4G, - CELLULAR_5G - }; - - // Connection type. - public ConnectionType connectionType = ConnectionType.UNKNOWN; - - // Status if network is metered. False - not metered. True - metered. - public boolean networkMetered = false; - - // Connection speed in kbps. - public long connectionSpeedKbps = 0; + public NetworkCapabilities networkCapabilities; - // Device metrics values. - public DeviceMetrics deviceMetrics = new DeviceMetrics(); + @UserData.NetworkType public int dataNetworkType; // installed packages. public List<AppInfo> appsInfo = new ArrayList<>(); diff --git a/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobService.java b/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobService.java index 8bb515db..eac673c4 100644 --- a/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobService.java +++ b/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobService.java @@ -76,14 +76,17 @@ public class UserDataCollectionJobService extends JobService { @Override public boolean onStartJob(JobParameters params) { - // TODO(b/265856477): return false to disable data collection if kid status is enabled. sLogger.d(TAG + ": onStartJob()"); if (FlagsFactory.getFlags().getGlobalKillSwitch()) { sLogger.d(TAG + ": GlobalKillSwitch enabled, finishing job."); jobFinished(params, /* wantsReschedule = */ false); return true; } - + if (!UserPrivacyStatus.getInstance().isPersonalizationStatusEnabled()) { + sLogger.d(TAG + ": Personalization is not allowed, finishing job."); + jobFinished(params, /* wantsReschedule = */ false); + return true; + } mUserDataCollector = UserDataCollector.getInstance(this); mUserData = RawUserData.getInstance(); mFuture = Futures.submit(new Runnable() { diff --git a/src/com/android/ondevicepersonalization/services/data/user/UserDataCollector.java b/src/com/android/ondevicepersonalization/services/data/user/UserDataCollector.java index 26eeb000..4ff2e769 100644 --- a/src/com/android/ondevicepersonalization/services/data/user/UserDataCollector.java +++ b/src/com/android/ondevicepersonalization/services/data/user/UserDataCollector.java @@ -30,14 +30,10 @@ import android.database.Cursor; import android.location.Location; import android.location.LocationManager; import android.net.ConnectivityManager; -import android.net.NetworkCapabilities; import android.os.BatteryManager; -import android.os.Build; import android.os.Environment; import android.os.StatFs; import android.telephony.TelephonyManager; -import android.util.DisplayMetrics; -import android.view.WindowManager; import androidx.annotation.NonNull; @@ -45,8 +41,6 @@ import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import com.android.ondevicepersonalization.services.data.user.LocationInfo.LocationProvider; -import com.google.common.base.Strings; - import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Calendar; @@ -55,6 +49,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; +import java.util.Set; import java.util.TimeZone; /** @@ -67,20 +62,41 @@ import java.util.TimeZone; * and update a few time-sensitive signals in UserData to the latest version. */ public class UserDataCollector { - public static final int BYTES_IN_MB = 1048576; + private static final int MILLISECONDS_IN_MINUTE = 60000; - private static UserDataCollector sUserDataCollector = null; + private static volatile UserDataCollector sUserDataCollector = null; private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "UserDataCollector"; + @VisibleForTesting + public static final Set<Integer> ALLOWED_NETWORK_TYPE = + Set.of( + TelephonyManager.NETWORK_TYPE_UNKNOWN, + TelephonyManager.NETWORK_TYPE_GPRS, + TelephonyManager.NETWORK_TYPE_EDGE, + TelephonyManager.NETWORK_TYPE_UMTS, + TelephonyManager.NETWORK_TYPE_CDMA, + TelephonyManager.NETWORK_TYPE_EVDO_0, + TelephonyManager.NETWORK_TYPE_EVDO_A, + TelephonyManager.NETWORK_TYPE_1xRTT, + TelephonyManager.NETWORK_TYPE_HSDPA, + TelephonyManager.NETWORK_TYPE_HSUPA, + TelephonyManager.NETWORK_TYPE_HSPA, + TelephonyManager.NETWORK_TYPE_EVDO_B, + TelephonyManager.NETWORK_TYPE_LTE, + TelephonyManager.NETWORK_TYPE_EHRPD, + TelephonyManager.NETWORK_TYPE_HSPAP, + TelephonyManager.NETWORK_TYPE_GSM, + TelephonyManager.NETWORK_TYPE_TD_SCDMA, + TelephonyManager.NETWORK_TYPE_IWLAN, + TelephonyManager.NETWORK_TYPE_NR + ); + @NonNull private final Context mContext; @NonNull - private Locale mLocale; - @NonNull private final TelephonyManager mTelephonyManager; - @NonNull - private final NetworkCapabilities mNetworkCapabilities; + @NonNull final ConnectivityManager mConnectivityManager; @NonNull private final LocationManager mLocationManager; @NonNull @@ -101,12 +117,8 @@ public class UserDataCollector { private UserDataCollector(Context context, UserDataDao userDataDao) { mContext = context; - mLocale = Locale.getDefault(); mTelephonyManager = mContext.getSystemService(TelephonyManager.class); - ConnectivityManager connectivityManager = mContext.getSystemService( - ConnectivityManager.class); - mNetworkCapabilities = connectivityManager.getNetworkCapabilities( - connectivityManager.getActiveNetwork()); + mConnectivityManager = mContext.getSystemService(ConnectivityManager.class); mLocationManager = mContext.getSystemService(LocationManager.class); mUserDataDao = userDataDao; mLastTimeMillisAppUsageCollected = 0L; @@ -117,13 +129,16 @@ public class UserDataCollector { /** Returns an instance of UserDataCollector. */ public static UserDataCollector getInstance(Context context) { - synchronized (UserDataCollector.class) { - if (sUserDataCollector == null) { - sUserDataCollector = new UserDataCollector( - context, UserDataDao.getInstance(context)); + if (sUserDataCollector == null) { + synchronized (UserDataCollector.class) { + if (sUserDataCollector == null) { + sUserDataCollector = new UserDataCollector( + context.getApplicationContext(), + UserDataDao.getInstance(context.getApplicationContext())); + } } - return sUserDataCollector; } + return sUserDataCollector; } /** @@ -151,9 +166,8 @@ public class UserDataCollector { if (!mInitialized) { return; } - userData.timeMillis = getTimeMillis(); - userData.utcOffset = getUtcOffset(); - userData.orientation = getOrientation(); + getUtcOffset(userData); + getOrientation(userData); } /** Update user data per periodic job servce. */ @@ -162,16 +176,12 @@ public class UserDataCollector { initializeUserData(userData); return; } - userData.availableStorageBytes = getAvailableStorageBytes(); - userData.batteryPercentage = getBatteryPercentage(); - userData.country = getCountry(); - userData.language = getLanguage(); - userData.carrier = getCarrier(); - userData.connectionType = getConnectionType(); - userData.networkMetered = isNetworkMetered(); - userData.connectionSpeedKbps = getConnectionSpeedKbps(); + getAvailableStorageBytes(userData); + getBatteryPercentage(userData); + getCarrier(userData); + getNetworkCapabilities(userData); + getDataNetworkType(userData); - getOSVersions(userData.osVersions); getInstalledApps(userData.appsInfo); getAppUsageStats(userData.appUsageHistory); getLastknownLocation(userData.locationHistory, userData.currentLocation); @@ -183,21 +193,13 @@ public class UserDataCollector { * for the schedule of {@link UserDataCollectionJobService} */ private void initializeUserData(@NonNull RawUserData userData) { - userData.timeMillis = getTimeMillis(); - userData.utcOffset = getUtcOffset(); - userData.orientation = getOrientation(); - userData.availableStorageBytes = getAvailableStorageBytes(); - userData.batteryPercentage = getBatteryPercentage(); - userData.country = getCountry(); - userData.language = getLanguage(); - userData.carrier = getCarrier(); - userData.connectionType = getConnectionType(); - userData.networkMetered = isNetworkMetered(); - userData.connectionSpeedKbps = getConnectionSpeedKbps(); - - getOSVersions(userData.osVersions); - - getDeviceMetrics(userData.deviceMetrics); + getUtcOffset(userData); + getOrientation(userData); + getAvailableStorageBytes(userData); + getBatteryPercentage(userData); + getCarrier(userData); + getNetworkCapabilities(userData); + getDataNetworkType(userData); getInstalledApps(userData.appsInfo); @@ -213,453 +215,141 @@ public class UserDataCollector { mInitialized = true; } - /** Collects current system clock on the device. */ - @VisibleForTesting - public long getTimeMillis() { - return System.currentTimeMillis(); - } - - /** Collects current device's time zone in +/- of minutes from UTC. */ - @VisibleForTesting - public int getUtcOffset() { - return TimeZone.getDefault().getOffset(System.currentTimeMillis()) / 60000; - } - - /** Collects the current device orientation. */ - @VisibleForTesting - public int getOrientation() { - return mContext.getResources().getConfiguration().orientation; - } - - /** Collects available bytes and converts to MB. */ - @VisibleForTesting - public long getAvailableStorageBytes() { - StatFs statFs = new StatFs(Environment.getDataDirectory().getPath()); - return statFs.getAvailableBytes(); - } - - /** Collects the battery percentage of the device. */ - @VisibleForTesting - public int getBatteryPercentage() { - IntentFilter ifilter = new IntentFilter(Intent.ACTION_BATTERY_CHANGED); - Intent batteryStatus = mContext.registerReceiver(null, ifilter); - - int level = batteryStatus.getIntExtra(BatteryManager.EXTRA_LEVEL, -1); - int scale = batteryStatus.getIntExtra(BatteryManager.EXTRA_SCALE, -1); - if (level >= 0 && scale > 0) { - return Math.round(level * 100.0f / (float) scale); - } - return 0; - } - - /** Collects current device's country information. */ + /** Collects current device's time zone in +/- offset of minutes from UTC. */ @VisibleForTesting - public Country getCountry() { - String countryCode = mLocale.getISO3Country(); - if (Strings.isNullOrEmpty(countryCode)) { - return Country.UNKNOWN; - } else { - Country country = Country.UNKNOWN; - try { - country = Country.valueOf(countryCode); - } catch (IllegalArgumentException iae) { - sLogger.e(TAG + ": Country code cannot match to a country.", iae); - return country; - } - return country; - } - } - - /** Collects current device's language information. */ - @VisibleForTesting - public Language getLanguage() { - String langCode = mLocale.getLanguage(); - if (Strings.isNullOrEmpty(langCode)) { - return Language.UNKNOWN; - } else { - Language language = Language.UNKNOWN; - try { - language = Language.valueOf(langCode.toUpperCase(Locale.US)); - } catch (IllegalArgumentException iae) { - sLogger.e(TAG + ": Language code cannot match to a language.", iae); - return language; - } - return language; - } - } - - /** Collects carrier info. */ - @VisibleForTesting - public Carrier getCarrier() { - // TODO: handle i18n later if the carrier's name is in non-English script. - switch (mTelephonyManager.getSimOperatorName().toUpperCase(Locale.US)) { - case "RELIANCE JIO": - return Carrier.RELIANCE_JIO; - case "VODAFONE": - return Carrier.VODAFONE; - case "T-MOBILE - US": - case "T-MOBILE": - return Carrier.T_MOBILE; - case "VERIZON WIRELESS": - return Carrier.VERIZON_WIRELESS; - case "AIRTEL": - return Carrier.AIRTEL; - case "ORANGE": - return Carrier.ORANGE; - case "NTT DOCOMO": - return Carrier.NTT_DOCOMO; - case "MOVISTAR": - return Carrier.MOVISTAR; - case "AT&T": - return Carrier.AT_T; - case "TELCEL": - return Carrier.TELCEL; - case "VIVO": - return Carrier.VIVO; - case "VI": - return Carrier.VI; - case "TIM": - return Carrier.TIM; - case "O2": - return Carrier.O2; - case "TELEKOM": - return Carrier.TELEKOM; - case "CLARO BR": - return Carrier.CLARO_BR; - case "SK TELECOM": - return Carrier.SK_TELECOM; - case "MTC": - return Carrier.MTC; - case "AU": - return Carrier.AU; - case "TELE2": - return Carrier.TELE2; - case "SFR": - return Carrier.SFR; - case "ETECSA": - return Carrier.ETECSA; - case "IR-MCI (HAMRAHE AVVAL)": - return Carrier.IR_MCI; - case "KT": - return Carrier.KT; - case "TELKOMSEL": - return Carrier.TELKOMSEL; - case "IRANCELL": - return Carrier.IRANCELL; - case "MEGAFON": - return Carrier.MEGAFON; - case "TELEFONICA": - return Carrier.TELEFONICA; - default: - return Carrier.UNKNOWN; + public void getUtcOffset(RawUserData userData) { + try { + userData.utcOffset = TimeZone.getDefault().getOffset(System.currentTimeMillis()) + / MILLISECONDS_IN_MINUTE; + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect timezone offset."); } } - /** - * Collects device OS version info. - * ODA only identifies three valid raw forms of OS releases - * and convert it to the three-version format. - * 13 -> 13.0.0 - * 8.1 -> 8.1.0 - * 4.1.2 as it is. - */ + /** Collects the current device orientation. */ @VisibleForTesting - public void getOSVersions(@NonNull OSVersion osVersions) { - String osRelease = Build.VERSION.RELEASE; - int major = 0; - int minor = 0; - int micro = 0; + public void getOrientation(RawUserData userData) { try { - major = Integer.parseInt(osRelease); - } catch (NumberFormatException nfe1) { - try { - String[] versions = osRelease.split("[.]"); - if (versions.length == 2) { - major = Integer.parseInt(versions[0]); - minor = Integer.parseInt(versions[1]); - } else if (versions.length == 3) { - major = Integer.parseInt(versions[0]); - minor = Integer.parseInt(versions[1]); - micro = Integer.parseInt(versions[2]); - } else { - // An irregular release like "UpsideDownCake" - sLogger.e(TAG + ": OS release string cannot be matched to a regular version.", - nfe1); - } - } catch (NumberFormatException nfe2) { - // An irrgular release like "QKQ1.200830.002" - sLogger.e(TAG + ": OS release string cannot be matched to a regular version.", - nfe2); - } - } finally { - osVersions.major = major; - osVersions.minor = minor; - osVersions.micro = micro; + userData.orientation = mContext.getResources().getConfiguration().orientation; + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect device orientation."); } } - /** Collects connection type. */ + /** Collects available bytes and converts to MB. */ @VisibleForTesting - public RawUserData.ConnectionType getConnectionType() { + public void getAvailableStorageBytes(RawUserData userData) { try { - // TODO(b/290256559): Fix permissions issue. - if (mNetworkCapabilities == null) { - return RawUserData.ConnectionType.UNKNOWN; - } else if (mNetworkCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)) { - switch (mTelephonyManager.getDataNetworkType()) { - case TelephonyManager.NETWORK_TYPE_1xRTT: - case TelephonyManager.NETWORK_TYPE_CDMA: - case TelephonyManager.NETWORK_TYPE_EDGE: - case TelephonyManager.NETWORK_TYPE_GPRS: - case TelephonyManager.NETWORK_TYPE_GSM: - case TelephonyManager.NETWORK_TYPE_IDEN: - return RawUserData.ConnectionType.CELLULAR_2G; - case TelephonyManager.NETWORK_TYPE_EHRPD: - case TelephonyManager.NETWORK_TYPE_EVDO_0: - case TelephonyManager.NETWORK_TYPE_EVDO_A: - case TelephonyManager.NETWORK_TYPE_EVDO_B: - case TelephonyManager.NETWORK_TYPE_HSDPA: - case TelephonyManager.NETWORK_TYPE_HSPA: - case TelephonyManager.NETWORK_TYPE_HSPAP: - case TelephonyManager.NETWORK_TYPE_HSUPA: - case TelephonyManager.NETWORK_TYPE_TD_SCDMA: - case TelephonyManager.NETWORK_TYPE_UMTS: - return RawUserData.ConnectionType.CELLULAR_3G; - case TelephonyManager.NETWORK_TYPE_LTE: - case TelephonyManager.NETWORK_TYPE_IWLAN: - return RawUserData.ConnectionType.CELLULAR_4G; - case TelephonyManager.NETWORK_TYPE_NR: - return RawUserData.ConnectionType.CELLULAR_5G; - default: - return RawUserData.ConnectionType.UNKNOWN; - } - } else if (mNetworkCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_WIFI)) { - return RawUserData.ConnectionType.WIFI; - } else if (mNetworkCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_ETHERNET)) { - return RawUserData.ConnectionType.ETHERNET; - } + StatFs statFs = new StatFs(Environment.getDataDirectory().getPath()); + userData.availableStorageBytes = statFs.getAvailableBytes(); } catch (Exception e) { - sLogger.e(TAG + ": getConnectionType() failed.", e); + sLogger.w(TAG + ": Failed to collect availableStorageBytes."); } - return RawUserData.ConnectionType.UNKNOWN; } - /** Collects metered status. */ + /** Collects the battery percentage of the device. */ @VisibleForTesting - public boolean isNetworkMetered() { - if (mNetworkCapabilities == null) { - return false; - } - int[] capabilities = mNetworkCapabilities.getCapabilities(); - for (int i = 0; i < capabilities.length; ++i) { - if (capabilities[i] == NetworkCapabilities.NET_CAPABILITY_NOT_METERED) { - return false; - } - } - return true; - } + public void getBatteryPercentage(RawUserData userData) { + try { + IntentFilter ifilter = new IntentFilter(Intent.ACTION_BATTERY_CHANGED); + Intent batteryStatus = mContext.registerReceiver(null, ifilter); - /** Collects connection speed in kbps */ - @VisibleForTesting - public long getConnectionSpeedKbps() { - if (mNetworkCapabilities == null) { - return 0; + int level = batteryStatus.getIntExtra(BatteryManager.EXTRA_LEVEL, -1); + int scale = batteryStatus.getIntExtra(BatteryManager.EXTRA_SCALE, -1); + if (level >= 0 && scale > 0) { + userData.batteryPercentage = Math.round(level * 100.0f / (float) scale); + } + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect batteryPercentage."); } - return mNetworkCapabilities.getLinkDownstreamBandwidthKbps(); } - /** Collects current device's static metrics. */ + /** Collects carrier info. */ @VisibleForTesting - public void getDeviceMetrics(DeviceMetrics deviceMetrics) { + public void getCarrier(RawUserData userData) { + // TODO (b/307158231): handle i18n later if the carrier's name is in non-English script. try { - // TODO(b/290256559): Fix permissions issue. - if (deviceMetrics == null) { - return; + switch (mTelephonyManager.getSimOperatorName().toUpperCase(Locale.US)) { + case "RELIANCE JIO" -> userData.carrier = Carrier.RELIANCE_JIO; + case "VODAFONE" -> userData.carrier = Carrier.VODAFONE; + case "T-MOBILE - US", "T-MOBILE" -> userData.carrier = Carrier.T_MOBILE; + case "VERIZON WIRELESS" -> userData.carrier = Carrier.VERIZON_WIRELESS; + case "AIRTEL" -> userData.carrier = Carrier.AIRTEL; + case "ORANGE" -> userData.carrier = Carrier.ORANGE; + case "NTT DOCOMO" -> userData.carrier = Carrier.NTT_DOCOMO; + case "MOVISTAR" -> userData.carrier = Carrier.MOVISTAR; + case "AT&T" -> userData.carrier = Carrier.AT_T; + case "TELCEL" -> userData.carrier = Carrier.TELCEL; + case "VIVO" -> userData.carrier = Carrier.VIVO; + case "VI" -> userData.carrier = Carrier.VI; + case "TIM" -> userData.carrier = Carrier.TIM; + case "O2" -> userData.carrier = Carrier.O2; + case "TELEKOM" -> userData.carrier = Carrier.TELEKOM; + case "CLARO BR" -> userData.carrier = Carrier.CLARO_BR; + case "SK TELECOM" -> userData.carrier = Carrier.SK_TELECOM; + case "MTC" -> userData.carrier = Carrier.MTC; + case "AU" -> userData.carrier = Carrier.AU; + case "TELE2" -> userData.carrier = Carrier.TELE2; + case "SFR" -> userData.carrier = Carrier.SFR; + case "ETECSA" -> userData.carrier = Carrier.ETECSA; + case "IR-MCI (HAMRAHE AVVAL)" -> userData.carrier = Carrier.IR_MCI; + case "KT" -> userData.carrier = Carrier.KT; + case "TELKOMSEL" -> userData.carrier = Carrier.TELKOMSEL; + case "IRANCELL" -> userData.carrier = Carrier.IRANCELL; + case "MEGAFON" -> userData.carrier = Carrier.MEGAFON; + case "TELEFONICA" -> userData.carrier = Carrier.TELEFONICA; + default -> userData.carrier = Carrier.UNKNOWN; } - deviceMetrics.make = getDeviceMake(); - deviceMetrics.model = getDeviceModel(); - deviceMetrics.screenHeight = mContext.getResources().getConfiguration().screenHeightDp; - deviceMetrics.screenWidth = mContext.getResources().getConfiguration().screenWidthDp; - DisplayMetrics displayMetrics = new DisplayMetrics(); - WindowManager wm = mContext.getSystemService(WindowManager.class); - wm.getDefaultDisplay().getMetrics(displayMetrics); - deviceMetrics.xdpi = displayMetrics.xdpi; - deviceMetrics.ydpi = displayMetrics.ydpi; - deviceMetrics.pxRatio = displayMetrics.density; } catch (Exception e) { - sLogger.e(TAG + ": getDeviceMetrics() failed.", e); + sLogger.w(TAG + "Failed to collect carrier info."); } } - /** - * Collects device make info. - */ + /** Collects network capabilities. */ @VisibleForTesting - public Make getDeviceMake() { - String manufacturer = Build.MANUFACTURER.toUpperCase(Locale.US); - Make make = Make.UNKNOWN; + public void getNetworkCapabilities(RawUserData userData) { try { - make = Make.valueOf(manufacturer); - } catch (IllegalArgumentException iae) { - // handle corner cases for irregularly formatted string. - make = getMakeFromSpecialString(manufacturer); - if (make == Make.UNKNOWN) { - sLogger.e(TAG + ": Manufacturer string cannot match to an available make type.", - iae); - } - return make; + userData.networkCapabilities = mConnectivityManager.getNetworkCapabilities( + mConnectivityManager.getActiveNetwork()); + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect networkCapabilities."); } - return make; } - /** Collects device model info */ @VisibleForTesting - public Model getDeviceModel() { - // Uppercase and replace whitespace/hyphen with underscore character - String deviceModel = Build.MODEL.toUpperCase(Locale.US).replace(' ', '_').replace('-', '_'); - Model model = Model.UNKNOWN; + public void getDataNetworkType(RawUserData userData) { try { - model = Model.valueOf(deviceModel); - } catch (IllegalArgumentException iae) { - // handle corner cases for irregularly formatted string. - model = getModelFromSpecialString(deviceModel); - if (model == Model.UNKNOWN) { - sLogger.e(TAG + ": Model string cannot match to an available make type.", iae); + int dataNetworkType = mTelephonyManager.getDataNetworkType(); + if (!ALLOWED_NETWORK_TYPE.contains(dataNetworkType)) { + userData.dataNetworkType = TelephonyManager.NETWORK_TYPE_UNKNOWN; + } else { + userData.dataNetworkType = dataNetworkType; } - return model; - } - return model; - } - - /** - * Helper function that handles irregularly formatted manufacturer string, - * which cannot be directly cast into enums. - */ - private Make getMakeFromSpecialString(String deviceMake) { - switch (deviceMake) { - case "TECNO MOBILE LIMITED": - return Make.TECNO; - case "INFINIX MOBILITY LIMITED": - return Make.INFINIX; - case "HMD GLOBAL": - return Make.HMD_GLOBAL; - case "LGE": - case "LG ELECTRONICS": - return Make.LG_ELECTRONICS; - case "SKYWORTHDIGITAL": - return Make.SKYWORTH; - case "ITEL": - case "ITEL MOBILE LIMITED": - return Make.ITEL_MOBILE; - case "KAON": - case "KAONMEDIA": - return Make.KAON_MEDIA; - case "ZEBRA TECHNOLOGIES": - return Make.ZEBRA_TECHNOLOGIES; - case "VOLVOCARS": - return Make.VOLVO_CARS; - case "SUMITOMOELECTRICINDUSTRIES": - return Make.SUMITOMO_ELECTRIC_INDUSTRIES; - case "STARHUB": - return Make.STAR_HUB; - case "GENERALMOBILE": - return Make.GENERAL_MOBILE; - case "KONNECTONE": - return Make.KONNECT_ONE; - case "CASIO COMPUTER CO., LTD.": - return Make.CASIO_COMPUTER; - case "SEI": - case "SEI ROBOTICS": - return Make.SEI_ROBOTICS; - case "EASTAEON": - return Make.EAST_AEON; - case "HIMEDIA": - return Make.HI_MEDIA; - case "HOT PEPPER INC": - return Make.HOT_PEPPER; - case "SKY": - case "SKY DEVICES": - case "SKYDEVICES": - return Make.SKY_DEVICES; - case "FOXX DEVELOPMENT INC.": - return Make.FOXX_DEVELOPMENT; - case "RELIANCE COMMUNICATIONS": - return Make.RELIANCE_COMMUNICATIONS; - case "EMPORIA TELECOM GMBH & CO. KG": - return Make.EMPORIA; - case "CIPHERLAB": - return Make.CIPHER_LAB; - case "ISAFEMOBILE": - return Make.ISAFE_MOBILE; - case "CLARO COLUMBIA": - return Make.CLARO; - case "MYPHONE": - return Make.MY_PHONE; - case "TAG HEUER": - return Make.TAG_HEUER; - case "XWIRELESSLLC": - return Make.XWIRELESS; - default: - return Make.UNKNOWN; - } - } - - /** - * Helper function that handles irregularly formatted model string, - * which cannot be directly cast into enums. - */ - private Model getModelFromSpecialString(String deviceModel) { - switch (deviceModel) { - case "AT&T_TV": - return Model.ATT_TV; - case "XVIEW+": - return Model.XVIEW_PLUS; - case "2201117TG": - return Model.REDMI_2201117TG; - case "2201117TY": - return Model.REDMI_2201117TY; - case "B860H_V5.0": - return Model.B860H_V5; - case "MOTO_G(20)": - return Model.MOTO_G20; - case "2020/2021_UHD_ANDROID_TV": - return Model.PHILIPS_2020_2021_UHD_ANDROID_TV; - case "2109119DG": - return Model.XIAOMI_2109119DG; - case "MOTO_G(9)_PLAY": - return Model.MOTO_G9_PLAY; - case "MOTO_E(7)": - return Model.MOTO_E7; - case "2021/22_PHILIPS_UHD_ANDROID_TV": - return Model.PHILIPS_2021_22_UHD_ANDROID_TV; - case "MOTO_G(30)": - return Model.MOTO_G30; - case "MOTO_G_POWER_(2021)": - return Model.MOTO_G_POWER_2021; - case "MOTO_G(7)_POWER": - return Model.MOTO_G7_POWER; - case "OTT_XVIEW+_AV1": - return Model.OTT_XVIEW_PLUS_AV1; - default: - return Model.UNKNOWN; + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect data network type."); } } /** Get app install and uninstall record. */ @VisibleForTesting public void getInstalledApps(@NonNull List<AppInfo> appsInfo) { - appsInfo.clear(); - PackageManager packageManager = mContext.getPackageManager(); - for (ApplicationInfo appInfo : - packageManager.getInstalledApplications(MATCH_UNINSTALLED_PACKAGES)) { - AppInfo app = new AppInfo(); - app.packageName = appInfo.packageName; - if ((appInfo.flags & ApplicationInfo.FLAG_INSTALLED) != 0) { - app.installed = true; - } else { - app.installed = false; + try { + appsInfo.clear(); + PackageManager packageManager = mContext.getPackageManager(); + for (ApplicationInfo appInfo : + packageManager.getInstalledApplications(MATCH_UNINSTALLED_PACKAGES)) { + AppInfo app = new AppInfo(); + app.packageName = appInfo.packageName; + if ((appInfo.flags & ApplicationInfo.FLAG_INSTALLED) != 0) { + app.installed = true; + } else { + app.installed = false; + } + appsInfo.add(app); } - appsInfo.add(app); + sLogger.d(TAG + ": Finished collecting AppInfo."); + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect installed AppInfo."); } } @@ -670,48 +360,53 @@ public class UserDataCollector { */ @VisibleForTesting public void getAppUsageStats(HashMap<String, Long> appUsageHistory) { - Calendar cal = Calendar.getInstance(); - // Obtain the 24-hour query range between [yesterday midnight] and [today midnight]. - cal.set(Calendar.MILLISECOND, 0); - cal.set(Calendar.SECOND, 0); - cal.set(Calendar.MINUTE, 0); - cal.set(Calendar.HOUR_OF_DAY, 0); - final long endTimeMillis = cal.getTimeInMillis(); - - // Skip the current collection cycle. - if (endTimeMillis == mLastTimeMillisAppUsageCollected) { - return; - } - - // Collect yesterday's app usage stats. - cal.add(Calendar.DATE, -1); - final long startTimeMillis = cal.getTimeInMillis(); - UsageStatsManager usageStatsManager = mContext.getSystemService(UsageStatsManager.class); - final List<UsageStats> statsList = usageStatsManager.queryUsageStats( - UsageStatsManager.INTERVAL_BEST, startTimeMillis, endTimeMillis); + try { + Calendar cal = Calendar.getInstance(); + // Obtain the 24-hour query range between [yesterday midnight] and [today midnight]. + cal.set(Calendar.MILLISECOND, 0); + cal.set(Calendar.SECOND, 0); + cal.set(Calendar.MINUTE, 0); + cal.set(Calendar.HOUR_OF_DAY, 0); + final long endTimeMillis = cal.getTimeInMillis(); + + // Skip the current collection cycle. + if (endTimeMillis == mLastTimeMillisAppUsageCollected) { + return; + } - List<AppUsageEntry> appUsageEntries = new ArrayList<>(); - for (UsageStats stats : statsList) { - if (stats.getTotalTimeVisible() == 0) { - continue; + // Collect yesterday's app usage stats. + cal.add(Calendar.DATE, -1); + final long startTimeMillis = cal.getTimeInMillis(); + UsageStatsManager usageStatsManager = + mContext.getSystemService(UsageStatsManager.class); + final List<UsageStats> statsList = usageStatsManager.queryUsageStats( + UsageStatsManager.INTERVAL_BEST, startTimeMillis, endTimeMillis); + + List<AppUsageEntry> appUsageEntries = new ArrayList<>(); + for (UsageStats stats : statsList) { + if (stats.getTotalTimeVisible() == 0) { + continue; + } + appUsageEntries.add(new AppUsageEntry(stats.getPackageName(), + startTimeMillis, endTimeMillis, stats.getTotalTimeVisible())); } - appUsageEntries.add(new AppUsageEntry(stats.getPackageName(), - startTimeMillis, endTimeMillis, stats.getTotalTimeVisible())); - } - // TODO(267678607): refactor the business logic when no stats is available. - if (appUsageEntries.size() == 0) { - return; - } + // TODO(267678607): refactor the business logic when no stats is available. + if (appUsageEntries.size() == 0) { + return; + } - // Update database. - if (!mUserDataDao.batchInsertAppUsageStatsData(appUsageEntries)) { - return; + // Update database. + if (!mUserDataDao.batchInsertAppUsageStatsData(appUsageEntries)) { + return; + } + // Update in-memory histogram. + updateAppUsageHistogram(appUsageHistory, appUsageEntries); + // Update metadata if all steps succeed as a transaction. + mLastTimeMillisAppUsageCollected = endTimeMillis; + } catch (Exception e) { + sLogger.w(TAG + ": Failed to collect app usage."); } - // Update in-memory histogram. - updateAppUsageHistogram(appUsageHistory, appUsageEntries); - // Update metadata if all steps succeed as a transaction. - mLastTimeMillisAppUsageCollected = endTimeMillis; } /** @@ -801,7 +496,7 @@ public class UserDataCollector { * @return true if location info collection is successful, false otherwise. */ private boolean setLocationInfo(Location location, LocationInfo locationInfo) { - long timeMillis = getTimeMillis() - location.getElapsedRealtimeAgeMillis(); + long timeMillis = System.currentTimeMillis() - location.getElapsedRealtimeAgeMillis(); double truncatedLatitude = Math.round(location.getLatitude() * 10000.0) / 10000.0; double truncatedLongitude = Math.round(location.getLongitude() * 10000.0) / 10000.0; LocationInfo.LocationProvider locationProvider = LocationProvider.UNKNOWN; @@ -869,30 +564,15 @@ public class UserDataCollector { } /** - * Setter to update locale for testing purpose. - */ - @VisibleForTesting - public void setLocale(Locale locale) { - mLocale = locale; - } - - /** * Util to reset all fields in [UserData] to default for testing purpose */ public void clearUserData(@NonNull RawUserData userData) { - userData.timeMillis = 0; userData.utcOffset = 0; userData.orientation = Configuration.ORIENTATION_PORTRAIT; userData.availableStorageBytes = 0; userData.batteryPercentage = 0; - userData.country = Country.UNKNOWN; - userData.language = Language.UNKNOWN; userData.carrier = Carrier.UNKNOWN; - userData.osVersions = new OSVersion(); - userData.connectionType = RawUserData.ConnectionType.UNKNOWN; - userData.networkMetered = false; - userData.connectionSpeedKbps = 0; - userData.deviceMetrics = new DeviceMetrics(); + userData.networkCapabilities = null; userData.appsInfo.clear(); userData.appUsageHistory.clear(); userData.locationHistory.clear(); diff --git a/src/com/android/ondevicepersonalization/services/data/user/UserDataDao.java b/src/com/android/ondevicepersonalization/services/data/user/UserDataDao.java index 5ef5f9f4..361e6880 100644 --- a/src/com/android/ondevicepersonalization/services/data/user/UserDataDao.java +++ b/src/com/android/ondevicepersonalization/services/data/user/UserDataDao.java @@ -35,7 +35,7 @@ public class UserDataDao { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "UserDataDao"; - private static UserDataDao sUserDataDao; + private static volatile UserDataDao sUserDataDao; private final OnDevicePersonalizationDbHelper mDbHelper; public static final int TTL_IN_MEMORY_DAYS = 30; @@ -50,13 +50,16 @@ public class UserDataDao { * @return Instance of UserDataDao for accessing the requested package's table. */ public static UserDataDao getInstance(Context context) { - synchronized (UserDataDao.class) { - if (sUserDataDao == null) { - sUserDataDao = new UserDataDao( - OnDevicePersonalizationDbHelper.getInstance(context)); + if (sUserDataDao == null) { + synchronized (UserDataDao.class) { + if (sUserDataDao == null) { + sUserDataDao = new UserDataDao( + OnDevicePersonalizationDbHelper.getInstance( + context.getApplicationContext())); + } } - return sUserDataDao; } + return sUserDataDao; } /** diff --git a/src/com/android/ondevicepersonalization/services/data/user/UserPrivacyStatus.java b/src/com/android/ondevicepersonalization/services/data/user/UserPrivacyStatus.java new file mode 100644 index 00000000..2dc5451c --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/data/user/UserPrivacyStatus.java @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.data.user; + +import com.android.ondevicepersonalization.services.PhFlags; + +/** + * A singleton class that stores all user privacy statuses in memory. + */ +public final class UserPrivacyStatus { + public static UserPrivacyStatus sUserPrivacyStatus = null; + private boolean mPersonalizationStatusEnabled; + + private UserPrivacyStatus() { + // Assume the more privacy-safe option until updated. + mPersonalizationStatusEnabled = false; + } + + /** Returns an instance of UserPrivacyStatus. */ + public static UserPrivacyStatus getInstance() { + synchronized (UserPrivacyStatus.class) { + if (sUserPrivacyStatus == null) { + sUserPrivacyStatus = new UserPrivacyStatus(); + } + return sUserPrivacyStatus; + } + } + + public void setPersonalizationStatusEnabled(boolean personalizationStatusEnabled) { + PhFlags phFlags = PhFlags.getInstance(); + if (!phFlags.isPersonalizationStatusOverrideEnabled()) { + mPersonalizationStatusEnabled = personalizationStatusEnabled; + } + } + + public boolean isPersonalizationStatusEnabled() { + PhFlags phFlags = PhFlags.getInstance(); + if (phFlags.isPersonalizationStatusOverrideEnabled()) { + return phFlags.getPersonalizationStatusOverrideValue(); + } + return mPersonalizationStatusEnabled; + } +} diff --git a/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationLocalDataDao.java b/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationLocalDataDao.java index f3f68ae4..9971f1f0 100644 --- a/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationLocalDataDao.java +++ b/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationLocalDataDao.java @@ -29,10 +29,10 @@ import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; -import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; /** * Dao used to manage access to local data tables @@ -43,7 +43,7 @@ public class OnDevicePersonalizationLocalDataDao { private static final String LOCAL_DATA_TABLE_NAME_PREFIX = "localdata_"; private static final Map<String, OnDevicePersonalizationLocalDataDao> sLocalDataDaos = - new HashMap<>(); + new ConcurrentHashMap<>(); private final OnDevicePersonalizationDbHelper mDbHelper; private final String mOwner; private final String mCertDigest; @@ -67,20 +67,23 @@ public class OnDevicePersonalizationLocalDataDao { * package's table */ public static OnDevicePersonalizationLocalDataDao getInstance(Context context, String owner, - String certDigest) { - synchronized (OnDevicePersonalizationLocalDataDao.class) { - // TODO: Validate the owner and certDigest - String tableName = getTableName(owner, certDigest); - OnDevicePersonalizationLocalDataDao instance = sLocalDataDaos.get(tableName); - if (instance == null) { - OnDevicePersonalizationDbHelper dbHelper = - OnDevicePersonalizationDbHelper.getInstance(context); - instance = new OnDevicePersonalizationLocalDataDao( - dbHelper, owner, certDigest); - sLocalDataDaos.put(tableName, instance); + String certDigest) { + // TODO: Validate the owner and certDigest + String tableName = getTableName(owner, certDigest); + OnDevicePersonalizationLocalDataDao instance = sLocalDataDaos.get(tableName); + if (instance == null) { + synchronized (sLocalDataDaos) { + instance = sLocalDataDaos.get(tableName); + if (instance == null) { + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstance(context); + instance = new OnDevicePersonalizationLocalDataDao( + dbHelper, owner, certDigest); + sLocalDataDaos.put(tableName, instance); + } } - return instance; } + return instance; } /** diff --git a/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationVendorDataDao.java b/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationVendorDataDao.java index e8586603..ded1c12b 100644 --- a/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationVendorDataDao.java +++ b/src/com/android/ondevicepersonalization/services/data/vendor/OnDevicePersonalizationVendorDataDao.java @@ -23,18 +23,17 @@ import android.database.SQLException; import android.database.sqlite.SQLiteDatabase; import android.database.sqlite.SQLiteException; - import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; import java.util.AbstractMap; import java.util.ArrayList; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; /** @@ -46,7 +45,7 @@ public class OnDevicePersonalizationVendorDataDao { private static final String VENDOR_DATA_TABLE_NAME_PREFIX = "vendordata_"; private static final Map<String, OnDevicePersonalizationVendorDataDao> sVendorDataDaos = - new HashMap<>(); + new ConcurrentHashMap<>(); private final OnDevicePersonalizationDbHelper mDbHelper; private final String mOwner; private final String mCertDigest; @@ -71,19 +70,22 @@ public class OnDevicePersonalizationVendorDataDao { */ public static OnDevicePersonalizationVendorDataDao getInstance(Context context, String owner, String certDigest) { - synchronized (OnDevicePersonalizationVendorDataDao.class) { - // TODO: Validate the owner and certDigest - String tableName = getTableName(owner, certDigest); - OnDevicePersonalizationVendorDataDao instance = sVendorDataDaos.get(tableName); - if (instance == null) { - OnDevicePersonalizationDbHelper dbHelper = - OnDevicePersonalizationDbHelper.getInstance(context); - instance = new OnDevicePersonalizationVendorDataDao( - dbHelper, owner, certDigest); - sVendorDataDaos.put(tableName, instance); + // TODO: Validate the owner and certDigest + String tableName = getTableName(owner, certDigest); + OnDevicePersonalizationVendorDataDao instance = sVendorDataDaos.get(tableName); + if (instance == null) { + synchronized (sVendorDataDaos) { + instance = sVendorDataDaos.get(tableName); + if (instance == null) { + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstance(context); + instance = new OnDevicePersonalizationVendorDataDao( + dbHelper, owner, certDigest); + sVendorDataDaos.put(tableName, instance); + } } - return instance; } + return instance; } /** diff --git a/src/com/android/ondevicepersonalization/services/display/DisplayHelper.java b/src/com/android/ondevicepersonalization/services/display/DisplayHelper.java index b0aa954d..9d757bcf 100644 --- a/src/com/android/ondevicepersonalization/services/display/DisplayHelper.java +++ b/src/com/android/ondevicepersonalization/services/display/DisplayHelper.java @@ -16,7 +16,7 @@ package com.android.ondevicepersonalization.services.display; -import android.adservices.ondevicepersonalization.RenderOutput; +import android.adservices.ondevicepersonalization.RenderOutputParcel; import android.adservices.ondevicepersonalization.RequestLogRecord; import android.annotation.NonNull; import android.content.Context; @@ -58,9 +58,9 @@ public class DisplayHelper { mContext = context; } - /** Generates an HTML string from the template data in RenderOutput. */ + /** Generates an HTML string from the template data in RenderOutputParcel. */ @NonNull public String generateHtml( - @NonNull RenderOutput renderContentResult, + @NonNull RenderOutputParcel renderContentResult, @NonNull String servicePackageName) { // If htmlContent is provided, do not render the template. String htmlContent = renderContentResult.getContent(); diff --git a/src/com/android/ondevicepersonalization/services/display/OdpWebViewClient.java b/src/com/android/ondevicepersonalization/services/display/OdpWebViewClient.java index 9b00ac78..d74f0c71 100644 --- a/src/com/android/ondevicepersonalization/services/display/OdpWebViewClient.java +++ b/src/com/android/ondevicepersonalization/services/display/OdpWebViewClient.java @@ -17,11 +17,11 @@ package com.android.ondevicepersonalization.services.display; import android.adservices.ondevicepersonalization.Constants; +import android.adservices.ondevicepersonalization.EventInputParcel; import android.adservices.ondevicepersonalization.EventLogRecord; +import android.adservices.ondevicepersonalization.EventOutputParcel; import android.adservices.ondevicepersonalization.RequestLogRecord; import android.adservices.ondevicepersonalization.UserData; -import android.adservices.ondevicepersonalization.WebViewEventInput; -import android.adservices.ondevicepersonalization.WebViewEventOutput; import android.annotation.NonNull; import android.content.Context; import android.content.Intent; @@ -34,6 +34,8 @@ import android.webkit.WebViewClient; import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.Flags; +import com.android.ondevicepersonalization.services.FlagsFactory; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl; import com.android.ondevicepersonalization.services.data.events.Event; @@ -43,18 +45,25 @@ import com.android.ondevicepersonalization.services.data.events.EventsDao; import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor; import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo; -import com.android.ondevicepersonalization.services.process.ProcessUtils; +import com.android.ondevicepersonalization.services.process.ProcessRunner; +import com.android.ondevicepersonalization.services.statsd.ApiCallStats; +import com.android.ondevicepersonalization.services.statsd.OdpStatsdLogger; +import com.android.ondevicepersonalization.services.util.Clock; +import com.android.ondevicepersonalization.services.util.MonotonicClock; import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; +import com.android.ondevicepersonalization.services.util.StatsUtils; import com.google.common.util.concurrent.FluentFuture; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.net.HttpURLConnection; import java.util.Collections; -import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; class OdpWebViewClient extends WebViewClient { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); @@ -63,7 +72,7 @@ class OdpWebViewClient extends WebViewClient { @VisibleForTesting static class Injector { - Executor getExecutor() { + ListeningExecutorService getExecutor() { return OnDevicePersonalizationExecutors.getBackgroundExecutor(); } @@ -76,8 +85,20 @@ class OdpWebViewClient extends WebViewClient { } } - long getTimeMillis() { - return System.currentTimeMillis(); + Clock getClock() { + return MonotonicClock.getInstance(); + } + + Flags getFlags() { + return FlagsFactory.getFlags(); + } + + ListeningScheduledExecutorService getScheduledExecutor() { + return OnDevicePersonalizationExecutors.getScheduledExecutor(); + } + + ProcessRunner getProcessRunner() { + return ProcessRunner.getInstance(); } } @@ -160,16 +181,18 @@ class OdpWebViewClient extends WebViewClient { return true; } - private ListenableFuture<WebViewEventOutput> executeEventHandler( - IsolatedServiceInfo isolatedServiceInfo, EventUrlPayload payload) { + private ListenableFuture<EventOutputParcel> executeEventHandler( + IsolatedServiceInfo isolatedServiceInfo, + EventUrlPayload payload) { try { sLogger.d(TAG + ": executeEventHandler() called"); Bundle serviceParams = new Bundle(); DataAccessServiceImpl binder = new DataAccessServiceImpl( - mServicePackageName, mContext, true); + mServicePackageName, mContext, /* includeLocalData */ true, + /* includeEventData */ true); serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder); // TODO(b/259950177): Add Query row to input. - WebViewEventInput input = new WebViewEventInput.Builder() + EventInputParcel input = new EventInputParcel.Builder() .setParameters(payload.getEventParams()) .setRequestLogRecord(mLogRecord) .build(); @@ -178,16 +201,31 @@ class OdpWebViewClient extends WebViewClient { UserData userData = userDataAccessor.getUserData(); serviceParams.putParcelable(Constants.EXTRA_USER_DATA, userData); return FluentFuture.from( - ProcessUtils.runIsolatedService( + mInjector.getProcessRunner().runIsolatedService( isolatedServiceInfo, AppManifestConfigHelper.getServiceNameFromOdpSettings( mContext, mServicePackageName), Constants.OP_WEB_VIEW_EVENT, serviceParams)) .transform( - result -> result.getParcelable( - Constants.EXTRA_RESULT, WebViewEventOutput.class), - mInjector.getExecutor()); + result -> { + writeServiceRequestMetrics( + result, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_SUCCESS); + return result.getParcelable( + Constants.EXTRA_RESULT, EventOutputParcel.class); + }, + mInjector.getExecutor()) + .catchingAsync( + Exception.class, + e -> { + writeServiceRequestMetrics( + null, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_INTERNAL_ERROR); + return Futures.immediateFailedFuture(e); + }, + mInjector.getExecutor() + ); } catch (Exception e) { sLogger.e(TAG + ": executeEventHandler() failed", e); return Futures.immediateFailedFuture(e); @@ -195,24 +233,25 @@ class OdpWebViewClient extends WebViewClient { } - ListenableFuture<WebViewEventOutput> getEventOutput(EventUrlPayload payload) { + ListenableFuture<EventOutputParcel> getEventOutputParcel( + ListenableFuture<IsolatedServiceInfo> loadFuture, + EventUrlPayload payload) { try { - sLogger.d(TAG + ": getEventOutput(): Starting isolated process."); - return FluentFuture.from(ProcessUtils.loadIsolatedService( - TASK_NAME, mServicePackageName, mContext)) + sLogger.d(TAG + ": getEventOutputParcel(): Starting isolated process."); + return FluentFuture.from(loadFuture) .transformAsync( result -> executeEventHandler(result, payload), mInjector.getExecutor()); } catch (Exception e) { - sLogger.e(TAG + ": getEventOutput() failed", e); + sLogger.e(TAG + ": getEventOutputParcel() failed", e); return Futures.immediateFailedFuture(e); } } - private ListenableFuture<Void> writeEvent(WebViewEventOutput result) { + private ListenableFuture<Void> writeEvent(EventOutputParcel result) { try { - sLogger.d(TAG + ": writeEvent() called. EventOutput: " + result.toString()); + sLogger.d(TAG + ": writeEvent() called. EventOutputParcel: " + result.toString()); if (result == null || result.getEventLogRecord() == null) { return Futures.immediateFuture(null); } @@ -232,7 +271,7 @@ class OdpWebViewClient extends WebViewClient { .setType(eventData.getType()) .setQueryId(mQueryId) .setServicePackageName(mServicePackageName) - .setTimeMillis(mInjector.getTimeMillis()) + .setTimeMillis(mInjector.getClock().currentTimeMillis()) .setRowIndex(eventData.getRowIndex()) .setEventData(data) .build(); @@ -250,13 +289,38 @@ class OdpWebViewClient extends WebViewClient { try { sLogger.d(TAG + ": handleEvent() called"); - var unused = FluentFuture.from(getEventOutput(eventUrlPayload)) + ListenableFuture<IsolatedServiceInfo> loadFuture = + mInjector.getProcessRunner().loadIsolatedService( + TASK_NAME, mServicePackageName); + + var doneFuture = FluentFuture.from(getEventOutputParcel(loadFuture, eventUrlPayload)) .transformAsync( result -> writeEvent(result), - mInjector.getExecutor()); + mInjector.getExecutor()) + .withTimeout( + mInjector.getFlags().getIsolatedServiceDeadlineSeconds(), + TimeUnit.SECONDS, + mInjector.getScheduledExecutor() + ); + var unused = Futures.whenAllComplete(loadFuture, doneFuture) + .callAsync(() -> mInjector.getProcessRunner().unloadIsolatedService( + loadFuture.get()), + mInjector.getExecutor()); } catch (Exception e) { sLogger.e(TAG + ": Failed to handle Event", e); } } + + private void writeServiceRequestMetrics(Bundle result, long startTimeMillis, int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - startTimeMillis); + int overheadLatencyMillis = + (int) StatsUtils.getOverheadLatencyMillis(latencyMillis, result); + ApiCallStats callStats = new ApiCallStats.Builder(ApiCallStats.API_SERVICE_ON_EVENT) + .setLatencyMillis(latencyMillis) + .setOverheadLatencyMillis(overheadLatencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } } diff --git a/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDataProcessingAsyncCallable.java b/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDataProcessingAsyncCallable.java index a0982818..b05527e4 100644 --- a/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDataProcessingAsyncCallable.java +++ b/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDataProcessingAsyncCallable.java @@ -17,8 +17,8 @@ package com.android.ondevicepersonalization.services.download; import android.adservices.ondevicepersonalization.Constants; +import android.adservices.ondevicepersonalization.DownloadCompletedOutputParcel; import android.adservices.ondevicepersonalization.DownloadInputParcel; -import android.adservices.ondevicepersonalization.DownloadOutput; import android.adservices.ondevicepersonalization.UserData; import android.content.Context; import android.content.pm.PackageManager; @@ -36,11 +36,17 @@ import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonal import com.android.ondevicepersonalization.services.data.vendor.VendorData; import com.android.ondevicepersonalization.services.download.mdd.MobileDataDownloadFactory; import com.android.ondevicepersonalization.services.download.mdd.OnDevicePersonalizationFileGroupPopulator; +import com.android.ondevicepersonalization.services.federatedcompute.FederatedComputeServiceImpl; import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor; import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo; -import com.android.ondevicepersonalization.services.process.ProcessUtils; +import com.android.ondevicepersonalization.services.process.ProcessRunner; +import com.android.ondevicepersonalization.services.statsd.ApiCallStats; +import com.android.ondevicepersonalization.services.statsd.OdpStatsdLogger; +import com.android.ondevicepersonalization.services.util.Clock; +import com.android.ondevicepersonalization.services.util.MonotonicClock; import com.android.ondevicepersonalization.services.util.PackageUtils; +import com.android.ondevicepersonalization.services.util.StatsUtils; import com.google.android.libraries.mobiledatadownload.GetFileGroupRequest; import com.google.android.libraries.mobiledatadownload.MobileDataDownload; @@ -57,6 +63,7 @@ import com.google.mobiledatadownload.ClientConfigProto.ClientFileGroup; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -74,10 +81,28 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async private final Context mContext; private OnDevicePersonalizationVendorDataDao mDao; + static class Injector { + Clock getClock() { + return MonotonicClock.getInstance(); + } + + ProcessRunner getProcessRunner() { + return ProcessRunner.getInstance(); + } + } + + private final Injector mInjector; + public OnDevicePersonalizationDataProcessingAsyncCallable(String packageName, Context context) { + this(packageName, context, new Injector()); + } + + public OnDevicePersonalizationDataProcessingAsyncCallable(String packageName, + Context context, Injector injector) { mPackageName = packageName; mContext = context; + mInjector = injector; } private static boolean validateSyncToken(long syncToken) { @@ -177,13 +202,12 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async Map<String, VendorData> finalVendorDataMap = vendorDataMap; long finalSyncToken = syncToken; try { - return FluentFuture.from(ProcessUtils.loadIsolatedService( - TASK_NAME, mPackageName, mContext)) + ListenableFuture<IsolatedServiceInfo> loadFuture = + mInjector.getProcessRunner().loadIsolatedService( + TASK_NAME, mPackageName); + var resultFuture = FluentFuture.from(loadFuture) .transformAsync( - result -> - executeDownloadHandler( - result, - finalVendorDataMap), + result -> executeDownloadHandler(result, finalVendorDataMap), OnDevicePersonalizationExecutors.getBackgroundExecutor()) .transform(pluginResult -> filterAndStoreData(pluginResult, finalSyncToken, finalVendorDataMap), @@ -195,6 +219,13 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async return null; }, OnDevicePersonalizationExecutors.getBackgroundExecutor()); + + var unused = Futures.whenAllComplete(loadFuture, resultFuture) + .callAsync(() -> mInjector.getProcessRunner().unloadIsolatedService( + loadFuture.get()), + OnDevicePersonalizationExecutors.getBackgroundExecutor()); + + return resultFuture; } catch (Exception e) { sLogger.e(TAG + ": Could not run isolated service.", e); return Futures.immediateFuture(null); @@ -205,8 +236,8 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async Map<String, VendorData> vendorDataMap) { sLogger.d(TAG + ": Plugin filter code completed successfully"); List<VendorData> filteredList = new ArrayList<>(); - DownloadOutput downloadResult = pluginResult.getParcelable( - Constants.EXTRA_RESULT, DownloadOutput.class); + DownloadCompletedOutputParcel downloadResult = pluginResult.getParcelable( + Constants.EXTRA_RESULT, DownloadCompletedOutputParcel.class); List<String> retainedKeys = downloadResult.getRetainedKeys(); if (retainedKeys == null) { // TODO(b/270710021): Determine how to correctly handle null retainedKeys. @@ -227,8 +258,12 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async Map<String, VendorData> vendorDataMap) { Bundle pluginParams = new Bundle(); DataAccessServiceImpl binder = new DataAccessServiceImpl( - mPackageName, mContext, true); + mPackageName, mContext, /* includeLocalData */ true, + /* includeEventData */ true); pluginParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder); + FederatedComputeServiceImpl fcpBinder = new FederatedComputeServiceImpl( + mPackageName, mContext); + pluginParams.putBinder(Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, fcpBinder); List<String> keys = new ArrayList<>(); List<byte[]> values = new ArrayList<>(); @@ -252,11 +287,31 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async UserDataAccessor userDataAccessor = new UserDataAccessor(); UserData userData = userDataAccessor.getUserData(); pluginParams.putParcelable(Constants.EXTRA_USER_DATA, userData); - return ProcessUtils.runIsolatedService( + ListenableFuture<Bundle> result = mInjector.getProcessRunner().runIsolatedService( isolatedServiceInfo, AppManifestConfigHelper.getServiceNameFromOdpSettings(mContext, mPackageName), Constants.OP_DOWNLOAD, pluginParams); + return FluentFuture.from(result) + .transform( + val -> { + writeServiceRequestMetrics( + val, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_SUCCESS); + return val; + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor() + ) + .catchingAsync( + Exception.class, + e -> { + writeServiceRequestMetrics( + null, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_INTERNAL_ERROR); + return Futures.immediateFailedFuture(e); + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor() + ); } private Map<String, VendorData> readContentsArray(JsonReader reader) throws IOException { @@ -282,7 +337,7 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async if (name.equals("key")) { key = reader.nextString(); } else if (name.equals("data")) { - data = reader.nextString().getBytes(); + data = reader.nextString().getBytes(StandardCharsets.UTF_8); } else { reader.skipValue(); } @@ -293,4 +348,17 @@ public class OnDevicePersonalizationDataProcessingAsyncCallable implements Async } return new VendorData.Builder().setKey(key).setData(data).build(); } + + private void writeServiceRequestMetrics(Bundle result, long startTimeMillis, int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - startTimeMillis); + int overheadLatencyMillis = + (int) StatsUtils.getOverheadLatencyMillis(latencyMillis, result); + ApiCallStats callStats = + new ApiCallStats.Builder(ApiCallStats.API_SERVICE_ON_DOWNLOAD_COMPLETED) + .setLatencyMillis(latencyMillis) + .setOverheadLatencyMillis(overheadLatencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } } diff --git a/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDownloadProcessingJobService.java b/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDownloadProcessingJobService.java index bb05ebe9..519ed10e 100644 --- a/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDownloadProcessingJobService.java +++ b/src/com/android/ondevicepersonalization/services/download/OnDevicePersonalizationDownloadProcessingJobService.java @@ -94,7 +94,7 @@ public class OnDevicePersonalizationDownloadProcessingJobService extends JobServ OnDevicePersonalizationExecutors.getBackgroundExecutor())); } } - Futures.whenAllComplete(mFutures).call(() -> { + var unused = Futures.whenAllComplete(mFutures).call(() -> { jobFinished(params, /* wantsReschedule */ false); return null; }, OnDevicePersonalizationExecutors.getLightweightExecutor()); diff --git a/src/com/android/ondevicepersonalization/services/download/mdd/MddJobService.java b/src/com/android/ondevicepersonalization/services/download/mdd/MddJobService.java index a428fa47..fce0b320 100644 --- a/src/com/android/ondevicepersonalization/services/download/mdd/MddJobService.java +++ b/src/com/android/ondevicepersonalization/services/download/mdd/MddJobService.java @@ -29,6 +29,7 @@ import android.os.PersistableBundle; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import com.android.ondevicepersonalization.services.FlagsFactory; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.download.OnDevicePersonalizationDownloadProcessingJobService; import com.google.android.libraries.mobiledatadownload.tracing.PropagatedFutures; @@ -54,6 +55,12 @@ public class MddJobService extends JobService { return true; } + if (!UserPrivacyStatus.getInstance().isPersonalizationStatusEnabled()) { + sLogger.d(TAG + ": Personalization is not allowed, finishing job."); + jobFinished(params, false); + return true; + } + // Get the mddTaskTag from input. PersistableBundle extras = params.getExtras(); if (null == extras) { diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/ContextData.java b/src/com/android/ondevicepersonalization/services/federatedcompute/ContextData.java new file mode 100644 index 00000000..9d0c5cfa --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/federatedcompute/ContextData.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.federatedcompute; + +import android.annotation.NonNull; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +/** + * ContextData object to pass to federatedcompute + * TODO(278106108): Move this class depending on scheduling impl. + */ +public class ContextData implements Serializable { + @NonNull + String mPackageName; + + public ContextData(@NonNull String packageName) { + mPackageName = packageName; + } + + /** + * Converts the given ContextData into a serialized byte[] + */ + public static byte[] toByteArray(ContextData contextData) throws IOException { + try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + ObjectOutputStream objectOutputStream = new ObjectOutputStream( + byteArrayOutputStream)) { + objectOutputStream.writeObject(contextData); + return byteArrayOutputStream.toByteArray(); + } + } + + /** + * Converts the given serialized byte[] into a ContextData object + */ + public static ContextData fromByteArray(byte[] arr) throws IOException, ClassNotFoundException { + try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arr); + ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream)) { + return (ContextData) objectInputStream.readObject(); + } + } + + @NonNull + public String getPackageName() { + return mPackageName; + } +} diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/FederatedComputeServiceImpl.java b/src/com/android/ondevicepersonalization/services/federatedcompute/FederatedComputeServiceImpl.java new file mode 100644 index 00000000..69a0c68b --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/federatedcompute/FederatedComputeServiceImpl.java @@ -0,0 +1,211 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.federatedcompute; + +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; +import android.annotation.NonNull; +import android.content.Context; +import android.content.pm.PackageManager; +import android.federatedcompute.FederatedComputeManager; +import android.federatedcompute.common.ClientConstants; +import android.federatedcompute.common.ScheduleFederatedComputeRequest; +import android.federatedcompute.common.TrainingOptions; +import android.os.OutcomeReceiver; +import android.os.RemoteException; +import android.os.SystemProperties; + +import com.android.internal.annotations.VisibleForTesting; +import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; +import com.android.ondevicepersonalization.services.util.PackageUtils; + +import com.google.common.util.concurrent.ListeningExecutorService; + +import java.io.IOException; +import java.util.Objects; + +/** + * A class that exports methods that plugin code in the isolated process + * can use to schedule federatedCompute jobs. + */ +public class FederatedComputeServiceImpl extends IFederatedComputeService.Stub { + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + private static final String TAG = "FederatedComputeServiceImpl"; + + private static final String OVERRIDE_FC_SERVER_URL_PACKAGE = + "debug.ondevicepersonalization.override_fc_server_url_package"; + private static final String OVERRIDE_FC_SERVER_URL = + "debug.ondevicepersonalization.override_fc_server_url"; + + + @NonNull + private final Context mApplicationContext; + @NonNull + private final String mServicePackageName; + @NonNull + private final Injector mInjector; + + @NonNull + private final FederatedComputeManager mFederatedComputeManager; + + @VisibleForTesting + public FederatedComputeServiceImpl( + @NonNull String servicePackageName, + @NonNull Context applicationContext, + @NonNull Injector injector) { + this.mApplicationContext = Objects.requireNonNull(applicationContext); + this.mServicePackageName = Objects.requireNonNull(servicePackageName); + this.mInjector = Objects.requireNonNull(injector); + this.mFederatedComputeManager = Objects.requireNonNull( + injector.getFederatedComputeManager(mApplicationContext)); + } + + public FederatedComputeServiceImpl( + @NonNull String servicePackageName, + @NonNull Context applicationContext) { + this(servicePackageName, applicationContext, new Injector()); + } + + @Override + public void schedule(TrainingOptions trainingOptions, + IFederatedComputeCallback callback) { + try { + String url = AppManifestConfigHelper.getFcRemoteServerUrlFromOdpSettings( + mApplicationContext, mServicePackageName); + + // Check for override manifest url property, if package is debuggable + if (PackageUtils.isPackageDebuggable(mApplicationContext, mServicePackageName)) { + if (SystemProperties.get(OVERRIDE_FC_SERVER_URL_PACKAGE, "").equals( + mServicePackageName)) { + String overrideManifestUrl = SystemProperties.get(OVERRIDE_FC_SERVER_URL, ""); + if (!overrideManifestUrl.isEmpty()) { + sLogger.d(TAG + ": Overriding fc server URL for package " + + mServicePackageName + " to " + overrideManifestUrl); + url = overrideManifestUrl; + } + } + } + + if (url == null) { + sLogger.d("Missing remote server URL for package: " + mServicePackageName); + sendError(callback); + return; + } + + ContextData contextData = new ContextData(mServicePackageName); + TrainingOptions trainingOptionsWithContext = new TrainingOptions.Builder() + .setContextData(ContextData.toByteArray(contextData)) + .setTrainingInterval(trainingOptions.getTrainingInterval()) + .setPopulationName(trainingOptions.getPopulationName()) + .setServerAddress(url) + .build(); + ScheduleFederatedComputeRequest request = + new ScheduleFederatedComputeRequest.Builder() + .setTrainingOptions(trainingOptionsWithContext) + .build(); + mFederatedComputeManager.schedule( + request, + mInjector.getExecutor(), + new OutcomeReceiver<>() { + @Override + public void onResult(Object result) { + mInjector.getEventsDao(mApplicationContext).updateOrInsertEventState( + new EventState.Builder() + .setServicePackageName(mServicePackageName) + .setTaskIdentifier(trainingOptions.getPopulationName()) + .setToken(new byte[]{}) + .build()); + sendSuccess(callback); + } + + @Override + public void onError(Exception e) { + sLogger.e(TAG + ": Error while scheduling federatedCompute", e); + sendError(callback); + } + }); + } catch (IOException | PackageManager.NameNotFoundException e) { + sLogger.e(TAG + ": Error while scheduling federatedCompute", e); + sendError(callback); + } + } + + @Override + public void cancel(String populationName, + IFederatedComputeCallback callback) { + EventState eventState = mInjector.getEventsDao(mApplicationContext).getEventState( + populationName, mServicePackageName); + if (eventState == null) { + sLogger.d("No population registered for package: " + mServicePackageName); + sendSuccess(callback); + return; + } + mFederatedComputeManager.cancel( + populationName, + mInjector.getExecutor(), + new OutcomeReceiver<>() { + @Override + public void onResult(Object result) { + sendSuccess(callback); + } + + @Override + public void onError(Exception e) { + sLogger.e(TAG + ": Error while cancelling federatedCompute", e); + sendError(callback); + } + }); + } + + private void sendSuccess( + @NonNull IFederatedComputeCallback callback) { + try { + callback.onSuccess(); + } catch (RemoteException e) { + sLogger.e(TAG + ": Callback error", e); + } + } + + private void sendError(@NonNull IFederatedComputeCallback callback) { + try { + callback.onFailure(ClientConstants.STATUS_INTERNAL_ERROR); + } catch (RemoteException e) { + sLogger.e(TAG + ": Callback error", e); + } + } + + @VisibleForTesting + static class Injector { + ListeningExecutorService getExecutor() { + return OnDevicePersonalizationExecutors.getBackgroundExecutor(); + } + + FederatedComputeManager getFederatedComputeManager(Context context) { + return context.getSystemService(FederatedComputeManager.class); + } + + EventsDao getEventsDao( + Context context + ) { + return EventsDao.getInstance(context); + } + } +} diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIterator.java b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIterator.java index 8ab2b2a9..793f45b3 100644 --- a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIterator.java +++ b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIterator.java @@ -16,15 +16,45 @@ package com.android.ondevicepersonalization.services.federatedcompute; +import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESULT; +import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN; + import android.federatedcompute.ExampleStoreIterator; +import android.os.Bundle; + +import androidx.annotation.NonNull; + +import java.util.List; +import java.util.ListIterator; /** * Implementation of ExampleStoreIterator for OnDevicePersonalization */ public class OdpExampleStoreIterator implements ExampleStoreIterator { + + ListIterator<byte[]> mExampleIterator; + ListIterator<byte[]> mResumptionTokens; + + OdpExampleStoreIterator(List<byte[]> exampleList, List<byte[]> resumptionTokens) { + if (exampleList.size() != resumptionTokens.size()) { + throw new IllegalArgumentException( + "exampleList and resumptionTokens must be the same size"); + } + mExampleIterator = exampleList.listIterator(); + mResumptionTokens = resumptionTokens.listIterator(); + } + @Override - public void next(IteratorCallback callback) { - // TODO(278106108): Implement this method. + public void next(@NonNull IteratorCallback callback) { + if (mExampleIterator.hasNext()) { + byte[] example = mExampleIterator.next(); + byte[] resumptionToken = mResumptionTokens.next(); + Bundle result = new Bundle(); + result.putByteArray(EXTRA_EXAMPLE_ITERATOR_RESULT, example); + result.putByteArray(EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken); + callback.onIteratorNextSuccess(result); + return; + } callback.onIteratorNextSuccess(null); } diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactory.java b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactory.java index ae98cd5b..6ed73640 100644 --- a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactory.java +++ b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactory.java @@ -16,48 +16,36 @@ package com.android.ondevicepersonalization.services.federatedcompute; -import android.annotation.NonNull; -import android.content.Context; - -import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; -import com.android.ondevicepersonalization.services.data.events.EventsDao; - -import com.google.common.util.concurrent.ListeningExecutorService; +import java.util.List; /** * Factory for creating iterators */ public class OdpExampleStoreIteratorFactory { + private static volatile OdpExampleStoreIteratorFactory sSingleton; - private final ListeningExecutorService mExecutor; - private final EventsDao mEventsDao; - private static OdpExampleStoreIteratorFactory sSingleton; - - private OdpExampleStoreIteratorFactory( - @NonNull Context context, - @NonNull ListeningExecutorService executor) { - mExecutor = executor; - mEventsDao = EventsDao.getInstance(context); + private OdpExampleStoreIteratorFactory() { } /** * Returns an instance of OdpExampleStoreIteratorFactory */ - public static OdpExampleStoreIteratorFactory getInstance(Context context) { - synchronized (OdpExampleStoreIteratorFactory.class) { - if (null == sSingleton) { - sSingleton = new OdpExampleStoreIteratorFactory(context, - OnDevicePersonalizationExecutors.getBackgroundExecutor()); + public static OdpExampleStoreIteratorFactory getInstance() { + if (null == sSingleton) { + synchronized (OdpExampleStoreIteratorFactory.class) { + if (null == sSingleton) { + sSingleton = new OdpExampleStoreIteratorFactory(); + } } - return sSingleton; } + return sSingleton; } /** * Creates an OdpExampleStoreIterator */ - public OdpExampleStoreIterator createIterator() { - // TODO(278106108): Implement this method. - return new OdpExampleStoreIterator(); + public OdpExampleStoreIterator createIterator(List<byte[]> exampleList, + List<byte[]> resumptionTokens) { + return new OdpExampleStoreIterator(exampleList, resumptionTokens); } } diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java index e14c450c..b0ae9bd4 100644 --- a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java +++ b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java @@ -16,17 +16,272 @@ package com.android.ondevicepersonalization.services.federatedcompute; +import android.adservices.ondevicepersonalization.Constants; +import android.adservices.ondevicepersonalization.TrainingExampleInput; +import android.adservices.ondevicepersonalization.TrainingExampleOutputParcel; +import android.adservices.ondevicepersonalization.UserData; +import android.annotation.NonNull; +import android.content.Context; import android.federatedcompute.ExampleStoreService; +import android.federatedcompute.FederatedComputeManager; +import android.federatedcompute.common.ClientConstants; import android.os.Bundle; +import android.os.OutcomeReceiver; + +import com.android.ondevicepersonalization.internal.util.ByteArrayParceledListSlice; +import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.Flags; +import com.android.ondevicepersonalization.services.FlagsFactory; +import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; +import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; +import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor; +import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo; +import com.android.ondevicepersonalization.services.process.ProcessRunner; +import com.android.ondevicepersonalization.services.statsd.ApiCallStats; +import com.android.ondevicepersonalization.services.statsd.OdpStatsdLogger; +import com.android.ondevicepersonalization.services.util.Clock; +import com.android.ondevicepersonalization.services.util.MonotonicClock; +import com.android.ondevicepersonalization.services.util.StatsUtils; + +import com.google.common.util.concurrent.FluentFuture; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; + +import java.util.ArrayList; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +/** Implementation of ExampleStoreService for OnDevicePersonalization */ +public final class OdpExampleStoreService extends ExampleStoreService { + + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + private static final String TAG = OdpExampleStoreService.class.getSimpleName(); + private static final String TASK_NAME = "ExampleStore"; + + static class Injector { + Clock getClock() { + return MonotonicClock.getInstance(); + } + + Flags getFlags() { + return FlagsFactory.getFlags(); + } + + ListeningScheduledExecutorService getScheduledExecutor() { + return OnDevicePersonalizationExecutors.getScheduledExecutor(); + } + + ProcessRunner getProcessRunner() { + return ProcessRunner.getInstance(); + } + } + + private final Injector mInjector = new Injector(); + + /** Generates a unique task identifier from the given strings */ + public static String getTaskIdentifier(String populationName, String taskName) { + return populationName + "_" + taskName; + } -/** - * Implementation of ExampleStoreService for OnDevicePersonalization - */ -public class OdpExampleStoreService extends ExampleStoreService { @Override - public void startQuery(Bundle params, QueryCallback callback) { - // TODO(278106108): Validate params and pass to iterator - callback.onStartQuerySuccess( - OdpExampleStoreIteratorFactory.getInstance(this).createIterator()); + public void startQuery(@NonNull Bundle params, @NonNull QueryCallback callback) { + try { + ContextData contextData = + ContextData.fromByteArray( + Objects.requireNonNull( + params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA))); + String packageName = contextData.getPackageName(); + String populationName = + Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME)); + String taskName = + Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_NAME)); + + EventsDao eventDao = EventsDao.getInstance(getContext()); + + // Cancel job if on longer valid. This is written to the table during scheduling + // via {@link FederatedComputeServiceImpl} and deleted either during cancel or + // during maintenance for uninstalled packages. + EventState eventStatePopulation = eventDao.getEventState(populationName, packageName); + if (eventStatePopulation == null) { + sLogger.w("Job was either cancelled or package was uninstalled"); + // Cancel job. + FederatedComputeManager FCManager = + getContext().getSystemService(FederatedComputeManager.class); + if (FCManager == null) { + sLogger.e(TAG + ": Failed to get FederatedCompute Service"); + callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); + return; + } + FCManager.cancel( + populationName, + OnDevicePersonalizationExecutors.getBackgroundExecutor(), + new OutcomeReceiver<Object, Exception>() { + @Override + public void onResult(Object result) { + sLogger.d(TAG + ": Successfully canceled job"); + callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); + } + + @Override + public void onError(Exception error) { + sLogger.e(TAG + ": Error while cancelling job", error); + OutcomeReceiver.super.onError(error); + callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); + } + }); + return; + } + + // Get resumptionToken + EventState eventState = + eventDao.getEventState( + getTaskIdentifier(populationName, taskName), packageName); + byte[] resumptionToken = null; + if (eventState != null) { + resumptionToken = eventState.getToken(); + } + + TrainingExampleInput input = + new TrainingExampleInput.Builder() + .setResumptionToken(resumptionToken) + .setPopulationName(populationName) + .setTaskName(taskName) + .build(); + + ListenableFuture<IsolatedServiceInfo> loadFuture = + mInjector.getProcessRunner().loadIsolatedService(TASK_NAME, packageName); + ListenableFuture<TrainingExampleOutputParcel> resultFuture = + FluentFuture.from(loadFuture) + .transformAsync( + result -> executeOnTrainingExample(result, input, packageName), + OnDevicePersonalizationExecutors.getBackgroundExecutor()) + .transform( + result -> { + return result.getParcelable( + Constants.EXTRA_RESULT, + TrainingExampleOutputParcel.class); + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor()) + .withTimeout( + mInjector.getFlags().getIsolatedServiceDeadlineSeconds(), + TimeUnit.SECONDS, + mInjector.getScheduledExecutor()); + + Futures.addCallback( + resultFuture, + new FutureCallback<TrainingExampleOutputParcel>() { + @Override + public void onSuccess( + TrainingExampleOutputParcel trainingExampleOutputParcel) { + ByteArrayParceledListSlice trainingExamplesListSlice = + trainingExampleOutputParcel.getTrainingExamples(); + ByteArrayParceledListSlice resumptionTokensListSlice = + trainingExampleOutputParcel.getResumptionTokens(); + if (trainingExamplesListSlice == null + || resumptionTokensListSlice == null) { + callback.onStartQuerySuccess( + OdpExampleStoreIteratorFactory.getInstance() + .createIterator( + new ArrayList<>(), new ArrayList<>())); + } else { + callback.onStartQuerySuccess( + OdpExampleStoreIteratorFactory.getInstance() + .createIterator( + trainingExamplesListSlice.getList(), + resumptionTokensListSlice.getList())); + } + } + + @Override + public void onFailure(Throwable t) { + sLogger.w(t, "%s : Request failed.", TAG); + callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); + } + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor()); + + var unused = + Futures.whenAllComplete(loadFuture, resultFuture) + .callAsync( + () -> + mInjector + .getProcessRunner() + .unloadIsolatedService(loadFuture.get()), + OnDevicePersonalizationExecutors.getBackgroundExecutor()); + } catch (Exception e) { + sLogger.w(e, "%s : Start query failed.", TAG); + callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); + } + } + + private ListenableFuture<Bundle> executeOnTrainingExample( + IsolatedServiceInfo isolatedServiceInfo, + TrainingExampleInput exampleInput, + String packageName) { + sLogger.d(TAG + ": executeOnTrainingExample() started."); + Bundle serviceParams = new Bundle(); + serviceParams.putParcelable(Constants.EXTRA_INPUT, exampleInput); + DataAccessServiceImpl binder = + new DataAccessServiceImpl( + packageName, + getContext(), /* includeLocalData */ + true, + /* includeEventData */ true); + serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder); + UserDataAccessor userDataAccessor = new UserDataAccessor(); + UserData userData = userDataAccessor.getUserData(); + serviceParams.putParcelable(Constants.EXTRA_USER_DATA, userData); + ListenableFuture<Bundle> result = + mInjector + .getProcessRunner() + .runIsolatedService( + isolatedServiceInfo, + AppManifestConfigHelper.getServiceNameFromOdpSettings( + getContext(), packageName), + Constants.OP_TRAINING_EXAMPLE, + serviceParams); + return FluentFuture.from(result) + .transform( + val -> { + writeServiceRequestMetrics( + val, + isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_SUCCESS); + return val; + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor()) + .catchingAsync( + Exception.class, + e -> { + writeServiceRequestMetrics( + null, + isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_INTERNAL_ERROR); + return Futures.immediateFailedFuture(e); + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor()); + } + + private void writeServiceRequestMetrics(Bundle result, long startTimeMillis, int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - startTimeMillis); + int overheadLatencyMillis = + (int) StatsUtils.getOverheadLatencyMillis(latencyMillis, result); + ApiCallStats callStats = + new ApiCallStats.Builder(ApiCallStats.API_SERVICE_ON_TRAINING_EXAMPLE) + .setLatencyMillis(latencyMillis) + .setOverheadLatencyMillis(overheadLatencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } + + // used for tests to provide mock/real implementation of context. + private Context getContext() { + return this.getApplicationContext(); } } diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpFederatedComputeJobService.java b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpFederatedComputeJobService.java deleted file mode 100644 index a4ad84e1..00000000 --- a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpFederatedComputeJobService.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (C) 2022 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.federatedcompute; - -import static android.app.job.JobScheduler.RESULT_FAILURE; - -import static com.android.ondevicepersonalization.services.OnDevicePersonalizationConfig.FEDERATED_COMPUTE_TASK_JOB_ID; -import static com.android.ondevicepersonalization.services.OnDevicePersonalizationConfig.ODP_POPULATION_NAME; - -import android.app.job.JobInfo; -import android.app.job.JobParameters; -import android.app.job.JobScheduler; -import android.app.job.JobService; -import android.content.ComponentName; -import android.content.Context; -import android.federatedcompute.FederatedComputeManager; -import android.federatedcompute.common.ScheduleFederatedComputeRequest; -import android.federatedcompute.common.TrainingOptions; -import android.os.OutcomeReceiver; - -import com.android.internal.annotations.VisibleForTesting; -import com.android.ondevicepersonalization.internal.util.LoggerFactory; -import com.android.ondevicepersonalization.services.FlagsFactory; -import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; - -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; - -/** JobService to handle the OnDevicePersonalization FederatedCompute scheduling */ -public class OdpFederatedComputeJobService extends JobService { - private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); - private static final String TAG = "OdpFederatedComputeJobService"; - private static final long PERIOD_SECONDS = 86400; - private ListenableFuture<Void> mFuture; - - /** Schedules a unique instance of OdpFederatedComputeJobService to be run. */ - public static int schedule(Context context) { - JobScheduler jobScheduler = context.getSystemService(JobScheduler.class); - if (jobScheduler.getPendingJob(FEDERATED_COMPUTE_TASK_JOB_ID) != null) { - sLogger.d(TAG + ": Job is already scheduled. Doing nothing,"); - return RESULT_FAILURE; - } - ComponentName serviceComponent = - new ComponentName(context, OdpFederatedComputeJobService.class); - JobInfo.Builder builder = - new JobInfo.Builder(FEDERATED_COMPUTE_TASK_JOB_ID, serviceComponent); - - // Constraints. - // TODO(278106108): Update scheduling conditions. - builder.setRequiresDeviceIdle(true); - builder.setRequiresBatteryNotLow(true); - builder.setRequiredNetworkType(JobInfo.NETWORK_TYPE_NONE); - builder.setPeriodic(1000 * PERIOD_SECONDS); // JobScheduler uses Milliseconds. - // persist this job across boots - builder.setPersisted(true); - - return jobScheduler.schedule(builder.build()); - } - - @Override - public boolean onStartJob(JobParameters params) { - sLogger.d(TAG + ": onStartJob()"); - if (FlagsFactory.getFlags().getGlobalKillSwitch()) { - sLogger.d(TAG + ": GlobalKillSwitch enabled, finishing job."); - jobFinished(params, /* wantsReschedule= */ false); - return true; - } - mFuture = - Futures.submit( - new Runnable() { - @Override - public void run() { - scheduleFederatedCompute(); - } - }, - OnDevicePersonalizationExecutors.getBackgroundExecutor()); - - Futures.addCallback( - mFuture, - new FutureCallback<Void>() { - @Override - public void onSuccess(Void result) { - sLogger.d(TAG + ": Job completed successfully."); - // Tell the JobScheduler that the job has completed and does not needs to be - // rescheduled. - jobFinished(params, /* wantsReschedule= */ false); - } - - @Override - public void onFailure(Throwable t) { - sLogger.e(TAG + ": Failed to handle JobService: " + params.getJobId(), t); - // When failure, also tell the JobScheduler that the job has completed and - // does not need to be rescheduled. - jobFinished(params, /* wantsReschedule= */ false); - } - }, - OnDevicePersonalizationExecutors.getBackgroundExecutor()); - - return true; - } - - @Override - public boolean onStopJob(JobParameters params) { - if (mFuture != null) { - mFuture.cancel(true); - } - // Reschedule the job since it ended before finishing - return true; - } - - @VisibleForTesting - void scheduleFederatedCompute() { - if (federatedComputeNeedsScheduling()) { - FederatedComputeManager FCManager = - this.getSystemService(FederatedComputeManager.class); - TrainingOptions trainingOptions = - new TrainingOptions.Builder().setPopulationName(ODP_POPULATION_NAME).build(); - ScheduleFederatedComputeRequest request = - new ScheduleFederatedComputeRequest.Builder() - .setTrainingOptions(trainingOptions) - .build(); - FCManager.scheduleFederatedCompute( - request, - OnDevicePersonalizationExecutors.getBackgroundExecutor(), - new OutcomeReceiver<Object, Exception>() { - @Override - public void onResult(Object result) { - sLogger.d(TAG + ": Successfully scheduled federatedCompute"); - } - - @Override - public void onError(Exception error) { - sLogger.e(TAG + ": Error while scheduling federatedCompute", error); - OutcomeReceiver.super.onError(error); - } - }); - } - } - - private boolean federatedComputeNeedsScheduling() { - // TODO(278106108): Add conditions for when to schedule. - return true; - } -} diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingService.java b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingService.java index af1bad17..259c4187 100644 --- a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingService.java +++ b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingService.java @@ -19,20 +19,106 @@ package com.android.ondevicepersonalization.services.federatedcompute; import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; import android.federatedcompute.ResultHandlingService; +import android.federatedcompute.common.ClientConstants; import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingOptions; +import android.os.Bundle; +import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; + +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.function.Consumer; -/** - * Implementation of ResultHandlingService for OnDevicePersonalization - */ +/** Implementation of ResultHandlingService for OnDevicePersonalization */ public class OdpResultHandlingService extends ResultHandlingService { + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + private static final String TAG = "OdpResultHandlingService"; + @Override - public void handleResult(TrainingOptions trainingOptions, boolean success, - List<ExampleConsumption> exampleConsumptionList, Consumer<Integer> callback) { - // TODO(278106108): Implement this method - callback.accept(STATUS_SUCCESS); + public void handleResult(Bundle params, Consumer<Integer> callback) { + try { + ContextData contextData = + ContextData.fromByteArray( + Objects.requireNonNull( + params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA))); + String packageName = contextData.getPackageName(); + String populationName = + Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME)); + String taskName = + Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_NAME)); + int computationResult = params.getInt(ClientConstants.EXTRA_COMPUTATION_RESULT); + ArrayList<ExampleConsumption> consumptionList = + Objects.requireNonNull( + params.getParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, + ExampleConsumption.class)); + + // Just return if training failed. Next query will retry the failed examples. + if (computationResult != STATUS_SUCCESS) { + callback.accept(ClientConstants.STATUS_SUCCESS); + return; + } + + ListenableFuture<Boolean> result = + Futures.submit( + () -> + processExampleConsumptions( + consumptionList, populationName, taskName, packageName), + OnDevicePersonalizationExecutors.getBackgroundExecutor()); + Futures.addCallback( + result, + new FutureCallback<Boolean>() { + @Override + public void onSuccess(Boolean result) { + if (result) { + callback.accept(STATUS_SUCCESS); + } else { + callback.accept(ClientConstants.STATUS_INTERNAL_ERROR); + } + } + + @Override + public void onFailure(Throwable t) { + sLogger.w(TAG + ": handleResult failed.", t); + callback.accept(ClientConstants.STATUS_INTERNAL_ERROR); + } + }, + OnDevicePersonalizationExecutors.getBackgroundExecutor()); + + } catch (Exception e) { + sLogger.w(TAG + ": handleResult failed.", e); + callback.accept(ClientConstants.STATUS_INTERNAL_ERROR); + } + } + + private Boolean processExampleConsumptions( + List<ExampleConsumption> exampleConsumptions, + String populationName, + String taskName, + String packageName) { + List<EventState> eventStates = new ArrayList<>(); + for (ExampleConsumption consumption : exampleConsumptions) { + String taskIdentifier = + OdpExampleStoreService.getTaskIdentifier(populationName, taskName); + byte[] resumptionToken = consumption.getResumptionToken(); + if (resumptionToken != null) { + eventStates.add( + new EventState.Builder() + .setServicePackageName(packageName) + .setTaskIdentifier(taskIdentifier) + .setToken(resumptionToken) + .build()); + } + } + EventsDao eventsDao = EventsDao.getInstance(this); + return eventsDao.updateOrInsertEventStatesTransaction(eventStates); } } diff --git a/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobService.java b/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobService.java index 80a2d014..f02c311e 100644 --- a/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobService.java +++ b/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobService.java @@ -28,12 +28,13 @@ import android.content.Context; import android.content.pm.PackageInfo; import android.content.pm.PackageManager; - import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; import com.android.ondevicepersonalization.services.FlagsFactory; import com.android.ondevicepersonalization.services.OnDevicePersonalizationConfig; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonalizationVendorDataDao; import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; import com.android.ondevicepersonalization.services.util.PackageUtils; @@ -53,7 +54,13 @@ import java.util.Set; public class OnDevicePersonalizationMaintenanceJobService extends JobService { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "OnDevicePersonalizationMaintenanceJobService"; + + // Every 24hrs. private static final long PERIOD_SECONDS = 86400; + + // The maximum deletion timeframe is 63 days. + // Set parameter to 60 days to account for job scheduler delays. + private static final long MAXIMUM_DELETION_TIMEFRAME_MILLIS = 5184000000L; private ListenableFuture<Void> mFuture; /** @@ -83,6 +90,39 @@ public class OnDevicePersonalizationMaintenanceJobService extends JobService { return jobScheduler.schedule(builder.build()); } + @VisibleForTesting + static void cleanupVendorData(Context context) throws Exception { + EventsDao eventsDao = EventsDao.getInstance(context); + + // Set of packageName and cert + Set<Map.Entry<String, String>> vendors = new HashSet<>( + OnDevicePersonalizationVendorDataDao.getVendors(context)); + + // Remove all valid packages from the set + for (PackageInfo packageInfo : context.getPackageManager().getInstalledPackages( + PackageManager.PackageInfoFlags.of(GET_META_DATA))) { + String packageName = packageInfo.packageName; + if (AppManifestConfigHelper.manifestContainsOdpSettings( + context, packageName)) { + vendors.remove(new AbstractMap.SimpleImmutableEntry<>(packageName, + PackageUtils.getCertDigest(context, packageName))); + } + } + + sLogger.d(TAG + ": Deleting: " + vendors); + // Delete the remaining tables for packages not found onboarded + for (Map.Entry<String, String> entry : vendors) { + String packageName = entry.getKey(); + String certDigest = entry.getValue(); + OnDevicePersonalizationVendorDataDao.deleteVendorData(context, packageName, certDigest); + eventsDao.deleteEventState(entry.getKey()); + } + + // Cleanup event and queries table. + eventsDao.deleteEventsAndQueries( + System.currentTimeMillis() - MAXIMUM_DELETION_TIMEFRAME_MILLIS); + } + @Override public boolean onStartJob(JobParameters params) { sLogger.d(TAG + ": onStartJob()"); @@ -91,6 +131,11 @@ public class OnDevicePersonalizationMaintenanceJobService extends JobService { jobFinished(params, /* wantsReschedule = */ false); return true; } + if (!UserPrivacyStatus.getInstance().isPersonalizationStatusEnabled()) { + sLogger.d(TAG + ": Personalization is not allowed, finishing job."); + jobFinished(params, false); + return true; + } Context context = this; mFuture = Futures.submit(new Runnable() { @Override @@ -136,29 +181,4 @@ public class OnDevicePersonalizationMaintenanceJobService extends JobService { // Reschedule the job since it ended before finishing return true; } - - @VisibleForTesting - static void cleanupVendorData(Context context) throws Exception { - Set<Map.Entry<String, String>> vendors = new HashSet<>( - OnDevicePersonalizationVendorDataDao.getVendors(context)); - - // Remove all valid packages from the set - for (PackageInfo packageInfo : context.getPackageManager().getInstalledPackages( - PackageManager.PackageInfoFlags.of(GET_META_DATA))) { - String packageName = packageInfo.packageName; - if (AppManifestConfigHelper.manifestContainsOdpSettings( - context, packageName)) { - vendors.remove(new AbstractMap.SimpleImmutableEntry<>(packageName, - PackageUtils.getCertDigest(context, packageName))); - } - } - - sLogger.d(TAG + ": Deleting: " + vendors.toString()); - // Delete the remaining tables for packages not found onboarded - for (Map.Entry<String, String> entry : vendors) { - OnDevicePersonalizationVendorDataDao.deleteVendorData(context, entry.getKey(), - entry.getValue()); - } - - } } diff --git a/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfig.java b/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfig.java index f970dc2d..037e9a48 100644 --- a/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfig.java +++ b/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfig.java @@ -22,10 +22,12 @@ package com.android.ondevicepersonalization.services.manifest; public class AppManifestConfig { private final String mDownloadUrl; private final String mServiceName; + private final String mFcRemoteServerUrl; - public AppManifestConfig(String downloadUrl, String serviceName) { + public AppManifestConfig(String downloadUrl, String serviceName, String fcRemoteServerUrl) { mDownloadUrl = downloadUrl; mServiceName = serviceName; + mFcRemoteServerUrl = fcRemoteServerUrl; } /** @@ -41,4 +43,11 @@ public class AppManifestConfig { public String getServiceName() { return mServiceName; } + + /** + * @return The federated compute service remote server url configured in manifest + */ + public String getFcRemoteServerUrl() { + return mFcRemoteServerUrl; + } } diff --git a/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigHelper.java b/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigHelper.java index 9d114fc9..b2447fe8 100644 --- a/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigHelper.java +++ b/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigHelper.java @@ -92,4 +92,15 @@ public final class AppManifestConfigHelper { String packageName) { return getAppManifestConfig(context, packageName).getServiceName(); } + + /** + * Gets the federated compute service remote server url from package's ODP settings config + * + * @param context the context of the API call. + * @param packageName the packageName of the package whose manifest config will be read + */ + public static String getFcRemoteServerUrlFromOdpSettings(Context context, + String packageName) { + return getAppManifestConfig(context, packageName).getFcRemoteServerUrl(); + } } diff --git a/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParser.java b/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParser.java index 26c4a7da..05c6be6e 100644 --- a/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParser.java +++ b/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParser.java @@ -31,8 +31,10 @@ public class AppManifestConfigParser { private static final String TAG = "AppManifestConfigParser"; private static final String TAG_ON_DEVICE_PERSONALIZATION_CONFIG = "on-device-personalization"; private static final String TAG_DOWNLOAD_SETTINGS = "download-settings"; + private static final String TAG_FEDERATED_COMPUTE_SETTINGS = "federated-compute-settings"; private static final String TAG_SERVICE = "service"; private static final String ATTR_DOWNLOAD_URL = "url"; + private static final String ATTR_FC_URL = "url"; private static final String ATTR_NAME = "name"; private AppManifestConfigParser() { @@ -47,6 +49,7 @@ public class AppManifestConfigParser { XmlPullParserException { String downloadUrl = null; String serviceName = null; + String fcServerUrl = null; while (parser.getEventType() != XmlPullParser.START_TAG) { parser.next(); @@ -70,12 +73,15 @@ public class AppManifestConfigParser { case TAG_DOWNLOAD_SETTINGS: downloadUrl = parser.getAttributeValue(null, ATTR_DOWNLOAD_URL); break; + case TAG_FEDERATED_COMPUTE_SETTINGS: + fcServerUrl = parser.getAttributeValue(null, ATTR_FC_URL); + break; default: sLogger.i(TAG + ": Unknown tag: " + parser.getName()); } parser.next(); } - return new AppManifestConfig(downloadUrl, serviceName); + return new AppManifestConfig(downloadUrl, serviceName, fcServerUrl); } } diff --git a/src/com/android/ondevicepersonalization/services/policyengine/UserDataAccessor.kt b/src/com/android/ondevicepersonalization/services/policyengine/UserDataAccessor.kt index 51ec8fdc..188f4695 100644 --- a/src/com/android/ondevicepersonalization/services/policyengine/UserDataAccessor.kt +++ b/src/com/android/ondevicepersonalization/services/policyengine/UserDataAccessor.kt @@ -16,19 +16,12 @@ package com.android.ondevicepersonalization.services.policyengine -import android.util.Log - import android.adservices.ondevicepersonalization.UserData import com.android.libraries.pcc.chronicle.util.MutableTypedMap import com.android.libraries.pcc.chronicle.util.TypedMap import com.android.libraries.pcc.chronicle.api.ConnectionRequest -import com.android.libraries.pcc.chronicle.api.ConnectionResult -import com.android.libraries.pcc.chronicle.api.ReadConnection import com.android.libraries.pcc.chronicle.api.error.ChronicleError -import com.android.libraries.pcc.chronicle.api.error.PolicyNotFound -import com.android.libraries.pcc.chronicle.api.error.PolicyViolation -import com.android.libraries.pcc.chronicle.api.error.Disabled import com.android.libraries.pcc.chronicle.api.ProcessorNode import com.android.ondevicepersonalization.internal.util.LoggerFactory @@ -39,13 +32,10 @@ import com.android.ondevicepersonalization.services.policyengine.policy.DataIngr import com.android.ondevicepersonalization.services.policyengine.policy.rules.KidStatusEnabled import com.android.ondevicepersonalization.services.policyengine.policy.rules.LimitedAdsTrackingEnabled -import com.android.ondevicepersonalization.services.data.user.RawUserData - class UserDataAccessor : ProcessorNode { - private val sLogger: LoggerFactory.Logger = LoggerFactory.getLogger(); - private val TAG: String = "UserDataAccessor"; - private var policyContext: MutableTypedMap - private val rawUserData: RawUserData = RawUserData.getInstance() + private val sLogger: LoggerFactory.Logger = LoggerFactory.getLogger() + private val TAG: String = "UserDataAccessor" + private var policyContext: MutableTypedMap = MutableTypedMap() override val requiredConnectionTypes = setOf(UserDataReader::class.java) @@ -56,7 +46,6 @@ class UserDataAccessor : ProcessorNode { ) init { - policyContext = MutableTypedMap() policyContext[KidStatusEnabled] = false policyContext[LimitedAdsTrackingEnabled] = false @@ -67,12 +56,11 @@ class UserDataAccessor : ProcessorNode { fun getUserData(): UserData? { try { val userDataReader: UserDataReader? = - chronicleManager.chronicle.getConnectionOrThrow( - ConnectionRequest(UserDataReader::class.java, this, - DataIngressPolicy.NPA_DATA_POLICY) - ) - val userData: UserData? = userDataReader?.readUserData() - return userData + chronicleManager.chronicle.getConnectionOrThrow( + ConnectionRequest(UserDataReader::class.java, this, + DataIngressPolicy.NPA_DATA_POLICY) + ) + return userDataReader?.readUserData() } catch (e: ChronicleError) { sLogger.e(e, TAG + ": Expect success but connection failed with: ") return null diff --git a/src/com/android/ondevicepersonalization/services/policyengine/data/GeneratedDTD.kt b/src/com/android/ondevicepersonalization/services/policyengine/data/GeneratedDTD.kt index b29080d2..60e371cd 100644 --- a/src/com/android/ondevicepersonalization/services/policyengine/data/GeneratedDTD.kt +++ b/src/com/android/ondevicepersonalization/services/policyengine/data/GeneratedDTD.kt @@ -29,11 +29,9 @@ public val USER_DATA_GENERATED_DTD: DataTypeDescriptor = dataTypeDescriptor(name "availableStorageMB" to FieldType.Long "batteryPercentage" to FieldType.Integer "carrier" to FieldType.String - "connectionType" to FieldType.Integer - "connectionSpeedKbps" to FieldType.Long - "networkMetered" to FieldType.Boolean - "appInstallInfo" to FieldType.List(dataTypeDescriptor(name = - "chronicle_dtd.AppInstallInfo", cls = AppInstallInfo::class) { + "dataNetworkType" to FieldType.Integer + "appInfos" to FieldType.List(dataTypeDescriptor(name = + "chronicle_dtd.AppInfo", cls = AppInfo::class) { "packageName" to FieldType.String "installed" to FieldType.Boolean }) diff --git a/src/com/android/ondevicepersonalization/services/policyengine/data/UserData.kt b/src/com/android/ondevicepersonalization/services/policyengine/data/UserData.kt index 1dcf526a..f9961518 100644 --- a/src/com/android/ondevicepersonalization/services/policyengine/data/UserData.kt +++ b/src/com/android/ondevicepersonalization/services/policyengine/data/UserData.kt @@ -32,16 +32,14 @@ data class UserData ( val availableStorageBytes: Long, val batteryPercentage: Int, val carrier: String, - val connectionType: Int, - val connectionSpeedKbps: Long, - val networkMetered: Boolean, - val appInstallInfo: List<AppInstallInfo>, + val dataNetworkType: Int, + val appInfos: List<AppInfo>, val appUsageHistory: List<AppUsageStatus>, val currentLocation: Location, val locationHistory: List<LocationStatus>, ) -data class AppInstallInfo ( +data class AppInfo ( val packageName: String, val installed: Boolean ) diff --git a/src/com/android/ondevicepersonalization/services/policyengine/data/impl/UserDataConnectionProvider.kt b/src/com/android/ondevicepersonalization/services/policyengine/data/impl/UserDataConnectionProvider.kt index 01e21298..bc1184c0 100644 --- a/src/com/android/ondevicepersonalization/services/policyengine/data/impl/UserDataConnectionProvider.kt +++ b/src/com/android/ondevicepersonalization/services/policyengine/data/impl/UserDataConnectionProvider.kt @@ -17,17 +17,13 @@ package com.android.ondevicepersonalization.services.policyengine.data.impl import android.adservices.ondevicepersonalization.UserData -import android.adservices.ondevicepersonalization.OSVersion -import android.adservices.ondevicepersonalization.DeviceMetrics import android.adservices.ondevicepersonalization.Location -import android.adservices.ondevicepersonalization.AppInstallInfo +import android.adservices.ondevicepersonalization.AppInfo import android.adservices.ondevicepersonalization.AppUsageStatus import android.adservices.ondevicepersonalization.LocationStatus import android.util.ArrayMap -import com.android.ondevicepersonalization.services.data.user.UserDataDao import com.android.ondevicepersonalization.services.data.user.RawUserData -import com.android.ondevicepersonalization.services.data.user.UserDataCollector import com.android.libraries.pcc.chronicle.api.Connection import com.android.libraries.pcc.chronicle.api.ConnectionProvider import com.android.libraries.pcc.chronicle.api.ConnectionRequest @@ -41,9 +37,6 @@ import com.android.ondevicepersonalization.services.policyengine.data.UserDataRe import java.time.Duration -import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.withContext - /** [ConnectionProvider] implementation for ODA use data. */ class UserDataConnectionProvider() : ConnectionProvider { override val dataType: DataType = @@ -59,21 +52,15 @@ class UserDataConnectionProvider() : ConnectionProvider { class UserDataReaderImpl : UserDataReader { override fun readUserData(): UserData? { - val rawUserData: RawUserData? = RawUserData.getInstance(); - if (rawUserData == null) { - return null; - } - + val rawUserData: RawUserData = RawUserData.getInstance() ?: return null // TODO(b/267013762): more privacy-preserving processing may be needed - return UserData.Builder() + val builder: UserData.Builder = UserData.Builder() .setTimezoneUtcOffsetMins(rawUserData.utcOffset) .setOrientation(rawUserData.orientation) .setAvailableStorageBytes(rawUserData.availableStorageBytes) .setBatteryPercentage(rawUserData.batteryPercentage) .setCarrier(rawUserData.carrier.toString()) - .setConnectionType(rawUserData.connectionType.ordinal) - .setNetworkConnectionSpeedKbps(rawUserData.connectionSpeedKbps) - .setNetworkMetered(rawUserData.networkMetered) + .setDataNetworkType(rawUserData.dataNetworkType) .setCurrentLocation(Location.Builder() .setTimestampSeconds(rawUserData.currentLocation.timeMillis / 1000) .setLatitude(rawUserData.currentLocation.latitude) @@ -81,17 +68,21 @@ class UserDataConnectionProvider() : ConnectionProvider { .setLocationProvider(rawUserData.currentLocation.provider.ordinal) .setPreciseLocation(rawUserData.currentLocation.isPreciseLocation) .build()) - .setAppInstallInfo(getAppInstallInfo(rawUserData)) + .setAppInfos(getAppInfos(rawUserData)) .setAppUsageHistory(getAppUsageHistory(rawUserData)) .setLocationHistory(getLocationHistory(rawUserData)) - .build() + // TODO (b/299683848): follow up the codegen bug + if (rawUserData.networkCapabilities != null) { + builder.setNetworkCapabilities(rawUserData.networkCapabilities) + } + return builder.build() } - private fun getAppInstallInfo(rawUserData: RawUserData): Map<String, AppInstallInfo> { - var res = ArrayMap<String, AppInstallInfo>() + private fun getAppInfos(rawUserData: RawUserData): Map<String, AppInfo> { + var res = ArrayMap<String, AppInfo>() for (appInfo in rawUserData.appsInfo) { res.put(appInfo.packageName, - AppInstallInfo.Builder() + AppInfo.Builder() .setInstalled(appInfo.installed) .build()) } diff --git a/src/com/android/ondevicepersonalization/services/policyengine/policy/DataIngressPolicy.kt b/src/com/android/ondevicepersonalization/services/policyengine/policy/DataIngressPolicy.kt index 3a9fb867..2e569674 100644 --- a/src/com/android/ondevicepersonalization/services/policyengine/policy/DataIngressPolicy.kt +++ b/src/com/android/ondevicepersonalization/services/policyengine/policy/DataIngressPolicy.kt @@ -49,10 +49,8 @@ class DataIngressPolicy { "availableStorageMB" {rawUsage(UsageType.ANY)} "batteryPercentage" {rawUsage(UsageType.ANY)} "carrier" {rawUsage(UsageType.ANY)} - "connectionType" {rawUsage(UsageType.ANY)} - "connectionSpeedKbps" {rawUsage(UsageType.ANY)} - "networkMetered" {rawUsage(UsageType.ANY)} - "appInstallInfo" { + "dataNetworkType" {rawUsage(UsageType.ANY)} + "appInfos" { "packageName" {rawUsage(UsageType.ANY)} "installed" {rawUsage(UsageType.ANY)} } diff --git a/src/com/android/ondevicepersonalization/services/process/IsolatedServiceInfo.java b/src/com/android/ondevicepersonalization/services/process/IsolatedServiceInfo.java index b35cfcce..06af5ea5 100644 --- a/src/com/android/ondevicepersonalization/services/process/IsolatedServiceInfo.java +++ b/src/com/android/ondevicepersonalization/services/process/IsolatedServiceInfo.java @@ -22,13 +22,22 @@ import com.android.ondevicepersonalization.libraries.plugin.PluginController; /** Wraps an instance of a loaded isolated service */ public class IsolatedServiceInfo { + @NonNull private final long mStartTimeMillis; @NonNull private final PluginController mPluginController; - IsolatedServiceInfo(@NonNull PluginController pluginController) { + IsolatedServiceInfo( + long startTimeMillis, + @NonNull PluginController pluginController) { + mStartTimeMillis = startTimeMillis; mPluginController = pluginController; } PluginController getPluginController() { return mPluginController; } + + /** Returns the service start time. */ + public long getStartTimeMillis() { + return mStartTimeMillis; + } } diff --git a/src/com/android/ondevicepersonalization/services/process/OnDevicePersonalizationPlugin.java b/src/com/android/ondevicepersonalization/services/process/OnDevicePersonalizationPlugin.java index a891efa6..440fd832 100644 --- a/src/com/android/ondevicepersonalization/services/process/OnDevicePersonalizationPlugin.java +++ b/src/com/android/ondevicepersonalization/services/process/OnDevicePersonalizationPlugin.java @@ -56,21 +56,21 @@ public class OnDevicePersonalizationPlugin implements Plugin { mPluginContext = pluginContext; try { - String className = input.getString(ProcessUtils.PARAM_CLASS_NAME_KEY); + String className = input.getString(ProcessRunner.PARAM_CLASS_NAME_KEY); if (className == null || className.isEmpty()) { sLogger.e(TAG + ": className missing."); sendErrorResult(FailureType.ERROR_EXECUTING_PLUGIN); return; } - int operation = input.getInt(ProcessUtils.PARAM_OPERATION_KEY); + int operation = input.getInt(ProcessRunner.PARAM_OPERATION_KEY); if (operation == 0) { sLogger.e(TAG + ": operation missing or invalid."); sendErrorResult(FailureType.ERROR_EXECUTING_PLUGIN); return; } - Bundle serviceParams = input.getParcelable(ProcessUtils.PARAM_SERVICE_INPUT, + Bundle serviceParams = input.getParcelable(ProcessRunner.PARAM_SERVICE_INPUT, Bundle.class); if (serviceParams == null) { sLogger.e(TAG + ": Missing service input."); diff --git a/src/com/android/ondevicepersonalization/services/process/ProcessUtils.java b/src/com/android/ondevicepersonalization/services/process/ProcessRunner.java index 293cf207..8de09e94 100644 --- a/src/com/android/ondevicepersonalization/services/process/ProcessUtils.java +++ b/src/com/android/ondevicepersonalization/services/process/ProcessRunner.java @@ -17,7 +17,6 @@ package com.android.ondevicepersonalization.services.process; import android.adservices.ondevicepersonalization.Constants; -import android.adservices.ondevicepersonalization.OnDevicePersonalizationException; import android.annotation.NonNull; import android.annotation.Nullable; import android.content.Context; @@ -33,6 +32,10 @@ import com.android.ondevicepersonalization.libraries.plugin.PluginController; import com.android.ondevicepersonalization.libraries.plugin.PluginInfo; import com.android.ondevicepersonalization.libraries.plugin.PluginManager; import com.android.ondevicepersonalization.libraries.plugin.impl.PluginManagerImpl; +import com.android.ondevicepersonalization.services.OdpServiceException; +import com.android.ondevicepersonalization.services.OnDevicePersonalizationApplication; +import com.android.ondevicepersonalization.services.util.Clock; +import com.android.ondevicepersonalization.services.util.MonotonicClock; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; @@ -41,7 +44,7 @@ import com.google.common.util.concurrent.ListenableFuture; import java.util.Objects; /** Utilities to support loading and executing plugins. */ -public class ProcessUtils { +public class ProcessRunner { private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); private static final String TAG = "ProcessUtils"; private static final String ENTRY_POINT_CLASS = @@ -51,24 +54,59 @@ public class ProcessUtils { public static final String PARAM_OPERATION_KEY = "param.operation"; public static final String PARAM_SERVICE_INPUT = "param.service_input"; - private static PluginManager sPluginManager; + @NonNull private Context mApplicationContext; + + private static volatile ProcessRunner sProcessRunner; + private static volatile PluginManager sPluginManager; + + static class Injector { + Clock getClock() { + return MonotonicClock.getInstance(); + } + } + + private final Injector mInjector; + + /** Creates a ProcessRunner. */ + ProcessRunner( + @NonNull Context applicationContext, + @NonNull Injector injector) { + mApplicationContext = Objects.requireNonNull(applicationContext); + mInjector = Objects.requireNonNull(injector); + } + + /** Returns the global ProcessRunner */ + @NonNull public static ProcessRunner getInstance() { + if (sProcessRunner == null) { + synchronized (ProcessRunner.class) { + if (sProcessRunner == null) { + sProcessRunner = new ProcessRunner( + OnDevicePersonalizationApplication.getAppContext(), + new Injector()); + } + } + } + return sProcessRunner; + } /** Loads a service in an isolated process */ - @NonNull public static ListenableFuture<IsolatedServiceInfo> loadIsolatedService( - @NonNull String taskName, @NonNull String packageName, - @NonNull Context context) { + @NonNull public ListenableFuture<IsolatedServiceInfo> loadIsolatedService( + @NonNull String taskName, @NonNull String packageName) { try { sLogger.d(TAG + ": loadIsolatedService: " + packageName); - return loadPlugin(createPluginController( - createPluginId(packageName, taskName), - getPluginManager(context), packageName)); + return loadPlugin( + mInjector.getClock().elapsedRealtime(), + createPluginController( + createPluginId(packageName, taskName), + getPluginManager(mApplicationContext), + packageName)); } catch (Exception e) { return Futures.immediateFailedFuture(e); } } /** Executes a service loaded in an isolated process */ - @NonNull public static ListenableFuture<Bundle> runIsolatedService( + @NonNull public ListenableFuture<Bundle> runIsolatedService( @NonNull IsolatedServiceInfo isolatedProcessInfo, @NonNull String className, int operationCode, @@ -81,13 +119,22 @@ public class ProcessUtils { return executePlugin(isolatedProcessInfo.getPluginController(), pluginParams); } - @NonNull static PluginManager getPluginManager(@NonNull Context context) { - synchronized (ProcessUtils.class) { - if (sPluginManager == null) { - sPluginManager = new PluginManagerImpl(context); + /** Unloads a service loaded in an isolated process */ + @NonNull public ListenableFuture<Void> unloadIsolatedService( + @NonNull IsolatedServiceInfo isolatedServiceInfo) { + return unloadPlugin(isolatedServiceInfo.getPluginController()); + } + + @NonNull + static PluginManager getPluginManager(@NonNull Context applicationContext) { + if (sPluginManager == null) { + synchronized (ProcessRunner.class) { + if (sPluginManager == null) { + sPluginManager = new PluginManagerImpl(applicationContext); + } } - return sPluginManager; } + return sPluginManager; } @NonNull static PluginController createPluginController( @@ -99,6 +146,7 @@ public class ProcessUtils { } @NonNull static ListenableFuture<IsolatedServiceInfo> loadPlugin( + long startTimeMillis, @NonNull PluginController pluginController) { return CallbackToFutureAdapter.getFuture( completer -> { @@ -106,10 +154,11 @@ public class ProcessUtils { sLogger.d(TAG + ": loadPlugin"); pluginController.load(new PluginCallback() { @Override public void onSuccess(Bundle bundle) { - completer.set(new IsolatedServiceInfo(pluginController)); + completer.set(new IsolatedServiceInfo( + startTimeMillis, pluginController)); } @Override public void onFailure(FailureType failure) { - completer.setException(new OnDevicePersonalizationException( + completer.setException(new OdpServiceException( Constants.STATUS_INTERNAL_ERROR, String.format("loadPlugin failed. %s", failure.toString()))); } @@ -133,7 +182,31 @@ public class ProcessUtils { completer.set(bundle); } @Override public void onFailure(FailureType failure) { - completer.setException(new OnDevicePersonalizationException( + completer.setException(new OdpServiceException( + Constants.STATUS_INTERNAL_ERROR, + String.format("executePlugin failed: %s", failure.toString()))); + } + }); + } catch (Exception e) { + completer.setException(e); + } + return "executePlugin"; + } + ); + } + + @NonNull static ListenableFuture<Void> unloadPlugin( + @NonNull PluginController pluginController) { + return CallbackToFutureAdapter.getFuture( + completer -> { + try { + sLogger.d(TAG + ": unloadPlugin"); + pluginController.unload(new PluginCallback() { + @Override public void onSuccess(Bundle bundle) { + completer.set(null); + } + @Override public void onFailure(FailureType failure) { + completer.setException(new OdpServiceException( Constants.STATUS_INTERNAL_ERROR, String.format("executePlugin failed: %s", failure.toString()))); } @@ -161,6 +234,4 @@ public class ProcessUtils { // TODO(b/249345663) Perform any validation needed on the input. return vendorPackageName + "-" + taskName; } - - private ProcessUtils() {} } diff --git a/src/com/android/ondevicepersonalization/services/request/AppRequestFlow.java b/src/com/android/ondevicepersonalization/services/request/AppRequestFlow.java index e20d7a80..88d94214 100644 --- a/src/com/android/ondevicepersonalization/services/request/AppRequestFlow.java +++ b/src/com/android/ondevicepersonalization/services/request/AppRequestFlow.java @@ -17,14 +17,16 @@ package com.android.ondevicepersonalization.services.request; import android.adservices.ondevicepersonalization.Constants; -import android.adservices.ondevicepersonalization.ExecuteInput; -import android.adservices.ondevicepersonalization.ExecuteOutput; -import android.adservices.ondevicepersonalization.OnDevicePersonalizationException; +import android.adservices.ondevicepersonalization.EventLogRecord; +import android.adservices.ondevicepersonalization.ExecuteInputParcel; +import android.adservices.ondevicepersonalization.ExecuteOutputParcel; import android.adservices.ondevicepersonalization.RenderingConfig; +import android.adservices.ondevicepersonalization.RequestLogRecord; import android.adservices.ondevicepersonalization.UserData; import android.adservices.ondevicepersonalization.aidl.IExecuteCallback; import android.annotation.NonNull; import android.content.ComponentName; +import android.content.ContentValues; import android.content.Context; import android.os.Bundle; import android.os.PersistableBundle; @@ -32,17 +34,27 @@ import android.os.RemoteException; import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.Flags; +import com.android.ondevicepersonalization.services.FlagsFactory; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl; +import com.android.ondevicepersonalization.services.data.events.Event; import com.android.ondevicepersonalization.services.data.events.EventsDao; import com.android.ondevicepersonalization.services.data.events.Query; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; +import com.android.ondevicepersonalization.services.federatedcompute.FederatedComputeServiceImpl; import com.android.ondevicepersonalization.services.manifest.AppManifestConfig; import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor; import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo; -import com.android.ondevicepersonalization.services.process.ProcessUtils; +import com.android.ondevicepersonalization.services.process.ProcessRunner; +import com.android.ondevicepersonalization.services.statsd.ApiCallStats; +import com.android.ondevicepersonalization.services.statsd.OdpStatsdLogger; +import com.android.ondevicepersonalization.services.util.Clock; import com.android.ondevicepersonalization.services.util.CryptUtils; +import com.android.ondevicepersonalization.services.util.MonotonicClock; import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; +import com.android.ondevicepersonalization.services.util.StatsUtils; import com.google.common.util.concurrent.AsyncCallable; import com.google.common.util.concurrent.FluentFuture; @@ -50,10 +62,12 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.concurrent.TimeUnit; /** * Handles a surface package request from an app or SDK. @@ -72,20 +86,46 @@ public class AppRequestFlow { private final IExecuteCallback mCallback; @NonNull private final Context mContext; + private final long mStartTimeMillis; @NonNull private String mServiceClassName; + @VisibleForTesting + static class Injector { + ListeningExecutorService getExecutor() { + return OnDevicePersonalizationExecutors.getBackgroundExecutor(); + } + + Clock getClock() { + return MonotonicClock.getInstance(); + } + + Flags getFlags() { + return FlagsFactory.getFlags(); + } + + ListeningScheduledExecutorService getScheduledExecutor() { + return OnDevicePersonalizationExecutors.getScheduledExecutor(); + } + + ProcessRunner getProcessRunner() { + return ProcessRunner.getInstance(); + } + } + @NonNull - private final ListeningExecutorService mExecutorService; + private final Injector mInjector; public AppRequestFlow( @NonNull String callingPackageName, @NonNull ComponentName service, @NonNull PersistableBundle params, @NonNull IExecuteCallback callback, - @NonNull Context context) { + @NonNull Context context, + long startTimeMillis) { this(callingPackageName, service, params, - callback, context, OnDevicePersonalizationExecutors.getBackgroundExecutor()); + callback, context, startTimeMillis, + new Injector()); } @VisibleForTesting @@ -95,61 +135,79 @@ public class AppRequestFlow { @NonNull PersistableBundle params, @NonNull IExecuteCallback callback, @NonNull Context context, - @NonNull ListeningExecutorService executorService) { + long startTimeMillis, + @NonNull Injector injector) { sLogger.d(TAG + ": AppRequestFlow created."); mCallingPackageName = Objects.requireNonNull(callingPackageName); mService = Objects.requireNonNull(service); mParams = Objects.requireNonNull(params); mCallback = Objects.requireNonNull(callback); mContext = Objects.requireNonNull(context); - mExecutorService = Objects.requireNonNull(executorService); + mStartTimeMillis = startTimeMillis; + mInjector = Objects.requireNonNull(injector); } /** Runs the request processing flow. */ public void run() { - var unused = Futures.submit(() -> this.processRequest(), mExecutorService); + var unused = Futures.submit(() -> this.processRequest(), mInjector.getExecutor()); } private void processRequest() { try { - AppManifestConfig config = Objects.requireNonNull( - AppManifestConfigHelper.getAppManifestConfig( + if (!isPersonalizationStatusEnabled()) { + sLogger.d(TAG + ": Personalization is disabled."); + sendErrorResult(Constants.STATUS_PERSONALIZATION_DISABLED); + return; + } + AppManifestConfig config = null; + try { + config = Objects.requireNonNull( + AppManifestConfigHelper.getAppManifestConfig( mContext, mService.getPackageName())); + } catch (Exception e) { + sLogger.d(TAG + ": Failed to read manifest.", e); + sendErrorResult(Constants.STATUS_NAME_NOT_FOUND); + return; + } if (!mService.getClassName().equals(config.getServiceName())) { - // TODO(b/228200518): Define a new error code and map it to a specific - // exception type in the client API. - throw new OnDevicePersonalizationException( - Constants.STATUS_INTERNAL_ERROR, - "Name not found: " + mService.getClassName() - + " expected: " + config.getServiceName()); + sLogger.d(TAG + "service class not found"); + sendErrorResult(Constants.STATUS_CLASS_NOT_FOUND); + return; } mServiceClassName = Objects.requireNonNull(config.getServiceName()); - ListenableFuture<ExecuteOutput> resultFuture = FluentFuture.from( - ProcessUtils.loadIsolatedService( - TASK_NAME, mService.getPackageName(), mContext)) + ListenableFuture<IsolatedServiceInfo> loadFuture = + mInjector.getProcessRunner().loadIsolatedService( + TASK_NAME, mService.getPackageName()); + ListenableFuture<ExecuteOutputParcel> resultFuture = FluentFuture.from(loadFuture) .transformAsync( result -> executeAppRequest(result), - mExecutorService + mInjector.getExecutor() ) .transform( result -> { return result.getParcelable( - Constants.EXTRA_RESULT, ExecuteOutput.class); + Constants.EXTRA_RESULT, ExecuteOutputParcel.class); }, - mExecutorService + mInjector.getExecutor() ); ListenableFuture<Long> queryIdFuture = FluentFuture.from(resultFuture) - .transformAsync(input -> logQuery(input), mExecutorService); + .transformAsync(input -> logQuery(input), mInjector.getExecutor()); ListenableFuture<List<String>> slotResultTokensFuture = - Futures.whenAllSucceed(resultFuture, queryIdFuture) - .callAsync(new AsyncCallable<List<String>>() { - @Override - public ListenableFuture<List<String>> call() { - return createTokens(resultFuture, queryIdFuture); - } - }, mExecutorService); + FluentFuture.from( + Futures.whenAllSucceed(resultFuture, queryIdFuture) + .callAsync(new AsyncCallable<List<String>>() { + @Override + public ListenableFuture<List<String>> call() { + return createTokens(resultFuture, queryIdFuture); + } + }, mInjector.getExecutor())) + .withTimeout( + mInjector.getFlags().getIsolatedServiceDeadlineSeconds(), + TimeUnit.SECONDS, + mInjector.getScheduledExecutor() + ); Futures.addCallback( slotResultTokensFuture, @@ -165,55 +223,123 @@ public class AppRequestFlow { sendErrorResult(Constants.STATUS_INTERNAL_ERROR); } }, - mExecutorService); + mInjector.getExecutor()); + + var unused = Futures.whenAllComplete(loadFuture, slotResultTokensFuture) + .callAsync(() -> mInjector.getProcessRunner().unloadIsolatedService( + loadFuture.get()), + mInjector.getExecutor()); } catch (Exception e) { sLogger.e(TAG + ": Could not process request.", e); sendErrorResult(Constants.STATUS_INTERNAL_ERROR); } } - private ListenableFuture<Bundle> executeAppRequest(IsolatedServiceInfo isolatedServiceInfo) { + private ListenableFuture<Bundle> executeAppRequest( + IsolatedServiceInfo isolatedServiceInfo) { sLogger.d(TAG + ": executeAppRequest() started."); Bundle serviceParams = new Bundle(); - ExecuteInput input = - new ExecuteInput.Builder() + ExecuteInputParcel input = + new ExecuteInputParcel.Builder() .setAppPackageName(mCallingPackageName) .setAppParams(mParams) .build(); serviceParams.putParcelable(Constants.EXTRA_INPUT, input); DataAccessServiceImpl binder = new DataAccessServiceImpl( - mService.getPackageName(), mContext, true); + mService.getPackageName(), mContext, /* includeLocalData */ true, + /* includeEventData */ true); serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder); + FederatedComputeServiceImpl fcpBinder = new FederatedComputeServiceImpl( + mService.getPackageName(), mContext); + serviceParams.putBinder(Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, fcpBinder); UserDataAccessor userDataAccessor = new UserDataAccessor(); UserData userData = userDataAccessor.getUserData(); serviceParams.putParcelable(Constants.EXTRA_USER_DATA, userData); - return ProcessUtils.runIsolatedService( + ListenableFuture<Bundle> result = mInjector.getProcessRunner().runIsolatedService( isolatedServiceInfo, mServiceClassName, Constants.OP_EXECUTE, serviceParams); + return FluentFuture.from(result) + .transform( + val -> { + writeServiceRequestMetrics( + val, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_SUCCESS); + return val; + }, + mInjector.getExecutor() + ) + .catchingAsync( + Exception.class, + e -> { + writeServiceRequestMetrics( + null, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_INTERNAL_ERROR); + return Futures.immediateFailedFuture(e); + }, + mInjector.getExecutor() + ); } - private ListenableFuture<Long> logQuery(ExecuteOutput result) { + private ListenableFuture<Long> logQuery(ExecuteOutputParcel result) { sLogger.d(TAG + ": logQuery() started."); - // TODO(b/228200518): Extract log data from ExecuteOutput. + EventsDao eventsDao = EventsDao.getInstance(mContext); + // Insert query + List<ContentValues> rows = null; + if (result.getRequestLogRecord() != null) { + rows = result.getRequestLogRecord().getRows(); + } byte[] queryData = OnDevicePersonalizationFlatbufferUtils.createQueryData( - mService.getPackageName(), null, result.getRequestLogRecord().getRows()); + mService.getPackageName(), null, rows); Query query = new Query.Builder() .setServicePackageName(mService.getPackageName()) .setQueryData(queryData) .setTimeMillis(System.currentTimeMillis()) .build(); - long queryId = EventsDao.getInstance(mContext).insertQuery(query); + long queryId = eventsDao.insertQuery(query); if (queryId == -1) { return Futures.immediateFailedFuture(new RuntimeException("Failed to log query.")); } + // Insert events + List<Event> events = new ArrayList<>(); + List<EventLogRecord> eventLogRecords = result.getEventLogRecords(); + for (EventLogRecord eventLogRecord : eventLogRecords) { + RequestLogRecord requestLogRecord = eventLogRecord.getRequestLogRecord(); + // Verify requestLogRecord exists and has the corresponding rowIndex + if (requestLogRecord == null || requestLogRecord.getRequestId() == 0 + || eventLogRecord.getRowIndex() >= requestLogRecord.getRows().size()) { + continue; + } + // Make sure query exists for package in QUERY table + Query queryRow = eventsDao.readSingleQueryRow(requestLogRecord.getRequestId(), + mService.getPackageName()); + if (queryRow == null || eventLogRecord.getRowIndex() + >= OnDevicePersonalizationFlatbufferUtils.getContentValuesLengthFromQueryData( + queryRow.getQueryData())) { + continue; + } + Event event = new Event.Builder() + .setEventData(OnDevicePersonalizationFlatbufferUtils.createEventData( + eventLogRecord.getData())) + .setQueryId(requestLogRecord.getRequestId()) + .setRowIndex(eventLogRecord.getRowIndex()) + .setServicePackageName(mService.getPackageName()) + .setTimeMillis(System.currentTimeMillis()) + .setType(eventLogRecord.getType()) + .build(); + events.add(event); + } + if (!eventsDao.insertEvents(events)) { + return Futures.immediateFailedFuture(new RuntimeException("Failed to log events.")); + } + return Futures.immediateFuture(queryId); } private ListenableFuture<List<String>> createTokens( - ListenableFuture<ExecuteOutput> resultFuture, + ListenableFuture<ExecuteOutputParcel> resultFuture, ListenableFuture<Long> queryIdFuture) { try { sLogger.d(TAG + ": createTokens() started."); - ExecuteOutput result = Futures.getDone(resultFuture); + ExecuteOutputParcel result = Futures.getDone(resultFuture); long queryId = Futures.getDone(queryIdFuture); List<RenderingConfig> renderingConfigs = result.getRenderingConfigs(); Objects.requireNonNull(renderingConfigs); @@ -239,15 +365,23 @@ public class AppRequestFlow { } private void sendResult(List<String> slotResultTokens) { + if (slotResultTokens != null) { + sendSuccessResult(slotResultTokens); + } else { + sLogger.w(TAG + ": slotResultTokens is null or empty"); + sendErrorResult(Constants.STATUS_INTERNAL_ERROR); + } + } + + private void sendSuccessResult(List<String> slotResultTokens) { + int responseCode = Constants.STATUS_SUCCESS; try { - if (slotResultTokens != null && slotResultTokens.size() > 0) { - mCallback.onSuccess(slotResultTokens); - } else { - sLogger.w(TAG + ": slotResultTokens is null or empty"); - sendErrorResult(Constants.STATUS_INTERNAL_ERROR); - } + mCallback.onSuccess(slotResultTokens); } catch (RemoteException e) { + responseCode = Constants.STATUS_INTERNAL_ERROR; sLogger.w(TAG + ": Callback error", e); + } finally { + writeAppRequestMetrics(responseCode); } } @@ -256,6 +390,36 @@ public class AppRequestFlow { mCallback.onError(errorCode); } catch (RemoteException e) { sLogger.w(TAG + ": Callback error", e); + } finally { + writeAppRequestMetrics(errorCode); } } + + private void writeAppRequestMetrics(int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - mStartTimeMillis); + ApiCallStats callStats = new ApiCallStats.Builder(ApiCallStats.API_EXECUTE) + .setLatencyMillis(latencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } + + private void writeServiceRequestMetrics(Bundle result, long startTimeMillis, int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - startTimeMillis); + int overheadLatencyMillis = + (int) StatsUtils.getOverheadLatencyMillis(latencyMillis, result); + ApiCallStats callStats = new ApiCallStats.Builder(ApiCallStats.API_SERVICE_ON_EXECUTE) + .setLatencyMillis(latencyMillis) + .setOverheadLatencyMillis(overheadLatencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } + + private boolean isPersonalizationStatusEnabled() { + UserPrivacyStatus privacyStatus = UserPrivacyStatus.getInstance(); + return privacyStatus.isPersonalizationStatusEnabled(); + } } + + diff --git a/src/com/android/ondevicepersonalization/services/request/RenderFlow.java b/src/com/android/ondevicepersonalization/services/request/RenderFlow.java index 0052c09f..d88e67a5 100644 --- a/src/com/android/ondevicepersonalization/services/request/RenderFlow.java +++ b/src/com/android/ondevicepersonalization/services/request/RenderFlow.java @@ -17,8 +17,8 @@ package com.android.ondevicepersonalization.services.request; import android.adservices.ondevicepersonalization.Constants; -import android.adservices.ondevicepersonalization.RenderInput; -import android.adservices.ondevicepersonalization.RenderOutput; +import android.adservices.ondevicepersonalization.RenderInputParcel; +import android.adservices.ondevicepersonalization.RenderOutputParcel; import android.adservices.ondevicepersonalization.RenderingConfig; import android.adservices.ondevicepersonalization.RequestLogRecord; import android.adservices.ondevicepersonalization.aidl.IRequestSurfacePackageCallback; @@ -31,21 +31,31 @@ import android.view.SurfaceControlViewHost.SurfacePackage; import com.android.internal.annotations.VisibleForTesting; import com.android.ondevicepersonalization.internal.util.LoggerFactory; +import com.android.ondevicepersonalization.services.Flags; +import com.android.ondevicepersonalization.services.FlagsFactory; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.display.DisplayHelper; import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo; -import com.android.ondevicepersonalization.services.process.ProcessUtils; +import com.android.ondevicepersonalization.services.process.ProcessRunner; +import com.android.ondevicepersonalization.services.statsd.ApiCallStats; +import com.android.ondevicepersonalization.services.statsd.OdpStatsdLogger; +import com.android.ondevicepersonalization.services.util.Clock; import com.android.ondevicepersonalization.services.util.CryptUtils; +import com.android.ondevicepersonalization.services.util.MonotonicClock; +import com.android.ondevicepersonalization.services.util.StatsUtils; import com.google.common.util.concurrent.FluentFuture; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; import java.util.Objects; +import java.util.concurrent.TimeUnit; /** * Handles a surface package request from an app or SDK. @@ -64,6 +74,22 @@ public class RenderFlow { SlotWrapper decryptToken(String slotResultToken) throws Exception { return (SlotWrapper) CryptUtils.decrypt(slotResultToken); } + + Clock getClock() { + return MonotonicClock.getInstance(); + } + + Flags getFlags() { + return FlagsFactory.getFlags(); + } + + ListeningScheduledExecutorService getScheduledExecutor() { + return OnDevicePersonalizationExecutors.getScheduledExecutor(); + } + + ProcessRunner getProcessRunner() { + return ProcessRunner.getInstance(); + } } @NonNull @@ -77,6 +103,7 @@ public class RenderFlow { private final IRequestSurfacePackageCallback mCallback; @NonNull private final Context mContext; + private final long mStartTimeMillis; @NonNull private final Injector mInjector; @NonNull @@ -93,9 +120,10 @@ public class RenderFlow { int width, int height, @NonNull IRequestSurfacePackageCallback callback, - @NonNull Context context) { + @NonNull Context context, + long startTimeMillis) { this(slotResultToken, hostToken, displayId, width, height, - callback, context, + callback, context, startTimeMillis, new Injector(), new DisplayHelper(context)); } @@ -109,6 +137,7 @@ public class RenderFlow { int height, @NonNull IRequestSurfacePackageCallback callback, @NonNull Context context, + long startTimeMillis, @NonNull Injector injector, @NonNull DisplayHelper displayHelper) { sLogger.d(TAG + ": RenderFlow created."); @@ -118,6 +147,7 @@ public class RenderFlow { mWidth = width; mHeight = height; mCallback = Objects.requireNonNull(callback); + mStartTimeMillis = startTimeMillis; mInjector = Objects.requireNonNull(injector); mContext = Objects.requireNonNull(context); mDisplayHelper = Objects.requireNonNull(displayHelper); @@ -130,6 +160,11 @@ public class RenderFlow { private void processRequest() { try { + if (!isPersonalizationStatusEnabled()) { + sLogger.d(TAG + ": Personalization is disabled."); + sendErrorResult(Constants.STATUS_PERSONALIZATION_DISABLED); + return; + } SlotWrapper slotWrapper = Objects.requireNonNull( mInjector.decryptToken(mSlotResultToken)); mServicePackageName = Objects.requireNonNull( @@ -138,8 +173,16 @@ public class RenderFlow { AppManifestConfigHelper.getServiceNameFromOdpSettings( mContext, mServicePackageName)); + ListenableFuture<IsolatedServiceInfo> loadFuture = + mInjector.getProcessRunner().loadIsolatedService( + TASK_NAME, mServicePackageName); ListenableFuture<SurfacePackage> surfacePackageFuture = - renderContentForSlot(slotWrapper); + FluentFuture.from(renderContentForSlot(loadFuture, slotWrapper)) + .withTimeout( + mInjector.getFlags().getIsolatedServiceDeadlineSeconds(), + TimeUnit.SECONDS, + mInjector.getScheduledExecutor() + ); Futures.addCallback( surfacePackageFuture, @@ -156,6 +199,11 @@ public class RenderFlow { } }, mInjector.getExecutor()); + + var unused = Futures.whenAllComplete(loadFuture, surfacePackageFuture) + .callAsync(() -> mInjector.getProcessRunner().unloadIsolatedService( + loadFuture.get()), + mInjector.getExecutor()); } catch (Exception e) { sLogger.e(TAG + ": Could not process request.", e); sendErrorResult(Constants.STATUS_INTERNAL_ERROR); @@ -163,6 +211,7 @@ public class RenderFlow { } private ListenableFuture<SurfacePackage> renderContentForSlot( + @NonNull ListenableFuture<IsolatedServiceInfo> loadFuture, @NonNull SlotWrapper slotWrapper ) { try { @@ -173,15 +222,14 @@ public class RenderFlow { Objects.requireNonNull(slotWrapper.getRenderingConfig()); long queryId = slotWrapper.getQueryId(); - return FluentFuture.from(ProcessUtils.loadIsolatedService( - TASK_NAME, mServicePackageName, mContext)) + return FluentFuture.from(loadFuture) .transformAsync( loadResult -> executeRenderContentRequest( loadResult, slotWrapper.getSlotIndex(), renderingConfig), mInjector.getExecutor()) .transform(result -> { return result.getParcelable( - Constants.EXTRA_RESULT, RenderOutput.class); + Constants.EXTRA_RESULT, RenderOutputParcel.class); }, mInjector.getExecutor()) .transform( result -> mDisplayHelper.generateHtml(result, mServicePackageName), @@ -207,8 +255,8 @@ public class RenderFlow { RenderingConfig renderingConfig) { sLogger.d(TAG + "executeRenderContentRequest() started."); Bundle serviceParams = new Bundle(); - RenderInput input = - new RenderInput.Builder() + RenderInputParcel input = + new RenderInputParcel.Builder() .setHeight(mHeight) .setWidth(mWidth) .setRenderingConfigIndex(slotIndex) @@ -216,23 +264,52 @@ public class RenderFlow { .build(); serviceParams.putParcelable(Constants.EXTRA_INPUT, input); DataAccessServiceImpl binder = new DataAccessServiceImpl( - mServicePackageName, mContext, false); + mServicePackageName, mContext, /* includeLocalData */ false, + /* includeEventData */ false); serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder); - return ProcessUtils.runIsolatedService( + ListenableFuture<Bundle> result = mInjector.getProcessRunner().runIsolatedService( isolatedServiceInfo, mServiceClassName, Constants.OP_RENDER, serviceParams); + return FluentFuture.from(result) + .transform( + val -> { + writeServiceRequestMetrics( + val, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_SUCCESS); + return val; + }, + mInjector.getExecutor() + ) + .catchingAsync( + Exception.class, + e -> { + writeServiceRequestMetrics( + null, isolatedServiceInfo.getStartTimeMillis(), + Constants.STATUS_INTERNAL_ERROR); + return Futures.immediateFailedFuture(e); + }, + mInjector.getExecutor() + ); } private void sendDisplayResult(SurfacePackage surfacePackage) { + if (surfacePackage != null) { + sendSuccessResult(surfacePackage); + } else { + sLogger.w(TAG + ": surfacePackages is null or empty"); + sendErrorResult(Constants.STATUS_INTERNAL_ERROR); + } + } + + private void sendSuccessResult(SurfacePackage surfacePackage) { + int responseCode = Constants.STATUS_SUCCESS; try { - if (surfacePackage != null) { - mCallback.onSuccess(surfacePackage); - } else { - sLogger.w(TAG + ": surfacePackages is null or empty"); - sendErrorResult(Constants.STATUS_INTERNAL_ERROR); - } + mCallback.onSuccess(surfacePackage); } catch (RemoteException e) { + responseCode = Constants.STATUS_INTERNAL_ERROR; sLogger.w(TAG + ": Callback error", e); + } finally { + writeAppRequestMetrics(responseCode); } } @@ -241,6 +318,34 @@ public class RenderFlow { mCallback.onError(errorCode); } catch (RemoteException e) { sLogger.w(TAG + ": Callback error", e); + } finally { + writeAppRequestMetrics(errorCode); } } + + private void writeAppRequestMetrics(int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - mStartTimeMillis); + ApiCallStats callStats = new ApiCallStats.Builder(ApiCallStats.API_REQUEST_SURFACE_PACKAGE) + .setLatencyMillis(latencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } + + private void writeServiceRequestMetrics(Bundle result, long startTimeMillis, int responseCode) { + int latencyMillis = (int) (mInjector.getClock().elapsedRealtime() - startTimeMillis); + int overheadLatencyMillis = + (int) StatsUtils.getOverheadLatencyMillis(latencyMillis, result); + ApiCallStats callStats = new ApiCallStats.Builder(ApiCallStats.API_SERVICE_ON_RENDER) + .setLatencyMillis(latencyMillis) + .setOverheadLatencyMillis(overheadLatencyMillis) + .setResponseCode(responseCode) + .build(); + OdpStatsdLogger.getInstance().logApiCallStats(callStats); + } + + private boolean isPersonalizationStatusEnabled() { + UserPrivacyStatus privacyStatus = UserPrivacyStatus.getInstance(); + return privacyStatus.isPersonalizationStatusEnabled(); + } } diff --git a/src/com/android/ondevicepersonalization/services/statsd/ApiCallStats.java b/src/com/android/ondevicepersonalization/services/statsd/ApiCallStats.java new file mode 100644 index 00000000..56c1b531 --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/statsd/ApiCallStats.java @@ -0,0 +1,326 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.statsd; + +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_CLASS__UNKNOWN; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__EXECUTE; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__REQUEST_SURFACE_PACKAGE; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_DOWNLOAD_COMPLETED; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_EVENT; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_EXECUTE; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_RENDER; +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_TRAINING_EXAMPLE; + +import com.android.ondevicepersonalization.internal.util.DataClass; + +/** + * Class holds OnDevicePersonalizationApiCalled defined at + * frameworks/proto_logging/stats/atoms/ondevicepersonalization/ondevicepersonalization_extension_atoms.proto + */ +@DataClass(genBuilder = true, genEqualsHashCode = true) +public class ApiCallStats { + public static final int API_EXECUTE = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__EXECUTE; + public static final int API_REQUEST_SURFACE_PACKAGE = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__REQUEST_SURFACE_PACKAGE; + public static final int API_SERVICE_ON_EXECUTE = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_EXECUTE; + public static final int API_SERVICE_ON_DOWNLOAD_COMPLETED = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_DOWNLOAD_COMPLETED; + public static final int API_SERVICE_ON_RENDER = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_RENDER; + public static final int API_SERVICE_ON_EVENT = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_EVENT; + public static final int API_SERVICE_ON_TRAINING_EXAMPLE = + ON_DEVICE_PERSONALIZATION_API_CALLED__API_NAME__SERVICE_ON_TRAINING_EXAMPLE; + + private int mApiClass = ON_DEVICE_PERSONALIZATION_API_CALLED__API_CLASS__UNKNOWN; + private final @Api int mApiName; + private int mLatencyMillis = 0; + private int mResponseCode = 0; + private int mOverheadLatencyMillis = 0; + + + + // Code below generated by codegen v1.0.23. + // + // DO NOT MODIFY! + // CHECKSTYLE:OFF Generated code + // + // To regenerate run: + // $ codegen $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/statsd/ApiCallStats.java + // + // To exclude the generated code from IntelliJ auto-formatting enable (one-time): + // Settings > Editor > Code Style > Formatter Control + //@formatter:off + + + @android.annotation.IntDef(prefix = "API_", value = { + API_EXECUTE, + API_REQUEST_SURFACE_PACKAGE, + API_SERVICE_ON_EXECUTE, + API_SERVICE_ON_DOWNLOAD_COMPLETED, + API_SERVICE_ON_RENDER, + API_SERVICE_ON_EVENT, + API_SERVICE_ON_TRAINING_EXAMPLE + }) + @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.SOURCE) + @DataClass.Generated.Member + public @interface Api {} + + @DataClass.Generated.Member + public static String apiToString(@Api int value) { + switch (value) { + case API_EXECUTE: + return "API_EXECUTE"; + case API_REQUEST_SURFACE_PACKAGE: + return "API_REQUEST_SURFACE_PACKAGE"; + case API_SERVICE_ON_EXECUTE: + return "API_SERVICE_ON_EXECUTE"; + case API_SERVICE_ON_DOWNLOAD_COMPLETED: + return "API_SERVICE_ON_DOWNLOAD_COMPLETED"; + case API_SERVICE_ON_RENDER: + return "API_SERVICE_ON_RENDER"; + case API_SERVICE_ON_EVENT: + return "API_SERVICE_ON_EVENT"; + case API_SERVICE_ON_TRAINING_EXAMPLE: + return "API_SERVICE_ON_TRAINING_EXAMPLE"; + default: return Integer.toHexString(value); + } + } + + @DataClass.Generated.Member + /* package-private */ ApiCallStats( + int apiClass, + @Api int apiName, + int latencyMillis, + int responseCode, + int overheadLatencyMillis) { + this.mApiClass = apiClass; + this.mApiName = apiName; + + if (!(mApiName == API_EXECUTE) + && !(mApiName == API_REQUEST_SURFACE_PACKAGE) + && !(mApiName == API_SERVICE_ON_EXECUTE) + && !(mApiName == API_SERVICE_ON_DOWNLOAD_COMPLETED) + && !(mApiName == API_SERVICE_ON_RENDER) + && !(mApiName == API_SERVICE_ON_EVENT) + && !(mApiName == API_SERVICE_ON_TRAINING_EXAMPLE)) { + throw new java.lang.IllegalArgumentException( + "apiName was " + mApiName + " but must be one of: " + + "API_EXECUTE(" + API_EXECUTE + "), " + + "API_REQUEST_SURFACE_PACKAGE(" + API_REQUEST_SURFACE_PACKAGE + "), " + + "API_SERVICE_ON_EXECUTE(" + API_SERVICE_ON_EXECUTE + "), " + + "API_SERVICE_ON_DOWNLOAD_COMPLETED(" + API_SERVICE_ON_DOWNLOAD_COMPLETED + "), " + + "API_SERVICE_ON_RENDER(" + API_SERVICE_ON_RENDER + "), " + + "API_SERVICE_ON_EVENT(" + API_SERVICE_ON_EVENT + "), " + + "API_SERVICE_ON_TRAINING_EXAMPLE(" + API_SERVICE_ON_TRAINING_EXAMPLE + ")"); + } + + this.mLatencyMillis = latencyMillis; + this.mResponseCode = responseCode; + this.mOverheadLatencyMillis = overheadLatencyMillis; + + // onConstructed(); // You can define this method to get a callback + } + + @DataClass.Generated.Member + public int getApiClass() { + return mApiClass; + } + + @DataClass.Generated.Member + public @Api int getApiName() { + return mApiName; + } + + @DataClass.Generated.Member + public int getLatencyMillis() { + return mLatencyMillis; + } + + @DataClass.Generated.Member + public int getResponseCode() { + return mResponseCode; + } + + @DataClass.Generated.Member + public int getOverheadLatencyMillis() { + return mOverheadLatencyMillis; + } + + @Override + @DataClass.Generated.Member + public boolean equals(@android.annotation.Nullable Object o) { + // You can override field equality logic by defining either of the methods like: + // boolean fieldNameEquals(ApiCallStats other) { ... } + // boolean fieldNameEquals(FieldType otherValue) { ... } + + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + @SuppressWarnings("unchecked") + ApiCallStats that = (ApiCallStats) o; + //noinspection PointlessBooleanExpression + return true + && mApiClass == that.mApiClass + && mApiName == that.mApiName + && mLatencyMillis == that.mLatencyMillis + && mResponseCode == that.mResponseCode + && mOverheadLatencyMillis == that.mOverheadLatencyMillis; + } + + @Override + @DataClass.Generated.Member + public int hashCode() { + // You can override field hashCode logic by defining methods like: + // int fieldNameHashCode() { ... } + + int _hash = 1; + _hash = 31 * _hash + mApiClass; + _hash = 31 * _hash + mApiName; + _hash = 31 * _hash + mLatencyMillis; + _hash = 31 * _hash + mResponseCode; + _hash = 31 * _hash + mOverheadLatencyMillis; + return _hash; + } + + /** + * A builder for {@link ApiCallStats} + */ + @SuppressWarnings("WeakerAccess") + @DataClass.Generated.Member + public static class Builder { + + private int mApiClass; + private @Api int mApiName; + private int mLatencyMillis; + private int mResponseCode; + private int mOverheadLatencyMillis; + + private long mBuilderFieldsSet = 0L; + + public Builder( + @Api int apiName) { + mApiName = apiName; + + if (!(mApiName == API_EXECUTE) + && !(mApiName == API_REQUEST_SURFACE_PACKAGE) + && !(mApiName == API_SERVICE_ON_EXECUTE) + && !(mApiName == API_SERVICE_ON_DOWNLOAD_COMPLETED) + && !(mApiName == API_SERVICE_ON_RENDER) + && !(mApiName == API_SERVICE_ON_EVENT) + && !(mApiName == API_SERVICE_ON_TRAINING_EXAMPLE)) { + throw new java.lang.IllegalArgumentException( + "apiName was " + mApiName + " but must be one of: " + + "API_EXECUTE(" + API_EXECUTE + "), " + + "API_REQUEST_SURFACE_PACKAGE(" + API_REQUEST_SURFACE_PACKAGE + "), " + + "API_SERVICE_ON_EXECUTE(" + API_SERVICE_ON_EXECUTE + "), " + + "API_SERVICE_ON_DOWNLOAD_COMPLETED(" + API_SERVICE_ON_DOWNLOAD_COMPLETED + "), " + + "API_SERVICE_ON_RENDER(" + API_SERVICE_ON_RENDER + "), " + + "API_SERVICE_ON_EVENT(" + API_SERVICE_ON_EVENT + "), " + + "API_SERVICE_ON_TRAINING_EXAMPLE(" + API_SERVICE_ON_TRAINING_EXAMPLE + ")"); + } + + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setApiClass(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x1; + mApiClass = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setApiName(@Api int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x2; + mApiName = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setLatencyMillis(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x4; + mLatencyMillis = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setResponseCode(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x8; + mResponseCode = value; + return this; + } + + @DataClass.Generated.Member + public @android.annotation.NonNull Builder setOverheadLatencyMillis(int value) { + checkNotUsed(); + mBuilderFieldsSet |= 0x10; + mOverheadLatencyMillis = value; + return this; + } + + /** Builds the instance. This builder should not be touched after calling this! */ + public @android.annotation.NonNull ApiCallStats build() { + checkNotUsed(); + mBuilderFieldsSet |= 0x20; // Mark builder used + + if ((mBuilderFieldsSet & 0x1) == 0) { + mApiClass = ON_DEVICE_PERSONALIZATION_API_CALLED__API_CLASS__UNKNOWN; + } + if ((mBuilderFieldsSet & 0x4) == 0) { + mLatencyMillis = 0; + } + if ((mBuilderFieldsSet & 0x8) == 0) { + mResponseCode = 0; + } + if ((mBuilderFieldsSet & 0x10) == 0) { + mOverheadLatencyMillis = 0; + } + ApiCallStats o = new ApiCallStats( + mApiClass, + mApiName, + mLatencyMillis, + mResponseCode, + mOverheadLatencyMillis); + return o; + } + + private void checkNotUsed() { + if ((mBuilderFieldsSet & 0x20) != 0) { + throw new IllegalStateException( + "This Builder should not be reused. Use a new Builder instance instead"); + } + } + } + + @DataClass.Generated( + time = 1696624777344L, + codegenVersion = "1.0.23", + sourceFile = "packages/modules/OnDevicePersonalization/src/com/android/ondevicepersonalization/services/statsd/ApiCallStats.java", + inputSignatures = "public static final int API_EXECUTE\npublic static final int API_REQUEST_SURFACE_PACKAGE\npublic static final int API_SERVICE_ON_EXECUTE\npublic static final int API_SERVICE_ON_DOWNLOAD_COMPLETED\npublic static final int API_SERVICE_ON_RENDER\npublic static final int API_SERVICE_ON_EVENT\npublic static final int API_SERVICE_ON_TRAINING_EXAMPLE\nprivate int mApiClass\nprivate final @com.android.ondevicepersonalization.services.statsd.ApiCallStats.Api int mApiName\nprivate int mLatencyMillis\nprivate int mResponseCode\nprivate int mOverheadLatencyMillis\nclass ApiCallStats extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") + @Deprecated + private void __metadata() {} + + + //@formatter:on + // End of generated code + +} diff --git a/src/com/android/ondevicepersonalization/services/statsd/OdpStatsdLogger.java b/src/com/android/ondevicepersonalization/services/statsd/OdpStatsdLogger.java new file mode 100644 index 00000000..4b0cdf6b --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/statsd/OdpStatsdLogger.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.statsd; + +import static com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog.ONDEVICEPERSONALIZATION_API_CALLED; + +import com.android.ondevicepersonalization.OnDevicePersonalizationStatsLog; + +/** Log API stats and client error stats to StatsD. */ +public class OdpStatsdLogger { + private static volatile OdpStatsdLogger sStatsdLogger = null; + + /** Returns an instance of {@link OdpStatsdLogger}. */ + public static OdpStatsdLogger getInstance() { + if (sStatsdLogger == null) { + synchronized (OdpStatsdLogger.class) { + if (sStatsdLogger == null) { + sStatsdLogger = new OdpStatsdLogger(); + } + } + } + return sStatsdLogger; + } + + /** Log API call stats e.g. response code, API name etc. */ + public void logApiCallStats(ApiCallStats apiCallStats) { + OnDevicePersonalizationStatsLog.write( + ONDEVICEPERSONALIZATION_API_CALLED, + apiCallStats.getApiClass(), + apiCallStats.getApiName(), + apiCallStats.getLatencyMillis(), + apiCallStats.getResponseCode(), + apiCallStats.getOverheadLatencyMillis()); + } +} diff --git a/src/com/android/ondevicepersonalization/services/util/Clock.java b/src/com/android/ondevicepersonalization/services/util/Clock.java new file mode 100644 index 00000000..3f6f040c --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/util/Clock.java @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.util; + +/** Wrapper of time operations. */ +public interface Clock { + /** Returns milliseconds since boot, including time spent in sleep. */ + long elapsedRealtime(); + + /** Get the current time of the clock in milliseconds. */ + long currentTimeMillis(); +} diff --git a/src/com/android/ondevicepersonalization/services/util/DebugUtils.java b/src/com/android/ondevicepersonalization/services/util/DebugUtils.java new file mode 100644 index 00000000..8d0e09cb --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/util/DebugUtils.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.util; + +import android.annotation.NonNull; +import android.content.ContentResolver; +import android.content.Context; +import android.os.Build; +import android.provider.Settings; + +import java.util.Objects; + +/** Fuctions for testing and debugging. */ +public class DebugUtils { + /** Returns true if the device is debuggable. */ + public static boolean isDeveloperModeEnabled(@NonNull Context context) { + ContentResolver resolver = Objects.requireNonNull(context.getContentResolver()); + return Build.isDebuggable() + || Settings.Global.getInt( + resolver, Settings.Global.DEVELOPMENT_SETTINGS_ENABLED, 0) != 0; + } + + private DebugUtils() {} +} diff --git a/src/com/android/ondevicepersonalization/services/util/MonotonicClock.java b/src/com/android/ondevicepersonalization/services/util/MonotonicClock.java new file mode 100644 index 00000000..3b8cd1cb --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/util/MonotonicClock.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.util; + +import android.os.SystemClock; + +/** + * The implementation of {@link Clock}. Allows replacement of clock operations for testing. It is + * monotonic until device reboots. + */ +public class MonotonicClock implements Clock { + private static final MonotonicClock INSTANCE = new MonotonicClock(); + + private final long mStartTimestampMs; + + public static Clock getInstance() { + return INSTANCE; + } + + private MonotonicClock() { + mStartTimestampMs = System.currentTimeMillis() - SystemClock.elapsedRealtime(); + } + + @Override + public long currentTimeMillis() { + return mStartTimestampMs + elapsedRealtime(); + } + + @Override + public long elapsedRealtime() { + return SystemClock.elapsedRealtime(); + } +} diff --git a/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtils.java b/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtils.java index 4c9a3469..20de1412 100644 --- a/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtils.java +++ b/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtils.java @@ -132,6 +132,15 @@ public class OnDevicePersonalizationFlatbufferUtils { } /** + * Retrieves the length of the rows in a QueryField flatbuffer. + */ + public static int getContentValuesLengthFromQueryData(byte[] queryData) { + QueryFields queryFields = QueryData.getRootAsQueryData( + ByteBuffer.wrap(queryData)).queryFields(0); + return queryFields.rowsLength(); + } + + /** * Retrieves the KeyValueList in an EventData flatbuffer as a ContentValues object. */ public static ContentValues getContentValuesFromEventData(byte[] eventData) { diff --git a/src/com/android/ondevicepersonalization/services/util/StatsUtils.java b/src/com/android/ondevicepersonalization/services/util/StatsUtils.java new file mode 100644 index 00000000..a4f2a892 --- /dev/null +++ b/src/com/android/ondevicepersonalization/services/util/StatsUtils.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.util; + +import android.adservices.ondevicepersonalization.CalleeMetadata; +import android.adservices.ondevicepersonalization.Constants; +import android.os.Bundle; + +/** Utilities for stats logging */ +public class StatsUtils { + /** Subtracts callee reported latency from caller reported latency. */ + public static long getOverheadLatencyMillis(long callerLatencyMillis, Bundle result) { + long calleeLatencyMillis = callerLatencyMillis; + if (result != null) { + CalleeMetadata metadata = + result.getParcelable(Constants.EXTRA_CALLEE_METADATA, CalleeMetadata.class); + if (metadata != null) { + if (metadata.getElapsedTimeMillis() > 0 + && metadata.getElapsedTimeMillis() < callerLatencyMillis) { + calleeLatencyMillis = metadata.getElapsedTimeMillis(); + } + } + } + return callerLatencyMillis - calleeLatencyMillis; + } + + private StatsUtils() {} +} diff --git a/systemservice/Android.bp b/systemservice/Android.bp index 1286fbbe..5e1c2370 100644 --- a/systemservice/Android.bp +++ b/systemservice/Android.bp @@ -36,6 +36,7 @@ java_sdk_library { defaults: ["framework-system-server-module-defaults"], libs: [ "framework-ondevicepersonalization.impl", + "modules-utils-preconditions" ], visibility: [ "//packages/modules/OnDevicePersonalization/tests:__subpackages__", diff --git a/systemservice/java/com/android/server/ondevicepersonalization/BooleanFileDataStore.java b/systemservice/java/com/android/server/ondevicepersonalization/BooleanFileDataStore.java new file mode 100644 index 00000000..127e4778 --- /dev/null +++ b/systemservice/java/com/android/server/ondevicepersonalization/BooleanFileDataStore.java @@ -0,0 +1,223 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.ondevicepersonalization; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.os.PersistableBundle; +import android.util.AtomicFile; + +import com.android.internal.annotations.GuardedBy; +import com.android.internal.annotations.VisibleForTesting; +import com.android.internal.util.Preconditions; +import com.android.ondevicepersonalization.internal.util.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * A generic key-value datastore utilizing {@link android.util.AtomicFile} and {@link + * android.os.PersistableBundle} to read/write a simple key/value map to file. + * This class is thread-safe. + * @hide + */ +public class BooleanFileDataStore { + private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); + private static final String TAG = "BooleanFileDataStore"; + private final ReadWriteLock mReadWriteLock = new ReentrantReadWriteLock(); + private final Lock mReadLock = mReadWriteLock.readLock(); + private final Lock mWriteLock = mReadWriteLock.writeLock(); + + private final AtomicFile mAtomicFile; + private final Map<String, Boolean> mLocalMap = new HashMap<>(); + + // TODO (b/300993651): make the datastore access singleton. + // TODO (b/301131410): add version history feature. + public BooleanFileDataStore(@NonNull String parentPath, @NonNull String filename) { + Objects.requireNonNull(parentPath); + Objects.requireNonNull(filename); + Preconditions.checkStringNotEmpty(parentPath); + Preconditions.checkStringNotEmpty(filename); + mAtomicFile = new AtomicFile(new File(parentPath, filename)); + } + + /** + * Loads data from the datastore file on disk. + * @throws IOException if file read fails. + */ + public void initialize() throws IOException { + sLogger.d(TAG + ": reading from file " + mAtomicFile.getBaseFile()); + mReadLock.lock(); + try { + readFromFile(); + } finally { + mReadLock.unlock(); + } + } + + /** + * Stores a value to the datastore file, which is immediately committed. + * @param key A non-null, non-empty String to store the {@code value}. + * @param value A boolean to be stored. + * @throws IOException if file write fails. + * @throws NullPointerException if {@code key} is null. + * @throws IllegalArgumentException if (@code key) is an empty string. + */ + public void put(@NonNull String key, boolean value) throws IOException { + Objects.requireNonNull(key); + Preconditions.checkStringNotEmpty(key, "Key must not be empty."); + + mWriteLock.lock(); + try { + mLocalMap.put(key, value); + writeToFile(); + } finally { + mWriteLock.unlock(); + } + } + + /** + * Retrieves a boolean value from the loaded datastore file. + * + * @param key A non-null, non-empty String key to fetch a value from. + * @return The boolean value stored against a {@code key}, or null if it doesn't exist. + * @throws IllegalArgumentException if {@code key} is an empty string. + * @throws NullPointerException if {@code key} is null. + */ + @Nullable + public Boolean get(@NonNull String key) { + Objects.requireNonNull(key); + Preconditions.checkStringNotEmpty(key); + + mReadLock.lock(); + try { + return mLocalMap.get(key); + } finally { + mReadLock.unlock(); + } + } + + /** + * Retrieves a {@link Set} of all keys loaded from the datastore file. + * + * @return A {@link Set} of {@link String} keys currently in the loaded datastore + */ + @NonNull + public Set<String> keySet() { + mReadLock.lock(); + try { + return Set.copyOf(mLocalMap.keySet()); + } finally { + mReadLock.unlock(); + } + } + + /** + * Clears all entries from the datastore file and committed immediately. + * + * @throws IOException if file write fails. + */ + public void clear() throws IOException { + sLogger.d(TAG + ": clearing all entries from datastore"); + + mWriteLock.lock(); + try { + mLocalMap.clear(); + writeToFile(); + } finally { + mWriteLock.unlock(); + } + } + + @GuardedBy("mWriteLock") + private void writeToFile() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final PersistableBundle persistableBundle = new PersistableBundle(); + for (Map.Entry<String, Boolean> entry: mLocalMap.entrySet()) { + persistableBundle.putBoolean(entry.getKey(), entry.getValue()); + } + + persistableBundle.writeToStream(outputStream); + + FileOutputStream out = null; + try { + out = mAtomicFile.startWrite(); + out.write(outputStream.toByteArray()); + mAtomicFile.finishWrite(out); + } catch (IOException e) { + mAtomicFile.failWrite(out); + sLogger.e(TAG + ": write to file " + mAtomicFile.getBaseFile() + " failed."); + throw e; + } + } + + @GuardedBy("mReadLock") + private void readFromFile() throws IOException { + try { + final ByteArrayInputStream inputStream = new ByteArrayInputStream( + mAtomicFile.readFully()); + final PersistableBundle persistableBundle = PersistableBundle.readFromStream( + inputStream); + + mLocalMap.clear(); + for (String key: persistableBundle.keySet()) { + mLocalMap.put(key, persistableBundle.getBoolean(key)); + } + } catch (FileNotFoundException e) { + sLogger.d(TAG + ": file not found exception."); + mLocalMap.clear(); + } catch (IOException e) { + sLogger.e(TAG + ": read from " + mAtomicFile.getBaseFile() + " failed"); + throw e; + } + } + + /** + * Delete the datastore file for testing. + */ + @VisibleForTesting + public void tearDownForTesting() { + mWriteLock.lock(); + try { + mAtomicFile.delete(); + mLocalMap.clear(); + } finally { + mWriteLock.unlock(); + } + } + + /** + * Clear the loaded content from local map for testing. + */ + @VisibleForTesting + public void clearLocalMapForTesting() { + mWriteLock.lock(); + mLocalMap.clear(); + mWriteLock.unlock(); + } +} diff --git a/systemservice/java/com/android/server/ondevicepersonalization/OnDevicePersonalizationSystemService.java b/systemservice/java/com/android/server/ondevicepersonalization/OnDevicePersonalizationSystemService.java index 89b8a84e..3de0e600 100644 --- a/systemservice/java/com/android/server/ondevicepersonalization/OnDevicePersonalizationSystemService.java +++ b/systemservice/java/com/android/server/ondevicepersonalization/OnDevicePersonalizationSystemService.java @@ -27,15 +27,31 @@ import android.util.Log; import com.android.internal.annotations.VisibleForTesting; import com.android.server.SystemService; +import java.io.IOException; +import java.util.Objects; + /** * @hide */ public class OnDevicePersonalizationSystemService extends IOnDevicePersonalizationSystemService.Stub { private static final String TAG = "OnDevicePersonalizationSystemService"; + // TODO(b/302991763): set up per-user directory if needed. + private static final String ODP_BASE_DIR = "/data/system/ondevicepersonalization/0/"; + private static final String CONFIG_FILE_IDENTIFIER = "CONFIG"; + public static final String PERSONALIZATION_STATUS_KEY = "PERSONALIZATION_STATUS"; + private BooleanFileDataStore mDataStore = null; + public static final int INTERNAL_SERVER_ERROR = 500; + public static final int KEY_NOT_FOUND_ERROR = 404; + // TODO(b/302992251): use a manager to access configs instead of directly exposing DataStore. @VisibleForTesting - OnDevicePersonalizationSystemService(Context context) { + OnDevicePersonalizationSystemService(Context context, BooleanFileDataStore dataStore) + throws IOException { + Objects.requireNonNull(context); + Objects.requireNonNull(dataStore); + this.mDataStore = dataStore; + mDataStore.initialize(); } @Override public void onRequest( @@ -48,6 +64,52 @@ public class OnDevicePersonalizationSystemService } } + @Override + public void setPersonalizationStatus(boolean enabled, + IOnDevicePersonalizationSystemServiceCallback callback) { + Bundle result = new Bundle(); + try { + mDataStore.put(PERSONALIZATION_STATUS_KEY, enabled); + // Confirm the value was updated. + Boolean statusResult = mDataStore.get(PERSONALIZATION_STATUS_KEY); + if (statusResult == null || statusResult.booleanValue() != enabled) { + callback.onError(INTERNAL_SERVER_ERROR); + return; + } + // Echo the result back + result.putBoolean(PERSONALIZATION_STATUS_KEY, statusResult); + callback.onResult(result); + } catch (IOException e) { + Log.e(TAG, "Unable to persist personalization status", e); + try { + callback.onError(INTERNAL_SERVER_ERROR); + } catch (RemoteException re) { + Log.e(TAG, "Callback error", e); + } + } catch (RemoteException e) { + Log.e(TAG, "Callback error", e); + } + } + + @Override + public void readPersonalizationStatus( + IOnDevicePersonalizationSystemServiceCallback callback) { + Boolean result = null; + Bundle bundle = new Bundle(); + try { + result = mDataStore.get(PERSONALIZATION_STATUS_KEY); + if (result == null) { + Log.d(TAG, "Unable to restore personalization status"); + callback.onError(KEY_NOT_FOUND_ERROR); + } else { + bundle.putBoolean(PERSONALIZATION_STATUS_KEY, result.booleanValue()); + callback.onResult(bundle); + } + } catch (RemoteException e) { + Log.e(TAG, "Callback error", e); + } + } + /** @hide */ public static class Lifecycle extends SystemService { private OnDevicePersonalizationSystemService mService; @@ -55,7 +117,12 @@ public class OnDevicePersonalizationSystemService /** @hide */ public Lifecycle(Context context) { super(context); - mService = new OnDevicePersonalizationSystemService(getContext()); + try { + mService = new OnDevicePersonalizationSystemService(getContext(), + new BooleanFileDataStore(ODP_BASE_DIR, CONFIG_FILE_IDENTIFIER)); + } catch (IOException e) { + Log.e(TAG, "Cannot initialize the system service.", e); + } } /** @hide */ diff --git a/tests/cts/endtoend/Android.bp b/tests/cts/endtoend/Android.bp new file mode 100644 index 00000000..95f26cab --- /dev/null +++ b/tests/cts/endtoend/Android.bp @@ -0,0 +1,51 @@ +// Copyright (C) 2022 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Make test APK +// ============================================================ + +package { + default_applicable_licenses: ["Android-Apache-2.0"], +} + +android_test { + name: "CtsOnDevicePersonalizationE2ETests", + srcs: ["src/**/*.java"], + defaults: ["framework-ondevicepersonalization-test-defaults"], + min_sdk_version: "Tiramisu", + target_sdk_version: "Tiramisu", + static_libs: [ + "androidx.test.core", + "androidx.test.ext.junit", + "androidx.test.ext.truth", + "androidx.test.rules", + ], + libs: [ + "android.test.base", + "android.test.runner", + "truth-prebuilt", + ], + data: [ + ":OdpTestingSampleService", + ], + resource_dirs: [ + "res", + ], + test_mainline_modules: ["com.google.android.ondevicepersonalization.apex"], + test_suites: [ + "general-tests", + "mts-ondevicepersonalization", + ], + test_config: "AndroidTest.xml", +} diff --git a/tests/cts/endtoend/AndroidManifest.xml b/tests/cts/endtoend/AndroidManifest.xml new file mode 100644 index 00000000..286223b0 --- /dev/null +++ b/tests/cts/endtoend/AndroidManifest.xml @@ -0,0 +1,36 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- Copyright 2022 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> + +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.android.ondevicepersonalization.cts.e2e"> + + <application android:label="CtsOnDevicePersonalizationE2ETests"> + <uses-library android:name="android.test.runner" /> + <activity + android:name=".TestActivity" + android:exported="true"> + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> + </application> + <instrumentation android:name="androidx.test.runner.AndroidJUnitRunner" + android:targetPackage="com.android.ondevicepersonalization.cts.e2e" + android:label="OnDevicePersonalizationManager CTS Tests"> + </instrumentation> + +</manifest> diff --git a/tests/cts/endtoend/AndroidTest.xml b/tests/cts/endtoend/AndroidTest.xml new file mode 100644 index 00000000..9750f31c --- /dev/null +++ b/tests/cts/endtoend/AndroidTest.xml @@ -0,0 +1,48 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- Copyright 2022 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<configuration description="Config for OnDevicePersonalization CTS tests"> + <option name="config-descriptor:metadata" key="component" value="framework" /> + <option name="config-descriptor:metadata" key="parameter" value="not_instant_app" /> + <option name="config-descriptor:metadata" key="parameter" value="not_multi_abi" /> + <option name="config-descriptor:metadata" key="parameter" value="secondary_user" /> + + <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller"> + <option name="cleanup-apks" value="true"/> + <option name="test-file-name" value="CtsOnDevicePersonalizationE2ETests.apk"/> + <option name="test-file-name" value="OdpTestingSampleService.apk"/> + </target_preparer> + + <target_preparer class="com.android.tradefed.targetprep.RunCommandTargetPreparer"> + <option name="run-command" value="device_config set_sync_disabled_for_tests persistent" /> + <option name="run-command" value="device_config put on_device_personalization global_kill_switch false" /> + <option name="run-command" value="device_config put on_device_personalization enable_personalization_status_override true"/> + <option name="run-command" value="device_config put on_device_personalization personalization_status_override_value true"/> + <option name="teardown-command" value="device_config delete on_device_personalization global_kill_switch" /> + <option name="teardown-command" value="device_config delete on_device_personalization enable_personalization_status_override" /> + <option name="teardown-command" value="device_config delete on_device_personalization personalization_status_override_value" /> + <option name="teardown-command" value="device_config set_sync_disabled_for_tests none" /> + </target_preparer> + + <test class="com.android.tradefed.testtype.AndroidJUnitTest"> + <option name="hidden-api-checks" value="false" /> <!-- Allow hidden API uses --> + <option name="package" value="com.android.ondevicepersonalization.cts.e2e"/> + </test> + + <object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController"> + <option name="mainline-module-package-name" value="com.google.android.ondevicepersonalization" /> + </object> + <option name="config-descriptor:metadata" key="mainline-param" value="com.google.android.ondevicepersonalization.apex" /> +</configuration> diff --git a/tests/cts/endtoend/res/layout/activity_main.xml b/tests/cts/endtoend/res/layout/activity_main.xml new file mode 100644 index 00000000..5f390caf --- /dev/null +++ b/tests/cts/endtoend/res/layout/activity_main.xml @@ -0,0 +1,26 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2022 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:orientation="vertical" + android:layout_width="match_parent" + android:layout_height="match_parent"> + <SurfaceView + android:id="@+id/test_surface_view" + android:layout_width="200dp" + android:layout_height="200dp" /> +</LinearLayout> diff --git a/tests/cts/endtoend/src/com/android/ondevicepersonalization/cts/e2e/CtsOdpManagerTests.java b/tests/cts/endtoend/src/com/android/ondevicepersonalization/cts/e2e/CtsOdpManagerTests.java new file mode 100644 index 00000000..8546fa91 --- /dev/null +++ b/tests/cts/endtoend/src/com/android/ondevicepersonalization/cts/e2e/CtsOdpManagerTests.java @@ -0,0 +1,417 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.android.ondevicepersonalization.cts.e2e; + +import static android.view.Display.DEFAULT_DISPLAY; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import android.adservices.ondevicepersonalization.OnDevicePersonalizationManager; +import android.adservices.ondevicepersonalization.SurfacePackageToken; +import android.content.ComponentName; +import android.content.Context; +import android.content.pm.PackageManager.NameNotFoundException; +import android.hardware.display.DisplayManager; +import android.os.OutcomeReceiver; +import android.os.PersistableBundle; +import android.view.Display; +import android.view.SurfaceControlViewHost.SurfacePackage; +import android.view.SurfaceView; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.rules.ActivityScenarioRule; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; + +/** + * CTS Test cases for OnDevicePersonalizationManager APIs. + */ +@RunWith(AndroidJUnit4.class) +public class CtsOdpManagerTests { + private static final String SERVICE_PACKAGE = + "com.android.ondevicepersonalization.testing.sampleservice"; + private static final String SERVICE_CLASS = + "com.android.ondevicepersonalization.testing.sampleservice.SampleService"; + + private final Context mContext = ApplicationProvider.getApplicationContext(); + + @Rule + public final ActivityScenarioRule<TestActivity> mActivityScenarioRule = + new ActivityScenarioRule<>(TestActivity.class); + + @Test + public void testExecuteThrowsIfComponentNameMissing() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + + assertThrows( + NullPointerException.class, + () -> manager.execute( + null, + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + new ResultReceiver<List<SurfacePackageToken>>())); + } + + @Test + public void testExecuteThrowsIfParamsMissing() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + + assertThrows( + NullPointerException.class, + () -> manager.execute( + new ComponentName(SERVICE_PACKAGE, SERVICE_CLASS), + null, + Executors.newSingleThreadExecutor(), + new ResultReceiver<List<SurfacePackageToken>>())); + } + + @Test + public void testExecuteThrowsIfExecutorMissing() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + + assertThrows( + NullPointerException.class, + () -> manager.execute( + new ComponentName(SERVICE_PACKAGE, SERVICE_CLASS), + PersistableBundle.EMPTY, + null, + new ResultReceiver<List<SurfacePackageToken>>())); + } + + @Test + public void testExecuteThrowsIfReceiverMissing() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + + assertThrows( + NullPointerException.class, + () -> manager.execute( + new ComponentName(SERVICE_PACKAGE, SERVICE_CLASS), + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + null)); + } + + @Test + public void testExecuteThrowsIfPackageNameMissing() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + + assertThrows( + IllegalArgumentException.class, + () -> manager.execute( + new ComponentName("", SERVICE_CLASS), + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + new ResultReceiver<List<SurfacePackageToken>>())); + } + + @Test + public void testExecuteThrowsIfClassNameMissing() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + + assertThrows( + IllegalArgumentException.class, + () -> manager.execute( + new ComponentName(SERVICE_PACKAGE, ""), + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + new ResultReceiver<List<SurfacePackageToken>>())); + } + + @Test + public void testExecuteReturnsNameNotFoundIfServiceNotInstalled() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + var receiver = new ResultReceiver<List<SurfacePackageToken>>(); + manager.execute( + new ComponentName("somepackage", "someclass"), + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + receiver); + receiver.await(); + assertNull(receiver.getResult()); + assertTrue(receiver.getException() instanceof NameNotFoundException); + } + + @Test + public void testExecuteReturnsClassNotFoundIfServiceClassNotFound() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + var receiver = new ResultReceiver<List<SurfacePackageToken>>(); + manager.execute( + new ComponentName(SERVICE_PACKAGE, "someclass"), + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + receiver); + receiver.await(); + assertNull(receiver.getResult()); + assertTrue(receiver.getException() instanceof ClassNotFoundException); + } + + @Test + public void testExecute() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + assertNotNull(manager); + var receiver = new ResultReceiver<List<SurfacePackageToken>>(); + manager.execute( + new ComponentName(SERVICE_PACKAGE, SERVICE_CLASS), + PersistableBundle.EMPTY, + Executors.newSingleThreadExecutor(), + receiver); + receiver.await(); + List<SurfacePackageToken> results = receiver.getResult(); + assertNotNull(results); + assertEquals(1, results.size()); + SurfacePackageToken token = results.get(0); + assertNotNull(token); + } + + @Test + public void testRequestSurfacePackage() throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + var receiver = new ResultReceiver<SurfacePackage>(); + SurfaceView surfaceView = createSurfaceView(); + manager.requestSurfacePackage( + tokens.get(0), + surfaceView.getHostToken(), + getDisplayId(), + surfaceView.getWidth(), + surfaceView.getHeight(), + Executors.newSingleThreadExecutor(), + receiver); + receiver.await(); + assertNotNull(receiver.getResult()); + } + + @Test + public void testRequestSurfacePackageThrowsIfSurfacePackageTokenMissing() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + NullPointerException.class, + () -> manager.requestSurfacePackage( + null, + surfaceView.getHostToken(), + getDisplayId(), + surfaceView.getWidth(), + surfaceView.getHeight(), + Executors.newSingleThreadExecutor(), + new ResultReceiver<SurfacePackage>())); + } + + @Test + public void testRequestSurfacePackageThrowsIfSurfaceViewHostTokenMissing() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + NullPointerException.class, + () -> manager.requestSurfacePackage( + tokens.get(0), + null, + getDisplayId(), + surfaceView.getWidth(), + surfaceView.getHeight(), + Executors.newSingleThreadExecutor(), + new ResultReceiver<SurfacePackage>())); + } + + @Test + public void testRequestSurfacePackageThrowsIfInvalidDisplayId() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + IllegalArgumentException.class, + () -> manager.requestSurfacePackage( + tokens.get(0), + surfaceView.getHostToken(), + -1, + surfaceView.getWidth(), + surfaceView.getHeight(), + Executors.newSingleThreadExecutor(), + new ResultReceiver<SurfacePackage>())); + } + + @Test + public void testRequestSurfacePackageThrowsIfInvalidWidth() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + IllegalArgumentException.class, + () -> manager.requestSurfacePackage( + tokens.get(0), + surfaceView.getHostToken(), + getDisplayId(), + 0, + surfaceView.getHeight(), + Executors.newSingleThreadExecutor(), + new ResultReceiver<SurfacePackage>())); + } + + @Test + public void testRequestSurfacePackageThrowsIfInvalidHeight() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + IllegalArgumentException.class, + () -> manager.requestSurfacePackage( + tokens.get(0), + surfaceView.getHostToken(), + getDisplayId(), + surfaceView.getWidth(), + 0, + Executors.newSingleThreadExecutor(), + new ResultReceiver<SurfacePackage>())); + } + + @Test + public void testRequestSurfacePackageThrowsIfExecutorMissing() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + NullPointerException.class, + () -> manager.requestSurfacePackage( + tokens.get(0), + surfaceView.getHostToken(), + getDisplayId(), + surfaceView.getWidth(), + surfaceView.getHeight(), + null, + new ResultReceiver<SurfacePackage>())); + } + + @Test + public void testRequestSurfacePackageThrowsIfOutcomeReceiverMissing() + throws InterruptedException { + OnDevicePersonalizationManager manager = + mContext.getSystemService(OnDevicePersonalizationManager.class); + List<SurfacePackageToken> tokens = + runExecute(manager, PersistableBundle.EMPTY); + SurfaceView surfaceView = createSurfaceView(); + assertThrows( + NullPointerException.class, + () -> manager.requestSurfacePackage( + tokens.get(0), + surfaceView.getHostToken(), + getDisplayId(), + surfaceView.getWidth(), + surfaceView.getHeight(), + Executors.newSingleThreadExecutor(), + null)); + } + + int getDisplayId() { + final DisplayManager dm = mContext.getSystemService(DisplayManager.class); + final Display primaryDisplay = dm.getDisplay(DEFAULT_DISPLAY); + final Context windowContext = mContext.createDisplayContext(primaryDisplay); + return windowContext.getDisplay().getDisplayId(); + } + + SurfaceView createSurfaceView() throws InterruptedException { + ArrayBlockingQueue<SurfaceView> viewQueue = new ArrayBlockingQueue<>(1); + mActivityScenarioRule.getScenario().onActivity( + a -> viewQueue.add(a.findViewById(R.id.test_surface_view))); + return viewQueue.take(); + } + + private List<SurfacePackageToken> runExecute( + OnDevicePersonalizationManager manager, PersistableBundle params) + throws InterruptedException { + var receiver = new ResultReceiver<List<SurfacePackageToken>>(); + manager.execute( + new ComponentName(SERVICE_PACKAGE, SERVICE_CLASS), + params, + Executors.newSingleThreadExecutor(), + receiver); + receiver.await(); + List<SurfacePackageToken> results = receiver.getResult(); + return results; + } + + class ResultReceiver<T> implements OutcomeReceiver<T, Exception> { + private CountDownLatch mLatch = new CountDownLatch(1); + private T mResult; + private Exception mException; + @Override public void onResult(T result) { + mResult = result; + mLatch.countDown(); + } + @Override public void onError(Exception e) { + mException = e; + mLatch.countDown(); + } + void await() throws InterruptedException { + mLatch.await(); + } + T getResult() { + return mResult; + } + Exception getException() { + return mException; + } + } +} diff --git a/tests/cts/endtoend/src/com/android/ondevicepersonalization/cts/e2e/TestActivity.java b/tests/cts/endtoend/src/com/android/ondevicepersonalization/cts/e2e/TestActivity.java new file mode 100644 index 00000000..22556a3d --- /dev/null +++ b/tests/cts/endtoend/src/com/android/ondevicepersonalization/cts/e2e/TestActivity.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.cts.e2e; + +import android.app.Activity; +import android.os.Bundle; + +/** + * A simple activity that can contain views. + */ +public class TestActivity extends Activity { + @Override + public void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + } +} diff --git a/tests/cts/service/Android.bp b/tests/cts/service/Android.bp new file mode 100644 index 00000000..3d27b621 --- /dev/null +++ b/tests/cts/service/Android.bp @@ -0,0 +1,41 @@ +// Copyright (C) 2022 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package { + default_applicable_licenses: ["Android-Apache-2.0"], +} + +android_test_helper_app { + name: "OdpTestingSampleService", + defaults: ["platform_app_defaults"], + platform_apis: true, + srcs: [ + "src/**/*.java", + ], + libs: [ + "framework-ondevicepersonalization.impl", + ], + static_libs: [ + "androidx.annotation_annotation", + "androidx.core_core", + "guava", + ], + resource_dirs: [ + "res", + ], + target_sdk_version: "Tiramisu", + min_sdk_version: "Tiramisu", + manifest: "AndroidManifest.xml", + certificate: "platform", +} diff --git a/tests/cts/service/AndroidManifest.xml b/tests/cts/service/AndroidManifest.xml new file mode 100644 index 00000000..767ab6c6 --- /dev/null +++ b/tests/cts/service/AndroidManifest.xml @@ -0,0 +1,28 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2022 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.android.ondevicepersonalization.testing.sampleservice" + android:versionName="1.0.0" > + <application android:label="OdpTestingSampleService" + android:debuggable="true"> + <property android:name="android.ondevicepersonalization.ON_DEVICE_PERSONALIZATION_CONFIG" + android:resource="@xml/OdpSettings"></property> + <service android:name="com.android.ondevicepersonalization.testing.sampleservice.SampleService" + android:exported="true" android:isolatedProcess="true" /> + </application> +</manifest> diff --git a/tests/cts/service/res/raw/test_data1.json b/tests/cts/service/res/raw/test_data1.json new file mode 100644 index 00000000..173e9431 --- /dev/null +++ b/tests/cts/service/res/raw/test_data1.json @@ -0,0 +1,8 @@ +{ + "syncToken": 1662138000, + "contents": [ + { "key": "ad1", + "data": "{ \"price\": 5.0 }" + } + ] +} diff --git a/tests/cts/service/res/xml/OdpSettings.xml b/tests/cts/service/res/xml/OdpSettings.xml new file mode 100644 index 00000000..42eefdb1 --- /dev/null +++ b/tests/cts/service/res/xml/OdpSettings.xml @@ -0,0 +1,23 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2022 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<!-- Odp Settings, in XML resource --> +<on-device-personalization> + <service name="com.android.ondevicepersonalization.testing.sampleservice.SampleService" > + <download-settings url="android.resource://com.android.ondevicepersonalization.testing.sampleservice/raw/test_data1" /> + </service> +</on-device-personalization> diff --git a/tests/cts/service/src/com/android/ondevicepersonalization/testing/sampleservice/SampleService.java b/tests/cts/service/src/com/android/ondevicepersonalization/testing/sampleservice/SampleService.java new file mode 100644 index 00000000..b802ce53 --- /dev/null +++ b/tests/cts/service/src/com/android/ondevicepersonalization/testing/sampleservice/SampleService.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.testing.sampleservice; + +import android.adservices.ondevicepersonalization.ExecuteInput; +import android.adservices.ondevicepersonalization.ExecuteOutput; +import android.adservices.ondevicepersonalization.IsolatedService; +import android.adservices.ondevicepersonalization.IsolatedWorker; +import android.adservices.ondevicepersonalization.RenderInput; +import android.adservices.ondevicepersonalization.RenderOutput; +import android.adservices.ondevicepersonalization.RenderingConfig; +import android.adservices.ondevicepersonalization.RequestLogRecord; +import android.adservices.ondevicepersonalization.RequestToken; +import android.content.ContentValues; + +import java.util.function.Consumer; + +public class SampleService extends IsolatedService { + class SampleWorker implements IsolatedWorker { + @Override public void onExecute(ExecuteInput input, Consumer<ExecuteOutput> consumer) { + ContentValues logData = new ContentValues(); + logData.put("id", "ad1"); + logData.put("pr", 5.0); + ExecuteOutput result = new ExecuteOutput.Builder() + .setRequestLogRecord(new RequestLogRecord.Builder().addRow(logData).build()) + .addRenderingConfig( + new RenderingConfig.Builder().addKey("bid1").build() + ) + .build(); + consumer.accept(result); + } + + @Override public void onRender(RenderInput input, Consumer<RenderOutput> consumer) { + var keys = input.getRenderingConfig().getKeys(); + if (keys.size() > 0) { + String html = "<body>" + input.getRenderingConfig().getKeys().get(0) + "</body>"; + consumer.accept(new RenderOutput.Builder().setContent(html).build()); + } else { + consumer.accept(null); + } + } + } + + @Override public IsolatedWorker onRequest(RequestToken requestToken) { + return new SampleWorker(); + } +} diff --git a/tests/endtoendtests/OdpClient/Android.bp b/tests/endtoendtests/OdpClient/Android.bp index eddd74a3..56760a59 100644 --- a/tests/endtoendtests/OdpClient/Android.bp +++ b/tests/endtoendtests/OdpClient/Android.bp @@ -18,7 +18,7 @@ package { android_test_helper_app { name: "OdpClient", - defaults: ["platform_app_defaults"], + defaults: ["framework-ondevicepersonalization-test-defaults"], platform_apis: true, srcs: [ "src/**/odpclient/*.java", diff --git a/tests/endtoendtests/OdpClient/AndroidManifest.xml b/tests/endtoendtests/OdpClient/AndroidManifest.xml index 9d19854c..71025450 100644 --- a/tests/endtoendtests/OdpClient/AndroidManifest.xml +++ b/tests/endtoendtests/OdpClient/AndroidManifest.xml @@ -17,7 +17,8 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.example.odpclient" - android:versionName="1.0.0" > + android:versionName="1.0.2" > + <uses-permission android:name="android.permission.ondevicepersonalization.MODIFY_ONDEVICEPERSONALIZATION_STATE"/> <application android:label="@string/title_activity_main"> diff --git a/tests/endtoendtests/OdpClient/res/layout/activity_main.xml b/tests/endtoendtests/OdpClient/res/layout/activity_main.xml index b96a0786..acee6c76 100644 --- a/tests/endtoendtests/OdpClient/res/layout/activity_main.xml +++ b/tests/endtoendtests/OdpClient/res/layout/activity_main.xml @@ -15,30 +15,73 @@ ~ limitations under the License. --> -<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" - android:orientation="vertical" - android:layout_width="match_parent" - android:layout_height="match_parent"> - - - <EditText - android:id="@+id/text_box" - android:inputType="text" - android:maxLines="1" - android:lines="1" - android:layout_height="wrap_content" - android:layout_width="match_parent" - android:hint="Keyword" /> - - <Button - android:id="@+id/get_ad_button" - android:layout_height="wrap_content" - android:layout_width="wrap_content" - style="?android:attr/buttonBarButtonStyle" - android:text="@string/get_ad" /> - - <SurfaceView - android:id="@+id/rendered_view" - android:layout_width="200dp" - android:layout_height="200dp" /> -</LinearLayout> +<ScrollView xmlns:android="http://schemas.android.com/apk/res/android" + android:layout_width="match_parent" + android:layout_height="match_parent"> + + <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:orientation="vertical" + android:layout_width="match_parent" + android:layout_height="match_parent"> + + <EditText + android:id="@+id/text_box" + android:inputType="text" + android:maxLines="1" + android:lines="1" + android:layout_height="wrap_content" + android:layout_width="match_parent" + android:hint="Keyword" /> + + <Button + android:id="@+id/get_ad_button" + android:layout_height="wrap_content" + android:layout_width="wrap_content" + style="?android:attr/buttonBarButtonStyle" + android:text="@string/get_ad" /> + + <EditText + android:id="@+id/schedule_training_text_box" + android:inputType="text" + android:maxLines="1" + android:lines="1" + android:layout_height="wrap_content" + android:layout_width="match_parent" + android:hint="Population" /> + + <Button + android:id="@+id/schedule_training_button" + android:layout_height="wrap_content" + android:layout_width="wrap_content" + style="?android:attr/buttonBarButtonStyle" + android:text="@string/schedule_training" /> + + <EditText + android:id="@+id/report_conversion_text_box" + android:inputType="text" + android:maxLines="1" + android:lines="1" + android:layout_height="wrap_content" + android:layout_width="match_parent" + android:hint="SourceAdId" /> + + <Button + android:id="@+id/report_conversion_button" + android:layout_height="wrap_content" + android:layout_width="wrap_content" + style="?android:attr/buttonBarButtonStyle" + android:text="@string/report_conversion" /> + + <Button + android:id="@+id/set_status_button" + android:layout_height="wrap_content" + android:layout_width="wrap_content" + style="?android:attr/buttonBarButtonStyle" + android:text="@string/set_status" /> + + <SurfaceView + android:id="@+id/rendered_view" + android:layout_width="200dp" + android:layout_height="200dp" /> + </LinearLayout> +</ScrollView>
\ No newline at end of file diff --git a/tests/endtoendtests/OdpClient/res/values/strings.xml b/tests/endtoendtests/OdpClient/res/values/strings.xml index bfb80f4f..54799014 100644 --- a/tests/endtoendtests/OdpClient/res/values/strings.xml +++ b/tests/endtoendtests/OdpClient/res/values/strings.xml @@ -18,4 +18,7 @@ <resources> <string name="title_activity_main" description="Launcher title">Odp Test</string> <string name="get_ad">Get Ad</string> + <string name="schedule_training">Schedule Training</string> + <string name="report_conversion">Report Conversion</string> + <string name="set_status">Set Status</string> </resources> diff --git a/tests/endtoendtests/OdpClient/src/com/example/odpclient/MainActivity.java b/tests/endtoendtests/OdpClient/src/com/example/odpclient/MainActivity.java index 5e84a5ee..2fe34124 100644 --- a/tests/endtoendtests/OdpClient/src/com/example/odpclient/MainActivity.java +++ b/tests/endtoendtests/OdpClient/src/com/example/odpclient/MainActivity.java @@ -16,6 +16,7 @@ package com.example.odpclient; +import android.adservices.ondevicepersonalization.OnDevicePersonalizationConfigManager; import android.adservices.ondevicepersonalization.OnDevicePersonalizationManager; import android.adservices.ondevicepersonalization.SurfacePackageToken; import android.app.Activity; @@ -34,20 +35,29 @@ import android.widget.Button; import android.widget.EditText; import android.widget.Toast; +import androidx.annotation.NonNull; + import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; public class MainActivity extends Activity { private static final String TAG = "OdpClient"; private OnDevicePersonalizationManager mOdpManager = null; + private OnDevicePersonalizationConfigManager mOdpConfigManager = null; private EditText mTextBox; private Button mGetAdButton; + private EditText mScheduleTrainingTextBox; + private Button mScheduleTrainingButton; + private Button mSetStatusButton; + private EditText mReportConversionTextBox; + private Button mReportConversionButton; private SurfaceView mRenderedView; - private Context mContext; + private static Executor sCallbackExecutor = Executors.newSingleThreadExecutor(); @Override public void onCreate(Bundle savedInstanceState) { @@ -57,11 +67,23 @@ public class MainActivity extends Activity { if (mOdpManager == null) { mOdpManager = mContext.getSystemService(OnDevicePersonalizationManager.class); } + if (mOdpConfigManager == null) { + mOdpConfigManager = mContext.getSystemService( + OnDevicePersonalizationConfigManager.class); + } mRenderedView = findViewById(R.id.rendered_view); mRenderedView.setVisibility(View.INVISIBLE); mGetAdButton = findViewById(R.id.get_ad_button); + mScheduleTrainingButton = findViewById(R.id.schedule_training_button); + mSetStatusButton = findViewById(R.id.set_status_button); + mReportConversionButton = findViewById(R.id.report_conversion_button); mTextBox = findViewById(R.id.text_box); + mScheduleTrainingTextBox = findViewById(R.id.schedule_training_text_box); + mReportConversionTextBox = findViewById(R.id.report_conversion_text_box); registerGetAdButton(); + registerScheduleTrainingButton(); + registerSetStatusButton(); + registerReportConversionButton(); } private void registerGetAdButton() { @@ -69,6 +91,14 @@ public class MainActivity extends Activity { v -> makeRequest()); } + private void registerSetStatusButton() { + mSetStatusButton.setOnClickListener(v -> setPersonalizationStatus()); + } + + private void registerReportConversionButton() { + mReportConversionButton.setOnClickListener(v -> reportConversion()); + } + private void makeRequest() { try { if (mOdpManager == null) { @@ -85,7 +115,7 @@ public class MainActivity extends Activity { "com.example.odpsamplenetwork", "com.example.odpsamplenetwork.SampleService"), appParams, - Executors.newSingleThreadExecutor(), + sCallbackExecutor, new OutcomeReceiver<List<SurfacePackageToken>, Exception>() { @Override public void onResult(List<SurfacePackageToken> result) { @@ -112,7 +142,7 @@ public class MainActivity extends Activity { getDisplay().getDisplayId(), mRenderedView.getWidth(), mRenderedView.getHeight(), - Executors.newSingleThreadExecutor(), + sCallbackExecutor, new OutcomeReceiver<SurfacePackage, Exception>() { @Override public void onResult(SurfacePackage surfacePackage) { @@ -139,6 +169,101 @@ public class MainActivity extends Activity { } } + private void registerScheduleTrainingButton() { + mScheduleTrainingButton.setOnClickListener( + v -> scheduleTraining()); + } + + private void scheduleTraining() { + try { + if (mOdpManager == null) { + makeToast("OnDevicePersonalizationManager is null"); + return; + } + CountDownLatch latch = new CountDownLatch(1); + Log.i(TAG, "Starting execute()"); + PersistableBundle appParams = new PersistableBundle(); + appParams.putString("schedule_training", mScheduleTrainingTextBox.getText().toString()); + mOdpManager.execute( + ComponentName.createRelative( + "com.example.odpsamplenetwork", + "com.example.odpsamplenetwork.SampleService"), + appParams, + sCallbackExecutor, + new OutcomeReceiver<List<SurfacePackageToken>, Exception>() { + @Override + public void onResult(List<SurfacePackageToken> result) { + Log.i(TAG, "execute() success: " + result.size()); + latch.countDown(); + } + + @Override + public void onError(Exception e) { + makeToast("execute() error: " + e.toString()); + latch.countDown(); + } + }); + latch.await(); + } catch (Exception e) { + Log.e(TAG, "Error", e); + } + } + + private void reportConversion() { + try { + if (mOdpManager == null) { + makeToast("OnDevicePersonalizationManager is null"); + return; + } + CountDownLatch latch = new CountDownLatch(1); + Log.i(TAG, "Starting execute()"); + PersistableBundle appParams = new PersistableBundle(); + appParams.putString("conversion_ad_id", mReportConversionTextBox.getText().toString()); + mOdpManager.execute( + ComponentName.createRelative( + "com.example.odpsamplenetwork", + "com.example.odpsamplenetwork.SampleService"), + appParams, + sCallbackExecutor, + new OutcomeReceiver<List<SurfacePackageToken>, Exception>() { + @Override + public void onResult(List<SurfacePackageToken> result) { + Log.i(TAG, "execute() success: " + result.size()); + latch.countDown(); + } + + @Override + public void onError(Exception e) { + makeToast("execute() error: " + e.toString()); + latch.countDown(); + } + }); + latch.await(); + } catch (Exception e) { + Log.e(TAG, "Error", e); + } + } + + private void setPersonalizationStatus() { + if (mOdpConfigManager == null) { + makeToast("OnDevicePersonalizationConfigManager is null"); + } + boolean enabled = true; + mOdpConfigManager.setPersonalizationEnabled(enabled, + sCallbackExecutor, + new OutcomeReceiver<Void, Exception>() { + @Override + public void onResult(Void result) { + makeToast("Personalization status is set to " + enabled); + } + + @Override + public void onError(@NonNull Exception error) { + makeToast(error.getMessage()); + } + }); + } + private void makeToast(String message) { Log.i(TAG, message); runOnUiThread(() -> Toast.makeText(MainActivity.this, message, Toast.LENGTH_LONG).show()); diff --git a/tests/endtoendtests/OdpSampleNetwork/Android.bp b/tests/endtoendtests/OdpSampleNetwork/Android.bp index 535236b5..0231e241 100644 --- a/tests/endtoendtests/OdpSampleNetwork/Android.bp +++ b/tests/endtoendtests/OdpSampleNetwork/Android.bp @@ -18,7 +18,7 @@ package { android_test_helper_app { name: "OdpSampleNetwork", - defaults: ["platform_app_defaults"], + defaults: ["framework-ondevicepersonalization-test-defaults"], platform_apis: true, srcs: [ "src/**/*.java", @@ -32,6 +32,7 @@ android_test_helper_app { "androidx.core_core", "cuckoofilter", "guava", + "tensorflow_core_proto_java_lite", ], resource_dirs: [ "res", diff --git a/tests/endtoendtests/OdpSampleNetwork/AndroidManifest.xml b/tests/endtoendtests/OdpSampleNetwork/AndroidManifest.xml index da31dbef..3dce1e63 100644 --- a/tests/endtoendtests/OdpSampleNetwork/AndroidManifest.xml +++ b/tests/endtoendtests/OdpSampleNetwork/AndroidManifest.xml @@ -17,7 +17,7 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.example.odpsamplenetwork" - android:versionName="1.0.0" > + android:versionName="1.0.2" > <application android:label="OdpSampleNetwork" android:debuggable="true"> <property android:name="android.ondevicepersonalization.ON_DEVICE_PERSONALIZATION_CONFIG" diff --git a/tests/endtoendtests/OdpSampleNetwork/res/xml/OdpSettings.xml b/tests/endtoendtests/OdpSampleNetwork/res/xml/OdpSettings.xml index 73ee0ffc..b6b03efe 100644 --- a/tests/endtoendtests/OdpSampleNetwork/res/xml/OdpSettings.xml +++ b/tests/endtoendtests/OdpSampleNetwork/res/xml/OdpSettings.xml @@ -19,5 +19,6 @@ <on-device-personalization> <service name="com.example.odpsamplenetwork.SampleService" > <download-settings url="android.resource://com.example.odpsamplenetwork/raw/test_data1" /> + <federated-compute-settings url="https://fcp.odp-androidtest.dev/" /> </service> </on-device-personalization> diff --git a/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleHandler.java b/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleHandler.java index 6aac011e..d0e255e5 100644 --- a/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleHandler.java +++ b/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleHandler.java @@ -16,24 +16,30 @@ package com.example.odpsamplenetwork; -import android.adservices.ondevicepersonalization.AppInstallInfo; -import android.adservices.ondevicepersonalization.DownloadInput; -import android.adservices.ondevicepersonalization.DownloadOutput; +import android.adservices.ondevicepersonalization.AppInfo; +import android.adservices.ondevicepersonalization.DownloadCompletedInput; +import android.adservices.ondevicepersonalization.DownloadCompletedOutput; +import android.adservices.ondevicepersonalization.EventInput; import android.adservices.ondevicepersonalization.EventLogRecord; +import android.adservices.ondevicepersonalization.EventOutput; import android.adservices.ondevicepersonalization.EventUrlProvider; import android.adservices.ondevicepersonalization.ExecuteInput; import android.adservices.ondevicepersonalization.ExecuteOutput; +import android.adservices.ondevicepersonalization.FederatedComputeInput; +import android.adservices.ondevicepersonalization.FederatedComputeScheduler; import android.adservices.ondevicepersonalization.IsolatedWorker; import android.adservices.ondevicepersonalization.KeyValueStore; +import android.adservices.ondevicepersonalization.LogReader; import android.adservices.ondevicepersonalization.RenderInput; import android.adservices.ondevicepersonalization.RenderOutput; import android.adservices.ondevicepersonalization.RenderingConfig; import android.adservices.ondevicepersonalization.RequestLogRecord; +import android.adservices.ondevicepersonalization.TrainingExampleInput; +import android.adservices.ondevicepersonalization.TrainingExampleOutput; +import android.adservices.ondevicepersonalization.TrainingInterval; import android.adservices.ondevicepersonalization.UserData; -import android.adservices.ondevicepersonalization.WebViewEventInput; -import android.adservices.ondevicepersonalization.WebViewEventOutput; -import android.annotation.NonNull; import android.content.ContentValues; +import android.net.Uri; import android.os.PersistableBundle; import android.os.Process; import android.os.StrictMode; @@ -42,17 +48,28 @@ import android.util.Base64; import android.util.JsonReader; import android.util.Log; +import androidx.annotation.NonNull; + +import com.google.common.base.Strings; import com.google.common.util.concurrent.FluentFuture; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.google.protobuf.ByteString; import com.google.setfilters.cuckoofilter.CuckooFilter; +import org.tensorflow.example.BytesList; +import org.tensorflow.example.Example; +import org.tensorflow.example.Feature; +import org.tensorflow.example.Features; +import org.tensorflow.example.Int64List; + import java.io.IOException; import java.io.StringReader; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -65,12 +82,14 @@ public class SampleHandler implements IsolatedWorker { public static final String TAG = "OdpSampleNetwork"; public static final int EVENT_TYPE_IMPRESSION = 1; public static final int EVENT_TYPE_CLICK = 2; + public static final int EVENT_TYPE_CONVERSION = 3; public static final double COST_RAISING_FACTOR = 2.0; private static final String AD_ID_KEY = "adid"; private static final String BID_PRICE_KEY = "price"; private static final String AUCTION_SCORE_KEY = "score"; private static final String CLICK_COST_KEY = "clkcost"; private static final String EVENT_TYPE_KEY = "type"; + private static final String SOURCE_TYPE_KEY = "sourcetype"; private static final int BID_PRICE_OFFSET = 0; private static final String TRANSPARENT_PNG_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAA" @@ -86,12 +105,17 @@ public class SampleHandler implements IsolatedWorker { private final KeyValueStore mRemoteData; private final EventUrlProvider mEventUrlProvider; private final UserData mUserData; + private final FederatedComputeScheduler mFCScheduler; + private final LogReader mLogReader; SampleHandler(KeyValueStore remoteData, EventUrlProvider eventUrlProvider, - UserData userData) { + UserData userData, FederatedComputeScheduler fcScheduler, + LogReader logReader) { mRemoteData = remoteData; mEventUrlProvider = eventUrlProvider; mUserData = userData; + mFCScheduler = fcScheduler; + mLogReader = logReader; if (mRemoteData == null) { Log.e(TAG, "RemoteData missing"); } @@ -101,15 +125,21 @@ public class SampleHandler implements IsolatedWorker { if (mUserData == null) { Log.e(TAG, "UserData missing"); } + if (mFCScheduler == null) { + Log.e(TAG, "Federated Compute Scheduler missing"); + } + if (mLogReader == null) { + Log.e(TAG, "LogReader missing"); + } } @Override - public void onDownload( - @NonNull DownloadInput input, - @NonNull Consumer<DownloadOutput> consumer) { + public void onDownloadCompleted( + @NonNull DownloadCompletedInput input, + @NonNull Consumer<DownloadCompletedOutput> consumer) { Log.d(TAG, "onDownload() started."); - DownloadOutput downloadResult = - new DownloadOutput.Builder() + DownloadCompletedOutput downloadResult = + new DownloadCompletedOutput.Builder() .setRetainedKeys(getFilteredKeys(input.getData())) .build(); consumer.accept(downloadResult); @@ -123,6 +153,13 @@ public class SampleHandler implements IsolatedWorker { sBackgroundExecutor.execute(() -> handleOnExecute(input, consumer)); } + @Override public void onTrainingExample( + @NonNull TrainingExampleInput input, + @NonNull Consumer<TrainingExampleOutput> consumer) { + Log.d(TAG, "onTrainingExample() started."); + sBackgroundExecutor.execute(() -> handleOnTrainingExample(input, consumer)); + } + @Override public void onRender( @NonNull RenderInput input, @NonNull Consumer<RenderOutput> consumer @@ -131,10 +168,10 @@ public class SampleHandler implements IsolatedWorker { sBackgroundExecutor.execute(() -> handleOnRender(input, consumer)); } - @Override public void onWebViewEvent( - @NonNull WebViewEventInput input, - @NonNull Consumer<WebViewEventOutput> consumer) { - Log.d(TAG, "onWebViewEvent() started."); + @Override public void onEvent( + @NonNull EventInput input, + @NonNull Consumer<EventOutput> consumer) { + Log.d(TAG, "onEvent() started."); sBackgroundExecutor.execute( () -> handleOnWebViewEvent(input, consumer)); } @@ -144,6 +181,9 @@ public class SampleHandler implements IsolatedWorker { try { ArrayList<Ad> ads = new ArrayList<>(); for (var key: remoteData.keySet()) { + if (!key.startsWith("ad")) { + continue; + } Ad ad = parseAd(key, remoteData.get(key)); if (ad != null) { ads.add(ad); @@ -242,30 +282,127 @@ public class SampleHandler implements IsolatedWorker { .build(); } + static Feature convertStringToFeature(String value) { + BytesList.Builder bytesListBuilder = BytesList.newBuilder(); + String nonNullValue = Strings.nullToEmpty(value); + bytesListBuilder.addValue(ByteString.copyFromUtf8(nonNullValue)); + return Feature.newBuilder().setBytesList(bytesListBuilder.build()).build(); + } + + static Feature convertLongToFeature(long value) { + return Feature.newBuilder() + .setInt64List(Int64List.newBuilder().addValue(value).build()) + .build(); + } + + private void handleOnTrainingExample( + @NonNull TrainingExampleInput input, + @NonNull Consumer<TrainingExampleOutput> consumer) { + Features.Builder featuresBuilder = Features.newBuilder(); + + featuresBuilder.putFeature("int-feature-1", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-2", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-3", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-4", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-5", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-6", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-7", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-8", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-9", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-10", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-11", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-12", convertLongToFeature(0L)); + featuresBuilder.putFeature("int-feature-13", convertLongToFeature(0L)); + + featuresBuilder.putFeature("categorical-feature-14", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-15", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-16", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-17", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-18", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-19", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-20", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-21", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-22", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-23", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-24", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-25", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-26", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-27", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-28", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-29", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-30", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-31", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-32", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-33", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-34", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-35", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-36", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-37", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-38", convertStringToFeature("")); + featuresBuilder.putFeature("categorical-feature-39", convertStringToFeature("")); + + featuresBuilder.putFeature("clicked", convertLongToFeature(1L)); + + Example example = Example.newBuilder().setFeatures(featuresBuilder.build()).build(); + TrainingExampleOutput result = new TrainingExampleOutput + .Builder() + .addTrainingExample(example.toByteArray()) + .addResumptionToken("token1".getBytes()).build(); + consumer.accept(result); + } + private void handleOnExecute( @NonNull ExecuteInput input, @NonNull Consumer<ExecuteOutput> consumer ) { try { - var unused = FluentFuture.from(readAds(mRemoteData)) - .transform( - ads -> buildResult(runAuction(matchAds(ads, input))), - sBackgroundExecutor) - .transform( - result -> { - consumer.accept(result); - return null; - }, - MoreExecutors.directExecutor()) - .catching( - Exception.class, - e -> { - Log.e(TAG, "Execution failed.", e); - consumer.accept(null); - return null; - }, - MoreExecutors.directExecutor()); - + if (input != null && input.getAppParams() != null + && input.getAppParams().getString("schedule_training") != null) { + if (input.getAppParams().getString("schedule_training").isEmpty()) { + consumer.accept(null); + return; + } + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumInterval(Duration.ofSeconds(10)) + .setSchedulingMode(2) + .build(); + FederatedComputeScheduler.Params params = new FederatedComputeScheduler + .Params(interval); + FederatedComputeInput fcInput = new FederatedComputeInput.Builder() + .setPopulationName(input.getAppParams().getString("schedule_training")) + .build(); + mFCScheduler.schedule(params, fcInput); + + ExecuteOutput result = new ExecuteOutput.Builder().build(); + consumer.accept(result); + } else if (input != null && input.getAppParams() != null + && input.getAppParams().getString("conversion_ad_id") != null) { + try { + consumer.accept(handleConversion(input)); + } catch (Exception e) { + consumer.accept(null); + return; + } + } else { + var unused = FluentFuture.from(readAds(mRemoteData)) + .transform( + ads -> buildResult(runAuction(matchAds(ads, input))), + sBackgroundExecutor) + .transform( + result -> { + consumer.accept(result); + return null; + }, + MoreExecutors.directExecutor()) + .catching( + Exception.class, + e -> { + Log.e(TAG, "Execution failed.", e); + consumer.accept(null); + return null; + }, + MoreExecutors.directExecutor()); + } } catch (Exception e) { Log.e(TAG, "handleOnExecute() failed", e); consumer.accept(null); @@ -276,7 +413,7 @@ public class SampleHandler implements IsolatedWorker { try { PersistableBundle eventParams = new PersistableBundle(); eventParams.putInt(EVENT_TYPE_KEY, EVENT_TYPE_IMPRESSION); - String url = mEventUrlProvider.getEventTrackingUrl( + String url = mEventUrlProvider.createEventTrackingUrlWithResponse( eventParams, TRANSPARENT_PNG_BYTES, "image/png").toString(); return Futures.immediateFuture(url); } catch (Exception e) { @@ -289,8 +426,8 @@ public class SampleHandler implements IsolatedWorker { try { PersistableBundle eventParams = new PersistableBundle(); eventParams.putInt(EVENT_TYPE_KEY, EVENT_TYPE_CLICK); - String url = mEventUrlProvider.getEventTrackingUrlWithRedirect( - eventParams, landingPage).toString(); + String url = mEventUrlProvider.createEventTrackingUrlWithRedirect( + eventParams, Uri.parse(landingPage)).toString(); return Futures.immediateFuture(url); } catch (Exception e) { return Futures.immediateFailedFuture(e); @@ -325,6 +462,45 @@ public class SampleHandler implements IsolatedWorker { } } + private ExecuteOutput handleConversion(ExecuteInput input) { + String adId = input.getAppParams().getString("conversion_ad_id"); + if (adId.isEmpty()) { + return null; + } + long now = System.currentTimeMillis(); + List<EventLogRecord> logRecords = mLogReader.getJoinedEvents( + now - 24 * 60 * 60 * 1000, now); + EventLogRecord found = null; + // Attribute conversion to most recent impression or click. + for (EventLogRecord ev : logRecords) { + RequestLogRecord req = ev.getRequestLogRecord(); + if (req == null || req.getRows() == null + || req.getRows().size() <= ev.getRowIndex() + || req.getRows().get(ev.getRowIndex()) == null) { + continue; + } + String reqAdId = (String) req.getRows().get(ev.getRowIndex()).get(AD_ID_KEY); + if (adId.equals(reqAdId)) { + if (found == null || found.getTimeMillis() < ev.getTimeMillis()) { + found = ev; + } + } + } + var builder = new ExecuteOutput.Builder(); + if (found != null) { + ContentValues values = new ContentValues(); + values.put(SOURCE_TYPE_KEY, found.getType()); + EventLogRecord conv = new EventLogRecord.Builder() + .setType(EVENT_TYPE_CONVERSION) + .setData(values) + .setRowIndex(found.getRowIndex()) + .setRequestLogRecord(found.getRequestLogRecord()) + .build(); + builder.addEventLogRecord(conv); + } + return builder.build(); + } + private void handleOnRender( @NonNull RenderInput input, @NonNull Consumer<RenderOutput> consumer @@ -366,14 +542,14 @@ public class SampleHandler implements IsolatedWorker { } public void handleOnWebViewEvent( - @NonNull WebViewEventInput input, - @NonNull Consumer<WebViewEventOutput> consumer) { + @NonNull EventInput input, + @NonNull Consumer<EventOutput> consumer) { try { Log.d(TAG, "handleOnEvent() started."); PersistableBundle eventParams = input.getParameters(); int eventType = eventParams.getInt(EVENT_TYPE_KEY); if (eventType <= 0) { - consumer.accept(new WebViewEventOutput.Builder().build()); + consumer.accept(new EventOutput.Builder().build()); return; } ContentValues logData = null; @@ -392,7 +568,7 @@ public class SampleHandler implements IsolatedWorker { logData = new ContentValues(); logData.put(CLICK_COST_KEY, updatedPrice); } - WebViewEventOutput result = new WebViewEventOutput.Builder() + EventOutput result = new EventOutput.Builder() .setEventLogRecord( new EventLogRecord.Builder() .setRowIndex(0) @@ -412,8 +588,7 @@ public class SampleHandler implements IsolatedWorker { return false; } - if (mUserData.getAppInstallInfo() == null - || mUserData.getAppInstallInfo().isEmpty()) { + if (mUserData.getAppInfos() == null || mUserData.getAppInfos().isEmpty()) { Log.i(TAG, "No installed apps."); return false; } @@ -422,8 +597,8 @@ public class SampleHandler implements IsolatedWorker { return false; } - for (String app: mUserData.getAppInstallInfo().keySet()) { - AppInstallInfo value = mUserData.getAppInstallInfo().get(app); + for (String app: mUserData.getAppInfos().keySet()) { + AppInfo value = mUserData.getAppInfos().get(app); if (value != null && value.isInstalled() && filter.contains(app)) { return true; } @@ -438,8 +613,7 @@ public class SampleHandler implements IsolatedWorker { return false; } - if (mUserData.getAppInstallInfo() == null - || mUserData.getAppInstallInfo().isEmpty()) { + if (mUserData.getAppInfos() == null || mUserData.getAppInfos().isEmpty()) { Log.i(TAG, "No installed apps."); return false; } @@ -448,7 +622,7 @@ public class SampleHandler implements IsolatedWorker { return false; } - for (String app: mUserData.getAppInstallInfo().keySet()) { + for (String app: mUserData.getAppInfos().keySet()) { if (apps.contains(app)) { return true; } diff --git a/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleService.java b/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleService.java index 4cd090f7..9ff7eeda 100644 --- a/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleService.java +++ b/tests/endtoendtests/OdpSampleNetwork/src/com/example/odpsamplenetwork/SampleService.java @@ -19,12 +19,14 @@ package com.example.odpsamplenetwork; import android.adservices.ondevicepersonalization.IsolatedService; import android.adservices.ondevicepersonalization.IsolatedWorker; import android.adservices.ondevicepersonalization.RequestToken; -import android.annotation.NonNull; + +import androidx.annotation.NonNull; public class SampleService extends IsolatedService { @NonNull @Override public IsolatedWorker onRequest( RequestToken requestToken) { return new SampleHandler(getRemoteData(requestToken), getEventUrlProvider(requestToken), - getUserData(requestToken)); + getUserData(requestToken), getFederatedComputeScheduler(requestToken), + getLogReader(requestToken)); } } diff --git a/tests/endtoendtests/OdpTargetingApp1/AndroidManifest.xml b/tests/endtoendtests/OdpTargetingApp1/AndroidManifest.xml index 662f7643..8ad3364a 100644 --- a/tests/endtoendtests/OdpTargetingApp1/AndroidManifest.xml +++ b/tests/endtoendtests/OdpTargetingApp1/AndroidManifest.xml @@ -17,7 +17,7 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.example.odptargetingapp1" - android:versionName="1.0.0" > + android:versionName="1.0.1" > <application android:label="ODP Targeting App 1"> </application> diff --git a/tests/endtoendtests/OdpTargetingApp2/AndroidManifest.xml b/tests/endtoendtests/OdpTargetingApp2/AndroidManifest.xml index 892dc212..90439c23 100644 --- a/tests/endtoendtests/OdpTargetingApp2/AndroidManifest.xml +++ b/tests/endtoendtests/OdpTargetingApp2/AndroidManifest.xml @@ -17,7 +17,7 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="com.example.odptargetingapp2" - android:versionName="1.0.0" > + android:versionName="1.0.1" > <application android:label="ODP Targeting App 2"> </application> diff --git a/tests/federatedcomputetests/Android.bp b/tests/federatedcomputetests/Android.bp index 4268aa32..bfc77aa5 100644 --- a/tests/federatedcomputetests/Android.bp +++ b/tests/federatedcomputetests/Android.bp @@ -22,7 +22,8 @@ android_test { "src/**/*.java", ":federatedcompute-sources", ":federatedcompute-fbs", - + ":statslog-federatedcompute-java-gen", + ":fcp_native_wrapper", ], libs: [ "android.test.base", @@ -31,6 +32,7 @@ android_test { "framework-annotations-lib", "framework-ondevicepersonalization.impl", "framework-configinfrastructure", // For PH flags + "framework-statsd.stubs.module_lib", // For WW logging "truth-prebuilt", ], static_libs: [ @@ -42,7 +44,6 @@ android_test { "mockito-target-extended-minus-junit4", "modules-utils-build", "modules-utils-preconditions", - "flatbuffers-java", "libprotobuf-java-lite", "tensorflow_core_proto_java_lite", ], @@ -61,5 +62,6 @@ android_test { jni_libs: [ "libdexmakerjvmtiagent", "libstaticjvmtiagent", + "libfcp_cpp_dep_jni", ], } diff --git a/tests/federatedcomputetests/AndroidManifest.xml b/tests/federatedcomputetests/AndroidManifest.xml index 5d4087e5..9767a489 100644 --- a/tests/federatedcomputetests/AndroidManifest.xml +++ b/tests/federatedcomputetests/AndroidManifest.xml @@ -23,6 +23,7 @@ <uses-permission android:name="android.permission.INTERNET" /> <!-- Used for persisting scheduled jobs --> <uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" /> + <uses-permission android:name="android.permission.BIND_EXAMPLE_STORE_SERVICE" /> <application android:label="FederatedComputeServicesTests" android:debuggable="true"> @@ -34,12 +35,12 @@ <service android:name="com.android.federatedcompute.services.training.IsolatedTrainingService" android:isolatedProcess="true" android:exported="false" > </service> - <service android:name="com.android.federatedcompute.services.examplestore.SampleExampleStoreService" - android:enabled="true" - android:exported="true"> + <service + android:name="com.android.federatedcompute.services.examplestore.SampleExampleStoreService" + android:enabled="true" + android:exported="true"> <intent-filter> <action android:name="android.federatedcompute.EXAMPLE_STORE" /> - <data android:scheme="app" /> </intent-filter> </service> <service android:name="com.android.federatedcompute.services.training.SampleResultHandlingService" @@ -47,9 +48,13 @@ android:exported="true"> <intent-filter> <action android:name="android.federatedcompute.COMPUTATION_RESULT" /> - <data android:scheme="app" /> </intent-filter> </service> + <service android:name="com.android.federatedcompute.services.training.FederatedIsolatedTrainingService" + android:exported="false" + android:process=":remote" + android:isolatedProcess="true"> + </service> </application> <instrumentation android:name="androidx.test.runner.AndroidJUnitRunner" android:targetPackage="com.android.ondevicepersonalization.federatedcomputetests" diff --git a/tests/federatedcomputetests/res/raw/federation_client_only_plan.pb b/tests/federatedcomputetests/res/raw/federation_client_only_plan.pb Binary files differnew file mode 100644 index 00000000..03a0945c --- /dev/null +++ b/tests/federatedcomputetests/res/raw/federation_client_only_plan.pb diff --git a/tests/federatedcomputetests/res/raw/federation_proxy_train_examples.pb b/tests/federatedcomputetests/res/raw/federation_proxy_train_examples.pb Binary files differnew file mode 100644 index 00000000..36237420 --- /dev/null +++ b/tests/federatedcomputetests/res/raw/federation_proxy_train_examples.pb diff --git a/tests/federatedcomputetests/res/raw/federation_test_checkpoint_client.ckp b/tests/federatedcomputetests/res/raw/federation_test_checkpoint_client.ckp Binary files differnew file mode 100644 index 00000000..1a8858c4 --- /dev/null +++ b/tests/federatedcomputetests/res/raw/federation_test_checkpoint_client.ckp diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegateTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegateTest.java index e9d6568b..64a70001 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegateTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/FederatedComputeManagingServiceDelegateTest.java @@ -16,7 +16,21 @@ package com.android.federatedcompute.services; +import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; +import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; + +import static com.android.dx.mockito.inline.extended.ExtendedMockito.doAnswer; +import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL; +import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE; + +import static com.google.common.truth.Truth.assertThat; + import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import android.content.Context; import android.federatedcompute.aidl.IFederatedComputeCallback; @@ -24,23 +38,37 @@ import android.federatedcompute.common.TrainingOptions; import androidx.test.core.app.ApplicationProvider; +import com.android.federatedcompute.services.common.Clock; import com.android.federatedcompute.services.common.PhFlagsTestUtil; import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager; +import com.android.federatedcompute.services.statsd.ApiCallStats; +import com.android.federatedcompute.services.statsd.FederatedComputeStatsdLogger; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; @RunWith(JUnit4.class) public final class FederatedComputeManagingServiceDelegateTest { + private static final int BINDER_CONNECTION_TIMEOUT_MS = 10_000; + private FederatedComputeManagingServiceDelegate mFcpService; private Context mContext; + private final FederatedComputeStatsdLogger mFcStatsdLogger = + spy(FederatedComputeStatsdLogger.getInstance()); + private static final CountDownLatch sJobFinishCountDown = new CountDownLatch(1); + @Mock FederatedComputeJobManager mMockJobManager; + @Mock private Clock mClock; @Before public void setUp() throws Exception { @@ -49,38 +77,48 @@ public final class FederatedComputeManagingServiceDelegateTest { PhFlagsTestUtil.disableGlobalKillSwitch(); mContext = ApplicationProvider.getApplicationContext(); - mFcpService = new FederatedComputeManagingServiceDelegate(mContext, new TestInjector()); + mFcpService = + new FederatedComputeManagingServiceDelegate( + mContext, new TestInjector(), mFcStatsdLogger, mClock); + when(mClock.elapsedRealtime()).thenReturn(100L, 200L); } @Test - public void testScheduleMissingPackageName_throwsException() throws Exception { + public void testScheduleMissingPackageName_throwsException() { TrainingOptions trainingOptions = new TrainingOptions.Builder().setPopulationName("fake-population").build(); assertThrows( NullPointerException.class, - () -> - mFcpService.scheduleFederatedCompute( - null, trainingOptions, new FederatedComputeCallback())); + () -> mFcpService.schedule(null, trainingOptions, new FederatedComputeCallback())); } @Test - public void testScheduleMissingCallback_throwsException() throws Exception { + public void testScheduleMissingCallback_throwsException() { TrainingOptions trainingOptions = new TrainingOptions.Builder().setPopulationName("fake-population").build(); assertThrows( NullPointerException.class, - () -> - mFcpService.scheduleFederatedCompute( - mContext.getPackageName(), trainingOptions, null)); + () -> mFcpService.schedule(mContext.getPackageName(), trainingOptions, null)); } @Test public void testSchedule_returnsSuccess() throws Exception { + when(mMockJobManager.onTrainerStartCalled(anyString(), any())).thenReturn(STATUS_SUCCESS); + TrainingOptions trainingOptions = new TrainingOptions.Builder().setPopulationName("fake-population").build(); - mFcpService.scheduleFederatedCompute( - mContext.getPackageName(), trainingOptions, new FederatedComputeCallback()); + invokeScheduleAndVerifyLogging(trainingOptions, STATUS_SUCCESS); + } + + @Test + public void testScheduleFailed() throws Exception { + when(mMockJobManager.onTrainerStartCalled(anyString(), any())) + .thenReturn(STATUS_INTERNAL_ERROR); + + TrainingOptions trainingOptions = + new TrainingOptions.Builder().setPopulationName("fake-population").build(); + invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR); } @Test @@ -92,7 +130,7 @@ public final class FederatedComputeManagingServiceDelegateTest { assertThrows( IllegalStateException.class, () -> - mFcpService.scheduleFederatedCompute( + mFcpService.schedule( mContext.getPackageName(), trainingOptions, new FederatedComputeCallback())); @@ -101,25 +139,124 @@ public final class FederatedComputeManagingServiceDelegateTest { } } + @Test + public void testCancelMissingPackageName_throwsException() { + assertThrows( + NullPointerException.class, + () -> mFcpService.cancel(null, "fake-population", new FederatedComputeCallback())); + } + + @Test + public void testCancelMissingCallback_throwsException() { + assertThrows( + NullPointerException.class, + () -> mFcpService.cancel(mContext.getPackageName(), "fake-population", null)); + } + + @Test + public void testCancel_returnsSuccess() throws Exception { + when(mMockJobManager.onTrainerStopCalled(anyString(), anyString())) + .thenReturn(STATUS_SUCCESS); + + invokeCancelAndVerifyLogging("fake-population", STATUS_SUCCESS); + } + + @Test + public void testCancelFails() throws Exception { + when(mMockJobManager.onTrainerStopCalled(anyString(), anyString())) + .thenReturn(STATUS_INTERNAL_ERROR); + + invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR); + } + + @Test + public void testCancelEnabledGlobalKillSwitch_throwsException() { + PhFlagsTestUtil.enableGlobalKillSwitch(); + try { + assertThrows( + IllegalStateException.class, + () -> + mFcpService.cancel( + mContext.getPackageName(), + "fake-population", + new FederatedComputeCallback())); + } finally { + PhFlagsTestUtil.disableGlobalKillSwitch(); + } + } + + private void invokeScheduleAndVerifyLogging( + TrainingOptions trainingOptions, int expectedResultCode) throws InterruptedException { + mFcpService.schedule( + mContext.getPackageName(), trainingOptions, new FederatedComputeCallback()); + + final CountDownLatch logOperationCalledLatch = new CountDownLatch(1); + doAnswer( + new Answer<Object>() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + // The method logAPiCallStats is called. + invocation.callRealMethod(); + logOperationCalledLatch.countDown(); + return null; + } + }) + .when(mFcStatsdLogger) + .logApiCallStats(any(ApiCallStats.class)); + sJobFinishCountDown.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + logOperationCalledLatch.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + + ArgumentCaptor<ApiCallStats> argument = ArgumentCaptor.forClass(ApiCallStats.class); + verify(mFcStatsdLogger).logApiCallStats(argument.capture()); + assertThat(argument.getValue().getResponseCode()).isEqualTo(expectedResultCode); + assertThat(argument.getValue().getLatencyMillis()).isEqualTo(100); + assertThat(argument.getValue().getApiName()) + .isEqualTo(FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE); + } + + private void invokeCancelAndVerifyLogging(String populationName, int expectedResultCode) + throws InterruptedException { + mFcpService.cancel( + mContext.getPackageName(), populationName, new FederatedComputeCallback()); + + final CountDownLatch logOperationCalledLatch = new CountDownLatch(1); + doAnswer( + new Answer<Object>() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + // The method logAPiCallStats is called. + invocation.callRealMethod(); + logOperationCalledLatch.countDown(); + return null; + } + }) + .when(mFcStatsdLogger) + .logApiCallStats(any(ApiCallStats.class)); + sJobFinishCountDown.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + logOperationCalledLatch.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + + ArgumentCaptor<ApiCallStats> argument = ArgumentCaptor.forClass(ApiCallStats.class); + verify(mFcStatsdLogger).logApiCallStats(argument.capture()); + assertThat(argument.getValue().getResponseCode()).isEqualTo(expectedResultCode); + assertThat(argument.getValue().getLatencyMillis()).isEqualTo(100); + assertThat(argument.getValue().getApiName()) + .isEqualTo(FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL); + } + static class FederatedComputeCallback extends IFederatedComputeCallback.Stub { public boolean mError = false; public int mErrorCode = 0; - private CountDownLatch mLatch = new CountDownLatch(1); @Override public void onSuccess() { - mLatch.countDown(); + sJobFinishCountDown.countDown(); } @Override public void onFailure(int errorCode) { mError = true; mErrorCode = errorCode; - mLatch.countDown(); - } - - public void await() throws Exception { - mLatch.await(); + sJobFinishCountDown.countDown(); } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/IsolatedServiceBinderTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/IsolatedServiceBinderTest.java new file mode 100644 index 00000000..12896c54 --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/IsolatedServiceBinderTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services; + +import static com.android.federatedcompute.services.common.Constants.ISOLATED_TRAINING_SERVICE_NAME; + +import static org.junit.Assert.assertNotNull; + +import android.content.Context; + +import androidx.test.core.app.ApplicationProvider; + +import com.android.federatedcompute.internal.util.AbstractServiceBinder; +import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService; + +import org.junit.Test; + +public class IsolatedServiceBinderTest { + + private final Context mContext = ApplicationProvider.getApplicationContext(); + + @Test + public void testFcpServiceBindingByName() { + AbstractServiceBinder<IIsolatedTrainingService> serviceBinder = + AbstractServiceBinder.getServiceBinderByServiceName( + mContext, + ISOLATED_TRAINING_SERVICE_NAME, + mContext.getPackageName(), + IIsolatedTrainingService.Stub::asInterface); + + final IIsolatedTrainingService service = serviceBinder.getService(Runnable::run); + assertNotNull(service); + } +} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDbHelperTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeDbHelperTest.java index 15ad83c9..651fb7fe 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDbHelperTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeDbHelperTest.java @@ -17,10 +17,12 @@ package com.android.federatedcompute.services.data; import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE; +import static com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.ENCRYPTION_KEY_TABLE; import static com.google.common.truth.Truth.assertThat; import android.content.Context; +import android.database.Cursor; import android.database.DatabaseUtils; import android.database.sqlite.SQLiteDatabase; @@ -32,8 +34,10 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import java.util.ArrayList; + @RunWith(AndroidJUnit4.class) -public final class FederatedTrainingTaskDbHelperTest { +public final class FederatedComputeDbHelperTest { private Context mContext; @Before @@ -43,17 +47,30 @@ public final class FederatedTrainingTaskDbHelperTest { @After public void cleanUp() throws Exception { - FederatedTrainingTaskDbHelper.resetInstance(); + FederatedComputeDbHelper.resetInstance(); } @Test public void onCreate() { - FederatedTrainingTaskDbHelper dbHelper = - FederatedTrainingTaskDbHelper.getInstanceForTest(mContext); + FederatedComputeDbHelper dbHelper = FederatedComputeDbHelper.getInstanceForTest(mContext); SQLiteDatabase db = dbHelper.getReadableDatabase(); assertThat(db).isNotNull(); assertThat(DatabaseUtils.queryNumEntries(db, FEDERATED_TRAINING_TASKS_TABLE)).isEqualTo(0); + assertThat(DatabaseUtils.queryNumEntries(db, ENCRYPTION_KEY_TABLE)).isEqualTo(0); + + // query number of tables + ArrayList<String> tableNames = new ArrayList<String>(); + Cursor cursor = db.rawQuery("SELECT name FROM sqlite_master WHERE type='table'", null); + if (cursor.moveToFirst()) { + while (!cursor.isAfterLast()) { + tableNames.add(cursor.getString(cursor.getColumnIndex("name"))); + cursor.moveToNext(); + } + } + String[] expectedTables = {FEDERATED_TRAINING_TASKS_TABLE, ENCRYPTION_KEY_TABLE}; + // android metadata table also exists in the database + assertThat(tableNames).containsAtLeastElementsIn(expectedTables); } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyDaoTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyDaoTest.java new file mode 100644 index 00000000..37e22944 --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyDaoTest.java @@ -0,0 +1,259 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.data; + +import static com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.ENCRYPTION_KEY_TABLE; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import android.content.Context; +import android.database.DatabaseUtils; +import android.database.sqlite.SQLiteDatabase; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import com.android.federatedcompute.services.common.Clock; +import com.android.federatedcompute.services.common.MonotonicClock; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyContract.FederatedComputeEncryptionColumns; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.List; +import java.util.Random; +import java.util.UUID; + +@RunWith(AndroidJUnit4.class) +public class FederatedComputeEncryptionKeyDaoTest { + private static final String KEY_ID = "0962201a-5abd-4e25-a486-2c7bd1ee1887"; + private static final String PUBLICKEY = "GOcMAnY4WkDYp6R3WSw8IpYK6eVe2RGZ9Z0OBb3EbjQ\\u003d"; + private static final int KEY_TYPE = FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION; + private static final long NOW = 1698193647L; + private static final long TTL = 100L; + + private FederatedComputeEncryptionKeyDao mEncryptionKeyDao; + private Context mContext; + + private final Clock mClock = MonotonicClock.getInstance(); + + @Before + public void setUp() { + mContext = ApplicationProvider.getApplicationContext(); + mEncryptionKeyDao = FederatedComputeEncryptionKeyDao.getInstanceForTest(mContext); + } + + @After + public void cleanUp() throws Exception { + FederatedComputeDbHelper dbHelper = FederatedComputeDbHelper.getInstanceForTest(mContext); + dbHelper.getWritableDatabase().close(); + dbHelper.getReadableDatabase().close(); + dbHelper.close(); + } + + @Test + public void testInsertEncryptionKey_success() throws Exception { + FederatedComputeEncryptionKey key1 = createRandomPublicKeyWithConstantTTL(3600); + FederatedComputeEncryptionKey key2 = createRandomPublicKeyWithConstantTTL(3600); + + assertTrue(mEncryptionKeyDao.insertEncryptionKey(key1)); + assertTrue(mEncryptionKeyDao.insertEncryptionKey(key2)); + + SQLiteDatabase db = + FederatedComputeDbHelper.getInstanceForTest(mContext).getReadableDatabase(); + + assertThat(DatabaseUtils.queryNumEntries(db, ENCRYPTION_KEY_TABLE)).isEqualTo(2); + } + + @Test + public void testInsertDuplicateEncryptionKey_success() { + FederatedComputeEncryptionKey key1 = createRandomPublicKeyWithConstantTTL(3600); + + assertTrue(mEncryptionKeyDao.insertEncryptionKey(key1)); + + FederatedComputeEncryptionKey key2 = + new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier(key1.getKeyIdentifier()) + .setPublicKey(key1.getPublicKey()) + .setKeyType(key1.getKeyType()) + .setCreationTime(key1.getCreationTime()) + .setExpiryTime(key1.getExpiryTime() + 10000L).build(); + + assertTrue(mEncryptionKeyDao.insertEncryptionKey(key2)); + + List<FederatedComputeEncryptionKey> keyList = mEncryptionKeyDao + .getLatestExpiryNKeys(2); + + assertThat(keyList.size()).isEqualTo(1); + assertThat(keyList.get(0)).isEqualTo(key2); + } + + @Test + public void testInsertNullPublicKeyFieldThrows() { + assertThrows(NullPointerException.class, () -> insertNullFieldEncryptionKey()); + } + + @Test + public void testQueryKeys_success() { + List<FederatedComputeEncryptionKey> keyList0 = + mEncryptionKeyDao.readFederatedComputeEncryptionKeysFromDatabase( + "" /* selection= */, new String[0] /* selectionArgs= */, + "" /* orderBy= */, 5); + + assertThat(keyList0.size()).isEqualTo(0); + + FederatedComputeEncryptionKey key1 = createFixedPublicKey(); + mEncryptionKeyDao.insertEncryptionKey(key1); + + List<FederatedComputeEncryptionKey> keyList1 = + mEncryptionKeyDao.readFederatedComputeEncryptionKeysFromDatabase( + "" /* selection= */, + new String[0] /* selectionArgs= */, + FederatedComputeEncryptionColumns.EXPIRY_TIME + " DESC", + 1); + + assertThat(keyList1.get(0)).isEqualTo(key1); + + // with selection args + String selection = + FederatedComputeEncryptionKeyContract.FederatedComputeEncryptionColumns + .KEY_IDENTIFIER + + " = ? "; + String[] selectionArgs = {KEY_ID}; + + List<FederatedComputeEncryptionKey> keyList2 = + mEncryptionKeyDao.readFederatedComputeEncryptionKeysFromDatabase( + selection, + selectionArgs, + FederatedComputeEncryptionKeyContract.FederatedComputeEncryptionColumns + .EXPIRY_TIME + + " DESC", + 1); + + assertThat(keyList2.size()).isEqualTo(1); + assertThat(keyList2.get(0)).isEqualTo(key1); + } + + @Test + public void findExpiryKeys_success() { + FederatedComputeEncryptionKey key1 = createRandomPublicKeyWithConstantTTL(1000000L); + FederatedComputeEncryptionKey key2 = createRandomPublicKeyWithConstantTTL(2000000L); + FederatedComputeEncryptionKey key3 = createRandomPublicKeyWithConstantTTL(3000000L); + mEncryptionKeyDao.insertEncryptionKey(key1); + mEncryptionKeyDao.insertEncryptionKey(key2); + mEncryptionKeyDao.insertEncryptionKey(key3); + + List<FederatedComputeEncryptionKey> keyList = mEncryptionKeyDao.getLatestExpiryNKeys(3); + + assertThat(keyList.size()).isEqualTo(3); + assertThat(keyList.get(0)).isEqualTo(key3); + assertThat(keyList.get(1)).isEqualTo(key2); + assertThat(keyList.get(2)).isEqualTo(key1); + } + + @Test + public void findExpiryKeysWithlimit_success() { + FederatedComputeEncryptionKey key1 = createRandomPublicKeyWithConstantTTL(1000000L); + FederatedComputeEncryptionKey key2 = createRandomPublicKeyWithConstantTTL(2000000L); + FederatedComputeEncryptionKey key3 = createRandomPublicKeyWithConstantTTL(3000000L); + mEncryptionKeyDao.insertEncryptionKey(key1); + mEncryptionKeyDao.insertEncryptionKey(key2); + mEncryptionKeyDao.insertEncryptionKey(key3); + + List<FederatedComputeEncryptionKey> keyList = mEncryptionKeyDao.getLatestExpiryNKeys(2); + + assertThat(keyList.size()).isEqualTo(2); + assertThat(keyList.get(0)).isEqualTo(key3); + assertThat(keyList.get(1)).isEqualTo(key2); + + // limit = 0 + List<FederatedComputeEncryptionKey> keyList0 = mEncryptionKeyDao.getLatestExpiryNKeys(0); + assertThat(keyList0.size()).isEqualTo(0); + } + + @Test + public void findExpiryKeys_empty_success() { + List<FederatedComputeEncryptionKey> keyList = mEncryptionKeyDao.getLatestExpiryNKeys(3); + + assertThat(keyList.size()).isEqualTo(0); + + List<FederatedComputeEncryptionKey> keyList0 = mEncryptionKeyDao.getLatestExpiryNKeys(0); + + assertThat(keyList0.size()).isEqualTo(0); + } + + @Test + public void deleteExpiredKeys_success() throws Exception { + FederatedComputeEncryptionKey key1 = createRandomPublicKeyWithConstantTTL(0); + mEncryptionKeyDao.insertEncryptionKey(key1); + + int deletedRows = mEncryptionKeyDao.deleteExpiredKeys(); + + assertThat(deletedRows).isEqualTo(1); + + // check current number of rows + List<FederatedComputeEncryptionKey> keyList = mEncryptionKeyDao.getLatestExpiryNKeys(3); + + assertThat(keyList.size()).isEqualTo(0); + } + + @Test + public void deleteNoKeys_success() { + int deletedRows = mEncryptionKeyDao.deleteExpiredKeys(); + assertThat(deletedRows).isEqualTo(0); + } + + private void insertNullFieldEncryptionKey() throws Exception { + FederatedComputeEncryptionKey key1 = + new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier(UUID.randomUUID().toString()) + .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_UNDEFINED) + .setCreationTime(mClock.currentTimeMillis()) + .setExpiryTime(mClock.currentTimeMillis() + TTL) + .build(); + + mEncryptionKeyDao.insertEncryptionKey(key1); + } + + private FederatedComputeEncryptionKey createRandomPublicKeyWithConstantTTL(long ttl) { + byte[] bytes = new byte[32]; + Random generator = new Random(); + generator.nextBytes(bytes); + return new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier(UUID.randomUUID().toString()) + .setPublicKey(new String(bytes, 0, bytes.length)) + .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_UNDEFINED) + .setCreationTime(mClock.currentTimeMillis()) + .setExpiryTime(mClock.currentTimeMillis() + ttl) + .build(); + } + + private FederatedComputeEncryptionKey createFixedPublicKey() { + return new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier(KEY_ID) + .setPublicKey(PUBLICKEY) + .setKeyType(KEY_TYPE) + .setCreationTime(NOW) + .setExpiryTime(NOW + TTL) + .build(); + } +} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyTest.java new file mode 100644 index 00000000..be9785cb --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedComputeEncryptionKeyTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.data; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + + +@RunWith(AndroidJUnit4.class) +public class FederatedComputeEncryptionKeyTest { + + private static final String KEY_ID = "0962201a-5abd-4e25-a486-2c7bd1ee1887"; + private static final String PUBLIC_KEY = "GOcMAnY4WkDYp6R3WSw8IpYK6eVe2RGZ9Z0OBb3EbjQ\\u003d"; + + private static final int KEY_TYPE = FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION; + + private static final long NOW = 1698193647L; + + private static final long TTL = 100L; + + @Test + public void testBuilderAndEquals() { + FederatedComputeEncryptionKey key1 = + new FederatedComputeEncryptionKey.Builder( + KEY_ID, PUBLIC_KEY, KEY_TYPE, NOW, NOW + TTL) + .build(); + + FederatedComputeEncryptionKey key2 = + new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier(KEY_ID) + .setPublicKey(PUBLIC_KEY) + .setKeyType(KEY_TYPE) + .setCreationTime(NOW) + .setExpiryTime(NOW + TTL) + .build(); + + assertEquals(key1, key2); + + FederatedComputeEncryptionKey key3 = + new FederatedComputeEncryptionKey.Builder() + .setKeyIdentifier(KEY_ID) + .setPublicKey(PUBLIC_KEY) + .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_UNDEFINED) + .setCreationTime(NOW) + .setExpiryTime(NOW + TTL) + .build(); + + assertNotEquals(key1, key3); + assertNotEquals(key2, key3); + } + + @Test + public void testBuildTwiceThrows() { + FederatedComputeEncryptionKey.Builder builder = + new FederatedComputeEncryptionKey.Builder( + KEY_ID, PUBLIC_KEY, KEY_TYPE, NOW, NOW + TTL); + builder.build(); + + assertThrows(IllegalStateException.class, () -> builder.build()); + } +} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDaoTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDaoTest.java index 16aaafaf..6daff32d 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDaoTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskDaoTest.java @@ -60,8 +60,8 @@ public final class FederatedTrainingTaskDaoTest { @After public void cleanUp() throws Exception { - FederatedTrainingTaskDbHelper dbHelper = - FederatedTrainingTaskDbHelper.getInstanceForTest(mContext); + FederatedComputeDbHelper dbHelper = + FederatedComputeDbHelper.getInstanceForTest(mContext); dbHelper.getWritableDatabase().close(); dbHelper.getReadableDatabase().close(); dbHelper.close(); diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskTest.java index e0f772b2..5650b958 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/data/FederatedTrainingTaskTest.java @@ -59,12 +59,12 @@ public final class FederatedTrainingTaskTest { private static final byte[] TRAINING_CONSTRAINTS = createDefaultTrainingConstraints(); private SQLiteDatabase mDatabase; - private FederatedTrainingTaskDbHelper mDbHelper; + private FederatedComputeDbHelper mDbHelper; @Before public void setUp() { Context context = ApplicationProvider.getApplicationContext(); - mDbHelper = FederatedTrainingTaskDbHelper.getInstanceForTest(context); + mDbHelper = FederatedComputeDbHelper.getInstanceForTest(context); mDatabase = mDbHelper.getWritableDatabase(); mDbHelper.resetDatabase(mDatabase); } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorderTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorderTest.java index 4a03169d..fc2b31cc 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorderTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleConsumptionRecorderTest.java @@ -38,20 +38,20 @@ public final class ExampleConsumptionRecorderTest { } @Test - public void testIncrementSameCollectionAndCriteria() { - String collection = "collection"; + public void testIncrementSameTaskNameAndCriteria() { + String taskName = "taskName"; byte[] selectionCriteria = new byte[] {10, 0, 1}; ExampleConsumptionRecorder recorder = new ExampleConsumptionRecorder(); byte[] token1 = "token1".getBytes(Charset.defaultCharset()); SingleQueryRecorder singleRecorder = - recorder.createRecorderForTracking(collection, selectionCriteria); + recorder.createRecorderForTracking(taskName, selectionCriteria); singleRecorder.incrementAndUpdateResumptionToken(token1); byte[] token2 = "token2".getBytes(Charset.defaultCharset()); singleRecorder.incrementAndUpdateResumptionToken(token2); assertThat(recorder.finishRecordingAndGet()) .containsExactly( new ExampleConsumption.Builder() - .setCollectionName(collection) + .setTaskName(taskName) .setExampleCount(2) .setSelectionCriteria(selectionCriteria) .setResumptionToken(token2) @@ -59,29 +59,29 @@ public final class ExampleConsumptionRecorderTest { } @Test - public void testIncrementDifferentCollection() { - String collection1 = "collection1"; + public void testIncrementDifferentTaskName() { + String taskName = "taskName"; byte[] criteria = new byte[] {10, 0, 1}; ExampleConsumptionRecorder recorder = new ExampleConsumptionRecorder(); byte[] token1 = "token1".getBytes(Charset.defaultCharset()); SingleQueryRecorder singleRecorder1 = - recorder.createRecorderForTracking(collection1, criteria); + recorder.createRecorderForTracking(taskName, criteria); singleRecorder1.incrementAndUpdateResumptionToken(token1); - String collection2 = "collection2"; + String taskName2 = "taskName2"; byte[] token2 = "token2".getBytes(Charset.defaultCharset()); SingleQueryRecorder singleQueryRecorder2 = - recorder.createRecorderForTracking(collection2, criteria); + recorder.createRecorderForTracking(taskName2, criteria); singleQueryRecorder2.incrementAndUpdateResumptionToken(token2); assertThat(recorder.finishRecordingAndGet()) .containsExactly( new ExampleConsumption.Builder() - .setCollectionName(collection1) + .setTaskName(taskName) .setSelectionCriteria(criteria) .setExampleCount(1) .setResumptionToken(token1) .build(), new ExampleConsumption.Builder() - .setCollectionName(collection2) + .setTaskName(taskName2) .setExampleCount(1) .setSelectionCriteria(criteria) .setResumptionToken(token2) diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImplTest.java deleted file mode 100644 index 7bc9769f..00000000 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleStoreIteratorProviderImplTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.examplestore; - -import static android.federatedcompute.common.ClientConstants.EXAMPLE_STORE_ACTION; - -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.when; - -import android.content.Context; -import android.content.Intent; -import android.federatedcompute.aidl.IExampleStoreIterator; -import android.net.Uri; - -import androidx.test.core.app.ApplicationProvider; -import androidx.test.ext.junit.runners.AndroidJUnit4; - -import com.android.federatedcompute.services.common.Flags; - -import com.google.internal.federated.plan.ExampleSelector; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -@RunWith(AndroidJUnit4.class) -public final class ExampleStoreIteratorProviderImplTest { - private static final long TIMEOUT_SECS = 5L; - private static final String EXPECTED_COLLECTION_NAME = - "/federatedcompute.examplestoretest/test_collection"; - - private ExampleStoreIteratorProviderImpl mExampleStoreIteratorProvider; - private ExampleStoreServiceProviderImpl mExampleStoreServiceProvider; - private Context mContext = ApplicationProvider.getApplicationContext(); - - private String mPackageName; - - private Intent mIntent; - @Mock private Flags mMockFlags; - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - mPackageName = mContext.getPackageName(); - mExampleStoreServiceProvider = new ExampleStoreServiceProviderImpl(mContext, mMockFlags); - mIntent = new Intent(); - mIntent.setAction(EXAMPLE_STORE_ACTION).setPackage(mPackageName); - mIntent.setData( - new Uri.Builder().scheme("app").authority(mPackageName).path("collection").build()); - when(mMockFlags.getAppHostedExampleStoreTimeoutSecs()).thenReturn(TIMEOUT_SECS); - mExampleStoreIteratorProvider = - new ExampleStoreIteratorProviderImpl(mExampleStoreServiceProvider, mMockFlags); - } - - @After - public void cleanup() { - mExampleStoreServiceProvider.unbindService(); - } - - @Test - public void testGetExampleStoreIterator() throws Exception { - ExampleSelector exampleSelector = - ExampleSelector.newBuilder().setCollectionUri(EXPECTED_COLLECTION_NAME).build(); - IExampleStoreIterator iterator = - mExampleStoreIteratorProvider.getExampleStoreIterator( - mPackageName, exampleSelector); - assertNotNull(iterator); - } - - @Test - public void testGetExampleStoreIterator_fail() throws Exception { - ExampleSelector exampleSelector = - ExampleSelector.newBuilder().setCollectionUri("bad_collection_name").build(); - assertThrows( - IllegalStateException.class, - () -> - mExampleStoreIteratorProvider.getExampleStoreIterator( - mPackageName, exampleSelector)); - } -} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProviderImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProviderImplTest.java deleted file mode 100644 index d0261ab8..00000000 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/ExampleStoreServiceProviderImplTest.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.examplestore; - -import static android.federatedcompute.common.ClientConstants.EXAMPLE_STORE_ACTION; - -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.when; - -import android.content.Context; -import android.content.Intent; -import android.net.Uri; - -import androidx.test.core.app.ApplicationProvider; -import androidx.test.ext.junit.runners.AndroidJUnit4; - -import com.android.federatedcompute.services.common.Flags; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -@RunWith(AndroidJUnit4.class) -public final class ExampleStoreServiceProviderImplTest { - private ExampleStoreServiceProviderImpl mExampleStoreServiceProvider; - private Context mContext = ApplicationProvider.getApplicationContext(); - private static final long TIMEOUT_SECS = 5L; - - private Intent mIntent; - @Mock private Flags mMockFlags; - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - String packageName = mContext.getPackageName(); - mExampleStoreServiceProvider = new ExampleStoreServiceProviderImpl(mContext, mMockFlags); - mIntent = new Intent(); - mIntent.setAction(EXAMPLE_STORE_ACTION).setPackage(packageName); - mIntent.setData( - new Uri.Builder().scheme("app").authority(packageName).path("collection").build()); - when(mMockFlags.getAppHostedExampleStoreTimeoutSecs()).thenReturn(TIMEOUT_SECS); - } - - @After - public void cleanup() { - mExampleStoreServiceProvider.unbindService(); - } - - @Test - public void testBindService() { - assertTrue(mExampleStoreServiceProvider.bindService(mIntent)); - } - - @Test - public void testGetExampleStoreService() throws Exception { - mExampleStoreServiceProvider.bindService(mIntent); - - assertNotNull(mExampleStoreServiceProvider.getExampleStoreService()); - } - - @Test - public void testUnbindService() throws Exception { - assertTrue(mExampleStoreServiceProvider.bindService(mIntent)); - - mExampleStoreServiceProvider.unbindService(); - } - - @Test - public void testUnbindService_serviceNonExist() { - mExampleStoreServiceProvider.unbindService(); - } -} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/FederatedExampleIteratorTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/FederatedExampleIteratorTest.java index 1382647f..e96365e5 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/FederatedExampleIteratorTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/FederatedExampleIteratorTest.java @@ -61,12 +61,11 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; @RunWith(JUnit4.class) public final class FederatedExampleIteratorTest { private static final String APP_ID = "com.foo.bar"; - private static final String FAKE_COLLECTION = "/collection1"; + private static final String FAKE_TASK_NAME = "task-name"; private static final byte[] FAKE_CRITERIA = new byte[] {10, 0, 1}; private static final byte[] RESUMPTION_TOKEN = "token1".getBytes(Charset.defaultCharset()); private static final byte[] EXAMPLE_1 = "example1".getBytes(Charset.defaultCharset()); @@ -77,19 +76,14 @@ public final class FederatedExampleIteratorTest { MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()); private final SingleQueryRecorder mRecorder = new ExampleConsumptionRecorder() - .createRecorderForTracking(FAKE_COLLECTION, FAKE_CRITERIA); + .createRecorderForTracking(FAKE_TASK_NAME, FAKE_CRITERIA); private FederatedExampleIterator mIterator; - private AtomicReference<IExampleStoreIterator> mExampleStoreIteratorStub; - private AtomicReference<Integer> mExampleStoreIteratorError; @Mock private Flags mMockFlags; @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); - - mExampleStoreIteratorStub = new AtomicReference<>(null); - mExampleStoreIteratorError = new AtomicReference<>(null); when(mMockFlags.getAppHostedExampleStoreTimeoutSecs()).thenReturn(TIMEOUT_SECS); } @@ -98,9 +92,7 @@ public final class FederatedExampleIteratorTest { ImmutableList<byte[]> fakeResults = ImmutableList.of(EXAMPLE_1, EXAMPLE_2); FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); // Verify the mIterator works in a typical hasNext/Next/hasNext/next/hasNext. assertThat(runInBackgroundAndWait(mIterator::hasNext)).isTrue(); @@ -119,9 +111,7 @@ public final class FederatedExampleIteratorTest { ImmutableList<byte[]> fakeResults = ImmutableList.of(EXAMPLE_1, EXAMPLE_2); FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); // Verify the mIterator works in a typical hasNext/Next/hasNext/next. assertThat(runInBackgroundAndWait(mIterator::hasNext)).isTrue(); @@ -143,9 +133,7 @@ public final class FederatedExampleIteratorTest { ImmutableList<byte[]> fakeResults = ImmutableList.of(EXAMPLE_1, EXAMPLE_2); FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); // Verify the mIterator works if only next() is called and hasNext() never called. assertThat(runInBackgroundAndWait(mIterator::next)).isEqualTo(EXAMPLE_1); @@ -161,9 +149,7 @@ public final class FederatedExampleIteratorTest { ImmutableList<byte[]> fakeResults = ImmutableList.of(EXAMPLE_1, EXAMPLE_2); FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); runInBackgroundAndWait(mIterator::close); } @@ -173,9 +159,7 @@ public final class FederatedExampleIteratorTest { ImmutableList<byte[]> fakeResults = ImmutableList.of(EXAMPLE_1, EXAMPLE_2); FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); assertThat(runInBackgroundAndWait(mIterator::hasNext)).isTrue(); runInBackgroundAndWait(mIterator::close); @@ -188,9 +172,7 @@ public final class FederatedExampleIteratorTest { ImmutableList<byte[]> fakeResults = ImmutableList.of(EXAMPLE_1, EXAMPLE_2); FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); assertThat(runInBackgroundAndWait(mIterator::next)).isEqualTo(EXAMPLE_1); assertThat(runInBackgroundAndWait(mIterator::hasNext)).isTrue(); @@ -205,9 +187,7 @@ public final class FederatedExampleIteratorTest { FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults, STATUS_INTERNAL_ERROR); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); runInBackgroundAndWait(mIterator::next); @@ -218,7 +198,7 @@ public final class FederatedExampleIteratorTest { ErrorStatusException errorStatusException = (ErrorStatusException) exception.getCause(); assertThat(errorStatusException.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE_VALUE); assertThat(errorStatusException.getStatus().getMessage()) - .isEqualTo("OnIteratorNextFailure: /collection1 1"); + .isEqualTo("OnIteratorNextFailure: 1"); assertThat(fakeIterator.mClosed.get()).isEqualTo(1); runInBackgroundAndWait(mIterator::close); assertThat(fakeIterator.mClosed.get()).isEqualTo(1); @@ -230,9 +210,7 @@ public final class FederatedExampleIteratorTest { FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults, STATUS_INTERNAL_ERROR); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); runInBackgroundAndWait(mIterator::next); @@ -243,7 +221,7 @@ public final class FederatedExampleIteratorTest { ErrorStatusException errorStatusException = (ErrorStatusException) exception.getCause(); assertThat(errorStatusException.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE_VALUE); assertThat(errorStatusException.getStatus().getMessage()) - .isEqualTo("OnIteratorNextFailure: /collection1 1"); + .isEqualTo("OnIteratorNextFailure: 1"); assertThat(fakeIterator.mClosed.get()).isEqualTo(1); runInBackgroundAndWait(mIterator::close); assertThat(fakeIterator.mClosed.get()).isEqualTo(1); @@ -255,9 +233,7 @@ public final class FederatedExampleIteratorTest { FakeExampleStoreIterator fakeIterator = new FakeExampleStoreIterator(fakeResults, STATUS_INTERNAL_ERROR); - mIterator = - new FederatedExampleIterator( - fakeIterator, FAKE_COLLECTION, FAKE_CRITERIA, RESUMPTION_TOKEN, mRecorder); + mIterator = new FederatedExampleIterator(fakeIterator, RESUMPTION_TOKEN, mRecorder); runInBackgroundAndWait(mIterator::close); diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/SampleExampleStoreService.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/SampleExampleStoreService.java index a35abf2b..0546cb32 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/SampleExampleStoreService.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/examplestore/SampleExampleStoreService.java @@ -16,8 +16,8 @@ package com.android.federatedcompute.services.examplestore; -import static android.federatedcompute.common.ClientConstants.EXTRA_COLLECTION_NAME; import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESULT; +import static android.federatedcompute.common.ClientConstants.EXTRA_TASK_NAME; import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; import android.annotation.NonNull; @@ -40,8 +40,7 @@ import java.util.List; /** A sample ExampleStoreService implementation. */ public class SampleExampleStoreService extends ExampleStoreService { - private static final String EXPECTED_COLLECTION_NAME = - "/federatedcompute.examplestoretest/test_collection"; + private static final String EXPECTED_TASK_NAME = "test_task"; private static final Example EXAMPLE_PROTO_1 = Example.newBuilder() .setFeatures( @@ -59,8 +58,8 @@ public class SampleExampleStoreService extends ExampleStoreService { @Override public void startQuery(@NonNull Bundle params, @NonNull QueryCallback callback) { - String collection = params.getString(EXTRA_COLLECTION_NAME); - if (!collection.equals(EXPECTED_COLLECTION_NAME)) { + String collection = params.getString(EXTRA_TASK_NAME); + if (!collection.equals(EXPECTED_TASK_NAME)) { callback.onStartQueryFailure(STATUS_INTERNAL_ERROR); return; } @@ -68,6 +67,11 @@ public class SampleExampleStoreService extends ExampleStoreService { new ListExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1))); } + @Override + protected boolean checkCallerPermission() { + return true; + } + /** * A simple {@link ExampleStoreIterator} that returns the contents of the {@link List} it's * constructed with. diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/FederatedComputeHttpResponseTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/FederatedComputeHttpResponseTest.java index 28627ee1..2497b19d 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/FederatedComputeHttpResponseTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/FederatedComputeHttpResponseTest.java @@ -22,18 +22,17 @@ import static org.junit.Assert.assertThrows; import static java.nio.charset.StandardCharsets.UTF_8; -import androidx.test.ext.junit.runners.AndroidJUnit4; - import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import java.util.List; import java.util.Map; -@RunWith(AndroidJUnit4.class) +@RunWith(JUnit4.class) public final class FederatedComputeHttpResponseTest { @Test public void testBuildWithAllValues() { diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpClientTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpClientTest.java index a1d235c9..2674f2d3 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpClientTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpClientTest.java @@ -26,13 +26,12 @@ import static org.mockito.Mockito.when; import static java.nio.charset.StandardCharsets.UTF_8; -import androidx.test.ext.junit.runners.AndroidJUnit4; - import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.Spy; @@ -49,7 +48,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -@RunWith(AndroidJUnit4.class) +@RunWith(JUnit4.class) public final class HttpClientTest { @Spy private HttpClient mHttpClient = new HttpClient(); @Rule public MockitoRule rule = MockitoJUnit.rule(); diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java index 92f7b7fe..fb1d95c4 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java @@ -16,83 +16,88 @@ package com.android.federatedcompute.services.http; -import static com.android.federatedcompute.services.http.HttpClientUtil.API_KEY_HDR; import static com.android.federatedcompute.services.http.HttpClientUtil.CONTENT_LENGTH_HDR; import static com.android.federatedcompute.services.http.HttpClientUtil.CONTENT_TYPE_HDR; -import static com.android.federatedcompute.services.http.HttpClientUtil.FAKE_API_KEY; import static com.android.federatedcompute.services.http.HttpClientUtil.PROTOBUF_CONTENT_TYPE; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.util.concurrent.Futures.immediateFuture; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.when; import static java.nio.charset.StandardCharsets.UTF_8; -import androidx.test.ext.junit.runners.AndroidJUnit4; - import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; +import com.android.federatedcompute.services.testutils.TrainingTestUtil; +import com.android.federatedcompute.services.training.util.ComputationResult; -import com.google.internal.federatedcompute.v1.ByteStreamResource; +import com.google.common.io.Files; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; +import com.google.internal.federated.plan.ClientOnlyPlan; import com.google.internal.federatedcompute.v1.ClientVersion; -import com.google.internal.federatedcompute.v1.ForwardingInfo; import com.google.internal.federatedcompute.v1.RejectionInfo; import com.google.internal.federatedcompute.v1.Resource; -import com.google.internal.federatedcompute.v1.Resource.InlineResource; -import com.google.internal.federatedcompute.v1.ResourceCapabilities; -import com.google.internal.federatedcompute.v1.ResourceCompressionFormat; -import com.google.internal.federatedcompute.v1.StartAggregationDataUploadResponse; -import com.google.internal.federatedcompute.v1.StartTaskAssignmentRequest; -import com.google.internal.federatedcompute.v1.StartTaskAssignmentResponse; -import com.google.internal.federatedcompute.v1.TaskAssignment; -import com.google.protobuf.ByteString; +import com.google.ondevicepersonalization.federatedcompute.proto.CreateTaskAssignmentRequest; +import com.google.ondevicepersonalization.federatedcompute.proto.CreateTaskAssignmentResponse; +import com.google.ondevicepersonalization.federatedcompute.proto.ReportResultRequest; +import com.google.ondevicepersonalization.federatedcompute.proto.ReportResultRequest.Result; +import com.google.ondevicepersonalization.federatedcompute.proto.ReportResultResponse; +import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment; +import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; +import java.io.File; import java.util.HashMap; import java.util.List; import java.util.concurrent.ExecutionException; -@RunWith(AndroidJUnit4.class) +@RunWith(JUnit4.class) public final class HttpFederatedProtocolTest { - private static final String TASK_ASSIGNMENT_TARGET_URI = "https://taskassignment.uri/"; - private static final String AGGREGATION_TARGET_URI = "https://aggregation.uri/"; + private static final String TASK_ASSIGNMENT_TARGET_URI = "https://test-server.com/"; private static final String PLAN_URI = "https://fake.uri/plan"; private static final String CHECKPOINT_URI = "https://fake.uri/checkpoint"; - private static final String BYTE_STREAM_TARGET_URI = "https://bytestream.uri/"; - private static final String SECOND_STAGE_AGGREGATION_TARGET_URI = - "https://aggregation.second.uri/"; - private static final String POPULATION_NAME = "TEST/POPULATION"; - private static final byte[] PLAN = "CLIENT_ONLY_PLAN".getBytes(UTF_8); - private static final String INIT_CHECKPOINT = "INIT_CHECKPOINT"; + private static final String START_TASK_ASSIGNMENT_URI = + "https://test-server.com/taskassignment/v1/population/test_population:create-task-assignment"; + private static final String REPORT_RESULT_URI = + "https://test-server.com/taskassignment/v1/population/test_population/task/task-id/" + + "aggregation/aggregation-id/task-assignment/assignment-id:report-result"; + private static final String UPLOAD_LOCATION_URI = "https://dataupload.uri"; + private static final String POPULATION_NAME = "test_population"; + private static final byte[] CHECKPOINT = "INIT_CHECKPOINT".getBytes(UTF_8); private static final String CLIENT_VERSION = "CLIENT_VERSION"; - private static final String CLIENT_SESSION_ID = "CLIENT_SESSION_ID"; - private static final String AGGREGATION_SESSION_ID = "AGGREGATION_SESSION_ID"; - private static final String AUTHORIZATION_TOKEN = "AUTHORIZATION_TOKEN"; - private static final String RESOURCE_NAME = "CHECKPOINT_RESOURCE"; - private static final String CLIENT_TOKEN = "CLIENT_TOKEN"; - private static final byte[] COMPUTATION_RESULT = "COMPUTATION_RESULT".getBytes(UTF_8); - - private static final StartTaskAssignmentRequest START_TASK_ASSIGNMENT_REQUEST = - StartTaskAssignmentRequest.newBuilder() + private static final String TASK_ID = "task-id"; + private static final String ASSIGNMENT_ID = "assignment-id"; + private static final String AGGREGATION_ID = "aggregation-id"; + private static final String OCTET_STREAM = "application/octet-stream"; + private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = + FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build(); + private static final FLRunnerResult FL_RUNNER_FAIL_RESULT = + FLRunnerResult.newBuilder().setContributionResult(ContributionResult.FAIL).build(); + private static final FederatedComputeHttpResponse CHECKPOINT_HTTP_RESPONSE = + new FederatedComputeHttpResponse.Builder() + .setStatusCode(200) + .setPayload(CHECKPOINT) + .build(); + + private static final CreateTaskAssignmentRequest START_TASK_ASSIGNMENT_REQUEST = + CreateTaskAssignmentRequest.newBuilder() .setClientVersion(ClientVersion.newBuilder().setVersionCode(CLIENT_VERSION)) - .setPopulationName(POPULATION_NAME) - .setResourceCapabilities( - ResourceCapabilities.newBuilder() - .addSupportedCompressionFormats( - ResourceCompressionFormat - .RESOURCE_COMPRESSION_FORMAT_GZIP) - .build()) .build(); + private static final FederatedComputeHttpResponse SUCCESS_EMPTY_HTTP_RESPONSE = new FederatedComputeHttpResponse.Builder().setStatusCode(200).build(); @@ -113,19 +118,8 @@ public final class HttpFederatedProtocolTest { } @Test - public void testTaskAssignedPlanDataFetchSuccess() throws Exception { - FederatedComputeHttpResponse startTaskAssignmentResponse = - createStartTaskAssignmentHttpResponse(); - FederatedComputeHttpResponse planHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(PLAN) - .build(); - // The workflow is start task assignment and download plan. The checkpoint is defined as - // inline resource. - when(mMockHttpClient.performRequest(mHttpRequestCaptor.capture())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse); + public void testIssueCheckinSuccess() throws Exception { + setUpHttpFederatedProtocol(); mHttpFederatedProtocol.issueCheckin().get(); @@ -133,29 +127,22 @@ public final class HttpFederatedProtocolTest { // Verify task assignment request. FederatedComputeHttpRequest actualStartTaskAssignmentRequest = actualHttpRequests.get(0); - assertThat(actualStartTaskAssignmentRequest.getUri()) - .isEqualTo( - "https://taskassignment.uri/v1/populations/TEST/POPULATION/taskassignments:start?%24alt=proto"); + assertThat(actualStartTaskAssignmentRequest.getUri()).isEqualTo(START_TASK_ASSIGNMENT_URI); assertThat(actualStartTaskAssignmentRequest.getBody()) .isEqualTo(START_TASK_ASSIGNMENT_REQUEST.toByteArray()); assertThat(actualStartTaskAssignmentRequest.getHttpMethod()).isEqualTo(HttpMethod.POST); HashMap<String, String> expectedHeaders = new HashMap<>(); - expectedHeaders.put(API_KEY_HDR, FAKE_API_KEY); - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(40)); + expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(18)); expectedHeaders.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); assertThat(actualStartTaskAssignmentRequest.getExtraHeaders()) .containsExactlyEntriesIn(expectedHeaders); - - FederatedComputeHttpRequest actualPlanHttpRequest = actualHttpRequests.get(1); - assertThat(actualPlanHttpRequest.getUri()).isEqualTo(PLAN_URI); - assertThat(actualPlanHttpRequest.getHttpMethod()).isEqualTo(HttpMethod.GET); } @Test - public void testCheckinFailsFromHttp() throws Exception { + public void testCreateTaskAssignmentFailed() throws Exception { FederatedComputeHttpResponse httpResponse = new FederatedComputeHttpResponse.Builder().setStatusCode(404).build(); - when(mMockHttpClient.performRequest(any())).thenReturn(httpResponse); + when(mMockHttpClient.performRequestAsync(any())).thenReturn(immediateFuture(httpResponse)); ExecutionException exception = assertThrows( @@ -165,21 +152,21 @@ public final class HttpFederatedProtocolTest { assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class); assertThat(exception.getCause()) .hasMessageThat() - .isEqualTo("start task assignment failed: 404"); + .isEqualTo("Start task assignment failed: 404"); } @Test - public void testCheckinRejection() throws Exception { - StartTaskAssignmentResponse startTaskAssignmentResponse = - StartTaskAssignmentResponse.newBuilder() + public void testCreateTaskAssignmentRejection() throws Exception { + CreateTaskAssignmentResponse createTaskAssignmentResponse = + CreateTaskAssignmentResponse.newBuilder() .setRejectionInfo(RejectionInfo.getDefaultInstance()) .build(); FederatedComputeHttpResponse httpResponse = new FederatedComputeHttpResponse.Builder() .setStatusCode(200) - .setPayload(startTaskAssignmentResponse.toByteArray()) + .setPayload(createTaskAssignmentResponse.toByteArray()) .build(); - when(mMockHttpClient.performRequest(any())).thenReturn(httpResponse); + when(mMockHttpClient.performRequestAsync(any())).thenReturn(immediateFuture(httpResponse)); ExecutionException exception = assertThrows( @@ -191,16 +178,19 @@ public final class HttpFederatedProtocolTest { } @Test - public void testTaskAssignedPlanDataFetchFailed() throws Exception { - FederatedComputeHttpResponse startTaskAssignmentResponse = - createStartTaskAssignmentHttpResponse(); + public void testTaskAssignmentSuccessPlanFetchFailed() throws Exception { FederatedComputeHttpResponse planHttpResponse = new FederatedComputeHttpResponse.Builder().setStatusCode(404).build(); - // The workflow is start task assignment and download plan. The checkpoint is defined as - // inline resource. - when(mMockHttpClient.performRequest(any())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse); + // The workflow: start task assignment success, download plan failed and download + // checkpoint success. + setUpHttpFederatedProtocol( + createStartTaskAssignmentHttpResponse(), + planHttpResponse, + CHECKPOINT_HTTP_RESPONSE, + /** reportResultHttpResponse= */ + null, + /** uploadResultHttpResponse= */ + null); ExecutionException exception = assertThrows( @@ -208,285 +198,231 @@ public final class HttpFederatedProtocolTest { () -> mHttpFederatedProtocol.issueCheckin().get()); assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); - assertThat(exception.getCause()).hasMessageThat().isEqualTo("plan fetch failed: 404"); + assertThat(exception.getCause()).hasMessageThat().isEqualTo("Fetch plan failed: 404"); } @Test - public void testTaskAssignedCheckpointDataFetchFailed() throws Exception { - StartTaskAssignmentResponse taskAssignmentResponse = - createStartTaskAssignmentResponse( - Resource.newBuilder().setUri(PLAN_URI).build(), - Resource.newBuilder().setUri(CHECKPOINT_URI).build()); - FederatedComputeHttpResponse startTaskAssignmentResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(taskAssignmentResponse.toByteArray()) - .build(); - FederatedComputeHttpResponse planHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(PLAN) - .build(); + public void testTaskAssignmentSuccessCheckpointDataFetchFailed() throws Exception { FederatedComputeHttpResponse checkpointHttpResponse = new FederatedComputeHttpResponse.Builder().setStatusCode(404).build(); + // The workflow: start task assignment success, download plan success and download // checkpoint failed. - when(mMockHttpClient.performRequest(any())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse) - .thenReturn(checkpointHttpResponse); + setUpHttpFederatedProtocol( + createStartTaskAssignmentHttpResponse(), + createPlanHttpResponse(), + checkpointHttpResponse, + /** reportResultHttpResponse= */ + null, + /** uploadResultHttpResponse= */ + null); ExecutionException exception = assertThrows( ExecutionException.class, () -> mHttpFederatedProtocol.issueCheckin().get()); - System.out.println(mHttpRequestCaptor.getAllValues()); assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); } @Test - public void testReportViaSimpleAggregation() throws Exception { - FederatedComputeHttpResponse startTaskAssignmentResponse = - createStartTaskAssignmentHttpResponse(); - FederatedComputeHttpResponse planHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(PLAN) - .build(); - StartAggregationDataUploadResponse startAggregationDataUploadResponse = - StartAggregationDataUploadResponse.newBuilder() - .setAggregationProtocolForwardingInfo( - ForwardingInfo.newBuilder() - .setTargetUriPrefix(SECOND_STAGE_AGGREGATION_TARGET_URI) - .build()) - .setResource( - ByteStreamResource.newBuilder() - .setResourceName(RESOURCE_NAME) - .setDataUploadForwardingInfo( - ForwardingInfo.newBuilder() - .setTargetUriPrefix(BYTE_STREAM_TARGET_URI)) - .build()) - .setClientToken(CLIENT_TOKEN) - .build(); - FederatedComputeHttpResponse startAggregationDataUploadHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(startAggregationDataUploadResponse.toByteArray()) - .build(); + public void testReportFailedTrainingResult_returnSuccess() throws Exception { + ComputationResult computationResult = + new ComputationResult(createOutputCheckpointFile(), FL_RUNNER_FAIL_RESULT, null); + + setUpHttpFederatedProtocol(); + // Setup task id, aggregation id for report result. + mHttpFederatedProtocol.issueCheckin().get(); + + mHttpFederatedProtocol.reportResult(computationResult).get(); + + // Verify ReportResult request. + List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues(); + assertThat(actualHttpRequests).hasSize(4); + FederatedComputeHttpRequest acutalReportResultRequest = actualHttpRequests.get(3); + ReportResultRequest reportResultRequest = + ReportResultRequest.newBuilder().setResult(Result.FAILED).build(); + assertThat(acutalReportResultRequest.getBody()) + .isEqualTo(reportResultRequest.toByteArray()); + } - when(mMockHttpClient.performRequest(mHttpRequestCaptor.capture())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse) - .thenReturn(startAggregationDataUploadHttpResponse) - .thenReturn(SUCCESS_EMPTY_HTTP_RESPONSE) - .thenReturn(SUCCESS_EMPTY_HTTP_RESPONSE); + @Test + public void testReportAndUploadResultSuccess() throws Exception { + ComputationResult computationResult = + new ComputationResult(createOutputCheckpointFile(), FL_RUNNER_SUCCESS_RESULT, null); + setUpHttpFederatedProtocol(); + // Setup task id, aggregation id for report result. mHttpFederatedProtocol.issueCheckin().get(); - mHttpFederatedProtocol.reportViaSimpleAggregation(COMPUTATION_RESULT).get(); - // Verify start aggregation request. + mHttpFederatedProtocol.reportResult(computationResult).get(); + + // Verify ReportResult request. List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues(); - FederatedComputeHttpRequest acutalStartAggregationRequest = actualHttpRequests.get(2); - assertThat(acutalStartAggregationRequest.getUri()) - .isEqualTo( - "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto"); - assertThat(acutalStartAggregationRequest.getHttpMethod()).isEqualTo(HttpMethod.POST); + FederatedComputeHttpRequest acutalReportResultRequest = actualHttpRequests.get(3); + assertThat(acutalReportResultRequest.getUri()).isEqualTo(REPORT_RESULT_URI); + assertThat(acutalReportResultRequest.getHttpMethod()).isEqualTo(HttpMethod.PUT); + ReportResultRequest reportResultRequest = + ReportResultRequest.newBuilder().setResult(Result.COMPLETED).build(); + assertThat(acutalReportResultRequest.getBody()) + .isEqualTo(reportResultRequest.toByteArray()); HashMap<String, String> expectedHeaders = new HashMap<>(); - expectedHeaders.put(API_KEY_HDR, FAKE_API_KEY); - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(45)); + expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(2)); expectedHeaders.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); - assertThat(acutalStartAggregationRequest.getExtraHeaders()).isEqualTo(expectedHeaders); + assertThat(acutalReportResultRequest.getExtraHeaders()).isEqualTo(expectedHeaders); // Verify upload data request. - FederatedComputeHttpRequest actualDataUploadRequest = actualHttpRequests.get(3); - assertThat(actualDataUploadRequest.getUri()) - .isEqualTo( - "https://bytestream.uri/upload/v1/media/CHECKPOINT_RESOURCE?upload_protocol=raw"); - assertThat(acutalStartAggregationRequest.getHttpMethod()).isEqualTo(HttpMethod.POST); + FederatedComputeHttpRequest actualDataUploadRequest = actualHttpRequests.get(4); + assertThat(actualDataUploadRequest.getUri()).isEqualTo(UPLOAD_LOCATION_URI); + assertThat(acutalReportResultRequest.getHttpMethod()).isEqualTo(HttpMethod.PUT); expectedHeaders = new HashMap<>(); - expectedHeaders.put(API_KEY_HDR, FAKE_API_KEY); - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(18)); + expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(17)); + expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM); assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders); - - // Verify submit aggregation report request. - FederatedComputeHttpRequest actualSubmitAggregationReportRequest = - actualHttpRequests.get(4); - assertThat(actualSubmitAggregationReportRequest.getUri()) - .isEqualTo( - "https://aggregation.second.uri/v1/aggregations/AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto"); - assertThat(actualSubmitAggregationReportRequest.getHttpMethod()).isEqualTo(HttpMethod.POST); - expectedHeaders = new HashMap<>(); - expectedHeaders.put(API_KEY_HDR, FAKE_API_KEY); - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(21)); - expectedHeaders.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); - assertThat(actualSubmitAggregationReportRequest.getExtraHeaders()) - .isEqualTo(expectedHeaders); } @Test - public void testReportCompleteStartAggregationFailed() throws Exception { - FederatedComputeHttpResponse startTaskAssignmentResponse = - createStartTaskAssignmentHttpResponse(); - FederatedComputeHttpResponse planHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(PLAN) - .build(); - - when(mMockHttpClient.performRequest(any())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse) - .thenReturn(new FederatedComputeHttpResponse.Builder().setStatusCode(503).build()); + public void testReportResultFailed() throws Exception { + FederatedComputeHttpResponse reportResultHttpResponse = + new FederatedComputeHttpResponse.Builder().setStatusCode(503).build(); + ComputationResult computationResult = + new ComputationResult(createOutputCheckpointFile(), FL_RUNNER_SUCCESS_RESULT, null); + + setUpHttpFederatedProtocol( + createStartTaskAssignmentHttpResponse(), + createPlanHttpResponse(), + CHECKPOINT_HTTP_RESPONSE, + reportResultHttpResponse, + null); mHttpFederatedProtocol.issueCheckin().get(); ExecutionException exception = assertThrows( ExecutionException.class, - () -> - mHttpFederatedProtocol - .reportViaSimpleAggregation(COMPUTATION_RESULT) - .get()); + () -> mHttpFederatedProtocol.reportResult(computationResult).get()); assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class); - assertThat(exception.getCause()) - .hasMessageThat() - .isEqualTo("start data upload failed: 503"); + assertThat(exception.getCause()).hasMessageThat().isEqualTo("ReportResult failed: 503"); } @Test - public void testReportCompleteUploadFailed() throws Exception { - FederatedComputeHttpResponse startTaskAssignmentResponse = - createStartTaskAssignmentHttpResponse(); - FederatedComputeHttpResponse planHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(PLAN) - .build(); - StartAggregationDataUploadResponse startAggregationDataUploadResponse = - StartAggregationDataUploadResponse.newBuilder() - .setAggregationProtocolForwardingInfo( - ForwardingInfo.newBuilder() - .setTargetUriPrefix(SECOND_STAGE_AGGREGATION_TARGET_URI) - .build()) - .setResource( - ByteStreamResource.newBuilder() - .setResourceName(RESOURCE_NAME) - .setDataUploadForwardingInfo( - ForwardingInfo.newBuilder() - .setTargetUriPrefix(BYTE_STREAM_TARGET_URI)) - .build()) - .setClientToken(CLIENT_TOKEN) - .build(); - FederatedComputeHttpResponse startAggregationDataUploadHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(startAggregationDataUploadResponse.toByteArray()) - .build(); - - when(mMockHttpClient.performRequest(any())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse) - .thenReturn(startAggregationDataUploadHttpResponse) - .thenReturn(new FederatedComputeHttpResponse.Builder().setStatusCode(503).build()); + public void testReportResultSuccessUploadFailed() throws Exception { + FederatedComputeHttpResponse uploadResultHttpResponse = + new FederatedComputeHttpResponse.Builder().setStatusCode(503).build(); + ComputationResult computationResult = + new ComputationResult(createOutputCheckpointFile(), FL_RUNNER_SUCCESS_RESULT, null); + + setUpHttpFederatedProtocol( + createStartTaskAssignmentHttpResponse(), + createPlanHttpResponse(), + CHECKPOINT_HTTP_RESPONSE, + createReportResultHttpResponse(), + uploadResultHttpResponse); mHttpFederatedProtocol.issueCheckin().get(); ExecutionException exception = assertThrows( ExecutionException.class, - () -> - mHttpFederatedProtocol - .reportViaSimpleAggregation(COMPUTATION_RESULT) - .get()); + () -> mHttpFederatedProtocol.reportResult(computationResult).get()); assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); - assertThat(exception.getCause()) - .hasMessageThat() - .isEqualTo("upload failed: 503 CHECKPOINT_RESOURCE"); + assertThat(exception.getCause()).hasMessageThat().isEqualTo("Upload result failed: 503"); } - @Test - public void testReportCompleteSubmitAggregationFailed() throws Exception { - FederatedComputeHttpResponse startTaskAssignmentResponse = - createStartTaskAssignmentHttpResponse(); - FederatedComputeHttpResponse planHttpResponse = - new FederatedComputeHttpResponse.Builder() - .setStatusCode(200) - .setPayload(PLAN) - .build(); - StartAggregationDataUploadResponse startAggregationDataUploadResponse = - StartAggregationDataUploadResponse.newBuilder() - .setAggregationProtocolForwardingInfo( - ForwardingInfo.newBuilder() - .setTargetUriPrefix(SECOND_STAGE_AGGREGATION_TARGET_URI) - .build()) - .setResource( - ByteStreamResource.newBuilder() - .setResourceName(RESOURCE_NAME) - .setDataUploadForwardingInfo( - ForwardingInfo.newBuilder() - .setTargetUriPrefix(BYTE_STREAM_TARGET_URI)) - .build()) - .setClientToken(CLIENT_TOKEN) - .build(); - FederatedComputeHttpResponse startAggregationDataUploadHttpResponse = + private String createOutputCheckpointFile() throws Exception { + File outputCheckpointFile = File.createTempFile("output", ".ckp"); + Files.write("output checkpoint".getBytes(), outputCheckpointFile); + outputCheckpointFile.deleteOnExit(); + return outputCheckpointFile.getAbsolutePath(); + } + + private FederatedComputeHttpResponse createPlanHttpResponse() { + ClientOnlyPlan clientOnlyPlan = TrainingTestUtil.createFederatedAnalyticClientPlan(); + return new FederatedComputeHttpResponse.Builder() + .setStatusCode(200) + .setPayload(clientOnlyPlan.toByteArray()) + .build(); + } + + private void setUpHttpFederatedProtocol() throws Exception { + FederatedComputeHttpResponse checkpointHttpResponse = new FederatedComputeHttpResponse.Builder() .setStatusCode(200) - .setPayload(startAggregationDataUploadResponse.toByteArray()) + .setPayload(CHECKPOINT) .build(); + setUpHttpFederatedProtocol( + createStartTaskAssignmentHttpResponse(), + createPlanHttpResponse(), + checkpointHttpResponse, + createReportResultHttpResponse(), + SUCCESS_EMPTY_HTTP_RESPONSE); + } - when(mMockHttpClient.performRequest(any())) - .thenReturn(startTaskAssignmentResponse) - .thenReturn(planHttpResponse) - .thenReturn(startAggregationDataUploadHttpResponse) - .thenReturn(SUCCESS_EMPTY_HTTP_RESPONSE) - .thenReturn(new FederatedComputeHttpResponse.Builder().setStatusCode(503).build()); - - mHttpFederatedProtocol.issueCheckin().get(); - ExecutionException exception = - assertThrows( - ExecutionException.class, - () -> - mHttpFederatedProtocol - .reportViaSimpleAggregation(COMPUTATION_RESULT) - .get()); + private void setUpHttpFederatedProtocol( + FederatedComputeHttpResponse createTaskAssignmentResponse, + FederatedComputeHttpResponse planHttpResponse, + FederatedComputeHttpResponse checkpointHttpResponse, + FederatedComputeHttpResponse reportResultHttpResponse, + FederatedComputeHttpResponse uploadResultHttpResponse) + throws Exception { + doAnswer( + invocation -> { + FederatedComputeHttpRequest httpRequest = invocation.getArgument(0); + String uri = httpRequest.getUri(); + if (uri.equals(PLAN_URI)) { + return immediateFuture(planHttpResponse); + } else if (uri.equals(CHECKPOINT_URI)) { + return immediateFuture(checkpointHttpResponse); + } else if (uri.equals(START_TASK_ASSIGNMENT_URI)) { + return immediateFuture(createTaskAssignmentResponse); + } else if (uri.equals(REPORT_RESULT_URI)) { + return immediateFuture(reportResultHttpResponse); + } else if (uri.equals(UPLOAD_LOCATION_URI)) { + return immediateFuture(uploadResultHttpResponse); + } + return immediateFuture(SUCCESS_EMPTY_HTTP_RESPONSE); + }) + .when(mMockHttpClient) + .performRequestAsync(mHttpRequestCaptor.capture()); + } - assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); - assertThat(exception.getCause()) - .hasMessageThat() - .isEqualTo("submit aggregation result failed: 503 CHECKPOINT_RESOURCE"); + private FederatedComputeHttpResponse createReportResultHttpResponse() throws Exception { + UploadInstruction.Builder uploadInstruction = + UploadInstruction.newBuilder().setUploadLocation(UPLOAD_LOCATION_URI); + uploadInstruction.putExtraRequestHeaders(CONTENT_TYPE_HDR, OCTET_STREAM); + ReportResultResponse reportResultResponse = + ReportResultResponse.newBuilder() + .setUploadInstruction(uploadInstruction.build()) + .build(); + return new FederatedComputeHttpResponse.Builder() + .setStatusCode(200) + .setPayload(reportResultResponse.toByteArray()) + .build(); } private FederatedComputeHttpResponse createStartTaskAssignmentHttpResponse() throws Exception { - StartTaskAssignmentResponse startTaskAssignmentResponse = - createStartTaskAssignmentResponse( + CreateTaskAssignmentResponse createTaskAssignmentResponse = + createCreateTaskAssignmentResponse( Resource.newBuilder().setUri(PLAN_URI).build(), - Resource.newBuilder() - .setInlineResource( - InlineResource.newBuilder() - .setData(ByteString.copyFromUtf8(INIT_CHECKPOINT)) - .build()) - .build()); + Resource.newBuilder().setUri(CHECKPOINT_URI).build()); return new FederatedComputeHttpResponse.Builder() .setStatusCode(200) - .setPayload(startTaskAssignmentResponse.toByteArray()) + .setPayload(createTaskAssignmentResponse.toByteArray()) .build(); } - private StartTaskAssignmentResponse createStartTaskAssignmentResponse( + private CreateTaskAssignmentResponse createCreateTaskAssignmentResponse( Resource plan, Resource checkpoint) { - ForwardingInfo forwardingInfo = - ForwardingInfo.newBuilder().setTargetUriPrefix(AGGREGATION_TARGET_URI).build(); TaskAssignment taskAssignment = TaskAssignment.newBuilder() - .setSessionId(CLIENT_SESSION_ID) - .setAggregationId(AGGREGATION_SESSION_ID) - .setAuthorizationToken(AUTHORIZATION_TOKEN) + .setPopulationName(POPULATION_NAME) + .setAggregationId(AGGREGATION_ID) + .setTaskId(TASK_ID) + .setAssignmentId(ASSIGNMENT_ID) .setPlan(plan) .setInitCheckpoint(checkpoint) - .setAggregationDataForwardingInfo(forwardingInfo) .build(); - return StartTaskAssignmentResponse.newBuilder().setTaskAssignment(taskAssignment).build(); + return CreateTaskAssignmentResponse.newBuilder().setTaskAssignment(taskAssignment).build(); } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/ProtocolRequestCreatorTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/ProtocolRequestCreatorTest.java index e69d15bf..1fa86e40 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/ProtocolRequestCreatorTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/ProtocolRequestCreatorTest.java @@ -16,7 +16,6 @@ package com.android.federatedcompute.services.http; -import static com.android.federatedcompute.services.http.HttpClientUtil.API_KEY_HDR; import static com.android.federatedcompute.services.http.HttpClientUtil.CONTENT_LENGTH_HDR; import static com.android.federatedcompute.services.http.HttpClientUtil.CONTENT_TYPE_HDR; import static com.android.federatedcompute.services.http.HttpClientUtil.PROTOBUF_CONTENT_TYPE; @@ -25,45 +24,35 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import androidx.test.ext.junit.runners.AndroidJUnit4; - import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; import com.google.internal.federatedcompute.v1.ForwardingInfo; import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import java.util.HashMap; -@RunWith(AndroidJUnit4.class) +@RunWith(JUnit4.class) public final class ProtocolRequestCreatorTest { private static final String REQUEST_BASE_URI = "https://initial.uri"; - private static final String API_KEY = "apiKey"; private static final byte[] REQUEST_BODY = "expectedBody".getBytes(); private static final String AGGREGATION_TARGET_URI = "https://aggregation.uri/"; @Test public void testCreateProtobufEncodedRequest() { ProtocolRequestCreator requestCreator = - new ProtocolRequestCreator( - REQUEST_BASE_URI, - API_KEY, - new HashMap<String, String>(), - /* useCompression= */ false); + new ProtocolRequestCreator(REQUEST_BASE_URI, new HashMap<String, String>(), false); FederatedComputeHttpRequest request = requestCreator.createProtoRequest( - "/v1/request", - HttpMethod.POST, - REQUEST_BODY, - /* isProtobufEncoded= */ true); + "/v1/request", HttpMethod.POST, REQUEST_BODY, true); - assertThat(request.getUri()).isEqualTo("https://initial.uri/v1/request?%24alt=proto"); + assertThat(request.getUri()).isEqualTo("https://initial.uri/v1/request"); assertThat(request.getHttpMethod()).isEqualTo(HttpMethod.POST); assertThat(request.getBody()).isEqualTo(REQUEST_BODY); HashMap<String, String> expectedHeaders = new HashMap<String, String>(); - expectedHeaders.put(API_KEY_HDR, API_KEY); expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(12)); expectedHeaders.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); assertThat(request.getExtraHeaders()).isEqualTo(expectedHeaders); @@ -76,9 +65,7 @@ public final class ProtocolRequestCreatorTest { IllegalArgumentException.class, () -> ProtocolRequestCreator.create( - API_KEY, - ForwardingInfo.getDefaultInstance(), - /* useCompression= */ false)); + ForwardingInfo.getDefaultInstance(), false)); assertThat(exception) .hasMessageThat() @@ -88,21 +75,14 @@ public final class ProtocolRequestCreatorTest { @Test public void testCreateProtocolRequestInvalidSuffix() { ProtocolRequestCreator requestCreator = - new ProtocolRequestCreator( - REQUEST_BASE_URI, - API_KEY, - new HashMap<String, String>(), - /* useCompression= */ false); + new ProtocolRequestCreator(REQUEST_BASE_URI, new HashMap<String, String>(), false); IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, () -> requestCreator.createProtoRequest( - "v1/request", - HttpMethod.POST, - REQUEST_BODY, - /* isProtobufEncoded= */ true)); + "v1/request", HttpMethod.POST, REQUEST_BODY, false)); assertThat(exception) .hasMessageThat() @@ -110,45 +90,34 @@ public final class ProtocolRequestCreatorTest { } @Test - public void testCreateProtoRequest() { + public void testCreateProtocolRequestWithForwardingInfo() { + ForwardingInfo forwardingInfo = + ForwardingInfo.newBuilder().setTargetUriPrefix(AGGREGATION_TARGET_URI).build(); ProtocolRequestCreator requestCreator = - new ProtocolRequestCreator( - REQUEST_BASE_URI, - API_KEY, - new HashMap<String, String>(), - /* useCompression= */ false); + ProtocolRequestCreator.create(forwardingInfo, false); FederatedComputeHttpRequest request = requestCreator.createProtoRequest( - "/v1/request", - HttpMethod.POST, - REQUEST_BODY, - /* isProtobufEncoded= */ true); + "/v1/request", HttpMethod.POST, REQUEST_BODY, false); - assertThat(request.getUri()).isEqualTo("https://initial.uri/v1/request?%24alt=proto"); - assertThat(request.getHttpMethod()).isEqualTo(HttpMethod.POST); - assertThat(request.getBody()).isEqualTo(REQUEST_BODY); - HashMap<String, String> expectedHeaders = new HashMap<String, String>(); - expectedHeaders.put(API_KEY_HDR, API_KEY); - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(12)); - expectedHeaders.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); - assertThat(request.getExtraHeaders()).isEqualTo(expectedHeaders); + assertThat(request.getUri()).isEqualTo("https://aggregation.uri/v1/request"); } @Test - public void testCreateProtocolRequestWithForwardingInfo() { - ForwardingInfo forwardingInfo = - ForwardingInfo.newBuilder().setTargetUriPrefix(AGGREGATION_TARGET_URI).build(); + public void testCreateProtoRequest() { ProtocolRequestCreator requestCreator = - ProtocolRequestCreator.create(API_KEY, forwardingInfo, /* useCompression= */ false); + new ProtocolRequestCreator(REQUEST_BASE_URI, new HashMap<String, String>(), false); FederatedComputeHttpRequest request = requestCreator.createProtoRequest( - "/v1/request", - HttpMethod.POST, - REQUEST_BODY, - /* isProtobufEncoded= */ true); + "/v1/request", HttpMethod.POST, REQUEST_BODY, true); - assertThat(request.getUri()).isEqualTo("https://aggregation.uri/v1/request?%24alt=proto"); + assertThat(request.getUri()).isEqualTo("https://initial.uri/v1/request"); + assertThat(request.getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(request.getBody()).isEqualTo(REQUEST_BODY); + HashMap<String, String> expectedHeaders = new HashMap<String, String>(); + expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(12)); + expectedHeaders.put(CONTENT_TYPE_HDR, PROTOBUF_CONTENT_TYPE); + assertThat(request.getExtraHeaders()).isEqualTo(expectedHeaders); } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java index 8547e631..2fc8188e 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/scheduling/FederatedComputeJobManagerTest.java @@ -16,6 +16,8 @@ package com.android.federatedcompute.services.scheduling; +import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; + import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; @@ -29,7 +31,6 @@ import android.app.job.JobInfo; import android.app.job.JobScheduler; import android.content.ComponentName; import android.content.Context; -import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.common.TrainingInterval; import android.federatedcompute.common.TrainingOptions; @@ -37,16 +38,16 @@ import androidx.test.core.app.ApplicationProvider; import com.android.federatedcompute.services.common.Clock; import com.android.federatedcompute.services.common.Flags; -import com.android.federatedcompute.services.common.TrainingResult; +import com.android.federatedcompute.services.data.FederatedComputeDbHelper; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.FederatedTrainingTaskDao; -import com.android.federatedcompute.services.data.FederatedTrainingTaskDbHelper; import com.android.federatedcompute.services.data.fbs.SchedulingMode; import com.android.federatedcompute.services.data.fbs.SchedulingReason; import com.android.federatedcompute.services.data.fbs.TrainingConstraints; import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; import com.google.flatbuffers.FlatBufferBuilder; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; import com.google.intelligence.fcp.client.engine.TaskRetry; import org.junit.After; @@ -58,7 +59,6 @@ import org.mockito.junit.MockitoJUnitRunner; import java.nio.ByteBuffer; import java.util.List; -import java.util.concurrent.CountDownLatch; import javax.annotation.Nullable; @@ -92,13 +92,9 @@ public final class FederatedComputeJobManagerTest { .build(); private static final TaskRetry TASK_RETRY = TaskRetry.newBuilder().setDelayMin(5000000).setDelayMax(6000000).build(); - private FederatedComputeJobManager mJobManager; private Context mContext; private FederatedTrainingTaskDao mTrainingTaskDao; - private boolean mSuccess = false; - private final CountDownLatch mLatch = new CountDownLatch(1); - @Mock private Clock mClock; @Mock private Flags mMockFlags; @Mock private FederatedJobIdGenerator mMockJobIdGenerator; @@ -134,8 +130,8 @@ public final class FederatedComputeJobManagerTest { public void tearDown() { // Manually clean up the database. mTrainingTaskDao.clearDatabase(); - FederatedTrainingTaskDbHelper dbHelper = - FederatedTrainingTaskDbHelper.getInstanceForTest(mContext); + FederatedComputeDbHelper dbHelper = + FederatedComputeDbHelper.getInstanceForTest(mContext); dbHelper.getWritableDatabase().close(); dbHelper.getReadableDatabase().close(); dbHelper.close(); @@ -145,10 +141,9 @@ public final class FederatedComputeJobManagerTest { public void testOnTrainerStartCalledSuccess() throws Exception { when(mClock.currentTimeMillis()).thenReturn(1000L).thenReturn(2000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + int resultCode = mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); - assertThat(mSuccess).isTrue(); + assertThat(resultCode).isEqualTo(STATUS_SUCCESS); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); assertThat(taskList) @@ -157,6 +152,7 @@ public final class FederatedComputeJobManagerTest { .creationTime(1000L) .lastScheduledTime(1000L) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) + .intervalOptions(createDefaultTrainingInterval()) .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) .build()); } @@ -165,14 +161,11 @@ public final class FederatedComputeJobManagerTest { public void testOnTrainerStartCalled_firstTime() throws Exception { when(mClock.currentTimeMillis()).thenReturn(1000L); // Make three onTrainerStart calls, each with different job ID and session name. - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + int resultCode = mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); when(mClock.currentTimeMillis()).thenReturn(2000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS2, new TestFederatedComputeCallback()); - mLatch.await(); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS2); - assertThat(mSuccess).isTrue(); + assertThat(resultCode).isEqualTo(STATUS_SUCCESS); // verify training tasks in database. List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -183,12 +176,14 @@ public final class FederatedComputeJobManagerTest { .lastScheduledTime(1000L) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) + .intervalOptions(createDefaultTrainingInterval()) .build(), basicFLTrainingTaskBuilder(JOB_ID2, POPULATION_NAME2, null) .creationTime(2000L) .lastScheduledTime(2000L) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) .earliestNextRunTime(2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) + .intervalOptions(createDefaultTrainingInterval()) .build()); assertThat(mJobScheduler.getAllPendingJobs()).hasSize(2); @@ -226,8 +221,7 @@ public final class FederatedComputeJobManagerTest { .setMinimumIntervalMillis(userDefinedIntervalMillis) .build()) .build(); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainerOptions); byte[] trainingIntervalOptions = createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); @@ -260,17 +254,15 @@ public final class FederatedComputeJobManagerTest { @Test public void testOnTrainerStartCalled_multipleTimes_sameParams() throws Exception { when(mClock.currentTimeMillis()).thenReturn(1000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); when(mClock.currentTimeMillis()).thenReturn(2000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); when(mClock.currentTimeMillis()).thenReturn(3000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + int resultCode = mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); + assertThat(resultCode).isEqualTo(STATUS_SUCCESS); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); FederatedTrainingTask expectedTask = @@ -279,6 +271,7 @@ public final class FederatedComputeJobManagerTest { .lastScheduledTime(3000L) .creationTime(1000L) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) + .intervalOptions(createDefaultTrainingInterval()) .build(); assertThat(taskList).containsExactly(expectedTask); @@ -310,16 +303,13 @@ public final class FederatedComputeJobManagerTest { .build(); when(mClock.currentTimeMillis()).thenReturn(1000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainingOptions); when(mClock.currentTimeMillis()).thenReturn(2000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainingOptions); when(mClock.currentTimeMillis()).thenReturn(3000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainingOptions); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -362,8 +352,7 @@ public final class FederatedComputeJobManagerTest { .build(); when(mClock.currentTimeMillis()).thenReturn(1000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainingOptions); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -399,8 +388,7 @@ public final class FederatedComputeJobManagerTest { .build(); when(mClock.currentTimeMillis()).thenReturn(2000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, newTrainingOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, newTrainingOptions); taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); expectedInterval = @@ -440,8 +428,7 @@ public final class FederatedComputeJobManagerTest { .build(); when(mClock.currentTimeMillis()).thenReturn(1000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainingOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainingOptions); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -465,9 +452,7 @@ public final class FederatedComputeJobManagerTest { public void testOnTrainerStartCalled_trainingIntervalChange_FL() throws Exception { when(mClock.currentTimeMillis()).thenReturn(1000L); mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, - basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(), - new TestFederatedComputeCallback()); + CALLING_PACKAGE_NAME, basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build()); long minTrainingIntervalMillis = 60000L; when(mClock.currentTimeMillis()).thenReturn(2000L); @@ -480,8 +465,7 @@ public final class FederatedComputeJobManagerTest { TrainingInterval.SCHEDULING_MODE_RECURRENT) .setMinimumIntervalMillis(minTrainingIntervalMillis) .build()) - .build(), - new TestFederatedComputeCallback()); + .build()); byte[] trainingInterval = createTrainingIntervalOptions(SchedulingMode.RECURRENT, minTrainingIntervalMillis); verifyTaskAndJobAfterIntervalChange( @@ -498,14 +482,14 @@ public final class FederatedComputeJobManagerTest { TrainingInterval.SCHEDULING_MODE_RECURRENT) .setMinimumIntervalMillis(newInterval) .build()) - .build(), - new TestFederatedComputeCallback()); + .build()); byte[] trainingIntervalOption2 = createTrainingIntervalOptions(SchedulingMode.RECURRENT, newInterval); // Verify the creation time not changed, modified time is set to now, and the min interval // is set to the new interval. verifyTaskAndJobAfterIntervalChange(trainingIntervalOption2, 1000, 3000, newInterval); + // Change to default training interval {one_time, interval 0}. when(mClock.currentTimeMillis()).thenReturn(4000L); mJobManager.onTrainerStartCalled( CALLING_PACKAGE_NAME, @@ -515,21 +499,12 @@ public final class FederatedComputeJobManagerTest { .setSchedulingMode( TrainingInterval.SCHEDULING_MODE_ONE_TIME) .build()) - .build(), - new TestFederatedComputeCallback()); + .build()); byte[] trainingIntervalOption3 = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0L); // Verify the creation time not changed, modified time is set to now, and the min interval // is set to the new interval. verifyTaskAndJobAfterIntervalChange( trainingIntervalOption3, 1000, 4000, DEFAULT_SCHEDULING_PERIOD_MILLIS); - - // Transition back to not set - when(mClock.currentTimeMillis()).thenReturn(5000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, - basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(), - new TestFederatedComputeCallback()); - verifyTaskAndJobAfterIntervalChange(null, 1000, 5000, DEFAULT_SCHEDULING_PERIOD_MILLIS); } private void verifyTaskAndJobAfterIntervalChange( @@ -563,8 +538,7 @@ public final class FederatedComputeJobManagerTest { .setPopulationName(POPULATION_NAME1) .setServerAddress(SERVER_ADDRESS) .build(); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, options1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, options1); // Pass in a new population name and We will assign new job id since population name // changes. @@ -574,8 +548,7 @@ public final class FederatedComputeJobManagerTest { .setPopulationName(POPULATION_NAME2) .setServerAddress(SERVER_ADDRESS) .build(); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, options2, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, options2); // Verify two training tasks in database. List<FederatedTrainingTask> taskList = @@ -587,12 +560,14 @@ public final class FederatedComputeJobManagerTest { .lastScheduledTime(1000L) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) + .intervalOptions(createDefaultTrainingInterval()) .build(), basicFLTrainingTaskBuilder(JOB_ID2, POPULATION_NAME2, null) .creationTime(2000L) .lastScheduledTime(2000L) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) .earliestNextRunTime(2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) + .intervalOptions(createDefaultTrainingInterval()) .build()); assertThat(mJobScheduler.getAllPendingJobs()).hasSize(2); assertJobInfosMatch( @@ -611,8 +586,7 @@ public final class FederatedComputeJobManagerTest { .setPopulationName(POPULATION_NAME1) .setServerAddress(SERVER_ADDRESS) .build(); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, options1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, options1); // For same population, we will reuse the job id assigned to the previous task. when(mClock.currentTimeMillis()).thenReturn(2000L); @@ -621,8 +595,7 @@ public final class FederatedComputeJobManagerTest { .setPopulationName(POPULATION_NAME1) .setServerAddress(SERVER_ADDRESS) .build(); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, options2, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, options2); // Verify only task in database. List<FederatedTrainingTask> taskList = @@ -636,6 +609,7 @@ public final class FederatedComputeJobManagerTest { .creationTime(1000L) .constraints(DEFAULT_CONSTRAINTS) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) + .intervalOptions(createDefaultTrainingInterval()) .build(); assertThat(taskList).containsExactly(expectedTask); @@ -664,8 +638,7 @@ public final class FederatedComputeJobManagerTest { long nowMillis = 1000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); // Simulate attempting to run a task a lot later. This should not fail, b/c we're not yet // past the TTL threshold. @@ -679,8 +652,7 @@ public final class FederatedComputeJobManagerTest { when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(1L); when(mClock.currentTimeMillis()).thenReturn(1000L); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); // Simulate attempting to run a task one second later. This should not fail, b/c we're not // yet @@ -691,11 +663,11 @@ public final class FederatedComputeJobManagerTest { assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); // Now reschedule again, should keep the task alive for another second. - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + int resultCode = mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); // The task should again still be alive a second later. nowMillis = 3000; + assertThat(resultCode).isEqualTo(STATUS_SUCCESS); when(mClock.currentTimeMillis()).thenReturn(nowMillis); assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); @@ -710,8 +682,7 @@ public final class FederatedComputeJobManagerTest { public void testRescheduleFLTask_success() throws Exception { long nowMillis = 1000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); nowMillis = 2000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); @@ -724,7 +695,7 @@ public final class FederatedComputeJobManagerTest { POPULATION_NAME1, createTrainingIntervalOptionsAsRoot(SchedulingMode.RECURRENT, 0), TASK_RETRY, - TrainingResult.SUCCESS); + ContributionResult.SUCCESS); assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); @@ -734,8 +705,7 @@ public final class FederatedComputeJobManagerTest { public void testRescheduleFLTask_oneoff_success() throws Exception { long nowMillis = 1000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, OPTIONS1, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); nowMillis = 2000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); @@ -748,7 +718,7 @@ public final class FederatedComputeJobManagerTest { POPULATION_NAME1, createTrainingIntervalOptionsAsRoot(SchedulingMode.ONE_TIME, 0), TASK_RETRY, - TrainingResult.SUCCESS); + ContributionResult.SUCCESS); assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull(); assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); @@ -768,8 +738,7 @@ public final class FederatedComputeJobManagerTest { TrainingInterval.SCHEDULING_MODE_ONE_TIME) .build()) .build(); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainerOptions); nowMillis = 2000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); @@ -787,7 +756,7 @@ public final class FederatedComputeJobManagerTest { .setDelayMin(serverRetryDelayMillis) .setDelayMax(serverRetryDelayMillis) .build(), - TrainingResult.FAIL); + ContributionResult.FAIL); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -827,8 +796,7 @@ public final class FederatedComputeJobManagerTest { .build(); long nowMillis = 1000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainerOptions); nowMillis = 2000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); @@ -847,7 +815,7 @@ public final class FederatedComputeJobManagerTest { .setDelayMin(minRetryDelayMillis) .setDelayMax(maxRetryDelayMillis) .build(), - TrainingResult.SUCCESS); + ContributionResult.SUCCESS); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -887,8 +855,7 @@ public final class FederatedComputeJobManagerTest { long nowMillis = 1000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainerOptions); nowMillis = 2000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); @@ -907,7 +874,7 @@ public final class FederatedComputeJobManagerTest { .setDelayMin(serverDefinedIntervalMillis) .setDelayMax(serverDefinedIntervalMillis) .build(), - TrainingResult.SUCCESS); + ContributionResult.SUCCESS); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -947,8 +914,7 @@ public final class FederatedComputeJobManagerTest { long nowMillis = 1000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); - mJobManager.onTrainerStartCalled( - CALLING_PACKAGE_NAME, trainerOptions, new TestFederatedComputeCallback()); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, trainerOptions); nowMillis = 2000; when(mClock.currentTimeMillis()).thenReturn(nowMillis); @@ -967,7 +933,7 @@ public final class FederatedComputeJobManagerTest { .setDelayMin(serverDefinedIntervalMillis) .setDelayMax(serverDefinedIntervalMillis) .build(), - TrainingResult.FAIL); + ContributionResult.FAIL); List<FederatedTrainingTask> taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); @@ -989,6 +955,34 @@ public final class FederatedComputeJobManagerTest { buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis)); } + @Test + public void testOnTrainerStopCalled_withoutOnTrainerStartCalled() throws Exception { + // Should not fail, even if onTrainerStartCalled was never called. + int resultCode = mJobManager.onTrainerStopCalled(CALLING_PACKAGE_NAME, POPULATION_NAME1); + + // No task should exist, nor should a job have been scheduled. + assertThat(resultCode).isEqualTo(STATUS_SUCCESS); + assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); + assertThat(mJobScheduler.getAllPendingJobs()).isEmpty(); + } + + @Test + public void testOnTrainerStopCalled_afterOnTrainerStartCalled() throws Exception { + // After a cycle of onTrainerStartCalled -> onTrainerStopCalled there should be no pending + // jobs. + long nowMillis = 1000; + when(mClock.currentTimeMillis()).thenReturn(nowMillis); + mJobManager.onTrainerStartCalled(CALLING_PACKAGE_NAME, OPTIONS1); + + nowMillis = 2000; + when(mClock.currentTimeMillis()).thenReturn(nowMillis); + mJobManager.onTrainerStopCalled(CALLING_PACKAGE_NAME, POPULATION_NAME1); + + // No task should exist, nor should a job be scheduled anymore + assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); + assertThat(mJobScheduler.getAllPendingJobs()).isEmpty(); + } + /** * Helper for checking that two JobInfos match, since JobInfos unfortunately can't be compared * directly. @@ -1073,22 +1067,13 @@ public final class FederatedComputeJobManagerTest { return builder.sizedByteArray(); } + private static byte[] createDefaultTrainingInterval() { + return createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0); + } + private static byte[] createDefaultTrainingConstraints() { FlatBufferBuilder builder = new FlatBufferBuilder(); builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true)); return builder.sizedByteArray(); } - - class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub { - @Override - public void onSuccess() { - mSuccess = true; - mLatch.countDown(); - } - - @Override - public void onFailure(int errorCode) { - mLatch.countDown(); - } - } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/ResettingExampleIterator.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/ResettingExampleIterator.java new file mode 100644 index 00000000..ede61f90 --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/ResettingExampleIterator.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.testutils; + +import com.android.federatedcompute.services.examplestore.ExampleIterator; + +import com.google.common.collect.ImmutableList; +import com.google.internal.federated.plan.Dataset; +import com.google.internal.federated.plan.Dataset.ClientDataset; +import com.google.protobuf.ByteString; + +import java.time.Duration; + +/** An ExampleIterator that serves a fixed number of examples. */ +public class ResettingExampleIterator implements ExampleIterator { + private final int mCapacity; + ImmutableList<ByteString> mExamples; + Duration mNextLatency; + private int mNumServed; + private int mNumHasNextInvocations; + private int mNumNextInvocations; + private int mNumCloseInvocations; + + /** + * Creates an {@link ExampleIterator} for use in tests. + * + * @param limit number of examples being served until {@link #hasNext()} fails. + * @param dataset The Dataset to serve examples from. + * @param nextLatency amount of time to sleep before returning an example from next(). + */ + public ResettingExampleIterator(int limit, Dataset dataset, Duration nextLatency) { + this.mCapacity = limit; + ImmutableList.Builder<ByteString> exampleDatasetBuilder = new ImmutableList.Builder<>(); + for (ClientDataset clientDataset : dataset.getClientDataList()) { + exampleDatasetBuilder.addAll(clientDataset.getExampleList()); + } + this.mExamples = exampleDatasetBuilder.build(); + this.mNextLatency = nextLatency; + } + + public ResettingExampleIterator(int limit, Dataset dataset) { + this(limit, dataset, Duration.ZERO); + } + + /** + * Iterator that always returns the same example. See {@link #ResettingExampleIterator(int, + * Dataset)} for details. + */ + public ResettingExampleIterator(int capacity, ByteString example) { + this.mCapacity = capacity; + this.mNextLatency = Duration.ZERO; + this.mExamples = ImmutableList.of(example); + } + + @Override + public boolean hasNext() { + mNumHasNextInvocations++; + if (mNumServed < mCapacity) { + return true; + } else { + mNumServed = 0; + return false; + } + } + + @Override + public byte[] next() { + mNumNextInvocations++; + mNumServed++; + // If the end of the provided examples has been reached, wrap around and start serving from + // the beginning again. + if (!mNextLatency.isZero()) { + try { + Thread.sleep(mNextLatency.toMillis()); + } catch (InterruptedException e) { + throw new IllegalStateException("error in Thread.sleep()", e); + } + } + return mExamples.get((mNumServed - 1) % mExamples.size()).toByteArray(); + } + + @Override + public void close() { + mNumCloseInvocations++; + } + + public int getNumHasNextInvocations() { + return mNumHasNextInvocations; + } + + public int getNumNextInvocations() { + return mNumNextInvocations; + } + + public int getNumCloseInvocations() { + return mNumCloseInvocations; + } +} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java index b08c1c4a..01c434ab 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/testutils/TrainingTestUtil.java @@ -26,14 +26,11 @@ import com.google.internal.federated.plan.ExampleQuerySpec.OutputVectorSpec.Data import com.google.internal.federated.plan.ExampleSelector; import com.google.internal.federated.plan.FederatedExampleQueryIORouter; import com.google.internal.federated.plan.TFV1CheckpointAggregation; +import com.google.internal.federated.plan.TensorflowSpec; +import com.google.protobuf.ByteString; /** The utility class for federated learning related tests. */ public class TrainingTestUtil { - public static final String CLIENT_PACKAGE_NAME = "de.myselph"; - public static final long RUN_ID = 12345L; - public static final String SESSION_NAME = "session_name"; - public static final String TASK_NAME = "task_name"; - public static final String POPULATION_NAME = "population_name"; public static final String STRING_VECTOR_NAME = "vector1"; public static final String INT_VECTOR_NAME = "vector2"; public static final String STRING_TENSOR_NAME = "tensor1"; @@ -85,4 +82,19 @@ public class TrainingTestUtil { .build(); return clientOnlyPlan; } + + public static ClientOnlyPlan createFakeFederatedLearningClientPlan() { + TensorflowSpec tensorflowSpec = + TensorflowSpec.newBuilder() + .setDatasetTokenTensorName("dataset") + .addTargetNodeNames("target") + .build(); + ClientOnlyPlan clientOnlyPlan = + ClientOnlyPlan.newBuilder() + .setTfliteGraph(ByteString.copyFromUtf8("tflite_graph")) + .setPhase( + ClientPhase.newBuilder().setTensorflowSpec(tensorflowSpec).build()) + .build(); + return clientOnlyPlan; + } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java index 4bbfca3e..34179225 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java @@ -16,108 +16,176 @@ package com.android.federatedcompute.services.training; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static com.android.federatedcompute.services.common.FileUtils.createTempFile; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.util.concurrent.Futures.immediateFailedFuture; +import static com.google.common.util.concurrent.Futures.immediateFuture; + +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.android.federatedcompute.services.common.Flags; -import com.android.federatedcompute.services.common.TrainingResult; +import android.content.Context; +import android.federatedcompute.aidl.IExampleStoreCallback; +import android.federatedcompute.aidl.IExampleStoreService; +import android.federatedcompute.common.ClientConstants; +import android.federatedcompute.common.ExampleConsumption; +import android.os.Bundle; +import android.os.RemoteException; + +import androidx.test.core.app.ApplicationProvider; + +import com.android.federatedcompute.services.common.Constants; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.fbs.SchedulingMode; import com.android.federatedcompute.services.data.fbs.SchedulingReason; import com.android.federatedcompute.services.data.fbs.TrainingConstraints; import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; +import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; +import com.android.federatedcompute.services.http.CheckinResult; +import com.android.federatedcompute.services.http.HttpFederatedProtocol; import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager; +import com.android.federatedcompute.services.testutils.FakeExampleStoreIterator; +import com.android.federatedcompute.services.testutils.TrainingTestUtil; +import com.android.federatedcompute.services.training.ResultCallbackHelper.CallbackResult; +import com.android.federatedcompute.services.training.aidl.IIsolatedTrainingService; +import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback; +import com.android.federatedcompute.services.training.util.TrainingConditionsChecker; +import com.android.federatedcompute.services.training.util.TrainingConditionsChecker.Condition; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.FluentFuture; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; import com.google.flatbuffers.FlatBufferBuilder; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; +import com.google.intelligence.fcp.client.RetryInfo; +import com.google.intelligence.fcp.client.engine.TaskRetry; +import com.google.internal.federated.plan.ClientOnlyPlan; +import com.google.internal.federated.plan.ClientPhase; +import com.google.internal.federated.plan.TensorflowSpec; +import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.tensorflow.example.BytesList; +import org.tensorflow.example.Example; +import org.tensorflow.example.Feature; +import org.tensorflow.example.Features; + +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.concurrent.ExecutionException; @RunWith(JUnit4.class) public final class FederatedComputeWorkerTest { private static final int JOB_ID = 1234; private static final String POPULATION_NAME = "barPopulation"; - private static final String SERVER_ADDRESS = "https://server.uri/"; + private static final String TASK_NAME = "task-id"; private static final long CREATION_TIME_MS = 10000L; private static final long TASK_EARLIEST_NEXT_RUN_TIME_MS = 1234567L; private static final String PACKAGE_NAME = "com.android.federatedcompute.services.training"; + private static final String SERVER_ADDRESS = "https://server.com/"; private static final byte[] DEFAULT_TRAINING_CONSTRAINTS = createTrainingConstraints(true, true, true); - private static final long FEDERATED_TRANSIENT_ERROR_RETRY_PERIOD_SECS = 50000; + private static final TaskRetry TASK_RETRY = + TaskRetry.newBuilder().setRetryToken("foobar").build(); + private static final CheckinResult FL_CHECKIN_RESULT = + new CheckinResult( + createTempFile("input", ".ckp"), + TrainingTestUtil.createFakeFederatedLearningClientPlan(), + TaskAssignment.newBuilder().setTaskName(TASK_NAME).build()); + private static final CheckinResult FA_CHECKIN_RESULT = + new CheckinResult( + createTempFile("input", ".ckp"), + TrainingTestUtil.createFederatedAnalyticClientPlan(), + TaskAssignment.newBuilder().setTaskName(TASK_NAME).build()); + private static final FLRunnerResult FL_RUNNER_FAILURE_RESULT = + FLRunnerResult.newBuilder().setContributionResult(ContributionResult.FAIL).build(); + + private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = + FLRunnerResult.newBuilder() + .setContributionResult(ContributionResult.SUCCESS) + .setRetryInfo( + RetryInfo.newBuilder() + .setRetryToken(TASK_RETRY.getRetryToken()) + .build()) + .build(); private static final byte[] INTERVAL_OPTIONS = createDefaultTrainingIntervalOptions(); private static final FederatedTrainingTask FEDERATED_TRAINING_TASK_1 = FederatedTrainingTask.builder() .appPackageName(PACKAGE_NAME) .creationTime(CREATION_TIME_MS) .lastScheduledTime(TASK_EARLIEST_NEXT_RUN_TIME_MS) + .serverAddress(SERVER_ADDRESS) .populationName(POPULATION_NAME) .jobId(JOB_ID) - .serverAddress(SERVER_ADDRESS) .intervalOptions(INTERVAL_OPTIONS) .constraints(DEFAULT_TRAINING_CONSTRAINTS) .earliestNextRunTime(TASK_EARLIEST_NEXT_RUN_TIME_MS) .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) .build(); - + private static final Example EXAMPLE_PROTO_1 = + Example.newBuilder() + .setFeatures( + Features.newBuilder() + .putFeature( + "feature1", + Feature.newBuilder() + .setBytesList( + BytesList.newBuilder() + .addValue( + ByteString.copyFromUtf8( + "f1_value1"))) + .build())) + .build(); + private static final Any FAKE_CRITERIA = Any.newBuilder().setTypeUrl("baz.com").build(); + private static final ExampleConsumption EXAMPLE_CONSUMPTION_1 = + new ExampleConsumption.Builder() + .setTaskName(TASK_NAME) + .setSelectionCriteria(FAKE_CRITERIA.toByteArray()) + .setExampleCount(100) + .build(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Mock TrainingConditionsChecker mTrainingConditionsChecker; @Mock FederatedComputeJobManager mMockJobManager; - @Mock private Flags mMockFlags; - private FederatedComputeWorker mFcpWorker; - - @Before - public void doBeforeEachTest() { - MockitoAnnotations.initMocks(this); - mFcpWorker = new FederatedComputeWorker(mMockJobManager, mMockFlags); - doNothing() - .when(mMockJobManager) - .onTrainingCompleted(anyInt(), anyString(), any(), any(), anyInt()); - when(mMockFlags.getTransientErrorRetryDelayJitterPercent()).thenReturn(0.1f); - when(mMockFlags.getTransientErrorRetryDelaySecs()) - .thenReturn(FEDERATED_TRANSIENT_ERROR_RETRY_PERIOD_SECS); - } - - @Test - public void testTrainingSuccess() { - when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1); - boolean result = mFcpWorker.startTrainingRun(JOB_ID); - - assertTrue(result); - verify(mMockJobManager, times(1)) - .onTrainingCompleted( - eq(JOB_ID), eq(POPULATION_NAME), any(), any(), eq(TrainingResult.SUCCESS)); - } - - @Test - public void testTrainingFailure_nonExist() { - when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(null); - boolean result = mFcpWorker.startTrainingRun(JOB_ID); - - assertFalse(result); - verify(mMockJobManager, times(0)) - .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), anyInt()); - } + private Context mContext; + private FederatedComputeWorker mSpyWorker; + @Mock private HttpFederatedProtocol mMockHttpFederatedProtocol; + @Mock private ComputationRunner mMockComputationRunner; + @Mock private ResultCallbackHelper mMockResultCallbackHelper; private static byte[] createTrainingConstraints( boolean requiresSchedulerIdle, - boolean requiresSchedulerCharging, + boolean requiresSchedulerBatteryNotLow, boolean requiresSchedulerUnmeteredNetwork) { FlatBufferBuilder builder = new FlatBufferBuilder(); builder.finish( TrainingConstraints.createTrainingConstraints( builder, requiresSchedulerIdle, - requiresSchedulerCharging, + requiresSchedulerBatteryNotLow, requiresSchedulerUnmeteredNetwork)); return builder.sizedByteArray(); } @@ -129,4 +197,347 @@ public final class FederatedComputeWorkerTest { builder, SchedulingMode.ONE_TIME, 0)); return builder.sizedByteArray(); } + + @Before + public void doBeforeEachTest() throws Exception { + mContext = ApplicationProvider.getApplicationContext(); + mSpyWorker = + Mockito.spy( + new FederatedComputeWorker( + mContext, + mMockJobManager, + mTrainingConditionsChecker, + mMockComputationRunner, + mMockResultCallbackHelper, + new TestInjector())); + when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any())) + .thenReturn(EnumSet.noneOf(Condition.class)); + when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any())) + .thenReturn(Futures.immediateFuture(CallbackResult.SUCCESS)); + when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1); + doReturn(mMockHttpFederatedProtocol) + .when(mSpyWorker) + .getHttpFederatedProtocol(anyString(), anyString()); + when(mMockComputationRunner.runTaskWithNativeRunner( + anyString(), + anyString(), + anyString(), + anyString(), + any(), + any(), + any(), + any(), + any())) + .thenReturn(FL_RUNNER_SUCCESS_RESULT); + } + + @Test + public void testJobNonExist_returnsFail() throws Exception { + when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(null); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + + assertNull(result); + verify(mMockJobManager, times(0)) + .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), any()); + } + + @Test + public void testTrainingConditionsCheckFailed_returnsFail() throws Exception { + when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any())) + .thenReturn(ImmutableSet.of(Condition.BATTERY_NOT_OK)); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + + assertNull(result); + verify(mMockJobManager) + .onTrainingCompleted(eq(JOB_ID), eq(POPULATION_NAME), any(), any(), any()); + } + + @Test + public void testCheckinFails_throwsException() throws Exception { + setUpExampleStoreService(); + + doReturn( + immediateFailedFuture( + new ExecutionException( + "issue checkin failed", + new IllegalStateException("http 404")))) + .when(mMockHttpFederatedProtocol) + .issueCheckin(); + doReturn(FluentFuture.from(immediateFuture(null))) + .when(mMockHttpFederatedProtocol) + .reportResult(any()); + + assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + + mSpyWorker.finish(null, ContributionResult.FAIL, false); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + } + + @Test + public void testReportResultFails_throwsException() throws Exception { + setUpExampleStoreService(); + + doReturn(immediateFuture(FA_CHECKIN_RESULT)) + .when(mMockHttpFederatedProtocol) + .issueCheckin(); + doReturn( + FluentFuture.from( + immediateFailedFuture( + new ExecutionException( + "report result failed", + new IllegalStateException("http 404"))))) + .when(mMockHttpFederatedProtocol) + .reportResult(any()); + + assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + + mSpyWorker.finish(null, ContributionResult.FAIL, false); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + verify(mSpyWorker).unbindFromExampleStoreService(); + } + + @Test + public void testBindToExampleStoreFails_throwsException() throws Exception { + setUpHttpFederatedProtocol(FL_CHECKIN_RESULT); + + // Mock failure bind to ExampleStoreService. + doReturn(null).when(mSpyWorker).getExampleStoreService(anyString()); + doNothing().when(mSpyWorker).unbindFromExampleStoreService(); + + assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + + mSpyWorker.finish(null, ContributionResult.FAIL, false); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + verify(mSpyWorker, times(0)).unbindFromExampleStoreService(); + } + + @Test + public void testRunFAComputationReturnsFailResult() throws Exception { + setUpExampleStoreService(); + setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); + + // Mock return failed runner result from native fcp client. + when(mMockComputationRunner.runTaskWithNativeRunner( + anyString(), + anyString(), + anyString(), + anyString(), + any(), + any(), + any(), + any(), + any())) + .thenReturn(FL_RUNNER_FAILURE_RESULT); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.FAIL); + + mSpyWorker.finish(result); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + verify(mSpyWorker).unbindFromExampleStoreService(); + } + + @Test + public void testPublishToResultHandlingServiceFails_returnsSuccess() throws Exception { + setUpExampleStoreService(); + setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); + + // Mock publish to ResultHandlingService fails which is best effort and should not affect + // final result. + when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any())) + .thenReturn(Futures.immediateFuture(CallbackResult.FAIL)); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS); + + mSpyWorker.finish(result); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS)); + verify(mSpyWorker).unbindFromExampleStoreService(); + } + + @Test + public void testPublishToResultHandlingServiceThrowsException_returnsSuccess() + throws Exception { + setUpExampleStoreService(); + setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); + + // Mock publish to ResultHandlingService throws exception which is best effort and should + // not affect final result. + when(mMockResultCallbackHelper.callHandleResult(eq(TASK_NAME), any(), any())) + .thenReturn( + immediateFailedFuture( + new ExecutionException( + "ResultHandlingService fail", + new IllegalStateException("can't bind to service")))); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS); + + mSpyWorker.finish(result); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS)); + verify(mSpyWorker).unbindFromExampleStoreService(); + verify(mMockResultCallbackHelper).callHandleResult(eq(TASK_NAME), any(), any()); + } + + @Test + public void testRunFAComputation_returnsSuccess() throws Exception { + setUpExampleStoreService(); + setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS); + + mSpyWorker.finish(result); + verify(mMockJobManager).onTrainingCompleted(anyInt(), anyString(), any(), any(), any()); + } + + @Test + public void testBindToIsolatedTrainingServiceFail_returnsFail() throws Exception { + doReturn(immediateFuture(FL_CHECKIN_RESULT)) + .when(mMockHttpFederatedProtocol) + .issueCheckin(); + setUpExampleStoreService(); + + // Mock failure bind to IsolatedTrainingService. + doReturn(null).when(mSpyWorker).getIsolatedTrainingService(); + doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService(); + + ExecutionException exception = + assertThrows( + ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class); + assertThat(exception.getCause()) + .hasMessageThat() + .isEqualTo("Could not bind to IsolatedTrainingService"); + + mSpyWorker.finish(null, ContributionResult.FAIL, false); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + } + + @Test + public void testRunFLComputation_emptyTfliteGraph_returns() throws Exception { + setUpExampleStoreService(); + TensorflowSpec tensorflowSpec = + TensorflowSpec.newBuilder() + .setDatasetTokenTensorName("dataset") + .addTargetNodeNames("target") + .build(); + ClientOnlyPlan clientOnlyPlan = + ClientOnlyPlan.newBuilder() + .setPhase( + ClientPhase.newBuilder().setTensorflowSpec(tensorflowSpec).build()) + .build(); + CheckinResult checkinResultNoTfliteGraph = + new CheckinResult( + createTempFile("input", ".ckp"), + clientOnlyPlan, + TaskAssignment.newBuilder().setTaskName(TASK_NAME).build()); + setUpHttpFederatedProtocol(checkinResultNoTfliteGraph); + + // Mock bind to IsolatedTrainingService. + doReturn(new FakeIsolatedTrainingService()).when(mSpyWorker).getIsolatedTrainingService(); + doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService(); + + assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + + mSpyWorker.finish(null, ContributionResult.FAIL, false); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + } + + @Test + public void testRunFLComputation_returnsSuccess() throws Exception { + setUpExampleStoreService(); + setUpHttpFederatedProtocol(FL_CHECKIN_RESULT); + + // Mock bind to IsolatedTrainingService. + doReturn(new FakeIsolatedTrainingService()).when(mSpyWorker).getIsolatedTrainingService(); + doNothing().when(mSpyWorker).unbindFromIsolatedTrainingService(); + + FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS); + + mSpyWorker.finish(result); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.SUCCESS)); + verify(mSpyWorker).unbindFromIsolatedTrainingService(); + verify(mSpyWorker).unbindFromExampleStoreService(); + } + + private void setUpExampleStoreService() { + TestExampleStoreService testExampleStoreService = new TestExampleStoreService(); + doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString()); + doNothing().when(mSpyWorker).unbindFromExampleStoreService(); + } + + private void setUpHttpFederatedProtocol(CheckinResult checkinResult) { + doReturn(immediateFuture(checkinResult)).when(mMockHttpFederatedProtocol).issueCheckin(); + doReturn(FluentFuture.from(immediateFuture(null))) + .when(mMockHttpFederatedProtocol) + .reportResult(any()); + } + + private static class TestExampleStoreService extends IExampleStoreService.Stub { + @Override + public void startQuery(Bundle params, IExampleStoreCallback callback) + throws RemoteException { + callback.onStartQuerySuccess( + new FakeExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1.toByteArray()))); + } + } + + private static class TestInjector extends FederatedComputeWorker.Injector { + @Override + ExampleConsumptionRecorder getExampleConsumptionRecorder() { + return new ExampleConsumptionRecorder() { + @Override + public synchronized ArrayList<ExampleConsumption> finishRecordingAndGet() { + ArrayList<ExampleConsumption> exampleList = new ArrayList<>(); + exampleList.add(EXAMPLE_CONSUMPTION_1); + return exampleList; + } + }; + } + + @Override + ListeningExecutorService getBgExecutor() { + return MoreExecutors.newDirectExecutorService(); + } + } + + private static final class FakeIsolatedTrainingService extends IIsolatedTrainingService.Stub { + @Override + public void runFlTraining(Bundle params, ITrainingResultCallback callback) + throws RemoteException { + Bundle bundle = new Bundle(); + bundle.putByteArray( + Constants.EXTRA_FL_RUNNER_RESULT, FL_RUNNER_SUCCESS_RESULT.toByteArray()); + ArrayList<ExampleConsumption> exampleConsumptionList = new ArrayList<>(); + exampleConsumptionList.add(EXAMPLE_CONSUMPTION_1); + bundle.putParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptionList); + callback.onResult(bundle); + } + + @Override + public void cancelTraining(long runId) {} + } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java index 37eb6809..b9e6086a 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedJobServiceTest.java @@ -16,12 +16,16 @@ package com.android.federatedcompute.services.training; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -32,20 +36,39 @@ import com.android.dx.mockito.inline.extended.ExtendedMockito; import com.android.federatedcompute.services.common.FederatedComputeExecutors; import com.android.federatedcompute.services.common.PhFlagsTestUtil; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.MoreExecutors; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; +import com.google.intelligence.fcp.client.RetryInfo; +import com.google.intelligence.fcp.client.engine.TaskRetry; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; import org.mockito.MockitoSession; import org.mockito.quality.Strictness; @RunWith(JUnit4.class) public final class FederatedJobServiceTest { - private static final long WAIT_IN_MILLIS = 1_000L; + private static final TaskRetry TASK_RETRY = + TaskRetry.newBuilder().setRetryToken("foobar").build(); + private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = + FLRunnerResult.newBuilder() + .setContributionResult(ContributionResult.SUCCESS) + .setRetryInfo( + RetryInfo.newBuilder() + .setRetryToken(TASK_RETRY.getRetryToken()) + .build()) + .build(); private FederatedJobService mSpyService; + @Mock private FederatedComputeWorker mMockWorker; + + private MockitoSession mStaticMockSession; @Before public void setUp() throws Exception { @@ -54,60 +77,55 @@ public final class FederatedJobServiceTest { mSpyService = spy(new FederatedJobService()); doNothing().when(mSpyService).jobFinished(any(), anyBoolean()); - } - - @Test - public void testOnStartJob() throws Exception { - MockitoSession session = + doReturn(mSpyService).when(mSpyService).getApplicationContext(); + mStaticMockSession = ExtendedMockito.mockitoSession() .spyStatic(FederatedComputeExecutors.class) + .spyStatic(FederatedComputeWorker.class) + .initMocks(this) .strictness(Strictness.LENIENT) .startMocking(); - try { - ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService()) - .when(FederatedComputeExecutors::getBackgroundExecutor); + ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService()) + .when(() -> FederatedComputeExecutors.getBackgroundExecutor()); + ExtendedMockito.doReturn(mMockWorker).when(() -> FederatedComputeWorker.getInstance(any())); + } - boolean result = mSpyService.onStartJob(mock(JobParameters.class)); + @After + public void teardown() { + if (mStaticMockSession != null) { + mStaticMockSession.finishMocking(); + } + } + + @Test + public void testOnStartJob() throws Exception { + doReturn(Futures.immediateFuture(FL_RUNNER_SUCCESS_RESULT)) + .when(mMockWorker) + .startTrainingRun(anyInt()); + doNothing().when(mMockWorker).finish(eq(FL_RUNNER_SUCCESS_RESULT)); - assertTrue(result); - Thread.sleep(WAIT_IN_MILLIS); + boolean result = mSpyService.onStartJob(mock(JobParameters.class)); - verify(mSpyService, times(1)).jobFinished(any(), anyBoolean()); - } finally { - session.finishMocking(); - } + assertTrue(result); + verify(mSpyService, times(1)).jobFinished(any(), anyBoolean()); } @Test public void testOnStartJobKillSwitch() throws Exception { PhFlagsTestUtil.enableGlobalKillSwitch(); - MockitoSession session = - ExtendedMockito.mockitoSession() - .spyStatic(FederatedComputeExecutors.class) - .strictness(Strictness.LENIENT) - .startMocking(); - try { - ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService()) - .when(FederatedComputeExecutors::getBackgroundExecutor); - - boolean result = mSpyService.onStartJob(mock(JobParameters.class)); - assertTrue(result); + boolean result = mSpyService.onStartJob(mock(JobParameters.class)); - verify(mSpyService, times(1)).jobFinished(any(), eq(false)); - } finally { - session.finishMocking(); - } + assertTrue(result); + verify(mMockWorker, never()).startTrainingRun(anyInt()); + verify(mSpyService, times(1)).jobFinished(any(), eq(false)); } @Test public void testOnStopJob() { - MockitoSession session = - ExtendedMockito.mockitoSession().strictness(Strictness.LENIENT).startMocking(); - try { - assertTrue(mSpyService.onStopJob(mock(JobParameters.class))); - } finally { - session.finishMocking(); - } + doNothing().when(mMockWorker).finish(any(), eq(ContributionResult.FAIL), eq(true)); + + // Do not reschedule in JobService. FederatedComputeJobManager will handle it. + assertFalse(mSpyService.onStopJob(mock(JobParameters.class))); } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java index 2ce331a3..d297e4f3 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/IsolatedTrainingServiceImplTest.java @@ -20,24 +20,17 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.when; -import android.content.Context; -import android.federatedcompute.aidl.IFederatedComputeCallback; -import android.federatedcompute.aidl.IResultHandlingService; -import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingOptions; +import android.federatedcompute.common.ClientConstants; import android.os.Bundle; import android.os.ParcelFileDescriptor; -import android.os.RemoteException; - -import androidx.test.core.app.ApplicationProvider; import com.android.dx.mockito.inline.extended.ExtendedMockito; import com.android.federatedcompute.services.common.Constants; import com.android.federatedcompute.services.common.FederatedComputeExecutors; +import com.android.federatedcompute.services.common.FileUtils; import com.android.federatedcompute.services.testutils.FakeExampleStoreIterator; import com.android.federatedcompute.services.testutils.TrainingTestUtil; import com.android.federatedcompute.services.training.aidl.ITrainingResultCallback; @@ -62,26 +55,17 @@ import org.mockito.MockitoSession; import org.mockito.quality.Strictness; import java.io.File; -import java.util.List; import java.util.concurrent.CountDownLatch; @RunWith(JUnit4.class) public final class IsolatedTrainingServiceImplTest { private static final String POPULATION_NAME = "population_name"; - private static final String CLIENT_PACKAGE_NAME = "de.myselph"; - private static final long RUN_ID = 12345L; - private static final String SESSION_NAME = "session_name"; private static final String TASK_NAME = "task_name"; - private static final String INPUT_CHECKPOINT_FD = "fd:///5"; - private static final String OUTPUT_CHECKPOINT_FD = "fd:///6"; + private static final long RUN_ID = 12345L; private static final FakeExampleStoreIterator FAKE_EXAMPLE_STORE_ITERATOR = new FakeExampleStoreIterator(ImmutableList.of()); - private static final FakeResultHandlingService FAKE_RESULT_HANDLING_SERVICE = - new FakeResultHandlingService(); private static final ExampleSelector EXAMPLE_SELECTOR = ExampleSelector.newBuilder().setCollectionUri("collection_uri").build(); - - private final CountDownLatch mLatch = new CountDownLatch(1); private static final TaskRetry TASK_RETRY = TaskRetry.newBuilder().setRetryToken("foobar").build(); private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = @@ -92,7 +76,6 @@ public final class IsolatedTrainingServiceImplTest { .setRetryToken(TASK_RETRY.getRetryToken()) .build()) .build(); - private static final FLRunnerResult FL_RUNNER_FAIL_RESULT = FLRunnerResult.newBuilder() .setContributionResult(ContributionResult.FAIL) @@ -101,11 +84,9 @@ public final class IsolatedTrainingServiceImplTest { .setRetryToken(TASK_RETRY.getRetryToken()) .build()) .build(); - - private final Context mContext = ApplicationProvider.getApplicationContext(); + private final CountDownLatch mLatch = new CountDownLatch(1); private IsolatedTrainingServiceImpl mIsolatedTrainingService; private Bundle mCallbackResult; - private int mCallbackErrorCode; @Mock private ComputationRunner mComputationRunner; private MockitoSession mStaticMockSession; private ParcelFileDescriptor mInputCheckpointFd; @@ -139,16 +120,7 @@ public final class IsolatedTrainingServiceImplTest { @Test public void runFlTrainingSuccess() throws Exception { when(mComputationRunner.runTaskWithNativeRunner( - anyInt(), - anyString(), - any(), - any(), - any(), - any(), - any(), - any(), - any(), - any())) + anyString(), anyString(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(FL_RUNNER_SUCCESS_RESULT); Bundle bundle = buildInputBundle(); @@ -162,16 +134,7 @@ public final class IsolatedTrainingServiceImplTest { @Test public void runFlTrainingFailure() throws Exception { when(mComputationRunner.runTaskWithNativeRunner( - anyInt(), - anyString(), - any(), - any(), - any(), - any(), - any(), - any(), - any(), - any())) + anyString(), anyString(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(FL_RUNNER_FAIL_RESULT); Bundle bundle = buildInputBundle(); @@ -183,15 +146,13 @@ public final class IsolatedTrainingServiceImplTest { } @Test - public void runFlTrainingMissingExampleSelector_returnsFailure() throws Exception { + public void runFlTrainingMissingExampleSelector_returnsFailure() { Bundle bundle = new Bundle(); - bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME); + bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME); bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd); bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd); bundle.putBinder( Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, FAKE_EXAMPLE_STORE_ITERATOR); - bundle.putBinder( - Constants.EXTRA_RESULT_HANDLING_SERVICE_BINDER, FAKE_RESULT_HANDLING_SERVICE); assertThrows( NullPointerException.class, @@ -199,15 +160,13 @@ public final class IsolatedTrainingServiceImplTest { } @Test - public void runFlTrainingInvalidExampleSelector_returnsFailure() throws Exception { + public void runFlTrainingInvalidExampleSelector_returnsFailure() { Bundle bundle = new Bundle(); - bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME); + bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME); bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd); bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd); bundle.putBinder( Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, FAKE_EXAMPLE_STORE_ITERATOR); - bundle.putBinder( - Constants.EXTRA_RESULT_HANDLING_SERVICE_BINDER, FAKE_RESULT_HANDLING_SERVICE); bundle.putByteArray(Constants.EXTRA_EXAMPLE_SELECTOR, "exampleselector".getBytes()); @@ -219,13 +178,11 @@ public final class IsolatedTrainingServiceImplTest { @Test public void runFlTrainingNullPlan_returnsFailure() throws Exception { Bundle bundle = new Bundle(); - bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME); + bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME); bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd); bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd); bundle.putBinder( Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, FAKE_EXAMPLE_STORE_ITERATOR); - bundle.putBinder( - Constants.EXTRA_RESULT_HANDLING_SERVICE_BINDER, FAKE_RESULT_HANDLING_SERVICE); bundle.putByteArray(Constants.EXTRA_EXAMPLE_SELECTOR, EXAMPLE_SELECTOR.toByteArray()); assertThrows( @@ -243,16 +200,21 @@ public final class IsolatedTrainingServiceImplTest { private Bundle buildInputBundle() throws Exception { Bundle bundle = new Bundle(); - bundle.putString(Constants.EXTRA_POPULATION_NAME, POPULATION_NAME); + bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, POPULATION_NAME); + bundle.putString(ClientConstants.EXTRA_TASK_NAME, TASK_NAME); bundle.putParcelable(Constants.EXTRA_INPUT_CHECKPOINT_FD, mInputCheckpointFd); bundle.putParcelable(Constants.EXTRA_OUTPUT_CHECKPOINT_FD, mOutputCheckpointFd); bundle.putByteArray(Constants.EXTRA_EXAMPLE_SELECTOR, EXAMPLE_SELECTOR.toByteArray()); bundle.putBinder( Constants.EXTRA_EXAMPLE_STORE_ITERATOR_BINDER, FAKE_EXAMPLE_STORE_ITERATOR); - bundle.putBinder( - Constants.EXTRA_RESULT_HANDLING_SERVICE_BINDER, FAKE_RESULT_HANDLING_SERVICE); ClientOnlyPlan clientOnlyPlan = TrainingTestUtil.createFederatedAnalyticClientPlan(); - bundle.putByteArray(Constants.EXTRA_CLIENT_ONLY_PLAN, clientOnlyPlan.toByteArray()); + String clientPlanFile = + FileUtils.createTempFile(Constants.EXTRA_CLIENT_ONLY_PLAN_FD, ".pb"); + FileUtils.writeToFile(clientPlanFile, clientOnlyPlan.toByteArray()); + bundle.putParcelable( + Constants.EXTRA_CLIENT_ONLY_PLAN_FD, + FileUtils.createTempFileDescriptor( + clientPlanFile, ParcelFileDescriptor.MODE_READ_ONLY)); return bundle; } @@ -274,16 +236,4 @@ public final class IsolatedTrainingServiceImplTest { mLatch.countDown(); } } - - private static final class FakeResultHandlingService extends IResultHandlingService.Stub { - @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) - throws RemoteException { - callback.onSuccess(); - } - } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java index 75ba38a8..0f9affa2 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultCallbackHelperTest.java @@ -17,153 +17,153 @@ package com.android.federatedcompute.services.training; import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; +import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; + +import android.content.Context; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IResultHandlingService; +import android.federatedcompute.common.ClientConstants; import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingOptions; +import android.os.Bundle; import android.os.RemoteException; +import androidx.test.core.app.ApplicationProvider; + +import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.fbs.SchedulingMode; +import com.android.federatedcompute.services.data.fbs.SchedulingReason; +import com.android.federatedcompute.services.data.fbs.TrainingConstraints; import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; import com.android.federatedcompute.services.training.ResultCallbackHelper.CallbackResult; +import com.android.federatedcompute.services.training.util.ComputationResult; -import com.google.common.collect.ImmutableList; import com.google.flatbuffers.FlatBufferBuilder; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; +import java.util.ArrayList; @RunWith(JUnit4.class) public final class ResultCallbackHelperTest { + private static final String PACKAGE_NAME = "app_package_name"; private static final byte[] SELECTION_CRITERIA = new byte[] {10, 0, 1}; + private static final int SCHEDULING_REASON = SchedulingReason.SCHEDULING_REASON_NEW_TASK; private static final String POPULATION_NAME = "population_name"; - private static final int JOB_ID = 123; + private static final String TASK_NAME = "task_name"; private static final byte[] INTERVAL_OPTIONS = createDefaultTrainingIntervalOptions(); - private final CountDownLatch mLatch = new CountDownLatch(1); - private static final ImmutableList<ExampleConsumption> EXAMPLE_CONSUMPTIONS = - ImmutableList.of( - new ExampleConsumption.Builder() - .setCollectionName("collection") - .setExampleCount(100) - .setSelectionCriteria(SELECTION_CRITERIA) - .build()); + private static final byte[] TRAINING_CONSTRAINTS = createDefaultTrainingConstraints(); + private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = + FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build(); + + private static final FederatedTrainingTask TRAINING_TASK = + FederatedTrainingTask.builder() + .appPackageName(PACKAGE_NAME) + .jobId(123) + .populationName(POPULATION_NAME) + .intervalOptions(INTERVAL_OPTIONS) + .constraints(TRAINING_CONSTRAINTS) + .serverAddress("server_address") + .creationTime(123L) + .lastScheduledTime(123L) + .earliestNextRunTime(123L) + .schedulingReason(SCHEDULING_REASON) + .build(); + private static final ArrayList<ExampleConsumption> EXAMPLE_CONSUMPTIONS = + getExampleConsumptions(); + private ComputationResult mComputationResult; + private ResultCallbackHelper mHelper; + + @Before + public void setUp() { + Context context = ApplicationProvider.getApplicationContext(); + mComputationResult = + new ComputationResult("output", FL_RUNNER_SUCCESS_RESULT, EXAMPLE_CONSUMPTIONS); + mHelper = Mockito.spy(new ResultCallbackHelper(context)); + doNothing().when(mHelper).unbindFromResultHandlingService(); + } @Test public void testHandleResult_success() throws Exception { - ResultCallbackHelper helper = - new ResultCallbackHelper(EXAMPLE_CONSUMPTIONS, new TestResultHandlingService()); + doReturn(new TestResultHandlingService()) + .when(mHelper) + .getResultHandlingService(eq(PACKAGE_NAME)); CallbackResult result = - helper.callHandleResult(JOB_ID, POPULATION_NAME, INTERVAL_OPTIONS, true); + mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get(); assertThat(result).isEqualTo(CallbackResult.SUCCESS); } @Test public void testHandleResult_remoteException() throws Exception { - ResultCallbackHelper helper = - new ResultCallbackHelper( - EXAMPLE_CONSUMPTIONS, new ResultHandlingServiceWithRemoteException()); + doReturn(new ResultHandlingServiceWithRemoteException()) + .when(mHelper) + .getResultHandlingService(eq(PACKAGE_NAME)); CallbackResult result = - helper.callHandleResult(JOB_ID, POPULATION_NAME, INTERVAL_OPTIONS, true); + mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get(); assertThat(result).isEqualTo(CallbackResult.FAIL); } - private static final class ResultHandlingServiceWithRemoteException - extends IResultHandlingService.Stub { - @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) - throws RemoteException { - throw new RemoteException("expected remote exception"); - } - } - @Test - public void testHandleResult_timeoutException() throws Exception { - ResultCallbackHelper helper = - new ResultCallbackHelper( - EXAMPLE_CONSUMPTIONS, new ResultHandlingServiceWithTimeoutException(), 20); + public void testHandleResult_interruptedException() throws Exception { + doReturn(new ResultHandlingServiceWithInterruptedException()) + .when(mHelper) + .getResultHandlingService(eq(PACKAGE_NAME)); CallbackResult result = - helper.callHandleResult(JOB_ID, POPULATION_NAME, INTERVAL_OPTIONS, true); + mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get(); assertThat(result).isEqualTo(CallbackResult.FAIL); } - class ResultHandlingServiceWithTimeoutException extends IResultHandlingService.Stub { - - @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) - throws RemoteException { - try { - mLatch.await(2000, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - } - } - } - @Test - public void testHandleResult_interruptedException() throws Exception { - ResultCallbackHelper helper = - new ResultCallbackHelper( - EXAMPLE_CONSUMPTIONS, new ResultHandlingServiceWithInterruptedException()); + public void testHandleResult_failed() throws Exception { + doReturn(new ResultHandlingServiceWithFail()) + .when(mHelper) + .getResultHandlingService(eq(PACKAGE_NAME)); CallbackResult result = - helper.callHandleResult(JOB_ID, POPULATION_NAME, INTERVAL_OPTIONS, true); + mHelper.callHandleResult(TASK_NAME, TRAINING_TASK, mComputationResult).get(); assertThat(result).isEqualTo(CallbackResult.FAIL); } - private static class ResultHandlingServiceWithInterruptedException + private static final class ResultHandlingServiceWithRemoteException extends IResultHandlingService.Stub { - @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) + public void handleResult(Bundle input, IFederatedComputeCallback callback) throws RemoteException { - Thread.currentThread().interrupt(); + throw new RemoteException("expected remote exception"); } } - @Test - public void testHandleResult_failed() throws Exception { - ResultCallbackHelper helper = - new ResultCallbackHelper(EXAMPLE_CONSUMPTIONS, new ResultHandlingServiceWithFail()); - - CallbackResult result = - helper.callHandleResult(JOB_ID, POPULATION_NAME, INTERVAL_OPTIONS, true); + private static class ResultHandlingServiceWithInterruptedException + extends IResultHandlingService.Stub { - assertThat(result).isEqualTo(CallbackResult.FAIL); + @Override + public void handleResult(Bundle input, IFederatedComputeCallback callback) + throws RemoteException { + Thread.currentThread().interrupt(); + } } private static final class ResultHandlingServiceWithFail extends IResultHandlingService.Stub { @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) + public void handleResult(Bundle input, IFederatedComputeCallback callback) throws RemoteException { callback.onFailure(STATUS_INTERNAL_ERROR); } @@ -171,18 +171,30 @@ public final class ResultCallbackHelperTest { private static final class TestResultHandlingService extends IResultHandlingService.Stub { @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - IFederatedComputeCallback callback) + public void handleResult(Bundle input, IFederatedComputeCallback callback) throws RemoteException { - assertThat(success).isTrue(); - assertThat(exampleConsumptionList).containsExactlyElementsIn(EXAMPLE_CONSUMPTIONS); + assertThat(input.getInt(ClientConstants.EXTRA_COMPUTATION_RESULT)) + .isEqualTo(STATUS_SUCCESS); + assertThat( + input.getParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, + ExampleConsumption.class)) + .containsExactlyElementsIn(EXAMPLE_CONSUMPTIONS); callback.onSuccess(); } } + private static ArrayList<ExampleConsumption> getExampleConsumptions() { + ArrayList<ExampleConsumption> exampleList = new ArrayList<>(); + exampleList.add( + new ExampleConsumption.Builder() + .setTaskName("taskName") + .setExampleCount(100) + .setSelectionCriteria(SELECTION_CRITERIA) + .build()); + return exampleList; + } + private static byte[] createDefaultTrainingIntervalOptions() { FlatBufferBuilder builder = new FlatBufferBuilder(); builder.finish( @@ -190,4 +202,10 @@ public final class ResultCallbackHelperTest { builder, SchedulingMode.ONE_TIME, 0)); return builder.sizedByteArray(); } + + private static byte[] createDefaultTrainingConstraints() { + FlatBufferBuilder builder = new FlatBufferBuilder(); + builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true)); + return builder.sizedByteArray(); + } } diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultHandlingServiceProviderImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultHandlingServiceProviderImplTest.java deleted file mode 100644 index d5c7a08d..00000000 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/ResultHandlingServiceProviderImplTest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.training; - -import static android.federatedcompute.common.ClientConstants.RESULT_HANDLING_SERVICE_ACTION; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.when; - -import android.content.Context; -import android.content.Intent; -import android.net.Uri; - -import androidx.test.core.app.ApplicationProvider; -import androidx.test.ext.junit.runners.AndroidJUnit4; - -import com.android.federatedcompute.services.common.Flags; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -@RunWith(AndroidJUnit4.class) -public final class ResultHandlingServiceProviderImplTest { - private ResultHandlingServiceProviderImpl mResultHandlingServiceProvider; - private Context mContext = ApplicationProvider.getApplicationContext(); - private static final long TIMEOUT_SECS = 5L; - - private Intent mIntent; - @Mock private Flags mMockFlags; - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - String packageName = mContext.getPackageName(); - mResultHandlingServiceProvider = - new ResultHandlingServiceProviderImpl(mContext, mMockFlags); - mIntent = new Intent(); - mIntent.setAction(RESULT_HANDLING_SERVICE_ACTION).setPackage(packageName); - mIntent.setData(new Uri.Builder().scheme("app").authority(packageName).build()); - when(mMockFlags.getResultHandlingBindServiceTimeoutSecs()).thenReturn(TIMEOUT_SECS); - } - - @After - public void cleanup() { - mResultHandlingServiceProvider.unbindService(); - } - - @Test - public void testBindService() { - assertTrue(mResultHandlingServiceProvider.bindService(mIntent)); - } - - @Test - public void testBindService_invalidIntent() { - Intent intent = new Intent(); - intent.setAction(RESULT_HANDLING_SERVICE_ACTION).setPackage("android.foo"); - intent.setData(new Uri.Builder().scheme("app").build()); - assertFalse(mResultHandlingServiceProvider.bindService(intent)); - } - - @Test - public void testGetResultHandlingService() throws Exception { - mResultHandlingServiceProvider.bindService(mIntent); - - assertNotNull(mResultHandlingServiceProvider.getResultHandlingService()); - } - - @Test - public void testUnbindService() throws Exception { - assertTrue(mResultHandlingServiceProvider.bindService(mIntent)); - - mResultHandlingServiceProvider.unbindService(); - } - - @Test - public void testUnbindService_serviceNonExist() { - mResultHandlingServiceProvider.unbindService(); - } -} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/SampleResultHandlingService.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/SampleResultHandlingService.java index c31de611..65df1e3d 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/SampleResultHandlingService.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/SampleResultHandlingService.java @@ -20,23 +20,26 @@ import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ER import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; import android.federatedcompute.ResultHandlingService; +import android.federatedcompute.common.ClientConstants; import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingOptions; +import android.os.Bundle; import android.util.Log; -import java.util.List; +import java.util.ArrayList; import java.util.function.Consumer; /** A simple implementation of {@link ResultHandlingService}. */ public class SampleResultHandlingService extends ResultHandlingService { private static final String TAG = "SampleResultHandlingService"; - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - Consumer<Integer> callback) { - Log.i(TAG, "Handling result for population: " + trainingOptions.getPopulationName()); + public void handleResult(Bundle params, Consumer<Integer> callback) { + Log.i( + TAG, + "Handling result for population: " + + params.getString(ClientConstants.EXTRA_POPULATION_NAME)); + ArrayList<ExampleConsumption> exampleConsumptionList = + params.getParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, ExampleConsumption.class); if (exampleConsumptionList.isEmpty()) { callback.accept(STATUS_INTERNAL_ERROR); return; diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/FlRunnerWrapperTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/FlRunnerWrapperTest.java new file mode 100644 index 00000000..ea007369 --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/FlRunnerWrapperTest.java @@ -0,0 +1,300 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.when; + +import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; + +import android.content.Context; +import android.net.Uri; +import android.os.ParcelFileDescriptor; + +import androidx.test.core.app.ApplicationProvider; + +import com.android.federatedcompute.services.examplestore.ExampleIterator; +import com.android.federatedcompute.services.testutils.ResettingExampleIterator; +import com.android.federatedcompute.services.testutils.TrainingTestUtil; +import com.android.federatedcompute.services.training.util.ListenableSupplier; + +import com.google.intelligence.fcp.client.ExampleQueryResult; +import com.google.intelligence.fcp.client.ExampleQueryResult.VectorData; +import com.google.intelligence.fcp.client.ExampleQueryResult.VectorData.Int64Values; +import com.google.intelligence.fcp.client.ExampleQueryResult.VectorData.StringValues; +import com.google.intelligence.fcp.client.ExampleQueryResult.VectorData.Values; +import com.google.intelligence.fcp.client.FLRunnerResult; +import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; +import com.google.internal.federated.plan.AggregationConfig; +import com.google.internal.federated.plan.ClientOnlyPlan; +import com.google.internal.federated.plan.ClientPhase; +import com.google.internal.federated.plan.Dataset; +import com.google.internal.federated.plan.Dataset.ClientDataset; +import com.google.internal.federated.plan.ExampleQuerySpec; +import com.google.internal.federated.plan.ExampleQuerySpec.ExampleQuery; +import com.google.internal.federated.plan.ExampleQuerySpec.OutputVectorSpec; +import com.google.internal.federated.plan.ExampleQuerySpec.OutputVectorSpec.DataType; +import com.google.internal.federated.plan.ExampleSelector; +import com.google.internal.federated.plan.FederatedExampleQueryIORouter; +import com.google.internal.federated.plan.TFV1CheckpointAggregation; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.File; +import java.io.InputStream; +import java.nio.file.Files; + +@RunWith(JUnit4.class) +public final class FlRunnerWrapperTest { + private static final String SESSION_NAME = "session_name"; + private static final String TASK_NAME = "task_name"; + private static final String POPULATION_NAME = "population_name"; + private static final String STRING_VECTOR_NAME = "vector1"; + private static final String INT_VECTOR_NAME = "vector2"; + private static final String STRING_TENSOR_NAME = "tensor1"; + private static final String INT_TENSOR_NAME = "tensor2"; + private static final String TEST_URI_PREFIX = + "android.resource://com.android.ondevicepersonalization.federatedcomputetests/raw/"; + + @Mock ListenableSupplier<Boolean> mInterruptionFlag; + @Mock ExampleIterator mExampleIterator; + + private FlRunnerWrapper mFlRunnerWrapper; + + @Before + public void doBeforeEachTest() throws Exception { + MockitoAnnotations.initMocks(this); + + when(mInterruptionFlag.get()).thenReturn(false); + } + + @After + public void tearDown() throws Exception { + mFlRunnerWrapper.close(); + } + + @Test + public void testRunInvalidPlan_returnFail() throws Exception { + File inputCheckpointFile = File.createTempFile("input", ".ckp"); + File outputCheckpointFile = File.createTempFile("output", ".ckp"); + ClientOnlyPlan invalidClientOnlyPlan = ClientOnlyPlan.getDefaultInstance(); + when(mExampleIterator.hasNext()).thenReturn(false); + + mFlRunnerWrapper = + new FlRunnerWrapper(mInterruptionFlag, POPULATION_NAME, mExampleIterator); + + FLRunnerResult result = + mFlRunnerWrapper.run( + TASK_NAME, + POPULATION_NAME, + invalidClientOnlyPlan, + inputCheckpointFile.getAbsolutePath(), + outputCheckpointFile.getAbsolutePath()); + + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.FAIL); + } + + @Test + public void testRunFederatedAnalytics_returnSuccess() throws Exception { + ClientOnlyPlan clientOnlyPlan = TrainingTestUtil.createFederatedAnalyticClientPlan(); + Values stringValues = + Values.newBuilder() + .setStringValues( + StringValues.newBuilder() + .addValue("value1") + .addValue("value2") + .build()) + .build(); + Values intValues = + Values.newBuilder() + .setInt64Values(Int64Values.newBuilder().addValue(42).addValue(24).build()) + .build(); + ExampleQueryResult queryResult = + ExampleQueryResult.newBuilder() + .setVectorData( + VectorData.newBuilder() + .putVectors(STRING_VECTOR_NAME, stringValues) + .putVectors(INT_VECTOR_NAME, intValues)) + .build(); + ClientDataset clientDataset = + ClientDataset.newBuilder() + .setClientId("clientId") + .addExample(queryResult.toByteString()) + .build(); + + Dataset dataset = Dataset.newBuilder().addClientData(clientDataset).build(); + File inputCheckpointFile = File.createTempFile("input", ".ckp"); + File outputCheckpointFile = File.createTempFile("output", ".ckp"); + + ResettingExampleIterator resettingExampleIterator = + new ResettingExampleIterator(dataset.getClientDataCount(), dataset); + + mFlRunnerWrapper = + new FlRunnerWrapper(mInterruptionFlag, POPULATION_NAME, resettingExampleIterator); + + FLRunnerResult result = + mFlRunnerWrapper.run( + TASK_NAME, + POPULATION_NAME, + clientOnlyPlan, + inputCheckpointFile.getAbsolutePath(), + outputCheckpointFile.getAbsolutePath()); + + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS); + byte[] content = Files.readAllBytes(outputCheckpointFile.toPath()); + assertTrue(content.length > 0); + } + + @Test + public void testInvalidCollection_returnsFail() throws Exception { + OutputVectorSpec stringVectorSpec = + OutputVectorSpec.newBuilder() + .setVectorName(STRING_VECTOR_NAME) + .setDataType(DataType.STRING) + .build(); + OutputVectorSpec intVectorSpec = + OutputVectorSpec.newBuilder() + .setVectorName(INT_VECTOR_NAME) + .setDataType(DataType.INT64) + .build(); + + ExampleQuery exampleQuery = + ExampleQuery.newBuilder() + .setExampleSelector( + ExampleSelector.newBuilder() + .setCollectionUri("app://com.foo.bar/inapp/collection1#") + .build()) + .putOutputVectorSpecs(STRING_TENSOR_NAME, stringVectorSpec) + .putOutputVectorSpecs(INT_TENSOR_NAME, intVectorSpec) + .build(); + AggregationConfig aggregationConfig = + AggregationConfig.newBuilder() + .setTfV1CheckpointAggregation( + TFV1CheckpointAggregation.getDefaultInstance()) + .build(); + FederatedExampleQueryIORouter ioRouter = + FederatedExampleQueryIORouter.newBuilder() + .putAggregations(STRING_TENSOR_NAME, aggregationConfig) + .putAggregations(INT_TENSOR_NAME, aggregationConfig) + .build(); + ClientOnlyPlan clientOnlyPlan = + ClientOnlyPlan.newBuilder() + .setPhase( + ClientPhase.newBuilder() + .setFederatedExampleQuery(ioRouter) + .setExampleQuerySpec( + ExampleQuerySpec.newBuilder() + .addExampleQueries(exampleQuery) + .build())) + .build(); + + File inputCheckpointFile = File.createTempFile("input", ".ckp"); + File outputCheckpointFile = File.createTempFile("output", ".ckp"); + ResettingExampleIterator resettingExampleIterator = + new ResettingExampleIterator(0, Dataset.getDefaultInstance()); + + mFlRunnerWrapper = + new FlRunnerWrapper(mInterruptionFlag, POPULATION_NAME, resettingExampleIterator); + + FLRunnerResult result = + mFlRunnerWrapper.run( + TASK_NAME, + POPULATION_NAME, + clientOnlyPlan, + inputCheckpointFile.getAbsolutePath(), + outputCheckpointFile.getAbsolutePath()); + + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.FAIL); + } + + @Test + public void testEmptyIterator_returnsFail() throws Exception { + ClientOnlyPlan clientOnlyPlan = TrainingTestUtil.createFederatedAnalyticClientPlan(); + File inputCheckpointFile = File.createTempFile("input", ".ckp"); + File outputCheckpointFile = File.createTempFile("output", ".ckp"); + ResettingExampleIterator resettingExampleIterator = + new ResettingExampleIterator(0, Dataset.getDefaultInstance()); + + mFlRunnerWrapper = + new FlRunnerWrapper(mInterruptionFlag, POPULATION_NAME, resettingExampleIterator); + + FLRunnerResult result = + mFlRunnerWrapper.run( + TASK_NAME, + POPULATION_NAME, + clientOnlyPlan, + inputCheckpointFile.getAbsolutePath(), + outputCheckpointFile.getAbsolutePath()); + + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.FAIL); + } + + @Test + public void testRunFederatedLearning_returnsSuccess() throws Exception { + Context context = ApplicationProvider.getApplicationContext(); + Uri checkpointUri = Uri.parse(TEST_URI_PREFIX + "federation_test_checkpoint_client"); + Uri clientOnlyPlanUri = Uri.parse(TEST_URI_PREFIX + "federation_client_only_plan"); + Uri trainExamplesUri = Uri.parse(TEST_URI_PREFIX + "federation_proxy_train_examples"); + File inputCheckpointFile = File.createTempFile("input", ".ckp"); + File outputCheckpointFile = File.createTempFile("output", ".ckp"); + InputStream in = context.getContentResolver().openInputStream(checkpointUri); + Files.copy(in, inputCheckpointFile.toPath(), REPLACE_EXISTING); + in.close(); + ParcelFileDescriptor inputCheckpointFd = + ParcelFileDescriptor.open(inputCheckpointFile, ParcelFileDescriptor.MODE_READ_ONLY); + ParcelFileDescriptor outputCheckpointFd = + ParcelFileDescriptor.open( + outputCheckpointFile, ParcelFileDescriptor.MODE_WRITE_ONLY); + + in = context.getContentResolver().openInputStream(clientOnlyPlanUri); + ClientOnlyPlan clientOnlyPlan = ClientOnlyPlan.parseFrom(in); + in.close(); + + in = context.getContentResolver().openInputStream(trainExamplesUri); + Dataset dataset = Dataset.parseFrom(in); + in.close(); + + ResettingExampleIterator resettingExampleIterator = + new ResettingExampleIterator(dataset.getClientDataCount(), dataset); + mFlRunnerWrapper = + new FlRunnerWrapper(mInterruptionFlag, POPULATION_NAME, resettingExampleIterator); + + FLRunnerResult result = + mFlRunnerWrapper.run( + TASK_NAME, + POPULATION_NAME, + clientOnlyPlan, + getFileDescriptorForTensorflow(inputCheckpointFd), + getFileDescriptorForTensorflow(outputCheckpointFd)); + + assertThat(result.getContributionResult()).isEqualTo(ContributionResult.SUCCESS); + } + + // We implement a customized tensorflow filesystem which support file descriptor for read and + // write. The file format is "fd:///${fd_number}". + private String getFileDescriptorForTensorflow(ParcelFileDescriptor parcelFileDescriptor) { + return "fd:///" + parcelFileDescriptor.getFd(); + } +} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/JavaExampleStoreTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/JavaExampleStoreTest.java new file mode 100644 index 00000000..9511d245 --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/JavaExampleStoreTest.java @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import static com.android.federatedcompute.services.common.ErrorStatusException.buildStatus; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; + +import com.android.federatedcompute.services.common.ErrorStatusException; +import com.android.federatedcompute.services.examplestore.ExampleIterator; + +import com.google.common.collect.ImmutableList; +import com.google.internal.federated.plan.ExampleSelector; +import com.google.internal.federatedcompute.v1.Code; +import com.google.protobuf.ByteString; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.lang.Thread.UncaughtExceptionHandler; +import java.util.List; + +@RunWith(JUnit4.class) +public final class JavaExampleStoreTest { + private static final String COLLECTION_URI = "app://test_collection"; + private static final ExampleSelector SELECTOR = + ExampleSelector.newBuilder().setCollectionUri(COLLECTION_URI).build(); + + @Mock private com.android.federatedcompute.services.examplestore.ExampleIterator mIterator; + private JavaExampleStore mJavaExampleStore; + @Mock UncaughtExceptionHandler mUncaughtExceptionHandler; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testCreateExampleIterator_success() throws Exception { + ByteString data1 = ByteString.copyFromUtf8("data1"); + ByteString data2 = ByteString.copyFromUtf8("data2"); + List<ByteString> data = ImmutableList.of(data1, data2); + mJavaExampleStore = new JavaExampleStore(new FakeExampleIterator(data)); + JavaExampleIterator nativeIterator = + mJavaExampleStore.createExampleIteratorWithContext( + SELECTOR.toByteArray(), new byte[] {}); + + // Verify all the data returns correctly. + assertThat(nativeIterator.next()).isEqualTo(data1.toByteArray()); + assertThat(nativeIterator.next()).isEqualTo(data2.toByteArray()); + + // Verify all subsequent next() calls return empty result array which indicates reach to the + // end of iterator. + assertThat(nativeIterator.next()).hasLength(0); + } + + @Test + public void testCreateExampleIterator_interruptedIterator() throws Exception { + when(mIterator.hasNext()).thenThrow(new InterruptedException("Interrupted")); + + mJavaExampleStore = new JavaExampleStore(mIterator); + + JavaExampleIterator iterator = + mJavaExampleStore.createExampleIteratorWithContext( + SELECTOR.toByteArray(), new byte[] {}); + + assertThrows(InterruptedException.class, () -> iterator.next()); + } + + @Test + public void testCreateExampleIterator_throwingIterator() throws Exception { + when(mIterator.hasNext()) + .thenThrow( + new ErrorStatusException( + buildStatus(Code.UNAVAILABLE_VALUE, "can't get next example"))); + + mJavaExampleStore = new JavaExampleStore(mIterator); + JavaExampleIterator javaExampleIterator = + mJavaExampleStore.createExampleIteratorWithContext( + SELECTOR.toByteArray(), new byte[] {}); + assertThrows(ErrorStatusException.class, () -> javaExampleIterator.next()); + } + + private static class FakeExampleIterator implements ExampleIterator { + private final List<ByteString> mData; + private int mNext; + + FakeExampleIterator(List<ByteString> data) { + this.mData = data; + this.mNext = 0; + } + + @Override + public boolean hasNext() { + return this.mNext < this.mData.size(); + } + + @Override + public byte[] next() { + return this.mData.get(this.mNext++).toByteArray(); + } + + @Override + public void close() {} + } +} diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironmentImplTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironmentImplTest.java new file mode 100644 index 00000000..fe087b79 --- /dev/null +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/jni/SimpleTaskEnvironmentImplTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.training.jni; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; + +import com.android.federatedcompute.services.examplestore.ExampleIterator; +import com.android.federatedcompute.services.training.util.ListenableSupplier; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +@RunWith(JUnit4.class) +public final class SimpleTaskEnvironmentImplTest { + private SimpleTaskEnvironmentImpl mNativeRunnerDeps; + + @Mock ListenableSupplier<Boolean> mInterruptionFlag; + @Mock ExampleIterator mExampleIterator; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + when(mInterruptionFlag.get()).thenReturn(false); + when(mExampleIterator.hasNext()).thenReturn(false); + mNativeRunnerDeps = new SimpleTaskEnvironmentImpl(mInterruptionFlag, mExampleIterator); + } + + @After + public void tearDown() { + if (mNativeRunnerDeps != null) { + mNativeRunnerDeps.close(); + } + } + + @Test + public void testTrainingConditionsSatisfied() { + assertThat(mNativeRunnerDeps.trainingConditionsSatisfied()).isTrue(); + } + + @Test + public void setTrainingConditionsInterruptionFlag() { + when(mInterruptionFlag.get()).thenReturn(Boolean.TRUE); + + assertThat(mNativeRunnerDeps.trainingConditionsSatisfied()).isFalse(); + } + + @Test + public void testGetCacheDir() { + assertThrows(UnsupportedOperationException.class, () -> mNativeRunnerDeps.getCacheDir()); + } + + @Test + public void testGetBaseDir() { + assertThrows(UnsupportedOperationException.class, () -> mNativeRunnerDeps.getBaseDir()); + } +} diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/EventUrlProviderTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/EventUrlProviderTest.java index 0b80adba..8524010a 100644 --- a/tests/frameworktests/src/android/adservices/ondevicepersonalization/EventUrlProviderTest.java +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/EventUrlProviderTest.java @@ -48,7 +48,8 @@ public class EventUrlProviderTest { params.putString("id", "abc"); assertEquals( "odp://5-abc-null-null-null", - mEventUrlProvider.getEventTrackingUrl(params, null, null).toString()); + mEventUrlProvider.createEventTrackingUrlWithResponse( + params, null, null).toString()); } @Test public void testGetEventUrlReturnsResponseFromService() throws Exception { @@ -57,7 +58,7 @@ public class EventUrlProviderTest { params.putString("id", "abc"); assertEquals( "odp://5-abc-AB-image/gif-null", - mEventUrlProvider.getEventTrackingUrl( + mEventUrlProvider.createEventTrackingUrlWithResponse( params, RESPONSE_BYTES, "image/gif").toString()); } @@ -66,8 +67,10 @@ public class EventUrlProviderTest { params.putInt("type", 5); params.putString("id", "abc"); assertEquals( - "odp://5-abc-null-null-def", - mEventUrlProvider.getEventTrackingUrlWithRedirect(params, "def").toString()); + "odp://5-abc-null-null-http://def", + mEventUrlProvider.createEventTrackingUrlWithRedirect( + params, Uri.parse("http://def")) + .toString()); } @Test public void testGetEventUrlThrowsOnError() throws Exception { @@ -76,8 +79,8 @@ public class EventUrlProviderTest { params.putInt("type", EVENT_TYPE_ERROR); params.putString("id", "abc"); assertThrows( - OnDevicePersonalizationException.class, - () -> mEventUrlProvider.getEventTrackingUrl( + IllegalStateException.class, + () -> mEventUrlProvider.createEventTrackingUrlWithResponse( params, null, null)); } diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/FederatedComputeSchedulerTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/FederatedComputeSchedulerTest.java new file mode 100644 index 00000000..34b3f414 --- /dev/null +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/FederatedComputeSchedulerTest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; +import android.federatedcompute.common.TrainingOptions; +import android.os.RemoteException; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.filters.SmallTest; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.time.Duration; + +/** + * Unit Tests of RemoteData API. + */ +@SmallTest +@RunWith(AndroidJUnit4.class) +public class FederatedComputeSchedulerTest { + FederatedComputeScheduler mFederatedComputeScheduler = new FederatedComputeScheduler( + IFederatedComputeService.Stub.asInterface( + new FederatedComputeService())); + + private boolean mCancelCalled = false; + private boolean mScheduleCalled = false; + + + @Test + public void testScheduleSuccess() { + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumInterval(Duration.ofHours(10)) + .setSchedulingMode(1) + .build(); + FederatedComputeScheduler.Params params = new FederatedComputeScheduler.Params(interval); + FederatedComputeInput input = new FederatedComputeInput.Builder() + .setPopulationName("population") + .build(); + mFederatedComputeScheduler.schedule(params, input); + assertTrue(mScheduleCalled); + } + + @Test + public void testScheduleNull() { + FederatedComputeScheduler fcs = new FederatedComputeScheduler(null); + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumInterval(Duration.ofHours(10)) + .setSchedulingMode(1) + .build(); + FederatedComputeScheduler.Params params = new FederatedComputeScheduler.Params(interval); + FederatedComputeInput input = new FederatedComputeInput.Builder() + .setPopulationName("population") + .build(); + assertThrows(IllegalStateException.class, () -> fcs.schedule(params, input)); + } + + @Test + public void testScheduleErr() { + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumInterval(Duration.ofHours(10)) + .setSchedulingMode(1) + .build(); + FederatedComputeScheduler.Params params = new FederatedComputeScheduler.Params(interval); + FederatedComputeInput input = new FederatedComputeInput.Builder() + .setPopulationName("err") + .build(); + assertThrows(IllegalStateException.class, + () -> mFederatedComputeScheduler.schedule(params, input)); + } + + @Test + public void testCancelSuccess() { + mFederatedComputeScheduler.cancel("population"); + assertTrue(mCancelCalled); + } + + @Test + public void testCancelNull() { + FederatedComputeScheduler fcs = new FederatedComputeScheduler(null); + assertThrows(IllegalStateException.class, () -> fcs.cancel("population")); + } + + @Test + public void testCancelErr() { + assertThrows(IllegalStateException.class, () -> mFederatedComputeScheduler.cancel("err")); + } + + class FederatedComputeService extends IFederatedComputeService.Stub { + @Override + public void schedule(TrainingOptions trainingOptions, + IFederatedComputeCallback iFederatedComputeCallback) throws RemoteException { + mScheduleCalled = true; + if (trainingOptions.getPopulationName().equals("err")) { + iFederatedComputeCallback.onFailure(1); + } + iFederatedComputeCallback.onSuccess(); + } + + @Override + public void cancel(String s, IFederatedComputeCallback iFederatedComputeCallback) + throws RemoteException { + mCancelCalled = true; + if (s.equals("err")) { + iFederatedComputeCallback.onFailure(1); + } + iFederatedComputeCallback.onSuccess(); + } + } +} diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/IsolatedServiceTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/IsolatedServiceTest.java index 4cd8d088..9fd0e768 100644 --- a/tests/frameworktests/src/android/adservices/ondevicepersonalization/IsolatedServiceTest.java +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/IsolatedServiceTest.java @@ -16,15 +16,19 @@ package android.adservices.ondevicepersonalization; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import android.adservices.ondevicepersonalization.aidl.IDataAccessService; import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; import android.adservices.ondevicepersonalization.aidl.IIsolatedService; import android.adservices.ondevicepersonalization.aidl.IIsolatedServiceCallback; import android.content.ContentValues; +import android.federatedcompute.common.TrainingOptions; import android.os.Bundle; import android.os.ParcelFileDescriptor; import android.os.PersistableBundle; @@ -39,23 +43,24 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; -/** - * Unit Tests of IsolatedService class. - */ +/** Unit Tests of IsolatedService class. */ @SmallTest @RunWith(AndroidJUnit4.class) public class IsolatedServiceTest { private static final String EVENT_TYPE_KEY = "event_type"; private final TestService mTestService = new TestService(); - private IIsolatedService mBinder; private final CountDownLatch mLatch = new CountDownLatch(1); + private IIsolatedService mBinder; private boolean mSelectContentCalled; private boolean mOnDownloadCalled; private boolean mOnRenderCalled; private boolean mOnEventCalled; + private boolean mOnTrainingExampleCalled; private Bundle mCallbackResult; private int mCallbackErrorCode; @@ -70,9 +75,7 @@ public class IsolatedServiceTest { assertThrows( IllegalArgumentException.class, () -> { - mBinder.onRequest( - 9999, new Bundle(), - new TestServiceCallback()); + mBinder.onRequest(9999, new Bundle(), new TestServiceCallback()); }); } @@ -80,39 +83,42 @@ public class IsolatedServiceTest { public void testOnExecute() throws Exception { PersistableBundle appParams = new PersistableBundle(); appParams.putString("x", "y"); - ExecuteInput input = - new ExecuteInput.Builder() - .setAppPackageName("com.testapp") - .setAppParams(appParams) - .build(); + ExecuteInputParcel input = + new ExecuteInputParcel.Builder() + .setAppPackageName("com.testapp") + .setAppParams(appParams) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_EXECUTE, params, new TestServiceCallback()); + params.putBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, + new TestFederatedComputeService()); + mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); mLatch.await(); assertTrue(mSelectContentCalled); - ExecuteOutput result = - mCallbackResult.getParcelable(Constants.EXTRA_RESULT, ExecuteOutput.class); - assertEquals( - 5, result.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue()); + ExecuteOutputParcel result = + mCallbackResult.getParcelable(Constants.EXTRA_RESULT, ExecuteOutputParcel.class); + assertEquals(5, result.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue()); assertEquals("123", result.getRenderingConfigs().get(0).getKeys().get(0)); } @Test public void testOnExecutePropagatesError() throws Exception { PersistableBundle appParams = new PersistableBundle(); - appParams.putInt("error", 1); // Trigger an error in the service. - ExecuteInput input = - new ExecuteInput.Builder() - .setAppPackageName("com.testapp") - .setAppParams(appParams) - .build(); + appParams.putInt("error", 1); // Trigger an error in the service. + ExecuteInputParcel input = + new ExecuteInputParcel.Builder() + .setAppPackageName("com.testapp") + .setAppParams(appParams) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_EXECUTE, params, new TestServiceCallback()); + params.putBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, + new TestFederatedComputeService()); + mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); mLatch.await(); assertTrue(mSelectContentCalled); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); @@ -120,15 +126,14 @@ public class IsolatedServiceTest { @Test public void testOnExecuteWithoutAppParams() throws Exception { - ExecuteInput input = - new ExecuteInput.Builder() - .setAppPackageName("com.testapp") - .build(); + ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_EXECUTE, params, new TestServiceCallback()); + params.putBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, + new TestFederatedComputeService()); + mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); mLatch.await(); assertTrue(mSelectContentCalled); } @@ -138,9 +143,7 @@ public class IsolatedServiceTest { assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_EXECUTE, null, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_EXECUTE, null, new TestServiceCallback()); }); } @@ -148,64 +151,76 @@ public class IsolatedServiceTest { public void testOnExecuteThrowsIfInputMissing() throws Exception { Bundle params = new Bundle(); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); + params.putBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, + new TestFederatedComputeService()); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_EXECUTE, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); }); } @Test public void testOnExecuteThrowsIfDataAccessServiceMissing() throws Exception { - ExecuteInput input = - new ExecuteInput.Builder() - .setAppPackageName("com.testapp") - .build(); + ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build(); Bundle params = new Bundle(); + params.putBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, + new TestFederatedComputeService()); params.putParcelable(Constants.EXTRA_INPUT, input); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_EXECUTE, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); + }); + } + + @Test + public void testOnExecuteThrowsIfFederatedComputeServiceMissing() throws Exception { + ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build(); + Bundle params = new Bundle(); + params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); + params.putParcelable(Constants.EXTRA_INPUT, input); + assertThrows( + NullPointerException.class, + () -> { + mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); }); } @Test public void testOnExecuteThrowsIfCallbackMissing() throws Exception { - ExecuteInput input = - new ExecuteInput.Builder() - .setAppPackageName("com.testapp") - .build(); + ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_EXECUTE, params, null); + mBinder.onRequest(Constants.OP_EXECUTE, params, null); }); } @Test public void testOnDownload() throws Exception { - DownloadInputParcel input = new DownloadInputParcel.Builder() - .setDownloadedKeys(StringParceledListSlice.emptyList()) - .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) - .build(); + DownloadInputParcel input = + new DownloadInputParcel.Builder() + .setDownloadedKeys(StringParceledListSlice.emptyList()) + .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_DOWNLOAD, params, new TestServiceCallback()); + params.putBinder( + Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, + new TestFederatedComputeService()); + mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback()); mLatch.await(); assertTrue(mOnDownloadCalled); - DownloadOutput result = - mCallbackResult.getParcelable(Constants.EXTRA_RESULT, DownloadOutput.class); + DownloadCompletedOutputParcel result = + mCallbackResult.getParcelable( + Constants.EXTRA_RESULT, DownloadCompletedOutputParcel.class); assertEquals("12", result.getRetainedKeys().get(0)); } @@ -214,9 +229,7 @@ public class IsolatedServiceTest { assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_DOWNLOAD, null, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_DOWNLOAD, null, new TestServiceCallback()); }); } @@ -227,83 +240,92 @@ public class IsolatedServiceTest { assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_DOWNLOAD, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback()); }); } @Test public void testOnDownloadThrowsIfDataAccessServiceMissing() throws Exception { - DownloadInputParcel input = new DownloadInputParcel.Builder() - .setDownloadedKeys(StringParceledListSlice.emptyList()) - .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) - .build(); + DownloadInputParcel input = + new DownloadInputParcel.Builder() + .setDownloadedKeys(StringParceledListSlice.emptyList()) + .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_DOWNLOAD, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback()); + }); + } + + @Test + public void testOnDownloadThrowsIfFederatedComputeServiceMissing() throws Exception { + DownloadInputParcel input = + new DownloadInputParcel.Builder() + .setDownloadedKeys(StringParceledListSlice.emptyList()) + .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) + .build(); + Bundle params = new Bundle(); + params.putParcelable(Constants.EXTRA_INPUT, input); + params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); + assertThrows( + NullPointerException.class, + () -> { + mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback()); }); } @Test public void testOnDownloadThrowsIfCallbackMissing() throws Exception { ParcelFileDescriptor[] pfds = ParcelFileDescriptor.createPipe(); - DownloadInputParcel input = new DownloadInputParcel.Builder() - .setDownloadedKeys(StringParceledListSlice.emptyList()) - .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) - .build(); + DownloadInputParcel input = + new DownloadInputParcel.Builder() + .setDownloadedKeys(StringParceledListSlice.emptyList()) + .setDownloadedValues(ByteArrayParceledListSlice.emptyList()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_DOWNLOAD, params, null); + mBinder.onRequest(Constants.OP_DOWNLOAD, params, null); }); } @Test public void testOnRender() throws Exception { - RenderInput input = - new RenderInput.Builder() - .setRenderingConfig( - new RenderingConfig.Builder() - .addKey("a") - .addKey("b") - .build()) - .build(); + RenderInputParcel input = + new RenderInputParcel.Builder() + .setRenderingConfig( + new RenderingConfig.Builder().addKey("a").addKey("b").build()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_RENDER, params, new TestServiceCallback()); + mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback()); mLatch.await(); assertTrue(mOnRenderCalled); - RenderOutput result = - mCallbackResult.getParcelable(Constants.EXTRA_RESULT, RenderOutput.class); + RenderOutputParcel result = + mCallbackResult.getParcelable(Constants.EXTRA_RESULT, RenderOutputParcel.class); assertEquals("htmlstring", result.getContent()); } @Test public void testOnRenderPropagatesError() throws Exception { - RenderInput input = - new RenderInput.Builder() - .setRenderingConfig( - new RenderingConfig.Builder() - .addKey("z") // Trigger error in service. - .build()) - .build(); + RenderInputParcel input = + new RenderInputParcel.Builder() + .setRenderingConfig( + new RenderingConfig.Builder() + .addKey("z") // Trigger error in service. + .build()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_RENDER, params, new TestServiceCallback()); + mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback()); mLatch.await(); assertTrue(mOnRenderCalled); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); @@ -314,9 +336,7 @@ public class IsolatedServiceTest { assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_RENDER, null, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_RENDER, null, new TestServiceCallback()); }); } @@ -327,85 +347,70 @@ public class IsolatedServiceTest { assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_RENDER, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback()); }); } @Test public void testOnRenderThrowsIfDataAccessServiceMissing() throws Exception { - RenderInput input = - new RenderInput.Builder() - .setRenderingConfig( - new RenderingConfig.Builder() - .addKey("a") - .addKey("b") - .build()) - .build(); + RenderInputParcel input = + new RenderInputParcel.Builder() + .setRenderingConfig( + new RenderingConfig.Builder().addKey("a").addKey("b").build()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_RENDER, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback()); }); } @Test public void testOnRenderThrowsIfCallbackMissing() throws Exception { - RenderInput input = - new RenderInput.Builder() - .setRenderingConfig( - new RenderingConfig.Builder() - .addKey("a") - .addKey("b") - .build()) - .build(); + RenderInputParcel input = + new RenderInputParcel.Builder() + .setRenderingConfig( + new RenderingConfig.Builder().addKey("a").addKey("b").build()) + .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_RENDER, params, null); + mBinder.onRequest(Constants.OP_RENDER, params, null); }); } @Test - public void testOnWebViewEvent() throws Exception { + public void testOnEvent() throws Exception { Bundle params = new Bundle(); params.putParcelable( Constants.EXTRA_INPUT, - new WebViewEventInput.Builder().setParameters(PersistableBundle.EMPTY).build()); + new EventInputParcel.Builder().setParameters(PersistableBundle.EMPTY).build()); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_WEB_VIEW_EVENT, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback()); mLatch.await(); assertTrue(mOnEventCalled); - WebViewEventOutput result = - mCallbackResult.getParcelable(Constants.EXTRA_RESULT, WebViewEventOutput.class); + EventOutputParcel result = + mCallbackResult.getParcelable(Constants.EXTRA_RESULT, EventOutputParcel.class); assertEquals(1, result.getEventLogRecord().getType()); assertEquals(2, result.getEventLogRecord().getRowIndex()); } @Test - public void testOnWebViewEventPropagatesError() throws Exception { + public void testOnEventPropagatesError() throws Exception { PersistableBundle eventParams = new PersistableBundle(); // Input value 9999 will trigger an error in the mock service. eventParams.putInt(EVENT_TYPE_KEY, 9999); Bundle params = new Bundle(); params.putParcelable( Constants.EXTRA_INPUT, - new WebViewEventInput.Builder().setParameters(eventParams).build()); + new EventInputParcel.Builder().setParameters(eventParams).build()); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); - mBinder.onRequest( - Constants.OP_WEB_VIEW_EVENT, params, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback()); mLatch.await(); assertTrue(mOnEventCalled); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); @@ -416,9 +421,7 @@ public class IsolatedServiceTest { assertThrows( NullPointerException.class, () -> { - mBinder.onRequest( - Constants.OP_WEB_VIEW_EVENT, null, - new TestServiceCallback()); + mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, null, new TestServiceCallback()); }); } @@ -430,8 +433,7 @@ public class IsolatedServiceTest { NullPointerException.class, () -> { mBinder.onRequest( - Constants.OP_WEB_VIEW_EVENT, params, - new TestServiceCallback()); + Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback()); }); } @@ -440,13 +442,12 @@ public class IsolatedServiceTest { Bundle params = new Bundle(); params.putParcelable( Constants.EXTRA_INPUT, - new WebViewEventInput.Builder().setParameters(PersistableBundle.EMPTY).build()); + new EventInputParcel.Builder().setParameters(PersistableBundle.EMPTY).build()); assertThrows( NullPointerException.class, () -> { mBinder.onRequest( - Constants.OP_WEB_VIEW_EVENT, params, - new TestServiceCallback()); + Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback()); }); } @@ -455,21 +456,105 @@ public class IsolatedServiceTest { Bundle params = new Bundle(); params.putParcelable( Constants.EXTRA_INPUT, - new WebViewEventInput.Builder().setParameters(PersistableBundle.EMPTY).build()); + new EventInputParcel.Builder().setParameters(PersistableBundle.EMPTY).build()); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); assertThrows( NullPointerException.class, () -> { + mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, null); + }); + } + + @Test + public void testOnTrainingExample() throws Exception { + JoinedLogRecord joinedLogRecord = new JoinedLogRecord.Builder().build(); + TrainingExampleInput input = + new TrainingExampleInput.Builder() + .setPopulationName("") + .setTaskName("") + .setResumptionToken(new byte[] {0}) + .build(); + Bundle params = new Bundle(); + params.putParcelable(Constants.EXTRA_INPUT, input); + params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); + mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback()); + mLatch.await(); + assertTrue(mOnTrainingExampleCalled); + TrainingExampleOutputParcel result = + mCallbackResult.getParcelable( + Constants.EXTRA_RESULT, TrainingExampleOutputParcel.class); + List<byte[]> examples = result.getTrainingExamples().getList(); + List<byte[]> tokens = result.getResumptionTokens().getList(); + assertEquals(1, examples.size()); + assertEquals(1, tokens.size()); + assertArrayEquals(new byte[] {12}, examples.get(0)); + assertArrayEquals(new byte[] {13}, tokens.get(0)); + } + + @Test + public void testOnTrainingExampleThrowsIfParamsMissing() throws Exception { + assertThrows( + NullPointerException.class, + () -> { + mBinder.onRequest( + Constants.OP_TRAINING_EXAMPLE, null, new TestServiceCallback()); + }); + } + + @Test + public void testOnTrainingExampleThrowsIfDataAccessServiceMissing() throws Exception { + JoinedLogRecord joinedLogRecord = new JoinedLogRecord.Builder().build(); + TrainingExampleInput input = + new TrainingExampleInput.Builder() + .setPopulationName("") + .setTaskName("") + .setResumptionToken(new byte[] {0}) + .build(); + Bundle params = new Bundle(); + params.putParcelable(Constants.EXTRA_INPUT, input); + assertThrows( + NullPointerException.class, + () -> { mBinder.onRequest( - Constants.OP_WEB_VIEW_EVENT, params, null); + Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback()); + }); + } + + @Test + public void testOnTrainingExampleThrowsIfCallbackMissing() throws Exception { + JoinedLogRecord joinedLogRecord = new JoinedLogRecord.Builder().build(); + TrainingExampleInput input = + new TrainingExampleInput.Builder() + .setPopulationName("") + .setTaskName("") + .setResumptionToken(new byte[] {0}) + .build(); + Bundle params = new Bundle(); + params.putParcelable(Constants.EXTRA_INPUT, input); + params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); + mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback()); + assertThrows( + NullPointerException.class, + () -> { + mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, null); }); } + static class TestDataAccessService extends IDataAccessService.Stub { + @Override + public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {} + } + + static class TestFederatedComputeService extends IFederatedComputeService.Stub { + @Override + public void schedule(TrainingOptions trainingOptions, IFederatedComputeCallback callback) {} + + public void cancel(String populationName, IFederatedComputeCallback callback) {} + } + class TestHandler implements IsolatedWorker { - @Override public void onExecute( - ExecuteInput input, - Consumer<ExecuteOutput> consumer - ) { + @Override + public void onExecute(ExecuteInput input, Consumer<ExecuteOutput> consumer) { mSelectContentCalled = true; if (input.getAppParams() != null && input.getAppParams().getInt("error") > 0) { consumer.accept(null); @@ -478,77 +563,83 @@ public class IsolatedServiceTest { row.put("a", 5); consumer.accept( new ExecuteOutput.Builder() - .setRequestLogRecord(new RequestLogRecord.Builder().addRow(row).build()) - .addRenderingConfig(new RenderingConfig.Builder().addKey("123").build()) - .build()); + .setRequestLogRecord( + new RequestLogRecord.Builder().addRow(row).build()) + .addRenderingConfig( + new RenderingConfig.Builder().addKey("123").build()) + .build()); } } - @Override public void onDownload( - DownloadInput input, - Consumer<DownloadOutput> consumer - ) { + @Override + public void onDownloadCompleted( + DownloadCompletedInput input, Consumer<DownloadCompletedOutput> consumer) { mOnDownloadCalled = true; - consumer.accept(new DownloadOutput.Builder().addRetainedKey("12").build()); + consumer.accept(new DownloadCompletedOutput.Builder().addRetainedKey("12").build()); } - @Override public void onRender( - RenderInput input, - Consumer<RenderOutput> consumer - ) { + @Override + public void onRender(RenderInput input, Consumer<RenderOutput> consumer) { mOnRenderCalled = true; if (input.getRenderingConfig().getKeys().size() >= 1 - && input.getRenderingConfig().getKeys().get(0).equals("z")) { + && input.getRenderingConfig().getKeys().get(0).equals("z")) { consumer.accept(null); } else { - consumer.accept( - new RenderOutput.Builder().setContent("htmlstring").build()); + consumer.accept(new RenderOutput.Builder().setContent("htmlstring").build()); } } - @Override public void onWebViewEvent( - WebViewEventInput input, - Consumer<WebViewEventOutput> consumer - ) { + @Override + public void onEvent(EventInput input, Consumer<EventOutput> consumer) { mOnEventCalled = true; int eventType = input.getParameters().getInt(EVENT_TYPE_KEY); if (eventType == 9999) { consumer.accept(null); } else { consumer.accept( - new WebViewEventOutput.Builder() - .setEventLogRecord( - new EventLogRecord.Builder() - .setType(1) - .setRowIndex(2) - .setData(new ContentValues()) - .build()) - .build()); + new EventOutput.Builder() + .setEventLogRecord( + new EventLogRecord.Builder() + .setType(1) + .setRowIndex(2) + .setData(new ContentValues()) + .build()) + .build()); } } - } - class TestService extends IsolatedService { - @Override public IsolatedWorker onRequest(RequestToken token) { - return new TestHandler(); + @Override + public void onTrainingExample( + TrainingExampleInput input, Consumer<TrainingExampleOutput> consumer) { + mOnTrainingExampleCalled = true; + List<byte[]> examples = new ArrayList<>(); + examples.add(new byte[] {12}); + List<byte[]> tokens = new ArrayList<>(); + tokens.add(new byte[] {13}); + consumer.accept( + new TrainingExampleOutput.Builder() + .setTrainingExamples(examples) + .setResumptionTokens(tokens) + .build()); } } - static class TestDataAccessService extends IDataAccessService.Stub { + class TestService extends IsolatedService { @Override - public void onRequest( - int operation, - Bundle params, - IDataAccessServiceCallback callback - ) {} + public IsolatedWorker onRequest(RequestToken token) { + return new TestHandler(); + } } class TestServiceCallback extends IIsolatedServiceCallback.Stub { - @Override public void onSuccess(Bundle result) { + @Override + public void onSuccess(Bundle result) { mCallbackResult = result; mLatch.countDown(); } - @Override public void onError(int errorCode) { + + @Override + public void onError(int errorCode) { mCallbackErrorCode = errorCode; mLatch.countDown(); } diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/LocalDataTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/LocalDataTest.java index 8e3a9ed9..bc065160 100644 --- a/tests/frameworktests/src/android/adservices/ondevicepersonalization/LocalDataTest.java +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/LocalDataTest.java @@ -62,17 +62,11 @@ public class LocalDataTest { @Test public void testLookupError() { // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mLocalData.get("z")); + assertThrows(IllegalStateException.class, () -> mLocalData.get("z")); } @Test - public void testLookupTimeout() { - // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mLocalData.get("timeout")); - } - - @Test - public void testKeysetSuccess() throws OnDevicePersonalizationException { + public void testKeysetSuccess() { Set<String> expectedResult = new HashSet<>(); expectedResult.add("a"); expectedResult.add("b"); @@ -91,14 +85,7 @@ public class LocalDataTest { @Test public void testPutError() { // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mLocalData.put("z", - new byte[10])); - } - - @Test - public void testPutTimeout() { - // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mLocalData.put("timeout", + assertThrows(IllegalStateException.class, () -> mLocalData.put("z", new byte[10])); } @@ -112,16 +99,9 @@ public class LocalDataTest { @Test public void testRemoveError() { // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mLocalData.remove("z")); + assertThrows(IllegalStateException.class, () -> mLocalData.remove("z")); } - @Test - public void testRemoveTimeout() { - // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mLocalData.remove("timeout")); - } - - public static class LocalDataService extends IDataAccessService.Stub { HashMap<String, byte[]> mContents = new HashMap<String, byte[]>(); @@ -161,16 +141,6 @@ public class LocalDataTest { return; } - if (keys.length == 1 && keys[0].equals("timeout")) { - // Force timeout by sleeping. - try { - Thread.sleep(2000); - } catch (Exception e) { - // Ignored - } - return; - } - if (operation == Constants.DATA_ACCESS_OP_LOCAL_DATA_LOOKUP || operation == Constants.DATA_ACCESS_OP_LOCAL_DATA_PUT || operation == Constants.DATA_ACCESS_OP_LOCAL_DATA_REMOVE) { diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/LogReaderTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/LogReaderTest.java new file mode 100644 index 00000000..6397a44e --- /dev/null +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/LogReaderTest.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.ondevicepersonalization; + +import static junit.framework.Assert.assertEquals; + +import static org.junit.Assert.assertThrows; + +import android.adservices.ondevicepersonalization.aidl.IDataAccessService; +import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; +import android.content.ContentValues; +import android.os.Bundle; +import android.os.RemoteException; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.filters.SmallTest; + +import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.ArrayList; +import java.util.List; + +/** + * Unit Tests of LogReader API. + */ +@SmallTest +@RunWith(AndroidJUnit4.class) +public class LogReaderTest { + + LogReader mLogReader; + + @Before + public void setup() { + mLogReader = new LogReader( + IDataAccessService.Stub.asInterface( + new LogReaderTest.LocalDataService())); + } + + @Test + public void testGetRequestsSuccess() { + List<RequestLogRecord> result = mLogReader.getRequests(10, 100); + assertEquals(2, result.size()); + assertEquals(1, result.get(0).getRows().size()); + assertEquals((int) (result.get(0).getRows().get(0).getAsInteger("a")), 1); + assertEquals((int) (result.get(0).getRows().get(0).getAsInteger("b")), 1); + assertEquals(1, result.get(1).getRows().size()); + assertEquals((int) (result.get(1).getRows().get(0).getAsInteger("a")), 1); + assertEquals((int) (result.get(1).getRows().get(0).getAsInteger("b")), 1); + } + + @Test + public void testGetRequestsError() { + // Triggers an expected error in the mock service. + assertThrows(IllegalStateException.class, () -> mLogReader.getRequests(7, 100)); + } + + @Test + public void testGetRequestsNegativeTimeError() { + assertThrows(IllegalArgumentException.class, () -> mLogReader.getRequests(-1, 100)); + } + + @Test + public void testGetRequestsBadTimeRangeError() { + assertThrows(IllegalArgumentException.class, () -> mLogReader.getRequests(100, 100)); + assertThrows(IllegalArgumentException.class, () -> mLogReader.getRequests(1000, 100)); + } + + @Test + public void testGetJoinedEventsSuccess() { + List<EventLogRecord> result = mLogReader.getJoinedEvents(10, 100); + assertEquals(2, result.size()); + assertEquals(result.get(0).getTimeMillis(), 30); + assertEquals(result.get(0).getRequestLogRecord().getTimeMillis(), 20); + assertEquals(result.get(0).getType(), 1); + assertEquals((int) (result.get(0).getData().getAsInteger("a")), 1); + assertEquals((int) (result.get(0).getData().getAsInteger("b")), 1); + assertEquals(result.get(1).getTimeMillis(), 40); + assertEquals(result.get(1).getRequestLogRecord().getTimeMillis(), 30); + assertEquals(result.get(1).getType(), 2); + assertEquals((int) (result.get(1).getData().getAsInteger("a")), 1); + assertEquals((int) (result.get(1).getData().getAsInteger("b")), 1); + } + + @Test + public void testGetJoinedEventsError() { + // Triggers an expected error in the mock service. + assertThrows(IllegalStateException.class, () -> mLogReader.getJoinedEvents(7, 100)); + } + + @Test + public void testGetJoinedEventsNegativeTimeError() { + assertThrows(IllegalArgumentException.class, () -> mLogReader.getJoinedEvents(-1, 100)); + } + + @Test + public void testGetJoinedEventsInputError() { + assertThrows(IllegalArgumentException.class, () -> mLogReader.getJoinedEvents(100, 100)); + assertThrows(IllegalArgumentException.class, () -> mLogReader.getJoinedEvents(1000, 100)); + } + + public static class LocalDataService extends IDataAccessService.Stub { + + public LocalDataService() { + } + + @Override + public void onRequest( + int operation, + Bundle params, + IDataAccessServiceCallback callback) { + if (operation == Constants.DATA_ACCESS_OP_GET_REQUESTS + || operation == Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS) { + long[] timestamps = params.getLongArray(Constants.EXTRA_LOOKUP_KEYS); + if (timestamps[0] == 7) { + // Raise expected error. + try { + callback.onError(Constants.STATUS_INTERNAL_ERROR); + } catch (RemoteException e) { + // Ignored. + } + return; + } + + Bundle result = new Bundle(); + ContentValues values = new ContentValues(); + values.put("a", 1); + values.put("b", 1); + if (operation == Constants.DATA_ACCESS_OP_GET_REQUESTS) { + List<RequestLogRecord> records = new ArrayList<>(); + records.add(new RequestLogRecord.Builder() + .setRequestId(1) + .addRow(values) + .build()); + records.add(new RequestLogRecord.Builder() + .setRequestId(2) + .addRow(values) + .build()); + result.putParcelable(Constants.EXTRA_RESULT, + new OdpParceledListSlice<RequestLogRecord>(records)); + } else if (operation == Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS) { + List<EventLogRecord> records = new ArrayList<>(); + records.add(new EventLogRecord.Builder() + .setType(1) + .setTimeMillis(30) + .setData(values) + .setRequestLogRecord(new RequestLogRecord.Builder() + .setRequestId(0) + .addRow(values) + .setTimeMillis(20) + .build()) + .build()); + records.add(new EventLogRecord.Builder() + .setType(2) + .setTimeMillis(40) + .setData(values) + .setRequestLogRecord(new RequestLogRecord.Builder() + .setRequestId(0) + .addRow(values) + .setTimeMillis(30) + .build()) + .build()); + result.putParcelable(Constants.EXTRA_RESULT, + new OdpParceledListSlice<EventLogRecord>(records)); + } + try { + callback.onSuccess(result); + } catch (RemoteException e) { + // Ignored. + } + } + } + } +} diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/OnDevicePersonalizationFrameworkClassesTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/OnDevicePersonalizationFrameworkClassesTest.java index b676724b..258ecac1 100644 --- a/tests/frameworktests/src/android/adservices/ondevicepersonalization/OnDevicePersonalizationFrameworkClassesTest.java +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/OnDevicePersonalizationFrameworkClassesTest.java @@ -37,26 +37,75 @@ import java.util.ArrayList; @RunWith(AndroidJUnit4.class) public class OnDevicePersonalizationFrameworkClassesTest { /** + * Tests that the ExecuteInput object serializes correctly. + */ + @Test + public void testExecuteInput() { + PersistableBundle bundle = new PersistableBundle(); + bundle.putInt("a", 5); + ExecuteInputParcel data = new ExecuteInputParcel.Builder() + .setAppPackageName("com.example.test") + .setAppParams(bundle) + .build(); + + Parcel parcel = Parcel.obtain(); + data.writeToParcel(parcel, 0); + parcel.setDataPosition(0); + ExecuteInputParcel data2 = ExecuteInputParcel.CREATOR.createFromParcel(parcel); + ExecuteInput result = new ExecuteInput(data2); + + assertEquals("com.example.test", result.getAppPackageName()); + assertEquals(5, result.getAppParams().getInt("a")); + } + + /** * Tests that the ExecuteOutput object serializes correctly. */ @Test public void testExecuteOutput() { ContentValues row = new ContentValues(); row.put("a", 5); - ExecuteOutput result = + ExecuteOutput data = new ExecuteOutput.Builder() .setRequestLogRecord(new RequestLogRecord.Builder().addRow(row).build()) .addRenderingConfig(new RenderingConfig.Builder().addKey("abc").build()) + .addEventLogRecord(new EventLogRecord.Builder().setType(1).build()) .build(); + ExecuteOutputParcel result = new ExecuteOutputParcel(data); Parcel parcel = Parcel.obtain(); result.writeToParcel(parcel, 0); parcel.setDataPosition(0); - ExecuteOutput result2 = ExecuteOutput.CREATOR.createFromParcel(parcel); + ExecuteOutputParcel result2 = ExecuteOutputParcel.CREATOR.createFromParcel(parcel); assertEquals( 5, result2.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue()); assertEquals("abc", result2.getRenderingConfigs().get(0).getKeys().get(0)); + assertEquals(1, result2.getEventLogRecords().get(0).getType()); + } + + /** + * Tests that the RenderInput object serializes correctly. + */ + @Test + public void testRenderInput() { + RenderInputParcel data = new RenderInputParcel.Builder() + .setWidth(10) + .setHeight(20) + .setRenderingConfigIndex(5) + .setRenderingConfig(new RenderingConfig.Builder().addKey("abc").build()) + .build(); + + Parcel parcel = Parcel.obtain(); + data.writeToParcel(parcel, 0); + parcel.setDataPosition(0); + RenderInputParcel data2 = RenderInputParcel.CREATOR.createFromParcel(parcel); + RenderInput result = new RenderInput(data2); + + assertEquals(10, result.getWidth()); + assertEquals(20, result.getHeight()); + assertEquals(5, result.getRenderingConfigIndex()); + assertEquals("abc", result.getRenderingConfig().getKeys().get(0)); } /** @@ -64,12 +113,13 @@ public class OnDevicePersonalizationFrameworkClassesTest { */ @Test public void testRenderOutput() { - RenderOutput result = new RenderOutput.Builder().setContent("abc").build(); + RenderOutput data = new RenderOutput.Builder().setContent("abc").build(); + RenderOutputParcel result = new RenderOutputParcel(data); Parcel parcel = Parcel.obtain(); result.writeToParcel(parcel, 0); parcel.setDataPosition(0); - RenderOutput result2 = RenderOutput.CREATOR.createFromParcel(parcel); + RenderOutputParcel result2 = RenderOutputParcel.CREATOR.createFromParcel(parcel); assertEquals("abc", result2.getContent()); } @@ -78,31 +128,33 @@ public class OnDevicePersonalizationFrameworkClassesTest { * Tests that the DownloadOutput object serializes correctly. */ @Test - public void teetDownloadOutput() { - DownloadOutput result = new DownloadOutput.Builder() + public void teetDownloadCompletedOutput() { + DownloadCompletedOutput data = new DownloadCompletedOutput.Builder() .addRetainedKey("abc").addRetainedKey("def").build(); + DownloadCompletedOutputParcel result = + new DownloadCompletedOutputParcel(data); Parcel parcel = Parcel.obtain(); result.writeToParcel(parcel, 0); parcel.setDataPosition(0); - DownloadOutput result2 = DownloadOutput.CREATOR.createFromParcel(parcel); + DownloadCompletedOutputParcel result2 = + DownloadCompletedOutputParcel.CREATOR.createFromParcel(parcel); - assertEquals(result, result2); assertEquals("abc", result2.getRetainedKeys().get(0)); assertEquals("def", result2.getRetainedKeys().get(1)); } /** - * Tests that the WebViewEventInput object serializes correctly. + * Tests that the EventInput object serializes correctly. */ @Test - public void testWebViewEventInput() { + public void testEventInput() { PersistableBundle params = new PersistableBundle(); params.putInt("x", 3); ArrayList<ContentValues> rows = new ArrayList<>(); rows.add(new ContentValues()); rows.get(0).put("a", 5); - WebViewEventInput result = new WebViewEventInput.Builder() + EventInputParcel data = new EventInputParcel.Builder() .setParameters(params) .setRequestLogRecord( new RequestLogRecord.Builder() @@ -111,23 +163,24 @@ public class OnDevicePersonalizationFrameworkClassesTest { .build(); Parcel parcel = Parcel.obtain(); - result.writeToParcel(parcel, 0); + data.writeToParcel(parcel, 0); parcel.setDataPosition(0); - WebViewEventInput result2 = WebViewEventInput.CREATOR.createFromParcel(parcel); + EventInputParcel data2 = EventInputParcel.CREATOR.createFromParcel(parcel); + EventInput result = new EventInput(data2); - assertEquals(3, result2.getParameters().getInt("x")); + assertEquals(3, result.getParameters().getInt("x")); assertEquals( - 5, result2.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue()); + 5, result.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue()); } /** - * Tests that the WebViewEventOutput object serializes correctly. + * Tests that the EventOutput object serializes correctly. */ @Test - public void testWebViewEventOutput() { + public void testEventOutput() { ContentValues data = new ContentValues(); data.put("a", 3); - WebViewEventOutput result = new WebViewEventOutput.Builder() + EventOutput output = new EventOutput.Builder() .setEventLogRecord( new EventLogRecord.Builder() .setType(5) @@ -135,13 +188,13 @@ public class OnDevicePersonalizationFrameworkClassesTest { .setData(data) .build()) .build(); + EventOutputParcel result = new EventOutputParcel(output); Parcel parcel = Parcel.obtain(); result.writeToParcel(parcel, 0); parcel.setDataPosition(0); - WebViewEventOutput result2 = WebViewEventOutput.CREATOR.createFromParcel(parcel); + EventOutputParcel result2 = EventOutputParcel.CREATOR.createFromParcel(parcel); - assertEquals(result, result2); assertEquals(5, result2.getEventLogRecord().getType()); assertEquals(6, result2.getEventLogRecord().getRowIndex()); assertEquals(3, result2.getEventLogRecord().getData().getAsInteger("a").intValue()); @@ -166,10 +219,12 @@ public class OnDevicePersonalizationFrameworkClassesTest { row = new ContentValues(); row.put("b", 6); rows.add(row); - RequestLogRecord logRecord = new RequestLogRecord.Builder().setRows(rows).build(); + RequestLogRecord logRecord = new RequestLogRecord.Builder().setRows(rows) + .setRequestId(1).build(); assertEquals(2, logRecord.getRows().size()); assertEquals(5, logRecord.getRows().get(0).getAsInteger("a").intValue()); assertEquals(6, logRecord.getRows().get(1).getAsInteger("b").intValue()); + assertEquals(1, logRecord.getRequestId()); } /** Test for RequestLogRecord class. */ @@ -181,9 +236,12 @@ public class OnDevicePersonalizationFrameworkClassesTest { .setType(1) .setRowIndex(2) .setData(row) + .setRequestLogRecord(new RequestLogRecord.Builder().addRow(row).build()) .build(); assertEquals(1, logRecord.getType()); assertEquals(2, logRecord.getRowIndex()); assertEquals(5, logRecord.getData().getAsInteger("a").intValue()); + assertEquals(5, logRecord.getRequestLogRecord().getRows() + .get(0).getAsInteger("a").intValue()); } } diff --git a/tests/frameworktests/src/android/adservices/ondevicepersonalization/RemoteDataTest.java b/tests/frameworktests/src/android/adservices/ondevicepersonalization/RemoteDataTest.java index fcff3ba8..6a8bfbc9 100644 --- a/tests/frameworktests/src/android/adservices/ondevicepersonalization/RemoteDataTest.java +++ b/tests/frameworktests/src/android/adservices/ondevicepersonalization/RemoteDataTest.java @@ -56,11 +56,11 @@ public class RemoteDataTest { @Test public void testLookupError() { // Triggers an expected error in the mock service. - assertThrows(OnDevicePersonalizationException.class, () -> mRemoteData.get("z")); + assertThrows(IllegalStateException.class, () -> mRemoteData.get("z")); } @Test - public void testKeysetSuccess() throws OnDevicePersonalizationException { + public void testKeysetSuccess() { Set<String> expectedResult = new HashSet<>(); expectedResult.add("a"); expectedResult.add("b"); diff --git a/tests/frameworktests/src/android/federatedcompute/ExampleStoreQueryCallbackImplTest.java b/tests/frameworktests/src/android/federatedcompute/ExampleStoreQueryCallbackImplTest.java index 4c7bce8c..b2bfce95 100644 --- a/tests/frameworktests/src/android/federatedcompute/ExampleStoreQueryCallbackImplTest.java +++ b/tests/frameworktests/src/android/federatedcompute/ExampleStoreQueryCallbackImplTest.java @@ -137,6 +137,7 @@ public final class ExampleStoreQueryCallbackImplTest { // The second call shouldn't result in another call to the app's close() method. verify(mMockExampleStoreIterator, never()).next(any()); } + /** * Tests that additional calls to a the callback are passed through to the proxy. It will be in * charge of ignoring all but the first call. @@ -150,6 +151,7 @@ public final class ExampleStoreQueryCallbackImplTest { assertThat(adapter.onIteratorNextSuccess(new Bundle())).isTrue(); verify(mMockAidlExampleStoreIteratorCallback, times(2)).onIteratorNextSuccess(any()); } + /** * Tests that additional calls to a the callback are passed through to the proxy. It will be in * charge of ignoring all but the first call. @@ -164,6 +166,7 @@ public final class ExampleStoreQueryCallbackImplTest { verify(mMockAidlExampleStoreIteratorCallback, times(2)) .onIteratorNextFailure(eq(STATUS_INTERNAL_ERROR)); } + /** * Tests that additional calls to a the callback are passed through to the proxy. It will be in * charge of ignoring all but the first call. @@ -180,6 +183,7 @@ public final class ExampleStoreQueryCallbackImplTest { .onIteratorNextFailure(eq(STATUS_INTERNAL_ERROR)); verify(mMockAidlExampleStoreIteratorCallback, times(2)).onIteratorNextSuccess(any()); } + /** * Tests that additional calls to a the callback are passed through to the proxy. It will be in * charge of ignoring all but the first call. diff --git a/tests/frameworktests/src/android/federatedcompute/ExampleStoreServiceTest.java b/tests/frameworktests/src/android/federatedcompute/ExampleStoreServiceTest.java index 26363d37..c6a0a310 100644 --- a/tests/frameworktests/src/android/federatedcompute/ExampleStoreServiceTest.java +++ b/tests/frameworktests/src/android/federatedcompute/ExampleStoreServiceTest.java @@ -16,7 +16,7 @@ package android.federatedcompute; -import static android.federatedcompute.common.ClientConstants.EXTRA_COLLECTION_NAME; +import static android.federatedcompute.common.ClientConstants.EXTRA_TASK_NAME; import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; import static com.google.common.truth.Truth.assertThat; @@ -50,8 +50,7 @@ import javax.annotation.Nonnull; @RunWith(AndroidJUnit4.class) public class ExampleStoreServiceTest { - private static final String EXPECTED_COLLECTION_NAME = - "/federatedcompute.examplestoretest/test_collection"; + private static final String EXPECTED_TASK_NAME = "federated_task"; private static final Example EXAMPLE_PROTO_1 = Example.newBuilder() .setFeatures( @@ -70,7 +69,8 @@ public class ExampleStoreServiceTest { private int mCallbackErrorCode; private boolean mStartQueryCalled; private final CountDownLatch mLatch = new CountDownLatch(1); - private final TestExampleStoreService mTestExampleStoreService = new TestExampleStoreService(); + private final TestJavaExampleStoreService mTestExampleStoreService = + new TestJavaExampleStoreService(); private IExampleStoreService mBinder; @Before @@ -82,8 +82,8 @@ public class ExampleStoreServiceTest { @Test public void testStartQuerySuccess() throws Exception { Bundle bundle = new Bundle(); - bundle.putString(EXTRA_COLLECTION_NAME, EXPECTED_COLLECTION_NAME); - mBinder.startQuery(bundle, new TestExampleStoreServiceCallback()); + bundle.putString(EXTRA_TASK_NAME, EXPECTED_TASK_NAME); + mBinder.startQuery(bundle, new TestJavaExampleStoreServiceCallback()); mLatch.await(); assertTrue(mStartQueryCalled); assertThat(mCallbackResult).isInstanceOf(IteratorAdapter.class); @@ -92,34 +92,41 @@ public class ExampleStoreServiceTest { @Test public void testStartQueryFailure() throws Exception { Bundle bundle = new Bundle(); - bundle.putString(EXTRA_COLLECTION_NAME, "wrong_collection"); - mBinder.startQuery(bundle, new TestExampleStoreServiceCallback()); + bundle.putString(EXTRA_TASK_NAME, "wrong_taskName"); + mBinder.startQuery(bundle, new TestJavaExampleStoreServiceCallback()); mLatch.await(); assertTrue(mStartQueryCalled); assertThat(mCallbackErrorCode).isEqualTo(STATUS_INTERNAL_ERROR); + assertThat(mCallbackErrorCode).isEqualTo(STATUS_INTERNAL_ERROR); } - class TestExampleStoreService extends ExampleStoreService { + class TestJavaExampleStoreService extends ExampleStoreService { @Override public void startQuery(@Nonnull Bundle params, @Nonnull QueryCallback callback) { mStartQueryCalled = true; - String collection = params.getString(EXTRA_COLLECTION_NAME); - if (!collection.equals(EXPECTED_COLLECTION_NAME)) { + String taskName = params.getString(EXTRA_TASK_NAME); + if (!taskName.equals(EXPECTED_TASK_NAME)) { callback.onStartQueryFailure(STATUS_INTERNAL_ERROR); return; } callback.onStartQuerySuccess( - new ListExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1))); + new ListJavaExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1))); + } + + @Override + protected boolean checkCallerPermission() { + return true; } } + /** * A simple {@link ExampleStoreIterator} that returns the contents of the {@link List} it's * constructed with. */ - private static class ListExampleStoreIterator implements ExampleStoreIterator { + private static class ListJavaExampleStoreIterator implements ExampleStoreIterator { private final Iterator<Example> mExampleIterator; - ListExampleStoreIterator(List<Example> examples) { + ListJavaExampleStoreIterator(List<Example> examples) { mExampleIterator = examples.iterator(); } @@ -132,7 +139,7 @@ public class ExampleStoreServiceTest { public void close() {} } - class TestExampleStoreServiceCallback extends IExampleStoreCallback.Stub { + class TestJavaExampleStoreServiceCallback extends IExampleStoreCallback.Stub { @Override public void onStartQuerySuccess(IExampleStoreIterator iterator) { mCallbackResult = iterator; diff --git a/tests/frameworktests/src/android/federatedcompute/FederatedComputeManagerTest.java b/tests/frameworktests/src/android/federatedcompute/FederatedComputeManagerTest.java new file mode 100644 index 00000000..214b3ad9 --- /dev/null +++ b/tests/frameworktests/src/android/federatedcompute/FederatedComputeManagerTest.java @@ -0,0 +1,317 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.federatedcompute; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import android.content.Context; +import android.content.ContextWrapper; +import android.content.Intent; +import android.content.ServiceConnection; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.content.pm.ServiceInfo; +import android.federatedcompute.aidl.IFederatedComputeCallback; +import android.federatedcompute.aidl.IFederatedComputeService; +import android.federatedcompute.common.ScheduleFederatedComputeRequest; +import android.federatedcompute.common.TrainingOptions; +import android.os.IBinder; +import android.os.OutcomeReceiver; +import android.os.RemoteException; + +import androidx.test.core.app.ApplicationProvider; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +@RunWith(Parameterized.class) +public class FederatedComputeManagerTest { + + private final Context mContext = + spy(new MyTestContext(ApplicationProvider.getApplicationContext())); + + @Parameterized.Parameter(0) + public String scenario; + + @Parameterized.Parameter(1) + public ScheduleFederatedComputeRequest request; + + @Parameterized.Parameter(2) + public String populationName; + + @Parameterized.Parameter(3) + public IFederatedComputeService iFederatedComputeService; + + @Mock private PackageManager mMockPackageManager; + @Mock private IBinder mMockIBinder; + @Mock private IFederatedComputeService mMockIService; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList( + new Object[][] { + {"schedule-allNull", null, null, null}, + { + "schedule-default-iService", + new ScheduleFederatedComputeRequest.Builder() + .setTrainingOptions(new TrainingOptions.Builder().build()) + .build(), + null, + new IFederatedComputeService.Default() + }, + { + "schedule-mockIService-RemoteException", + new ScheduleFederatedComputeRequest.Builder() + .setTrainingOptions(new TrainingOptions.Builder().build()) + .build(), + null, + null /* mock will be returned */ + }, + { + "schedule-mockIService-onSuccess", + new ScheduleFederatedComputeRequest.Builder() + .setTrainingOptions(new TrainingOptions.Builder().build()) + .build(), + null, + null /* mock will be returned */ + }, + { + "schedule-mockIService-onFailure", + new ScheduleFederatedComputeRequest.Builder() + .setTrainingOptions(new TrainingOptions.Builder().build()) + .build(), + null, + null /* mock will be returned */ + }, + {"cancel-allNull", null, null, null}, + { + "cancel-default-iService", + null, + "testPopulation", + new IFederatedComputeService.Default() + }, + { + "cancel-mockIService-RemoteException", + null, + "testPopulation", + null /* mock will be returned */ + }, + { + "cancel-mockIService-onSuccess", + null, + "testPopulation", + null /* mock will be returned */ + }, + { + "cancel-mockIService-onFailure", + null, + "testPopulation", + null /* mock will be returned */ + }, + }); + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + ResolveInfo resolveInfo = new ResolveInfo(); + ServiceInfo serviceInfo = new ServiceInfo(); + serviceInfo.name = "TestName"; + serviceInfo.packageName = "com.android.federatedcompute.services"; + resolveInfo.serviceInfo = serviceInfo; + when(mMockPackageManager.queryIntentServices(any(), anyInt())) + .thenReturn(List.of(resolveInfo)); + when(mMockIBinder.queryLocalInterface(any())).thenReturn(iFederatedComputeService); + } + + @Test + public void testScheduleFederatedCompute() throws RemoteException { + FederatedComputeManager manager = new FederatedComputeManager(mContext); + OutcomeReceiver<Object, Exception> spyCallback; + + switch (scenario) { + case "schedule-allNull": + assertThrows( + NullPointerException.class, () -> manager.schedule(request, null, null)); + break; + case "schedule-default-iService": + manager.schedule(request, Executors.newSingleThreadExecutor(), null); + break; + case "schedule-mockIService-RemoteException": + when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); + doThrow(new RemoteException()).when(mMockIService).schedule(any(), any(), any()); + spyCallback = spy(new MyTestCallback()); + + manager.schedule(request, Runnable::run, spyCallback); + + verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); + verify(spyCallback, times(1)).onError(any(RemoteException.class)); + verify(mContext, times(1)).unbindService(any()); + break; + case "schedule-mockIService-onSuccess": + when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); + doAnswer( + invocation -> { + IFederatedComputeCallback federatedComputeCallback = + invocation.getArgument(2); + federatedComputeCallback.onSuccess(); + return null; + }) + .when(mMockIService) + .schedule(any(), any(), any()); + spyCallback = spy(new MyTestCallback()); + + manager.schedule(request, Runnable::run, spyCallback); + + verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); + verify(spyCallback, times(1)).onResult(isNull()); + verify(mContext, times(1)).unbindService(any()); + break; + case "schedule-mockIService-onFailure": + when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); + doAnswer( + invocation -> { + IFederatedComputeCallback federatedComputeCallback = + invocation.getArgument(2); + federatedComputeCallback.onFailure(1); + return null; + }) + .when(mMockIService) + .schedule(any(), any(), any()); + spyCallback = spy(new MyTestCallback()); + + manager.schedule(request, Runnable::run, spyCallback); + + verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); + verify(spyCallback, times(1)).onError(any(FederatedComputeException.class)); + verify(mContext, times(1)).unbindService(any()); + break; + case "cancel-allNull": + assertThrows( + NullPointerException.class, + () -> manager.cancel(populationName, null, null)); + break; + case "cancel-default-iService": + manager.cancel(populationName, Executors.newSingleThreadExecutor(), null); + break; + case "cancel-mockIService-RemoteException": + when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); + doThrow(new RemoteException()).when(mMockIService).cancel(any(), any(), any()); + spyCallback = spy(new MyTestCallback()); + + manager.cancel(populationName, Runnable::run, spyCallback); + + verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); + verify(spyCallback, times(1)).onError(any(RemoteException.class)); + verify(mContext, times(1)).unbindService(any()); + break; + case "cancel-mockIService-onSuccess": + when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); + doAnswer( + invocation -> { + IFederatedComputeCallback federatedComputeCallback = + invocation.getArgument(2); + federatedComputeCallback.onSuccess(); + return null; + }) + .when(mMockIService) + .cancel(any(), any(), any()); + spyCallback = spy(new MyTestCallback()); + + manager.cancel(populationName, Runnable::run, spyCallback); + + verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); + verify(spyCallback, times(1)).onResult(isNull()); + verify(mContext, times(1)).unbindService(any()); + break; + case "cancel-mockIService-onFailure": + when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); + doAnswer( + invocation -> { + IFederatedComputeCallback federatedComputeCallback = + invocation.getArgument(2); + federatedComputeCallback.onFailure(1); + return null; + }) + .when(mMockIService) + .cancel(any(), any(), any()); + spyCallback = spy(new MyTestCallback()); + + manager.cancel(populationName, Runnable::run, spyCallback); + + verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); + verify(spyCallback, times(1)).onError(any(FederatedComputeException.class)); + verify(mContext, times(1)).unbindService(any()); + break; + default: + break; + } + } + + public class MyTestContext extends ContextWrapper { + + MyTestContext(Context context) { + super(context); + } + + @Override + public PackageManager getPackageManager() { + return mMockPackageManager != null ? mMockPackageManager : super.getPackageManager(); + } + + @Override + public boolean bindService( + Intent service, int flags, Executor executor, ServiceConnection conn) { + executor.execute( + () -> { + conn.onServiceConnected(null, mMockIBinder); + }); + return true; + } + + public void unbindService(ServiceConnection conn) {} + } + + public class MyTestCallback implements OutcomeReceiver<Object, Exception> { + + @Override + public void onResult(Object o) {} + + @Override + public void onError(Exception error) { + OutcomeReceiver.super.onError(error); + } + } +} diff --git a/tests/frameworktests/src/android/federatedcompute/ResultHandlingServiceTest.java b/tests/frameworktests/src/android/federatedcompute/ResultHandlingServiceTest.java index b5726e92..7ce3b804 100644 --- a/tests/frameworktests/src/android/federatedcompute/ResultHandlingServiceTest.java +++ b/tests/frameworktests/src/android/federatedcompute/ResultHandlingServiceTest.java @@ -26,24 +26,25 @@ import static org.junit.Assert.assertTrue; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IResultHandlingService; +import android.federatedcompute.common.ClientConstants; import android.federatedcompute.common.ExampleConsumption; import android.federatedcompute.common.TrainingInterval; import android.federatedcompute.common.TrainingOptions; +import android.os.Bundle; import androidx.test.ext.junit.runners.AndroidJUnit4; -import com.google.common.collect.ImmutableList; - import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import java.util.List; +import java.util.ArrayList; import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; @RunWith(AndroidJUnit4.class) public final class ResultHandlingServiceTest { + private static final String TASK_NAME = "task-name"; private static final String TEST_POPULATION = "testPopulation"; private static final int JOB_ID = 12345; private static final byte[] SELECTION_CRITERIA = new byte[] {10, 0, 1}; @@ -55,13 +56,8 @@ public final class ResultHandlingServiceTest { .setSchedulingMode(SCHEDULING_MODE_ONE_TIME) .build()) .build(); - private static final ImmutableList<ExampleConsumption> EXAMPLE_CONSUMPTIONS = - ImmutableList.of( - new ExampleConsumption.Builder() - .setCollectionName("collection") - .setExampleCount(100) - .setSelectionCriteria(SELECTION_CRITERIA) - .build()); + private static final ArrayList<ExampleConsumption> EXAMPLE_CONSUMPTIONS = + createExampleConsumptionList(); private boolean mSuccess = false; private boolean mHandleResultCalled = false; @@ -80,8 +76,14 @@ public final class ResultHandlingServiceTest { @Test public void testHandleResult_success() throws Exception { - mBinder.handleResult( - TRAINING_OPTIONS, true, EXAMPLE_CONSUMPTIONS, new TestFederatedComputeCallback()); + Bundle input = new Bundle(); + input.putString(ClientConstants.EXTRA_TASK_NAME, TASK_NAME); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, TEST_POPULATION); + input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS); + input.putParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, EXAMPLE_CONSUMPTIONS); + + mBinder.handleResult(input, new TestFederatedComputeCallback()); mLatch.await(); assertTrue(mHandleResultCalled); @@ -90,7 +92,13 @@ public final class ResultHandlingServiceTest { @Test public void testHandleResult_failure() throws Exception { - mBinder.handleResult(TRAINING_OPTIONS, true, null, new TestFederatedComputeCallback()); + Bundle input = new Bundle(); + input.putString(ClientConstants.EXTRA_TASK_NAME, TASK_NAME); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, TEST_POPULATION); + input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS); + input.putParcelableArrayList(ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, null); + + mBinder.handleResult(input, new TestFederatedComputeCallback()); mLatch.await(); assertTrue(mHandleResultCalled); @@ -99,12 +107,12 @@ public final class ResultHandlingServiceTest { class TestResultHandlingService extends ResultHandlingService { @Override - public void handleResult( - TrainingOptions trainingOptions, - boolean success, - List<ExampleConsumption> exampleConsumptionList, - Consumer<Integer> callback) { + public void handleResult(Bundle input, Consumer<Integer> callback) { mHandleResultCalled = true; + ArrayList<ExampleConsumption> exampleConsumptionList = + input.getParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, + ExampleConsumption.class); if (exampleConsumptionList == null || exampleConsumptionList.isEmpty()) { callback.accept(STATUS_INTERNAL_ERROR); return; @@ -113,6 +121,17 @@ public final class ResultHandlingServiceTest { } } + private static ArrayList<ExampleConsumption> createExampleConsumptionList() { + ArrayList<ExampleConsumption> exampleList = new ArrayList<>(); + exampleList.add( + new ExampleConsumption.Builder() + .setTaskName("taskName") + .setExampleCount(100) + .setSelectionCriteria(SELECTION_CRITERIA) + .build()); + return exampleList; + } + class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub { @Override public void onSuccess() { diff --git a/tests/frameworktests/src/android/federatedcompute/common/ExampleConsumptionTest.java b/tests/frameworktests/src/android/federatedcompute/common/ExampleConsumptionTest.java index db394fe4..81daa6cf 100644 --- a/tests/frameworktests/src/android/federatedcompute/common/ExampleConsumptionTest.java +++ b/tests/frameworktests/src/android/federatedcompute/common/ExampleConsumptionTest.java @@ -35,49 +35,37 @@ import javax.annotation.Nullable; public final class ExampleConsumptionTest { @Test - public void testBuilder_emptyCollectionName() { + public void testBuilder_emptyTaskName() { assertThrows( IllegalArgumentException.class, () -> new ExampleConsumption.Builder() - .setCollectionName("") + .setTaskName("") .setExampleCount(10) .setSelectionCriteria(new byte[] {10, 0, 1}) .build()); } @Test - public void testBuilder_nullCollectionName() { + public void testBuilder_nullTaskName() { assertThrows( IllegalArgumentException.class, () -> new ExampleConsumption.Builder() - .setCollectionName(null) + .setTaskName(null) .setExampleCount(10) .setSelectionCriteria(new byte[] {10, 0, 1}) .build()); } @Test - public void testBuilder_nullSelectionCriteria() { - assertThrows( - NullPointerException.class, - () -> - new ExampleConsumption.Builder() - .setCollectionName("collection_name") - .setExampleCount(10) - .setSelectionCriteria(null) - .build()); - } - - @Test public void testBuilder_normalCaseWithoutResumptionToken() { - String collectionName = "my_collection"; + String taskName = "my_task"; byte[] selectionCriteria = new byte[] {10, 0, 1}; int exampleCount = 10; ExampleConsumption consumption = - createExampleConsumption(collectionName, selectionCriteria, exampleCount, null); - assertThat(consumption.getCollectionName()).isEqualTo(collectionName); + createExampleConsumption(taskName, selectionCriteria, exampleCount, null); + assertThat(consumption.getTaskName()).isEqualTo(taskName); assertThat(consumption.getExampleCount()).isEqualTo(exampleCount); assertThat(ByteString.copyFrom(consumption.getSelectionCriteria())) .isEqualTo(ByteString.copyFrom(selectionCriteria)); @@ -86,14 +74,14 @@ public final class ExampleConsumptionTest { @Test public void testBuilder_normalCaseWithResumptionToken() { - String collectionName = "my_collection"; + String taskName = "my_task"; byte[] selectionCriteria = new byte[] {10, 0, 1}; int exampleCount = 10; byte[] resumptionToken = new byte[] {25, 10, 4, 56}; ExampleConsumption consumption = createExampleConsumption( - collectionName, selectionCriteria, exampleCount, resumptionToken); - assertThat(consumption.getCollectionName()).isEqualTo(collectionName); + taskName, selectionCriteria, exampleCount, resumptionToken); + assertThat(consumption.getTaskName()).isEqualTo(taskName); assertThat(consumption.getExampleCount()).isEqualTo(exampleCount); assertThat(ByteString.copyFrom(consumption.getSelectionCriteria())) .isEqualTo(ByteString.copyFrom(selectionCriteria)); @@ -103,13 +91,13 @@ public final class ExampleConsumptionTest { @Test public void testWriteToParcel() { - String collectionName = "my_collection"; + String taskName = "my_task"; byte[] selectionCriteria = new byte[] {10, 0, 1}; int exampleCount = 10; byte[] resumptionToken = new byte[] {25, 10, 4, 56}; ExampleConsumption consumption = createExampleConsumption( - collectionName, selectionCriteria, exampleCount, resumptionToken); + taskName, selectionCriteria, exampleCount, resumptionToken); Parcel parcel = Parcel.obtain(); consumption.writeToParcel(parcel, 0); @@ -118,7 +106,7 @@ public final class ExampleConsumptionTest { parcel.setDataPosition(0); ExampleConsumption recoveredConsumption = ExampleConsumption.CREATOR.createFromParcel(parcel); - assertThat(recoveredConsumption.getCollectionName()).isEqualTo(collectionName); + assertThat(recoveredConsumption.getTaskName()).isEqualTo(taskName); assertThat(recoveredConsumption.getExampleCount()).isEqualTo(exampleCount); assertThat(ByteString.copyFrom(recoveredConsumption.getSelectionCriteria())) .isEqualTo(ByteString.copyFrom(selectionCriteria)); @@ -127,13 +115,13 @@ public final class ExampleConsumptionTest { } private static ExampleConsumption createExampleConsumption( - String collectionName, + String taskName, byte[] selectionCriteria, int exampleCount, @Nullable byte[] resumptionToken) { ExampleConsumption.Builder builder = new ExampleConsumption.Builder() - .setCollectionName(collectionName) + .setTaskName(taskName) .setSelectionCriteria(selectionCriteria) .setExampleCount(exampleCount); if (resumptionToken != null) { diff --git a/tests/frameworktests/src/android/federatedcompute/common/TrainingOptionsTest.java b/tests/frameworktests/src/android/federatedcompute/common/TrainingOptionsTest.java index ac878089..5eeda233 100644 --- a/tests/frameworktests/src/android/federatedcompute/common/TrainingOptionsTest.java +++ b/tests/frameworktests/src/android/federatedcompute/common/TrainingOptionsTest.java @@ -71,25 +71,25 @@ public final class TrainingOptionsTest { } @Test - public void testNullServerAddressIsAllowed() { - TrainingOptions options = - new TrainingOptions.Builder() - .setPopulationName(POPULATION_NAME) - .setServerAddress(null) - .build(); - - assertThat(options.getServerAddress()).isNull(); + public void testNullServerAddressIsNotAllowed() { + assertThrows( + IllegalArgumentException.class, + () -> + new TrainingOptions.Builder() + .setPopulationName(POPULATION_NAME) + .setServerAddress(null) + .build()); } @Test - public void testEmptyServerAddressIsAllowed() { - TrainingOptions options = - new TrainingOptions.Builder() - .setPopulationName(POPULATION_NAME) - .setServerAddress("") - .build(); - - assertThat(options.getServerAddress()).isEmpty(); + public void testEmptyServerAddressIsNotAllowed() { + assertThrows( + IllegalArgumentException.class, + () -> + new TrainingOptions.Builder() + .setPopulationName(POPULATION_NAME) + .setServerAddress("") + .build()); } @Test diff --git a/tests/frameworktests/src/com/android/federatedcompute/internal/util/AndroidServiceBinderTest.java b/tests/frameworktests/src/com/android/federatedcompute/internal/util/AndroidServiceBinderTest.java new file mode 100644 index 00000000..7afd4165 --- /dev/null +++ b/tests/frameworktests/src/com/android/federatedcompute/internal/util/AndroidServiceBinderTest.java @@ -0,0 +1,133 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.internal.util; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationManagingService; +import android.content.Context; +import android.federatedcompute.aidl.IFederatedComputeService; + +import androidx.test.core.app.ApplicationProvider; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +@RunWith(JUnit4.class) +public class AndroidServiceBinderTest { + public static final String ODP_MANAGING_SERVICE_INTENT_ACTION = + "android.OnDevicePersonalizationService"; + public static final String ODP_MANAGING_SERVICE_PACKAGE = + "com.android.ondevicepersonalization.services"; + public static final String ALT_ODP_MANAGING_SERVICE_PACKAGE = + "com.google.android.ondevicepersonalization.services"; + public static final String INCORRECT_PACKAGE = + "NOT.android.ondevicepersonalization.or.federatedcompute.services"; + private static final String FEDERATED_COMPUTATION_SERVICE_INTENT_ACTION = + "android.federatedcompute.FederatedComputeService"; + private static final String FEDERATED_COMPUTATION_SERVICE_PACKAGE = + "com.android.federatedcompute.services"; + private static final String GOOGLE_RENAMED_FEDERATED_COMPUTATION_SERVICE_PACKAGE = + "com.google.android.federatedcompute"; + private final Context mSpyContext = spy(ApplicationProvider.getApplicationContext()); + + @Test + public void testOdpServiceBinding() { + AbstractServiceBinder<IOnDevicePersonalizationManagingService> serviceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + mSpyContext, + ODP_MANAGING_SERVICE_INTENT_ACTION, + List.of( + ODP_MANAGING_SERVICE_PACKAGE, + ALT_ODP_MANAGING_SERVICE_PACKAGE), + IOnDevicePersonalizationManagingService.Stub::asInterface); + + final IOnDevicePersonalizationManagingService service = + serviceBinder.getService(Runnable::run); + assertNotNull(service); + } + + @Test + public void testServiceBindingWithFlags() { + AbstractServiceBinder<IOnDevicePersonalizationManagingService> serviceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + mSpyContext, + ODP_MANAGING_SERVICE_INTENT_ACTION, + List.of( + ODP_MANAGING_SERVICE_PACKAGE, + ALT_ODP_MANAGING_SERVICE_PACKAGE), + Context.BIND_ALLOW_ACTIVITY_STARTS, + IOnDevicePersonalizationManagingService.Stub::asInterface); + + final IOnDevicePersonalizationManagingService service = + serviceBinder.getService(Runnable::run); + verify(mSpyContext) + .bindService( + any(), + eq(Context.BIND_ALLOW_ACTIVITY_STARTS | Context.BIND_AUTO_CREATE), + any(), + any()); + assertNotNull(service); + } + + @Test + public void testFcpServiceBinding() { + AbstractServiceBinder<IFederatedComputeService> serviceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + mSpyContext, + FEDERATED_COMPUTATION_SERVICE_INTENT_ACTION, + List.of( + FEDERATED_COMPUTATION_SERVICE_PACKAGE, + GOOGLE_RENAMED_FEDERATED_COMPUTATION_SERVICE_PACKAGE), + IFederatedComputeService.Stub::asInterface); + + final IFederatedComputeService service = serviceBinder.getService(Runnable::run); + assertNotNull(service); + } + + @Test + public void testOdpServiceBindingWrongPackage() { + AbstractServiceBinder<IOnDevicePersonalizationManagingService> serviceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + mSpyContext, + ODP_MANAGING_SERVICE_INTENT_ACTION, + INCORRECT_PACKAGE, + IOnDevicePersonalizationManagingService.Stub::asInterface); + + assertThrows(IllegalStateException.class, () -> serviceBinder.getService(Runnable::run)); + } + + @Test + public void testFcpServiceBindingWrongPackage() { + AbstractServiceBinder<IFederatedComputeService> serviceBinder = + AbstractServiceBinder.getServiceBinderByIntent( + mSpyContext, + FEDERATED_COMPUTATION_SERVICE_INTENT_ACTION, + INCORRECT_PACKAGE, + IFederatedComputeService.Stub::asInterface); + + assertThrows(IllegalStateException.class, () -> serviceBinder.getService(Runnable::run)); + } +} diff --git a/tests/manualtests/Android.bp b/tests/manualtests/Android.bp index c39e9698..09142797 100644 --- a/tests/manualtests/Android.bp +++ b/tests/manualtests/Android.bp @@ -23,6 +23,7 @@ android_test { ":ondevicepersonalization-sources", ":ondevicepersonalization-fbs", ":chronicle-sources", + ":statslog-ondevicepersonalization-java-gen", ], libs: [ "android.test.base", @@ -33,6 +34,7 @@ android_test { "kotlin-annotations", "truth-prebuilt", "framework-ondevicepersonalization.impl", + "framework-statsd.stubs.module_lib", // For WW logging ], static_libs: [ "androidx.test.ext.junit", diff --git a/tests/manualtests/src/com/test/TestPersonalizationHandler.java b/tests/manualtests/src/com/test/TestPersonalizationHandler.java index 7091ba2d..f0406e76 100644 --- a/tests/manualtests/src/com/test/TestPersonalizationHandler.java +++ b/tests/manualtests/src/com/test/TestPersonalizationHandler.java @@ -16,8 +16,8 @@ package com.test; -import android.adservices.ondevicepersonalization.DownloadInput; -import android.adservices.ondevicepersonalization.DownloadOutput; +import android.adservices.ondevicepersonalization.DownloadCompletedInput; +import android.adservices.ondevicepersonalization.DownloadCompletedOutput; import android.adservices.ondevicepersonalization.IsolatedWorker; import android.adservices.ondevicepersonalization.KeyValueStore; import android.util.Log; @@ -39,15 +39,17 @@ public class TestPersonalizationHandler implements IsolatedWorker { } @Override - public void onDownload(DownloadInput input, Consumer<DownloadOutput> consumer) { + public void onDownloadCompleted( + DownloadCompletedInput input, + Consumer<DownloadCompletedOutput> consumer) { try { Log.d(TAG, "Starting filterData."); Log.d(TAG, "Existing keyExtra: " + Arrays.toString(mRemoteData.get("keyExtra"))); Log.d(TAG, "Existing keySet: " + mRemoteData.keySet()); - DownloadOutput result = - new DownloadOutput.Builder() + DownloadCompletedOutput result = + new DownloadCompletedOutput.Builder() .setRetainedKeys(getFilteredKeys(input.getData())) .build(); consumer.accept(result); diff --git a/tests/perftests/scenarios/Android.bp b/tests/perftests/scenarios/Android.bp index 1165e8e3..088b58f8 100644 --- a/tests/perftests/scenarios/Android.bp +++ b/tests/perftests/scenarios/Android.bp @@ -19,7 +19,27 @@ package { java_library { name: "ondevicepersonalization-test-scenarios", srcs: [ - "src/**/*.java", + "src/android/ondevicepersonalization/**/*.java", + ], + libs: [ + "framework-ondevicepersonalization.impl", + ], + static_libs: [ + "androidx.media_media", + "androidx.test.rules", + "androidx.test.runner", + "androidx.test.uiautomator_uiautomator", + "common-platform-scenarios", + "platform-test-annotations", + "platform-test-options", + "platform-test-rules", + ], +} + +java_library { + name: "federatedcompute-test-scenarios", + srcs: [ + "src/android/federatedcompute/**/*.java", ], libs: [ "framework-ondevicepersonalization.impl", diff --git a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleAndForceTraining.java b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleAndForceTraining.java new file mode 100644 index 00000000..d66ca093 --- /dev/null +++ b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleAndForceTraining.java @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.federatedcompute.test.scenario.federatedcompute; + +import android.platform.test.scenario.annotation.Scenario; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +@Scenario +@RunWith(JUnit4.class) +/** + * Schedule a federatedCompute training task from Odp Test app UI + * Then force the task execution through ADB commands + */ +public class ScheduleAndForceTraining { + private TestHelper mTestHelper = new TestHelper(); + + /** Prepare the device before entering the test class */ + @BeforeClass + public static void prepareDevice() { + TestHelper.initialize(); + } + + @Before + public void setup() throws IOException { + mTestHelper.pressHome(); + mTestHelper.openTestApp(); + mTestHelper.inputPopulationForScheduleTraining(); + } + + @Test + public void testScheduleAndForceTraining() throws IOException { + mTestHelper.clickScheduleTraining(); + mTestHelper.forceExecuteTrainingTaskForTestApp(); + } + + /** Return device to original state after test exeuction */ + @AfterClass + public static void tearDown() throws IOException { + TestHelper.pressHome(); + TestHelper.wrapUp(); + } +} diff --git a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java new file mode 100644 index 00000000..1a4fc6a3 --- /dev/null +++ b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java @@ -0,0 +1,176 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.federatedcompute.test.scenario.federatedcompute; + +import static org.junit.Assert.assertNotNull; + +import android.os.SystemClock; + +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.uiautomator.By; +import androidx.test.uiautomator.UiDevice; +import androidx.test.uiautomator.UiObject2; +import androidx.test.uiautomator.Until; + +import org.junit.Assert; + +import java.io.IOException; + +/** Helper class for interacting with federatedcompute in perf tests. */ +public class TestHelper { + private static UiDevice sUiDevice; + private static final long UI_FIND_RESOURCE_TIMEOUT = 5000; + private static final long TRAINING_TASK_COMPLETION_TIMEOUT = 120_000; + private static final String ODP_CLIENT_TEST_APP_PACKAGE_NAME = "com.example.odpclient"; + private static final String SCHEDULE_TRAINING_BUTTON_RESOURCE_ID = "schedule_training_button"; + private static final String SCHEDULE_TRAINING_TEXT_BOX_RESOURCE_ID = + "schedule_training_text_box"; + private static final String ODP_TEST_APP_POPULATION_NAME = "criteo_app_test_task"; + private static final String ODP_TEST_APP_TRAINING_TASK_JOB_ID = "1586947961"; + private static final String FEDERATED_TRAINING_JOB_SUCCESS_LOG = + "FederatedJobService - Federated computation job 1586947961 is done"; + + public static void pressHome() { + getUiDevice().pressHome(); + } + + /** Commands to prepare the device, odp module, fcp module before testing. */ + public static void initialize() { + executeShellCommand( + "device_config set_sync_disabled_for_tests persistent"); + disableGlobalKillSwitch(); + disableFederatedComputeKillSwitch(); + executeShellCommand( + "device_config put on_device_personalization " + + "enable_ondevicepersonalization_apis true"); + executeShellCommand( + "device_config put on_device_personalization " + + "enable_personalization_status_override true"); + executeShellCommand( + "device_config put on_device_personalization " + + "personalization_status_override_value true"); + executeShellCommand("setprop log.tag.ondevicepersonalization VERBOSE"); + executeShellCommand("setprop log.tag.federatedcompute VERBOSE"); + executeShellCommand( + "am broadcast -a android.intent.action.BOOT_COMPLETED -p " + + "com.google.android.ondevicepersonalization.services"); + executeShellCommand( + "am broadcast -a android.intent.action.BOOT_COMPLETED -p " + + "com.google.android.federatedcompute"); + } + + /** Commands to return device to original state */ + public static void wrapUp() { + executeShellCommand( + "device_config set_sync_disabled_for_tests none"); + } + + /** Open ODP client test app. */ + public void openTestApp() throws IOException { + sUiDevice.executeShellCommand( + "am start " + ODP_CLIENT_TEST_APP_PACKAGE_NAME + "/.MainActivity"); + } + + /** Put the default population name down for training */ + public void inputPopulationForScheduleTraining() { + UiObject2 scheduleTrainingTextBox = getScheduleTrainingTextBox(); + assertNotNull("Schedule Training text box not found", scheduleTrainingTextBox); + scheduleTrainingTextBox.setText(ODP_TEST_APP_POPULATION_NAME); + } + + /** Click Schedule Training button. */ + public void clickScheduleTraining() { + UiObject2 scheduleTrainingButton = getScheduleTrainingButton(); + assertNotNull("Schedule Training button not found", scheduleTrainingButton); + scheduleTrainingButton.click(); + SystemClock.sleep(10000); + } + + /** Force the JobScheduler to execute the training task, bypassing all constraints */ + public void forceExecuteTrainingTaskForTestApp() throws IOException { + executeShellCommand("logcat -c"); // Cleans the log buffer + executeShellCommand("logcat -G 32M"); // Set log buffer to 32MB + executeShellCommand( + "cmd jobscheduler run -f com.google.android.federatedcompute " + + ODP_TEST_APP_TRAINING_TASK_JOB_ID); + SystemClock.sleep(10000); + + boolean foundTrainingJobSuccessLog = findLog( + FEDERATED_TRAINING_JOB_SUCCESS_LOG, + TRAINING_TASK_COMPLETION_TIMEOUT, + 10000); + + if (!foundTrainingJobSuccessLog) { + Assert.fail(String.format( + "Failed to find federated training job success log within test window %d ms", + TRAINING_TASK_COMPLETION_TIMEOUT)); + } + } + + /** Attempt to find a specific log entry within the timeout window */ + private boolean findLog(final String targetLog, long timeoutMillis, + long queryIntervalMillis) throws IOException { + + long startTime = System.currentTimeMillis(); + while (System.currentTimeMillis() - startTime < timeoutMillis) { + if (getUiDevice().executeShellCommand("logcat -d").contains(targetLog)) { + return true; + } + SystemClock.sleep(queryIntervalMillis); + } + return false; + } + + private static void disableGlobalKillSwitch() { + executeShellCommand( + "device_config put on_device_personalization global_kill_switch false"); + } + + private static void disableFederatedComputeKillSwitch() { + executeShellCommand( + "device_config put on_device_personalization federated_compute_kill_switch false"); + } + + private static void executeShellCommand(String cmd) { + try { + getUiDevice().executeShellCommand(cmd); + } catch (IOException e) { + Assert.fail("Failed to execute shell command: " + cmd + ". error: " + e); + } + } + + private static UiDevice getUiDevice() { + if (sUiDevice == null) { + sUiDevice = UiDevice.getInstance(InstrumentationRegistry.getInstrumentation()); + } + return sUiDevice; + } + + private UiObject2 getScheduleTrainingTextBox() { + return sUiDevice.wait( + Until.findObject( + By.res(ODP_CLIENT_TEST_APP_PACKAGE_NAME, SCHEDULE_TRAINING_TEXT_BOX_RESOURCE_ID)), + UI_FIND_RESOURCE_TIMEOUT); + } + + private UiObject2 getScheduleTrainingButton() { + return sUiDevice.wait( + Until.findObject( + By.res(ODP_CLIENT_TEST_APP_PACKAGE_NAME, SCHEDULE_TRAINING_BUTTON_RESOURCE_ID)), + UI_FIND_RESOURCE_TIMEOUT); + } + +} diff --git a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadHelper.java b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadHelper.java index 31b1c042..1bda1e99 100644 --- a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadHelper.java +++ b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadHelper.java @@ -45,6 +45,34 @@ public class DownloadHelper { private static final String DOWNLOAD_PROCESSING_TASK_JOB_ID = "1004"; private static final String MAINTENANCE_TASK_JOB_ID = "1005"; + /** Commands to prepare the device and odp module before testing. */ + public static void initialize() throws IOException { + executeShellCommand( + "device_config set_sync_disabled_for_tests persistent"); + executeShellCommand( + "device_config put on_device_personalization global_kill_switch false"); + executeShellCommand( + "device_config put on_device_personalization " + + "enable_ondevicepersonalization_apis true"); + executeShellCommand( + "device_config put on_device_personalization " + + "enable_personalization_status_override true"); + executeShellCommand( + "device_config put on_device_personalization " + + "personalization_status_override_value true"); + executeShellCommand("setprop log.tag.ondevicepersonalization VERBOSE"); + executeShellCommand( + "am broadcast -a android.intent.action.BOOT_COMPLETED -p " + + "com.google.android.ondevicepersonalization.services"); + executeShellCommand( + "cmd jobscheduler run -f " + + "com.google.android.ondevicepersonalization.services 1000"); + SystemClock.sleep(5000); + executeShellCommand( + "cmd jobscheduler run -f " + + "com.google.android.ondevicepersonalization.services 1006"); + SystemClock.sleep(5000); + } public static void pressHome() { getUiDevice().pressHome(); } @@ -91,7 +119,7 @@ public class DownloadHelper { executeShellCommand( "cmd jobscheduler run -f com.google.android.ondevicepersonalization.services " + MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID); - SystemClock.sleep(5000); + SystemClock.sleep(60000); } public void processDownloadedVendorData() throws IOException { @@ -101,7 +129,7 @@ public class DownloadHelper { SystemClock.sleep(5000); } - private void executeShellCommand(String cmd) { + private static void executeShellCommand(String cmd) { try { getUiDevice().executeShellCommand(cmd); } catch (IOException e) { @@ -116,4 +144,10 @@ public class DownloadHelper { return sUiDevice; } + /** Commands to return device to original state */ + public static void wrapUp() throws IOException { + executeShellCommand( + "device_config set_sync_disabled_for_tests none"); + } + } diff --git a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadVendorData.java b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadVendorData.java new file mode 100644 index 00000000..a97a9fc2 --- /dev/null +++ b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadVendorData.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.ondevicepersonalization.test.scenario.ondevicepersonalization; + +import android.platform.test.scenario.annotation.Scenario; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +@Scenario +@RunWith(JUnit4.class) +public class DownloadVendorData { + + private DownloadHelper mDownloadHelper = new DownloadHelper(); + + /** Prepare the device before entering the test class */ + @BeforeClass + public static void prepareDevice() throws IOException { + DownloadHelper.initialize(); + } + @Before + public void setUp() throws IOException { + mDownloadHelper.pressHome(); + } + + @Test + public void testDownloadVendorData() throws IOException { + mDownloadHelper.downloadVendorData(); + mDownloadHelper.processDownloadedVendorData(); + } + + @After + public void tearDown() throws IOException { + mDownloadHelper.uninstallVendorApk(); + mDownloadHelper.cleanupDatabase(); + mDownloadHelper.cleanupDownloadedMetadata(); + mDownloadHelper.pressHome(); + mDownloadHelper.wrapUp(); + } + +} diff --git a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAd.java b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAd.java index 18759572..71f22bca 100644 --- a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAd.java +++ b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAd.java @@ -19,6 +19,7 @@ import android.platform.test.scenario.annotation.Scenario; import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -31,6 +32,12 @@ public class RequestAd { private TestAppHelper mTestAppHelper = new TestAppHelper(); + /** Prepare the device before entering the test class */ + @BeforeClass + public static void prepareDevice() throws IOException { + TestAppHelper.initialize(); + } + @Before public void setup() throws IOException { mTestAppHelper.openApp(); @@ -45,5 +52,6 @@ public class RequestAd { @AfterClass public static void tearDown() throws IOException { TestAppHelper.goToHomeScreen(); + TestAppHelper.wrapUp(); } } diff --git a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdAndClickAd.java b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdAndClickAd.java new file mode 100644 index 00000000..9cd3d359 --- /dev/null +++ b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdAndClickAd.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.ondevicepersonalization.test.scenario.ondevicepersonalization; + +import android.platform.test.scenario.annotation.Scenario; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +@Scenario +@RunWith(JUnit4.class) +public class RequestAdAndClickAd { + + private TestAppHelper mTestAppHelper = new TestAppHelper(); + + /** Prepare the device before entering the test class */ + @BeforeClass + public static void prepareDevice() throws IOException { + TestAppHelper.initialize(); + } + + @Before + public void setup() throws IOException { + mTestAppHelper.openApp(); + } + + @Test + public void testRequestAdAndClickAd() { + mTestAppHelper.clickGetAd(); + mTestAppHelper.verifyRenderedView(); + mTestAppHelper.clickAd("Google!"); + } + + /** Return device to original state after test exeuction */ + @AfterClass + public static void tearDown() throws IOException { + TestAppHelper.goToHomeScreen(); + TestAppHelper.wrapUp(); + } +} diff --git a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdWithTestAppRotations.java b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdWithTestAppRotations.java new file mode 100644 index 00000000..5e899f4d --- /dev/null +++ b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdWithTestAppRotations.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.ondevicepersonalization.test.scenario.ondevicepersonalization; + +import android.os.RemoteException; +import android.platform.test.scenario.annotation.Scenario; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +@Scenario +@RunWith(JUnit4.class) +public class RequestAdWithTestAppRotations { + + private TestAppHelper mTestAppHelper = new TestAppHelper(); + + /** Prepare the device before entering the test class */ + @BeforeClass + public static void prepareDevice() throws IOException { + TestAppHelper.initialize(); + } + + @Before + public void setup() throws IOException { + mTestAppHelper.openApp(); + } + + @Test + public void testRequestAdWithTestAppRotations() throws RemoteException { + mTestAppHelper.clickGetAd(); + mTestAppHelper.verifyRenderedView(); + + // Rotate to landscape layout + mTestAppHelper.setOrientationLandscape(); + mTestAppHelper.clickGetAd(); + mTestAppHelper.verifyRenderedView(); + + // Rotate to portrait layout + mTestAppHelper.setOrientationPortrait(); + mTestAppHelper.clickGetAd(); + mTestAppHelper.verifyRenderedView(); + + // Rotate to landscape layout + mTestAppHelper.setOrientationLandscape(); + mTestAppHelper.clickGetAd(); + mTestAppHelper.verifyRenderedView(); + + // Rotate to portrait layout + mTestAppHelper.setOrientationPortrait(); + mTestAppHelper.clickGetAd(); + mTestAppHelper.verifyRenderedView(); + } + + /** Return device to original state after test exeuction */ + @AfterClass + public static void tearDown() throws IOException { + TestAppHelper.goToHomeScreen(); + TestAppHelper.wrapUp(); + } +} diff --git a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/TestAppHelper.java b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/TestAppHelper.java index 655994f6..6c2c720e 100644 --- a/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/TestAppHelper.java +++ b/tests/perftests/scenarios/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/TestAppHelper.java @@ -19,21 +19,81 @@ import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentat import static org.junit.Assert.assertNotNull; +import android.os.RemoteException; +import android.os.SystemClock; + import androidx.test.uiautomator.By; import androidx.test.uiautomator.UiDevice; import androidx.test.uiautomator.UiObject2; +import androidx.test.uiautomator.UiObjectNotFoundException; +import androidx.test.uiautomator.UiScrollable; +import androidx.test.uiautomator.UiSelector; import androidx.test.uiautomator.Until; +import org.junit.Assert; + import java.io.IOException; /** Helper class for interacting with OdpClient test app in perf tests. */ public class TestAppHelper { private static final UiDevice sUiDevice = UiDevice.getInstance(getInstrumentation()); + private static UiScrollable sUiScrollable; private static final long UI_FIND_RESOURCE_TIMEOUT = 5000; - private static final String ODP_CLIENT_TEST_APP_PACKAGE_NAME = "com.android.odpclient"; + private static final long UI_ROTATE_IDLE_TIMEOUT = 5000; + private static final String ODP_CLIENT_TEST_APP_PACKAGE_NAME = "com.example.odpclient"; private static final String GET_AD_BUTTON_RESOURCE_ID = "get_ad_button"; private static final String RENDERED_VIEW_RESOURCE_ID = "rendered_view"; - private static final String SURFACE_VIEW_TEXT = "Nest"; + + /** Commands to prepare the device and odp module before testing. */ + public static void initialize() throws IOException { + executeShellCommand( + "device_config set_sync_disabled_for_tests persistent"); + executeShellCommand( + "device_config put on_device_personalization global_kill_switch false"); + executeShellCommand( + "device_config put on_device_personalization " + + "enable_ondevicepersonalization_apis true"); + executeShellCommand( + "device_config put on_device_personalization " + + "enable_personalization_status_override true"); + executeShellCommand( + "device_config put on_device_personalization " + + "personalization_status_override_value true"); + executeShellCommand("setprop log.tag.ondevicepersonalization VERBOSE"); + executeShellCommand( + "am broadcast -a android.intent.action.BOOT_COMPLETED -p " + + "com.google.android.ondevicepersonalization.services"); + executeShellCommand( + "cmd jobscheduler run -f " + + "com.google.android.ondevicepersonalization.services 1000"); + SystemClock.sleep(5000); + executeShellCommand( + "cmd jobscheduler run -f " + + "com.google.android.ondevicepersonalization.services 1006"); + SystemClock.sleep(5000); + executeShellCommand( + "cmd jobscheduler run -f " + + "com.google.android.ondevicepersonalization.services 1003"); + SystemClock.sleep(5000); + executeShellCommand( + "cmd jobscheduler run -f " + + "com.google.android.ondevicepersonalization.services 1004"); + SystemClock.sleep(5000); + } + + /** Commands to return device to original state */ + public static void wrapUp() throws IOException { + executeShellCommand( + "device_config set_sync_disabled_for_tests none"); + } + + private static void executeShellCommand(String cmd) { + try { + sUiDevice.executeShellCommand(cmd); + } catch (IOException e) { + Assert.fail("Failed to execute shell command: " + cmd + ". error: " + e); + } + } /** Open ODP client test app. */ public static void openApp() throws IOException { @@ -46,6 +106,20 @@ public class TestAppHelper { sUiDevice.pressHome(); } + /** Rotate screen to landscape orientation */ + public void setOrientationLandscape() throws RemoteException { + sUiDevice.unfreezeRotation(); + sUiDevice.setOrientationLandscape(); + SystemClock.sleep(UI_ROTATE_IDLE_TIMEOUT); + } + + /** Rotate screen to portrait orientation */ + public void setOrientationPortrait() throws RemoteException { + sUiDevice.unfreezeRotation(); + sUiDevice.setOrientationPortrait(); + SystemClock.sleep(UI_ROTATE_IDLE_TIMEOUT); + } + /** Click Get Ad button. */ public void clickGetAd() { UiObject2 getAdButton = getGetAdButton(); @@ -58,8 +132,23 @@ public class TestAppHelper { UiObject2 renderedView = getRenderedView(); assertNotNull("Rendered view not found", renderedView); - UiObject2 childSurfaceView = getChildSurfaceViewByText(SURFACE_VIEW_TEXT); - assertNotNull("Child surface view not found", childSurfaceView); + SystemClock.sleep(UI_FIND_RESOURCE_TIMEOUT); + if (renderedView.getChildCount() == 0) { + Assert.fail("Failed to render child surface view"); + } + } + + /** Click text on rendered Ad */ + public void clickAd(final String text) { + UiObject2 adUiObject = getUiObjectByText(text); + assertNotNull("Could not find Ad UiObject by given text " + text, adUiObject); + adUiObject.click(); + SystemClock.sleep(5000); + + if (sUiDevice.getCurrentPackageName() == null + || !sUiDevice.getCurrentPackageName().contains("com.android.chrome")) { + Assert.fail("Failed to click ad and jump to landing page in the default browser"); + } } private UiObject2 getGetAdButton() { @@ -68,15 +157,40 @@ public class TestAppHelper { UI_FIND_RESOURCE_TIMEOUT); } + /** Locate the rendered UI element in the scrollable view */ private UiObject2 getRenderedView() { - return sUiDevice.wait( - Until.findObject(By.res(ODP_CLIENT_TEST_APP_PACKAGE_NAME, RENDERED_VIEW_RESOURCE_ID)), - UI_FIND_RESOURCE_TIMEOUT); + for (int i = 0; i < 2; i++) { + // Try finding the renderedView on current screen + UiObject2 renderedView = sUiDevice.wait( + Until.findObject( + By.res(ODP_CLIENT_TEST_APP_PACKAGE_NAME, RENDERED_VIEW_RESOURCE_ID)), + UI_FIND_RESOURCE_TIMEOUT); + if (renderedView != null) { + return renderedView; + } + + // Try scroll to the end + try { + getUiScrollable().scrollToEnd(5); + } catch (UiObjectNotFoundException e) { + throw new RuntimeException(e); + } + } + return null; } - private UiObject2 getChildSurfaceViewByText(final String text) { + private UiObject2 getUiObjectByText(final String text) { return sUiDevice.wait( Until.findObject(By.desc(text)), UI_FIND_RESOURCE_TIMEOUT); } + + /** Get a UiScrollable instance configured for vertical scrolling */ + private static UiScrollable getUiScrollable() { + if (sUiScrollable == null) { + sUiScrollable = new UiScrollable(new UiSelector().scrollable(true)); + sUiScrollable.setAsVerticalList(); + } + return sUiScrollable; + } } diff --git a/tests/perftests/scenarios/tests/Android.bp b/tests/perftests/scenarios/tests/Android.bp index 4f660f18..7be79d1d 100644 --- a/tests/perftests/scenarios/tests/Android.bp +++ b/tests/perftests/scenarios/tests/Android.bp @@ -42,4 +42,32 @@ android_test { // Certificate and platform api is needed for collector-device-lib-platform. certificate: "platform", platform_apis: true, +} + +android_test { + name: "FederatedComputePerfScenariosTests", + srcs: [ + "src/android/federatedcompute/**/*.java", + ], + static_libs: [ + "androidx.test.runner", + "androidx.test.rules", + "collector-device-lib-platform", + "microbenchmark-device-lib", + "federatedcompute-test-scenarios", + "platform-test-options", + "platform-test-rules", + ], + data: [ + ":perfetto_artifacts" + ], + min_sdk_version: "Tiramisu", + test_suites: [ + "device-tests", + "general-tests", + ], + + // Certificate and platform api is needed for collector-device-lib-platform. + certificate: "platform", + platform_apis: true, }
\ No newline at end of file diff --git a/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleAndForceTrainingMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleAndForceTrainingMicrobenchmark.java new file mode 100644 index 00000000..48e36cf7 --- /dev/null +++ b/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleAndForceTrainingMicrobenchmark.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.federatedcompute.test.scenario.federatedcompute; + +import android.platform.test.microbenchmark.Microbenchmark; +import android.platform.test.rule.DropCachesRule; +import android.platform.test.rule.KillAppsRule; +import android.platform.test.rule.PressHomeRule; + +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.runner.RunWith; + +@RunWith(Microbenchmark.class) +public class ScheduleAndForceTrainingMicrobenchmark extends ScheduleAndForceTraining { + + @Rule + public RuleChain rules = RuleChain.outerRule(new DropCachesRule()) + .around(new KillAppsRule("com.example.odpclient")) + .around(new PressHomeRule()); +} diff --git a/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadVendorDataMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadVendorDataMicrobenchmark.java new file mode 100644 index 00000000..c3af84ca --- /dev/null +++ b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/DownloadVendorDataMicrobenchmark.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.ondevicepersonalization.test.scenario.ondevicepersonalization; + +import android.platform.test.microbenchmark.Microbenchmark; +import android.platform.test.rule.DropCachesRule; +import android.platform.test.rule.PressHomeRule; + +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.runner.RunWith; + +@RunWith(Microbenchmark.class) +public class DownloadVendorDataMicrobenchmark extends DownloadVendorData { + + @Rule + public RuleChain rules = RuleChain.outerRule(new DropCachesRule()) + .around(new PressHomeRule()); +} diff --git a/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdAndClickAdMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdAndClickAdMicrobenchmark.java new file mode 100644 index 00000000..0c74be8c --- /dev/null +++ b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdAndClickAdMicrobenchmark.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.ondevicepersonalization.test.scenario.ondevicepersonalization; + +import android.platform.test.microbenchmark.Microbenchmark; +import android.platform.test.rule.DropCachesRule; +import android.platform.test.rule.KillAppsRule; +import android.platform.test.rule.PressHomeRule; + +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.runner.RunWith; + +@RunWith(Microbenchmark.class) +public class RequestAdAndClickAdMicrobenchmark extends RequestAdAndClickAd { + + @Rule + public RuleChain rules = RuleChain.outerRule(new DropCachesRule()) + .around(new KillAppsRule("com.example.odpclient")) + .around(new KillAppsRule("com.android.chrome")) + .around(new PressHomeRule()); +} diff --git a/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdMicrobenchmark.java index 4d2135e9..a33f039f 100644 --- a/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdMicrobenchmark.java +++ b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdMicrobenchmark.java @@ -30,6 +30,6 @@ public class RequestAdMicrobenchmark extends RequestAd { @Rule public RuleChain rules = RuleChain.outerRule(new DropCachesRule()) - .around(new KillAppsRule("com.android.odpclient")) + .around(new KillAppsRule("com.example.odpclient")) .around(new PressHomeRule()); } diff --git a/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdWithTestAppRotationsMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdWithTestAppRotationsMicrobenchmark.java new file mode 100644 index 00000000..4fe02c78 --- /dev/null +++ b/tests/perftests/scenarios/tests/src/android/ondevicepersonalization/test/scenario/ondevicepersonalization/RequestAdWithTestAppRotationsMicrobenchmark.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.ondevicepersonalization.test.scenario.ondevicepersonalization; + +import android.platform.test.microbenchmark.Microbenchmark; +import android.platform.test.rule.DropCachesRule; +import android.platform.test.rule.KillAppsRule; +import android.platform.test.rule.PressHomeRule; + +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.runner.RunWith; + +@RunWith(Microbenchmark.class) +public class RequestAdWithTestAppRotationsMicrobenchmark extends RequestAdWithTestAppRotations { + + @Rule + public RuleChain rules = RuleChain.outerRule(new DropCachesRule()) + .around(new KillAppsRule("com.example.odpclient")) + .around(new PressHomeRule()); +} diff --git a/tests/servicetests/Android.bp b/tests/servicetests/Android.bp index 52c54655..b5267eed 100644 --- a/tests/servicetests/Android.bp +++ b/tests/servicetests/Android.bp @@ -24,6 +24,7 @@ android_test { ":ondevicepersonalization-sources", ":ondevicepersonalization-fbs", ":chronicle-sources", + ":statslog-ondevicepersonalization-java-gen", ], libs: [ "android.test.base", @@ -34,6 +35,7 @@ android_test { "kotlin-annotations", "truth-prebuilt", "framework-ondevicepersonalization.impl", + "framework-statsd.stubs.module_lib", // For WW logging ], static_libs: [ "androidx.test.ext.junit", diff --git a/tests/servicetests/AndroidManifest.xml b/tests/servicetests/AndroidManifest.xml index 72c21fe2..c8c78ff4 100644 --- a/tests/servicetests/AndroidManifest.xml +++ b/tests/servicetests/AndroidManifest.xml @@ -31,6 +31,9 @@ <!-- Used for persisting scheduled jobs --> <uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" /> + <!-- Permission to call OdpExampleStore --> + <uses-permission android:name="android.permission.BIND_EXAMPLE_STORE_SERVICE" /> + <application android:name="com.android.ondevicepersonalization.services.OnDevicePersonalizationApplication" android:label="OnDevicePersonalizationManagingServicesTests" android:debuggable="true"> @@ -42,9 +45,14 @@ <action android:name="android.OnDevicePersonalizationService" /> </intent-filter> </service> - <service android:name="com.android.ondevicepersonalization.services.OnDevicePersonalizationPrivacyStatusServiceImpl" android:exported="true" > + <service android:name="com.android.ondevicepersonalization.services.OnDevicePersonalizationConfigServiceImpl" android:exported="true" > <intent-filter> - <action android:name="android.OnDevicePersonalizationPrivacyStatusService" /> + <action android:name="android.OnDevicePersonalizationConfigService" /> + </intent-filter> + </service> + <service android:name="com.android.ondevicepersonalization.services.OnDevicePersonalizationDebugServiceImpl" android:exported="true" > + <intent-filter> + <action android:name="android.OnDevicePersonalizationService" /> </intent-filter> </service> <!-- TODO(b/258808270): Set isolated process to true --> @@ -67,24 +75,21 @@ android:exported="false" android:permission="android.permission.BIND_JOB_SERVICE"> </service> - <service android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpExampleStoreService" - android:enabled="true" android:exported="true" > + <service + android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpExampleStoreService" + android:enabled="true" + android:exported="true" + android:permission="android.permission.BIND_EXAMPLE_STORE_SERVICE"> <intent-filter> <action android:name="android.federatedcompute.EXAMPLE_STORE" /> - <data android:scheme="app" /> </intent-filter> </service> <service android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpResultHandlingService" android:enabled="true" android:exported="true" > <intent-filter> <action android:name="android.federatedcompute.COMPUTATION_RESULT" /> - <data android:scheme="app" /> </intent-filter> </service> - <service android:name="com.android.ondevicepersonalization.services.federatedcompute.OdpFederatedComputeJobService" - android:exported="false" - android:permission="android.permission.BIND_JOB_SERVICE"> - </service> <!-- On BOOT_COMPLETED receiver for registering jobs --> <!-- TODO(b/250001593) Enable any required broadcast receivers during runtime/onCreate. --> diff --git a/tests/servicetests/res/xml/OdpSettings.xml b/tests/servicetests/res/xml/OdpSettings.xml index aa60cb5d..31636673 100644 --- a/tests/servicetests/res/xml/OdpSettings.xml +++ b/tests/servicetests/res/xml/OdpSettings.xml @@ -19,6 +19,7 @@ <on-device-personalization> <service name="com.test.TestPersonalizationService"> <download-settings url="android.resource://com.android.ondevicepersonalization.servicetests/raw/test_data1" /> + <federated-compute-settings url="https://google.com" /> </service> </on-device-personalization> diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiverTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiverTests.java index 64258a59..9376bb58 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiverTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationBroadcastReceiverTests.java @@ -58,7 +58,6 @@ public class OnDevicePersonalizationBroadcastReceiverTests { jobScheduler.cancel(OnDevicePersonalizationConfig.MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID); jobScheduler.cancel(OnDevicePersonalizationConfig.MAINTENANCE_TASK_JOB_ID); jobScheduler.cancel(OnDevicePersonalizationConfig.USER_DATA_COLLECTION_ID); - jobScheduler.cancel(OnDevicePersonalizationConfig.FEDERATED_COMPUTE_TASK_JOB_ID); } @Test @@ -68,8 +67,7 @@ public class OnDevicePersonalizationBroadcastReceiverTests { MobileDataDownloadFactory.getMdd(mContext, executorService, executorService); OnDevicePersonalizationBroadcastReceiver receiver = - new OnDevicePersonalizationBroadcastReceiver( - executorService); + new OnDevicePersonalizationBroadcastReceiver(executorService); Intent intent = new Intent(Intent.ACTION_BOOT_COMPLETED); receiver.onReceive(mContext, intent); @@ -78,21 +76,31 @@ public class OnDevicePersonalizationBroadcastReceiverTests { JobScheduler jobScheduler = mContext.getSystemService(JobScheduler.class); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MAINTENANCE_TASK_JOB_ID) != null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.FEDERATED_COMPUTE_TASK_JOB_ID) != null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.USER_DATA_COLLECTION_ID) != null); + assertTrue( + jobScheduler.getPendingJob(OnDevicePersonalizationConfig.MAINTENANCE_TASK_JOB_ID) + != null); + assertTrue( + jobScheduler.getPendingJob(OnDevicePersonalizationConfig.USER_DATA_COLLECTION_ID) + != null); // MDD tasks - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_MAINTENANCE_PERIODIC_TASK_JOB_ID) != null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_CHARGING_PERIODIC_TASK_JOB_ID) != null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_CELLULAR_CHARGING_PERIODIC_TASK_JOB_ID) != null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID) != null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig.MDD_MAINTENANCE_PERIODIC_TASK_JOB_ID) + != null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig.MDD_CHARGING_PERIODIC_TASK_JOB_ID) + != null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig + .MDD_CELLULAR_CHARGING_PERIODIC_TASK_JOB_ID) + != null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig + .MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID) + != null); } @Test @@ -106,28 +114,38 @@ public class OnDevicePersonalizationBroadcastReceiverTests { JobScheduler jobScheduler = mContext.getSystemService(JobScheduler.class); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MAINTENANCE_TASK_JOB_ID) == null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.FEDERATED_COMPUTE_TASK_JOB_ID) == null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.USER_DATA_COLLECTION_ID) == null); + assertTrue( + jobScheduler.getPendingJob(OnDevicePersonalizationConfig.MAINTENANCE_TASK_JOB_ID) + == null); + assertTrue( + jobScheduler.getPendingJob(OnDevicePersonalizationConfig.USER_DATA_COLLECTION_ID) + == null); // MDD tasks - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_MAINTENANCE_PERIODIC_TASK_JOB_ID) == null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_CHARGING_PERIODIC_TASK_JOB_ID) == null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_CELLULAR_CHARGING_PERIODIC_TASK_JOB_ID) == null); - assertTrue(jobScheduler.getPendingJob( - OnDevicePersonalizationConfig.MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID) == null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig.MDD_MAINTENANCE_PERIODIC_TASK_JOB_ID) + == null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig.MDD_CHARGING_PERIODIC_TASK_JOB_ID) + == null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig + .MDD_CELLULAR_CHARGING_PERIODIC_TASK_JOB_ID) + == null); + assertTrue( + jobScheduler.getPendingJob( + OnDevicePersonalizationConfig + .MDD_WIFI_CHARGING_PERIODIC_TASK_JOB_ID) + == null); } @Test public void testEnableReceiver() { assertTrue(OnDevicePersonalizationBroadcastReceiver.enableReceiver(mContext)); - ComponentName componentName = new ComponentName(mContext, - OnDevicePersonalizationBroadcastReceiver.class); + ComponentName componentName = + new ComponentName(mContext, OnDevicePersonalizationBroadcastReceiver.class); final PackageManager pm = mContext.getPackageManager(); final int result = pm.getComponentEnabledSetting(componentName); assertEquals(COMPONENT_ENABLED_STATE_ENABLED, result); diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceTest.java index 8b25dc73..45000418 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationPrivacyStatusServiceTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationConfigServiceTest.java @@ -22,85 +22,114 @@ import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; -import android.adservices.ondevicepersonalization.aidl.IPrivacyStatusServiceCallback; +import android.adservices.ondevicepersonalization.aidl.IOnDevicePersonalizationConfigServiceCallback; import android.content.Context; import android.content.Intent; +import android.content.pm.PackageManager; import android.database.Cursor; import android.os.IBinder; import androidx.test.core.app.ApplicationProvider; import androidx.test.rule.ServiceTestRule; -import com.android.ondevicepersonalization.services.data.user.PrivacySignal; import com.android.ondevicepersonalization.services.data.user.RawUserData; import com.android.ondevicepersonalization.services.data.user.UserDataCollector; import com.android.ondevicepersonalization.services.data.user.UserDataDao; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.TimeZone; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeoutException; @RunWith(JUnit4.class) -public class OnDevicePersonalizationPrivacyStatusServiceTest { +public class OnDevicePersonalizationConfigServiceTest { @Rule public final ServiceTestRule serviceRule = new ServiceTestRule(); - private Context mContext = ApplicationProvider.getApplicationContext(); - private OnDevicePersonalizationPrivacyStatusServiceDelegate mBinder; - private PrivacySignal mPrivacySignal; + private Context mContext = spy(ApplicationProvider.getApplicationContext()); + private OnDevicePersonalizationConfigServiceDelegate mBinder; + private UserPrivacyStatus mUserPrivacyStatus; private RawUserData mUserData; private UserDataCollector mUserDataCollector; private UserDataDao mUserDataDao; @Before public void setup() throws Exception { - mBinder = new OnDevicePersonalizationPrivacyStatusServiceDelegate(mContext); - mPrivacySignal = PrivacySignal.getInstance(); + PhFlagsTestUtil.setUpDeviceConfigPermissions(); + PhFlagsTestUtil.disableGlobalKillSwitch(); + PhFlagsTestUtil.enableOnDevicePersonalizationApis(); + PhFlagsTestUtil.disablePersonalizationStatusOverride(); + when(mContext.checkCallingOrSelfPermission(anyString())) + .thenReturn(PackageManager.PERMISSION_GRANTED); + mBinder = new OnDevicePersonalizationConfigServiceDelegate(mContext); + mUserPrivacyStatus = UserPrivacyStatus.getInstance(); + mUserPrivacyStatus.setPersonalizationStatusEnabled(false); mUserData = RawUserData.getInstance(); + TimeZone pstTime = TimeZone.getTimeZone("GMT-08:00"); + TimeZone.setDefault(pstTime); mUserDataCollector = UserDataCollector.getInstanceForTest(mContext); mUserDataDao = UserDataDao.getInstanceForTest(mContext); } @Test - public void testSetKidStatusChanged() throws Exception { - assertTrue(mPrivacySignal.isKidStatusEnabled()); + public void testDisableOnDevicePersonalizationApis() throws Exception { + PhFlagsTestUtil.disableOnDevicePersonalizationApis(); + try { + assertThrows( + IllegalStateException.class, + () -> + mBinder.setPersonalizationStatus(true, null) + ); + } finally { + PhFlagsTestUtil.enableOnDevicePersonalizationApis(); + } + } + + @Test + public void testSetPersonalizationStatusNoCallingPermission() throws Exception { + when(mContext.checkCallingOrSelfPermission(anyString())) + .thenReturn(PackageManager.PERMISSION_DENIED); + assertThrows(SecurityException.class, () -> { + mBinder.setPersonalizationStatus(true, null); + }); + } + @Test + public void testSetPersonalizationStatusChanged() throws Exception { + assertFalse(mUserPrivacyStatus.isPersonalizationStatusEnabled()); populateUserData(); - assertNotEquals(0, mUserData.timeMillis); + assertNotEquals(0, mUserData.utcOffset); assertTrue(mUserDataCollector.isInitialized()); CountDownLatch latch = new CountDownLatch(1); - mBinder.setKidStatus(false, new IPrivacyStatusServiceCallback() { - @Override - public void onSuccess() { - latch.countDown(); - } - - @Override - public void onFailure(int errorCode) { - Assert.fail(); - } - - @Override - public IBinder asBinder() { - return null; - } - }); + mBinder.setPersonalizationStatus(true, + new IOnDevicePersonalizationConfigServiceCallback.Stub() { + @Override + public void onSuccess() { + latch.countDown(); + } + + @Override + public void onFailure(int errorCode) { + latch.countDown(); + } + }); latch.await(); + assertTrue(mUserPrivacyStatus.isPersonalizationStatusEnabled()); - assertFalse(mPrivacySignal.isKidStatusEnabled()); - - assertEquals(0, mUserData.timeMillis); + assertEquals(0, mUserData.utcOffset); assertFalse(mUserDataCollector.isInitialized()); - assertEquals(0, mUserData.timeMillis); Cursor appUsageCursor = mUserDataDao.readAppUsageInLastXDays(30); assertNotNull(appUsageCursor); assertEquals(0, appUsageCursor.getCount()); @@ -110,19 +139,19 @@ public class OnDevicePersonalizationPrivacyStatusServiceTest { } @Test - public void testSetKidStatusIfCallbackMissing() throws Exception { + public void testSetPersonalizationStatusIfCallbackMissing() throws Exception { assertThrows(NullPointerException.class, () -> { - mBinder.setKidStatus(false, null); + mBinder.setPersonalizationStatus(true, null); }); } @Test - public void testSetKidStatusNoOps() throws Exception { - mPrivacySignal.setKidStatusEnabled(false); + public void testSetPersonalizationStatusNoOps() throws Exception { + mUserPrivacyStatus.setPersonalizationStatusEnabled(true); populateUserData(); - assertNotEquals(0, mUserData.timeMillis); - long timeMillis = mUserData.timeMillis; + assertNotEquals(0, mUserData.utcOffset); + int utcOffset = mUserData.utcOffset; assertTrue(mUserDataCollector.isInitialized()); Cursor appUsageCursor = mUserDataDao.readAppUsageInLastXDays(30); Cursor locationCursor = mUserDataDao.readLocationInLastXDays(30); @@ -134,28 +163,24 @@ public class OnDevicePersonalizationPrivacyStatusServiceTest { assertTrue(locationCount > 0); CountDownLatch latch = new CountDownLatch(1); - mBinder.setKidStatus(false, new IPrivacyStatusServiceCallback() { - @Override - public void onSuccess() { - latch.countDown(); - } - - @Override - public void onFailure(int errorCode) { - Assert.fail(); - } - - @Override - public IBinder asBinder() { - return null; - } - }); + mBinder.setPersonalizationStatus(true, + new IOnDevicePersonalizationConfigServiceCallback.Stub() { + @Override + public void onSuccess() { + latch.countDown(); + } + + @Override + public void onFailure(int errorCode) { + latch.countDown(); + } + }); latch.await(); - assertFalse(mPrivacySignal.isKidStatusEnabled()); + assertTrue(mUserPrivacyStatus.isPersonalizationStatusEnabled()); // Adult data should not be roll-back'ed - assertEquals(timeMillis, mUserData.timeMillis); + assertEquals(utcOffset, mUserData.utcOffset); assertTrue(mUserDataCollector.isInitialized()); Cursor newAppUsageCursor = mUserDataDao.readAppUsageInLastXDays(30); Cursor newLocationCursor = mUserDataDao.readLocationInLastXDays(30); @@ -168,14 +193,13 @@ public class OnDevicePersonalizationPrivacyStatusServiceTest { @Test public void testWithBoundService() throws TimeoutException { Intent serviceIntent = new Intent(mContext, - OnDevicePersonalizationPrivacyStatusServiceImpl.class); + OnDevicePersonalizationConfigServiceImpl.class); IBinder binder = serviceRule.bindService(serviceIntent); - assertTrue(binder instanceof OnDevicePersonalizationPrivacyStatusServiceDelegate); + assertTrue(binder instanceof OnDevicePersonalizationConfigServiceDelegate); } @After public void tearDown() throws Exception { - mPrivacySignal.setKidStatusEnabled(true); mUserDataCollector.clearUserData(mUserData); mUserDataCollector.clearMetadata(); mUserDataCollector.clearDatabase(); diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceTest.java new file mode 100644 index 00000000..10dfac2c --- /dev/null +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationDebugServiceTest.java @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services; + +import static com.android.dx.mockito.inline.extended.ExtendedMockito.when; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import android.content.Context; +import android.content.Intent; +import android.os.IBinder; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.rule.ServiceTestRule; + +import com.android.dx.mockito.inline.extended.ExtendedMockito; +import com.android.ondevicepersonalization.services.util.DebugUtils; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockitoSession; + +import java.util.concurrent.TimeoutException; + +@RunWith(JUnit4.class) +public class OnDevicePersonalizationDebugServiceTest { + @Rule + public final ServiceTestRule serviceRule = new ServiceTestRule(); + private final Context mContext = ApplicationProvider.getApplicationContext(); + private OnDevicePersonalizationDebugServiceDelegate mService; + + @Before + public void setup() throws Exception { + mService = new OnDevicePersonalizationDebugServiceDelegate(mContext); + } + + @Test + public void testIsEnabledTrueWhenDeveloperModeOn() throws Exception { + MockitoSession session = ExtendedMockito.mockitoSession() + .mockStatic(DebugUtils.class) + .startMocking(); + try { + when(DebugUtils.isDeveloperModeEnabled(mContext)).thenReturn(true); + assertTrue(mService.isEnabled()); + } finally { + session.finishMocking(); + } + } + + @Test + public void testIsEnabledFalseWhenDeveloperModeOff() throws Exception { + MockitoSession session = ExtendedMockito.mockitoSession() + .mockStatic(DebugUtils.class) + .startMocking(); + try { + when(DebugUtils.isDeveloperModeEnabled(mContext)).thenReturn(false); + assertFalse(mService.isEnabled()); + } finally { + session.finishMocking(); + } + } + + @Test + public void testWithBoundService() throws TimeoutException { + Intent serviceIntent = new Intent(mContext, + OnDevicePersonalizationDebugServiceImpl.class); + IBinder binder = serviceRule.bindService(serviceIntent); + assertTrue(binder instanceof OnDevicePersonalizationDebugServiceDelegate); + } + +} diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceTest.java index 172dc31f..8f120924 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/OnDevicePersonalizationManagingServiceTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import android.adservices.ondevicepersonalization.CallerMetadata; import android.adservices.ondevicepersonalization.aidl.IExecuteCallback; import android.adservices.ondevicepersonalization.aidl.IRequestSurfacePackageCallback; import android.content.ComponentName; @@ -34,6 +35,7 @@ import android.view.SurfaceControlViewHost; import androidx.test.core.app.ApplicationProvider; import androidx.test.rule.ServiceTestRule; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.request.AppRequestFlow; import com.android.ondevicepersonalization.services.request.RenderFlow; @@ -55,11 +57,13 @@ public class OnDevicePersonalizationManagingServiceTest { private OnDevicePersonalizationManagingServiceDelegate mService; private boolean mAppRequestFlowStarted = false; private boolean mRenderFlowStarted = false; + private UserPrivacyStatus mPrivacyStatus = UserPrivacyStatus.getInstance(); @Before public void setup() throws Exception { PhFlagsTestUtil.setUpDeviceConfigPermissions(); PhFlagsTestUtil.disableGlobalKillSwitch(); + mPrivacyStatus.setPersonalizationStatusEnabled(true); mService = new OnDevicePersonalizationManagingServiceDelegate( mContext, new TestInjector()); } @@ -81,6 +85,7 @@ public class OnDevicePersonalizationManagingServiceTest { new ComponentName( mContext.getPackageName(), "com.test.TestPersonalizationHandler"), PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), callback )); } finally { @@ -96,6 +101,7 @@ public class OnDevicePersonalizationManagingServiceTest { new ComponentName( mContext.getPackageName(), "com.test.TestPersonalizationHandler"), PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), callback); assertTrue(mAppRequestFlowStarted); } @@ -112,11 +118,12 @@ public class OnDevicePersonalizationManagingServiceTest { mContext.getPackageName(), "com.test.TestPersonalizationHandler"), PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), callback)); } @Test - public void testExecuteThrowsIfAppPackageNameMissing() throws Exception { + public void testExecuteThrowsIfAppPackageNameNull() throws Exception { var callback = new ExecuteCallback(); assertThrows( NullPointerException.class, @@ -127,11 +134,28 @@ public class OnDevicePersonalizationManagingServiceTest { mContext.getPackageName(), "com.test.TestPersonalizationHandler"), PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), callback)); } @Test - public void testExecuteThrowsIfSHandlerMissing() throws Exception { + public void testExecuteThrowsIfAppPackageNameMissing() throws Exception { + var callback = new ExecuteCallback(); + assertThrows( + IllegalArgumentException.class, + () -> + mService.execute( + "", + new ComponentName( + mContext.getPackageName(), + "com.test.TestPersonalizationHandler"), + PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), + callback)); + } + + @Test + public void testExecuteThrowsIfHandlerMissing() throws Exception { var callback = new ExecuteCallback(); assertThrows( NullPointerException.class, @@ -140,6 +164,50 @@ public class OnDevicePersonalizationManagingServiceTest { mContext.getPackageName(), null, PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), + callback)); + } + + @Test + public void testExecuteThrowsIfServicePackageMissing() throws Exception { + var callback = new ExecuteCallback(); + assertThrows( + IllegalArgumentException.class, + () -> + mService.execute( + mContext.getPackageName(), + new ComponentName("", "ServiceClass"), + PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), + callback)); + } + + @Test + public void testExecuteThrowsIfServiceClassMissing() throws Exception { + var callback = new ExecuteCallback(); + assertThrows( + IllegalArgumentException.class, + () -> + mService.execute( + mContext.getPackageName(), + new ComponentName("com.test.TestPackage", ""), + PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), + callback)); + } + + @Test + public void testExecuteThrowsIfMetadataMissing() throws Exception { + var callback = new ExecuteCallback(); + assertThrows( + NullPointerException.class, + () -> + mService.execute( + mContext.getPackageName(), + new ComponentName( + mContext.getPackageName(), "com.test.TestPersonalizationHandler"), + PersistableBundle.EMPTY, + null, callback)); } @@ -153,6 +221,7 @@ public class OnDevicePersonalizationManagingServiceTest { new ComponentName( mContext.getPackageName(), "com.test.TestPersonalizationHandler"), PersistableBundle.EMPTY, + new CallerMetadata.Builder().build(), null)); } @@ -170,6 +239,7 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 100, 50, + new CallerMetadata.Builder().build(), callback )); } finally { @@ -186,6 +256,7 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 100, 50, + new CallerMetadata.Builder().build(), callback); assertTrue(mRenderFlowStarted); } @@ -202,6 +273,7 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 100, 50, + new CallerMetadata.Builder().build(), callback)); } @@ -217,6 +289,7 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 100, 50, + new CallerMetadata.Builder().build(), callback)); } @@ -232,6 +305,7 @@ public class OnDevicePersonalizationManagingServiceTest { -1, 100, 50, + new CallerMetadata.Builder().build(), callback)); } @@ -247,6 +321,7 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 0, 50, + new CallerMetadata.Builder().build(), callback)); } @@ -262,6 +337,23 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 100, 0, + new CallerMetadata.Builder().build(), + callback)); + } + + @Test + public void testRequestSurfacePackageThrowsIfMetadataMissing() throws Exception { + var callback = new RequestSurfacePackageCallback(); + assertThrows( + NullPointerException.class, + () -> + mService.requestSurfacePackage( + "resultToken", + new Binder(), + 0, + 100, + 50, + null, callback)); } @@ -276,6 +368,7 @@ public class OnDevicePersonalizationManagingServiceTest { 0, 100, 50, + new CallerMetadata.Builder().build(), null)); } @@ -292,7 +385,8 @@ public class OnDevicePersonalizationManagingServiceTest { mContext.getPackageName(), "com.test.TestPersonalizationHandler"), PersistableBundle.EMPTY, executeCallback, - mContext)); + mContext, + 0L)); assertNotNull(injector.getRenderFlow( "resultToken", @@ -301,7 +395,8 @@ public class OnDevicePersonalizationManagingServiceTest { 100, 50, renderCallback, - mContext + mContext, + 0L )); } @@ -319,9 +414,10 @@ public class OnDevicePersonalizationManagingServiceTest { ComponentName handler, PersistableBundle params, IExecuteCallback callback, - Context context) { + Context context, + long startTimeMillis) { return new AppRequestFlow( - callingPackageName, handler, params, callback, context) { + callingPackageName, handler, params, callback, context, startTimeMillis) { @Override public void run() { mAppRequestFlowStarted = true; } @@ -335,9 +431,11 @@ public class OnDevicePersonalizationManagingServiceTest { int width, int height, IRequestSurfacePackageCallback callback, - Context context) { + Context context, + long startTimeMillis) { return new RenderFlow( - slotResultToken, hostToken, displayId, width, height, callback, context) { + slotResultToken, hostToken, displayId, width, height, callback, context, + startTimeMillis) { @Override public void run() { mRenderFlowStarted = true; } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTest.java index 3cc46fa9..4cada534 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTest.java @@ -16,8 +16,14 @@ package com.android.ondevicepersonalization.services; +import static com.android.ondevicepersonalization.services.Flags.ENABLE_ONDEVICEPERSONALIZATION_APIS; +import static com.android.ondevicepersonalization.services.Flags.ENABLE_PERSONALIZATION_STATUS_OVERRIDE; import static com.android.ondevicepersonalization.services.Flags.GLOBAL_KILL_SWITCH; +import static com.android.ondevicepersonalization.services.Flags.PERSONALIZATION_STATUS_OVERRIDE_VALUE; +import static com.android.ondevicepersonalization.services.PhFlags.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; +import static com.android.ondevicepersonalization.services.PhFlags.KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE; import static com.android.ondevicepersonalization.services.PhFlags.KEY_GLOBAL_KILL_SWITCH; +import static com.android.ondevicepersonalization.services.PhFlags.KEY_PERSONALIZATION_STATUS_OVERRIDE_VALUE; import static com.google.common.truth.Truth.assertThat; @@ -61,4 +67,70 @@ public class PhFlagsTest { Flags phFlags = FlagsFactory.getFlags(); assertThat(phFlags.getGlobalKillSwitch()).isEqualTo(phOverridingValue); } + + @Test + public void testIsOnDevicePersonalizationApisEnabled() { + PhFlagsTestUtil.disableGlobalKillSwitch(); + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS, + Boolean.toString(ENABLE_ONDEVICEPERSONALIZATION_APIS), + /* makeDefault */ false); + assertThat(FlagsFactory.getFlags().isOnDevicePersonalizationApisEnabled()).isEqualTo( + ENABLE_ONDEVICEPERSONALIZATION_APIS); + + final boolean phOverridingValue = true; + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS, + Boolean.toString(phOverridingValue), + /* makeDefault */ false); + + Flags phFlags = FlagsFactory.getFlags(); + assertThat(phFlags.isOnDevicePersonalizationApisEnabled()).isEqualTo(phOverridingValue); + } + + @Test + public void testIsPersonalizationStatusOverrideEnabled() { + PhFlagsTestUtil.disableGlobalKillSwitch(); + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE, + Boolean.toString(ENABLE_PERSONALIZATION_STATUS_OVERRIDE), + /* makeDefault */ false); + assertThat(FlagsFactory.getFlags().isPersonalizationStatusOverrideEnabled()).isEqualTo( + ENABLE_PERSONALIZATION_STATUS_OVERRIDE); + + final boolean phOverridingValue = true; + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE, + Boolean.toString(phOverridingValue), + /* makeDefault */ false); + + Flags phFlags = FlagsFactory.getFlags(); + assertThat(phFlags.isPersonalizationStatusOverrideEnabled()).isEqualTo(phOverridingValue); + } + + @Test + public void testGetPersonalizationStatusOverrideValue() { + PhFlagsTestUtil.disableGlobalKillSwitch(); + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_PERSONALIZATION_STATUS_OVERRIDE_VALUE, + Boolean.toString(PERSONALIZATION_STATUS_OVERRIDE_VALUE), + /* makeDefault */ false); + assertThat(FlagsFactory.getFlags().getPersonalizationStatusOverrideValue()).isEqualTo( + PERSONALIZATION_STATUS_OVERRIDE_VALUE); + + final boolean phOverridingValue = true; + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_PERSONALIZATION_STATUS_OVERRIDE_VALUE, + Boolean.toString(phOverridingValue), + /* makeDefault */ false); + + Flags phFlags = FlagsFactory.getFlags(); + assertThat(phFlags.getPersonalizationStatusOverrideValue()).isEqualTo(phOverridingValue); + } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTestUtil.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTestUtil.java index 0066884c..81926d30 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTestUtil.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/PhFlagsTestUtil.java @@ -16,6 +16,8 @@ package com.android.ondevicepersonalization.services; +import static com.android.ondevicepersonalization.services.PhFlags.KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS; +import static com.android.ondevicepersonalization.services.PhFlags.KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE; import static com.android.ondevicepersonalization.services.PhFlags.KEY_GLOBAL_KILL_SWITCH; import android.provider.DeviceConfig; @@ -58,4 +60,33 @@ public class PhFlagsTestUtil { Boolean.toString(false), /* makeDefault */ false); } + + /** Enable OnDevicePersonalization APIs. */ + public static void enableOnDevicePersonalizationApis() { + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS, + Boolean.toString(true), + /* makeDefault */ false); + } + + /** Disable OnDevicePersonalization APIs. */ + public static void disableOnDevicePersonalizationApis() { + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_ONDEVICEPERSONALIZATION_APIS, + Boolean.toString(false), + /* makeDefault */ false); + } + + /** + * Disable the enable_personalization_status_override to test personalization-related features. + */ + public static void disablePersonalizationStatusOverride() { + DeviceConfig.setProperty( + DeviceConfig.NAMESPACE_ON_DEVICE_PERSONALIZATION, + KEY_ENABLE_PERSONALIZATION_STATUS_OVERRIDE, + Boolean.toString(false), + /* makeDefault */ false); + } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImplTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImplTest.java index b4c5f8eb..2e953aa7 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImplTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/DataAccessServiceImplTest.java @@ -25,8 +25,11 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import android.adservices.ondevicepersonalization.Constants; +import android.adservices.ondevicepersonalization.EventLogRecord; +import android.adservices.ondevicepersonalization.RequestLogRecord; import android.adservices.ondevicepersonalization.aidl.IDataAccessService; import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; +import android.content.ContentValues; import android.content.Context; import android.net.Uri; import android.os.Bundle; @@ -34,12 +37,17 @@ import android.os.PersistableBundle; import androidx.test.core.app.ApplicationProvider; +import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice; +import com.android.ondevicepersonalization.services.data.events.Event; import com.android.ondevicepersonalization.services.data.events.EventUrlHelper; import com.android.ondevicepersonalization.services.data.events.EventUrlPayload; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import com.android.ondevicepersonalization.services.data.events.Query; import com.android.ondevicepersonalization.services.data.vendor.LocalData; import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonalizationLocalDataDao; import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonalizationVendorDataDao; import com.android.ondevicepersonalization.services.data.vendor.VendorData; +import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; import com.android.ondevicepersonalization.services.util.PackageUtils; import com.google.common.util.concurrent.ListeningExecutorService; @@ -61,6 +69,7 @@ import java.util.concurrent.CountDownLatch; public class DataAccessServiceImplTest { private static final double DELTA = 0.001; private static final byte[] RESPONSE_BYTES = {'A', 'B'}; + private static final int EVENT_TYPE_B2D = 1; private final Context mApplicationContext = ApplicationProvider.getApplicationContext(); private long mTimeMillis = 1000; private EventUrlPayload mEventUrlPayload; @@ -72,25 +81,28 @@ public class DataAccessServiceImplTest { private boolean mOnErrorCalled = false; private OnDevicePersonalizationLocalDataDao mLocalDao; private OnDevicePersonalizationVendorDataDao mVendorDao; + private EventsDao mEventsDao; private DataAccessServiceImpl mServiceImpl; private IDataAccessService mServiceProxy; @Before public void setup() throws Exception { mInjector = new TestInjector(); - mVendorDao = mInjector.getVendorDataDao(mApplicationContext, + mVendorDao = mInjector.getVendorDataDao(mApplicationContext, mApplicationContext.getPackageName(), PackageUtils.getCertDigest(mApplicationContext, mApplicationContext.getPackageName())); - mLocalDao = mInjector.getLocalDataDao(mApplicationContext, + mLocalDao = mInjector.getLocalDataDao(mApplicationContext, mApplicationContext.getPackageName(), PackageUtils.getCertDigest(mApplicationContext, mApplicationContext.getPackageName())); + mEventsDao = mInjector.getEventsDao(mApplicationContext); + mServiceImpl = new DataAccessServiceImpl( mApplicationContext.getPackageName(), mApplicationContext, - true, mInjector); + true, true, mInjector); mServiceProxy = IDataAccessService.Stub.asInterface(mServiceImpl); } @@ -251,13 +263,14 @@ public class DataAccessServiceImplTest { assertEquals("xyz", eventParamsFromUrl.getString("b")); assertEquals(5.0, eventParamsFromUrl.getDouble("c"), DELTA); Uri uri = Uri.parse(eventUrl); - assertEquals(uri.getQueryParameter(EventUrlHelper.URL_LANDING_PAGE_EVENT_KEY), "http://example.com"); + assertEquals(uri.getQueryParameter(EventUrlHelper.URL_LANDING_PAGE_EVENT_KEY), + "http://example.com"); } @Test public void testLocalDataThrowsNotIncluded() { mServiceImpl = new DataAccessServiceImpl( - mApplicationContext.getPackageName(), mApplicationContext, false, mInjector); + mApplicationContext.getPackageName(), mApplicationContext, false, true, mInjector); mServiceProxy = IDataAccessService.Stub.asInterface(mServiceImpl); Bundle params = new Bundle(); params.putStringArray(Constants.EXTRA_LOOKUP_KEYS, new String[]{"localkey"}); @@ -281,41 +294,84 @@ public class DataAccessServiceImplTest { } - class TestCallback extends IDataAccessServiceCallback.Stub { - @Override public void onSuccess(Bundle result) { - mResult = result; - mOnSuccessCalled = true; - mLatch.countDown(); - } - @Override public void onError(int errorCode) { - mErrorCode = errorCode; - mOnErrorCalled = true; - mLatch.countDown(); - } + @Test + public void testGetRequests() throws Exception { + addTestData(); + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, new long[]{0L, 200L}); + mServiceProxy.onRequest( + Constants.DATA_ACCESS_OP_GET_REQUESTS, + params, + new TestCallback()); + mLatch.await(); + assertNotNull(mResult); + List<RequestLogRecord> data = mResult.getParcelable( + Constants.EXTRA_RESULT, OdpParceledListSlice.class).getList(); + assertEquals(3, data.size()); + assertEquals(1, data.get(0).getRequestId()); + assertEquals(2, data.get(1).getRequestId()); + assertEquals(3, data.get(2).getRequestId()); } - class TestInjector extends DataAccessServiceImpl.Injector { - long getTimeMillis() { - return mTimeMillis; - } + @Test + public void testGetRequestsBadInput() { + addTestData(); + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, new long[]{0L}); + assertThrows(IllegalArgumentException.class, () -> mServiceProxy.onRequest( + Constants.DATA_ACCESS_OP_GET_REQUESTS, + params, + new TestCallback())); + } - ListeningExecutorService getExecutor() { - return MoreExecutors.newDirectExecutorService(); - } + @Test + public void testGetJoinedEvents() throws Exception { + addTestData(); + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, new long[]{0L, 200L}); + mServiceProxy.onRequest( + Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS, + params, + new TestCallback()); + mLatch.await(); + assertNotNull(mResult); + List<EventLogRecord> data = mResult.getParcelable( + Constants.EXTRA_RESULT, OdpParceledListSlice.class).getList(); + assertEquals(3, data.size()); + assertEquals(2L, data.get(0).getTimeMillis()); + assertEquals(1L, data.get(0).getRequestLogRecord().getTimeMillis()); + assertEquals(11L, data.get(1).getTimeMillis()); + assertEquals(10L, data.get(1).getRequestLogRecord().getTimeMillis()); + assertEquals(101L, data.get(2).getTimeMillis()); + assertEquals(100L, data.get(2).getRequestLogRecord().getTimeMillis()); + } - OnDevicePersonalizationVendorDataDao getVendorDataDao( - Context context, String packageName, String certDigest - ) { - return OnDevicePersonalizationVendorDataDao.getInstanceForTest( - context, packageName, certDigest); - } + @Test + public void testGetJoinedEventsBadInput() { + addTestData(); + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, new long[]{0L}); + assertThrows(IllegalArgumentException.class, () -> mServiceProxy.onRequest( + Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS, + params, + new TestCallback())); + } - OnDevicePersonalizationLocalDataDao getLocalDataDao( - Context context, String packageName, String certDigest - ) { - return OnDevicePersonalizationLocalDataDao.getInstanceForTest( - context, packageName, certDigest); - } + @Test + public void testEventDataThrowsNotIncluded() { + mServiceImpl = new DataAccessServiceImpl( + mApplicationContext.getPackageName(), mApplicationContext, true, false, mInjector); + mServiceProxy = IDataAccessService.Stub.asInterface(mServiceImpl); + Bundle params = new Bundle(); + params.putLongArray(Constants.EXTRA_LOOKUP_KEYS, new long[]{1L, 2L}); + assertThrows(IllegalStateException.class, () -> mServiceProxy.onRequest( + Constants.DATA_ACCESS_OP_GET_REQUESTS, + params, + new TestCallback())); + assertThrows(IllegalStateException.class, () -> mServiceProxy.onRequest( + Constants.DATA_ACCESS_OP_GET_JOINED_EVENTS, + params, + new TestCallback())); } private void addTestData() { @@ -334,6 +390,70 @@ public class DataAccessServiceImplTest { mLocalDao.updateOrInsertLocalData( new LocalData.Builder().setKey("localkey2").setData(new byte[10]).build()); + + ArrayList<ContentValues> rows = new ArrayList<>(); + ContentValues row = new ContentValues(); + row.put("a", 1); + rows.add(row); + byte[] queryDataBytes = OnDevicePersonalizationFlatbufferUtils.createQueryData( + mApplicationContext.getPackageName(), "AABBCCDD", rows); + + Query query1 = new Query.Builder() + .setTimeMillis(1L) + .setServicePackageName(mApplicationContext.getPackageName()) + .setQueryData(queryDataBytes) + .build(); + long queryId1 = mEventsDao.insertQuery(query1); + Query query2 = new Query.Builder() + .setTimeMillis(10L) + .setServicePackageName(mApplicationContext.getPackageName()) + .setQueryData(queryDataBytes) + .build(); + long queryId2 = mEventsDao.insertQuery(query2); + Query query3 = new Query.Builder() + .setTimeMillis(100L) + .setServicePackageName(mApplicationContext.getPackageName()) + .setQueryData(queryDataBytes) + .build(); + long queryId3 = mEventsDao.insertQuery(query3); + Query query4 = new Query.Builder() + .setTimeMillis(100L) + .setServicePackageName("packageA") + .setQueryData(queryDataBytes) + .build(); + mEventsDao.insertQuery(query4); + + ContentValues data = new ContentValues(); + data.put("a", 1); + byte[] eventData = OnDevicePersonalizationFlatbufferUtils.createEventData(data); + + Event event1 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData(eventData) + .setServicePackageName(mApplicationContext.getPackageName()) + .setQueryId(queryId1) + .setTimeMillis(2L) + .setRowIndex(0) + .build(); + mEventsDao.insertEvent(event1); + Event event2 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData(eventData) + .setServicePackageName(mApplicationContext.getPackageName()) + .setQueryId(queryId2) + .setTimeMillis(11L) + .setRowIndex(0) + .build(); + mEventsDao.insertEvent(event2); + Event event3 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData(eventData) + .setServicePackageName(mApplicationContext.getPackageName()) + .setQueryId(queryId3) + .setTimeMillis(101L) + .setRowIndex(0) + .build(); + mEventsDao.insertEvent(event3); } @After @@ -352,4 +472,50 @@ public class DataAccessServiceImplTest { params.putDouble("c", 5.0); return params; } + + class TestCallback extends IDataAccessServiceCallback.Stub { + @Override + public void onSuccess(Bundle result) { + mResult = result; + mOnSuccessCalled = true; + mLatch.countDown(); + } + + @Override + public void onError(int errorCode) { + mErrorCode = errorCode; + mOnErrorCalled = true; + mLatch.countDown(); + } + } + + class TestInjector extends DataAccessServiceImpl.Injector { + long getTimeMillis() { + return mTimeMillis; + } + + ListeningExecutorService getExecutor() { + return MoreExecutors.newDirectExecutorService(); + } + + OnDevicePersonalizationVendorDataDao getVendorDataDao( + Context context, String packageName, String certDigest + ) { + return OnDevicePersonalizationVendorDataDao.getInstanceForTest( + context, packageName, certDigest); + } + + OnDevicePersonalizationLocalDataDao getLocalDataDao( + Context context, String packageName, String certDigest + ) { + return OnDevicePersonalizationLocalDataDao.getInstanceForTest( + context, packageName, certDigest); + } + + EventsDao getEventsDao( + Context context + ) { + return EventsDao.getInstanceForTest(context); + } + } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/ColumnSchemaTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/ColumnSchemaTest.java new file mode 100644 index 00000000..3192f7a2 --- /dev/null +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/ColumnSchemaTest.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.data.events; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ColumnSchemaTest { + @Test + public void testBuilderAndEquals() { + String columnName = "column"; + int sqlType = ColumnSchema.SQL_DATA_TYPE_INTEGER; + ColumnSchema columnSchema1 = new ColumnSchema.Builder().setName(columnName).setType( + sqlType).build(); + + assertEquals(columnSchema1.getName(), columnName); + assertEquals(columnSchema1.getType(), sqlType); + + ColumnSchema columnSchema2 = new ColumnSchema.Builder( + columnName, sqlType) + .build(); + assertEquals(columnSchema1, columnSchema2); + assertEquals(columnSchema1.hashCode(), columnSchema2.hashCode()); + } + + @Test + public void testToString() { + String columnName = "column"; + ColumnSchema columnSchema = new ColumnSchema.Builder().setName(columnName).setType( + ColumnSchema.SQL_DATA_TYPE_INTEGER).build(); + assertEquals(columnName + " " + "INTEGER", columnSchema.toString()); + + columnSchema = new ColumnSchema.Builder().setName(columnName).setType( + ColumnSchema.SQL_DATA_TYPE_REAL).build(); + assertEquals(columnName + " " + "REAL", columnSchema.toString()); + + columnSchema = new ColumnSchema.Builder().setName(columnName).setType( + ColumnSchema.SQL_DATA_TYPE_TEXT).build(); + assertEquals(columnName + " " + "TEXT", columnSchema.toString()); + + columnSchema = new ColumnSchema.Builder().setName(columnName).setType( + ColumnSchema.SQL_DATA_TYPE_BLOB).build(); + assertEquals(columnName + " " + "BLOB", columnSchema.toString()); + + + } + + @Test + public void testBuildTwiceThrows() { + String columnName = "column"; + int sqlType = ColumnSchema.SQL_DATA_TYPE_INTEGER; + ColumnSchema.Builder builder = new ColumnSchema.Builder().setName(columnName).setType( + sqlType); + + builder.build(); + assertThrows(IllegalStateException.class, builder::build); + } +} diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventStateTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventStateTest.java index 83d9cfe6..3a6fe7e1 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventStateTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventStateTest.java @@ -16,6 +16,7 @@ package com.android.ondevicepersonalization.services.data.events; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; @@ -29,22 +30,20 @@ public class EventStateTest { public void testBuilderAndEquals() { String servicePackageName = "servicePackageName"; String taskIdentifier = "taskIdentifier"; - long queryId = 1; - long eventId = 1; + byte[] token = new byte[] {1}; + EventState eventState1 = new EventState.Builder() .setTaskIdentifier(taskIdentifier) .setServicePackageName(servicePackageName) - .setQueryId(queryId) - .setEventId(eventId) + .setToken(token) .build(); assertEquals(eventState1.getTaskIdentifier(), taskIdentifier); assertEquals(eventState1.getServicePackageName(), servicePackageName); - assertEquals(eventState1.getQueryId(), queryId); - assertEquals(eventState1.getEventId(), eventId); + assertArrayEquals(eventState1.getToken(), token); EventState eventState2 = new EventState.Builder( - eventId, queryId, servicePackageName, taskIdentifier) + token, servicePackageName, taskIdentifier) .build(); assertEquals(eventState1, eventState2); assertEquals(eventState1.hashCode(), eventState2.hashCode()); @@ -54,13 +53,11 @@ public class EventStateTest { public void testBuildTwiceThrows() { String servicePackageName = "servicePackageName"; String taskIdentifier = "taskIdentifier"; - long queryId = 1; - long eventId = 1; + byte[] token = new byte[] {1}; EventState.Builder builder = new EventState.Builder() .setTaskIdentifier(taskIdentifier) .setServicePackageName(servicePackageName) - .setQueryId(queryId) - .setEventId(eventId); + .setToken(token); builder.build(); assertThrows(IllegalStateException.class, builder::build); diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventTest.java index 25102fbd..a791509f 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventTest.java @@ -35,7 +35,7 @@ public class EventTest { long queryId = 1; long timeMillis = 1; long eventId = 1; - long rowIndex = 1; + int rowIndex = 1; Event event1 = new Event.Builder() .setType(EVENT_TYPE) .setEventData(eventData) @@ -68,7 +68,7 @@ public class EventTest { long queryId = 1; long timeMillis = 1; long eventId = 1; - long rowIndex = 1; + int rowIndex = 1; Event.Builder builder = new Event.Builder() .setType(EVENT_TYPE) .setEventData(eventData) diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventsDaoTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventsDaoTest.java index 23c008df..08106889 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventsDaoTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/EventsDaoTest.java @@ -17,6 +17,7 @@ package com.android.ondevicepersonalization.services.data.events; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; @@ -33,6 +34,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; @RunWith(JUnit4.class) @@ -57,8 +59,7 @@ public class EventsDaoTest { private final EventState mEventState = new EventState.Builder() .setTaskIdentifier(TASK_IDENTIFIER) .setServicePackageName(mContext.getPackageName()) - .setQueryId(1L) - .setEventId(1L) + .setToken(new byte[]{1}) .build(); private EventsDao mDao; @@ -92,28 +93,65 @@ public class EventsDaoTest { } @Test + public void testInsertEvents() { + mDao.insertQuery(mTestQuery); + Event testEvent = new Event.Builder() + .setType(EVENT_TYPE_CLICK) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(1L) + .setTimeMillis(1L) + .setRowIndex(0) + .build(); + List<Event> events = new ArrayList<>(); + events.add(mTestEvent); + events.add(testEvent); + assertTrue(mDao.insertEvents(events)); + } + + @Test + public void testInsertEventsFalse() { + List<Event> events = new ArrayList<>(); + events.add(mTestEvent); + assertFalse(mDao.insertEvents(events)); + } + + @Test public void testInsertAndReadEventState() { assertTrue(mDao.updateOrInsertEventState(mEventState)); assertEquals(mEventState, mDao.getEventState(TASK_IDENTIFIER, mContext.getPackageName())); EventState testEventState = new EventState.Builder() .setTaskIdentifier(TASK_IDENTIFIER) .setServicePackageName(mContext.getPackageName()) - .setQueryId(5L) - .setEventId(7L) + .setToken(new byte[]{100}) .build(); assertTrue(mDao.updateOrInsertEventState(testEventState)); assertEquals(testEventState, mDao.getEventState(TASK_IDENTIFIER, mContext.getPackageName())); } + + @Test + public void testInsertAndReadEventStatesTransaction() { + EventState testEventState = new EventState.Builder() + .setTaskIdentifier(TASK_IDENTIFIER) + .setServicePackageName(mContext.getPackageName()) + .setToken(new byte[]{100}) + .build(); + List<EventState> eventStates = new ArrayList<>(); + eventStates.add(mEventState); + eventStates.add(testEventState); + assertTrue(mDao.updateOrInsertEventStatesTransaction(eventStates)); + assertEquals(testEventState, + mDao.getEventState(TASK_IDENTIFIER, mContext.getPackageName())); + } @Test public void testDeleteEventState() { mDao.updateOrInsertEventState(mEventState); EventState testEventState = new EventState.Builder() .setTaskIdentifier(TASK_IDENTIFIER) .setServicePackageName("packageA") - .setQueryId(5L) - .setEventId(7L) + .setToken(new byte[]{100}) .build(); mDao.updateOrInsertEventState(testEventState); mDao.deleteEventState(mContext.getPackageName()); @@ -270,6 +308,269 @@ public class EventsDaoTest { } @Test + public void testReadAllQueries() { + Query query1 = new Query.Builder() + .setTimeMillis(1L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId1 = mDao.insertQuery(query1); + Query query2 = new Query.Builder() + .setTimeMillis(10L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId2 = mDao.insertQuery(query2); + Query query3 = new Query.Builder() + .setTimeMillis(100L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId3 = mDao.insertQuery(query3); + Query query4 = new Query.Builder() + .setTimeMillis(100L) + .setServicePackageName("package") + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId4 = mDao.insertQuery(query4); + + List<Query> result = mDao.readAllQueries(0, 1000, mContext.getPackageName()); + assertEquals(3, result.size()); + assertEquals(queryId1, (long) result.get(0).getQueryId()); + assertEquals(queryId2, (long) result.get(1).getQueryId()); + assertEquals(queryId3, (long) result.get(2).getQueryId()); + + result = mDao.readAllQueries(0, 1000, "package"); + assertEquals(1, result.size()); + assertEquals(queryId4, (long) result.get(0).getQueryId()); + + result = mDao.readAllQueries(500, 1000, mContext.getPackageName()); + assertEquals(0, result.size()); + + result = mDao.readAllQueries(5, 1000, mContext.getPackageName()); + assertEquals(2, result.size()); + assertEquals(queryId2, (long) result.get(0).getQueryId()); + assertEquals(queryId3, (long) result.get(1).getQueryId()); + } + + @Test + public void testReadAllEventIds() { + Query query1 = new Query.Builder() + .setTimeMillis(1L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId1 = mDao.insertQuery(query1); + Query query2 = new Query.Builder() + .setTimeMillis(10L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId2 = mDao.insertQuery(query2); + Query query3 = new Query.Builder() + .setTimeMillis(100L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId3 = mDao.insertQuery(query3); + + Event event1 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId1) + .setTimeMillis(2L) + .setRowIndex(0) + .build(); + long eventId1 = mDao.insertEvent(event1); + Event event2 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId2) + .setTimeMillis(11L) + .setRowIndex(0) + .build(); + long eventId2 = mDao.insertEvent(event2); + Event event3 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId3) + .setTimeMillis(101L) + .setRowIndex(0) + .build(); + long eventId3 = mDao.insertEvent(event3); + + List<Long> result = mDao.readAllEventIds(0, 1000, mContext.getPackageName()); + assertEquals(3, result.size()); + assertEquals(eventId1, (long) result.get(0)); + assertEquals(eventId2, (long) result.get(1)); + assertEquals(eventId3, (long) result.get(2)); + + result = mDao.readAllEventIds(0, 1000, "package"); + assertEquals(0, result.size()); + + result = mDao.readAllEventIds(500, 1000, mContext.getPackageName()); + assertEquals(0, result.size()); + + result = mDao.readAllEventIds(5, 1000, mContext.getPackageName()); + assertEquals(2, result.size()); + assertEquals(eventId2, (long) result.get(0)); + assertEquals(eventId3, (long) result.get(1)); + } + + @Test + public void testReadEventIdsForRequest() { + Query query1 = new Query.Builder() + .setTimeMillis(1L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId1 = mDao.insertQuery(query1); + Query query2 = new Query.Builder() + .setTimeMillis(10L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId2 = mDao.insertQuery(query2); + + Event event1 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId1) + .setTimeMillis(2L) + .setRowIndex(0) + .build(); + long eventId1 = mDao.insertEvent(event1); + Event event2 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId2) + .setTimeMillis(11L) + .setRowIndex(0) + .build(); + long eventId2 = mDao.insertEvent(event2); + Event event3 = new Event.Builder() + .setType(EVENT_TYPE_CLICK) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId2) + .setTimeMillis(101L) + .setRowIndex(0) + .build(); + long eventId3 = mDao.insertEvent(event3); + + List<Long> result = mDao.readAllEventIdsForQuery(queryId1, mContext.getPackageName()); + assertEquals(1, result.size()); + assertEquals(eventId1, (long) result.get(0)); + + result = mDao.readAllEventIdsForQuery(queryId2, mContext.getPackageName()); + assertEquals(2, result.size()); + assertEquals(eventId2, (long) result.get(0)); + assertEquals(eventId3, (long) result.get(1)); + + result = mDao.readAllEventIdsForQuery(1000, mContext.getPackageName()); + assertEquals(0, result.size()); + + result = mDao.readAllEventIdsForQuery(queryId1, "package"); + assertEquals(0, result.size()); + } + + @Test + public void testReadJoinedEvents() { + Query query1 = new Query.Builder() + .setTimeMillis(1L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId1 = mDao.insertQuery(query1); + Query query2 = new Query.Builder() + .setTimeMillis(10L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId2 = mDao.insertQuery(query2); + Query query3 = new Query.Builder() + .setTimeMillis(100L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId3 = mDao.insertQuery(query3); + + Event event1 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId1) + .setTimeMillis(2L) + .setRowIndex(0) + .build(); + long eventId1 = mDao.insertEvent(event1); + Event event2 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId2) + .setTimeMillis(11L) + .setRowIndex(0) + .build(); + long eventId2 = mDao.insertEvent(event2); + Event event3 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId3) + .setTimeMillis(101L) + .setRowIndex(0) + .build(); + long eventId3 = mDao.insertEvent(event3); + + List<JoinedEvent> result = mDao.readJoinedTableRows(0, 1000, mContext.getPackageName()); + assertEquals(3, result.size()); + assertEquals(createExpectedJoinedEvent(event1, query1, eventId1, queryId1), result.get(0)); + assertEquals(createExpectedJoinedEvent(event2, query2, eventId2, queryId2), result.get(1)); + assertEquals(createExpectedJoinedEvent(event3, query3, eventId3, queryId3), result.get(2)); + + result = mDao.readJoinedTableRows(0, 1000, "package"); + assertEquals(0, result.size()); + + result = mDao.readJoinedTableRows(500, 1000, mContext.getPackageName()); + assertEquals(0, result.size()); + + result = mDao.readJoinedTableRows(5, 1000, mContext.getPackageName()); + assertEquals(2, result.size()); + assertEquals(createExpectedJoinedEvent(event2, query2, eventId2, queryId2), result.get(0)); + assertEquals(createExpectedJoinedEvent(event3, query3, eventId3, queryId3), result.get(1)); + } + + @Test + public void testReadSingleQuery() { + Query query1 = new Query.Builder() + .setQueryId(1) + .setTimeMillis(1L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + mDao.insertQuery(query1); + assertEquals(query1, mDao.readSingleQueryRow(1, mContext.getPackageName())); + assertNull(mDao.readSingleQueryRow(100, mContext.getPackageName())); + assertNull(mDao.readSingleQueryRow(1, "package")); + } + + @Test + public void testReadSingleJoinedTableRow() { + mDao.insertQuery(mTestQuery); + mDao.insertEvent(mTestEvent); + assertEquals(createExpectedJoinedEvent(mTestEvent, mTestQuery, 1, 1), + mDao.readSingleJoinedTableRow(1, mContext.getPackageName())); + assertNull(mDao.readSingleJoinedTableRow(100, mContext.getPackageName())); + assertNull(mDao.readSingleJoinedTableRow(1, "package")); + } + + @Test public void testReadEventStateNoEventState() { assertNull(mDao.getEventState(TASK_IDENTIFIER, mContext.getPackageName())); } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedEventTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedEventTest.java index d12d6e52..64a52d68 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedEventTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedEventTest.java @@ -31,7 +31,7 @@ public class JoinedEventTest { String servicePackageName = "servicePackageName"; long queryId = 1; long eventId = 1; - long rowIndex = 1; + int rowIndex = 1; int type = 1; long eventTimeMillis = 100; long queryTimeMillis = 50; @@ -73,7 +73,7 @@ public class JoinedEventTest { String servicePackageName = "servicePackageName"; long queryId = 1; long eventId = 1; - long rowIndex = 1; + int rowIndex = 1; int type = 1; long eventTimeMillis = 100; long queryTimeMillis = 50; diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedTableDaoTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedTableDaoTest.java new file mode 100644 index 00000000..549142aa --- /dev/null +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/events/JoinedTableDaoTest.java @@ -0,0 +1,219 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.data.events; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import android.content.ContentValues; +import android.content.Context; +import android.database.Cursor; + +import androidx.test.core.app.ApplicationProvider; + +import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; +import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.List; + +@RunWith(JUnit4.class) +public class JoinedTableDaoTest { + + private static final int EVENT_TYPE_B2D = 1; + private static final int EVENT_TYPE_CLICK = 2; + private final Context mContext = ApplicationProvider.getApplicationContext(); + private EventsDao mDao; + + @Before + public void setup() { + mDao = EventsDao.getInstanceForTest(mContext); + } + + @After + public void cleanup() { + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstanceForTest(mContext); + dbHelper.getWritableDatabase().close(); + dbHelper.getReadableDatabase().close(); + dbHelper.close(); + } + + @Test + public void invalidProvidedColumns() { + List<ColumnSchema> columnSchemaList = new ArrayList<>(); + columnSchemaList.add(new ColumnSchema.Builder().setName( + JoinedTableDao.SERVICE_PACKAGE_NAME_COL).setType( + ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + assertThrows(IllegalArgumentException.class, + () -> new JoinedTableDao(columnSchemaList, 0, 0, mContext)); + } + + @Test + public void emptyColumns() { + List<ColumnSchema> columnSchemaList = new ArrayList<>(); + assertThrows(IllegalArgumentException.class, + () -> new JoinedTableDao(columnSchemaList, 0, 0, mContext)); + } + + @Test + public void duplicateProvidedColumnNames() { + List<ColumnSchema> columnSchemaList = new ArrayList<>(); + columnSchemaList.add(new ColumnSchema.Builder().setName("ColumnName").setType( + ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + columnSchemaList.add(new ColumnSchema.Builder().setName("ColumnName").setType( + ColumnSchema.SQL_DATA_TYPE_BLOB).build()); + assertThrows(IllegalArgumentException.class, + () -> new JoinedTableDao(columnSchemaList, 0, 0, mContext)); + } + + @Test + public void testRawQuery() { + insertEventAndQueryData(); + + List<ColumnSchema> columnSchemaList = new ArrayList<>( + JoinedTableDao.ODP_PROVIDED_COLUMNS.values()); + columnSchemaList.add(new ColumnSchema.Builder().setName("eventCol1").setType( + ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + columnSchemaList.add(new ColumnSchema.Builder().setName("eventCol2").setType( + ColumnSchema.SQL_DATA_TYPE_TEXT).build()); + columnSchemaList.add(new ColumnSchema.Builder().setName("eventCol4").setType( + ColumnSchema.SQL_DATA_TYPE_REAL).build()); + columnSchemaList.add(new ColumnSchema.Builder().setName("queryCol1").setType( + ColumnSchema.SQL_DATA_TYPE_INTEGER).build()); + + JoinedTableDao joinedTableDao = new JoinedTableDao(columnSchemaList, 0, 0, mContext); + try (Cursor cursor = joinedTableDao.rawQuery( + "SELECT * FROM " + JoinedTableDao.TABLE_NAME + " ORDER BY ROWID")) { + // Assert two rows for the two joined events. two rows for the query. + assertEquals(4, cursor.getCount()); + for (int i = 0; i < 4; i++) { + cursor.moveToNext(); + String servicePackageName = cursor.getString( + cursor.getColumnIndexOrThrow(JoinedTableDao.SERVICE_PACKAGE_NAME_COL)); + int type = cursor.getInt(cursor.getColumnIndexOrThrow(JoinedTableDao.TYPE_COL)); + long eventTimeMillis = cursor.getLong( + cursor.getColumnIndexOrThrow(JoinedTableDao.EVENT_TIME_MILLIS_COL)); + long queryTimeMillis = cursor.getLong( + cursor.getColumnIndexOrThrow(JoinedTableDao.QUERY_TIME_MILLIS_COL)); + int eventCol1 = cursor.getInt(cursor.getColumnIndexOrThrow("eventCol1")); + String eventCol2 = cursor.getString(cursor.getColumnIndexOrThrow("eventCol2")); + double eventCol4 = cursor.getDouble(cursor.getColumnIndexOrThrow("eventCol4")); + int queryCol1 = cursor.getInt(cursor.getColumnIndexOrThrow("queryCol1")); + assertThrows(IllegalArgumentException.class, + () -> cursor.getColumnIndexOrThrow("eventCol3")); + assertThrows(IllegalArgumentException.class, + () -> cursor.getColumnIndexOrThrow("random")); + assertThrows(IllegalArgumentException.class, + () -> cursor.getColumnIndexOrThrow("someCol")); + if (i == 0) { + assertEquals(mContext.getPackageName(), servicePackageName); + assertEquals(EVENT_TYPE_B2D, type); + assertEquals(1L, eventTimeMillis); + assertEquals(1L, queryTimeMillis); + assertEquals(100, eventCol1); + assertEquals("helloWorld", eventCol2); + assertEquals(0.0, eventCol4, 0.001); + assertEquals(1, queryCol1); + } else if (i == 1) { + assertEquals(mContext.getPackageName(), servicePackageName); + assertEquals(EVENT_TYPE_CLICK, type); + assertEquals(2L, eventTimeMillis); + assertEquals(1L, queryTimeMillis); + assertEquals(50, eventCol1); + assertEquals("helloEarth", eventCol2); + assertEquals(2.0, eventCol4, 0.001); + assertEquals(2, queryCol1); + } else if (i == 2) { + assertEquals(mContext.getPackageName(), servicePackageName); + assertEquals(0L, type); + assertEquals(0L, eventTimeMillis); + assertEquals(1L, queryTimeMillis); + assertEquals(0, eventCol1); + assertNull(eventCol2); + assertEquals(0.0, eventCol4, 0.001); + assertEquals(1, queryCol1); + } else if (i == 3) { + assertEquals(mContext.getPackageName(), servicePackageName); + assertEquals(0L, type); + assertEquals(0L, eventTimeMillis); + assertEquals(1L, queryTimeMillis); + assertEquals(0, eventCol1); + assertNull(eventCol2); + assertEquals(0.0, eventCol4, 0.001); + assertEquals(2, queryCol1); + } + } + } + } + + private void insertEventAndQueryData() { + ArrayList<ContentValues> rows = new ArrayList<>(); + ContentValues row = new ContentValues(); + row.put("queryCol1", 1); + rows.add(row); + row = new ContentValues(); + row.put("queryCol1", 2); + rows.add(row); + Query query = new Query.Builder() + .setTimeMillis(1L) + .setServicePackageName(mContext.getPackageName()) + .setQueryData(OnDevicePersonalizationFlatbufferUtils.createQueryData( + mContext.getPackageName(), "AABBCCDD", rows)) + .build(); + long queryId = mDao.insertQuery(query); + + ContentValues eventData = new ContentValues(); + eventData.put("eventCol1", 100); + eventData.put("eventCol2", "helloWorld"); + eventData.put("eventCol3", "unused"); + eventData.put("eventCol4", "wrong_type"); + eventData.put("random", 20); + Event event1 = new Event.Builder() + .setType(EVENT_TYPE_B2D) + .setEventData(OnDevicePersonalizationFlatbufferUtils.createEventData(eventData)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId) + .setTimeMillis(1L) + .setRowIndex(0) + .build(); + mDao.insertEvent(event1); + + ContentValues eventData2 = new ContentValues(); + eventData2.put("eventCol1", 50); + eventData2.put("eventCol2", "helloEarth"); + eventData2.put("eventCol3", "unused"); + eventData2.put("eventCol4", 2.0); + eventData2.put("someCol", 600); + Event event2 = new Event.Builder() + .setType(EVENT_TYPE_CLICK) + .setEventData(OnDevicePersonalizationFlatbufferUtils.createEventData(eventData2)) + .setServicePackageName(mContext.getPackageName()) + .setQueryId(queryId) + .setTimeMillis(2L) + .setRowIndex(1) + .build(); + mDao.insertEvent(event2); + } +} diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobServiceTest.java index 549bbe62..d5c2c501 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobServiceTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectionJobServiceTest.java @@ -52,11 +52,14 @@ public class UserDataCollectionJobServiceTest { private final Context mContext = ApplicationProvider.getApplicationContext(); private UserDataCollector mUserDataCollector; private UserDataCollectionJobService mService; + private UserPrivacyStatus mPrivacyStatus = UserPrivacyStatus.getInstance(); @Before public void setup() throws Exception { PhFlagsTestUtil.setUpDeviceConfigPermissions(); PhFlagsTestUtil.disableGlobalKillSwitch(); + PhFlagsTestUtil.disablePersonalizationStatusOverride(); + mPrivacyStatus.setPersonalizationStatusEnabled(true); mUserDataCollector = UserDataCollector.getInstanceForTest(mContext); mService = spy(new UserDataCollectionJobService()); } @@ -97,6 +100,20 @@ public class UserDataCollectionJobServiceTest { } @Test + public void onStartJobTestPersonalizationBlocked() { + mPrivacyStatus.setPersonalizationStatusEnabled(false); + MockitoSession session = ExtendedMockito.mockitoSession().startMocking(); + try { + doNothing().when(mService).jobFinished(any(), anyBoolean()); + boolean result = mService.onStartJob(mock(JobParameters.class)); + assertTrue(result); + verify(mService, times(1)).jobFinished(any(), eq(false)); + } finally { + session.finishMocking(); + } + } + + @Test public void onStopJobTest() { MockitoSession session = ExtendedMockito.mockitoSession().strictness( Strictness.LENIENT).startMocking(); diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectorTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectorTest.java index 04b1e41e..443a1e80 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectorTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/data/user/UserDataCollectorTest.java @@ -18,6 +18,7 @@ package com.android.ondevicepersonalization.services.data.user; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -37,7 +38,6 @@ import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.TimeZone; @RunWith(JUnit4.class) @@ -50,47 +50,22 @@ public class UserDataCollectorTest { public void setup() { mCollector = UserDataCollector.getInstanceForTest(mContext); mUserData = RawUserData.getInstance(); + TimeZone pstTime = TimeZone.getTimeZone("GMT-08:00"); + TimeZone.setDefault(pstTime); } @Test - public void testUpdateUserData() throws InterruptedException { + public void testUpdateUserData() throws Exception { mCollector.updateUserData(mUserData); // Test initial collection. // TODO(b/261748573): Add manual tests for histogram updates - assertTrue(mUserData.timeMillis > 0); - assertTrue(mUserData.timeMillis <= mCollector.getTimeMillis()); - assertNotNull(mUserData.utcOffset); - assertEquals(mUserData.utcOffset, mCollector.getUtcOffset()); - - assertTrue(mUserData.availableStorageBytes > 0); - assertTrue(mUserData.batteryPercentage > 0); - assertEquals(mUserData.country, mCollector.getCountry()); - assertEquals(mUserData.language, mCollector.getLanguage()); - assertEquals(mUserData.carrier, mCollector.getCarrier()); - assertEquals(mUserData.connectionType, mCollector.getConnectionType()); - assertEquals(mUserData.networkMetered, mCollector.isNetworkMetered()); - - OSVersion osVersions = new OSVersion(); - mCollector.getOSVersions(osVersions); - assertEquals(mUserData.osVersions.major, osVersions.major); - assertEquals(mUserData.osVersions.minor, osVersions.minor); - assertEquals(mUserData.osVersions.micro, osVersions.micro); - - DeviceMetrics deviceMetrics = new DeviceMetrics(); - mCollector.getDeviceMetrics(deviceMetrics); - assertEquals(mUserData.deviceMetrics.make, deviceMetrics.make); - assertEquals(mUserData.deviceMetrics.model, deviceMetrics.model); - assertTrue(mUserData.deviceMetrics.screenHeight > 0); - assertEquals(mUserData.deviceMetrics.screenHeight, deviceMetrics.screenHeight); - assertTrue(mUserData.deviceMetrics.screenWidth > 0); - assertEquals(mUserData.deviceMetrics.screenWidth, deviceMetrics.screenWidth); - assertTrue(mUserData.deviceMetrics.xdpi > 0); - assertEquals(mUserData.deviceMetrics.xdpi, deviceMetrics.xdpi, 0.01); - assertTrue(mUserData.deviceMetrics.ydpi > 0); - assertEquals(mUserData.deviceMetrics.ydpi, deviceMetrics.ydpi, 0.01); - assertTrue(mUserData.deviceMetrics.pxRatio > 0); - assertEquals(mUserData.deviceMetrics.pxRatio, deviceMetrics.pxRatio, 0.01); + assertNotEquals(0, mUserData.utcOffset); + assertTrue(mUserData.availableStorageBytes >= 0); + assertTrue(mUserData.batteryPercentage >= 0); + assertTrue(mUserData.batteryPercentage <= 100); + assertNotNull(mUserData.networkCapabilities); + assertTrue(UserDataCollector.ALLOWED_NETWORK_TYPE.contains(mUserData.dataNetworkType)); List<AppInfo> appsInfo = new ArrayList(); mCollector.getInstalledApps(appsInfo); @@ -105,49 +80,15 @@ public class UserDataCollectorTest { @Test public void testRealTimeUpdate() { - // TODO: test orientation modification. + // TODO (b/307176787): test orientation modification. mCollector.updateUserData(mUserData); - long oldTimeMillis = mUserData.timeMillis; TimeZone tzGmt4 = TimeZone.getTimeZone("GMT+04:00"); TimeZone.setDefault(tzGmt4); mCollector.getRealTimeData(mUserData); - assertTrue(oldTimeMillis <= mUserData.timeMillis); assertEquals(mUserData.utcOffset, 240); } @Test - public void testGetCountry() { - mCollector.setLocale(new Locale("en", "US")); - mCollector.updateUserData(mUserData); - assertNotNull(mUserData.country); - assertEquals(mUserData.country, Country.USA); - } - - @Test - public void testUnknownCountry() { - mCollector.setLocale(new Locale("en")); - mCollector.updateUserData(mUserData); - assertNotNull(mUserData.country); - assertEquals(mUserData.country, Country.UNKNOWN); - } - - @Test - public void testGetLanguage() { - mCollector.setLocale(new Locale("zh", "CN")); - mCollector.updateUserData(mUserData); - assertNotNull(mUserData.language); - assertEquals(mUserData.language, Language.ZH); - } - - @Test - public void testUnknownLanguage() { - mCollector.setLocale(new Locale("nonexist_lang", "CA")); - mCollector.updateUserData(mUserData); - assertNotNull(mUserData.language); - assertEquals(mUserData.language, Language.UNKNOWN); - } - - @Test public void testRecoveryFromSystemCrash() { mCollector.updateUserData(mUserData); // Backup sample answer. diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/display/DisplayHelperTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/display/DisplayHelperTest.java index b092c7a1..cd09ce69 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/display/DisplayHelperTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/display/DisplayHelperTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertNotNull; import android.Manifest; import android.adservices.ondevicepersonalization.RenderOutput; +import android.adservices.ondevicepersonalization.RenderOutputParcel; import android.adservices.ondevicepersonalization.RequestLogRecord; import android.content.Context; import android.hardware.display.DisplayManager; @@ -72,7 +73,8 @@ public class DisplayHelperTest { DisplayHelper displayHelper = new DisplayHelper(mContext); RenderOutput renderContentResult = new RenderOutput.Builder() .setContent("html").build(); - assertEquals("html", displayHelper.generateHtml(renderContentResult, + RenderOutputParcel resultParcel = new RenderOutputParcel(renderContentResult); + assertEquals("html", displayHelper.generateHtml(resultParcel, mContext.getPackageName())); } @@ -95,8 +97,9 @@ public class DisplayHelperTest { .setTemplateId("templateId") .setTemplateParams(bundle) .build(); + RenderOutputParcel resultParcel = new RenderOutputParcel(renderContentResult); String expected = "Hello odp! I am 100."; - assertEquals(expected, displayHelper.generateHtml(renderContentResult, + assertEquals(expected, displayHelper.generateHtml(resultParcel, mContext.getPackageName())); } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/display/OdpWebViewClientTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/display/OdpWebViewClientTests.java index c8c1b0c0..0834c938 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/display/OdpWebViewClientTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/display/OdpWebViewClientTests.java @@ -48,6 +48,7 @@ import com.android.ondevicepersonalization.services.data.events.EventsDao; import com.android.ondevicepersonalization.services.data.events.Query; import com.android.ondevicepersonalization.services.fbs.EventFields; +import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import org.junit.After; @@ -63,7 +64,6 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @RunWith(JUnit4.class) @@ -223,8 +223,8 @@ public class OdpWebViewClientTests { } class TestInjector extends OdpWebViewClient.Injector { - Executor getExecutor() { - return MoreExecutors.directExecutor(); + ListeningExecutorService getExecutor() { + return MoreExecutors.newDirectExecutorService(); } void openUrl(String url, Context context) { diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/download/mdd/MddJobServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/download/mdd/MddJobServiceTest.java index fde6f432..334888f1 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/download/mdd/MddJobServiceTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/download/mdd/MddJobServiceTest.java @@ -43,6 +43,7 @@ import androidx.test.core.app.ApplicationProvider; import com.android.dx.mockito.inline.extended.ExtendedMockito; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; import com.android.ondevicepersonalization.services.PhFlagsTestUtil; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; @@ -60,11 +61,14 @@ public class MddJobServiceTest { private JobScheduler mMockJobScheduler; private MddJobService mSpyService; + private UserPrivacyStatus mUserPrivacyStatus = UserPrivacyStatus.getInstance(); @Before public void setup() throws Exception { PhFlagsTestUtil.setUpDeviceConfigPermissions(); PhFlagsTestUtil.disableGlobalKillSwitch(); + PhFlagsTestUtil.disablePersonalizationStatusOverride(); + mUserPrivacyStatus.setPersonalizationStatusEnabled(true); ListeningExecutorService executorService = MoreExecutors.newDirectExecutorService(); MobileDataDownloadFactory.getMdd(mContext, executorService, executorService); @@ -118,6 +122,21 @@ public class MddJobServiceTest { } @Test + public void onStartJobTestPersonalizationBlocked() { + mUserPrivacyStatus.setPersonalizationStatusEnabled(false); + MockitoSession session = ExtendedMockito.mockitoSession().startMocking(); + try { + doNothing().when(mSpyService).jobFinished(any(), anyBoolean()); + boolean result = mSpyService.onStartJob(mock(JobParameters.class)); + assertTrue(result); + verify(mSpyService, times(1)).jobFinished(any(), eq(false)); + verify(mMockJobScheduler, times(0)).schedule(any()); + } finally { + session.finishMocking(); + } + } + + @Test public void onStartJobNoTaskTagTest() { MockitoSession session = ExtendedMockito.mockitoSession().spyStatic( OnDevicePersonalizationExecutors.class).strictness( diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/FederatedComputeServiceImplTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/FederatedComputeServiceImplTest.java new file mode 100644 index 00000000..0fc82ea8 --- /dev/null +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/FederatedComputeServiceImplTest.java @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.federatedcompute; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; +import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; +import android.content.Context; +import android.federatedcompute.FederatedComputeManager; +import android.federatedcompute.common.ClientConstants; +import android.federatedcompute.common.ScheduleFederatedComputeRequest; +import android.federatedcompute.common.TrainingInterval; +import android.federatedcompute.common.TrainingOptions; +import android.os.OutcomeReceiver; + +import androidx.test.core.app.ApplicationProvider; + +import com.android.compatibility.common.util.ShellUtils; +import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; + +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +@RunWith(JUnit4.class) +public class FederatedComputeServiceImplTest { + private static final String FC_SERVER_URL = + "https://google.com"; + private final Context mApplicationContext = ApplicationProvider.getApplicationContext(); + ArgumentCaptor<OutcomeReceiver<Object, Exception>> mCallbackCapture; + ArgumentCaptor<ScheduleFederatedComputeRequest> mRequestCapture; + private TestInjector mInjector = new TestInjector(); + private CountDownLatch mLatch = new CountDownLatch(1); + private int mErrorCode = 0; + private boolean mOnSuccessCalled = false; + private boolean mOnErrorCalled = false; + private FederatedComputeServiceImpl mServiceImpl; + private IFederatedComputeService mServiceProxy; + private FederatedComputeManager mMockManager; + + @Before + public void setup() throws Exception { + mInjector = new TestInjector(); + mMockManager = Mockito.mock(FederatedComputeManager.class); + mCallbackCapture = ArgumentCaptor.forClass(OutcomeReceiver.class); + mRequestCapture = ArgumentCaptor.forClass(ScheduleFederatedComputeRequest.class); + doNothing().when(mMockManager).cancel(any(), any(), mCallbackCapture.capture()); + doNothing().when(mMockManager).schedule(mRequestCapture.capture(), any(), + mCallbackCapture.capture()); + + mServiceImpl = new FederatedComputeServiceImpl( + mApplicationContext.getPackageName(), mApplicationContext, mInjector); + mServiceProxy = IFederatedComputeService.Stub.asInterface(mServiceImpl); + } + + @Test + public void testSchedule() throws Exception { + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumIntervalMillis(100) + .setSchedulingMode(1) + .build(); + TrainingOptions options = new TrainingOptions.Builder() + .setPopulationName("population") + .setTrainingInterval(interval) + .build(); + mServiceProxy.schedule( + options, + new TestCallback()); + mCallbackCapture.getValue().onResult(null); + var request = mRequestCapture.getValue(); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertEquals(FC_SERVER_URL, request.getTrainingOptions().getServerAddress()); + assertEquals("population", request.getTrainingOptions().getPopulationName()); + assertTrue(mOnSuccessCalled); + } + + @Test + public void testScheduleUrlOverride() throws Exception { + ShellUtils.runShellCommand( + "setprop debug.ondevicepersonalization.override_fc_server_url_package " + + mApplicationContext.getPackageName()); + String overrideUrl = "https://android.com"; + ShellUtils.runShellCommand( + "setprop debug.ondevicepersonalization.override_fc_server_url " + overrideUrl); + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumIntervalMillis(100) + .setSchedulingMode(1) + .build(); + TrainingOptions options = new TrainingOptions.Builder() + .setPopulationName("population") + .setTrainingInterval(interval) + .build(); + mServiceProxy.schedule( + options, + new TestCallback()); + mCallbackCapture.getValue().onResult(null); + var request = mRequestCapture.getValue(); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertEquals(overrideUrl, request.getTrainingOptions().getServerAddress()); + assertEquals("population", request.getTrainingOptions().getPopulationName()); + assertTrue(mOnSuccessCalled); + } + + @Test + public void testScheduleErr() throws Exception { + TrainingInterval interval = new TrainingInterval.Builder() + .setMinimumIntervalMillis(100) + .setSchedulingMode(1) + .build(); + TrainingOptions options = new TrainingOptions.Builder() + .setPopulationName("population") + .setTrainingInterval(interval) + .build(); + mServiceProxy.schedule( + options, + new TestCallback()); + mCallbackCapture.getValue().onError(new Exception()); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertTrue(mOnErrorCalled); + assertEquals(ClientConstants.STATUS_INTERNAL_ERROR, mErrorCode); + } + + @Test + public void testCancel() throws Exception { + EventsDao.getInstanceForTest(mApplicationContext).updateOrInsertEventState( + new EventState.Builder() + .setServicePackageName(mApplicationContext.getPackageName()) + .setTaskIdentifier("population") + .setToken(new byte[]{}) + .build() + ); + mServiceProxy.cancel( + "population", + new TestCallback()); + mCallbackCapture.getValue().onResult(null); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertTrue(mOnSuccessCalled); + } + + @Test + public void testCancelNoPopulation() throws Exception { + mServiceProxy.cancel( + "population", + new TestCallback()); + mLatch.await(1000, TimeUnit.MILLISECONDS); + verify(mMockManager, times(0)).cancel(any(), any(), any()); + assertTrue(mOnSuccessCalled); + } + + @Test + public void testCancelErr() throws Exception { + EventsDao.getInstanceForTest(mApplicationContext).updateOrInsertEventState( + new EventState.Builder() + .setServicePackageName(mApplicationContext.getPackageName()) + .setTaskIdentifier("population") + .setToken(new byte[]{}) + .build() + ); + mServiceProxy.cancel( + "population", + new TestCallback()); + mCallbackCapture.getValue().onError(new Exception()); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertTrue(mOnErrorCalled); + assertEquals(ClientConstants.STATUS_INTERNAL_ERROR, mErrorCode); + } + + @After + public void cleanup() { + ShellUtils.runShellCommand( + "setprop debug.ondevicepersonalization.override_fc_server_url_package \"\""); + ShellUtils.runShellCommand( + "setprop debug.ondevicepersonalization.override_fc_server_url \"\""); + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstanceForTest(mApplicationContext); + dbHelper.getWritableDatabase().close(); + dbHelper.getReadableDatabase().close(); + dbHelper.close(); + } + + class TestCallback extends IFederatedComputeCallback.Stub { + @Override + public void onSuccess() { + mOnSuccessCalled = true; + mLatch.countDown(); + } + + @Override + public void onFailure(int i) { + mErrorCode = i; + mOnErrorCalled = true; + mLatch.countDown(); + } + } + + class TestInjector extends FederatedComputeServiceImpl.Injector { + + ListeningExecutorService getExecutor() { + return MoreExecutors.newDirectExecutorService(); + } + + FederatedComputeManager getFederatedComputeManager(Context context) { + return mMockManager; + } + + EventsDao getEventsDao( + Context context + ) { + return EventsDao.getInstanceForTest(context); + } + } +} diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactoryTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactoryTest.java index 98761686..437d14a9 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactoryTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorFactoryTest.java @@ -32,6 +32,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.ArrayList; +import java.util.List; + @RunWith(JUnit4.class) public class OdpExampleStoreIteratorFactoryTest { private final Context mContext = ApplicationProvider.getApplicationContext(); @@ -46,8 +49,12 @@ public class OdpExampleStoreIteratorFactoryTest { @Test public void testNext() { - OdpExampleStoreIterator it = OdpExampleStoreIteratorFactory.getInstance(mContext) - .createIterator(); + List<byte[]> exampleList = new ArrayList<>(); + exampleList.add(new byte[] {1}); + List<byte[]> tokenList = new ArrayList<>(); + tokenList.add(new byte[] {2}); + OdpExampleStoreIterator it = + OdpExampleStoreIteratorFactory.getInstance().createIterator(exampleList, tokenList); it.next(new TestIteratorCallback()); assertTrue(mIteratorCallbackOnSuccessCalled); assertFalse(mIteratorCallbackOnFailureCalled); @@ -63,7 +70,7 @@ public class OdpExampleStoreIteratorFactoryTest { @Override public void onIteratorNextFailure(int errorCode) { - mIteratorCallbackOnSuccessCalled = true; + mIteratorCallbackOnFailureCalled = true; } } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorTest.java index 07eb45a2..df95652e 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreIteratorTest.java @@ -16,7 +16,13 @@ package com.android.ondevicepersonalization.services.federatedcompute; +import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESULT; +import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN; + +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import android.federatedcompute.ExampleStoreIterator; @@ -26,6 +32,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; @RunWith(JUnit4.class) @@ -37,16 +45,51 @@ public class OdpExampleStoreIteratorTest { @Test public void testNext() { - OdpExampleStoreIterator it = new OdpExampleStoreIterator(); - it.next(new TestIteratorCallback()); + List<byte[]> exampleList = new ArrayList<>(); + exampleList.add(new byte[] {1}); + List<byte[]> tokenList = new ArrayList<>(); + tokenList.add(new byte[] {2}); + OdpExampleStoreIterator it = new OdpExampleStoreIterator(exampleList, tokenList); + it.next(new TestIteratorCallback(new byte[] {1}, new byte[] {2})); + assertTrue(mIteratorCallbackOnSuccessCalled); + assertFalse(mIteratorCallbackOnFailureCalled); + mIteratorCallbackOnSuccessCalled = false; + it.next(new TestIteratorCallback(null, null)); assertTrue(mIteratorCallbackOnSuccessCalled); assertFalse(mIteratorCallbackOnFailureCalled); } + @Test + public void testConstructorError() { + List<byte[]> exampleList = new ArrayList<>(); + exampleList.add(new byte[] {1}); + List<byte[]> tokenList = new ArrayList<>(); + assertThrows( + IllegalArgumentException.class, + () -> new OdpExampleStoreIterator(exampleList, tokenList)); + } + public class TestIteratorCallback implements ExampleStoreIterator.IteratorCallback { + byte[] mExpectedExample; + byte[] mExpectedResumptionToken; + + TestIteratorCallback(byte[] expectedExample, byte[] expectedResumptionToken) { + mExpectedExample = expectedExample; + mExpectedResumptionToken = expectedResumptionToken; + } + @Override public boolean onIteratorNextSuccess(Bundle result) { + if (mExpectedExample == null) { + assertNull(result); + } else { + assertArrayEquals( + mExpectedExample, result.getByteArray(EXTRA_EXAMPLE_ITERATOR_RESULT)); + assertArrayEquals( + mExpectedResumptionToken, + result.getByteArray(EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN)); + } mIteratorCallbackOnSuccessCalled = true; mLatch.countDown(); return true; @@ -54,7 +97,7 @@ public class OdpExampleStoreIteratorTest { @Override public void onIteratorNextFailure(int errorCode) { - mIteratorCallbackOnSuccessCalled = true; + mIteratorCallbackOnFailureCalled = true; mLatch.countDown(); } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreServiceTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreServiceTests.java index 7f79ea0f..472f2454 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreServiceTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreServiceTests.java @@ -17,62 +17,219 @@ package com.android.ondevicepersonalization.services.federatedcompute; import static android.federatedcompute.common.ClientConstants.EXAMPLE_STORE_ACTION; +import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESULT; +import static android.federatedcompute.common.ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; import android.content.Context; import android.content.Intent; +import android.content.pm.PackageManager; import android.federatedcompute.aidl.IExampleStoreCallback; import android.federatedcompute.aidl.IExampleStoreIterator; +import android.federatedcompute.aidl.IExampleStoreIteratorCallback; import android.federatedcompute.aidl.IExampleStoreService; -import android.net.Uri; +import android.federatedcompute.common.ClientConstants; import android.os.Bundle; import android.os.IBinder; import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; -import androidx.test.rule.ServiceTestRule; -import org.junit.Rule; +import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; + +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.InjectMocks; +import org.mockito.Mock; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @RunWith(JUnit4.class) public class OdpExampleStoreServiceTests { - @Rule - public final ServiceTestRule serviceRule = new ServiceTestRule(); private final Context mContext = ApplicationProvider.getApplicationContext(); - private final CountDownLatch mLatch = new CountDownLatch(1); + @Mock Context mMockContext; + @InjectMocks OdpExampleStoreService mService; + private CountDownLatch mLatch; + private boolean mIteratorCallbackOnSuccessCalled = false; + private boolean mIteratorCallbackOnFailureCalled = false; private boolean mQueryCallbackOnSuccessCalled = false; private boolean mQueryCallbackOnFailureCalled = false; + private final EventsDao mEventsDao = EventsDao.getInstanceForTest(mContext); + + @Before + public void setUp() { + initMocks(this); + when(mMockContext.getApplicationContext()).thenReturn(mContext); + mQueryCallbackOnSuccessCalled = false; + mQueryCallbackOnFailureCalled = false; + mLatch = new CountDownLatch(1); + } + @Test public void testWithStartQuery() throws Exception { - Intent mIntent = new Intent(); - mIntent.setAction(EXAMPLE_STORE_ACTION).setPackage(mContext.getPackageName()); - mIntent.setData( - new Uri.Builder().scheme("app").authority(mContext.getPackageName()) - .path("collection").build()); - IBinder binder = serviceRule.bindService(mIntent); + mEventsDao.updateOrInsertEventState( + new EventState.Builder() + .setTaskIdentifier("PopulationName") + .setServicePackageName(mContext.getPackageName()) + .setToken() + .build()); + mService.onCreate(); + Intent intent = new Intent(); + intent.setAction(EXAMPLE_STORE_ACTION).setPackage(mContext.getPackageName()); + IExampleStoreService binder = + IExampleStoreService.Stub.asInterface(mService.onBind(intent)); assertNotNull(binder); - ((IExampleStoreService.Stub) binder).startQuery(Bundle.EMPTY, new TestQueryCallback()); - mLatch.await(1000, TimeUnit.MILLISECONDS); + TestQueryCallback callback = new TestQueryCallback(); + Bundle input = new Bundle(); + ContextData contextData = new ContextData(mContext.getPackageName()); + input.putByteArray( + ClientConstants.EXTRA_CONTEXT_DATA, ContextData.toByteArray(contextData)); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, "PopulationName"); + input.putString(ClientConstants.EXTRA_TASK_NAME, "TaskName"); + + binder.startQuery(input, callback); + assertTrue( + "timeout reached while waiting for countdownlatch!", + mLatch.await(1000, TimeUnit.MILLISECONDS)); + assertTrue(mQueryCallbackOnSuccessCalled); assertFalse(mQueryCallbackOnFailureCalled); + + IExampleStoreIterator iterator = callback.getIterator(); + TestIteratorCallback iteratorCallback = new TestIteratorCallback(); + mLatch = new CountDownLatch(1); + iteratorCallback.setExpected(new byte[] {10}, "token1".getBytes()); + iterator.next(iteratorCallback); + assertTrue( + "timeout reached while waiting for countdownlatch!", + mLatch.await(1000, TimeUnit.MILLISECONDS)); + assertTrue(mIteratorCallbackOnSuccessCalled); + assertFalse(mIteratorCallbackOnFailureCalled); + mIteratorCallbackOnSuccessCalled = false; + + mLatch = new CountDownLatch(1); + iteratorCallback.setExpected(new byte[] {20}, "token2".getBytes()); + iterator.next(iteratorCallback); + assertTrue( + "timeout reached while waiting for countdownlatch!", + mLatch.await(1000, TimeUnit.MILLISECONDS)); + assertTrue(mIteratorCallbackOnSuccessCalled); + assertFalse(mIteratorCallbackOnFailureCalled); + } + + @Test + public void testWithStartQueryNotValidJob() throws Exception { + mService.onCreate(); + Intent intent = new Intent(); + intent.setAction(EXAMPLE_STORE_ACTION).setPackage(mContext.getPackageName()); + IExampleStoreService binder = + IExampleStoreService.Stub.asInterface(mService.onBind(intent)); + assertNotNull(binder); + TestQueryCallback callback = new TestQueryCallback(); + Bundle input = new Bundle(); + ContextData contextData = new ContextData(mContext.getPackageName()); + input.putByteArray( + ClientConstants.EXTRA_CONTEXT_DATA, ContextData.toByteArray(contextData)); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, "PopulationName"); + input.putString(ClientConstants.EXTRA_TASK_NAME, "TaskName"); + + ((IExampleStoreService.Stub) binder).startQuery(input, callback); + mLatch.await(1000, TimeUnit.MILLISECONDS); + + assertFalse(mQueryCallbackOnSuccessCalled); + assertTrue(mQueryCallbackOnFailureCalled); + } + + @Test + public void testWithStartQueryBadInput() throws Exception { + mService.onCreate(); + Intent intent = new Intent(); + intent.setAction(EXAMPLE_STORE_ACTION).setPackage(mContext.getPackageName()); + IExampleStoreService binder = + IExampleStoreService.Stub.asInterface(mService.onBind(intent)); + assertNotNull(binder); + TestQueryCallback callback = new TestQueryCallback(); + binder.startQuery(Bundle.EMPTY, callback); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertFalse(mQueryCallbackOnSuccessCalled); + assertTrue(mQueryCallbackOnFailureCalled); + } + + @Test + public void testFailedPermissionCheck() throws Exception { + when(mMockContext.checkCallingOrSelfPermission( + eq("android.permission.BIND_EXAMPLE_STORE_SERVICE"))) + .thenReturn(PackageManager.PERMISSION_DENIED); + mService.onCreate(); + Intent intent = new Intent(); + intent.setAction(EXAMPLE_STORE_ACTION).setPackage(mContext.getPackageName()); + IExampleStoreService binder = + IExampleStoreService.Stub.asInterface(mService.onBind(intent)); + + assertThrows( + SecurityException.class, + () -> binder.startQuery(Bundle.EMPTY, new TestQueryCallback())); + + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertFalse(mQueryCallbackOnSuccessCalled); + assertFalse(mQueryCallbackOnFailureCalled); + } + + public class TestIteratorCallback implements IExampleStoreIteratorCallback { + byte[] mExpectedExample; + byte[] mExpectedResumptionToken; + + public void setExpected(byte[] expectedExample, byte[] expectedResumptionToken) { + mExpectedExample = expectedExample; + mExpectedResumptionToken = expectedResumptionToken; + } + + @Override + public void onIteratorNextSuccess(Bundle result) throws RemoteException { + assertArrayEquals(mExpectedExample, result.getByteArray(EXTRA_EXAMPLE_ITERATOR_RESULT)); + assertArrayEquals( + mExpectedResumptionToken, + result.getByteArray(EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN)); + mIteratorCallbackOnSuccessCalled = true; + mLatch.countDown(); + } + + @Override + public void onIteratorNextFailure(int i) throws RemoteException { + mIteratorCallbackOnFailureCalled = true; + mLatch.countDown(); + } + + @Override + public IBinder asBinder() { + return null; + } } public class TestQueryCallback implements IExampleStoreCallback { + private IExampleStoreIterator mIterator; + @Override public void onStartQuerySuccess(IExampleStoreIterator iExampleStoreIterator) throws RemoteException { mQueryCallbackOnSuccessCalled = true; + mIterator = iExampleStoreIterator; mLatch.countDown(); } @@ -86,5 +243,18 @@ public class OdpExampleStoreServiceTests { public IBinder asBinder() { return null; } + + public IExampleStoreIterator getIterator() { + return mIterator; + } + } + + @After + public void cleanup() { + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstanceForTest(mContext); + dbHelper.getWritableDatabase().close(); + dbHelper.getReadableDatabase().close(); + dbHelper.close(); } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpFederatedComputeJobServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpFederatedComputeJobServiceTest.java deleted file mode 100644 index cc5c4546..00000000 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpFederatedComputeJobServiceTest.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.ondevicepersonalization.services.federatedcompute; - -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - -import android.app.job.JobParameters; -import android.content.Context; -import android.federatedcompute.FederatedComputeManager; - -import androidx.test.core.app.ApplicationProvider; - -import com.android.dx.mockito.inline.extended.ExtendedMockito; -import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; -import com.android.ondevicepersonalization.services.PhFlagsTestUtil; - -import com.google.common.util.concurrent.MoreExecutors; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.MockitoSession; -import org.mockito.quality.Strictness; - -@RunWith(JUnit4.class) -public class OdpFederatedComputeJobServiceTest { - private final Context mContext = ApplicationProvider.getApplicationContext(); - private OdpFederatedComputeJobService mSpyService; - - @Before - public void setup() throws Exception { - PhFlagsTestUtil.setUpDeviceConfigPermissions(); - PhFlagsTestUtil.disableGlobalKillSwitch(); - mSpyService = spy(new OdpFederatedComputeJobService()); - } - - @Test - public void onStartJobTest() { - MockitoSession session = ExtendedMockito.mockitoSession().spyStatic( - OnDevicePersonalizationExecutors.class).strictness( - Strictness.LENIENT).startMocking(); - try { - FederatedComputeManager mockManager = mock(FederatedComputeManager.class); - doNothing().when(mSpyService).jobFinished(any(), anyBoolean()); - doNothing().when(mockManager).scheduleFederatedCompute(any(), any(), any()); - doReturn(mockManager).when(mSpyService).getSystemService(FederatedComputeManager.class); - ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService()).when( - OnDevicePersonalizationExecutors::getBackgroundExecutor); - ExtendedMockito.doReturn(MoreExecutors.newDirectExecutorService()).when( - OnDevicePersonalizationExecutors::getLightweightExecutor); - - boolean result = mSpyService.onStartJob(mock(JobParameters.class)); - assertTrue(result); - verify(mSpyService, times(1)).jobFinished(any(), eq(false)); - verify(mockManager, times(1)) - .scheduleFederatedCompute(any(), any(), any()); - } finally { - session.finishMocking(); - } - } - - @Test - public void onStartJobTestKillSwitchEnabled() { - PhFlagsTestUtil.enableGlobalKillSwitch(); - MockitoSession session = ExtendedMockito.mockitoSession().strictness( - Strictness.LENIENT).startMocking(); - try { - FederatedComputeManager mockManager = mock(FederatedComputeManager.class); - doNothing().when(mSpyService).jobFinished(any(), anyBoolean()); - boolean result = mSpyService.onStartJob(mock(JobParameters.class)); - assertTrue(result); - verify(mSpyService, times(1)).jobFinished(any(), eq(false)); - verify(mockManager, times(0)) - .scheduleFederatedCompute(any(), any(), any()); - } finally { - session.finishMocking(); - } - } - - @Test - public void onStopJobTest() { - MockitoSession session = ExtendedMockito.mockitoSession().strictness( - Strictness.LENIENT).startMocking(); - try { - assertTrue(mSpyService.onStopJob(mock(JobParameters.class))); - } finally { - session.finishMocking(); - } - } -} diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingServiceTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingServiceTests.java index cc16745d..2db657e3 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingServiceTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/federatedcompute/OdpResultHandlingServiceTests.java @@ -17,8 +17,10 @@ package com.android.ondevicepersonalization.services.federatedcompute; import static android.federatedcompute.common.ClientConstants.RESULT_HANDLING_SERVICE_ACTION; -import static android.federatedcompute.common.TrainingInterval.SCHEDULING_MODE_ONE_TIME; +import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; +import static android.federatedcompute.common.ClientConstants.STATUS_TRAINING_FAILED; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -27,23 +29,27 @@ import android.content.Context; import android.content.Intent; import android.federatedcompute.aidl.IFederatedComputeCallback; import android.federatedcompute.aidl.IResultHandlingService; +import android.federatedcompute.common.ClientConstants; import android.federatedcompute.common.ExampleConsumption; -import android.federatedcompute.common.TrainingInterval; -import android.federatedcompute.common.TrainingOptions; -import android.net.Uri; +import android.os.Bundle; import android.os.IBinder; import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; import androidx.test.rule.ServiceTestRule; -import com.google.common.collect.ImmutableList; +import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import org.junit.After; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.ArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -56,37 +62,85 @@ public class OdpResultHandlingServiceTests { private boolean mCallbackOnSuccessCalled = false; private boolean mCallbackOnFailureCalled = false; + private EventsDao mEventsDao; + + @Before + public void setup() { + mEventsDao = EventsDao.getInstanceForTest(mContext); + } + + @After + public void cleanup() { + OnDevicePersonalizationDbHelper dbHelper = + OnDevicePersonalizationDbHelper.getInstanceForTest(mContext); + dbHelper.getWritableDatabase().close(); + dbHelper.getReadableDatabase().close(); + dbHelper.close(); + } + @Test public void testHandleResult() throws Exception { Intent mIntent = new Intent(); mIntent.setAction(RESULT_HANDLING_SERVICE_ACTION).setPackage(mContext.getPackageName()); - mIntent.setData( - new Uri.Builder() - .scheme("app") - .authority(mContext.getPackageName()) - .path("collection") + IBinder binder = serviceRule.bindService(mIntent); + assertNotNull(binder); + + Bundle input = new Bundle(); + ContextData contextData = new ContextData(mContext.getPackageName()); + input.putByteArray( + ClientConstants.EXTRA_CONTEXT_DATA, ContextData.toByteArray(contextData)); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, "population"); + input.putString(ClientConstants.EXTRA_TASK_NAME, "task_name"); + input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS); + ArrayList<ExampleConsumption> exampleConsumptions = new ArrayList<>(); + exampleConsumptions.add( + new ExampleConsumption.Builder() + .setTaskName("task_name") + .setExampleCount(100) + .setSelectionCriteria(new byte[] {10, 0, 1}) + .setResumptionToken(new byte[] {10, 0, 1}) .build()); + input.putParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptions); + + ((IResultHandlingService.Stub) binder).handleResult(input, new TestCallback()); + mLatch.await(1000, TimeUnit.MILLISECONDS); + assertTrue(mCallbackOnSuccessCalled); + assertFalse(mCallbackOnFailureCalled); + + EventState state1 = + mEventsDao.getEventState( + OdpExampleStoreService.getTaskIdentifier("population", "task_name"), + mContext.getPackageName()); + assertArrayEquals(new byte[] {10, 0, 1}, state1.getToken()); + } + + @Test + public void testHandleResultTrainingFailed() throws Exception { + Intent mIntent = new Intent(); + mIntent.setAction(RESULT_HANDLING_SERVICE_ACTION).setPackage(mContext.getPackageName()); IBinder binder = serviceRule.bindService(mIntent); assertNotNull(binder); - TrainingOptions trainingOptions = - new TrainingOptions.Builder() - .setPopulationName("population") - .setTrainingInterval( - new TrainingInterval.Builder() - .setSchedulingMode(SCHEDULING_MODE_ONE_TIME) - .build()) - .build(); - ImmutableList<ExampleConsumption> exampleConsumptions = - ImmutableList.of( - new ExampleConsumption.Builder() - .setCollectionName("collection") - .setExampleCount(100) - .setSelectionCriteria(new byte[] {10, 0, 1}) - .build()); - - ((IResultHandlingService.Stub) binder) - .handleResult(trainingOptions, true, exampleConsumptions, new TestCallback()); + Bundle input = new Bundle(); + ContextData contextData = new ContextData(mContext.getPackageName()); + input.putByteArray( + ClientConstants.EXTRA_CONTEXT_DATA, ContextData.toByteArray(contextData)); + input.putString(ClientConstants.EXTRA_POPULATION_NAME, "population"); + input.putString(ClientConstants.EXTRA_TASK_NAME, "task_name"); + input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_TRAINING_FAILED); + ArrayList<ExampleConsumption> exampleConsumptions = new ArrayList<>(); + exampleConsumptions.add( + new ExampleConsumption.Builder() + .setTaskName("task") + .setExampleCount(100) + .setSelectionCriteria(new byte[] {10, 0, 1}) + .setResumptionToken(new byte[] {10, 0, 1}) + .build()); + input.putParcelableArrayList( + ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, exampleConsumptions); + + ((IResultHandlingService.Stub) binder).handleResult(input, new TestCallback()); mLatch.await(1000, TimeUnit.MILLISECONDS); assertTrue(mCallbackOnSuccessCalled); assertFalse(mCallbackOnFailureCalled); diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobServiceTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobServiceTest.java index 01b88e74..44c1977f 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobServiceTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/maintenance/OnDevicePersonalizationMaintenanceJobServiceTest.java @@ -17,6 +17,8 @@ package com.android.ondevicepersonalization.services.maintenance; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -37,6 +39,11 @@ import com.android.dx.mockito.inline.extended.ExtendedMockito; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; import com.android.ondevicepersonalization.services.PhFlagsTestUtil; import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; +import com.android.ondevicepersonalization.services.data.events.Event; +import com.android.ondevicepersonalization.services.data.events.EventState; +import com.android.ondevicepersonalization.services.data.events.EventsDao; +import com.android.ondevicepersonalization.services.data.events.Query; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.data.vendor.OnDevicePersonalizationVendorDataDao; import com.android.ondevicepersonalization.services.data.vendor.VendorData; import com.android.ondevicepersonalization.services.util.PackageUtils; @@ -51,6 +58,7 @@ import org.junit.runners.JUnit4; import org.mockito.MockitoSession; import org.mockito.quality.Strictness; +import java.nio.charset.StandardCharsets; import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; @@ -60,12 +68,17 @@ import java.util.Map; public class OnDevicePersonalizationMaintenanceJobServiceTest { private static final String TEST_OWNER = "owner"; private static final String TEST_CERT_DIGEST = "certDigest"; + private static final String TASK_IDENTIFIER = "task"; private final Context mContext = ApplicationProvider.getApplicationContext(); private OnDevicePersonalizationVendorDataDao mTestDao; private OnDevicePersonalizationVendorDataDao mDao; + + private EventsDao mEventsDao; private OnDevicePersonalizationMaintenanceJobService mSpyService; + private UserPrivacyStatus mPrivacyStatus = UserPrivacyStatus.getInstance(); private static void addTestData(long timestamp, OnDevicePersonalizationVendorDataDao dao) { + // Add vendor data List<VendorData> dataList = new ArrayList<>(); dataList.add(new VendorData.Builder().setKey("key").setData(new byte[10]).build()); dataList.add(new VendorData.Builder().setKey("key2").setData(new byte[10]).build()); @@ -76,15 +89,44 @@ public class OnDevicePersonalizationMaintenanceJobServiceTest { timestamp)); } + private void addEventData(String packageName, long timestamp) { + Query query = new Query.Builder() + .setTimeMillis(timestamp) + .setServicePackageName(packageName) + .setQueryData("query".getBytes(StandardCharsets.UTF_8)) + .build(); + long queryId = mEventsDao.insertQuery(query); + + Event event = new Event.Builder() + .setType(1) + .setEventData("event".getBytes(StandardCharsets.UTF_8)) + .setServicePackageName(packageName) + .setQueryId(queryId) + .setTimeMillis(timestamp) + .setRowIndex(0) + .build(); + mEventsDao.insertEvent(event); + + EventState eventState = new EventState.Builder() + .setTaskIdentifier(TASK_IDENTIFIER) + .setServicePackageName(packageName) + .setToken(new byte[]{1}) + .build(); + mEventsDao.updateOrInsertEventState(eventState); + } + @Before public void setup() throws Exception { PhFlagsTestUtil.setUpDeviceConfigPermissions(); PhFlagsTestUtil.disableGlobalKillSwitch(); + PhFlagsTestUtil.disablePersonalizationStatusOverride(); + mPrivacyStatus.setPersonalizationStatusEnabled(true); mTestDao = OnDevicePersonalizationVendorDataDao.getInstanceForTest(mContext, TEST_OWNER, TEST_CERT_DIGEST); mDao = OnDevicePersonalizationVendorDataDao.getInstanceForTest(mContext, mContext.getPackageName(), PackageUtils.getCertDigest(mContext, mContext.getPackageName())); + mEventsDao = EventsDao.getInstanceForTest(mContext); mSpyService = spy(new OnDevicePersonalizationMaintenanceJobService()); } @@ -125,6 +167,20 @@ public class OnDevicePersonalizationMaintenanceJobServiceTest { } @Test + public void onStartJobTestPersonalizationBlocked() { + mPrivacyStatus.setPersonalizationStatusEnabled(false); + MockitoSession session = ExtendedMockito.mockitoSession().startMocking(); + try { + doNothing().when(mSpyService).jobFinished(any(), anyBoolean()); + boolean result = mSpyService.onStartJob(mock(JobParameters.class)); + assertTrue(result); + verify(mSpyService, times(1)).jobFinished(any(), eq(false)); + } finally { + session.finishMocking(); + } + } + + @Test public void onStopJobTest() { MockitoSession session = ExtendedMockito.mockitoSession().strictness( Strictness.LENIENT).startMocking(); @@ -139,12 +195,23 @@ public class OnDevicePersonalizationMaintenanceJobServiceTest { public void testVendorDataCleanup() throws Exception { addTestData(System.currentTimeMillis(), mTestDao); addTestData(System.currentTimeMillis(), mDao); + addEventData(mContext.getPackageName(), System.currentTimeMillis()); + addEventData(mContext.getPackageName(), 100L); + addEventData(TEST_OWNER, System.currentTimeMillis()); + OnDevicePersonalizationMaintenanceJobService.cleanupVendorData(mContext); List<Map.Entry<String, String>> vendors = OnDevicePersonalizationVendorDataDao.getVendors( mContext); assertEquals(1, vendors.size()); assertEquals(new AbstractMap.SimpleEntry<>(mContext.getPackageName(), PackageUtils.getCertDigest(mContext, mContext.getPackageName())), vendors.get(0)); + + assertNull(mEventsDao.getEventState(TASK_IDENTIFIER, TEST_OWNER)); + assertNotNull(mEventsDao.getEventState(TASK_IDENTIFIER, mContext.getPackageName())); + + assertEquals(2, mEventsDao.readAllNewRowsForPackage(TEST_OWNER, 0, 0).size()); + assertEquals(2, + mEventsDao.readAllNewRowsForPackage(mContext.getPackageName(), 0, 0).size()); } @After diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParserTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParserTests.java index bcfaf3a7..80b2cb68 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParserTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigParserTests.java @@ -32,6 +32,7 @@ public class AppManifestConfigParserTests { "<on-device-personalization>" + " <service name=\"com.example.TestService\" >" + " <download-settings url=\"http://example.com/get\" />" + + " <federated-compute-settings url=\"http://google.com/get\" />" + " </service>" + "</on-device-personalization>"; @@ -45,5 +46,6 @@ public class AppManifestConfigParserTests { AppManifestConfig config = AppManifestConfigParser.getConfig(xpp); assertEquals("com.example.TestService", config.getServiceName()); assertEquals("http://example.com/get", config.getDownloadUrl()); + assertEquals("http://google.com/get", config.getFcRemoteServerUrl()); } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigTests.java index 7f9d94d7..50b1f034 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/manifest/AppManifestConfigTests.java @@ -34,6 +34,9 @@ import org.junit.runners.JUnit4; public class AppManifestConfigTests { private static final String BASE_DOWNLOAD_URL = "android.resource://com.android.ondevicepersonalization.servicetests/raw/test_data1"; + + private static final String FC_SERVER_URL = + "https://google.com"; private final Context mContext = ApplicationProvider.getApplicationContext(); @Test @@ -54,6 +57,7 @@ public class AppManifestConfigTests { AppManifestConfigHelper.getAppManifestConfig(mContext, mContext.getPackageName()); assertEquals(BASE_DOWNLOAD_URL, config.getDownloadUrl()); assertEquals("com.test.TestPersonalizationService", config.getServiceName()); + assertEquals(FC_SERVER_URL, config.getFcRemoteServerUrl()); } @Test diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/policyengine/UserDataReaderTest.kt b/tests/servicetests/src/com/android/ondevicepersonalization/services/policyengine/UserDataReaderTest.kt index 217621ba..7353e4da 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/policyengine/UserDataReaderTest.kt +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/policyengine/UserDataReaderTest.kt @@ -26,7 +26,7 @@ import org.junit.Test import android.util.Log -import android.adservices.ondevicepersonalization.AppInstallInfo +import android.adservices.ondevicepersonalization.AppInfo import android.adservices.ondevicepersonalization.AppUsageStatus import android.adservices.ondevicepersonalization.DeviceMetrics import android.adservices.ondevicepersonalization.OSVersion @@ -143,13 +143,13 @@ class UserDataReaderTest : ProcessorNode { @Test fun testAppInstallInfo() { - var appInstallStatus1 = AppInstallInfo.Builder() + var appInstallStatus1 = AppInfo.Builder() .setInstalled(true) .build() var parcel = Parcel.obtain() appInstallStatus1.writeToParcel(parcel, 0) parcel.setDataPosition(0); - var appInstallStatus2 = AppInstallInfo.CREATOR.createFromParcel(parcel) + var appInstallStatus2 = AppInfo.CREATOR.createFromParcel(parcel) assertThat(appInstallStatus1).isEqualTo(appInstallStatus2) assertThat(appInstallStatus1.hashCode()).isEqualTo(appInstallStatus2.hashCode()) assertThat(appInstallStatus1.describeContents()).isEqualTo(0) @@ -242,7 +242,7 @@ class UserDataReaderTest : ProcessorNode { @Test fun testUserData() { - val appInstalledHistory: Map<String, AppInstallInfo> = mapOf<String, AppInstallInfo>(); + val appInstalledHistory: Map<String, AppInfo> = mapOf<String, AppInfo>(); val appUsageHistory: List<AppUsageStatus> = listOf(); var location = Location.Builder() .setTimestampSeconds(111111) @@ -258,10 +258,8 @@ class UserDataReaderTest : ProcessorNode { .setAvailableStorageBytes(222) .setBatteryPercentage(33) .setCarrier("AT_T") - .setConnectionType(2) - .setNetworkConnectionSpeedKbps(666) - .setNetworkMetered(true) - .setAppInstallInfo(appInstalledHistory) + .setDataNetworkType(1) + .setAppInfos(appInstalledHistory) .setAppUsageHistory(appUsageHistory) .setCurrentLocation(location) .setLocationHistory(locationHistory) @@ -282,9 +280,8 @@ class UserDataReaderTest : ProcessorNode { assertThat(userData.getBatteryPercentage()).isEqualTo(ref.batteryPercentage) assertThat(userData.getCarrier()).isEqualTo(ref.carrier.toString()) - assertThat(userData.getConnectionType()).isEqualTo(ref.connectionType.ordinal) - assertThat(userData.getNetworkConnectionSpeedKbps()).isEqualTo(ref.connectionSpeedKbps) - assertThat(userData.isNetworkMetered()).isEqualTo(ref.networkMetered) + assertThat(userData.getNetworkCapabilities()).isEqualTo(ref.networkCapabilities) + assertThat(userData.getDataNetworkType()).isEqualTo(ref.dataNetworkType) val currentLocation: Location = userData.getCurrentLocation() @@ -294,7 +291,7 @@ class UserDataReaderTest : ProcessorNode { assertThat(currentLocation.getLocationProvider()).isEqualTo(rawUserData.currentLocation.provider.ordinal) assertThat(currentLocation.isPreciseLocation()).isEqualTo(rawUserData.currentLocation.isPreciseLocation) - assertThat(userData.getAppInstallInfo().size).isEqualTo(rawUserData.appsInfo.size) + assertThat(userData.getAppInfos().size).isEqualTo(rawUserData.appsInfo.size) assertThat(userData.getAppUsageHistory().size).isEqualTo(rawUserData.appUsageHistory.size) assertThat(userData.getLocationHistory().size).isEqualTo(rawUserData.locationHistory.size) } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/process/ProcessUtilsTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/process/ProcessRunnerTest.java index 6124a237..7f7bdb66 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/process/ProcessUtilsTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/process/ProcessRunnerTest.java @@ -19,6 +19,8 @@ package com.android.ondevicepersonalization.services.process; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import androidx.test.core.app.ApplicationProvider; + import com.android.ondevicepersonalization.libraries.plugin.PluginInfo; import com.google.common.collect.ImmutableList; @@ -28,15 +30,18 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class ProcessUtilsTest { +public class ProcessRunnerTest { + ProcessRunner mProcessRunner = new ProcessRunner( + ApplicationProvider.getApplicationContext(), + new ProcessRunner.Injector()); @Test public void testGetArchiveList_NullApkList() throws Exception { - assertTrue(ProcessUtils.getArchiveList(null).isEmpty()); + assertTrue(ProcessRunner.getArchiveList(null).isEmpty()); } @Test public void testGetArchiveList() throws Exception { - ImmutableList<PluginInfo.ArchiveInfo> result = ProcessUtils.getArchiveList("fakeApk"); + ImmutableList<PluginInfo.ArchiveInfo> result = ProcessRunner.getArchiveList("fakeApk"); assertEquals(1, result.size()); assertEquals("fakeApk", result.get(0).packageName()); } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/request/AppRequestFlowTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/request/AppRequestFlowTest.java index 24c8ecfa..88890b26 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/request/AppRequestFlowTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/request/AppRequestFlowTest.java @@ -19,17 +19,25 @@ package com.android.ondevicepersonalization.services.request; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import android.adservices.ondevicepersonalization.Constants; import android.adservices.ondevicepersonalization.aidl.IExecuteCallback; import android.content.ComponentName; +import android.content.ContentValues; import android.content.Context; import android.os.PersistableBundle; import androidx.test.core.app.ApplicationProvider; +import com.android.ondevicepersonalization.services.PhFlagsTestUtil; import com.android.ondevicepersonalization.services.data.OnDevicePersonalizationDbHelper; +import com.android.ondevicepersonalization.services.data.events.EventsContract; import com.android.ondevicepersonalization.services.data.events.EventsDao; import com.android.ondevicepersonalization.services.data.events.QueriesContract; +import com.android.ondevicepersonalization.services.data.events.Query; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; +import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils; +import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import org.junit.After; @@ -38,6 +46,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -46,6 +55,7 @@ public class AppRequestFlowTest { private final Context mContext = ApplicationProvider.getApplicationContext(); private final CountDownLatch mLatch = new CountDownLatch(1); private OnDevicePersonalizationDbHelper mDbHelper; + private UserPrivacyStatus mUserPrivacyStatus = UserPrivacyStatus.getInstance(); private String mRenderedContent; private boolean mGenerateHtmlCalled; @@ -53,11 +63,26 @@ public class AppRequestFlowTest { private boolean mDisplayHtmlCalled; private boolean mCallbackSuccess; private boolean mCallbackError; + private int mCallbackErrorCode; @Before public void setup() { mDbHelper = OnDevicePersonalizationDbHelper.getInstanceForTest(mContext); + ArrayList<ContentValues> rows = new ArrayList<>(); + ContentValues row1 = new ContentValues(); + row1.put("a", 1); + rows.add(row1); + ContentValues row2 = new ContentValues(); + row2.put("b", 2); + rows.add(row2); + byte[] queryDataBytes = OnDevicePersonalizationFlatbufferUtils.createQueryData( + "com.example.test", "AABBCCDD", rows); + EventsDao.getInstanceForTest(mContext).insertQuery( + new Query.Builder().setServicePackageName(mContext.getPackageName()).setQueryData( + queryDataBytes).build()); EventsDao.getInstanceForTest(mContext); + PhFlagsTestUtil.disablePersonalizationStatusOverride(); + mUserPrivacyStatus.setPersonalizationStatusEnabled(true); } @After @@ -73,23 +98,50 @@ public class AppRequestFlowTest { "abc", new ComponentName(mContext.getPackageName(), "com.test.TestPersonalizationService"), PersistableBundle.EMPTY, - new TestCallback(), mContext, MoreExecutors.newDirectExecutorService()); + new TestCallback(), mContext, 100L, new TestInjector()); appRequestFlow.run(); mLatch.await(); assertTrue(mCallbackSuccess); - assertEquals(1, + assertEquals(2, mDbHelper.getReadableDatabase().query(QueriesContract.QueriesEntry.TABLE_NAME, null, null, null, null, null, null).getCount()); + assertEquals(1, + mDbHelper.getReadableDatabase().query(EventsContract.EventsEntry.TABLE_NAME, null, + null, null, null, null, null).getCount()); + } + + @Test + public void testRunAppRequestFlowPersonalizationDisabled() throws Exception { + mUserPrivacyStatus.setPersonalizationStatusEnabled(false); + AppRequestFlow appRequestFlow = new AppRequestFlow( + "abc", + new ComponentName(mContext.getPackageName(), "com.test.TestPersonalizationService"), + PersistableBundle.EMPTY, + new TestCallback(), mContext, 100L, new TestInjector()); + appRequestFlow.run(); + mLatch.await(); + assertTrue(mCallbackError); + assertEquals(Constants.STATUS_PERSONALIZATION_DISABLED, mCallbackErrorCode); } class TestCallback extends IExecuteCallback.Stub { - @Override public void onSuccess(List<String> tokens) { + @Override + public void onSuccess(List<String> tokens) { mCallbackSuccess = true; mLatch.countDown(); } - @Override public void onError(int errorCode) { + + @Override + public void onError(int errorCode) { mCallbackError = true; + mCallbackErrorCode = errorCode; mLatch.countDown(); } } + + class TestInjector extends AppRequestFlow.Injector { + ListeningExecutorService getExecutor() { + return MoreExecutors.newDirectExecutorService(); + } + } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/request/RenderFlowTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/request/RenderFlowTest.java index f254799f..61e2d373 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/request/RenderFlowTest.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/request/RenderFlowTest.java @@ -19,7 +19,8 @@ package com.android.ondevicepersonalization.services.request; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import android.adservices.ondevicepersonalization.RenderOutput; +import android.adservices.ondevicepersonalization.Constants; +import android.adservices.ondevicepersonalization.RenderOutputParcel; import android.adservices.ondevicepersonalization.RenderingConfig; import android.adservices.ondevicepersonalization.RequestLogRecord; import android.adservices.ondevicepersonalization.aidl.IRequestSurfacePackageCallback; @@ -32,6 +33,8 @@ import android.view.SurfaceControlViewHost.SurfacePackage; import androidx.test.core.app.ApplicationProvider; import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; +import com.android.ondevicepersonalization.services.PhFlagsTestUtil; +import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; import com.android.ondevicepersonalization.services.display.DisplayHelper; import com.android.ondevicepersonalization.services.util.CryptUtils; @@ -40,6 +43,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -50,6 +54,7 @@ import java.util.concurrent.CountDownLatch; public class RenderFlowTest { private final Context mContext = ApplicationProvider.getApplicationContext(); private final CountDownLatch mLatch = new CountDownLatch(1); + private UserPrivacyStatus mUserPrivacyStatus = UserPrivacyStatus.getInstance(); private String mRenderedContent; private boolean mGenerateHtmlCalled; @@ -57,6 +62,13 @@ public class RenderFlowTest { private boolean mDisplayHtmlCalled; private boolean mCallbackSuccess; private boolean mCallbackError; + private int mCallbackErrorCode; + + @Before + public void setUp() { + PhFlagsTestUtil.disablePersonalizationStatusOverride(); + mUserPrivacyStatus.setPersonalizationStatusEnabled(true); + } @Test public void testRunRenderFlow() throws Exception { @@ -68,6 +80,7 @@ public class RenderFlowTest { 50, new TestCallback(), mContext, + 100L, new TestInjector(), new TestDisplayHelper()); flow.run(); @@ -79,6 +92,26 @@ public class RenderFlowTest { } @Test + public void testRunRenderFlowPersonalizationDisabled() throws Exception { + mUserPrivacyStatus.setPersonalizationStatusEnabled(false); + RenderFlow flow = new RenderFlow( + "token", + new Binder(), + 0, + 100, + 50, + new TestCallback(), + mContext, + 100L, + new TestInjector(), + new TestDisplayHelper()); + flow.run(); + mLatch.await(); + assertTrue(mCallbackError); + assertEquals(Constants.STATUS_PERSONALIZATION_DISABLED, mCallbackErrorCode); + } + + @Test public void testDefaultInjector() throws Exception { RenderFlow.Injector injector = new RenderFlow.Injector(); assertEquals(OnDevicePersonalizationExecutors.getBackgroundExecutor(), @@ -123,7 +156,8 @@ public class RenderFlowTest { super(mContext); } - @Override public String generateHtml(RenderOutput renderContentResult, String packageName) { + @Override public String generateHtml( + RenderOutputParcel renderContentResult, String packageName) { mRenderedContent = renderContentResult.getContent(); mGenerateHtmlCalled = true; return mRenderedContent; @@ -146,6 +180,7 @@ public class RenderFlowTest { } @Override public void onError(int errorCode) { mCallbackError = true; + mCallbackErrorCode = errorCode; mLatch.countDown(); } } diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/util/DebugUtilsTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/util/DebugUtilsTest.java new file mode 100644 index 00000000..61a92722 --- /dev/null +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/util/DebugUtilsTest.java @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.util; + +import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn; +import static com.android.dx.mockito.inline.extended.ExtendedMockito.when; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.quality.Strictness.LENIENT; + +import android.content.Context; +import android.os.Build; +import android.provider.Settings; + +import androidx.test.core.app.ApplicationProvider; + +import com.android.dx.mockito.inline.extended.ExtendedMockito; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockitoSession; + +public class DebugUtilsTest { + private Context mContext = ApplicationProvider.getApplicationContext(); + private MockitoSession mSession; + + @Before + public void setUp() { + mSession = ExtendedMockito.mockitoSession() + .mockStatic(Build.class) + .mockStatic(Settings.Global.class) + .strictness(LENIENT) + .startMocking(); + } + + @After + public void tearDown() { + mSession.finishMocking(); + } + + @Test + public void isDeveloperModeEnabledReturnsTrueIfDeviceDebuggableAndDevOptionsOff() { + doReturn(true).when(Build::isDebuggable); + disableDeveloperOptions(); + assertTrue(DebugUtils.isDeveloperModeEnabled(mContext)); + } + + @Test + public void isDeveloperModeEnabledReturnsTrueIfDeviceDebuggableAndDevOptionsOn() { + doReturn(true).when(Build::isDebuggable); + enableDeveloperOptions(); + assertTrue(DebugUtils.isDeveloperModeEnabled(mContext)); + } + + @Test + public void isDeveloperModeEnabledReturnsFalseIfDeviceNotDebuggableAndDevOptionsOff() { + doReturn(false).when(Build::isDebuggable); + disableDeveloperOptions(); + assertFalse(DebugUtils.isDeveloperModeEnabled(mContext)); + } + + @Test + public void isDeveloperModeEnabledReturnsTrueIfDeviceNotDebuggableAndDevOptionsOn() { + doReturn(false).when(Build::isDebuggable); + enableDeveloperOptions(); + assertTrue(DebugUtils.isDeveloperModeEnabled(mContext)); + } + + private void enableDeveloperOptions() { + when(Settings.Global.getInt( + mContext.getContentResolver(), Settings.Global.DEVELOPMENT_SETTINGS_ENABLED, 0)) + .thenReturn(1); + } + + private void disableDeveloperOptions() { + when(Settings.Global.getInt( + mContext.getContentResolver(), Settings.Global.DEVELOPMENT_SETTINGS_ENABLED, 0)) + .thenReturn(0); + } +} diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtilsTests.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtilsTests.java index 702df3be..b1b603e0 100644 --- a/tests/servicetests/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtilsTests.java +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/util/OnDevicePersonalizationFlatbufferUtilsTests.java @@ -161,6 +161,22 @@ public class OnDevicePersonalizationFlatbufferUtilsTests { } @Test + public void testGetContentValuesLengthFromQueryData() { + ArrayList<ContentValues> rows = new ArrayList<>(); + ContentValues row1 = new ContentValues(); + row1.put("a", 1); + rows.add(row1); + ContentValues row2 = new ContentValues(); + row2.put("b", 2); + rows.add(row2); + byte[] queryDataBytes = OnDevicePersonalizationFlatbufferUtils.createQueryData( + "com.example.test", "AABBCCDD", rows); + + assertEquals(2, OnDevicePersonalizationFlatbufferUtils.getContentValuesLengthFromQueryData( + queryDataBytes)); + } + + @Test public void testGetContentValuesFromEventData() { ContentValues data = new ContentValues(); data.put("a", 1); diff --git a/tests/servicetests/src/com/android/ondevicepersonalization/services/util/StatsUtilsTest.java b/tests/servicetests/src/com/android/ondevicepersonalization/services/util/StatsUtilsTest.java new file mode 100644 index 00000000..6eb4b66d --- /dev/null +++ b/tests/servicetests/src/com/android/ondevicepersonalization/services/util/StatsUtilsTest.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.ondevicepersonalization.services.util; + +import static org.junit.Assert.assertEquals; + +import android.adservices.ondevicepersonalization.CalleeMetadata; +import android.adservices.ondevicepersonalization.Constants; +import android.os.Bundle; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class StatsUtilsTest { + @Test public void testServiceReturnsElapsedTime() { + CalleeMetadata metadata = new CalleeMetadata.Builder().setElapsedTimeMillis(100).build(); + Bundle bundle = new Bundle(); + bundle.putParcelable(Constants.EXTRA_CALLEE_METADATA, metadata); + assertEquals(50, StatsUtils.getOverheadLatencyMillis(150, bundle)); + } + + @Test public void testServiceReturnsNoResult() { + assertEquals(0, StatsUtils.getOverheadLatencyMillis(150, null)); + } + + @Test public void testServiceReturnsNoMetadata() { + assertEquals(0, StatsUtils.getOverheadLatencyMillis(150, new Bundle())); + } + + @Test public void testServiceReturnsNegativeElapsedTime() { + CalleeMetadata metadata = new CalleeMetadata.Builder().setElapsedTimeMillis(-1).build(); + Bundle bundle = new Bundle(); + bundle.putParcelable(Constants.EXTRA_CALLEE_METADATA, metadata); + assertEquals(0, StatsUtils.getOverheadLatencyMillis(150, bundle)); + } + + @Test public void testServiceReturnsTooHighElapsedTime() { + CalleeMetadata metadata = new CalleeMetadata.Builder().setElapsedTimeMillis(300).build(); + Bundle bundle = new Bundle(); + bundle.putParcelable(Constants.EXTRA_CALLEE_METADATA, metadata); + assertEquals(0, StatsUtils.getOverheadLatencyMillis(150, bundle)); + } +} diff --git a/tests/servicetests/src/com/test/TestPersonalizationHandler.java b/tests/servicetests/src/com/test/TestPersonalizationHandler.java index 63fea98e..bd629c76 100644 --- a/tests/servicetests/src/com/test/TestPersonalizationHandler.java +++ b/tests/servicetests/src/com/test/TestPersonalizationHandler.java @@ -16,9 +16,11 @@ package com.test; -import android.adservices.ondevicepersonalization.DownloadInput; -import android.adservices.ondevicepersonalization.DownloadOutput; +import android.adservices.ondevicepersonalization.DownloadCompletedInput; +import android.adservices.ondevicepersonalization.DownloadCompletedOutput; +import android.adservices.ondevicepersonalization.EventInput; import android.adservices.ondevicepersonalization.EventLogRecord; +import android.adservices.ondevicepersonalization.EventOutput; import android.adservices.ondevicepersonalization.ExecuteInput; import android.adservices.ondevicepersonalization.ExecuteOutput; import android.adservices.ondevicepersonalization.IsolatedWorker; @@ -27,8 +29,8 @@ import android.adservices.ondevicepersonalization.RenderInput; import android.adservices.ondevicepersonalization.RenderOutput; import android.adservices.ondevicepersonalization.RenderingConfig; import android.adservices.ondevicepersonalization.RequestLogRecord; -import android.adservices.ondevicepersonalization.WebViewEventInput; -import android.adservices.ondevicepersonalization.WebViewEventOutput; +import android.adservices.ondevicepersonalization.TrainingExampleInput; +import android.adservices.ondevicepersonalization.TrainingExampleOutput; import android.annotation.NonNull; import android.content.ContentValues; import android.util.Log; @@ -50,63 +52,67 @@ public class TestPersonalizationHandler implements IsolatedWorker { } @Override - public void onDownload(DownloadInput input, Consumer<DownloadOutput> consumer) { + public void onDownloadCompleted( + DownloadCompletedInput input, Consumer<DownloadCompletedOutput> consumer) { try { Log.d(TAG, "Starting filterData."); Log.d(TAG, "Data: " + input.getData()); - Log.d(TAG, "Existing keyExtra: " - + Arrays.toString(mRemoteData.get("keyExtra"))); + Log.d(TAG, "Existing keyExtra: " + Arrays.toString(mRemoteData.get("keyExtra"))); Log.d(TAG, "Existing keySet: " + mRemoteData.keySet()); - List<String> keysToRetain = - getFilteredKeys(input.getData()); + List<String> keysToRetain = getFilteredKeys(input.getData()); keysToRetain.add("keyExtra"); // Get the keys to keep from the downloaded data - DownloadOutput result = - new DownloadOutput.Builder() - .setRetainedKeys(keysToRetain) - .build(); + DownloadCompletedOutput result = + new DownloadCompletedOutput.Builder().setRetainedKeys(keysToRetain).build(); consumer.accept(result); } catch (Exception e) { Log.e(TAG, "Error occurred in onDownload", e); } } - @Override public void onExecute( - @NonNull ExecuteInput input, - @NonNull Consumer<ExecuteOutput> consumer - ) { + @Override + public void onExecute(@NonNull ExecuteInput input, @NonNull Consumer<ExecuteOutput> consumer) { Log.d(TAG, "onExecute() started."); ContentValues logData = new ContentValues(); logData.put("id", "bid1"); logData.put("pr", 5.0); - ExecuteOutput result = new ExecuteOutput.Builder() - .setRequestLogRecord(new RequestLogRecord.Builder().addRow(logData).build()) - .addRenderingConfig( - new RenderingConfig.Builder().addKey("bid1").build() - ) - .build(); + ExecuteOutput result = + new ExecuteOutput.Builder() + .setRequestLogRecord(new RequestLogRecord.Builder().addRow(logData).build()) + .addRenderingConfig(new RenderingConfig.Builder().addKey("bid1").build()) + .addEventLogRecord( + new EventLogRecord.Builder() + .setData(logData) + .setRequestLogRecord( + new RequestLogRecord.Builder() + .addRow(logData) + .addRow(logData) + .setRequestId(1) + .build()) + .setType(1) + .setRowIndex(1) + .build()) + .build(); consumer.accept(result); } - @Override public void onRender( - @NonNull RenderInput input, - @NonNull Consumer<RenderOutput> consumer - ) { + @Override + public void onRender(@NonNull RenderInput input, @NonNull Consumer<RenderOutput> consumer) { Log.d(TAG, "onRender() started."); RenderOutput result = new RenderOutput.Builder() - .setContent("<p>RenderResult: " - + String.join(",", input.getRenderingConfig().getKeys()) + "<p>") - .build(); + .setContent( + "<p>RenderResult: " + + String.join(",", input.getRenderingConfig().getKeys()) + + "<p>") + .build(); consumer.accept(result); } - public void onWebViewEvent( - @NonNull WebViewEventInput input, - @NonNull Consumer<WebViewEventOutput> consumer - ) { + @Override + public void onEvent(@NonNull EventInput input, @NonNull Consumer<EventOutput> consumer) { Log.d(TAG, "onEvent() started."); long longValue = 0; if (input.getParameters() != null) { @@ -114,16 +120,15 @@ public class TestPersonalizationHandler implements IsolatedWorker { } ContentValues logData = new ContentValues(); logData.put("x", longValue); - WebViewEventOutput result = - new WebViewEventOutput.Builder() - .setEventLogRecord( - new EventLogRecord.Builder() - .setType(1) - .setRowIndex(0) - .setData(logData) - .build() - ) - .build(); + EventOutput result = + new EventOutput.Builder() + .setEventLogRecord( + new EventLogRecord.Builder() + .setType(1) + .setRowIndex(0) + .setData(logData) + .build()) + .build(); Log.d(TAG, "onEvent() result: " + result.toString()); consumer.accept(result); } @@ -133,4 +138,27 @@ public class TestPersonalizationHandler implements IsolatedWorker { filteredKeys.remove("key3"); return new ArrayList<>(filteredKeys); } + + @Override + public void onTrainingExample( + @NonNull TrainingExampleInput input, + @NonNull Consumer<TrainingExampleOutput> consumer) { + Log.d(TAG, "onTrainingExample() started."); + Log.d(TAG, "Population name: " + input.getPopulationName()); + Log.d(TAG, "Task name: " + input.getTaskName()); + + List<byte[]> examples = new ArrayList<>(); + List<byte[]> tokens = new ArrayList<>(); + examples.add(new byte[] {10}); + examples.add(new byte[] {20}); + tokens.add("token1".getBytes()); + tokens.add("token2".getBytes()); + + TrainingExampleOutput output = + new TrainingExampleOutput.Builder() + .setTrainingExamples(examples) + .setResumptionTokens(tokens) + .build(); + consumer.accept(output); + } } diff --git a/tests/systemserviceapitests/src/com/android/ondevicepersonalization/systemserviceapitests/OdpSystemServiceApiTest.java b/tests/systemserviceapitests/src/com/android/ondevicepersonalization/systemserviceapitests/OdpSystemServiceApiTest.java index 0d209342..57f1cf02 100644 --- a/tests/systemserviceapitests/src/com/android/ondevicepersonalization/systemserviceapitests/OdpSystemServiceApiTest.java +++ b/tests/systemserviceapitests/src/com/android/ondevicepersonalization/systemserviceapitests/OdpSystemServiceApiTest.java @@ -39,8 +39,10 @@ import java.util.concurrent.CountDownLatch; @RunWith(JUnit4.class) public class OdpSystemServiceApiTest { private final Context mContext = ApplicationProvider.getApplicationContext(); - boolean mOnResultCalled = false; - CountDownLatch mLatch = new CountDownLatch(1); + boolean mOnRequestCalled = false; + boolean mSetPersonalizationStatusCalled = false; + boolean mReadPersonalizationStatusCalled = false; + CountDownLatch mLatch = new CountDownLatch(3); @Test public void testInvokeSystemServerServiceSucceedsOnU() throws Exception { @@ -58,12 +60,44 @@ public class OdpSystemServiceApiTest { new Bundle(), new IOnDevicePersonalizationSystemServiceCallback.Stub() { @Override public void onResult(Bundle result) { - mOnResultCalled = true; + mOnRequestCalled = true; + mLatch.countDown(); + } + @Override + public void onError(int errorCode) { + mOnRequestCalled = true; + mLatch.countDown(); + } + }); + + //TODO(b/302991761): delete the file in system server. + service.setPersonalizationStatus(false, + new IOnDevicePersonalizationSystemServiceCallback.Stub() { + @Override public void onResult(Bundle result) { + mSetPersonalizationStatusCalled = true; + mLatch.countDown(); + } + @Override public void onError(int errorCode) { + mSetPersonalizationStatusCalled = true; + mLatch.countDown(); + } + }); + + service.readPersonalizationStatus( + new IOnDevicePersonalizationSystemServiceCallback.Stub() { + @Override public void onResult(Bundle result) { + mReadPersonalizationStatusCalled = true; + mLatch.countDown(); + } + @Override public void onError(int errorCode) { + mReadPersonalizationStatusCalled = true; mLatch.countDown(); } }); mLatch.await(); - assertTrue(mOnResultCalled); + assertTrue(mOnRequestCalled); + assertTrue(mSetPersonalizationStatusCalled); + assertTrue(mReadPersonalizationStatusCalled); } @Test diff --git a/tests/systemserviceimpltests/Android.bp b/tests/systemserviceimpltests/Android.bp index da45992b..723e49cb 100644 --- a/tests/systemserviceimpltests/Android.bp +++ b/tests/systemserviceimpltests/Android.bp @@ -31,6 +31,7 @@ android_test { "androidx.test.ext.junit", "androidx.test.ext.truth", "androidx.test.rules", + "modules-utils-build", "service-ondevicepersonalization.impl", ], sdk_version: "module_current", diff --git a/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/BooleanFileDataStoreTest.java b/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/BooleanFileDataStoreTest.java new file mode 100644 index 00000000..87394edc --- /dev/null +++ b/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/BooleanFileDataStoreTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.ondevicepersonalization; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import android.content.Context; + +import androidx.test.core.app.ApplicationProvider; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.Set; + +public class BooleanFileDataStoreTest { + private static final Context APPLICATION_CONTEXT = ApplicationProvider.getApplicationContext(); + private static final String FILENAME = "BooleanFileDatastoreTest"; + private static final String TEST_KEY = "key"; + private static final int TEST_KEY_COUNT = 10; + + private BooleanFileDataStore mDataStore; + + @Before + public void setup() throws IOException { + mDataStore = new BooleanFileDataStore( + APPLICATION_CONTEXT.getFilesDir().getAbsolutePath(), FILENAME); + mDataStore.initialize(); + } + + @Test + public void testInitializeEmptyBooleanFileDatastore() { + assertTrue(mDataStore.keySet().isEmpty()); + } + + @Test + public void testNullOrEmptyKeyFails() { + assertThrows( + NullPointerException.class, + () -> { + mDataStore.put(null, true); + }); + + assertThrows( + IllegalArgumentException.class, + () -> { + mDataStore.put("", true); + }); + assertThrows( + NullPointerException.class, + () -> { + mDataStore.get(null); + }); + + assertThrows( + IllegalArgumentException.class, + () -> { + mDataStore.get(""); + }); + } + + @Test + public void testPutGetUpdate() throws IOException { + // Empty + assertNull(mDataStore.get(TEST_KEY)); + + // Put + mDataStore.put(TEST_KEY, false); + + // Get + Boolean readValue = mDataStore.get(TEST_KEY); + assertEquals(false, readValue); + + // Update + mDataStore.put(TEST_KEY, true); + readValue = mDataStore.get(TEST_KEY); + assertEquals(true, readValue); + + // Test overwrite + Set<String> keys = mDataStore.keySet(); + assertEquals(keys.size(), 1); + assertTrue(keys.contains(TEST_KEY)); + } + + @Test + public void testClearAll() throws IOException { + for (int i = 0; i < TEST_KEY_COUNT; ++i) { + mDataStore.put(TEST_KEY + i, true); + } + assertEquals(TEST_KEY_COUNT, mDataStore.keySet().size()); + mDataStore.clear(); + mDataStore.initialize(); + assertTrue(mDataStore.keySet().isEmpty()); + } + + @Test + public void testReinitializeFromDisk() throws IOException { + for (int i = 0; i < TEST_KEY_COUNT; ++i) { + mDataStore.put(TEST_KEY + i, true); + } + assertEquals(TEST_KEY_COUNT, mDataStore.keySet().size()); + + // Mock memory crash + mDataStore.clearLocalMapForTesting(); + assertTrue(mDataStore.keySet().isEmpty()); + + // Re-initialize from the file and still be able to recover + mDataStore.initialize(); + assertEquals(TEST_KEY_COUNT, mDataStore.keySet().size()); + for (int i = 0; i < TEST_KEY_COUNT; ++i) { + Boolean readValue = mDataStore.get(TEST_KEY + i); + assertEquals(true, readValue); + } + } + + @After + public void tearDown() { + mDataStore.tearDownForTesting(); + } +} diff --git a/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/OdpSystemServiceImplTest.java b/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/OdpSystemServiceImplTest.java index 2ca59223..12baa289 100644 --- a/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/OdpSystemServiceImplTest.java +++ b/tests/systemserviceimpltests/src/com/android/server/ondevicepersonalization/OdpSystemServiceImplTest.java @@ -16,7 +16,12 @@ package com.android.server.ondevicepersonalization; -import static org.junit.Assert.assertNotEquals; +import static com.android.server.ondevicepersonalization.OnDevicePersonalizationSystemService.KEY_NOT_FOUND_ERROR; +import static com.android.server.ondevicepersonalization.OnDevicePersonalizationSystemService.PERSONALIZATION_STATUS_KEY; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import android.content.Context; @@ -26,6 +31,10 @@ import android.os.Bundle; import androidx.test.core.app.ApplicationProvider; +import com.android.modules.utils.build.SdkLevel; + +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,25 +44,99 @@ import java.util.concurrent.CountDownLatch; @RunWith(JUnit4.class) public class OdpSystemServiceImplTest { private final Context mContext = ApplicationProvider.getApplicationContext(); - boolean mOnResultCalled = false; - CountDownLatch mLatch = new CountDownLatch(1); + private static final String TEST_CONFIG_FILE_IDENTIFIER = "TEST_CONFIG"; + private static final String BAD_TEST_KEY = "non-exist-key"; + private final BooleanFileDataStore mTestDataStore = + new BooleanFileDataStore(mContext.getFilesDir().getAbsolutePath(), + TEST_CONFIG_FILE_IDENTIFIER); + private boolean mOnResultCalled; + private boolean mOnErrorCalled; + private Bundle mResult; + private int mErrorCode; + private CountDownLatch mLatch; + private OnDevicePersonalizationSystemService mService; + private IOnDevicePersonalizationSystemService mBinder; + private IOnDevicePersonalizationSystemServiceCallback mCallback; + + @Before + public void setUp() throws Exception { + mService = new OnDevicePersonalizationSystemService(mContext, mTestDataStore); + mBinder = IOnDevicePersonalizationSystemService.Stub.asInterface(mService); + mOnResultCalled = false; + mOnErrorCalled = false; + mResult = null; + mErrorCode = 0; + mLatch = new CountDownLatch(1); + mCallback = new IOnDevicePersonalizationSystemServiceCallback.Stub() { + @Override + public void onResult(Bundle bundle) { + mOnResultCalled = true; + mResult = bundle; + mLatch.countDown(); + } + + @Override + public void onError(int errorCode) { + mOnErrorCalled = true; + mErrorCode = errorCode; + mLatch.countDown(); + } + }; + assertNotNull(mBinder); + assertNotNull(mCallback); + } + + @Test + public void testSystemServerServiceOnRequest() throws Exception { + if (!SdkLevel.isAtLeastU()) { + return; + } + mBinder.onRequest(new Bundle(), mCallback); + mLatch.await(); + assertTrue(mOnResultCalled); + assertNull(mResult); + } @Test - public void testSystemServerService() throws Exception { - OnDevicePersonalizationSystemService serviceImpl = - new OnDevicePersonalizationSystemService(mContext); - IOnDevicePersonalizationSystemService service = - IOnDevicePersonalizationSystemService.Stub.asInterface(serviceImpl); - assertNotEquals(null, service); - service.onRequest( - new Bundle(), - new IOnDevicePersonalizationSystemServiceCallback.Stub() { - @Override public void onResult(Bundle result) { - mOnResultCalled = true; - mLatch.countDown(); - } - }); + public void testSystemServerServiceSetPersonalizationStatus() throws Exception { + if (!SdkLevel.isAtLeastU()) { + return; + } + mBinder.setPersonalizationStatus(true, mCallback); mLatch.await(); assertTrue(mOnResultCalled); + assertNotNull(mResult); + boolean inputBool = mResult.getBoolean(PERSONALIZATION_STATUS_KEY); + assertTrue(inputBool); + } + + @Test + public void testSystemServerServiceReadPersonalizationStatusSuccess() throws Exception { + if (!SdkLevel.isAtLeastU()) { + return; + } + mTestDataStore.put(PERSONALIZATION_STATUS_KEY, true); + mBinder.readPersonalizationStatus(mCallback); + assertTrue(mOnResultCalled); + assertNotNull(mResult); + boolean inputBool = mResult.getBoolean(PERSONALIZATION_STATUS_KEY); + assertTrue(inputBool); + } + + @Test + public void testSystemServerServiceReadPersonalizationStatusNotFound() throws Exception { + if (!SdkLevel.isAtLeastU()) { + return; + } + mTestDataStore.put(BAD_TEST_KEY, true); + mBinder.readPersonalizationStatus(mCallback); + assertTrue(mOnErrorCalled); + assertNull(mResult); + assertEquals(mErrorCode, KEY_NOT_FOUND_ERROR); + } + + @After + public void cleanUp() { + mTestDataStore.tearDownForTesting(); } } |