diff options
Diffstat (limited to 'fcp/client/engine/example_query_plan_engine_test.cc')
-rw-r--r-- | fcp/client/engine/example_query_plan_engine_test.cc | 547 |
1 files changed, 547 insertions, 0 deletions
diff --git a/fcp/client/engine/example_query_plan_engine_test.cc b/fcp/client/engine/example_query_plan_engine_test.cc new file mode 100644 index 0000000..dbad82d --- /dev/null +++ b/fcp/client/engine/example_query_plan_engine_test.cc @@ -0,0 +1,547 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "fcp/client/engine/example_query_plan_engine.h" + +#include <fcntl.h> + +#include <cstdint> +#include <filesystem> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "fcp/client/client_runner.h" +#include "fcp/client/engine/common.h" +#include "fcp/client/example_query_result.pb.h" +#include "fcp/client/test_helpers.h" +#include "fcp/protos/plan.pb.h" +#include "fcp/testing/testing.h" +#include "tensorflow/c/checkpoint_reader.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace fcp { +namespace client { +namespace engine { +namespace { + +namespace tf = ::tensorflow; + +using ::fcp::client::ExampleQueryResult; +using ::google::internal::federated::plan::AggregationConfig; +using ::google::internal::federated::plan::ClientOnlyPlan; +using ::google::internal::federated::plan::Dataset; +using ::google::internal::federated::plan::ExampleQuerySpec; +using ::google::internal::federated::plan::ExampleSelector; +using ::testing::StrictMock; + +const char* const kCollectionUri = "app:/test_collection"; +const char* const kOutputStringVectorName = "vector1"; +const char* const kOutputIntVectorName = "vector2"; +const char* const kOutputStringTensorName = "tensor1"; +const char* const kOutputIntTensorName = "tensor2"; + +class InvalidExampleIteratorFactory : public ExampleIteratorFactory { + public: + InvalidExampleIteratorFactory() = default; + + bool CanHandle(const google::internal::federated::plan::ExampleSelector& + example_selector) override { + return false; + } + + absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( + const ExampleSelector& example_selector) override { + absl::Status error(absl::StatusCode::kInternal, ""); + return error; + } + + bool ShouldCollectStats() override { return false; } +}; + +class NoIteratorExampleIteratorFactory : public ExampleIteratorFactory { + public: + NoIteratorExampleIteratorFactory() = default; + + bool CanHandle(const google::internal::federated::plan::ExampleSelector& + example_selector) override { + return true; + } + + absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( + const ExampleSelector& example_selector) override { + absl::Status error(absl::StatusCode::kInternal, ""); + return error; + } + + bool ShouldCollectStats() override { return false; } +}; + +class TwoExampleIteratorsFactory : public ExampleIteratorFactory { + public: + explicit TwoExampleIteratorsFactory( + std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>( + const google::internal::federated::plan::ExampleSelector& + + )> + create_first_iterator_func, + std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>( + const google::internal::federated::plan::ExampleSelector& + + )> + create_second_iterator_func, + const std::string& first_collection_uri, + const std::string& second_collection_uri) + : create_first_iterator_func_(create_first_iterator_func), + create_second_iterator_func_(create_second_iterator_func), + first_collection_uri_(first_collection_uri), + second_collection_uri_(second_collection_uri) {} + + bool CanHandle(const google::internal::federated::plan::ExampleSelector& + example_selector) override { + return true; + } + + absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( + const google::internal::federated::plan::ExampleSelector& + example_selector) override { + if (example_selector.collection_uri() == first_collection_uri_) { + return create_first_iterator_func_(example_selector); + } else if (example_selector.collection_uri() == second_collection_uri_) { + return create_second_iterator_func_(example_selector); + } + return absl::InvalidArgumentError("Unknown collection URI"); + } + + bool ShouldCollectStats() override { return false; } + + private: + std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>( + const google::internal::federated::plan::ExampleSelector&)> + create_first_iterator_func_; + std::function<absl::StatusOr<std::unique_ptr<ExampleIterator>>( + const google::internal::federated::plan::ExampleSelector&)> + create_second_iterator_func_; + std::string first_collection_uri_; + std::string second_collection_uri_; +}; + +absl::StatusOr<absl::flat_hash_map<std::string, tf::Tensor>> ReadTensors( + std::string checkpoint_path) { + absl::flat_hash_map<std::string, tf::Tensor> tensors; + tf::TF_StatusPtr tf_status(TF_NewStatus()); + tf::checkpoint::CheckpointReader tf_checkpoint_reader(checkpoint_path, + tf_status.get()); + if (TF_GetCode(tf_status.get()) != TF_OK) { + return absl::NotFoundError("Couldn't read an input checkpoint"); + } + for (const auto& [name, tf_dtype] : + tf_checkpoint_reader.GetVariableToDataTypeMap()) { + std::unique_ptr<tf::Tensor> tensor; + tf_checkpoint_reader.GetTensor(name, &tensor, tf_status.get()); + if (TF_GetCode(tf_status.get()) != TF_OK) { + return absl::NotFoundError( + absl::StrFormat("Checkpoint doesn't have tensor %s", name)); + } + tensors[name] = *tensor; + } + + return tensors; +} + +class ExampleQueryPlanEngineTest : public testing::Test { + protected: + void Initialize() { + std::filesystem::path root_dir(testing::TempDir()); + std::filesystem::path output_path = root_dir / std::string("output.ckpt"); + output_checkpoint_filename_ = output_path.string(); + + ExampleQuerySpec::OutputVectorSpec string_vector_spec; + string_vector_spec.set_vector_name(kOutputStringVectorName); + string_vector_spec.set_data_type( + ExampleQuerySpec::OutputVectorSpec::STRING); + ExampleQuerySpec::OutputVectorSpec int_vector_spec; + int_vector_spec.set_vector_name(kOutputIntVectorName); + int_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::INT64); + + ExampleQuerySpec::ExampleQuery example_query; + example_query.mutable_example_selector()->set_collection_uri( + kCollectionUri); + (*example_query.mutable_output_vector_specs())[kOutputStringTensorName] = + string_vector_spec; + (*example_query.mutable_output_vector_specs())[kOutputIntTensorName] = + int_vector_spec; + client_only_plan_.mutable_phase() + ->mutable_example_query_spec() + ->mutable_example_queries() + ->Add(std::move(example_query)); + + AggregationConfig aggregation_config; + aggregation_config.mutable_tf_v1_checkpoint_aggregation(); + (*client_only_plan_.mutable_phase() + ->mutable_federated_example_query() + ->mutable_aggregations())[kOutputStringTensorName] = + aggregation_config; + (*client_only_plan_.mutable_phase() + ->mutable_federated_example_query() + ->mutable_aggregations())[kOutputIntTensorName] = aggregation_config; + + ExampleQueryResult::VectorData::Values int_values; + int_values.mutable_int64_values()->add_value(42); + int_values.mutable_int64_values()->add_value(24); + (*example_query_result_.mutable_vector_data() + ->mutable_vectors())[kOutputIntVectorName] = int_values; + ExampleQueryResult::VectorData::Values string_values; + string_values.mutable_string_values()->add_value("value1"); + string_values.mutable_string_values()->add_value("value2"); + (*example_query_result_.mutable_vector_data() + ->mutable_vectors())[kOutputStringVectorName] = string_values; + std::string example = example_query_result_.SerializeAsString(); + + Dataset::ClientDataset client_dataset; + client_dataset.set_client_id("client_id"); + client_dataset.add_example(example); + dataset_.mutable_client_data()->Add(std::move(client_dataset)); + + num_examples_ = 1; + example_bytes_ = example.size(); + + example_iterator_factory_ = + std::make_unique<FunctionalExampleIteratorFactory>( + [&dataset = dataset_]( + const google::internal::federated::plan::ExampleSelector& + selector) { + return std::make_unique<SimpleExampleIterator>(dataset); + }); + } + + fcp::client::FilesImpl files_impl_; + StrictMock<MockOpStatsLogger> mock_opstats_logger_; + std::unique_ptr<ExampleIteratorFactory> example_iterator_factory_; + + ExampleQueryResult example_query_result_; + ClientOnlyPlan client_only_plan_; + Dataset dataset_; + std::string output_checkpoint_filename_; + + int num_examples_ = 0; + int64_t example_bytes_ = 0; +}; + +TEST_F(ExampleQueryPlanEngineTest, PlanSucceeds) { + Initialize(); + + EXPECT_CALL( + mock_opstats_logger_, + UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_)); + + ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kSuccess); + + auto tensors = ReadTensors(output_checkpoint_filename_); + ASSERT_OK(tensors); + tf::Tensor int_tensor = tensors.value()[kOutputIntTensorName]; + ASSERT_EQ(int_tensor.shape(), tf::TensorShape({2})); + ASSERT_EQ(int_tensor.dtype(), tf::DT_INT64); + auto int_data = static_cast<int64_t*>(int_tensor.data()); + std::vector<int64_t> expected_int_data({42, 24}); + for (int i = 0; i < 2; ++i) { + ASSERT_EQ(int_data[i], expected_int_data[i]); + } + + tf::Tensor string_tensor = tensors.value()[kOutputStringTensorName]; + ASSERT_EQ(string_tensor.shape(), tf::TensorShape({2})); + ASSERT_EQ(string_tensor.dtype(), tf::DT_STRING); + auto string_data = static_cast<tf::tstring*>(string_tensor.data()); + std::vector<std::string> expected_string_data({"value1", "value2"}); + for (int i = 0; i < 2; ++i) { + ASSERT_EQ(static_cast<std::string>(string_data[i]), + expected_string_data[i]); + } +} + +TEST_F(ExampleQueryPlanEngineTest, MultipleQueries) { + Initialize(); + + ExampleQuerySpec::OutputVectorSpec float_vector_spec; + float_vector_spec.set_vector_name("float_vector"); + float_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::FLOAT); + ExampleQuerySpec::OutputVectorSpec string_vector_spec; + // Same vector name as in the other ExampleQuery, but with a different output + // one to make sure these vectors are distinguished in + // example_query_plan_engine. + string_vector_spec.set_vector_name(kOutputStringVectorName); + string_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::STRING); + + ExampleQuerySpec::ExampleQuery second_example_query; + second_example_query.mutable_example_selector()->set_collection_uri( + "app:/second_collection"); + (*second_example_query.mutable_output_vector_specs())["float_tensor"] = + float_vector_spec; + (*second_example_query + .mutable_output_vector_specs())["another_string_tensor"] = + string_vector_spec; + client_only_plan_.mutable_phase() + ->mutable_example_query_spec() + ->mutable_example_queries() + ->Add(std::move(second_example_query)); + + AggregationConfig aggregation_config; + aggregation_config.mutable_tf_v1_checkpoint_aggregation(); + (*client_only_plan_.mutable_phase() + ->mutable_federated_example_query() + ->mutable_aggregations())["float_tensor"] = aggregation_config; + + ExampleQueryResult second_example_query_result; + ExampleQueryResult::VectorData::Values float_values; + float_values.mutable_float_values()->add_value(0.24f); + float_values.mutable_float_values()->add_value(0.42f); + float_values.mutable_float_values()->add_value(0.33f); + ExampleQueryResult::VectorData::Values string_values; + string_values.mutable_string_values()->add_value("another_string_value"); + (*second_example_query_result.mutable_vector_data() + ->mutable_vectors())["float_vector"] = float_values; + (*second_example_query_result.mutable_vector_data() + ->mutable_vectors())[kOutputStringVectorName] = string_values; + std::string example = second_example_query_result.SerializeAsString(); + + Dataset::ClientDataset dataset; + dataset.set_client_id("second_client_id"); + dataset.add_example(example); + Dataset second_dataset; + second_dataset.mutable_client_data()->Add(std::move(dataset)); + + example_iterator_factory_ = std::make_unique<TwoExampleIteratorsFactory>( + [&dataset = dataset_]( + const google::internal::federated::plan::ExampleSelector& selector) { + return std::make_unique<SimpleExampleIterator>(dataset); + }, + [&dataset = second_dataset]( + const google::internal::federated::plan::ExampleSelector& selector) { + return std::make_unique<SimpleExampleIterator>(dataset); + }, + kCollectionUri, "app:/second_collection"); + + ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kSuccess); + + auto tensors = ReadTensors(output_checkpoint_filename_); + ASSERT_OK(tensors); + tf::Tensor int_tensor = tensors.value()[kOutputIntTensorName]; + ASSERT_EQ(int_tensor.shape(), tf::TensorShape({2})); + ASSERT_EQ(int_tensor.dtype(), tf::DT_INT64); + auto int_data = static_cast<int64_t*>(int_tensor.data()); + std::vector<int64_t> expected_int_data({42, 24}); + for (int i = 0; i < 2; ++i) { + ASSERT_EQ(int_data[i], expected_int_data[i]); + } + + tf::Tensor string_tensor = tensors.value()[kOutputStringTensorName]; + ASSERT_EQ(string_tensor.shape(), tf::TensorShape({2})); + ASSERT_EQ(string_tensor.dtype(), tf::DT_STRING); + auto string_data = static_cast<tf::tstring*>(string_tensor.data()); + std::vector<std::string> expected_string_data({"value1", "value2"}); + for (int i = 0; i < 2; ++i) { + ASSERT_EQ(static_cast<std::string>(string_data[i]), + expected_string_data[i]); + } + + tf::Tensor float_tensor = tensors.value()["float_tensor"]; + ASSERT_EQ(float_tensor.shape(), tf::TensorShape({3})); + ASSERT_EQ(float_tensor.dtype(), tf::DT_FLOAT); + auto float_data = static_cast<float*>(float_tensor.data()); + std::vector<float> expected_float_data({0.24f, 0.42f, 0.33f}); + for (int i = 0; i < 3; ++i) { + ASSERT_EQ(float_data[i], expected_float_data[i]); + } + + tf::Tensor second_query_string_tensor = + tensors.value()["another_string_tensor"]; + ASSERT_EQ(second_query_string_tensor.shape(), tf::TensorShape({1})); + ASSERT_EQ(second_query_string_tensor.dtype(), tf::DT_STRING); + auto second_query_string_data = + static_cast<tf::tstring*>(second_query_string_tensor.data()); + ASSERT_EQ(static_cast<std::string>(*second_query_string_data), + "another_string_value"); +} + +TEST_F(ExampleQueryPlanEngineTest, OutputVectorSpecMissingInResult) { + Initialize(); + + ExampleQuerySpec::OutputVectorSpec new_vector_spec; + new_vector_spec.set_vector_name("new_vector"); + new_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::DOUBLE); + + ExampleQuerySpec::ExampleQuery example_query = + client_only_plan_.phase().example_query_spec().example_queries().at(0); + (*example_query.mutable_output_vector_specs())["new_tensor"] = + new_vector_spec; + client_only_plan_.mutable_phase() + ->mutable_example_query_spec() + ->clear_example_queries(); + client_only_plan_.mutable_phase() + ->mutable_example_query_spec() + ->mutable_example_queries() + ->Add(std::move(example_query)); + + ExampleQueryResult example_query_result; + ExampleQueryResult::VectorData::Values bool_values; + bool_values.mutable_bool_values()->add_value(true); + (*example_query_result_.mutable_vector_data() + ->mutable_vectors())["new_vector"] = bool_values; + std::string example = example_query_result_.SerializeAsString(); + + Dataset::ClientDataset client_dataset; + client_dataset.set_client_id("client_id"); + client_dataset.add_example(example); + dataset_.clear_client_data(); + dataset_.mutable_client_data()->Add(std::move(client_dataset)); + + num_examples_ = 1; + example_bytes_ = example.size(); + + example_iterator_factory_ = + std::make_unique<FunctionalExampleIteratorFactory>( + [&dataset = dataset_]( + const google::internal::federated::plan::ExampleSelector& + selector) { + return std::make_unique<SimpleExampleIterator>(dataset); + }); + + EXPECT_CALL( + mock_opstats_logger_, + UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_)); + + ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); +} + +TEST_F(ExampleQueryPlanEngineTest, OutputVectorSpecTypeMismatch) { + Initialize(); + + ExampleQuerySpec::OutputVectorSpec new_vector_spec; + new_vector_spec.set_vector_name("new_vector"); + new_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::DOUBLE); + + ExampleQuerySpec::ExampleQuery example_query = + client_only_plan_.phase().example_query_spec().example_queries().at(0); + (*example_query.mutable_output_vector_specs())["new_tensor"] = + new_vector_spec; + client_only_plan_.mutable_phase() + ->mutable_example_query_spec() + ->clear_example_queries(); + client_only_plan_.mutable_phase() + ->mutable_example_query_spec() + ->mutable_example_queries() + ->Add(std::move(example_query)); + + EXPECT_CALL( + mock_opstats_logger_, + UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_)); + + ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); +} + +TEST_F(ExampleQueryPlanEngineTest, FactoryNotFound) { + Initialize(); + auto invalid_example_factory = + std::make_unique<InvalidExampleIteratorFactory>(); + + ExampleQueryPlanEngine plan_engine({invalid_example_factory.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); +} + +TEST_F(ExampleQueryPlanEngineTest, NoIteratorCreated) { + Initialize(); + auto invalid_example_factory = + std::make_unique<NoIteratorExampleIteratorFactory>(); + + ExampleQueryPlanEngine plan_engine({invalid_example_factory.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); +} + +TEST_F(ExampleQueryPlanEngineTest, InvalidExampleQueryResultFormat) { + Initialize(); + std::string invalid_example = "invalid_example"; + Dataset::ClientDataset client_dataset; + client_dataset.add_example(invalid_example); + dataset_.clear_client_data(); + dataset_.mutable_client_data()->Add(std::move(client_dataset)); + example_iterator_factory_ = + std::make_unique<FunctionalExampleIteratorFactory>( + [&dataset = dataset_]( + const google::internal::federated::plan::ExampleSelector& + selector) { + return std::make_unique<SimpleExampleIterator>(dataset); + }); + EXPECT_CALL(mock_opstats_logger_, + UpdateDatasetStats(kCollectionUri, 1, invalid_example.size())); + + ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, + &mock_opstats_logger_); + engine::PlanResult result = + plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), + output_checkpoint_filename_); + + EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); +} + +} // anonymous namespace +} // namespace engine +} // namespace client +} // namespace fcp |