diff options
author | Yuri Wiitala <miu@chromium.org> | 2020-11-30 08:13:53 -0800 |
---|---|---|
committer | Yuri Wiitala <miu@chromium.org> | 2020-11-30 20:03:13 +0000 |
commit | c6465ca683e686cd7f7dfa347e86451425e7af25 (patch) | |
tree | 71d7a43bb08d7d94d576e55c0e5ef0c0c5876040 /cast | |
parent | 9be11f7d2b44dff83dc88dff054b967a43f319cc (diff) | |
download | openscreen-c6465ca683e686cd7f7dfa347e86451425e7af25.tar.gz |
Remote virtual connections [1/3]: Refactor VCRouter to extend VCManager.
This is a clean-up patch to simplify code structure. Before this patch,
all clients of VirtualConnectionRouter need to instantiate both a
VirtualConnectionManager and a VirtualConnectionRouter, and then pass
the former (by pointer) to the latter. After this patch, VCR extends
VCM, and clients need not worry about a VCM.
Bug: b/162542369
Change-Id: I0dff69819d9b5282a43643a2da0fe33e7bf0a3fd
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2546803
Reviewed-by: Brandon Tolsch <btolsch@chromium.org>
Diffstat (limited to 'cast')
23 files changed, 345 insertions, 471 deletions
diff --git a/cast/common/BUILD.gn b/cast/common/BUILD.gn index c1ea0b3c..c1a8acd7 100644 --- a/cast/common/BUILD.gn +++ b/cast/common/BUILD.gn @@ -64,8 +64,6 @@ source_set("channel") { "channel/namespace_router.cc", "channel/namespace_router.h", "channel/virtual_connection.h", - "channel/virtual_connection_manager.cc", - "channel/virtual_connection_manager.h", "channel/virtual_connection_router.cc", "channel/virtual_connection_router.h", "public/cast_socket.h", @@ -152,7 +150,6 @@ source_set("unittests") { "channel/connection_namespace_handler_unittest.cc", "channel/message_framer_unittest.cc", "channel/namespace_router_unittest.cc", - "channel/virtual_connection_manager_unittest.cc", "channel/virtual_connection_router_unittest.cc", "public/service_info_unittest.cc", ] diff --git a/cast/common/channel/cast_socket_message_port.cc b/cast/common/channel/cast_socket_message_port.cc index 3bcdca22..0c51304b 100644 --- a/cast/common/channel/cast_socket_message_port.cc +++ b/cast/common/channel/cast_socket_message_port.cc @@ -9,7 +9,6 @@ #include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/virtual_connection.h" -#include "cast/common/channel/virtual_connection_manager.h" namespace openscreen { namespace cast { @@ -51,8 +50,7 @@ void CastSocketMessagePort::ResetClient() { client_ = nullptr; router_->RemoveHandlerForLocalId(client_sender_id_); - router_->manager()->RemoveConnectionsByLocalId( - client_sender_id_, VirtualConnection::CloseReason::kClosedBySelf); + router_->RemoveConnectionsByLocalId(client_sender_id_); client_sender_id_.clear(); } @@ -72,9 +70,8 @@ void CastSocketMessagePort::PostMessage( VirtualConnection connection{client_sender_id_, destination_sender_id, socket_->socket_id()}; - if (!router_->manager()->GetConnectionData(connection)) { - router_->manager()->AddConnection(connection, - VirtualConnection::AssociatedData{}); + if (!router_->GetConnectionData(connection)) { + router_->AddConnection(connection, VirtualConnection::AssociatedData{}); } const Error send_error = router_->Send( diff --git a/cast/common/channel/connection_namespace_handler.cc b/cast/common/channel/connection_namespace_handler.cc index a449dcbd..dd459060 100644 --- a/cast/common/channel/connection_namespace_handler.cc +++ b/cast/common/channel/connection_namespace_handler.cc @@ -4,6 +4,7 @@ #include "cast/common/channel/connection_namespace_handler.h" +#include <algorithm> #include <string> #include <type_traits> #include <utility> @@ -12,7 +13,6 @@ #include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/virtual_connection.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "util/json/json_serialization.h" @@ -82,11 +82,11 @@ VirtualConnection::CloseReason GetCloseReason( } // namespace ConnectionNamespaceHandler::ConnectionNamespaceHandler( - VirtualConnectionManager* vc_manager, + VirtualConnectionRouter* vc_router, VirtualConnectionPolicy* vc_policy) - : vc_manager_(vc_manager), vc_policy_(vc_policy) { - OSP_DCHECK(vc_manager); - OSP_DCHECK(vc_policy); + : vc_router_(vc_router), vc_policy_(vc_policy) { + OSP_DCHECK(vc_router_); + OSP_DCHECK(vc_policy_); } ConnectionNamespaceHandler::~ConnectionNamespaceHandler() = default; @@ -120,17 +120,16 @@ void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router, absl::string_view type_str = type.value(); if (type_str == kMessageTypeConnect) { - HandleConnect(router, socket, std::move(message), std::move(value)); + HandleConnect(socket, std::move(message), std::move(value)); } else if (type_str == kMessageTypeClose) { - HandleClose(router, socket, std::move(message), std::move(value)); + HandleClose(socket, std::move(message), std::move(value)); } else { // NOTE: Unknown message type so ignore it. // TODO(btolsch): Should be included in future error reporting. } } -void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, - CastSocket* socket, +void ConnectionNamespaceHandler::HandleConnect(CastSocket* socket, CastMessage message, Json::Value parsed_message) { if (message.destination_id() == kBroadcastId || @@ -142,7 +141,7 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, std::move(message.source_id()), ToCastSocketId(socket)}; if (!vc_policy_->IsConnectionAllowed(virtual_conn)) { - SendClose(router, std::move(virtual_conn)); + SendClose(virtual_conn); return; } @@ -153,7 +152,7 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, int int_type = maybe_conn_type.value(); if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) || int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) { - SendClose(router, std::move(virtual_conn)); + SendClose(virtual_conn); return; } conn_type = static_cast<VirtualConnection::Type>(int_type); @@ -202,20 +201,19 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, // maintains compatibility with older senders that don't send a version and // don't expect a response. if (negotiated_version) { - SendConnectedResponse(router, virtual_conn, negotiated_version.value()); + SendConnectedResponse(virtual_conn, negotiated_version.value()); } - vc_manager_->AddConnection(std::move(virtual_conn), std::move(data)); + vc_router_->AddConnection(std::move(virtual_conn), std::move(data)); } -void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router, - CastSocket* socket, +void ConnectionNamespaceHandler::HandleClose(CastSocket* socket, CastMessage message, Json::Value parsed_message) { VirtualConnection virtual_conn{std::move(message.destination_id()), std::move(message.source_id()), ToCastSocketId(socket)}; - if (!vc_manager_->GetConnectionData(virtual_conn)) { + if (!vc_router_->GetConnectionData(virtual_conn)) { return; } @@ -224,11 +222,11 @@ void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router, OSP_DVLOG << "Connection closed (reason: " << reason << "): " << virtual_conn.local_id << ", " << virtual_conn.peer_id << ", " << virtual_conn.socket_id; - vc_manager_->RemoveConnection(virtual_conn, reason); + vc_router_->RemoveConnection(virtual_conn, reason); } -void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router, - VirtualConnection virtual_conn) { +void ConnectionNamespaceHandler::SendClose( + const VirtualConnection& virtual_conn) { Json::Value close_message(Json::ValueType::objectValue); close_message[kMessageKeyType] = kMessageTypeClose; @@ -237,13 +235,12 @@ void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router, return; } - router->Send( + vc_router_->Send( std::move(virtual_conn), MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value()))); } void ConnectionNamespaceHandler::SendConnectedResponse( - VirtualConnectionRouter* router, const VirtualConnection& virtual_conn, int max_protocol_version) { Json::Value connected_message(Json::ValueType::objectValue); @@ -256,8 +253,9 @@ void ConnectionNamespaceHandler::SendConnectedResponse( return; } - router->Send(virtual_conn, MakeSimpleUTF8Message(kConnectionNamespace, - std::move(result.value()))); + vc_router_->Send( + virtual_conn, + MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value()))); } } // namespace cast diff --git a/cast/common/channel/connection_namespace_handler.h b/cast/common/channel/connection_namespace_handler.h index 5307e896..65388a80 100644 --- a/cast/common/channel/connection_namespace_handler.h +++ b/cast/common/channel/connection_namespace_handler.h @@ -13,7 +13,6 @@ namespace openscreen { namespace cast { struct VirtualConnection; -class VirtualConnectionManager; class VirtualConnectionRouter; // Handles CastMessages in the connection namespace by opening and closing @@ -28,8 +27,8 @@ class ConnectionNamespaceHandler final : public CastMessageHandler { const VirtualConnection& virtual_conn) const = 0; }; - // Both |vc_manager| and |vc_policy| should outlive this object. - ConnectionNamespaceHandler(VirtualConnectionManager* vc_manager, + // Both |vc_router| and |vc_policy| should outlive this object. + ConnectionNamespaceHandler(VirtualConnectionRouter* vc_router, VirtualConnectionPolicy* vc_policy); ~ConnectionNamespaceHandler() override; @@ -39,22 +38,18 @@ class ConnectionNamespaceHandler final : public CastMessageHandler { ::cast::channel::CastMessage message) override; private: - void HandleConnect(VirtualConnectionRouter* router, - CastSocket* socket, + void HandleConnect(CastSocket* socket, ::cast::channel::CastMessage message, Json::Value parsed_message); - void HandleClose(VirtualConnectionRouter* router, - CastSocket* socket, + void HandleClose(CastSocket* socket, ::cast::channel::CastMessage message, Json::Value parsed_message); - void SendClose(VirtualConnectionRouter* router, - VirtualConnection virtual_conn); - void SendConnectedResponse(VirtualConnectionRouter* router, - const VirtualConnection& virtual_conn, + void SendClose(const VirtualConnection& virtual_conn); + void SendConnectedResponse(const VirtualConnection& virtual_conn, int max_protocol_version); - VirtualConnectionManager* const vc_manager_; + VirtualConnectionRouter* const vc_router_; VirtualConnectionPolicy* const vc_policy_; }; diff --git a/cast/common/channel/connection_namespace_handler_unittest.cc b/cast/common/channel/connection_namespace_handler_unittest.cc index f57a4253..00b56f8e 100644 --- a/cast/common/channel/connection_namespace_handler_unittest.cc +++ b/cast/common/channel/connection_namespace_handler_unittest.cc @@ -4,11 +4,14 @@ #include "cast/common/channel/connection_namespace_handler.h" +#include <string> +#include <utility> +#include <vector> + #include "cast/common/channel/message_util.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" #include "cast/common/channel/virtual_connection.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "gmock/gmock.h" @@ -139,9 +142,8 @@ class ConnectionNamespaceHandlerTest : public ::testing::Test { CastSocket* socket_; NiceMock<MockVirtualConnectionPolicy> vc_policy_; - VirtualConnectionManager vc_manager_; - VirtualConnectionRouter router_{&vc_manager_}; - ConnectionNamespaceHandler connection_namespace_handler_{&vc_manager_, + VirtualConnectionRouter router_; + ConnectionNamespaceHandler connection_namespace_handler_{&router_, &vc_policy_}; const std::string sender_id_{"sender-5678"}; @@ -151,7 +153,7 @@ class ConnectionNamespaceHandlerTest : public ::testing::Test { TEST_F(ConnectionNamespaceHandlerTest, Connect) { connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); - EXPECT_TRUE(vc_manager_.GetConnectionData( + EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) @@ -166,7 +168,7 @@ TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) { sender_id_); connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); - EXPECT_FALSE(vc_manager_.GetConnectionData( + EXPECT_FALSE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } @@ -179,7 +181,7 @@ TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) { MakeVersionedConnectMessage( sender_id_, receiver_id_, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {})); - EXPECT_TRUE(vc_manager_.GetConnectionData( + EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } @@ -194,31 +196,31 @@ TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) { ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0})); - EXPECT_TRUE(vc_manager_.GetConnectionData( + EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } TEST_F(ConnectionNamespaceHandlerTest, Close) { connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); - EXPECT_TRUE(vc_manager_.GetConnectionData( + EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); connection_namespace_handler_.OnMessage( &router_, socket_, MakeCloseMessage(sender_id_, receiver_id_)); - EXPECT_FALSE(vc_manager_.GetConnectionData( + EXPECT_FALSE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) { connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); - EXPECT_TRUE(vc_manager_.GetConnectionData( + EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); connection_namespace_handler_.OnMessage( &router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_)); - EXPECT_TRUE(vc_manager_.GetConnectionData( + EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } diff --git a/cast/common/channel/namespace_router_unittest.cc b/cast/common/channel/namespace_router_unittest.cc index 96907c08..7f4ce17e 100644 --- a/cast/common/channel/namespace_router_unittest.cc +++ b/cast/common/channel/namespace_router_unittest.cc @@ -4,11 +4,12 @@ #include "cast/common/channel/namespace_router.h" +#include <utility> + #include "cast/common/channel/cast_message_handler.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_cast_message_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -27,8 +28,7 @@ class NamespaceRouterTest : public ::testing::Test { CastSocket* socket() { return &fake_socket_.socket; } FakeCastSocket fake_socket_; - VirtualConnectionManager vc_manager_; - VirtualConnectionRouter vc_router_{&vc_manager_}; + VirtualConnectionRouter vc_router_; NamespaceRouter router_; }; diff --git a/cast/common/channel/virtual_connection_manager.cc b/cast/common/channel/virtual_connection_manager.cc deleted file mode 100644 index 86e06706..00000000 --- a/cast/common/channel/virtual_connection_manager.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/common/channel/virtual_connection_manager.h" - -#include <type_traits> - -namespace openscreen { -namespace cast { - -VirtualConnectionManager::VirtualConnectionManager() = default; - -VirtualConnectionManager::~VirtualConnectionManager() = default; - -void VirtualConnectionManager::AddConnection( - VirtualConnection virtual_connection, - VirtualConnection::AssociatedData associated_data) { - auto& socket_map = connections_[virtual_connection.socket_id]; - auto local_entries = socket_map.equal_range(virtual_connection.local_id); - auto it = std::find_if( - local_entries.first, local_entries.second, - [&virtual_connection](const std::pair<std::string, VCTail>& entry) { - return entry.second.peer_id == virtual_connection.peer_id; - }); - if (it == socket_map.end()) { - socket_map.emplace(std::move(virtual_connection.local_id), - VCTail{std::move(virtual_connection.peer_id), - std::move(associated_data)}); - } -} - -bool VirtualConnectionManager::RemoveConnection( - const VirtualConnection& virtual_connection, - VirtualConnection::CloseReason reason) { - auto socket_entry = connections_.find(virtual_connection.socket_id); - if (socket_entry == connections_.end()) { - return false; - } - - auto& socket_map = socket_entry->second; - auto local_entries = socket_map.equal_range(virtual_connection.local_id); - if (local_entries.first == socket_map.end()) { - return false; - } - for (auto it = local_entries.first; it != local_entries.second; ++it) { - if (it->second.peer_id == virtual_connection.peer_id) { - socket_map.erase(it); - if (socket_map.empty()) { - connections_.erase(socket_entry); - } - return true; - } - } - return false; -} - -size_t VirtualConnectionManager::RemoveConnectionsByLocalId( - const std::string& local_id, - VirtualConnection::CloseReason reason) { - size_t removed_count = 0; - for (auto socket_entry = connections_.begin(); - socket_entry != connections_.end();) { - auto& socket_map = socket_entry->second; - auto local_entries = socket_map.equal_range(local_id); - if (local_entries.first != socket_map.end()) { - size_t current_count = - std::distance(local_entries.first, local_entries.second); - removed_count += current_count; - socket_map.erase(local_entries.first, local_entries.second); - if (socket_map.empty()) { - socket_entry = connections_.erase(socket_entry); - } else { - ++socket_entry; - } - } else { - ++socket_entry; - } - } - return removed_count; -} - -size_t VirtualConnectionManager::RemoveConnectionsBySocketId( - int socket_id, - VirtualConnection::CloseReason reason) { - auto entry = connections_.find(socket_id); - if (entry == connections_.end()) { - return 0; - } - - size_t removed_count = entry->second.size(); - connections_.erase(entry); - - return removed_count; -} - -absl::optional<const VirtualConnection::AssociatedData*> -VirtualConnectionManager::GetConnectionData( - const VirtualConnection& virtual_connection) const { - auto socket_entry = connections_.find(virtual_connection.socket_id); - if (socket_entry == connections_.end()) { - return absl::nullopt; - } - - auto& socket_map = socket_entry->second; - auto local_entries = socket_map.equal_range(virtual_connection.local_id); - if (local_entries.first == socket_map.end()) { - return absl::nullopt; - } - for (auto it = local_entries.first; it != local_entries.second; ++it) { - if (it->second.peer_id == virtual_connection.peer_id) { - return &it->second.data; - } - } - return absl::nullopt; -} - -} // namespace cast -} // namespace openscreen diff --git a/cast/common/channel/virtual_connection_manager.h b/cast/common/channel/virtual_connection_manager.h deleted file mode 100644 index 902d2a96..00000000 --- a/cast/common/channel/virtual_connection_manager.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_ -#define CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_ - -#include <cstdint> -#include <map> -#include <string> - -#include "absl/types/optional.h" -#include "cast/common/channel/virtual_connection.h" - -namespace openscreen { -namespace cast { - -// Maintains a collection of open VirtualConnections and associated data. -class VirtualConnectionManager { - public: - VirtualConnectionManager(); - ~VirtualConnectionManager(); - - void AddConnection(VirtualConnection virtual_connection, - VirtualConnection::AssociatedData associated_data); - - // Returns true if a connection matching |virtual_connection| was found and - // removed. - bool RemoveConnection(const VirtualConnection& virtual_connection, - VirtualConnection::CloseReason reason); - - // Returns the number of connections removed. - size_t RemoveConnectionsByLocalId(const std::string& local_id, - VirtualConnection::CloseReason reason); - size_t RemoveConnectionsBySocketId(int socket_id, - VirtualConnection::CloseReason reason); - - // Returns the AssociatedData for |virtual_connection| if a connection exists, - // nullopt otherwise. The pointer isn't stable in the long term, so if it - // actually needs to be stored for later, the caller should make a copy. - absl::optional<const VirtualConnection::AssociatedData*> GetConnectionData( - const VirtualConnection& virtual_connection) const; - - private: - // This struct simply stores the remainder of the data {VirtualConnection, - // VirtVirtualConnection::AssociatedData} that is not broken up into map keys - // for |connections_|. - struct VCTail { - std::string peer_id; - VirtualConnection::AssociatedData data; - }; - - std::map<int /* socket_id */, - std::multimap<std::string /* local_id */, VCTail>> - connections_; -}; - -} // namespace cast -} // namespace openscreen - -#endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_ diff --git a/cast/common/channel/virtual_connection_manager_unittest.cc b/cast/common/channel/virtual_connection_manager_unittest.cc deleted file mode 100644 index 963fcac7..00000000 --- a/cast/common/channel/virtual_connection_manager_unittest.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/common/channel/virtual_connection_manager.h" - -#include "cast/common/channel/proto/cast_channel.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace openscreen { -namespace cast { -namespace { - -static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 == - static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0), - "V2 1.0 constants must be equal"); -static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 == - static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1), - "V2 1.1 constants must be equal"); -static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 == - static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2), - "V2 1.2 constants must be equal"); -static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 == - static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3), - "V2 1.3 constants must be equal"); - -using ::testing::_; -using ::testing::Invoke; - -class VirtualConnectionManagerTest : public ::testing::Test { - protected: - VirtualConnectionManager manager_; - - VirtualConnection vc1_{"local1", "peer1", 75}; - VirtualConnection vc2_{"local2", "peer2", 76}; - VirtualConnection vc3_{"local1", "peer3", 75}; -}; - -} // namespace - -TEST_F(VirtualConnectionManagerTest, NoConnections) { - EXPECT_FALSE(manager_.GetConnectionData(vc1_)); - EXPECT_FALSE(manager_.GetConnectionData(vc2_)); - EXPECT_FALSE(manager_.GetConnectionData(vc3_)); -} - -TEST_F(VirtualConnectionManagerTest, AddConnections) { - VirtualConnection::AssociatedData data1 = {}; - - manager_.AddConnection(vc1_, std::move(data1)); - EXPECT_TRUE(manager_.GetConnectionData(vc1_)); - EXPECT_FALSE(manager_.GetConnectionData(vc2_)); - EXPECT_FALSE(manager_.GetConnectionData(vc3_)); - - VirtualConnection::AssociatedData data2 = {}; - manager_.AddConnection(vc2_, std::move(data2)); - EXPECT_TRUE(manager_.GetConnectionData(vc1_)); - EXPECT_TRUE(manager_.GetConnectionData(vc2_)); - EXPECT_FALSE(manager_.GetConnectionData(vc3_)); - - VirtualConnection::AssociatedData data3 = {}; - manager_.AddConnection(vc3_, std::move(data3)); - EXPECT_TRUE(manager_.GetConnectionData(vc1_)); - EXPECT_TRUE(manager_.GetConnectionData(vc2_)); - EXPECT_TRUE(manager_.GetConnectionData(vc3_)); -} - -TEST_F(VirtualConnectionManagerTest, RemoveConnections) { - VirtualConnection::AssociatedData data1 = {}; - VirtualConnection::AssociatedData data2 = {}; - VirtualConnection::AssociatedData data3 = {}; - - manager_.AddConnection(vc1_, std::move(data1)); - manager_.AddConnection(vc2_, std::move(data2)); - manager_.AddConnection(vc3_, std::move(data3)); - - EXPECT_TRUE(manager_.RemoveConnection( - vc1_, VirtualConnection::CloseReason::kClosedBySelf)); - EXPECT_FALSE(manager_.GetConnectionData(vc1_)); - EXPECT_TRUE(manager_.GetConnectionData(vc2_)); - EXPECT_TRUE(manager_.GetConnectionData(vc3_)); - - EXPECT_TRUE(manager_.RemoveConnection( - vc2_, VirtualConnection::CloseReason::kClosedBySelf)); - EXPECT_FALSE(manager_.GetConnectionData(vc1_)); - EXPECT_FALSE(manager_.GetConnectionData(vc2_)); - EXPECT_TRUE(manager_.GetConnectionData(vc3_)); - - EXPECT_TRUE(manager_.RemoveConnection( - vc3_, VirtualConnection::CloseReason::kClosedBySelf)); - EXPECT_FALSE(manager_.GetConnectionData(vc1_)); - EXPECT_FALSE(manager_.GetConnectionData(vc2_)); - EXPECT_FALSE(manager_.GetConnectionData(vc3_)); - - EXPECT_FALSE(manager_.RemoveConnection( - vc1_, VirtualConnection::CloseReason::kClosedBySelf)); - EXPECT_FALSE(manager_.RemoveConnection( - vc2_, VirtualConnection::CloseReason::kClosedBySelf)); - EXPECT_FALSE(manager_.RemoveConnection( - vc3_, VirtualConnection::CloseReason::kClosedBySelf)); -} - -TEST_F(VirtualConnectionManagerTest, RemoveConnectionsByIds) { - VirtualConnection::AssociatedData data1 = {}; - VirtualConnection::AssociatedData data2 = {}; - VirtualConnection::AssociatedData data3 = {}; - - manager_.AddConnection(vc1_, std::move(data1)); - manager_.AddConnection(vc2_, std::move(data2)); - manager_.AddConnection(vc3_, std::move(data3)); - - EXPECT_EQ(manager_.RemoveConnectionsByLocalId( - "local1", VirtualConnection::CloseReason::kClosedBySelf), - 2u); - EXPECT_FALSE(manager_.GetConnectionData(vc1_)); - EXPECT_TRUE(manager_.GetConnectionData(vc2_)); - EXPECT_FALSE(manager_.GetConnectionData(vc3_)); - - data1 = {}; - data2 = {}; - data3 = {}; - manager_.AddConnection(vc1_, std::move(data1)); - manager_.AddConnection(vc2_, std::move(data2)); - manager_.AddConnection(vc3_, std::move(data3)); - EXPECT_EQ(manager_.RemoveConnectionsBySocketId( - 76, VirtualConnection::CloseReason::kClosedBySelf), - 1u); - EXPECT_TRUE(manager_.GetConnectionData(vc1_)); - EXPECT_FALSE(manager_.GetConnectionData(vc2_)); - EXPECT_TRUE(manager_.GetConnectionData(vc3_)); - - EXPECT_EQ(manager_.RemoveConnectionsBySocketId( - 75, VirtualConnection::CloseReason::kClosedBySelf), - 2u); - EXPECT_FALSE(manager_.GetConnectionData(vc1_)); - EXPECT_FALSE(manager_.GetConnectionData(vc2_)); - EXPECT_FALSE(manager_.GetConnectionData(vc3_)); -} - -} // namespace cast -} // namespace openscreen diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc index 140ca138..98e6d19d 100644 --- a/cast/common/channel/virtual_connection_router.cc +++ b/cast/common/channel/virtual_connection_router.cc @@ -9,7 +9,6 @@ #include "cast/common/channel/cast_message_handler.h" #include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "util/osp_logging.h" namespace openscreen { @@ -17,14 +16,97 @@ namespace cast { using ::cast::channel::CastMessage; -VirtualConnectionRouter::VirtualConnectionRouter( - VirtualConnectionManager* vc_manager) - : vc_manager_(vc_manager) { - OSP_DCHECK(vc_manager); -} +VirtualConnectionRouter::VirtualConnectionRouter() = default; VirtualConnectionRouter::~VirtualConnectionRouter() = default; +void VirtualConnectionRouter::AddConnection( + VirtualConnection virtual_connection, + VirtualConnection::AssociatedData associated_data) { + auto& socket_map = connections_[virtual_connection.socket_id]; + auto local_entries = socket_map.equal_range(virtual_connection.local_id); + auto it = std::find_if( + local_entries.first, local_entries.second, + [&virtual_connection](const std::pair<std::string, VCTail>& entry) { + return entry.second.peer_id == virtual_connection.peer_id; + }); + if (it == socket_map.end()) { + socket_map.emplace(std::move(virtual_connection.local_id), + VCTail{std::move(virtual_connection.peer_id), + std::move(associated_data)}); + } +} + +bool VirtualConnectionRouter::RemoveConnection( + const VirtualConnection& virtual_connection, + VirtualConnection::CloseReason reason) { + auto socket_entry = connections_.find(virtual_connection.socket_id); + if (socket_entry == connections_.end()) { + return false; + } + + auto& socket_map = socket_entry->second; + auto local_entries = socket_map.equal_range(virtual_connection.local_id); + if (local_entries.first == socket_map.end()) { + return false; + } + for (auto it = local_entries.first; it != local_entries.second; ++it) { + if (it->second.peer_id == virtual_connection.peer_id) { + socket_map.erase(it); + if (socket_map.empty()) { + connections_.erase(socket_entry); + } + return true; + } + } + return false; +} + +void VirtualConnectionRouter::RemoveConnectionsByLocalId( + const std::string& local_id) { + for (auto socket_entry = connections_.begin(); + socket_entry != connections_.end();) { + auto& socket_map = socket_entry->second; + auto local_entries = socket_map.equal_range(local_id); + if (local_entries.first != socket_map.end()) { + socket_map.erase(local_entries.first, local_entries.second); + if (socket_map.empty()) { + socket_entry = connections_.erase(socket_entry); + continue; + } + } + ++socket_entry; + } +} + +void VirtualConnectionRouter::RemoveConnectionsBySocketId(int socket_id) { + auto entry = connections_.find(socket_id); + if (entry != connections_.end()) { + connections_.erase(entry); + } +} + +absl::optional<const VirtualConnection::AssociatedData*> +VirtualConnectionRouter::GetConnectionData( + const VirtualConnection& virtual_connection) const { + auto socket_entry = connections_.find(virtual_connection.socket_id); + if (socket_entry == connections_.end()) { + return absl::nullopt; + } + + auto& socket_map = socket_entry->second; + auto local_entries = socket_map.equal_range(virtual_connection.local_id); + if (local_entries.first == socket_map.end()) { + return absl::nullopt; + } + for (auto it = local_entries.first; it != local_entries.second; ++it) { + if (it->second.peer_id == virtual_connection.peer_id) { + return &it->second.data; + } + } + return absl::nullopt; +} + bool VirtualConnectionRouter::AddHandlerForLocalId( std::string local_id, CastMessageHandler* endpoint) { @@ -46,8 +128,7 @@ void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler, void VirtualConnectionRouter::CloseSocket(int id) { auto it = sockets_.find(id); if (it != sockets_.end()) { - vc_manager_->RemoveConnectionsBySocketId( - id, VirtualConnection::kTransportClosed); + RemoveConnectionsBySocketId(id); std::unique_ptr<CastSocket> socket = std::move(it->second.socket); SocketErrorHandler* error_handler = it->second.error_handler; sockets_.erase(it); @@ -63,7 +144,7 @@ Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn, } if (!IsTransportNamespace(message.namespace_()) && - !vc_manager_->GetConnectionData(virtual_conn)) { + !GetConnectionData(virtual_conn)) { return Error::Code::kNoActiveConnection; } auto it = sockets_.find(virtual_conn.socket_id); @@ -104,7 +185,7 @@ void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { const int id = socket->socket_id(); auto it = sockets_.find(id); if (it != sockets_.end()) { - vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown); + RemoveConnectionsBySocketId(id); std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket); SocketErrorHandler* error_handler = it->second.error_handler; sockets_.erase(it); @@ -123,8 +204,8 @@ void VirtualConnectionRouter::OnMessage(CastSocket* socket, } } else { if (!IsTransportNamespace(message.namespace_()) && - !vc_manager_->GetConnectionData(VirtualConnection{ - local_id, message.source_id(), socket->socket_id()})) { + !GetConnectionData(VirtualConnection{local_id, message.source_id(), + socket->socket_id()})) { return; } auto it = endpoints_.find(local_id); diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h index 1bbf2bc1..5080e948 100644 --- a/cast/common/channel/virtual_connection_router.h +++ b/cast/common/channel/virtual_connection_router.h @@ -10,15 +10,15 @@ #include <memory> #include <string> +#include "absl/types/optional.h" #include "cast/common/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/virtual_connection.h" #include "cast/common/public/cast_socket.h" namespace openscreen { namespace cast { class CastMessageHandler; -struct VirtualConnection; -class VirtualConnectionManager; // Handles CastSockets by routing received messages to appropriate message // handlers based on the VirtualConnection's local ID and sending messages over @@ -37,8 +37,13 @@ class VirtualConnectionManager; // // 4. Anything Foo wants to send (launch, app availability, etc.) goes through // VCRouter::Send via an appropriate VC. The virtual connection is not -// created automatically, so Foo should either ensure its existence via logic -// or check with the VirtualConnectionManager first. +// created automatically, so AddConnection() must be called first. +// +// 5. Anything Foo wants to receive must be registered with a handler by calling +// AddHandlerForLocalId(). +// +// 6. Foo is expected to clean-up after itself (#4 and #5) by calling +// RemoveConnection() and RemoveHandlerForLocalId(). class VirtualConnectionRouter final : public CastSocket::Client { public: class SocketErrorHandler { @@ -47,10 +52,39 @@ class VirtualConnectionRouter final : public CastSocket::Client { virtual void OnError(CastSocket* socket, Error error) = 0; }; - explicit VirtualConnectionRouter(VirtualConnectionManager* vc_manager); + VirtualConnectionRouter(); ~VirtualConnectionRouter() override; - // These return whether the given |local_id| was successfully added/removed. + // Adds a VirtualConnection, if one does not already exist, to enable routing + // of peer-to-peer messages. + void AddConnection(VirtualConnection virtual_connection, + VirtualConnection::AssociatedData associated_data); + + // Removes a VirtualConnection and returns true if a connection matching + // |virtual_connection| was found and removed. + bool RemoveConnection(const VirtualConnection& virtual_connection, + VirtualConnection::CloseReason reason); + + // Removes all VirtualConnections whose local endpoint matches the given + // |local_id|. + void RemoveConnectionsByLocalId(const std::string& local_id); + + // Removes all VirtualConnections whose traffic passes over the socket + // referenced by |socket_id|. + void RemoveConnectionsBySocketId(int socket_id); + + // Returns the AssociatedData for a |virtual_connection| if a connection + // exists, nullopt otherwise. The pointer isn't stable in the long term; so, + // if it actually needs to be stored for later, the caller should make a copy. + absl::optional<const VirtualConnection::AssociatedData*> GetConnectionData( + const VirtualConnection& virtual_connection) const; + + // Adds/Removes a CastMessageHandler for all messages destined for the given + // |endpoint| referred to by |local_id|, and returns whether the given + // |local_id| was successfully added/removed. + // + // Note: Clients will need to separately call AddConnection(), and + // RemoveConnection() or RemoveConnectionsByLocalId(). bool AddHandlerForLocalId(std::string local_id, CastMessageHandler* endpoint); bool RemoveHandlerForLocalId(const std::string& local_id); @@ -70,15 +104,23 @@ class VirtualConnectionRouter final : public CastSocket::Client { void OnMessage(CastSocket* socket, ::cast::channel::CastMessage message) override; - VirtualConnectionManager* manager() { return vc_manager_; } - private: + // This struct simply stores the remainder of the data {VirtualConnection, + // VirtualConnection::AssociatedData} that is not broken up into map keys for + // |connections_|. + struct VCTail { + std::string peer_id; + VirtualConnection::AssociatedData data; + }; + struct SocketWithHandler { std::unique_ptr<CastSocket> socket; SocketErrorHandler* error_handler; }; - VirtualConnectionManager* const vc_manager_; + std::map<int /* socket_id */, + std::multimap<std::string /* local_id */, VCTail>> + connections_; std::map<int, SocketWithHandler> sockets_; std::map<std::string /* local_id */, CastMessageHandler*> endpoints_; }; diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc index b05d10e3..69a08c53 100644 --- a/cast/common/channel/virtual_connection_router_unittest.cc +++ b/cast/common/channel/virtual_connection_router_unittest.cc @@ -11,7 +11,6 @@ #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_cast_message_handler.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/public/cast_socket.h" #include "gtest/gtest.h" @@ -19,6 +18,19 @@ namespace openscreen { namespace cast { namespace { +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0), + "V2 1.0 constants must be equal"); +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1), + "V2 1.1 constants must be equal"); +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2), + "V2 1.2 constants must be equal"); +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3), + "V2 1.3 constants must be equal"); + using ::cast::channel::CastMessage; using ::testing::_; using ::testing::Invoke; @@ -44,21 +56,115 @@ class VirtualConnectionRouterTest : public ::testing::Test { MockSocketErrorHandler mock_error_handler_; - VirtualConnectionManager local_manager_; - VirtualConnectionRouter local_router_{&local_manager_}; + VirtualConnectionRouter local_router_; + VirtualConnectionRouter remote_router_; - VirtualConnectionManager remote_manager_; - VirtualConnectionRouter remote_router_{&remote_manager_}; + VirtualConnection vc1_{"local1", "peer1", 75}; + VirtualConnection vc2_{"local2", "peer2", 76}; + VirtualConnection vc3_{"local1", "peer3", 75}; }; } // namespace +TEST_F(VirtualConnectionRouterTest, NoConnections) { + EXPECT_FALSE(local_router_.GetConnectionData(vc1_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc2_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc3_)); +} + +TEST_F(VirtualConnectionRouterTest, AddConnections) { + VirtualConnection::AssociatedData data1 = {}; + + local_router_.AddConnection(vc1_, std::move(data1)); + EXPECT_TRUE(local_router_.GetConnectionData(vc1_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc2_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc3_)); + + VirtualConnection::AssociatedData data2 = {}; + local_router_.AddConnection(vc2_, std::move(data2)); + EXPECT_TRUE(local_router_.GetConnectionData(vc1_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc2_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc3_)); + + VirtualConnection::AssociatedData data3 = {}; + local_router_.AddConnection(vc3_, std::move(data3)); + EXPECT_TRUE(local_router_.GetConnectionData(vc1_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc2_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc3_)); +} + +TEST_F(VirtualConnectionRouterTest, RemoveConnections) { + VirtualConnection::AssociatedData data1 = {}; + VirtualConnection::AssociatedData data2 = {}; + VirtualConnection::AssociatedData data3 = {}; + + local_router_.AddConnection(vc1_, std::move(data1)); + local_router_.AddConnection(vc2_, std::move(data2)); + local_router_.AddConnection(vc3_, std::move(data3)); + + EXPECT_TRUE(local_router_.RemoveConnection( + vc1_, VirtualConnection::CloseReason::kClosedBySelf)); + EXPECT_FALSE(local_router_.GetConnectionData(vc1_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc2_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc3_)); + + EXPECT_TRUE(local_router_.RemoveConnection( + vc2_, VirtualConnection::CloseReason::kClosedBySelf)); + EXPECT_FALSE(local_router_.GetConnectionData(vc1_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc2_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc3_)); + + EXPECT_TRUE(local_router_.RemoveConnection( + vc3_, VirtualConnection::CloseReason::kClosedBySelf)); + EXPECT_FALSE(local_router_.GetConnectionData(vc1_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc2_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc3_)); + + EXPECT_FALSE(local_router_.RemoveConnection( + vc1_, VirtualConnection::CloseReason::kClosedBySelf)); + EXPECT_FALSE(local_router_.RemoveConnection( + vc2_, VirtualConnection::CloseReason::kClosedBySelf)); + EXPECT_FALSE(local_router_.RemoveConnection( + vc3_, VirtualConnection::CloseReason::kClosedBySelf)); +} + +TEST_F(VirtualConnectionRouterTest, RemoveConnectionsByIds) { + VirtualConnection::AssociatedData data1 = {}; + VirtualConnection::AssociatedData data2 = {}; + VirtualConnection::AssociatedData data3 = {}; + + local_router_.AddConnection(vc1_, std::move(data1)); + local_router_.AddConnection(vc2_, std::move(data2)); + local_router_.AddConnection(vc3_, std::move(data3)); + + local_router_.RemoveConnectionsByLocalId("local1"); + EXPECT_FALSE(local_router_.GetConnectionData(vc1_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc2_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc3_)); + + data1 = {}; + data2 = {}; + data3 = {}; + local_router_.AddConnection(vc1_, std::move(data1)); + local_router_.AddConnection(vc2_, std::move(data2)); + local_router_.AddConnection(vc3_, std::move(data3)); + local_router_.RemoveConnectionsBySocketId(76); + EXPECT_TRUE(local_router_.GetConnectionData(vc1_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc2_)); + EXPECT_TRUE(local_router_.GetConnectionData(vc3_)); + + local_router_.RemoveConnectionsBySocketId(75); + EXPECT_FALSE(local_router_.GetConnectionData(vc1_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc2_)); + EXPECT_FALSE(local_router_.GetConnectionData(vc3_)); +} + TEST_F(VirtualConnectionRouterTest, LocalIdHandler) { MockCastMessageHandler mock_message_handler; local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler); - local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873", - local_socket_->socket_id()}, - {}); + local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873", + local_socket_->socket_id()}, + {}); CastMessage message; message.set_protocol_version( @@ -84,9 +190,9 @@ TEST_F(VirtualConnectionRouterTest, LocalIdHandler) { TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) { MockCastMessageHandler mock_message_handler; local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler); - local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873", - local_socket_->socket_id()}, - {}); + local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873", + local_socket_->socket_id()}, + {}); CastMessage message; message.set_protocol_version( @@ -108,16 +214,15 @@ TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) { } TEST_F(VirtualConnectionRouterTest, SendMessage) { - local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321", - local_socket_->socket_id()}, - {}); + local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321", + local_socket_->socket_id()}, + {}); MockCastMessageHandler destination; remote_router_.AddHandlerForLocalId("sender-4321", &destination); - remote_manager_.AddConnection( - VirtualConnection{"sender-4321", "receiver-1234", - remote_socket_->socket_id()}, - {}); + remote_router_.AddConnection(VirtualConnection{"sender-4321", "receiver-1234", + remote_socket_->socket_id()}, + {}); CastMessage message; message.set_protocol_version( @@ -142,15 +247,15 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) { } TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) { - local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321", - local_socket_->socket_id()}, - {}); + local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321", + local_socket_->socket_id()}, + {}); EXPECT_CALL(mock_error_handler_, OnClose(local_socket_)).Times(1); int id = local_socket_->socket_id(); local_router_.CloseSocket(id); - EXPECT_FALSE(local_manager_.GetConnectionData( + EXPECT_FALSE(local_router_.GetConnectionData( VirtualConnection{"receiver-1234", "sender-4321", id})); } diff --git a/cast/receiver/application_agent.cc b/cast/receiver/application_agent.cc index ec8a44a4..06ca3095 100644 --- a/cast/receiver/application_agent.cc +++ b/cast/receiver/application_agent.cc @@ -50,8 +50,7 @@ ApplicationAgent::ApplicationAgent( DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider) : task_runner_(task_runner), auth_handler_(credentials_provider), - connection_handler_(&connection_manager_, this), - router_(&connection_manager_), + connection_handler_(&router_, this), message_port_(&router_) { router_.AddHandlerForLocalId(kPlatformReceiverId, this); } diff --git a/cast/receiver/application_agent.h b/cast/receiver/application_agent.h index a6a792aa..10818f99 100644 --- a/cast/receiver/application_agent.h +++ b/cast/receiver/application_agent.h @@ -12,7 +12,6 @@ #include "cast/common/channel/cast_socket_message_port.h" #include "cast/common/channel/connection_namespace_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/receiver/channel/device_auth_namespace_handler.h" @@ -156,9 +155,8 @@ class ApplicationAgent final TaskRunner* const task_runner_; DeviceAuthNamespaceHandler auth_handler_; - VirtualConnectionManager connection_manager_; - ConnectionNamespaceHandler connection_handler_; VirtualConnectionRouter router_; + ConnectionNamespaceHandler connection_handler_; std::map<std::string, Application*> registered_applications_; Application* idle_screen_app_ = nullptr; diff --git a/cast/receiver/channel/device_auth_namespace_handler_unittest.cc b/cast/receiver/channel/device_auth_namespace_handler_unittest.cc index edfc807b..7d4b736c 100644 --- a/cast/receiver/channel/device_auth_namespace_handler_unittest.cc +++ b/cast/receiver/channel/device_auth_namespace_handler_unittest.cc @@ -11,7 +11,6 @@ #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/receiver/channel/static_credentials.h" @@ -55,8 +54,7 @@ class DeviceAuthNamespaceHandlerTest : public ::testing::Test { CastSocket* socket_; StaticCredentialsProvider creds_; - VirtualConnectionManager manager_; - VirtualConnectionRouter router_{&manager_}; + VirtualConnectionRouter router_; DeviceAuthNamespaceHandler auth_handler_{&creds_}; }; diff --git a/cast/sender/cast_app_discovery_service_impl_unittest.cc b/cast/sender/cast_app_discovery_service_impl_unittest.cc index 9f2f8793..a2eeb046 100644 --- a/cast/sender/cast_app_discovery_service_impl_unittest.cc +++ b/cast/sender/cast_app_discovery_service_impl_unittest.cc @@ -4,9 +4,10 @@ #include "cast/sender/cast_app_discovery_service_impl.h" +#include <utility> + #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/service_info.h" #include "cast/sender/testing/test_helpers.h" @@ -79,12 +80,10 @@ class CastAppDiscoveryServiceImplTest : public ::testing::Test { FakeCastSocketPair fake_cast_socket_pair_; int32_t socket_id_; MockSocketErrorHandler mock_error_handler_; - VirtualConnectionManager manager_; - VirtualConnectionRouter router_{&manager_}; + VirtualConnectionRouter router_; FakeClock clock_{Clock::now()}; FakeTaskRunner task_runner_{&clock_}; - CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now, - &task_runner_}; + CastPlatformClient platform_client_{&router_, &FakeClock::now, &task_runner_}; CastAppDiscoveryServiceImpl app_discovery_service_{&platform_client_, &FakeClock::now}; diff --git a/cast/sender/cast_platform_client.cc b/cast/sender/cast_platform_client.cc index f57adc8b..c321201a 100644 --- a/cast/sender/cast_platform_client.cc +++ b/cast/sender/cast_platform_client.cc @@ -9,7 +9,6 @@ #include <utility> #include "absl/strings/str_cat.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/common/public/service_info.h" @@ -23,21 +22,20 @@ namespace cast { static constexpr std::chrono::seconds kRequestTimeout = std::chrono::seconds(5); CastPlatformClient::CastPlatformClient(VirtualConnectionRouter* router, - VirtualConnectionManager* manager, ClockNowFunctionPtr clock, TaskRunner* task_runner) : sender_id_(MakeUniqueSessionId("sender")), virtual_conn_router_(router), - virtual_conn_manager_(manager), clock_(clock), task_runner_(task_runner) { - OSP_DCHECK(virtual_conn_manager_); + OSP_DCHECK(virtual_conn_router_); OSP_DCHECK(clock_); OSP_DCHECK(task_runner_); virtual_conn_router_->AddHandlerForLocalId(sender_id_, this); } CastPlatformClient::~CastPlatformClient() { + virtual_conn_router_->RemoveConnectionsByLocalId(sender_id_); virtual_conn_router_->RemoveHandlerForLocalId(sender_id_); for (auto& pending_requests : pending_requests_by_device_id_) { @@ -73,9 +71,9 @@ absl::optional<int> CastPlatformClient::RequestAppAvailability( request_id, app_id, std::move(timeout), std::move(callback)}); VirtualConnection virtual_conn{sender_id_, kPlatformReceiverId, socket_id}; - if (!virtual_conn_manager_->GetConnectionData(virtual_conn)) { - virtual_conn_manager_->AddConnection(virtual_conn, - VirtualConnection::AssociatedData{}); + if (!virtual_conn_router_->GetConnectionData(virtual_conn)) { + virtual_conn_router_->AddConnection(virtual_conn, + VirtualConnection::AssociatedData{}); } virtual_conn_router_->Send(std::move(virtual_conn), diff --git a/cast/sender/cast_platform_client.h b/cast/sender/cast_platform_client.h index 41ad7fc7..8ea9a99a 100644 --- a/cast/sender/cast_platform_client.h +++ b/cast/sender/cast_platform_client.h @@ -7,7 +7,9 @@ #include <functional> #include <map> +#include <memory> #include <string> +#include <vector> #include "absl/types/optional.h" #include "cast/common/channel/cast_message_handler.h" @@ -19,7 +21,6 @@ namespace openscreen { namespace cast { struct ServiceInfo; -class VirtualConnectionManager; class VirtualConnectionRouter; // This class handles Cast messages that generally relate to the "platform", in @@ -35,7 +36,6 @@ class CastPlatformClient final : public CastMessageHandler { std::function<void(const std::string& app_id, AppAvailabilityResult)>; CastPlatformClient(VirtualConnectionRouter* router, - VirtualConnectionManager* manager, ClockNowFunctionPtr clock, TaskRunner* task_runner); ~CastPlatformClient() override; @@ -82,7 +82,6 @@ class CastPlatformClient final : public CastMessageHandler { const std::string sender_id_; VirtualConnectionRouter* const virtual_conn_router_; - VirtualConnectionManager* const virtual_conn_manager_; std::map<std::string /* device_id */, int> socket_id_by_device_id_; std::map<std::string /* device_id */, PendingRequests> pending_requests_by_device_id_; diff --git a/cast/sender/cast_platform_client_unittest.cc b/cast/sender/cast_platform_client_unittest.cc index eaaf5c6e..ae721a15 100644 --- a/cast/sender/cast_platform_client_unittest.cc +++ b/cast/sender/cast_platform_client_unittest.cc @@ -4,9 +4,10 @@ #include "cast/sender/cast_platform_client.h" +#include <utility> + #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/service_info.h" #include "cast/sender/testing/test_helpers.h" @@ -46,12 +47,10 @@ class CastPlatformClientTest : public ::testing::Test { FakeCastSocketPair fake_cast_socket_pair_; CastSocket* socket_ = nullptr; MockSocketErrorHandler mock_error_handler_; - VirtualConnectionManager manager_; - VirtualConnectionRouter router_{&manager_}; + VirtualConnectionRouter router_; FakeClock clock_{Clock::now()}; FakeTaskRunner task_runner_{&clock_}; - CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now, - &task_runner_}; + CastPlatformClient platform_client_{&router_, &FakeClock::now, &task_runner_}; ServiceInfo receiver_; }; diff --git a/cast/standalone_sender/looping_file_cast_agent.cc b/cast/standalone_sender/looping_file_cast_agent.cc index a7e99fdb..727941aa 100644 --- a/cast/standalone_sender/looping_file_cast_agent.cc +++ b/cast/standalone_sender/looping_file_cast_agent.cc @@ -24,10 +24,8 @@ using DeviceMediaPolicy = SenderSocketFactory::DeviceMediaPolicy; LoopingFileCastAgent::LoopingFileCastAgent(TaskRunner* task_runner) : task_runner_(task_runner) { - router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_, - &connection_manager_); message_port_ = - MakeSerialDelete<CastSocketMessagePort>(task_runner_, router_.get()); + MakeSerialDelete<CastSocketMessagePort>(task_runner_, &router_); socket_factory_ = MakeSerialDelete<SenderSocketFactory>(task_runner_, this, task_runner_); connection_factory_ = SerialDeletePtr<TlsConnectionFactory>( @@ -49,7 +47,7 @@ void LoopingFileCastAgent::Connect(ConnectionSettings settings) { task_runner_->PostTask([this, policy] { wake_lock_ = ScopedWakeLock::Create(task_runner_); socket_factory_->Connect(connection_settings_->receiver_endpoint, policy, - router_.get()); + &router_); }); } @@ -75,7 +73,7 @@ void LoopingFileCastAgent::OnConnected(SenderSocketFactory* factory, OSP_LOG_INFO << "Received connection from peer at: " << endpoint; message_port_->SetSocket(socket->GetWeakPtr()); - router_->TakeSocket(this, std::move(socket)); + router_.TakeSocket(this, std::move(socket)); CreateAndStartSession(); } @@ -144,7 +142,7 @@ void LoopingFileCastAgent::StopCurrentSession() { current_session_.reset(); environment_.reset(); file_sender_.reset(); - router_->CloseSocket(message_port_->GetSocketId()); + router_.CloseSocket(message_port_->GetSocketId()); message_port_->SetSocket(nullptr); } diff --git a/cast/standalone_sender/looping_file_cast_agent.h b/cast/standalone_sender/looping_file_cast_agent.h index abe91c96..eae7a8be 100644 --- a/cast/standalone_sender/looping_file_cast_agent.h +++ b/cast/standalone_sender/looping_file_cast_agent.h @@ -13,7 +13,6 @@ #include "absl/types/optional.h" #include "cast/common/channel/cast_socket_message_port.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/sender/public/sender_socket_factory.h" @@ -94,9 +93,8 @@ class LoopingFileCastAgent final void StopCurrentSession(); // Member variables set as part of construction. - VirtualConnectionManager connection_manager_; TaskRunner* const task_runner_; - SerialDeletePtr<VirtualConnectionRouter> router_; + VirtualConnectionRouter router_; SerialDeletePtr<CastSocketMessagePort> message_port_; SerialDeletePtr<SenderSocketFactory> socket_factory_; SerialDeletePtr<TlsConnectionFactory> connection_factory_; diff --git a/cast/test/cast_socket_e2e_test.cc b/cast/test/cast_socket_e2e_test.cc index bfd1bbf9..1f3273fa 100644 --- a/cast/test/cast_socket_e2e_test.cc +++ b/cast/test/cast_socket_e2e_test.cc @@ -12,7 +12,6 @@ #include "cast/common/certificate/testing/test_helpers.h" #include "cast/common/channel/connection_namespace_handler.h" #include "cast/common/channel/message_util.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/receiver/channel/device_auth_namespace_handler.h" @@ -141,8 +140,7 @@ class CastSocketE2ETest : public ::testing::Test { std::chrono::milliseconds(0)); task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner(); - sender_router_ = MakeSerialDelete<VirtualConnectionRouter>( - task_runner_, &sender_vc_manager_); + sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_); sender_client_ = std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get()); sender_factory_ = MakeSerialDelete<SenderSocketFactory>( @@ -161,8 +159,7 @@ class CastSocketE2ETest : public ::testing::Test { CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der); auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>( task_runner_, credentials_.provider.get()); - receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>( - task_runner_, &receiver_vc_manager_); + receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_); receiver_router_->AddHandlerForLocalId(kPlatformReceiverId, auth_handler_.get()); receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>( @@ -258,14 +255,12 @@ class CastSocketE2ETest : public ::testing::Test { TaskRunner* task_runner_; // NOTE: Sender components. - VirtualConnectionManager sender_vc_manager_; SerialDeletePtr<VirtualConnectionRouter> sender_router_; std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_; SerialDeletePtr<SenderSocketFactory> sender_factory_; SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_; // NOTE: Receiver components. - VirtualConnectionManager receiver_vc_manager_; SerialDeletePtr<VirtualConnectionRouter> receiver_router_; GeneratedCredentials credentials_; SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_; diff --git a/cast/test/device_auth_test.cc b/cast/test/device_auth_test.cc index 36a9fc8a..0776e57e 100644 --- a/cast/test/device_auth_test.cc +++ b/cast/test/device_auth_test.cc @@ -9,7 +9,6 @@ #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" -#include "cast/common/channel/virtual_connection_manager.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/receiver/channel/device_auth_namespace_handler.h" @@ -127,8 +126,7 @@ class DeviceAuthTest : public ::testing::Test { CastSocket* socket_; StaticCredentialsProvider creds_; - VirtualConnectionManager manager_; - VirtualConnectionRouter router_{&manager_}; + VirtualConnectionRouter router_; DeviceAuthNamespaceHandler auth_handler_{&creds_}; }; |