aboutsummaryrefslogtreecommitdiff
path: root/cc/prf
diff options
context:
space:
mode:
authorfelobato <felobato@google.com>2023-01-17 14:54:45 -0800
committerCopybara-Service <copybara-worker@google.com>2023-01-17 14:55:55 -0800
commit67cb689b0fa57095c56414349e15fedf2b7d0c18 (patch)
tree5a8cacb4d1de3a6b51f91bdbf002d0da73660c43 /cc/prf
parentb23975e79c92067bea1c86922c37d678a520abd0 (diff)
downloadtink-67cb689b0fa57095c56414349e15fedf2b7d0c18.tar.gz
[C++] Monitoring for PRF Set
Wrap each PRF in the PRFSet with a `MonitoredPRF`. The `MonitoredPRF` will emit monitoring when enabled. The PRFSet now needs to also keep ownership of each MonitoredPRF. Because we already have a public facing API that returns a map<uint32,PRF*>, we can't remove that contract hence having a duplicate structure to keep track of the new MonitoredPRFs. PiperOrigin-RevId: 502696442
Diffstat (limited to 'cc/prf')
-rw-r--r--cc/prf/BUILD.bazel8
-rw-r--r--cc/prf/CMakeLists.txt8
-rw-r--r--cc/prf/prf_set_wrapper.cc84
-rw-r--r--cc/prf/prf_set_wrapper_test.cc124
4 files changed, 212 insertions, 12 deletions
diff --git a/cc/prf/BUILD.bazel b/cc/prf/BUILD.bazel
index cf4d7b6e9..dbb9e1cfd 100644
--- a/cc/prf/BUILD.bazel
+++ b/cc/prf/BUILD.bazel
@@ -90,11 +90,15 @@ cc_library(
":prf_set",
"//:primitive_set",
"//:primitive_wrapper",
+ "//internal:monitoring_util",
+ "//internal:registry_impl",
+ "//monitoring",
"//proto:tink_cc_proto",
"//util:status",
"//util:statusor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
],
)
@@ -195,10 +199,14 @@ cc_test(
":prf_set",
":prf_set_wrapper",
"//:primitive_set",
+ "//:registry",
+ "//monitoring:monitoring_client_mocks",
"//proto:tink_cc_proto",
+ "//util:status",
"//util:statusor",
"//util:test_matchers",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
diff --git a/cc/prf/CMakeLists.txt b/cc/prf/CMakeLists.txt
index 028587e44..bca726de9 100644
--- a/cc/prf/CMakeLists.txt
+++ b/cc/prf/CMakeLists.txt
@@ -79,8 +79,12 @@ tink_cc_library(
tink::prf::prf_set
absl::memory
absl::status
+ absl::statusor
tink::core::primitive_set
tink::core::primitive_wrapper
+ tink::internal::monitoring_util
+ tink::internal::registry_impl
+ tink::monitoring::monitoring
tink::util::status
tink::util::statusor
tink::proto::tink_cc_proto
@@ -183,8 +187,12 @@ tink_cc_test(
tink::prf::prf_set_wrapper
gmock
absl::memory
+ absl::status
absl::strings
tink::core::primitive_set
+ tink::core::registry
+ tink::monitoring::monitoring_client_mocks
+ tink::util::status
tink::util::statusor
tink::util::test_matchers
tink::proto::tink_cc_proto
diff --git a/cc/prf/prf_set_wrapper.cc b/cc/prf/prf_set_wrapper.cc
index 89adbdaf7..af81b56d2 100644
--- a/cc/prf/prf_set_wrapper.cc
+++ b/cc/prf/prf_set_wrapper.cc
@@ -15,12 +15,20 @@
///////////////////////////////////////////////////////////////////////////////
#include "tink/prf/prf_set_wrapper.h"
+#include <cstdint>
#include <map>
#include <memory>
+#include <string>
#include <utility>
+#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "tink/internal/monitoring_util.h"
+#include "tink/internal/registry_impl.h"
+#include "tink/monitoring/monitoring.h"
+#include "tink/prf/prf_set.h"
#include "tink/util/status.h"
#include "proto/tink.pb.h"
@@ -31,12 +39,59 @@ using google::crypto::tink::OutputPrefixType;
namespace {
+constexpr absl::string_view kPrimitive = "prf";
+constexpr absl::string_view kComputeApi = "compute";
+
+class MonitoredPrf : public Prf {
+ public:
+ explicit MonitoredPrf(uint32_t key_id, const Prf* prf,
+ MonitoringClient* monitoring_client)
+ : key_id_(key_id), prf_(prf), monitoring_client_(monitoring_client) {}
+ ~MonitoredPrf() override = default;
+
+ MonitoredPrf(MonitoredPrf&& other) = default;
+ MonitoredPrf& operator=(MonitoredPrf&& other) = default;
+
+ MonitoredPrf(const MonitoredPrf&) = delete;
+ MonitoredPrf& operator=(const MonitoredPrf&) = delete;
+
+ util::StatusOr<std::string> Compute(absl::string_view input,
+ size_t output_length) const override {
+ util::StatusOr<std::string> result = prf_->Compute(input, output_length);
+ if (!result.ok()) {
+ if (monitoring_client_ != nullptr) {
+ monitoring_client_->LogFailure();
+ }
+ return result.status();
+ }
+
+ if (monitoring_client_ != nullptr) {
+ monitoring_client_->Log(key_id_, input.size());
+ }
+ return result.value();
+ }
+
+ private:
+ uint32_t key_id_;
+ const Prf* prf_;
+ MonitoringClient* monitoring_client_;
+};
+
class PrfSetPrimitiveWrapper : public PrfSet {
public:
- explicit PrfSetPrimitiveWrapper(std::unique_ptr<PrimitiveSet<Prf>> prf_set)
- : prf_set_(std::move(prf_set)) {
+ explicit PrfSetPrimitiveWrapper(
+ std::unique_ptr<PrimitiveSet<Prf>> prf_set,
+ std::unique_ptr<MonitoringClient> monitoring_client = nullptr)
+ : prf_set_(std::move(prf_set)),
+ monitoring_client_(std::move(monitoring_client)) {
+ wrapped_prfs_.reserve(prf_set_->get_raw_primitives().value()->size());
for (const auto& prf : *prf_set_->get_raw_primitives().value()) {
- prfs_.insert({prf->get_key_id(), &prf->get_primitive()});
+ std::unique_ptr<Prf> wrapped_prf = std::make_unique<MonitoredPrf>(
+ prf->get_key_id(), &prf->get_primitive(),
+ monitoring_client_.get());
+
+ prfs_.insert({prf->get_key_id(), wrapped_prf.get()});
+ wrapped_prfs_.push_back(std::move(wrapped_prf));
}
}
@@ -49,6 +104,8 @@ class PrfSetPrimitiveWrapper : public PrfSet {
private:
std::unique_ptr<PrimitiveSet<Prf>> prf_set_;
+ std::unique_ptr<MonitoringClient> monitoring_client_;
+ std::vector<std::unique_ptr<Prf>> wrapped_prfs_;
std::map<uint32_t, Prf*> prfs_;
};
@@ -76,7 +133,26 @@ util::StatusOr<std::unique_ptr<PrfSet>> PrfSetWrapper::Wrap(
std::unique_ptr<PrimitiveSet<Prf>> prf_set) const {
util::Status status = Validate(prf_set.get());
if (!status.ok()) return status;
- return {absl::make_unique<PrfSetPrimitiveWrapper>(std::move(prf_set))};
+
+ MonitoringClientFactory* const monitoring_factory =
+ internal::RegistryImpl::GlobalInstance().GetMonitoringClientFactory();
+ // Monitoring is not enabled. Create a wrapper without monitoring clients.
+ if (monitoring_factory == nullptr) {
+ return {absl::make_unique<PrfSetPrimitiveWrapper>(std::move(prf_set))};
+ }
+ util::StatusOr<MonitoringKeySetInfo> keyset_info =
+ internal::MonitoringKeySetInfoFromPrimitiveSet(*prf_set);
+ if (!keyset_info.ok()) {
+ return keyset_info.status();
+ }
+ util::StatusOr<std::unique_ptr<MonitoringClient>> monitoring_client =
+ monitoring_factory->New(
+ MonitoringContext(kPrimitive, kComputeApi, *keyset_info));
+ if (!monitoring_client.ok()) {
+ return monitoring_client.status();
+ }
+ return {absl::make_unique<PrfSetPrimitiveWrapper>(
+ std::move(prf_set), *std::move(monitoring_client))};
}
} // namespace tink
diff --git a/cc/prf/prf_set_wrapper_test.cc b/cc/prf/prf_set_wrapper_test.cc
index 095c67288..b100de165 100644
--- a/cc/prf/prf_set_wrapper_test.cc
+++ b/cc/prf/prf_set_wrapper_test.cc
@@ -16,6 +16,7 @@
#include "tink/prf/prf_set_wrapper.h"
#include <cstdint>
+#include <map>
#include <memory>
#include <string>
#include <utility>
@@ -23,9 +24,13 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/memory/memory.h"
+#include "absl/status/status.h"
#include "absl/strings/string_view.h"
+#include "tink/monitoring/monitoring_client_mocks.h"
#include "tink/prf/prf_set.h"
#include "tink/primitive_set.h"
+#include "tink/registry.h"
+#include "tink/util/status.h"
#include "tink/util/statusor.h"
#include "tink/util/test_matchers.h"
#include "proto/tink.pb.h"
@@ -38,11 +43,24 @@ using ::crypto::tink::test::IsOk;
using ::crypto::tink::test::IsOkAndHolds;
using ::google::crypto::tink::KeysetInfo;
using ::google::crypto::tink::KeyStatusType;
+using ::testing::_;
+using ::testing::ByMove;
using ::testing::Key;
+using ::testing::NiceMock;
using ::testing::Not;
+using ::testing::Return;
using ::testing::StrEq;
+using ::testing::Test;
using ::testing::UnorderedElementsAre;
+KeysetInfo::KeyInfo MakeKey(uint32_t id) {
+ KeysetInfo::KeyInfo key;
+ key.set_output_prefix_type(google::crypto::tink::OutputPrefixType::RAW);
+ key.set_key_id(id);
+ key.set_status(KeyStatusType::ENABLED);
+ return key;
+}
+
class FakePrf : public Prf {
public:
explicit FakePrf(const std::string& output) : output_(output) {}
@@ -65,14 +83,6 @@ class PrfSetWrapperTest : public ::testing::Test {
return prf_set_->AddPrimitive(std::move(prf), key_info);
}
- KeysetInfo::KeyInfo MakeKey(uint32_t id) {
- KeysetInfo::KeyInfo key;
- key.set_output_prefix_type(google::crypto::tink::OutputPrefixType::RAW);
- key.set_key_id(id);
- key.set_status(KeyStatusType::ENABLED);
- return key;
- }
-
std::unique_ptr<PrimitiveSet<Prf>>& PrfSet() { return prf_set_; }
private:
@@ -134,6 +144,104 @@ TEST_F(PrfSetWrapperTest, WrapTwo) {
IsOkAndHolds(StrEq("different")));
}
+// Tests for the monitoring behavior.
+class PrfSetWrapperWithMonitoringTest : public Test {
+ protected:
+ // Reset the global registry.
+ void SetUp() override {
+ Registry::Reset();
+ // Setup mocks for catching Monitoring calls.
+ auto monitoring_client_factory =
+ absl::make_unique<MockMonitoringClientFactory>();
+ auto monitoring_client =
+ absl::make_unique<NiceMock<MockMonitoringClient>>();
+ monitoring_client_ref_ = monitoring_client.get();
+ // Monitoring tests expect that the client factory will create the
+ // corresponding MockMonitoringClients.
+ EXPECT_CALL(*monitoring_client_factory, New(_))
+ .WillOnce(
+ Return(ByMove(util::StatusOr<std::unique_ptr<MonitoringClient>>(
+ std::move(monitoring_client)))));
+
+ ASSERT_THAT(internal::RegistryImpl::GlobalInstance()
+ .RegisterMonitoringClientFactory(
+ std::move(monitoring_client_factory)),
+ IsOk());
+ ASSERT_THAT(
+ internal::RegistryImpl::GlobalInstance().GetMonitoringClientFactory(),
+ Not(testing::IsNull()));
+ }
+
+ // Cleanup the registry to avoid mock leaks.
+ ~PrfSetWrapperWithMonitoringTest() override { Registry::Reset(); }
+
+ NiceMock<MockMonitoringClient>* monitoring_client_ref_;
+};
+
+class AlwaysFailingPrf : public Prf {
+ public:
+ AlwaysFailingPrf() = default;
+
+ util::StatusOr<std::string> Compute(absl::string_view input,
+ size_t output_length) const override {
+ return util::Status(absl::StatusCode::kOutOfRange, "AlwaysFailingPrf");
+ }
+};
+
+TEST_F(PrfSetWrapperWithMonitoringTest, WrapKeysetWithMonitoringFailure) {
+ const absl::flat_hash_map<std::string, std::string> annotations = {
+ {"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}};
+ auto primitive_set = absl::make_unique<PrimitiveSet<Prf>>(annotations);
+ util::StatusOr<PrimitiveSet<Prf>::Entry<Prf>*> entry =
+ primitive_set->AddPrimitive(absl::make_unique<AlwaysFailingPrf>(),
+ MakeKey(/*id=*/1));
+ ASSERT_THAT(entry, IsOk());
+ ASSERT_THAT(primitive_set->set_primary(entry.value()), IsOk());
+ ASSERT_THAT(primitive_set
+ ->AddPrimitive(absl::make_unique<FakePrf>("output"),
+ MakeKey(/*id=*/1))
+ .status(),
+ IsOk());
+ util::StatusOr<std::unique_ptr<PrfSet>> prf_set =
+ PrfSetWrapper().Wrap(std::move(primitive_set));
+ ASSERT_THAT(prf_set, IsOk());
+ EXPECT_CALL(*monitoring_client_ref_, LogFailure());
+ EXPECT_THAT((*prf_set)->ComputePrimary("input", /*output_length=*/16),
+ Not(IsOk()));
+}
+
+TEST_F(PrfSetWrapperWithMonitoringTest, WrapKeysetWithMonitoringVerifySuccess) {
+ const absl::flat_hash_map<std::string, std::string> annotations = {
+ {"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}};
+ auto primitive_set = absl::make_unique<PrimitiveSet<Prf>>(annotations);
+
+ util::StatusOr<PrimitiveSet<Prf>::Entry<Prf>*> entry =
+ primitive_set->AddPrimitive(absl::make_unique<FakePrf>("output"),
+ MakeKey(/*id=*/1));
+ ASSERT_THAT(entry, IsOk());
+ ASSERT_THAT(primitive_set->set_primary(entry.value()), IsOk());
+ ASSERT_THAT(primitive_set
+ ->AddPrimitive(absl::make_unique<FakePrf>("output"),
+ MakeKey(/*id=*/1))
+ .status(),
+ IsOk());
+
+ util::StatusOr<std::unique_ptr<PrfSet>> prf_set =
+ PrfSetWrapper().Wrap(std::move(primitive_set));
+ ASSERT_THAT(prf_set, IsOk());
+ std::map<uint32_t, Prf*> prf_map = (*prf_set)->GetPrfs();
+ std::string input = "input";
+ for (const auto& entry : prf_map) {
+ EXPECT_CALL(*monitoring_client_ref_, Log(entry.first, input.size()));
+ EXPECT_THAT((entry.second)->Compute(input, /*output_length=*/16).status(),
+ IsOk());
+ }
+ input = "hello_world";
+ EXPECT_CALL(*monitoring_client_ref_,
+ Log((*prf_set)->GetPrimaryId(), input.size()));
+ EXPECT_THAT((*prf_set)->ComputePrimary(input, /*output_length=*/16), IsOk());
+}
+
} // namespace
} // namespace tink
} // namespace crypto