aboutsummaryrefslogtreecommitdiff
path: root/cc/util
diff options
context:
space:
mode:
authorambrosin <ambrosin@google.com>2023-01-20 21:08:46 -0800
committerCopybara-Service <copybara-worker@google.com>2023-01-20 21:10:26 -0800
commit8caaaca75a7317ddccc22723e50197040ffb2a92 (patch)
treed31fe048f89a453d7772b6f63cdaacd7ed5f21f4 /cc/util
parentc97b0eca3cb8a56e4942606c8cef1c47684d7e41 (diff)
downloadtink-8caaaca75a7317ddccc22723e50197040ffb2a92.tar.gz
Fix DummyDecryptingRandomAccessStream and add unit tests
DummyDecryptingRandomAccessStream behaves incorrectly when reading all the content from the stream returns an EOF, and > 0 bytes are added to the buffer. PiperOrigin-RevId: 503585078
Diffstat (limited to 'cc/util')
-rw-r--r--cc/util/BUILD.bazel8
-rw-r--r--cc/util/CMakeLists.txt8
-rw-r--r--cc/util/test_util.h57
-rw-r--r--cc/util/test_util_test.cc181
4 files changed, 230 insertions, 24 deletions
diff --git a/cc/util/BUILD.bazel b/cc/util/BUILD.bazel
index c95ca0c30..5b74a21e1 100644
--- a/cc/util/BUILD.bazel
+++ b/cc/util/BUILD.bazel
@@ -537,11 +537,19 @@ cc_test(
name = "test_util_test",
srcs = ["test_util_test.cc"],
deps = [
+ ":buffer",
+ ":ostream_output_stream",
+ ":statusor",
":test_matchers",
":test_util",
+ "//:output_stream",
+ "//:random_access_stream",
+ "//internal:test_random_access_stream",
"//proto:aes_gcm_cc_proto",
"//proto:tink_cc_proto",
"//subtle",
+ "//subtle:test_util",
+ "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
diff --git a/cc/util/CMakeLists.txt b/cc/util/CMakeLists.txt
index 8d64d7d93..a5f8adbd9 100644
--- a/cc/util/CMakeLists.txt
+++ b/cc/util/CMakeLists.txt
@@ -370,10 +370,18 @@ tink_cc_test(
SRCS
test_util_test.cc
DEPS
+ tink::util::buffer
+ tink::util::ostream_output_stream
+ tink::util::statusor
tink::util::test_matchers
tink::util::test_util
gmock
+ absl::strings
+ tink::core::output_stream
+ tink::core::random_access_stream
+ tink::internal::test_random_access_stream
tink::subtle::subtle
+ tink::subtle::test_util
tink::proto::aes_gcm_cc_proto
tink::proto::tink_cc_proto
)
diff --git a/cc/util/test_util.h b/cc/util/test_util.h
index 5f97f5509..4e2657d86 100644
--- a/cc/util/test_util.h
+++ b/cc/util/test_util.h
@@ -495,40 +495,35 @@ class DummyStreamingAead : public StreamingAead {
util::Status status_;
}; // class DummyDecryptingStream
- // Upon first call to PRead() tries to read from 'ct_source' a header
- // that is expected to be equal to 'expected_header'. If this
+ // Upon first call to PRead() tries to read from `ct_source` a header
+ // that is expected to be equal to `expected_header`. If this
// header matching succeeds, all subsequent method calls are forwarded
- // to the corresponding methods of 'cd_source'.
+ // to `ct_source->PRead`.
class DummyDecryptingRandomAccessStream
: public crypto::tink::RandomAccessStream {
public:
DummyDecryptingRandomAccessStream(
std::unique_ptr<crypto::tink::RandomAccessStream> ct_source,
absl::string_view expected_header)
- : ct_source_(std::move(ct_source)),
- exp_header_(expected_header),
- status_(util::Status(absl::StatusCode::kUnavailable,
- "not initialized")) {}
+ : ct_source_(std::move(ct_source)), exp_header_(expected_header) {}
crypto::tink::util::Status PRead(
int64_t position, int count,
crypto::tink::util::Buffer* dest_buffer) override {
- { // Initialize, if not initialized yet.
- absl::MutexLock lock(&status_mutex_);
- if (status_.code() == absl::StatusCode::kUnavailable) Initialize();
- if (!status_.ok()) return status_;
+ util::Status status = CheckHeader();
+ if (!status.ok()) {
+ return status;
}
- auto status = dest_buffer->set_size(0);
+ status = dest_buffer->set_size(0);
if (!status.ok()) return status;
return ct_source_->PRead(position + exp_header_.size(), count,
dest_buffer);
}
util::StatusOr<int64_t> size() override {
- { // Initialize, if not initialized yet.
- absl::MutexLock lock(&status_mutex_);
- if (status_.code() == absl::StatusCode::kUnavailable) Initialize();
- if (!status_.ok()) return status_;
+ util::Status status = CheckHeader();
+ if (!status.ok()) {
+ return status;
}
auto ct_size_result = ct_source_->size();
if (!ct_size_result.ok()) return ct_size_result.status();
@@ -538,25 +533,39 @@ class DummyStreamingAead : public StreamingAead {
}
private:
- void Initialize() ABSL_EXCLUSIVE_LOCKS_REQUIRED(status_mutex_) {
+ util::Status CheckHeader()
+ ABSL_LOCKS_EXCLUDED(header_check_status_mutex_) {
+ absl::MutexLock lock(&header_check_status_mutex_);
+ if (header_check_status_.code() != absl::StatusCode::kUnavailable) {
+ return header_check_status_;
+ }
auto buf = std::move(util::Buffer::New(exp_header_.size()).value());
- status_ = ct_source_->PRead(0, exp_header_.size(), buf.get());
- if (!status_.ok() && status_.code() != absl::StatusCode::kOutOfRange)
- return;
+ header_check_status_ =
+ ct_source_->PRead(0, exp_header_.size(), buf.get());
+ if (!header_check_status_.ok() &&
+ header_check_status_.code() != absl::StatusCode::kOutOfRange) {
+ return header_check_status_;
+ }
+ // EOF or Ok indicate a valid read has happened.
+ header_check_status_ = util::OkStatus();
+ // Invalid header.
if (buf->size() < exp_header_.size()) {
- status_ = util::Status(absl::StatusCode::kInvalidArgument,
+ header_check_status_ = util::Status(absl::StatusCode::kInvalidArgument,
"Could not read header");
} else if (memcmp(buf->get_mem_block(), exp_header_.data(),
static_cast<int>(exp_header_.size()))) {
- status_ = util::Status(absl::StatusCode::kInvalidArgument,
+ header_check_status_ = util::Status(absl::StatusCode::kInvalidArgument,
"Corrupted header");
}
+ return header_check_status_;
}
std::unique_ptr<crypto::tink::RandomAccessStream> ct_source_;
std::string exp_header_;
- mutable absl::Mutex status_mutex_;
- util::Status status_ ABSL_GUARDED_BY(status_mutex_);
+ mutable absl::Mutex header_check_status_mutex_;
+ util::Status header_check_status_
+ ABSL_GUARDED_BY(header_check_status_mutex_) =
+ util::Status(absl::StatusCode::kUnavailable, "Uninitialized");
}; // class DummyDecryptingRandomAccessStream
private:
diff --git a/cc/util/test_util_test.cc b/cc/util/test_util_test.cc
index 4b8fdf9e0..34171c334 100644
--- a/cc/util/test_util_test.cc
+++ b/cc/util/test_util_test.cc
@@ -15,9 +15,24 @@
///////////////////////////////////////////////////////////////////////////////
#include "tink/util/test_util.h"
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/strings/string_view.h"
+#include "tink/internal/test_random_access_stream.h"
+#include "tink/output_stream.h"
+#include "tink/random_access_stream.h"
#include "tink/subtle/random.h"
+#include "tink/subtle/test_util.h"
+#include "tink/util/buffer.h"
+#include "tink/util/ostream_output_stream.h"
+#include "tink/util/statusor.h"
#include "tink/util/test_matchers.h"
#include "proto/aes_gcm.pb.h"
#include "proto/tink.pb.h"
@@ -27,6 +42,8 @@ namespace tink {
namespace test {
namespace {
+using ::crypto::tink::internal::TestRandomAccessStream;
+using ::crypto::tink::test::StatusIs;
using ::google::crypto::tink::AesGcmKey;
using ::google::crypto::tink::KeyData;
using ::testing::Eq;
@@ -107,6 +124,170 @@ TEST(ZTests, AutocorrelationUniformString) {
IsOk());
}
+TEST(DummyStreamingAead, DummyDecryptingStreamPreadAllAtOnceSucceeds) {
+ const int stream_size = 1024;
+ std::string stream_content = subtle::Random::GetRandomBytes(stream_size);
+
+ auto ostream = std::make_unique<std::ostringstream>();
+ auto string_stream_buffer = ostream->rdbuf();
+ auto output_stream =
+ std::make_unique<util::OstreamOutputStream>(std::move(ostream));
+
+ DummyStreamingAead streaming_aead("Some AEAD");
+ util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream =
+ streaming_aead.NewEncryptingStream(std::move(output_stream), "Some AAD");
+ ASSERT_THAT(encrypting_output_stream.status(), IsOk());
+ ASSERT_THAT(subtle::test::WriteToStream(
+ encrypting_output_stream.value().get(), stream_content),
+ IsOk());
+
+ std::string ciphertext = string_stream_buffer->str();
+ auto test_random_access_stream =
+ std::make_unique<TestRandomAccessStream>(ciphertext);
+ util::StatusOr<std::unique_ptr<RandomAccessStream>>
+ decrypting_random_access_stream =
+ streaming_aead.NewDecryptingRandomAccessStream(
+ std::move(test_random_access_stream), "Some AAD");
+ ASSERT_THAT(decrypting_random_access_stream.status(), IsOk());
+
+ auto buffer = util::Buffer::New(ciphertext.size());
+ EXPECT_THAT((*decrypting_random_access_stream)
+ ->PRead(/*position=*/0, ciphertext.size(), buffer->get()),
+ StatusIs(absl::StatusCode::kOutOfRange));
+ EXPECT_EQ(stream_content,
+ std::string((*buffer)->get_mem_block(), (*buffer)->size()));
+}
+
+TEST(DummyStreamingAead, DummyDecryptingStreamPreadInChunksSucceeds) {
+ const int stream_size = 1024;
+ std::string stream_content = subtle::Random::GetRandomBytes(stream_size);
+
+ auto ostream = std::make_unique<std::ostringstream>();
+ auto string_stream_buffer = ostream->rdbuf();
+ auto output_stream =
+ std::make_unique<util::OstreamOutputStream>(std::move(ostream));
+
+ DummyStreamingAead streaming_aead("Some AEAD");
+ util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream =
+ streaming_aead.NewEncryptingStream(std::move(output_stream), "Some AAD");
+ ASSERT_THAT(encrypting_output_stream.status(), IsOk());
+ ASSERT_THAT(subtle::test::WriteToStream(
+ encrypting_output_stream.value().get(), stream_content),
+ IsOk());
+
+ std::string ciphertext = string_stream_buffer->str();
+ auto test_random_access_stream =
+ std::make_unique<TestRandomAccessStream>(ciphertext);
+ util::StatusOr<std::unique_ptr<RandomAccessStream>>
+ decrypting_random_access_stream =
+ streaming_aead.NewDecryptingRandomAccessStream(
+ std::move(test_random_access_stream), "Some AAD");
+ ASSERT_THAT(decrypting_random_access_stream.status(), IsOk());
+
+ int chunk_size = 10;
+ auto buffer = util::Buffer::New(chunk_size);
+ std::string plaintext;
+ int64_t position = 0;
+ util::Status status = (*decrypting_random_access_stream)
+ ->PRead(position, chunk_size, buffer->get());
+ while (status.ok()) {
+ plaintext.append((*buffer)->get_mem_block(), (*buffer)->size());
+ position += (*buffer)->size();
+ status = (*decrypting_random_access_stream)
+ ->PRead(position, chunk_size, buffer->get());
+ }
+ EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange));
+ plaintext.append((*buffer)->get_mem_block(), (*buffer)->size());
+ EXPECT_EQ(stream_content, plaintext);
+}
+
+TEST(DummyStreamingAead, DummyDecryptingStreamPreadWithSmallerHeaderFails) {
+ const int stream_size = 1024;
+ std::string stream_content = subtle::Random::GetRandomBytes(stream_size);
+
+ auto ostream = std::make_unique<std::ostringstream>();
+ auto output_stream =
+ std::make_unique<util::OstreamOutputStream>(std::move(ostream));
+
+ constexpr absl::string_view kStreamingAeadName = "Some AEAD";
+ constexpr absl::string_view kStreamingAeadAad = "Some associated data";
+
+ DummyStreamingAead streaming_aead(kStreamingAeadName);
+ util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream =
+ streaming_aead.NewEncryptingStream(std::move(output_stream),
+ kStreamingAeadAad);
+ ASSERT_THAT(encrypting_output_stream.status(), IsOk());
+ ASSERT_THAT(subtle::test::WriteToStream(
+ encrypting_output_stream.value().get(), stream_content),
+ IsOk());
+ // Stream content size is too small; DummyDecryptingStream expects
+ // absl::StrCat(kStreamingAeadName, kStreamingAeadAad).
+ std::string ciphertext = "Invalid header";
+ auto test_random_access_stream =
+ std::make_unique<TestRandomAccessStream>(ciphertext);
+ util::StatusOr<std::unique_ptr<RandomAccessStream>>
+ decrypting_random_access_stream =
+ streaming_aead.NewDecryptingRandomAccessStream(
+ std::move(test_random_access_stream), kStreamingAeadAad);
+ ASSERT_THAT(decrypting_random_access_stream.status(), IsOk());
+
+ int chunk_size = 10;
+ auto buffer = util::Buffer::New(chunk_size);
+ EXPECT_THAT(
+ (*decrypting_random_access_stream)
+ ->PRead(/*position=*/0, chunk_size, buffer->get()),
+ StatusIs(absl::StatusCode::kInvalidArgument, "Could not read header"));
+ EXPECT_THAT(
+ (*decrypting_random_access_stream)
+ ->PRead(/*position=*/0, chunk_size, buffer->get()),
+ StatusIs(absl::StatusCode::kInvalidArgument, "Could not read header"));
+ EXPECT_THAT(
+ (*decrypting_random_access_stream)->size().status(),
+ StatusIs(absl::StatusCode::kInvalidArgument, "Could not read header"));
+}
+
+TEST(DummyStreamingAead, DummyDecryptingStreamPreadWithCorruptedAadFails) {
+ const int stream_size = 1024;
+ std::string stream_content = subtle::Random::GetRandomBytes(stream_size);
+
+ auto ostream = std::make_unique<std::ostringstream>();
+ auto string_stream_buffer = ostream->rdbuf();
+ auto output_stream =
+ std::make_unique<util::OstreamOutputStream>(std::move(ostream));
+
+ constexpr absl::string_view kStreamingAeadName = "Some AEAD";
+ constexpr absl::string_view kStreamingAeadAad = "Some associated data";
+
+ DummyStreamingAead streaming_aead(kStreamingAeadName);
+ util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream =
+ streaming_aead.NewEncryptingStream(std::move(output_stream),
+ kStreamingAeadAad);
+ ASSERT_THAT(encrypting_output_stream.status(), IsOk());
+ ASSERT_THAT(subtle::test::WriteToStream(
+ encrypting_output_stream.value().get(), stream_content),
+ IsOk());
+ // Invalid associated data.
+ std::string ciphertext = string_stream_buffer->str();
+ auto test_random_access_stream =
+ std::make_unique<TestRandomAccessStream>(ciphertext);
+ util::StatusOr<std::unique_ptr<RandomAccessStream>>
+ decrypting_random_access_stream =
+ streaming_aead.NewDecryptingRandomAccessStream(
+ std::move(test_random_access_stream), "Some wrong AAD");
+ ASSERT_THAT(decrypting_random_access_stream.status(), IsOk());
+
+ int chunk_size = 10;
+ auto buffer = util::Buffer::New(chunk_size);
+ EXPECT_THAT((*decrypting_random_access_stream)
+ ->PRead(/*position=*/0, chunk_size, buffer->get()),
+ StatusIs(absl::StatusCode::kInvalidArgument, "Corrupted header"));
+ EXPECT_THAT((*decrypting_random_access_stream)
+ ->PRead(/*position=*/0, chunk_size, buffer->get()),
+ StatusIs(absl::StatusCode::kInvalidArgument, "Corrupted header"));
+ EXPECT_THAT((*decrypting_random_access_stream)->size().status(),
+ StatusIs(absl::StatusCode::kInvalidArgument, "Corrupted header"));
+}
+
} // namespace
} // namespace test
} // namespace tink