aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cast/common/BUILD.gn3
-rw-r--r--cast/common/channel/cast_socket_message_port.cc9
-rw-r--r--cast/common/channel/connection_namespace_handler.cc44
-rw-r--r--cast/common/channel/connection_namespace_handler.h19
-rw-r--r--cast/common/channel/connection_namespace_handler_unittest.cc26
-rw-r--r--cast/common/channel/namespace_router_unittest.cc6
-rw-r--r--cast/common/channel/virtual_connection_manager.cc119
-rw-r--r--cast/common/channel/virtual_connection_manager.h61
-rw-r--r--cast/common/channel/virtual_connection_manager_unittest.cc142
-rw-r--r--cast/common/channel/virtual_connection_router.cc105
-rw-r--r--cast/common/channel/virtual_connection_router.h60
-rw-r--r--cast/common/channel/virtual_connection_router_unittest.cc149
-rw-r--r--cast/receiver/application_agent.cc3
-rw-r--r--cast/receiver/application_agent.h4
-rw-r--r--cast/receiver/channel/device_auth_namespace_handler_unittest.cc4
-rw-r--r--cast/sender/cast_app_discovery_service_impl_unittest.cc9
-rw-r--r--cast/sender/cast_platform_client.cc12
-rw-r--r--cast/sender/cast_platform_client.h5
-rw-r--r--cast/sender/cast_platform_client_unittest.cc9
-rw-r--r--cast/standalone_sender/looping_file_cast_agent.cc10
-rw-r--r--cast/standalone_sender/looping_file_cast_agent.h4
-rw-r--r--cast/test/cast_socket_e2e_test.cc9
-rw-r--r--cast/test/device_auth_test.cc4
23 files changed, 345 insertions, 471 deletions
diff --git a/cast/common/BUILD.gn b/cast/common/BUILD.gn
index c1ea0b3c..c1a8acd7 100644
--- a/cast/common/BUILD.gn
+++ b/cast/common/BUILD.gn
@@ -64,8 +64,6 @@ source_set("channel") {
"channel/namespace_router.cc",
"channel/namespace_router.h",
"channel/virtual_connection.h",
- "channel/virtual_connection_manager.cc",
- "channel/virtual_connection_manager.h",
"channel/virtual_connection_router.cc",
"channel/virtual_connection_router.h",
"public/cast_socket.h",
@@ -152,7 +150,6 @@ source_set("unittests") {
"channel/connection_namespace_handler_unittest.cc",
"channel/message_framer_unittest.cc",
"channel/namespace_router_unittest.cc",
- "channel/virtual_connection_manager_unittest.cc",
"channel/virtual_connection_router_unittest.cc",
"public/service_info_unittest.cc",
]
diff --git a/cast/common/channel/cast_socket_message_port.cc b/cast/common/channel/cast_socket_message_port.cc
index 3bcdca22..0c51304b 100644
--- a/cast/common/channel/cast_socket_message_port.cc
+++ b/cast/common/channel/cast_socket_message_port.cc
@@ -9,7 +9,6 @@
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/virtual_connection.h"
-#include "cast/common/channel/virtual_connection_manager.h"
namespace openscreen {
namespace cast {
@@ -51,8 +50,7 @@ void CastSocketMessagePort::ResetClient() {
client_ = nullptr;
router_->RemoveHandlerForLocalId(client_sender_id_);
- router_->manager()->RemoveConnectionsByLocalId(
- client_sender_id_, VirtualConnection::CloseReason::kClosedBySelf);
+ router_->RemoveConnectionsByLocalId(client_sender_id_);
client_sender_id_.clear();
}
@@ -72,9 +70,8 @@ void CastSocketMessagePort::PostMessage(
VirtualConnection connection{client_sender_id_, destination_sender_id,
socket_->socket_id()};
- if (!router_->manager()->GetConnectionData(connection)) {
- router_->manager()->AddConnection(connection,
- VirtualConnection::AssociatedData{});
+ if (!router_->GetConnectionData(connection)) {
+ router_->AddConnection(connection, VirtualConnection::AssociatedData{});
}
const Error send_error = router_->Send(
diff --git a/cast/common/channel/connection_namespace_handler.cc b/cast/common/channel/connection_namespace_handler.cc
index a449dcbd..dd459060 100644
--- a/cast/common/channel/connection_namespace_handler.cc
+++ b/cast/common/channel/connection_namespace_handler.cc
@@ -4,6 +4,7 @@
#include "cast/common/channel/connection_namespace_handler.h"
+#include <algorithm>
#include <string>
#include <type_traits>
#include <utility>
@@ -12,7 +13,6 @@
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/virtual_connection.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "util/json/json_serialization.h"
@@ -82,11 +82,11 @@ VirtualConnection::CloseReason GetCloseReason(
} // namespace
ConnectionNamespaceHandler::ConnectionNamespaceHandler(
- VirtualConnectionManager* vc_manager,
+ VirtualConnectionRouter* vc_router,
VirtualConnectionPolicy* vc_policy)
- : vc_manager_(vc_manager), vc_policy_(vc_policy) {
- OSP_DCHECK(vc_manager);
- OSP_DCHECK(vc_policy);
+ : vc_router_(vc_router), vc_policy_(vc_policy) {
+ OSP_DCHECK(vc_router_);
+ OSP_DCHECK(vc_policy_);
}
ConnectionNamespaceHandler::~ConnectionNamespaceHandler() = default;
@@ -120,17 +120,16 @@ void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
absl::string_view type_str = type.value();
if (type_str == kMessageTypeConnect) {
- HandleConnect(router, socket, std::move(message), std::move(value));
+ HandleConnect(socket, std::move(message), std::move(value));
} else if (type_str == kMessageTypeClose) {
- HandleClose(router, socket, std::move(message), std::move(value));
+ HandleClose(socket, std::move(message), std::move(value));
} else {
// NOTE: Unknown message type so ignore it.
// TODO(btolsch): Should be included in future error reporting.
}
}
-void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
- CastSocket* socket,
+void ConnectionNamespaceHandler::HandleConnect(CastSocket* socket,
CastMessage message,
Json::Value parsed_message) {
if (message.destination_id() == kBroadcastId ||
@@ -142,7 +141,7 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
std::move(message.source_id()),
ToCastSocketId(socket)};
if (!vc_policy_->IsConnectionAllowed(virtual_conn)) {
- SendClose(router, std::move(virtual_conn));
+ SendClose(virtual_conn);
return;
}
@@ -153,7 +152,7 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
int int_type = maybe_conn_type.value();
if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) ||
int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) {
- SendClose(router, std::move(virtual_conn));
+ SendClose(virtual_conn);
return;
}
conn_type = static_cast<VirtualConnection::Type>(int_type);
@@ -202,20 +201,19 @@ void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
// maintains compatibility with older senders that don't send a version and
// don't expect a response.
if (negotiated_version) {
- SendConnectedResponse(router, virtual_conn, negotiated_version.value());
+ SendConnectedResponse(virtual_conn, negotiated_version.value());
}
- vc_manager_->AddConnection(std::move(virtual_conn), std::move(data));
+ vc_router_->AddConnection(std::move(virtual_conn), std::move(data));
}
-void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router,
- CastSocket* socket,
+void ConnectionNamespaceHandler::HandleClose(CastSocket* socket,
CastMessage message,
Json::Value parsed_message) {
VirtualConnection virtual_conn{std::move(message.destination_id()),
std::move(message.source_id()),
ToCastSocketId(socket)};
- if (!vc_manager_->GetConnectionData(virtual_conn)) {
+ if (!vc_router_->GetConnectionData(virtual_conn)) {
return;
}
@@ -224,11 +222,11 @@ void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router,
OSP_DVLOG << "Connection closed (reason: " << reason
<< "): " << virtual_conn.local_id << ", " << virtual_conn.peer_id
<< ", " << virtual_conn.socket_id;
- vc_manager_->RemoveConnection(virtual_conn, reason);
+ vc_router_->RemoveConnection(virtual_conn, reason);
}
-void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router,
- VirtualConnection virtual_conn) {
+void ConnectionNamespaceHandler::SendClose(
+ const VirtualConnection& virtual_conn) {
Json::Value close_message(Json::ValueType::objectValue);
close_message[kMessageKeyType] = kMessageTypeClose;
@@ -237,13 +235,12 @@ void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router,
return;
}
- router->Send(
+ vc_router_->Send(
std::move(virtual_conn),
MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
}
void ConnectionNamespaceHandler::SendConnectedResponse(
- VirtualConnectionRouter* router,
const VirtualConnection& virtual_conn,
int max_protocol_version) {
Json::Value connected_message(Json::ValueType::objectValue);
@@ -256,8 +253,9 @@ void ConnectionNamespaceHandler::SendConnectedResponse(
return;
}
- router->Send(virtual_conn, MakeSimpleUTF8Message(kConnectionNamespace,
- std::move(result.value())));
+ vc_router_->Send(
+ virtual_conn,
+ MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
}
} // namespace cast
diff --git a/cast/common/channel/connection_namespace_handler.h b/cast/common/channel/connection_namespace_handler.h
index 5307e896..65388a80 100644
--- a/cast/common/channel/connection_namespace_handler.h
+++ b/cast/common/channel/connection_namespace_handler.h
@@ -13,7 +13,6 @@ namespace openscreen {
namespace cast {
struct VirtualConnection;
-class VirtualConnectionManager;
class VirtualConnectionRouter;
// Handles CastMessages in the connection namespace by opening and closing
@@ -28,8 +27,8 @@ class ConnectionNamespaceHandler final : public CastMessageHandler {
const VirtualConnection& virtual_conn) const = 0;
};
- // Both |vc_manager| and |vc_policy| should outlive this object.
- ConnectionNamespaceHandler(VirtualConnectionManager* vc_manager,
+ // Both |vc_router| and |vc_policy| should outlive this object.
+ ConnectionNamespaceHandler(VirtualConnectionRouter* vc_router,
VirtualConnectionPolicy* vc_policy);
~ConnectionNamespaceHandler() override;
@@ -39,22 +38,18 @@ class ConnectionNamespaceHandler final : public CastMessageHandler {
::cast::channel::CastMessage message) override;
private:
- void HandleConnect(VirtualConnectionRouter* router,
- CastSocket* socket,
+ void HandleConnect(CastSocket* socket,
::cast::channel::CastMessage message,
Json::Value parsed_message);
- void HandleClose(VirtualConnectionRouter* router,
- CastSocket* socket,
+ void HandleClose(CastSocket* socket,
::cast::channel::CastMessage message,
Json::Value parsed_message);
- void SendClose(VirtualConnectionRouter* router,
- VirtualConnection virtual_conn);
- void SendConnectedResponse(VirtualConnectionRouter* router,
- const VirtualConnection& virtual_conn,
+ void SendClose(const VirtualConnection& virtual_conn);
+ void SendConnectedResponse(const VirtualConnection& virtual_conn,
int max_protocol_version);
- VirtualConnectionManager* const vc_manager_;
+ VirtualConnectionRouter* const vc_router_;
VirtualConnectionPolicy* const vc_policy_;
};
diff --git a/cast/common/channel/connection_namespace_handler_unittest.cc b/cast/common/channel/connection_namespace_handler_unittest.cc
index f57a4253..00b56f8e 100644
--- a/cast/common/channel/connection_namespace_handler_unittest.cc
+++ b/cast/common/channel/connection_namespace_handler_unittest.cc
@@ -4,11 +4,14 @@
#include "cast/common/channel/connection_namespace_handler.h"
+#include <string>
+#include <utility>
+#include <vector>
+
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_socket_error_handler.h"
#include "cast/common/channel/virtual_connection.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "gmock/gmock.h"
@@ -139,9 +142,8 @@ class ConnectionNamespaceHandlerTest : public ::testing::Test {
CastSocket* socket_;
NiceMock<MockVirtualConnectionPolicy> vc_policy_;
- VirtualConnectionManager vc_manager_;
- VirtualConnectionRouter router_{&vc_manager_};
- ConnectionNamespaceHandler connection_namespace_handler_{&vc_manager_,
+ VirtualConnectionRouter router_;
+ ConnectionNamespaceHandler connection_namespace_handler_{&router_,
&vc_policy_};
const std::string sender_id_{"sender-5678"};
@@ -151,7 +153,7 @@ class ConnectionNamespaceHandlerTest : public ::testing::Test {
TEST_F(ConnectionNamespaceHandlerTest, Connect) {
connection_namespace_handler_.OnMessage(
&router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
- EXPECT_TRUE(vc_manager_.GetConnectionData(
+ EXPECT_TRUE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
@@ -166,7 +168,7 @@ TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) {
sender_id_);
connection_namespace_handler_.OnMessage(
&router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
- EXPECT_FALSE(vc_manager_.GetConnectionData(
+ EXPECT_FALSE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
}
@@ -179,7 +181,7 @@ TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) {
MakeVersionedConnectMessage(
sender_id_, receiver_id_,
::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {}));
- EXPECT_TRUE(vc_manager_.GetConnectionData(
+ EXPECT_TRUE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
}
@@ -194,31 +196,31 @@ TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) {
::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2,
{::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3,
::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0}));
- EXPECT_TRUE(vc_manager_.GetConnectionData(
+ EXPECT_TRUE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
}
TEST_F(ConnectionNamespaceHandlerTest, Close) {
connection_namespace_handler_.OnMessage(
&router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
- EXPECT_TRUE(vc_manager_.GetConnectionData(
+ EXPECT_TRUE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
connection_namespace_handler_.OnMessage(
&router_, socket_, MakeCloseMessage(sender_id_, receiver_id_));
- EXPECT_FALSE(vc_manager_.GetConnectionData(
+ EXPECT_FALSE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
}
TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) {
connection_namespace_handler_.OnMessage(
&router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
- EXPECT_TRUE(vc_manager_.GetConnectionData(
+ EXPECT_TRUE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
connection_namespace_handler_.OnMessage(
&router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_));
- EXPECT_TRUE(vc_manager_.GetConnectionData(
+ EXPECT_TRUE(router_.GetConnectionData(
VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
}
diff --git a/cast/common/channel/namespace_router_unittest.cc b/cast/common/channel/namespace_router_unittest.cc
index 96907c08..7f4ce17e 100644
--- a/cast/common/channel/namespace_router_unittest.cc
+++ b/cast/common/channel/namespace_router_unittest.cc
@@ -4,11 +4,12 @@
#include "cast/common/channel/namespace_router.h"
+#include <utility>
+
#include "cast/common/channel/cast_message_handler.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_cast_message_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -27,8 +28,7 @@ class NamespaceRouterTest : public ::testing::Test {
CastSocket* socket() { return &fake_socket_.socket; }
FakeCastSocket fake_socket_;
- VirtualConnectionManager vc_manager_;
- VirtualConnectionRouter vc_router_{&vc_manager_};
+ VirtualConnectionRouter vc_router_;
NamespaceRouter router_;
};
diff --git a/cast/common/channel/virtual_connection_manager.cc b/cast/common/channel/virtual_connection_manager.cc
deleted file mode 100644
index 86e06706..00000000
--- a/cast/common/channel/virtual_connection_manager.cc
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "cast/common/channel/virtual_connection_manager.h"
-
-#include <type_traits>
-
-namespace openscreen {
-namespace cast {
-
-VirtualConnectionManager::VirtualConnectionManager() = default;
-
-VirtualConnectionManager::~VirtualConnectionManager() = default;
-
-void VirtualConnectionManager::AddConnection(
- VirtualConnection virtual_connection,
- VirtualConnection::AssociatedData associated_data) {
- auto& socket_map = connections_[virtual_connection.socket_id];
- auto local_entries = socket_map.equal_range(virtual_connection.local_id);
- auto it = std::find_if(
- local_entries.first, local_entries.second,
- [&virtual_connection](const std::pair<std::string, VCTail>& entry) {
- return entry.second.peer_id == virtual_connection.peer_id;
- });
- if (it == socket_map.end()) {
- socket_map.emplace(std::move(virtual_connection.local_id),
- VCTail{std::move(virtual_connection.peer_id),
- std::move(associated_data)});
- }
-}
-
-bool VirtualConnectionManager::RemoveConnection(
- const VirtualConnection& virtual_connection,
- VirtualConnection::CloseReason reason) {
- auto socket_entry = connections_.find(virtual_connection.socket_id);
- if (socket_entry == connections_.end()) {
- return false;
- }
-
- auto& socket_map = socket_entry->second;
- auto local_entries = socket_map.equal_range(virtual_connection.local_id);
- if (local_entries.first == socket_map.end()) {
- return false;
- }
- for (auto it = local_entries.first; it != local_entries.second; ++it) {
- if (it->second.peer_id == virtual_connection.peer_id) {
- socket_map.erase(it);
- if (socket_map.empty()) {
- connections_.erase(socket_entry);
- }
- return true;
- }
- }
- return false;
-}
-
-size_t VirtualConnectionManager::RemoveConnectionsByLocalId(
- const std::string& local_id,
- VirtualConnection::CloseReason reason) {
- size_t removed_count = 0;
- for (auto socket_entry = connections_.begin();
- socket_entry != connections_.end();) {
- auto& socket_map = socket_entry->second;
- auto local_entries = socket_map.equal_range(local_id);
- if (local_entries.first != socket_map.end()) {
- size_t current_count =
- std::distance(local_entries.first, local_entries.second);
- removed_count += current_count;
- socket_map.erase(local_entries.first, local_entries.second);
- if (socket_map.empty()) {
- socket_entry = connections_.erase(socket_entry);
- } else {
- ++socket_entry;
- }
- } else {
- ++socket_entry;
- }
- }
- return removed_count;
-}
-
-size_t VirtualConnectionManager::RemoveConnectionsBySocketId(
- int socket_id,
- VirtualConnection::CloseReason reason) {
- auto entry = connections_.find(socket_id);
- if (entry == connections_.end()) {
- return 0;
- }
-
- size_t removed_count = entry->second.size();
- connections_.erase(entry);
-
- return removed_count;
-}
-
-absl::optional<const VirtualConnection::AssociatedData*>
-VirtualConnectionManager::GetConnectionData(
- const VirtualConnection& virtual_connection) const {
- auto socket_entry = connections_.find(virtual_connection.socket_id);
- if (socket_entry == connections_.end()) {
- return absl::nullopt;
- }
-
- auto& socket_map = socket_entry->second;
- auto local_entries = socket_map.equal_range(virtual_connection.local_id);
- if (local_entries.first == socket_map.end()) {
- return absl::nullopt;
- }
- for (auto it = local_entries.first; it != local_entries.second; ++it) {
- if (it->second.peer_id == virtual_connection.peer_id) {
- return &it->second.data;
- }
- }
- return absl::nullopt;
-}
-
-} // namespace cast
-} // namespace openscreen
diff --git a/cast/common/channel/virtual_connection_manager.h b/cast/common/channel/virtual_connection_manager.h
deleted file mode 100644
index 902d2a96..00000000
--- a/cast/common/channel/virtual_connection_manager.h
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_
-#define CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_
-
-#include <cstdint>
-#include <map>
-#include <string>
-
-#include "absl/types/optional.h"
-#include "cast/common/channel/virtual_connection.h"
-
-namespace openscreen {
-namespace cast {
-
-// Maintains a collection of open VirtualConnections and associated data.
-class VirtualConnectionManager {
- public:
- VirtualConnectionManager();
- ~VirtualConnectionManager();
-
- void AddConnection(VirtualConnection virtual_connection,
- VirtualConnection::AssociatedData associated_data);
-
- // Returns true if a connection matching |virtual_connection| was found and
- // removed.
- bool RemoveConnection(const VirtualConnection& virtual_connection,
- VirtualConnection::CloseReason reason);
-
- // Returns the number of connections removed.
- size_t RemoveConnectionsByLocalId(const std::string& local_id,
- VirtualConnection::CloseReason reason);
- size_t RemoveConnectionsBySocketId(int socket_id,
- VirtualConnection::CloseReason reason);
-
- // Returns the AssociatedData for |virtual_connection| if a connection exists,
- // nullopt otherwise. The pointer isn't stable in the long term, so if it
- // actually needs to be stored for later, the caller should make a copy.
- absl::optional<const VirtualConnection::AssociatedData*> GetConnectionData(
- const VirtualConnection& virtual_connection) const;
-
- private:
- // This struct simply stores the remainder of the data {VirtualConnection,
- // VirtVirtualConnection::AssociatedData} that is not broken up into map keys
- // for |connections_|.
- struct VCTail {
- std::string peer_id;
- VirtualConnection::AssociatedData data;
- };
-
- std::map<int /* socket_id */,
- std::multimap<std::string /* local_id */, VCTail>>
- connections_;
-};
-
-} // namespace cast
-} // namespace openscreen
-
-#endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_
diff --git a/cast/common/channel/virtual_connection_manager_unittest.cc b/cast/common/channel/virtual_connection_manager_unittest.cc
deleted file mode 100644
index 963fcac7..00000000
--- a/cast/common/channel/virtual_connection_manager_unittest.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "cast/common/channel/virtual_connection_manager.h"
-
-#include "cast/common/channel/proto/cast_channel.pb.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace openscreen {
-namespace cast {
-namespace {
-
-static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 ==
- static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0),
- "V2 1.0 constants must be equal");
-static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 ==
- static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1),
- "V2 1.1 constants must be equal");
-static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 ==
- static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2),
- "V2 1.2 constants must be equal");
-static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 ==
- static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3),
- "V2 1.3 constants must be equal");
-
-using ::testing::_;
-using ::testing::Invoke;
-
-class VirtualConnectionManagerTest : public ::testing::Test {
- protected:
- VirtualConnectionManager manager_;
-
- VirtualConnection vc1_{"local1", "peer1", 75};
- VirtualConnection vc2_{"local2", "peer2", 76};
- VirtualConnection vc3_{"local1", "peer3", 75};
-};
-
-} // namespace
-
-TEST_F(VirtualConnectionManagerTest, NoConnections) {
- EXPECT_FALSE(manager_.GetConnectionData(vc1_));
- EXPECT_FALSE(manager_.GetConnectionData(vc2_));
- EXPECT_FALSE(manager_.GetConnectionData(vc3_));
-}
-
-TEST_F(VirtualConnectionManagerTest, AddConnections) {
- VirtualConnection::AssociatedData data1 = {};
-
- manager_.AddConnection(vc1_, std::move(data1));
- EXPECT_TRUE(manager_.GetConnectionData(vc1_));
- EXPECT_FALSE(manager_.GetConnectionData(vc2_));
- EXPECT_FALSE(manager_.GetConnectionData(vc3_));
-
- VirtualConnection::AssociatedData data2 = {};
- manager_.AddConnection(vc2_, std::move(data2));
- EXPECT_TRUE(manager_.GetConnectionData(vc1_));
- EXPECT_TRUE(manager_.GetConnectionData(vc2_));
- EXPECT_FALSE(manager_.GetConnectionData(vc3_));
-
- VirtualConnection::AssociatedData data3 = {};
- manager_.AddConnection(vc3_, std::move(data3));
- EXPECT_TRUE(manager_.GetConnectionData(vc1_));
- EXPECT_TRUE(manager_.GetConnectionData(vc2_));
- EXPECT_TRUE(manager_.GetConnectionData(vc3_));
-}
-
-TEST_F(VirtualConnectionManagerTest, RemoveConnections) {
- VirtualConnection::AssociatedData data1 = {};
- VirtualConnection::AssociatedData data2 = {};
- VirtualConnection::AssociatedData data3 = {};
-
- manager_.AddConnection(vc1_, std::move(data1));
- manager_.AddConnection(vc2_, std::move(data2));
- manager_.AddConnection(vc3_, std::move(data3));
-
- EXPECT_TRUE(manager_.RemoveConnection(
- vc1_, VirtualConnection::CloseReason::kClosedBySelf));
- EXPECT_FALSE(manager_.GetConnectionData(vc1_));
- EXPECT_TRUE(manager_.GetConnectionData(vc2_));
- EXPECT_TRUE(manager_.GetConnectionData(vc3_));
-
- EXPECT_TRUE(manager_.RemoveConnection(
- vc2_, VirtualConnection::CloseReason::kClosedBySelf));
- EXPECT_FALSE(manager_.GetConnectionData(vc1_));
- EXPECT_FALSE(manager_.GetConnectionData(vc2_));
- EXPECT_TRUE(manager_.GetConnectionData(vc3_));
-
- EXPECT_TRUE(manager_.RemoveConnection(
- vc3_, VirtualConnection::CloseReason::kClosedBySelf));
- EXPECT_FALSE(manager_.GetConnectionData(vc1_));
- EXPECT_FALSE(manager_.GetConnectionData(vc2_));
- EXPECT_FALSE(manager_.GetConnectionData(vc3_));
-
- EXPECT_FALSE(manager_.RemoveConnection(
- vc1_, VirtualConnection::CloseReason::kClosedBySelf));
- EXPECT_FALSE(manager_.RemoveConnection(
- vc2_, VirtualConnection::CloseReason::kClosedBySelf));
- EXPECT_FALSE(manager_.RemoveConnection(
- vc3_, VirtualConnection::CloseReason::kClosedBySelf));
-}
-
-TEST_F(VirtualConnectionManagerTest, RemoveConnectionsByIds) {
- VirtualConnection::AssociatedData data1 = {};
- VirtualConnection::AssociatedData data2 = {};
- VirtualConnection::AssociatedData data3 = {};
-
- manager_.AddConnection(vc1_, std::move(data1));
- manager_.AddConnection(vc2_, std::move(data2));
- manager_.AddConnection(vc3_, std::move(data3));
-
- EXPECT_EQ(manager_.RemoveConnectionsByLocalId(
- "local1", VirtualConnection::CloseReason::kClosedBySelf),
- 2u);
- EXPECT_FALSE(manager_.GetConnectionData(vc1_));
- EXPECT_TRUE(manager_.GetConnectionData(vc2_));
- EXPECT_FALSE(manager_.GetConnectionData(vc3_));
-
- data1 = {};
- data2 = {};
- data3 = {};
- manager_.AddConnection(vc1_, std::move(data1));
- manager_.AddConnection(vc2_, std::move(data2));
- manager_.AddConnection(vc3_, std::move(data3));
- EXPECT_EQ(manager_.RemoveConnectionsBySocketId(
- 76, VirtualConnection::CloseReason::kClosedBySelf),
- 1u);
- EXPECT_TRUE(manager_.GetConnectionData(vc1_));
- EXPECT_FALSE(manager_.GetConnectionData(vc2_));
- EXPECT_TRUE(manager_.GetConnectionData(vc3_));
-
- EXPECT_EQ(manager_.RemoveConnectionsBySocketId(
- 75, VirtualConnection::CloseReason::kClosedBySelf),
- 2u);
- EXPECT_FALSE(manager_.GetConnectionData(vc1_));
- EXPECT_FALSE(manager_.GetConnectionData(vc2_));
- EXPECT_FALSE(manager_.GetConnectionData(vc3_));
-}
-
-} // namespace cast
-} // namespace openscreen
diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc
index 140ca138..98e6d19d 100644
--- a/cast/common/channel/virtual_connection_router.cc
+++ b/cast/common/channel/virtual_connection_router.cc
@@ -9,7 +9,6 @@
#include "cast/common/channel/cast_message_handler.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "util/osp_logging.h"
namespace openscreen {
@@ -17,14 +16,97 @@ namespace cast {
using ::cast::channel::CastMessage;
-VirtualConnectionRouter::VirtualConnectionRouter(
- VirtualConnectionManager* vc_manager)
- : vc_manager_(vc_manager) {
- OSP_DCHECK(vc_manager);
-}
+VirtualConnectionRouter::VirtualConnectionRouter() = default;
VirtualConnectionRouter::~VirtualConnectionRouter() = default;
+void VirtualConnectionRouter::AddConnection(
+ VirtualConnection virtual_connection,
+ VirtualConnection::AssociatedData associated_data) {
+ auto& socket_map = connections_[virtual_connection.socket_id];
+ auto local_entries = socket_map.equal_range(virtual_connection.local_id);
+ auto it = std::find_if(
+ local_entries.first, local_entries.second,
+ [&virtual_connection](const std::pair<std::string, VCTail>& entry) {
+ return entry.second.peer_id == virtual_connection.peer_id;
+ });
+ if (it == socket_map.end()) {
+ socket_map.emplace(std::move(virtual_connection.local_id),
+ VCTail{std::move(virtual_connection.peer_id),
+ std::move(associated_data)});
+ }
+}
+
+bool VirtualConnectionRouter::RemoveConnection(
+ const VirtualConnection& virtual_connection,
+ VirtualConnection::CloseReason reason) {
+ auto socket_entry = connections_.find(virtual_connection.socket_id);
+ if (socket_entry == connections_.end()) {
+ return false;
+ }
+
+ auto& socket_map = socket_entry->second;
+ auto local_entries = socket_map.equal_range(virtual_connection.local_id);
+ if (local_entries.first == socket_map.end()) {
+ return false;
+ }
+ for (auto it = local_entries.first; it != local_entries.second; ++it) {
+ if (it->second.peer_id == virtual_connection.peer_id) {
+ socket_map.erase(it);
+ if (socket_map.empty()) {
+ connections_.erase(socket_entry);
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+void VirtualConnectionRouter::RemoveConnectionsByLocalId(
+ const std::string& local_id) {
+ for (auto socket_entry = connections_.begin();
+ socket_entry != connections_.end();) {
+ auto& socket_map = socket_entry->second;
+ auto local_entries = socket_map.equal_range(local_id);
+ if (local_entries.first != socket_map.end()) {
+ socket_map.erase(local_entries.first, local_entries.second);
+ if (socket_map.empty()) {
+ socket_entry = connections_.erase(socket_entry);
+ continue;
+ }
+ }
+ ++socket_entry;
+ }
+}
+
+void VirtualConnectionRouter::RemoveConnectionsBySocketId(int socket_id) {
+ auto entry = connections_.find(socket_id);
+ if (entry != connections_.end()) {
+ connections_.erase(entry);
+ }
+}
+
+absl::optional<const VirtualConnection::AssociatedData*>
+VirtualConnectionRouter::GetConnectionData(
+ const VirtualConnection& virtual_connection) const {
+ auto socket_entry = connections_.find(virtual_connection.socket_id);
+ if (socket_entry == connections_.end()) {
+ return absl::nullopt;
+ }
+
+ auto& socket_map = socket_entry->second;
+ auto local_entries = socket_map.equal_range(virtual_connection.local_id);
+ if (local_entries.first == socket_map.end()) {
+ return absl::nullopt;
+ }
+ for (auto it = local_entries.first; it != local_entries.second; ++it) {
+ if (it->second.peer_id == virtual_connection.peer_id) {
+ return &it->second.data;
+ }
+ }
+ return absl::nullopt;
+}
+
bool VirtualConnectionRouter::AddHandlerForLocalId(
std::string local_id,
CastMessageHandler* endpoint) {
@@ -46,8 +128,7 @@ void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
void VirtualConnectionRouter::CloseSocket(int id) {
auto it = sockets_.find(id);
if (it != sockets_.end()) {
- vc_manager_->RemoveConnectionsBySocketId(
- id, VirtualConnection::kTransportClosed);
+ RemoveConnectionsBySocketId(id);
std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
@@ -63,7 +144,7 @@ Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn,
}
if (!IsTransportNamespace(message.namespace_()) &&
- !vc_manager_->GetConnectionData(virtual_conn)) {
+ !GetConnectionData(virtual_conn)) {
return Error::Code::kNoActiveConnection;
}
auto it = sockets_.find(virtual_conn.socket_id);
@@ -104,7 +185,7 @@ void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
const int id = socket->socket_id();
auto it = sockets_.find(id);
if (it != sockets_.end()) {
- vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown);
+ RemoveConnectionsBySocketId(id);
std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
@@ -123,8 +204,8 @@ void VirtualConnectionRouter::OnMessage(CastSocket* socket,
}
} else {
if (!IsTransportNamespace(message.namespace_()) &&
- !vc_manager_->GetConnectionData(VirtualConnection{
- local_id, message.source_id(), socket->socket_id()})) {
+ !GetConnectionData(VirtualConnection{local_id, message.source_id(),
+ socket->socket_id()})) {
return;
}
auto it = endpoints_.find(local_id);
diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h
index 1bbf2bc1..5080e948 100644
--- a/cast/common/channel/virtual_connection_router.h
+++ b/cast/common/channel/virtual_connection_router.h
@@ -10,15 +10,15 @@
#include <memory>
#include <string>
+#include "absl/types/optional.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/virtual_connection.h"
#include "cast/common/public/cast_socket.h"
namespace openscreen {
namespace cast {
class CastMessageHandler;
-struct VirtualConnection;
-class VirtualConnectionManager;
// Handles CastSockets by routing received messages to appropriate message
// handlers based on the VirtualConnection's local ID and sending messages over
@@ -37,8 +37,13 @@ class VirtualConnectionManager;
//
// 4. Anything Foo wants to send (launch, app availability, etc.) goes through
// VCRouter::Send via an appropriate VC. The virtual connection is not
-// created automatically, so Foo should either ensure its existence via logic
-// or check with the VirtualConnectionManager first.
+// created automatically, so AddConnection() must be called first.
+//
+// 5. Anything Foo wants to receive must be registered with a handler by calling
+// AddHandlerForLocalId().
+//
+// 6. Foo is expected to clean-up after itself (#4 and #5) by calling
+// RemoveConnection() and RemoveHandlerForLocalId().
class VirtualConnectionRouter final : public CastSocket::Client {
public:
class SocketErrorHandler {
@@ -47,10 +52,39 @@ class VirtualConnectionRouter final : public CastSocket::Client {
virtual void OnError(CastSocket* socket, Error error) = 0;
};
- explicit VirtualConnectionRouter(VirtualConnectionManager* vc_manager);
+ VirtualConnectionRouter();
~VirtualConnectionRouter() override;
- // These return whether the given |local_id| was successfully added/removed.
+ // Adds a VirtualConnection, if one does not already exist, to enable routing
+ // of peer-to-peer messages.
+ void AddConnection(VirtualConnection virtual_connection,
+ VirtualConnection::AssociatedData associated_data);
+
+ // Removes a VirtualConnection and returns true if a connection matching
+ // |virtual_connection| was found and removed.
+ bool RemoveConnection(const VirtualConnection& virtual_connection,
+ VirtualConnection::CloseReason reason);
+
+ // Removes all VirtualConnections whose local endpoint matches the given
+ // |local_id|.
+ void RemoveConnectionsByLocalId(const std::string& local_id);
+
+ // Removes all VirtualConnections whose traffic passes over the socket
+ // referenced by |socket_id|.
+ void RemoveConnectionsBySocketId(int socket_id);
+
+ // Returns the AssociatedData for a |virtual_connection| if a connection
+ // exists, nullopt otherwise. The pointer isn't stable in the long term; so,
+ // if it actually needs to be stored for later, the caller should make a copy.
+ absl::optional<const VirtualConnection::AssociatedData*> GetConnectionData(
+ const VirtualConnection& virtual_connection) const;
+
+ // Adds/Removes a CastMessageHandler for all messages destined for the given
+ // |endpoint| referred to by |local_id|, and returns whether the given
+ // |local_id| was successfully added/removed.
+ //
+ // Note: Clients will need to separately call AddConnection(), and
+ // RemoveConnection() or RemoveConnectionsByLocalId().
bool AddHandlerForLocalId(std::string local_id, CastMessageHandler* endpoint);
bool RemoveHandlerForLocalId(const std::string& local_id);
@@ -70,15 +104,23 @@ class VirtualConnectionRouter final : public CastSocket::Client {
void OnMessage(CastSocket* socket,
::cast::channel::CastMessage message) override;
- VirtualConnectionManager* manager() { return vc_manager_; }
-
private:
+ // This struct simply stores the remainder of the data {VirtualConnection,
+ // VirtualConnection::AssociatedData} that is not broken up into map keys for
+ // |connections_|.
+ struct VCTail {
+ std::string peer_id;
+ VirtualConnection::AssociatedData data;
+ };
+
struct SocketWithHandler {
std::unique_ptr<CastSocket> socket;
SocketErrorHandler* error_handler;
};
- VirtualConnectionManager* const vc_manager_;
+ std::map<int /* socket_id */,
+ std::multimap<std::string /* local_id */, VCTail>>
+ connections_;
std::map<int, SocketWithHandler> sockets_;
std::map<std::string /* local_id */, CastMessageHandler*> endpoints_;
};
diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc
index b05d10e3..69a08c53 100644
--- a/cast/common/channel/virtual_connection_router_unittest.cc
+++ b/cast/common/channel/virtual_connection_router_unittest.cc
@@ -11,7 +11,6 @@
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_cast_message_handler.h"
#include "cast/common/channel/testing/mock_socket_error_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/public/cast_socket.h"
#include "gtest/gtest.h"
@@ -19,6 +18,19 @@ namespace openscreen {
namespace cast {
namespace {
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0),
+ "V2 1.0 constants must be equal");
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1),
+ "V2 1.1 constants must be equal");
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2),
+ "V2 1.2 constants must be equal");
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3),
+ "V2 1.3 constants must be equal");
+
using ::cast::channel::CastMessage;
using ::testing::_;
using ::testing::Invoke;
@@ -44,21 +56,115 @@ class VirtualConnectionRouterTest : public ::testing::Test {
MockSocketErrorHandler mock_error_handler_;
- VirtualConnectionManager local_manager_;
- VirtualConnectionRouter local_router_{&local_manager_};
+ VirtualConnectionRouter local_router_;
+ VirtualConnectionRouter remote_router_;
- VirtualConnectionManager remote_manager_;
- VirtualConnectionRouter remote_router_{&remote_manager_};
+ VirtualConnection vc1_{"local1", "peer1", 75};
+ VirtualConnection vc2_{"local2", "peer2", 76};
+ VirtualConnection vc3_{"local1", "peer3", 75};
};
} // namespace
+TEST_F(VirtualConnectionRouterTest, NoConnections) {
+ EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
+}
+
+TEST_F(VirtualConnectionRouterTest, AddConnections) {
+ VirtualConnection::AssociatedData data1 = {};
+
+ local_router_.AddConnection(vc1_, std::move(data1));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
+
+ VirtualConnection::AssociatedData data2 = {};
+ local_router_.AddConnection(vc2_, std::move(data2));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
+
+ VirtualConnection::AssociatedData data3 = {};
+ local_router_.AddConnection(vc3_, std::move(data3));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
+}
+
+TEST_F(VirtualConnectionRouterTest, RemoveConnections) {
+ VirtualConnection::AssociatedData data1 = {};
+ VirtualConnection::AssociatedData data2 = {};
+ VirtualConnection::AssociatedData data3 = {};
+
+ local_router_.AddConnection(vc1_, std::move(data1));
+ local_router_.AddConnection(vc2_, std::move(data2));
+ local_router_.AddConnection(vc3_, std::move(data3));
+
+ EXPECT_TRUE(local_router_.RemoveConnection(
+ vc1_, VirtualConnection::CloseReason::kClosedBySelf));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
+
+ EXPECT_TRUE(local_router_.RemoveConnection(
+ vc2_, VirtualConnection::CloseReason::kClosedBySelf));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
+
+ EXPECT_TRUE(local_router_.RemoveConnection(
+ vc3_, VirtualConnection::CloseReason::kClosedBySelf));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
+
+ EXPECT_FALSE(local_router_.RemoveConnection(
+ vc1_, VirtualConnection::CloseReason::kClosedBySelf));
+ EXPECT_FALSE(local_router_.RemoveConnection(
+ vc2_, VirtualConnection::CloseReason::kClosedBySelf));
+ EXPECT_FALSE(local_router_.RemoveConnection(
+ vc3_, VirtualConnection::CloseReason::kClosedBySelf));
+}
+
+TEST_F(VirtualConnectionRouterTest, RemoveConnectionsByIds) {
+ VirtualConnection::AssociatedData data1 = {};
+ VirtualConnection::AssociatedData data2 = {};
+ VirtualConnection::AssociatedData data3 = {};
+
+ local_router_.AddConnection(vc1_, std::move(data1));
+ local_router_.AddConnection(vc2_, std::move(data2));
+ local_router_.AddConnection(vc3_, std::move(data3));
+
+ local_router_.RemoveConnectionsByLocalId("local1");
+ EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
+
+ data1 = {};
+ data2 = {};
+ data3 = {};
+ local_router_.AddConnection(vc1_, std::move(data1));
+ local_router_.AddConnection(vc2_, std::move(data2));
+ local_router_.AddConnection(vc3_, std::move(data3));
+ local_router_.RemoveConnectionsBySocketId(76);
+ EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
+ EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
+
+ local_router_.RemoveConnectionsBySocketId(75);
+ EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
+ EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
+}
+
TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
MockCastMessageHandler mock_message_handler;
local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
- local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
- local_socket_->socket_id()},
- {});
+ local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
+ local_socket_->socket_id()},
+ {});
CastMessage message;
message.set_protocol_version(
@@ -84,9 +190,9 @@ TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
MockCastMessageHandler mock_message_handler;
local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
- local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
- local_socket_->socket_id()},
- {});
+ local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
+ local_socket_->socket_id()},
+ {});
CastMessage message;
message.set_protocol_version(
@@ -108,16 +214,15 @@ TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
}
TEST_F(VirtualConnectionRouterTest, SendMessage) {
- local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
- local_socket_->socket_id()},
- {});
+ local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
+ local_socket_->socket_id()},
+ {});
MockCastMessageHandler destination;
remote_router_.AddHandlerForLocalId("sender-4321", &destination);
- remote_manager_.AddConnection(
- VirtualConnection{"sender-4321", "receiver-1234",
- remote_socket_->socket_id()},
- {});
+ remote_router_.AddConnection(VirtualConnection{"sender-4321", "receiver-1234",
+ remote_socket_->socket_id()},
+ {});
CastMessage message;
message.set_protocol_version(
@@ -142,15 +247,15 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) {
}
TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) {
- local_manager_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
- local_socket_->socket_id()},
- {});
+ local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
+ local_socket_->socket_id()},
+ {});
EXPECT_CALL(mock_error_handler_, OnClose(local_socket_)).Times(1);
int id = local_socket_->socket_id();
local_router_.CloseSocket(id);
- EXPECT_FALSE(local_manager_.GetConnectionData(
+ EXPECT_FALSE(local_router_.GetConnectionData(
VirtualConnection{"receiver-1234", "sender-4321", id}));
}
diff --git a/cast/receiver/application_agent.cc b/cast/receiver/application_agent.cc
index ec8a44a4..06ca3095 100644
--- a/cast/receiver/application_agent.cc
+++ b/cast/receiver/application_agent.cc
@@ -50,8 +50,7 @@ ApplicationAgent::ApplicationAgent(
DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider)
: task_runner_(task_runner),
auth_handler_(credentials_provider),
- connection_handler_(&connection_manager_, this),
- router_(&connection_manager_),
+ connection_handler_(&router_, this),
message_port_(&router_) {
router_.AddHandlerForLocalId(kPlatformReceiverId, this);
}
diff --git a/cast/receiver/application_agent.h b/cast/receiver/application_agent.h
index a6a792aa..10818f99 100644
--- a/cast/receiver/application_agent.h
+++ b/cast/receiver/application_agent.h
@@ -12,7 +12,6 @@
#include "cast/common/channel/cast_socket_message_port.h"
#include "cast/common/channel/connection_namespace_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/receiver/channel/device_auth_namespace_handler.h"
@@ -156,9 +155,8 @@ class ApplicationAgent final
TaskRunner* const task_runner_;
DeviceAuthNamespaceHandler auth_handler_;
- VirtualConnectionManager connection_manager_;
- ConnectionNamespaceHandler connection_handler_;
VirtualConnectionRouter router_;
+ ConnectionNamespaceHandler connection_handler_;
std::map<std::string, Application*> registered_applications_;
Application* idle_screen_app_ = nullptr;
diff --git a/cast/receiver/channel/device_auth_namespace_handler_unittest.cc b/cast/receiver/channel/device_auth_namespace_handler_unittest.cc
index edfc807b..7d4b736c 100644
--- a/cast/receiver/channel/device_auth_namespace_handler_unittest.cc
+++ b/cast/receiver/channel/device_auth_namespace_handler_unittest.cc
@@ -11,7 +11,6 @@
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_socket_error_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/receiver/channel/static_credentials.h"
@@ -55,8 +54,7 @@ class DeviceAuthNamespaceHandlerTest : public ::testing::Test {
CastSocket* socket_;
StaticCredentialsProvider creds_;
- VirtualConnectionManager manager_;
- VirtualConnectionRouter router_{&manager_};
+ VirtualConnectionRouter router_;
DeviceAuthNamespaceHandler auth_handler_{&creds_};
};
diff --git a/cast/sender/cast_app_discovery_service_impl_unittest.cc b/cast/sender/cast_app_discovery_service_impl_unittest.cc
index 9f2f8793..a2eeb046 100644
--- a/cast/sender/cast_app_discovery_service_impl_unittest.cc
+++ b/cast/sender/cast_app_discovery_service_impl_unittest.cc
@@ -4,9 +4,10 @@
#include "cast/sender/cast_app_discovery_service_impl.h"
+#include <utility>
+
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_socket_error_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/service_info.h"
#include "cast/sender/testing/test_helpers.h"
@@ -79,12 +80,10 @@ class CastAppDiscoveryServiceImplTest : public ::testing::Test {
FakeCastSocketPair fake_cast_socket_pair_;
int32_t socket_id_;
MockSocketErrorHandler mock_error_handler_;
- VirtualConnectionManager manager_;
- VirtualConnectionRouter router_{&manager_};
+ VirtualConnectionRouter router_;
FakeClock clock_{Clock::now()};
FakeTaskRunner task_runner_{&clock_};
- CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now,
- &task_runner_};
+ CastPlatformClient platform_client_{&router_, &FakeClock::now, &task_runner_};
CastAppDiscoveryServiceImpl app_discovery_service_{&platform_client_,
&FakeClock::now};
diff --git a/cast/sender/cast_platform_client.cc b/cast/sender/cast_platform_client.cc
index f57adc8b..c321201a 100644
--- a/cast/sender/cast_platform_client.cc
+++ b/cast/sender/cast_platform_client.cc
@@ -9,7 +9,6 @@
#include <utility>
#include "absl/strings/str_cat.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/common/public/service_info.h"
@@ -23,21 +22,20 @@ namespace cast {
static constexpr std::chrono::seconds kRequestTimeout = std::chrono::seconds(5);
CastPlatformClient::CastPlatformClient(VirtualConnectionRouter* router,
- VirtualConnectionManager* manager,
ClockNowFunctionPtr clock,
TaskRunner* task_runner)
: sender_id_(MakeUniqueSessionId("sender")),
virtual_conn_router_(router),
- virtual_conn_manager_(manager),
clock_(clock),
task_runner_(task_runner) {
- OSP_DCHECK(virtual_conn_manager_);
+ OSP_DCHECK(virtual_conn_router_);
OSP_DCHECK(clock_);
OSP_DCHECK(task_runner_);
virtual_conn_router_->AddHandlerForLocalId(sender_id_, this);
}
CastPlatformClient::~CastPlatformClient() {
+ virtual_conn_router_->RemoveConnectionsByLocalId(sender_id_);
virtual_conn_router_->RemoveHandlerForLocalId(sender_id_);
for (auto& pending_requests : pending_requests_by_device_id_) {
@@ -73,9 +71,9 @@ absl::optional<int> CastPlatformClient::RequestAppAvailability(
request_id, app_id, std::move(timeout), std::move(callback)});
VirtualConnection virtual_conn{sender_id_, kPlatformReceiverId, socket_id};
- if (!virtual_conn_manager_->GetConnectionData(virtual_conn)) {
- virtual_conn_manager_->AddConnection(virtual_conn,
- VirtualConnection::AssociatedData{});
+ if (!virtual_conn_router_->GetConnectionData(virtual_conn)) {
+ virtual_conn_router_->AddConnection(virtual_conn,
+ VirtualConnection::AssociatedData{});
}
virtual_conn_router_->Send(std::move(virtual_conn),
diff --git a/cast/sender/cast_platform_client.h b/cast/sender/cast_platform_client.h
index 41ad7fc7..8ea9a99a 100644
--- a/cast/sender/cast_platform_client.h
+++ b/cast/sender/cast_platform_client.h
@@ -7,7 +7,9 @@
#include <functional>
#include <map>
+#include <memory>
#include <string>
+#include <vector>
#include "absl/types/optional.h"
#include "cast/common/channel/cast_message_handler.h"
@@ -19,7 +21,6 @@ namespace openscreen {
namespace cast {
struct ServiceInfo;
-class VirtualConnectionManager;
class VirtualConnectionRouter;
// This class handles Cast messages that generally relate to the "platform", in
@@ -35,7 +36,6 @@ class CastPlatformClient final : public CastMessageHandler {
std::function<void(const std::string& app_id, AppAvailabilityResult)>;
CastPlatformClient(VirtualConnectionRouter* router,
- VirtualConnectionManager* manager,
ClockNowFunctionPtr clock,
TaskRunner* task_runner);
~CastPlatformClient() override;
@@ -82,7 +82,6 @@ class CastPlatformClient final : public CastMessageHandler {
const std::string sender_id_;
VirtualConnectionRouter* const virtual_conn_router_;
- VirtualConnectionManager* const virtual_conn_manager_;
std::map<std::string /* device_id */, int> socket_id_by_device_id_;
std::map<std::string /* device_id */, PendingRequests>
pending_requests_by_device_id_;
diff --git a/cast/sender/cast_platform_client_unittest.cc b/cast/sender/cast_platform_client_unittest.cc
index eaaf5c6e..ae721a15 100644
--- a/cast/sender/cast_platform_client_unittest.cc
+++ b/cast/sender/cast_platform_client_unittest.cc
@@ -4,9 +4,10 @@
#include "cast/sender/cast_platform_client.h"
+#include <utility>
+
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_socket_error_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/service_info.h"
#include "cast/sender/testing/test_helpers.h"
@@ -46,12 +47,10 @@ class CastPlatformClientTest : public ::testing::Test {
FakeCastSocketPair fake_cast_socket_pair_;
CastSocket* socket_ = nullptr;
MockSocketErrorHandler mock_error_handler_;
- VirtualConnectionManager manager_;
- VirtualConnectionRouter router_{&manager_};
+ VirtualConnectionRouter router_;
FakeClock clock_{Clock::now()};
FakeTaskRunner task_runner_{&clock_};
- CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now,
- &task_runner_};
+ CastPlatformClient platform_client_{&router_, &FakeClock::now, &task_runner_};
ServiceInfo receiver_;
};
diff --git a/cast/standalone_sender/looping_file_cast_agent.cc b/cast/standalone_sender/looping_file_cast_agent.cc
index a7e99fdb..727941aa 100644
--- a/cast/standalone_sender/looping_file_cast_agent.cc
+++ b/cast/standalone_sender/looping_file_cast_agent.cc
@@ -24,10 +24,8 @@ using DeviceMediaPolicy = SenderSocketFactory::DeviceMediaPolicy;
LoopingFileCastAgent::LoopingFileCastAgent(TaskRunner* task_runner)
: task_runner_(task_runner) {
- router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_,
- &connection_manager_);
message_port_ =
- MakeSerialDelete<CastSocketMessagePort>(task_runner_, router_.get());
+ MakeSerialDelete<CastSocketMessagePort>(task_runner_, &router_);
socket_factory_ =
MakeSerialDelete<SenderSocketFactory>(task_runner_, this, task_runner_);
connection_factory_ = SerialDeletePtr<TlsConnectionFactory>(
@@ -49,7 +47,7 @@ void LoopingFileCastAgent::Connect(ConnectionSettings settings) {
task_runner_->PostTask([this, policy] {
wake_lock_ = ScopedWakeLock::Create(task_runner_);
socket_factory_->Connect(connection_settings_->receiver_endpoint, policy,
- router_.get());
+ &router_);
});
}
@@ -75,7 +73,7 @@ void LoopingFileCastAgent::OnConnected(SenderSocketFactory* factory,
OSP_LOG_INFO << "Received connection from peer at: " << endpoint;
message_port_->SetSocket(socket->GetWeakPtr());
- router_->TakeSocket(this, std::move(socket));
+ router_.TakeSocket(this, std::move(socket));
CreateAndStartSession();
}
@@ -144,7 +142,7 @@ void LoopingFileCastAgent::StopCurrentSession() {
current_session_.reset();
environment_.reset();
file_sender_.reset();
- router_->CloseSocket(message_port_->GetSocketId());
+ router_.CloseSocket(message_port_->GetSocketId());
message_port_->SetSocket(nullptr);
}
diff --git a/cast/standalone_sender/looping_file_cast_agent.h b/cast/standalone_sender/looping_file_cast_agent.h
index abe91c96..eae7a8be 100644
--- a/cast/standalone_sender/looping_file_cast_agent.h
+++ b/cast/standalone_sender/looping_file_cast_agent.h
@@ -13,7 +13,6 @@
#include "absl/types/optional.h"
#include "cast/common/channel/cast_socket_message_port.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/sender/public/sender_socket_factory.h"
@@ -94,9 +93,8 @@ class LoopingFileCastAgent final
void StopCurrentSession();
// Member variables set as part of construction.
- VirtualConnectionManager connection_manager_;
TaskRunner* const task_runner_;
- SerialDeletePtr<VirtualConnectionRouter> router_;
+ VirtualConnectionRouter router_;
SerialDeletePtr<CastSocketMessagePort> message_port_;
SerialDeletePtr<SenderSocketFactory> socket_factory_;
SerialDeletePtr<TlsConnectionFactory> connection_factory_;
diff --git a/cast/test/cast_socket_e2e_test.cc b/cast/test/cast_socket_e2e_test.cc
index bfd1bbf9..1f3273fa 100644
--- a/cast/test/cast_socket_e2e_test.cc
+++ b/cast/test/cast_socket_e2e_test.cc
@@ -12,7 +12,6 @@
#include "cast/common/certificate/testing/test_helpers.h"
#include "cast/common/channel/connection_namespace_handler.h"
#include "cast/common/channel/message_util.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/receiver/channel/device_auth_namespace_handler.h"
@@ -141,8 +140,7 @@ class CastSocketE2ETest : public ::testing::Test {
std::chrono::milliseconds(0));
task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
- sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(
- task_runner_, &sender_vc_manager_);
+ sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
sender_client_ =
std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get());
sender_factory_ = MakeSerialDelete<SenderSocketFactory>(
@@ -161,8 +159,7 @@ class CastSocketE2ETest : public ::testing::Test {
CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der);
auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>(
task_runner_, credentials_.provider.get());
- receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(
- task_runner_, &receiver_vc_manager_);
+ receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
receiver_router_->AddHandlerForLocalId(kPlatformReceiverId,
auth_handler_.get());
receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>(
@@ -258,14 +255,12 @@ class CastSocketE2ETest : public ::testing::Test {
TaskRunner* task_runner_;
// NOTE: Sender components.
- VirtualConnectionManager sender_vc_manager_;
SerialDeletePtr<VirtualConnectionRouter> sender_router_;
std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_;
SerialDeletePtr<SenderSocketFactory> sender_factory_;
SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_;
// NOTE: Receiver components.
- VirtualConnectionManager receiver_vc_manager_;
SerialDeletePtr<VirtualConnectionRouter> receiver_router_;
GeneratedCredentials credentials_;
SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_;
diff --git a/cast/test/device_auth_test.cc b/cast/test/device_auth_test.cc
index 36a9fc8a..0776e57e 100644
--- a/cast/test/device_auth_test.cc
+++ b/cast/test/device_auth_test.cc
@@ -9,7 +9,6 @@
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/testing/fake_cast_socket.h"
#include "cast/common/channel/testing/mock_socket_error_handler.h"
-#include "cast/common/channel/virtual_connection_manager.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/receiver/channel/device_auth_namespace_handler.h"
@@ -127,8 +126,7 @@ class DeviceAuthTest : public ::testing::Test {
CastSocket* socket_;
StaticCredentialsProvider creds_;
- VirtualConnectionManager manager_;
- VirtualConnectionRouter router_{&manager_};
+ VirtualConnectionRouter router_;
DeviceAuthNamespaceHandler auth_handler_{&creds_};
};