aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cast/common/channel/cast_socket.cc21
-rw-r--r--cast/common/channel/cast_socket.h8
-rw-r--r--cast/common/channel/virtual_connection.h2
-rw-r--r--cast/common/channel/virtual_connection_manager.cc2
-rw-r--r--cast/common/channel/virtual_connection_manager.h4
-rw-r--r--cast/common/channel/virtual_connection_router.cc9
-rw-r--r--cast/common/channel/virtual_connection_router.h4
-rw-r--r--cast/common/channel/virtual_connection_router_unittest.cc11
-rw-r--r--cast/sender/channel/sender_socket_factory.cc4
-rw-r--r--cast/sender/channel/sender_socket_factory.h4
10 files changed, 47 insertions, 22 deletions
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc
index 1b309332..ebd22540 100644
--- a/cast/common/channel/cast_socket.cc
+++ b/cast/common/channel/cast_socket.cc
@@ -4,7 +4,7 @@
#include "cast/common/channel/cast_socket.h"
-#include <atomic>
+#include <mutex>
#include "cast/common/channel/message_framer.h"
#include "util/logging.h"
@@ -15,14 +15,24 @@ namespace cast {
using ::cast::channel::CastMessage;
using message_serialization::DeserializeResult;
-uint32_t GetNextSocketId() {
- static std::atomic<uint32_t> id(1);
- return id++;
+static std::vector<int32_t> g_free_ids;
+static std::mutex g_free_ids_mutex;
+
+int32_t GetNextSocketId() {
+ static int32_t id{1};
+ std::lock_guard<std::mutex> lock(g_free_ids_mutex);
+ if (g_free_ids.empty()) {
+ return id++;
+ } else {
+ int32_t id = g_free_ids.back();
+ g_free_ids.pop_back();
+ return id;
+ }
}
CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
Client* client,
- uint32_t socket_id)
+ int32_t socket_id)
: connection_(std::move(connection)),
client_(client),
socket_id_(socket_id) {
@@ -32,6 +42,7 @@ CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
CastSocket::~CastSocket() {
connection_->SetClient(nullptr);
+ g_free_ids.push_back(socket_id_);
}
Error CastSocket::SendMessage(const CastMessage& message) {
diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h
index 5032db9d..dfdfe633 100644
--- a/cast/common/channel/cast_socket.h
+++ b/cast/common/channel/cast_socket.h
@@ -15,7 +15,7 @@
namespace openscreen {
namespace cast {
-uint32_t GetNextSocketId();
+int32_t GetNextSocketId();
// Represents a simple message-oriented socket for communicating with the Cast
// V2 protocol. It isn't thread-safe, so it should only be used on the same
@@ -35,7 +35,7 @@ class CastSocket : public TlsConnection::Client {
CastSocket(std::unique_ptr<TlsConnection> connection,
Client* client,
- uint32_t socket_id);
+ int32_t socket_id);
~CastSocket();
// Sends |message| immediately unless the underlying TLS connection is
@@ -48,7 +48,7 @@ class CastSocket : public TlsConnection::Client {
std::array<uint8_t, 2> GetSanitizedIpAddress();
- uint32_t socket_id() const { return socket_id_; }
+ int32_t socket_id() const { return socket_id_; }
// TlsConnection::Client overrides.
void OnError(TlsConnection* connection, Error error) override;
@@ -62,7 +62,7 @@ class CastSocket : public TlsConnection::Client {
const std::unique_ptr<TlsConnection> connection_;
Client* client_; // May never be null.
- const uint32_t socket_id_;
+ const int32_t socket_id_;
std::vector<uint8_t> read_buffer_;
State state_ = State::kOpen;
};
diff --git a/cast/common/channel/virtual_connection.h b/cast/common/channel/virtual_connection.h
index a4045c92..31041bf8 100644
--- a/cast/common/channel/virtual_connection.h
+++ b/cast/common/channel/virtual_connection.h
@@ -99,7 +99,7 @@ struct VirtualConnection {
// app on the device.
std::string local_id;
std::string peer_id;
- uint32_t socket_id;
+ int32_t socket_id;
};
inline bool operator==(const VirtualConnection& a, const VirtualConnection& b) {
diff --git a/cast/common/channel/virtual_connection_manager.cc b/cast/common/channel/virtual_connection_manager.cc
index cda15ab8..d9de82b4 100644
--- a/cast/common/channel/virtual_connection_manager.cc
+++ b/cast/common/channel/virtual_connection_manager.cc
@@ -81,7 +81,7 @@ size_t VirtualConnectionManager::RemoveConnectionsByLocalId(
}
size_t VirtualConnectionManager::RemoveConnectionsBySocketId(
- uint32_t socket_id,
+ int32_t socket_id,
VirtualConnection::CloseReason reason) {
auto entry = connections_.find(socket_id);
if (entry == connections_.end()) {
diff --git a/cast/common/channel/virtual_connection_manager.h b/cast/common/channel/virtual_connection_manager.h
index a4d1823b..e8b1b708 100644
--- a/cast/common/channel/virtual_connection_manager.h
+++ b/cast/common/channel/virtual_connection_manager.h
@@ -32,7 +32,7 @@ class VirtualConnectionManager {
// Returns the number of connections removed.
size_t RemoveConnectionsByLocalId(const std::string& local_id,
VirtualConnection::CloseReason reason);
- size_t RemoveConnectionsBySocketId(uint32_t socket_id,
+ size_t RemoveConnectionsBySocketId(int32_t socket_id,
VirtualConnection::CloseReason reason);
// Returns the AssociatedData for |virtual_connection| if a connection exists,
@@ -50,7 +50,7 @@ class VirtualConnectionManager {
VirtualConnection::AssociatedData data;
};
- std::map<uint32_t /* socket_id */,
+ std::map<int32_t /* socket_id */,
std::multimap<std::string /* local_id */, VCTail>>
connections_;
};
diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc
index fbec2054..47650447 100644
--- a/cast/common/channel/virtual_connection_router.cc
+++ b/cast/common/channel/virtual_connection_router.cc
@@ -37,14 +37,16 @@ bool VirtualConnectionRouter::RemoveHandlerForLocalId(
void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
std::unique_ptr<CastSocket> socket) {
- uint32_t id = socket->socket_id();
+ int32_t id = socket->socket_id();
socket->SetClient(this);
sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler});
}
-void VirtualConnectionRouter::CloseSocket(uint32_t id) {
+void VirtualConnectionRouter::CloseSocket(int32_t id) {
auto it = sockets_.find(id);
if (it != sockets_.end()) {
+ vc_manager_->RemoveConnectionsBySocketId(
+ id, VirtualConnection::kTransportClosed);
std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
@@ -69,9 +71,10 @@ Error VirtualConnectionRouter::SendMessage(VirtualConnection virtual_conn,
}
void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
- uint32_t id = socket->socket_id();
+ int32_t id = socket->socket_id();
auto it = sockets_.find(id);
if (it != sockets_.end()) {
+ vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown);
std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h
index 57ddf904..41f71612 100644
--- a/cast/common/channel/virtual_connection_router.h
+++ b/cast/common/channel/virtual_connection_router.h
@@ -57,7 +57,7 @@ class VirtualConnectionRouter final : public CastSocket::Client {
// |error_handler| must live until either its OnError or OnClose is called.
void TakeSocket(SocketErrorHandler* error_handler,
std::unique_ptr<CastSocket> socket);
- void CloseSocket(uint32_t id);
+ void CloseSocket(int32_t id);
Error SendMessage(VirtualConnection virtual_conn,
::cast::channel::CastMessage message);
@@ -74,7 +74,7 @@ class VirtualConnectionRouter final : public CastSocket::Client {
};
VirtualConnectionManager* const vc_manager_;
- std::map<uint32_t, SocketWithHandler> sockets_;
+ std::map<int32_t, SocketWithHandler> sockets_;
std::map<std::string /* local_id */, CastMessageHandler*> endpoints_;
};
diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc
index 74c6a2cd..624c10b2 100644
--- a/cast/common/channel/virtual_connection_router_unittest.cc
+++ b/cast/common/channel/virtual_connection_router_unittest.cc
@@ -118,5 +118,16 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) {
std::move(message));
}
+TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) {
+ manager_.AddConnection(
+ VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
+ {});
+
+ int32_t id = socket_->socket_id();
+ router_.CloseSocket(id);
+ EXPECT_FALSE(manager_.GetConnectionData(
+ VirtualConnection{"receiver-1234", "sender-4321", id}));
+}
+
} // namespace cast
} // namespace openscreen
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
index 8dff73c4..6f8d9a51 100644
--- a/cast/sender/channel/sender_socket_factory.cc
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -16,11 +16,11 @@ namespace openscreen {
namespace cast {
bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
- uint32_t b) {
+ int32_t b) {
return a && a->socket->socket_id() < b;
}
-bool operator<(uint32_t a,
+bool operator<(int32_t a,
const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
return b && a < b->socket->socket_id();
}
diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h
index 9cb31066..094530dd 100644
--- a/cast/sender/channel/sender_socket_factory.h
+++ b/cast/sender/channel/sender_socket_factory.h
@@ -79,8 +79,8 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client,
bssl::UniquePtr<X509> peer_cert;
};
- friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b);
- friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b);
+ friend bool operator<(const std::unique_ptr<PendingAuth>& a, int32_t b);
+ friend bool operator<(int32_t a, const std::unique_ptr<PendingAuth>& b);
std::vector<PendingConnection>::iterator FindPendingConnection(
const IPEndpoint& endpoint);