aboutsummaryrefslogtreecommitdiff
path: root/cast/sender
diff options
context:
space:
mode:
authorbtolsch <btolsch@chromium.org>2019-10-15 12:54:10 -0700
committerCommit Bot <commit-bot@chromium.org>2019-10-15 20:03:29 +0000
commita04b13eaa77d8ca31c0f06e32231bef84f96e32d (patch)
tree5158547e199df455b63cbe7f93c6c99748054181 /cast/sender
parentc56f0388085f0e0d4df657498381f35317fbf041 (diff)
downloadopenscreen-a04b13eaa77d8ca31c0f06e32231bef84f96e32d.tar.gz
Add cast sender socket factory
This change adds a CastSocket factory for the sender-side which performs the sender auth challenge and verification before passing a CastSocket back to the caller. Bug: openscreen:59 Change-Id: Ibbbdb2b8881e385cc0a8defbe309c7f10a2af323 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1834457 Commit-Queue: Brandon Tolsch <btolsch@chromium.org> Reviewed-by: Ryan Keane <rwkeane@google.com>
Diffstat (limited to 'cast/sender')
-rw-r--r--cast/sender/channel/BUILD.gn4
-rw-r--r--cast/sender/channel/message_util.cc34
-rw-r--r--cast/sender/channel/message_util.h21
-rw-r--r--cast/sender/channel/sender_socket_factory.cc166
-rw-r--r--cast/sender/channel/sender_socket_factory.h104
5 files changed, 329 insertions, 0 deletions
diff --git a/cast/sender/channel/BUILD.gn b/cast/sender/channel/BUILD.gn
index 85f633b0..5fbed9ce 100644
--- a/cast/sender/channel/BUILD.gn
+++ b/cast/sender/channel/BUILD.gn
@@ -6,6 +6,10 @@ source_set("channel") {
sources = [
"cast_auth_util.cc",
"cast_auth_util.h",
+ "message_util.cc",
+ "message_util.h",
+ "sender_socket_factory.cc",
+ "sender_socket_factory.h",
]
deps = [
diff --git a/cast/sender/channel/message_util.cc b/cast/sender/channel/message_util.cc
new file mode 100644
index 00000000..ab3ed5d8
--- /dev/null
+++ b/cast/sender/channel/message_util.cc
@@ -0,0 +1,34 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/message_util.h"
+
+#include "cast/sender/channel/cast_auth_util.h"
+
+namespace cast {
+namespace channel {
+
+CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) {
+ CastMessage message;
+ DeviceAuthMessage auth_message;
+
+ AuthChallenge* challenge = auth_message.mutable_challenge();
+ challenge->set_sender_nonce(auth_context.nonce());
+ challenge->set_hash_algorithm(SHA256);
+
+ std::string auth_message_string;
+ auth_message.SerializeToString(&auth_message_string);
+
+ message.set_protocol_version(CastMessage::CASTV2_1_0);
+ message.set_source_id(kPlatformSenderId);
+ message.set_destination_id(kPlatformReceiverId);
+ message.set_namespace_(kAuthNamespace);
+ message.set_payload_type(CastMessage_PayloadType_BINARY);
+ message.set_payload_binary(auth_message_string);
+
+ return message;
+}
+
+} // namespace channel
+} // namespace cast
diff --git a/cast/sender/channel/message_util.h b/cast/sender/channel/message_util.h
new file mode 100644
index 00000000..e2da0cd8
--- /dev/null
+++ b/cast/sender/channel/message_util.h
@@ -0,0 +1,21 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
+#define CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
+
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace cast {
+namespace channel {
+
+class AuthContext;
+
+CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context);
+
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
new file mode 100644
index 00000000..83e73370
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -0,0 +1,166 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/sender_socket_factory.h"
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/message_util.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::platform::TlsConnectOptions;
+
+bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
+ uint32_t b) {
+ return a && a->socket->socket_id() < b;
+}
+
+bool operator<(uint32_t a,
+ const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
+ return b && a < b->socket->socket_id();
+}
+
+SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) {
+ OSP_DCHECK(client);
+}
+
+SenderSocketFactory::~SenderSocketFactory() = default;
+
+void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
+ DeviceMediaPolicy media_policy,
+ CastSocket::Client* client) {
+ OSP_DCHECK(client);
+ auto it = FindPendingConnection(endpoint);
+ if (it == pending_connections_.end()) {
+ pending_connections_.emplace_back(
+ PendingConnection{endpoint, media_policy, client});
+ factory_->Connect(endpoint, TlsConnectOptions{true});
+ }
+}
+
+void SenderSocketFactory::OnAccepted(
+ TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ OSP_NOTREACHED() << "This factory is connect-only.";
+}
+
+void SenderSocketFactory::OnConnected(
+ TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ const IPEndpoint& endpoint = connection->remote_address();
+ auto it = FindPendingConnection(endpoint);
+ if (it == pending_connections_.end()) {
+ OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
+ << endpoint;
+ return;
+ }
+ DeviceMediaPolicy media_policy = it->media_policy;
+ CastSocket::Client* client = it->client;
+ pending_connections_.erase(it);
+
+ if (!peer_cert) {
+ client_->OnError(this, endpoint, Error::Code::kErrCertsMissing);
+ return;
+ }
+
+ auto socket = std::make_unique<CastSocket>(std::move(connection), this,
+ GetNextSocketId());
+ pending_auth_.emplace_back(new PendingAuth{endpoint, media_policy,
+ std::move(socket), client,
+ AuthContext::Create(), peer_cert});
+ PendingAuth& pending = *pending_auth_.back();
+
+ CastMessage auth_challenge = CreateAuthChallengeMessage(pending.auth_context);
+ Error error = pending.socket->SendMessage(auth_challenge);
+ if (!error.ok()) {
+ pending_auth_.pop_back();
+ client_->OnError(this, endpoint, error);
+ }
+}
+
+void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) {
+ auto it = FindPendingConnection(remote_address);
+ if (it == pending_connections_.end()) {
+ OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
+ << remote_address;
+ return;
+ }
+ pending_connections_.erase(it);
+ client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
+}
+
+void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
+ std::vector<PendingConnection> connections;
+ pending_connections_.swap(connections);
+ for (const PendingConnection& pending : connections) {
+ client_->OnError(this, pending.endpoint, error);
+ }
+}
+
+std::vector<SenderSocketFactory::PendingConnection>::iterator
+SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
+ return std::find_if(pending_connections_.begin(), pending_connections_.end(),
+ [&endpoint](const PendingConnection& pending) {
+ return pending.endpoint == endpoint;
+ });
+}
+
+void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
+ auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+ [id = socket->socket_id()](
+ const std::unique_ptr<PendingAuth>& pending_auth) {
+ return pending_auth->socket->socket_id() == id;
+ });
+ if (it == pending_auth_.end()) {
+ OSP_DLOG_ERROR << "Got error for unknown pending socket";
+ return;
+ }
+ IPEndpoint endpoint = (*it)->endpoint;
+ pending_auth_.erase(it);
+ client_->OnError(this, endpoint, error);
+}
+
+void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
+ auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+ [id = socket->socket_id()](
+ const std::unique_ptr<PendingAuth>& pending_auth) {
+ return pending_auth->socket->socket_id() == id;
+ });
+ if (it == pending_auth_.end()) {
+ OSP_DLOG_ERROR << "Got message for unknown pending socket";
+ return;
+ }
+
+ std::unique_ptr<PendingAuth> pending = std::move(*it);
+ pending_auth_.erase(it);
+ if (!IsAuthMessage(message)) {
+ client_->OnError(this, pending->endpoint,
+ Error::Code::kCastV2AuthenticationError);
+ return;
+ }
+
+ ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
+ message, (*it)->peer_cert, (*it)->auth_context);
+ if (policy_or_error.is_error()) {
+ client_->OnError(this, pending->endpoint, policy_or_error.error());
+ return;
+ }
+
+ if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
+ pending->media_policy != DeviceMediaPolicy::kAudioOnly) {
+ client_->OnError(this, pending->endpoint,
+ Error::Code::kCastV2ChannelPolicyMismatch);
+ return;
+ }
+
+ pending->socket->set_client(pending->client);
+ client_->OnConnected(this, pending->endpoint, std::move(pending->socket));
+}
+
+} // namespace channel
+} // namespace cast
diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h
new file mode 100644
index 00000000..d5b6622d
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.h
@@ -0,0 +1,104 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
+#define CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/cast_auth_util.h"
+#include "platform/api/logging.h"
+#include "platform/api/tls_connection_factory.h"
+#include "platform/base/ip_address.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::Error;
+using openscreen::IPEndpoint;
+using openscreen::IPEndpointComparator;
+using openscreen::platform::TlsConnection;
+using openscreen::platform::TlsConnectionFactory;
+
+class SenderSocketFactory final : public TlsConnectionFactory::Client,
+ public CastSocket::Client {
+ public:
+ class Client {
+ public:
+ virtual void OnConnected(SenderSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) = 0;
+ virtual void OnError(SenderSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ Error error) = 0;
+ };
+
+ enum class DeviceMediaPolicy {
+ kAudioOnly,
+ kIncludesVideo,
+ };
+
+ // |client| must outlive |this|.
+ explicit SenderSocketFactory(Client* client);
+ ~SenderSocketFactory();
+
+ void set_factory(TlsConnectionFactory* factory) {
+ OSP_DCHECK(factory);
+ factory_ = factory;
+ }
+
+ void Connect(const IPEndpoint& endpoint,
+ DeviceMediaPolicy media_policy,
+ CastSocket::Client* client);
+
+ // TlsConnectionFactory::Client overrides.
+ void OnAccepted(TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) override;
+ void OnConnected(TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) override;
+ void OnConnectionFailed(TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) override;
+ void OnError(TlsConnectionFactory* factory, Error error) override;
+
+ private:
+ struct PendingConnection {
+ IPEndpoint endpoint;
+ DeviceMediaPolicy media_policy;
+ CastSocket::Client* client;
+ };
+
+ struct PendingAuth {
+ IPEndpoint endpoint;
+ DeviceMediaPolicy media_policy;
+ std::unique_ptr<CastSocket> socket;
+ CastSocket::Client* client;
+ AuthContext auth_context;
+ X509* peer_cert;
+ };
+
+ friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b);
+ friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b);
+
+ std::vector<PendingConnection>::iterator FindPendingConnection(
+ const IPEndpoint& endpoint);
+
+ // CastSocket::Client overrides.
+ void OnError(CastSocket* socket, Error error) override;
+ void OnMessage(CastSocket* socket, CastMessage message) override;
+
+ Client* const client_;
+ TlsConnectionFactory* factory_ = nullptr;
+ std::vector<PendingConnection> pending_connections_;
+ std::vector<std::unique_ptr<PendingAuth>> pending_auth_;
+};
+
+} // namespace channel
+} // namespace cast
+
+#endif // CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_