diff options
Diffstat (limited to 'cast')
-rw-r--r-- | cast/common/certificate/cast_cert_validator_internal.cc | 16 | ||||
-rw-r--r-- | cast/common/discovery/e2e_test/tests.cc | 12 | ||||
-rw-r--r-- | cast/common/public/cast_socket.h | 5 | ||||
-rw-r--r-- | cast/standalone_receiver/BUILD.gn | 27 | ||||
-rw-r--r-- | cast/standalone_receiver/cast_agent.cc | 140 | ||||
-rw-r--r-- | cast/standalone_receiver/cast_agent.h | 44 | ||||
-rw-r--r-- | cast/standalone_receiver/cast_agent_integration_tests.cc | 142 | ||||
-rw-r--r-- | cast/standalone_receiver/cast_socket_message_port.cc | 26 | ||||
-rw-r--r-- | cast/standalone_receiver/cast_socket_message_port.h | 15 | ||||
-rw-r--r-- | cast/standalone_receiver/main.cc | 15 | ||||
-rw-r--r-- | cast/standalone_receiver/static_credentials.cc | 138 | ||||
-rw-r--r-- | cast/standalone_receiver/static_credentials.h | 60 | ||||
-rw-r--r-- | cast/streaming/receiver_session.h | 2 | ||||
-rw-r--r-- | cast/test/cast_socket_e2e_test.cc | 10 |
14 files changed, 529 insertions, 123 deletions
diff --git a/cast/common/certificate/cast_cert_validator_internal.cc b/cast/common/certificate/cast_cert_validator_internal.cc index 63457298..8284802e 100644 --- a/cast/common/certificate/cast_cert_validator_internal.cc +++ b/cast/common/certificate/cast_cert_validator_internal.cc @@ -381,11 +381,15 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->intermediate_certs; target_cert.reset(ParseX509Der(der_certs[0])); if (!target_cert) { + OSP_DVLOG << "FindCertificatePath: Invalid target certificate"; return Error::Code::kErrCertsParse; } for (size_t i = 1; i < der_certs.size(); ++i) { intermediate_certs.emplace_back(ParseX509Der(der_certs[i])); if (!intermediate_certs.back()) { + OSP_DVLOG + << "FindCertificatePath: Failed to parse intermediate certificate " + << i << " of " << der_certs.size(); return Error::Code::kErrCertsParse; } } @@ -393,10 +397,12 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, // Basic checks on the target certificate. Error::Code error = VerifyCertTime(target_cert.get(), time); if (error != Error::Code::kNone) { + OSP_DVLOG << "FindCertificatePath: Failed to verify certificate time"; return error; } bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(target_cert.get())}; if (!VerifyPublicKeyLength(public_key.get())) { + OSP_DVLOG << "FindCertificatePath: Failed with invalid public key length"; return Error::Code::kErrCertsVerifyGeneric; } if (X509_ALGOR_cmp(target_cert.get()->sig_alg, @@ -405,11 +411,13 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, } bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(target_cert.get()); if (!key_usage) { + OSP_DVLOG << "FindCertificatePath: Failed with no key usage"; return Error::Code::kErrCertsRestrictions; } int bit = ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kDigitalSignature); if (bit == 0) { + OSP_DVLOG << "FindCertificatePath: Failed to get digital signature"; return Error::Code::kErrCertsRestrictions; } @@ -443,6 +451,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, Error::Code last_error = Error::Code::kNone; for (;;) { X509_NAME* target_issuer_name = X509_get_issuer_name(path_head); + OSP_DVLOG << "FindCertificatePath: Target certificate issuer name: " + << X509_NAME_oneline(target_issuer_name, 0, 0); // The next issuer certificate to add to the current path. X509* next_issuer = nullptr; @@ -451,6 +461,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, X509* trust_store_cert = trust_store->certs[i].get(); X509_NAME* trust_store_cert_name = X509_get_subject_name(trust_store_cert); + OSP_DVLOG << "FindCertificatePath: Trust store certificate issuer name: " + << X509_NAME_oneline(trust_store_cert_name, 0, 0); if (X509_NAME_cmp(trust_store_cert_name, target_issuer_name) == 0) { CertPathStep& next_step = path[--path_index]; next_step.cert = trust_store_cert; @@ -485,6 +497,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, if (path_index == first_index) { // There are no more paths to try. Ensure an error is returned. if (last_error == Error::Code::kNone) { + OSP_DVLOG << "FindCertificatePath: Failed after trying all " + "certificate paths, no matches"; return Error::Code::kErrCertsVerifyGeneric; } return last_error; @@ -515,6 +529,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->path.push_back(path[i].cert); } + OSP_DVLOG + << "FindCertificatePath: Succeeded at validating receiver certificates"; return Error::Code::kNone; } diff --git a/cast/common/discovery/e2e_test/tests.cc b/cast/common/discovery/e2e_test/tests.cc index 97f4baf8..31541758 100644 --- a/cast/common/discovery/e2e_test/tests.cc +++ b/cast/common/discovery/e2e_test/tests.cc @@ -36,7 +36,6 @@ constexpr std::chrono::milliseconds kCheckLoopSleepTime = std::chrono::milliseconds(100); constexpr int kMaxCheckLoopIterations = 25; -} // namespace // Publishes new service instances. class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { @@ -66,9 +65,9 @@ class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { }; // Receives incoming services and outputs their results to stdout. -class Receiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { +class ServiceReceiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { public: - explicit Receiver(discovery::DnsSdService* service) + explicit ServiceReceiver(discovery::DnsSdService* service) : discovery::DnsSdServiceWatcher<ServiceInfo>( service, kCastV2ServiceId, @@ -77,7 +76,7 @@ class Receiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { std::vector<std::reference_wrapper<const ServiceInfo>> infos) { ProcessResults(std::move(infos)); }) { - OSP_LOG << "Initializing Receiver..."; + OSP_LOG << "Initializing ServiceReceiver..."; } bool IsServiceFound(const ServiceInfo& check_service) { @@ -161,7 +160,7 @@ class DiscoveryE2ETest : public testing::Test { task_runner_->PostTask([this, &config, &done]() { dnssd_service_ = discovery::CreateDnsSdService( task_runner_, &reporting_client_, config); - receiver_ = std::make_unique<Receiver>(dnssd_service_.get()); + receiver_ = std::make_unique<ServiceReceiver>(dnssd_service_.get()); publisher_ = std::make_unique<Publisher>(dnssd_service_.get()); done = true; }); @@ -264,7 +263,7 @@ class DiscoveryE2ETest : public testing::Test { TaskRunner* task_runner_; FailOnErrorReporting reporting_client_; SerialDeletePtr<discovery::DnsSdService> dnssd_service_; - std::unique_ptr<Receiver> receiver_; + std::unique_ptr<ServiceReceiver> receiver_; std::unique_ptr<Publisher> publisher_; private: @@ -578,5 +577,6 @@ TEST_F(DiscoveryE2ETest, ValidateRefreshFlow) { WaitUntilSeen(true, &found); } +} // namespace } // namespace cast } // namespace openscreen diff --git a/cast/common/public/cast_socket.h b/cast/common/public/cast_socket.h index 0fd065d6..d7ac683f 100644 --- a/cast/common/public/cast_socket.h +++ b/cast/common/public/cast_socket.h @@ -10,6 +10,7 @@ #include <vector> #include "platform/api/tls_connection.h" +#include "util/weak_ptr.h" namespace cast { namespace channel { @@ -58,6 +59,8 @@ class CastSocket : public TlsConnection::Client { void OnError(TlsConnection* connection, Error error) override; void OnRead(TlsConnection* connection, std::vector<uint8_t> block) override; + WeakPtr<CastSocket> GetWeakPtr() const { return weak_factory_.GetWeakPtr(); } + private: enum class State : bool { kOpen = true, @@ -72,6 +75,8 @@ class CastSocket : public TlsConnection::Client { bool audio_only_ = false; std::vector<uint8_t> read_buffer_; State state_ = State::kOpen; + + WeakPtrFactory<CastSocket> weak_factory_{this}; }; } // namespace cast diff --git a/cast/standalone_receiver/BUILD.gn b/cast/standalone_receiver/BUILD.gn index ba54d33f..46ab0cad 100644 --- a/cast/standalone_receiver/BUILD.gn +++ b/cast/standalone_receiver/BUILD.gn @@ -9,16 +9,18 @@ import("//build_overrides/build.gni") # standalone platform implementation; since this is itself a standalone # application. if (!build_with_chromium) { - executable("cast_receiver") { + source_set("standalone_receiver") { sources = [ "cast_agent.cc", "cast_agent.h", "cast_socket_message_port.cc", "cast_socket_message_port.h", - "main.cc", + "static_credentials.cc", + "static_credentials.h", "streaming_playback_controller.cc", "streaming_playback_controller.h", ] + deps = [ "../../platform", "../../third_party/jsoncpp", @@ -57,4 +59,25 @@ if (!build_with_chromium) { ] } } + + source_set("e2e_tests") { + testonly = true + sources = [ "cast_agent_integration_tests.cc" ] + + deps = [ + ":standalone_receiver", + "../../third_party/boringssl", + "../../third_party/googletest:gtest", + "../receiver:channel", + ] + } + + executable("cast_receiver") { + sources = [ "main.cc" ] + + deps = [ + ":standalone_receiver", + "../receiver:channel", + ] + } } diff --git a/cast/standalone_receiver/cast_agent.cc b/cast/standalone_receiver/cast_agent.cc index dfccd351..ad22f2d6 100644 --- a/cast/standalone_receiver/cast_agent.cc +++ b/cast/standalone_receiver/cast_agent.cc @@ -11,13 +11,12 @@ #include <vector> #include "absl/strings/str_cat.h" +#include "cast/common/channel/message_util.h" #include "cast/standalone_receiver/cast_socket_message_port.h" -#include "cast/standalone_receiver/private_key_der.h" #include "cast/streaming/constants.h" #include "cast/streaming/offer_messages.h" #include "platform/base/tls_credentials.h" #include "platform/base/tls_listen_options.h" -#include "util/crypto/certificate_utils.h" #include "util/json/json_serialization.h" #include "util/osp_logging.h" @@ -28,100 +27,77 @@ namespace { constexpr int kDefaultMaxBacklogSize = 64; const TlsListenOptions kDefaultListenOptions{kDefaultMaxBacklogSize}; -constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60; -constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds); - -// Generates a valid set of credentials for use with the TLS Server socket, -// including a generated X509 certificate generated from the static private key -// stored in private_key_der.h. The certificate is valid for -// kCertificateDuration from when this function is called. -ErrorOr<TlsCredentials> CreateCredentials(const IPEndpoint& endpoint) { - ErrorOr<bssl::UniquePtr<EVP_PKEY>> private_key = - ImportRSAPrivateKey(kPrivateKeyDer.data(), kPrivateKeyDer.size()); - OSP_CHECK(private_key); - - ErrorOr<bssl::UniquePtr<X509>> cert = CreateSelfSignedX509Certificate( - endpoint.ToString(), kCertificateDuration, *private_key.value()); - if (!cert) { - return cert.error(); - } - - auto cert_bytes = ExportX509CertificateToDer(*cert.value()); - if (!cert_bytes) { - return cert_bytes.error(); - } - - // TODO(jophba): either refactor the TLS server socket to use the public key - // and add a valid key here, or remove from the TlsCredentials struct. - return TlsCredentials( - std::vector<uint8_t>(kPrivateKeyDer.begin(), kPrivateKeyDer.end()), - std::vector<uint8_t>{}, std::move(cert_bytes.value())); -} - } // namespace -CastAgent::CastAgent(TaskRunner* task_runner, InterfaceInfo interface) - : task_runner_(task_runner) { - // Create the Environment that holds the required injected dependencies - // (clock, task runner) used throughout the system, and owns the UDP socket - // over which all communication occurs with the Sender. - IPEndpoint receive_endpoint{IPAddress::kV4LoopbackAddress, kDefaultCastPort}; - receive_endpoint.address = interface.GetIpAddressV4() - ? interface.GetIpAddressV4() - : interface.GetIpAddressV6(); - OSP_DCHECK(receive_endpoint.address); - environment_ = std::make_unique<Environment>(&Clock::now, task_runner_, - receive_endpoint); - receive_endpoint_ = std::move(receive_endpoint); +CastAgent::CastAgent( + TaskRunner* task_runner, + InterfaceInfo interface, + DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider, + TlsCredentials tls_credentials) + : task_runner_(task_runner), + credentials_provider_(credentials_provider), + tls_credentials_(tls_credentials) { + const IPAddress address = interface.GetIpAddressV4() + ? interface.GetIpAddressV4() + : interface.GetIpAddressV6(); + OSP_CHECK(address); + environment_ = std::make_unique<Environment>( + &Clock::now, task_runner_, + IPEndpoint{address, kDefaultCastStreamingPort}); + receive_endpoint_ = IPEndpoint{address, kDefaultCastPort}; } CastAgent::~CastAgent() = default; Error CastAgent::Start() { - OSP_DCHECK(!current_session_); - - task_runner_->PostTask( - [this] { this->wake_lock_ = ScopedWakeLock::Create(); }); - - // TODO(jophba): add command line argument for setting the private key. - ErrorOr<TlsCredentials> credentials = CreateCredentials(receive_endpoint_); - if (!credentials) { - return credentials.error(); - } - - // TODO(jophba, rwkeane): begin discovery process before creating TLS - // connection factory instance. - socket_factory_ = - std::make_unique<ReceiverSocketFactory>(this, &message_port_); - task_runner_->PostTask([this, creds = std::move(credentials.value())] { - connection_factory_ = TlsConnectionFactory::CreateFactory( - socket_factory_.get(), task_runner_); - connection_factory_->SetListenCredentials(creds); + OSP_CHECK(!current_session_); + + auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>( + task_runner_, credentials_provider_); + router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_, + &connection_manager_); + router_->AddHandlerForLocalId(kPlatformReceiverId, auth_handler_.get()); + socket_factory_ = MakeSerialDelete<ReceiverSocketFactory>(task_runner_, this, + router_.get()); + + task_runner_->PostTask([this] { + wake_lock_ = ScopedWakeLock::Create(task_runner_); + + connection_factory_ = SerialDeletePtr<TlsConnectionFactory>( + task_runner_, + TlsConnectionFactory::CreateFactory(socket_factory_.get(), task_runner_) + .release()); + connection_factory_->SetListenCredentials(tls_credentials_); connection_factory_->Listen(receive_endpoint_, kDefaultListenOptions); + OSP_LOG_INFO << "Listening for connections at: " << receive_endpoint_; }); - OSP_LOG_INFO << "Listening for connections at: " << receive_endpoint_; return Error::None(); } Error CastAgent::Stop() { - controller_.reset(); - current_session_.reset(); + task_runner_->PostTask([this] { + router_.reset(); + connection_factory_.reset(); + controller_.reset(); + current_session_.reset(); + socket_factory_.reset(); + wake_lock_.reset(); + }); return Error::None(); } void CastAgent::OnConnected(ReceiverSocketFactory* factory, const IPEndpoint& endpoint, std::unique_ptr<CastSocket> socket) { - OSP_DCHECK(factory); - if (current_session_) { OSP_LOG_WARN << "Already connected, dropping peer at: " << endpoint; return; } OSP_LOG_INFO << "Received connection from peer at: " << endpoint; - message_port_.SetSocket(std::move(socket)); + message_port_.SetSocket(socket->GetWeakPtr()); + router_->TakeSocket(this, std::move(socket)); controller_ = std::make_unique<StreamingPlaybackController>(task_runner_, this); current_session_ = std::make_unique<ReceiverSession>( @@ -131,6 +107,17 @@ void CastAgent::OnConnected(ReceiverSocketFactory* factory, void CastAgent::OnError(ReceiverSocketFactory* factory, Error error) { OSP_LOG_ERROR << "Cast agent received socket factory error: " << error; + StopCurrentSession(); +} + +void CastAgent::OnClose(CastSocket* cast_socket) { + OSP_VLOG << "Cast agent socket closed."; + StopCurrentSession(); +} + +void CastAgent::OnError(CastSocket* socket, Error error) { + OSP_LOG_ERROR << "Cast agent received socket error: " << error; + StopCurrentSession(); } // Currently we don't do anything with the receiver output--the session @@ -139,23 +126,30 @@ void CastAgent::OnError(ReceiverSocketFactory* factory, Error error) { // about the receiver configurations we will have to handle OnNegotiated here. void CastAgent::OnNegotiated(const ReceiverSession* session, ReceiverSession::ConfiguredReceivers receivers) { - OSP_LOG_INFO << "Successfully negotiated with sender."; + OSP_VLOG << "Successfully negotiated with sender."; } void CastAgent::OnConfiguredReceiversDestroyed(const ReceiverSession* session) { - OSP_LOG_INFO << "Receiver instances destroyed."; + OSP_VLOG << "Receiver instances destroyed."; } // Currently, we just kill the session if an error is encountered. void CastAgent::OnError(const ReceiverSession* session, Error error) { OSP_LOG_ERROR << "Cast agent received receiver session error: " << error; - current_session_.reset(); + StopCurrentSession(); } void CastAgent::OnPlaybackError(StreamingPlaybackController* controller, Error error) { OSP_LOG_ERROR << "Cast agent received playback error: " << error; + StopCurrentSession(); +} + +void CastAgent::StopCurrentSession() { + controller_.reset(); current_session_.reset(); + router_->CloseSocket(message_port_.GetSocketId()); + message_port_.SetSocket(nullptr); } } // namespace cast diff --git a/cast/standalone_receiver/cast_agent.h b/cast/standalone_receiver/cast_agent.h index d9932b9d..b4fca60b 100644 --- a/cast/standalone_receiver/cast_agent.h +++ b/cast/standalone_receiver/cast_agent.h @@ -8,10 +8,15 @@ #include <openssl/x509.h> #include <memory> +#include <vector> +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" +#include "cast/receiver/channel/device_auth_namespace_handler.h" #include "cast/receiver/public/receiver_socket_factory.h" #include "cast/standalone_receiver/cast_socket_message_port.h" +#include "cast/standalone_receiver/static_credentials.h" #include "cast/standalone_receiver/streaming_playback_controller.h" #include "cast/streaming/environment.h" #include "cast/streaming/receiver_session.h" @@ -19,6 +24,7 @@ #include "platform/api/serial_delete_ptr.h" #include "platform/base/error.h" #include "platform/base/interface_info.h" +#include "platform/base/tls_credentials.h" #include "platform/impl/task_runner.h" namespace openscreen { @@ -29,13 +35,19 @@ namespace cast { // received, and linking Receivers to the output decoder and SDL visualizer. // // Consumers of this class are expected to provide a single threaded task runner -// implementation, and a network interface information struct that will be used -// both for TLS listening and UDP messaging. -class CastAgent : public ReceiverSocketFactory::Client, - public ReceiverSession::Client, - public StreamingPlaybackController::Client { +// implementation, a network interface information struct that will be used +// both for TLS listening and UDP messaging, and a credentials provider used +// for TLS listening. +class CastAgent final : public ReceiverSocketFactory::Client, + public VirtualConnectionRouter::SocketErrorHandler, + public ReceiverSession::Client, + public StreamingPlaybackController::Client { public: - CastAgent(TaskRunner* task_runner, InterfaceInfo interface); + CastAgent( + TaskRunner* task_runner, + InterfaceInfo interface, + DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider, + TlsCredentials tls_credentials); ~CastAgent(); // Initialization occurs as part of construction, however to actually bind @@ -49,6 +61,10 @@ class CastAgent : public ReceiverSocketFactory::Client, std::unique_ptr<CastSocket> socket) override; void OnError(ReceiverSocketFactory* factory, Error error) override; + // VirtualConnectionRouter::SocketErrorHandler overrides. + void OnClose(CastSocket* cast_socket) override; + void OnError(CastSocket* socket, Error error) override; + // ReceiverSession::Client overrides. void OnNegotiated(const ReceiverSession* session, ReceiverSession::ConfiguredReceivers receivers) override; @@ -60,16 +76,26 @@ class CastAgent : public ReceiverSocketFactory::Client, Error error) override; private: + // Helper for stopping the current session. This is useful for when we don't + // want to completely stop (e.g. an issue with a specific Sender) but need + // to terminate the current connection. + void StopCurrentSession(); + // Member variables set as part of construction. std::unique_ptr<Environment> environment_; TaskRunner* const task_runner_; IPEndpoint receive_endpoint_; + DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider_; CastSocketMessagePort message_port_; + TlsCredentials tls_credentials_; // Member variables set as part of starting up. - std::unique_ptr<TlsConnectionFactory> connection_factory_; - std::unique_ptr<ReceiverSocketFactory> socket_factory_; - std::unique_ptr<ScopedWakeLock> wake_lock_; + SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_; + SerialDeletePtr<TlsConnectionFactory> connection_factory_; + VirtualConnectionManager connection_manager_; + SerialDeletePtr<VirtualConnectionRouter> router_; + SerialDeletePtr<ReceiverSocketFactory> socket_factory_; + SerialDeletePtr<ScopedWakeLock> wake_lock_; // Member variables set as part of a sender connection. // NOTE: currently we only support a single sender connection and a diff --git a/cast/standalone_receiver/cast_agent_integration_tests.cc b/cast/standalone_receiver/cast_agent_integration_tests.cc new file mode 100644 index 00000000..49b1ced4 --- /dev/null +++ b/cast/standalone_receiver/cast_agent_integration_tests.cc @@ -0,0 +1,142 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/certificate/cast_trust_store.h" +#include "cast/common/certificate/testing/test_helpers.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "cast/sender/public/sender_socket_factory.h" +#include "cast/standalone_receiver/static_credentials.h" +#include "cast_agent.h" +#include "gtest/gtest.h" +#include "platform/api/serial_delete_ptr.h" +#include "platform/api/time.h" +#include "platform/impl/network_interface.h" +#include "platform/impl/platform_client_posix.h" +#include "platform/impl/task_runner.h" + +namespace openscreen { +namespace cast { +namespace { + +// Based heavily on SenderSocketsClient from cast_socket_e2e_test.cc. +class MockSender final : public SenderSocketFactory::Client, + public VirtualConnectionRouter::SocketErrorHandler { + public: + explicit MockSender(VirtualConnectionRouter* router) : router_(router) {} + ~MockSender() = default; + + CastSocket* socket() const { return socket_; } + + // SenderSocketFactory::Client overrides. + void OnConnected(SenderSocketFactory* factory, + const IPEndpoint& endpoint, + std::unique_ptr<CastSocket> socket) override { + ASSERT_FALSE(socket_); + OSP_LOG_INFO << "Sender connected to endpoint: " << endpoint; + socket_ = socket.get(); + router_->TakeSocket(this, std::move(socket)); + } + + void OnError(SenderSocketFactory* factory, + const IPEndpoint& endpoint, + Error error) override { + FAIL() << error; + } + + // VirtualConnectionRouter::SocketErrorHandler overrides. + void OnClose(CastSocket* socket) override {} + void OnError(CastSocket* socket, Error error) override { FAIL() << error; } + + private: + VirtualConnectionRouter* const router_; + std::atomic<CastSocket*> socket_{nullptr}; +}; + +class CastAgentIntegrationTest : public ::testing::Test { + public: + void SetUp() override { + PlatformClientPosix::Create(std::chrono::milliseconds{50}, + std::chrono::milliseconds{50}); + task_runner_ = reinterpret_cast<TaskRunnerImpl*>( + PlatformClientPosix::GetInstance()->GetTaskRunner()); + + sender_router_ = MakeSerialDelete<VirtualConnectionRouter>( + task_runner_, &sender_vc_manager_); + sender_client_ = std::make_unique<MockSender>(sender_router_.get()); + sender_factory_ = MakeSerialDelete<SenderSocketFactory>( + task_runner_, sender_client_.get(), task_runner_); + sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>( + task_runner_, + TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_) + .release()); + sender_factory_->set_factory(sender_tls_factory_.get()); + } + + void TearDown() override { + sender_router_.reset(); + sender_tls_factory_.reset(); + sender_factory_.reset(); + PlatformClientPosix::ShutDown(); + // Must be shut down after platform client, so joined tasks + // depending on certs are called correctly. + CastTrustStore::ResetInstance(); + } + + void WaitAndAssertSenderSocketConnected() { + constexpr int kMaxAttempts = 10; + constexpr std::chrono::milliseconds kSocketWaitDelay(250); + for (int i = 0; i < kMaxAttempts; ++i) { + OSP_LOG_INFO << "\tChecking for CastSocket, attempt " << i + 1 << "/" + << kMaxAttempts; + if (sender_client_->socket()) { + break; + } + std::this_thread::sleep_for(kSocketWaitDelay); + } + ASSERT_TRUE(sender_client_->socket()); + } + + void AssertConnect(const IPAddress& address) { + OSP_LOG_INFO << "Sending connect task"; + task_runner_->PostTask( + [this, &address, port = (static_cast<uint16_t>(kDefaultCastPort))]() { + OSP_LOG_INFO << "Calling SenderSocketFactory::Connect"; + sender_factory_->Connect( + IPEndpoint{address, port}, + SenderSocketFactory::DeviceMediaPolicy::kNone, + sender_router_.get()); + }); + WaitAndAssertSenderSocketConnected(); + } + + TaskRunnerImpl* task_runner_; + // Cast socket sender components, used in conjuction to mock a Libcast sender. + VirtualConnectionManager sender_vc_manager_; + SerialDeletePtr<VirtualConnectionRouter> sender_router_; + std::unique_ptr<MockSender> sender_client_; + SerialDeletePtr<SenderSocketFactory> sender_factory_; + SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_; +}; + +TEST_F(CastAgentIntegrationTest, StartsListeningProperly) { + absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting(); + ASSERT_TRUE(loopback.has_value()); + + ErrorOr<GeneratedCredentials> creds = + GenerateCredentials("Test Device Certificate"); + ASSERT_TRUE(creds.is_value()); + CastTrustStore::CreateInstanceForTest(creds.value().root_cert_der); + + auto agent = MakeSerialDelete<CastAgent>( + task_runner_, task_runner_, loopback.value(), + creds.value().provider.get(), creds.value().tls_credentials); + EXPECT_TRUE(agent->Start().ok()); + AssertConnect(loopback.value().GetIpAddressV4()); + EXPECT_TRUE(agent->Stop().ok()); +} + +} // namespace +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_receiver/cast_socket_message_port.cc b/cast/standalone_receiver/cast_socket_message_port.cc index d5540f28..6f3c55c8 100644 --- a/cast/standalone_receiver/cast_socket_message_port.cc +++ b/cast/standalone_receiver/cast_socket_message_port.cc @@ -19,27 +19,17 @@ CastSocketMessagePort::~CastSocketMessagePort() = default; // since sockets should map one to one with receiver sessions, we reset our // client. The consumer of this message port should call SetClient with the new // message port client after setting the socket. -void CastSocketMessagePort::SetSocket(std::unique_ptr<CastSocket> socket) { +void CastSocketMessagePort::SetSocket(WeakPtr<CastSocket> socket) { client_ = nullptr; - socket_ = std::move(socket); + socket_ = socket; } -void CastSocketMessagePort::SetClient(MessagePort::Client* client) { - client_ = client; -} - -void CastSocketMessagePort::OnError(CastSocket* socket, Error error) { - if (client_) { - client_->OnError(error); - } +int CastSocketMessagePort::GetSocketId() { + return socket_ ? socket_->socket_id() : -1; } -void CastSocketMessagePort::OnMessage(CastSocket* socket, - ::cast::channel::CastMessage message) { - if (client_) { - client_->OnMessage(message.source_id(), message.namespace_(), - message.payload_utf8()); - } +void CastSocketMessagePort::SetClient(MessagePort::Client* client) { + client_ = client; } void CastSocketMessagePort::PostMessage(absl::string_view sender_id, @@ -51,6 +41,10 @@ void CastSocketMessagePort::PostMessage(absl::string_view sender_id, message_namespace.size()); cast_message.set_payload_utf8(message.data(), message.size()); + if (!socket_) { + client_->OnError(Error::Code::kAlreadyClosed); + return; + } Error error = socket_->Send(cast_message); if (!error.ok()) { client_->OnError(error); diff --git a/cast/standalone_receiver/cast_socket_message_port.h b/cast/standalone_receiver/cast_socket_message_port.h index 98fc47f6..67d037e9 100644 --- a/cast/standalone_receiver/cast_socket_message_port.h +++ b/cast/standalone_receiver/cast_socket_message_port.h @@ -11,16 +11,20 @@ #include "cast/common/public/cast_socket.h" #include "cast/streaming/receiver_session.h" +#include "util/weak_ptr.h" namespace openscreen { namespace cast { -class CastSocketMessagePort : public MessagePort, public CastSocket::Client { +class CastSocketMessagePort : public MessagePort { public: CastSocketMessagePort(); ~CastSocketMessagePort() override; - void SetSocket(std::unique_ptr<CastSocket> socket); + void SetSocket(WeakPtr<CastSocket> socket); + + // Returns current socket identifier, or -1 if not connected. + int GetSocketId(); // MessagePort overrides. void SetClient(MessagePort::Client* client) override; @@ -28,14 +32,9 @@ class CastSocketMessagePort : public MessagePort, public CastSocket::Client { absl::string_view message_namespace, absl::string_view message) override; - // CastSocket::Client overrides. - void OnError(CastSocket* socket, Error error) override; - void OnMessage(CastSocket* socket, - ::cast::channel::CastMessage message) override; - private: MessagePort::Client* client_ = nullptr; - std::unique_ptr<CastSocket> socket_; + WeakPtr<CastSocket> socket_; }; } // namespace cast diff --git a/cast/standalone_receiver/main.cc b/cast/standalone_receiver/main.cc index 9109d501..027edd81 100644 --- a/cast/standalone_receiver/main.cc +++ b/cast/standalone_receiver/main.cc @@ -8,8 +8,10 @@ #include <chrono> // NOLINT #include <iostream> +#include "absl/strings/str_cat.h" #include "cast/common/public/service_info.h" #include "cast/standalone_receiver/cast_agent.h" +#include "cast/standalone_receiver/static_credentials.h" #include "cast/streaming/ssrc.h" #include "discovery/common/config.h" #include "discovery/common/reporting_client.h" @@ -93,8 +95,11 @@ ErrorOr<std::unique_ptr<DiscoveryState>> StartDiscovery( return state; } -void StartCastAgent(TaskRunnerImpl* task_runner, InterfaceInfo interface) { - CastAgent agent(task_runner, interface); +void StartCastAgent(TaskRunnerImpl* task_runner, + InterfaceInfo interface, + GeneratedCredentials* creds) { + CastAgent agent(task_runner, interface, creds->provider.get(), + creds->tls_credentials); const auto error = agent.Start(); if (!error.ok()) { OSP_LOG_ERROR << "Error occurred while starting agent: " << error; @@ -179,9 +184,13 @@ int RunStandaloneReceiver(int argc, char* argv[]) { auto discovery_state = StartDiscovery(task_runner, interface_info); OSP_CHECK(discovery_state.is_value()) << "Failed to start discovery."; + auto creds = GenerateCredentials( + absl::StrCat("Standalone Receiver on ", argv[optind])); + OSP_CHECK(creds.is_value()); + // Runs until the process is interrupted. Safe to pass |task_runner| as it // will not be destroyed by ShutDown() until this exits. - StartCastAgent(task_runner, interface_info); + StartCastAgent(task_runner, interface_info, &(creds.value())); // The task runner must be deleted after all serial delete pointers, such // as the one stored in the discovery state. diff --git a/cast/standalone_receiver/static_credentials.cc b/cast/standalone_receiver/static_credentials.cc new file mode 100644 index 00000000..2ced690b --- /dev/null +++ b/cast/standalone_receiver/static_credentials.cc @@ -0,0 +1,138 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_receiver/static_credentials.h" + +#include <openssl/mem.h> + +#include <memory> +#include <utility> +#include <vector> + +#include "cast/standalone_receiver/private_key_der.h" +#include "platform/base/tls_credentials.h" +#include "util/crypto/certificate_utils.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace cast { +namespace { + +constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60; +constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds); + +} // namespace + +StaticCredentialsProvider::StaticCredentialsProvider() = default; +StaticCredentialsProvider::StaticCredentialsProvider( + DeviceCredentials device_creds, + std::vector<uint8_t> tls_cert_der) + : device_creds(std::move(device_creds)), + tls_cert_der(std::move(tls_cert_der)) {} + +StaticCredentialsProvider::StaticCredentialsProvider( + StaticCredentialsProvider&&) = default; +StaticCredentialsProvider& StaticCredentialsProvider::operator=( + StaticCredentialsProvider&&) = default; +StaticCredentialsProvider::~StaticCredentialsProvider() = default; + +ErrorOr<GeneratedCredentials> GenerateCredentials( + absl::string_view device_certificate_id) { + GeneratedCredentials credentials; + + bssl::UniquePtr<EVP_PKEY> root_key = GenerateRsaKeyPair(); + bssl::UniquePtr<EVP_PKEY> intermediate_key = GenerateRsaKeyPair(); + bssl::UniquePtr<EVP_PKEY> device_key = GenerateRsaKeyPair(); + OSP_CHECK(root_key); + OSP_CHECK(intermediate_key); + OSP_CHECK(device_key); + + ErrorOr<bssl::UniquePtr<X509>> root_cert_or_error = + CreateSelfSignedX509Certificate("Cast Root CA", kCertificateDuration, + *root_key, GetWallTimeSinceUnixEpoch(), + true); + OSP_CHECK(root_cert_or_error); + bssl::UniquePtr<X509> root_cert = std::move(root_cert_or_error.value()); + + ErrorOr<bssl::UniquePtr<X509>> intermediate_cert_or_error = + CreateSelfSignedX509Certificate( + "Cast Intermediate", kCertificateDuration, *intermediate_key, + GetWallTimeSinceUnixEpoch(), true, root_cert.get(), root_key.get()); + OSP_CHECK(intermediate_cert_or_error); + bssl::UniquePtr<X509> intermediate_cert = + std::move(intermediate_cert_or_error.value()); + + ErrorOr<bssl::UniquePtr<X509>> device_cert_or_error = + CreateSelfSignedX509Certificate( + device_certificate_id, kCertificateDuration, *device_key, + GetWallTimeSinceUnixEpoch(), false, intermediate_cert.get(), + intermediate_key.get()); + OSP_CHECK(device_cert_or_error); + bssl::UniquePtr<X509> device_cert = std::move(device_cert_or_error.value()); + + // NOTE: Device cert chain plumbing + serialization. + DeviceCredentials device_creds; + device_creds.private_key = std::move(device_key); + + int cert_length = i2d_X509(device_cert.get(), nullptr); + std::string cert_serial(cert_length, 0); + uint8_t* out = reinterpret_cast<uint8_t*>(&cert_serial[0]); + i2d_X509(device_cert.get(), &out); + device_creds.certs.emplace_back(std::move(cert_serial)); + + cert_length = i2d_X509(intermediate_cert.get(), nullptr); + cert_serial.resize(cert_length); + out = reinterpret_cast<uint8_t*>(&cert_serial[0]); + i2d_X509(intermediate_cert.get(), &out); + device_creds.certs.emplace_back(std::move(cert_serial)); + + cert_length = i2d_X509(root_cert.get(), nullptr); + std::vector<uint8_t> trust_anchor_der(cert_length); + out = &trust_anchor_der[0]; + i2d_X509(root_cert.get(), &out); + + // NOTE: TLS key pair + certificate generation. + bssl::UniquePtr<EVP_PKEY> tls_key = GenerateRsaKeyPair(); + OSP_CHECK_EQ(EVP_PKEY_id(tls_key.get()), EVP_PKEY_RSA); + ErrorOr<bssl::UniquePtr<X509>> tls_cert_or_error = + CreateSelfSignedX509Certificate("Test Device TLS", kCertificateDuration, + *tls_key, GetWallTimeSinceUnixEpoch()); + OSP_CHECK(tls_cert_or_error); + bssl::UniquePtr<X509> tls_cert = std::move(tls_cert_or_error.value()); + + // NOTE: TLS private key serialization. + RSA* rsa_key = EVP_PKEY_get0_RSA(tls_key.get()); + size_t pkey_len = 0; + uint8_t* pkey_bytes = nullptr; + OSP_CHECK(RSA_private_key_to_bytes(&pkey_bytes, &pkey_len, rsa_key)); + OSP_CHECK_GT(pkey_len, 0u); + std::vector<uint8_t> tls_key_serial(pkey_bytes, pkey_bytes + pkey_len); + OPENSSL_free(pkey_bytes); + + // NOTE: TLS public key serialization. + pkey_len = 0; + pkey_bytes = nullptr; + OSP_CHECK(RSA_public_key_to_bytes(&pkey_bytes, &pkey_len, rsa_key)); + OSP_CHECK_GT(pkey_len, 0u); + std::vector<uint8_t> tls_pub_serial(pkey_bytes, pkey_bytes + pkey_len); + OPENSSL_free(pkey_bytes); + + // NOTE: TLS cert serialization. + cert_length = 0; + cert_length = i2d_X509(tls_cert.get(), nullptr); + OSP_CHECK_GT(cert_length, 0); + std::vector<uint8_t> tls_cert_serial(cert_length); + out = &tls_cert_serial[0]; + i2d_X509(tls_cert.get(), &out); + + return GeneratedCredentials{ + std::make_unique<StaticCredentialsProvider>(std::move(device_creds), + tls_cert_serial), + TlsCredentials{std::move(tls_key_serial), std::move(tls_pub_serial), + std::move(tls_cert_serial)}, + std::move(trust_anchor_der)}; +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_receiver/static_credentials.h b/cast/standalone_receiver/static_credentials.h new file mode 100644 index 00000000..4707f5f4 --- /dev/null +++ b/cast/standalone_receiver/static_credentials.h @@ -0,0 +1,60 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_STATIC_CREDENTIALS_H_ +#define CAST_STANDALONE_RECEIVER_STATIC_CREDENTIALS_H_ + +#include <memory> +#include <vector> + +#include "absl/strings/string_view.h" +#include "cast/receiver/channel/device_auth_namespace_handler.h" +#include "platform/base/error.h" +#include "platform/base/tls_credentials.h" + +namespace openscreen { +namespace cast { + +class StaticCredentialsProvider final + : public DeviceAuthNamespaceHandler::CredentialsProvider { + public: + StaticCredentialsProvider(); + StaticCredentialsProvider(DeviceCredentials device_creds, + std::vector<uint8_t> tls_cert_der); + + StaticCredentialsProvider(const StaticCredentialsProvider&) = delete; + StaticCredentialsProvider(StaticCredentialsProvider&&); + StaticCredentialsProvider& operator=(const StaticCredentialsProvider&) = + delete; + StaticCredentialsProvider& operator=(StaticCredentialsProvider&&); + ~StaticCredentialsProvider(); + + absl::Span<const uint8_t> GetCurrentTlsCertAsDer() override { + return absl::Span<uint8_t>(tls_cert_der); + } + const DeviceCredentials& GetCurrentDeviceCredentials() override { + return device_creds; + } + + DeviceCredentials device_creds; + std::vector<uint8_t> tls_cert_der; +}; + +struct GeneratedCredentials { + std::unique_ptr<StaticCredentialsProvider> provider; + TlsCredentials tls_credentials; + std::vector<uint8_t> root_cert_der; +}; + +// Generates a valid set of credentials for use with the TLS Server socket, +// including a generated X509 certificate generated from the static private key +// stored in private_key_der.h. The certificate is valid for +// kCertificateDuration from when this function is called. +ErrorOr<GeneratedCredentials> GenerateCredentials( + absl::string_view device_certificate_id); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_STATIC_CREDENTIALS_H_ diff --git a/cast/streaming/receiver_session.h b/cast/streaming/receiver_session.h index be59f4ec..d41af31c 100644 --- a/cast/streaming/receiver_session.h +++ b/cast/streaming/receiver_session.h @@ -23,7 +23,7 @@ class CastSocket; class Environment; class Receiver; class VirtualConnectionRouter; -class VirtualConnection; +struct VirtualConnection; class ReceiverSession final : public MessagePort::Client { public: diff --git a/cast/test/cast_socket_e2e_test.cc b/cast/test/cast_socket_e2e_test.cc index d64322d3..28c7b4fb 100644 --- a/cast/test/cast_socket_e2e_test.cc +++ b/cast/test/cast_socket_e2e_test.cc @@ -138,14 +138,14 @@ class CastSocketE2ETest : public ::testing::Test { ASSERT_TRUE(device_key); ErrorOr<bssl::UniquePtr<X509>> root_cert_or_error = - CreateSelfSignedX509CertificateForTest( - "Cast Root CA", kCertificateDuration, *root_key, - GetWallTimeSinceUnixEpoch(), true); + CreateSelfSignedX509Certificate("Cast Root CA", kCertificateDuration, + *root_key, GetWallTimeSinceUnixEpoch(), + true); ASSERT_TRUE(root_cert_or_error); bssl::UniquePtr<X509> root_cert = std::move(root_cert_or_error.value()); ErrorOr<bssl::UniquePtr<X509>> intermediate_cert_or_error = - CreateSelfSignedX509CertificateForTest( + CreateSelfSignedX509Certificate( "Cast Intermediate", kCertificateDuration, *intermediate_key, GetWallTimeSinceUnixEpoch(), true, root_cert.get(), root_key.get()); ASSERT_TRUE(intermediate_cert_or_error); @@ -153,7 +153,7 @@ class CastSocketE2ETest : public ::testing::Test { std::move(intermediate_cert_or_error.value()); ErrorOr<bssl::UniquePtr<X509>> device_cert_or_error = - CreateSelfSignedX509CertificateForTest( + CreateSelfSignedX509Certificate( "Test Device", kCertificateDuration, *device_key, GetWallTimeSinceUnixEpoch(), false, intermediate_cert.get(), intermediate_key.get()); |