diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-12-01 09:21:15 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-12-01 09:21:15 +0000 |
commit | 7f90ce3bc69345e8df192183e4b130ceec7e13aa (patch) | |
tree | c31e30fc7f909401abb3ccf92cd0db3b4b3cb1ca | |
parent | b2e129469e2398760130fa2e2f6265c7e5390823 (diff) | |
parent | 15e3ef7b1a426964cf2c2e080a8cbbd02112f6d0 (diff) | |
download | OnDevicePersonalization-android14-mainline-healthfitness-release.tar.gz |
Snap for 11166233 from 15e3ef7b1a426964cf2c2e080a8cbbd02112f6d0 to mainline-healthfitness-releaseaml_hef_341415040android14-mainline-healthfitness-release
Change-Id: I376aa60d10b4077dae690afcf6c58b0ed2577684
3 files changed, 107 insertions, 15 deletions
diff --git a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java index d7883542..1bae7ebd 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java +++ b/federatedcompute/src/com/android/federatedcompute/services/http/HttpFederatedProtocol.java @@ -62,7 +62,7 @@ import java.util.concurrent.Callable; /** Implements a single session of HTTP-based federated compute protocol. */ public final class HttpFederatedProtocol { - public static final String TAG = "HttpFederatedProtocol"; + public static final String TAG = HttpFederatedProtocol.class.getSimpleName(); private final String mClientVersion; private final String mPopulationName; private final HttpClient mHttpClient; diff --git a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java index ac54dde5..4244aff4 100644 --- a/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java +++ b/federatedcompute/src/com/android/federatedcompute/services/training/FederatedComputeWorker.java @@ -60,6 +60,7 @@ import com.android.internal.util.Preconditions; import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.FluentFuture; +import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; @@ -264,6 +265,10 @@ public class FederatedComputeWorker { run.mTask.appPackageName(), run.mTaskName, getExampleSelector(checkinResult)); + // report failure to server if getting iterator failed with any exception. + FutureCallback<Object> serverFailureReportCallback = getServerFailureReportCallback(); + Futures.addCallback( + iteratorFuture, serverFailureReportCallback, getLightweightExecutor()); // 3. Run federated learning or federated analytic depends on task type. Federated // learning job will start a new isolated process to run TFLite training. @@ -272,6 +277,9 @@ public class FederatedComputeWorker { .transformAsync( iterator -> runFederatedComputation(checkinResult, run, iterator), mInjector.getBgExecutor()); + // report failure to server if computation failed with any exception. + computationResultFuture.addCallback( + serverFailureReportCallback, getLightweightExecutor()); // 4. Report computation result to federated compute server. ListenableFuture<RejectionInfo> reportToServerFuture = @@ -284,6 +292,8 @@ public class FederatedComputeWorker { ComputationResult computationResult = Futures.getDone(computationResultFuture); RejectionInfo reportToServer = Futures.getDone(reportToServerFuture); + // report to Server will hold null in case of success, or rejection info + // in case server answered with rejection if (reportToServer != null) { ComputationResult failedReportComputationResult = new ComputationResult( @@ -320,6 +330,41 @@ public class FederatedComputeWorker { mInjector.getBgExecutor()); } + @androidx.annotation.NonNull + private FutureCallback<Object> getServerFailureReportCallback() { + return new FutureCallback<Object>() { + volatile int mNumberOfInvocations = 0; + + @Override + public void onSuccess(Object unused) { + // do nothing. + } + + // We do not want race condition and repeating reporting failures from computation + // failed future in right before case Example Store iterator failed. + // Thus method is synchronised. + @Override + public synchronized void onFailure(Throwable throwable) { + if (mNumberOfInvocations < 1) { + LogUtil.d( + TAG, + "Training failed. Reporting failure result to server due to exception.", + throwable); + ComputationResult failedReportComputationResult = + new ComputationResult( + null, + FLRunnerResult.newBuilder() + .setContributionResult(ContributionResult.FAIL) + .setErrorMessage(throwable.getMessage()) + .build(), + null); + var unused = mHttpFederatedProtocol.reportResult(failedReportComputationResult); + } + mNumberOfInvocations++; + } + }; + } + private static TaskRetry buildTaskRetry(RejectionInfo rejectionInfo) { TaskRetry.Builder taskRetryBuilder = TaskRetry.newBuilder(); diff --git a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java index acdd3c86..71fe9998 100644 --- a/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java +++ b/tests/federatedcomputetests/src/com/android/federatedcompute/services/training/FederatedComputeWorkerTest.java @@ -189,7 +189,7 @@ public final class FederatedComputeWorkerTest { @Mock FederatedComputeJobManager mMockJobManager; private Context mContext; private FederatedComputeWorker mSpyWorker; - @Mock private HttpFederatedProtocol mMockHttpFederatedProtocol; + private HttpFederatedProtocol mSpyHttpFederatedProtocol; @Mock private ComputationRunner mMockComputationRunner; private ResultCallbackHelper mSpyResultCallbackHelper; @@ -218,6 +218,9 @@ public final class FederatedComputeWorkerTest { @Before public void doBeforeEachTest() throws Exception { mContext = ApplicationProvider.getApplicationContext(); + mSpyHttpFederatedProtocol = + Mockito.spy( + HttpFederatedProtocol.create(SERVER_ADDRESS, "1.0.0.1", POPULATION_NAME)); mSpyResultCallbackHelper = Mockito.spy(new ResultCallbackHelper(mContext)); mSpyWorker = Mockito.spy( @@ -234,7 +237,7 @@ public final class FederatedComputeWorkerTest { .when(mSpyResultCallbackHelper) .callHandleResult(eq(TASK_NAME), any(), any()); when(mMockJobManager.onTrainingStarted(anyInt())).thenReturn(FEDERATED_TRAINING_TASK_1); - doReturn(mMockHttpFederatedProtocol) + doReturn(mSpyHttpFederatedProtocol) .when(mSpyWorker) .getHttpFederatedProtocol(anyString(), anyString()); when(mMockComputationRunner.runTaskWithNativeRunner( @@ -282,10 +285,10 @@ public final class FederatedComputeWorkerTest { new ExecutionException( "issue checkin failed", new IllegalStateException("http 404")))) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .issueCheckin(); doReturn(FluentFuture.from(immediateFuture(null))) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .reportResult(any()); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -301,7 +304,7 @@ public final class FederatedComputeWorkerTest { setUpExampleStoreService(); doReturn( immediateFuture(REJECTION_CHECKIN_RESULT)) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .issueCheckin(); FLRunnerResult result = mSpyWorker.startTrainingRun(JOB_ID).get(); @@ -317,10 +320,10 @@ public final class FederatedComputeWorkerTest { public void testReportResultWithRejection() throws Exception { setUpExampleStoreService(); doReturn(immediateFuture(FA_CHECKIN_RESULT)) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .issueCheckin(); doReturn(FluentFuture.from(immediateFuture(REJECTION_INFO))) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .reportResult(any()); doCallRealMethod().when(mSpyResultCallbackHelper).callHandleResult(any(), any(), any()); ArgumentCaptor<ComputationResult> computationResultCaptor = @@ -348,7 +351,7 @@ public final class FederatedComputeWorkerTest { setUpExampleStoreService(); doReturn(immediateFuture(FA_CHECKIN_RESULT)) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .issueCheckin(); doReturn( FluentFuture.from( @@ -356,7 +359,7 @@ public final class FederatedComputeWorkerTest { new ExecutionException( "report result failed", new IllegalStateException("http 404"))))) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .reportResult(any()); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -371,10 +374,10 @@ public final class FederatedComputeWorkerTest { @Test public void testBindToExampleStoreFails_throwsException() throws Exception { setUpHttpFederatedProtocol(FL_CHECKIN_RESULT); - // Mock failure bind to ExampleStoreService. doReturn(null).when(mSpyWorker).getExampleStoreService(anyString()); - doNothing().when(mSpyWorker).unbindFromExampleStoreService(); + ArgumentCaptor<ComputationResult> computationResultCaptor = + ArgumentCaptor.forClass(ComputationResult.class); assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); @@ -383,6 +386,13 @@ public final class FederatedComputeWorkerTest { .onTrainingCompleted( anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); verify(mSpyWorker, times(0)).unbindFromExampleStoreService(); + verify(mSpyHttpFederatedProtocol, times(1)) + .reportResult(computationResultCaptor.capture()); + ComputationResult computationResult = computationResultCaptor.getValue(); + assertNotNull(computationResult.getFlRunnerResult()); + assertEquals( + ContributionResult.FAIL, + computationResult.getFlRunnerResult().getContributionResult()); } @Test @@ -414,6 +424,43 @@ public final class FederatedComputeWorkerTest { } @Test + public void testRunFAComputationThrows() throws Exception { + setUpExampleStoreService(); + setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); + doReturn(FluentFuture.from(immediateFuture(null))) + .when(mSpyHttpFederatedProtocol) + .reportResult(any()); + when(mMockComputationRunner.runTaskWithNativeRunner( + anyString(), + anyString(), + anyString(), + anyString(), + any(), + any(), + any(), + any(), + any())) + .thenThrow(new RuntimeException("Test failures!")); + ArgumentCaptor<ComputationResult> computationResultCaptor = + ArgumentCaptor.forClass(ComputationResult.class); + + assertThrows(ExecutionException.class, () -> mSpyWorker.startTrainingRun(JOB_ID).get()); + + mSpyWorker.finish(null, ContributionResult.FAIL, false); + verify(mMockJobManager) + .onTrainingCompleted( + anyInt(), anyString(), any(), any(), eq(ContributionResult.FAIL)); + verify(mSpyWorker).unbindFromExampleStoreService(); + verify(mSpyHttpFederatedProtocol, times(1)) + .reportResult(computationResultCaptor.capture()); + ComputationResult computationResult = computationResultCaptor.getValue(); + assertNotNull(computationResult.getFlRunnerResult()); + assertEquals( + ContributionResult.FAIL, + computationResult.getFlRunnerResult().getContributionResult()); + } + + @Test public void testPublishToResultHandlingServiceFails_returnsSuccess() throws Exception { setUpExampleStoreService(); setUpHttpFederatedProtocol(FA_CHECKIN_RESULT); @@ -475,7 +522,7 @@ public final class FederatedComputeWorkerTest { @Test public void testBindToIsolatedTrainingServiceFail_returnsFail() throws Exception { doReturn(immediateFuture(FL_CHECKIN_RESULT)) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .issueCheckin(); setUpExampleStoreService(); @@ -556,9 +603,9 @@ public final class FederatedComputeWorkerTest { } private void setUpHttpFederatedProtocol(CheckinResult checkinResult) { - doReturn(immediateFuture(checkinResult)).when(mMockHttpFederatedProtocol).issueCheckin(); + doReturn(immediateFuture(checkinResult)).when(mSpyHttpFederatedProtocol).issueCheckin(); doReturn(FluentFuture.from(immediateFuture(null))) - .when(mMockHttpFederatedProtocol) + .when(mSpyHttpFederatedProtocol) .reportResult(any()); } |