From 3aa23228350c41da8b42d199a4904bfd0144a75e Mon Sep 17 00:00:00 2001 From: "mark a. foltz" Date: Fri, 14 Feb 2020 11:04:33 -0800 Subject: [Open Screen] Remove use of atomics in module code. Since OSL APIs are expected to be run on the same sequence, there isn't a need for atomics outside of code in platform/impl and util/ that deal with tasks and threads. This also simplifies the CastSocket ctor to allocate IDs in a fixed sequence (instead of maintaining a list of free IDs), and converts socket IDs from int32_t to int. Change-Id: Icaceabdff02ff205ac59b340413b1e7be06dcb29 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2055143 Commit-Queue: mark a. foltz Reviewed-by: Brandon Tolsch --- cast/common/channel/cast_socket.cc | 25 ++++------------------ cast/common/channel/cast_socket.h | 12 +++++------ cast/common/channel/testing/fake_cast_socket.h | 10 ++++----- cast/common/channel/virtual_connection.h | 2 +- cast/common/channel/virtual_connection_manager.cc | 2 +- cast/common/channel/virtual_connection_manager.h | 4 ++-- cast/common/channel/virtual_connection_router.cc | 6 +++--- cast/common/channel/virtual_connection_router.h | 4 ++-- .../channel/virtual_connection_router_unittest.cc | 2 +- cast/receiver/channel/receiver_socket_factory.cc | 4 ++-- cast/sender/channel/sender_socket_factory.cc | 8 +++---- cast/sender/channel/sender_socket_factory.h | 4 ++-- discovery/mdns/mdns_records.cc | 3 +-- osp/public/service_listener.h | 3 +-- osp/public/service_publisher.h | 3 +-- util/crypto/certificate_utils.cc | 3 +-- 16 files changed, 36 insertions(+), 59 deletions(-) diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc index ebd22540..ba799658 100644 --- a/cast/common/channel/cast_socket.cc +++ b/cast/common/channel/cast_socket.cc @@ -4,8 +4,6 @@ #include "cast/common/channel/cast_socket.h" -#include - #include "cast/common/channel/message_framer.h" #include "util/logging.h" @@ -15,34 +13,17 @@ namespace cast { using ::cast::channel::CastMessage; using message_serialization::DeserializeResult; -static std::vector g_free_ids; -static std::mutex g_free_ids_mutex; - -int32_t GetNextSocketId() { - static int32_t id{1}; - std::lock_guard lock(g_free_ids_mutex); - if (g_free_ids.empty()) { - return id++; - } else { - int32_t id = g_free_ids.back(); - g_free_ids.pop_back(); - return id; - } -} - CastSocket::CastSocket(std::unique_ptr connection, - Client* client, - int32_t socket_id) + Client* client) : connection_(std::move(connection)), client_(client), - socket_id_(socket_id) { + socket_id_(g_next_socket_id_++) { OSP_DCHECK(client); connection_->SetClient(this); } CastSocket::~CastSocket() { connection_->SetClient(nullptr); - g_free_ids.push_back(socket_id_); } Error CastSocket::SendMessage(const CastMessage& message) { @@ -101,5 +82,7 @@ void CastSocket::OnRead(TlsConnection* connection, std::vector block) { client_->OnMessage(this, std::move(message_or_error.value().message)); } +int CastSocket::g_next_socket_id_ = 1; + } // namespace cast } // namespace openscreen diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h index 35436b4c..550aa76e 100644 --- a/cast/common/channel/cast_socket.h +++ b/cast/common/channel/cast_socket.h @@ -15,8 +15,6 @@ namespace openscreen { namespace cast { -int32_t GetNextSocketId(); - // Represents a simple message-oriented socket for communicating with the Cast // V2 protocol. It isn't thread-safe, so it should only be used on the same // TaskRunner thread as its TlsConnection. @@ -33,9 +31,7 @@ class CastSocket : public TlsConnection::Client { ::cast::channel::CastMessage message) = 0; }; - CastSocket(std::unique_ptr connection, - Client* client, - int32_t socket_id); + CastSocket(std::unique_ptr connection, Client* client); ~CastSocket(); // Sends |message| immediately unless the underlying TLS connection is @@ -48,7 +44,7 @@ class CastSocket : public TlsConnection::Client { std::array GetSanitizedIpAddress(); - int32_t socket_id() const { return socket_id_; } + int socket_id() const { return socket_id_; } void set_audio_only(bool audio_only) { audio_only_ = audio_only; } bool audio_only() const { return audio_only_; } @@ -63,9 +59,11 @@ class CastSocket : public TlsConnection::Client { kError = false, }; + static int g_next_socket_id_; + const std::unique_ptr connection_; Client* client_; // May never be null. - const int32_t socket_id_; + const int socket_id_; bool audio_only_ = false; std::vector read_buffer_; State state_ = State::kOpen; diff --git a/cast/common/channel/testing/fake_cast_socket.h b/cast/common/channel/testing/fake_cast_socket.h index 0287b0b3..495bdecd 100644 --- a/cast/common/channel/testing/fake_cast_socket.h +++ b/cast/common/channel/testing/fake_cast_socket.h @@ -34,7 +34,7 @@ struct FakeCastSocket { remote(remote), moved_connection(std::make_unique(local, remote)), connection(moved_connection.get()), - socket(std::move(moved_connection), &mock_client, 1) {} + socket(std::move(moved_connection), &mock_client) {} IPEndpoint local; IPEndpoint remote; @@ -56,14 +56,14 @@ struct FakeCastSocketPair { auto moved_connection = std::make_unique<::testing::NiceMock>(local, remote); connection = moved_connection.get(); - socket = std::make_unique(std::move(moved_connection), - &mock_client, 1); + socket = + std::make_unique(std::move(moved_connection), &mock_client); auto moved_peer = std::make_unique<::testing::NiceMock>(remote, local); peer_connection = moved_peer.get(); - peer_socket = std::make_unique(std::move(moved_peer), - &mock_peer_client, 2); + peer_socket = + std::make_unique(std::move(moved_peer), &mock_peer_client); ON_CALL(*connection, Send(_, _)) .WillByDefault(Invoke([this](const void* data, size_t len) { diff --git a/cast/common/channel/virtual_connection.h b/cast/common/channel/virtual_connection.h index 31041bf8..04f3ba06 100644 --- a/cast/common/channel/virtual_connection.h +++ b/cast/common/channel/virtual_connection.h @@ -99,7 +99,7 @@ struct VirtualConnection { // app on the device. std::string local_id; std::string peer_id; - int32_t socket_id; + int socket_id; }; inline bool operator==(const VirtualConnection& a, const VirtualConnection& b) { diff --git a/cast/common/channel/virtual_connection_manager.cc b/cast/common/channel/virtual_connection_manager.cc index d9de82b4..86e06706 100644 --- a/cast/common/channel/virtual_connection_manager.cc +++ b/cast/common/channel/virtual_connection_manager.cc @@ -81,7 +81,7 @@ size_t VirtualConnectionManager::RemoveConnectionsByLocalId( } size_t VirtualConnectionManager::RemoveConnectionsBySocketId( - int32_t socket_id, + int socket_id, VirtualConnection::CloseReason reason) { auto entry = connections_.find(socket_id); if (entry == connections_.end()) { diff --git a/cast/common/channel/virtual_connection_manager.h b/cast/common/channel/virtual_connection_manager.h index e8b1b708..902d2a96 100644 --- a/cast/common/channel/virtual_connection_manager.h +++ b/cast/common/channel/virtual_connection_manager.h @@ -32,7 +32,7 @@ class VirtualConnectionManager { // Returns the number of connections removed. size_t RemoveConnectionsByLocalId(const std::string& local_id, VirtualConnection::CloseReason reason); - size_t RemoveConnectionsBySocketId(int32_t socket_id, + size_t RemoveConnectionsBySocketId(int socket_id, VirtualConnection::CloseReason reason); // Returns the AssociatedData for |virtual_connection| if a connection exists, @@ -50,7 +50,7 @@ class VirtualConnectionManager { VirtualConnection::AssociatedData data; }; - std::map> connections_; }; diff --git a/cast/common/channel/virtual_connection_router.cc b/cast/common/channel/virtual_connection_router.cc index 47650447..7ceace60 100644 --- a/cast/common/channel/virtual_connection_router.cc +++ b/cast/common/channel/virtual_connection_router.cc @@ -37,12 +37,12 @@ bool VirtualConnectionRouter::RemoveHandlerForLocalId( void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler, std::unique_ptr socket) { - int32_t id = socket->socket_id(); + int id = socket->socket_id(); socket->SetClient(this); sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler}); } -void VirtualConnectionRouter::CloseSocket(int32_t id) { +void VirtualConnectionRouter::CloseSocket(int id) { auto it = sockets_.find(id); if (it != sockets_.end()) { vc_manager_->RemoveConnectionsBySocketId( @@ -71,7 +71,7 @@ Error VirtualConnectionRouter::SendMessage(VirtualConnection virtual_conn, } void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { - int32_t id = socket->socket_id(); + int id = socket->socket_id(); auto it = sockets_.find(id); if (it != sockets_.end()) { vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown); diff --git a/cast/common/channel/virtual_connection_router.h b/cast/common/channel/virtual_connection_router.h index 41f71612..375c24ff 100644 --- a/cast/common/channel/virtual_connection_router.h +++ b/cast/common/channel/virtual_connection_router.h @@ -57,7 +57,7 @@ class VirtualConnectionRouter final : public CastSocket::Client { // |error_handler| must live until either its OnError or OnClose is called. void TakeSocket(SocketErrorHandler* error_handler, std::unique_ptr socket); - void CloseSocket(int32_t id); + void CloseSocket(int id); Error SendMessage(VirtualConnection virtual_conn, ::cast::channel::CastMessage message); @@ -74,7 +74,7 @@ class VirtualConnectionRouter final : public CastSocket::Client { }; VirtualConnectionManager* const vc_manager_; - std::map sockets_; + std::map sockets_; std::map endpoints_; }; diff --git a/cast/common/channel/virtual_connection_router_unittest.cc b/cast/common/channel/virtual_connection_router_unittest.cc index 624c10b2..57e9fa1b 100644 --- a/cast/common/channel/virtual_connection_router_unittest.cc +++ b/cast/common/channel/virtual_connection_router_unittest.cc @@ -123,7 +123,7 @@ TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) { VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, {}); - int32_t id = socket_->socket_id(); + int id = socket_->socket_id(); router_.CloseSocket(id); EXPECT_FALSE(manager_.GetConnectionData( VirtualConnection{"receiver-1234", "sender-4321", id})); diff --git a/cast/receiver/channel/receiver_socket_factory.cc b/cast/receiver/channel/receiver_socket_factory.cc index f4276e4d..d5e3506c 100644 --- a/cast/receiver/channel/receiver_socket_factory.cc +++ b/cast/receiver/channel/receiver_socket_factory.cc @@ -23,8 +23,8 @@ void ReceiverSocketFactory::OnAccepted( std::vector der_x509_peer_cert, std::unique_ptr connection) { IPEndpoint endpoint = connection->GetRemoteEndpoint(); - auto socket = std::make_unique(std::move(connection), - socket_client_, GetNextSocketId()); + auto socket = + std::make_unique(std::move(connection), socket_client_); client_->OnConnected(this, endpoint, std::move(socket)); } diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc index 8259685f..b2b036b1 100644 --- a/cast/sender/channel/sender_socket_factory.cc +++ b/cast/sender/channel/sender_socket_factory.cc @@ -17,11 +17,11 @@ namespace openscreen { namespace cast { bool operator<(const std::unique_ptr& a, - int32_t b) { + int b) { return a && a->socket->socket_id() < b; } -bool operator<(int32_t a, +bool operator<(int a, const std::unique_ptr& b) { return b && a < b->socket->socket_id(); } @@ -81,8 +81,8 @@ void SenderSocketFactory::OnConnected( return; } - auto socket = MakeSerialDelete( - task_runner_, std::move(connection), this, GetNextSocketId()); + auto socket = + MakeSerialDelete(task_runner_, std::move(connection), this); pending_auth_.emplace_back( new PendingAuth{endpoint, media_policy, std::move(socket), client, AuthContext::Create(), std::move(peer_cert.value())}); diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h index 4b867e2c..7c788536 100644 --- a/cast/sender/channel/sender_socket_factory.h +++ b/cast/sender/channel/sender_socket_factory.h @@ -86,8 +86,8 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client, bssl::UniquePtr peer_cert; }; - friend bool operator<(const std::unique_ptr& a, int32_t b); - friend bool operator<(int32_t a, const std::unique_ptr& b); + friend bool operator<(const std::unique_ptr& a, int b); + friend bool operator<(int a, const std::unique_ptr& b); std::vector::iterator FindPendingConnection( const IPEndpoint& endpoint); diff --git a/discovery/mdns/mdns_records.cc b/discovery/mdns/mdns_records.cc index 135db1d8..af75a6c6 100644 --- a/discovery/mdns/mdns_records.cc +++ b/discovery/mdns/mdns_records.cc @@ -4,7 +4,6 @@ #include "discovery/mdns/mdns_records.h" -#include #include #include "absl/strings/ascii.h" @@ -651,7 +650,7 @@ bool MdnsMessage::CanAddRecord(const MdnsRecord& record) { } uint16_t CreateMessageId() { - static std::atomic id(0); + static uint16_t id(0); return id++; } diff --git a/osp/public/service_listener.h b/osp/public/service_listener.h index 5d81536d..2a11e44c 100644 --- a/osp/public/service_listener.h +++ b/osp/public/service_listener.h @@ -5,7 +5,6 @@ #ifndef OSP_PUBLIC_SERVICE_LISTENER_H_ #define OSP_PUBLIC_SERVICE_LISTENER_H_ -#include #include #include #include @@ -150,7 +149,7 @@ class ServiceListener { protected: ServiceListener(); - std::atomic state_; + State state_; ServiceListenerError last_error_; std::vector observers_; diff --git a/osp/public/service_publisher.h b/osp/public/service_publisher.h index ef94042d..b31f59fc 100644 --- a/osp/public/service_publisher.h +++ b/osp/public/service_publisher.h @@ -5,7 +5,6 @@ #ifndef OSP_PUBLIC_SERVICE_PUBLISHER_H_ #define OSP_PUBLIC_SERVICE_PUBLISHER_H_ -#include #include #include #include @@ -145,7 +144,7 @@ class ServicePublisher { protected: explicit ServicePublisher(Observer* observer); - std::atomic state_; + State state_; ServicePublisherError last_error_; Observer* observer_; diff --git a/util/crypto/certificate_utils.cc b/util/crypto/certificate_utils.cc index 2e90321d..a9844600 100644 --- a/util/crypto/certificate_utils.cc +++ b/util/crypto/certificate_utils.cc @@ -12,7 +12,6 @@ #include #include -#include #include #include "util/crypto/openssl_util.h" @@ -46,7 +45,7 @@ bssl::UniquePtr CreateCertificateInternal( // Serial numbers must be unique for this session. As a pretend CA, we should // not issue certificates with the same serial number in the same session. - static std::atomic_int serial_number(1); + static int serial_number(1); if (ASN1_INTEGER_set(X509_get_serialNumber(certificate.get()), serial_number++) != 1) { return nullptr; -- cgit v1.2.3