diff options
Diffstat (limited to 'fcp/client/http/http_secagg_send_to_server_impl.h')
-rw-r--r-- | fcp/client/http/http_secagg_send_to_server_impl.h | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/fcp/client/http/http_secagg_send_to_server_impl.h b/fcp/client/http/http_secagg_send_to_server_impl.h new file mode 100644 index 0000000..8fc61a8 --- /dev/null +++ b/fcp/client/http/http_secagg_send_to_server_impl.h @@ -0,0 +1,181 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_ +#define FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_ + +#include <functional> +#include <memory> +#include <optional> +#include <string> +#include <utility> + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "fcp/client/http/protocol_request_helper.h" +#include "fcp/client/secagg_event_publisher.h" +#include "fcp/client/secagg_runner.h" +#include "fcp/protos/federatedcompute/common.pb.h" +#include "fcp/protos/federatedcompute/secure_aggregations.pb.h" +#include "fcp/secagg/shared/secagg_messages.pb.h" + +namespace fcp { +namespace client { +namespace http { + +// Implementation of SecAggSendToServerBase for HTTP federated protocol. +class HttpSecAggSendToServerImpl : public SecAggSendToServerBase { + public: + // Create an instance of HttpSecAggSendToServerImpl. + // This method returns error status when failed to create + // ProtocolRequestCreator based on the input ForwardingInfo or + // ByteStreamResources. + static absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>> Create( + absl::string_view api_key, Clock* clock, + ProtocolRequestHelper* request_helper, + InterruptibleRunner* interruptible_runner, + std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)> + delayed_interruptible_runner_creator, + absl::StatusOr<secagg::ServerToClientWrapperMessage>* + server_response_holder, + absl::string_view aggregation_id, absl::string_view client_token, + const google::internal::federatedcompute::v1::ForwardingInfo& + secagg_upload_forwarding_info, + const google::internal::federatedcompute::v1::ByteStreamResource& + masked_result_resource, + const google::internal::federatedcompute::v1::ByteStreamResource& + nonmasked_result_resource, + std::optional<std::string> tf_checkpoint, + bool disable_request_body_compression, + absl::Duration waiting_period_for_cancellation); + ~HttpSecAggSendToServerImpl() override = default; + + // Sends a client to server request based on the + // secagg::ClientToServerWrapperMessage, waits for the response, and set it to + // the server response holder. + void Send(secagg::ClientToServerWrapperMessage* message) override; + + private: + HttpSecAggSendToServerImpl( + absl::string_view api_key, Clock* clock, + ProtocolRequestHelper* request_helper, + InterruptibleRunner* interruptible_runner, + std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)> + delayed_interruptible_runner_creator, + absl::StatusOr<secagg::ServerToClientWrapperMessage>* + server_response_holder, + absl::string_view aggregation_id, absl::string_view client_token, + absl::string_view masked_resource_name, + absl::string_view nonmasked_resource_name, + std::unique_ptr<ProtocolRequestCreator> secagg_request_creator, + std::unique_ptr<ProtocolRequestCreator> + masked_result_upload_request_creator, + std::unique_ptr<ProtocolRequestCreator> + nonmasked_result_upload_request_creator, + std::optional<std::string> tf_checkpoint, + absl::Duration waiting_period_for_cancellation) + : api_key_(api_key), + clock_(*clock), + request_helper_(*request_helper), + interruptible_runner_(*interruptible_runner), + delayed_interruptible_runner_creator_( + delayed_interruptible_runner_creator), + server_response_holder_(*server_response_holder), + aggregation_id_(std::string(aggregation_id)), + client_token_(std::string(client_token)), + masked_resource_name_(std::string(masked_resource_name)), + nonmasked_resource_name_(std::string(nonmasked_resource_name)), + secagg_request_creator_(std::move(secagg_request_creator)), + masked_result_upload_request_creator_( + std::move(masked_result_upload_request_creator)), + nonmasked_result_upload_request_creator_( + std::move(nonmasked_result_upload_request_creator)), + tf_checkpoint_(std::move(tf_checkpoint)), + waiting_period_for_cancellation_(waiting_period_for_cancellation) {} + + // Sends an AbortSecureAggregationRequest. + absl::StatusOr<secagg::ServerToClientWrapperMessage> AbortSecureAggregation( + secagg::AbortMessage abort_message); + // Sends an AdvertiseKeysRequest and waits for the AdvertiseKeysResponse, + // polling the corresponding LRO if needed. + absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR0AdvertiseKeys( + secagg::AdvertiseKeys advertise_keys); + // Sends an ShareKeysRequest and waits for the ShareKeysResponse, polling + // the corresponding LRO if needed. + absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR1ShareKeys( + secagg::ShareKeysResponse share_keys_response); + // Uploads masked resource and (optional) nonmasked resource. After successful + // upload, sends an SubmitSecureAggregationResultRequest and waits for the + // SubmitSecureAggregationResultResponse, polling the corresponding LRO if + // needed. + absl::StatusOr<secagg::ServerToClientWrapperMessage> + DoR2SubmitSecureAggregationResult( + secagg::MaskedInputCollectionResponse masked_input_response); + // Sends an UnmaskRequest and waits for the UnmaskResponse. + absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR3Unmask( + secagg::UnmaskingResponse unmasking_response); + const std::string api_key_; + Clock& clock_; + ProtocolRequestHelper& request_helper_; + InterruptibleRunner& interruptible_runner_; + std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)> + delayed_interruptible_runner_creator_; + absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_; + std::string aggregation_id_; + std::string client_token_; + std::string masked_resource_name_; + std::string nonmasked_resource_name_; + std::unique_ptr<ProtocolRequestCreator> secagg_request_creator_; + std::unique_ptr<ProtocolRequestCreator> masked_result_upload_request_creator_; + std::unique_ptr<ProtocolRequestCreator> + nonmasked_result_upload_request_creator_; + std::optional<std::string> tf_checkpoint_; + absl::Duration waiting_period_for_cancellation_; +}; + +// Implementation of SecAggProtocolDelegate for the HTTP federated protocol. +class HttpSecAggProtocolDelegate : public SecAggProtocolDelegate { + public: + HttpSecAggProtocolDelegate( + google::protobuf::Map< + std::string, + google::internal::federatedcompute::v1::SecureAggregandExecutionInfo> + secure_aggregands, + absl::StatusOr<secagg::ServerToClientWrapperMessage>* + server_response_holder) + : secure_aggregands_(std::move(secure_aggregands)), + server_response_holder_(*server_response_holder) {} + // Retrieve the modulus for a given SecAgg vector. + absl::StatusOr<uint64_t> GetModulus(const std::string& key) override; + // Receive Server message. + absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage() + override; + // Called when the SecAgg protocol is interrupted. + void Abort() override; + size_t last_received_message_size() override; + + private: + google::protobuf::Map< + std::string, + google::internal::federatedcompute::v1::SecureAggregandExecutionInfo> + secure_aggregands_; + absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_; +}; + +} // namespace http +} // namespace client +} // namespace fcp + +#endif // FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_ |