diff options
Diffstat (limited to 'fcp/aggregation/core/grouping_federated_sum_test.cc')
-rw-r--r-- | fcp/aggregation/core/grouping_federated_sum_test.cc | 129 |
1 files changed, 109 insertions, 20 deletions
diff --git a/fcp/aggregation/core/grouping_federated_sum_test.cc b/fcp/aggregation/core/grouping_federated_sum_test.cc index df6b9bb..7c777f1 100644 --- a/fcp/aggregation/core/grouping_federated_sum_test.cc +++ b/fcp/aggregation/core/grouping_federated_sum_test.cc @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include <cstdint> #include <memory> +#include <string> #include <utility> #include "gmock/gmock.h" @@ -38,6 +40,9 @@ namespace { using ::testing::Eq; using testing::HasSubstr; using ::testing::IsTrue; +using testing::TestWithParam; + +using GroupingFederatedSumTest = TestWithParam<bool>; Intrinsic GetDefaultIntrinsic() { // One "GoogleSQL:sum" intrinsic with a single int32 tensor of unknown size. @@ -48,7 +53,7 @@ Intrinsic GetDefaultIntrinsic() { {}}; } -TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) { +TEST_P(GroupingFederatedSumTest, ScalarAggregationSucceeds) { auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinal = Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value(); @@ -57,6 +62,17 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) { Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value(); EXPECT_THAT(aggregator->Accumulate({&ordinal, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinal, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinal, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); @@ -67,7 +83,7 @@ TEST(GroupingFederatedSumTest, ScalarAggregationSucceeds) { EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6})); } -TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) { +TEST_P(GroupingFederatedSumTest, DenseAggregationSucceeds) { TensorShape shape{4}; auto aggregator = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinals = @@ -81,6 +97,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) { Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value(); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(GetDefaultIntrinsic(), state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -94,15 +121,14 @@ TEST(GroupingFederatedSumTest, DenseAggregationSucceeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { +TEST_P(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { TensorShape shape{4}; - auto aggregator = - CreateTensorAggregator(Intrinsic{"GoogleSQL:sum", - {TensorSpec{"foo", DT_INT32, {-1}}}, - {TensorSpec{"foo_out", DT_INT64, {-1}}}, - {}, - {}}) - .value(); + Intrinsic intrinsic{"GoogleSQL:sum", + {TensorSpec{"foo", DT_INT32, {-1}}}, + {TensorSpec{"foo_out", DT_INT64, {-1}}}, + {}, + {}}; + auto aggregator = CreateTensorAggregator(intrinsic).value(); Tensor ordinals = Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3})) .value(); @@ -114,6 +140,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value(); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(intrinsic, state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -127,15 +164,15 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerTypeSucceeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) { +TEST_P(GroupingFederatedSumTest, + DenseAggregationCastToLargerFloatTypeSucceeds) { TensorShape shape{4}; - auto aggregator = - CreateTensorAggregator(Intrinsic{"GoogleSQL:sum", - {TensorSpec{"foo", DT_FLOAT, {-1}}}, - {TensorSpec{"foo_out", DT_DOUBLE, {-1}}}, - {}, - {}}) - .value(); + Intrinsic intrinsic{"GoogleSQL:sum", + {TensorSpec{"foo", DT_FLOAT, {-1}}}, + {TensorSpec{"foo_out", DT_DOUBLE, {-1}}}, + {}, + {}}; + auto aggregator = CreateTensorAggregator(intrinsic).value(); Tensor ordinals = Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3})) .value(); @@ -150,6 +187,17 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) { .value(); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t1}), IsOk()); EXPECT_THAT(aggregator->Accumulate({&ordinals, &t2}), IsOk()); + + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator.release())); + auto state = std::move(*(one_dim_base_aggregator)).ToProto(); + aggregator = factory->FromProto(intrinsic, state).value(); + } + EXPECT_THAT(aggregator->Accumulate({&ordinals, &t3}), IsOk()); EXPECT_THAT(aggregator->CanReport(), IsTrue()); EXPECT_THAT(aggregator->GetNumInputs(), Eq(3)); @@ -163,7 +211,7 @@ TEST(GroupingFederatedSumTest, DenseAggregationCastToLargerFloatTypeSucceeds) { EXPECT_TRUE(result.value()[0].is_dense()); } -TEST(GroupingFederatedSumTest, MergeSucceeds) { +TEST_P(GroupingFederatedSumTest, MergeSucceeds) { auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinal = @@ -175,6 +223,21 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) { EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -194,7 +257,7 @@ TEST(GroupingFederatedSumTest, MergeSucceeds) { EXPECT_THAT(result.value()[0], IsTensor<int64_t>({1}, {6})); } -TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) { +TEST_P(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) { auto aggregator1 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); auto aggregator2 = CreateTensorAggregator(GetDefaultIntrinsic()).value(); Tensor ordinal = @@ -206,6 +269,21 @@ TEST(GroupingFederatedSumTest, MergeSucceedsWithNonSharedOrdinals) { EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t2}), IsOk()); EXPECT_THAT(aggregator2->Accumulate({&ordinal, &t3}), IsOk()); + if (GetParam()) { + auto factory = dynamic_cast<const OneDimBaseGroupingAggregatorFactory*>( + GetAggregatorFactory(GetDefaultIntrinsic().uri).value()); + auto one_dim_base_aggregator1 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator1.release())); + auto state = std::move(*(one_dim_base_aggregator1)).ToProto(); + aggregator1 = factory->FromProto(GetDefaultIntrinsic(), state).value(); + auto one_dim_base_aggregator2 = + std::unique_ptr<OneDimBaseGroupingAggregator>( + dynamic_cast<OneDimBaseGroupingAggregator*>(aggregator2.release())); + auto state2 = std::move(*(one_dim_base_aggregator2)).ToProto(); + aggregator2 = factory->FromProto(GetDefaultIntrinsic(), state2).value(); + } + int aggregator2_num_inputs = aggregator2->GetNumInputs(); auto aggregator2_result = std::move(std::move(*aggregator2).Report().value()[0]); @@ -343,6 +421,17 @@ TEST(GroupingFederatedSumTest, CreateUnsupportedStringDataType) { HasSubstr("GroupingFederatedSum isn't supported for DT_STRING datatype")); } +TEST(GroupingFederatedSumTest, Deserialize_Unimplemented) { + Status s = DeserializeTensorAggregator(GetDefaultIntrinsic(), "").status(); + EXPECT_THAT(s, IsCode(UNIMPLEMENTED)); +} + +INSTANTIATE_TEST_SUITE_P( + GroupingFederatedSumTestInstantiation, GroupingFederatedSumTest, + testing::ValuesIn<bool>({false, true}), + [](const testing::TestParamInfo<GroupingFederatedSumTest::ParamType>& + info) { return info.param ? "SaveIntermediateState" : "None"; }); + } // namespace } // namespace aggregation } // namespace fcp |