aboutsummaryrefslogtreecommitdiff
path: root/cast
diff options
context:
space:
mode:
authorbtolsch <btolsch@chromium.org>2019-10-15 12:54:10 -0700
committerCommit Bot <commit-bot@chromium.org>2019-10-15 20:03:29 +0000
commita04b13eaa77d8ca31c0f06e32231bef84f96e32d (patch)
tree5158547e199df455b63cbe7f93c6c99748054181 /cast
parentc56f0388085f0e0d4df657498381f35317fbf041 (diff)
downloadopenscreen-a04b13eaa77d8ca31c0f06e32231bef84f96e32d.tar.gz
Add cast sender socket factory
This change adds a CastSocket factory for the sender-side which performs the sender auth challenge and verification before passing a CastSocket back to the caller. Bug: openscreen:59 Change-Id: Ibbbdb2b8881e385cc0a8defbe309c7f10a2af323 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1834457 Commit-Queue: Brandon Tolsch <btolsch@chromium.org> Reviewed-by: Ryan Keane <rwkeane@google.com>
Diffstat (limited to 'cast')
-rw-r--r--cast/common/channel/BUILD.gn1
-rw-r--r--cast/common/channel/cast_socket.cc7
-rw-r--r--cast/common/channel/cast_socket.h8
-rw-r--r--cast/common/channel/message_util.h39
-rw-r--r--cast/sender/channel/BUILD.gn4
-rw-r--r--cast/sender/channel/message_util.cc34
-rw-r--r--cast/sender/channel/message_util.h21
-rw-r--r--cast/sender/channel/sender_socket_factory.cc166
-rw-r--r--cast/sender/channel/sender_socket_factory.h104
9 files changed, 381 insertions, 3 deletions
diff --git a/cast/common/channel/BUILD.gn b/cast/common/channel/BUILD.gn
index 71b393b2..7a362e3c 100644
--- a/cast/common/channel/BUILD.gn
+++ b/cast/common/channel/BUILD.gn
@@ -8,6 +8,7 @@ source_set("channel") {
"cast_socket.h",
"message_framer.cc",
"message_framer.h",
+ "message_util.h",
]
deps = [
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc
index 8ad61542..ce35e54e 100644
--- a/cast/common/channel/cast_socket.cc
+++ b/cast/common/channel/cast_socket.cc
@@ -4,6 +4,8 @@
#include "cast/common/channel/cast_socket.h"
+#include <atomic>
+
#include "cast/common/channel/message_framer.h"
#include "platform/api/logging.h"
@@ -14,6 +16,11 @@ using message_serialization::DeserializeResult;
using openscreen::ErrorOr;
using openscreen::platform::TlsConnection;
+uint32_t GetNextSocketId() {
+ static std::atomic<uint32_t> id(1);
+ return id++;
+}
+
CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
Client* client,
uint32_t socket_id)
diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h
index a8fa3c48..0c173ef5 100644
--- a/cast/common/channel/cast_socket.h
+++ b/cast/common/channel/cast_socket.h
@@ -17,6 +17,8 @@ using TlsConnection = openscreen::platform::TlsConnection;
class CastMessage;
+uint32_t GetNextSocketId();
+
// Represents a simple message-oriented socket for communicating with the Cast
// V2 protocol. It isn't thread-safe, so it should only be used on the same
// TaskRunner thread as its TlsConnection.
@@ -38,9 +40,9 @@ class CastSocket : public TlsConnection::Client {
~CastSocket();
// Sends |message| immediately unless the underlying TLS connection is
- // write-blocked, in which case |message| will be queued. No error is
- // returned for both queueing and successful sending. An error will be
- // returned if |message| cannot be serialized for any reason.
+ // write-blocked, in which case |message| will be queued. An error will be
+ // returned if |message| cannot be serialized for any reason, even while
+ // write-blocked.
Error SendMessage(const CastMessage& message);
void set_client(Client* client) {
diff --git a/cast/common/channel/message_util.h b/cast/common/channel/message_util.h
new file mode 100644
index 00000000..5b84dbd4
--- /dev/null
+++ b/cast/common/channel/message_util.h
@@ -0,0 +1,39 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
+#define CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
+
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace cast {
+namespace channel {
+
+// Reserved message namespaces for internal messages.
+static constexpr char kCastInternalNamespacePrefix[] =
+ "urn:x-cast:com.google.cast.";
+static constexpr char kAuthNamespace[] =
+ "urn:x-cast:com.google.cast.tp.deviceauth";
+static constexpr char kHeartbeatNamespace[] =
+ "urn:x-cast:com.google.cast.tp.heartbeat";
+static constexpr char kConnectionNamespace[] =
+ "urn:x-cast:com.google.cast.tp.connection";
+static constexpr char kReceiverNamespace[] =
+ "urn:x-cast:com.google.cast.receiver";
+static constexpr char kBroadcastNamespace[] =
+ "urn:x-cast:com.google.cast.broadcast";
+static constexpr char kMediaNamespace[] = "urn:x-cast:com.google.cast.media";
+
+// Sender and receiver IDs to use for platform messages.
+static constexpr char kPlatformSenderId[] = "sender-0";
+static constexpr char kPlatformReceiverId[] = "receiver-0";
+
+inline bool IsAuthMessage(const CastMessage& message) {
+ return message.namespace_() == kAuthNamespace;
+}
+
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
diff --git a/cast/sender/channel/BUILD.gn b/cast/sender/channel/BUILD.gn
index 85f633b0..5fbed9ce 100644
--- a/cast/sender/channel/BUILD.gn
+++ b/cast/sender/channel/BUILD.gn
@@ -6,6 +6,10 @@ source_set("channel") {
sources = [
"cast_auth_util.cc",
"cast_auth_util.h",
+ "message_util.cc",
+ "message_util.h",
+ "sender_socket_factory.cc",
+ "sender_socket_factory.h",
]
deps = [
diff --git a/cast/sender/channel/message_util.cc b/cast/sender/channel/message_util.cc
new file mode 100644
index 00000000..ab3ed5d8
--- /dev/null
+++ b/cast/sender/channel/message_util.cc
@@ -0,0 +1,34 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/message_util.h"
+
+#include "cast/sender/channel/cast_auth_util.h"
+
+namespace cast {
+namespace channel {
+
+CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) {
+ CastMessage message;
+ DeviceAuthMessage auth_message;
+
+ AuthChallenge* challenge = auth_message.mutable_challenge();
+ challenge->set_sender_nonce(auth_context.nonce());
+ challenge->set_hash_algorithm(SHA256);
+
+ std::string auth_message_string;
+ auth_message.SerializeToString(&auth_message_string);
+
+ message.set_protocol_version(CastMessage::CASTV2_1_0);
+ message.set_source_id(kPlatformSenderId);
+ message.set_destination_id(kPlatformReceiverId);
+ message.set_namespace_(kAuthNamespace);
+ message.set_payload_type(CastMessage_PayloadType_BINARY);
+ message.set_payload_binary(auth_message_string);
+
+ return message;
+}
+
+} // namespace channel
+} // namespace cast
diff --git a/cast/sender/channel/message_util.h b/cast/sender/channel/message_util.h
new file mode 100644
index 00000000..e2da0cd8
--- /dev/null
+++ b/cast/sender/channel/message_util.h
@@ -0,0 +1,21 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
+#define CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
+
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace cast {
+namespace channel {
+
+class AuthContext;
+
+CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context);
+
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
new file mode 100644
index 00000000..83e73370
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -0,0 +1,166 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/sender_socket_factory.h"
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/message_util.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::platform::TlsConnectOptions;
+
+bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
+ uint32_t b) {
+ return a && a->socket->socket_id() < b;
+}
+
+bool operator<(uint32_t a,
+ const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
+ return b && a < b->socket->socket_id();
+}
+
+SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) {
+ OSP_DCHECK(client);
+}
+
+SenderSocketFactory::~SenderSocketFactory() = default;
+
+void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
+ DeviceMediaPolicy media_policy,
+ CastSocket::Client* client) {
+ OSP_DCHECK(client);
+ auto it = FindPendingConnection(endpoint);
+ if (it == pending_connections_.end()) {
+ pending_connections_.emplace_back(
+ PendingConnection{endpoint, media_policy, client});
+ factory_->Connect(endpoint, TlsConnectOptions{true});
+ }
+}
+
+void SenderSocketFactory::OnAccepted(
+ TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ OSP_NOTREACHED() << "This factory is connect-only.";
+}
+
+void SenderSocketFactory::OnConnected(
+ TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ const IPEndpoint& endpoint = connection->remote_address();
+ auto it = FindPendingConnection(endpoint);
+ if (it == pending_connections_.end()) {
+ OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
+ << endpoint;
+ return;
+ }
+ DeviceMediaPolicy media_policy = it->media_policy;
+ CastSocket::Client* client = it->client;
+ pending_connections_.erase(it);
+
+ if (!peer_cert) {
+ client_->OnError(this, endpoint, Error::Code::kErrCertsMissing);
+ return;
+ }
+
+ auto socket = std::make_unique<CastSocket>(std::move(connection), this,
+ GetNextSocketId());
+ pending_auth_.emplace_back(new PendingAuth{endpoint, media_policy,
+ std::move(socket), client,
+ AuthContext::Create(), peer_cert});
+ PendingAuth& pending = *pending_auth_.back();
+
+ CastMessage auth_challenge = CreateAuthChallengeMessage(pending.auth_context);
+ Error error = pending.socket->SendMessage(auth_challenge);
+ if (!error.ok()) {
+ pending_auth_.pop_back();
+ client_->OnError(this, endpoint, error);
+ }
+}
+
+void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) {
+ auto it = FindPendingConnection(remote_address);
+ if (it == pending_connections_.end()) {
+ OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
+ << remote_address;
+ return;
+ }
+ pending_connections_.erase(it);
+ client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
+}
+
+void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
+ std::vector<PendingConnection> connections;
+ pending_connections_.swap(connections);
+ for (const PendingConnection& pending : connections) {
+ client_->OnError(this, pending.endpoint, error);
+ }
+}
+
+std::vector<SenderSocketFactory::PendingConnection>::iterator
+SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
+ return std::find_if(pending_connections_.begin(), pending_connections_.end(),
+ [&endpoint](const PendingConnection& pending) {
+ return pending.endpoint == endpoint;
+ });
+}
+
+void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
+ auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+ [id = socket->socket_id()](
+ const std::unique_ptr<PendingAuth>& pending_auth) {
+ return pending_auth->socket->socket_id() == id;
+ });
+ if (it == pending_auth_.end()) {
+ OSP_DLOG_ERROR << "Got error for unknown pending socket";
+ return;
+ }
+ IPEndpoint endpoint = (*it)->endpoint;
+ pending_auth_.erase(it);
+ client_->OnError(this, endpoint, error);
+}
+
+void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
+ auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+ [id = socket->socket_id()](
+ const std::unique_ptr<PendingAuth>& pending_auth) {
+ return pending_auth->socket->socket_id() == id;
+ });
+ if (it == pending_auth_.end()) {
+ OSP_DLOG_ERROR << "Got message for unknown pending socket";
+ return;
+ }
+
+ std::unique_ptr<PendingAuth> pending = std::move(*it);
+ pending_auth_.erase(it);
+ if (!IsAuthMessage(message)) {
+ client_->OnError(this, pending->endpoint,
+ Error::Code::kCastV2AuthenticationError);
+ return;
+ }
+
+ ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
+ message, (*it)->peer_cert, (*it)->auth_context);
+ if (policy_or_error.is_error()) {
+ client_->OnError(this, pending->endpoint, policy_or_error.error());
+ return;
+ }
+
+ if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
+ pending->media_policy != DeviceMediaPolicy::kAudioOnly) {
+ client_->OnError(this, pending->endpoint,
+ Error::Code::kCastV2ChannelPolicyMismatch);
+ return;
+ }
+
+ pending->socket->set_client(pending->client);
+ client_->OnConnected(this, pending->endpoint, std::move(pending->socket));
+}
+
+} // namespace channel
+} // namespace cast
diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h
new file mode 100644
index 00000000..d5b6622d
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.h
@@ -0,0 +1,104 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
+#define CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/cast_auth_util.h"
+#include "platform/api/logging.h"
+#include "platform/api/tls_connection_factory.h"
+#include "platform/base/ip_address.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::Error;
+using openscreen::IPEndpoint;
+using openscreen::IPEndpointComparator;
+using openscreen::platform::TlsConnection;
+using openscreen::platform::TlsConnectionFactory;
+
+class SenderSocketFactory final : public TlsConnectionFactory::Client,
+ public CastSocket::Client {
+ public:
+ class Client {
+ public:
+ virtual void OnConnected(SenderSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) = 0;
+ virtual void OnError(SenderSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ Error error) = 0;
+ };
+
+ enum class DeviceMediaPolicy {
+ kAudioOnly,
+ kIncludesVideo,
+ };
+
+ // |client| must outlive |this|.
+ explicit SenderSocketFactory(Client* client);
+ ~SenderSocketFactory();
+
+ void set_factory(TlsConnectionFactory* factory) {
+ OSP_DCHECK(factory);
+ factory_ = factory;
+ }
+
+ void Connect(const IPEndpoint& endpoint,
+ DeviceMediaPolicy media_policy,
+ CastSocket::Client* client);
+
+ // TlsConnectionFactory::Client overrides.
+ void OnAccepted(TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) override;
+ void OnConnected(TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) override;
+ void OnConnectionFailed(TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) override;
+ void OnError(TlsConnectionFactory* factory, Error error) override;
+
+ private:
+ struct PendingConnection {
+ IPEndpoint endpoint;
+ DeviceMediaPolicy media_policy;
+ CastSocket::Client* client;
+ };
+
+ struct PendingAuth {
+ IPEndpoint endpoint;
+ DeviceMediaPolicy media_policy;
+ std::unique_ptr<CastSocket> socket;
+ CastSocket::Client* client;
+ AuthContext auth_context;
+ X509* peer_cert;
+ };
+
+ friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b);
+ friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b);
+
+ std::vector<PendingConnection>::iterator FindPendingConnection(
+ const IPEndpoint& endpoint);
+
+ // CastSocket::Client overrides.
+ void OnError(CastSocket* socket, Error error) override;
+ void OnMessage(CastSocket* socket, CastMessage message) override;
+
+ Client* const client_;
+ TlsConnectionFactory* factory_ = nullptr;
+ std::vector<PendingConnection> pending_connections_;
+ std::vector<std::unique_ptr<PendingAuth>> pending_auth_;
+};
+
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_