diff options
-rw-r--r-- | cast/common/channel/cast_socket.cc | 31 | ||||
-rw-r--r-- | cast/common/channel/cast_socket.h | 8 | ||||
-rw-r--r-- | cast/common/channel/cast_socket_unittest.cc | 30 | ||||
-rw-r--r-- | cast/sender/channel/sender_socket_factory.cc | 5 | ||||
-rw-r--r-- | platform/BUILD.gn | 2 | ||||
-rw-r--r-- | platform/api/tls_connection.cc | 47 | ||||
-rw-r--r-- | platform/api/tls_connection.h | 46 | ||||
-rw-r--r-- | platform/api/tls_connection_factory.cc | 36 | ||||
-rw-r--r-- | platform/api/tls_connection_factory.h | 33 | ||||
-rw-r--r-- | platform/api/udp_socket.h | 4 | ||||
-rw-r--r-- | platform/impl/stream_socket_posix.cc | 4 | ||||
-rw-r--r-- | platform/impl/stream_socket_posix.h | 5 | ||||
-rw-r--r-- | platform/impl/tls_connection_factory_posix.cc | 93 | ||||
-rw-r--r-- | platform/impl/tls_connection_factory_posix.h | 25 | ||||
-rw-r--r-- | platform/impl/tls_connection_posix.cc | 127 | ||||
-rw-r--r-- | platform/impl/tls_connection_posix.h | 60 | ||||
-rw-r--r-- | platform/impl/udp_socket_posix.cc | 35 | ||||
-rw-r--r-- | platform/impl/udp_socket_posix.h | 8 | ||||
-rw-r--r-- | platform/impl/weak_ptr.h | 216 | ||||
-rw-r--r-- | platform/impl/weak_ptr_unittest.cc | 186 |
20 files changed, 732 insertions, 269 deletions
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc index ce35e54e..aa41bafe 100644 --- a/cast/common/channel/cast_socket.cc +++ b/cast/common/channel/cast_socket.cc @@ -28,10 +28,12 @@ CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection, connection_(std::move(connection)), socket_id_(socket_id) { OSP_DCHECK(client); - connection_->set_client(this); + connection_->SetClient(this); } -CastSocket::~CastSocket() = default; +CastSocket::~CastSocket() { + connection_->SetClient(nullptr); +} Error CastSocket::SendMessage(const CastMessage& message) { if (state_ == State::kError) { @@ -53,6 +55,11 @@ Error CastSocket::SendMessage(const CastMessage& message) { return Error::Code::kNone; } +void CastSocket::SetClient(Client* client) { + OSP_DCHECK(client); + client_ = client; +} + void CastSocket::OnWriteBlocked(TlsConnection* connection) { if (state_ == State::kOpen) { state_ = State::kBlocked; @@ -60,14 +67,20 @@ void CastSocket::OnWriteBlocked(TlsConnection* connection) { } void CastSocket::OnWriteUnblocked(TlsConnection* connection) { - if (state_ == State::kBlocked) { - state_ = State::kOpen; - for (const auto& message : message_queue_) { - connection_->Write(message.data(), message.size()); - } - OSP_DCHECK(state_ == State::kOpen) << static_cast<int>(state_); - message_queue_.clear(); + if (state_ != State::kBlocked) { + return; + } + state_ = State::kOpen; + + // Attempt to write all messages that have been queued-up while the socket was + // blocked. Stop if the socket becomes blocked again, or an error occurs. + auto it = message_queue_.begin(); + for (const auto end = message_queue_.end(); + it != end && state_ == State::kOpen; ++it) { + // The following Write() could transition |state_| to kBlocked or kError. + connection_->Write(it->data(), it->size()); } + message_queue_.erase(message_queue_.begin(), it); } void CastSocket::OnError(TlsConnection* connection, Error error) { diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h index 0c173ef5..6bd099b2 100644 --- a/cast/common/channel/cast_socket.h +++ b/cast/common/channel/cast_socket.h @@ -45,10 +45,8 @@ class CastSocket : public TlsConnection::Client { // write-blocked. Error SendMessage(const CastMessage& message); - void set_client(Client* client) { - OSP_DCHECK(client); - client_ = client; - } + void SetClient(Client* client); + uint32_t socket_id() const { return socket_id_; } // TlsConnection::Client overrides. @@ -64,7 +62,7 @@ class CastSocket : public TlsConnection::Client { kError, }; - Client* client_; + Client* client_; // May never be null. const std::unique_ptr<TlsConnection> connection_; std::vector<uint8_t> read_buffer_; const uint32_t socket_id_; diff --git a/cast/common/channel/cast_socket_unittest.cc b/cast/common/channel/cast_socket_unittest.cc index 56a69a88..453ffa52 100644 --- a/cast/common/channel/cast_socket_unittest.cc +++ b/cast/common/channel/cast_socket_unittest.cc @@ -29,25 +29,28 @@ class MockTlsConnection final : public TlsConnection { MockTlsConnection(TaskRunner* task_runner, IPEndpoint local_address, IPEndpoint remote_address) - : TlsConnection(task_runner), - local_address_(local_address), - remote_address_(remote_address) {} + : local_address_(local_address), remote_address_(remote_address) {} ~MockTlsConnection() override = default; + void SetClient(TlsConnection::Client* client) final { client_ = client; } + MOCK_METHOD(void, Write, (const void* data, size_t len)); - IPEndpoint local_address() const override { return local_address_; } - IPEndpoint remote_address() const override { return remote_address_; } + IPEndpoint GetLocalEndpoint() const override { return local_address_; } + IPEndpoint GetRemoteEndpoint() const override { return remote_address_; } - void OnWriteBlocked() { TlsConnection::OnWriteBlocked(); } - void OnWriteUnblocked() { TlsConnection::OnWriteUnblocked(); } - void OnError(Error error) { TlsConnection::OnError(error); } - void OnRead(std::vector<uint8_t> block) { TlsConnection::OnRead(block); } + void OnWriteBlocked() { client_->OnWriteBlocked(this); } + void OnWriteUnblocked() { client_->OnWriteUnblocked(this); } + void OnError(Error error) { client_->OnError(this, std::move(error)); } + void OnRead(std::vector<uint8_t> block) { + client_->OnRead(this, std::move(block)); + } private: const IPEndpoint local_address_; const IPEndpoint remote_address_; + TlsConnection::Client* client_ = nullptr; }; class MockCastSocketClient final : public CastSocket::Client { @@ -61,6 +64,7 @@ class MockCastSocketClient final : public CastSocket::Client { class CastSocketTest : public ::testing::Test { public: void SetUp() override { + connection_->SetClient(&socket_); message_.set_protocol_version(CastMessage::CASTV2_1_0); message_.set_source_id("source"); message_.set_destination_id("destination"); @@ -149,7 +153,6 @@ TEST_F(CastSocketTest, ReadChunkedMessage) { TEST_F(CastSocketTest, SendMessageWhileBlocked) { connection_->OnWriteBlocked(); - task_runner_.RunTasksUntilIdle(); EXPECT_CALL(*connection_, Write(_, _)).Times(0); ASSERT_TRUE(socket_.SendMessage(message_).ok()); @@ -161,18 +164,13 @@ TEST_F(CastSocketTest, SendMessageWhileBlocked) { reinterpret_cast<const uint8_t*>(data) + len)); })); connection_->OnWriteUnblocked(); - task_runner_.RunTasksUntilIdle(); - EXPECT_CALL(*connection_, Write(_, _)).Times(0); connection_->OnWriteBlocked(); - task_runner_.RunTasksUntilIdle(); connection_->OnWriteUnblocked(); - task_runner_.RunTasksUntilIdle(); } TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) { connection_->OnWriteBlocked(); - task_runner_.RunTasksUntilIdle(); EXPECT_CALL(*connection_, Write(_, _)).Times(0); ASSERT_TRUE(socket_.SendMessage(message_).ok()); @@ -185,9 +183,7 @@ TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) { connection_->OnError(Error::Code::kUnknownError); })); connection_->OnWriteUnblocked(); - task_runner_.RunTasksUntilIdle(); - EXPECT_CALL(*connection_, Write(_, _)).Times(0); ASSERT_FALSE(socket_.SendMessage(message_).ok()); } diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc index 83e73370..811f07df 100644 --- a/cast/sender/channel/sender_socket_factory.cc +++ b/cast/sender/channel/sender_socket_factory.cc @@ -6,6 +6,7 @@ #include "cast/common/channel/cast_socket.h" #include "cast/sender/channel/message_util.h" +#include "platform/base/tls_connect_options.h" namespace cast { namespace channel { @@ -51,7 +52,7 @@ void SenderSocketFactory::OnConnected( TlsConnectionFactory* factory, X509* peer_cert, std::unique_ptr<TlsConnection> connection) { - const IPEndpoint& endpoint = connection->remote_address(); + const IPEndpoint& endpoint = connection->GetRemoteEndpoint(); auto it = FindPendingConnection(endpoint); if (it == pending_connections_.end()) { OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: " @@ -158,7 +159,7 @@ void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) { return; } - pending->socket->set_client(pending->client); + pending->socket->SetClient(pending->client); client_->OnConnected(this, pending->endpoint, std::move(pending->socket)); } diff --git a/platform/BUILD.gn b/platform/BUILD.gn index 622b2397..7efad0c8 100644 --- a/platform/BUILD.gn +++ b/platform/BUILD.gn @@ -68,6 +68,7 @@ source_set("platform") { "impl/time.cc", "impl/tls_write_buffer.cc", "impl/tls_write_buffer.h", + "impl/weak_ptr.h", ] if (is_linux) { @@ -161,6 +162,7 @@ source_set("unittests") { "api/socket_integration_unittest.cc", "impl/task_runner_unittest.cc", "impl/time_unittest.cc", + "impl/weak_ptr_unittest.cc", ] if (is_posix) { diff --git a/platform/api/tls_connection.cc b/platform/api/tls_connection.cc index e233a001..fc3a8c8e 100644 --- a/platform/api/tls_connection.cc +++ b/platform/api/tls_connection.cc @@ -4,54 +4,11 @@ #include "platform/api/tls_connection.h" -#include "platform/api/task_runner.h" - namespace openscreen { namespace platform { -void TlsConnection::OnWriteBlocked() { - if (!client_) { - return; - } - - task_runner_->PostTask([this]() { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnWriteBlocked(this); - }); -} - -void TlsConnection::OnWriteUnblocked() { - if (!client_) { - return; - } - - task_runner_->PostTask([this]() { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnWriteUnblocked(this); - }); -} - -void TlsConnection::OnError(Error error) { - if (!client_) { - return; - } - - task_runner_->PostTask([e = std::move(error), this]() mutable { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnError(this, std::move(e)); - }); -} - -void TlsConnection::OnRead(std::vector<uint8_t> block) { - if (!client_) { - return; - } - - task_runner_->PostTask([b = std::move(block), this]() mutable { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnRead(this, std::move(b)); - }); -} +TlsConnection::TlsConnection() = default; +TlsConnection::~TlsConnection() = default; } // namespace platform } // namespace openscreen diff --git a/platform/api/tls_connection.h b/platform/api/tls_connection.h index a1b2965c..7c2777fb 100644 --- a/platform/api/tls_connection.h +++ b/platform/api/tls_connection.h @@ -6,23 +6,17 @@ #define PLATFORM_API_TLS_CONNECTION_H_ #include <cstdint> -#include <memory> -#include <string> #include <vector> -#include "absl/types/optional.h" -#include "platform/api/network_interface.h" -#include "platform/api/task_runner.h" #include "platform/base/error.h" #include "platform/base/ip_address.h" -#include "platform/base/macros.h" namespace openscreen { namespace platform { class TlsConnection { public: - // Client callbacks are ran on the provided TaskRunner. + // Client callbacks are run via the TaskRunner used by TlsConnectionFactory. class Client { public: // Called when |connection| writing is blocked and unblocked, respectively. @@ -43,41 +37,25 @@ class TlsConnection { virtual ~Client() = default; }; + virtual ~TlsConnection(); + + // Sets the Client associated with this instance. This should be called as + // soon as the factory provides a new TlsConnection instance via + // TlsConnectionFactory::OnAccepted() or OnConnected(). Pass nullptr to unset + // the Client. + virtual void SetClient(Client* client) = 0; + // Sends a message. virtual void Write(const void* data, size_t len) = 0; // Get the local address. - virtual IPEndpoint local_address() const = 0; + virtual IPEndpoint GetLocalEndpoint() const = 0; // Get the connected remote address. - virtual IPEndpoint remote_address() const = 0; - - // Sets the client for this instance. - void set_client(Client* client) { client_ = client; } - - virtual ~TlsConnection() = default; + virtual IPEndpoint GetRemoteEndpoint() const = 0; protected: - explicit TlsConnection(TaskRunner* task_runner) : task_runner_(task_runner) {} - - // Called when |connection| writing is blocked and unblocked, respectively. - // This call will be proxied to the Client set for this TlsConnection. - void OnWriteBlocked(); - void OnWriteUnblocked(); - - // Called when |connection| experiences an error, such as a read error. This - // call will be proxied to the Client set for this TlsConnection. - void OnError(Error error); - - // Called when a |packet| arrives on |socket|. This call will be proxied to - // the Client set for this TlsConnection. - void OnRead(std::vector<uint8_t> message); - - private: - Client* client_; - TaskRunner* const task_runner_; - - OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnection); + TlsConnection(); }; } // namespace platform diff --git a/platform/api/tls_connection_factory.cc b/platform/api/tls_connection_factory.cc index ee1f9abb..a3bfb287 100644 --- a/platform/api/tls_connection_factory.cc +++ b/platform/api/tls_connection_factory.cc @@ -7,40 +7,8 @@ namespace openscreen { namespace platform { -void TlsConnectionFactory::OnAccepted( - X509* peer_cert, - std::unique_ptr<TlsConnection> connection) { - task_runner_->PostTask( - [peer_cert, c = std::move(connection), this]() mutable { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnAccepted(this, peer_cert, std::move(c)); - }); -} - -void TlsConnectionFactory::OnConnected( - X509* peer_cert, - std::unique_ptr<TlsConnection> connection) { - task_runner_->PostTask( - [peer_cert, c = std::move(connection), this]() mutable { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnConnected(this, peer_cert, std::move(c)); - }); -} - -void TlsConnectionFactory::OnConnectionFailed( - const IPEndpoint& remote_address) { - task_runner_->PostTask([remote_address, this]() { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnConnectionFailed(this, remote_address); - }); -} - -void TlsConnectionFactory::OnError(Error error) { - task_runner_->PostTask([e = std::move(error), this]() mutable { - // TODO(crbug.com/openscreen/71): |this| may be invalid at this point. - this->client_->OnError(this, std::move(e)); - }); -} +TlsConnectionFactory::TlsConnectionFactory() = default; +TlsConnectionFactory::~TlsConnectionFactory() = default; } // namespace platform } // namespace openscreen diff --git a/platform/api/tls_connection_factory.h b/platform/api/tls_connection_factory.h index 81fb76c6..e68d2fc6 100644 --- a/platform/api/tls_connection_factory.h +++ b/platform/api/tls_connection_factory.h @@ -8,18 +8,18 @@ #include <openssl/crypto.h> #include <memory> -#include <string> -#include "absl/types/optional.h" -#include "platform/api/tls_connection.h" #include "platform/base/ip_address.h" -#include "platform/base/tls_connect_options.h" -#include "platform/base/tls_credentials.h" -#include "platform/base/tls_listen_options.h" namespace openscreen { namespace platform { +class TaskRunner; +class TlsConnection; +struct TlsConnectOptions; +class TlsCredentials; +struct TlsListenOptions; + // We expect a single factory to be able to handle an arbitrary number of // calls using the same client and task runner. class TlsConnectionFactory { @@ -49,10 +49,8 @@ class TlsConnectionFactory { Client* client, TaskRunner* task_runner); - virtual ~TlsConnectionFactory() = default; + virtual ~TlsConnectionFactory(); - // TODO(jophba, rwkeane): Determine how to handle multiple connection attempts - // to the same remote_address, and how to distinguish errors. // Fires an OnConnected or OnConnectionFailed event. virtual void Connect(const IPEndpoint& remote_address, const TlsConnectOptions& options) = 0; @@ -67,22 +65,7 @@ class TlsConnectionFactory { const TlsListenOptions& options) = 0; protected: - TlsConnectionFactory(Client* client, TaskRunner* task_runner) - : client_(client), task_runner_(task_runner) {} - - // The below methods proxy calls to this TlsConnectionFactory's Client. - void OnAccepted(X509* peer_cert, std::unique_ptr<TlsConnection> connection); - - void OnConnected(X509* peer_cert, std::unique_ptr<TlsConnection> connection); - - void OnConnectionFailed(const IPEndpoint& remote_address); - - // Called when a non-recoverable error occurs. - void OnError(Error error); - - private: - Client* client_; - TaskRunner* task_runner_; + TlsConnectionFactory(); }; } // namespace platform diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h index cc75e839..e76b9eb8 100644 --- a/platform/api/udp_socket.h +++ b/platform/api/udp_socket.h @@ -14,7 +14,6 @@ #include "platform/api/network_interface.h" #include "platform/base/error.h" #include "platform/base/ip_address.h" -#include "platform/base/macros.h" #include "platform/base/udp_packet.h" namespace openscreen { @@ -118,9 +117,6 @@ class UdpSocket { protected: UdpSocket(); - - private: - OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocket); }; } // namespace platform diff --git a/platform/impl/stream_socket_posix.cc b/platform/impl/stream_socket_posix.cc index dd476a34..b60e82ae 100644 --- a/platform/impl/stream_socket_posix.cc +++ b/platform/impl/stream_socket_posix.cc @@ -49,6 +49,10 @@ StreamSocketPosix::~StreamSocketPosix() { } } +WeakPtr<StreamSocketPosix> StreamSocketPosix::GetWeakPtr() const { + return weak_factory_.GetWeakPtr(); +} + ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() { if (!EnsureInitialized()) { return ReportSocketClosedError(); diff --git a/platform/impl/stream_socket_posix.h b/platform/impl/stream_socket_posix.h index fe30a2b6..a990ca2b 100644 --- a/platform/impl/stream_socket_posix.h +++ b/platform/impl/stream_socket_posix.h @@ -15,6 +15,7 @@ #include "platform/impl/socket_address_posix.h" #include "platform/impl/socket_handle_posix.h" #include "platform/impl/stream_socket.h" +#include "platform/impl/weak_ptr.h" namespace openscreen { namespace platform { @@ -33,6 +34,8 @@ class StreamSocketPosix : public StreamSocket { StreamSocketPosix& operator=(StreamSocketPosix&& other) = default; virtual ~StreamSocketPosix(); + WeakPtr<StreamSocketPosix> GetWeakPtr() const; + // StreamSocket overrides. ErrorOr<std::unique_ptr<StreamSocket>> Accept() override; Error Bind() override; @@ -74,6 +77,8 @@ class StreamSocketPosix : public StreamSocket { bool is_bound_ = false; bool is_initialized_ = false; SocketState state_ = SocketState::kNotConnected; + + WeakPtrFactory<StreamSocketPosix> weak_factory_{this}; }; } // namespace platform diff --git a/platform/impl/tls_connection_factory_posix.cc b/platform/impl/tls_connection_factory_posix.cc index 0ee528f0..c6b43f94 100644 --- a/platform/impl/tls_connection_factory_posix.cc +++ b/platform/impl/tls_connection_factory_posix.cc @@ -15,13 +15,14 @@ #include <unistd.h> #include <cstring> -#include <memory> -#include "absl/types/optional.h" #include "platform/api/logging.h" +#include "platform/api/task_runner.h" #include "platform/api/tls_connection_factory.h" #include "platform/api/trace_logging.h" -#include "platform/base/error.h" +#include "platform/base/tls_connect_options.h" +#include "platform/base/tls_credentials.h" +#include "platform/base/tls_listen_options.h" #include "platform/impl/stream_socket.h" #include "platform/impl/tls_connection_posix.h" #include "util/crypto/openssl_util.h" @@ -40,9 +41,12 @@ TlsConnectionFactoryPosix::TlsConnectionFactoryPosix( Client* client, TaskRunner* task_runner, PlatformClientPosix* platform_client) - : TlsConnectionFactory(client, task_runner), + : client_(client), task_runner_(task_runner), - platform_client_(platform_client) {} + platform_client_(platform_client) { + OSP_DCHECK(client_); + OSP_DCHECK(task_runner_); +} TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() = default; @@ -52,13 +56,13 @@ void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address, const TlsConnectOptions& options) { TRACE_SCOPED(TraceCategory::SSL, "TlsConnectionFactoryPosix::Connect"); IPAddress::Version version = remote_address.address.version(); - std::unique_ptr<TlsConnectionPosix> connection = - std::make_unique<TlsConnectionPosix>(version, task_runner_, - platform_client_); + std::unique_ptr<TlsConnectionPosix> connection( + new TlsConnectionPosix(version, task_runner_, platform_client_)); Error connect_error = connection->socket_->Connect(remote_address); if (!connect_error.ok()) { TRACE_SET_RESULT(connect_error.error()); - OnConnectionFailed(remote_address); + DispatchConnectionFailed(remote_address); + return; } if (!ConfigureSsl(connection.get())) { @@ -75,13 +79,18 @@ void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address, const int connection_status = SSL_connect(connection->ssl_.get()); if (connection_status != 1) { - OnConnectionFailed(connection->remote_address()); + DispatchConnectionFailed(connection->GetRemoteEndpoint()); TRACE_SET_RESULT(GetSSLError(connection->ssl_.get(), connection_status)); return; } - X509* peer_cert = SSL_get_peer_certificate(connection->ssl_.get()); - OnConnected(peer_cert, std::move(connection)); + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + moved_connection = std::move(connection)]() mutable { + if (auto* self = weak_this.get()) { + X509* peer_cert = SSL_get_peer_certificate(moved_connection->ssl_.get()); + self->client_->OnConnected(self, peer_cert, std::move(moved_connection)); + } + }); } void TlsConnectionFactoryPosix::SetListenCredentials( @@ -92,7 +101,7 @@ void TlsConnectionFactoryPosix::SetListenCredentials( // it, so a const cast is unfortunately necessary. X509* non_const_cert = const_cast<X509*>(&credentials.certificate()); if (SSL_CTX_use_certificate(ssl_context_.get(), non_const_cert) != 1) { - OnError(Error::Code::kSocketListenFailure); + DispatchError(Error::Code::kSocketListenFailure); TRACE_SET_RESULT(Error::Code::kSocketListenFailure); return; } @@ -116,31 +125,39 @@ void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address, } void TlsConnectionFactoryPosix::OnConnectionPending(StreamSocketPosix* socket) { - task_runner_->PostTask([socket, this]() mutable { - // TODO(crbug.com/openscreen/71): |this|, |socket| may be invalid at this - // point. - ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket->Accept(); + task_runner_->PostTask([connection_factory_weak_ptr = + weak_factory_.GetWeakPtr(), + socket_weak_ptr = socket->GetWeakPtr()] { + if (!connection_factory_weak_ptr || !socket_weak_ptr) { + // Cancel the Accept() since either the factory or the listener socket + // went away before this task has run. + return; + } + + ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket_weak_ptr->Accept(); if (accepted.is_error()) { // Check for special error code. Because this call doesn't get executed // until it gets through the task runner, OnConnectionPending may get // called multiple times. This check ensures only the first such call will // create a new SSL connection. if (accepted.error().code() != Error::Code::kAgain) { - this->OnError(accepted.error()); + connection_factory_weak_ptr->DispatchError(std::move(accepted.error())); } return; } - this->OnSocketAccepted(std::move(accepted.value())); + connection_factory_weak_ptr->OnSocketAccepted(std::move(accepted.value())); }); } void TlsConnectionFactoryPosix::OnSocketAccepted( std::unique_ptr<StreamSocket> socket) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + TRACE_SCOPED(TraceCategory::SSL, "TlsConnectionFactoryPosix::OnSocketAccepted"); - auto connection = std::make_unique<TlsConnectionPosix>( - std::move(socket), task_runner_, platform_client_); + std::unique_ptr<TlsConnectionPosix> connection(new TlsConnectionPosix( + std::move(socket), task_runner_, platform_client_)); if (!ConfigureSsl(connection.get())) { return; @@ -148,26 +165,31 @@ void TlsConnectionFactoryPosix::OnSocketAccepted( const int connection_status = SSL_accept(connection->ssl_.get()); if (connection_status != 1) { - OnConnectionFailed(connection->remote_address()); + DispatchConnectionFailed(connection->GetRemoteEndpoint()); TRACE_SET_RESULT(GetSSLError(ssl.get(), connection_status)); return; } - X509* peer_cert = SSL_get_peer_certificate(connection->ssl_.get()); - OnAccepted(peer_cert, std::move(connection)); + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + moved_connection = std::move(connection)]() mutable { + if (auto* self = weak_this.get()) { + X509* peer_cert = SSL_get_peer_certificate(moved_connection->ssl_.get()); + self->client_->OnAccepted(self, peer_cert, std::move(moved_connection)); + } + }); } bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) { ErrorOr<bssl::UniquePtr<SSL>> connection_result = GetSslConnection(); if (connection_result.is_error()) { - OnError(connection_result.error()); + DispatchError(connection_result.error()); TRACE_SET_RESULT(connection_result.error()); return false; } bssl::UniquePtr<SSL> ssl = std::move(connection_result.value()); if (!SSL_set_fd(ssl.get(), connection->socket_->socket_handle().fd)) { - OnConnectionFailed(connection->remote_address()); + DispatchConnectionFailed(connection->GetRemoteEndpoint()); TRACE_SET_RESULT(Error(Error::Code::kSocketBindFailure)); return false; } @@ -206,5 +228,24 @@ void TlsConnectionFactoryPosix::Initialize() { ssl_context_.reset(context); } +void TlsConnectionFactoryPosix::DispatchConnectionFailed( + const IPEndpoint& remote_endpoint) { + task_runner_->PostTask( + [weak_this = weak_factory_.GetWeakPtr(), remote = remote_endpoint] { + if (auto* self = weak_this.get()) { + self->client_->OnConnectionFailed(self, remote); + } + }); +} + +void TlsConnectionFactoryPosix::DispatchError(Error error) { + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + moved_error = std::move(error)]() mutable { + if (auto* self = weak_this.get()) { + self->client_->OnError(self, std::move(moved_error)); + } + }); +} + } // namespace platform } // namespace openscreen diff --git a/platform/impl/tls_connection_factory_posix.h b/platform/impl/tls_connection_factory_posix.h index 7d0970cc..aa2aee1a 100644 --- a/platform/impl/tls_connection_factory_posix.h +++ b/platform/impl/tls_connection_factory_posix.h @@ -5,15 +5,14 @@ #ifndef PLATFORM_IMPL_TLS_CONNECTION_FACTORY_POSIX_H_ #define PLATFORM_IMPL_TLS_CONNECTION_FACTORY_POSIX_H_ -#include <future> #include <memory> -#include <string> #include "platform/api/tls_connection.h" #include "platform/api/tls_connection_factory.h" -#include "platform/base/socket_state.h" +#include "platform/base/error.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/tls_data_router_posix.h" +#include "platform/impl/weak_ptr.h" namespace openscreen { namespace platform { @@ -29,18 +28,20 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory, PlatformClientPosix::GetInstance()); ~TlsConnectionFactoryPosix() override; - // TlsConnectionFactory overrides + // TlsConnectionFactory overrides. + // + // TODO(jophba, rwkeane): Determine how to handle multiple connection attempts + // to the same remote_address, and how to distinguish errors. void Connect(const IPEndpoint& remote_address, const TlsConnectOptions& options) override; - void SetListenCredentials(const TlsCredentials& credentials) override; void Listen(const IPEndpoint& local_address, const TlsListenOptions& options) override; + private: // TlsDataRouterPosix::SocketObserver overrides. void OnConnectionPending(StreamSocketPosix* socket) override; - private: // Configures a new SSL connection when a StreamSocket connection is accepted. void OnSocketAccepted(std::unique_ptr<StreamSocket> socket); @@ -59,6 +60,11 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory, // factory. void Initialize(); + // Called on any thread, to post a task to notify the Client that a connection + // failure or other error has occurred. + void DispatchConnectionFailed(const IPEndpoint& remote_endpoint); + void DispatchError(Error error); + // Thread-safe mechanism to ensure Initialize() is only called once. std::once_flag init_instance_flag_; @@ -66,13 +72,16 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory, // from the SSL_CTX is non-trivial, so we store a property instead. bool listen_credentials_set_ = false; - // The task runner used for internal operations. + Client* const client_; TaskRunner* const task_runner_; + PlatformClientPosix* const platform_client_; // SSL context, for creating SSL Connections via BoringSSL. bssl::UniquePtr<SSL_CTX> ssl_context_; - PlatformClientPosix* platform_client_; + WeakPtrFactory<TlsConnectionFactoryPosix> weak_factory_{this}; + + OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionFactoryPosix); }; } // namespace platform diff --git a/platform/impl/tls_connection_posix.cc b/platform/impl/tls_connection_posix.cc index 06058f5f..a67e4f04 100644 --- a/platform/impl/tls_connection_posix.cc +++ b/platform/impl/tls_connection_posix.cc @@ -21,6 +21,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "platform/api/logging.h" +#include "platform/api/task_runner.h" #include "platform/base/error.h" #include "platform/impl/stream_socket.h" #include "util/crypto/openssl_util.h" @@ -32,10 +33,11 @@ namespace platform { TlsConnectionPosix::TlsConnectionPosix(IPEndpoint local_address, TaskRunner* task_runner, PlatformClientPosix* platform_client) - : TlsConnection(task_runner), + : task_runner_(task_runner), + platform_client_(platform_client), socket_(std::make_unique<StreamSocketPosix>(local_address)), - buffer_(this), - platform_client_(platform_client) { + buffer_(this) { + OSP_DCHECK(task_runner_); if (platform_client_) { platform_client_->tls_data_router()->RegisterConnection(this); } @@ -44,10 +46,11 @@ TlsConnectionPosix::TlsConnectionPosix(IPEndpoint local_address, TlsConnectionPosix::TlsConnectionPosix(IPAddress::Version version, TaskRunner* task_runner, PlatformClientPosix* platform_client) - : TlsConnection(task_runner), + : task_runner_(task_runner), + platform_client_(platform_client), socket_(std::make_unique<StreamSocketPosix>(version)), - buffer_(this), - platform_client_(platform_client) { + buffer_(this) { + OSP_DCHECK(task_runner_); if (platform_client_) { platform_client_->tls_data_router()->RegisterConnection(this); } @@ -56,10 +59,11 @@ TlsConnectionPosix::TlsConnectionPosix(IPAddress::Version version, TlsConnectionPosix::TlsConnectionPosix(std::unique_ptr<StreamSocket> socket, TaskRunner* task_runner, PlatformClientPosix* platform_client) - : TlsConnection(task_runner), + : task_runner_(task_runner), + platform_client_(platform_client), socket_(std::move(socket)), - buffer_(this), - platform_client_(platform_client) { + buffer_(this) { + OSP_DCHECK(task_runner_); if (platform_client_) { platform_client_->tls_data_router()->RegisterConnection(this); } @@ -86,42 +90,73 @@ void TlsConnectionPosix::TryReceiveMessage() { if (bytes_read <= 0) { const Error error = GetSSLError(ssl_.get(), bytes_read); if (!error.ok() && (error != Error::Code::kAgain)) { - OnError(error); + DispatchError(error); } return; } block.resize(bytes_read); - OnRead(std::move(block)); + + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + moved_block = std::move(block)]() mutable { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnRead(self, std::move(moved_block)); + } + } + }); } } +void TlsConnectionPosix::SetClient(Client* client) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + client_ = client; + notified_client_buffer_is_blocked_ = false; +} + void TlsConnectionPosix::Write(const void* data, size_t len) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); buffer_.Write(data, len); } -IPEndpoint TlsConnectionPosix::local_address() const { +IPEndpoint TlsConnectionPosix::GetLocalEndpoint() const { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + absl::optional<IPEndpoint> endpoint = socket_->local_address(); OSP_DCHECK(endpoint.has_value()); - return std::move(endpoint.value()); + return endpoint.value(); } -IPEndpoint TlsConnectionPosix::remote_address() const { +IPEndpoint TlsConnectionPosix::GetRemoteEndpoint() const { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + absl::optional<IPEndpoint> endpoint = socket_->remote_address(); OSP_DCHECK(endpoint.has_value()); - return std::move(endpoint.value()); + return endpoint.value(); } void TlsConnectionPosix::NotifyWriteBufferFill(double fraction) { + // WARNING: This method is called on multiple threads. + // + // The following is very subtle/complex behavior: Only "writes" can increase + // the buffer fill, so we expect transitions into the "blocked" state to occur + // on the |task_runner_| thread, and |client_| will be notified + // *synchronously* when that happens. Likewise, only "reads" can cause + // transitions to the "unblocked" state; but these will not occur on the + // |task_runner_| thread. Thus, when unblocking, the |client_| will be + // notified *asynchronously*; but, that should be acceptable because it's only + // a race towards a buffer overrun that is of concern. + // + // TODO(rwkeane): Have Write() return a bool, and then none of this is needed. constexpr double kBlockBufferPercentage = 0.5; - if (fraction > kBlockBufferPercentage && !is_buffer_blocked_) { - OnWriteBlocked(); - is_buffer_blocked_ = true; - } else if (fraction < kBlockBufferPercentage && is_buffer_blocked_) { - OnWriteUnblocked(); - is_buffer_blocked_ = false; - } else if (fraction >= 0.99 && is_buffer_blocked_) { - OnError(Error::Code::kInsufficientBuffer); + if (fraction > kBlockBufferPercentage && + !notified_client_buffer_is_blocked_) { + NotifyClientOfWriteBlockStatusSequentially(true); + } else if (fraction < kBlockBufferPercentage && + notified_client_buffer_is_blocked_) { + NotifyClientOfWriteBlockStatusSequentially(false); + } else if (fraction >= 0.99) { + DispatchError(Error::Code::kInsufficientBuffer); } } @@ -136,12 +171,56 @@ void TlsConnectionPosix::SendAvailableBytes() { if (result <= 0) { const Error result_error = GetSSLError(ssl_.get(), result); if (!result_error.ok() && (result_error.code() != Error::Code::kAgain)) { - OnError(result_error); + DispatchError(result_error); } } else { buffer_.Consume(static_cast<size_t>(result)); } } +void TlsConnectionPosix::DispatchError(Error error) { + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + moved_error = std::move(error)]() mutable { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnError(self, std::move(moved_error)); + } + } + }); +} + +void TlsConnectionPosix::NotifyClientOfWriteBlockStatusSequentially( + bool is_blocked) { + if (!task_runner_->IsRunningOnTaskRunner()) { + task_runner_->PostTask( + [weak_this = weak_factory_.GetWeakPtr(), is_blocked = is_blocked] { + if (auto* self = weak_this.get()) { + OSP_DCHECK(self->task_runner_->IsRunningOnTaskRunner()); + self->NotifyClientOfWriteBlockStatusSequentially(is_blocked); + } + }); + return; + } + + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (!client_) { + return; + } + + // Check again, now that the block/unblock state change is happening + // in-sequence (it originated from parallel executions). + if (notified_client_buffer_is_blocked_ == is_blocked) { + return; + } + + notified_client_buffer_is_blocked_ = is_blocked; + if (is_blocked) { + client_->OnWriteBlocked(this); + } else { + client_->OnWriteUnblocked(this); + } +} + } // namespace platform } // namespace openscreen diff --git a/platform/impl/tls_connection_posix.h b/platform/impl/tls_connection_posix.h index 23b46278..c0089844 100644 --- a/platform/impl/tls_connection_posix.h +++ b/platform/impl/tls_connection_posix.h @@ -9,22 +9,43 @@ #include <atomic> #include <memory> -#include <string> -#include <vector> -#include "platform/api/task_runner.h" #include "platform/api/tls_connection.h" -#include "platform/base/socket_state.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/stream_socket_posix.h" #include "platform/impl/tls_write_buffer.h" +#include "platform/impl/weak_ptr.h" namespace openscreen { namespace platform { +class TaskRunner; +class TlsConnectionFactoryPosix; + class TlsConnectionPosix : public TlsConnection, public TlsWriteBuffer::Observer { public: + ~TlsConnectionPosix() override; + + // Sends any available bytes from this connection's buffer_. + virtual void SendAvailableBytes(); + + // Read out a block/message, if one is available, and notify this instance's + // TlsConnection::Client. + virtual void TryReceiveMessage(); + + // TlsConnection overrides. + void SetClient(Client* client) override; + void Write(const void* data, size_t len) override; + IPEndpoint GetLocalEndpoint() const override; + IPEndpoint GetRemoteEndpoint() const override; + + // TlsWriteBuffer::Observer overrides. + void NotifyWriteBufferFill(double fraction) override; + + protected: + friend class TlsConnectionFactoryPosix; + TlsConnectionPosix(IPEndpoint local_address, TaskRunner* task_runner, PlatformClientPosix* platform_client = @@ -37,33 +58,32 @@ class TlsConnectionPosix : public TlsConnection, TaskRunner* task_runner, PlatformClientPosix* platform_client = PlatformClientPosix::GetInstance()); - ~TlsConnectionPosix(); - // Sends any available bytes from this connection's buffer_. - virtual void SendAvailableBytes(); + private: + // Called on any thread, to post a task to notify the Client that an |error| + // has occurred. + void DispatchError(Error error); - // Read out a block/message, if one is available, and notify this instance's - // TlsConnection::Client. - virtual void TryReceiveMessage(); + // Helper to call OnWriteBlocked() or OnWriteUnblocked(). If this is not + // called within a task run by |task_runner_|, it trampolines by posting a + // task to call itself back via |task_runner_|. See comments in implementation + // of NotifyWriteBufferFill() for further details. + void NotifyClientOfWriteBlockStatusSequentially(bool is_blocked); - // TlsConnection overrides. - void Write(const void* data, size_t len) override; - IPEndpoint local_address() const override; - IPEndpoint remote_address() const override; + TaskRunner* const task_runner_; + PlatformClientPosix* const platform_client_; - // TlsWriteBuffer::Observer overrides. - void NotifyWriteBufferFill(double fraction) override; + Client* client_ = nullptr; - private: std::unique_ptr<StreamSocket> socket_; bssl::UniquePtr<SSL> ssl_; - std::atomic_bool is_buffer_blocked_{false}; + std::atomic_bool notified_client_buffer_is_blocked_{false}; TlsWriteBuffer buffer_; - PlatformClientPosix* platform_client_; + WeakPtrFactory<TlsConnectionPosix> weak_factory_{this}; - friend class TlsConnectionFactoryPosix; + OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionPosix); }; } // namespace platform diff --git a/platform/impl/udp_socket_posix.cc b/platform/impl/udp_socket_posix.cc index f98dc866..c4226024 100644 --- a/platform/impl/udp_socket_posix.cc +++ b/platform/impl/udp_socket_posix.cc @@ -425,10 +425,11 @@ void UdpSocketPosix::ReceiveMessage() { // calling into all the other methods. if (is_closed()) { - task_runner_->PostTask([this] { - // TODO(issues/71): |this| may be invalid at this point. - if (client_) { - client_->OnRead(this, Error::Code::kSocketClosedFailure); + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr()] { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnRead(self, Error::Code::kSocketClosedFailure); + } } }); return; @@ -437,11 +438,13 @@ void UdpSocketPosix::ReceiveMessage() { ssize_t bytes_available = recv(handle_.fd, nullptr, 0, MSG_PEEK | MSG_TRUNC); if (bytes_available == -1) { task_runner_->PostTask( - [this, error = ChooseError(errno, - Error::Code::kSocketReadFailure)]() mutable { - // TODO(issues/71): |this| may be invalid at this point. - if (client_) { - client_->OnRead(this, std::move(error)); + [weak_this = weak_factory_.GetWeakPtr(), + error = + ChooseError(errno, Error::Code::kSocketReadFailure)]() mutable { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnRead(self, std::move(error)); + } } }); return; @@ -466,12 +469,14 @@ void UdpSocketPosix::ReceiveMessage() { } task_runner_->PostTask( - [this, read_result = result.ok() ? ErrorOr<UdpPacket>(std::move(packet)) - : ErrorOr<UdpPacket>( - std::move(result))]() mutable { - // TODO(issues/71): |this| may be invalid at this point. - if (client_) { - client_->OnRead(this, std::move(read_result)); + [weak_this = weak_factory_.GetWeakPtr(), + read_result = result.ok() + ? ErrorOr<UdpPacket>(std::move(packet)) + : ErrorOr<UdpPacket>(std::move(result))]() mutable { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnRead(self, std::move(read_result)); + } } }); } diff --git a/platform/impl/udp_socket_posix.h b/platform/impl/udp_socket_posix.h index dcd614a5..ba44d61a 100644 --- a/platform/impl/udp_socket_posix.h +++ b/platform/impl/udp_socket_posix.h @@ -7,8 +7,10 @@ #include "absl/types/optional.h" #include "platform/api/udp_socket.h" +#include "platform/base/macros.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/socket_handle_posix.h" +#include "platform/impl/weak_ptr.h" namespace openscreen { namespace platform { @@ -78,7 +80,11 @@ class UdpSocketPosix : public UdpSocket { // port is non-zero, it is assumed never to change again. mutable IPEndpoint local_endpoint_; - PlatformClientPosix* platform_client_; + WeakPtrFactory<UdpSocketPosix> weak_factory_{this}; + + PlatformClientPosix* const platform_client_; + + OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocketPosix); }; } // namespace platform diff --git a/platform/impl/weak_ptr.h b/platform/impl/weak_ptr.h new file mode 100644 index 00000000..e065b067 --- /dev/null +++ b/platform/impl/weak_ptr.h @@ -0,0 +1,216 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef PLATFORM_IMPL_WEAK_PTR_H_ +#define PLATFORM_IMPL_WEAK_PTR_H_ + +#include <memory> + +#include "platform/api/logging.h" + +namespace openscreen { + +// Weak pointers are pointers to an object that do not affect its lifetime, +// and which may be invalidated (i.e. reset to nullptr) by the object, or its +// owner, at any time; most commonly when the object is about to be deleted. +// +// Weak pointers are useful when an object needs to be accessed safely by one +// or more objects other than its owner, and those callers can cope with the +// object vanishing and e.g. tasks posted to it being silently dropped. +// Reference-counting such an object would complicate the ownership graph and +// make it harder to reason about the object's lifetime. +// +// EXAMPLE: +// +// class Controller { +// public: +// void SpawnWorker() { new Worker(weak_factory_.GetWeakPtr()); } +// void WorkComplete(const Result& result) { ... } +// private: +// // Member variables should appear before the WeakPtrFactory, to ensure +// // that any WeakPtrs to Controller are invalidated before its members +// // variable's destructors are executed, rendering them invalid. +// WeakPtrFactory<Controller> weak_factory_{this}; +// }; +// +// class Worker { +// public: +// explicit Worker(WeakPtr<Controller> controller) +// : controller_(std::move(controller)) {} +// private: +// void DidCompleteAsynchronousProcessing(const Result& result) { +// if (controller_) +// controller_->WorkComplete(result); +// delete this; +// } +// const WeakPtr<Controller> controller_; +// }; +// +// With this implementation a caller may use SpawnWorker() to dispatch multiple +// Workers and subsequently delete the Controller, without waiting for all +// Workers to have completed. +// +// ------------------------- IMPORTANT: Thread-safety ------------------------- +// +// Generally, Open Screen code is meant to be single-threaded. For the few +// exceptional cases, the following is relevant: +// +// WeakPtrs may be created from WeakPtrFactory, and also duplicated/moved on any +// thread/sequence. However, they may only be dereferenced on the same +// thread/sequence that will ultimately execute the WeakPtrFactory destructor or +// call InvalidateWeakPtrs(). Otherwise, use-during-free or use-after-free is +// possible. +// +// openscreen::WeakPtr and WeakPtrFactory are similar, but not identical, to +// Chromium's base::WeakPtrFactory. Open Screen WeakPtrs may be safely created +// from WeakPtrFactory on any thread/sequence, since they are backed by the +// thread-safe bookkeeping of std::shared_ptr<>. + +template <typename T> +class WeakPtrFactory; + +template <typename T> +class WeakPtr { + public: + WeakPtr() = default; + ~WeakPtr() = default; + + // Copy/Move constructors and assignment operators. + WeakPtr(const WeakPtr& other) : impl_(other.impl_) {} + + WeakPtr(WeakPtr&& other) noexcept : impl_(std::move(other.impl_)) {} + + WeakPtr& operator=(const WeakPtr& other) { + impl_ = other.impl_; + return *this; + } + + WeakPtr& operator=(WeakPtr&& other) noexcept { + impl_ = std::move(other.impl_); + return *this; + } + + // Create/Assign from nullptr. + WeakPtr(std::nullptr_t) {} + + WeakPtr& operator=(std::nullptr_t) { + impl_.reset(); + return *this; + } + + // Copy/Move constructors and assignment operators with upcast conversion. + template <typename U> + WeakPtr(const WeakPtr<U>& other) : impl_(other.as_std_weak_ptr()) {} + + template <typename U> + WeakPtr(WeakPtr<U>&& other) noexcept + : impl_(std::move(other).as_std_weak_ptr()) {} + + template <typename U> + WeakPtr& operator=(const WeakPtr<U>& other) { + impl_ = other.as_std_weak_ptr(); + return *this; + } + + template <typename U> + WeakPtr& operator=(WeakPtr<U>&& other) noexcept { + impl_ = std::move(other).as_std_weak_ptr(); + return *this; + } + + // Accessors. + T* get() const { return impl_.lock().get(); } + + T& operator*() const { + T* const pointer = get(); + OSP_DCHECK(pointer); + return *pointer; + } + + T* operator->() const { + T* const pointer = get(); + OSP_DCHECK(pointer); + return pointer; + } + + // Allow conditionals to test validity, e.g. if (weak_ptr) {...} + explicit operator bool() const { return get() != nullptr; } + + // Conversion to std::weak_ptr<T>. It is unsafe to convert in the other + // direction. See comments for private constructors, below. + const std::weak_ptr<T>& as_std_weak_ptr() const& { return impl_; } + std::weak_ptr<T> as_std_weak_ptr() && { return std::move(impl_); } + + private: + friend class WeakPtrFactory<T>; + + // Called by WeakPtrFactory<T> and the WeakPtr<T> upcast conversion + // constructors and assigners. These are purposely not being exposed publicly + // because that would allow a WeakPtr<T> to be valid/invalid by a different + // ownership/threading model than the intended one (see top-level comments). + template <typename U> + explicit WeakPtr(const std::weak_ptr<U>& other) : impl_(other) {} + + template <typename U> + explicit WeakPtr(std::weak_ptr<U>&& other) noexcept + : impl_(std::move(other)) {} + + std::weak_ptr<T> impl_; +}; + +// Allow callers to compare WeakPtrs against nullptr to test validity. +template <typename T> +bool operator!=(const WeakPtr<T>& weak_ptr, std::nullptr_t) { + return weak_ptr.get() != nullptr; +} +template <typename T> +bool operator!=(std::nullptr_t, const WeakPtr<T>& weak_ptr) { + return weak_ptr.get() != nullptr; +} +template <typename T> +bool operator==(const WeakPtr<T>& weak_ptr, std::nullptr_t) { + return weak_ptr.get() == nullptr; +} +template <typename T> +bool operator==(std::nullptr_t, const WeakPtr<T>& weak_ptr) { + return weak_ptr == nullptr; +} + +template <typename T> +class WeakPtrFactory { + public: + explicit WeakPtrFactory(T* instance) { Reset(instance); } + WeakPtrFactory(WeakPtrFactory&& other) noexcept = default; + WeakPtrFactory& operator=(WeakPtrFactory&& other) noexcept = default; + + // Thread-safe: WeakPtrs may be created on any thread/seuence. They may also + // be copied and moved on any thread/sequence. However, they MUST only be + // dereferenced on the same thread/sequence that calls the destructor or + // InvalidateWeakPtrs(). + WeakPtr<T> GetWeakPtr() const { + return WeakPtr<T>(std::weak_ptr<T>(bookkeeper_)); + } + + // Destruction and Invalidation: These must be called on the same + // thread/sequence that dereferences any WeakPtrs to avoid use-after-free + // bugs. + ~WeakPtrFactory() = default; + void InvalidateWeakPtrs() { Reset(bookkeeper_.get()); } + + private: + WeakPtrFactory(const WeakPtrFactory& other) = delete; + WeakPtrFactory& operator=(const WeakPtrFactory& other) = delete; + + void Reset(T* instance) { + // T is owned externally to WeakPtrFactory. Thus, provide a no-op Deleter. + bookkeeper_ = {instance, [](T* instance) {}}; + } + + // Manages the std::weak_ptr's referring to T. Does not own T. + std::shared_ptr<T> bookkeeper_; +}; + +} // namespace openscreen + +#endif // PLATFORM_IMPL_WEAK_PTR_H_ diff --git a/platform/impl/weak_ptr_unittest.cc b/platform/impl/weak_ptr_unittest.cc new file mode 100644 index 00000000..b06574bb --- /dev/null +++ b/platform/impl/weak_ptr_unittest.cc @@ -0,0 +1,186 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "platform/impl/weak_ptr.h" + +#include "gtest/gtest.h" + +namespace openscreen { +namespace { + +class SomeClass { + public: + virtual ~SomeClass() = default; + virtual int GetValue() const { return 42; } +}; + +struct SomeSubclass : public SomeClass { + public: + ~SomeSubclass() final = default; + int GetValue() const override { return 999; } +}; + +TEST(WeakPtrTest, InteractsWithNullptr) { + WeakPtr<const int> default_constructed; // Invoke default constructor. + EXPECT_TRUE(default_constructed == nullptr); + EXPECT_TRUE(nullptr == default_constructed); + EXPECT_FALSE(default_constructed != nullptr); + EXPECT_FALSE(nullptr != default_constructed); + + WeakPtr<const int> null_constructed = nullptr; // Invoke construct-from-null. + EXPECT_TRUE(null_constructed == nullptr); + EXPECT_TRUE(nullptr == null_constructed); + EXPECT_FALSE(null_constructed != nullptr); + EXPECT_FALSE(nullptr != null_constructed); + + const int foo = 42; + WeakPtrFactory<const int> factory(&foo); + WeakPtr<const int> not_null = factory.GetWeakPtr(); + EXPECT_TRUE(not_null != nullptr); + EXPECT_TRUE(nullptr != not_null); + EXPECT_FALSE(not_null == nullptr); + EXPECT_FALSE(nullptr == not_null); +} + +TEST(WeakPtrTest, CopyConstructsAndAssigns) { + SomeSubclass foo; + WeakPtrFactory<SomeSubclass> factory(&foo); + + WeakPtr<SomeSubclass> weak_ptr = factory.GetWeakPtr(); + EXPECT_TRUE(weak_ptr); + EXPECT_EQ(&foo, weak_ptr.get()); + + // Normal copy constructor. + WeakPtr<SomeSubclass> copied0 = weak_ptr; + EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original. + EXPECT_TRUE(copied0); + EXPECT_EQ(&foo, copied0.get()); + + // Copy constructor, adding const qualifier. + WeakPtr<const SomeSubclass> copied1 = weak_ptr; + EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original. + EXPECT_TRUE(copied1); + EXPECT_EQ(&foo, copied1.get()); + + // Normal copy assignment. + WeakPtr<SomeSubclass> assigned0; + EXPECT_FALSE(assigned0); + assigned0 = copied0; + EXPECT_EQ(&foo, copied0.get()); // Did not mutate original. + EXPECT_TRUE(assigned0); + EXPECT_EQ(&foo, assigned0.get()); + + // Copy assignment, adding const qualifier. + WeakPtr<const SomeSubclass> assigned1; + EXPECT_FALSE(assigned1); + assigned1 = copied0; + EXPECT_EQ(&foo, copied0.get()); // Did not mutate original. + EXPECT_TRUE(assigned1); + EXPECT_EQ(&foo, assigned1.get()); + + // Upcast copy constructor. + WeakPtr<SomeClass> copied2 = weak_ptr; + EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original. + EXPECT_TRUE(copied2); + EXPECT_EQ(&foo, copied2.get()); + EXPECT_EQ(999, (*copied2).GetValue()); + EXPECT_EQ(999, copied2->GetValue()); + + // Upcast copy assignment. + WeakPtr<SomeClass> assigned2; + EXPECT_FALSE(assigned2); + assigned2 = weak_ptr; + EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original. + EXPECT_TRUE(assigned2); + EXPECT_EQ(&foo, assigned2.get()); + EXPECT_EQ(999, (*assigned2).GetValue()); + EXPECT_EQ(999, assigned2->GetValue()); +} + +TEST(WeakPtrTest, MoveConstructsAndAssigns) { + SomeSubclass foo; + WeakPtrFactory<SomeSubclass> factory(&foo); + + // Normal move constructor. + WeakPtr<SomeSubclass> weak_ptr = factory.GetWeakPtr(); + WeakPtr<SomeSubclass> moved0 = std::move(weak_ptr); + EXPECT_FALSE(weak_ptr); // Original becomes null. + EXPECT_TRUE(moved0); + EXPECT_EQ(&foo, moved0.get()); + + // Move constructor, adding const qualifier. + weak_ptr = factory.GetWeakPtr(); + WeakPtr<const SomeSubclass> moved1 = std::move(weak_ptr); + EXPECT_FALSE(weak_ptr); // Original becomes null. + EXPECT_TRUE(moved1); + EXPECT_EQ(&foo, moved1.get()); + + // Normal move assignment. + weak_ptr = factory.GetWeakPtr(); + WeakPtr<SomeSubclass> assigned0; + EXPECT_FALSE(assigned0); + assigned0 = std::move(weak_ptr); + EXPECT_FALSE(weak_ptr); // Original becomes null. + EXPECT_TRUE(assigned0); + EXPECT_EQ(&foo, assigned0.get()); + + // Move assignment, adding const qualifier. + weak_ptr = factory.GetWeakPtr(); + WeakPtr<const SomeSubclass> assigned1; + EXPECT_FALSE(assigned1); + assigned1 = std::move(weak_ptr); + EXPECT_FALSE(weak_ptr); // Original becomes null. + EXPECT_TRUE(assigned1); + EXPECT_EQ(&foo, assigned1.get()); + + // Upcast move constructor. + weak_ptr = factory.GetWeakPtr(); + WeakPtr<SomeClass> moved2 = std::move(weak_ptr); + EXPECT_FALSE(weak_ptr); // Original becomes null. + EXPECT_TRUE(moved2); + EXPECT_EQ(&foo, moved2.get()); + EXPECT_EQ(999, (*moved2).GetValue()); // Result from subclass's GetValue(). + EXPECT_EQ(999, moved2->GetValue()); // Result from subclass's GetValue(). + + // Upcast move assignment. + weak_ptr = factory.GetWeakPtr(); + WeakPtr<SomeClass> assigned2; + EXPECT_FALSE(assigned2); + assigned2 = std::move(weak_ptr); + EXPECT_FALSE(weak_ptr); // Original becomes null. + EXPECT_TRUE(assigned2); + EXPECT_EQ(&foo, assigned2.get()); + EXPECT_EQ(999, + (*assigned2).GetValue()); // Result from subclass's GetValue(). + EXPECT_EQ(999, assigned2->GetValue()); // Result from subclass's GetValue(). +} + +TEST(WeakPtrTest, InvalidatesWeakPtrs) { + const int foo = 1337; + WeakPtrFactory<const int> factory(&foo); + + // Thrice: Create weak pointers and invalidate them. This is done more than + // once to confirm the factory can create valid WeakPtrs again after each + // InvalidateWeakPtrs() call. + for (int i = 0; i < 3; ++i) { + // Create three WeakPtrs, two from the factory, one as a copy of another + // WeakPtr. + WeakPtr<const int> ptr0 = factory.GetWeakPtr(); + WeakPtr<const int> ptr1 = factory.GetWeakPtr(); + WeakPtr<const int> ptr2 = ptr1; + EXPECT_EQ(&foo, ptr0.get()); + EXPECT_EQ(&foo, ptr1.get()); + EXPECT_EQ(&foo, ptr2.get()); + + // Invalidate all outstanding WeakPtrs from the factory, and confirm all + // outstanding WeakPtrs become null. + factory.InvalidateWeakPtrs(); + EXPECT_FALSE(ptr0); + EXPECT_FALSE(ptr1); + EXPECT_FALSE(ptr2); + } +} + +} // namespace +} // namespace openscreen |