aboutsummaryrefslogtreecommitdiff
path: root/cast/sender/channel/sender_socket_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'cast/sender/channel/sender_socket_factory.cc')
-rw-r--r--cast/sender/channel/sender_socket_factory.cc166
1 files changed, 166 insertions, 0 deletions
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
new file mode 100644
index 00000000..83e73370
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -0,0 +1,166 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/sender_socket_factory.h"
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/message_util.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::platform::TlsConnectOptions;
+
+bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
+ uint32_t b) {
+ return a && a->socket->socket_id() < b;
+}
+
+bool operator<(uint32_t a,
+ const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
+ return b && a < b->socket->socket_id();
+}
+
+SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) {
+ OSP_DCHECK(client);
+}
+
+SenderSocketFactory::~SenderSocketFactory() = default;
+
+void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
+ DeviceMediaPolicy media_policy,
+ CastSocket::Client* client) {
+ OSP_DCHECK(client);
+ auto it = FindPendingConnection(endpoint);
+ if (it == pending_connections_.end()) {
+ pending_connections_.emplace_back(
+ PendingConnection{endpoint, media_policy, client});
+ factory_->Connect(endpoint, TlsConnectOptions{true});
+ }
+}
+
+void SenderSocketFactory::OnAccepted(
+ TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ OSP_NOTREACHED() << "This factory is connect-only.";
+}
+
+void SenderSocketFactory::OnConnected(
+ TlsConnectionFactory* factory,
+ X509* peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ const IPEndpoint& endpoint = connection->remote_address();
+ auto it = FindPendingConnection(endpoint);
+ if (it == pending_connections_.end()) {
+ OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
+ << endpoint;
+ return;
+ }
+ DeviceMediaPolicy media_policy = it->media_policy;
+ CastSocket::Client* client = it->client;
+ pending_connections_.erase(it);
+
+ if (!peer_cert) {
+ client_->OnError(this, endpoint, Error::Code::kErrCertsMissing);
+ return;
+ }
+
+ auto socket = std::make_unique<CastSocket>(std::move(connection), this,
+ GetNextSocketId());
+ pending_auth_.emplace_back(new PendingAuth{endpoint, media_policy,
+ std::move(socket), client,
+ AuthContext::Create(), peer_cert});
+ PendingAuth& pending = *pending_auth_.back();
+
+ CastMessage auth_challenge = CreateAuthChallengeMessage(pending.auth_context);
+ Error error = pending.socket->SendMessage(auth_challenge);
+ if (!error.ok()) {
+ pending_auth_.pop_back();
+ client_->OnError(this, endpoint, error);
+ }
+}
+
+void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) {
+ auto it = FindPendingConnection(remote_address);
+ if (it == pending_connections_.end()) {
+ OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
+ << remote_address;
+ return;
+ }
+ pending_connections_.erase(it);
+ client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
+}
+
+void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
+ std::vector<PendingConnection> connections;
+ pending_connections_.swap(connections);
+ for (const PendingConnection& pending : connections) {
+ client_->OnError(this, pending.endpoint, error);
+ }
+}
+
+std::vector<SenderSocketFactory::PendingConnection>::iterator
+SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
+ return std::find_if(pending_connections_.begin(), pending_connections_.end(),
+ [&endpoint](const PendingConnection& pending) {
+ return pending.endpoint == endpoint;
+ });
+}
+
+void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
+ auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+ [id = socket->socket_id()](
+ const std::unique_ptr<PendingAuth>& pending_auth) {
+ return pending_auth->socket->socket_id() == id;
+ });
+ if (it == pending_auth_.end()) {
+ OSP_DLOG_ERROR << "Got error for unknown pending socket";
+ return;
+ }
+ IPEndpoint endpoint = (*it)->endpoint;
+ pending_auth_.erase(it);
+ client_->OnError(this, endpoint, error);
+}
+
+void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
+ auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+ [id = socket->socket_id()](
+ const std::unique_ptr<PendingAuth>& pending_auth) {
+ return pending_auth->socket->socket_id() == id;
+ });
+ if (it == pending_auth_.end()) {
+ OSP_DLOG_ERROR << "Got message for unknown pending socket";
+ return;
+ }
+
+ std::unique_ptr<PendingAuth> pending = std::move(*it);
+ pending_auth_.erase(it);
+ if (!IsAuthMessage(message)) {
+ client_->OnError(this, pending->endpoint,
+ Error::Code::kCastV2AuthenticationError);
+ return;
+ }
+
+ ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
+ message, (*it)->peer_cert, (*it)->auth_context);
+ if (policy_or_error.is_error()) {
+ client_->OnError(this, pending->endpoint, policy_or_error.error());
+ return;
+ }
+
+ if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
+ pending->media_policy != DeviceMediaPolicy::kAudioOnly) {
+ client_->OnError(this, pending->endpoint,
+ Error::Code::kCastV2ChannelPolicyMismatch);
+ return;
+ }
+
+ pending->socket->set_client(pending->client);
+ client_->OnConnected(this, pending->endpoint, std::move(pending->socket));
+}
+
+} // namespace channel
+} // namespace cast