aboutsummaryrefslogtreecommitdiff
path: root/cast
diff options
context:
space:
mode:
authorbtolsch <btolsch@chromium.org>2020-03-16 18:06:20 -0700
committerCommit Bot <commit-bot@chromium.org>2020-03-17 01:38:27 +0000
commit1aa882626c3dbd6c9c3349dd3f99f7da0b2d7028 (patch)
treee36fde1a39bf8ebaed85f7d47906105899ed34b8 /cast
parent9170376d36ad66badba549759de9f10c61afd12a (diff)
downloadopenscreen-1aa882626c3dbd6c9c3349dd3f99f7da0b2d7028.tar.gz
Add first CastSocket E2E test
This change adds the first E2E test for CastSocket, which includes the support code necessary for generating certificate chains at runtime and loading an alternate Cast root certificate for testing. Bug: openscreen:59 Change-Id: I3362c555e63d64700e06abdd452bdbf7eb1ac204 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2099442 Commit-Queue: Brandon Tolsch <btolsch@chromium.org> Reviewed-by: Ryan Keane <rwkeane@google.com> Reviewed-by: Max Yakimakha <yakimakha@chromium.org>
Diffstat (limited to 'cast')
-rw-r--r--cast/common/BUILD.gn2
-rw-r--r--cast/common/certificate/cast_cert_validator.cc37
-rw-r--r--cast/common/certificate/cast_cert_validator_internal.h5
-rw-r--r--cast/common/certificate/cast_trust_store.cc66
-rw-r--r--cast/common/certificate/cast_trust_store.h39
-rw-r--r--cast/receiver/channel/receiver_socket_factory.cc2
-rw-r--r--cast/receiver/channel/receiver_socket_factory.h1
-rw-r--r--cast/sender/channel/sender_socket_factory.cc5
-rw-r--r--cast/test/BUILD.gn21
-rw-r--r--cast/test/DEPS2
-rw-r--r--cast/test/cast_socket_e2e_test.cc345
11 files changed, 489 insertions, 36 deletions
diff --git a/cast/common/BUILD.gn b/cast/common/BUILD.gn
index 5abbbf8a..f753198c 100644
--- a/cast/common/BUILD.gn
+++ b/cast/common/BUILD.gn
@@ -14,6 +14,8 @@ source_set("certificate") {
"certificate/cast_cert_validator_internal.h",
"certificate/cast_crl.cc",
"certificate/cast_crl.h",
+ "certificate/cast_trust_store.cc",
+ "certificate/cast_trust_store.h",
"certificate/types.cc",
"certificate/types.h",
]
diff --git a/cast/common/certificate/cast_cert_validator.cc b/cast/common/certificate/cast_cert_validator.cc
index 5645c9d6..6d2a12e7 100644
--- a/cast/common/certificate/cast_cert_validator.cc
+++ b/cast/common/certificate/cast_cert_validator.cc
@@ -17,25 +17,13 @@
#include "cast/common/certificate/cast_cert_validator_internal.h"
#include "cast/common/certificate/cast_crl.h"
+#include "cast/common/certificate/cast_trust_store.h"
+#include "util/logging.h"
namespace openscreen {
namespace cast {
namespace {
-// -------------------------------------------------------------------------
-// Cast trust anchors.
-// -------------------------------------------------------------------------
-
-// There are two trusted roots for Cast certificate chains:
-//
-// (1) CN=Cast Root CA (kCastRootCaDer)
-// (2) CN=Eureka Root CA (kEurekaRootCaDer)
-//
-// These constants are defined by the files included next:
-
-#include "cast/common/certificate/cast_root_ca_cert_der-inc.h"
-#include "cast/common/certificate/eureka_root_ca_der-inc.h"
-
// Returns the OID for the Audio-Only Cast policy
// (1.3.6.1.4.1.11129.2.5.2) in DER form.
const ConstDataSpan& AudioOnlyPolicyOid() {
@@ -141,27 +129,6 @@ CastDeviceCertPolicy GetAudioPolicy(const std::vector<X509*>& path) {
} // namespace
-class CastTrustStore {
- public:
- // Singleton for the Cast trust store for legacy networkingPrivate use.
- static CastTrustStore* GetInstance() {
- static CastTrustStore* store = new CastTrustStore();
- return store;
- }
-
- CastTrustStore() {
- trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer));
- trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer));
- }
- ~CastTrustStore() = default;
-
- TrustStore* trust_store() { return &trust_store_; }
-
- private:
- TrustStore trust_store_;
- OSP_DISALLOW_COPY_AND_ASSIGN(CastTrustStore);
-};
-
Error VerifyDeviceCert(const std::vector<std::string>& der_certs,
const DateTime& time,
std::unique_ptr<CertVerificationContext>* context,
diff --git a/cast/common/certificate/cast_cert_validator_internal.h b/cast/common/certificate/cast_cert_validator_internal.h
index c5b383ff..f8424b6d 100644
--- a/cast/common/certificate/cast_cert_validator_internal.h
+++ b/cast/common/certificate/cast_cert_validator_internal.h
@@ -26,6 +26,11 @@ bssl::UniquePtr<X509> MakeTrustAnchor(const uint8_t (&data)[N]) {
return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, N)};
}
+inline bssl::UniquePtr<X509> MakeTrustAnchor(const std::vector<uint8_t>& data) {
+ const uint8_t* dptr = data.data();
+ return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, data.size())};
+}
+
struct ConstDataSpan;
struct DateTime;
diff --git a/cast/common/certificate/cast_trust_store.cc b/cast/common/certificate/cast_trust_store.cc
new file mode 100644
index 00000000..8c9e5e24
--- /dev/null
+++ b/cast/common/certificate/cast_trust_store.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/certificate/cast_trust_store.h"
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+// -------------------------------------------------------------------------
+// Cast trust anchors.
+// -------------------------------------------------------------------------
+
+// There are two trusted roots for Cast certificate chains:
+//
+// (1) CN=Cast Root CA (kCastRootCaDer)
+// (2) CN=Eureka Root CA (kEurekaRootCaDer)
+//
+// These constants are defined by the files included next:
+
+#include "cast/common/certificate/cast_root_ca_cert_der-inc.h"
+#include "cast/common/certificate/eureka_root_ca_der-inc.h"
+
+} // namespace
+
+// static
+CastTrustStore* CastTrustStore::GetInstance() {
+ if (!store_) {
+ store_ = new CastTrustStore();
+ }
+ return store_;
+}
+
+// static
+void CastTrustStore::ResetInstance() {
+ delete store_;
+ store_ = nullptr;
+}
+
+// static
+CastTrustStore* CastTrustStore::CreateInstanceForTest(
+ const std::vector<uint8_t>& trust_anchor_der) {
+ OSP_DCHECK(!store_);
+ store_ = new CastTrustStore(trust_anchor_der);
+ return store_;
+}
+
+CastTrustStore::CastTrustStore() {
+ trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer));
+ trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer));
+}
+
+CastTrustStore::CastTrustStore(const std::vector<uint8_t>& trust_anchor_der) {
+ trust_store_.certs.emplace_back(MakeTrustAnchor(trust_anchor_der));
+}
+
+CastTrustStore::~CastTrustStore() = default;
+
+// static
+CastTrustStore* CastTrustStore::store_ = nullptr;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/cast/common/certificate/cast_trust_store.h b/cast/common/certificate/cast_trust_store.h
new file mode 100644
index 00000000..8aac9d39
--- /dev/null
+++ b/cast/common/certificate/cast_trust_store.h
@@ -0,0 +1,39 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_
+#define CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_
+
+#include <vector>
+
+#include "cast/common/certificate/cast_cert_validator_internal.h"
+
+namespace openscreen {
+namespace cast {
+
+class CastTrustStore {
+ public:
+ static CastTrustStore* GetInstance();
+ static void ResetInstance();
+
+ static CastTrustStore* CreateInstanceForTest(
+ const std::vector<uint8_t>& trust_anchor_der);
+
+ CastTrustStore();
+ CastTrustStore(const std::vector<uint8_t>& trust_anchor_der);
+ CastTrustStore(const CastTrustStore&) = delete;
+ ~CastTrustStore();
+ CastTrustStore& operator=(const CastTrustStore&) = delete;
+
+ TrustStore* trust_store() { return &trust_store_; }
+
+ private:
+ static CastTrustStore* store_;
+ TrustStore trust_store_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_
diff --git a/cast/receiver/channel/receiver_socket_factory.cc b/cast/receiver/channel/receiver_socket_factory.cc
index d5e3506c..f0a642e7 100644
--- a/cast/receiver/channel/receiver_socket_factory.cc
+++ b/cast/receiver/channel/receiver_socket_factory.cc
@@ -39,6 +39,8 @@ void ReceiverSocketFactory::OnConnectionFailed(
TlsConnectionFactory* factory,
const IPEndpoint& remote_address) {
OSP_DVLOG << "Receiving connection from endpoint failed: " << remote_address;
+ client_->OnError(this, Error(Error::Code::kConnectionFailed,
+ "Accepting connection failed."));
}
void ReceiverSocketFactory::OnError(TlsConnectionFactory* factory,
diff --git a/cast/receiver/channel/receiver_socket_factory.h b/cast/receiver/channel/receiver_socket_factory.h
index 71a66cb2..d8bde8ca 100644
--- a/cast/receiver/channel/receiver_socket_factory.h
+++ b/cast/receiver/channel/receiver_socket_factory.h
@@ -25,6 +25,7 @@ class ReceiverSocketFactory final : public TlsConnectionFactory::Client {
};
// |client| and |socket_client| must outlive |this|.
+ // TODO(btolsch): Add TaskRunner argument just for sequence checking.
ReceiverSocketFactory(Client* client, CastSocket::Client* socket_client);
~ReceiverSocketFactory();
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
index b2b036b1..bf89de88 100644
--- a/cast/sender/channel/sender_socket_factory.cc
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -33,7 +33,9 @@ SenderSocketFactory::SenderSocketFactory(Client* client,
OSP_DCHECK(task_runner);
}
-SenderSocketFactory::~SenderSocketFactory() = default;
+SenderSocketFactory::~SenderSocketFactory() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+}
void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
OSP_DCHECK(factory);
@@ -43,6 +45,7 @@ void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
DeviceMediaPolicy media_policy,
CastSocket::Client* client) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(client);
auto it = FindPendingConnection(endpoint);
if (it == pending_connections_.end()) {
diff --git a/cast/test/BUILD.gn b/cast/test/BUILD.gn
index 5a3e6778..83950d2f 100644
--- a/cast/test/BUILD.gn
+++ b/cast/test/BUILD.gn
@@ -24,6 +24,27 @@ source_set("unittests") {
}
if (is_posix && !build_with_chromium) {
+ source_set("e2e_tests") {
+ testonly = true
+ sources = [
+ "cast_socket_e2e_test.cc",
+ ]
+
+ deps = [
+ "../../platform",
+ "../../third_party/abseil",
+ "../../third_party/boringssl",
+ "../../third_party/googletest:gtest",
+ "../../util",
+ "../common:certificate",
+ "../common:channel",
+ "../common:test_helpers",
+ "../receiver:channel",
+ "../receiver:test_helpers",
+ "../sender:channel",
+ ]
+ }
+
executable("make_crl_tests") {
testonly = true
sources = [
diff --git a/cast/test/DEPS b/cast/test/DEPS
index e43aeddb..fb626694 100644
--- a/cast/test/DEPS
+++ b/cast/test/DEPS
@@ -2,4 +2,6 @@ include_rules = [
'+cast/common',
'+cast/receiver',
'+cast/sender',
+
+ '+platform/impl',
]
diff --git a/cast/test/cast_socket_e2e_test.cc b/cast/test/cast_socket_e2e_test.cc
new file mode 100644
index 00000000..ae5331e1
--- /dev/null
+++ b/cast/test/cast_socket_e2e_test.cc
@@ -0,0 +1,345 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <openssl/evp.h>
+#include <openssl/mem.h>
+
+#include <atomic>
+#include <chrono>
+
+#include "cast/common/certificate/cast_trust_store.h"
+#include "cast/common/certificate/testing/test_helpers.h"
+#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/connection_namespace_handler.h"
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "cast/receiver/channel/device_auth_namespace_handler.h"
+#include "cast/receiver/channel/receiver_socket_factory.h"
+#include "cast/receiver/channel/testing/device_auth_test_helpers.h"
+#include "cast/sender/channel/sender_socket_factory.h"
+#include "gtest/gtest.h"
+#include "platform/api/tls_connection_factory.h"
+#include "platform/base/tls_connect_options.h"
+#include "platform/base/tls_credentials.h"
+#include "platform/base/tls_listen_options.h"
+#include "platform/impl/logging.h"
+#include "platform/impl/network_interface.h"
+#include "platform/impl/platform_client_posix.h"
+#include "util/crypto/certificate_utils.h"
+#include "util/logging.h"
+#include "util/serial_delete_ptr.h"
+
+namespace openscreen {
+namespace cast {
+
+constexpr auto kCertificateDuration = std::chrono::hours(24);
+
+class SenderSocketsClient final
+ : public SenderSocketFactory::Client,
+ public VirtualConnectionRouter::SocketErrorHandler {
+ public:
+ SenderSocketsClient(VirtualConnectionRouter* router) : router_(router) {}
+ ~SenderSocketsClient() = default;
+
+ CastSocket* socket() const { return socket_; }
+
+ // SenderSocketFactory::Client overrides.
+ void OnConnected(SenderSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) {
+ OSP_DCHECK(!socket_);
+ OSP_LOG_INFO << "\tSender connected to endpoint: " << endpoint;
+ socket_ = socket.get();
+ router_->TakeSocket(this, std::move(socket));
+ }
+
+ void OnError(SenderSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ Error error) override {
+ OSP_NOTREACHED() << error;
+ }
+
+ // VirtualConnectionRouter::SocketErrorHandler overrides.
+ void OnClose(CastSocket* socket) override {}
+ void OnError(CastSocket* socket, Error error) override {
+ OSP_NOTREACHED() << error;
+ }
+
+ private:
+ VirtualConnectionRouter* const router_;
+ std::atomic<CastSocket*> socket_{nullptr};
+};
+
+class ReceiverSocketsClient final
+ : public ReceiverSocketFactory::Client,
+ public VirtualConnectionRouter::SocketErrorHandler {
+ public:
+ explicit ReceiverSocketsClient(VirtualConnectionRouter* router)
+ : router_(router) {}
+ ~ReceiverSocketsClient() = default;
+
+ const IPEndpoint& endpoint() const { return endpoint_; }
+ CastSocket* socket() const { return socket_; }
+
+ // ReceiverSocketFactory::Client overrides.
+ void OnConnected(ReceiverSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) override {
+ OSP_DCHECK(!socket_);
+ OSP_LOG_INFO << "\tReceiver got connection from endpoint: " << endpoint;
+ endpoint_ = endpoint;
+ socket_ = socket.get();
+ router_->TakeSocket(this, std::move(socket));
+ }
+
+ void OnError(ReceiverSocketFactory* factory, Error error) override {
+ OSP_NOTREACHED() << error;
+ }
+
+ // VirtualConnectionRouter::SocketErrorHandler overrides.
+ void OnClose(CastSocket* socket) override {}
+ void OnError(CastSocket* socket, Error error) override {
+ OSP_NOTREACHED() << error;
+ }
+
+ private:
+ VirtualConnectionRouter* router_;
+ IPEndpoint endpoint_;
+ std::atomic<CastSocket*> socket_{nullptr};
+};
+
+class CastSocketE2ETest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ SetLogLevel(LogLevel::kInfo);
+
+ PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50});
+ task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
+
+ sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(
+ task_runner_, &sender_vc_manager_);
+ sender_client_ =
+ std::make_unique<SenderSocketsClient>(sender_router_.get());
+ sender_factory_ = MakeSerialDelete<SenderSocketFactory>(
+ task_runner_, sender_client_.get(), task_runner_);
+ sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
+ task_runner_,
+ TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_)
+ .release());
+ sender_factory_->set_factory(sender_tls_factory_.get());
+
+ // NOTE: Device cert chain generation.
+ bssl::UniquePtr<EVP_PKEY> root_key = GenerateRsaKeyPair();
+ bssl::UniquePtr<EVP_PKEY> intermediate_key = GenerateRsaKeyPair();
+ bssl::UniquePtr<EVP_PKEY> device_key = GenerateRsaKeyPair();
+ ASSERT_TRUE(root_key);
+ ASSERT_TRUE(intermediate_key);
+ ASSERT_TRUE(device_key);
+
+ ErrorOr<bssl::UniquePtr<X509>> root_cert_or_error =
+ CreateCertificateForTest("Cast Root CA", kCertificateDuration,
+ *root_key, GetWallTimeSinceUnixEpoch(), true);
+ ASSERT_TRUE(root_cert_or_error);
+ bssl::UniquePtr<X509> root_cert = std::move(root_cert_or_error.value());
+
+ ErrorOr<bssl::UniquePtr<X509>> intermediate_cert_or_error =
+ CreateCertificateForTest("Cast Intermediate", kCertificateDuration,
+ *intermediate_key, GetWallTimeSinceUnixEpoch(),
+ true, root_cert.get(), root_key.get());
+ ASSERT_TRUE(intermediate_cert_or_error);
+ bssl::UniquePtr<X509> intermediate_cert =
+ std::move(intermediate_cert_or_error.value());
+
+ ErrorOr<bssl::UniquePtr<X509>> device_cert_or_error =
+ CreateCertificateForTest("Test Device", kCertificateDuration,
+ *device_key, GetWallTimeSinceUnixEpoch(),
+ false, intermediate_cert.get(),
+ intermediate_key.get());
+ ASSERT_TRUE(device_cert_or_error);
+ bssl::UniquePtr<X509> device_cert = std::move(device_cert_or_error.value());
+
+ // NOTE: Device cert chain plumbing + serialization.
+ receiver_creds_provider_.device_creds.private_key = std::move(device_key);
+
+ int cert_length = i2d_X509(device_cert.get(), nullptr);
+ std::string cert_serial(cert_length, 0);
+ uint8_t* out = reinterpret_cast<uint8_t*>(&cert_serial[0]);
+ i2d_X509(device_cert.get(), &out);
+ receiver_creds_provider_.device_creds.certs.emplace_back(
+ std::move(cert_serial));
+
+ cert_length = i2d_X509(intermediate_cert.get(), nullptr);
+ cert_serial.resize(cert_length);
+ out = reinterpret_cast<uint8_t*>(&cert_serial[0]);
+ i2d_X509(intermediate_cert.get(), &out);
+ receiver_creds_provider_.device_creds.certs.emplace_back(
+ std::move(cert_serial));
+
+ cert_length = i2d_X509(root_cert.get(), nullptr);
+ std::vector<uint8_t> trust_anchor_der(cert_length);
+ out = &trust_anchor_der[0];
+ i2d_X509(root_cert.get(), &out);
+ CastTrustStore::CreateInstanceForTest(trust_anchor_der);
+
+ // NOTE: TLS key pair + certificate generation.
+ bssl::UniquePtr<EVP_PKEY> tls_key = GenerateRsaKeyPair();
+ ASSERT_EQ(EVP_PKEY_id(tls_key.get()), EVP_PKEY_RSA);
+ ErrorOr<bssl::UniquePtr<X509>> tls_cert_or_error =
+ CreateCertificate("Test Device TLS", kCertificateDuration, *tls_key,
+ GetWallTimeSinceUnixEpoch());
+ ASSERT_TRUE(tls_cert_or_error);
+ bssl::UniquePtr<X509> tls_cert = std::move(tls_cert_or_error.value());
+
+ // NOTE: TLS private key serialization.
+ RSA* rsa_key = EVP_PKEY_get0_RSA(tls_key.get());
+ size_t pkey_len = 0;
+ uint8_t* pkey_bytes = nullptr;
+ ASSERT_TRUE(RSA_private_key_to_bytes(&pkey_bytes, &pkey_len, rsa_key));
+ ASSERT_GT(pkey_len, 0u);
+ std::vector<uint8_t> tls_key_serial(pkey_bytes, pkey_bytes + pkey_len);
+ OPENSSL_free(pkey_bytes);
+ receiver_tls_creds_.der_rsa_private_key = std::move(tls_key_serial);
+
+ // NOTE: TLS public key serialization.
+ pkey_len = 0;
+ pkey_bytes = nullptr;
+ ASSERT_TRUE(RSA_public_key_to_bytes(&pkey_bytes, &pkey_len, rsa_key));
+ ASSERT_GT(pkey_len, 0u);
+ std::vector<uint8_t> tls_pub_serial(pkey_bytes, pkey_bytes + pkey_len);
+ OPENSSL_free(pkey_bytes);
+ receiver_tls_creds_.der_rsa_public_key = std::move(tls_pub_serial);
+
+ // NOTE: TLS cert serialization.
+ cert_length = 0;
+ cert_length = i2d_X509(tls_cert.get(), nullptr);
+ ASSERT_GT(cert_length, 0);
+ std::vector<uint8_t> tls_cert_serial(cert_length);
+ out = &tls_cert_serial[0];
+ i2d_X509(tls_cert.get(), &out);
+ receiver_creds_provider_.tls_cert_der = tls_cert_serial;
+ receiver_tls_creds_.der_x509_cert = std::move(tls_cert_serial);
+
+ auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>(
+ task_runner_, &receiver_creds_provider_);
+ receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(
+ task_runner_, &receiver_vc_manager_);
+ receiver_router_->AddHandlerForLocalId(kPlatformReceiverId,
+ auth_handler_.get());
+ receiver_client_ =
+ std::make_unique<ReceiverSocketsClient>(receiver_router_.get());
+ receiver_factory_ = MakeSerialDelete<ReceiverSocketFactory>(
+ task_runner_, receiver_client_.get(), receiver_router_.get());
+
+ receiver_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
+ task_runner_, TlsConnectionFactory::CreateFactory(
+ receiver_factory_.get(), task_runner_)
+ .release());
+ }
+
+ void TearDown() override {
+ OSP_LOG_INFO << "Shutting down";
+ sender_router_.reset();
+ receiver_router_.reset();
+ receiver_tls_factory_.reset();
+ receiver_factory_.reset();
+ auth_handler_.reset();
+ sender_tls_factory_.reset();
+ sender_factory_.reset();
+ CastTrustStore::ResetInstance();
+ PlatformClientPosix::ShutDown();
+ }
+
+ protected:
+ IPAddress GetLoopbackV4Address() {
+ absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
+ OSP_DCHECK(loopback);
+ auto address = loopback->GetIpAddressV4();
+ OSP_DCHECK(address);
+ return address;
+ }
+
+ IPAddress GetLoopbackV6Address() {
+ absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
+ OSP_DCHECK(loopback);
+ auto address = loopback->GetIpAddressV6();
+ OSP_DCHECK(address);
+ return address;
+ }
+
+ void WaitForCastSocket() {
+ int attempts = 1;
+ constexpr int kMaxAttempts = 8;
+ constexpr std::chrono::milliseconds kSocketWaitDelay(250);
+ do {
+ OSP_LOG_INFO << "\tChecking for CastSocket, attempt " << attempts << "/"
+ << kMaxAttempts;
+ if (sender_client_->socket()) {
+ break;
+ }
+ std::this_thread::sleep_for(kSocketWaitDelay);
+ } while (attempts++ < kMaxAttempts);
+ ASSERT_TRUE(sender_client_->socket());
+ }
+
+ void Connect(const IPAddress& address) {
+ uint16_t port = 65321;
+ OSP_LOG_INFO << "\tStarting socket factories";
+ task_runner_->PostTask([this, &address, port]() {
+ OSP_LOG_INFO << "\tReceiver TLS factory Listen()";
+ receiver_tls_factory_->SetListenCredentials(receiver_tls_creds_);
+ receiver_tls_factory_->Listen(IPEndpoint{address, port},
+ TlsListenOptions{1u});
+ });
+
+ task_runner_->PostTask([this, &address, port]() {
+ OSP_LOG_INFO << "\tSender CastSocket factory Connect()";
+ sender_factory_->Connect(IPEndpoint{address, port},
+ SenderSocketFactory::DeviceMediaPolicy::kNone,
+ sender_router_.get());
+ });
+
+ WaitForCastSocket();
+ }
+
+ TaskRunner* task_runner_;
+
+ // NOTE: Sender components.
+ VirtualConnectionManager sender_vc_manager_;
+ SerialDeletePtr<VirtualConnectionRouter> sender_router_;
+ std::unique_ptr<SenderSocketsClient> sender_client_;
+ SerialDeletePtr<SenderSocketFactory> sender_factory_;
+ SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_;
+
+ // NOTE: Receiver components.
+ VirtualConnectionManager receiver_vc_manager_;
+ SerialDeletePtr<VirtualConnectionRouter> receiver_router_;
+ StaticCredentialsProvider receiver_creds_provider_;
+ SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_;
+ std::unique_ptr<ReceiverSocketsClient> receiver_client_;
+ SerialDeletePtr<ReceiverSocketFactory> receiver_factory_;
+ TlsCredentials receiver_tls_creds_;
+ SerialDeletePtr<TlsConnectionFactory> receiver_tls_factory_;
+};
+
+// These test the most basic setup of a complete CastSocket. This means
+// constructing both a SenderSocketFactory and ReceiverSocketFactory, making a
+// TLS connection to a known port over the loopback device, and checking device
+// authentication.
+TEST_F(CastSocketE2ETest, ConnectV4) {
+ OSP_LOG_INFO << "Getting loopback IPv4 address";
+ IPAddress loopback_address = GetLoopbackV4Address();
+ OSP_LOG_INFO << "Connecting CastSockets";
+ Connect(loopback_address);
+}
+
+TEST_F(CastSocketE2ETest, ConnectV6) {
+ OSP_LOG_INFO << "Getting loopback IPv6 address";
+ IPAddress loopback_address = GetLoopbackV6Address();
+ OSP_LOG_INFO << "Connecting CastSockets";
+ Connect(loopback_address);
+}
+
+} // namespace cast
+} // namespace openscreen