aboutsummaryrefslogtreecommitdiff
path: root/cast/common
diff options
context:
space:
mode:
authorbtolsch <btolsch@chromium.org>2020-04-16 11:44:36 -0700
committerCommit Bot <commit-bot@chromium.org>2020-04-16 19:14:17 +0000
commit371c663c9113479198564307868e7fef918478e2 (patch)
tree3c4d4704ee740835f53b884206f8a0d3a49eca20 /cast/common
parentc117a70b3913598e3dfe37b2668a64c0eb5634c7 (diff)
downloadopenscreen-371c663c9113479198564307868e7fef918478e2.tar.gz
Read all available CastMessages in CastSocket
This change makes CastSocket repeatedly deserialize CastMessages whenever it gets a new block from TLS. This fixes the case where multiple messages are received in one TLS read. Bug: 1050913 Change-Id: Ia29e2c82c29d921073eaa20874a835d33b2eb4bb Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2151857 Commit-Queue: Brandon Tolsch <btolsch@chromium.org> Reviewed-by: Ryan Keane <rwkeane@google.com>
Diffstat (limited to 'cast/common')
-rw-r--r--cast/common/channel/cast_socket.cc22
-rw-r--r--cast/common/channel/cast_socket_unittest.cc28
2 files changed, 41 insertions, 9 deletions
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc
index 86880e23..070c4b53 100644
--- a/cast/common/channel/cast_socket.cc
+++ b/cast/common/channel/cast_socket.cc
@@ -71,15 +71,19 @@ void CastSocket::OnError(TlsConnection* connection, Error error) {
void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) {
read_buffer_.insert(read_buffer_.end(), block.begin(), block.end());
- ErrorOr<DeserializeResult> message_or_error =
- message_serialization::TryDeserialize(
- absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size()));
- if (!message_or_error) {
- return;
- }
- read_buffer_.erase(read_buffer_.begin(),
- read_buffer_.begin() + message_or_error.value().length);
- client_->OnMessage(this, std::move(message_or_error.value().message));
+ // NOTE: Read as many messages as possible out of |read_buffer_| since we only
+ // get one callback opportunity for this.
+ do {
+ ErrorOr<DeserializeResult> message_or_error =
+ message_serialization::TryDeserialize(
+ absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size()));
+ if (!message_or_error) {
+ return;
+ }
+ read_buffer_.erase(read_buffer_.begin(),
+ read_buffer_.begin() + message_or_error.value().length);
+ client_->OnMessage(this, std::move(message_or_error.value().message));
+ } while (!read_buffer_.empty());
}
int CastSocket::g_next_socket_id_ = 1;
diff --git a/cast/common/channel/cast_socket_unittest.cc b/cast/common/channel/cast_socket_unittest.cc
index decfb041..351224fd 100644
--- a/cast/common/channel/cast_socket_unittest.cc
+++ b/cast/common/channel/cast_socket_unittest.cc
@@ -121,6 +121,34 @@ TEST_F(CastSocketTest, ReadChunkedMessage) {
data + double_message.size()));
}
+TEST_F(CastSocketTest, ReadMultipleMessagesPerBlock) {
+ CastMessage message2;
+ std::vector<uint8_t> frame_serial2;
+ message2.set_protocol_version(CastMessage::CASTV2_1_0);
+ message2.set_source_id("alt-source");
+ message2.set_destination_id("alt-destination");
+ message2.set_namespace_("alt-namespace");
+ message2.set_payload_type(CastMessage::STRING);
+ message2.set_payload_utf8("alternate payload");
+ ErrorOr<std::vector<uint8_t>> serialized_or_error =
+ message_serialization::Serialize(message2);
+ ASSERT_TRUE(serialized_or_error);
+ frame_serial2 = std::move(serialized_or_error.value());
+
+ std::vector<uint8_t> send_data;
+ send_data.reserve(frame_serial_.size() + frame_serial2.size());
+ send_data.insert(send_data.end(), frame_serial_.begin(), frame_serial_.end());
+ send_data.insert(send_data.end(), frame_serial2.begin(), frame_serial2.end());
+ EXPECT_CALL(mock_client(), OnMessage(_, _))
+ .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
+ EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
+ }))
+ .WillOnce([message2](CastSocket* socket, CastMessage message) {
+ EXPECT_EQ(message2.SerializeAsString(), message.SerializeAsString());
+ });
+ connection().OnRead(std::move(send_data));
+}
+
TEST_F(CastSocketTest, SanitizedAddress) {
std::array<uint8_t, 2> result1 = socket().GetSanitizedIpAddress();
EXPECT_EQ(result1[0], 1u);