aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-12-01 09:21:15 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-12-01 09:21:15 +0000
commit7f90ce3bc69345e8df192183e4b130ceec7e13aa (patch)
treec31e30fc7f909401abb3ccf92cd0db3b4b3cb1ca
parentb2e129469e2398760130fa2e2f6265c7e5390823 (diff)
parent15e3ef7b1a426964cf2c2e080a8cbbd02112f6d0 (diff)
downloadOnDevicePersonalization-android14-mainline-healthfitness-release.tar.gz
Snap for 11166233 from 15e3ef7b1a426964cf2c2e080a8cbbd02112f6d0 to mainline-healthfitness-releaseaml_hef_341415040android14-mainline-healthfitness-release
Change-Id: I376aa60d10b4077dae690afcf6c58b0ed2577684
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java2
-rw-r--r--federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java45
-rw-r--r--tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java75
3 files changed, 107 insertions, 15 deletions
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
index d7883542..1bae7ebd 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java
@@ -62,7 +62,7 @@ import java.util.concurrent.Callable;
/** Implements a single session of HTTP-based federated compute protocol. */
public final class HttpFederatedProtocol {
- public static final String TAG = "HttpFederatedProtocol";
+ public static final String TAG = HttpFederatedProtocol.class.getSimpleName();
private final String mClientVersion;
private final String mPopulationName;
private final HttpClient mHttpClient;
diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
index ac54dde5..4244aff4 100644
--- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
+++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java
@@ -60,6 +60,7 @@ import com.android.internal.util.Preconditions;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
@@ -264,6 +265,10 @@ public class FederatedComputeWorker {
run.mTask.appPackageName(),
run.mTaskName,
getExampleSelector(checkinResult));
+ // report failure to server if getting iterator failed with any exception.
+ FutureCallback<Object> serverFailureReportCallback = getServerFailureReportCallback();
+ Futures.addCallback(
+ iteratorFuture, serverFailureReportCallback, getLightweightExecutor());
// 3. Run federated learning or federated analytic depends on task type. Federated
// learning job will start a new isolated process to run TFLite training.
@@ -272,6 +277,9 @@ public class FederatedComputeWorker {
.transformAsync(
iterator -> runFederatedComputation(checkinResult, run, iterator),
mInjector.getBgExecutor());
+ // report failure to server if computation failed with any exception.
+ computationResultFuture.addCallback(
+ serverFailureReportCallback, getLightweightExecutor());
// 4. Report computation result to federated compute server.
ListenableFuture<RejectionInfo> reportToServerFuture =
@@ -284,6 +292,8 @@ public class FederatedComputeWorker {
ComputationResult computationResult =
Futures.getDone(computationResultFuture);
RejectionInfo reportToServer = Futures.getDone(reportToServerFuture);
+ // report to Server will hold null in case of success, or rejection info
+ // in case server answered with rejection
if (reportToServer != null) {
ComputationResult failedReportComputationResult =
new ComputationResult(
@@ -320,6 +330,41 @@ public class FederatedComputeWorker {
mInjector.getBgExecutor());
}
+ @androidx.annotation.NonNull
+ private FutureCallback<Object> getServerFailureReportCallback() {
+ return new FutureCallback<Object>() {
+ volatile int mNumberOfInvocations = 0;
+
+ @Override
+ public void onSuccess(Object unused) {
+ // do nothing.
+ }
+
+ // We do not want race condition and repeating reporting failures from computation
+ // failed future in right before case Example Store iterator failed.
+ // Thus method is synchronised.
+ @Override
+ public synchronized void onFailure(Throwable throwable) {
+ if (mNumberOfInvocations < 1) {
+ LogUtil.d(
+ TAG,
+ "Training failed. Reporting failure result to server due to exception.",
+ throwable);
+ ComputationResult failedReportComputationResult =
+ new ComputationResult(
+ null,
+ FLRunnerResult.newBuilder()
+ .setContributionResult(ContributionResult.FAIL)
+ .setErrorMessage(throwable.getMessage())
+ .build(),
+ null);
+ var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult);
+ }
+ mNumberOfInvocations++;
+ }
+ };
+ }
+
private static TaskRetry buildTaskRetry(RejectionInfo rejectionInfo) {
TaskRetry.Builder taskRetryBuilder =
TaskRetry.newBuilder();
diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
index acdd3c86..71fe9998 100644
--- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
+++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java
@@ -189,7 +189,7 @@ public final class FederatedComputeWorkerTest {
@Mock FederatedComputeJobManager mMockJobManager;
private Context mContext;
private FederatedComputeWorker mSpyWorker;
- @Mock private HttpFederatedProtocol mMockHttpFederatedProtocol;
+ private HttpFederatedProtocol mSpyHttpFederatedProtocol;
@Mock private ComputationRunner mMockComputationRunner;
private ResultCallbackHelper mSpyResultCallbackHelper;
@@ -218,6 +218,9 @@ public final class FederatedComputeWorkerTest {
@Before
public void doBeforeEachTest() throws Exception {
mContext = ApplicationProvider.getApplicationContext();
+ mSpyHttpFederatedProtocol =
+ Mockito.spy(
+ HttpFederatedProtocol.create(SERVER_ADDRESS, "1.0.0.1", POPULATION_NAME));
mSpyResultCallbackHelper = Mockito.spy(new ResultCallbackHelper(mContext));
mSpyWorker =
Mockito.spy(
@@ -234,7 +237,7 @@ public final class FederatedComputeWorkerTest {
.when(mSpyResultCallbackHelper)
.callHandleResult(eq(TASK_NAME), any(), any());
when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1);
- doReturn(mMockHttpFederatedProtocol)
+ doReturn(mSpyHttpFederatedProtocol)
.when(mSpyWorker)
.getHttpFederatedProtocol(anyString(), anyString());
when(mMockComputationRunner.runTaskWithNativeRunner(
@@ -282,10 +285,10 @@ public final class FederatedComputeWorkerTest {
new ExecutionException(
"issue checkin failed",
new IllegalStateException("http 404"))))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.issueCheckin();
doReturn(FluentFuture.from(immediateFuture(null)))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.reportResult(any());
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -301,7 +304,7 @@ public final class FederatedComputeWorkerTest {
setUpExampleStoreService();
doReturn(
immediateFuture(REJECTION_CHECKIN_RESULT))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.issueCheckin();
FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get();
@@ -317,10 +320,10 @@ public final class FederatedComputeWorkerTest {
public void testReportResultWithRejection() throws Exception {
setUpExampleStoreService();
doReturn(immediateFuture(FA_CHECKIN_RESULT))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.issueCheckin();
doReturn(FluentFuture.from(immediateFuture(REJECTION_INFO)))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.reportResult(any());
doCallRealMethod().when(mSpyResultCallbackHelper).callHandleResult(any(), any(), any());
ArgumentCaptor<ComputationResult> computationResultCaptor =
@@ -348,7 +351,7 @@ public final class FederatedComputeWorkerTest {
setUpExampleStoreService();
doReturn(immediateFuture(FA_CHECKIN_RESULT))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.issueCheckin();
doReturn(
FluentFuture.from(
@@ -356,7 +359,7 @@ public final class FederatedComputeWorkerTest {
new ExecutionException(
"report result failed",
new IllegalStateException("http 404")))))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.reportResult(any());
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -371,10 +374,10 @@ public final class FederatedComputeWorkerTest {
@Test
public void testBindToExampleStoreFails_throwsException() throws Exception {
setUpHttpFederatedProtocol(FL_CHECKIN_RESULT);
-
// Mock failure bind to ExampleStoreService.
doReturn(null).when(mSpyWorker).getExampleStoreService(anyString());
- doNothing().when(mSpyWorker).unbindFromExampleStoreService();
+ ArgumentCaptor<ComputationResult> computationResultCaptor =
+ ArgumentCaptor.forClass(ComputationResult.class);
assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
@@ -383,6 +386,13 @@ public final class FederatedComputeWorkerTest {
.onTrainingCompleted(
anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
verify(mSpyWorker, times(0)).unbindFromExampleStoreService();
+ verify(mSpyHttpFederatedProtocol, times(1))
+ .reportResult(computationResultCaptor.capture());
+ ComputationResult computationResult = computationResultCaptor.getValue();
+ assertNotNull(computationResult.getFlRunnerResult());
+ assertEquals(
+ ContributionResult.FAIL,
+ computationResult.getFlRunnerResult().getContributionResult());
}
@Test
@@ -414,6 +424,43 @@ public final class FederatedComputeWorkerTest {
}
@Test
+ public void testRunFAComputationThrows() throws Exception {
+ setUpExampleStoreService();
+ setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
+ doReturn(FluentFuture.from(immediateFuture(null)))
+ .when(mSpyHttpFederatedProtocol)
+ .reportResult(any());
+ when(mMockComputationRunner.runTaskWithNativeRunner(
+ anyString(),
+ anyString(),
+ anyString(),
+ anyString(),
+ any(),
+ any(),
+ any(),
+ any(),
+ any()))
+ .thenThrow(new RuntimeException("Test failures!"));
+ ArgumentCaptor<ComputationResult> computationResultCaptor =
+ ArgumentCaptor.forClass(ComputationResult.class);
+
+ assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get());
+
+ mSpyWorker.finish(null, ContributionResult.FAIL, false);
+ verify(mMockJobManager)
+ .onTrainingCompleted(
+ anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL));
+ verify(mSpyWorker).unbindFromExampleStoreService();
+ verify(mSpyHttpFederatedProtocol, times(1))
+ .reportResult(computationResultCaptor.capture());
+ ComputationResult computationResult = computationResultCaptor.getValue();
+ assertNotNull(computationResult.getFlRunnerResult());
+ assertEquals(
+ ContributionResult.FAIL,
+ computationResult.getFlRunnerResult().getContributionResult());
+ }
+
+ @Test
public void testPublishToResultHandlingServiceFails_returnsSuccess() throws Exception {
setUpExampleStoreService();
setUpHttpFederatedProtocol(FA_CHECKIN_RESULT);
@@ -475,7 +522,7 @@ public final class FederatedComputeWorkerTest {
@Test
public void testBindToIsolatedTrainingServiceFail_returnsFail() throws Exception {
doReturn(immediateFuture(FL_CHECKIN_RESULT))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.issueCheckin();
setUpExampleStoreService();
@@ -556,9 +603,9 @@ public final class FederatedComputeWorkerTest {
}
private void setUpHttpFederatedProtocol(CheckinResult checkinResult) {
- doReturn(immediateFuture(checkinResult)).when(mMockHttpFederatedProtocol).issueCheckin();
+ doReturn(immediateFuture(checkinResult)).when(mSpyHttpFederatedProtocol).issueCheckin();
doReturn(FluentFuture.from(immediateFuture(null)))
- .when(mMockHttpFederatedProtocol)
+ .when(mSpyHttpFederatedProtocol)
.reportResult(any());
}