aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java32
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java31
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java11
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java69
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java64
-rw-r--r--tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java42
-rw-r--r--tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java55
-rw-r--r--tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java64
-rw-r--r--tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java37
-rw-r--r--tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java36
10 files changed, 176 insertions, 265 deletions
diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java
deleted file mode 100644
index 7398e1f5..00000000
--- a/federatedcompute/src/com/android/federatedcompute/services/encryption/Encrypter.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.federatedcompute.services.encryption;
-
-/** Interface for crypto libraries to encrypt data */
-public interface Encrypter {
-
- /**
- * encrypt {@code plainText} to cipher text {@code byte[]}.
- *
- * @param publicKey the public key used for encryption
- * @param plainText the plain text string to encrypt
- * @param associatedData additional data used for encryption
- * @return the encrypted ciphertext
- */
- byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData);
-
-}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java b/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java
deleted file mode 100644
index dd6040c1..00000000
--- a/federatedcompute/src/com/android/federatedcompute/services/encryption/HpkeJniEncrypter.java
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.federatedcompute.services.encryption;
-
-import com.android.federatedcompute.services.encryption.jni.HpkeJni;
-
-
-/**
- * The implementation of HPKE (Hybrid Public Key Encryption) using BoringSSL JNI.
- */
-public class HpkeJniEncrypter implements Encrypter {
-
- @Override
- public byte[] encrypt(byte[] publicKey, byte[] plainText, byte[] associatedData) {
- return HpkeJni.encrypt(publicKey, plainText, associatedData);
- }
-}
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java
index 749fbdc7..61d719e9 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpClientUtil.java
@@ -21,8 +21,6 @@ import com.android.federatedcompute.internal.util.LogUtil;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.ByteString;
-import org.json.JSONObject;
-
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -51,15 +49,6 @@ public final class HttpClientUtil {
POST,
PUT,
}
- public static final class FederatedComputePayloadDataContract {
- public static final String KEY_ID = "keyId";
-
- public static final String ENCRYPTED_PAYLOAD = "encryptedPayload";
-
- public static final String ASSOCIATED_DATA_KEY = "associatedData";
-
- public static final byte[] ASSOCIATED_DATA = new JSONObject().toString().getBytes();
- }
/** Compresses the input data using Gzip. */
public static byte[] compressWithGzip(byte[] uncompressedData) {
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
index 208f2be9..00fdf1fb 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
@@ -31,12 +31,8 @@ import static com.android.federatedcompute.services.http.HttpClientUtil.compress
import static com.android.federatedcompute.services.http.HttpClientUtil.uncompressWithGzip;
import android.os.Trace;
-import android.util.Base64;
import com.android.federatedcompute.internal.util.LogUtil;
-import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
-import com.android.federatedcompute.services.encryption.Encrypter;
-import com.android.federatedcompute.services.http.HttpClientUtil.FederatedComputePayloadDataContract;
import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod;
import com.android.federatedcompute.services.training.util.ComputationResult;
@@ -60,44 +56,34 @@ import com.google.ondevicepersonalization.federatedcompute.proto.TaskAssignment;
import com.google.ondevicepersonalization.federatedcompute.proto.UploadInstruction;
import com.google.protobuf.InvalidProtocolBufferException;
-import org.json.JSONObject;
-
import java.util.HashMap;
import java.util.UUID;
import java.util.concurrent.Callable;
-
/** Implements a single session of HTTP-based federated compute protocol. */
public final class HttpFederatedProtocol {
public static final String TAG = HttpFederatedProtocol.class.getSimpleName();
private final String mClientVersion;
private final String mPopulationName;
private final HttpClient mHttpClient;
- private final ProtocolRequestCreator mTaskAssignmentRequestCreator;
- private final Encrypter mEncrypter;
private String mTaskId;
private String mAggregationId;
private String mAssignmentId;
+ private final ProtocolRequestCreator mTaskAssignmentRequestCreator;
@VisibleForTesting
HttpFederatedProtocol(
- String entryUri,
- String clientVersion,
- String populationName,
- HttpClient httpClient,
- Encrypter encrypter) {
+ String entryUri, String clientVersion, String populationName, HttpClient httpClient) {
this.mClientVersion = clientVersion;
this.mPopulationName = populationName;
this.mHttpClient = httpClient;
this.mTaskAssignmentRequestCreator = new ProtocolRequestCreator(entryUri, new HashMap<>());
- this.mEncrypter = encrypter;
}
/** Creates a HttpFederatedProtocol object. */
public static HttpFederatedProtocol create(
- String entryUri, String clientVersion, String populationName, Encrypter encrypter) {
- return new HttpFederatedProtocol(
- entryUri, clientVersion, populationName, new HttpClient(), encrypter);
+ String entryUri, String clientVersion, String populationName) {
+ return new HttpFederatedProtocol(entryUri, clientVersion, populationName, new HttpClient());
}
/** Helper function to perform check in and download federated task from remote servers. */
@@ -147,12 +133,9 @@ public final class HttpFederatedProtocol {
}
/** Helper functions to reporting result and upload result. */
- public FluentFuture<RejectionInfo> reportResult(
- ComputationResult computationResult, FederatedComputeEncryptionKey encryptionKey) {
+ public FluentFuture<RejectionInfo> reportResult(ComputationResult computationResult) {
Trace.beginAsyncSection(TRACE_HTTP_REPORT_RESULT, 0);
- if (computationResult != null
- && computationResult.isResultSuccess()
- && encryptionKey != null) {
+ if (computationResult != null && computationResult.isResultSuccess()) {
return FluentFuture.from(performReportResult(computationResult))
.transformAsync(
reportResp -> {
@@ -164,9 +147,7 @@ public final class HttpFederatedProtocol {
}
return FluentFuture.from(
processReportResultResponseAndUploadResult(
- reportResultResponse,
- computationResult,
- encryptionKey))
+ reportResultResponse, computationResult))
.transform(
resp -> {
validateHttpResponseStatus(
@@ -323,9 +304,8 @@ public final class HttpFederatedProtocol {
private ListenableFuture<FederatedComputeHttpResponse>
processReportResultResponseAndUploadResult(
- ReportResultResponse reportResultResponse,
- ComputationResult computationResult,
- FederatedComputeEncryptionKey encryptionKey) {
+ ReportResultResponse reportResultResponse,
+ ComputationResult computationResult) {
try {
Preconditions.checkArgument(
!computationResult.getOutputCheckpointFile().isEmpty(),
@@ -334,11 +314,7 @@ public final class HttpFederatedProtocol {
Preconditions.checkArgument(
!uploadInstruction.getUploadLocation().isEmpty(),
"UploadInstruction.upload_location must not be empty");
- byte[] outputBytes =
- createEncryptedRequestBody(
- computationResult.getOutputCheckpointFile(),
- encryptionKey);
- // Apply a top-level compression to the payload.
+ byte[] outputBytes = readFileAsByteArray(computationResult.getOutputCheckpointFile());
if (uploadInstruction.getCompressionFormat()
== ResourceCompressionFormat.RESOURCE_COMPRESSION_FORMAT_GZIP) {
outputBytes = compressWithGzip(outputBytes);
@@ -369,31 +345,6 @@ public final class HttpFederatedProtocol {
}
}
- private byte[] createEncryptedRequestBody(
- String filePath,
- FederatedComputeEncryptionKey encryptionKey)
- throws Exception {
- byte[] fileOutputBytes = readFileAsByteArray(filePath);
- fileOutputBytes = compressWithGzip(fileOutputBytes);
- // encryption
- byte[] publicKey = Base64.decode(encryptionKey.getPublicKey(), Base64.NO_WRAP);
-
- byte[] encryptedOutput =
- mEncrypter.encrypt(
- publicKey, fileOutputBytes,
- FederatedComputePayloadDataContract.ASSOCIATED_DATA);
- // create payload
- final JSONObject body = new JSONObject();
- body.put(FederatedComputePayloadDataContract.KEY_ID,
- encryptionKey.getKeyIdentifier());
- body.put(FederatedComputePayloadDataContract.ENCRYPTED_PAYLOAD,
- Base64.encodeToString(encryptedOutput, Base64.NO_WRAP));
- body.put(FederatedComputePayloadDataContract.ASSOCIATED_DATA_KEY,
- Base64.encodeToString(FederatedComputePayloadDataContract.ASSOCIATED_DATA,
- Base64.NO_WRAP));
- return body.toString().getBytes();
- }
-
private ReportResultResponse getReportResultResponse(FederatedComputeHttpResponse httpResponse)
throws InvalidProtocolBufferException {
validateHttpResponseStatus("ReportResult", httpResponse);
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
index cb13d87b..a6583bab 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
@@ -43,11 +43,8 @@ import com.android.federatedcompute.internal.util.AbstractServiceBinder;
import com.android.federatedcompute.internal.util.LogUtil;
import com.android.federatedcompute.services.common.Constants;
import com.android.federatedcompute.services.common.FileUtils;
-import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
-import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager;
-import com.android.federatedcompute.services.encryption.HpkeJniEncrypter;
import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
import com.android.federatedcompute.services.http.CheckinResult;
import com.android.federatedcompute.services.http.HttpFederatedProtocol;
@@ -80,9 +77,7 @@ import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.List;
import java.util.Objects;
-import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -109,10 +104,6 @@ public class FederatedComputeWorker {
private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder;
private AbstractServiceBinder<IIsolatedTrainingService> mIsolatedTrainingServiceBinder;
- private FederatedComputeEncryptionKeyManager mEncryptionKeyManager;
-
- private static final int NUM_ACTIVE_KEYS_TO_CHOOSE_FROM = 5;
-
@VisibleForTesting
public FederatedComputeWorker(
Context context,
@@ -120,15 +111,13 @@ public class FederatedComputeWorker {
TrainingConditionsChecker trainingConditionsChecker,
ComputationRunner computationRunner,
ResultCallbackHelper resultCallbackHelper,
- Injector injector,
- FederatedComputeEncryptionKeyManager keyManager) {
+ Injector injector) {
this.mContext = context.getApplicationContext();
this.mJobManager = jobManager;
this.mTrainingConditionsChecker = trainingConditionsChecker;
this.mComputationRunner = computationRunner;
this.mInjector = injector;
this.mResultCallbackHelper = resultCallbackHelper;
- this.mEncryptionKeyManager = keyManager;
}
/** Gets an instance of {@link FederatedComputeWorker}. */
@@ -144,8 +133,7 @@ public class FederatedComputeWorker {
TrainingConditionsChecker.getInstance(context),
new ComputationRunner(),
new ResultCallbackHelper(context),
- new Injector(),
- FederatedComputeEncryptionKeyManager.getInstance(context));
+ new Injector());
}
}
}
@@ -256,6 +244,8 @@ public class FederatedComputeWorker {
TrainingRun run, CheckinResult checkinResult) {
// Stop processing if have rejection Info
if (checkinResult.getRejectionInfo() != null) {
+ LogUtil.d(TAG, "job %d was rejected during check in, reason %s",
+ run.mTask.jobId(), checkinResult.getRejectionInfo().getReason());
mJobManager.onTrainingCompleted(
run.mTask.jobId(),
run.mTask.populationName(),
@@ -264,40 +254,13 @@ public class FederatedComputeWorker {
ContributionResult.FAIL);
return Futures.immediateFuture(null);
}
-
+ // 2. Bind to client app implemented ExampleStoreService based on ExampleSelector.
+ // Set active run's task name.
String taskName = checkinResult.getTaskAssignment().getTaskName();
Preconditions.checkArgument(!taskName.isEmpty(), "Task name should not be empty");
synchronized (mLock) {
mActiveRun.mTaskName = taskName;
}
- // 2. Fetch Active keys to encrypt the computation result.
- List<FederatedComputeEncryptionKey> activeKeys = mEncryptionKeyManager
- .getOrFetchActiveKeys(NUM_ACTIVE_KEYS_TO_CHOOSE_FROM, FederatedComputeEncryptionKey
- .KEY_TYPE_ENCRYPTION);
- // select a random key
- FederatedComputeEncryptionKey encryptionKey = activeKeys.isEmpty() ? null :
- activeKeys.get(new Random().nextInt(activeKeys.size()));
- if (encryptionKey == null) {
- // no active keys to encrypt the FL/FA computation results, stop the computation run.
- ComputationResult failedComputationResult = new ComputationResult(
- null,
- FLRunnerResult.newBuilder()
- .setContributionResult(ContributionResult.FAIL)
- .setErrorMessage("No active key available on device.")
- .build(), null);
- mJobManager.onTrainingCompleted(
- run.mJobId,
- run.mTask.populationName(),
- run.mTask.getTrainingIntervalOptions(),
- buildTaskRetry(RejectionInfo.newBuilder().build()),
- ContributionResult.FAIL);
- var unused = mHttpFederatedProtocol.reportResult(failedComputationResult, null);
- return Futures.immediateFailedFuture(
- new IllegalStateException("No active key available on device."));
- }
-
- // 3. Bind to client app implemented ExampleStoreService based on ExampleSelector.
- // Set active run's task name.
ListenableFuture<IExampleStoreIterator> iteratorFuture =
getExampleStoreIterator(
run,
@@ -309,7 +272,7 @@ public class FederatedComputeWorker {
Futures.addCallback(
iteratorFuture, serverFailureReportCallback, getLightweightExecutor());
- // 4. Run federated learning or federated analytic depends on task type. Federated
+ // 3. Run federated learning or federated analytic depends on task type. Federated
// learning job will start a new isolated process to run TFLite training.
FluentFuture<ComputationResult> computationResultFuture =
FluentFuture.from(iteratorFuture)
@@ -320,11 +283,10 @@ public class FederatedComputeWorker {
computationResultFuture.addCallback(
serverFailureReportCallback, getLightweightExecutor());
-
- // 5. Report computation result to federated compute server.
+ // 4. Report computation result to federated compute server.
ListenableFuture<RejectionInfo> reportToServerFuture =
computationResultFuture.transformAsync(
- result -> mHttpFederatedProtocol.reportResult(result, encryptionKey),
+ result -> mHttpFederatedProtocol.reportResult(result),
getLightweightExecutor());
return Futures.whenAllSucceed(reportToServerFuture, computationResultFuture)
.call(
@@ -359,7 +321,7 @@ public class FederatedComputeWorker {
ContributionResult.FAIL);
return null;
}
- // 6. Publish computation result and consumed
+ // 5. Publish computation result and consumed
// examples to client implemented
// ResultHandlingService.
var unused =
@@ -398,8 +360,7 @@ public class FederatedComputeWorker {
.setErrorMessage(throwable.getMessage())
.build(),
null);
- var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult,
- null);
+ var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult);
}
mNumberOfInvocations++;
}
@@ -492,8 +453,7 @@ public class FederatedComputeWorker {
@VisibleForTesting
HttpFederatedProtocol getHttpFederatedProtocol(String serverAddress, String populationName) {
- return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName,
- new HpkeJniEncrypter());
+ return HttpFederatedProtocol.create(serverAddress, "1.0.0.1", populationName);
}
private ExampleSelector getExampleSelector(CheckinResult checkinResult) {
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java
index 48c35f9f..e483cf22 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/http/HttpFederatedProtocolTest.java
@@ -43,16 +43,12 @@ import android.net.Uri;
import androidx.test.core.app.ApplicationProvider;
-import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
-import com.android.federatedcompute.services.encryption.HpkeJniEncrypter;
import com.android.federatedcompute.services.http.HttpClientUtil.HttpMethod;
import com.android.federatedcompute.services.testutils.TrainingTestUtil;
import com.android.federatedcompute.services.training.util.ComputationResult;
-import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Range;
import com.google.intelligence.fcp.client.FLRunnerResult;
import com.google.intelligence.fcp.client.FLRunnerResult.ContributionResult;
import com.google.internal.federatedcompute.v1.ClientVersion;
@@ -105,14 +101,6 @@ public final class HttpFederatedProtocolTest {
private static final String ASSIGNMENT_ID = "assignment-id";
private static final String AGGREGATION_ID = "aggregation-id";
private static final String OCTET_STREAM = "application/octet-stream";
-
- private static final FederatedComputeEncryptionKey ENCRYPTION_KEY =
- new FederatedComputeEncryptionKey.Builder()
- .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=")
- .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887")
- .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION)
- .setCreationTime(1L)
- .setExpiryTime(1L).build();
private static final FLRunnerResult FL_RUNNER_SUCCESS_RESULT =
FLRunnerResult.newBuilder().setContributionResult(ContributionResult.SUCCESS).build();
@@ -156,8 +144,7 @@ public final class HttpFederatedProtocolTest {
TASK_ASSIGNMENT_TARGET_URI,
CLIENT_VERSION,
POPULATION_NAME,
- mMockHttpClient,
- new HpkeJniEncrypter());
+ mMockHttpClient);
}
@Test
@@ -293,7 +280,7 @@ public final class HttpFederatedProtocolTest {
// Setup task id, aggregation id for report result.
mHttpFederatedProtocol.issueCheckin().get();
- mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get();
+ mHttpFederatedProtocol.reportResult(computationResult).get();
// Verify ReportResult request.
List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues();
@@ -314,7 +301,7 @@ public final class HttpFederatedProtocolTest {
// Setup task id, aggregation id for report result.
mHttpFederatedProtocol.issueCheckin().get();
- mHttpFederatedProtocol.reportResult(computationResult, ENCRYPTION_KEY).get();
+ mHttpFederatedProtocol.reportResult(computationResult).get();
// Verify ReportResult request.
List<FederatedComputeHttpRequest> actualHttpRequests = mHttpRequestCaptor.getAllValues();
@@ -335,23 +322,14 @@ public final class HttpFederatedProtocolTest {
assertThat(actualDataUploadRequest.getUri()).isEqualTo(UPLOAD_LOCATION_URI);
assertThat(acutalReportResultRequest.getHttpMethod()).isEqualTo(HttpMethod.PUT);
expectedHeaders = new HashMap<>();
- expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM);
if (mSupportCompression) {
+ expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(339));
expectedHeaders.put(CONTENT_ENCODING_HDR, GZIP_ENCODING_HDR);
- }
-
- int actualContentLength = Integer
- .parseInt(actualDataUploadRequest.getExtraHeaders().remove(CONTENT_LENGTH_HDR));
- assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders);
- // The encryption is non-deterministic with BoringSSL JNI.
- // Only check the range of the content.
- if (mSupportCompression) {
- assertThat(actualContentLength)
- .isIn(Range.range(500, BoundType.CLOSED, 550, BoundType.CLOSED));
} else {
- assertThat(actualContentLength)
- .isIn(Range.range(600, BoundType.CLOSED, 650, BoundType.CLOSED));
+ expectedHeaders.put(CONTENT_LENGTH_HDR, String.valueOf(31846));
}
+ expectedHeaders.put(CONTENT_TYPE_HDR, OCTET_STREAM);
+ assertThat(actualDataUploadRequest.getExtraHeaders()).isEqualTo(expectedHeaders);
}
@Test
@@ -372,8 +350,7 @@ public final class HttpFederatedProtocolTest {
ExecutionException exception =
assertThrows(
ExecutionException.class,
- () -> mHttpFederatedProtocol.reportResult(computationResult,
- ENCRYPTION_KEY).get());
+ () -> mHttpFederatedProtocol.reportResult(computationResult).get());
assertThat(exception.getCause()).isInstanceOf(IllegalStateException.class);
assertThat(exception.getCause()).hasMessageThat().isEqualTo("ReportResult failed: 503");
@@ -397,8 +374,7 @@ public final class HttpFederatedProtocolTest {
ExecutionException exception =
assertThrows(
ExecutionException.class,
- () -> mHttpFederatedProtocol.reportResult(computationResult,
- ENCRYPTION_KEY).get());
+ () -> mHttpFederatedProtocol.reportResult(computationResult).get());
assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class);
assertThat(exception.getCause()).hasMessageThat().isEqualTo("Upload result failed: 503");
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
index 8f98c6a0..71fe9998 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
@@ -48,14 +48,11 @@ import android.os.RemoteException;
import androidx.test.core.app.ApplicationProvider;
import com.android.federatedcompute.services.common.Constants;
-import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
import com.android.federatedcompute.services.data.FederatedTrainingTask;
import com.android.federatedcompute.services.data.fbs.SchedulingMode;
import com.android.federatedcompute.services.data.fbs.SchedulingReason;
import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
-import com.android.federatedcompute.services.encryption.FederatedComputeEncryptionKeyManager;
-import com.android.federatedcompute.services.encryption.HpkeJniEncrypter;
import com.android.federatedcompute.services.examplestore.ExampleConsumptionRecorder;
import com.android.federatedcompute.services.http.CheckinResult;
import com.android.federatedcompute.services.http.HttpFederatedProtocol;
@@ -107,7 +104,6 @@ import org.tensorflow.example.Features;
import java.util.ArrayList;
import java.util.EnumSet;
-import java.util.List;
import java.util.concurrent.ExecutionException;
@RunWith(JUnit4.class)
@@ -197,16 +193,6 @@ public final class FederatedComputeWorkerTest {
@Mock private ComputationRunner mMockComputationRunner;
private ResultCallbackHelper mSpyResultCallbackHelper;
- @Mock private FederatedComputeEncryptionKeyManager mMockKeyManager;
-
- private static final FederatedComputeEncryptionKey ENCRYPTION_KEY =
- new FederatedComputeEncryptionKey.Builder()
- .setPublicKey("rSJBSUYG0ebvfW1AXCWO0CMGMJhDzpfQm3eLyw1uxX8=")
- .setKeyIdentifier("0962201a-5abd-4e25-a486-2c7bd1ee1887")
- .setKeyType(FederatedComputeEncryptionKey.KEY_TYPE_ENCRYPTION)
- .setCreationTime(1L)
- .setExpiryTime(1L).build();
-
private static byte[] createTrainingConstraints(
boolean requiresSchedulerIdle,
boolean requiresSchedulerBatteryNotLow,
@@ -234,11 +220,7 @@ public final class FederatedComputeWorkerTest {
mContext = ApplicationProvider.getApplicationContext();
mSpyHttpFederatedProtocol =
Mockito.spy(
- HttpFederatedProtocol.create(
- SERVER_ADDRESS,
- "1.0.0.1",
- POPULATION_NAME,
- new HpkeJniEncrypter()));
+ HttpFederatedProtocol.create(SERVER_ADDRESS, "1.0.0.1", POPULATION_NAME));
mSpyResultCallbackHelper = Mockito.spy(new ResultCallbackHelper(mContext));
mSpyWorker =
Mockito.spy(
@@ -248,8 +230,7 @@ public final class FederatedComputeWorkerTest {
mTrainingConditionsChecker,
mMockComputationRunner,
mSpyResultCallbackHelper,
- new TestInjector(),
- mMockKeyManager));
+ new TestInjector()));
when(mTrainingConditionsChecker.checkAllConditionsForFlTraining(any()))
.thenReturn(EnumSet.noneOf(Condition.class));
doReturn(Futures.immediateFuture(CallbackResult.SUCCESS))
@@ -270,8 +251,6 @@ public final class FederatedComputeWorkerTest {
any(),
any()))
.thenReturn(FL_RUNNER_SUCCESS_RESULT);
- doReturn(List.of(ENCRYPTION_KEY)).when(mMockKeyManager).getOrFetchActiveKeys(anyInt(),
- anyInt());
}
@Test
@@ -310,7 +289,7 @@ public final class FederatedComputeWorkerTest {
.issueCheckin();
doReturn(FluentFuture.from(immediateFuture(null)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any(), any());
+ .reportResult(any());
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -345,7 +324,7 @@ public final class FederatedComputeWorkerTest {
.issueCheckin();
doReturn(FluentFuture.from(immediateFuture(REJECTION_INFO)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any(), any());
+ .reportResult(any());
doCallRealMethod().when(mSpyResultCallbackHelper).callHandleResult(any(), any(), any());
ArgumentCaptor<ComputationResult> computationResultCaptor =
ArgumentCaptor.forClass(ComputationResult.class);
@@ -381,7 +360,7 @@ public final class FederatedComputeWorkerTest {
"report result failed",
new IllegalStateException("http 404")))))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any(), any());
+ .reportResult(any());
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -408,7 +387,7 @@ public final class FederatedComputeWorkerTest {
anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
verify(mSpyWorker, times(0)).unbindFromExampleStoreService();
verify(mSpyHttpFederatedProtocol, times(1))
- .reportResult(computationResultCaptor.capture(), any());
+ .reportResult(computationResultCaptor.capture());
ComputationResult computationResult = computationResultCaptor.getValue();
assertNotNull(computationResult.getFlRunnerResult());
assertEquals(
@@ -450,7 +429,7 @@ public final class FederatedComputeWorkerTest {
setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
doReturn(FluentFuture.from(immediateFuture(null)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any(), any());
+ .reportResult(any());
when(mMockComputationRunner.runTaskWithNativeRunner(
anyString(),
anyString(),
@@ -473,7 +452,7 @@ public final class FederatedComputeWorkerTest {
anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
verify(mSpyWorker).unbindFromExampleStoreService();
verify(mSpyHttpFederatedProtocol, times(1))
- .reportResult(computationResultCaptor.capture(), any());
+ .reportResult(computationResultCaptor.capture());
ComputationResult computationResult = computationResultCaptor.getValue();
assertNotNull(computationResult.getFlRunnerResult());
assertEquals(
@@ -617,22 +596,6 @@ public final class FederatedComputeWorkerTest {
verify(mSpyWorker).unbindFromExampleStoreService();
}
- @Test
- public void testRunFLComputation_noKey_throws() throws Exception {
- setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
- doReturn(new ArrayList<FederatedComputeEncryptionKey>() {}).when(mMockKeyManager)
- .getOrFetchActiveKeys(anyInt(), anyInt());
- setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
-
- assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
-
- verify(mMockJobManager)
- .onTrainingCompleted(
- anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
- verify(mSpyHttpFederatedProtocol).reportResult(
- any(), eq(null));
- }
-
private void setUpExampleStoreService() {
TestExampleStoreService testExampleStoreService = new TestExampleStoreService();
doReturn(testExampleStoreService).when(mSpyWorker).getExampleStoreService(anyString());
@@ -643,7 +606,7 @@ public final class FederatedComputeWorkerTest {
doReturn(immediateFuture(checkinResult)).when(mSpyHttpFederatedProtocol).issueCheckin();
doReturn(FluentFuture.from(immediateFuture(null)))
.when(mSpyHttpFederatedProtocol)
- .reportResult(any(), any());
+ .reportResult(any());
}
private static class TestExampleStoreService extends IExampleStoreService.Stub {
diff --git a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java
new file mode 100644
index 00000000..27ee61c7
--- /dev/null
+++ b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTraining.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.federatedcompute.test.scenario.federatedcompute;
+
+import android.platform.test.scenario.annotation.Scenario;
+
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.IOException;
+
+@Scenario
+@RunWith(JUnit4.class)
+/**
+ * Schedule a non-existent population training task from Odp Test app UI
+ * Force the task execution through ADB commands and verify error handling and exit behavior
+ */
+public class ScheduleNonExistentPopulationForTraining {
+ private TestHelper mTestHelper = new TestHelper();
+
+ /** Prepare the device before entering the test class */
+ @BeforeClass
+ public static void prepareDevice() throws IOException {
+ TestHelper.initialize();
+ TestHelper.killRunningProcess();
+ }
+
+ @Before
+ public void setup() throws IOException {
+ mTestHelper.pressHome();
+ mTestHelper.openTestApp();
+ mTestHelper.inputNonExistentPopulationForScheduleTraining();
+ }
+
+ @Test
+ public void testScheduleNonExistentPopulationForTraining() throws IOException {
+ mTestHelper.clickScheduleTraining();
+ mTestHelper.forceExecuteTrainingForNonExistentPopulation();
+ }
+
+ /** Return device to original state after test exeuction */
+ @AfterClass
+ public static void tearDown() throws IOException {
+ TestHelper.pressHome();
+ TestHelper.wrapUp();
+ }
+}
diff --git a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java
index 07ec1f87..4c135533 100644
--- a/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java
+++ b/tests/perftests/scenarios/src/android/federatedcompute/test/scenario/federatedcompute/TestHelper.java
@@ -34,6 +34,7 @@ public class TestHelper {
private static UiDevice sUiDevice;
private static final long UI_FIND_RESOURCE_TIMEOUT = 5000;
private static final long TRAINING_TASK_COMPLETION_TIMEOUT = 120_000;
+ private static final long CHECKIN_REJECTION_COMPLETION_TIMEOUT = 20_000;
private static final String ODP_CLIENT_TEST_APP_PACKAGE_NAME = "com.example.odpclient";
private static final String SCHEDULE_TRAINING_BUTTON_RESOURCE_ID = "schedule_training_button";
private static final String SCHEDULE_TRAINING_TEXT_BOX_RESOURCE_ID =
@@ -42,6 +43,10 @@ public class TestHelper {
private static final String ODP_TEST_APP_TRAINING_TASK_JOB_ID = "-630817781";
private static final String FEDERATED_TRAINING_JOB_SUCCESS_LOG =
"FederatedJobService - Federated computation job -630817781 is done";
+ private static final String NON_EXISTENT_POPULATION_NAME = "test_non_existent_population";
+ private static final String NON_EXISTENT_POPULATION_TASK_JOB_ID = "1892833995";
+ private static final String NON_EXISTENT_POPULATION_JOB_FAILURE_LOG =
+ "job 1892833995 was rejected during check in, reason NO_TASK_AVAILABLE";
public static void pressHome() {
getUiDevice().pressHome();
@@ -119,6 +124,13 @@ public class TestHelper {
scheduleTrainingTextBox.setText(ODP_TEST_APP_POPULATION_NAME);
}
+ /** Put a test non existent population name down for training */
+ public void inputNonExistentPopulationForScheduleTraining() {
+ UiObject2 scheduleTrainingTextBox = getScheduleTrainingTextBox();
+ assertNotNull("Schedule Training text box not found", scheduleTrainingTextBox);
+ scheduleTrainingTextBox.setText(NON_EXISTENT_POPULATION_NAME);
+ }
+
/** Click Schedule Training button. */
public void clickScheduleTraining() {
UiObject2 scheduleTrainingButton = getScheduleTrainingButton();
@@ -143,11 +155,34 @@ public class TestHelper {
if (!foundTrainingJobSuccessLog) {
Assert.fail(String.format(
- "Failed to find federated training job success log within test window %d ms",
+ "Failed to find federated training job success log %s within test window %d ms",
+ FEDERATED_TRAINING_JOB_SUCCESS_LOG,
TRAINING_TASK_COMPLETION_TIMEOUT));
}
}
+ /** Force the JobScheduler to execute the training task for non existent population */
+ public void forceExecuteTrainingForNonExistentPopulation() throws IOException {
+ executeShellCommand("logcat -c"); // Cleans the log buffer
+ executeShellCommand("logcat -G 32M"); // Set log buffer to 32MB
+ executeShellCommand(
+ "cmd jobscheduler run -f com.google.android.federatedcompute "
+ + NON_EXISTENT_POPULATION_TASK_JOB_ID);
+ SystemClock.sleep(10000);
+
+ boolean foundTrainingFailureLog = findLog(
+ NON_EXISTENT_POPULATION_JOB_FAILURE_LOG,
+ CHECKIN_REJECTION_COMPLETION_TIMEOUT,
+ 5000);
+
+ if (!foundTrainingFailureLog) {
+ Assert.fail(String.format(
+ "Failed to find federated training failure log: %s within test window %d ms",
+ NON_EXISTENT_POPULATION_JOB_FAILURE_LOG,
+ CHECKIN_REJECTION_COMPLETION_TIMEOUT));
+ }
+ }
+
/** Attempt to find a specific log entry within the timeout window */
private boolean findLog(final String targetLog, long timeoutMillis,
long queryIntervalMillis) throws IOException {
diff --git a/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java b/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java
new file mode 100644
index 00000000..8150bf3d
--- /dev/null
+++ b/tests/perftests/scenarios/tests/src/android/federatedcompute/test/scenario/federatedcompute/ScheduleNonExistentPopulationForTrainingMicrobenchmark.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.federatedcompute.test.scenario.federatedcompute;
+
+import android.platform.test.microbenchmark.Microbenchmark;
+import android.platform.test.rule.DropCachesRule;
+import android.platform.test.rule.KillAppsRule;
+import android.platform.test.rule.PressHomeRule;
+
+import org.junit.Rule;
+import org.junit.rules.RuleChain;
+import org.junit.runner.RunWith;
+
+@RunWith(Microbenchmark.class)
+public class ScheduleNonExistentPopulationForTrainingMicrobenchmark
+ extends ScheduleNonExistentPopulationForTraining {
+
+ @Rule
+ public RuleChain rules = RuleChain.outerRule(new DropCachesRule())
+ .around(new KillAppsRule("com.example.odpclient"))
+ .around(new PressHomeRule());
+}