From a04b13eaa77d8ca31c0f06e32231bef84f96e32d Mon Sep 17 00:00:00 2001 From: btolsch Date: Tue, 15 Oct 2019 12:54:10 -0700 Subject: Add cast sender socket factory This change adds a CastSocket factory for the sender-side which performs the sender auth challenge and verification before passing a CastSocket back to the caller. Bug: openscreen:59 Change-Id: Ibbbdb2b8881e385cc0a8defbe309c7f10a2af323 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1834457 Commit-Queue: Brandon Tolsch Reviewed-by: Ryan Keane --- cast/sender/channel/BUILD.gn | 4 + cast/sender/channel/message_util.cc | 34 ++++++ cast/sender/channel/message_util.h | 21 ++++ cast/sender/channel/sender_socket_factory.cc | 166 +++++++++++++++++++++++++++ cast/sender/channel/sender_socket_factory.h | 104 +++++++++++++++++ 5 files changed, 329 insertions(+) create mode 100644 cast/sender/channel/message_util.cc create mode 100644 cast/sender/channel/message_util.h create mode 100644 cast/sender/channel/sender_socket_factory.cc create mode 100644 cast/sender/channel/sender_socket_factory.h (limited to 'cast/sender') diff --git a/cast/sender/channel/BUILD.gn b/cast/sender/channel/BUILD.gn index 85f633b0..5fbed9ce 100644 --- a/cast/sender/channel/BUILD.gn +++ b/cast/sender/channel/BUILD.gn @@ -6,6 +6,10 @@ source_set("channel") { sources = [ "cast_auth_util.cc", "cast_auth_util.h", + "message_util.cc", + "message_util.h", + "sender_socket_factory.cc", + "sender_socket_factory.h", ] deps = [ diff --git a/cast/sender/channel/message_util.cc b/cast/sender/channel/message_util.cc new file mode 100644 index 00000000..ab3ed5d8 --- /dev/null +++ b/cast/sender/channel/message_util.cc @@ -0,0 +1,34 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/channel/message_util.h" + +#include "cast/sender/channel/cast_auth_util.h" + +namespace cast { +namespace channel { + +CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) { + CastMessage message; + DeviceAuthMessage auth_message; + + AuthChallenge* challenge = auth_message.mutable_challenge(); + challenge->set_sender_nonce(auth_context.nonce()); + challenge->set_hash_algorithm(SHA256); + + std::string auth_message_string; + auth_message.SerializeToString(&auth_message_string); + + message.set_protocol_version(CastMessage::CASTV2_1_0); + message.set_source_id(kPlatformSenderId); + message.set_destination_id(kPlatformReceiverId); + message.set_namespace_(kAuthNamespace); + message.set_payload_type(CastMessage_PayloadType_BINARY); + message.set_payload_binary(auth_message_string); + + return message; +} + +} // namespace channel +} // namespace cast diff --git a/cast/sender/channel/message_util.h b/cast/sender/channel/message_util.h new file mode 100644 index 00000000..e2da0cd8 --- /dev/null +++ b/cast/sender/channel/message_util.h @@ -0,0 +1,21 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_ +#define CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_ + +#include "cast/common/channel/message_util.h" +#include "cast/common/channel/proto/cast_channel.pb.h" + +namespace cast { +namespace channel { + +class AuthContext; + +CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context); + +} // namespace channel +} // namespace cast + +#endif // CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_ diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc new file mode 100644 index 00000000..83e73370 --- /dev/null +++ b/cast/sender/channel/sender_socket_factory.cc @@ -0,0 +1,166 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/channel/sender_socket_factory.h" + +#include "cast/common/channel/cast_socket.h" +#include "cast/sender/channel/message_util.h" + +namespace cast { +namespace channel { + +using openscreen::platform::TlsConnectOptions; + +bool operator<(const std::unique_ptr& a, + uint32_t b) { + return a && a->socket->socket_id() < b; +} + +bool operator<(uint32_t a, + const std::unique_ptr& b) { + return b && a < b->socket->socket_id(); +} + +SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) { + OSP_DCHECK(client); +} + +SenderSocketFactory::~SenderSocketFactory() = default; + +void SenderSocketFactory::Connect(const IPEndpoint& endpoint, + DeviceMediaPolicy media_policy, + CastSocket::Client* client) { + OSP_DCHECK(client); + auto it = FindPendingConnection(endpoint); + if (it == pending_connections_.end()) { + pending_connections_.emplace_back( + PendingConnection{endpoint, media_policy, client}); + factory_->Connect(endpoint, TlsConnectOptions{true}); + } +} + +void SenderSocketFactory::OnAccepted( + TlsConnectionFactory* factory, + X509* peer_cert, + std::unique_ptr connection) { + OSP_NOTREACHED() << "This factory is connect-only."; +} + +void SenderSocketFactory::OnConnected( + TlsConnectionFactory* factory, + X509* peer_cert, + std::unique_ptr connection) { + const IPEndpoint& endpoint = connection->remote_address(); + auto it = FindPendingConnection(endpoint); + if (it == pending_connections_.end()) { + OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: " + << endpoint; + return; + } + DeviceMediaPolicy media_policy = it->media_policy; + CastSocket::Client* client = it->client; + pending_connections_.erase(it); + + if (!peer_cert) { + client_->OnError(this, endpoint, Error::Code::kErrCertsMissing); + return; + } + + auto socket = std::make_unique(std::move(connection), this, + GetNextSocketId()); + pending_auth_.emplace_back(new PendingAuth{endpoint, media_policy, + std::move(socket), client, + AuthContext::Create(), peer_cert}); + PendingAuth& pending = *pending_auth_.back(); + + CastMessage auth_challenge = CreateAuthChallengeMessage(pending.auth_context); + Error error = pending.socket->SendMessage(auth_challenge); + if (!error.ok()) { + pending_auth_.pop_back(); + client_->OnError(this, endpoint, error); + } +} + +void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory, + const IPEndpoint& remote_address) { + auto it = FindPendingConnection(remote_address); + if (it == pending_connections_.end()) { + OSP_DVLOG << "OnConnectionFailed reported for untracked address: " + << remote_address; + return; + } + pending_connections_.erase(it); + client_->OnError(this, remote_address, Error::Code::kConnectionFailed); +} + +void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) { + std::vector connections; + pending_connections_.swap(connections); + for (const PendingConnection& pending : connections) { + client_->OnError(this, pending.endpoint, error); + } +} + +std::vector::iterator +SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) { + return std::find_if(pending_connections_.begin(), pending_connections_.end(), + [&endpoint](const PendingConnection& pending) { + return pending.endpoint == endpoint; + }); +} + +void SenderSocketFactory::OnError(CastSocket* socket, Error error) { + auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(), + [id = socket->socket_id()]( + const std::unique_ptr& pending_auth) { + return pending_auth->socket->socket_id() == id; + }); + if (it == pending_auth_.end()) { + OSP_DLOG_ERROR << "Got error for unknown pending socket"; + return; + } + IPEndpoint endpoint = (*it)->endpoint; + pending_auth_.erase(it); + client_->OnError(this, endpoint, error); +} + +void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) { + auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(), + [id = socket->socket_id()]( + const std::unique_ptr& pending_auth) { + return pending_auth->socket->socket_id() == id; + }); + if (it == pending_auth_.end()) { + OSP_DLOG_ERROR << "Got message for unknown pending socket"; + return; + } + + std::unique_ptr pending = std::move(*it); + pending_auth_.erase(it); + if (!IsAuthMessage(message)) { + client_->OnError(this, pending->endpoint, + Error::Code::kCastV2AuthenticationError); + return; + } + + ErrorOr policy_or_error = AuthenticateChallengeReply( + message, (*it)->peer_cert, (*it)->auth_context); + if (policy_or_error.is_error()) { + client_->OnError(this, pending->endpoint, policy_or_error.error()); + return; + } + + if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly && + pending->media_policy != DeviceMediaPolicy::kAudioOnly) { + client_->OnError(this, pending->endpoint, + Error::Code::kCastV2ChannelPolicyMismatch); + return; + } + + pending->socket->set_client(pending->client); + client_->OnConnected(this, pending->endpoint, std::move(pending->socket)); +} + +} // namespace channel +} // namespace cast diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h new file mode 100644 index 00000000..d5b6622d --- /dev/null +++ b/cast/sender/channel/sender_socket_factory.h @@ -0,0 +1,104 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_ +#define CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_ + +#include +#include +#include + +#include "cast/common/channel/cast_socket.h" +#include "cast/sender/channel/cast_auth_util.h" +#include "platform/api/logging.h" +#include "platform/api/tls_connection_factory.h" +#include "platform/base/ip_address.h" + +namespace cast { +namespace channel { + +using openscreen::Error; +using openscreen::IPEndpoint; +using openscreen::IPEndpointComparator; +using openscreen::platform::TlsConnection; +using openscreen::platform::TlsConnectionFactory; + +class SenderSocketFactory final : public TlsConnectionFactory::Client, + public CastSocket::Client { + public: + class Client { + public: + virtual void OnConnected(SenderSocketFactory* factory, + const IPEndpoint& endpoint, + std::unique_ptr socket) = 0; + virtual void OnError(SenderSocketFactory* factory, + const IPEndpoint& endpoint, + Error error) = 0; + }; + + enum class DeviceMediaPolicy { + kAudioOnly, + kIncludesVideo, + }; + + // |client| must outlive |this|. + explicit SenderSocketFactory(Client* client); + ~SenderSocketFactory(); + + void set_factory(TlsConnectionFactory* factory) { + OSP_DCHECK(factory); + factory_ = factory; + } + + void Connect(const IPEndpoint& endpoint, + DeviceMediaPolicy media_policy, + CastSocket::Client* client); + + // TlsConnectionFactory::Client overrides. + void OnAccepted(TlsConnectionFactory* factory, + X509* peer_cert, + std::unique_ptr connection) override; + void OnConnected(TlsConnectionFactory* factory, + X509* peer_cert, + std::unique_ptr connection) override; + void OnConnectionFailed(TlsConnectionFactory* factory, + const IPEndpoint& remote_address) override; + void OnError(TlsConnectionFactory* factory, Error error) override; + + private: + struct PendingConnection { + IPEndpoint endpoint; + DeviceMediaPolicy media_policy; + CastSocket::Client* client; + }; + + struct PendingAuth { + IPEndpoint endpoint; + DeviceMediaPolicy media_policy; + std::unique_ptr socket; + CastSocket::Client* client; + AuthContext auth_context; + X509* peer_cert; + }; + + friend bool operator<(const std::unique_ptr& a, uint32_t b); + friend bool operator<(uint32_t a, const std::unique_ptr& b); + + std::vector::iterator FindPendingConnection( + const IPEndpoint& endpoint); + + // CastSocket::Client overrides. + void OnError(CastSocket* socket, Error error) override; + void OnMessage(CastSocket* socket, CastMessage message) override; + + Client* const client_; + TlsConnectionFactory* factory_ = nullptr; + std::vector pending_connections_; + std::vector> pending_auth_; +}; + +} // namespace channel +} // namespace cast + +#endif // CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_ -- cgit v1.2.3