diff options
author | felobato <felobato@google.com> | 2023-01-17 14:54:45 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2023-01-17 14:55:55 -0800 |
commit | 67cb689b0fa57095c56414349e15fedf2b7d0c18 (patch) | |
tree | 5a8cacb4d1de3a6b51f91bdbf002d0da73660c43 /cc/prf | |
parent | b23975e79c92067bea1c86922c37d678a520abd0 (diff) | |
download | tink-67cb689b0fa57095c56414349e15fedf2b7d0c18.tar.gz |
[C++] Monitoring for PRF Set
Wrap each PRF in the PRFSet with a `MonitoredPRF`. The `MonitoredPRF` will emit monitoring when enabled. The PRFSet now needs to also keep ownership of each MonitoredPRF. Because we already have a public facing API that returns a map<uint32,PRF*>, we can't remove that contract hence having a duplicate structure to keep track of the new MonitoredPRFs.
PiperOrigin-RevId: 502696442
Diffstat (limited to 'cc/prf')
-rw-r--r-- | cc/prf/BUILD.bazel | 8 | ||||
-rw-r--r-- | cc/prf/CMakeLists.txt | 8 | ||||
-rw-r--r-- | cc/prf/prf_set_wrapper.cc | 84 | ||||
-rw-r--r-- | cc/prf/prf_set_wrapper_test.cc | 124 |
4 files changed, 212 insertions, 12 deletions
diff --git a/cc/prf/BUILD.bazel b/cc/prf/BUILD.bazel index cf4d7b6e9..dbb9e1cfd 100644 --- a/cc/prf/BUILD.bazel +++ b/cc/prf/BUILD.bazel @@ -90,11 +90,15 @@ cc_library( ":prf_set", "//:primitive_set", "//:primitive_wrapper", + "//internal:monitoring_util", + "//internal:registry_impl", + "//monitoring", "//proto:tink_cc_proto", "//util:status", "//util:statusor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) @@ -195,10 +199,14 @@ cc_test( ":prf_set", ":prf_set_wrapper", "//:primitive_set", + "//:registry", + "//monitoring:monitoring_client_mocks", "//proto:tink_cc_proto", + "//util:status", "//util:statusor", "//util:test_matchers", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], diff --git a/cc/prf/CMakeLists.txt b/cc/prf/CMakeLists.txt index 028587e44..bca726de9 100644 --- a/cc/prf/CMakeLists.txt +++ b/cc/prf/CMakeLists.txt @@ -79,8 +79,12 @@ tink_cc_library( tink::prf::prf_set absl::memory absl::status + absl::statusor tink::core::primitive_set tink::core::primitive_wrapper + tink::internal::monitoring_util + tink::internal::registry_impl + tink::monitoring::monitoring tink::util::status tink::util::statusor tink::proto::tink_cc_proto @@ -183,8 +187,12 @@ tink_cc_test( tink::prf::prf_set_wrapper gmock absl::memory + absl::status absl::strings tink::core::primitive_set + tink::core::registry + tink::monitoring::monitoring_client_mocks + tink::util::status tink::util::statusor tink::util::test_matchers tink::proto::tink_cc_proto diff --git a/cc/prf/prf_set_wrapper.cc b/cc/prf/prf_set_wrapper.cc index 89adbdaf7..af81b56d2 100644 --- a/cc/prf/prf_set_wrapper.cc +++ b/cc/prf/prf_set_wrapper.cc @@ -15,12 +15,20 @@ /////////////////////////////////////////////////////////////////////////////// #include "tink/prf/prf_set_wrapper.h" +#include <cstdint> #include <map> #include <memory> +#include <string> #include <utility> +#include <vector> #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tink/internal/monitoring_util.h" +#include "tink/internal/registry_impl.h" +#include "tink/monitoring/monitoring.h" +#include "tink/prf/prf_set.h" #include "tink/util/status.h" #include "proto/tink.pb.h" @@ -31,12 +39,59 @@ using google::crypto::tink::OutputPrefixType; namespace { +constexpr absl::string_view kPrimitive = "prf"; +constexpr absl::string_view kComputeApi = "compute"; + +class MonitoredPrf : public Prf { + public: + explicit MonitoredPrf(uint32_t key_id, const Prf* prf, + MonitoringClient* monitoring_client) + : key_id_(key_id), prf_(prf), monitoring_client_(monitoring_client) {} + ~MonitoredPrf() override = default; + + MonitoredPrf(MonitoredPrf&& other) = default; + MonitoredPrf& operator=(MonitoredPrf&& other) = default; + + MonitoredPrf(const MonitoredPrf&) = delete; + MonitoredPrf& operator=(const MonitoredPrf&) = delete; + + util::StatusOr<std::string> Compute(absl::string_view input, + size_t output_length) const override { + util::StatusOr<std::string> result = prf_->Compute(input, output_length); + if (!result.ok()) { + if (monitoring_client_ != nullptr) { + monitoring_client_->LogFailure(); + } + return result.status(); + } + + if (monitoring_client_ != nullptr) { + monitoring_client_->Log(key_id_, input.size()); + } + return result.value(); + } + + private: + uint32_t key_id_; + const Prf* prf_; + MonitoringClient* monitoring_client_; +}; + class PrfSetPrimitiveWrapper : public PrfSet { public: - explicit PrfSetPrimitiveWrapper(std::unique_ptr<PrimitiveSet<Prf>> prf_set) - : prf_set_(std::move(prf_set)) { + explicit PrfSetPrimitiveWrapper( + std::unique_ptr<PrimitiveSet<Prf>> prf_set, + std::unique_ptr<MonitoringClient> monitoring_client = nullptr) + : prf_set_(std::move(prf_set)), + monitoring_client_(std::move(monitoring_client)) { + wrapped_prfs_.reserve(prf_set_->get_raw_primitives().value()->size()); for (const auto& prf : *prf_set_->get_raw_primitives().value()) { - prfs_.insert({prf->get_key_id(), &prf->get_primitive()}); + std::unique_ptr<Prf> wrapped_prf = std::make_unique<MonitoredPrf>( + prf->get_key_id(), &prf->get_primitive(), + monitoring_client_.get()); + + prfs_.insert({prf->get_key_id(), wrapped_prf.get()}); + wrapped_prfs_.push_back(std::move(wrapped_prf)); } } @@ -49,6 +104,8 @@ class PrfSetPrimitiveWrapper : public PrfSet { private: std::unique_ptr<PrimitiveSet<Prf>> prf_set_; + std::unique_ptr<MonitoringClient> monitoring_client_; + std::vector<std::unique_ptr<Prf>> wrapped_prfs_; std::map<uint32_t, Prf*> prfs_; }; @@ -76,7 +133,26 @@ util::StatusOr<std::unique_ptr<PrfSet>> PrfSetWrapper::Wrap( std::unique_ptr<PrimitiveSet<Prf>> prf_set) const { util::Status status = Validate(prf_set.get()); if (!status.ok()) return status; - return {absl::make_unique<PrfSetPrimitiveWrapper>(std::move(prf_set))}; + + MonitoringClientFactory* const monitoring_factory = + internal::RegistryImpl::GlobalInstance().GetMonitoringClientFactory(); + // Monitoring is not enabled. Create a wrapper without monitoring clients. + if (monitoring_factory == nullptr) { + return {absl::make_unique<PrfSetPrimitiveWrapper>(std::move(prf_set))}; + } + util::StatusOr<MonitoringKeySetInfo> keyset_info = + internal::MonitoringKeySetInfoFromPrimitiveSet(*prf_set); + if (!keyset_info.ok()) { + return keyset_info.status(); + } + util::StatusOr<std::unique_ptr<MonitoringClient>> monitoring_client = + monitoring_factory->New( + MonitoringContext(kPrimitive, kComputeApi, *keyset_info)); + if (!monitoring_client.ok()) { + return monitoring_client.status(); + } + return {absl::make_unique<PrfSetPrimitiveWrapper>( + std::move(prf_set), *std::move(monitoring_client))}; } } // namespace tink diff --git a/cc/prf/prf_set_wrapper_test.cc b/cc/prf/prf_set_wrapper_test.cc index 095c67288..b100de165 100644 --- a/cc/prf/prf_set_wrapper_test.cc +++ b/cc/prf/prf_set_wrapper_test.cc @@ -16,6 +16,7 @@ #include "tink/prf/prf_set_wrapper.h" #include <cstdint> +#include <map> #include <memory> #include <string> #include <utility> @@ -23,9 +24,13 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "tink/monitoring/monitoring_client_mocks.h" #include "tink/prf/prf_set.h" #include "tink/primitive_set.h" +#include "tink/registry.h" +#include "tink/util/status.h" #include "tink/util/statusor.h" #include "tink/util/test_matchers.h" #include "proto/tink.pb.h" @@ -38,11 +43,24 @@ using ::crypto::tink::test::IsOk; using ::crypto::tink::test::IsOkAndHolds; using ::google::crypto::tink::KeysetInfo; using ::google::crypto::tink::KeyStatusType; +using ::testing::_; +using ::testing::ByMove; using ::testing::Key; +using ::testing::NiceMock; using ::testing::Not; +using ::testing::Return; using ::testing::StrEq; +using ::testing::Test; using ::testing::UnorderedElementsAre; +KeysetInfo::KeyInfo MakeKey(uint32_t id) { + KeysetInfo::KeyInfo key; + key.set_output_prefix_type(google::crypto::tink::OutputPrefixType::RAW); + key.set_key_id(id); + key.set_status(KeyStatusType::ENABLED); + return key; +} + class FakePrf : public Prf { public: explicit FakePrf(const std::string& output) : output_(output) {} @@ -65,14 +83,6 @@ class PrfSetWrapperTest : public ::testing::Test { return prf_set_->AddPrimitive(std::move(prf), key_info); } - KeysetInfo::KeyInfo MakeKey(uint32_t id) { - KeysetInfo::KeyInfo key; - key.set_output_prefix_type(google::crypto::tink::OutputPrefixType::RAW); - key.set_key_id(id); - key.set_status(KeyStatusType::ENABLED); - return key; - } - std::unique_ptr<PrimitiveSet<Prf>>& PrfSet() { return prf_set_; } private: @@ -134,6 +144,104 @@ TEST_F(PrfSetWrapperTest, WrapTwo) { IsOkAndHolds(StrEq("different"))); } +// Tests for the monitoring behavior. +class PrfSetWrapperWithMonitoringTest : public Test { + protected: + // Reset the global registry. + void SetUp() override { + Registry::Reset(); + // Setup mocks for catching Monitoring calls. + auto monitoring_client_factory = + absl::make_unique<MockMonitoringClientFactory>(); + auto monitoring_client = + absl::make_unique<NiceMock<MockMonitoringClient>>(); + monitoring_client_ref_ = monitoring_client.get(); + // Monitoring tests expect that the client factory will create the + // corresponding MockMonitoringClients. + EXPECT_CALL(*monitoring_client_factory, New(_)) + .WillOnce( + Return(ByMove(util::StatusOr<std::unique_ptr<MonitoringClient>>( + std::move(monitoring_client))))); + + ASSERT_THAT(internal::RegistryImpl::GlobalInstance() + .RegisterMonitoringClientFactory( + std::move(monitoring_client_factory)), + IsOk()); + ASSERT_THAT( + internal::RegistryImpl::GlobalInstance().GetMonitoringClientFactory(), + Not(testing::IsNull())); + } + + // Cleanup the registry to avoid mock leaks. + ~PrfSetWrapperWithMonitoringTest() override { Registry::Reset(); } + + NiceMock<MockMonitoringClient>* monitoring_client_ref_; +}; + +class AlwaysFailingPrf : public Prf { + public: + AlwaysFailingPrf() = default; + + util::StatusOr<std::string> Compute(absl::string_view input, + size_t output_length) const override { + return util::Status(absl::StatusCode::kOutOfRange, "AlwaysFailingPrf"); + } +}; + +TEST_F(PrfSetWrapperWithMonitoringTest, WrapKeysetWithMonitoringFailure) { + const absl::flat_hash_map<std::string, std::string> annotations = { + {"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}; + auto primitive_set = absl::make_unique<PrimitiveSet<Prf>>(annotations); + util::StatusOr<PrimitiveSet<Prf>::Entry<Prf>*> entry = + primitive_set->AddPrimitive(absl::make_unique<AlwaysFailingPrf>(), + MakeKey(/*id=*/1)); + ASSERT_THAT(entry, IsOk()); + ASSERT_THAT(primitive_set->set_primary(entry.value()), IsOk()); + ASSERT_THAT(primitive_set + ->AddPrimitive(absl::make_unique<FakePrf>("output"), + MakeKey(/*id=*/1)) + .status(), + IsOk()); + util::StatusOr<std::unique_ptr<PrfSet>> prf_set = + PrfSetWrapper().Wrap(std::move(primitive_set)); + ASSERT_THAT(prf_set, IsOk()); + EXPECT_CALL(*monitoring_client_ref_, LogFailure()); + EXPECT_THAT((*prf_set)->ComputePrimary("input", /*output_length=*/16), + Not(IsOk())); +} + +TEST_F(PrfSetWrapperWithMonitoringTest, WrapKeysetWithMonitoringVerifySuccess) { + const absl::flat_hash_map<std::string, std::string> annotations = { + {"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}; + auto primitive_set = absl::make_unique<PrimitiveSet<Prf>>(annotations); + + util::StatusOr<PrimitiveSet<Prf>::Entry<Prf>*> entry = + primitive_set->AddPrimitive(absl::make_unique<FakePrf>("output"), + MakeKey(/*id=*/1)); + ASSERT_THAT(entry, IsOk()); + ASSERT_THAT(primitive_set->set_primary(entry.value()), IsOk()); + ASSERT_THAT(primitive_set + ->AddPrimitive(absl::make_unique<FakePrf>("output"), + MakeKey(/*id=*/1)) + .status(), + IsOk()); + + util::StatusOr<std::unique_ptr<PrfSet>> prf_set = + PrfSetWrapper().Wrap(std::move(primitive_set)); + ASSERT_THAT(prf_set, IsOk()); + std::map<uint32_t, Prf*> prf_map = (*prf_set)->GetPrfs(); + std::string input = "input"; + for (const auto& entry : prf_map) { + EXPECT_CALL(*monitoring_client_ref_, Log(entry.first, input.size())); + EXPECT_THAT((entry.second)->Compute(input, /*output_length=*/16).status(), + IsOk()); + } + input = "hello_world"; + EXPECT_CALL(*monitoring_client_ref_, + Log((*prf_set)->GetPrimaryId(), input.size())); + EXPECT_THAT((*prf_set)->ComputePrimary(input, /*output_length=*/16), IsOk()); +} + } // namespace } // namespace tink } // namespace crypto |