diff options
author | Carol Zheng <cazheng@google.com> | 2023-12-13 18:11:17 +0000 |
---|---|---|
committer | Android (Google) Code Review <android-gerrit@google.com> | 2023-12-13 18:11:17 +0000 |
commit | 1c22a514b272bedc5201fc00a580777a6ce67de5 (patch) | |
tree | 8a06d928b9939f62a18e0fe3b402fc04082e985d | |
parent | 0b74c968324906c1dc76c48407ce74b9b3ac2fdd (diff) | |
parent | 05fcc96163fd735cabb556ddd5808fa1fbcc1f51 (diff) | |
download | OnDevicePersonalization-1c22a514b272bedc5201fc00a580777a6ce67de5.tar.gz |
Merge "Encrypt FL/FA payload." into udc-mainline-prod
7 files changed, 264 insertions, 38 deletions
diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java new file mode 100644 index 00000000..7398e1f5 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.encryption; + +/** Interface for crypto libraries to encrypt data */ +public interface Encrypter { + + /** + * encrypt {@code plainText} to cipher text {@code byte[]}. + * + * @param publicKey the public key used for encryption + * @param plainText the plain text string to encrypt + * @param associatedData additional data used for encryption + * @return the encrypted ciphertext + */ + byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData); + +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java new file mode 100644 index 00000000..dd6040c1 --- /dev/null +++ b/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.federatedcompute.services.encryption; + +import com.android.federatedcompute.services.encryption.jni.HpkeJni; + + +/** + * The implementation of HPKE (Hybrid Public Key Encryption) using BoringSSL JNI. + */ +public class HpkeJniEncrypter implements Encrypter { + + @Override + public byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData) { + return HpkeJni.encrypt(publicKey, plainText, associatedData); + } +} diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java index 61d719e9..749fbdc7 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java @@ -21,6 +21,8 @@ import com.android.federatedcompute.internal.util.LogUtil; import com.google.common.collect.ImmutableSet; import com.google.protobuf.ByteString; +import org.json.JSONObject; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -49,6 +51,15 @@ public final class HttpClientUtil { POST, PUT, } + public static final class FederatedComputePayloadDataContract { + public static final String KEY_ID = "keyId"; + + public static final String ENCRYPTED_PAYLOAD = "encryptedPayload"; + + public static final String ASSOCIATED_DATA_KEY = "associatedData"; + + public static final byte[] ASSOCIATED_DATA = new JSONObject().toString().getBytes(); + } /** Compresses the input data using Gzip. */ public static byte[] compressWithGzip(byte[] uncompressedData) { diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java index 00fdf1fb..208f2be9 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java @@ -31,8 +31,12 @@ import static com.android.federatedcompute.services.http.HttpClientUtil.compress import static com.android.federatedcompute.services.http.HttpClientUtil.uncompressWithGzip; import android.os.Trace; +import android.util.Base64; import com.android.federatedcompute.internal.util.LogUtil; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; +import com.android.federatedcompute.services.encryption.Encrypter; +import com.android.federatedcompute.services.http.HttpClientUtil.FederatedComputePayloadDataContract; import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; import com.android.federatedcompute.services.training.util.ComputationResult; @@ -56,34 +60,44 @@ import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment; import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction; import com.google.protobuf.InvalidProtocolBufferException; +import org.json.JSONObject; + import java.util.HashMap; import java.util.UUID; import java.util.concurrent.Callable; + /** Implements a single session of HTTP-based federated compute protocol. */ public final class HttpFederatedProtocol { public static final String TAG = HttpFederatedProtocol.class.getSimpleName(); private final String mClientVersion; private final String mPopulationName; private final HttpClient mHttpClient; + private final ProtocolRequestCreator mTaskAssignmentRequestCreator; + private final Encrypter mEncrypter; private String mTaskId; private String mAggregationId; private String mAssignmentId; - private final ProtocolRequestCreator mTaskAssignmentRequestCreator; @VisibleForTesting HttpFederatedProtocol( - String entryUri, String clientVersion, String populationName, HttpClient httpClient) { + String entryUri, + String clientVersion, + String populationName, + HttpClient httpClient, + Encrypter encrypter) { this.mClientVersion = clientVersion; this.mPopulationName = populationName; this.mHttpClient = httpClient; this.mTaskAssignmentRequestCreator = new ProtocolRequestCreator(entryUri, new HashMap<>()); + this.mEncrypter = encrypter; } /** Creates a HttpFederatedProtocol object. */ public static HttpFederatedProtocol create( - String entryUri, String clientVersion, String populationName) { - return new HttpFederatedProtocol(entryUri, clientVersion, populationName, new HttpClient()); + String entryUri, String clientVersion, String populationName, Encrypter encrypter) { + return new HttpFederatedProtocol( + entryUri, clientVersion, populationName, new HttpClient(), encrypter); } /** Helper function to perform check in and download federated task from remote servers. */ @@ -133,9 +147,12 @@ public final class HttpFederatedProtocol { } /** Helper functions to reporting result and upload result. */ - public FluentFuture<RejectionInfo> reportResult(ComputationResult computationResult) { + public FluentFuture<RejectionInfo> reportResult( + ComputationResult computationResult, FederatedComputeEncryptionKey encryptionKey) { Trace.beginAsyncSection(TRACE_HTTP_REPORT_RESULT, 0); - if (computationResult != null && computationResult.isResultSuccess()) { + if (computationResult != null + && computationResult.isResultSuccess() + && encryptionKey != null) { return FluentFuture.from(performReportResult(computationResult)) .transformAsync( reportResp -> { @@ -147,7 +164,9 @@ public final class HttpFederatedProtocol { } return FluentFuture.from( processReportResultResponseAndUploadResult( - reportResultResponse, computationResult)) + reportResultResponse, + computationResult, + encryptionKey)) .transform( resp -> { validateHttpResponseStatus( @@ -304,8 +323,9 @@ public final class HttpFederatedProtocol { private ListenableFuture<FederatedComputeHttpResponse> processReportResultResponseAndUploadResult( - ReportResultResponse reportResultResponse, - ComputationResult computationResult) { + ReportResultResponse reportResultResponse, + ComputationResult computationResult, + FederatedComputeEncryptionKey encryptionKey) { try { Preconditions.checkArgument( !computationResult.getOutputCheckpointFile().isEmpty(), @@ -314,7 +334,11 @@ public final class HttpFederatedProtocol { Preconditions.checkArgument( !uploadInstruction.getUploadLocation().isEmpty(), "UploadInstruction.upload_location must not be empty"); - byte[] outputBytes = readFileAsByteArray(computationResult.getOutputCheckpointFile()); + byte[] outputBytes = + createEncryptedRequestBody( + computationResult.getOutputCheckpointFile(), + encryptionKey); + // Apply a top-level compression to the payload. if (uploadInstruction.getCompressionFormat() == ResourceCompressionFormat.RESOURCE_COMPRESSION_FORMAT_GZIP) { outputBytes = compressWithGzip(outputBytes); @@ -345,6 +369,31 @@ public final class HttpFederatedProtocol { } } + private byte[] createEncryptedRequestBody( + String filePath, + FederatedComputeEncryptionKey encryptionKey) + throws Exception { + byte[] fileOutputBytes = readFileAsByteArray(filePath); + fileOutputBytes = compressWithGzip(fileOutputBytes); + // encryption + byte[] publicKey = Base64.decode(encryptionKey.getPublicKey(), Base64.NO_WRAP); + + byte[] encryptedOutput = + mEncrypter.encrypt( + publicKey, fileOutputBytes, + FederatedComputePayloadDataContract.ASSOCIATED_DATA); + // create payload + final JSONObject body = new JSONObject(); + body.put(FederatedComputePayloadDataContract.KEY_ID, + encryptionKey.getKeyIdentifier()); + body.put(FederatedComputePayloadDataContract.ENCRYPTED_PAYLOAD, + Base64.encodeToString(encryptedOutput, Base64.NO_WRAP)); + body.put(FederatedComputePayloadDataContract.ASSOCIATED_DATA_KEY, + Base64.encodeToString(FederatedComputePayloadDataContract.ASSOCIATED_DATA, + Base64.NO_WRAP)); + return body.toString().getBytes(); + } + private ReportResultResponse getReportResultResponse(FederatedComputeHttpResponse httpResponse) throws InvalidProtocolBufferException { validateHttpResponseStatus("ReportResult", httpResponse); diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java index 4244aff4..cb13d87b 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java @@ -43,8 +43,11 @@ import com.android.federatedcompute.internal.util.AbstractServiceBinder; import com.android.federatedcompute.internal.util.LogUtil; import com.android.federatedcompute.services.common.Constants; import com.android.federatedcompute.services.common.FileUtils; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.fbs.TrainingConstraints; +import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager; +import com.android.federatedcompute.services.encryption.HpkeJniEncrypter; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; import com.android.federatedcompute.services.http.CheckinResult; import com.android.federatedcompute.services.http.HttpFederatedProtocol; @@ -77,7 +80,9 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import java.util.Random; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -104,6 +109,10 @@ public class FederatedComputeWorker { private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder; private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder; + private FederatedComputeEncryptionKeyManager mEncryptionKeyManager; + + private static final int NUM_ACTIVE_KEYS_TO_CHOOSE_FROM = 5; + @VisibleForTesting public FederatedComputeWorker( Context context, @@ -111,13 +120,15 @@ public class FederatedComputeWorker { TrainingConditionsChecker trainingConditionsChecker, ComputationRunner computationRunner, ResultCallbackHelper resultCallbackHelper, - Injector injector) { + Injector injector, + FederatedComputeEncryptionKeyManager keyManager) { this.mContext = context.getApplicationContext(); this.mJobManager = jobManager; this.mTrainingConditionsChecker = trainingConditionsChecker; this.mComputationRunner = computationRunner; this.mInjector = injector; this.mResultCallbackHelper = resultCallbackHelper; + this.mEncryptionKeyManager = keyManager; } /** Gets an instance of {@link FederatedComputeWorker}. */ @@ -133,7 +144,8 @@ public class FederatedComputeWorker { TrainingConditionsChecker.getInstance(context), new ComputationRunner(), new ResultCallbackHelper(context), - new Injector()); + new Injector(), + FederatedComputeEncryptionKeyManager.getInstance(context)); } } } @@ -252,13 +264,40 @@ public class FederatedComputeWorker { ContributionResult.FAIL); return Futures.immediateFuture(null); } - // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector. - // Set active run's task name. + String taskName = checkinResult.getTaskAssignment().getTaskName(); Preconditions.checkArgument(!taskName.isEmpty(), "Task name should not be empty"); synchronized (mLock) { mActiveRun.mTaskName = taskName; } + // 2. Fetch Active keys to encrypt the computation result. + List<FederatedComputeEncryptionKey> activeKeys = mEncryptionKeyManager + .getOrFetchActiveKeys(NUM_ACTIVE_KEYS_TO_CHOOSE_FROM, FederatedComputeEncryptionKey + .KEY_TYPE_ENCRYPTION); + // select a random key + FederatedComputeEncryptionKey encryptionKey = activeKeys.isEmpty() ? null : + activeKeys.get(new Random().nextInt(activeKeys.size())); + if (encryptionKey == null) { + // no active keys to encrypt the FL/FA computation results, stop the computation run. + ComputationResult failedComputationResult = new ComputationResult( + null, + FLRunnerResult.newBuilder() + .setContributionResult(ContributionResult.FAIL) + .setErrorMessage("No active key available on device.") + .build(), null); + mJobManager.onTrainingCompleted( + run.mJobId, + run.mTask.populationName(), + run.mTask.getTrainingIntervalOptions(), + buildTaskRetry(RejectionInfo.newBuilder().build()), + ContributionResult.FAIL); + var unused = mHttpFederatedProtocol.reportResult(failedComputationResult, null); + return Futures.immediateFailedFuture( + new IllegalStateException("No active key available on device.")); + } + + // 3. Bind to client app implemented ExampleStoreService based on ExampleSelector. + // Set active run's task name. ListenableFuture<IExampleStoreIterator> iteratorFuture = getExampleStoreIterator( run, @@ -270,7 +309,7 @@ public class FederatedComputeWorker { Futures.addCallback( iteratorFuture, serverFailureReportCallback, getLightweightExecutor()); - // 3. Run federated learning or federated analytic depends on task type. Federated + // 4. Run federated learning or federated analytic depends on task type. Federated // learning job will start a new isolated process to run TFLite training. FluentFuture<ComputationResult> computationResultFuture = FluentFuture.from(iteratorFuture) @@ -281,10 +320,11 @@ public class FederatedComputeWorker { computationResultFuture.addCallback( serverFailureReportCallback, getLightweightExecutor()); - // 4. Report computation result to federated compute server. + + // 5. Report computation result to federated compute server. ListenableFuture<RejectionInfo> reportToServerFuture = computationResultFuture.transformAsync( - result -> mHttpFederatedProtocol.reportResult(result), + result -> mHttpFederatedProtocol.reportResult(result, encryptionKey), getLightweightExecutor()); return Futures.whenAllSucceed(reportToServerFuture, computationResultFuture) .call( @@ -319,7 +359,7 @@ public class FederatedComputeWorker { ContributionResult.FAIL); return null; } - // 5. Publish computation result and consumed + // 6. Publish computation result and consumed // examples to client implemented // ResultHandlingService. var unused = @@ -358,7 +398,8 @@ public class FederatedComputeWorker { .setErrorMessage(throwable.getMessage()) .build(), null); - var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult); + var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult, + null); } mNumberOfInvocations++; } @@ -451,7 +492,8 @@ public class FederatedComputeWorker { @VisibleForTesting HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) { - return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName); + return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName, + new HpkeJniEncrypter()); } private ExampleSelector getExampleSelector(CheckinResult checkinResult) { diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java index e483cf22..48c35f9f 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java @@ -43,12 +43,16 @@ import android.net.Uri; import androidx.test.core.app.ApplicationProvider; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; +import com.android.federatedcompute.services.encryption.HpkeJniEncrypter; import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod; import com.android.federatedcompute.services.testutils.TrainingTestUtil; import com.android.federatedcompute.services.training.util.ComputationResult; +import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Range; import com.google.intelligence.fcp.client.FLRunnerResult; import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult; import com.google.internal.federatedcompute.v1.ClientVersion; @@ -101,6 +105,14 @@ public final class HttpFederatedProtocolTest { private static final String ASSIGNMENT_ID = "assignment-id"; private static final String AGGREGATION_ID = "aggregation-id"; private static final String OCTET_STREAM = "application/octet-stream"; + + private static final FederatedComputeEncryptionKey ENCRYPTION_KEY = + new FederatedComputeEncryptionKey.Builder() + .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=") + .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887") + .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION) + .setCreationTime(1L) + .setExpiryTime(1L).build(); private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT = FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build(); @@ -144,7 +156,8 @@ public final class HttpFederatedProtocolTest { TASK_ASSIGNMENT_TARGET_URI, CLIENT_VERSION, POPULATION_NAME, - mMockHttpClient); + mMockHttpClient, + new HpkeJniEncrypter()); } @Test @@ -280,7 +293,7 @@ public final class HttpFederatedProtocolTest { // Setup task id, aggregation id for report result. mHttpFederatedProtocol.issueCheckin().get(); - mHttpFederatedProtocol.reportResult(computationResult).get(); + mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get(); // Verify ReportResult request. List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues(); @@ -301,7 +314,7 @@ public final class HttpFederatedProtocolTest { // Setup task id, aggregation id for report result. mHttpFederatedProtocol.issueCheckin().get(); - mHttpFederatedProtocol.reportResult(computationResult).get(); + mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get(); // Verify ReportResult request. List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues(); @@ -322,14 +335,23 @@ public final class HttpFederatedProtocolTest { assertThat(actualDataUploadRequest.getUri()).isEqualTo(UPLOAD_LOCATION_URI); assertThat(acutalReportResultRequest.getHttpMethod()).isEqualTo(HttpMethod.PUT); expectedHeaders = new HashMap<>(); + expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM); if (mSupportCompression) { - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(339)); expectedHeaders.put(CONTENT_ENCODING_HDR, GZIP_ENCODING_HDR); - } else { - expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(31846)); } - expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM); + + int actualContentLength = Integer + .parseInt(actualDataUploadRequest.getExtraHeaders().remove(CONTENT_LENGTH_HDR)); assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders); + // The encryption is non-deterministic with BoringSSL JNI. + // Only check the range of the content. + if (mSupportCompression) { + assertThat(actualContentLength) + .isIn(Range.range(500, BoundType.CLOSED, 550, BoundType.CLOSED)); + } else { + assertThat(actualContentLength) + .isIn(Range.range(600, BoundType.CLOSED, 650, BoundType.CLOSED)); + } } @Test @@ -350,7 +372,8 @@ public final class HttpFederatedProtocolTest { ExecutionException exception = assertThrows( ExecutionException.class, - () -> mHttpFederatedProtocol.reportResult(computationResult).get()); + () -> mHttpFederatedProtocol.reportResult(computationResult, + ENCRYPTION_KEY).get()); assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class); assertThat(exception.getCause()).hasMessageThat().isEqualTo("ReportResult failed: 503"); @@ -374,7 +397,8 @@ public final class HttpFederatedProtocolTest { ExecutionException exception = assertThrows( ExecutionException.class, - () -> mHttpFederatedProtocol.reportResult(computationResult).get()); + () -> mHttpFederatedProtocol.reportResult(computationResult, + ENCRYPTION_KEY).get()); assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); assertThat(exception.getCause()).hasMessageThat().isEqualTo("Upload result failed: 503"); diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java index 71fe9998..8f98c6a0 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java @@ -48,11 +48,14 @@ import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; import com.android.federatedcompute.services.common.Constants; +import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey; import com.android.federatedcompute.services.data.FederatedTrainingTask; import com.android.federatedcompute.services.data.fbs.SchedulingMode; import com.android.federatedcompute.services.data.fbs.SchedulingReason; import com.android.federatedcompute.services.data.fbs.TrainingConstraints; import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; +import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager; +import com.android.federatedcompute.services.encryption.HpkeJniEncrypter; import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder; import com.android.federatedcompute.services.http.CheckinResult; import com.android.federatedcompute.services.http.HttpFederatedProtocol; @@ -104,6 +107,7 @@ import org.tensorflow.example.Features; import java.util.ArrayList; import java.util.EnumSet; +import java.util.List; import java.util.concurrent.ExecutionException; @RunWith(JUnit4.class) @@ -193,6 +197,16 @@ public final class FederatedComputeWorkerTest { @Mock private ComputationRunner mMockComputationRunner; private ResultCallbackHelper mSpyResultCallbackHelper; + @Mock private FederatedComputeEncryptionKeyManager mMockKeyManager; + + private static final FederatedComputeEncryptionKey ENCRYPTION_KEY = + new FederatedComputeEncryptionKey.Builder() + .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=") + .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887") + .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION) + .setCreationTime(1L) + .setExpiryTime(1L).build(); + private static byte[] createTrainingConstraints( boolean requiresSchedulerIdle, boolean requiresSchedulerBatteryNotLow, @@ -220,7 +234,11 @@ public final class FederatedComputeWorkerTest { mContext = ApplicationProvider.getApplicationContext(); mSpyHttpFederatedProtocol = Mockito.spy( - HttpFederatedProtocol.create(SERVER_ADDRESS, "1.0.0.1", POPULATION_NAME)); + HttpFederatedProtocol.create( + SERVER_ADDRESS, + "1.0.0.1", + POPULATION_NAME, + new HpkeJniEncrypter())); mSpyResultCallbackHelper = Mockito.spy(new ResultCallbackHelper(mContext)); mSpyWorker = Mockito.spy( @@ -230,7 +248,8 @@ public final class FederatedComputeWorkerTest { mTrainingConditionsChecker, mMockComputationRunner, mSpyResultCallbackHelper, - new TestInjector())); + new TestInjector(), + mMockKeyManager)); when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any())) .thenReturn(EnumSet.noneOf(Condition.class)); doReturn(Futures.immediateFuture(CallbackResult.SUCCESS)) @@ -251,6 +270,8 @@ public final class FederatedComputeWorkerTest { any(), any())) .thenReturn(FL_RUNNER_SUCCESS_RESULT); + doReturn(List.of(ENCRYPTION_KEY)).when(mMockKeyManager).getOrFetchActiveKeys(anyInt(), + anyInt()); } @Test @@ -289,7 +310,7 @@ public final class FederatedComputeWorkerTest { .issueCheckin(); doReturn(FluentFuture.from(immediateFuture(null))) .when(mSpyHttpFederatedProtocol) - .reportResult(any()); + .reportResult(any(), any()); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -324,7 +345,7 @@ public final class FederatedComputeWorkerTest { .issueCheckin(); doReturn(FluentFuture.from(immediateFuture(REJECTION_INFO))) .when(mSpyHttpFederatedProtocol) - .reportResult(any()); + .reportResult(any(), any()); doCallRealMethod().when(mSpyResultCallbackHelper).callHandleResult(any(), any(), any()); ArgumentCaptor<ComputationResult> computationResultCaptor = ArgumentCaptor.forClass(ComputationResult.class); @@ -360,7 +381,7 @@ public final class FederatedComputeWorkerTest { "report result failed", new IllegalStateException("http 404"))))) .when(mSpyHttpFederatedProtocol) - .reportResult(any()); + .reportResult(any(), any()); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -387,7 +408,7 @@ public final class FederatedComputeWorkerTest { anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); verify(mSpyWorker, times(0)).unbindFromExampleStoreService(); verify(mSpyHttpFederatedProtocol, times(1)) - .reportResult(computationResultCaptor.capture()); + .reportResult(computationResultCaptor.capture(), any()); ComputationResult computationResult = computationResultCaptor.getValue(); assertNotNull(computationResult.getFlRunnerResult()); assertEquals( @@ -429,7 +450,7 @@ public final class FederatedComputeWorkerTest { setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); doReturn(FluentFuture.from(immediateFuture(null))) .when(mSpyHttpFederatedProtocol) - .reportResult(any()); + .reportResult(any(), any()); when(mMockComputationRunner.runTaskWithNativeRunner( anyString(), anyString(), @@ -452,7 +473,7 @@ public final class FederatedComputeWorkerTest { anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); verify(mSpyWorker).unbindFromExampleStoreService(); verify(mSpyHttpFederatedProtocol, times(1)) - .reportResult(computationResultCaptor.capture()); + .reportResult(computationResultCaptor.capture(), any()); ComputationResult computationResult = computationResultCaptor.getValue(); assertNotNull(computationResult.getFlRunnerResult()); assertEquals( @@ -596,6 +617,22 @@ public final class FederatedComputeWorkerTest { verify(mSpyWorker).unbindFromExampleStoreService(); } + @Test + public void testRunFLComputation_noKey_throws() throws Exception { + setUpHttpFederatedProtocol(FL_CHECKIN_RESULT); + doReturn(new ArrayList<FederatedComputeEncryptionKey>() {}).when(mMockKeyManager) + .getOrFetchActiveKeys(anyInt(), anyInt()); + setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); + + assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + verify(mSpyHttpFederatedProtocol).reportResult( + any(), eq(null)); + } + private void setUpExampleStoreService() { TestExampleStoreService testExampleStoreService = new TestExampleStoreService(); doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString()); @@ -606,7 +643,7 @@ public final class FederatedComputeWorkerTest { doReturn(immediateFuture(checkinResult)).when(mSpyHttpFederatedProtocol).issueCheckin(); doReturn(FluentFuture.from(immediateFuture(null))) .when(mSpyHttpFederatedProtocol) - .reportResult(any()); + .reportResult(any(), any()); } private static class TestExampleStoreService extends IExampleStoreService.Stub { |