diff options
author | btolsch <btolsch@chromium.org> | 2019-10-01 11:39:33 -0700 |
---|---|---|
committer | Commit Bot <commit-bot@chromium.org> | 2019-10-01 20:24:53 +0000 |
commit | 9dd4cf87c1214381e72be6b9e1c4af33417e8f99 (patch) | |
tree | 9769ab2c497ed03ff6dac98bec5bfc77434fd36e | |
parent | 5dc91624a6d9d1369769eab8b603399744cdd63c (diff) | |
download | openscreen-9dd4cf87c1214381e72be6b9e1c4af33417e8f99.tar.gz |
Add CastSocket and implementation
This change adds a CastSocket interface to handle sending and receiving
CastMessage structures along with an implementation that uses the
platform TlsConnection for transport.
Bug: openscreen:59
Change-Id: I92dc29b45efb2b1657a2ff6d962f1ed670311370
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1825511
Commit-Queue: Brandon Tolsch <btolsch@chromium.org>
Reviewed-by: Ryan Keane <rwkeane@google.com>
-rw-r--r-- | BUILD.gn | 1 | ||||
-rw-r--r-- | cast/common/channel/BUILD.gn | 40 | ||||
-rw-r--r-- | cast/common/channel/cast_socket.cc | 85 | ||||
-rw-r--r-- | cast/common/channel/cast_socket.h | 76 | ||||
-rw-r--r-- | cast/common/channel/cast_socket_unittest.cc | 195 | ||||
-rw-r--r-- | cast/common/channel/message_framer.cc | 71 | ||||
-rw-r--r-- | cast/common/channel/message_framer.h | 43 | ||||
-rw-r--r-- | cast/common/channel/message_framer_unittest.cc | 153 | ||||
-rw-r--r-- | cast/common/channel/proto/BUILD.gn (renamed from cast/sender/channel/proto/BUILD.gn) | 0 | ||||
-rw-r--r-- | cast/common/channel/proto/cast_channel.proto (renamed from cast/sender/channel/proto/cast_channel.proto) | 0 | ||||
-rw-r--r-- | cast/sender/channel/BUILD.gn | 10 | ||||
-rw-r--r-- | cast/sender/channel/cast_auth_util.h | 2 | ||||
-rw-r--r-- | cast/sender/channel/cast_auth_util_unittest.cc | 2 | ||||
-rw-r--r-- | cast/sender/channel/cast_framer.cc | 94 | ||||
-rw-r--r-- | cast/sender/channel/cast_framer.h | 59 | ||||
-rw-r--r-- | cast/sender/channel/cast_framer_unittest.cc | 179 |
16 files changed, 668 insertions, 342 deletions
@@ -50,6 +50,7 @@ executable("openscreen_unittests") { deps = [ "cast/common:unittests", "cast/common/certificate:unittests", + "cast/common/channel:unittests", "cast/sender/channel:unittests", "osp:unittests", "osp/impl/discovery/mdns:unittests", diff --git a/cast/common/channel/BUILD.gn b/cast/common/channel/BUILD.gn new file mode 100644 index 00000000..71b393b2 --- /dev/null +++ b/cast/common/channel/BUILD.gn @@ -0,0 +1,40 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("channel") { + sources = [ + "cast_socket.cc", + "cast_socket.h", + "message_framer.cc", + "message_framer.h", + ] + + deps = [ + "../../../util", + "proto", + ] + + public_deps = [ + "../../../platform", + "../../../third_party/abseil", + ] +} + +source_set("unittests") { + testonly = true + sources = [ + "cast_socket_unittest.cc", + "message_framer_unittest.cc", + ] + + deps = [ + ":channel", + "../../../platform", + "../../../third_party/googletest:gmock", + "../../../third_party/googletest:gtest", + "../../../util", + "../../common/certificate/proto:unittest_proto", + "proto", + ] +} diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc new file mode 100644 index 00000000..8ad61542 --- /dev/null +++ b/cast/common/channel/cast_socket.cc @@ -0,0 +1,85 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/cast_socket.h" + +#include "cast/common/channel/message_framer.h" +#include "platform/api/logging.h" + +namespace cast { +namespace channel { + +using message_serialization::DeserializeResult; +using openscreen::ErrorOr; +using openscreen::platform::TlsConnection; + +CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection, + Client* client, + uint32_t socket_id) + : client_(client), + connection_(std::move(connection)), + socket_id_(socket_id) { + OSP_DCHECK(client); + connection_->set_client(this); +} + +CastSocket::~CastSocket() = default; + +Error CastSocket::SendMessage(const CastMessage& message) { + if (state_ == State::kError) { + return Error::Code::kSocketClosedFailure; + } + + const ErrorOr<std::vector<uint8_t>> out = + message_serialization::Serialize(message); + if (!out) { + return out.error(); + } + + if (state_ == State::kBlocked) { + message_queue_.emplace_back(std::move(out.value())); + return Error::Code::kNone; + } + + connection_->Write(out.value().data(), out.value().size()); + return Error::Code::kNone; +} + +void CastSocket::OnWriteBlocked(TlsConnection* connection) { + if (state_ == State::kOpen) { + state_ = State::kBlocked; + } +} + +void CastSocket::OnWriteUnblocked(TlsConnection* connection) { + if (state_ == State::kBlocked) { + state_ = State::kOpen; + for (const auto& message : message_queue_) { + connection_->Write(message.data(), message.size()); + } + OSP_DCHECK(state_ == State::kOpen) << static_cast<int>(state_); + message_queue_.clear(); + } +} + +void CastSocket::OnError(TlsConnection* connection, Error error) { + state_ = State::kError; + client_->OnError(this, error); +} + +void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) { + read_buffer_.insert(read_buffer_.end(), block.begin(), block.end()); + ErrorOr<DeserializeResult> message_or_error = + message_serialization::TryDeserialize( + absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size())); + if (!message_or_error) { + return; + } + read_buffer_.erase(read_buffer_.begin(), + read_buffer_.begin() + message_or_error.value().length); + client_->OnMessage(this, std::move(message_or_error.value().message)); +} + +} // namespace channel +} // namespace cast diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h new file mode 100644 index 00000000..a8fa3c48 --- /dev/null +++ b/cast/common/channel/cast_socket.h @@ -0,0 +1,76 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_CAST_SOCKET_H_ +#define CAST_COMMON_CHANNEL_CAST_SOCKET_H_ + +#include <vector> + +#include "platform/api/tls_connection.h" + +namespace cast { +namespace channel { + +using openscreen::Error; +using TlsConnection = openscreen::platform::TlsConnection; + +class CastMessage; + +// Represents a simple message-oriented socket for communicating with the Cast +// V2 protocol. It isn't thread-safe, so it should only be used on the same +// TaskRunner thread as its TlsConnection. +class CastSocket : public TlsConnection::Client { + public: + class Client { + public: + virtual ~Client() = default; + + // Called when a terminal error on |socket| has occurred. + virtual void OnError(CastSocket* socket, Error error) = 0; + + virtual void OnMessage(CastSocket* socket, CastMessage message) = 0; + }; + + CastSocket(std::unique_ptr<TlsConnection> connection, + Client* client, + uint32_t socket_id); + ~CastSocket(); + + // Sends |message| immediately unless the underlying TLS connection is + // write-blocked, in which case |message| will be queued. No error is + // returned for both queueing and successful sending. An error will be + // returned if |message| cannot be serialized for any reason. + Error SendMessage(const CastMessage& message); + + void set_client(Client* client) { + OSP_DCHECK(client); + client_ = client; + } + uint32_t socket_id() const { return socket_id_; } + + // TlsConnection::Client overrides. + void OnWriteBlocked(TlsConnection* connection) override; + void OnWriteUnblocked(TlsConnection* connection) override; + void OnError(TlsConnection* connection, Error error) override; + void OnRead(TlsConnection* connection, std::vector<uint8_t> block) override; + + private: + enum class State { + kOpen, + kBlocked, + kError, + }; + + Client* client_; + const std::unique_ptr<TlsConnection> connection_; + std::vector<uint8_t> read_buffer_; + const uint32_t socket_id_; + State state_ = State::kOpen; + std::vector<std::vector<uint8_t>> message_queue_; +}; + +} // namespace channel +} // namespace cast + +#endif // CAST_COMMON_CHANNEL_CAST_SOCKET_H_ diff --git a/cast/common/channel/cast_socket_unittest.cc b/cast/common/channel/cast_socket_unittest.cc new file mode 100644 index 00000000..c16ce4dc --- /dev/null +++ b/cast/common/channel/cast_socket_unittest.cc @@ -0,0 +1,195 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/cast_socket.h" + +#include "cast/common/channel/message_framer.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/test/fake_task_runner.h" + +namespace cast { +namespace channel { +namespace { + +using openscreen::ErrorOr; +using openscreen::IPEndpoint; +using openscreen::platform::FakeClock; +using openscreen::platform::FakeTaskRunner; +using openscreen::platform::TaskRunner; +using openscreen::platform::TlsConnection; + +using ::testing::_; +using ::testing::Invoke; + +class MockTlsConnection final : public TlsConnection { + public: + MockTlsConnection(TaskRunner* task_runner, + IPEndpoint local_address, + IPEndpoint remote_address) + : TlsConnection(task_runner), + local_address_(local_address), + remote_address_(remote_address) {} + + ~MockTlsConnection() override = default; + + MOCK_METHOD(void, Write, (const void* data, size_t len)); + + const IPEndpoint& local_address() const override { return local_address_; } + const IPEndpoint& remote_address() const override { return remote_address_; } + + void OnWriteBlocked() { TlsConnection::OnWriteBlocked(); } + void OnWriteUnblocked() { TlsConnection::OnWriteUnblocked(); } + void OnError(Error error) { TlsConnection::OnError(error); } + void OnRead(std::vector<uint8_t> block) { TlsConnection::OnRead(block); } + + private: + const IPEndpoint local_address_; + const IPEndpoint remote_address_; +}; + +class MockCastSocketClient final : public CastSocket::Client { + public: + ~MockCastSocketClient() override = default; + + MOCK_METHOD(void, OnError, (CastSocket * socket, Error error)); + MOCK_METHOD(void, OnMessage, (CastSocket * socket, CastMessage message)); +}; + +class CastSocketTest : public ::testing::Test { + public: + void SetUp() override { + message_.set_protocol_version(CastMessage::CASTV2_1_0); + message_.set_source_id("source"); + message_.set_destination_id("destination"); + message_.set_namespace_("namespace"); + message_.set_payload_type(CastMessage::STRING); + message_.set_payload_utf8("payload"); + ErrorOr<std::vector<uint8_t>> serialized_or_error = + message_serialization::Serialize(message_); + ASSERT_TRUE(serialized_or_error); + frame_serial_ = std::move(serialized_or_error.value()); + } + + protected: + FakeClock clock_{openscreen::platform::Clock::now()}; + FakeTaskRunner task_runner_{&clock_}; + IPEndpoint local_{{10, 0, 1, 7}, 1234}; + IPEndpoint remote_{{10, 0, 1, 9}, 4321}; + std::unique_ptr<MockTlsConnection> moved_connection_{ + new MockTlsConnection(&task_runner_, local_, remote_)}; + MockTlsConnection* connection_{moved_connection_.get()}; + MockCastSocketClient mock_client_; + CastSocket socket_{std::move(moved_connection_), &mock_client_, 1}; + CastMessage message_; + std::vector<uint8_t> frame_serial_; +}; + +} // namespace + +TEST_F(CastSocketTest, SendMessage) { + EXPECT_CALL(*connection_, Write(_, _)) + .WillOnce(Invoke([this](const void* data, size_t len) { + EXPECT_EQ( + frame_serial_, + std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), + reinterpret_cast<const uint8_t*>(data) + len)); + })); + ASSERT_TRUE(socket_.SendMessage(message_).ok()); +} + +TEST_F(CastSocketTest, ReadCompleteMessage) { + const uint8_t* data = frame_serial_.data(); + EXPECT_CALL(mock_client_, OnMessage(_, _)) + .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { + EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); + })); + connection_->OnRead(std::vector<uint8_t>(data, data + frame_serial_.size())); + task_runner_.RunTasksUntilIdle(); +} + +TEST_F(CastSocketTest, ReadChunkedMessage) { + const uint8_t* data = frame_serial_.data(); + EXPECT_CALL(mock_client_, OnMessage(_, _)).Times(0); + connection_->OnRead(std::vector<uint8_t>(data, data + 10)); + task_runner_.RunTasksUntilIdle(); + + EXPECT_CALL(mock_client_, OnMessage(_, _)) + .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { + EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); + })); + connection_->OnRead( + std::vector<uint8_t>(data + 10, data + frame_serial_.size())); + task_runner_.RunTasksUntilIdle(); + + std::vector<uint8_t> double_message; + double_message.insert(double_message.end(), frame_serial_.begin(), + frame_serial_.end()); + double_message.insert(double_message.end(), frame_serial_.begin(), + frame_serial_.end()); + data = double_message.data(); + EXPECT_CALL(mock_client_, OnMessage(_, _)) + .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { + EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); + })); + connection_->OnRead( + std::vector<uint8_t>(data, data + frame_serial_.size() + 10)); + task_runner_.RunTasksUntilIdle(); + + EXPECT_CALL(mock_client_, OnMessage(_, _)) + .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { + EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); + })); + connection_->OnRead(std::vector<uint8_t>(data + frame_serial_.size() + 10, + data + double_message.size())); + task_runner_.RunTasksUntilIdle(); +} + +TEST_F(CastSocketTest, SendMessageWhileBlocked) { + connection_->OnWriteBlocked(); + task_runner_.RunTasksUntilIdle(); + EXPECT_CALL(*connection_, Write(_, _)).Times(0); + ASSERT_TRUE(socket_.SendMessage(message_).ok()); + + EXPECT_CALL(*connection_, Write(_, _)) + .WillOnce(Invoke([this](const void* data, size_t len) { + EXPECT_EQ( + frame_serial_, + std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), + reinterpret_cast<const uint8_t*>(data) + len)); + })); + connection_->OnWriteUnblocked(); + task_runner_.RunTasksUntilIdle(); + + EXPECT_CALL(*connection_, Write(_, _)).Times(0); + connection_->OnWriteBlocked(); + task_runner_.RunTasksUntilIdle(); + connection_->OnWriteUnblocked(); + task_runner_.RunTasksUntilIdle(); +} + +TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) { + connection_->OnWriteBlocked(); + task_runner_.RunTasksUntilIdle(); + EXPECT_CALL(*connection_, Write(_, _)).Times(0); + ASSERT_TRUE(socket_.SendMessage(message_).ok()); + + EXPECT_CALL(*connection_, Write(_, _)) + .WillOnce(Invoke([this](const void* data, size_t len) { + EXPECT_EQ( + frame_serial_, + std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), + reinterpret_cast<const uint8_t*>(data) + len)); + connection_->OnError(Error::Code::kUnknownError); + })); + connection_->OnWriteUnblocked(); + task_runner_.RunTasksUntilIdle(); + + EXPECT_CALL(*connection_, Write(_, _)).Times(0); + ASSERT_FALSE(socket_.SendMessage(message_).ok()); +} + +} // namespace channel +} // namespace cast diff --git a/cast/common/channel/message_framer.cc b/cast/common/channel/message_framer.cc new file mode 100644 index 00000000..97a82466 --- /dev/null +++ b/cast/common/channel/message_framer.cc @@ -0,0 +1,71 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/message_framer.h" + +#include <stdlib.h> +#include <string.h> + +#include <limits> + +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "platform/api/logging.h" +#include "util/big_endian.h" + +namespace cast { +namespace channel { +namespace message_serialization { + +using openscreen::Error; + +namespace { + +static constexpr size_t kHeaderSize = sizeof(uint32_t); + +// Cast specifies a max message body size of 64 KiB. +static constexpr size_t kMaxBodySize = 65536; + +} // namespace + +ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message) { + const size_t message_size = message.ByteSizeLong(); + if (message_size > kMaxBodySize || message_size == 0) { + return Error::Code::kCastV2InvalidMessage; + } + std::vector<uint8_t> out(message_size + kHeaderSize, 0); + openscreen::WriteBigEndian<uint32_t>(message_size, out.data()); + if (!message.SerializeToArray(&out[kHeaderSize], message_size)) { + return Error::Code::kCastV2InvalidMessage; + } + return out; +} + +ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input) { + if (input.size() < kHeaderSize) { + return Error::Code::kInsufficientBuffer; + } + + const uint32_t message_size = + openscreen::ReadBigEndian<uint32_t>(input.data()); + if (message_size > kMaxBodySize) { + return Error::Code::kCastV2InvalidMessage; + } + + if (input.size() < (kHeaderSize + message_size)) { + return Error::Code::kInsufficientBuffer; + } + + DeserializeResult result; + if (!result.message.ParseFromArray(input.data() + kHeaderSize, + message_size)) { + return Error::Code::kCastV2InvalidMessage; + } + result.length = kHeaderSize + message_size; + + return result; +} + +} // namespace message_serialization +} // namespace channel +} // namespace cast diff --git a/cast/common/channel/message_framer.h b/cast/common/channel/message_framer.h new file mode 100644 index 00000000..c092487c --- /dev/null +++ b/cast/common/channel/message_framer.h @@ -0,0 +1,43 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_ +#define CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <memory> +#include <vector> + +#include "absl/types/span.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "platform/base/error.h" + +namespace cast { +namespace channel { +namespace message_serialization { + +using openscreen::ErrorOr; + +// Serializes |message_proto| into |message_data|. +// Returns true if the message was serialized successfully, false otherwise. +ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message); + +struct DeserializeResult { + CastMessage message; + size_t length; +}; + +// Reads bytes from |input| and returns a new CastMessage if one is fully +// read. Returns a parsed CastMessage if a message was received in its +// entirety, and an error otherwise. The result also contains the number of +// bytes consumed from |input| when a parse succeeds. +ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input); + +} // namespace message_serialization +} // namespace channel +} // namespace cast + +#endif // CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_ diff --git a/cast/common/channel/message_framer_unittest.cc b/cast/common/channel/message_framer_unittest.cc new file mode 100644 index 00000000..ae2ff33e --- /dev/null +++ b/cast/common/channel/message_framer_unittest.cc @@ -0,0 +1,153 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/message_framer.h" + +#include <stddef.h> + +#include <algorithm> +#include <string> + +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "gtest/gtest.h" +#include "util/big_endian.h" +#include "util/std_util.h" + +namespace cast { +namespace channel { +namespace message_serialization { + +using openscreen::Error; + +namespace { + +static constexpr size_t kHeaderSize = sizeof(uint32_t); + +// Cast specifies a max message body size of 64 KiB. +static constexpr size_t kMaxBodySize = 65536; + +} // namespace + +class CastFramerTest : public testing::Test { + public: + CastFramerTest() : buffer_(kHeaderSize + kMaxBodySize) {} + + void SetUp() override { + cast_message_.set_protocol_version(CastMessage::CASTV2_1_0); + cast_message_.set_source_id("source"); + cast_message_.set_destination_id("destination"); + cast_message_.set_namespace_("namespace"); + cast_message_.set_payload_type(CastMessage::STRING); + cast_message_.set_payload_utf8("payload"); + ErrorOr<std::vector<uint8_t>> result = Serialize(cast_message_); + ASSERT_TRUE(result.is_value()); + cast_message_serial_ = std::move(result.value()); + } + + void WriteToBuffer(const std::vector<uint8_t>& data) { + memcpy(&buffer_[0], data.data(), data.size()); + } + + absl::Span<uint8_t> GetSpan(size_t size) { + return absl::Span<uint8_t>(&buffer_[0], size); + } + absl::Span<uint8_t> GetSpan() { return GetSpan(cast_message_serial_.size()); } + + protected: + CastMessage cast_message_; + std::vector<uint8_t> cast_message_serial_; + std::vector<uint8_t> buffer_; +}; + +TEST_F(CastFramerTest, TestMessageFramerCompleteMessage) { + WriteToBuffer(cast_message_serial_); + + // Receive 1 byte of the header, framer demands 3 more bytes. + ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(1)); + EXPECT_FALSE(result); + EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code()); + + // TryDeserialize remaining 3, expect that the framer has moved on to + // requesting the body contents. + result = TryDeserialize(GetSpan(3)); + EXPECT_FALSE(result); + EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code()); + + // Remainder of packet sent over the wire. + result = TryDeserialize(GetSpan()); + ASSERT_TRUE(result); + EXPECT_EQ(result.value().length, cast_message_serial_.size()); + const CastMessage& message = result.value().message; + EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString()); +} + +TEST_F(CastFramerTest, TestSerializeErrorMessageTooLarge) { + CastMessage big_message; + big_message.CopyFrom(cast_message_); + std::string payload; + payload.append(kMaxBodySize + 1, 'x'); + big_message.set_payload_utf8(payload); + EXPECT_FALSE(Serialize(big_message)); +} + +TEST_F(CastFramerTest, TestCompleteMessageAtOnce) { + WriteToBuffer(cast_message_serial_); + + ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan()); + ASSERT_TRUE(result); + EXPECT_EQ(result.value().length, cast_message_serial_.size()); + const CastMessage& message = result.value().message; + EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString()); +} + +TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage) { + std::vector<uint8_t> mangled_cast_message = cast_message_serial_; + mangled_cast_message[0] = 88; + mangled_cast_message[1] = 88; + mangled_cast_message[2] = 88; + mangled_cast_message[3] = 88; + WriteToBuffer(mangled_cast_message); + + ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4)); + ASSERT_FALSE(result); + EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code()); +} + +TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage2) { + std::vector<uint8_t> mangled_cast_message = cast_message_serial_; + // Header indicates body size is 0x00010001 = 65537 + mangled_cast_message[0] = 0; + mangled_cast_message[1] = 0x1; + mangled_cast_message[2] = 0; + mangled_cast_message[3] = 0x1; + WriteToBuffer(mangled_cast_message); + + ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4)); + ASSERT_FALSE(result); + EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code()); +} + +TEST_F(CastFramerTest, TestUnparsableBodyProto) { + // Message header is OK, but the body is replaced with "x"es. + std::vector<uint8_t> mangled_cast_message = cast_message_serial_; + for (size_t i = kHeaderSize; i < mangled_cast_message.size(); ++i) { + std::fill(mangled_cast_message.begin() + kHeaderSize, + mangled_cast_message.end(), 'x'); + } + WriteToBuffer(mangled_cast_message); + + // Send header. + ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4)); + EXPECT_FALSE(result); + EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code()); + + // Send body, expect an error. + result = TryDeserialize(GetSpan()); + ASSERT_FALSE(result); + EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code()); +} + +} // namespace message_serialization +} // namespace channel +} // namespace cast diff --git a/cast/sender/channel/proto/BUILD.gn b/cast/common/channel/proto/BUILD.gn index c3bfa439..c3bfa439 100644 --- a/cast/sender/channel/proto/BUILD.gn +++ b/cast/common/channel/proto/BUILD.gn diff --git a/cast/sender/channel/proto/cast_channel.proto b/cast/common/channel/proto/cast_channel.proto index 57c7b3f3..57c7b3f3 100644 --- a/cast/sender/channel/proto/cast_channel.proto +++ b/cast/common/channel/proto/cast_channel.proto diff --git a/cast/sender/channel/BUILD.gn b/cast/sender/channel/BUILD.gn index 13b061c1..85f633b0 100644 --- a/cast/sender/channel/BUILD.gn +++ b/cast/sender/channel/BUILD.gn @@ -6,18 +6,14 @@ source_set("channel") { sources = [ "cast_auth_util.cc", "cast_auth_util.h", - "cast_framer.cc", - "cast_framer.h", ] deps = [ - "../../../util", - "proto", + "../../common/channel/proto", ] public_deps = [ "../../../platform", - "../../../third_party/abseil", ] } @@ -25,15 +21,13 @@ source_set("unittests") { testonly = true sources = [ "cast_auth_util_unittest.cc", - "cast_framer_unittest.cc", ] deps = [ ":channel", "../../../platform", "../../../third_party/googletest:gtest", - "../../../util", "../../common/certificate/proto:unittest_proto", - "proto", + "../../common/channel/proto", ] } diff --git a/cast/sender/channel/cast_auth_util.h b/cast/sender/channel/cast_auth_util.h index b4a81e0b..35f1d028 100644 --- a/cast/sender/channel/cast_auth_util.h +++ b/cast/sender/channel/cast_auth_util.h @@ -10,7 +10,7 @@ #include <string> #include "cast/common/certificate/cast_cert_validator.h" -#include "cast/sender/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/proto/cast_channel.pb.h" #include "platform/base/error.h" namespace cast { diff --git a/cast/sender/channel/cast_auth_util_unittest.cc b/cast/sender/channel/cast_auth_util_unittest.cc index 10819362..61d49aa1 100644 --- a/cast/sender/channel/cast_auth_util_unittest.cc +++ b/cast/sender/channel/cast_auth_util_unittest.cc @@ -10,7 +10,7 @@ #include "cast/common/certificate/cast_crl.h" #include "cast/common/certificate/proto/test_suite.pb.h" #include "cast/common/certificate/test_helpers.h" -#include "cast/sender/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/proto/cast_channel.pb.h" #include "gtest/gtest.h" #include "platform/api/logging.h" #include "platform/api/time.h" diff --git a/cast/sender/channel/cast_framer.cc b/cast/sender/channel/cast_framer.cc deleted file mode 100644 index b5a8acb1..00000000 --- a/cast/sender/channel/cast_framer.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/sender/channel/cast_framer.h" - -#include <stdlib.h> -#include <string.h> - -#include <limits> - -#include "cast/sender/channel/proto/cast_channel.pb.h" -#include "platform/api/logging.h" -#include "util/big_endian.h" -#include "util/std_util.h" - -namespace cast { -namespace channel { - -using ChannelError = openscreen::Error::Code; - -namespace { - -static constexpr size_t kHeaderSize = sizeof(uint32_t); - -// Cast specifies a max message body size of 64 KiB. -static constexpr size_t kMaxBodySize = 65536; - -} // namespace - -MessageFramer::MessageFramer(absl::Span<uint8_t> input_buffer) - : input_buffer_(input_buffer) {} - -MessageFramer::~MessageFramer() = default; - -// static -ErrorOr<std::string> MessageFramer::Serialize(const CastMessage& message) { - const size_t message_size = message.ByteSizeLong(); - if (message_size > kMaxBodySize || message_size == 0) { - return ChannelError::kCastV2InvalidMessage; - } - std::string out(message_size + kHeaderSize, 0); - openscreen::WriteBigEndian<uint32_t>(message_size, openscreen::data(out)); - if (!message.SerializeToArray(&out[kHeaderSize], message_size)) { - return ChannelError::kCastV2InvalidMessage; - } - return out; -} - -ErrorOr<size_t> MessageFramer::BytesRequested() const { - if (message_bytes_received_ < kHeaderSize) { - return kHeaderSize - message_bytes_received_; - } - - const uint32_t message_size = - openscreen::ReadBigEndian<uint32_t>(input_buffer_.data()); - if (message_size > kMaxBodySize) { - return ChannelError::kCastV2InvalidMessage; - } - return (kHeaderSize + message_size) - message_bytes_received_; -} - -ErrorOr<CastMessage> MessageFramer::TryDeserialize(size_t byte_count) { - message_bytes_received_ += byte_count; - if (message_bytes_received_ > input_buffer_.size()) { - return ChannelError::kCastV2InvalidMessage; - } - - if (message_bytes_received_ < kHeaderSize) { - return ChannelError::kInsufficientBuffer; - } - - const uint32_t message_size = - openscreen::ReadBigEndian<uint32_t>(input_buffer_.data()); - if (message_size > kMaxBodySize) { - return ChannelError::kCastV2InvalidMessage; - } - - if (message_bytes_received_ < (kHeaderSize + message_size)) { - return ChannelError::kInsufficientBuffer; - } - - CastMessage parsed_message; - if (!parsed_message.ParseFromArray(input_buffer_.data() + kHeaderSize, - message_size)) { - return ChannelError::kCastV2InvalidMessage; - } - - message_bytes_received_ = 0; - return parsed_message; -} - -} // namespace channel -} // namespace cast diff --git a/cast/sender/channel/cast_framer.h b/cast/sender/channel/cast_framer.h deleted file mode 100644 index 8fbabfd3..00000000 --- a/cast/sender/channel/cast_framer.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CAST_SENDER_CHANNEL_CAST_FRAMER_H_ -#define CAST_SENDER_CHANNEL_CAST_FRAMER_H_ - -#include <stddef.h> -#include <stdint.h> - -#include <memory> -#include <string> - -#include "absl/types/span.h" -#include "platform/base/error.h" - -namespace cast { -namespace channel { - -class CastMessage; - -using openscreen::ErrorOr; - -// Class for constructing and parsing CastMessage packet data. -class MessageFramer { - public: - // Serializes |message_proto| into |message_data|. - // Returns true if the message was serialized successfully, false otherwise. - static ErrorOr<std::string> Serialize(const CastMessage& message); - - explicit MessageFramer(absl::Span<uint8_t> input_buffer); - ~MessageFramer(); - - // The number of bytes required from the next |input_buffer| passed to - // TryDeserialize to complete the CastMessage being read. Returns zero if - // there has been a parsing error. - ErrorOr<size_t> BytesRequested() const; - - // Reads bytes from |input_buffer_| and returns a new CastMessage if one is - // fully read. - // - // |byte_count| Number of additional bytes available in |input_buffer_|. - // Returns a pointer to a parsed CastMessage if a message was received in its - // entirety, empty unique_ptr if parsing was successful but didn't produce a - // complete message, and an error otherwise. - ErrorOr<CastMessage> TryDeserialize(size_t byte_count); - - private: - // Total size of the message received so far in bytes (head + body). - size_t message_bytes_received_ = 0; - - // Data buffer wherein the caller should place message data for ingest. - absl::Span<uint8_t> input_buffer_; -}; - -} // namespace channel -} // namespace cast - -#endif // CAST_SENDER_CHANNEL_CAST_FRAMER_H_ diff --git a/cast/sender/channel/cast_framer_unittest.cc b/cast/sender/channel/cast_framer_unittest.cc deleted file mode 100644 index 0cd10c71..00000000 --- a/cast/sender/channel/cast_framer_unittest.cc +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/sender/channel/cast_framer.h" - -#include <stddef.h> - -#include <algorithm> -#include <string> - -#include "cast/sender/channel/proto/cast_channel.pb.h" -#include "gtest/gtest.h" -#include "util/big_endian.h" -#include "util/std_util.h" - -namespace cast { -namespace channel { - -using ChannelError = openscreen::Error::Code; - -namespace { - -static constexpr size_t kHeaderSize = sizeof(uint32_t); - -// Cast specifies a max message body size of 64 KiB. -static constexpr size_t kMaxBodySize = 65536; - -} // namespace - -class CastFramerTest : public testing::Test { - public: - CastFramerTest() - : buffer_(kHeaderSize + kMaxBodySize), - framer_(absl::Span<uint8_t>(&buffer_[0], buffer_.size())) {} - - void SetUp() override { - cast_message_.set_protocol_version(CastMessage::CASTV2_1_0); - cast_message_.set_source_id("source"); - cast_message_.set_destination_id("destination"); - cast_message_.set_namespace_("namespace"); - cast_message_.set_payload_type(CastMessage::STRING); - cast_message_.set_payload_utf8("payload"); - ErrorOr<std::string> result = MessageFramer::Serialize(cast_message_); - ASSERT_TRUE(result.is_value()); - cast_message_str_ = std::move(result.value()); - } - - void WriteToBuffer(const std::string& data) { - memcpy(&buffer_[0], data.data(), data.size()); - } - - protected: - CastMessage cast_message_; - std::string cast_message_str_; - std::vector<uint8_t> buffer_; - MessageFramer framer_; -}; - -TEST_F(CastFramerTest, TestMessageFramerCompleteMessage) { - WriteToBuffer(cast_message_str_); - - // Receive 1 byte of the header, framer demands 3 more bytes. - EXPECT_EQ(4u, framer_.BytesRequested().value()); - ErrorOr<CastMessage> result = framer_.TryDeserialize(1); - EXPECT_FALSE(result); - EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code()); - EXPECT_EQ(3u, framer_.BytesRequested().value()); - - // TryDeserialize remaining 3, expect that the framer has moved on to - // requesting the body contents. - result = framer_.TryDeserialize(3); - EXPECT_FALSE(result); - EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code()); - EXPECT_EQ(cast_message_str_.size() - kHeaderSize, - framer_.BytesRequested().value()); - - // Remainder of packet sent over the wire. - result = framer_.TryDeserialize(framer_.BytesRequested().value()); - ASSERT_TRUE(result); - const CastMessage& message = result.value(); - EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString()); - EXPECT_EQ(4u, framer_.BytesRequested().value()); -} - -TEST_F(CastFramerTest, BigEndianMessageHeader) { - WriteToBuffer(cast_message_str_); - - EXPECT_EQ(4u, framer_.BytesRequested().value()); - ErrorOr<CastMessage> result = framer_.TryDeserialize(4); - EXPECT_FALSE(result); - EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code()); - - const uint32_t expected_size = - openscreen::ReadBigEndian<uint32_t>(openscreen::data(cast_message_str_)); - EXPECT_EQ(expected_size, framer_.BytesRequested().value()); -} - -TEST_F(CastFramerTest, TestSerializeErrorMessageTooLarge) { - CastMessage big_message; - big_message.CopyFrom(cast_message_); - std::string payload; - payload.append(kMaxBodySize + 1, 'x'); - big_message.set_payload_utf8(payload); - EXPECT_FALSE(MessageFramer::Serialize(big_message)); -} - -TEST_F(CastFramerTest, TestCompleteMessageAtOnce) { - WriteToBuffer(cast_message_str_); - - ErrorOr<CastMessage> result = - framer_.TryDeserialize(cast_message_str_.size()); - ASSERT_TRUE(result); - const CastMessage& message = result.value(); - EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString()); - EXPECT_EQ(4u, framer_.BytesRequested().value()); -} - -TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage) { - std::string mangled_cast_message = cast_message_str_; - mangled_cast_message[0] = 88; - mangled_cast_message[1] = 88; - mangled_cast_message[2] = 88; - mangled_cast_message[3] = 88; - WriteToBuffer(mangled_cast_message); - - EXPECT_EQ(4u, framer_.BytesRequested().value()); - ErrorOr<CastMessage> result = framer_.TryDeserialize(4); - ASSERT_FALSE(result); - EXPECT_EQ(ChannelError::kCastV2InvalidMessage, result.error().code()); - ErrorOr<size_t> bytes_requested = framer_.BytesRequested(); - ASSERT_FALSE(bytes_requested); - EXPECT_EQ(ChannelError::kCastV2InvalidMessage, - bytes_requested.error().code()); -} - -TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage2) { - std::string mangled_cast_message = cast_message_str_; - // Header indicates body size is 0x00010001 = 65537 - mangled_cast_message[0] = 0; - mangled_cast_message[1] = 0x1; - mangled_cast_message[2] = 0; - mangled_cast_message[3] = 0x1; - WriteToBuffer(mangled_cast_message); - - EXPECT_EQ(4u, framer_.BytesRequested().value()); - ErrorOr<CastMessage> result = framer_.TryDeserialize(4); - ASSERT_FALSE(result); - EXPECT_EQ(ChannelError::kCastV2InvalidMessage, result.error().code()); - ErrorOr<size_t> bytes_requested = framer_.BytesRequested(); - ASSERT_FALSE(bytes_requested); - EXPECT_EQ(ChannelError::kCastV2InvalidMessage, - bytes_requested.error().code()); -} - -TEST_F(CastFramerTest, TestUnparsableBodyProto) { - // Message header is OK, but the body is replaced with "x"es. - std::string mangled_cast_message = cast_message_str_; - for (size_t i = kHeaderSize; i < mangled_cast_message.size(); ++i) { - std::fill(mangled_cast_message.begin() + kHeaderSize, - mangled_cast_message.end(), 'x'); - } - WriteToBuffer(mangled_cast_message); - - // Send header. - EXPECT_EQ(4u, framer_.BytesRequested().value()); - ErrorOr<CastMessage> result = framer_.TryDeserialize(4); - EXPECT_FALSE(result); - EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code()); - EXPECT_EQ(cast_message_str_.size() - 4, framer_.BytesRequested().value()); - - // Send body, expect an error. - result = framer_.TryDeserialize(framer_.BytesRequested().value()); - ASSERT_FALSE(result); - EXPECT_EQ(ChannelError::kCastV2InvalidMessage, result.error().code()); -} - -} // namespace channel -} // namespace cast |