aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cast/common/channel/cast_message_handler.h1
-rw-r--r--cast/common/channel/cast_socket_message_port.cc2
-rw-r--r--cast/common/channel/connection_namespace_handler.cc12
-rw-r--r--cast/common/channel/virtual_connection.h2
-rw-r--r--cast/common/channel/virtual_connection_router.cc61
-rw-r--r--cast/common/channel/virtual_connection_router.h3
-rw-r--r--cast/common/channel/virtual_connection_router_unittest.cc248
-rw-r--r--cast/common/public/cast_socket.h5
-rw-r--r--cast/receiver/channel/device_auth_namespace_handler.cc6
-rw-r--r--cast/sender/cast_platform_client.cc9
10 files changed, 283 insertions, 66 deletions
diff --git a/cast/common/channel/cast_message_handler.h b/cast/common/channel/cast_message_handler.h
index cd0d13e6..e478d156 100644
--- a/cast/common/channel/cast_message_handler.h
+++ b/cast/common/channel/cast_message_handler.h
@@ -17,6 +17,7 @@ class CastMessageHandler {
public:
virtual ~CastMessageHandler() = default;
+ // |socket| is null if the source of the message is a local peer.
virtual void OnMessage(VirtualConnectionRouter* router,
CastSocket* socket,
::cast::channel::CastMessage message) = 0;
diff --git a/cast/common/channel/cast_socket_message_port.cc b/cast/common/channel/cast_socket_message_port.cc
index c3ca0df0..b6d65123 100644
--- a/cast/common/channel/cast_socket_message_port.cc
+++ b/cast/common/channel/cast_socket_message_port.cc
@@ -32,7 +32,7 @@ void CastSocketMessagePort::SetSocket(WeakPtr<CastSocket> socket) {
}
int CastSocketMessagePort::GetSocketId() {
- return socket_ ? socket_->socket_id() : -1;
+ return ToCastSocketId(socket_.get());
}
void CastSocketMessagePort::SetClient(MessagePort::Client* client,
diff --git a/cast/common/channel/connection_namespace_handler.cc b/cast/common/channel/connection_namespace_handler.cc
index 396b5d53..a449dcbd 100644
--- a/cast/common/channel/connection_namespace_handler.cc
+++ b/cast/common/channel/connection_namespace_handler.cc
@@ -4,7 +4,9 @@
#include "cast/common/channel/connection_namespace_handler.h"
+#include <string>
#include <type_traits>
+#include <utility>
#include "absl/types/optional.h"
#include "cast/common/channel/message_util.h"
@@ -138,7 +140,7 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
VirtualConnection virtual_conn{std::move(message.destination_id()),
std::move(message.source_id()),
- socket->socket_id()};
+ ToCastSocketId(socket)};
if (!vc_policy_->IsConnectionAllowed(virtual_conn)) {
SendClose(router, std::move(virtual_conn));
return;
@@ -187,7 +189,11 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0;
}
- data.ip_fragment = socket->GetSanitizedIpAddress();
+ if (socket) {
+ data.ip_fragment = socket->GetSanitizedIpAddress();
+ } else {
+ data.ip_fragment = {};
+ }
OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", "
<< virtual_conn.peer_id << ", " << virtual_conn.socket_id;
@@ -208,7 +214,7 @@ void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router,
Json::Value parsed_message) {
VirtualConnection virtual_conn{std::move(message.destination_id()),
std::move(message.source_id()),
- socket->socket_id()};
+ ToCastSocketId(socket)};
if (!vc_manager_->GetConnectionData(virtual_conn)) {
return;
}
diff --git a/cast/common/channel/virtual_connection.h b/cast/common/channel/virtual_connection.h
index 04f3ba06..6f8b2cb8 100644
--- a/cast/common/channel/virtual_connection.h
+++ b/cast/common/channel/virtual_connection.h
@@ -97,6 +97,8 @@ struct VirtualConnection {
// generated and intended to be unique within that device.
// - GUID-style hex string: Random string identifying a particular receiver
// app on the device.
+ //
+ // Additionally, |peer_id| can be an asterisk when broadcast-sending.
std::string local_id;
std::string peer_id;
int socket_id;
diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc
index 74efcd89..140ca138 100644
--- a/cast/common/channel/virtual_connection_router.cc
+++ b/cast/common/channel/virtual_connection_router.cc
@@ -4,6 +4,8 @@
#include "cast/common/channel/virtual_connection_router.h"
+#include <utility>
+
#include "cast/common/channel/cast_message_handler.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
@@ -55,7 +57,11 @@ void VirtualConnectionRouter::CloseSocket(int id) {
Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn,
CastMessage message) {
- // TODO(btolsch): Check for broadcast message.
+ if (virtual_conn.peer_id == kBroadcastId) {
+ return BroadcastFromLocalPeer(std::move(virtual_conn.local_id),
+ std::move(message));
+ }
+
if (!IsTransportNamespace(message.namespace_()) &&
!vc_manager_->GetConnectionData(virtual_conn)) {
return Error::Code::kNoActiveConnection;
@@ -69,8 +75,33 @@ Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn,
return it->second.socket->Send(message);
}
+Error VirtualConnectionRouter::BroadcastFromLocalPeer(
+ std::string local_id,
+ ::cast::channel::CastMessage message) {
+ message.set_source_id(std::move(local_id));
+ message.set_destination_id(kBroadcastId);
+
+ // Broadcast to local endpoints.
+ for (const auto& entry : endpoints_) {
+ if (entry.first != message.source_id()) {
+ entry.second->OnMessage(this, nullptr, message);
+ }
+ }
+
+ // Broadcast to remote endpoints. If an Error occurs, continue broadcasting,
+ // and later return the first Error that occurred.
+ Error error;
+ for (const auto& entry : sockets_) {
+ auto result = entry.second.socket->Send(message);
+ if (!result.ok() && error.ok()) {
+ error = std::move(result);
+ }
+ }
+ return error;
+}
+
void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
- int id = socket->socket_id();
+ const int id = socket->socket_id();
auto it = sockets_.find(id);
if (it != sockets_.end()) {
vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown);
@@ -83,17 +114,23 @@ void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
void VirtualConnectionRouter::OnMessage(CastSocket* socket,
CastMessage message) {
- // TODO(btolsch): Check for broadcast message.
- VirtualConnection virtual_conn{message.destination_id(), message.source_id(),
- socket->socket_id()};
- if (!IsTransportNamespace(message.namespace_()) &&
- !vc_manager_->GetConnectionData(virtual_conn)) {
- return;
- }
+ OSP_DCHECK(socket);
+
const std::string& local_id = message.destination_id();
- auto it = endpoints_.find(local_id);
- if (it != endpoints_.end()) {
- it->second->OnMessage(this, socket, std::move(message));
+ if (local_id == kBroadcastId) {
+ for (const auto& entry : endpoints_) {
+ entry.second->OnMessage(this, socket, message);
+ }
+ } else {
+ if (!IsTransportNamespace(message.namespace_()) &&
+ !vc_manager_->GetConnectionData(VirtualConnection{
+ local_id, message.source_id(), socket->socket_id()})) {
+ return;
+ }
+ auto it = endpoints_.find(local_id);
+ if (it != endpoints_.end()) {
+ it->second->OnMessage(this, socket, std::move(message));
+ }
}
}
diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h
index 3238d5aa..1bbf2bc1 100644
--- a/cast/common/channel/virtual_connection_router.h
+++ b/cast/common/channel/virtual_connection_router.h
@@ -62,6 +62,9 @@ class VirtualConnectionRouter final : public CastSocket::Client {
Error Send(VirtualConnection virtual_conn,
::cast::channel::CastMessage message);
+ Error BroadcastFromLocalPeer(std::string local_id,
+ ::cast::channel::CastMessage message);
+
// CastSocket::Client overrides.
void OnError(CastSocket* socket, Error error) override;
void OnMessage(CastSocket* socket,
diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc
index 6b1f0055..b05d10e3 100644
--- a/cast/common/channel/virtual_connection_router_unittest.cc
+++ b/cast/common/channel/virtual_connection_router_unittest.cc
@@ -4,6 +4,9 @@
#include "cast/common/channel/virtual_connection_router.h"
+#include <utility>
+
+#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_cast_message_handler.h"
@@ -19,35 +22,43 @@ namespace {
using ::cast::channel::CastMessage;
using ::testing::_;
using ::testing::Invoke;
+using ::testing::SaveArg;
+using ::testing::WithArg;
class VirtualConnectionRouterTest : public ::testing::Test {
public:
void SetUp() override {
- socket_ = fake_cast_socket_pair_.socket.get();
- router_.TakeSocket(&mock_error_handler_,
- std::move(fake_cast_socket_pair_.socket));
+ local_socket_ = fake_cast_socket_pair_.socket.get();
+ local_router_.TakeSocket(&mock_error_handler_,
+ std::move(fake_cast_socket_pair_.socket));
+
+ remote_socket_ = fake_cast_socket_pair_.peer_socket.get();
+ remote_router_.TakeSocket(&mock_error_handler_,
+ std::move(fake_cast_socket_pair_.peer_socket));
}
protected:
- CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; }
-
FakeCastSocketPair fake_cast_socket_pair_;
- CastSocket* socket_;
+ CastSocket* local_socket_;
+ CastSocket* remote_socket_;
MockSocketErrorHandler mock_error_handler_;
- MockCastMessageHandler mock_message_handler_;
- VirtualConnectionManager manager_;
- VirtualConnectionRouter router_{&manager_};
+ VirtualConnectionManager local_manager_;
+ VirtualConnectionRouter local_router_{&local_manager_};
+
+ VirtualConnectionManager remote_manager_;
+ VirtualConnectionRouter remote_router_{&remote_manager_};
};
} // namespace
TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
- router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_);
- manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()},
- {});
+ MockCastMessageHandler mock_message_handler;
+ local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
+ local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
+ local_socket_->socket_id()},
+ {});
CastMessage message;
message.set_protocol_version(
@@ -57,22 +68,25 @@ TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
message.set_destination_id("receiver-1234");
message.set_payload_type(CastMessage::STRING);
message.set_payload_utf8("cnlybnq");
- EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _));
- EXPECT_TRUE(peer_socket().Send(message).ok());
+ EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
+ EXPECT_TRUE(remote_socket_->Send(message).ok());
- EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _));
- EXPECT_TRUE(peer_socket().Send(message).ok());
+ EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
+ EXPECT_TRUE(remote_socket_->Send(message).ok());
message.set_destination_id("receiver-4321");
- EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)).Times(0);
- EXPECT_TRUE(peer_socket().Send(message).ok());
+ EXPECT_CALL(mock_message_handler, OnMessage(_, _, _)).Times(0);
+ EXPECT_TRUE(remote_socket_->Send(message).ok());
+
+ local_router_.RemoveHandlerForLocalId("receiver-1234");
}
TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
- router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_);
- manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()},
- {});
+ MockCastMessageHandler mock_message_handler;
+ local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
+ local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
+ local_socket_->socket_id()},
+ {});
CastMessage message;
message.set_protocol_version(
@@ -82,18 +96,27 @@ TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
message.set_destination_id("receiver-1234");
message.set_payload_type(CastMessage::STRING);
message.set_payload_utf8("cnlybnq");
- EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _));
- EXPECT_TRUE(peer_socket().Send(message).ok());
+ EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
+ EXPECT_TRUE(remote_socket_->Send(message).ok());
+
+ local_router_.RemoveHandlerForLocalId("receiver-1234");
- router_.RemoveHandlerForLocalId("receiver-1234");
+ EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)).Times(0);
+ EXPECT_TRUE(remote_socket_->Send(message).ok());
- EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)).Times(0);
- EXPECT_TRUE(peer_socket().Send(message).ok());
+ local_router_.RemoveHandlerForLocalId("receiver-1234");
}
TEST_F(VirtualConnectionRouterTest, SendMessage) {
- manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
+ local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
+ local_socket_->socket_id()},
+ {});
+
+ MockCastMessageHandler destination;
+ remote_router_.AddHandlerForLocalId("sender-4321", &destination);
+ remote_manager_.AddConnection(
+ VirtualConnection{"sender-4321", "receiver-1234",
+ remote_socket_->socket_id()},
{});
CastMessage message;
@@ -104,30 +127,159 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) {
message.set_destination_id("sender-4321");
message.set_payload_type(CastMessage::STRING);
message.set_payload_utf8("cnlybnq");
- EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
- .WillOnce(Invoke([](CastSocket* socket, CastMessage message) {
- EXPECT_EQ(message.namespace_(), "zrqvn");
- EXPECT_EQ(message.source_id(), "receiver-1234");
- EXPECT_EQ(message.destination_id(), "sender-4321");
- ASSERT_EQ(message.payload_type(),
- ::cast::channel::CastMessage_PayloadType_STRING);
- EXPECT_EQ(message.payload_utf8(), "cnlybnq");
- }));
- router_.Send(
- VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
- std::move(message));
+ ASSERT_TRUE(message.IsInitialized());
+
+ EXPECT_CALL(destination, OnMessage(&remote_router_, remote_socket_, _))
+ .WillOnce(
+ WithArg<2>(Invoke([&message](CastMessage message_at_destination) {
+ ASSERT_TRUE(message_at_destination.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(),
+ message_at_destination.SerializeAsString());
+ })));
+ local_router_.Send(VirtualConnection{"receiver-1234", "sender-4321",
+ local_socket_->socket_id()},
+ message);
}
TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) {
- manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
- {});
+ local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
+ local_socket_->socket_id()},
+ {});
+
+ EXPECT_CALL(mock_error_handler_, OnClose(local_socket_)).Times(1);
- int id = socket_->socket_id();
- router_.CloseSocket(id);
- EXPECT_FALSE(manager_.GetConnectionData(
+ int id = local_socket_->socket_id();
+ local_router_.CloseSocket(id);
+ EXPECT_FALSE(local_manager_.GetConnectionData(
VirtualConnection{"receiver-1234", "sender-4321", id}));
}
+// Tests that VirtualConnectionRouter::Send() broadcasts a message from a local
+// source to both: 1) all other local peers; and 2) all remote peers.
+TEST_F(VirtualConnectionRouterTest, BroadcastsFromLocalSource) {
+ // Local peers.
+ MockCastMessageHandler alice, bob;
+ local_router_.AddHandlerForLocalId("alice", &alice);
+ local_router_.AddHandlerForLocalId("bob", &bob);
+
+ // Remote peers.
+ MockCastMessageHandler charlie, dave, eve;
+ remote_router_.AddHandlerForLocalId("charlie", &charlie);
+ remote_router_.AddHandlerForLocalId("dave", &dave);
+ remote_router_.AddHandlerForLocalId("eve", &eve);
+
+ // The local broadcaster, which should never receive her own messages.
+ MockCastMessageHandler wendy;
+ local_router_.AddHandlerForLocalId("wendy", &wendy);
+ EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0);
+
+ CastMessage message;
+ message.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
+ message.set_namespace_("zrqvn");
+ message.set_payload_type(CastMessage::STRING);
+ message.set_payload_utf8("cnlybnq");
+
+ CastMessage message_alice_got, message_bob_got, message_charlie_got,
+ message_dave_got, message_eve_got;
+ EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _))
+ .WillOnce(SaveArg<2>(&message_alice_got))
+ .RetiresOnSaturation();
+ EXPECT_CALL(bob, OnMessage(&local_router_, nullptr, _))
+ .WillOnce(SaveArg<2>(&message_bob_got))
+ .RetiresOnSaturation();
+ EXPECT_CALL(charlie, OnMessage(&remote_router_, remote_socket_, _))
+ .WillOnce(SaveArg<2>(&message_charlie_got))
+ .RetiresOnSaturation();
+ EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _))
+ .WillOnce(SaveArg<2>(&message_dave_got))
+ .RetiresOnSaturation();
+ EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _))
+ .WillOnce(SaveArg<2>(&message_eve_got))
+ .RetiresOnSaturation();
+ ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok());
+
+ // Confirm message data is correct.
+ message.set_source_id("wendy");
+ message.set_destination_id(kBroadcastId);
+ ASSERT_TRUE(message.IsInitialized());
+ ASSERT_TRUE(message_alice_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString());
+ ASSERT_TRUE(message_bob_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString());
+ ASSERT_TRUE(message_charlie_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(),
+ message_charlie_got.SerializeAsString());
+ ASSERT_TRUE(message_dave_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(), message_dave_got.SerializeAsString());
+ ASSERT_TRUE(message_eve_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(), message_eve_got.SerializeAsString());
+
+ // Remove one local peer and one remote peer, and confirm only the correct
+ // entities receive a broadcast message.
+ local_router_.RemoveHandlerForLocalId("bob");
+ remote_router_.RemoveHandlerForLocalId("charlie");
+ EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _)).Times(1);
+ EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0);
+ EXPECT_CALL(charlie, OnMessage(_, _, _)).Times(0);
+ EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _)).Times(1);
+ EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _)).Times(1);
+ ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok());
+}
+
+// Tests that VirtualConnectionRouter::OnMessage() broadcasts a message from a
+// remote source to all local peers.
+TEST_F(VirtualConnectionRouterTest, BroadcastsFromRemoteSource) {
+ // Local peers.
+ MockCastMessageHandler alice, bob, charlie;
+ local_router_.AddHandlerForLocalId("alice", &alice);
+ local_router_.AddHandlerForLocalId("bob", &bob);
+ local_router_.AddHandlerForLocalId("charlie", &charlie);
+
+ // The remote broadcaster, which should never receive her own messages.
+ MockCastMessageHandler wendy;
+ remote_router_.AddHandlerForLocalId("wendy", &wendy);
+ EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0);
+
+ CastMessage message;
+ message.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
+ message.set_namespace_("zrqvn");
+ message.set_payload_type(CastMessage::STRING);
+ message.set_payload_utf8("cnlybnq");
+
+ CastMessage message_alice_got, message_bob_got, message_charlie_got;
+ EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _))
+ .WillOnce(SaveArg<2>(&message_alice_got))
+ .RetiresOnSaturation();
+ EXPECT_CALL(bob, OnMessage(&local_router_, local_socket_, _))
+ .WillOnce(SaveArg<2>(&message_bob_got))
+ .RetiresOnSaturation();
+ EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _))
+ .WillOnce(SaveArg<2>(&message_charlie_got))
+ .RetiresOnSaturation();
+ ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok());
+
+ // Confirm message data is correct.
+ message.set_source_id("wendy");
+ message.set_destination_id(kBroadcastId);
+ ASSERT_TRUE(message.IsInitialized());
+ ASSERT_TRUE(message_alice_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString());
+ ASSERT_TRUE(message_bob_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString());
+ ASSERT_TRUE(message_charlie_got.IsInitialized());
+ EXPECT_EQ(message.SerializeAsString(),
+ message_charlie_got.SerializeAsString());
+
+ // Remove one local peer, and confirm only the two remaining local peers
+ // receive a broadcast message from the remote source.
+ local_router_.RemoveHandlerForLocalId("bob");
+ EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _)).Times(1);
+ EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0);
+ EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _)).Times(1);
+ ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok());
+}
+
} // namespace cast
} // namespace openscreen
diff --git a/cast/common/public/cast_socket.h b/cast/common/public/cast_socket.h
index d7ac683f..2a67b659 100644
--- a/cast/common/public/cast_socket.h
+++ b/cast/common/public/cast_socket.h
@@ -79,6 +79,11 @@ class CastSocket : public TlsConnection::Client {
WeakPtrFactory<CastSocket> weak_factory_{this};
};
+// Returns socket->socket_id() if |socket| is not null, otherwise 0.
+inline int ToCastSocketId(CastSocket* socket) {
+ return socket ? socket->socket_id() : 0;
+}
+
} // namespace cast
} // namespace openscreen
diff --git a/cast/receiver/channel/device_auth_namespace_handler.cc b/cast/receiver/channel/device_auth_namespace_handler.cc
index 239459a0..17aca182 100644
--- a/cast/receiver/channel/device_auth_namespace_handler.cc
+++ b/cast/receiver/channel/device_auth_namespace_handler.cc
@@ -6,6 +6,9 @@
#include <openssl/evp.h>
+#include <memory>
+#include <utility>
+
#include "cast/common/certificate/cast_cert_validator.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
@@ -54,6 +57,9 @@ DeviceAuthNamespaceHandler::~DeviceAuthNamespaceHandler() = default;
void DeviceAuthNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
CastSocket* socket,
CastMessage message) {
+ if (!socket) {
+ return; // Don't handle auth messages from local senders. That's nonsense.
+ }
if (message.payload_type() !=
::cast::channel::CastMessage_PayloadType_BINARY) {
return;
diff --git a/cast/sender/cast_platform_client.cc b/cast/sender/cast_platform_client.cc
index 4d59c65b..224a58a4 100644
--- a/cast/sender/cast_platform_client.cc
+++ b/cast/sender/cast_platform_client.cc
@@ -4,7 +4,9 @@
#include "cast/sender/cast_platform_client.h"
+#include <memory>
#include <random>
+#include <utility>
#include "absl/strings/str_cat.h"
#include "cast/common/channel/virtual_connection_manager.h"
@@ -22,6 +24,8 @@ static constexpr std::chrono::seconds kRequestTimeout = std::chrono::seconds(5);
namespace {
+// TODO(miu): This is duplicated in another teammate's WIP CL. De-dupe this by
+// placing the utility in cast/common.
std::string MakeRandomSenderId() {
static auto& rd = *new std::random_device();
static auto& gen = *new std::mt19937(rd());
@@ -149,8 +153,9 @@ void CastPlatformClient::OnMessage(VirtualConnectionRouter* router,
if (request_id) {
auto entry = std::find_if(
socket_id_by_device_id_.begin(), socket_id_by_device_id_.end(),
- [socket](const std::pair<std::string, int>& entry) {
- return entry.second == socket->socket_id();
+ [socket_id =
+ ToCastSocketId(socket)](const std::pair<std::string, int>& entry) {
+ return entry.second == socket_id;
});
if (entry != socket_id_by_device_id_.end()) {
HandleResponse(entry->first, request_id.value(), dict);