aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCarol Zheng <cazheng@google.com>2023-12-13 18:11:17 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2023-12-13 18:11:17 +0000
commit1c22a514b272bedc5201fc00a580777a6ce67de5 (patch)
tree8a06d928b9939f62a18e0fe3b402fc04082e985d
parent0b74c968324906c1dc76c48407ce74b9b3ac2fdd (diff)
parent05fcc96163fd735cabb556ddd5808fa1fbcc1f51 (diff)
downloadOnDevicePersonalization-1c22a514b272bedc5201fc00a580777a6ce67de5.tar.gz
Merge "Encrypt FL/FA payload." into udc-mainline-prod
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java32
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java31
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java11
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java69
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java62
-rw-r--r--tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java42
-rw-r--r--tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java55
7 files changed, 264 insertions, 38 deletions
diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java
new file mode 100644
index 00000000..7398e1f5
--- /dev/null
+++ b/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.federatedcompute.services.encryption;
+
+/** Interface for crypto libraries to encrypt data */
+public interface Encrypter {
+
+ /**
+ * encrypt {@code plainText} to cipher text {@code byte[]}.
+ *
+ * @param publicKey the public key used for encryption
+ * @param plainText the plain text string to encrypt
+ * @param associatedData additional data used for encryption
+ * @return the encrypted ciphertext
+ */
+ byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData);
+
+}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java
new file mode 100644
index 00000000..dd6040c1
--- /dev/null
+++ b/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.federatedcompute.services.encryption;
+
+import com.android.federatedcompute.services.encryption.jni.HpkeJni;
+
+
+/**
+ * The implementation of HPKE (Hybrid Public Key Encryption) using BoringSSL JNI.
+ */
+public class HpkeJniEncrypter implements Encrypter {
+
+ @Override
+ public byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData) {
+ return HpkeJni.encrypt(publicKey, plainText, associatedData);
+ }
+}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java
index 61d719e9..749fbdc7 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java
@@ -21,6 +21,8 @@ import com.android.federatedcompute.internal.util.LogUtil;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.ByteString;
+import org.json.JSONObject;
+
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -49,6 +51,15 @@ public final class HttpClientUtil {
POST,
PUT,
}
+ public static final class FederatedComputePayloadDataContract {
+ public static final String KEY_ID = "keyId";
+
+ public static final String ENCRYPTED_PAYLOAD = "encryptedPayload";
+
+ public static final String ASSOCIATED_DATA_KEY = "associatedData";
+
+ public static final byte[] ASSOCIATED_DATA = new JSONObject().toString().getBytes();
+ }
/** Compresses the input data using Gzip. */
public static byte[] compressWithGzip(byte[] uncompressedData) {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
index 00fdf1fb..208f2be9 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
@@ -31,8 +31,12 @@ import static com.android.federatedcompute.services.http.HttpClientUtil.compress
import static com.android.federatedcompute.services.http.HttpClientUtil.uncompressWithGzip;
import android.os.Trace;
+import android.util.Base64;
import com.android.federatedcompute.internal.util.LogUtil;
+import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
+import com.android.federatedcompute.services.encryption.Encrypter;
+import com.android.federatedcompute.services.http.HttpClientUtil.FederatedComputePayloadDataContract;
import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod;
import com.android.federatedcompute.services.training.util.ComputationResult;
@@ -56,34 +60,44 @@ import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment;
import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction;
import com.google.protobuf.InvalidProtocolBufferException;
+import org.json.JSONObject;
+
import java.util.HashMap;
import java.util.UUID;
import java.util.concurrent.Callable;
+
/** Implements a single session of HTTP-based federated compute protocol. */
public final class HttpFederatedProtocol {
public static final String TAG = HttpFederatedProtocol.class.getSimpleName();
private final String mClientVersion;
private final String mPopulationName;
private final HttpClient mHttpClient;
+ private final ProtocolRequestCreator mTaskAssignmentRequestCreator;
+ private final Encrypter mEncrypter;
private String mTaskId;
private String mAggregationId;
private String mAssignmentId;
- private final ProtocolRequestCreator mTaskAssignmentRequestCreator;
@VisibleForTesting
HttpFederatedProtocol(
- String entryUri, String clientVersion, String populationName, HttpClient httpClient) {
+ String entryUri,
+ String clientVersion,
+ String populationName,
+ HttpClient httpClient,
+ Encrypter encrypter) {
this.mClientVersion = clientVersion;
this.mPopulationName = populationName;
this.mHttpClient = httpClient;
this.mTaskAssignmentRequestCreator = new ProtocolRequestCreator(entryUri, new HashMap<>());
+ this.mEncrypter = encrypter;
}
/** Creates a HttpFederatedProtocol object. */
public static HttpFederatedProtocol create(
- String entryUri, String clientVersion, String populationName) {
- return new HttpFederatedProtocol(entryUri, clientVersion, populationName, new HttpClient());
+ String entryUri, String clientVersion, String populationName, Encrypter encrypter) {
+ return new HttpFederatedProtocol(
+ entryUri, clientVersion, populationName, new HttpClient(), encrypter);
}
/** Helper function to perform check in and download federated task from remote servers. */
@@ -133,9 +147,12 @@ public final class HttpFederatedProtocol {
}
/** Helper functions to reporting result and upload result. */
- public FluentFuture<RejectionInfo> reportResult(ComputationResult computationResult) {
+ public FluentFuture<RejectionInfo> reportResult(
+ ComputationResult computationResult, FederatedComputeEncryptionKey encryptionKey) {
Trace.beginAsyncSection(TRACE_HTTP_REPORT_RESULT, 0);
- if (computationResult != null && computationResult.isResultSuccess()) {
+ if (computationResult != null
+ && computationResult.isResultSuccess()
+ && encryptionKey != null) {
return FluentFuture.from(performReportResult(computationResult))
.transformAsync(
reportResp -> {
@@ -147,7 +164,9 @@ public final class HttpFederatedProtocol {
}
return FluentFuture.from(
processReportResultResponseAndUploadResult(
- reportResultResponse, computationResult))
+ reportResultResponse,
+ computationResult,
+ encryptionKey))
.transform(
resp -> {
validateHttpResponseStatus(
@@ -304,8 +323,9 @@ public final class HttpFederatedProtocol {
private ListenableFuture<FederatedComputeHttpResponse>
processReportResultResponseAndUploadResult(
- ReportResultResponse reportResultResponse,
- ComputationResult computationResult) {
+ ReportResultResponse reportResultResponse,
+ ComputationResult computationResult,
+ FederatedComputeEncryptionKey encryptionKey) {
try {
Preconditions.checkArgument(
!computationResult.getOutputCheckpointFile().isEmpty(),
@@ -314,7 +334,11 @@ public final class HttpFederatedProtocol {
Preconditions.checkArgument(
!uploadInstruction.getUploadLocation().isEmpty(),
"UploadInstruction.upload_location must not be empty");
- byte[] outputBytes = readFileAsByteArray(computationResult.getOutputCheckpointFile());
+ byte[] outputBytes =
+ createEncryptedRequestBody(
+ computationResult.getOutputCheckpointFile(),
+ encryptionKey);
+ // Apply a top-level compression to the payload.
if (uploadInstruction.getCompressionFormat()
== ResourceCompressionFormat.RESOURCE_COMPRESSION_FORMAT_GZIP) {
outputBytes = compressWithGzip(outputBytes);
@@ -345,6 +369,31 @@ public final class HttpFederatedProtocol {
}
}
+ private byte[] createEncryptedRequestBody(
+ String filePath,
+ FederatedComputeEncryptionKey encryptionKey)
+ throws Exception {
+ byte[] fileOutputBytes = readFileAsByteArray(filePath);
+ fileOutputBytes = compressWithGzip(fileOutputBytes);
+ // encryption
+ byte[] publicKey = Base64.decode(encryptionKey.getPublicKey(), Base64.NO_WRAP);
+
+ byte[] encryptedOutput =
+ mEncrypter.encrypt(
+ publicKey, fileOutputBytes,
+ FederatedComputePayloadDataContract.ASSOCIATED_DATA);
+ // create payload
+ final JSONObject body = new JSONObject();
+ body.put(FederatedComputePayloadDataContract.KEY_ID,
+ encryptionKey.getKeyIdentifier());
+ body.put(FederatedComputePayloadDataContract.ENCRYPTED_PAYLOAD,
+ Base64.encodeToString(encryptedOutput, Base64.NO_WRAP));
+ body.put(FederatedComputePayloadDataContract.ASSOCIATED_DATA_KEY,
+ Base64.encodeToString(FederatedComputePayloadDataContract.ASSOCIATED_DATA,
+ Base64.NO_WRAP));
+ return body.toString().getBytes();
+ }
+
private ReportResultResponse getReportResultResponse(FederatedComputeHttpResponse httpResponse)
throws InvalidProtocolBufferException {
validateHttpResponseStatus("ReportResult", httpResponse);
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
index 4244aff4..cb13d87b 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
@@ -43,8 +43,11 @@ import com.android.federatedcompute.internal.util.AbstractServiceBinder;
import com.android.federatedcompute.internal.util.LogUtil;
import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.common.FileUtils;
+import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
+import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager;
+import com.android.federatedcompute.services.encryption.HpkeJniEncrypter;
import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
import com.android.federatedcompute.services.http.CheckinResult;
import com.android.federatedcompute.services.http.HttpFederatedProtocol;
@@ -77,7 +80,9 @@ import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -104,6 +109,10 @@ public class FederatedComputeWorker {
private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder;
private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder;
+ private FederatedComputeEncryptionKeyManager mEncryptionKeyManager;
+
+ private static final int NUM_ACTIVE_KEYS_TO_CHOOSE_FROM = 5;
+
@VisibleForTesting
public FederatedComputeWorker(
Context context,
@@ -111,13 +120,15 @@ public class FederatedComputeWorker {
TrainingConditionsChecker trainingConditionsChecker,
ComputationRunner computationRunner,
ResultCallbackHelper resultCallbackHelper,
- Injector injector) {
+ Injector injector,
+ FederatedComputeEncryptionKeyManager keyManager) {
this.mContext = context.getApplicationContext();
this.mJobManager = jobManager;
this.mTrainingConditionsChecker = trainingConditionsChecker;
this.mComputationRunner = computationRunner;
this.mInjector = injector;
this.mResultCallbackHelper = resultCallbackHelper;
+ this.mEncryptionKeyManager = keyManager;
}
/** Gets an instance of {@link FederatedComputeWorker}. */
@@ -133,7 +144,8 @@ public class FederatedComputeWorker {
TrainingConditionsChecker.getInstance(context),
new ComputationRunner(),
new ResultCallbackHelper(context),
- new Injector());
+ new Injector(),
+ FederatedComputeEncryptionKeyManager.getInstance(context));
}
}
}
@@ -252,13 +264,40 @@ public class FederatedComputeWorker {
ContributionResult.FAIL);
return Futures.immediateFuture(null);
}
- // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector.
- // Set active run's task name.
+
String taskName = checkinResult.getTaskAssignment().getTaskName();
Preconditions.checkArgument(!taskName.isEmpty(), "Task name should not be empty");
synchronized (mLock) {
mActiveRun.mTaskName = taskName;
}
+ // 2. Fetch Active keys to encrypt the computation result.
+ List<FederatedComputeEncryptionKey> activeKeys = mEncryptionKeyManager
+ .getOrFetchActiveKeys(NUM_ACTIVE_KEYS_TO_CHOOSE_FROM, FederatedComputeEncryptionKey
+ .KEY_TYPE_ENCRYPTION);
+ // select a random key
+ FederatedComputeEncryptionKey encryptionKey = activeKeys.isEmpty() ? null :
+ activeKeys.get(new Random().nextInt(activeKeys.size()));
+ if (encryptionKey == null) {
+ // no active keys to encrypt the FL/FA computation results, stop the computation run.
+ ComputationResult failedComputationResult = new ComputationResult(
+ null,
+ FLRunnerResult.newBuilder()
+ .setContributionResult(ContributionResult.FAIL)
+ .setErrorMessage("No active key available on device.")
+ .build(), null);
+ mJobManager.onTrainingCompleted(
+ run.mJobId,
+ run.mTask.populationName(),
+ run.mTask.getTrainingIntervalOptions(),
+ buildTaskRetry(RejectionInfo.newBuilder().build()),
+ ContributionResult.FAIL);
+ var unused = mHttpFederatedProtocol.reportResult(failedComputationResult, null);
+ return Futures.immediateFailedFuture(
+ new IllegalStateException("No active key available on device."));
+ }
+
+ // 3. Bind to client app implemented ExampleStoreService based on ExampleSelector.
+ // Set active run's task name.
ListenableFuture<IExampleStoreIterator> iteratorFuture =
getExampleStoreIterator(
run,
@@ -270,7 +309,7 @@ public class FederatedComputeWorker {
Futures.addCallback(
iteratorFuture, serverFailureReportCallback, getLightweightExecutor());
- // 3. Run federated learning or federated analytic depends on task type. Federated
+ // 4. Run federated learning or federated analytic depends on task type. Federated
// learning job will start a new isolated process to run TFLite training.
FluentFuture<ComputationResult> computationResultFuture =
FluentFuture.from(iteratorFuture)
@@ -281,10 +320,11 @@ public class FederatedComputeWorker {
computationResultFuture.addCallback(
serverFailureReportCallback, getLightweightExecutor());
- // 4. Report computation result to federated compute server.
+
+ // 5. Report computation result to federated compute server.
ListenableFuture<RejectionInfo> reportToServerFuture =
computationResultFuture.transformAsync(
- result -> mHttpFederatedProtocol.reportResult(result),
+ result -> mHttpFederatedProtocol.reportResult(result, encryptionKey),
getLightweightExecutor());
return Futures.whenAllSucceed(reportToServerFuture, computationResultFuture)
.call(
@@ -319,7 +359,7 @@ public class FederatedComputeWorker {
ContributionResult.FAIL);
return null;
}
- // 5. Publish computation result and consumed
+ // 6. Publish computation result and consumed
// examples to client implemented
// ResultHandlingService.
var unused =
@@ -358,7 +398,8 @@ public class FederatedComputeWorker {
.setErrorMessage(throwable.getMessage())
.build(),
null);
- var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult);
+ var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult,
+ null);
}
mNumberOfInvocations++;
}
@@ -451,7 +492,8 @@ public class FederatedComputeWorker {
@VisibleForTesting
HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) {
- return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName);
+ return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName,
+ new HpkeJniEncrypter());
}
private ExampleSelector getExampleSelector(CheckinResult checkinResult) {
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java
index e483cf22..48c35f9f 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java
@@ -43,12 +43,16 @@ import android.net.Uri;
import androidx.test.core.app.ApplicationProvider;
+import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
+import com.android.federatedcompute.services.encryption.HpkeJniEncrypter;
import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod;
import com.android.federatedcompute.services.testutils.TrainingTestUtil;
import com.android.federatedcompute.services.training.util.ComputationResult;
+import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Range;
import com.google.intelligence.fcp.client.FLRunnerResult;
import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import com.google.internal.federatedcompute.v1.ClientVersion;
@@ -101,6 +105,14 @@ public final class HttpFederatedProtocolTest {
private static final String ASSIGNMENT_ID = "assignment-id";
private static final String AGGREGATION_ID = "aggregation-id";
private static final String OCTET_STREAM = "application/octet-stream";
+
+ private static final FederatedComputeEncryptionKey ENCRYPTION_KEY =
+ new FederatedComputeEncryptionKey.Builder()
+ .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=")
+ .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887")
+ .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION)
+ .setCreationTime(1L)
+ .setExpiryTime(1L).build();
private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build();
@@ -144,7 +156,8 @@ public final class HttpFederatedProtocolTest {
TASK_ASSIGNMENT_TARGET_URI,
CLIENT_VERSION,
POPULATION_NAME,
- mMockHttpClient);
+ mMockHttpClient,
+ new HpkeJniEncrypter());
}
@Test
@@ -280,7 +293,7 @@ public final class HttpFederatedProtocolTest {
// Setup task id, aggregation id for report result.
mHttpFederatedProtocol.issueCheckin().get();
- mHttpFederatedProtocol.reportResult(computationResult).get();
+ mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get();
// Verify ReportResult request.
List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues();
@@ -301,7 +314,7 @@ public final class HttpFederatedProtocolTest {
// Setup task id, aggregation id for report result.
mHttpFederatedProtocol.issueCheckin().get();
- mHttpFederatedProtocol.reportResult(computationResult).get();
+ mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get();
// Verify ReportResult request.
List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues();
@@ -322,14 +335,23 @@ public final class HttpFederatedProtocolTest {
assertThat(actualDataUploadRequest.getUri()).isEqualTo(UPLOAD_LOCATION_URI);
assertThat(acutalReportResultRequest.getHttpMethod()).isEqualTo(HttpMethod.PUT);
expectedHeaders = new HashMap<>();
+ expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM);
if (mSupportCompression) {
- expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(339));
expectedHeaders.put(CONTENT_ENCODING_HDR, GZIP_ENCODING_HDR);
- } else {
- expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(31846));
}
- expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM);
+
+ int actualContentLength = Integer
+ .parseInt(actualDataUploadRequest.getExtraHeaders().remove(CONTENT_LENGTH_HDR));
assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders);
+ // The encryption is non-deterministic with BoringSSL JNI.
+ // Only check the range of the content.
+ if (mSupportCompression) {
+ assertThat(actualContentLength)
+ .isIn(Range.range(500, BoundType.CLOSED, 550, BoundType.CLOSED));
+ } else {
+ assertThat(actualContentLength)
+ .isIn(Range.range(600, BoundType.CLOSED, 650, BoundType.CLOSED));
+ }
}
@Test
@@ -350,7 +372,8 @@ public final class HttpFederatedProtocolTest {
ExecutionException exception =
assertThrows(
ExecutionException.class,
- () -> mHttpFederatedProtocol.reportResult(computationResult).get());
+ () -> mHttpFederatedProtocol.reportResult(computationResult,
+ ENCRYPTION_KEY).get());
assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class);
assertThat(exception.getCause()).hasMessageThat().isEqualTo("ReportResult failed: 503");
@@ -374,7 +397,8 @@ public final class HttpFederatedProtocolTest {
ExecutionException exception =
assertThrows(
ExecutionException.class,
- () -> mHttpFederatedProtocol.reportResult(computationResult).get());
+ () -> mHttpFederatedProtocol.reportResult(computationResult,
+ ENCRYPTION_KEY).get());
assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class);
assertThat(exception.getCause()).hasMessageThat().isEqualTo("Upload result failed: 503");
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
index 71fe9998..8f98c6a0 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
@@ -48,11 +48,14 @@ import android.os.RemoteException;
import androidx.test.core.app.ApplicationProvider;
import com.android.federatedcompute.services.common.Constants;
+import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.fbs.SchedulingMode;
import com.android.federatedcompute.services.data.fbs.SchedulingReason;
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
+import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager;
+import com.android.federatedcompute.services.encryption.HpkeJniEncrypter;
import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
import com.android.federatedcompute.services.http.CheckinResult;
import com.android.federatedcompute.services.http.HttpFederatedProtocol;
@@ -104,6 +107,7 @@ import org.tensorflow.example.Features;
import java.util.ArrayList;
import java.util.EnumSet;
+import java.util.List;
import java.util.concurrent.ExecutionException;
@RunWith(JUnit4.class)
@@ -193,6 +197,16 @@ public final class FederatedComputeWorkerTest {
@Mock private ComputationRunner mMockComputationRunner;
private ResultCallbackHelper mSpyResultCallbackHelper;
+ @Mock private FederatedComputeEncryptionKeyManager mMockKeyManager;
+
+ private static final FederatedComputeEncryptionKey ENCRYPTION_KEY =
+ new FederatedComputeEncryptionKey.Builder()
+ .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=")
+ .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887")
+ .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION)
+ .setCreationTime(1L)
+ .setExpiryTime(1L).build();
+
private static byte[] createTrainingConstraints(
boolean requiresSchedulerIdle,
boolean requiresSchedulerBatteryNotLow,
@@ -220,7 +234,11 @@ public final class FederatedComputeWorkerTest {
mContext = ApplicationProvider.getApplicationContext();
mSpyHttpFederatedProtocol =
Mockito.spy(
- HttpFederatedProtocol.create(SERVER_ADDRESS, "1.0.0.1", POPULATION_NAME));
+ HttpFederatedProtocol.create(
+ SERVER_ADDRESS,
+ "1.0.0.1",
+ POPULATION_NAME,
+ new HpkeJniEncrypter()));
mSpyResultCallbackHelper = Mockito.spy(new ResultCallbackHelper(mContext));
mSpyWorker =
Mockito.spy(
@@ -230,7 +248,8 @@ public final class FederatedComputeWorkerTest {
mTrainingConditionsChecker,
mMockComputationRunner,
mSpyResultCallbackHelper,
- new TestInjector()));
+ new TestInjector(),
+ mMockKeyManager));
when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any()))
.thenReturn(EnumSet.noneOf(Condition.class));
doReturn(Futures.immediateFuture(CallbackResult.SUCCESS))
@@ -251,6 +270,8 @@ public final class FederatedComputeWorkerTest {
any(),
any()))
.thenReturn(FL_RUNNER_SUCCESS_RESULT);
+ doReturn(List.of(ENCRYPTION_KEY)).when(mMockKeyManager).getOrFetchActiveKeys(anyInt(),
+ anyInt());
}
@Test
@@ -289,7 +310,7 @@ public final class FederatedComputeWorkerTest {
.issueCheckin();
doReturn(FluentFuture.from(immediateFuture(null)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any());
+ .reportResult(any(), any());
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -324,7 +345,7 @@ public final class FederatedComputeWorkerTest {
.issueCheckin();
doReturn(FluentFuture.from(immediateFuture(REJECTION_INFO)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any());
+ .reportResult(any(), any());
doCallRealMethod().when(mSpyResultCallbackHelper).callHandleResult(any(), any(), any());
ArgumentCaptor<ComputationResult> computationResultCaptor =
ArgumentCaptor.forClass(ComputationResult.class);
@@ -360,7 +381,7 @@ public final class FederatedComputeWorkerTest {
"report result failed",
new IllegalStateException("http 404")))))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any());
+ .reportResult(any(), any());
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -387,7 +408,7 @@ public final class FederatedComputeWorkerTest {
anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
verify(mSpyWorker, times(0)).unbindFromExampleStoreService();
verify(mSpyHttpFederatedProtocol, times(1))
- .reportResult(computationResultCaptor.capture());
+ .reportResult(computationResultCaptor.capture(), any());
ComputationResult computationResult = computationResultCaptor.getValue();
assertNotNull(computationResult.getFlRunnerResult());
assertEquals(
@@ -429,7 +450,7 @@ public final class FederatedComputeWorkerTest {
setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
doReturn(FluentFuture.from(immediateFuture(null)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any());
+ .reportResult(any(), any());
when(mMockComputationRunner.runTaskWithNativeRunner(
anyString(),
anyString(),
@@ -452,7 +473,7 @@ public final class FederatedComputeWorkerTest {
anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
verify(mSpyWorker).unbindFromExampleStoreService();
verify(mSpyHttpFederatedProtocol, times(1))
- .reportResult(computationResultCaptor.capture());
+ .reportResult(computationResultCaptor.capture(), any());
ComputationResult computationResult = computationResultCaptor.getValue();
assertNotNull(computationResult.getFlRunnerResult());
assertEquals(
@@ -596,6 +617,22 @@ public final class FederatedComputeWorkerTest {
verify(mSpyWorker).unbindFromExampleStoreService();
}
+ @Test
+ public void testRunFLComputation_noKey_throws() throws Exception {
+ setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
+ doReturn(new ArrayList<FederatedComputeEncryptionKey>() {}).when(mMockKeyManager)
+ .getOrFetchActiveKeys(anyInt(), anyInt());
+ setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+
+ assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ verify(mSpyHttpFederatedProtocol).reportResult(
+ any(), eq(null));
+ }
+
private void setUpExampleStoreService() {
TestExampleStoreService testExampleStoreService = new TestExampleStoreService();
doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString());
@@ -606,7 +643,7 @@ public final class FederatedComputeWorkerTest {
doReturn(immediateFuture(checkinResult)).when(mSpyHttpFederatedProtocol).issueCheckin();
doReturn(FluentFuture.from(immediateFuture(null)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any());
+ .reportResult(any(), any());
}
private static class TestExampleStoreService extends IExampleStoreService.Stub {