diff options
10 files changed, 176 insertions, 265 deletions
diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java deleted file mode 100644 index 7398e1f5..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.encryption; - -/** Interface for crypto libraries to encrypt data */ -public interface Encrypter { - - /** - * encrypt {@code plainText} to cipher text {@code byte[]}. - * - * @param publicKey the public key used for encryption - * @param plainText the plain text string to encrypt - * @param associatedData additional data used for encryption - * @return the encrypted ciphertext - */ - byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData); - -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java deleted file mode 100644 index dd6040c1..00000000 --- a/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.federatedcompute.services.encryption; - -import com.android.federatedcompute.services.encryption.jni.HpkeJni; - - -/** - * The implementation of HPKE (Hybrid Public Key Encryption) using BoringSSL JNI. - */ -public class HpkeJniEncrypter implements Encrypter { - - @Override - public byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData) { - return HpkeJni.encrypt(publicKey, plainText, associatedData); - } -} diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java index 749fbdc7..61d719e9 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java @@ -21,8 +21,6 @@ import com.android.federatedcompute.internal.util.LogUtil; import com.google.common.collect.ImmutableSet; import com.google.protobuf.ByteString; -import org.json.JSONObject; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -51,15 +49,6 @@ public final class HttpClientUtil { POST, PUT, } - public static final class FederatedComputePayloadDataContract { - public static final String KEY_ID = "keyId"; - - public static final String ENCRYPTED_PAYLOAD = "encryptedPayload"; - - public static final String ASSOCIATED_DATA_KEY = "associatedData"; - - public static final byte[] ASSOCIATED_DATA = new JSONObject().toString().getBytes(); - } /** Compresses the input data using Gzip. */ public static byte[] compressWithGzip(byte[] uncompressedData) { diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java index 208f2be9..00fdf1fb 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java @@ -31,12 +31,8 @@ import static com.android.federatedcompute.services.http.HttpClientUtil.compress import static com.android.federatedcompute.services.http.HttpClientUtil.uncompressWithGzip; import android.os.Trace; -import android.util.Base64; import com.android.federatedcompute.internal.util.LogUtil; -import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; -import com.android.federatedcompute.services.encryption.Encrypter; -import com.android.federatedcompute.services.http.HttpClientUtil.FederatedComputePayloadDataContract; import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; import com.android.federatedcompute.services.training.util.ComputationResult; @@ -60,44 +56,34 @@ import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment; import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction; import com.google.protobuf.InvalidProtocolBufferException; -import org.json.JSONObject; - import java.util.HashMap; import java.util.UUID; import java.util.concurrent.Callable; - /** Implements a single session of HTTP-based federated compute protocol. */ public final class HttpFederatedProtocol { public static final String TAG = HttpFederatedProtocol.class.getSimpleName(); private final String mClientVersion; private final String mPopulationName; private final HttpClient mHttpClient; - private final ProtocolRequestCreator mTaskAssignmentRequestCreator; - private final Encrypter mEncrypter; private String mTaskId; private String mAggregationId; private String mAssignmentId; + private final ProtocolRequestCreator mTaskAssignmentRequestCreator; @VisibleForTesting HttpFederatedProtocol( - String entryUri, - String clientVersion, - String populationName, - HttpClient httpClient, - Encrypter encrypter) { + String entryUri, String clientVersion, String populationName, HttpClient httpClient) { this.mClientVersion = clientVersion; this.mPopulationName = populationName; this.mHttpClient = httpClient; this.mTaskAssignmentRequestCreator = new ProtocolRequestCreator(entryUri, new HashMap<>()); - this.mEncrypter = encrypter; } /** Creates a HttpFederatedProtocol object. */ public static HttpFederatedProtocol create( - String entryUri, String clientVersion, String populationName, Encrypter encrypter) { - return new HttpFederatedProtocol( - entryUri, clientVersion, populationName, new HttpClient(), encrypter); + String entryUri, String clientVersion, String populationName) { + return new HttpFederatedProtocol(entryUri, clientVersion, populationName, new HttpClient()); } /** Helper function to perform check in and download federated task from remote servers. */ @@ -147,12 +133,9 @@ public final class HttpFederatedProtocol { } /** Helper functions to reporting result and upload result. */ - public FluentFuture<RejectionInfo> reportResult( - ComputationResult computationResult, FederatedComputeEncryptionKey encryptionKey) { + public FluentFuture<RejectionInfo> reportResult(ComputationResult computationResult) { Trace.beginAsyncSection(TRACE_HTTP_REPORT_RESULT, 0); - if (computationResult != null - && computationResult.isResultSuccess() - && encryptionKey != null) { + if (computationResult != null && computationResult.isResultSuccess()) { return FluentFuture.from(performReportResult(computationResult)) .transformAsync( reportResp -> { @@ -164,9 +147,7 @@ public final class HttpFederatedProtocol { } return FluentFuture.from( processReportResultResponseAndUploadResult( - reportResultResponse, - computationResult, - encryptionKey)) + reportResultResponse, computationResult)) .transform( resp -> { validateHttpResponseStatus( @@ -323,9 +304,8 @@ public final class HttpFederatedProtocol { private ListenableFuture<FederatedComputeHttpResponse> processReportResultResponseAndUploadResult( - ReportResultResponse reportResultResponse, - ComputationResult computationResult, - FederatedComputeEncryptionKey encryptionKey) { + ReportResultResponse reportResultResponse, + ComputationResult computationResult) { try { Preconditions.checkArgument( !computationResult.getOutputCheckpointFile().isEmpty(), @@ -334,11 +314,7 @@ public final class HttpFederatedProtocol { Preconditions.checkArgument( !uploadInstruction.getUploadLocation().isEmpty(), "UploadInstruction.upload_location must not be empty"); - byte[] outputBytes = - createEncryptedRequestBody( - computationResult.getOutputCheckpointFile(), - encryptionKey); - // Apply a top-level compression to the payload. + byte[] outputBytes = readFileAsByteArray(computationResult.getOutputCheckpointFile()); if (uploadInstruction.getCompressionFormat() == ResourceCompressionFormat.RESOURCE_COMPRESSION_FORMAT_GZIP) { outputBytes = compressWithGzip(outputBytes); @@ -369,31 +345,6 @@ public final class HttpFederatedProtocol { } } - private byte[] createEncryptedRequestBody( - String filePath, - FederatedComputeEncryptionKey encryptionKey) - throws Exception { - byte[] fileOutputBytes = readFileAsByteArray(filePath); - fileOutputBytes = compressWithGzip(fileOutputBytes); - // encryption - byte[] publicKey = Base64.decode(encryptionKey.getPublicKey(), Base64.NO_WRAP); - - byte[] encryptedOutput = - mEncrypter.encrypt( - publicKey, fileOutputBytes, - FederatedComputePayloadDataContract.ASSOCIATED_DATA); - // create payload - final JSONObject body = new JSONObject(); - body.put(FederatedComputePayloadDataContract.KEY_ID, - encryptionKey.getKeyIdentifier()); - body.put(FederatedComputePayloadDataContract.ENCRYPTED_PAYLOAD, - Base64.encodeToString(encryptedOutput, Base64.NO_WRAP)); - body.put(FederatedComputePayloadDataContract.ASSOCIATED_DATA_KEY, - Base64.encodeToString(FederatedComputePayloadDataContract.ASSOCIATED_DATA, - Base64.NO_WRAP)); - return body.toString().getBytes(); - } - private ReportResultResponse getReportResultResponse(FederatedComputeHttpResponse httpResponse) throws InvalidProtocolBufferException { validateHttpResponseStatus("ReportResult", httpResponse); diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java index cb13d87b..a6583bab 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java @@ -43,11 +43,8 @@ import com.android.federatedcompute.internal.util.AbstractServiceBinder; import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.Constants; import com.android.federatedcompute.services.common.FileUtils; -import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.fbs.TrainingConstraints; -import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager; -import com.android.federatedcompute.services.encryption.HpkeJniEncrypter; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; import com.android.federatedcompute.services.http.CheckinResult; import com.android.federatedcompute.services.http.HttpFederatedProtocol; @@ -80,9 +77,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.Random; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -109,10 +104,6 @@ public class FederatedComputeWorker { private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder; private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder; - private FederatedComputeEncryptionKeyManager mEncryptionKeyManager; - - private static final int NUM_ACTIVE_KEYS_TO_CHOOSE_FROM = 5; - @VisibleForTesting public FederatedComputeWorker( Context context, @@ -120,15 +111,13 @@ public class FederatedComputeWorker { TrainingConditionsChecker trainingConditionsChecker, ComputationRunner computationRunner, ResultCallbackHelper resultCallbackHelper, - Injector injector, - FederatedComputeEncryptionKeyManager keyManager) { + Injector injector) { this.mContext = context.getApplicationContext(); this.mJobManager = jobManager; this.mTrainingConditionsChecker = trainingConditionsChecker; this.mComputationRunner = computationRunner; this.mInjector = injector; this.mResultCallbackHelper = resultCallbackHelper; - this.mEncryptionKeyManager = keyManager; } /** Gets an instance of {@link FederatedComputeWorker}. */ @@ -144,8 +133,7 @@ public class FederatedComputeWorker { TrainingConditionsChecker.getInstance(context), new ComputationRunner(), new ResultCallbackHelper(context), - new Injector(), - FederatedComputeEncryptionKeyManager.getInstance(context)); + new Injector()); } } } @@ -256,6 +244,8 @@ public class FederatedComputeWorker { TrainingRun run, CheckinResult checkinResult) { // Stop processing if have rejection Info if (checkinResult.getRejectionInfo() != null) { + LogUtil.d(TAG, "job %d was rejected during check in, reason %s", + run.mTask.jobId(), checkinResult.getRejectionInfo().getReason()); mJobManager.onTrainingCompleted( run.mTask.jobId(), run.mTask.populationName(), @@ -264,40 +254,13 @@ public class FederatedComputeWorker { ContributionResult.FAIL); return Futures.immediateFuture(null); } - + // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector. + // Set active run's task name. String taskName = checkinResult.getTaskAssignment().getTaskName(); Preconditions.checkArgument(!taskName.isEmpty(), "Task name should not be empty"); synchronized (mLock) { mActiveRun.mTaskName = taskName; } - // 2. Fetch Active keys to encrypt the computation result. - List<FederatedComputeEncryptionKey> activeKeys = mEncryptionKeyManager - .getOrFetchActiveKeys(NUM_ACTIVE_KEYS_TO_CHOOSE_FROM, FederatedComputeEncryptionKey - .KEY_TYPE_ENCRYPTION); - // select a random key - FederatedComputeEncryptionKey encryptionKey = activeKeys.isEmpty() ? null : - activeKeys.get(new Random().nextInt(activeKeys.size())); - if (encryptionKey == null) { - // no active keys to encrypt the FL/FA computation results, stop the computation run. - ComputationResult failedComputationResult = new ComputationResult( - null, - FLRunnerResult.newBuilder() - .setContributionResult(ContributionResult.FAIL) - .setErrorMessage("No active key available on device.") - .build(), null); - mJobManager.onTrainingCompleted( - run.mJobId, - run.mTask.populationName(), - run.mTask.getTrainingIntervalOptions(), - buildTaskRetry(RejectionInfo.newBuilder().build()), - ContributionResult.FAIL); - var unused = mHttpFederatedProtocol.reportResult(failedComputationResult, null); - return Futures.immediateFailedFuture( - new IllegalStateException("No active key available on device.")); - } - - // 3. Bind to client app implemented ExampleStoreService based on ExampleSelector. - // Set active run's task name. ListenableFuture<IExampleStoreIterator> iteratorFuture = getExampleStoreIterator( run, @@ -309,7 +272,7 @@ public class FederatedComputeWorker { Futures.addCallback( iteratorFuture, serverFailureReportCallback, getLightweightExecutor()); - // 4. Run federated learning or federated analytic depends on task type. Federated + // 3. Run federated learning or federated analytic depends on task type. Federated // learning job will start a new isolated process to run TFLite training. FluentFuture<ComputationResult> computationResultFuture = FluentFuture.from(iteratorFuture) @@ -320,11 +283,10 @@ public class FederatedComputeWorker { computationResultFuture.addCallback( serverFailureReportCallback, getLightweightExecutor()); - - // 5. Report computation result to federated compute server. + // 4. Report computation result to federated compute server. ListenableFuture<RejectionInfo> reportToServerFuture = computationResultFuture.transformAsync( - result -> mHttpFederatedProtocol.reportResult(result, encryptionKey), + result -> mHttpFederatedProtocol.reportResult(result), getLightweightExecutor()); return Futures.whenAllSucceed(reportToServerFuture, computationResultFuture) .call( @@ -359,7 +321,7 @@ public class FederatedComputeWorker { ContributionResult.FAIL); return null; } - // 6. Publish computation result and consumed + // 5. Publish computation result and consumed // examples to client implemented // ResultHandlingService. var unused = @@ -398,8 +360,7 @@ public class FederatedComputeWorker { .setErrorMessage(throwable.getMessage()) .build(), null); - var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult, - null); + var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult); } mNumberOfInvocations++; } @@ -492,8 +453,7 @@ public class FederatedComputeWorker { @VisibleForTesting HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) { - return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName, - new HpkeJniEncrypter()); + return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName); } private ExampleSelector getExampleSelector(CheckinResult checkinResult) { diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java index 48c35f9f..e483cf22 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java @@ -43,16 +43,12 @@ import android.net.Uri; import androidx.test.core.app.ApplicationProvider; -import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; -import com.android.federatedcompute.services.encryption.HpkeJniEncrypter; import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; import com.android.federatedcompute.services.testutils.TrainingTestUtil; import com.android.federatedcompute.services.training.util.ComputationResult; -import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Range; import com.google.intelligence.fcp.client.FLRunnerResult; import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; import com.google.internal.federatedcompute.v1.ClientVersion; @@ -105,14 +101,6 @@ public final class HttpFederatedProtocolTest { private static final String ASSIGNMENT_ID = "assignment-id"; private static final String AGGREGATION_ID = "aggregation-id"; private static final String OCTET_STREAM = "application/octet-stream"; - - private static final FederatedComputeEncryptionKey ENCRYPTION_KEY = - new FederatedComputeEncryptionKey.Builder() - .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=") - .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887") - .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION) - .setCreationTime(1L) - .setExpiryTime(1L).build(); private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build(); @@ -156,8 +144,7 @@ public final class HttpFederatedProtocolTest { TASK_ASSIGNMENT_TARGET_URI, CLIENT_VERSION, POPULATION_NAME, - mMockHttpClient, - new HpkeJniEncrypter()); + mMockHttpClient); } @Test @@ -293,7 +280,7 @@ public final class HttpFederatedProtocolTest { // Setup task id, aggregation id for report result. mHttpFederatedProtocol.issueCheckin().get(); - mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get(); + mHttpFederatedProtocol.reportResult(computationResult).get(); // Verify ReportResult request. List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues(); @@ -314,7 +301,7 @@ public final class HttpFederatedProtocolTest { // Setup task id, aggregation id for report result. mHttpFederatedProtocol.issueCheckin().get(); - mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get(); + mHttpFederatedProtocol.reportResult(computationResult).get(); // Verify ReportResult request. List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues(); @@ -335,23 +322,14 @@ public final class HttpFederatedProtocolTest { assertThat(actualDataUploadRequest.getUri()).isEqualTo(UPLOAD_LOCATION_URI); assertThat(acutalReportResultRequest.getHttpMethod()).isEqualTo(HttpMethod.PUT); expectedHeaders = new HashMap<>(); - expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM); if (mSupportCompression) { + expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(339)); expectedHeaders.put(CONTENT_ENCODING_HDR, GZIP_ENCODING_HDR); - } - - int actualContentLength = Integer - .parseInt(actualDataUploadRequest.getExtraHeaders().remove(CONTENT_LENGTH_HDR)); - assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders); - // The encryption is non-deterministic with BoringSSL JNI. - // Only check the range of the content. - if (mSupportCompression) { - assertThat(actualContentLength) - .isIn(Range.range(500, BoundType.CLOSED, 550, BoundType.CLOSED)); } else { - assertThat(actualContentLength) - .isIn(Range.range(600, BoundType.CLOSED, 650, BoundType.CLOSED)); + expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(31846)); } + expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM); + assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders); } @Test @@ -372,8 +350,7 @@ public final class HttpFederatedProtocolTest { ExecutionException exception = assertThrows( ExecutionException.class, - () -> mHttpFederatedProtocol.reportResult(computationResult, - ENCRYPTION_KEY).get()); + () -> mHttpFederatedProtocol.reportResult(computationResult).get()); assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class); assertThat(exception.getCause()).hasMessageThat().isEqualTo("ReportResult failed: 503"); @@ -397,8 +374,7 @@ public final class HttpFederatedProtocolTest { ExecutionException exception = assertThrows( ExecutionException.class, - () -> mHttpFederatedProtocol.reportResult(computationResult, - ENCRYPTION_KEY).get()); + () -> mHttpFederatedProtocol.reportResult(computationResult).get()); assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); assertThat(exception.getCause()).hasMessageThat().isEqualTo("Upload result failed: 503"); diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java index 8f98c6a0..71fe9998 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java @@ -48,14 +48,11 @@ import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; import com.android.federatedcompute.services.common.Constants; -import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.fbs.SchedulingMode; import com.android.federatedcompute.services.data.fbs.SchedulingReason; import com.android.federatedcompute.services.data.fbs.TrainingConstraints; import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; -import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager; -import com.android.federatedcompute.services.encryption.HpkeJniEncrypter; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; import com.android.federatedcompute.services.http.CheckinResult; import com.android.federatedcompute.services.http.HttpFederatedProtocol; @@ -107,7 +104,6 @@ import org.tensorflow.example.Features; import java.util.ArrayList; import java.util.EnumSet; -import java.util.List; import java.util.concurrent.ExecutionException; @RunWith(JUnit4.class) @@ -197,16 +193,6 @@ public final class FederatedComputeWorkerTest { @Mock private ComputationRunner mMockComputationRunner; private ResultCallbackHelper mSpyResultCallbackHelper; - @Mock private FederatedComputeEncryptionKeyManager mMockKeyManager; - - private static final FederatedComputeEncryptionKey ENCRYPTION_KEY = - new FederatedComputeEncryptionKey.Builder() - .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=") - .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887") - .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION) - .setCreationTime(1L) - .setExpiryTime(1L).build(); - private static byte[] createTrainingConstraints( boolean requiresSchedulerIdle, boolean requiresSchedulerBatteryNotLow, @@ -234,11 +220,7 @@ public final class FederatedComputeWorkerTest { mContext = ApplicationProvider.getApplicationContext(); mSpyHttpFederatedProtocol = Mockito.spy( - HttpFederatedProtocol.create( - SERVER_ADDRESS, - "1.0.0.1", - POPULATION_NAME, - new HpkeJniEncrypter())); + HttpFederatedProtocol.create(SERVER_ADDRESS, "1.0.0.1", POPULATION_NAME)); mSpyResultCallbackHelper = Mockito.spy(new ResultCallbackHelper(mContext)); mSpyWorker = Mockito.spy( @@ -248,8 +230,7 @@ public final class FederatedComputeWorkerTest { mTrainingConditionsChecker, mMockComputationRunner, mSpyResultCallbackHelper, - new TestInjector(), - mMockKeyManager)); + new TestInjector())); when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any())) .thenReturn(EnumSet.noneOf(Condition.class)); doReturn(Futures.immediateFuture(CallbackResult.SUCCESS)) @@ -270,8 +251,6 @@ public final class FederatedComputeWorkerTest { any(), any())) .thenReturn(FL_RUNNER_SUCCESS_RESULT); - doReturn(List.of(ENCRYPTION_KEY)).when(mMockKeyManager).getOrFetchActiveKeys(anyInt(), - anyInt()); } @Test @@ -310,7 +289,7 @@ public final class FederatedComputeWorkerTest { .issueCheckin(); doReturn(FluentFuture.from(immediateFuture(null))) .when(mSpyHttpFederatedProtocol) - .reportResult(any(), any()); + .reportResult(any()); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -345,7 +324,7 @@ public final class FederatedComputeWorkerTest { .issueCheckin(); doReturn(FluentFuture.from(immediateFuture(REJECTION_INFO))) .when(mSpyHttpFederatedProtocol) - .reportResult(any(), any()); + .reportResult(any()); doCallRealMethod().when(mSpyResultCallbackHelper).callHandleResult(any(), any(), any()); ArgumentCaptor<ComputationResult> computationResultCaptor = ArgumentCaptor.forClass(ComputationResult.class); @@ -381,7 +360,7 @@ public final class FederatedComputeWorkerTest { "report result failed", new IllegalStateException("http 404"))))) .when(mSpyHttpFederatedProtocol) - .reportResult(any(), any()); + .reportResult(any()); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -408,7 +387,7 @@ public final class FederatedComputeWorkerTest { anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); verify(mSpyWorker, times(0)).unbindFromExampleStoreService(); verify(mSpyHttpFederatedProtocol, times(1)) - .reportResult(computationResultCaptor.capture(), any()); + .reportResult(computationResultCaptor.capture()); ComputationResult computationResult = computationResultCaptor.getValue(); assertNotNull(computationResult.getFlRunnerResult()); assertEquals( @@ -450,7 +429,7 @@ public final class FederatedComputeWorkerTest { setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); doReturn(FluentFuture.from(immediateFuture(null))) .when(mSpyHttpFederatedProtocol) - .reportResult(any(), any()); + .reportResult(any()); when(mMockComputationRunner.runTaskWithNativeRunner( anyString(), anyString(), @@ -473,7 +452,7 @@ public final class FederatedComputeWorkerTest { anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); verify(mSpyWorker).unbindFromExampleStoreService(); verify(mSpyHttpFederatedProtocol, times(1)) - .reportResult(computationResultCaptor.capture(), any()); + .reportResult(computationResultCaptor.capture()); ComputationResult computationResult = computationResultCaptor.getValue(); assertNotNull(computationResult.getFlRunnerResult()); assertEquals( @@ -617,22 +596,6 @@ public final class FederatedComputeWorkerTest { verify(mSpyWorker).unbindFromExampleStoreService(); } - @Test - public void testRunFLComputation_noKey_throws() throws Exception { - setUpHttpFederatedProtocol(FL_CHECKIN_RESULT); - doReturn(new ArrayList<FederatedComputeEncryptionKey>() {}).when(mMockKeyManager) - .getOrFetchActiveKeys(anyInt(), anyInt()); - setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); - - assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); - - verify(mMockJobManager) - .onTrainingCompleted( - anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); - verify(mSpyHttpFederatedProtocol).reportResult( - any(), eq(null)); - } - private void setUpExampleStoreService() { TestExampleStoreService testExampleStoreService = new TestExampleStoreService(); doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString()); @@ -643,7 +606,7 @@ public final class FederatedComputeWorkerTest { doReturn(immediateFuture(checkinResult)).when(mSpyHttpFederatedProtocol).issueCheckin(); doReturn(FluentFuture.from(immediateFuture(null))) .when(mSpyHttpFederatedProtocol) - .reportResult(any(), any()); + .reportResult(any()); } private static class TestExampleStoreService extends IExampleStoreService.Stub { diff --git a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java new file mode 100644 index 00000000..27ee61c7 --- /dev/null +++ b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.federatedcompute.test.scenario.federatedcompute; + +import android.platform.test.scenario.annotation.Scenario; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +@Scenario +@RunWith(JUnit4.class) +/** + * Schedule a non-existent population training task from Odp Test app UI + * Force the task execution through ADB commands and verify error handling and exit behavior + */ +public class ScheduleNonExistentPopulationForTraining { + private TestHelper mTestHelper = new TestHelper(); + + /** Prepare the device before entering the test class */ + @BeforeClass + public static void prepareDevice() throws IOException { + TestHelper.initialize(); + TestHelper.killRunningProcess(); + } + + @Before + public void setup() throws IOException { + mTestHelper.pressHome(); + mTestHelper.openTestApp(); + mTestHelper.inputNonExistentPopulationForScheduleTraining(); + } + + @Test + public void testScheduleNonExistentPopulationForTraining() throws IOException { + mTestHelper.clickScheduleTraining(); + mTestHelper.forceExecuteTrainingForNonExistentPopulation(); + } + + /** Return device to original state after test exeuction */ + @AfterClass + public static void tearDown() throws IOException { + TestHelper.pressHome(); + TestHelper.wrapUp(); + } +} diff --git a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java index 07ec1f87..4c135533 100644 --- a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java +++ b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java @@ -34,6 +34,7 @@ public class TestHelper { private static UiDevice sUiDevice; private static final long UI_FIND_RESOURCE_TIMEOUT = 5000; private static final long TRAINING_TASK_COMPLETION_TIMEOUT = 120_000; + private static final long CHECKIN_REJECTION_COMPLETION_TIMEOUT = 20_000; private static final String ODP_CLIENT_TEST_APP_PACKAGE_NAME = "com.example.odpclient"; private static final String SCHEDULE_TRAINING_BUTTON_RESOURCE_ID = "schedule_training_button"; private static final String SCHEDULE_TRAINING_TEXT_BOX_RESOURCE_ID = @@ -42,6 +43,10 @@ public class TestHelper { private static final String ODP_TEST_APP_TRAINING_TASK_JOB_ID = "-630817781"; private static final String FEDERATED_TRAINING_JOB_SUCCESS_LOG = "FederatedJobService - Federated computation job -630817781 is done"; + private static final String NON_EXISTENT_POPULATION_NAME = "test_non_existent_population"; + private static final String NON_EXISTENT_POPULATION_TASK_JOB_ID = "1892833995"; + private static final String NON_EXISTENT_POPULATION_JOB_FAILURE_LOG = + "job 1892833995 was rejected during check in, reason NO_TASK_AVAILABLE"; public static void pressHome() { getUiDevice().pressHome(); @@ -119,6 +124,13 @@ public class TestHelper { scheduleTrainingTextBox.setText(ODP_TEST_APP_POPULATION_NAME); } + /** Put a test non existent population name down for training */ + public void inputNonExistentPopulationForScheduleTraining() { + UiObject2 scheduleTrainingTextBox = getScheduleTrainingTextBox(); + assertNotNull("Schedule Training text box not found", scheduleTrainingTextBox); + scheduleTrainingTextBox.setText(NON_EXISTENT_POPULATION_NAME); + } + /** Click Schedule Training button. */ public void clickScheduleTraining() { UiObject2 scheduleTrainingButton = getScheduleTrainingButton(); @@ -143,11 +155,34 @@ public class TestHelper { if (!foundTrainingJobSuccessLog) { Assert.fail(String.format( - "Failed to find federated training job success log within test window %d ms", + "Failed to find federated training job success log %s within test window %d ms", + FEDERATED_TRAINING_JOB_SUCCESS_LOG, TRAINING_TASK_COMPLETION_TIMEOUT)); } } + /** Force the JobScheduler to execute the training task for non existent population */ + public void forceExecuteTrainingForNonExistentPopulation() throws IOException { + executeShellCommand("logcat -c"); // Cleans the log buffer + executeShellCommand("logcat -G 32M"); // Set log buffer to 32MB + executeShellCommand( + "cmd jobscheduler run -f com.google.android.federatedcompute " + + NON_EXISTENT_POPULATION_TASK_JOB_ID); + SystemClock.sleep(10000); + + boolean foundTrainingFailureLog = findLog( + NON_EXISTENT_POPULATION_JOB_FAILURE_LOG, + CHECKIN_REJECTION_COMPLETION_TIMEOUT, + 5000); + + if (!foundTrainingFailureLog) { + Assert.fail(String.format( + "Failed to find federated training failure log: %s within test window %d ms", + NON_EXISTENT_POPULATION_JOB_FAILURE_LOG, + CHECKIN_REJECTION_COMPLETION_TIMEOUT)); + } + } + /** Attempt to find a specific log entry within the timeout window */ private boolean findLog(final String targetLog, long timeoutMillis, long queryIntervalMillis) throws IOException { diff --git a/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java new file mode 100644 index 00000000..8150bf3d --- /dev/null +++ b/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.federatedcompute.test.scenario.federatedcompute; + +import android.platform.test.microbenchmark.Microbenchmark; +import android.platform.test.rule.DropCachesRule; +import android.platform.test.rule.KillAppsRule; +import android.platform.test.rule.PressHomeRule; + +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.runner.RunWith; + +@RunWith(Microbenchmark.class) +public class ScheduleNonExistentPopulationForTrainingMicrobenchmark + extends ScheduleNonExistentPopulationForTraining { + + @Rule + public RuleChain rules = RuleChain.outerRule(new DropCachesRule()) + .around(new KillAppsRule("com.example.odpclient")) + .around(new PressHomeRule()); +} |