aboutsummaryrefslogtreecommitdiff
path: root/fcp/aggregation/core/federated_mean.cc
diff options
context:
space:
mode:
Diffstat (limited to 'fcp/aggregation/core/federated_mean.cc')
-rw-r--r--fcp/aggregation/core/federated_mean.cc101
1 files changed, 78 insertions, 23 deletions
diff --git a/fcp/aggregation/core/federated_mean.cc b/fcp/aggregation/core/federated_mean.cc
index 5e3d35c..b7c21d3 100644
--- a/fcp/aggregation/core/federated_mean.cc
+++ b/fcp/aggregation/core/federated_mean.cc
@@ -16,19 +16,21 @@
#include <cstddef>
#include <memory>
+#include <string>
#include <utility>
#include <vector>
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/input_tensor_list.h"
#include "fcp/aggregation/core/intrinsic.h"
#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_aggregator_registry.h"
-#include "fcp/aggregation/core/tensor_data.h"
#include "fcp/aggregation/core/tensor_shape.h"
#include "fcp/aggregation/core/tensor_spec.h"
#include "fcp/base/monitoring.h"
@@ -42,13 +44,29 @@ constexpr char kFederatedWeightedMeanUri[] = "federated_weighted_mean";
template <typename V, typename W>
class FederatedMean final : public TensorAggregator {
public:
- explicit FederatedMean(DataType dtype, TensorShape shape,
- MutableVectorData<V>* weighted_values_sum)
- : weighted_values_sum_(*weighted_values_sum),
- result_tensor_(
- Tensor::Create(dtype, shape,
- std::unique_ptr<TensorData>(weighted_values_sum))
- .value()) {}
+ explicit FederatedMean(
+ DataType dtype, TensorShape shape,
+ std::unique_ptr<MutableVectorData<V>> weighted_values_sum)
+ : FederatedMean(dtype, shape, std::move(weighted_values_sum), 0, 0) {}
+
+ FederatedMean(DataType dtype, TensorShape shape,
+ std::unique_ptr<MutableVectorData<V>> weighted_values_sum,
+ W weights_sum, int num_inputs)
+ : dtype_(dtype),
+ shape_(std::move(shape)),
+ weighted_values_sum_(std::move(weighted_values_sum)),
+ weights_sum_(weights_sum),
+ num_inputs_(num_inputs) {}
+
+ StatusOr<std::string> Serialize() && override {
+ FederatedMeanAggregatorState aggregator_state;
+ aggregator_state.set_num_inputs(num_inputs_);
+ *(aggregator_state.mutable_weighted_values_sum()) =
+ weighted_values_sum_->EncodeContent();
+ *(aggregator_state.mutable_weights_sum()) = std::string(
+ reinterpret_cast<char*>(&weights_sum_), sizeof(weights_sum_));
+ return aggregator_state.SerializeAsString();
+ }
private:
Status MergeWith(TensorAggregator&& other) override {
@@ -61,16 +79,16 @@ class FederatedMean final : public TensorAggregator {
}
FCP_RETURN_IF_ERROR((*other_ptr).CheckValid());
- std::pair<std::vector<V>, W> other_internal_state =
+ std::pair<std::unique_ptr<MutableVectorData<V>>, W> other_internal_state =
other_ptr->GetInternalState();
- if (other_internal_state.first.size() != weighted_values_sum_.size()) {
+ if (other_internal_state.first->size() != weighted_values_sum_->size()) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "FederatedMean::MergeWith: Can only merge weighted value sum "
"tensors of equal length.";
}
- for (int i = 0; i < weighted_values_sum_.size(); ++i) {
- weighted_values_sum_[i] += other_internal_state.first[i];
+ for (int i = 0; i < weighted_values_sum_->size(); ++i) {
+ (*weighted_values_sum_)[i] += (*other_internal_state.first)[i];
}
weights_sum_ += other_internal_state.second;
num_inputs_ += other_ptr->GetNumInputs();
@@ -101,12 +119,12 @@ class FederatedMean final : public TensorAggregator {
"weights are allowed.";
}
for (auto value : values) {
- weighted_values_sum_[value.index] += value.value * weight;
+ (*weighted_values_sum_)[value.index] += value.value * weight;
}
weights_sum_ += weight;
} else {
for (auto value : values) {
- weighted_values_sum_[value.index] += value.value;
+ (*weighted_values_sum_)[value.index] += value.value;
}
}
num_inputs_++;
@@ -126,29 +144,32 @@ class FederatedMean final : public TensorAggregator {
// Produce the final weighted mean values by dividing the weighted values
// sum by the weights sum (tracked by weights_sum_ in the weighted case and
// num_inputs_ in the non-weighted case).
- for (int i = 0; i < weighted_values_sum_.size(); ++i) {
- weighted_values_sum_[i] /=
+ for (int i = 0; i < weighted_values_sum_->size(); ++i) {
+ (*weighted_values_sum_)[i] /=
(weights_sum_ > 0 ? weights_sum_ : num_inputs_);
}
OutputTensorList outputs = std::vector<Tensor>();
- outputs.push_back(std::move(result_tensor_));
+ outputs.push_back(
+ Tensor::Create(dtype_, shape_, std::move(weighted_values_sum_))
+ .value());
return outputs;
}
int GetNumInputs() const override { return num_inputs_; }
- std::pair<std::vector<V>, W> GetInternalState() {
+ std::pair<std::unique_ptr<MutableVectorData<V>>, W> GetInternalState() {
output_consumed_ = true;
return std::make_pair(std::move(weighted_values_sum_), weights_sum_);
}
bool output_consumed_ = false;
- std::vector<V>& weighted_values_sum_;
+ DataType dtype_;
+ TensorShape shape_;
+ std::unique_ptr<MutableVectorData<V>> weighted_values_sum_;
// In the weighted case, use the weights_sum_ variable to track the total
// weight. Otherwise, just rely on the num_inputs_ variable.
- W weights_sum_ = 0;
- Tensor result_tensor_;
- int num_inputs_ = 0;
+ W weights_sum_;
+ int num_inputs_;
};
// Factory class for the FederatedMean.
@@ -162,6 +183,24 @@ class FederatedMeanFactory final : public TensorAggregatorFactory {
StatusOr<std::unique_ptr<TensorAggregator>> Create(
const Intrinsic& intrinsic) const override {
+ return CreateInternal(intrinsic, nullptr);
+ }
+
+ StatusOr<std::unique_ptr<TensorAggregator>> Deserialize(
+ const Intrinsic& intrinsic, std::string serialized_state) const override {
+ FederatedMeanAggregatorState aggregator_state;
+ if (!aggregator_state.ParseFromString(serialized_state)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "FederatedMeanFactory::Deserialize: Failed to parse "
+ "FederatedMeanAggregatorState.";
+ }
+ return CreateInternal(intrinsic, &aggregator_state);
+ }
+
+ private:
+ StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal(
+ const Intrinsic& intrinsic,
+ const FederatedMeanAggregatorState* aggregator_state) const {
// Check that the configuration is valid.
if (kFederatedMeanUri == intrinsic.uri) {
if (intrinsic.inputs.size() != 1) {
@@ -232,13 +271,29 @@ class FederatedMeanFactory final : public TensorAggregatorFactory {
}
std::unique_ptr<TensorAggregator> aggregator;
+ if (aggregator_state == nullptr) {
+ FLOATING_ONLY_DTYPE_CASES(
+ input_value_type, V,
+ NUMERICAL_ONLY_DTYPE_CASES(
+ input_weight_type, W,
+ aggregator = (std::make_unique<FederatedMean<V, W>>(
+ input_value_type, input_value_spec.shape(),
+ std::make_unique<MutableVectorData<V>>(
+ value_num_elements.value())))));
+ return aggregator;
+ }
+
FLOATING_ONLY_DTYPE_CASES(
input_value_type, V,
NUMERICAL_ONLY_DTYPE_CASES(
input_weight_type, W,
aggregator = (std::make_unique<FederatedMean<V, W>>(
input_value_type, input_value_spec.shape(),
- new MutableVectorData<V>(value_num_elements.value())))));
+ MutableVectorData<V>::CreateFromEncodedContent(
+ aggregator_state->weighted_values_sum()),
+ *(reinterpret_cast<const W*>(
+ aggregator_state->weights_sum().data())),
+ aggregator_state->num_inputs()))));
return aggregator;
}
};