diff options
-rw-r--r-- | cast/common/channel/cast_socket.cc | 21 | ||||
-rw-r--r-- | cast/common/channel/cast_socket.h | 8 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection.h | 2 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_manager.cc | 2 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_manager.h | 4 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_router.cc | 9 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_router.h | 4 | ||||
-rw-r--r-- | cast/common/channel/virtual_connection_router_unittest.cc | 11 | ||||
-rw-r--r-- | cast/sender/channel/sender_socket_factory.cc | 4 | ||||
-rw-r--r-- | cast/sender/channel/sender_socket_factory.h | 4 |
10 files changed, 47 insertions, 22 deletions
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc index 1b309332..ebd22540 100644 --- a/cast/common/channel/cast_socket.cc +++ b/cast/common/channel/cast_socket.cc @@ -4,7 +4,7 @@ #include "cast/common/channel/cast_socket.h" -#include <atomic> +#include <mutex> #include "cast/common/channel/message_framer.h" #include "util/logging.h" @@ -15,14 +15,24 @@ namespace cast { using ::cast::channel::CastMessage; using message_serialization::DeserializeResult; -uint32_t GetNextSocketId() { - static std::atomic<uint32_t> id(1); - return id++; +static std::vector<int32_t> g_free_ids; +static std::mutex g_free_ids_mutex; + +int32_t GetNextSocketId() { + static int32_t id{1}; + std::lock_guard<std::mutex> lock(g_free_ids_mutex); + if (g_free_ids.empty()) { + return id++; + } else { + int32_t id = g_free_ids.back(); + g_free_ids.pop_back(); + return id; + } } CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection, Client* client, - uint32_t socket_id) + int32_t socket_id) : connection_(std::move(connection)), client_(client), socket_id_(socket_id) { @@ -32,6 +42,7 @@ CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection, CastSocket::~CastSocket() { connection_->SetClient(nullptr); + g_free_ids.push_back(socket_id_); } Error CastSocket::SendMessage(const CastMessage& message) { diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h index 5032db9d..dfdfe633 100644 --- a/cast/common/channel/cast_socket.h +++ b/cast/common/channel/cast_socket.h @@ -15,7 +15,7 @@ namespace openscreen { namespace cast { -uint32_t GetNextSocketId(); +int32_t GetNextSocketId(); // Represents a simple message-oriented socket for communicating with the Cast // V2 protocol. It isn't thread-safe, so it should only be used on the same @@ -35,7 +35,7 @@ class CastSocket : public TlsConnection::Client { CastSocket(std::unique_ptr<TlsConnection> connection, Client* client, - uint32_t socket_id); + int32_t socket_id); ~CastSocket(); // Sends |message| immediately unless the underlying TLS connection is @@ -48,7 +48,7 @@ class CastSocket : public TlsConnection::Client { std::array<uint8_t, 2> GetSanitizedIpAddress(); - uint32_t socket_id() const { return socket_id_; } + int32_t socket_id() const { return socket_id_; } // TlsConnection::Client overrides. void OnError(TlsConnection* connection, Error error) override; @@ -62,7 +62,7 @@ class CastSocket : public TlsConnection::Client { const std::unique_ptr<TlsConnection> connection_; Client* client_; // May never be null. - const uint32_t socket_id_; + const int32_t socket_id_; std::vector<uint8_t> read_buffer_; State state_ = State::kOpen; }; diff --git a/cast/common/channel/virtual_connection.h b/cast/common/channel/virtual_connection.h index a4045c92..31041bf8 100644 --- a/cast/common/channel/virtual_connection.h +++ b/cast/common/channel/virtual_connection.h @@ -99,7 +99,7 @@ struct VirtualConnection { // app on the device. std::string local_id; std::string peer_id; - uint32_t socket_id; + int32_t socket_id; }; inline bool operator==(const VirtualConnection& a, const VirtualConnection& b) { diff --git a/cast/common/channel/virtual_connection_manager.cc b/cast/common/channel/virtual_connection_manager.cc index cda15ab8..d9de82b4 100644 --- a/cast/common/channel/virtual_connection_manager.cc +++ b/cast/common/channel/virtual_connection_manager.cc @@ -81,7 +81,7 @@ size_t VirtualConnectionManager::RemoveConnectionsByLocalId( } size_t VirtualConnectionManager::RemoveConnectionsBySocketId( - uint32_t socket_id, + int32_t socket_id, VirtualConnection::CloseReason reason) { auto entry = connections_.find(socket_id); if (entry == connections_.end()) { diff --git a/cast/common/channel/virtual_connection_manager.h b/cast/common/channel/virtual_connection_manager.h index a4d1823b..e8b1b708 100644 --- a/cast/common/channel/virtual_connection_manager.h +++ b/cast/common/channel/virtual_connection_manager.h @@ -32,7 +32,7 @@ class VirtualConnectionManager { // Returns the number of connections removed. size_t RemoveConnectionsByLocalId(const std::string& local_id, VirtualConnection::CloseReason reason); - size_t RemoveConnectionsBySocketId(uint32_t socket_id, + size_t RemoveConnectionsBySocketId(int32_t socket_id, VirtualConnection::CloseReason reason); // Returns the AssociatedData for |virtual_connection| if a connection exists, @@ -50,7 +50,7 @@ class VirtualConnectionManager { VirtualConnection::AssociatedData data; }; - std::map<uint32_t /* socket_id */, + std::map<int32_t /* socket_id */, std::multimap<std::string /* local_id */, VCTail>> connections_; }; diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc index fbec2054..47650447 100644 --- a/cast/common/channel/virtual_connection_router.cc +++ b/cast/common/channel/virtual_connection_router.cc @@ -37,14 +37,16 @@ bool VirtualConnectionRouter::RemoveHandlerForLocalId( void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler, std::unique_ptr<CastSocket> socket) { - uint32_t id = socket->socket_id(); + int32_t id = socket->socket_id(); socket->SetClient(this); sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler}); } -void VirtualConnectionRouter::CloseSocket(uint32_t id) { +void VirtualConnectionRouter::CloseSocket(int32_t id) { auto it = sockets_.find(id); if (it != sockets_.end()) { + vc_manager_->RemoveConnectionsBySocketId( + id, VirtualConnection::kTransportClosed); std::unique_ptr<CastSocket> socket = std::move(it->second.socket); SocketErrorHandler* error_handler = it->second.error_handler; sockets_.erase(it); @@ -69,9 +71,10 @@ Error VirtualConnectionRouter::SendMessage(VirtualConnection virtual_conn, } void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { - uint32_t id = socket->socket_id(); + int32_t id = socket->socket_id(); auto it = sockets_.find(id); if (it != sockets_.end()) { + vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown); std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket); SocketErrorHandler* error_handler = it->second.error_handler; sockets_.erase(it); diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h index 57ddf904..41f71612 100644 --- a/cast/common/channel/virtual_connection_router.h +++ b/cast/common/channel/virtual_connection_router.h @@ -57,7 +57,7 @@ class VirtualConnectionRouter final : public CastSocket::Client { // |error_handler| must live until either its OnError or OnClose is called. void TakeSocket(SocketErrorHandler* error_handler, std::unique_ptr<CastSocket> socket); - void CloseSocket(uint32_t id); + void CloseSocket(int32_t id); Error SendMessage(VirtualConnection virtual_conn, ::cast::channel::CastMessage message); @@ -74,7 +74,7 @@ class VirtualConnectionRouter final : public CastSocket::Client { }; VirtualConnectionManager* const vc_manager_; - std::map<uint32_t, SocketWithHandler> sockets_; + std::map<int32_t, SocketWithHandler> sockets_; std::map<std::string /* local_id */, CastMessageHandler*> endpoints_; }; diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc index 74c6a2cd..624c10b2 100644 --- a/cast/common/channel/virtual_connection_router_unittest.cc +++ b/cast/common/channel/virtual_connection_router_unittest.cc @@ -118,5 +118,16 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) { std::move(message)); } +TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) { + manager_.AddConnection( + VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, + {}); + + int32_t id = socket_->socket_id(); + router_.CloseSocket(id); + EXPECT_FALSE(manager_.GetConnectionData( + VirtualConnection{"receiver-1234", "sender-4321", id})); +} + } // namespace cast } // namespace openscreen diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc index 8dff73c4..6f8d9a51 100644 --- a/cast/sender/channel/sender_socket_factory.cc +++ b/cast/sender/channel/sender_socket_factory.cc @@ -16,11 +16,11 @@ namespace openscreen { namespace cast { bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a, - uint32_t b) { + int32_t b) { return a && a->socket->socket_id() < b; } -bool operator<(uint32_t a, +bool operator<(int32_t a, const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) { return b && a < b->socket->socket_id(); } diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h index 9cb31066..094530dd 100644 --- a/cast/sender/channel/sender_socket_factory.h +++ b/cast/sender/channel/sender_socket_factory.h @@ -79,8 +79,8 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client, bssl::UniquePtr<X509> peer_cert; }; - friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b); - friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b); + friend bool operator<(const std::unique_ptr<PendingAuth>& a, int32_t b); + friend bool operator<(int32_t a, const std::unique_ptr<PendingAuth>& b); std::vector<PendingConnection>::iterator FindPendingConnection( const IPEndpoint& endpoint); |