aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cast/common/channel/cast_socket.cc31
-rw-r--r--cast/common/channel/cast_socket.h8
-rw-r--r--cast/common/channel/cast_socket_unittest.cc30
-rw-r--r--cast/sender/channel/sender_socket_factory.cc5
-rw-r--r--platform/BUILD.gn2
-rw-r--r--platform/api/tls_connection.cc47
-rw-r--r--platform/api/tls_connection.h46
-rw-r--r--platform/api/tls_connection_factory.cc36
-rw-r--r--platform/api/tls_connection_factory.h33
-rw-r--r--platform/api/udp_socket.h4
-rw-r--r--platform/impl/stream_socket_posix.cc4
-rw-r--r--platform/impl/stream_socket_posix.h5
-rw-r--r--platform/impl/tls_connection_factory_posix.cc93
-rw-r--r--platform/impl/tls_connection_factory_posix.h25
-rw-r--r--platform/impl/tls_connection_posix.cc127
-rw-r--r--platform/impl/tls_connection_posix.h60
-rw-r--r--platform/impl/udp_socket_posix.cc35
-rw-r--r--platform/impl/udp_socket_posix.h8
-rw-r--r--platform/impl/weak_ptr.h216
-rw-r--r--platform/impl/weak_ptr_unittest.cc186
20 files changed, 732 insertions, 269 deletions
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc
index ce35e54e..aa41bafe 100644
--- a/cast/common/channel/cast_socket.cc
+++ b/cast/common/channel/cast_socket.cc
@@ -28,10 +28,12 @@ CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
connection_(std::move(connection)),
socket_id_(socket_id) {
OSP_DCHECK(client);
- connection_->set_client(this);
+ connection_->SetClient(this);
}
-CastSocket::~CastSocket() = default;
+CastSocket::~CastSocket() {
+ connection_->SetClient(nullptr);
+}
Error CastSocket::SendMessage(const CastMessage& message) {
if (state_ == State::kError) {
@@ -53,6 +55,11 @@ Error CastSocket::SendMessage(const CastMessage& message) {
return Error::Code::kNone;
}
+void CastSocket::SetClient(Client* client) {
+ OSP_DCHECK(client);
+ client_ = client;
+}
+
void CastSocket::OnWriteBlocked(TlsConnection* connection) {
if (state_ == State::kOpen) {
state_ = State::kBlocked;
@@ -60,14 +67,20 @@ void CastSocket::OnWriteBlocked(TlsConnection* connection) {
}
void CastSocket::OnWriteUnblocked(TlsConnection* connection) {
- if (state_ == State::kBlocked) {
- state_ = State::kOpen;
- for (const auto& message : message_queue_) {
- connection_->Write(message.data(), message.size());
- }
- OSP_DCHECK(state_ == State::kOpen) << static_cast<int>(state_);
- message_queue_.clear();
+ if (state_ != State::kBlocked) {
+ return;
+ }
+ state_ = State::kOpen;
+
+ // Attempt to write all messages that have been queued-up while the socket was
+ // blocked. Stop if the socket becomes blocked again, or an error occurs.
+ auto it = message_queue_.begin();
+ for (const auto end = message_queue_.end();
+ it != end && state_ == State::kOpen; ++it) {
+ // The following Write() could transition |state_| to kBlocked or kError.
+ connection_->Write(it->data(), it->size());
}
+ message_queue_.erase(message_queue_.begin(), it);
}
void CastSocket::OnError(TlsConnection* connection, Error error) {
diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h
index 0c173ef5..6bd099b2 100644
--- a/cast/common/channel/cast_socket.h
+++ b/cast/common/channel/cast_socket.h
@@ -45,10 +45,8 @@ class CastSocket : public TlsConnection::Client {
// write-blocked.
Error SendMessage(const CastMessage& message);
- void set_client(Client* client) {
- OSP_DCHECK(client);
- client_ = client;
- }
+ void SetClient(Client* client);
+
uint32_t socket_id() const { return socket_id_; }
// TlsConnection::Client overrides.
@@ -64,7 +62,7 @@ class CastSocket : public TlsConnection::Client {
kError,
};
- Client* client_;
+ Client* client_; // May never be null.
const std::unique_ptr<TlsConnection> connection_;
std::vector<uint8_t> read_buffer_;
const uint32_t socket_id_;
diff --git a/cast/common/channel/cast_socket_unittest.cc b/cast/common/channel/cast_socket_unittest.cc
index 56a69a88..453ffa52 100644
--- a/cast/common/channel/cast_socket_unittest.cc
+++ b/cast/common/channel/cast_socket_unittest.cc
@@ -29,25 +29,28 @@ class MockTlsConnection final : public TlsConnection {
MockTlsConnection(TaskRunner* task_runner,
IPEndpoint local_address,
IPEndpoint remote_address)
- : TlsConnection(task_runner),
- local_address_(local_address),
- remote_address_(remote_address) {}
+ : local_address_(local_address), remote_address_(remote_address) {}
~MockTlsConnection() override = default;
+ void SetClient(TlsConnection::Client* client) final { client_ = client; }
+
MOCK_METHOD(void, Write, (const void* data, size_t len));
- IPEndpoint local_address() const override { return local_address_; }
- IPEndpoint remote_address() const override { return remote_address_; }
+ IPEndpoint GetLocalEndpoint() const override { return local_address_; }
+ IPEndpoint GetRemoteEndpoint() const override { return remote_address_; }
- void OnWriteBlocked() { TlsConnection::OnWriteBlocked(); }
- void OnWriteUnblocked() { TlsConnection::OnWriteUnblocked(); }
- void OnError(Error error) { TlsConnection::OnError(error); }
- void OnRead(std::vector<uint8_t> block) { TlsConnection::OnRead(block); }
+ void OnWriteBlocked() { client_->OnWriteBlocked(this); }
+ void OnWriteUnblocked() { client_->OnWriteUnblocked(this); }
+ void OnError(Error error) { client_->OnError(this, std::move(error)); }
+ void OnRead(std::vector<uint8_t> block) {
+ client_->OnRead(this, std::move(block));
+ }
private:
const IPEndpoint local_address_;
const IPEndpoint remote_address_;
+ TlsConnection::Client* client_ = nullptr;
};
class MockCastSocketClient final : public CastSocket::Client {
@@ -61,6 +64,7 @@ class MockCastSocketClient final : public CastSocket::Client {
class CastSocketTest : public ::testing::Test {
public:
void SetUp() override {
+ connection_->SetClient(&socket_);
message_.set_protocol_version(CastMessage::CASTV2_1_0);
message_.set_source_id("source");
message_.set_destination_id("destination");
@@ -149,7 +153,6 @@ TEST_F(CastSocketTest, ReadChunkedMessage) {
TEST_F(CastSocketTest, SendMessageWhileBlocked) {
connection_->OnWriteBlocked();
- task_runner_.RunTasksUntilIdle();
EXPECT_CALL(*connection_, Write(_, _)).Times(0);
ASSERT_TRUE(socket_.SendMessage(message_).ok());
@@ -161,18 +164,13 @@ TEST_F(CastSocketTest, SendMessageWhileBlocked) {
reinterpret_cast<const uint8_t*>(data) + len));
}));
connection_->OnWriteUnblocked();
- task_runner_.RunTasksUntilIdle();
- EXPECT_CALL(*connection_, Write(_, _)).Times(0);
connection_->OnWriteBlocked();
- task_runner_.RunTasksUntilIdle();
connection_->OnWriteUnblocked();
- task_runner_.RunTasksUntilIdle();
}
TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) {
connection_->OnWriteBlocked();
- task_runner_.RunTasksUntilIdle();
EXPECT_CALL(*connection_, Write(_, _)).Times(0);
ASSERT_TRUE(socket_.SendMessage(message_).ok());
@@ -185,9 +183,7 @@ TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) {
connection_->OnError(Error::Code::kUnknownError);
}));
connection_->OnWriteUnblocked();
- task_runner_.RunTasksUntilIdle();
- EXPECT_CALL(*connection_, Write(_, _)).Times(0);
ASSERT_FALSE(socket_.SendMessage(message_).ok());
}
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
index 83e73370..811f07df 100644
--- a/cast/sender/channel/sender_socket_factory.cc
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -6,6 +6,7 @@
#include "cast/common/channel/cast_socket.h"
#include "cast/sender/channel/message_util.h"
+#include "platform/base/tls_connect_options.h"
namespace cast {
namespace channel {
@@ -51,7 +52,7 @@ void SenderSocketFactory::OnConnected(
TlsConnectionFactory* factory,
X509* peer_cert,
std::unique_ptr<TlsConnection> connection) {
- const IPEndpoint& endpoint = connection->remote_address();
+ const IPEndpoint& endpoint = connection->GetRemoteEndpoint();
auto it = FindPendingConnection(endpoint);
if (it == pending_connections_.end()) {
OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
@@ -158,7 +159,7 @@ void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
return;
}
- pending->socket->set_client(pending->client);
+ pending->socket->SetClient(pending->client);
client_->OnConnected(this, pending->endpoint, std::move(pending->socket));
}
diff --git a/platform/BUILD.gn b/platform/BUILD.gn
index 622b2397..7efad0c8 100644
--- a/platform/BUILD.gn
+++ b/platform/BUILD.gn
@@ -68,6 +68,7 @@ source_set("platform") {
"impl/time.cc",
"impl/tls_write_buffer.cc",
"impl/tls_write_buffer.h",
+ "impl/weak_ptr.h",
]
if (is_linux) {
@@ -161,6 +162,7 @@ source_set("unittests") {
"api/socket_integration_unittest.cc",
"impl/task_runner_unittest.cc",
"impl/time_unittest.cc",
+ "impl/weak_ptr_unittest.cc",
]
if (is_posix) {
diff --git a/platform/api/tls_connection.cc b/platform/api/tls_connection.cc
index e233a001..fc3a8c8e 100644
--- a/platform/api/tls_connection.cc
+++ b/platform/api/tls_connection.cc
@@ -4,54 +4,11 @@
#include "platform/api/tls_connection.h"
-#include "platform/api/task_runner.h"
-
namespace openscreen {
namespace platform {
-void TlsConnection::OnWriteBlocked() {
- if (!client_) {
- return;
- }
-
- task_runner_->PostTask([this]() {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnWriteBlocked(this);
- });
-}
-
-void TlsConnection::OnWriteUnblocked() {
- if (!client_) {
- return;
- }
-
- task_runner_->PostTask([this]() {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnWriteUnblocked(this);
- });
-}
-
-void TlsConnection::OnError(Error error) {
- if (!client_) {
- return;
- }
-
- task_runner_->PostTask([e = std::move(error), this]() mutable {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnError(this, std::move(e));
- });
-}
-
-void TlsConnection::OnRead(std::vector<uint8_t> block) {
- if (!client_) {
- return;
- }
-
- task_runner_->PostTask([b = std::move(block), this]() mutable {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnRead(this, std::move(b));
- });
-}
+TlsConnection::TlsConnection() = default;
+TlsConnection::~TlsConnection() = default;
} // namespace platform
} // namespace openscreen
diff --git a/platform/api/tls_connection.h b/platform/api/tls_connection.h
index a1b2965c..7c2777fb 100644
--- a/platform/api/tls_connection.h
+++ b/platform/api/tls_connection.h
@@ -6,23 +6,17 @@
#define PLATFORM_API_TLS_CONNECTION_H_
#include <cstdint>
-#include <memory>
-#include <string>
#include <vector>
-#include "absl/types/optional.h"
-#include "platform/api/network_interface.h"
-#include "platform/api/task_runner.h"
#include "platform/base/error.h"
#include "platform/base/ip_address.h"
-#include "platform/base/macros.h"
namespace openscreen {
namespace platform {
class TlsConnection {
public:
- // Client callbacks are ran on the provided TaskRunner.
+ // Client callbacks are run via the TaskRunner used by TlsConnectionFactory.
class Client {
public:
// Called when |connection| writing is blocked and unblocked, respectively.
@@ -43,41 +37,25 @@ class TlsConnection {
virtual ~Client() = default;
};
+ virtual ~TlsConnection();
+
+ // Sets the Client associated with this instance. This should be called as
+ // soon as the factory provides a new TlsConnection instance via
+ // TlsConnectionFactory::OnAccepted() or OnConnected(). Pass nullptr to unset
+ // the Client.
+ virtual void SetClient(Client* client) = 0;
+
// Sends a message.
virtual void Write(const void* data, size_t len) = 0;
// Get the local address.
- virtual IPEndpoint local_address() const = 0;
+ virtual IPEndpoint GetLocalEndpoint() const = 0;
// Get the connected remote address.
- virtual IPEndpoint remote_address() const = 0;
-
- // Sets the client for this instance.
- void set_client(Client* client) { client_ = client; }
-
- virtual ~TlsConnection() = default;
+ virtual IPEndpoint GetRemoteEndpoint() const = 0;
protected:
- explicit TlsConnection(TaskRunner* task_runner) : task_runner_(task_runner) {}
-
- // Called when |connection| writing is blocked and unblocked, respectively.
- // This call will be proxied to the Client set for this TlsConnection.
- void OnWriteBlocked();
- void OnWriteUnblocked();
-
- // Called when |connection| experiences an error, such as a read error. This
- // call will be proxied to the Client set for this TlsConnection.
- void OnError(Error error);
-
- // Called when a |packet| arrives on |socket|. This call will be proxied to
- // the Client set for this TlsConnection.
- void OnRead(std::vector<uint8_t> message);
-
- private:
- Client* client_;
- TaskRunner* const task_runner_;
-
- OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnection);
+ TlsConnection();
};
} // namespace platform
diff --git a/platform/api/tls_connection_factory.cc b/platform/api/tls_connection_factory.cc
index ee1f9abb..a3bfb287 100644
--- a/platform/api/tls_connection_factory.cc
+++ b/platform/api/tls_connection_factory.cc
@@ -7,40 +7,8 @@
namespace openscreen {
namespace platform {
-void TlsConnectionFactory::OnAccepted(
- X509* peer_cert,
- std::unique_ptr<TlsConnection> connection) {
- task_runner_->PostTask(
- [peer_cert, c = std::move(connection), this]() mutable {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnAccepted(this, peer_cert, std::move(c));
- });
-}
-
-void TlsConnectionFactory::OnConnected(
- X509* peer_cert,
- std::unique_ptr<TlsConnection> connection) {
- task_runner_->PostTask(
- [peer_cert, c = std::move(connection), this]() mutable {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnConnected(this, peer_cert, std::move(c));
- });
-}
-
-void TlsConnectionFactory::OnConnectionFailed(
- const IPEndpoint& remote_address) {
- task_runner_->PostTask([remote_address, this]() {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnConnectionFailed(this, remote_address);
- });
-}
-
-void TlsConnectionFactory::OnError(Error error) {
- task_runner_->PostTask([e = std::move(error), this]() mutable {
- // TODO(crbug.com/openscreen/71): |this| may be invalid at this point.
- this->client_->OnError(this, std::move(e));
- });
-}
+TlsConnectionFactory::TlsConnectionFactory() = default;
+TlsConnectionFactory::~TlsConnectionFactory() = default;
} // namespace platform
} // namespace openscreen
diff --git a/platform/api/tls_connection_factory.h b/platform/api/tls_connection_factory.h
index 81fb76c6..e68d2fc6 100644
--- a/platform/api/tls_connection_factory.h
+++ b/platform/api/tls_connection_factory.h
@@ -8,18 +8,18 @@
#include <openssl/crypto.h>
#include <memory>
-#include <string>
-#include "absl/types/optional.h"
-#include "platform/api/tls_connection.h"
#include "platform/base/ip_address.h"
-#include "platform/base/tls_connect_options.h"
-#include "platform/base/tls_credentials.h"
-#include "platform/base/tls_listen_options.h"
namespace openscreen {
namespace platform {
+class TaskRunner;
+class TlsConnection;
+struct TlsConnectOptions;
+class TlsCredentials;
+struct TlsListenOptions;
+
// We expect a single factory to be able to handle an arbitrary number of
// calls using the same client and task runner.
class TlsConnectionFactory {
@@ -49,10 +49,8 @@ class TlsConnectionFactory {
Client* client,
TaskRunner* task_runner);
- virtual ~TlsConnectionFactory() = default;
+ virtual ~TlsConnectionFactory();
- // TODO(jophba, rwkeane): Determine how to handle multiple connection attempts
- // to the same remote_address, and how to distinguish errors.
// Fires an OnConnected or OnConnectionFailed event.
virtual void Connect(const IPEndpoint& remote_address,
const TlsConnectOptions& options) = 0;
@@ -67,22 +65,7 @@ class TlsConnectionFactory {
const TlsListenOptions& options) = 0;
protected:
- TlsConnectionFactory(Client* client, TaskRunner* task_runner)
- : client_(client), task_runner_(task_runner) {}
-
- // The below methods proxy calls to this TlsConnectionFactory's Client.
- void OnAccepted(X509* peer_cert, std::unique_ptr<TlsConnection> connection);
-
- void OnConnected(X509* peer_cert, std::unique_ptr<TlsConnection> connection);
-
- void OnConnectionFailed(const IPEndpoint& remote_address);
-
- // Called when a non-recoverable error occurs.
- void OnError(Error error);
-
- private:
- Client* client_;
- TaskRunner* task_runner_;
+ TlsConnectionFactory();
};
} // namespace platform
diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h
index cc75e839..e76b9eb8 100644
--- a/platform/api/udp_socket.h
+++ b/platform/api/udp_socket.h
@@ -14,7 +14,6 @@
#include "platform/api/network_interface.h"
#include "platform/base/error.h"
#include "platform/base/ip_address.h"
-#include "platform/base/macros.h"
#include "platform/base/udp_packet.h"
namespace openscreen {
@@ -118,9 +117,6 @@ class UdpSocket {
protected:
UdpSocket();
-
- private:
- OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocket);
};
} // namespace platform
diff --git a/platform/impl/stream_socket_posix.cc b/platform/impl/stream_socket_posix.cc
index dd476a34..b60e82ae 100644
--- a/platform/impl/stream_socket_posix.cc
+++ b/platform/impl/stream_socket_posix.cc
@@ -49,6 +49,10 @@ StreamSocketPosix::~StreamSocketPosix() {
}
}
+WeakPtr<StreamSocketPosix> StreamSocketPosix::GetWeakPtr() const {
+ return weak_factory_.GetWeakPtr();
+}
+
ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() {
if (!EnsureInitialized()) {
return ReportSocketClosedError();
diff --git a/platform/impl/stream_socket_posix.h b/platform/impl/stream_socket_posix.h
index fe30a2b6..a990ca2b 100644
--- a/platform/impl/stream_socket_posix.h
+++ b/platform/impl/stream_socket_posix.h
@@ -15,6 +15,7 @@
#include "platform/impl/socket_address_posix.h"
#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/stream_socket.h"
+#include "platform/impl/weak_ptr.h"
namespace openscreen {
namespace platform {
@@ -33,6 +34,8 @@ class StreamSocketPosix : public StreamSocket {
StreamSocketPosix& operator=(StreamSocketPosix&& other) = default;
virtual ~StreamSocketPosix();
+ WeakPtr<StreamSocketPosix> GetWeakPtr() const;
+
// StreamSocket overrides.
ErrorOr<std::unique_ptr<StreamSocket>> Accept() override;
Error Bind() override;
@@ -74,6 +77,8 @@ class StreamSocketPosix : public StreamSocket {
bool is_bound_ = false;
bool is_initialized_ = false;
SocketState state_ = SocketState::kNotConnected;
+
+ WeakPtrFactory<StreamSocketPosix> weak_factory_{this};
};
} // namespace platform
diff --git a/platform/impl/tls_connection_factory_posix.cc b/platform/impl/tls_connection_factory_posix.cc
index 0ee528f0..c6b43f94 100644
--- a/platform/impl/tls_connection_factory_posix.cc
+++ b/platform/impl/tls_connection_factory_posix.cc
@@ -15,13 +15,14 @@
#include <unistd.h>
#include <cstring>
-#include <memory>
-#include "absl/types/optional.h"
#include "platform/api/logging.h"
+#include "platform/api/task_runner.h"
#include "platform/api/tls_connection_factory.h"
#include "platform/api/trace_logging.h"
-#include "platform/base/error.h"
+#include "platform/base/tls_connect_options.h"
+#include "platform/base/tls_credentials.h"
+#include "platform/base/tls_listen_options.h"
#include "platform/impl/stream_socket.h"
#include "platform/impl/tls_connection_posix.h"
#include "util/crypto/openssl_util.h"
@@ -40,9 +41,12 @@ TlsConnectionFactoryPosix::TlsConnectionFactoryPosix(
Client* client,
TaskRunner* task_runner,
PlatformClientPosix* platform_client)
- : TlsConnectionFactory(client, task_runner),
+ : client_(client),
task_runner_(task_runner),
- platform_client_(platform_client) {}
+ platform_client_(platform_client) {
+ OSP_DCHECK(client_);
+ OSP_DCHECK(task_runner_);
+}
TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() = default;
@@ -52,13 +56,13 @@ void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
const TlsConnectOptions& options) {
TRACE_SCOPED(TraceCategory::SSL, "TlsConnectionFactoryPosix::Connect");
IPAddress::Version version = remote_address.address.version();
- std::unique_ptr<TlsConnectionPosix> connection =
- std::make_unique<TlsConnectionPosix>(version, task_runner_,
- platform_client_);
+ std::unique_ptr<TlsConnectionPosix> connection(
+ new TlsConnectionPosix(version, task_runner_, platform_client_));
Error connect_error = connection->socket_->Connect(remote_address);
if (!connect_error.ok()) {
TRACE_SET_RESULT(connect_error.error());
- OnConnectionFailed(remote_address);
+ DispatchConnectionFailed(remote_address);
+ return;
}
if (!ConfigureSsl(connection.get())) {
@@ -75,13 +79,18 @@ void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
const int connection_status = SSL_connect(connection->ssl_.get());
if (connection_status != 1) {
- OnConnectionFailed(connection->remote_address());
+ DispatchConnectionFailed(connection->GetRemoteEndpoint());
TRACE_SET_RESULT(GetSSLError(connection->ssl_.get(), connection_status));
return;
}
- X509* peer_cert = SSL_get_peer_certificate(connection->ssl_.get());
- OnConnected(peer_cert, std::move(connection));
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ moved_connection = std::move(connection)]() mutable {
+ if (auto* self = weak_this.get()) {
+ X509* peer_cert = SSL_get_peer_certificate(moved_connection->ssl_.get());
+ self->client_->OnConnected(self, peer_cert, std::move(moved_connection));
+ }
+ });
}
void TlsConnectionFactoryPosix::SetListenCredentials(
@@ -92,7 +101,7 @@ void TlsConnectionFactoryPosix::SetListenCredentials(
// it, so a const cast is unfortunately necessary.
X509* non_const_cert = const_cast<X509*>(&credentials.certificate());
if (SSL_CTX_use_certificate(ssl_context_.get(), non_const_cert) != 1) {
- OnError(Error::Code::kSocketListenFailure);
+ DispatchError(Error::Code::kSocketListenFailure);
TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
return;
}
@@ -116,31 +125,39 @@ void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address,
}
void TlsConnectionFactoryPosix::OnConnectionPending(StreamSocketPosix* socket) {
- task_runner_->PostTask([socket, this]() mutable {
- // TODO(crbug.com/openscreen/71): |this|, |socket| may be invalid at this
- // point.
- ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket->Accept();
+ task_runner_->PostTask([connection_factory_weak_ptr =
+ weak_factory_.GetWeakPtr(),
+ socket_weak_ptr = socket->GetWeakPtr()] {
+ if (!connection_factory_weak_ptr || !socket_weak_ptr) {
+ // Cancel the Accept() since either the factory or the listener socket
+ // went away before this task has run.
+ return;
+ }
+
+ ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket_weak_ptr->Accept();
if (accepted.is_error()) {
// Check for special error code. Because this call doesn't get executed
// until it gets through the task runner, OnConnectionPending may get
// called multiple times. This check ensures only the first such call will
// create a new SSL connection.
if (accepted.error().code() != Error::Code::kAgain) {
- this->OnError(accepted.error());
+ connection_factory_weak_ptr->DispatchError(std::move(accepted.error()));
}
return;
}
- this->OnSocketAccepted(std::move(accepted.value()));
+ connection_factory_weak_ptr->OnSocketAccepted(std::move(accepted.value()));
});
}
void TlsConnectionFactoryPosix::OnSocketAccepted(
std::unique_ptr<StreamSocket> socket) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
TRACE_SCOPED(TraceCategory::SSL,
"TlsConnectionFactoryPosix::OnSocketAccepted");
- auto connection = std::make_unique<TlsConnectionPosix>(
- std::move(socket), task_runner_, platform_client_);
+ std::unique_ptr<TlsConnectionPosix> connection(new TlsConnectionPosix(
+ std::move(socket), task_runner_, platform_client_));
if (!ConfigureSsl(connection.get())) {
return;
@@ -148,26 +165,31 @@ void TlsConnectionFactoryPosix::OnSocketAccepted(
const int connection_status = SSL_accept(connection->ssl_.get());
if (connection_status != 1) {
- OnConnectionFailed(connection->remote_address());
+ DispatchConnectionFailed(connection->GetRemoteEndpoint());
TRACE_SET_RESULT(GetSSLError(ssl.get(), connection_status));
return;
}
- X509* peer_cert = SSL_get_peer_certificate(connection->ssl_.get());
- OnAccepted(peer_cert, std::move(connection));
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ moved_connection = std::move(connection)]() mutable {
+ if (auto* self = weak_this.get()) {
+ X509* peer_cert = SSL_get_peer_certificate(moved_connection->ssl_.get());
+ self->client_->OnAccepted(self, peer_cert, std::move(moved_connection));
+ }
+ });
}
bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) {
ErrorOr<bssl::UniquePtr<SSL>> connection_result = GetSslConnection();
if (connection_result.is_error()) {
- OnError(connection_result.error());
+ DispatchError(connection_result.error());
TRACE_SET_RESULT(connection_result.error());
return false;
}
bssl::UniquePtr<SSL> ssl = std::move(connection_result.value());
if (!SSL_set_fd(ssl.get(), connection->socket_->socket_handle().fd)) {
- OnConnectionFailed(connection->remote_address());
+ DispatchConnectionFailed(connection->GetRemoteEndpoint());
TRACE_SET_RESULT(Error(Error::Code::kSocketBindFailure));
return false;
}
@@ -206,5 +228,24 @@ void TlsConnectionFactoryPosix::Initialize() {
ssl_context_.reset(context);
}
+void TlsConnectionFactoryPosix::DispatchConnectionFailed(
+ const IPEndpoint& remote_endpoint) {
+ task_runner_->PostTask(
+ [weak_this = weak_factory_.GetWeakPtr(), remote = remote_endpoint] {
+ if (auto* self = weak_this.get()) {
+ self->client_->OnConnectionFailed(self, remote);
+ }
+ });
+}
+
+void TlsConnectionFactoryPosix::DispatchError(Error error) {
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ moved_error = std::move(error)]() mutable {
+ if (auto* self = weak_this.get()) {
+ self->client_->OnError(self, std::move(moved_error));
+ }
+ });
+}
+
} // namespace platform
} // namespace openscreen
diff --git a/platform/impl/tls_connection_factory_posix.h b/platform/impl/tls_connection_factory_posix.h
index 7d0970cc..aa2aee1a 100644
--- a/platform/impl/tls_connection_factory_posix.h
+++ b/platform/impl/tls_connection_factory_posix.h
@@ -5,15 +5,14 @@
#ifndef PLATFORM_IMPL_TLS_CONNECTION_FACTORY_POSIX_H_
#define PLATFORM_IMPL_TLS_CONNECTION_FACTORY_POSIX_H_
-#include <future>
#include <memory>
-#include <string>
#include "platform/api/tls_connection.h"
#include "platform/api/tls_connection_factory.h"
-#include "platform/base/socket_state.h"
+#include "platform/base/error.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/tls_data_router_posix.h"
+#include "platform/impl/weak_ptr.h"
namespace openscreen {
namespace platform {
@@ -29,18 +28,20 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory,
PlatformClientPosix::GetInstance());
~TlsConnectionFactoryPosix() override;
- // TlsConnectionFactory overrides
+ // TlsConnectionFactory overrides.
+ //
+ // TODO(jophba, rwkeane): Determine how to handle multiple connection attempts
+ // to the same remote_address, and how to distinguish errors.
void Connect(const IPEndpoint& remote_address,
const TlsConnectOptions& options) override;
-
void SetListenCredentials(const TlsCredentials& credentials) override;
void Listen(const IPEndpoint& local_address,
const TlsListenOptions& options) override;
+ private:
// TlsDataRouterPosix::SocketObserver overrides.
void OnConnectionPending(StreamSocketPosix* socket) override;
- private:
// Configures a new SSL connection when a StreamSocket connection is accepted.
void OnSocketAccepted(std::unique_ptr<StreamSocket> socket);
@@ -59,6 +60,11 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory,
// factory.
void Initialize();
+ // Called on any thread, to post a task to notify the Client that a connection
+ // failure or other error has occurred.
+ void DispatchConnectionFailed(const IPEndpoint& remote_endpoint);
+ void DispatchError(Error error);
+
// Thread-safe mechanism to ensure Initialize() is only called once.
std::once_flag init_instance_flag_;
@@ -66,13 +72,16 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory,
// from the SSL_CTX is non-trivial, so we store a property instead.
bool listen_credentials_set_ = false;
- // The task runner used for internal operations.
+ Client* const client_;
TaskRunner* const task_runner_;
+ PlatformClientPosix* const platform_client_;
// SSL context, for creating SSL Connections via BoringSSL.
bssl::UniquePtr<SSL_CTX> ssl_context_;
- PlatformClientPosix* platform_client_;
+ WeakPtrFactory<TlsConnectionFactoryPosix> weak_factory_{this};
+
+ OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionFactoryPosix);
};
} // namespace platform
diff --git a/platform/impl/tls_connection_posix.cc b/platform/impl/tls_connection_posix.cc
index 06058f5f..a67e4f04 100644
--- a/platform/impl/tls_connection_posix.cc
+++ b/platform/impl/tls_connection_posix.cc
@@ -21,6 +21,7 @@
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "platform/api/logging.h"
+#include "platform/api/task_runner.h"
#include "platform/base/error.h"
#include "platform/impl/stream_socket.h"
#include "util/crypto/openssl_util.h"
@@ -32,10 +33,11 @@ namespace platform {
TlsConnectionPosix::TlsConnectionPosix(IPEndpoint local_address,
TaskRunner* task_runner,
PlatformClientPosix* platform_client)
- : TlsConnection(task_runner),
+ : task_runner_(task_runner),
+ platform_client_(platform_client),
socket_(std::make_unique<StreamSocketPosix>(local_address)),
- buffer_(this),
- platform_client_(platform_client) {
+ buffer_(this) {
+ OSP_DCHECK(task_runner_);
if (platform_client_) {
platform_client_->tls_data_router()->RegisterConnection(this);
}
@@ -44,10 +46,11 @@ TlsConnectionPosix::TlsConnectionPosix(IPEndpoint local_address,
TlsConnectionPosix::TlsConnectionPosix(IPAddress::Version version,
TaskRunner* task_runner,
PlatformClientPosix* platform_client)
- : TlsConnection(task_runner),
+ : task_runner_(task_runner),
+ platform_client_(platform_client),
socket_(std::make_unique<StreamSocketPosix>(version)),
- buffer_(this),
- platform_client_(platform_client) {
+ buffer_(this) {
+ OSP_DCHECK(task_runner_);
if (platform_client_) {
platform_client_->tls_data_router()->RegisterConnection(this);
}
@@ -56,10 +59,11 @@ TlsConnectionPosix::TlsConnectionPosix(IPAddress::Version version,
TlsConnectionPosix::TlsConnectionPosix(std::unique_ptr<StreamSocket> socket,
TaskRunner* task_runner,
PlatformClientPosix* platform_client)
- : TlsConnection(task_runner),
+ : task_runner_(task_runner),
+ platform_client_(platform_client),
socket_(std::move(socket)),
- buffer_(this),
- platform_client_(platform_client) {
+ buffer_(this) {
+ OSP_DCHECK(task_runner_);
if (platform_client_) {
platform_client_->tls_data_router()->RegisterConnection(this);
}
@@ -86,42 +90,73 @@ void TlsConnectionPosix::TryReceiveMessage() {
if (bytes_read <= 0) {
const Error error = GetSSLError(ssl_.get(), bytes_read);
if (!error.ok() && (error != Error::Code::kAgain)) {
- OnError(error);
+ DispatchError(error);
}
return;
}
block.resize(bytes_read);
- OnRead(std::move(block));
+
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ moved_block = std::move(block)]() mutable {
+ if (auto* self = weak_this.get()) {
+ if (auto* client = self->client_) {
+ client->OnRead(self, std::move(moved_block));
+ }
+ }
+ });
}
}
+void TlsConnectionPosix::SetClient(Client* client) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+ client_ = client;
+ notified_client_buffer_is_blocked_ = false;
+}
+
void TlsConnectionPosix::Write(const void* data, size_t len) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
buffer_.Write(data, len);
}
-IPEndpoint TlsConnectionPosix::local_address() const {
+IPEndpoint TlsConnectionPosix::GetLocalEndpoint() const {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
absl::optional<IPEndpoint> endpoint = socket_->local_address();
OSP_DCHECK(endpoint.has_value());
- return std::move(endpoint.value());
+ return endpoint.value();
}
-IPEndpoint TlsConnectionPosix::remote_address() const {
+IPEndpoint TlsConnectionPosix::GetRemoteEndpoint() const {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
absl::optional<IPEndpoint> endpoint = socket_->remote_address();
OSP_DCHECK(endpoint.has_value());
- return std::move(endpoint.value());
+ return endpoint.value();
}
void TlsConnectionPosix::NotifyWriteBufferFill(double fraction) {
+ // WARNING: This method is called on multiple threads.
+ //
+ // The following is very subtle/complex behavior: Only "writes" can increase
+ // the buffer fill, so we expect transitions into the "blocked" state to occur
+ // on the |task_runner_| thread, and |client_| will be notified
+ // *synchronously* when that happens. Likewise, only "reads" can cause
+ // transitions to the "unblocked" state; but these will not occur on the
+ // |task_runner_| thread. Thus, when unblocking, the |client_| will be
+ // notified *asynchronously*; but, that should be acceptable because it's only
+ // a race towards a buffer overrun that is of concern.
+ //
+ // TODO(rwkeane): Have Write() return a bool, and then none of this is needed.
constexpr double kBlockBufferPercentage = 0.5;
- if (fraction > kBlockBufferPercentage && !is_buffer_blocked_) {
- OnWriteBlocked();
- is_buffer_blocked_ = true;
- } else if (fraction < kBlockBufferPercentage && is_buffer_blocked_) {
- OnWriteUnblocked();
- is_buffer_blocked_ = false;
- } else if (fraction >= 0.99 && is_buffer_blocked_) {
- OnError(Error::Code::kInsufficientBuffer);
+ if (fraction > kBlockBufferPercentage &&
+ !notified_client_buffer_is_blocked_) {
+ NotifyClientOfWriteBlockStatusSequentially(true);
+ } else if (fraction < kBlockBufferPercentage &&
+ notified_client_buffer_is_blocked_) {
+ NotifyClientOfWriteBlockStatusSequentially(false);
+ } else if (fraction >= 0.99) {
+ DispatchError(Error::Code::kInsufficientBuffer);
}
}
@@ -136,12 +171,56 @@ void TlsConnectionPosix::SendAvailableBytes() {
if (result <= 0) {
const Error result_error = GetSSLError(ssl_.get(), result);
if (!result_error.ok() && (result_error.code() != Error::Code::kAgain)) {
- OnError(result_error);
+ DispatchError(result_error);
}
} else {
buffer_.Consume(static_cast<size_t>(result));
}
}
+void TlsConnectionPosix::DispatchError(Error error) {
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ moved_error = std::move(error)]() mutable {
+ if (auto* self = weak_this.get()) {
+ if (auto* client = self->client_) {
+ client->OnError(self, std::move(moved_error));
+ }
+ }
+ });
+}
+
+void TlsConnectionPosix::NotifyClientOfWriteBlockStatusSequentially(
+ bool is_blocked) {
+ if (!task_runner_->IsRunningOnTaskRunner()) {
+ task_runner_->PostTask(
+ [weak_this = weak_factory_.GetWeakPtr(), is_blocked = is_blocked] {
+ if (auto* self = weak_this.get()) {
+ OSP_DCHECK(self->task_runner_->IsRunningOnTaskRunner());
+ self->NotifyClientOfWriteBlockStatusSequentially(is_blocked);
+ }
+ });
+ return;
+ }
+
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (!client_) {
+ return;
+ }
+
+ // Check again, now that the block/unblock state change is happening
+ // in-sequence (it originated from parallel executions).
+ if (notified_client_buffer_is_blocked_ == is_blocked) {
+ return;
+ }
+
+ notified_client_buffer_is_blocked_ = is_blocked;
+ if (is_blocked) {
+ client_->OnWriteBlocked(this);
+ } else {
+ client_->OnWriteUnblocked(this);
+ }
+}
+
} // namespace platform
} // namespace openscreen
diff --git a/platform/impl/tls_connection_posix.h b/platform/impl/tls_connection_posix.h
index 23b46278..c0089844 100644
--- a/platform/impl/tls_connection_posix.h
+++ b/platform/impl/tls_connection_posix.h
@@ -9,22 +9,43 @@
#include <atomic>
#include <memory>
-#include <string>
-#include <vector>
-#include "platform/api/task_runner.h"
#include "platform/api/tls_connection.h"
-#include "platform/base/socket_state.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/stream_socket_posix.h"
#include "platform/impl/tls_write_buffer.h"
+#include "platform/impl/weak_ptr.h"
namespace openscreen {
namespace platform {
+class TaskRunner;
+class TlsConnectionFactoryPosix;
+
class TlsConnectionPosix : public TlsConnection,
public TlsWriteBuffer::Observer {
public:
+ ~TlsConnectionPosix() override;
+
+ // Sends any available bytes from this connection's buffer_.
+ virtual void SendAvailableBytes();
+
+ // Read out a block/message, if one is available, and notify this instance's
+ // TlsConnection::Client.
+ virtual void TryReceiveMessage();
+
+ // TlsConnection overrides.
+ void SetClient(Client* client) override;
+ void Write(const void* data, size_t len) override;
+ IPEndpoint GetLocalEndpoint() const override;
+ IPEndpoint GetRemoteEndpoint() const override;
+
+ // TlsWriteBuffer::Observer overrides.
+ void NotifyWriteBufferFill(double fraction) override;
+
+ protected:
+ friend class TlsConnectionFactoryPosix;
+
TlsConnectionPosix(IPEndpoint local_address,
TaskRunner* task_runner,
PlatformClientPosix* platform_client =
@@ -37,33 +58,32 @@ class TlsConnectionPosix : public TlsConnection,
TaskRunner* task_runner,
PlatformClientPosix* platform_client =
PlatformClientPosix::GetInstance());
- ~TlsConnectionPosix();
- // Sends any available bytes from this connection's buffer_.
- virtual void SendAvailableBytes();
+ private:
+ // Called on any thread, to post a task to notify the Client that an |error|
+ // has occurred.
+ void DispatchError(Error error);
- // Read out a block/message, if one is available, and notify this instance's
- // TlsConnection::Client.
- virtual void TryReceiveMessage();
+ // Helper to call OnWriteBlocked() or OnWriteUnblocked(). If this is not
+ // called within a task run by |task_runner_|, it trampolines by posting a
+ // task to call itself back via |task_runner_|. See comments in implementation
+ // of NotifyWriteBufferFill() for further details.
+ void NotifyClientOfWriteBlockStatusSequentially(bool is_blocked);
- // TlsConnection overrides.
- void Write(const void* data, size_t len) override;
- IPEndpoint local_address() const override;
- IPEndpoint remote_address() const override;
+ TaskRunner* const task_runner_;
+ PlatformClientPosix* const platform_client_;
- // TlsWriteBuffer::Observer overrides.
- void NotifyWriteBufferFill(double fraction) override;
+ Client* client_ = nullptr;
- private:
std::unique_ptr<StreamSocket> socket_;
bssl::UniquePtr<SSL> ssl_;
- std::atomic_bool is_buffer_blocked_{false};
+ std::atomic_bool notified_client_buffer_is_blocked_{false};
TlsWriteBuffer buffer_;
- PlatformClientPosix* platform_client_;
+ WeakPtrFactory<TlsConnectionPosix> weak_factory_{this};
- friend class TlsConnectionFactoryPosix;
+ OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionPosix);
};
} // namespace platform
diff --git a/platform/impl/udp_socket_posix.cc b/platform/impl/udp_socket_posix.cc
index f98dc866..c4226024 100644
--- a/platform/impl/udp_socket_posix.cc
+++ b/platform/impl/udp_socket_posix.cc
@@ -425,10 +425,11 @@ void UdpSocketPosix::ReceiveMessage() {
// calling into all the other methods.
if (is_closed()) {
- task_runner_->PostTask([this] {
- // TODO(issues/71): |this| may be invalid at this point.
- if (client_) {
- client_->OnRead(this, Error::Code::kSocketClosedFailure);
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr()] {
+ if (auto* self = weak_this.get()) {
+ if (auto* client = self->client_) {
+ client->OnRead(self, Error::Code::kSocketClosedFailure);
+ }
}
});
return;
@@ -437,11 +438,13 @@ void UdpSocketPosix::ReceiveMessage() {
ssize_t bytes_available = recv(handle_.fd, nullptr, 0, MSG_PEEK | MSG_TRUNC);
if (bytes_available == -1) {
task_runner_->PostTask(
- [this, error = ChooseError(errno,
- Error::Code::kSocketReadFailure)]() mutable {
- // TODO(issues/71): |this| may be invalid at this point.
- if (client_) {
- client_->OnRead(this, std::move(error));
+ [weak_this = weak_factory_.GetWeakPtr(),
+ error =
+ ChooseError(errno, Error::Code::kSocketReadFailure)]() mutable {
+ if (auto* self = weak_this.get()) {
+ if (auto* client = self->client_) {
+ client->OnRead(self, std::move(error));
+ }
}
});
return;
@@ -466,12 +469,14 @@ void UdpSocketPosix::ReceiveMessage() {
}
task_runner_->PostTask(
- [this, read_result = result.ok() ? ErrorOr<UdpPacket>(std::move(packet))
- : ErrorOr<UdpPacket>(
- std::move(result))]() mutable {
- // TODO(issues/71): |this| may be invalid at this point.
- if (client_) {
- client_->OnRead(this, std::move(read_result));
+ [weak_this = weak_factory_.GetWeakPtr(),
+ read_result = result.ok()
+ ? ErrorOr<UdpPacket>(std::move(packet))
+ : ErrorOr<UdpPacket>(std::move(result))]() mutable {
+ if (auto* self = weak_this.get()) {
+ if (auto* client = self->client_) {
+ client->OnRead(self, std::move(read_result));
+ }
}
});
}
diff --git a/platform/impl/udp_socket_posix.h b/platform/impl/udp_socket_posix.h
index dcd614a5..ba44d61a 100644
--- a/platform/impl/udp_socket_posix.h
+++ b/platform/impl/udp_socket_posix.h
@@ -7,8 +7,10 @@
#include "absl/types/optional.h"
#include "platform/api/udp_socket.h"
+#include "platform/base/macros.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/socket_handle_posix.h"
+#include "platform/impl/weak_ptr.h"
namespace openscreen {
namespace platform {
@@ -78,7 +80,11 @@ class UdpSocketPosix : public UdpSocket {
// port is non-zero, it is assumed never to change again.
mutable IPEndpoint local_endpoint_;
- PlatformClientPosix* platform_client_;
+ WeakPtrFactory<UdpSocketPosix> weak_factory_{this};
+
+ PlatformClientPosix* const platform_client_;
+
+ OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocketPosix);
};
} // namespace platform
diff --git a/platform/impl/weak_ptr.h b/platform/impl/weak_ptr.h
new file mode 100644
index 00000000..e065b067
--- /dev/null
+++ b/platform/impl/weak_ptr.h
@@ -0,0 +1,216 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef PLATFORM_IMPL_WEAK_PTR_H_
+#define PLATFORM_IMPL_WEAK_PTR_H_
+
+#include <memory>
+
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+// Weak pointers are pointers to an object that do not affect its lifetime,
+// and which may be invalidated (i.e. reset to nullptr) by the object, or its
+// owner, at any time; most commonly when the object is about to be deleted.
+//
+// Weak pointers are useful when an object needs to be accessed safely by one
+// or more objects other than its owner, and those callers can cope with the
+// object vanishing and e.g. tasks posted to it being silently dropped.
+// Reference-counting such an object would complicate the ownership graph and
+// make it harder to reason about the object's lifetime.
+//
+// EXAMPLE:
+//
+// class Controller {
+// public:
+// void SpawnWorker() { new Worker(weak_factory_.GetWeakPtr()); }
+// void WorkComplete(const Result& result) { ... }
+// private:
+// // Member variables should appear before the WeakPtrFactory, to ensure
+// // that any WeakPtrs to Controller are invalidated before its members
+// // variable's destructors are executed, rendering them invalid.
+// WeakPtrFactory<Controller> weak_factory_{this};
+// };
+//
+// class Worker {
+// public:
+// explicit Worker(WeakPtr<Controller> controller)
+// : controller_(std::move(controller)) {}
+// private:
+// void DidCompleteAsynchronousProcessing(const Result& result) {
+// if (controller_)
+// controller_->WorkComplete(result);
+// delete this;
+// }
+// const WeakPtr<Controller> controller_;
+// };
+//
+// With this implementation a caller may use SpawnWorker() to dispatch multiple
+// Workers and subsequently delete the Controller, without waiting for all
+// Workers to have completed.
+//
+// ------------------------- IMPORTANT: Thread-safety -------------------------
+//
+// Generally, Open Screen code is meant to be single-threaded. For the few
+// exceptional cases, the following is relevant:
+//
+// WeakPtrs may be created from WeakPtrFactory, and also duplicated/moved on any
+// thread/sequence. However, they may only be dereferenced on the same
+// thread/sequence that will ultimately execute the WeakPtrFactory destructor or
+// call InvalidateWeakPtrs(). Otherwise, use-during-free or use-after-free is
+// possible.
+//
+// openscreen::WeakPtr and WeakPtrFactory are similar, but not identical, to
+// Chromium's base::WeakPtrFactory. Open Screen WeakPtrs may be safely created
+// from WeakPtrFactory on any thread/sequence, since they are backed by the
+// thread-safe bookkeeping of std::shared_ptr<>.
+
+template <typename T>
+class WeakPtrFactory;
+
+template <typename T>
+class WeakPtr {
+ public:
+ WeakPtr() = default;
+ ~WeakPtr() = default;
+
+ // Copy/Move constructors and assignment operators.
+ WeakPtr(const WeakPtr& other) : impl_(other.impl_) {}
+
+ WeakPtr(WeakPtr&& other) noexcept : impl_(std::move(other.impl_)) {}
+
+ WeakPtr& operator=(const WeakPtr& other) {
+ impl_ = other.impl_;
+ return *this;
+ }
+
+ WeakPtr& operator=(WeakPtr&& other) noexcept {
+ impl_ = std::move(other.impl_);
+ return *this;
+ }
+
+ // Create/Assign from nullptr.
+ WeakPtr(std::nullptr_t) {}
+
+ WeakPtr& operator=(std::nullptr_t) {
+ impl_.reset();
+ return *this;
+ }
+
+ // Copy/Move constructors and assignment operators with upcast conversion.
+ template <typename U>
+ WeakPtr(const WeakPtr<U>& other) : impl_(other.as_std_weak_ptr()) {}
+
+ template <typename U>
+ WeakPtr(WeakPtr<U>&& other) noexcept
+ : impl_(std::move(other).as_std_weak_ptr()) {}
+
+ template <typename U>
+ WeakPtr& operator=(const WeakPtr<U>& other) {
+ impl_ = other.as_std_weak_ptr();
+ return *this;
+ }
+
+ template <typename U>
+ WeakPtr& operator=(WeakPtr<U>&& other) noexcept {
+ impl_ = std::move(other).as_std_weak_ptr();
+ return *this;
+ }
+
+ // Accessors.
+ T* get() const { return impl_.lock().get(); }
+
+ T& operator*() const {
+ T* const pointer = get();
+ OSP_DCHECK(pointer);
+ return *pointer;
+ }
+
+ T* operator->() const {
+ T* const pointer = get();
+ OSP_DCHECK(pointer);
+ return pointer;
+ }
+
+ // Allow conditionals to test validity, e.g. if (weak_ptr) {...}
+ explicit operator bool() const { return get() != nullptr; }
+
+ // Conversion to std::weak_ptr<T>. It is unsafe to convert in the other
+ // direction. See comments for private constructors, below.
+ const std::weak_ptr<T>& as_std_weak_ptr() const& { return impl_; }
+ std::weak_ptr<T> as_std_weak_ptr() && { return std::move(impl_); }
+
+ private:
+ friend class WeakPtrFactory<T>;
+
+ // Called by WeakPtrFactory<T> and the WeakPtr<T> upcast conversion
+ // constructors and assigners. These are purposely not being exposed publicly
+ // because that would allow a WeakPtr<T> to be valid/invalid by a different
+ // ownership/threading model than the intended one (see top-level comments).
+ template <typename U>
+ explicit WeakPtr(const std::weak_ptr<U>& other) : impl_(other) {}
+
+ template <typename U>
+ explicit WeakPtr(std::weak_ptr<U>&& other) noexcept
+ : impl_(std::move(other)) {}
+
+ std::weak_ptr<T> impl_;
+};
+
+// Allow callers to compare WeakPtrs against nullptr to test validity.
+template <typename T>
+bool operator!=(const WeakPtr<T>& weak_ptr, std::nullptr_t) {
+ return weak_ptr.get() != nullptr;
+}
+template <typename T>
+bool operator!=(std::nullptr_t, const WeakPtr<T>& weak_ptr) {
+ return weak_ptr.get() != nullptr;
+}
+template <typename T>
+bool operator==(const WeakPtr<T>& weak_ptr, std::nullptr_t) {
+ return weak_ptr.get() == nullptr;
+}
+template <typename T>
+bool operator==(std::nullptr_t, const WeakPtr<T>& weak_ptr) {
+ return weak_ptr == nullptr;
+}
+
+template <typename T>
+class WeakPtrFactory {
+ public:
+ explicit WeakPtrFactory(T* instance) { Reset(instance); }
+ WeakPtrFactory(WeakPtrFactory&& other) noexcept = default;
+ WeakPtrFactory& operator=(WeakPtrFactory&& other) noexcept = default;
+
+ // Thread-safe: WeakPtrs may be created on any thread/seuence. They may also
+ // be copied and moved on any thread/sequence. However, they MUST only be
+ // dereferenced on the same thread/sequence that calls the destructor or
+ // InvalidateWeakPtrs().
+ WeakPtr<T> GetWeakPtr() const {
+ return WeakPtr<T>(std::weak_ptr<T>(bookkeeper_));
+ }
+
+ // Destruction and Invalidation: These must be called on the same
+ // thread/sequence that dereferences any WeakPtrs to avoid use-after-free
+ // bugs.
+ ~WeakPtrFactory() = default;
+ void InvalidateWeakPtrs() { Reset(bookkeeper_.get()); }
+
+ private:
+ WeakPtrFactory(const WeakPtrFactory& other) = delete;
+ WeakPtrFactory& operator=(const WeakPtrFactory& other) = delete;
+
+ void Reset(T* instance) {
+ // T is owned externally to WeakPtrFactory. Thus, provide a no-op Deleter.
+ bookkeeper_ = {instance, [](T* instance) {}};
+ }
+
+ // Manages the std::weak_ptr's referring to T. Does not own T.
+ std::shared_ptr<T> bookkeeper_;
+};
+
+} // namespace openscreen
+
+#endif // PLATFORM_IMPL_WEAK_PTR_H_
diff --git a/platform/impl/weak_ptr_unittest.cc b/platform/impl/weak_ptr_unittest.cc
new file mode 100644
index 00000000..b06574bb
--- /dev/null
+++ b/platform/impl/weak_ptr_unittest.cc
@@ -0,0 +1,186 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "platform/impl/weak_ptr.h"
+
+#include "gtest/gtest.h"
+
+namespace openscreen {
+namespace {
+
+class SomeClass {
+ public:
+ virtual ~SomeClass() = default;
+ virtual int GetValue() const { return 42; }
+};
+
+struct SomeSubclass : public SomeClass {
+ public:
+ ~SomeSubclass() final = default;
+ int GetValue() const override { return 999; }
+};
+
+TEST(WeakPtrTest, InteractsWithNullptr) {
+ WeakPtr<const int> default_constructed; // Invoke default constructor.
+ EXPECT_TRUE(default_constructed == nullptr);
+ EXPECT_TRUE(nullptr == default_constructed);
+ EXPECT_FALSE(default_constructed != nullptr);
+ EXPECT_FALSE(nullptr != default_constructed);
+
+ WeakPtr<const int> null_constructed = nullptr; // Invoke construct-from-null.
+ EXPECT_TRUE(null_constructed == nullptr);
+ EXPECT_TRUE(nullptr == null_constructed);
+ EXPECT_FALSE(null_constructed != nullptr);
+ EXPECT_FALSE(nullptr != null_constructed);
+
+ const int foo = 42;
+ WeakPtrFactory<const int> factory(&foo);
+ WeakPtr<const int> not_null = factory.GetWeakPtr();
+ EXPECT_TRUE(not_null != nullptr);
+ EXPECT_TRUE(nullptr != not_null);
+ EXPECT_FALSE(not_null == nullptr);
+ EXPECT_FALSE(nullptr == not_null);
+}
+
+TEST(WeakPtrTest, CopyConstructsAndAssigns) {
+ SomeSubclass foo;
+ WeakPtrFactory<SomeSubclass> factory(&foo);
+
+ WeakPtr<SomeSubclass> weak_ptr = factory.GetWeakPtr();
+ EXPECT_TRUE(weak_ptr);
+ EXPECT_EQ(&foo, weak_ptr.get());
+
+ // Normal copy constructor.
+ WeakPtr<SomeSubclass> copied0 = weak_ptr;
+ EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original.
+ EXPECT_TRUE(copied0);
+ EXPECT_EQ(&foo, copied0.get());
+
+ // Copy constructor, adding const qualifier.
+ WeakPtr<const SomeSubclass> copied1 = weak_ptr;
+ EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original.
+ EXPECT_TRUE(copied1);
+ EXPECT_EQ(&foo, copied1.get());
+
+ // Normal copy assignment.
+ WeakPtr<SomeSubclass> assigned0;
+ EXPECT_FALSE(assigned0);
+ assigned0 = copied0;
+ EXPECT_EQ(&foo, copied0.get()); // Did not mutate original.
+ EXPECT_TRUE(assigned0);
+ EXPECT_EQ(&foo, assigned0.get());
+
+ // Copy assignment, adding const qualifier.
+ WeakPtr<const SomeSubclass> assigned1;
+ EXPECT_FALSE(assigned1);
+ assigned1 = copied0;
+ EXPECT_EQ(&foo, copied0.get()); // Did not mutate original.
+ EXPECT_TRUE(assigned1);
+ EXPECT_EQ(&foo, assigned1.get());
+
+ // Upcast copy constructor.
+ WeakPtr<SomeClass> copied2 = weak_ptr;
+ EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original.
+ EXPECT_TRUE(copied2);
+ EXPECT_EQ(&foo, copied2.get());
+ EXPECT_EQ(999, (*copied2).GetValue());
+ EXPECT_EQ(999, copied2->GetValue());
+
+ // Upcast copy assignment.
+ WeakPtr<SomeClass> assigned2;
+ EXPECT_FALSE(assigned2);
+ assigned2 = weak_ptr;
+ EXPECT_EQ(&foo, weak_ptr.get()); // Did not mutate original.
+ EXPECT_TRUE(assigned2);
+ EXPECT_EQ(&foo, assigned2.get());
+ EXPECT_EQ(999, (*assigned2).GetValue());
+ EXPECT_EQ(999, assigned2->GetValue());
+}
+
+TEST(WeakPtrTest, MoveConstructsAndAssigns) {
+ SomeSubclass foo;
+ WeakPtrFactory<SomeSubclass> factory(&foo);
+
+ // Normal move constructor.
+ WeakPtr<SomeSubclass> weak_ptr = factory.GetWeakPtr();
+ WeakPtr<SomeSubclass> moved0 = std::move(weak_ptr);
+ EXPECT_FALSE(weak_ptr); // Original becomes null.
+ EXPECT_TRUE(moved0);
+ EXPECT_EQ(&foo, moved0.get());
+
+ // Move constructor, adding const qualifier.
+ weak_ptr = factory.GetWeakPtr();
+ WeakPtr<const SomeSubclass> moved1 = std::move(weak_ptr);
+ EXPECT_FALSE(weak_ptr); // Original becomes null.
+ EXPECT_TRUE(moved1);
+ EXPECT_EQ(&foo, moved1.get());
+
+ // Normal move assignment.
+ weak_ptr = factory.GetWeakPtr();
+ WeakPtr<SomeSubclass> assigned0;
+ EXPECT_FALSE(assigned0);
+ assigned0 = std::move(weak_ptr);
+ EXPECT_FALSE(weak_ptr); // Original becomes null.
+ EXPECT_TRUE(assigned0);
+ EXPECT_EQ(&foo, assigned0.get());
+
+ // Move assignment, adding const qualifier.
+ weak_ptr = factory.GetWeakPtr();
+ WeakPtr<const SomeSubclass> assigned1;
+ EXPECT_FALSE(assigned1);
+ assigned1 = std::move(weak_ptr);
+ EXPECT_FALSE(weak_ptr); // Original becomes null.
+ EXPECT_TRUE(assigned1);
+ EXPECT_EQ(&foo, assigned1.get());
+
+ // Upcast move constructor.
+ weak_ptr = factory.GetWeakPtr();
+ WeakPtr<SomeClass> moved2 = std::move(weak_ptr);
+ EXPECT_FALSE(weak_ptr); // Original becomes null.
+ EXPECT_TRUE(moved2);
+ EXPECT_EQ(&foo, moved2.get());
+ EXPECT_EQ(999, (*moved2).GetValue()); // Result from subclass's GetValue().
+ EXPECT_EQ(999, moved2->GetValue()); // Result from subclass's GetValue().
+
+ // Upcast move assignment.
+ weak_ptr = factory.GetWeakPtr();
+ WeakPtr<SomeClass> assigned2;
+ EXPECT_FALSE(assigned2);
+ assigned2 = std::move(weak_ptr);
+ EXPECT_FALSE(weak_ptr); // Original becomes null.
+ EXPECT_TRUE(assigned2);
+ EXPECT_EQ(&foo, assigned2.get());
+ EXPECT_EQ(999,
+ (*assigned2).GetValue()); // Result from subclass's GetValue().
+ EXPECT_EQ(999, assigned2->GetValue()); // Result from subclass's GetValue().
+}
+
+TEST(WeakPtrTest, InvalidatesWeakPtrs) {
+ const int foo = 1337;
+ WeakPtrFactory<const int> factory(&foo);
+
+ // Thrice: Create weak pointers and invalidate them. This is done more than
+ // once to confirm the factory can create valid WeakPtrs again after each
+ // InvalidateWeakPtrs() call.
+ for (int i = 0; i < 3; ++i) {
+ // Create three WeakPtrs, two from the factory, one as a copy of another
+ // WeakPtr.
+ WeakPtr<const int> ptr0 = factory.GetWeakPtr();
+ WeakPtr<const int> ptr1 = factory.GetWeakPtr();
+ WeakPtr<const int> ptr2 = ptr1;
+ EXPECT_EQ(&foo, ptr0.get());
+ EXPECT_EQ(&foo, ptr1.get());
+ EXPECT_EQ(&foo, ptr2.get());
+
+ // Invalidate all outstanding WeakPtrs from the factory, and confirm all
+ // outstanding WeakPtrs become null.
+ factory.InvalidateWeakPtrs();
+ EXPECT_FALSE(ptr0);
+ EXPECT_FALSE(ptr1);
+ EXPECT_FALSE(ptr2);
+ }
+}
+
+} // namespace
+} // namespace openscreen