aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbtolsch <btolsch@chromium.org>2019-10-01 11:39:33 -0700
committerCommit Bot <commit-bot@chromium.org>2019-10-01 20:24:53 +0000
commit9dd4cf87c1214381e72be6b9e1c4af33417e8f99 (patch)
tree9769ab2c497ed03ff6dac98bec5bfc77434fd36e
parent5dc91624a6d9d1369769eab8b603399744cdd63c (diff)
downloadopenscreen-9dd4cf87c1214381e72be6b9e1c4af33417e8f99.tar.gz
Add CastSocket and implementation
This change adds a CastSocket interface to handle sending and receiving CastMessage structures along with an implementation that uses the platform TlsConnection for transport. Bug: openscreen:59 Change-Id: I92dc29b45efb2b1657a2ff6d962f1ed670311370 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1825511 Commit-Queue: Brandon Tolsch <btolsch@chromium.org> Reviewed-by: Ryan Keane <rwkeane@google.com>
-rw-r--r--BUILD.gn1
-rw-r--r--cast/common/channel/BUILD.gn40
-rw-r--r--cast/common/channel/cast_socket.cc85
-rw-r--r--cast/common/channel/cast_socket.h76
-rw-r--r--cast/common/channel/cast_socket_unittest.cc195
-rw-r--r--cast/common/channel/message_framer.cc71
-rw-r--r--cast/common/channel/message_framer.h43
-rw-r--r--cast/common/channel/message_framer_unittest.cc153
-rw-r--r--cast/common/channel/proto/BUILD.gn (renamed from cast/sender/channel/proto/BUILD.gn)0
-rw-r--r--cast/common/channel/proto/cast_channel.proto (renamed from cast/sender/channel/proto/cast_channel.proto)0
-rw-r--r--cast/sender/channel/BUILD.gn10
-rw-r--r--cast/sender/channel/cast_auth_util.h2
-rw-r--r--cast/sender/channel/cast_auth_util_unittest.cc2
-rw-r--r--cast/sender/channel/cast_framer.cc94
-rw-r--r--cast/sender/channel/cast_framer.h59
-rw-r--r--cast/sender/channel/cast_framer_unittest.cc179
16 files changed, 668 insertions, 342 deletions
diff --git a/BUILD.gn b/BUILD.gn
index 82c2105c..4940cfc6 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -50,6 +50,7 @@ executable("openscreen_unittests") {
deps = [
"cast/common:unittests",
"cast/common/certificate:unittests",
+ "cast/common/channel:unittests",
"cast/sender/channel:unittests",
"osp:unittests",
"osp/impl/discovery/mdns:unittests",
diff --git a/cast/common/channel/BUILD.gn b/cast/common/channel/BUILD.gn
new file mode 100644
index 00000000..71b393b2
--- /dev/null
+++ b/cast/common/channel/BUILD.gn
@@ -0,0 +1,40 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("channel") {
+ sources = [
+ "cast_socket.cc",
+ "cast_socket.h",
+ "message_framer.cc",
+ "message_framer.h",
+ ]
+
+ deps = [
+ "../../../util",
+ "proto",
+ ]
+
+ public_deps = [
+ "../../../platform",
+ "../../../third_party/abseil",
+ ]
+}
+
+source_set("unittests") {
+ testonly = true
+ sources = [
+ "cast_socket_unittest.cc",
+ "message_framer_unittest.cc",
+ ]
+
+ deps = [
+ ":channel",
+ "../../../platform",
+ "../../../third_party/googletest:gmock",
+ "../../../third_party/googletest:gtest",
+ "../../../util",
+ "../../common/certificate/proto:unittest_proto",
+ "proto",
+ ]
+}
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc
new file mode 100644
index 00000000..8ad61542
--- /dev/null
+++ b/cast/common/channel/cast_socket.cc
@@ -0,0 +1,85 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/cast_socket.h"
+
+#include "cast/common/channel/message_framer.h"
+#include "platform/api/logging.h"
+
+namespace cast {
+namespace channel {
+
+using message_serialization::DeserializeResult;
+using openscreen::ErrorOr;
+using openscreen::platform::TlsConnection;
+
+CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
+ Client* client,
+ uint32_t socket_id)
+ : client_(client),
+ connection_(std::move(connection)),
+ socket_id_(socket_id) {
+ OSP_DCHECK(client);
+ connection_->set_client(this);
+}
+
+CastSocket::~CastSocket() = default;
+
+Error CastSocket::SendMessage(const CastMessage& message) {
+ if (state_ == State::kError) {
+ return Error::Code::kSocketClosedFailure;
+ }
+
+ const ErrorOr<std::vector<uint8_t>> out =
+ message_serialization::Serialize(message);
+ if (!out) {
+ return out.error();
+ }
+
+ if (state_ == State::kBlocked) {
+ message_queue_.emplace_back(std::move(out.value()));
+ return Error::Code::kNone;
+ }
+
+ connection_->Write(out.value().data(), out.value().size());
+ return Error::Code::kNone;
+}
+
+void CastSocket::OnWriteBlocked(TlsConnection* connection) {
+ if (state_ == State::kOpen) {
+ state_ = State::kBlocked;
+ }
+}
+
+void CastSocket::OnWriteUnblocked(TlsConnection* connection) {
+ if (state_ == State::kBlocked) {
+ state_ = State::kOpen;
+ for (const auto& message : message_queue_) {
+ connection_->Write(message.data(), message.size());
+ }
+ OSP_DCHECK(state_ == State::kOpen) << static_cast<int>(state_);
+ message_queue_.clear();
+ }
+}
+
+void CastSocket::OnError(TlsConnection* connection, Error error) {
+ state_ = State::kError;
+ client_->OnError(this, error);
+}
+
+void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) {
+ read_buffer_.insert(read_buffer_.end(), block.begin(), block.end());
+ ErrorOr<DeserializeResult> message_or_error =
+ message_serialization::TryDeserialize(
+ absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size()));
+ if (!message_or_error) {
+ return;
+ }
+ read_buffer_.erase(read_buffer_.begin(),
+ read_buffer_.begin() + message_or_error.value().length);
+ client_->OnMessage(this, std::move(message_or_error.value().message));
+}
+
+} // namespace channel
+} // namespace cast
diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h
new file mode 100644
index 00000000..a8fa3c48
--- /dev/null
+++ b/cast/common/channel/cast_socket.h
@@ -0,0 +1,76 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_CAST_SOCKET_H_
+#define CAST_COMMON_CHANNEL_CAST_SOCKET_H_
+
+#include <vector>
+
+#include "platform/api/tls_connection.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::Error;
+using TlsConnection = openscreen::platform::TlsConnection;
+
+class CastMessage;
+
+// Represents a simple message-oriented socket for communicating with the Cast
+// V2 protocol. It isn't thread-safe, so it should only be used on the same
+// TaskRunner thread as its TlsConnection.
+class CastSocket : public TlsConnection::Client {
+ public:
+ class Client {
+ public:
+ virtual ~Client() = default;
+
+ // Called when a terminal error on |socket| has occurred.
+ virtual void OnError(CastSocket* socket, Error error) = 0;
+
+ virtual void OnMessage(CastSocket* socket, CastMessage message) = 0;
+ };
+
+ CastSocket(std::unique_ptr<TlsConnection> connection,
+ Client* client,
+ uint32_t socket_id);
+ ~CastSocket();
+
+ // Sends |message| immediately unless the underlying TLS connection is
+ // write-blocked, in which case |message| will be queued. No error is
+ // returned for both queueing and successful sending. An error will be
+ // returned if |message| cannot be serialized for any reason.
+ Error SendMessage(const CastMessage& message);
+
+ void set_client(Client* client) {
+ OSP_DCHECK(client);
+ client_ = client;
+ }
+ uint32_t socket_id() const { return socket_id_; }
+
+ // TlsConnection::Client overrides.
+ void OnWriteBlocked(TlsConnection* connection) override;
+ void OnWriteUnblocked(TlsConnection* connection) override;
+ void OnError(TlsConnection* connection, Error error) override;
+ void OnRead(TlsConnection* connection, std::vector<uint8_t> block) override;
+
+ private:
+ enum class State {
+ kOpen,
+ kBlocked,
+ kError,
+ };
+
+ Client* client_;
+ const std::unique_ptr<TlsConnection> connection_;
+ std::vector<uint8_t> read_buffer_;
+ const uint32_t socket_id_;
+ State state_ = State::kOpen;
+ std::vector<std::vector<uint8_t>> message_queue_;
+};
+
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_COMMON_CHANNEL_CAST_SOCKET_H_
diff --git a/cast/common/channel/cast_socket_unittest.cc b/cast/common/channel/cast_socket_unittest.cc
new file mode 100644
index 00000000..c16ce4dc
--- /dev/null
+++ b/cast/common/channel/cast_socket_unittest.cc
@@ -0,0 +1,195 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/cast_socket.h"
+
+#include "cast/common/channel/message_framer.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_task_runner.h"
+
+namespace cast {
+namespace channel {
+namespace {
+
+using openscreen::ErrorOr;
+using openscreen::IPEndpoint;
+using openscreen::platform::FakeClock;
+using openscreen::platform::FakeTaskRunner;
+using openscreen::platform::TaskRunner;
+using openscreen::platform::TlsConnection;
+
+using ::testing::_;
+using ::testing::Invoke;
+
+class MockTlsConnection final : public TlsConnection {
+ public:
+ MockTlsConnection(TaskRunner* task_runner,
+ IPEndpoint local_address,
+ IPEndpoint remote_address)
+ : TlsConnection(task_runner),
+ local_address_(local_address),
+ remote_address_(remote_address) {}
+
+ ~MockTlsConnection() override = default;
+
+ MOCK_METHOD(void, Write, (const void* data, size_t len));
+
+ const IPEndpoint& local_address() const override { return local_address_; }
+ const IPEndpoint& remote_address() const override { return remote_address_; }
+
+ void OnWriteBlocked() { TlsConnection::OnWriteBlocked(); }
+ void OnWriteUnblocked() { TlsConnection::OnWriteUnblocked(); }
+ void OnError(Error error) { TlsConnection::OnError(error); }
+ void OnRead(std::vector<uint8_t> block) { TlsConnection::OnRead(block); }
+
+ private:
+ const IPEndpoint local_address_;
+ const IPEndpoint remote_address_;
+};
+
+class MockCastSocketClient final : public CastSocket::Client {
+ public:
+ ~MockCastSocketClient() override = default;
+
+ MOCK_METHOD(void, OnError, (CastSocket * socket, Error error));
+ MOCK_METHOD(void, OnMessage, (CastSocket * socket, CastMessage message));
+};
+
+class CastSocketTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ message_.set_protocol_version(CastMessage::CASTV2_1_0);
+ message_.set_source_id("source");
+ message_.set_destination_id("destination");
+ message_.set_namespace_("namespace");
+ message_.set_payload_type(CastMessage::STRING);
+ message_.set_payload_utf8("payload");
+ ErrorOr<std::vector<uint8_t>> serialized_or_error =
+ message_serialization::Serialize(message_);
+ ASSERT_TRUE(serialized_or_error);
+ frame_serial_ = std::move(serialized_or_error.value());
+ }
+
+ protected:
+ FakeClock clock_{openscreen::platform::Clock::now()};
+ FakeTaskRunner task_runner_{&clock_};
+ IPEndpoint local_{{10, 0, 1, 7}, 1234};
+ IPEndpoint remote_{{10, 0, 1, 9}, 4321};
+ std::unique_ptr<MockTlsConnection> moved_connection_{
+ new MockTlsConnection(&task_runner_, local_, remote_)};
+ MockTlsConnection* connection_{moved_connection_.get()};
+ MockCastSocketClient mock_client_;
+ CastSocket socket_{std::move(moved_connection_), &mock_client_, 1};
+ CastMessage message_;
+ std::vector<uint8_t> frame_serial_;
+};
+
+} // namespace
+
+TEST_F(CastSocketTest, SendMessage) {
+ EXPECT_CALL(*connection_, Write(_, _))
+ .WillOnce(Invoke([this](const void* data, size_t len) {
+ EXPECT_EQ(
+ frame_serial_,
+ std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
+ reinterpret_cast<const uint8_t*>(data) + len));
+ }));
+ ASSERT_TRUE(socket_.SendMessage(message_).ok());
+}
+
+TEST_F(CastSocketTest, ReadCompleteMessage) {
+ const uint8_t* data = frame_serial_.data();
+ EXPECT_CALL(mock_client_, OnMessage(_, _))
+ .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
+ EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
+ }));
+ connection_->OnRead(std::vector<uint8_t>(data, data + frame_serial_.size()));
+ task_runner_.RunTasksUntilIdle();
+}
+
+TEST_F(CastSocketTest, ReadChunkedMessage) {
+ const uint8_t* data = frame_serial_.data();
+ EXPECT_CALL(mock_client_, OnMessage(_, _)).Times(0);
+ connection_->OnRead(std::vector<uint8_t>(data, data + 10));
+ task_runner_.RunTasksUntilIdle();
+
+ EXPECT_CALL(mock_client_, OnMessage(_, _))
+ .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
+ EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
+ }));
+ connection_->OnRead(
+ std::vector<uint8_t>(data + 10, data + frame_serial_.size()));
+ task_runner_.RunTasksUntilIdle();
+
+ std::vector<uint8_t> double_message;
+ double_message.insert(double_message.end(), frame_serial_.begin(),
+ frame_serial_.end());
+ double_message.insert(double_message.end(), frame_serial_.begin(),
+ frame_serial_.end());
+ data = double_message.data();
+ EXPECT_CALL(mock_client_, OnMessage(_, _))
+ .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
+ EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
+ }));
+ connection_->OnRead(
+ std::vector<uint8_t>(data, data + frame_serial_.size() + 10));
+ task_runner_.RunTasksUntilIdle();
+
+ EXPECT_CALL(mock_client_, OnMessage(_, _))
+ .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
+ EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
+ }));
+ connection_->OnRead(std::vector<uint8_t>(data + frame_serial_.size() + 10,
+ data + double_message.size()));
+ task_runner_.RunTasksUntilIdle();
+}
+
+TEST_F(CastSocketTest, SendMessageWhileBlocked) {
+ connection_->OnWriteBlocked();
+ task_runner_.RunTasksUntilIdle();
+ EXPECT_CALL(*connection_, Write(_, _)).Times(0);
+ ASSERT_TRUE(socket_.SendMessage(message_).ok());
+
+ EXPECT_CALL(*connection_, Write(_, _))
+ .WillOnce(Invoke([this](const void* data, size_t len) {
+ EXPECT_EQ(
+ frame_serial_,
+ std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
+ reinterpret_cast<const uint8_t*>(data) + len));
+ }));
+ connection_->OnWriteUnblocked();
+ task_runner_.RunTasksUntilIdle();
+
+ EXPECT_CALL(*connection_, Write(_, _)).Times(0);
+ connection_->OnWriteBlocked();
+ task_runner_.RunTasksUntilIdle();
+ connection_->OnWriteUnblocked();
+ task_runner_.RunTasksUntilIdle();
+}
+
+TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) {
+ connection_->OnWriteBlocked();
+ task_runner_.RunTasksUntilIdle();
+ EXPECT_CALL(*connection_, Write(_, _)).Times(0);
+ ASSERT_TRUE(socket_.SendMessage(message_).ok());
+
+ EXPECT_CALL(*connection_, Write(_, _))
+ .WillOnce(Invoke([this](const void* data, size_t len) {
+ EXPECT_EQ(
+ frame_serial_,
+ std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
+ reinterpret_cast<const uint8_t*>(data) + len));
+ connection_->OnError(Error::Code::kUnknownError);
+ }));
+ connection_->OnWriteUnblocked();
+ task_runner_.RunTasksUntilIdle();
+
+ EXPECT_CALL(*connection_, Write(_, _)).Times(0);
+ ASSERT_FALSE(socket_.SendMessage(message_).ok());
+}
+
+} // namespace channel
+} // namespace cast
diff --git a/cast/common/channel/message_framer.cc b/cast/common/channel/message_framer.cc
new file mode 100644
index 00000000..97a82466
--- /dev/null
+++ b/cast/common/channel/message_framer.cc
@@ -0,0 +1,71 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/message_framer.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <limits>
+
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "platform/api/logging.h"
+#include "util/big_endian.h"
+
+namespace cast {
+namespace channel {
+namespace message_serialization {
+
+using openscreen::Error;
+
+namespace {
+
+static constexpr size_t kHeaderSize = sizeof(uint32_t);
+
+// Cast specifies a max message body size of 64 KiB.
+static constexpr size_t kMaxBodySize = 65536;
+
+} // namespace
+
+ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message) {
+ const size_t message_size = message.ByteSizeLong();
+ if (message_size > kMaxBodySize || message_size == 0) {
+ return Error::Code::kCastV2InvalidMessage;
+ }
+ std::vector<uint8_t> out(message_size + kHeaderSize, 0);
+ openscreen::WriteBigEndian<uint32_t>(message_size, out.data());
+ if (!message.SerializeToArray(&out[kHeaderSize], message_size)) {
+ return Error::Code::kCastV2InvalidMessage;
+ }
+ return out;
+}
+
+ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input) {
+ if (input.size() < kHeaderSize) {
+ return Error::Code::kInsufficientBuffer;
+ }
+
+ const uint32_t message_size =
+ openscreen::ReadBigEndian<uint32_t>(input.data());
+ if (message_size > kMaxBodySize) {
+ return Error::Code::kCastV2InvalidMessage;
+ }
+
+ if (input.size() < (kHeaderSize + message_size)) {
+ return Error::Code::kInsufficientBuffer;
+ }
+
+ DeserializeResult result;
+ if (!result.message.ParseFromArray(input.data() + kHeaderSize,
+ message_size)) {
+ return Error::Code::kCastV2InvalidMessage;
+ }
+ result.length = kHeaderSize + message_size;
+
+ return result;
+}
+
+} // namespace message_serialization
+} // namespace channel
+} // namespace cast
diff --git a/cast/common/channel/message_framer.h b/cast/common/channel/message_framer.h
new file mode 100644
index 00000000..c092487c
--- /dev/null
+++ b/cast/common/channel/message_framer.h
@@ -0,0 +1,43 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_
+#define CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <memory>
+#include <vector>
+
+#include "absl/types/span.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "platform/base/error.h"
+
+namespace cast {
+namespace channel {
+namespace message_serialization {
+
+using openscreen::ErrorOr;
+
+// Serializes |message_proto| into |message_data|.
+// Returns true if the message was serialized successfully, false otherwise.
+ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message);
+
+struct DeserializeResult {
+ CastMessage message;
+ size_t length;
+};
+
+// Reads bytes from |input| and returns a new CastMessage if one is fully
+// read. Returns a parsed CastMessage if a message was received in its
+// entirety, and an error otherwise. The result also contains the number of
+// bytes consumed from |input| when a parse succeeds.
+ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input);
+
+} // namespace message_serialization
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_
diff --git a/cast/common/channel/message_framer_unittest.cc b/cast/common/channel/message_framer_unittest.cc
new file mode 100644
index 00000000..ae2ff33e
--- /dev/null
+++ b/cast/common/channel/message_framer_unittest.cc
@@ -0,0 +1,153 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/message_framer.h"
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <string>
+
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "gtest/gtest.h"
+#include "util/big_endian.h"
+#include "util/std_util.h"
+
+namespace cast {
+namespace channel {
+namespace message_serialization {
+
+using openscreen::Error;
+
+namespace {
+
+static constexpr size_t kHeaderSize = sizeof(uint32_t);
+
+// Cast specifies a max message body size of 64 KiB.
+static constexpr size_t kMaxBodySize = 65536;
+
+} // namespace
+
+class CastFramerTest : public testing::Test {
+ public:
+ CastFramerTest() : buffer_(kHeaderSize + kMaxBodySize) {}
+
+ void SetUp() override {
+ cast_message_.set_protocol_version(CastMessage::CASTV2_1_0);
+ cast_message_.set_source_id("source");
+ cast_message_.set_destination_id("destination");
+ cast_message_.set_namespace_("namespace");
+ cast_message_.set_payload_type(CastMessage::STRING);
+ cast_message_.set_payload_utf8("payload");
+ ErrorOr<std::vector<uint8_t>> result = Serialize(cast_message_);
+ ASSERT_TRUE(result.is_value());
+ cast_message_serial_ = std::move(result.value());
+ }
+
+ void WriteToBuffer(const std::vector<uint8_t>& data) {
+ memcpy(&buffer_[0], data.data(), data.size());
+ }
+
+ absl::Span<uint8_t> GetSpan(size_t size) {
+ return absl::Span<uint8_t>(&buffer_[0], size);
+ }
+ absl::Span<uint8_t> GetSpan() { return GetSpan(cast_message_serial_.size()); }
+
+ protected:
+ CastMessage cast_message_;
+ std::vector<uint8_t> cast_message_serial_;
+ std::vector<uint8_t> buffer_;
+};
+
+TEST_F(CastFramerTest, TestMessageFramerCompleteMessage) {
+ WriteToBuffer(cast_message_serial_);
+
+ // Receive 1 byte of the header, framer demands 3 more bytes.
+ ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(1));
+ EXPECT_FALSE(result);
+ EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code());
+
+ // TryDeserialize remaining 3, expect that the framer has moved on to
+ // requesting the body contents.
+ result = TryDeserialize(GetSpan(3));
+ EXPECT_FALSE(result);
+ EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code());
+
+ // Remainder of packet sent over the wire.
+ result = TryDeserialize(GetSpan());
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.value().length, cast_message_serial_.size());
+ const CastMessage& message = result.value().message;
+ EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString());
+}
+
+TEST_F(CastFramerTest, TestSerializeErrorMessageTooLarge) {
+ CastMessage big_message;
+ big_message.CopyFrom(cast_message_);
+ std::string payload;
+ payload.append(kMaxBodySize + 1, 'x');
+ big_message.set_payload_utf8(payload);
+ EXPECT_FALSE(Serialize(big_message));
+}
+
+TEST_F(CastFramerTest, TestCompleteMessageAtOnce) {
+ WriteToBuffer(cast_message_serial_);
+
+ ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan());
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.value().length, cast_message_serial_.size());
+ const CastMessage& message = result.value().message;
+ EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString());
+}
+
+TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage) {
+ std::vector<uint8_t> mangled_cast_message = cast_message_serial_;
+ mangled_cast_message[0] = 88;
+ mangled_cast_message[1] = 88;
+ mangled_cast_message[2] = 88;
+ mangled_cast_message[3] = 88;
+ WriteToBuffer(mangled_cast_message);
+
+ ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4));
+ ASSERT_FALSE(result);
+ EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code());
+}
+
+TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage2) {
+ std::vector<uint8_t> mangled_cast_message = cast_message_serial_;
+ // Header indicates body size is 0x00010001 = 65537
+ mangled_cast_message[0] = 0;
+ mangled_cast_message[1] = 0x1;
+ mangled_cast_message[2] = 0;
+ mangled_cast_message[3] = 0x1;
+ WriteToBuffer(mangled_cast_message);
+
+ ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4));
+ ASSERT_FALSE(result);
+ EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code());
+}
+
+TEST_F(CastFramerTest, TestUnparsableBodyProto) {
+ // Message header is OK, but the body is replaced with "x"es.
+ std::vector<uint8_t> mangled_cast_message = cast_message_serial_;
+ for (size_t i = kHeaderSize; i < mangled_cast_message.size(); ++i) {
+ std::fill(mangled_cast_message.begin() + kHeaderSize,
+ mangled_cast_message.end(), 'x');
+ }
+ WriteToBuffer(mangled_cast_message);
+
+ // Send header.
+ ErrorOr<DeserializeResult> result = TryDeserialize(GetSpan(4));
+ EXPECT_FALSE(result);
+ EXPECT_EQ(Error::Code::kInsufficientBuffer, result.error().code());
+
+ // Send body, expect an error.
+ result = TryDeserialize(GetSpan());
+ ASSERT_FALSE(result);
+ EXPECT_EQ(Error::Code::kCastV2InvalidMessage, result.error().code());
+}
+
+} // namespace message_serialization
+} // namespace channel
+} // namespace cast
diff --git a/cast/sender/channel/proto/BUILD.gn b/cast/common/channel/proto/BUILD.gn
index c3bfa439..c3bfa439 100644
--- a/cast/sender/channel/proto/BUILD.gn
+++ b/cast/common/channel/proto/BUILD.gn
diff --git a/cast/sender/channel/proto/cast_channel.proto b/cast/common/channel/proto/cast_channel.proto
index 57c7b3f3..57c7b3f3 100644
--- a/cast/sender/channel/proto/cast_channel.proto
+++ b/cast/common/channel/proto/cast_channel.proto
diff --git a/cast/sender/channel/BUILD.gn b/cast/sender/channel/BUILD.gn
index 13b061c1..85f633b0 100644
--- a/cast/sender/channel/BUILD.gn
+++ b/cast/sender/channel/BUILD.gn
@@ -6,18 +6,14 @@ source_set("channel") {
sources = [
"cast_auth_util.cc",
"cast_auth_util.h",
- "cast_framer.cc",
- "cast_framer.h",
]
deps = [
- "../../../util",
- "proto",
+ "../../common/channel/proto",
]
public_deps = [
"../../../platform",
- "../../../third_party/abseil",
]
}
@@ -25,15 +21,13 @@ source_set("unittests") {
testonly = true
sources = [
"cast_auth_util_unittest.cc",
- "cast_framer_unittest.cc",
]
deps = [
":channel",
"../../../platform",
"../../../third_party/googletest:gtest",
- "../../../util",
"../../common/certificate/proto:unittest_proto",
- "proto",
+ "../../common/channel/proto",
]
}
diff --git a/cast/sender/channel/cast_auth_util.h b/cast/sender/channel/cast_auth_util.h
index b4a81e0b..35f1d028 100644
--- a/cast/sender/channel/cast_auth_util.h
+++ b/cast/sender/channel/cast_auth_util.h
@@ -10,7 +10,7 @@
#include <string>
#include "cast/common/certificate/cast_cert_validator.h"
-#include "cast/sender/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
#include "platform/base/error.h"
namespace cast {
diff --git a/cast/sender/channel/cast_auth_util_unittest.cc b/cast/sender/channel/cast_auth_util_unittest.cc
index 10819362..61d49aa1 100644
--- a/cast/sender/channel/cast_auth_util_unittest.cc
+++ b/cast/sender/channel/cast_auth_util_unittest.cc
@@ -10,7 +10,7 @@
#include "cast/common/certificate/cast_crl.h"
#include "cast/common/certificate/proto/test_suite.pb.h"
#include "cast/common/certificate/test_helpers.h"
-#include "cast/sender/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
#include "gtest/gtest.h"
#include "platform/api/logging.h"
#include "platform/api/time.h"
diff --git a/cast/sender/channel/cast_framer.cc b/cast/sender/channel/cast_framer.cc
deleted file mode 100644
index b5a8acb1..00000000
--- a/cast/sender/channel/cast_framer.cc
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "cast/sender/channel/cast_framer.h"
-
-#include <stdlib.h>
-#include <string.h>
-
-#include <limits>
-
-#include "cast/sender/channel/proto/cast_channel.pb.h"
-#include "platform/api/logging.h"
-#include "util/big_endian.h"
-#include "util/std_util.h"
-
-namespace cast {
-namespace channel {
-
-using ChannelError = openscreen::Error::Code;
-
-namespace {
-
-static constexpr size_t kHeaderSize = sizeof(uint32_t);
-
-// Cast specifies a max message body size of 64 KiB.
-static constexpr size_t kMaxBodySize = 65536;
-
-} // namespace
-
-MessageFramer::MessageFramer(absl::Span<uint8_t> input_buffer)
- : input_buffer_(input_buffer) {}
-
-MessageFramer::~MessageFramer() = default;
-
-// static
-ErrorOr<std::string> MessageFramer::Serialize(const CastMessage& message) {
- const size_t message_size = message.ByteSizeLong();
- if (message_size > kMaxBodySize || message_size == 0) {
- return ChannelError::kCastV2InvalidMessage;
- }
- std::string out(message_size + kHeaderSize, 0);
- openscreen::WriteBigEndian<uint32_t>(message_size, openscreen::data(out));
- if (!message.SerializeToArray(&out[kHeaderSize], message_size)) {
- return ChannelError::kCastV2InvalidMessage;
- }
- return out;
-}
-
-ErrorOr<size_t> MessageFramer::BytesRequested() const {
- if (message_bytes_received_ < kHeaderSize) {
- return kHeaderSize - message_bytes_received_;
- }
-
- const uint32_t message_size =
- openscreen::ReadBigEndian<uint32_t>(input_buffer_.data());
- if (message_size > kMaxBodySize) {
- return ChannelError::kCastV2InvalidMessage;
- }
- return (kHeaderSize + message_size) - message_bytes_received_;
-}
-
-ErrorOr<CastMessage> MessageFramer::TryDeserialize(size_t byte_count) {
- message_bytes_received_ += byte_count;
- if (message_bytes_received_ > input_buffer_.size()) {
- return ChannelError::kCastV2InvalidMessage;
- }
-
- if (message_bytes_received_ < kHeaderSize) {
- return ChannelError::kInsufficientBuffer;
- }
-
- const uint32_t message_size =
- openscreen::ReadBigEndian<uint32_t>(input_buffer_.data());
- if (message_size > kMaxBodySize) {
- return ChannelError::kCastV2InvalidMessage;
- }
-
- if (message_bytes_received_ < (kHeaderSize + message_size)) {
- return ChannelError::kInsufficientBuffer;
- }
-
- CastMessage parsed_message;
- if (!parsed_message.ParseFromArray(input_buffer_.data() + kHeaderSize,
- message_size)) {
- return ChannelError::kCastV2InvalidMessage;
- }
-
- message_bytes_received_ = 0;
- return parsed_message;
-}
-
-} // namespace channel
-} // namespace cast
diff --git a/cast/sender/channel/cast_framer.h b/cast/sender/channel/cast_framer.h
deleted file mode 100644
index 8fbabfd3..00000000
--- a/cast/sender/channel/cast_framer.h
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef CAST_SENDER_CHANNEL_CAST_FRAMER_H_
-#define CAST_SENDER_CHANNEL_CAST_FRAMER_H_
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <memory>
-#include <string>
-
-#include "absl/types/span.h"
-#include "platform/base/error.h"
-
-namespace cast {
-namespace channel {
-
-class CastMessage;
-
-using openscreen::ErrorOr;
-
-// Class for constructing and parsing CastMessage packet data.
-class MessageFramer {
- public:
- // Serializes |message_proto| into |message_data|.
- // Returns true if the message was serialized successfully, false otherwise.
- static ErrorOr<std::string> Serialize(const CastMessage& message);
-
- explicit MessageFramer(absl::Span<uint8_t> input_buffer);
- ~MessageFramer();
-
- // The number of bytes required from the next |input_buffer| passed to
- // TryDeserialize to complete the CastMessage being read. Returns zero if
- // there has been a parsing error.
- ErrorOr<size_t> BytesRequested() const;
-
- // Reads bytes from |input_buffer_| and returns a new CastMessage if one is
- // fully read.
- //
- // |byte_count| Number of additional bytes available in |input_buffer_|.
- // Returns a pointer to a parsed CastMessage if a message was received in its
- // entirety, empty unique_ptr if parsing was successful but didn't produce a
- // complete message, and an error otherwise.
- ErrorOr<CastMessage> TryDeserialize(size_t byte_count);
-
- private:
- // Total size of the message received so far in bytes (head + body).
- size_t message_bytes_received_ = 0;
-
- // Data buffer wherein the caller should place message data for ingest.
- absl::Span<uint8_t> input_buffer_;
-};
-
-} // namespace channel
-} // namespace cast
-
-#endif // CAST_SENDER_CHANNEL_CAST_FRAMER_H_
diff --git a/cast/sender/channel/cast_framer_unittest.cc b/cast/sender/channel/cast_framer_unittest.cc
deleted file mode 100644
index 0cd10c71..00000000
--- a/cast/sender/channel/cast_framer_unittest.cc
+++ /dev/null
@@ -1,179 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "cast/sender/channel/cast_framer.h"
-
-#include <stddef.h>
-
-#include <algorithm>
-#include <string>
-
-#include "cast/sender/channel/proto/cast_channel.pb.h"
-#include "gtest/gtest.h"
-#include "util/big_endian.h"
-#include "util/std_util.h"
-
-namespace cast {
-namespace channel {
-
-using ChannelError = openscreen::Error::Code;
-
-namespace {
-
-static constexpr size_t kHeaderSize = sizeof(uint32_t);
-
-// Cast specifies a max message body size of 64 KiB.
-static constexpr size_t kMaxBodySize = 65536;
-
-} // namespace
-
-class CastFramerTest : public testing::Test {
- public:
- CastFramerTest()
- : buffer_(kHeaderSize + kMaxBodySize),
- framer_(absl::Span<uint8_t>(&buffer_[0], buffer_.size())) {}
-
- void SetUp() override {
- cast_message_.set_protocol_version(CastMessage::CASTV2_1_0);
- cast_message_.set_source_id("source");
- cast_message_.set_destination_id("destination");
- cast_message_.set_namespace_("namespace");
- cast_message_.set_payload_type(CastMessage::STRING);
- cast_message_.set_payload_utf8("payload");
- ErrorOr<std::string> result = MessageFramer::Serialize(cast_message_);
- ASSERT_TRUE(result.is_value());
- cast_message_str_ = std::move(result.value());
- }
-
- void WriteToBuffer(const std::string& data) {
- memcpy(&buffer_[0], data.data(), data.size());
- }
-
- protected:
- CastMessage cast_message_;
- std::string cast_message_str_;
- std::vector<uint8_t> buffer_;
- MessageFramer framer_;
-};
-
-TEST_F(CastFramerTest, TestMessageFramerCompleteMessage) {
- WriteToBuffer(cast_message_str_);
-
- // Receive 1 byte of the header, framer demands 3 more bytes.
- EXPECT_EQ(4u, framer_.BytesRequested().value());
- ErrorOr<CastMessage> result = framer_.TryDeserialize(1);
- EXPECT_FALSE(result);
- EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code());
- EXPECT_EQ(3u, framer_.BytesRequested().value());
-
- // TryDeserialize remaining 3, expect that the framer has moved on to
- // requesting the body contents.
- result = framer_.TryDeserialize(3);
- EXPECT_FALSE(result);
- EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code());
- EXPECT_EQ(cast_message_str_.size() - kHeaderSize,
- framer_.BytesRequested().value());
-
- // Remainder of packet sent over the wire.
- result = framer_.TryDeserialize(framer_.BytesRequested().value());
- ASSERT_TRUE(result);
- const CastMessage& message = result.value();
- EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString());
- EXPECT_EQ(4u, framer_.BytesRequested().value());
-}
-
-TEST_F(CastFramerTest, BigEndianMessageHeader) {
- WriteToBuffer(cast_message_str_);
-
- EXPECT_EQ(4u, framer_.BytesRequested().value());
- ErrorOr<CastMessage> result = framer_.TryDeserialize(4);
- EXPECT_FALSE(result);
- EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code());
-
- const uint32_t expected_size =
- openscreen::ReadBigEndian<uint32_t>(openscreen::data(cast_message_str_));
- EXPECT_EQ(expected_size, framer_.BytesRequested().value());
-}
-
-TEST_F(CastFramerTest, TestSerializeErrorMessageTooLarge) {
- CastMessage big_message;
- big_message.CopyFrom(cast_message_);
- std::string payload;
- payload.append(kMaxBodySize + 1, 'x');
- big_message.set_payload_utf8(payload);
- EXPECT_FALSE(MessageFramer::Serialize(big_message));
-}
-
-TEST_F(CastFramerTest, TestCompleteMessageAtOnce) {
- WriteToBuffer(cast_message_str_);
-
- ErrorOr<CastMessage> result =
- framer_.TryDeserialize(cast_message_str_.size());
- ASSERT_TRUE(result);
- const CastMessage& message = result.value();
- EXPECT_EQ(message.SerializeAsString(), cast_message_.SerializeAsString());
- EXPECT_EQ(4u, framer_.BytesRequested().value());
-}
-
-TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage) {
- std::string mangled_cast_message = cast_message_str_;
- mangled_cast_message[0] = 88;
- mangled_cast_message[1] = 88;
- mangled_cast_message[2] = 88;
- mangled_cast_message[3] = 88;
- WriteToBuffer(mangled_cast_message);
-
- EXPECT_EQ(4u, framer_.BytesRequested().value());
- ErrorOr<CastMessage> result = framer_.TryDeserialize(4);
- ASSERT_FALSE(result);
- EXPECT_EQ(ChannelError::kCastV2InvalidMessage, result.error().code());
- ErrorOr<size_t> bytes_requested = framer_.BytesRequested();
- ASSERT_FALSE(bytes_requested);
- EXPECT_EQ(ChannelError::kCastV2InvalidMessage,
- bytes_requested.error().code());
-}
-
-TEST_F(CastFramerTest, TestTryDeserializeIllegalLargeMessage2) {
- std::string mangled_cast_message = cast_message_str_;
- // Header indicates body size is 0x00010001 = 65537
- mangled_cast_message[0] = 0;
- mangled_cast_message[1] = 0x1;
- mangled_cast_message[2] = 0;
- mangled_cast_message[3] = 0x1;
- WriteToBuffer(mangled_cast_message);
-
- EXPECT_EQ(4u, framer_.BytesRequested().value());
- ErrorOr<CastMessage> result = framer_.TryDeserialize(4);
- ASSERT_FALSE(result);
- EXPECT_EQ(ChannelError::kCastV2InvalidMessage, result.error().code());
- ErrorOr<size_t> bytes_requested = framer_.BytesRequested();
- ASSERT_FALSE(bytes_requested);
- EXPECT_EQ(ChannelError::kCastV2InvalidMessage,
- bytes_requested.error().code());
-}
-
-TEST_F(CastFramerTest, TestUnparsableBodyProto) {
- // Message header is OK, but the body is replaced with "x"es.
- std::string mangled_cast_message = cast_message_str_;
- for (size_t i = kHeaderSize; i < mangled_cast_message.size(); ++i) {
- std::fill(mangled_cast_message.begin() + kHeaderSize,
- mangled_cast_message.end(), 'x');
- }
- WriteToBuffer(mangled_cast_message);
-
- // Send header.
- EXPECT_EQ(4u, framer_.BytesRequested().value());
- ErrorOr<CastMessage> result = framer_.TryDeserialize(4);
- EXPECT_FALSE(result);
- EXPECT_EQ(ChannelError::kInsufficientBuffer, result.error().code());
- EXPECT_EQ(cast_message_str_.size() - 4, framer_.BytesRequested().value());
-
- // Send body, expect an error.
- result = framer_.TryDeserialize(framer_.BytesRequested().value());
- ASSERT_FALSE(result);
- EXPECT_EQ(ChannelError::kCastV2InvalidMessage, result.error().code());
-}
-
-} // namespace channel
-} // namespace cast