aboutsummaryrefslogtreecommitdiff
path: root/fcp/aggregation/core/grouping_federated_sum_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'fcp/aggregation/core/grouping_federated_sum_test.cc')
-rw-r--r--fcp/aggregation/core/grouping_federated_sum_test.cc129
1 files changed, 109 insertions, 20 deletions
diff --git a/fcp/aggregation/core/grouping_federated_sum_test.cc b/fcp/aggregation/core/grouping_federated_sum_test.cc
index df6b9bb..7c777f1 100644
--- a/fcp/aggregation/core/grouping_federated_sum_test.cc
+++ b/fcp/aggregation/core/grouping_federated_sum_test.cc
@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
#include <cstdint>
#include <memory>
+#include <string>
#include <utility>
#include "gmock/gmock.h"
@@ -38,6 +40,9 @@ namespace {
using ::testing::Eq;
using testing::HasSubstr;
using ::testing::IsTrue;
+using testing::TestWithParam;
+
+using GroupingFederatedSumTest = TestWithParam<bool>;
Intrinsic GetDefaultIntrinsic() {
// One "GoogleSQL:sum" intrinsic with a single int32 tensor of unknown size.
@@ -48,7 +53,7 @@ Intrinsic GetDefaultIntrinsic() {
{}};
}
-TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
+TEST_P(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinal =
Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
@@ -57,6 +62,17 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
EXPECT_THAT(aggregator->Accumulate({&ordinal, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinal, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinal, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
@@ -67,7 +83,7 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) {
EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6}));
}
-TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) {
+TEST_P(GroupingFederatedSumTest, DenseAggregationSucceeds) {
TensorShape shape{4};
auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinals =
@@ -81,6 +97,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) {
Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -94,15 +121,14 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
+TEST_P(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
TensorShape shape{4};
- auto aggregator =
- CreateTensorAggregator(Intrinsic{"GoogleSQL:sum",
- {TensorSpec{"foo", DT_INT32, {-1}}},
- {TensorSpec{"foo_out", DT_INT64, {-1}}},
- {},
- {}})
- .value();
+ Intrinsic intrinsic{"GoogleSQL:sum",
+ {TensorSpec{"foo", DT_INT32, {-1}}},
+ {TensorSpec{"foo_out", DT_INT64, {-1}}},
+ {},
+ {}};
+ auto aggregator = CreateTensorAggregator(intrinsic).value();
Tensor ordinals =
Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3}))
.value();
@@ -114,6 +140,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(intrinsic, state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -127,15 +164,15 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) {
+TEST_P(GroupingFederatedSumTest,
+ DenseAggregationCastToLargerFloatTypeSucceeds) {
TensorShape shape{4};
- auto aggregator =
- CreateTensorAggregator(Intrinsic{"GoogleSQL:sum",
- {TensorSpec{"foo", DT_FLOAT, {-1}}},
- {TensorSpec{"foo_out", DT_DOUBLE, {-1}}},
- {},
- {}})
- .value();
+ Intrinsic intrinsic{"GoogleSQL:sum",
+ {TensorSpec{"foo", DT_FLOAT, {-1}}},
+ {TensorSpec{"foo_out", DT_DOUBLE, {-1}}},
+ {},
+ {}};
+ auto aggregator = CreateTensorAggregator(intrinsic).value();
Tensor ordinals =
Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3}))
.value();
@@ -150,6 +187,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) {
.value();
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk());
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk());
+
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release()));
+ auto state = std::move(*(one_dim_base_aggregator)).ToProto();
+ aggregator = factory->FromProto(intrinsic, state).value();
+ }
+
EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk());
EXPECT_THAT(aggregator->CanReport(), IsTrue());
EXPECT_THAT(aggregator->GetNumInputs(), Eq(3));
@@ -163,7 +211,7 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) {
EXPECT_TRUE(result.value()[0].is_dense());
}
-TEST(GroupingFederatedSumTest, MergeSucceeds) {
+TEST_P(GroupingFederatedSumTest, MergeSucceeds) {
auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinal =
@@ -175,6 +223,21 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) {
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -194,7 +257,7 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) {
EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6}));
}
-TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) {
+TEST_P(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) {
auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value();
Tensor ordinal =
@@ -206,6 +269,21 @@ TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) {
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk());
EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk());
+ if (GetParam()) {
+ auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>(
+ GetAggregatorFactory(GetDefaultIntrinsic().uri).value());
+ auto one_dim_base_aggregator1 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release()));
+ auto state = std::move(*(one_dim_base_aggregator1)).ToProto();
+ aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value();
+ auto one_dim_base_aggregator2 =
+ std::unique_ptr<OneDimBaseGroupingAggregator>(
+ dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release()));
+ auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto();
+ aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value();
+ }
+
int aggregator2_num_inputs = aggregator2->GetNumInputs();
auto aggregator2_result =
std::move(std::move(*aggregator2).Report().value()[0]);
@@ -343,6 +421,17 @@ TEST(GroupingFederatedSumTest, CreateUnsupportedStringDataType) {
HasSubstr("GroupingFederatedSum isn't supported for DT_STRING datatype"));
}
+TEST(GroupingFederatedSumTest, Deserialize_Unimplemented) {
+ Status s = DeserializeTensorAggregator(GetDefaultIntrinsic(), "").status();
+ EXPECT_THAT(s, IsCode(UNIMPLEMENTED));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ GroupingFederatedSumTestInstantiation, GroupingFederatedSumTest,
+ testing::ValuesIn<bool>({false, true}),
+ [](const testing::TestParamInfo<GroupingFederatedSumTest::ParamType>&
+ info) { return info.param ? "SaveIntermediateState" : "None"; });
+
} // namespace
} // namespace aggregation
} // namespace fcp