// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "cast/sender/public/sender_socket_factory.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/sender/channel/cast_auth_util.h" #include "cast/sender/channel/message_util.h" #include "platform/base/tls_connect_options.h" #include "util/crypto/certificate_utils.h" #include "util/osp_logging.h" using ::cast::channel::CastMessage; namespace openscreen { namespace cast { SenderSocketFactory::Client::~Client() = default; bool operator<(const std::unique_ptr& a, int b) { return a && a->socket->socket_id() < b; } bool operator<(int a, const std::unique_ptr& b) { return b && a < b->socket->socket_id(); } SenderSocketFactory::SenderSocketFactory(Client* client, TaskRunner* task_runner) : client_(client), task_runner_(task_runner) { OSP_DCHECK(client); OSP_DCHECK(task_runner); } SenderSocketFactory::~SenderSocketFactory() { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); } void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) { OSP_DCHECK(factory); factory_ = factory; } void SenderSocketFactory::Connect(const IPEndpoint& endpoint, DeviceMediaPolicy media_policy, CastSocket::Client* client) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(client); auto it = FindPendingConnection(endpoint); if (it == pending_connections_.end()) { pending_connections_.emplace_back( PendingConnection{endpoint, media_policy, client}); factory_->Connect(endpoint, TlsConnectOptions{true}); } } void SenderSocketFactory::OnAccepted( TlsConnectionFactory* factory, std::vector der_x509_peer_cert, std::unique_ptr connection) { OSP_NOTREACHED(); OSP_LOG_FATAL << "This factory is connect-only"; } void SenderSocketFactory::OnConnected( TlsConnectionFactory* factory, std::vector der_x509_peer_cert, std::unique_ptr connection) { const IPEndpoint& endpoint = connection->GetRemoteEndpoint(); auto it = FindPendingConnection(endpoint); if (it == pending_connections_.end()) { OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: " << endpoint; return; } DeviceMediaPolicy media_policy = it->media_policy; CastSocket::Client* client = it->client; pending_connections_.erase(it); ErrorOr> peer_cert = ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size()); if (!peer_cert) { client_->OnError(this, endpoint, peer_cert.error()); return; } auto socket = MakeSerialDelete(task_runner_, std::move(connection), this); pending_auth_.emplace_back( new PendingAuth{endpoint, media_policy, std::move(socket), client, std::make_unique(AuthContext::Create()), std::move(peer_cert.value())}); PendingAuth& pending = *pending_auth_.back(); CastMessage auth_challenge = CreateAuthChallengeMessage(*pending.auth_context); Error error = pending.socket->Send(auth_challenge); if (!error.ok()) { pending_auth_.pop_back(); client_->OnError(this, endpoint, error); } } void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory, const IPEndpoint& remote_address) { auto it = FindPendingConnection(remote_address); if (it == pending_connections_.end()) { return; } pending_connections_.erase(it); client_->OnError(this, remote_address, Error::Code::kConnectionFailed); } void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) { std::vector connections; pending_connections_.swap(connections); for (const PendingConnection& pending : connections) { client_->OnError(this, pending.endpoint, error); } } std::vector::iterator SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) { return std::find_if(pending_connections_.begin(), pending_connections_.end(), [&endpoint](const PendingConnection& pending) { return pending.endpoint == endpoint; }); } void SenderSocketFactory::OnError(CastSocket* socket, Error error) { auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(), [id = socket->socket_id()]( const std::unique_ptr& pending_auth) { return pending_auth->socket->socket_id() == id; }); if (it == pending_auth_.end()) { OSP_DLOG_ERROR << "Got error for unknown pending socket"; return; } IPEndpoint endpoint = (*it)->endpoint; pending_auth_.erase(it); client_->OnError(this, endpoint, error); } void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) { auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(), [id = socket->socket_id()]( const std::unique_ptr& pending_auth) { return pending_auth->socket->socket_id() == id; }); if (it == pending_auth_.end()) { OSP_DLOG_ERROR << "Got message for unknown pending socket"; return; } std::unique_ptr pending = std::move(*it); pending_auth_.erase(it); if (!IsAuthMessage(message)) { client_->OnError(this, pending->endpoint, Error::Code::kCastV2AuthenticationError); return; } ErrorOr policy_or_error = AuthenticateChallengeReply( message, pending->peer_cert.get(), *pending->auth_context); if (policy_or_error.is_error()) { OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint << " with error: " << policy_or_error.error(); client_->OnError(this, pending->endpoint, policy_or_error.error()); return; } if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly && pending->media_policy == DeviceMediaPolicy::kIncludesVideo) { client_->OnError(this, pending->endpoint, Error::Code::kCastV2ChannelPolicyMismatch); return; } pending->socket->set_audio_only(policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly); pending->socket->SetClient(pending->client); client_->OnConnected(this, pending->endpoint, std::unique_ptr(pending->socket.release())); } } // namespace cast } // namespace openscreen