aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaya Spivak <mspivak@google.com>2024-04-08 15:24:11 -0700
committerCopybara-Service <copybara-worker@google.com>2024-04-08 15:25:07 -0700
commit5fb2bf43a96ef4ba83ed28ad218eb6e036dc55f4 (patch)
tree57b9917559d4dc87b5f4567e985bd911b46a17e7
parent3a84a0625d90657928ea2aed4384e2448d84280f (diff)
downloadfederated-compute-5fb2bf43a96ef4ba83ed28ad218eb6e036dc55f4.tar.gz
Add serialization and deserialization methods to TensorAggregator and TensorAggregatorFactory respectively.
Add serialization implementation for AggVectorAggregator and deserialization implementation for FederatedSum which is currently the only use of AggVectorAggregator. Implementation for other aggregator types will be added in subsequent changes. PiperOrigin-RevId: 622967226
-rw-r--r--fcp/aggregation/core/BUILD16
-rw-r--r--fcp/aggregation/core/agg_core.proto9
-rw-r--r--fcp/aggregation/core/agg_vector_aggregator.h64
-rw-r--r--fcp/aggregation/core/agg_vector_aggregator_test.cc31
-rw-r--r--fcp/aggregation/core/federated_sum.cc33
-rw-r--r--fcp/aggregation/core/federated_sum_test.cc26
-rw-r--r--fcp/aggregation/core/mutable_vector_data.h17
-rw-r--r--fcp/aggregation/core/mutable_vector_data_test.cc16
-rw-r--r--fcp/aggregation/core/tensor_aggregator.h7
-rw-r--r--fcp/aggregation/core/tensor_aggregator_factory.h12
-rw-r--r--fcp/aggregation/core/tensor_aggregator_registry.cc9
-rw-r--r--fcp/aggregation/core/tensor_aggregator_registry.h5
12 files changed, 220 insertions, 25 deletions
diff --git a/fcp/aggregation/core/BUILD b/fcp/aggregation/core/BUILD
index 0e9cf8a..841de08 100644
--- a/fcp/aggregation/core/BUILD
+++ b/fcp/aggregation/core/BUILD
@@ -51,6 +51,16 @@ py_proto_library(
],
)
+proto_library(
+ name = "agg_core_proto",
+ srcs = ["agg_core.proto"],
+)
+
+cc_proto_library(
+ name = "agg_core_cc_proto",
+ deps = [":agg_core_proto"],
+)
+
cc_library(
name = "tensor",
srcs = [
@@ -102,9 +112,12 @@ cc_library(
],
copts = FCP_COPTS,
deps = [
+ ":agg_core_cc_proto",
":tensor",
+ ":tensor_cc_proto",
"//fcp/base",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/synchronization",
],
)
@@ -131,6 +144,7 @@ cc_library(
],
copts = FCP_COPTS,
deps = [
+ ":agg_core_cc_proto",
":aggregator",
":fedsql_constants",
":tensor",
@@ -184,6 +198,7 @@ cc_test(
],
copts = FCP_COPTS,
deps = [
+ ":agg_core_cc_proto",
":aggregator",
":tensor",
":tensor_cc_proto",
@@ -253,6 +268,7 @@ cc_test(
srcs = ["federated_sum_test.cc"],
copts = FCP_COPTS,
deps = [
+ ":agg_core_cc_proto",
":aggregation_cores",
":aggregator",
":tensor",
diff --git a/fcp/aggregation/core/agg_core.proto b/fcp/aggregation/core/agg_core.proto
new file mode 100644
index 0000000..0ce048e
--- /dev/null
+++ b/fcp/aggregation/core/agg_core.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package fcp.aggregation;
+
+// Internal state representation of an AggVectorAggregator.
+message AggVectorAggregatorState {
+ uint64 num_inputs = 1;
+ bytes vector_data = 2;
+}
diff --git a/fcp/aggregation/core/agg_vector_aggregator.h b/fcp/aggregation/core/agg_vector_aggregator.h
index ec40dea..3fbb88f 100644
--- a/fcp/aggregation/core/agg_vector_aggregator.h
+++ b/fcp/aggregation/core/agg_vector_aggregator.h
@@ -17,18 +17,21 @@
#ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
#define FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
+#include <cstddef>
#include <cstdint>
#include <memory>
+#include <string>
#include <utility>
#include <vector>
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/input_tensor_list.h"
#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/tensor.h"
+#include "fcp/aggregation/core/tensor.pb.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
-#include "fcp/aggregation/core/tensor_data.h"
#include "fcp/aggregation/core/tensor_shape.h"
#include "fcp/base/monitoring.h"
@@ -41,10 +44,21 @@ template <typename T>
class AggVectorAggregator : public TensorAggregator {
public:
AggVectorAggregator(DataType dtype, TensorShape shape)
- : AggVectorAggregator(dtype, shape, CreateData(shape)) {}
+ : AggVectorAggregator(dtype, shape, CreateData(shape), 0) {}
+
+ AggVectorAggregator(DataType dtype, TensorShape shape,
+ std::unique_ptr<MutableVectorData<T>> data,
+ int num_inputs)
+ : dtype_(dtype),
+ shape_(std::move(shape)),
+ data_vector_(std::move(data)),
+ num_inputs_(num_inputs) {
+ FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype)
+ << "Incompatible dtype";
+ }
// Provides mutable access to the aggregator data as a vector<T>
- inline std::vector<T>& data() { return data_vector_; }
+ inline std::vector<T>& data() { return *data_vector_; }
int GetNumInputs() const override { return num_inputs_; }
@@ -58,7 +72,7 @@ class AggVectorAggregator : public TensorAggregator {
<< "AggVectorAggregator::MergeOutputTensors: AggVectorAggregator "
"should produce a single output tensor";
const Tensor& output = output_tensors[0];
- if (output.shape() != result_tensor_.shape()) {
+ if (output.shape() != shape_) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "AggVectorAggregator::MergeOutputTensors: tensor shape "
"mismatch";
@@ -70,6 +84,13 @@ class AggVectorAggregator : public TensorAggregator {
return FCP_STATUS(OK);
}
+ StatusOr<std::string> Serialize() && override {
+ AggVectorAggregatorState aggregator_state;
+ aggregator_state.set_num_inputs(num_inputs_);
+ *(aggregator_state.mutable_vector_data()) = data_vector_->EncodeContent();
+ return aggregator_state.SerializeAsString();
+ }
+
protected:
// Implementation of the tensor aggregation.
Status AggregateTensors(InputTensorList tensors) override {
@@ -81,7 +102,7 @@ class AggVectorAggregator : public TensorAggregator {
return FCP_STATUS(INVALID_ARGUMENT)
<< "AggVectorAggregator::AggregateTensors: dtype mismatch";
}
- if (tensor->shape() != result_tensor_.shape()) {
+ if (tensor->shape() != shape_) {
return FCP_STATUS(INVALID_ARGUMENT)
<< "AggVectorAggregator::AggregateTensors: tensor shape mismatch";
}
@@ -92,11 +113,19 @@ class AggVectorAggregator : public TensorAggregator {
return FCP_STATUS(OK);
}
- Status CheckValid() const override { return result_tensor_.CheckValid(); }
+ Status CheckValid() const override {
+ if (data_vector_ == nullptr) {
+ return FCP_STATUS(FAILED_PRECONDITION)
+ << "AggVectorAggregator::CheckValid: Output has already been "
+ << "consumed.";
+ }
+ return FCP_STATUS(OK);
+ }
OutputTensorList TakeOutputs() && override {
OutputTensorList outputs = std::vector<Tensor>();
- outputs.push_back(std::move(result_tensor_));
+ outputs.push_back(
+ Tensor::Create(dtype_, shape_, std::move(data_vector_)).value());
return outputs;
}
@@ -104,22 +133,12 @@ class AggVectorAggregator : public TensorAggregator {
virtual void AggregateVector(const AggVector<T>& agg_vector) = 0;
private:
- AggVectorAggregator(DataType dtype, TensorShape shape,
- MutableVectorData<T>* data)
- : result_tensor_(
- Tensor::Create(dtype, shape, std::unique_ptr<TensorData>(data))
- .value()),
- data_vector_(*data),
- num_inputs_(0) {
- FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype)
- << "Incompatible dtype";
- }
-
- static MutableVectorData<T>* CreateData(const TensorShape& shape) {
+ static std::unique_ptr<MutableVectorData<T>> CreateData(
+ const TensorShape& shape) {
StatusOr<size_t> num_elements = shape.NumElements();
FCP_CHECK(num_elements.ok()) << "AggVectorAggregator: All dimensions of "
"tensor shape must be known in advance.";
- return new MutableVectorData<T>(num_elements.value());
+ return std::make_unique<MutableVectorData<T>>(num_elements.value());
}
StatusOr<AggVectorAggregator<T>*> CastOther(TensorAggregator& other) {
@@ -134,8 +153,9 @@ class AggVectorAggregator : public TensorAggregator {
return other_ptr;
}
- Tensor result_tensor_;
- std::vector<T>& data_vector_;
+ const DataType dtype_;
+ const TensorShape shape_;
+ std::unique_ptr<MutableVectorData<T>> data_vector_;
int num_inputs_;
};
diff --git a/fcp/aggregation/core/agg_vector_aggregator_test.cc b/fcp/aggregation/core/agg_vector_aggregator_test.cc
index cd1e098..69596cf 100644
--- a/fcp/aggregation/core/agg_vector_aggregator_test.cc
+++ b/fcp/aggregation/core/agg_vector_aggregator_test.cc
@@ -17,10 +17,13 @@
#include "fcp/aggregation/core/agg_vector_aggregator.h"
#include <cstdint>
+#include <string>
#include <utility>
+#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/input_tensor_list.h"
#include "fcp/aggregation/core/tensor.h"
#include "fcp/aggregation/core/tensor.pb.h"
@@ -166,6 +169,34 @@ TEST(AggVectorAggregatorTest, TypeCheckFailure) {
EXPECT_DEATH(new SumAggregator<float>(DT_INT32, {}), "Incompatible dtype");
}
+TEST(AggVectorAggregatorTest, Serialization_Succeeds) {
+ const TensorShape shape = {4};
+ SumAggregator<int32_t> aggregator(DT_INT32, shape);
+ Tensor t1 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
+ Tensor t2 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
+ Tensor t3 =
+ Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
+ EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
+ EXPECT_THAT(aggregator.CanReport(), IsTrue());
+ EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
+
+ auto serialized_state = std::move(aggregator).Serialize();
+
+ AggVectorAggregatorState aggregator_state;
+ aggregator_state.ParseFromString(serialized_state.value());
+ EXPECT_THAT(aggregator_state.num_inputs(), Eq(3));
+ const int32_t* vector_data =
+ reinterpret_cast<const int32_t*>(aggregator_state.vector_data().data());
+ std::vector<int32_t> data(
+ vector_data,
+ vector_data + aggregator_state.vector_data().size() / sizeof(int32_t));
+ EXPECT_EQ(data, std::vector<int32_t>({14, 19, 23, 49}));
+}
+
} // namespace
} // namespace aggregation
} // namespace fcp
diff --git a/fcp/aggregation/core/federated_sum.cc b/fcp/aggregation/core/federated_sum.cc
index 68cc0f7..a9a3734 100644
--- a/fcp/aggregation/core/federated_sum.cc
+++ b/fcp/aggregation/core/federated_sum.cc
@@ -18,9 +18,11 @@
#include <string>
#include <utility>
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/agg_vector_aggregator.h"
#include "fcp/aggregation/core/datatype.h"
#include "fcp/aggregation/core/intrinsic.h"
+#include "fcp/aggregation/core/mutable_vector_data.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_aggregator_registry.h"
@@ -59,6 +61,24 @@ class FederatedSumFactory final : public TensorAggregatorFactory {
StatusOr<std::unique_ptr<TensorAggregator>> Create(
const Intrinsic& intrinsic) const override {
+ return CreateInternal(intrinsic, nullptr);
+ }
+
+ StatusOr<std::unique_ptr<TensorAggregator>> Deserialize(
+ const Intrinsic& intrinsic, std::string serialized_state) const override {
+ AggVectorAggregatorState aggregator_state;
+ if (!aggregator_state.ParseFromString(serialized_state)) {
+ return FCP_STATUS(INVALID_ARGUMENT)
+ << "FederatedSumFactory: Failed to deserialize the "
+ "AggVectorAggregatorState.";
+ }
+ return CreateInternal(intrinsic, &aggregator_state);
+ };
+
+ private:
+ StatusOr<std::unique_ptr<TensorAggregator>> CreateInternal(
+ const Intrinsic& intrinsic,
+ const AggVectorAggregatorState* aggregator_state) const {
// Check that the configuration is valid for federated_sum.
if (kFederatedSumUri != intrinsic.uri) {
return FCP_STATUS(INVALID_ARGUMENT)
@@ -92,10 +112,21 @@ class FederatedSumFactory final : public TensorAggregatorFactory {
"specs.";
}
std::unique_ptr<TensorAggregator> aggregator;
+ if (aggregator_state == nullptr) {
+ NUMERICAL_ONLY_DTYPE_CASES(
+ input_spec.dtype(), T,
+ aggregator = std::make_unique<FederatedSum<T>>(
+ input_spec.dtype(), std::move(input_spec.shape())));
+ return aggregator;
+ }
+
NUMERICAL_ONLY_DTYPE_CASES(
input_spec.dtype(), T,
aggregator = std::make_unique<FederatedSum<T>>(
- input_spec.dtype(), std::move(input_spec.shape())));
+ input_spec.dtype(), std::move(input_spec.shape()),
+ MutableVectorData<T>::CreateFromEncodedContent(
+ aggregator_state->vector_data()),
+ aggregator_state->num_inputs()));
return aggregator;
}
};
diff --git a/fcp/aggregation/core/federated_sum_test.cc b/fcp/aggregation/core/federated_sum_test.cc
index 6a76722..6977a97 100644
--- a/fcp/aggregation/core/federated_sum_test.cc
+++ b/fcp/aggregation/core/federated_sum_test.cc
@@ -19,6 +19,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "fcp/aggregation/core/agg_core.pb.h"
#include "fcp/aggregation/core/tensor.h"
#include "fcp/aggregation/core/tensor_aggregator_factory.h"
#include "fcp/aggregation/core/tensor_aggregator_registry.h"
@@ -111,6 +112,31 @@ TEST(FederatedSumTest, Merge_Succeeds) {
EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
}
+TEST(FederatedSumTest, SerializeDeserialize_Succeeds) {
+ auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value();
+ Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
+ Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
+ Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
+ EXPECT_THAT(aggregator->Accumulate(t1), IsOk());
+ EXPECT_THAT(aggregator->Accumulate(t2), IsOk());
+ EXPECT_THAT(aggregator->CanReport(), IsTrue());
+
+ auto serialized_state = std::move(*aggregator).Serialize();
+ auto deserialized_aggregator =
+ DeserializeTensorAggregator(GetDefaultIntrinsic(),
+ serialized_state.value())
+ .value();
+
+ EXPECT_THAT(deserialized_aggregator->Accumulate(t3), IsOk());
+ EXPECT_THAT(deserialized_aggregator->GetNumInputs(), Eq(3));
+ EXPECT_THAT(deserialized_aggregator->CanReport(), IsTrue());
+
+ auto result = std::move(*deserialized_aggregator).Report();
+ // Verify the resulting tensor.
+ EXPECT_THAT(result.value().size(), Eq(1));
+ EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
+}
+
TEST(FederatedSumTest, Create_WrongUri) {
Intrinsic intrinsic{"wrong_uri",
{TensorSpec{"foo", DT_INT32, {}}},
diff --git a/fcp/aggregation/core/mutable_vector_data.h b/fcp/aggregation/core/mutable_vector_data.h
index c2f43a9..207075b 100644
--- a/fcp/aggregation/core/mutable_vector_data.h
+++ b/fcp/aggregation/core/mutable_vector_data.h
@@ -18,6 +18,8 @@
#define FCP_AGGREGATION_CORE_MUTABLE_VECTOR_DATA_H_
#include <cstddef>
+#include <memory>
+#include <string>
#include <vector>
#include "fcp/aggregation/core/tensor_data.h"
@@ -39,6 +41,21 @@ class MutableVectorData : public std::vector<T>, public TensorData {
// Implementation of the base class methods.
size_t byte_size() const override { return this->size() * sizeof(T); }
const void* data() const override { return this->std::vector<T>::data(); }
+
+ // Copy the MutableVectorData into a string.
+ std::string EncodeContent() {
+ return std::string(reinterpret_cast<const char*>(this->data()),
+ this->byte_size());
+ }
+
+ // Create and return a new MutableVectorData populated with the data from
+ // content.
+ static std::unique_ptr<MutableVectorData<T>> CreateFromEncodedContent(
+ const std::string& content) {
+ const T* data = reinterpret_cast<const T*>(content.data());
+ return std::make_unique<MutableVectorData<T>>(
+ data, data + content.size() / sizeof(T));
+ }
};
} // namespace aggregation
diff --git a/fcp/aggregation/core/mutable_vector_data_test.cc b/fcp/aggregation/core/mutable_vector_data_test.cc
index b799d12..3f6aa1a 100644
--- a/fcp/aggregation/core/mutable_vector_data_test.cc
+++ b/fcp/aggregation/core/mutable_vector_data_test.cc
@@ -16,6 +16,8 @@
#include "fcp/aggregation/core/mutable_vector_data.h"
#include <cstdint>
+#include <string>
+#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -33,6 +35,20 @@ TEST(MutableVectorDataTest, MutableVectorDataValid) {
EXPECT_THAT(vector_data.CheckValid<int64_t>(), IsOk());
}
+TEST(MutableVectorDataTest, EncodeDecodeSucceeds) {
+ MutableVectorData<int64_t> vector_data;
+ vector_data.push_back(1);
+ vector_data.push_back(2);
+ vector_data.push_back(3);
+ std::string encoded_vector_data = vector_data.EncodeContent();
+ EXPECT_THAT(vector_data.CheckValid<int64_t>(), IsOk());
+ auto decoded_vector_data =
+ MutableVectorData<int64_t>::CreateFromEncodedContent(encoded_vector_data);
+ EXPECT_THAT(decoded_vector_data->CheckValid<int64_t>(), IsOk());
+ EXPECT_EQ(std::vector<int64_t>(*decoded_vector_data),
+ std::vector<int64_t>({1, 2, 3}));
+}
+
} // namespace
} // namespace aggregation
} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_aggregator.h b/fcp/aggregation/core/tensor_aggregator.h
index d0590be..86e3e03 100644
--- a/fcp/aggregation/core/tensor_aggregator.h
+++ b/fcp/aggregation/core/tensor_aggregator.h
@@ -19,6 +19,7 @@
#include <vector>
+#include "absl/strings/cord.h"
#include "fcp/aggregation/core/aggregator.h"
#include "fcp/aggregation/core/input_tensor_list.h"
#include "fcp/aggregation/core/tensor.h"
@@ -44,6 +45,12 @@ class TensorAggregator
// Returns the number of aggregated inputs.
virtual int GetNumInputs() const = 0;
+ // Serialize the internal state of the TensorAggregator as a string.
+ // TODO: b/331978180 - Make pure virtual once all derived classes implement.
+ virtual StatusOr<std::string> Serialize() && {
+ return FCP_STATUS(UNIMPLEMENTED);
+ };
+
protected:
// Construct TensorAggregator
explicit TensorAggregator() {}
diff --git a/fcp/aggregation/core/tensor_aggregator_factory.h b/fcp/aggregation/core/tensor_aggregator_factory.h
index 1609d28..807807f 100644
--- a/fcp/aggregation/core/tensor_aggregator_factory.h
+++ b/fcp/aggregation/core/tensor_aggregator_factory.h
@@ -18,6 +18,7 @@
#define FCP_AGGREGATION_CORE_TENSOR_AGGREGATOR_FACTORY_H_
#include <memory>
+#include <string>
#include "fcp/aggregation/core/intrinsic.h"
#include "fcp/aggregation/core/tensor_aggregator.h"
@@ -39,6 +40,17 @@ class TensorAggregatorFactory {
// hold pointers referring to the Intrinsic.
virtual StatusOr<std::unique_ptr<TensorAggregator>> Create(
const Intrinsic& intrinsic) const = 0;
+
+ // Creates an instance of a specific aggregator for the specified type of the
+ // aggregation intrinsic and serialized aggregator state.
+ // The lifetime of the provided Intrinsic must outlast that of the returned
+ // TensorAggregator as it is valid for the TensorAggregator implementation to
+ // hold pointers referring to the Intrinsic.
+ // TODO: b/331978180 - Make pure virtual once all derived classes implement.
+ virtual StatusOr<std::unique_ptr<TensorAggregator>> Deserialize(
+ const Intrinsic& intrinsic, std::string serialized_state) const {
+ return FCP_STATUS(UNIMPLEMENTED);
+ };
};
} // namespace aggregation
diff --git a/fcp/aggregation/core/tensor_aggregator_registry.cc b/fcp/aggregation/core/tensor_aggregator_registry.cc
index c87ebac..45c846c 100644
--- a/fcp/aggregation/core/tensor_aggregator_registry.cc
+++ b/fcp/aggregation/core/tensor_aggregator_registry.cc
@@ -17,10 +17,9 @@
#include <memory>
#include <string>
-#include "fcp/aggregation/core/tensor_aggregator_factory.h"
-
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
+#include "fcp/aggregation/core/tensor_aggregator_factory.h"
namespace fcp {
namespace aggregation {
@@ -83,5 +82,11 @@ StatusOr<std::unique_ptr<TensorAggregator>> CreateTensorAggregator(
return (*GetAggregatorFactory(intrinsic.uri))->Create(intrinsic);
}
+StatusOr<std::unique_ptr<TensorAggregator>> DeserializeTensorAggregator(
+ const Intrinsic& intrinsic, std::string serialized_state) {
+ return (*GetAggregatorFactory(intrinsic.uri))
+ ->Deserialize(intrinsic, serialized_state);
+}
+
} // namespace aggregation
} // namespace fcp
diff --git a/fcp/aggregation/core/tensor_aggregator_registry.h b/fcp/aggregation/core/tensor_aggregator_registry.h
index 48d240a..547422e 100644
--- a/fcp/aggregation/core/tensor_aggregator_registry.h
+++ b/fcp/aggregation/core/tensor_aggregator_registry.h
@@ -38,6 +38,11 @@ StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
StatusOr<std::unique_ptr<TensorAggregator>> CreateTensorAggregator(
const Intrinsic& intrinsic);
+// Creates a TensorAggregator with the given internal state via the factory
+// registered for the given intrinsic
+StatusOr<std::unique_ptr<TensorAggregator>> DeserializeTensorAggregator(
+ const Intrinsic& intrinsic, std::string serialized_state);
+
namespace internal {
template <typename FactoryType>