diff options
author | qiaoli <qiaoli@google.com> | 2023-05-22 23:24:26 +0000 |
---|---|---|
committer | qiaoli <qiaoli@google.com> | 2023-06-08 23:13:52 +0000 |
commit | 1e493e03e249aec6c4cb9f3d91e7b7bcb1157649 (patch) | |
tree | bf0a8f49bc4f85e552fa9be7567b3560e1c016b8 | |
parent | 46621e2009439889a6f96f66b9de71c447e3bafc (diff) | |
download | federated-compute-1e493e03e249aec6c4cb9f3d91e7b7bcb1157649.tar.gz |
Add build rules for federated compute library
Bug: 242229007
Test: mma
Change-Id: If6681e8ece5ebe2821028d7de4865504e17a5920
-rw-r--r-- | Android.bp | 167 | ||||
-rw-r--r-- | fcp/client/fcp_runner.cc | 290 | ||||
-rw-r--r-- | fcp/client/fcp_runner.h | 46 | ||||
-rw-r--r-- | fcp/protos/federatedcompute/common.proto | 2 |
4 files changed, 504 insertions, 1 deletions
diff --git a/Android.bp b/Android.bp new file mode 100644 index 0000000..c3d2e4a --- /dev/null +++ b/Android.bp @@ -0,0 +1,167 @@ +// Copyright (C) 2023 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package { + default_applicable_licenses: ["external_federated_compute_license"], +} + +license { + name: "external_federated_compute_license", + visibility: [":__subpackages__"], + license_kinds: [ + "SPDX-license-identifier-Apache-2.0", + ], + license_text: [ + "LICENSE", + ], +} + +cc_library_headers { + name: "libfederatedcompute_headers", + export_include_dirs: ["."], + sdk_version: "current", + min_sdk_version: "33", + apex_available: ["com.android.ondevicepersonalization"], +} + +java_library_static { + name: "federated-compute-java-proto-lite", + proto: { + type: "lite", + canonical_path_from_root: false, + include_dirs: [ + "external/protobuf/src", + "external/protobuf/java", + "external/tensorflow", + ], + }, + srcs: [ + "fcp/secagg/shared/secagg_messages.proto", + "fcp/protos/federatedcompute/aggregations.proto", + "fcp/protos/federatedcompute/eligibility_eval_tasks.proto", + "fcp/protos/federatedcompute/task_assignments.proto", + "fcp/protos/federatedcompute/common.proto", + "fcp/protos/plan.proto", + "fcp/protos/federated_api.proto", + "fcp/client/**/*.proto", + ":libprotobuf-internal-protos", + ], + static_libs: [ + "libprotobuf-java-lite", + "tensorflow_core_proto_java_lite", + ], + sdk_version: "current", + min_sdk_version: "33", + apex_available: ["com.android.ondevicepersonalization"], +} + +cc_library { + name: "federated-compute-cc-proto-lite", + srcs: [ + "fcp/secagg/shared/secagg_messages.proto", + "fcp/secagg/server/secagg_server_enums.proto", + "fcp/client/**/*.proto", + "fcp/protos/**/*.proto", + "fcp/dictionary/*.proto", + ":libprotobuf-internal-protos", + ], + proto: { + type: "lite", + export_proto_headers: true, + canonical_path_from_root: false, + include_dirs: [ + "external/protobuf/src", + "external/tensorflow", + ], + }, + static_libs: [ + "tensorflow_core_proto_cpp_lite", + "libprotobuf-cpp-lite-ndk", + ], + shared_libs: [ + "liblog", + ], + stl: "libc++_static", + apex_available: ["com.android.ondevicepersonalization"], + sdk_version: "current", + min_sdk_version: "33", +} + +cc_library_static { + name: "libfederatedcompute", + srcs: [ + "fcp/client/fcp_runner.cc", + "fcp/client/interruptible_runner.cc", + "fcp/client/simple_task_environment.cc", + "fcp/client/engine/*.cc", + "fcp/tensorflow/*.cc", + "fcp/base/*.cc", + ], + exclude_srcs: [ + "fcp/**/*test*.cc", + "fcp/client/fake_*.cc", + "fcp/client/engine/tf_*.cc", + "fcp/tensorflow/tf_session.cc", + "fcp/base/string_stream.cc", + "fcp/base/status_converters.cc", + ], + whole_static_libs: [ + "federated-compute-cc-proto-lite", + "libtflite_flex_delegate", + ], + header_libs: [ + "flatbuffer_headers", + "libeigen", + "libtextclassifier_hash_headers", + ], + shared_libs: [ + "libcrypto", + ], + visibility: [ + "//packages/modules/OnDevicePersonalization:__subpackages__", + ], + cflags: [ + "-DNAMESPACE_FOR_HASH_FUNCTIONS=farmhash", + "-Wno-ignored-qualifiers", + "-Wno-unused-parameter", + "-Wno-missing-field-initializers", + "-Wno-defaulted-function-deleted", + "-Wno-deprecated-declarations", + ], + stl: "libc++_static", + sdk_version: "current", + apex_available: ["com.android.ondevicepersonalization"], + min_sdk_version: "33", +} + +filegroup { + name: "fcp_native_wrapper", + srcs: ["fcp/java_src/main/java/com/google/fcp/client/CallFromNativeWrapper.java"], + visibility: [ + "//packages/modules/OnDevicePersonalization:__subpackages__" + ], +} + +filegroup { + name: "fcp_artifacts_testdata", + srcs: [ + "fcp/testdata/federation_client_only_plan.pb", + "fcp/testdata/federation_proxy_train_examples.pb", + "fcp/testdata/federation_test_checkpoint.client.ckp", + "fcp/testdata/federation_test_select_checkpoints.pb", + ], + visibility: [ + "//packages/modules/OnDevicePersonalization:__subpackages__" + ], +} diff --git a/fcp/client/fcp_runner.cc b/fcp/client/fcp_runner.cc new file mode 100644 index 0000000..4736b6a --- /dev/null +++ b/fcp/client/fcp_runner.cc @@ -0,0 +1,290 @@ +#include "fcp/client/fcp_runner.h" + +#include "fcp/client/engine/example_iterator_factory.h" +#include "fcp/client/engine/example_query_plan_engine.h" +#include "fcp/client/engine/plan_engine_helpers.h" +#include "fcp/client/engine/tflite_plan_engine.h" +#include "fcp/client/fl_runner.pb.h" +#include "fcp/client/opstats/opstats_logger.h" +#include "fcp/protos/plan.pb.h" + +namespace fcp { +namespace client { + +using ::fcp::client::opstats::OpStatsLogger; +using ::google::internal::federated::plan::AggregationConfig; +using ::google::internal::federated::plan::ClientOnlyPlan; +using ::google::internal::federated::plan::FederatedComputeIORouter; +using ::google::internal::federated::plan::TensorflowSpec; + +using TfLiteInputs = absl::flat_hash_map<std::string, std::string>; +namespace { + +// Creates an ExampleIteratorFactory that routes queries to the +// SimpleTaskEnvironment::CreateExampleIterator() method. +std::unique_ptr<engine::ExampleIteratorFactory> +CreateSimpleTaskEnvironmentIteratorFactory( + SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) { + return std::make_unique<engine::FunctionalExampleIteratorFactory>( + /*can_handle_func=*/ + [](const google::internal::federated::plan::ExampleSelector&) { + // The SimpleTaskEnvironment-based ExampleIteratorFactory should + // be the catch-all factory that is able to handle all queries + // that no other ExampleIteratorFactory is able to handle. + return true; + }, + /*create_iterator_func=*/ + [task_env, selector_context]( + const google::internal::federated::plan::ExampleSelector& + example_selector) { + return task_env->CreateExampleIterator(example_selector, + selector_context); + }, + /*should_collect_stats=*/true); +} + +std::unique_ptr<TfLiteInputs> ConstructTFLiteInputsForTensorflowSpecPlan( + const FederatedComputeIORouter& io_router, + const std::string& checkpoint_input_filename, + const std::string& checkpoint_output_filename) { + auto inputs = std::make_unique<TfLiteInputs>(); + if (!io_router.input_filepath_tensor_name().empty()) { + (*inputs)[io_router.input_filepath_tensor_name()] = + checkpoint_input_filename; + } + + if (!io_router.output_filepath_tensor_name().empty()) { + (*inputs)[io_router.output_filepath_tensor_name()] = + checkpoint_output_filename; + } + + return inputs; +} + +absl::StatusOr<std::vector<std::string>> ConstructOutputsWithDeterministicOrder( + const TensorflowSpec& tensorflow_spec, + const FederatedComputeIORouter& io_router) { + std::vector<std::string> output_names; + // The order of output tensor names should match the order in TensorflowSpec. + for (const auto& output_tensor_spec : tensorflow_spec.output_tensor_specs()) { + std::string tensor_name = output_tensor_spec.name(); + if (!io_router.aggregations().contains(tensor_name) || + !io_router.aggregations().at(tensor_name).has_secure_aggregation()) { + return absl::InvalidArgumentError( + "Output tensor is missing in AggregationConfig, or has unsupported " + "aggregation type."); + } + output_names.push_back(tensor_name); + } + + return output_names; +} + +struct PlanResultAndCheckpointFile { + explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result) + : plan_result(std::move(plan_result)) {} + engine::PlanResult plan_result; + std::string checkpoint_file; + + PlanResultAndCheckpointFile(PlanResultAndCheckpointFile&&) = default; + PlanResultAndCheckpointFile& operator=(PlanResultAndCheckpointFile&&) = + default; + + // Disallow copy and assign. + PlanResultAndCheckpointFile(const PlanResultAndCheckpointFile&) = delete; + PlanResultAndCheckpointFile& operator=(const PlanResultAndCheckpointFile&) = + delete; +}; + +PlanResultAndCheckpointFile RunPlanWithExampleQuerySpec( + std::vector<engine::ExampleIteratorFactory*> example_iterator_factories, + OpStatsLogger* opstats_logger, const Flags* flags, + const ClientOnlyPlan& client_plan, + const std::string& checkpoint_output_filename) { + if (!client_plan.phase().has_example_query_spec()) { + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, + absl::InvalidArgumentError("Plan must include ExampleQuerySpec"))); + } + if (!flags->enable_example_query_plan_engine()) { + // Example query plan received while the flag is off. + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, + absl::InvalidArgumentError( + "Example query plan received while the flag is off"))); + } + if (!client_plan.phase().has_federated_example_query()) { + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, + absl::InvalidArgumentError("Invalid ExampleQuerySpec-based plan"))); + } + for (const auto& example_query : + client_plan.phase().example_query_spec().example_queries()) { + for (auto const& [vector_name, spec] : + example_query.output_vector_specs()) { + const auto& aggregations = + client_plan.phase().federated_example_query().aggregations(); + if ((aggregations.find(vector_name) == aggregations.end()) || + !aggregations.at(vector_name).has_tf_v1_checkpoint_aggregation()) { + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, + absl::InvalidArgumentError("Output vector is missing in " + "AggregationConfig, or has unsupported " + "aggregation type."))); + } + } + } + + engine::ExampleQueryPlanEngine plan_engine(example_iterator_factories, + opstats_logger); + engine::PlanResult plan_result = plan_engine.RunPlan( + client_plan.phase().example_query_spec(), checkpoint_output_filename); + PlanResultAndCheckpointFile result(std::move(plan_result)); + result.checkpoint_file = checkpoint_output_filename; + return result; +} + +PlanResultAndCheckpointFile RunPlanWithTensorflowSpec( + std::vector<engine::ExampleIteratorFactory*> example_iterator_factories, + std::function<bool()> should_abort, LogManager* log_manager, + OpStatsLogger* opstats_logger, const Flags* flags, + const ClientOnlyPlan& client_plan, + const std::string& checkpoint_input_filename, + const std::string& checkpoint_output_filename, + const fcp::client::InterruptibleRunner::TimingConfig& timing_config) { + if (!client_plan.phase().has_tensorflow_spec()) { + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, + absl::InvalidArgumentError("Plan must include TensorflowSpec."))); + } + if (!client_plan.phase().has_federated_compute()) { + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, + absl::InvalidArgumentError("Invalid TensorflowSpec-based plan"))); + } + + // Get the output tensor names. + absl::StatusOr<std::vector<std::string>> output_names; + output_names = ConstructOutputsWithDeterministicOrder( + client_plan.phase().tensorflow_spec(), + client_plan.phase().federated_compute()); + if (!output_names.ok()) { + return PlanResultAndCheckpointFile(engine::PlanResult( + engine::PlanOutcome::kInvalidArgument, output_names.status())); + } + + // Run plan and get a set of output tensors back. + if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) { + std::unique_ptr<TfLiteInputs> tflite_inputs = + ConstructTFLiteInputsForTensorflowSpecPlan( + client_plan.phase().federated_compute(), checkpoint_input_filename, + checkpoint_output_filename); + engine::TfLitePlanEngine plan_engine(example_iterator_factories, + should_abort, log_manager, + opstats_logger, flags, &timing_config); + engine::PlanResult plan_result = plan_engine.RunPlan( + client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(), + std::move(tflite_inputs), *output_names); + PlanResultAndCheckpointFile result(std::move(plan_result)); + result.checkpoint_file = checkpoint_output_filename; + + return result; + } + + return PlanResultAndCheckpointFile( + engine::PlanResult(engine::PlanOutcome::kTensorflowError, + absl::InternalError("No plan engine enabled"))); +} +} // namespace + +absl::StatusOr<FLRunnerResult> RunFederatedComputation( + SimpleTaskEnvironment* env_deps, LogManager* log_manager, + const Flags* flags, + const google::internal::federated::plan::ClientOnlyPlan& client_plan, + const std::string& checkpoint_input_filename, + const std::string& checkpoint_output_filename, + const std::string& session_name, const std::string& population_name, + const std::string& task_name, + const fcp::client::InterruptibleRunner::TimingConfig& timing_config) { + SelectorContext federated_selector_context; + federated_selector_context.mutable_computation_properties()->set_session_name( + session_name); + FederatedComputation federated_computation; + federated_computation.set_population_name(population_name); + *federated_selector_context.mutable_computation_properties() + ->mutable_federated() = federated_computation; + federated_selector_context.mutable_computation_properties() + ->mutable_federated() + ->set_task_name(task_name); + if (client_plan.phase().has_example_query_spec()) { + federated_selector_context.mutable_computation_properties() + ->set_example_iterator_output_format( + ::fcp::client::QueryTimeComputationProperties:: + EXAMPLE_QUERY_RESULT); + } else { + const auto& federated_compute_io_router = + client_plan.phase().federated_compute(); + const bool has_simpleagg_tensors = + !federated_compute_io_router.output_filepath_tensor_name().empty(); + bool all_aggregations_are_secagg = true; + for (const auto& aggregation : federated_compute_io_router.aggregations()) { + all_aggregations_are_secagg &= + aggregation.second.protocol_config_case() == + AggregationConfig::kSecureAggregation; + } + if (!has_simpleagg_tensors && all_aggregations_are_secagg) { + federated_selector_context.mutable_computation_properties() + ->mutable_federated() + ->mutable_secure_aggregation() + ->set_minimum_clients_in_server_visible_aggregate(100); + } else { + // Has an output checkpoint, so some tensors must be simply aggregated. + *(federated_selector_context.mutable_computation_properties() + ->mutable_federated() + ->mutable_simple_aggregation()) = SimpleAggregation(); + } + } + + auto opstats_logger = + engine::CreateOpStatsLogger(env_deps->GetBaseDir(), flags, log_manager, + session_name, population_name); + + // Check if the device conditions allow for checking in with the server + // and running a federated computation. If not, bail early with the + // transient error retry window. + std::function<bool()> should_abort = [env_deps, &timing_config]() { + return env_deps->ShouldAbort(absl::Now(), timing_config.polling_period); + }; + + // Regular plans can use example iterators from the SimpleTaskEnvironment, + // those reading the OpStats DB, or those serving Federated Select slices. + std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory = + CreateSimpleTaskEnvironmentIteratorFactory(env_deps, + federated_selector_context); + std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{ + env_example_iterator_factory.get()}; + PlanResultAndCheckpointFile plan_result_and_checkpoint_file = + client_plan.phase().has_example_query_spec() + ? RunPlanWithExampleQuerySpec(example_iterator_factories, + opstats_logger.get(), flags, + client_plan, checkpoint_output_filename) + : RunPlanWithTensorflowSpec(example_iterator_factories, should_abort, + log_manager, opstats_logger.get(), flags, + client_plan, checkpoint_input_filename, + checkpoint_output_filename, + timing_config); + auto outcome = plan_result_and_checkpoint_file.plan_result.outcome; + FLRunnerResult fl_runner_result; + + if (outcome == engine::PlanOutcome::kSuccess) { + fl_runner_result.set_contribution_result(FLRunnerResult::SUCCESS); + } else { + fl_runner_result.set_contribution_result(FLRunnerResult::FAIL); + std::string error_message = std::string{ + plan_result_and_checkpoint_file.plan_result.original_status.message()}; + } + return fl_runner_result; +} + +} // namespace client +} // namespace fcp
\ No newline at end of file diff --git a/fcp/client/fcp_runner.h b/fcp/client/fcp_runner.h new file mode 100644 index 0000000..0e6d369 --- /dev/null +++ b/fcp/client/fcp_runner.h @@ -0,0 +1,46 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FCP_CLIENT_FCP_RUNNER_H_ +#define FCP_CLIENT_FCP_RUNNER_H_ + +#include <string> + +#include "fcp/client/fl_runner.pb.h" +#include "fcp/client/flags.h" +#include "fcp/client/interruptible_runner.h" +#include "fcp/client/log_manager.h" +#include "fcp/client/simple_task_environment.h" +#include "fcp/protos/plan.pb.h" + +namespace fcp { +namespace client { + +// This is exposed for use that only invoke run plan on engine and exclude http +// protocol parts. +absl::StatusOr<FLRunnerResult> RunFederatedComputation( + SimpleTaskEnvironment* env_deps, LogManager* log_manager, + const Flags* flags, + const google::internal::federated::plan::ClientOnlyPlan& client_plan, + const std::string& checkpoint_input_filename, + const std::string& checkpoint_output_filename, + const std::string& session_name, const std::string& population_name, + const std::string& task_name, + const fcp::client::InterruptibleRunner::TimingConfig& timing_config); + +} // namespace client +} // namespace fcp + +#endif // FCP_CLIENT_FCP_RUNNER_H_
\ No newline at end of file diff --git a/fcp/protos/federatedcompute/common.proto b/fcp/protos/federatedcompute/common.proto index e67bec9..0e6ed03 100644 --- a/fcp/protos/federatedcompute/common.proto +++ b/fcp/protos/federatedcompute/common.proto @@ -150,7 +150,7 @@ message ByteStreamResource { // Copied from //google/rpc/status.proto. message Status { // The status code, which should be an enum value of [google.rpc.Code][]. - int32 code = 1; + Code code = 1; string message = 2; } |