diff options
author | btolsch <btolsch@chromium.org> | 2020-04-16 11:44:36 -0700 |
---|---|---|
committer | Commit Bot <commit-bot@chromium.org> | 2020-04-16 19:14:17 +0000 |
commit | 371c663c9113479198564307868e7fef918478e2 (patch) | |
tree | 3c4d4704ee740835f53b884206f8a0d3a49eca20 /cast/common | |
parent | c117a70b3913598e3dfe37b2668a64c0eb5634c7 (diff) | |
download | openscreen-371c663c9113479198564307868e7fef918478e2.tar.gz |
Read all available CastMessages in CastSocket
This change makes CastSocket repeatedly deserialize CastMessages
whenever it gets a new block from TLS. This fixes the case where
multiple messages are received in one TLS read.
Bug: 1050913
Change-Id: Ia29e2c82c29d921073eaa20874a835d33b2eb4bb
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2151857
Commit-Queue: Brandon Tolsch <btolsch@chromium.org>
Reviewed-by: Ryan Keane <rwkeane@google.com>
Diffstat (limited to 'cast/common')
-rw-r--r-- | cast/common/channel/cast_socket.cc | 22 | ||||
-rw-r--r-- | cast/common/channel/cast_socket_unittest.cc | 28 |
2 files changed, 41 insertions, 9 deletions
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc index 86880e23..070c4b53 100644 --- a/cast/common/channel/cast_socket.cc +++ b/cast/common/channel/cast_socket.cc @@ -71,15 +71,19 @@ void CastSocket::OnError(TlsConnection* connection, Error error) { void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) { read_buffer_.insert(read_buffer_.end(), block.begin(), block.end()); - ErrorOr<DeserializeResult> message_or_error = - message_serialization::TryDeserialize( - absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size())); - if (!message_or_error) { - return; - } - read_buffer_.erase(read_buffer_.begin(), - read_buffer_.begin() + message_or_error.value().length); - client_->OnMessage(this, std::move(message_or_error.value().message)); + // NOTE: Read as many messages as possible out of |read_buffer_| since we only + // get one callback opportunity for this. + do { + ErrorOr<DeserializeResult> message_or_error = + message_serialization::TryDeserialize( + absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size())); + if (!message_or_error) { + return; + } + read_buffer_.erase(read_buffer_.begin(), + read_buffer_.begin() + message_or_error.value().length); + client_->OnMessage(this, std::move(message_or_error.value().message)); + } while (!read_buffer_.empty()); } int CastSocket::g_next_socket_id_ = 1; diff --git a/cast/common/channel/cast_socket_unittest.cc b/cast/common/channel/cast_socket_unittest.cc index decfb041..351224fd 100644 --- a/cast/common/channel/cast_socket_unittest.cc +++ b/cast/common/channel/cast_socket_unittest.cc @@ -121,6 +121,34 @@ TEST_F(CastSocketTest, ReadChunkedMessage) { data + double_message.size())); } +TEST_F(CastSocketTest, ReadMultipleMessagesPerBlock) { + CastMessage message2; + std::vector<uint8_t> frame_serial2; + message2.set_protocol_version(CastMessage::CASTV2_1_0); + message2.set_source_id("alt-source"); + message2.set_destination_id("alt-destination"); + message2.set_namespace_("alt-namespace"); + message2.set_payload_type(CastMessage::STRING); + message2.set_payload_utf8("alternate payload"); + ErrorOr<std::vector<uint8_t>> serialized_or_error = + message_serialization::Serialize(message2); + ASSERT_TRUE(serialized_or_error); + frame_serial2 = std::move(serialized_or_error.value()); + + std::vector<uint8_t> send_data; + send_data.reserve(frame_serial_.size() + frame_serial2.size()); + send_data.insert(send_data.end(), frame_serial_.begin(), frame_serial_.end()); + send_data.insert(send_data.end(), frame_serial2.begin(), frame_serial2.end()); + EXPECT_CALL(mock_client(), OnMessage(_, _)) + .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { + EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); + })) + .WillOnce([message2](CastSocket* socket, CastMessage message) { + EXPECT_EQ(message2.SerializeAsString(), message.SerializeAsString()); + }); + connection().OnRead(std::move(send_data)); +} + TEST_F(CastSocketTest, SanitizedAddress) { std::array<uint8_t, 2> result1 = socket().GetSanitizedIpAddress(); EXPECT_EQ(result1[0], 1u); |