diff options
Diffstat (limited to 'fcp/aggregation/core/federated_mean.cc')
-rw-r--r-- | fcp/aggregation/core/federated_mean.cc | 101 |
1 files changed, 78 insertions, 23 deletions
diff --git a/fcp/aggregation/core/federated_mean.cc b/fcp/aggregation/core/federated_mean.cc index 5e3d35c..b7c21d3 100644 --- a/fcp/aggregation/core/federated_mean.cc +++ b/fcp/aggregation/core/federated_mean.cc @@ -16,19 +16,21 @@ #include <cstddef> #include <memory> +#include <string> #include <utility> #include <vector> +#include "fcp/aggregation/core/agg_core.pb.h" #include "fcp/aggregation/core/agg_vector.h" #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/input_tensor_list.h" #include "fcp/aggregation/core/intrinsic.h" #include "fcp/aggregation/core/mutable_vector_data.h" #include "fcp/aggregation/core/tensor.h" +#include "fcp/aggregation/core/tensor.pb.h" #include "fcp/aggregation/core/tensor_aggregator.h" #include "fcp/aggregation/core/tensor_aggregator_factory.h" #include "fcp/aggregation/core/tensor_aggregator_registry.h" -#include "fcp/aggregation/core/tensor_data.h" #include "fcp/aggregation/core/tensor_shape.h" #include "fcp/aggregation/core/tensor_spec.h" #include "fcp/base/monitoring.h" @@ -42,13 +44,29 @@ constexpr char kFederatedWeightedMeanUri[] = "federated_weighted_mean"; template <typename V, typename W> class FederatedMean final : public TensorAggregator { public: - explicit FederatedMean(DataType dtype, TensorShape shape, - MutableVectorData<V>* weighted_values_sum) - : weighted_values_sum_(*weighted_values_sum), - result_tensor_( - Tensor::Create(dtype, shape, - std::unique_ptr<TensorData>(weighted_values_sum)) - .value()) {} + explicit FederatedMean( + DataType dtype, TensorShape shape, + std::unique_ptr<MutableVectorData<V>> weighted_values_sum) + : FederatedMean(dtype, shape, std::move(weighted_values_sum), 0, 0) {} + + FederatedMean(DataType dtype, TensorShape shape, + std::unique_ptr<MutableVectorData<V>> weighted_values_sum, + W weights_sum, int num_inputs) + : dtype_(dtype), + shape_(std::move(shape)), + weighted_values_sum_(std::move(weighted_values_sum)), + weights_sum_(weights_sum), + num_inputs_(num_inputs) {} + + StatusOr<std::string> Serialize() && override { + FederatedMeanAggregatorState aggregator_state; + aggregator_state.set_num_inputs(num_inputs_); + *(aggregator_state.mutable_weighted_values_sum()) = + weighted_values_sum_->EncodeContent(); + *(aggregator_state.mutable_weights_sum()) = std::string( + reinterpret_cast<char*>(&weights_sum_), sizeof(weights_sum_)); + return aggregator_state.SerializeAsString(); + } private: Status MergeWith(TensorAggregator&& other) override { @@ -61,16 +79,16 @@ class FederatedMean final : public TensorAggregator { } FCP_RETURN_IF_ERROR((*other_ptr).CheckValid()); - std::pair<std::vector<V>, W> other_internal_state = + std::pair<std::unique_ptr<MutableVectorData<V>>, W> other_internal_state = other_ptr->GetInternalState(); - if (other_internal_state.first.size() != weighted_values_sum_.size()) { + if (other_internal_state.first->size() != weighted_values_sum_->size()) { return FCP_STATUS(INVALID_ARGUMENT) << "FederatedMean::MergeWith: Can only merge weighted value sum " "tensors of equal length."; } - for (int i = 0; i < weighted_values_sum_.size(); ++i) { - weighted_values_sum_[i] += other_internal_state.first[i]; + for (int i = 0; i < weighted_values_sum_->size(); ++i) { + (*weighted_values_sum_)[i] += (*other_internal_state.first)[i]; } weights_sum_ += other_internal_state.second; num_inputs_ += other_ptr->GetNumInputs(); @@ -101,12 +119,12 @@ class FederatedMean final : public TensorAggregator { "weights are allowed."; } for (auto value : values) { - weighted_values_sum_[value.index] += value.value * weight; + (*weighted_values_sum_)[value.index] += value.value * weight; } weights_sum_ += weight; } else { for (auto value : values) { - weighted_values_sum_[value.index] += value.value; + (*weighted_values_sum_)[value.index] += value.value; } } num_inputs_++; @@ -126,29 +144,32 @@ class FederatedMean final : public TensorAggregator { // Produce the final weighted mean values by dividing the weighted values // sum by the weights sum (tracked by weights_sum_ in the weighted case and // num_inputs_ in the non-weighted case). - for (int i = 0; i < weighted_values_sum_.size(); ++i) { - weighted_values_sum_[i] /= + for (int i = 0; i < weighted_values_sum_->size(); ++i) { + (*weighted_values_sum_)[i] /= (weights_sum_ > 0 ? weights_sum_ : num_inputs_); } OutputTensorList outputs = std::vector<Tensor>(); - outputs.push_back(std::move(result_tensor_)); + outputs.push_back( + Tensor::Create(dtype_, shape_, std::move(weighted_values_sum_)) + .value()); return outputs; } int GetNumInputs() const override { return num_inputs_; } - std::pair<std::vector<V>, W> GetInternalState() { + std::pair<std::unique_ptr<MutableVectorData<V>>, W> GetInternalState() { output_consumed_ = true; return std::make_pair(std::move(weighted_values_sum_), weights_sum_); } bool output_consumed_ = false; - std::vector<V>& weighted_values_sum_; + DataType dtype_; + TensorShape shape_; + std::unique_ptr<MutableVectorData<V>> weighted_values_sum_; // In the weighted case, use the weights_sum_ variable to track the total // weight. Otherwise, just rely on the num_inputs_ variable. - W weights_sum_ = 0; - Tensor result_tensor_; - int num_inputs_ = 0; + W weights_sum_; + int num_inputs_; }; // Factory class for the FederatedMean. @@ -162,6 +183,24 @@ class FederatedMeanFactory final : public TensorAggregatorFactory { StatusOr<std::unique_ptr<TensorAggregator>> Create( const Intrinsic& intrinsic) const override { + return CreateInternal(intrinsic, nullptr); + } + + StatusOr<std::unique_ptr<TensorAggregator>> Deserialize( + const Intrinsic& intrinsic, std::string serialized_state) const override { + FederatedMeanAggregatorState aggregator_state; + if (!aggregator_state.ParseFromString(serialized_state)) { + return FCP_STATUS(INVALID_ARGUMENT) + << "FederatedMeanFactory::Deserialize: Failed to parse " + "FederatedMeanAggregatorState."; + } + return CreateInternal(intrinsic, &aggregator_state); + } + + private: + StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal( + const Intrinsic& intrinsic, + const FederatedMeanAggregatorState* aggregator_state) const { // Check that the configuration is valid. if (kFederatedMeanUri == intrinsic.uri) { if (intrinsic.inputs.size() != 1) { @@ -232,13 +271,29 @@ class FederatedMeanFactory final : public TensorAggregatorFactory { } std::unique_ptr<TensorAggregator> aggregator; + if (aggregator_state == nullptr) { + FLOATING_ONLY_DTYPE_CASES( + input_value_type, V, + NUMERICAL_ONLY_DTYPE_CASES( + input_weight_type, W, + aggregator = (std::make_unique<FederatedMean<V, W>>( + input_value_type, input_value_spec.shape(), + std::make_unique<MutableVectorData<V>>( + value_num_elements.value()))))); + return aggregator; + } + FLOATING_ONLY_DTYPE_CASES( input_value_type, V, NUMERICAL_ONLY_DTYPE_CASES( input_weight_type, W, aggregator = (std::make_unique<FederatedMean<V, W>>( input_value_type, input_value_spec.shape(), - new MutableVectorData<V>(value_num_elements.value()))))); + MutableVectorData<V>::CreateFromEncodedContent( + aggregator_state->weighted_values_sum()), + *(reinterpret_cast<const W*>( + aggregator_state->weights_sum().data())), + aggregator_state->num_inputs())))); return aggregator; } }; |