diff options
author | ambrosin <ambrosin@google.com> | 2023-01-20 21:08:46 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2023-01-20 21:10:26 -0800 |
commit | 8caaaca75a7317ddccc22723e50197040ffb2a92 (patch) | |
tree | d31fe048f89a453d7772b6f63cdaacd7ed5f21f4 /cc/util | |
parent | c97b0eca3cb8a56e4942606c8cef1c47684d7e41 (diff) | |
download | tink-8caaaca75a7317ddccc22723e50197040ffb2a92.tar.gz |
Fix DummyDecryptingRandomAccessStream and add unit tests
DummyDecryptingRandomAccessStream behaves incorrectly when reading all the content from the stream returns an EOF, and > 0 bytes are added to the buffer.
PiperOrigin-RevId: 503585078
Diffstat (limited to 'cc/util')
-rw-r--r-- | cc/util/BUILD.bazel | 8 | ||||
-rw-r--r-- | cc/util/CMakeLists.txt | 8 | ||||
-rw-r--r-- | cc/util/test_util.h | 57 | ||||
-rw-r--r-- | cc/util/test_util_test.cc | 181 |
4 files changed, 230 insertions, 24 deletions
diff --git a/cc/util/BUILD.bazel b/cc/util/BUILD.bazel index c95ca0c30..5b74a21e1 100644 --- a/cc/util/BUILD.bazel +++ b/cc/util/BUILD.bazel @@ -537,11 +537,19 @@ cc_test( name = "test_util_test", srcs = ["test_util_test.cc"], deps = [ + ":buffer", + ":ostream_output_stream", + ":statusor", ":test_matchers", ":test_util", + "//:output_stream", + "//:random_access_stream", + "//internal:test_random_access_stream", "//proto:aes_gcm_cc_proto", "//proto:tink_cc_proto", "//subtle", + "//subtle:test_util", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) diff --git a/cc/util/CMakeLists.txt b/cc/util/CMakeLists.txt index 8d64d7d93..a5f8adbd9 100644 --- a/cc/util/CMakeLists.txt +++ b/cc/util/CMakeLists.txt @@ -370,10 +370,18 @@ tink_cc_test( SRCS test_util_test.cc DEPS + tink::util::buffer + tink::util::ostream_output_stream + tink::util::statusor tink::util::test_matchers tink::util::test_util gmock + absl::strings + tink::core::output_stream + tink::core::random_access_stream + tink::internal::test_random_access_stream tink::subtle::subtle + tink::subtle::test_util tink::proto::aes_gcm_cc_proto tink::proto::tink_cc_proto ) diff --git a/cc/util/test_util.h b/cc/util/test_util.h index 5f97f5509..4e2657d86 100644 --- a/cc/util/test_util.h +++ b/cc/util/test_util.h @@ -495,40 +495,35 @@ class DummyStreamingAead : public StreamingAead { util::Status status_; }; // class DummyDecryptingStream - // Upon first call to PRead() tries to read from 'ct_source' a header - // that is expected to be equal to 'expected_header'. If this + // Upon first call to PRead() tries to read from `ct_source` a header + // that is expected to be equal to `expected_header`. If this // header matching succeeds, all subsequent method calls are forwarded - // to the corresponding methods of 'cd_source'. + // to `ct_source->PRead`. class DummyDecryptingRandomAccessStream : public crypto::tink::RandomAccessStream { public: DummyDecryptingRandomAccessStream( std::unique_ptr<crypto::tink::RandomAccessStream> ct_source, absl::string_view expected_header) - : ct_source_(std::move(ct_source)), - exp_header_(expected_header), - status_(util::Status(absl::StatusCode::kUnavailable, - "not initialized")) {} + : ct_source_(std::move(ct_source)), exp_header_(expected_header) {} crypto::tink::util::Status PRead( int64_t position, int count, crypto::tink::util::Buffer* dest_buffer) override { - { // Initialize, if not initialized yet. - absl::MutexLock lock(&status_mutex_); - if (status_.code() == absl::StatusCode::kUnavailable) Initialize(); - if (!status_.ok()) return status_; + util::Status status = CheckHeader(); + if (!status.ok()) { + return status; } - auto status = dest_buffer->set_size(0); + status = dest_buffer->set_size(0); if (!status.ok()) return status; return ct_source_->PRead(position + exp_header_.size(), count, dest_buffer); } util::StatusOr<int64_t> size() override { - { // Initialize, if not initialized yet. - absl::MutexLock lock(&status_mutex_); - if (status_.code() == absl::StatusCode::kUnavailable) Initialize(); - if (!status_.ok()) return status_; + util::Status status = CheckHeader(); + if (!status.ok()) { + return status; } auto ct_size_result = ct_source_->size(); if (!ct_size_result.ok()) return ct_size_result.status(); @@ -538,25 +533,39 @@ class DummyStreamingAead : public StreamingAead { } private: - void Initialize() ABSL_EXCLUSIVE_LOCKS_REQUIRED(status_mutex_) { + util::Status CheckHeader() + ABSL_LOCKS_EXCLUDED(header_check_status_mutex_) { + absl::MutexLock lock(&header_check_status_mutex_); + if (header_check_status_.code() != absl::StatusCode::kUnavailable) { + return header_check_status_; + } auto buf = std::move(util::Buffer::New(exp_header_.size()).value()); - status_ = ct_source_->PRead(0, exp_header_.size(), buf.get()); - if (!status_.ok() && status_.code() != absl::StatusCode::kOutOfRange) - return; + header_check_status_ = + ct_source_->PRead(0, exp_header_.size(), buf.get()); + if (!header_check_status_.ok() && + header_check_status_.code() != absl::StatusCode::kOutOfRange) { + return header_check_status_; + } + // EOF or Ok indicate a valid read has happened. + header_check_status_ = util::OkStatus(); + // Invalid header. if (buf->size() < exp_header_.size()) { - status_ = util::Status(absl::StatusCode::kInvalidArgument, + header_check_status_ = util::Status(absl::StatusCode::kInvalidArgument, "Could not read header"); } else if (memcmp(buf->get_mem_block(), exp_header_.data(), static_cast<int>(exp_header_.size()))) { - status_ = util::Status(absl::StatusCode::kInvalidArgument, + header_check_status_ = util::Status(absl::StatusCode::kInvalidArgument, "Corrupted header"); } + return header_check_status_; } std::unique_ptr<crypto::tink::RandomAccessStream> ct_source_; std::string exp_header_; - mutable absl::Mutex status_mutex_; - util::Status status_ ABSL_GUARDED_BY(status_mutex_); + mutable absl::Mutex header_check_status_mutex_; + util::Status header_check_status_ + ABSL_GUARDED_BY(header_check_status_mutex_) = + util::Status(absl::StatusCode::kUnavailable, "Uninitialized"); }; // class DummyDecryptingRandomAccessStream private: diff --git a/cc/util/test_util_test.cc b/cc/util/test_util_test.cc index 4b8fdf9e0..34171c334 100644 --- a/cc/util/test_util_test.cc +++ b/cc/util/test_util_test.cc @@ -15,9 +15,24 @@ /////////////////////////////////////////////////////////////////////////////// #include "tink/util/test_util.h" +#include <algorithm> +#include <cstdint> +#include <memory> +#include <sstream> +#include <string> +#include <utility> + #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "tink/internal/test_random_access_stream.h" +#include "tink/output_stream.h" +#include "tink/random_access_stream.h" #include "tink/subtle/random.h" +#include "tink/subtle/test_util.h" +#include "tink/util/buffer.h" +#include "tink/util/ostream_output_stream.h" +#include "tink/util/statusor.h" #include "tink/util/test_matchers.h" #include "proto/aes_gcm.pb.h" #include "proto/tink.pb.h" @@ -27,6 +42,8 @@ namespace tink { namespace test { namespace { +using ::crypto::tink::internal::TestRandomAccessStream; +using ::crypto::tink::test::StatusIs; using ::google::crypto::tink::AesGcmKey; using ::google::crypto::tink::KeyData; using ::testing::Eq; @@ -107,6 +124,170 @@ TEST(ZTests, AutocorrelationUniformString) { IsOk()); } +TEST(DummyStreamingAead, DummyDecryptingStreamPreadAllAtOnceSucceeds) { + const int stream_size = 1024; + std::string stream_content = subtle::Random::GetRandomBytes(stream_size); + + auto ostream = std::make_unique<std::ostringstream>(); + auto string_stream_buffer = ostream->rdbuf(); + auto output_stream = + std::make_unique<util::OstreamOutputStream>(std::move(ostream)); + + DummyStreamingAead streaming_aead("Some AEAD"); + util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream = + streaming_aead.NewEncryptingStream(std::move(output_stream), "Some AAD"); + ASSERT_THAT(encrypting_output_stream.status(), IsOk()); + ASSERT_THAT(subtle::test::WriteToStream( + encrypting_output_stream.value().get(), stream_content), + IsOk()); + + std::string ciphertext = string_stream_buffer->str(); + auto test_random_access_stream = + std::make_unique<TestRandomAccessStream>(ciphertext); + util::StatusOr<std::unique_ptr<RandomAccessStream>> + decrypting_random_access_stream = + streaming_aead.NewDecryptingRandomAccessStream( + std::move(test_random_access_stream), "Some AAD"); + ASSERT_THAT(decrypting_random_access_stream.status(), IsOk()); + + auto buffer = util::Buffer::New(ciphertext.size()); + EXPECT_THAT((*decrypting_random_access_stream) + ->PRead(/*position=*/0, ciphertext.size(), buffer->get()), + StatusIs(absl::StatusCode::kOutOfRange)); + EXPECT_EQ(stream_content, + std::string((*buffer)->get_mem_block(), (*buffer)->size())); +} + +TEST(DummyStreamingAead, DummyDecryptingStreamPreadInChunksSucceeds) { + const int stream_size = 1024; + std::string stream_content = subtle::Random::GetRandomBytes(stream_size); + + auto ostream = std::make_unique<std::ostringstream>(); + auto string_stream_buffer = ostream->rdbuf(); + auto output_stream = + std::make_unique<util::OstreamOutputStream>(std::move(ostream)); + + DummyStreamingAead streaming_aead("Some AEAD"); + util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream = + streaming_aead.NewEncryptingStream(std::move(output_stream), "Some AAD"); + ASSERT_THAT(encrypting_output_stream.status(), IsOk()); + ASSERT_THAT(subtle::test::WriteToStream( + encrypting_output_stream.value().get(), stream_content), + IsOk()); + + std::string ciphertext = string_stream_buffer->str(); + auto test_random_access_stream = + std::make_unique<TestRandomAccessStream>(ciphertext); + util::StatusOr<std::unique_ptr<RandomAccessStream>> + decrypting_random_access_stream = + streaming_aead.NewDecryptingRandomAccessStream( + std::move(test_random_access_stream), "Some AAD"); + ASSERT_THAT(decrypting_random_access_stream.status(), IsOk()); + + int chunk_size = 10; + auto buffer = util::Buffer::New(chunk_size); + std::string plaintext; + int64_t position = 0; + util::Status status = (*decrypting_random_access_stream) + ->PRead(position, chunk_size, buffer->get()); + while (status.ok()) { + plaintext.append((*buffer)->get_mem_block(), (*buffer)->size()); + position += (*buffer)->size(); + status = (*decrypting_random_access_stream) + ->PRead(position, chunk_size, buffer->get()); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange)); + plaintext.append((*buffer)->get_mem_block(), (*buffer)->size()); + EXPECT_EQ(stream_content, plaintext); +} + +TEST(DummyStreamingAead, DummyDecryptingStreamPreadWithSmallerHeaderFails) { + const int stream_size = 1024; + std::string stream_content = subtle::Random::GetRandomBytes(stream_size); + + auto ostream = std::make_unique<std::ostringstream>(); + auto output_stream = + std::make_unique<util::OstreamOutputStream>(std::move(ostream)); + + constexpr absl::string_view kStreamingAeadName = "Some AEAD"; + constexpr absl::string_view kStreamingAeadAad = "Some associated data"; + + DummyStreamingAead streaming_aead(kStreamingAeadName); + util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream = + streaming_aead.NewEncryptingStream(std::move(output_stream), + kStreamingAeadAad); + ASSERT_THAT(encrypting_output_stream.status(), IsOk()); + ASSERT_THAT(subtle::test::WriteToStream( + encrypting_output_stream.value().get(), stream_content), + IsOk()); + // Stream content size is too small; DummyDecryptingStream expects + // absl::StrCat(kStreamingAeadName, kStreamingAeadAad). + std::string ciphertext = "Invalid header"; + auto test_random_access_stream = + std::make_unique<TestRandomAccessStream>(ciphertext); + util::StatusOr<std::unique_ptr<RandomAccessStream>> + decrypting_random_access_stream = + streaming_aead.NewDecryptingRandomAccessStream( + std::move(test_random_access_stream), kStreamingAeadAad); + ASSERT_THAT(decrypting_random_access_stream.status(), IsOk()); + + int chunk_size = 10; + auto buffer = util::Buffer::New(chunk_size); + EXPECT_THAT( + (*decrypting_random_access_stream) + ->PRead(/*position=*/0, chunk_size, buffer->get()), + StatusIs(absl::StatusCode::kInvalidArgument, "Could not read header")); + EXPECT_THAT( + (*decrypting_random_access_stream) + ->PRead(/*position=*/0, chunk_size, buffer->get()), + StatusIs(absl::StatusCode::kInvalidArgument, "Could not read header")); + EXPECT_THAT( + (*decrypting_random_access_stream)->size().status(), + StatusIs(absl::StatusCode::kInvalidArgument, "Could not read header")); +} + +TEST(DummyStreamingAead, DummyDecryptingStreamPreadWithCorruptedAadFails) { + const int stream_size = 1024; + std::string stream_content = subtle::Random::GetRandomBytes(stream_size); + + auto ostream = std::make_unique<std::ostringstream>(); + auto string_stream_buffer = ostream->rdbuf(); + auto output_stream = + std::make_unique<util::OstreamOutputStream>(std::move(ostream)); + + constexpr absl::string_view kStreamingAeadName = "Some AEAD"; + constexpr absl::string_view kStreamingAeadAad = "Some associated data"; + + DummyStreamingAead streaming_aead(kStreamingAeadName); + util::StatusOr<std::unique_ptr<OutputStream>> encrypting_output_stream = + streaming_aead.NewEncryptingStream(std::move(output_stream), + kStreamingAeadAad); + ASSERT_THAT(encrypting_output_stream.status(), IsOk()); + ASSERT_THAT(subtle::test::WriteToStream( + encrypting_output_stream.value().get(), stream_content), + IsOk()); + // Invalid associated data. + std::string ciphertext = string_stream_buffer->str(); + auto test_random_access_stream = + std::make_unique<TestRandomAccessStream>(ciphertext); + util::StatusOr<std::unique_ptr<RandomAccessStream>> + decrypting_random_access_stream = + streaming_aead.NewDecryptingRandomAccessStream( + std::move(test_random_access_stream), "Some wrong AAD"); + ASSERT_THAT(decrypting_random_access_stream.status(), IsOk()); + + int chunk_size = 10; + auto buffer = util::Buffer::New(chunk_size); + EXPECT_THAT((*decrypting_random_access_stream) + ->PRead(/*position=*/0, chunk_size, buffer->get()), + StatusIs(absl::StatusCode::kInvalidArgument, "Corrupted header")); + EXPECT_THAT((*decrypting_random_access_stream) + ->PRead(/*position=*/0, chunk_size, buffer->get()), + StatusIs(absl::StatusCode::kInvalidArgument, "Corrupted header")); + EXPECT_THAT((*decrypting_random_access_stream)->size().status(), + StatusIs(absl::StatusCode::kInvalidArgument, "Corrupted header")); +} + } // namespace } // namespace test } // namespace tink |