diff options
-rw-r--r-- | cast/common/channel/cast_message_handler.h | 1 | ||||
-rw-r--r-- | cast/common/channel/cast_socket_message_port.cc | 2 | ||||
-rw-r--r-- | cast/common/channel/connection_namespace_handler.cc | 12 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection.h | 2 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_router.cc | 61 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_router.h | 3 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_router_unittest.cc | 248 | ||||
-rw-r--r-- | cast/common/public/cast_socket.h | 5 | ||||
-rw-r--r-- | cast/receiver/channel/device_auth_namespace_handler.cc | 6 | ||||
-rw-r--r-- | cast/sender/cast_platform_client.cc | 9 |
10 files changed, 283 insertions, 66 deletions
diff --git a/cast/common/channel/cast_message_handler.h b/cast/common/channel/cast_message_handler.h index cd0d13e6..e478d156 100644 --- a/cast/common/channel/cast_message_handler.h +++ b/cast/common/channel/cast_message_handler.h @@ -17,6 +17,7 @@ class CastMessageHandler { public: virtual ~CastMessageHandler() = default; + // |socket| is null if the source of the message is a local peer. virtual void OnMessage(VirtualConnectionRouter* router, CastSocket* socket, ::cast::channel::CastMessage message) = 0; diff --git a/cast/common/channel/cast_socket_message_port.cc b/cast/common/channel/cast_socket_message_port.cc index c3ca0df0..b6d65123 100644 --- a/cast/common/channel/cast_socket_message_port.cc +++ b/cast/common/channel/cast_socket_message_port.cc @@ -32,7 +32,7 @@ void CastSocketMessagePort::SetSocket(WeakPtr<CastSocket> socket) { } int CastSocketMessagePort::GetSocketId() { - return socket_ ? socket_->socket_id() : -1; + return ToCastSocketId(socket_.get()); } void CastSocketMessagePort::SetClient(MessagePort::Client* client, diff --git a/cast/common/channel/connection_namespace_handler.cc b/cast/common/channel/connection_namespace_handler.cc index 396b5d53..a449dcbd 100644 --- a/cast/common/channel/connection_namespace_handler.cc +++ b/cast/common/channel/connection_namespace_handler.cc @@ -4,7 +4,9 @@ #include "cast/common/channel/connection_namespace_handler.h" +#include <string> #include <type_traits> +#include <utility> #include "absl/types/optional.h" #include "cast/common/channel/message_util.h" @@ -138,7 +140,7 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, VirtualConnection virtual_conn{std::move(message.destination_id()), std::move(message.source_id()), - socket->socket_id()}; + ToCastSocketId(socket)}; if (!vc_policy_->IsConnectionAllowed(virtual_conn)) { SendClose(router, std::move(virtual_conn)); return; @@ -187,7 +189,11 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0; } - data.ip_fragment = socket->GetSanitizedIpAddress(); + if (socket) { + data.ip_fragment = socket->GetSanitizedIpAddress(); + } else { + data.ip_fragment = {}; + } OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", " << virtual_conn.peer_id << ", " << virtual_conn.socket_id; @@ -208,7 +214,7 @@ void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router, Json::Value parsed_message) { VirtualConnection virtual_conn{std::move(message.destination_id()), std::move(message.source_id()), - socket->socket_id()}; + ToCastSocketId(socket)}; if (!vc_manager_->GetConnectionData(virtual_conn)) { return; } diff --git a/cast/common/channel/virtual_connection.h b/cast/common/channel/virtual_connection.h index 04f3ba06..6f8b2cb8 100644 --- a/cast/common/channel/virtual_connection.h +++ b/cast/common/channel/virtual_connection.h @@ -97,6 +97,8 @@ struct VirtualConnection { // generated and intended to be unique within that device. // - GUID-style hex string: Random string identifying a particular receiver // app on the device. + // + // Additionally, |peer_id| can be an asterisk when broadcast-sending. std::string local_id; std::string peer_id; int socket_id; diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc index 74efcd89..140ca138 100644 --- a/cast/common/channel/virtual_connection_router.cc +++ b/cast/common/channel/virtual_connection_router.cc @@ -4,6 +4,8 @@ #include "cast/common/channel/virtual_connection_router.h" +#include <utility> + #include "cast/common/channel/cast_message_handler.h" #include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" @@ -55,7 +57,11 @@ void VirtualConnectionRouter::CloseSocket(int id) { Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn, CastMessage message) { - // TODO(btolsch): Check for broadcast message. + if (virtual_conn.peer_id == kBroadcastId) { + return BroadcastFromLocalPeer(std::move(virtual_conn.local_id), + std::move(message)); + } + if (!IsTransportNamespace(message.namespace_()) && !vc_manager_->GetConnectionData(virtual_conn)) { return Error::Code::kNoActiveConnection; @@ -69,8 +75,33 @@ Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn, return it->second.socket->Send(message); } +Error VirtualConnectionRouter::BroadcastFromLocalPeer( + std::string local_id, + ::cast::channel::CastMessage message) { + message.set_source_id(std::move(local_id)); + message.set_destination_id(kBroadcastId); + + // Broadcast to local endpoints. + for (const auto& entry : endpoints_) { + if (entry.first != message.source_id()) { + entry.second->OnMessage(this, nullptr, message); + } + } + + // Broadcast to remote endpoints. If an Error occurs, continue broadcasting, + // and later return the first Error that occurred. + Error error; + for (const auto& entry : sockets_) { + auto result = entry.second.socket->Send(message); + if (!result.ok() && error.ok()) { + error = std::move(result); + } + } + return error; +} + void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { - int id = socket->socket_id(); + const int id = socket->socket_id(); auto it = sockets_.find(id); if (it != sockets_.end()) { vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown); @@ -83,17 +114,23 @@ void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { void VirtualConnectionRouter::OnMessage(CastSocket* socket, CastMessage message) { - // TODO(btolsch): Check for broadcast message. - VirtualConnection virtual_conn{message.destination_id(), message.source_id(), - socket->socket_id()}; - if (!IsTransportNamespace(message.namespace_()) && - !vc_manager_->GetConnectionData(virtual_conn)) { - return; - } + OSP_DCHECK(socket); + const std::string& local_id = message.destination_id(); - auto it = endpoints_.find(local_id); - if (it != endpoints_.end()) { - it->second->OnMessage(this, socket, std::move(message)); + if (local_id == kBroadcastId) { + for (const auto& entry : endpoints_) { + entry.second->OnMessage(this, socket, message); + } + } else { + if (!IsTransportNamespace(message.namespace_()) && + !vc_manager_->GetConnectionData(VirtualConnection{ + local_id, message.source_id(), socket->socket_id()})) { + return; + } + auto it = endpoints_.find(local_id); + if (it != endpoints_.end()) { + it->second->OnMessage(this, socket, std::move(message)); + } } } diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h index 3238d5aa..1bbf2bc1 100644 --- a/cast/common/channel/virtual_connection_router.h +++ b/cast/common/channel/virtual_connection_router.h @@ -62,6 +62,9 @@ class VirtualConnectionRouter final : public CastSocket::Client { Error Send(VirtualConnection virtual_conn, ::cast::channel::CastMessage message); + Error BroadcastFromLocalPeer(std::string local_id, + ::cast::channel::CastMessage message); + // CastSocket::Client overrides. void OnError(CastSocket* socket, Error error) override; void OnMessage(CastSocket* socket, diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc index 6b1f0055..b05d10e3 100644 --- a/cast/common/channel/virtual_connection_router_unittest.cc +++ b/cast/common/channel/virtual_connection_router_unittest.cc @@ -4,6 +4,9 @@ #include "cast/common/channel/virtual_connection_router.h" +#include <utility> + +#include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_cast_message_handler.h" @@ -19,35 +22,43 @@ namespace { using ::cast::channel::CastMessage; using ::testing::_; using ::testing::Invoke; +using ::testing::SaveArg; +using ::testing::WithArg; class VirtualConnectionRouterTest : public ::testing::Test { public: void SetUp() override { - socket_ = fake_cast_socket_pair_.socket.get(); - router_.TakeSocket(&mock_error_handler_, - std::move(fake_cast_socket_pair_.socket)); + local_socket_ = fake_cast_socket_pair_.socket.get(); + local_router_.TakeSocket(&mock_error_handler_, + std::move(fake_cast_socket_pair_.socket)); + + remote_socket_ = fake_cast_socket_pair_.peer_socket.get(); + remote_router_.TakeSocket(&mock_error_handler_, + std::move(fake_cast_socket_pair_.peer_socket)); } protected: - CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; } - FakeCastSocketPair fake_cast_socket_pair_; - CastSocket* socket_; + CastSocket* local_socket_; + CastSocket* remote_socket_; MockSocketErrorHandler mock_error_handler_; - MockCastMessageHandler mock_message_handler_; - VirtualConnectionManager manager_; - VirtualConnectionRouter router_{&manager_}; + VirtualConnectionManager local_manager_; + VirtualConnectionRouter local_router_{&local_manager_}; + + VirtualConnectionManager remote_manager_; + VirtualConnectionRouter remote_router_{&remote_manager_}; }; } // namespace TEST_F(VirtualConnectionRouterTest, LocalIdHandler) { - router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_); - manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()}, - {}); + MockCastMessageHandler mock_message_handler; + local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler); + local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873", + local_socket_->socket_id()}, + {}); CastMessage message; message.set_protocol_version( @@ -57,22 +68,25 @@ TEST_F(VirtualConnectionRouterTest, LocalIdHandler) { message.set_destination_id("receiver-1234"); message.set_payload_type(CastMessage::STRING); message.set_payload_utf8("cnlybnq"); - EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)); - EXPECT_TRUE(peer_socket().Send(message).ok()); + EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)); + EXPECT_TRUE(remote_socket_->Send(message).ok()); - EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)); - EXPECT_TRUE(peer_socket().Send(message).ok()); + EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)); + EXPECT_TRUE(remote_socket_->Send(message).ok()); message.set_destination_id("receiver-4321"); - EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)).Times(0); - EXPECT_TRUE(peer_socket().Send(message).ok()); + EXPECT_CALL(mock_message_handler, OnMessage(_, _, _)).Times(0); + EXPECT_TRUE(remote_socket_->Send(message).ok()); + + local_router_.RemoveHandlerForLocalId("receiver-1234"); } TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) { - router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_); - manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()}, - {}); + MockCastMessageHandler mock_message_handler; + local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler); + local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873", + local_socket_->socket_id()}, + {}); CastMessage message; message.set_protocol_version( @@ -82,18 +96,27 @@ TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) { message.set_destination_id("receiver-1234"); message.set_payload_type(CastMessage::STRING); message.set_payload_utf8("cnlybnq"); - EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)); - EXPECT_TRUE(peer_socket().Send(message).ok()); + EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)); + EXPECT_TRUE(remote_socket_->Send(message).ok()); + + local_router_.RemoveHandlerForLocalId("receiver-1234"); - router_.RemoveHandlerForLocalId("receiver-1234"); + EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)).Times(0); + EXPECT_TRUE(remote_socket_->Send(message).ok()); - EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)).Times(0); - EXPECT_TRUE(peer_socket().Send(message).ok()); + local_router_.RemoveHandlerForLocalId("receiver-1234"); } TEST_F(VirtualConnectionRouterTest, SendMessage) { - manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, + local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321", + local_socket_->socket_id()}, + {}); + + MockCastMessageHandler destination; + remote_router_.AddHandlerForLocalId("sender-4321", &destination); + remote_manager_.AddConnection( + VirtualConnection{"sender-4321", "receiver-1234", + remote_socket_->socket_id()}, {}); CastMessage message; @@ -104,30 +127,159 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) { message.set_destination_id("sender-4321"); message.set_payload_type(CastMessage::STRING); message.set_payload_utf8("cnlybnq"); - EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) - .WillOnce(Invoke([](CastSocket* socket, CastMessage message) { - EXPECT_EQ(message.namespace_(), "zrqvn"); - EXPECT_EQ(message.source_id(), "receiver-1234"); - EXPECT_EQ(message.destination_id(), "sender-4321"); - ASSERT_EQ(message.payload_type(), - ::cast::channel::CastMessage_PayloadType_STRING); - EXPECT_EQ(message.payload_utf8(), "cnlybnq"); - })); - router_.Send( - VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, - std::move(message)); + ASSERT_TRUE(message.IsInitialized()); + + EXPECT_CALL(destination, OnMessage(&remote_router_, remote_socket_, _)) + .WillOnce( + WithArg<2>(Invoke([&message](CastMessage message_at_destination) { + ASSERT_TRUE(message_at_destination.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), + message_at_destination.SerializeAsString()); + }))); + local_router_.Send(VirtualConnection{"receiver-1234", "sender-4321", + local_socket_->socket_id()}, + message); } TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) { - manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, - {}); + local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321", + local_socket_->socket_id()}, + {}); + + EXPECT_CALL(mock_error_handler_, OnClose(local_socket_)).Times(1); - int id = socket_->socket_id(); - router_.CloseSocket(id); - EXPECT_FALSE(manager_.GetConnectionData( + int id = local_socket_->socket_id(); + local_router_.CloseSocket(id); + EXPECT_FALSE(local_manager_.GetConnectionData( VirtualConnection{"receiver-1234", "sender-4321", id})); } +// Tests that VirtualConnectionRouter::Send() broadcasts a message from a local +// source to both: 1) all other local peers; and 2) all remote peers. +TEST_F(VirtualConnectionRouterTest, BroadcastsFromLocalSource) { + // Local peers. + MockCastMessageHandler alice, bob; + local_router_.AddHandlerForLocalId("alice", &alice); + local_router_.AddHandlerForLocalId("bob", &bob); + + // Remote peers. + MockCastMessageHandler charlie, dave, eve; + remote_router_.AddHandlerForLocalId("charlie", &charlie); + remote_router_.AddHandlerForLocalId("dave", &dave); + remote_router_.AddHandlerForLocalId("eve", &eve); + + // The local broadcaster, which should never receive her own messages. + MockCastMessageHandler wendy; + local_router_.AddHandlerForLocalId("wendy", &wendy); + EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0); + + CastMessage message; + message.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); + message.set_namespace_("zrqvn"); + message.set_payload_type(CastMessage::STRING); + message.set_payload_utf8("cnlybnq"); + + CastMessage message_alice_got, message_bob_got, message_charlie_got, + message_dave_got, message_eve_got; + EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _)) + .WillOnce(SaveArg<2>(&message_alice_got)) + .RetiresOnSaturation(); + EXPECT_CALL(bob, OnMessage(&local_router_, nullptr, _)) + .WillOnce(SaveArg<2>(&message_bob_got)) + .RetiresOnSaturation(); + EXPECT_CALL(charlie, OnMessage(&remote_router_, remote_socket_, _)) + .WillOnce(SaveArg<2>(&message_charlie_got)) + .RetiresOnSaturation(); + EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _)) + .WillOnce(SaveArg<2>(&message_dave_got)) + .RetiresOnSaturation(); + EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _)) + .WillOnce(SaveArg<2>(&message_eve_got)) + .RetiresOnSaturation(); + ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok()); + + // Confirm message data is correct. + message.set_source_id("wendy"); + message.set_destination_id(kBroadcastId); + ASSERT_TRUE(message.IsInitialized()); + ASSERT_TRUE(message_alice_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString()); + ASSERT_TRUE(message_bob_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString()); + ASSERT_TRUE(message_charlie_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), + message_charlie_got.SerializeAsString()); + ASSERT_TRUE(message_dave_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), message_dave_got.SerializeAsString()); + ASSERT_TRUE(message_eve_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), message_eve_got.SerializeAsString()); + + // Remove one local peer and one remote peer, and confirm only the correct + // entities receive a broadcast message. + local_router_.RemoveHandlerForLocalId("bob"); + remote_router_.RemoveHandlerForLocalId("charlie"); + EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _)).Times(1); + EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0); + EXPECT_CALL(charlie, OnMessage(_, _, _)).Times(0); + EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _)).Times(1); + EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _)).Times(1); + ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok()); +} + +// Tests that VirtualConnectionRouter::OnMessage() broadcasts a message from a +// remote source to all local peers. +TEST_F(VirtualConnectionRouterTest, BroadcastsFromRemoteSource) { + // Local peers. + MockCastMessageHandler alice, bob, charlie; + local_router_.AddHandlerForLocalId("alice", &alice); + local_router_.AddHandlerForLocalId("bob", &bob); + local_router_.AddHandlerForLocalId("charlie", &charlie); + + // The remote broadcaster, which should never receive her own messages. + MockCastMessageHandler wendy; + remote_router_.AddHandlerForLocalId("wendy", &wendy); + EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0); + + CastMessage message; + message.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); + message.set_namespace_("zrqvn"); + message.set_payload_type(CastMessage::STRING); + message.set_payload_utf8("cnlybnq"); + + CastMessage message_alice_got, message_bob_got, message_charlie_got; + EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _)) + .WillOnce(SaveArg<2>(&message_alice_got)) + .RetiresOnSaturation(); + EXPECT_CALL(bob, OnMessage(&local_router_, local_socket_, _)) + .WillOnce(SaveArg<2>(&message_bob_got)) + .RetiresOnSaturation(); + EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _)) + .WillOnce(SaveArg<2>(&message_charlie_got)) + .RetiresOnSaturation(); + ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok()); + + // Confirm message data is correct. + message.set_source_id("wendy"); + message.set_destination_id(kBroadcastId); + ASSERT_TRUE(message.IsInitialized()); + ASSERT_TRUE(message_alice_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString()); + ASSERT_TRUE(message_bob_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString()); + ASSERT_TRUE(message_charlie_got.IsInitialized()); + EXPECT_EQ(message.SerializeAsString(), + message_charlie_got.SerializeAsString()); + + // Remove one local peer, and confirm only the two remaining local peers + // receive a broadcast message from the remote source. + local_router_.RemoveHandlerForLocalId("bob"); + EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _)).Times(1); + EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0); + EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _)).Times(1); + ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok()); +} + } // namespace cast } // namespace openscreen diff --git a/cast/common/public/cast_socket.h b/cast/common/public/cast_socket.h index d7ac683f..2a67b659 100644 --- a/cast/common/public/cast_socket.h +++ b/cast/common/public/cast_socket.h @@ -79,6 +79,11 @@ class CastSocket : public TlsConnection::Client { WeakPtrFactory<CastSocket> weak_factory_{this}; }; +// Returns socket->socket_id() if |socket| is not null, otherwise 0. +inline int ToCastSocketId(CastSocket* socket) { + return socket ? socket->socket_id() : 0; +} + } // namespace cast } // namespace openscreen diff --git a/cast/receiver/channel/device_auth_namespace_handler.cc b/cast/receiver/channel/device_auth_namespace_handler.cc index 239459a0..17aca182 100644 --- a/cast/receiver/channel/device_auth_namespace_handler.cc +++ b/cast/receiver/channel/device_auth_namespace_handler.cc @@ -6,6 +6,9 @@ #include <openssl/evp.h> +#include <memory> +#include <utility> + #include "cast/common/certificate/cast_cert_validator.h" #include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" @@ -54,6 +57,9 @@ DeviceAuthNamespaceHandler::~DeviceAuthNamespaceHandler() = default; void DeviceAuthNamespaceHandler::OnMessage(VirtualConnectionRouter* router, CastSocket* socket, CastMessage message) { + if (!socket) { + return; // Don't handle auth messages from local senders. That's nonsense. + } if (message.payload_type() != ::cast::channel::CastMessage_PayloadType_BINARY) { return; diff --git a/cast/sender/cast_platform_client.cc b/cast/sender/cast_platform_client.cc index 4d59c65b..224a58a4 100644 --- a/cast/sender/cast_platform_client.cc +++ b/cast/sender/cast_platform_client.cc @@ -4,7 +4,9 @@ #include "cast/sender/cast_platform_client.h" +#include <memory> #include <random> +#include <utility> #include "absl/strings/str_cat.h" #include "cast/common/channel/virtual_connection_manager.h" @@ -22,6 +24,8 @@ static constexpr std::chrono::seconds kRequestTimeout = std::chrono::seconds(5); namespace { +// TODO(miu): This is duplicated in another teammate's WIP CL. De-dupe this by +// placing the utility in cast/common. std::string MakeRandomSenderId() { static auto& rd = *new std::random_device(); static auto& gen = *new std::mt19937(rd()); @@ -149,8 +153,9 @@ void CastPlatformClient::OnMessage(VirtualConnectionRouter* router, if (request_id) { auto entry = std::find_if( socket_id_by_device_id_.begin(), socket_id_by_device_id_.end(), - [socket](const std::pair<std::string, int>& entry) { - return entry.second == socket->socket_id(); + [socket_id = + ToCastSocketId(socket)](const std::pair<std::string, int>& entry) { + return entry.second == socket_id; }); if (entry != socket_id_by_device_id_.end()) { HandleResponse(entry->first, request_id.value(), dict); |