aboutsummaryrefslogtreecommitdiff
path: root/fcp/client/test_helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'fcp/client/test_helpers.h')
-rw-r--r--fcp/client/test_helpers.h861
1 files changed, 861 insertions, 0 deletions
diff --git a/fcp/client/test_helpers.h b/fcp/client/test_helpers.h
new file mode 100644
index 0000000..30022cf
--- /dev/null
+++ b/fcp/client/test_helpers.h
@@ -0,0 +1,861 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef FCP_CLIENT_TEST_HELPERS_H_
+#define FCP_CLIENT_TEST_HELPERS_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "fcp/base/monitoring.h"
+#include "fcp/client/engine/example_iterator_factory.h"
+#include "fcp/client/event_publisher.h"
+#include "fcp/client/federated_protocol.h"
+#include "fcp/client/federated_select.h"
+#include "fcp/client/flags.h"
+#include "fcp/client/http/http_client.h"
+#include "fcp/client/log_manager.h"
+#include "fcp/client/opstats/opstats_db.h"
+#include "fcp/client/opstats/opstats_logger.h"
+#include "fcp/client/phase_logger.h"
+#include "fcp/client/secagg_event_publisher.h"
+#include "fcp/client/secagg_runner.h"
+#include "fcp/client/simple_task_environment.h"
+#include "gmock/gmock.h"
+#include "google/protobuf/duration.pb.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+
+namespace fcp {
+namespace client {
+
+class MockSecAggEventPublisher : public SecAggEventPublisher {
+ public:
+ MOCK_METHOD(void, PublishStateTransition,
+ (::fcp::secagg::ClientState state, size_t last_sent_message_size,
+ size_t last_received_message_size),
+ (override));
+ MOCK_METHOD(void, PublishError, (), (override));
+ MOCK_METHOD(void, PublishAbort,
+ (bool client_initiated, const std::string& error_message),
+ (override));
+ MOCK_METHOD(void, set_execution_session_id, (int64_t execution_session_id),
+ (override));
+};
+
+class MockEventPublisher : public EventPublisher {
+ public:
+ MOCK_METHOD(void, PublishEligibilityEvalCheckin, (), (override));
+ MOCK_METHOD(void, PublishEligibilityEvalPlanUriReceived,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalPlanReceived,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalNotConfigured,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalRejected,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckin, (), (override));
+ MOCK_METHOD(void, PublishCheckinFinished,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishRejected, (), (override));
+ MOCK_METHOD(void, PublishReportStarted, (int64_t report_size_bytes),
+ (override));
+ MOCK_METHOD(void, PublishReportFinished,
+ (const NetworkStats& network_stats,
+ absl::Duration report_duration),
+ (override));
+ MOCK_METHOD(void, PublishPlanExecutionStarted, (), (override));
+ MOCK_METHOD(void, PublishTensorFlowError,
+ (int example_count, absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishIoError, (absl::string_view error_message),
+ (override));
+ MOCK_METHOD(void, PublishExampleSelectorError,
+ (int example_count, absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishInterruption,
+ (const ExampleStats& example_stats, absl::Time start_time),
+ (override));
+ MOCK_METHOD(void, PublishPlanCompleted,
+ (const ExampleStats& example_stats, absl::Time start_time),
+ (override));
+ MOCK_METHOD(void, SetModelIdentifier, (const std::string& model_identifier),
+ (override));
+ MOCK_METHOD(void, PublishTaskNotStarted, (absl::string_view error_message),
+ (override));
+ MOCK_METHOD(void, PublishNonfatalInitializationError,
+ (absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishFatalInitializationError,
+ (absl::string_view error_message), (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinIoError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalCheckinErrorInvalidPayload,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationStarted, (), (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationInvalidArgument,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationExampleIteratorError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationTensorflowError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationInterrupted,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishEligibilityEvalComputationCompleted,
+ (const ExampleStats& example_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinIoError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinInvalidPayload,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishRejected,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinPlanUriReceived,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishCheckinFinishedV2,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationStarted, (), (override));
+ MOCK_METHOD(void, PublishComputationInvalidArgument,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationIOError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationExampleIteratorError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationTensorflowError,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationInterrupted,
+ (absl::string_view error_message,
+ const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishComputationCompleted,
+ (const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadStarted, (), (override));
+ MOCK_METHOD(void, PublishResultUploadIOError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishResultUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadStarted, (), (override));
+ MOCK_METHOD(void, PublishFailureUploadIOError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadClientInterrupted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadServerAborted,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+ MOCK_METHOD(void, PublishFailureUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Duration phase_duration),
+ (override));
+
+ SecAggEventPublisher* secagg_event_publisher() override {
+ return &secagg_event_publisher_;
+ }
+
+ private:
+ ::testing::NiceMock<MockSecAggEventPublisher> secagg_event_publisher_;
+};
+
+// A mock FederatedProtocol implementation, which keeps track of the stages in
+// the protocol and returns a different set of network stats and RetryWindow for
+// each stage, making it easier to write accurate assertions in unit tests.
+class MockFederatedProtocol : public FederatedProtocol {
+ public:
+ constexpr static NetworkStats
+ kPostEligibilityCheckinPlanUriReceivedNetworkStats = {
+ .bytes_downloaded = 280,
+ .bytes_uploaded = 380,
+ .network_duration = absl::Milliseconds(25)};
+ constexpr static NetworkStats kPostEligibilityCheckinNetworkStats = {
+ .bytes_downloaded = 300,
+ .bytes_uploaded = 400,
+ .network_duration = absl::Milliseconds(50)};
+ constexpr static NetworkStats kPostReportEligibilityEvalErrorNetworkStats = {
+ .bytes_downloaded = 400,
+ .bytes_uploaded = 500,
+ .network_duration = absl::Milliseconds(150)};
+ constexpr static NetworkStats kPostCheckinPlanUriReceivedNetworkStats = {
+ .bytes_downloaded = 2970,
+ .bytes_uploaded = 3970,
+ .network_duration = absl::Milliseconds(225)};
+ constexpr static NetworkStats kPostCheckinNetworkStats = {
+ .bytes_downloaded = 3000,
+ .bytes_uploaded = 4000,
+ .network_duration = absl::Milliseconds(250)};
+ constexpr static NetworkStats kPostReportCompletedNetworkStats = {
+ .bytes_downloaded = 30000,
+ .bytes_uploaded = 40000,
+ .network_duration = absl::Milliseconds(350)};
+ constexpr static NetworkStats kPostReportNotCompletedNetworkStats = {
+ .bytes_downloaded = 29999,
+ .bytes_uploaded = 39999,
+ .network_duration = absl::Milliseconds(450)};
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetInitialRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(0L);
+ retry_window.mutable_delay_max()->set_seconds(1L);
+ *retry_window.mutable_retry_token() = "INITIAL";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostEligibilityCheckinRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(100L);
+ retry_window.mutable_delay_max()->set_seconds(101L);
+ *retry_window.mutable_retry_token() = "POST_ELIGIBILITY";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostCheckinRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(200L);
+ retry_window.mutable_delay_max()->set_seconds(201L);
+ *retry_window.mutable_retry_token() = "POST_CHECKIN";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostReportCompletedRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(300L);
+ retry_window.mutable_delay_max()->set_seconds(301L);
+ *retry_window.mutable_retry_token() = "POST_REPORT_COMPLETED";
+ return retry_window;
+ }
+
+ static google::internal::federatedml::v2::RetryWindow
+ GetPostReportNotCompletedRetryWindow() {
+ google::internal::federatedml::v2::RetryWindow retry_window;
+ retry_window.mutable_delay_min()->set_seconds(400L);
+ retry_window.mutable_delay_max()->set_seconds(401L);
+ *retry_window.mutable_retry_token() = "POST_REPORT_NOT_COMPLETED";
+ return retry_window;
+ }
+
+ explicit MockFederatedProtocol() {}
+
+ // We override the real FederatedProtocol methods so that we can intercept the
+ // progression of protocol stages, and expose dedicate gMock-overridable
+ // methods for use in tests.
+ absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin(
+ std::function<void(const EligibilityEvalTask&)>
+ payload_uris_received_callback) final {
+ absl::StatusOr<EligibilityEvalCheckinResult> result =
+ MockEligibilityEvalCheckin();
+ if (result.ok() &&
+ std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
+ *result)) {
+ network_stats_ = kPostEligibilityCheckinPlanUriReceivedNetworkStats;
+ payload_uris_received_callback(
+ std::get<FederatedProtocol::EligibilityEvalTask>(*result));
+ }
+ network_stats_ = kPostEligibilityCheckinNetworkStats;
+ retry_window_ = GetPostEligibilityCheckinRetryWindow();
+ return result;
+ };
+ MOCK_METHOD(absl::StatusOr<EligibilityEvalCheckinResult>,
+ MockEligibilityEvalCheckin, ());
+
+ void ReportEligibilityEvalError(absl::Status error_status) final {
+ network_stats_ = kPostReportEligibilityEvalErrorNetworkStats;
+ retry_window_ = GetPostEligibilityCheckinRetryWindow();
+ MockReportEligibilityEvalError(error_status);
+ }
+ MOCK_METHOD(void, MockReportEligibilityEvalError,
+ (absl::Status error_status));
+
+ absl::StatusOr<CheckinResult> Checkin(
+ const std::optional<
+ ::google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info,
+ std::function<void(const FederatedProtocol::TaskAssignment&)>
+ payload_uris_received_callback) final {
+ absl::StatusOr<CheckinResult> result = MockCheckin(task_eligibility_info);
+ if (result.ok() &&
+ std::holds_alternative<FederatedProtocol::TaskAssignment>(*result)) {
+ network_stats_ = kPostCheckinPlanUriReceivedNetworkStats;
+ payload_uris_received_callback(
+ std::get<FederatedProtocol::TaskAssignment>(*result));
+ }
+ retry_window_ = GetPostCheckinRetryWindow();
+ network_stats_ = kPostCheckinNetworkStats;
+ return result;
+ };
+ MOCK_METHOD(absl::StatusOr<CheckinResult>, MockCheckin,
+ (const std::optional<
+ ::google::internal::federatedml::v2::TaskEligibilityInfo>&
+ task_eligibility_info));
+
+ absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments(
+ const std::vector<std::string>& task_names) final {
+ absl::StatusOr<MultipleTaskAssignments> result =
+ MockPerformMultipleTaskAssignments(task_names);
+ retry_window_ = GetPostCheckinRetryWindow();
+ network_stats_ = kPostCheckinPlanUriReceivedNetworkStats;
+ return result;
+ };
+
+ MOCK_METHOD(absl::StatusOr<MultipleTaskAssignments>,
+ MockPerformMultipleTaskAssignments,
+ (const std::vector<std::string>& task_names));
+
+ absl::Status ReportCompleted(
+ ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) final {
+ network_stats_ = kPostReportCompletedNetworkStats;
+ retry_window_ = GetPostReportCompletedRetryWindow();
+ return MockReportCompleted(std::move(results), plan_duration,
+ aggregation_session_id);
+ };
+ MOCK_METHOD(absl::Status, MockReportCompleted,
+ (ComputationResults results, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id));
+
+ absl::Status ReportNotCompleted(
+ engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id) final {
+ network_stats_ = kPostReportNotCompletedNetworkStats;
+ retry_window_ = GetPostReportNotCompletedRetryWindow();
+ return MockReportNotCompleted(phase_outcome, plan_duration,
+ aggregation_session_id);
+ };
+ MOCK_METHOD(absl::Status, MockReportNotCompleted,
+ (engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
+ std::optional<std::string> aggregation_session_id));
+
+ ::google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
+ final {
+ return retry_window_;
+ }
+
+ NetworkStats GetNetworkStats() final { return network_stats_; }
+
+ private:
+ NetworkStats network_stats_;
+ ::google::internal::federatedml::v2::RetryWindow retry_window_ =
+ GetInitialRetryWindow();
+};
+
+class MockLogManager : public LogManager {
+ public:
+ MOCK_METHOD(void, LogDiag, (ProdDiagCode), (override));
+ MOCK_METHOD(void, LogDiag, (DebugDiagCode), (override));
+ MOCK_METHOD(void, LogToLongHistogram,
+ (fcp::client::HistogramCounters, int, int,
+ fcp::client::engine::DataSourceType, int64_t),
+ (override));
+ MOCK_METHOD(void, SetModelIdentifier, (const std::string&), (override));
+};
+
+class MockOpStatsLogger : public ::fcp::client::opstats::OpStatsLogger {
+ public:
+ MOCK_METHOD(
+ void, AddEventAndSetTaskName,
+ (const std::string& task_name,
+ ::fcp::client::opstats::OperationalStats::Event::EventKind event),
+ (override));
+ MOCK_METHOD(
+ void, AddEvent,
+ (::fcp::client::opstats::OperationalStats::Event::EventKind event),
+ (override));
+ MOCK_METHOD(void, AddEventWithErrorMessage,
+ (::fcp::client::opstats::OperationalStats::Event::EventKind event,
+ const std::string& error_message),
+ (override));
+ MOCK_METHOD(void, UpdateDatasetStats,
+ (const std::string& collection_uri, int additional_example_count,
+ int64_t additional_example_size_bytes),
+ (override));
+ MOCK_METHOD(void, SetNetworkStats, (const NetworkStats& network_stats),
+ (override));
+ MOCK_METHOD(void, SetRetryWindow,
+ (google::internal::federatedml::v2::RetryWindow retry_window),
+ (override));
+ MOCK_METHOD(::fcp::client::opstats::OpStatsDb*, GetOpStatsDb, (), (override));
+ MOCK_METHOD(bool, IsOpStatsEnabled, (), (const override));
+ MOCK_METHOD(absl::Status, CommitToStorage, (), (override));
+ MOCK_METHOD(std::string, GetCurrentTaskName, (), (override));
+};
+
+class MockSimpleTaskEnvironment : public SimpleTaskEnvironment {
+ public:
+ MOCK_METHOD(std::string, GetBaseDir, (), (override));
+ MOCK_METHOD(std::string, GetCacheDir, (), (override));
+ MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>),
+ CreateExampleIterator,
+ (const google::internal::federated::plan::ExampleSelector&
+ example_selector),
+ (override));
+ MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>),
+ CreateExampleIterator,
+ (const google::internal::federated::plan::ExampleSelector&
+ example_selector,
+ const SelectorContext& selector_context),
+ (override));
+ MOCK_METHOD(std::unique_ptr<fcp::client::http::HttpClient>, CreateHttpClient,
+ (), (override));
+ MOCK_METHOD(bool, TrainingConditionsSatisfied, (), (override));
+};
+
+class MockExampleIterator : public ExampleIterator {
+ public:
+ MOCK_METHOD(absl::StatusOr<std::string>, Next, (), (override));
+ MOCK_METHOD(void, Close, (), (override));
+};
+
+// An iterator that passes through each example in the dataset once.
+class SimpleExampleIterator : public ExampleIterator {
+ public:
+ // Uses the given bytes as the examples to return.
+ explicit SimpleExampleIterator(std::vector<const char*> examples);
+ // Passes through each of the examples in the `Dataset.client_data.example`
+ // field.
+ explicit SimpleExampleIterator(
+ google::internal::federated::plan::Dataset dataset);
+ // Passes through each of the examples in the
+ // `Dataset.client_data.selected_example.example` field, whose example
+ // collection URI matches the provided `collection_uri`.
+ SimpleExampleIterator(google::internal::federated::plan::Dataset dataset,
+ absl::string_view collection_uri);
+ absl::StatusOr<std::string> Next() override;
+ void Close() override {}
+
+ protected:
+ std::vector<std::string> examples_;
+ int index_ = 0;
+};
+
+struct ComputationArtifacts {
+ // The path to the file containing the plan data.
+ std::string plan_filepath;
+ // The already-parsed plan data.
+ google::internal::federated::plan::ClientOnlyPlan plan;
+ // The test dataset.
+ google::internal::federated::plan::Dataset dataset;
+ // The path to the file containing the initial checkpoint data (not set for
+ // local compute task artifacts).
+ std::string checkpoint_filepath;
+ // The initial checkpoint data, as a string (not set for local compute task
+ // artifacts).
+ std::string checkpoint;
+ // The Federated Select slice data (not set for local compute task artifacts).
+ google::internal::federated::plan::SlicesTestDataset federated_select_slices;
+};
+
+absl::StatusOr<ComputationArtifacts> LoadFlArtifacts();
+
+class MockFlags : public Flags {
+ public:
+ MOCK_METHOD(int64_t, condition_polling_period_millis, (), (const, override));
+ MOCK_METHOD(int64_t, tf_execution_teardown_grace_period_millis, (),
+ (const, override));
+ MOCK_METHOD(int64_t, tf_execution_teardown_extended_period_millis, (),
+ (const, override));
+ MOCK_METHOD(int64_t, grpc_channel_deadline_seconds, (), (const, override));
+ MOCK_METHOD(bool, log_tensorflow_error_messages, (), (const, override));
+ MOCK_METHOD(bool, enable_opstats, (), (const, override));
+ MOCK_METHOD(int64_t, opstats_ttl_days, (), (const, override));
+ MOCK_METHOD(int64_t, opstats_db_size_limit_bytes, (), (const, override));
+ MOCK_METHOD(int64_t, federated_training_transient_errors_retry_delay_secs, (),
+ (const, override));
+ MOCK_METHOD(float,
+ federated_training_transient_errors_retry_delay_jitter_percent,
+ (), (const, override));
+ MOCK_METHOD(int64_t, federated_training_permanent_errors_retry_delay_secs, (),
+ (const, override));
+ MOCK_METHOD(float,
+ federated_training_permanent_errors_retry_delay_jitter_percent,
+ (), (const, override));
+ MOCK_METHOD(std::vector<int32_t>, federated_training_permanent_error_codes,
+ (), (const, override));
+ MOCK_METHOD(bool, use_tflite_training, (), (const, override));
+ MOCK_METHOD(bool, enable_grpc_with_http_resource_support, (),
+ (const, override));
+ MOCK_METHOD(bool, enable_grpc_with_eligibility_eval_http_resource_support, (),
+ (const, override));
+ MOCK_METHOD(bool, ensure_dynamic_tensors_are_released, (), (const, override));
+ MOCK_METHOD(int32_t, large_tensor_threshold_for_dynamic_allocation, (),
+ (const, override));
+ MOCK_METHOD(bool, disable_http_request_body_compression, (),
+ (const, override));
+ MOCK_METHOD(bool, use_http_federated_compute_protocol, (), (const, override));
+ MOCK_METHOD(bool, enable_computation_id, (), (const, override));
+ MOCK_METHOD(int32_t, waiting_period_sec_for_cancellation, (),
+ (const, override));
+ MOCK_METHOD(bool, enable_federated_select, (), (const, override));
+ MOCK_METHOD(int32_t, num_threads_for_tflite, (), (const, override));
+ MOCK_METHOD(bool, disable_tflite_delegate_clustering, (), (const, override));
+ MOCK_METHOD(bool, enable_example_query_plan_engine, (), (const, override));
+ MOCK_METHOD(bool, support_constant_tf_inputs, (), (const, override));
+ MOCK_METHOD(bool, http_protocol_supports_multiple_task_assignments, (),
+ (const, override));
+};
+
+// Helper methods for extracting opstats fields from TF examples.
+std::string ExtractSingleString(const tensorflow::Example& example,
+ const char key[]);
+google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
+ const tensorflow::Example& example, const char key[]);
+int64_t ExtractSingleInt64(const tensorflow::Example& example,
+ const char key[]);
+google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
+ const tensorflow::Example& example, const char key[]);
+
+class MockOpStatsDb : public ::fcp::client::opstats::OpStatsDb {
+ public:
+ MOCK_METHOD(absl::StatusOr<::fcp::client::opstats::OpStatsSequence>, Read, (),
+ (override));
+ MOCK_METHOD(absl::Status, Transform,
+ (std::function<void(::fcp::client::opstats::OpStatsSequence&)>),
+ (override));
+};
+
+class MockPhaseLogger : public PhaseLogger {
+ public:
+ MOCK_METHOD(
+ void, UpdateRetryWindowAndNetworkStats,
+ (const ::google::internal::federatedml::v2::RetryWindow& retry_window,
+ const NetworkStats& network_stats),
+ (override));
+ MOCK_METHOD(void, SetModelIdentifier, (absl::string_view model_identifier),
+ (override));
+ MOCK_METHOD(void, LogTaskNotStarted, (absl::string_view error_message),
+ (override));
+ MOCK_METHOD(void, LogNonfatalInitializationError, (absl::Status error_status),
+ (override));
+ MOCK_METHOD(void, LogFatalInitializationError, (absl::Status error_status),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinStarted, (), (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinInvalidPayloadError,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalNotConfigured,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinTurnedAway,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinPlanUriReceived,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalCheckinCompleted,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationStarted, (), (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationInvalidArgument,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationExampleIteratorError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationTensorflowError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationInterrupted,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogEligibilityEvalComputationCompleted,
+ (const ExampleStats& example_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinStarted, (), (override));
+ MOCK_METHOD(void, LogCheckinIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinInvalidPayload,
+ (absl::string_view error_message,
+ const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinTurnedAway,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_checkin, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogCheckinPlanUriReceived,
+ (absl::string_view task_name, const NetworkStats& network_stats,
+ absl::Time time_before_checkin),
+ (override));
+ MOCK_METHOD(void, LogCheckinCompleted,
+ (absl::string_view task_name, const NetworkStats& network_stats,
+ absl::Time time_before_checkin,
+ absl::Time time_before_plan_download, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogComputationStarted, (), (override));
+ MOCK_METHOD(void, LogComputationInvalidArgument,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogComputationExampleIteratorError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogComputationIOError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time),
+ (override));
+ MOCK_METHOD(void, LogComputationTensorflowError,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogComputationInterrupted,
+ (absl::Status error_status, const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogComputationCompleted,
+ (const ExampleStats& example_stats,
+ const NetworkStats& network_stats,
+ absl::Time run_plan_start_time, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(absl::Status, LogResultUploadStarted, (), (override));
+ MOCK_METHOD(void, LogResultUploadIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogResultUploadClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogResultUploadServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogResultUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+ MOCK_METHOD(absl::Status, LogFailureUploadStarted, (), (override));
+ MOCK_METHOD(void, LogFailureUploadIOError,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogFailureUploadClientInterrupted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogFailureUploadServerAborted,
+ (absl::Status error_status, const NetworkStats& network_stats,
+ absl::Time time_before_failure_upload,
+ absl::Time reference_time),
+ (override));
+ MOCK_METHOD(void, LogFailureUploadCompleted,
+ (const NetworkStats& network_stats,
+ absl::Time time_before_result_upload, absl::Time reference_time),
+ (override));
+};
+
+class MockFederatedSelectManager : public FederatedSelectManager {
+ public:
+ MOCK_METHOD(std::unique_ptr<engine::ExampleIteratorFactory>,
+ CreateExampleIteratorFactoryForUriTemplate,
+ (absl::string_view uri_template), (override));
+
+ MOCK_METHOD(NetworkStats, GetNetworkStats, (), (override));
+};
+
+class MockFederatedSelectExampleIteratorFactory
+ : public FederatedSelectExampleIteratorFactory {
+ public:
+ MOCK_METHOD(absl::StatusOr<std::unique_ptr<ExampleIterator>>,
+ CreateExampleIterator,
+ (const ::google::internal::federated::plan::ExampleSelector&
+ example_selector),
+ (override));
+};
+
+class MockSecAggRunnerFactory : public SecAggRunnerFactory {
+ public:
+ MOCK_METHOD(std::unique_ptr<SecAggRunner>, CreateSecAggRunner,
+ (std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
+ std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
+ SecAggEventPublisher* secagg_event_publisher,
+ LogManager* log_manager,
+ InterruptibleRunner* interruptible_runner,
+ int64_t expected_number_of_clients,
+ int64_t minimum_surviving_clients_for_reconstruction),
+ (override));
+};
+
+class MockSecAggRunner : public SecAggRunner {
+ public:
+ MOCK_METHOD(absl::Status, Run, (ComputationResults results), (override));
+};
+
+class MockSecAggSendToServerBase : public SecAggSendToServerBase {
+ MOCK_METHOD(void, Send, (secagg::ClientToServerWrapperMessage * message),
+ (override));
+};
+
+class MockSecAggProtocolDelegate : public SecAggProtocolDelegate {
+ public:
+ MOCK_METHOD(absl::StatusOr<uint64_t>, GetModulus, (const std::string& key),
+ (override));
+ MOCK_METHOD(absl::StatusOr<secagg::ServerToClientWrapperMessage>,
+ ReceiveServerMessage, (), (override));
+ MOCK_METHOD(void, Abort, (), (override));
+};
+
+} // namespace client
+} // namespace fcp
+
+#endif // FCP_CLIENT_TEST_HELPERS_H_