diff options
45 files changed, 659 insertions, 321 deletions
diff --git a/build/config/BUILD.gn b/build/config/BUILD.gn index a68031e8..7309ad8b 100644 --- a/build/config/BUILD.gn +++ b/build/config/BUILD.gn @@ -254,3 +254,12 @@ config("sysroot_runtime_libraries") { } } } + +config("operating_system_defines") { + defines = [] + if (is_linux) { + defines += [ "OS_LINUX" ] + } else if (is_mac) { + defines += [ "MAC_OSX" ] + } +} diff --git a/build/config/BUILDCONFIG.gn b/build/config/BUILDCONFIG.gn index 0fa96935..3b1a06a1 100644 --- a/build/config/BUILDCONFIG.gn +++ b/build/config/BUILDCONFIG.gn @@ -161,6 +161,7 @@ _shared_binary_target_configs = [ "//build/config:compiler_cpu_abi", "//build/config:default_optimization", "//build/config:sysroot_runtime_libraries", + "//build/config:operating_system_defines", ] # Apply that default list to the binary target types. diff --git a/cast/README.md b/cast/README.md index 1b890c5b..a501703b 100644 --- a/cast/README.md +++ b/cast/README.md @@ -3,3 +3,31 @@ libcast is an open source implementation of the Cast procotol supporting Cast applications and streaming to Cast-compatible devices. +## Using the standalone implementations + +To run the standalone sender and receivers together, first you need to install +the following dependencies: FFMPEG, LibVPX, LibOpus, LibSDL2, as well as their +headers (frequently in a seperate -dev package). From here, you need to generate +a RSA private key and create a self signed certificate with that key. + +From there, after building Open Screen the `cast_sender` and `cast_receiver` +executables should be ready to use: +``` + $ /path/to/out/Default/cast_sender -s <certificate> <path/to/video> + ... + $ /path/to/out/Default/cast_receiver <interface> -p <private_key> -s <certificate> +``` + +When running on Mac OS X, also pass the `-x` flag to the cast receiver to +disable DNS-SD/mDNS, since Open Screen does not currently integrate with +Bonjour. + +When connecting to a receiver that's not running on the loopback interface +(typically `lo` or `lo0`), pass the `-r <receiver IP endpoint>` flag to the +`cast_sender` binary. + +An archive containing test running scripts, a video, and a generated RSA +key and certificate is available from google storage. Note that it may require +modification to work on your specific work environment: + +https://storage.googleapis.com/openscreen_standalone/cast_streaming_demo.tar.gz diff --git a/cast/common/certificate/cast_cert_validator_internal.cc b/cast/common/certificate/cast_cert_validator_internal.cc index 569d22b1..e4c689f8 100644 --- a/cast/common/certificate/cast_cert_validator_internal.cc +++ b/cast/common/certificate/cast_cert_validator_internal.cc @@ -14,6 +14,7 @@ #include <vector> #include "cast/common/certificate/types.h" +#include "util/crypto/pem_helpers.h" #include "util/osp_logging.h" namespace openscreen { @@ -95,7 +96,8 @@ bssl::UniquePtr<ASN1_BIT_STRING> GetKeyUsage(X509* cert) { Error::Code VerifyCertificateChain(const std::vector<CertPathStep>& path, uint32_t step_index, - const DateTime& time) { + const DateTime& time, + TrustStore::Mode mode) { // Default max path length is the number of intermediate certificates. int max_pathlen = path.size() - 2; @@ -132,33 +134,37 @@ Error::Code VerifyCertificateChain(const std::vector<CertPathStep>& path, } } - // Check that basicConstraints is present, specifies the CA bit, and use - // pathLenConstraint if present. - const int basic_constraints_index = - X509_get_ext_by_NID(issuer, NID_basic_constraints, -1); - if (basic_constraints_index == -1) { - return Error::Code::kErrCertsVerifyGeneric; - } - X509_EXTENSION* const basic_constraints_extension = - X509_get_ext(issuer, basic_constraints_index); - bssl::UniquePtr<BASIC_CONSTRAINTS> basic_constraints{ - reinterpret_cast<BASIC_CONSTRAINTS*>( - X509V3_EXT_d2i(basic_constraints_extension))}; + // Certificates issued by a valid CA authority shall have the + // basicConstraints property present with the CA bit set. Self-signed + // certificates do not have this property present. + if (mode == TrustStore::Mode::kStrict) { + const int basic_constraints_index = + X509_get_ext_by_NID(issuer, NID_basic_constraints, -1); + if (basic_constraints_index == -1) { + return Error::Code::kErrCertsVerifyGeneric; + } - if (!basic_constraints || !basic_constraints->ca) { - return Error::Code::kErrCertsVerifyGeneric; - } + X509_EXTENSION* const basic_constraints_extension = + X509_get_ext(issuer, basic_constraints_index); + bssl::UniquePtr<BASIC_CONSTRAINTS> basic_constraints{ + reinterpret_cast<BASIC_CONSTRAINTS*>( + X509V3_EXT_d2i(basic_constraints_extension))}; - if (basic_constraints->pathlen) { - if (basic_constraints->pathlen->length != 1) { + if (!basic_constraints || !basic_constraints->ca) { return Error::Code::kErrCertsVerifyGeneric; - } else { - const int pathlen = *basic_constraints->pathlen->data; - if (pathlen < 0) { + } + + if (basic_constraints->pathlen) { + if (basic_constraints->pathlen->length != 1) { return Error::Code::kErrCertsVerifyGeneric; - } - if (pathlen < max_pathlen) { - max_pathlen = pathlen; + } else { + const int pathlen = *basic_constraints->pathlen->data; + if (pathlen < 0) { + return Error::Code::kErrCertsVerifyGeneric; + } + if (pathlen < max_pathlen) { + max_pathlen = pathlen; + } } } } @@ -355,6 +361,21 @@ bool GetCertValidTimeRange(X509* cert, return times_valid; } +// static +TrustStore TrustStore::CreateInstanceFromPemFile(absl::string_view file_path, + TrustStore::Mode mode) { + TrustStore store; + + std::vector<std::string> certs = ReadCertificatesFromPemFile(file_path); + for (const auto& der_cert : certs) { + const uint8_t* data = (const uint8_t*)der_cert.data(); + store.certs.emplace_back(d2i_X509(nullptr, &data, der_cert.size())); + } + + store.mode = mode; + return store; +} + bool VerifySignedData(const EVP_MD* digest, EVP_PKEY* public_key, const ConstDataSpan& data, @@ -374,7 +395,7 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, CertificatePathResult* result_path, TrustStore* trust_store) { if (der_certs.empty()) { - return Error::Code::kErrCertsMissing; + return Error(Error::Code::kErrCertsMissing, "Missing DER certificates"); } bssl::UniquePtr<X509>& target_cert = result_path->target_cert; @@ -500,7 +521,7 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, if (last_error == Error::Code::kNone) { OSP_DVLOG << "FindCertificatePath: Failed after trying all " "certificate paths, no matches"; - return Error::Code::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyUntrustedCert; } return last_error; } else { @@ -512,7 +533,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, } if (path_cert_in_trust_store) { - last_error = VerifyCertificateChain(path, path_index, time); + last_error = + VerifyCertificateChain(path, path_index, time, trust_store->mode); if (last_error != Error::Code::kNone) { CertPathStep& last_step = path[path_index++]; trust_store_index = last_step.trust_store_index; diff --git a/cast/common/certificate/cast_cert_validator_internal.h b/cast/common/certificate/cast_cert_validator_internal.h index f8424b6d..9264418e 100644 --- a/cast/common/certificate/cast_cert_validator_internal.h +++ b/cast/common/certificate/cast_cert_validator_internal.h @@ -7,15 +7,32 @@ #include <openssl/x509.h> +#include <string> #include <vector> +#include "absl/strings/string_view.h" #include "platform/base/error.h" - namespace openscreen { namespace cast { struct TrustStore { + enum class Mode { + // In strict mode, only certificates signed by a CA will be accepted as + // part of authentication. Note that if a self-signed certificate is placed + // in a strict mode TrustStore, it cannot be used for authentication. + kStrict, + + // In allow self signed mode, certificates signed by an arbitrary private + // key that have been placed in this trust store will be allowed. Note + // that certificates must still otherwise be valid. + kAllowSelfSigned + }; + + static TrustStore CreateInstanceFromPemFile(absl::string_view file_path, + Mode mode = Mode::kStrict); + std::vector<bssl::UniquePtr<X509>> certs; + Mode mode = Mode::kStrict; }; // Adds a trust anchor given a DER-encoded certificate from static diff --git a/cast/common/certificate/cast_cert_validator_unittest.cc b/cast/common/certificate/cast_cert_validator_unittest.cc index f7e21d84..53b6f05f 100644 --- a/cast/common/certificate/cast_cert_validator_unittest.cc +++ b/cast/common/certificate/cast_cert_validator_unittest.cc @@ -12,6 +12,7 @@ #include "gtest/gtest.h" #include "openssl/pem.h" #include "platform/test/paths.h" +#include "util/crypto/pem_helpers.h" namespace openscreen { namespace cast { @@ -51,8 +52,7 @@ void RunTest(Error::Code expected_result, const DateTime& time, TrustStoreDependency trust_store_dependency, const std::string& optional_signed_data_file_name) { - std::vector<std::string> certs = - testing::ReadCertificatesFromPemFile(certs_file_name); + std::vector<std::string> certs = ReadCertificatesFromPemFile(certs_file_name); TrustStore* trust_store; std::unique_ptr<TrustStore> fake_trust_store; @@ -94,7 +94,10 @@ void RunTest(Error::Code expected_result, // Test that the context is good. EXPECT_EQ(expected_common_name, context->GetCommonName()); -#define DATA_SPAN_FROM_LITERAL(s) ConstDataSpan{(uint8_t*)s, sizeof(s) - 1} +#define DATA_SPAN_FROM_LITERAL(s) \ + ConstDataSpan{const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(s)), \ + sizeof(s) - 1} + // Test verification of some invalid signatures. EXPECT_FALSE(context->VerifySignatureOverData( DATA_SPAN_FROM_LITERAL("bogus signature"), @@ -233,7 +236,7 @@ TEST(VerifyCastDeviceCertTest, Fugu) { // This is invalid because it does not chain to a trust anchor. TEST(VerifyCastDeviceCertTest, Unchained) { std::string data_path = GetSpecificTestDataPath(); - RunTest(Error::Code::kErrCertsVerifyGeneric, "", + RunTest(Error::Code::kErrCertsVerifyUntrustedCert, "", CastDeviceCertPolicy::kUnrestricted, data_path + "certificates/unchained.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); diff --git a/cast/common/certificate/cast_crl_unittest.cc b/cast/common/certificate/cast_crl_unittest.cc index fe65cce3..c4d3bfc4 100644 --- a/cast/common/certificate/cast_crl_unittest.cc +++ b/cast/common/certificate/cast_crl_unittest.cc @@ -99,9 +99,11 @@ bool RunTest(const DeviceCertTest& test_case) { std::unique_ptr<TrustStore> crl_trust_store; std::unique_ptr<TrustStore> cast_trust_store; if (test_case.use_test_trust_anchors()) { - crl_trust_store = testing::CreateTrustStoreFromPemFile( + crl_trust_store = std::make_unique<TrustStore>(); + cast_trust_store = std::make_unique<TrustStore>(); + *crl_trust_store = TrustStore::CreateInstanceFromPemFile( GetSpecificTestDataPath() + "certificates/cast_crl_test_root_ca.pem"); - cast_trust_store = testing::CreateTrustStoreFromPemFile( + *cast_trust_store = TrustStore::CreateInstanceFromPemFile( GetSpecificTestDataPath() + "certificates/cast_test_root_ca.pem"); EXPECT_FALSE(crl_trust_store->certs.empty()); diff --git a/cast/common/certificate/cast_trust_store.cc b/cast/common/certificate/cast_trust_store.cc index 93db49ba..d8ec513c 100644 --- a/cast/common/certificate/cast_trust_store.cc +++ b/cast/common/certificate/cast_trust_store.cc @@ -4,6 +4,9 @@ #include "cast/common/certificate/cast_trust_store.h" +#include <utility> + +#include "util/crypto/pem_helpers.h" #include "util/osp_logging.h" namespace openscreen { @@ -48,6 +51,16 @@ CastTrustStore* CastTrustStore::CreateInstanceForTest( return store_; } +// static +CastTrustStore* CastTrustStore::CreateInstanceFromPemFile( + absl::string_view file_path, + TrustStore::Mode mode) { + OSP_DCHECK(!store_); + store_ = new CastTrustStore(); + store_->trust_store_ = TrustStore::CreateInstanceFromPemFile(file_path, mode); + return store_; +} + CastTrustStore::CastTrustStore() { trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer)); trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer)); @@ -57,6 +70,9 @@ CastTrustStore::CastTrustStore(const std::vector<uint8_t>& trust_anchor_der) { trust_store_.certs.emplace_back(MakeTrustAnchor(trust_anchor_der)); } +CastTrustStore::CastTrustStore(TrustStore trust_store) + : trust_store_(std::move(trust_store)) {} + CastTrustStore::~CastTrustStore() = default; // static diff --git a/cast/common/certificate/cast_trust_store.h b/cast/common/certificate/cast_trust_store.h index 801d9274..7bd75955 100644 --- a/cast/common/certificate/cast_trust_store.h +++ b/cast/common/certificate/cast_trust_store.h @@ -7,6 +7,7 @@ #include <vector> +#include "absl/strings/string_view.h" #include "cast/common/certificate/cast_cert_validator_internal.h" namespace openscreen { @@ -20,8 +21,13 @@ class CastTrustStore { static CastTrustStore* CreateInstanceForTest( const std::vector<uint8_t>& trust_anchor_der); + static CastTrustStore* CreateInstanceFromPemFile( + absl::string_view file_path, + TrustStore::Mode mode = TrustStore::Mode::kStrict); + CastTrustStore(); explicit CastTrustStore(const std::vector<uint8_t>& trust_anchor_der); + explicit CastTrustStore(TrustStore trust_store); CastTrustStore(const CastTrustStore&) = delete; ~CastTrustStore(); CastTrustStore& operator=(const CastTrustStore&) = delete; diff --git a/cast/common/certificate/testing/test_helpers.cc b/cast/common/certificate/testing/test_helpers.cc index 113a4bc4..c28269de 100644 --- a/cast/common/certificate/testing/test_helpers.cc +++ b/cast/common/certificate/testing/test_helpers.cc @@ -17,58 +17,6 @@ namespace openscreen { namespace cast { namespace testing { -std::vector<std::string> ReadCertificatesFromPemFile( - absl::string_view filename) { - FILE* fp = fopen(filename.data(), "r"); - if (!fp) { - return {}; - } - std::vector<std::string> certs; - char* name; - char* header; - unsigned char* data; - long length; - while (PEM_read(fp, &name, &header, &data, &length) == 1) { - if (absl::StartsWith(name, "CERTIFICATE")) { - certs.emplace_back((char*)data, length); - } - OPENSSL_free(name); - OPENSSL_free(header); - OPENSSL_free(data); - } - fclose(fp); - return certs; -} - -bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename) { - FILE* fp = fopen(filename.data(), "r"); - if (!fp) { - return nullptr; - } - bssl::UniquePtr<EVP_PKEY> pkey; - char* name; - char* header; - unsigned char* data; - long length; - while (PEM_read(fp, &name, &header, &data, &length) == 1) { - if (absl::StartsWith(name, "RSA PRIVATE KEY")) { - OSP_DCHECK(!pkey); - CBS cbs; - CBS_init(&cbs, data, length); - RSA* rsa = RSA_parse_private_key(&cbs); - if (rsa) { - pkey.reset(EVP_PKEY_new()); - EVP_PKEY_assign_RSA(pkey.get(), rsa); - } - } - OPENSSL_free(name); - OPENSSL_free(header); - OPENSSL_free(data); - } - fclose(fp); - return pkey; -} - SignatureTestData::SignatureTestData() : message{nullptr, 0}, sha1{nullptr, 0}, sha256{nullptr, 0} {} @@ -85,8 +33,9 @@ SignatureTestData ReadSignatureTestData(absl::string_view filename) { char* name; char* header; unsigned char* data; - long length; - while (PEM_read(fp, &name, &header, &data, &length) == 1) { + int64_t length; + while (PEM_read(fp, &name, &header, &data, + reinterpret_cast<long*>(&length)) == 1) { if (strcmp(name, "MESSAGE") == 0) { OSP_DCHECK(!result.message.data); result.message.data = data; @@ -112,19 +61,6 @@ SignatureTestData ReadSignatureTestData(absl::string_view filename) { return result; } -std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile( - absl::string_view filename) { - std::unique_ptr<TrustStore> store = std::make_unique<TrustStore>(); - - std::vector<std::string> certs = - testing::ReadCertificatesFromPemFile(filename); - for (const auto& der_cert : certs) { - const uint8_t* data = (const uint8_t*)der_cert.data(); - store->certs.emplace_back(d2i_X509(nullptr, &data, der_cert.size())); - } - return store; -} - } // namespace testing } // namespace cast } // namespace openscreen diff --git a/cast/common/certificate/testing/test_helpers.h b/cast/common/certificate/testing/test_helpers.h index c1ff9a25..30715971 100644 --- a/cast/common/certificate/testing/test_helpers.h +++ b/cast/common/certificate/testing/test_helpers.h @@ -18,10 +18,6 @@ namespace openscreen { namespace cast { namespace testing { -std::vector<std::string> ReadCertificatesFromPemFile( - absl::string_view filename); -bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename); - class SignatureTestData { public: SignatureTestData(); @@ -34,9 +30,6 @@ class SignatureTestData { SignatureTestData ReadSignatureTestData(absl::string_view filename); -std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile( - absl::string_view filename); - } // namespace testing } // namespace cast } // namespace openscreen diff --git a/cast/common/channel/cast_socket_message_port.cc b/cast/common/channel/cast_socket_message_port.cc index 8d255e6d..2b596830 100644 --- a/cast/common/channel/cast_socket_message_port.cc +++ b/cast/common/channel/cast_socket_message_port.cc @@ -6,12 +6,16 @@ #include <utility> +#include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/virtual_connection.h" namespace openscreen { namespace cast { -CastSocketMessagePort::CastSocketMessagePort() = default; +CastSocketMessagePort::CastSocketMessagePort(VirtualConnectionRouter* router) + : router_(router) {} + CastSocketMessagePort::~CastSocketMessagePort() = default; // NOTE: we assume here that this message port is already the client for @@ -20,7 +24,7 @@ CastSocketMessagePort::~CastSocketMessagePort() = default; // client. The consumer of this message port should call SetClient with the new // message port client after setting the socket. void CastSocketMessagePort::SetSocket(WeakPtr<CastSocket> socket) { - client_ = nullptr; + ResetClient(); socket_ = socket; } @@ -28,28 +32,55 @@ int CastSocketMessagePort::GetSocketId() { return socket_ ? socket_->socket_id() : -1; } -void CastSocketMessagePort::SetClient(MessagePort::Client* client) { +void CastSocketMessagePort::SetClient(MessagePort::Client* client, + std::string client_sender_id) { client_ = client; + client_sender_id_ = std::move(client_sender_id); + router_->AddHandlerForLocalId(client_sender_id_, this); +} + +void CastSocketMessagePort::ResetClient() { + client_ = nullptr; + router_->RemoveHandlerForLocalId(client_sender_id_); + client_sender_id_.clear(); } -void CastSocketMessagePort::PostMessage(const std::string& sender_id, - const std::string& message_namespace, - const std::string& message) { +void CastSocketMessagePort::PostMessage( + const std::string& destination_sender_id, + const std::string& message_namespace, + const std::string& message) { ::cast::channel::CastMessage cast_message; - cast_message.set_source_id(sender_id.data(), sender_id.size()); + cast_message.set_protocol_version(::cast::channel::CastMessage::CASTV2_1_0); cast_message.set_namespace_(message_namespace.data(), message_namespace.size()); + cast_message.set_source_id(client_sender_id_.data(), + client_sender_id_.size()); + cast_message.set_destination_id(destination_sender_id.data(), + destination_sender_id.size()); + cast_message.set_payload_type(::cast::channel::CastMessage::STRING); cast_message.set_payload_utf8(message.data(), message.size()); if (!socket_) { client_->OnError(Error::Code::kAlreadyClosed); return; } + + // TODO(jophba): migrate to using VirtualConnectionRouter::Send(). Error error = socket_->Send(cast_message); if (!error.ok()) { client_->OnError(error); } } +void CastSocketMessagePort::OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) { + OSP_DCHECK(router == router_); + OSP_DCHECK(socket_.get() == socket); + OSP_DVLOG << "Received a cast socket message"; + client_->OnMessage(message.source_id(), message.namespace_(), + message.payload_utf8()); +} + } // namespace cast } // namespace openscreen diff --git a/cast/common/channel/cast_socket_message_port.h b/cast/common/channel/cast_socket_message_port.h index b2aeb96f..4dbd141c 100644 --- a/cast/common/channel/cast_socket_message_port.h +++ b/cast/common/channel/cast_socket_message_port.h @@ -9,6 +9,8 @@ #include <string> #include <vector> +#include "cast/common/channel/cast_message_handler.h" +#include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/common/public/message_port.h" #include "util/weak_ptr.h" @@ -16,9 +18,10 @@ namespace openscreen { namespace cast { -class CastSocketMessagePort : public MessagePort { +class CastSocketMessagePort : public MessagePort, public CastMessageHandler { public: - CastSocketMessagePort(); + // The router is expected to outlive this message port. + explicit CastSocketMessagePort(VirtualConnectionRouter* router); ~CastSocketMessagePort() override; void SetSocket(WeakPtr<CastSocket> socket); @@ -27,12 +30,21 @@ class CastSocketMessagePort : public MessagePort { int GetSocketId(); // MessagePort overrides. - void SetClient(MessagePort::Client* client) override; - void PostMessage(const std::string& sender_id, + void SetClient(MessagePort::Client* client, + std::string client_sender_id) override; + void ResetClient() override; + void PostMessage(const std::string& destination_sender_id, const std::string& message_namespace, const std::string& message) override; + // CastMessageHandler overrides. + void OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) override; + private: + VirtualConnectionRouter* const router_; + std::string client_sender_id_; MessagePort::Client* client_ = nullptr; WeakPtr<CastSocket> socket_; }; diff --git a/cast/common/public/message_port.h b/cast/common/public/message_port.h index aa614167..0e62dfe6 100644 --- a/cast/common/public/message_port.h +++ b/cast/common/public/message_port.h @@ -14,21 +14,23 @@ namespace cast { // This interface is intended to provide an abstraction for communicating // cast messages across a pipe with guaranteed delivery. This is used to -// decouple the cast receiver session (and potentially other classes) from any +// decouple the cast streaming receiver and sender sessions from the // network implementation. class MessagePort { public: class Client { public: - virtual void OnMessage(const std::string& sender_id, + virtual void OnMessage(const std::string& source_sender_id, const std::string& message_namespace, const std::string& message) = 0; virtual void OnError(Error error) = 0; }; virtual ~MessagePort() = default; - virtual void SetClient(Client* client) = 0; - virtual void PostMessage(const std::string& sender_id, + virtual void SetClient(Client* client, std::string client_sender_id) = 0; + virtual void ResetClient() = 0; + + virtual void PostMessage(const std::string& destination_sender_id, const std::string& message_namespace, const std::string& message) = 0; }; diff --git a/cast/receiver/channel/static_credentials.cc b/cast/receiver/channel/static_credentials.cc index 9883e982..73a5d95f 100644 --- a/cast/receiver/channel/static_credentials.cc +++ b/cast/receiver/channel/static_credentials.cc @@ -5,7 +5,9 @@ #include "cast/receiver/channel/static_credentials.h" #include <openssl/mem.h> +#include <openssl/pem.h> +#include <cstdio> #include <memory> #include <string> #include <utility> @@ -19,47 +21,26 @@ namespace openscreen { namespace cast { namespace { +using FileUniquePtr = std::unique_ptr<FILE, decltype(&fclose)>; + constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60; constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds); -} // namespace - -StaticCredentialsProvider::StaticCredentialsProvider() = default; -StaticCredentialsProvider::StaticCredentialsProvider( - DeviceCredentials device_creds, - std::vector<uint8_t> tls_cert_der) - : device_creds(std::move(device_creds)), - tls_cert_der(std::move(tls_cert_der)) {} - -StaticCredentialsProvider::StaticCredentialsProvider( - StaticCredentialsProvider&&) = default; -StaticCredentialsProvider& StaticCredentialsProvider::operator=( - StaticCredentialsProvider&&) = default; -StaticCredentialsProvider::~StaticCredentialsProvider() = default; - ErrorOr<GeneratedCredentials> GenerateCredentials( - absl::string_view device_certificate_id) { - GeneratedCredentials credentials; - - // Device cert chain generation. - bssl::UniquePtr<EVP_PKEY> root_key = GenerateRsaKeyPair(); + std::string device_certificate_id, + EVP_PKEY* root_key, + X509* root_cert) { + OSP_CHECK(root_key); + OSP_CHECK(root_cert); bssl::UniquePtr<EVP_PKEY> intermediate_key = GenerateRsaKeyPair(); bssl::UniquePtr<EVP_PKEY> device_key = GenerateRsaKeyPair(); - OSP_CHECK(root_key); OSP_CHECK(intermediate_key); OSP_CHECK(device_key); - ErrorOr<bssl::UniquePtr<X509>> root_cert_or_error = - CreateSelfSignedX509Certificate("Cast Root CA", kCertificateDuration, - *root_key, GetWallTimeSinceUnixEpoch(), - true); - OSP_CHECK(root_cert_or_error); - bssl::UniquePtr<X509> root_cert = std::move(root_cert_or_error.value()); - ErrorOr<bssl::UniquePtr<X509>> intermediate_cert_or_error = CreateSelfSignedX509Certificate( "Cast Intermediate", kCertificateDuration, *intermediate_key, - GetWallTimeSinceUnixEpoch(), true, root_cert.get(), root_key.get()); + GetWallTimeSinceUnixEpoch(), true, root_cert, root_key); OSP_CHECK(intermediate_cert_or_error); bssl::UniquePtr<X509> intermediate_cert = std::move(intermediate_cert_or_error.value()); @@ -88,10 +69,10 @@ ErrorOr<GeneratedCredentials> GenerateCredentials( i2d_X509(intermediate_cert.get(), &out); device_creds.certs.emplace_back(std::move(cert_serial)); - cert_length = i2d_X509(root_cert.get(), nullptr); + cert_length = i2d_X509(root_cert, nullptr); std::vector<uint8_t> trust_anchor_der(cert_length); out = &trust_anchor_der[0]; - i2d_X509(root_cert.get(), &out); + i2d_X509(root_cert, &out); // TLS key pair + certificate generation. bssl::UniquePtr<EVP_PKEY> tls_key = GenerateRsaKeyPair(); @@ -136,5 +117,66 @@ ErrorOr<GeneratedCredentials> GenerateCredentials( std::move(trust_anchor_der)}; } +bssl::UniquePtr<X509> GenerateRootCert(const EVP_PKEY& root_key) { + ErrorOr<bssl::UniquePtr<X509>> root_cert_or_error = + CreateSelfSignedX509Certificate("Cast Root CA", kCertificateDuration, + root_key, GetWallTimeSinceUnixEpoch(), + true); + OSP_CHECK(root_cert_or_error); + return std::move(root_cert_or_error.value()); +} +} // namespace + +StaticCredentialsProvider::StaticCredentialsProvider() = default; +StaticCredentialsProvider::StaticCredentialsProvider( + DeviceCredentials device_creds, + std::vector<uint8_t> tls_cert_der) + : device_creds(std::move(device_creds)), + tls_cert_der(std::move(tls_cert_der)) {} + +StaticCredentialsProvider::StaticCredentialsProvider( + StaticCredentialsProvider&&) = default; +StaticCredentialsProvider& StaticCredentialsProvider::operator=( + StaticCredentialsProvider&&) = default; +StaticCredentialsProvider::~StaticCredentialsProvider() = default; + +ErrorOr<GeneratedCredentials> GenerateCredentials( + const std::string& device_certificate_id) { + bssl::UniquePtr<EVP_PKEY> root_key = GenerateRsaKeyPair(); + OSP_CHECK(root_key); + + bssl::UniquePtr<X509> root_cert = GenerateRootCert(*root_key); + OSP_CHECK(root_cert); + + return GenerateCredentials(device_certificate_id, root_key.get(), + root_cert.get()); +} + +ErrorOr<GeneratedCredentials> GenerateCredentials( + const std::string& device_certificate_id, + const std::string& private_key_path, + const std::string& server_certificate_path) { + OSP_CHECK(!private_key_path.empty() && !server_certificate_path.empty()); + + FileUniquePtr key_file(fopen(private_key_path.c_str(), "r"), &fclose); + if (!key_file) { + return Error(Error::Code::kParameterInvalid, + "Missing private key file path"); + } + bssl::UniquePtr<EVP_PKEY> root_key(PEM_read_PrivateKey( + key_file.get(), nullptr /* x */, nullptr /* cb */, nullptr /* u */)); + + FileUniquePtr cert_file(fopen(server_certificate_path.c_str(), "r"), &fclose); + if (!cert_file) { + return Error(Error::Code::kParameterInvalid, + "Missing server certificate file path"); + } + bssl::UniquePtr<X509> root_cert(PEM_read_X509( + cert_file.get(), nullptr /* x */, nullptr /* cb */, nullptr /* u */)); + + return GenerateCredentials(device_certificate_id, root_key.get(), + root_cert.get()); +} + } // namespace cast } // namespace openscreen diff --git a/cast/receiver/channel/static_credentials.h b/cast/receiver/channel/static_credentials.h index e886da7f..97b90cc8 100644 --- a/cast/receiver/channel/static_credentials.h +++ b/cast/receiver/channel/static_credentials.h @@ -6,6 +6,7 @@ #define CAST_RECEIVER_CHANNEL_STATIC_CREDENTIALS_H_ #include <memory> +#include <string> #include <vector> #include "absl/strings/string_view.h" @@ -53,7 +54,12 @@ struct GeneratedCredentials { // stored in private_key_der.h. The certificate is valid for // kCertificateDuration from when this function is called. ErrorOr<GeneratedCredentials> GenerateCredentials( - absl::string_view device_certificate_id); + const std::string& device_certificate_id); + +ErrorOr<GeneratedCredentials> GenerateCredentials( + const std::string& device_certificate_id, + const std::string& private_key_path, + const std::string& server_certificate_path); } // namespace cast } // namespace openscreen diff --git a/cast/receiver/channel/testing/device_auth_test_helpers.cc b/cast/receiver/channel/testing/device_auth_test_helpers.cc index 904f9f80..77237dad 100644 --- a/cast/receiver/channel/testing/device_auth_test_helpers.cc +++ b/cast/receiver/channel/testing/device_auth_test_helpers.cc @@ -9,6 +9,7 @@ #include "cast/common/certificate/testing/test_helpers.h" #include "gtest/gtest.h" +#include "util/crypto/pem_helpers.h" namespace openscreen { namespace cast { @@ -19,10 +20,9 @@ void InitStaticCredentialsFromFiles(StaticCredentialsProvider* creds, absl::string_view privkey_filename, absl::string_view chain_filename, absl::string_view tls_filename) { - auto private_key = testing::ReadKeyFromPemFile(privkey_filename); + auto private_key = ReadKeyFromPemFile(privkey_filename); ASSERT_TRUE(private_key); - std::vector<std::string> certs = - testing::ReadCertificatesFromPemFile(chain_filename); + std::vector<std::string> certs = ReadCertificatesFromPemFile(chain_filename); ASSERT_GT(certs.size(), 1u); // Use the root of the chain as the trust store for the test. @@ -39,7 +39,7 @@ void InitStaticCredentialsFromFiles(StaticCredentialsProvider* creds, std::move(certs), std::move(private_key), std::string()}; const std::vector<std::string> tls_cert = - testing::ReadCertificatesFromPemFile(tls_filename); + ReadCertificatesFromPemFile(tls_filename); ASSERT_EQ(tls_cert.size(), 1u); data = reinterpret_cast<const uint8_t*>(tls_cert[0].data()); if (parsed_cert) { diff --git a/cast/sender/channel/cast_auth_util.cc b/cast/sender/channel/cast_auth_util.cc index 10cbdc45..cb1ced69 100644 --- a/cast/sender/channel/cast_auth_util.cc +++ b/cast/sender/channel/cast_auth_util.cc @@ -7,6 +7,7 @@ #include <openssl/rand.h> #include <algorithm> +#include <memory> #include "cast/common/certificate/cast_cert_validator.h" #include "cast/common/certificate/cast_cert_validator_internal.h" @@ -29,13 +30,13 @@ namespace { #define PARSE_ERROR_PREFIX "Failed to parse auth message: " // The maximum number of days a cert can live for. -const int kMaxSelfSignedCertLifetimeInDays = 4; +constexpr int kMaxSelfSignedCertLifetimeInDays = 4; // The size of the nonce challenge in bytes. -const int kNonceSizeInBytes = 16; +constexpr int kNonceSizeInBytes = 16; // The number of hours after which a nonce is regenerated. -long kNonceExpirationTimeInHours = 24; +constexpr int kNonceExpirationTimeInHours = 24; // Extracts an embedded DeviceAuthMessage payload from an auth challenge reply // message. @@ -122,6 +123,9 @@ Error MapToOpenscreenError(Error::Code error, bool crl_required) { case Error::Code::kErrCertsRestrictions: return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, "Failed certificate restrictions."); + case Error::Code::kErrCertsVerifyUntrustedCert: + return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, + "Failed with untrusted certificate."); case Error::Code::kErrCrlInvalid: // This error is only encountered if |crl_required| is true. OSP_DCHECK(crl_required); diff --git a/cast/sender/channel/cast_auth_util_unittest.cc b/cast/sender/channel/cast_auth_util_unittest.cc index 03655419..acdb07a2 100644 --- a/cast/sender/channel/cast_auth_util_unittest.cc +++ b/cast/sender/channel/cast_auth_util_unittest.cc @@ -15,6 +15,7 @@ #include "platform/api/time.h" #include "platform/test/paths.h" #include "testing/util/read_file.h" +#include "util/crypto/pem_helpers.h" #include "util/osp_logging.h" namespace openscreen { @@ -124,7 +125,7 @@ class CastAuthUtilTest : public ::testing::Test { static AuthResponse CreateAuthResponse( std::vector<uint8_t>* signed_data, ::cast::channel::HashAlgorithm digest_algorithm) { - std::vector<std::string> chain = testing::ReadCertificatesFromPemFile( + std::vector<std::string> chain = ReadCertificatesFromPemFile( GetSpecificTestDataPath() + "certificates/chromecast_gen1.pem"); OSP_CHECK(!chain.empty()); @@ -292,7 +293,7 @@ TEST_F(CastAuthUtilTest, VerifySenderNonceMissing) { } TEST_F(CastAuthUtilTest, VerifyTLSCertificateSuccess) { - std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile( + std::vector<std::string> tls_cert_der = ReadCertificatesFromPemFile( data_path_ + "certificates/test_tls_cert.pem"); std::string& der_cert = tls_cert_der[0]; const uint8_t* data = (const uint8_t*)der_cert.data(); @@ -310,7 +311,7 @@ TEST_F(CastAuthUtilTest, VerifyTLSCertificateSuccess) { } TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) { - std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile( + std::vector<std::string> tls_cert_der = ReadCertificatesFromPemFile( data_path_ + "certificates/test_tls_cert.pem"); std::string& der_cert = tls_cert_der[0]; const uint8_t* data = (const uint8_t*)der_cert.data(); @@ -331,7 +332,7 @@ TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) { } TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooLate) { - std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile( + std::vector<std::string> tls_cert_der = ReadCertificatesFromPemFile( data_path_ + "certificates/test_tls_cert.pem"); std::string& der_cert = tls_cert_der[0]; const uint8_t* data = (const uint8_t*)der_cert.data(); @@ -392,16 +393,16 @@ ErrorOr<CastDeviceCertPolicy> TestVerifyRevocation( // Runs a single test case. bool RunTest(const DeviceCertTest& test_case) { - std::unique_ptr<TrustStore> crl_trust_store; - std::unique_ptr<TrustStore> cast_trust_store; + TrustStore crl_trust_store; + TrustStore cast_trust_store; if (test_case.use_test_trust_anchors()) { - crl_trust_store = testing::CreateTrustStoreFromPemFile( + crl_trust_store = TrustStore::CreateInstanceFromPemFile( GetSpecificTestDataPath() + "certificates/cast_crl_test_root_ca.pem"); - cast_trust_store = testing::CreateTrustStoreFromPemFile( + cast_trust_store = TrustStore::CreateInstanceFromPemFile( GetSpecificTestDataPath() + "certificates/cast_test_root_ca.pem"); - EXPECT_FALSE(crl_trust_store->certs.empty()); - EXPECT_FALSE(cast_trust_store->certs.empty()); + EXPECT_FALSE(crl_trust_store.certs.empty()); + EXPECT_FALSE(cast_trust_store.certs.empty()); } std::vector<std::string> certificate_chain; @@ -421,9 +422,9 @@ bool RunTest(const DeviceCertTest& test_case) { ErrorOr<CastDeviceCertPolicy> result(CastDeviceCertPolicy::kUnrestricted); switch (test_case.expected_result()) { case ::cast::certificate::PATH_VERIFICATION_FAILED: - result = TestVerifyRevocation( - certificate_chain, crl_bundle, verification_time, false, - cast_trust_store.get(), crl_trust_store.get()); + result = + TestVerifyRevocation(certificate_chain, crl_bundle, verification_time, + false, &cast_trust_store, &cast_trust_store); EXPECT_EQ(result.error().code(), Error::Code::kCastV2CertNotSignedByTrustedCa); return result.error().code() == @@ -431,9 +432,9 @@ bool RunTest(const DeviceCertTest& test_case) { case ::cast::certificate::CRL_VERIFICATION_FAILED: // Fall-through intended. case ::cast::certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL: - result = TestVerifyRevocation( - certificate_chain, crl_bundle, verification_time, true, - cast_trust_store.get(), crl_trust_store.get()); + result = + TestVerifyRevocation(certificate_chain, crl_bundle, verification_time, + true, &cast_trust_store, &cast_trust_store); EXPECT_EQ(result.error().code(), Error::Code::kErrCrlInvalid); return result.error().code() == Error::Code::kErrCrlInvalid; case ::cast::certificate::CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: @@ -441,15 +442,15 @@ bool RunTest(const DeviceCertTest& test_case) { // certificate is verified. return true; case ::cast::certificate::REVOCATION_CHECK_FAILED: - result = TestVerifyRevocation( - certificate_chain, crl_bundle, verification_time, true, - cast_trust_store.get(), crl_trust_store.get()); + result = + TestVerifyRevocation(certificate_chain, crl_bundle, verification_time, + true, &cast_trust_store, &cast_trust_store); EXPECT_EQ(result.error().code(), Error::Code::kErrCertsRevoked); return result.error().code() == Error::Code::kErrCertsRevoked; case ::cast::certificate::SUCCESS: - result = TestVerifyRevocation( - certificate_chain, crl_bundle, verification_time, false, - cast_trust_store.get(), crl_trust_store.get()); + result = + TestVerifyRevocation(certificate_chain, crl_bundle, verification_time, + false, &cast_trust_store, &cast_trust_store); EXPECT_EQ(result.error().code(), Error::Code::kCastV2SignedBlobsMismatch); return result.error().code() == Error::Code::kCastV2SignedBlobsMismatch; case ::cast::certificate::UNSPECIFIED: diff --git a/cast/standalone_receiver/cast_agent.cc b/cast/standalone_receiver/cast_agent.cc index 99ff8409..d574d4f7 100644 --- a/cast/standalone_receiver/cast_agent.cc +++ b/cast/standalone_receiver/cast_agent.cc @@ -59,6 +59,8 @@ Error CastAgent::Start() { task_runner_, credentials_provider_); router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_, &connection_manager_); + message_port_ = + MakeSerialDelete<CastSocketMessagePort>(task_runner_, router_.get()); router_->AddHandlerForLocalId(kPlatformReceiverId, auth_handler_.get()); socket_factory_ = MakeSerialDelete<ReceiverSocketFactory>( task_runner_, this, router_.get()); @@ -95,12 +97,12 @@ void CastAgent::OnConnected(ReceiverSocketFactory* factory, } OSP_LOG_INFO << "Received connection from peer at: " << endpoint; - message_port_.SetSocket(socket->GetWeakPtr()); + message_port_->SetSocket(socket->GetWeakPtr()); router_->TakeSocket(this, std::move(socket)); controller_ = std::make_unique<StreamingPlaybackController>(task_runner_, this); current_session_ = std::make_unique<ReceiverSession>( - controller_.get(), environment_.get(), &message_port_, + controller_.get(), environment_.get(), message_port_.get(), ReceiverSession::Preferences{}); } @@ -155,10 +157,10 @@ void CastAgent::OnPlaybackError(StreamingPlaybackController* controller, } void CastAgent::StopCurrentSession() { - controller_.reset(); current_session_.reset(); - router_->CloseSocket(message_port_.GetSocketId()); - message_port_.SetSocket(nullptr); + controller_.reset(); + router_->CloseSocket(message_port_->GetSocketId()); + message_port_->SetSocket(nullptr); } } // namespace cast diff --git a/cast/standalone_receiver/cast_agent.h b/cast/standalone_receiver/cast_agent.h index e7b92200..db8cf668 100644 --- a/cast/standalone_receiver/cast_agent.h +++ b/cast/standalone_receiver/cast_agent.h @@ -87,7 +87,6 @@ class CastAgent final : public ReceiverSocketFactory::Client, TaskRunner* const task_runner_; IPEndpoint receive_endpoint_; DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider_; - CastSocketMessagePort message_port_; TlsCredentials tls_credentials_; // Member variables set as part of starting up. @@ -95,6 +94,7 @@ class CastAgent final : public ReceiverSocketFactory::Client, SerialDeletePtr<TlsConnectionFactory> connection_factory_; VirtualConnectionManager connection_manager_; SerialDeletePtr<VirtualConnectionRouter> router_; + SerialDeletePtr<CastSocketMessagePort> message_port_; SerialDeletePtr<ReceiverSocketFactory> socket_factory_; SerialDeletePtr<ScopedWakeLock> wake_lock_; diff --git a/cast/standalone_receiver/main.cc b/cast/standalone_receiver/main.cc index 270b1f66..1c497f51 100644 --- a/cast/standalone_receiver/main.cc +++ b/cast/standalone_receiver/main.cc @@ -118,7 +118,13 @@ options: interface Specifies the network interface to bind to. The interface is looked up from the system interface registry. - mandatory, as it must be known for publishing discovery. + Mandatory, as it must be known for publishing discovery. + + -p, --private-key=path-to-key: Path to OpenSSL-generated private key to be + used for TLS authentication. + + -s, --server-certificate=path-to-cert: Path to PEM file containing a + server certificate to be used for TLS authentication. -t, --tracing: Enable performance tracing logging. @@ -157,6 +163,8 @@ int RunStandaloneReceiver(int argc, char* argv[]) { // being exposed, consider if it applies to the standalone receiver, // standalone sender, osp demo, and test_main argument options. const struct option kArgumentOptions[] = { + {"private-key", required_argument, nullptr, 'p'}, + {"server-certificate", required_argument, nullptr, 's'}, {"tracing", no_argument, nullptr, 't'}, {"verbose", no_argument, nullptr, 'v'}, {"help", no_argument, nullptr, 'h'}, @@ -168,11 +176,19 @@ int RunStandaloneReceiver(int argc, char* argv[]) { bool is_verbose = false; bool discovery_enabled = true; + std::string private_key_path; + std::string server_certificate_path; std::unique_ptr<openscreen::TextTraceLoggingPlatform> trace_logger; int ch = -1; - while ((ch = getopt_long(argc, argv, "tvhx", kArgumentOptions, nullptr)) != - -1) { + while ((ch = getopt_long(argc, argv, "p:s:tvhx", kArgumentOptions, + nullptr)) != -1) { switch (ch) { + case 'p': + private_key_path = optarg; + break; + case 's': + server_certificate_path = optarg; + break; case 't': trace_logger = std::make_unique<openscreen::TextTraceLoggingPlatform>(); break; @@ -187,6 +203,11 @@ int RunStandaloneReceiver(int argc, char* argv[]) { return 1; } } + if (private_key_path.empty() != server_certificate_path.empty()) { + OSP_LOG_ERROR << "If a private key or server certificate path is provided, " + "both are required."; + return 1; + } SetLogLevel(is_verbose ? openscreen::LogLevel::kVerbose : openscreen::LogLevel::kInfo); @@ -202,8 +223,15 @@ int RunStandaloneReceiver(int argc, char* argv[]) { OSP_CHECK(interface_name && strlen(interface_name) > 0) << "No interface name provided."; - auto creds = GenerateCredentials( - absl::StrCat("Standalone Receiver on ", interface_name)); + std::string device_id = + absl::StrCat("Standalone Receiver on ", interface_name); + ErrorOr<GeneratedCredentials> creds = Error::Code::kEVPInitializationError; + if (private_key_path.empty()) { + creds = GenerateCredentials(device_id); + } else { + creds = GenerateCredentials(device_id, private_key_path, + server_certificate_path); + } OSP_CHECK(creds.is_value()) << creds.error(); task_runner->PostTask( [&, interface = GetInterfaceInfoFromName(interface_name)] { diff --git a/cast/standalone_sender/looping_file_cast_agent.cc b/cast/standalone_sender/looping_file_cast_agent.cc index 4fec432e..c6c9ea7c 100644 --- a/cast/standalone_sender/looping_file_cast_agent.cc +++ b/cast/standalone_sender/looping_file_cast_agent.cc @@ -25,6 +25,8 @@ LoopingFileCastAgent::LoopingFileCastAgent(TaskRunner* task_runner) : task_runner_(task_runner) { router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_, &connection_manager_); + message_port_ = + MakeSerialDelete<CastSocketMessagePort>(task_runner_, router_.get()); socket_factory_ = MakeSerialDelete<SenderSocketFactory>(task_runner_, this, task_runner_); connection_factory_ = SerialDeletePtr<TlsConnectionFactory>( @@ -43,6 +45,7 @@ void LoopingFileCastAgent::Connect(ConnectionSettings settings) { : DeviceMediaPolicy::kAudioOnly; task_runner_->PostTask([this, policy] { + wake_lock_ = ScopedWakeLock::Create(task_runner_); socket_factory_->Connect(connection_settings_->receiver_endpoint, policy, router_.get()); }); @@ -68,7 +71,8 @@ void LoopingFileCastAgent::OnConnected(SenderSocketFactory* factory, } OSP_LOG_INFO << "Received connection from peer at: " << endpoint; - message_port_.SetSocket(socket->GetWeakPtr()); + message_port_->SetSocket(socket->GetWeakPtr()); + router_->TakeSocket(this, std::move(socket)); CreateAndStartSession(); } @@ -117,20 +121,27 @@ void LoopingFileCastAgent::CreateAndStartSession() { std::make_unique<Environment>(&Clock::now, task_runner_, IPEndpoint{}); current_session_ = std::make_unique<SenderSession>( connection_settings_->receiver_endpoint.address, this, environment_.get(), - &message_port_); + message_port_.get()); AudioCaptureConfig audio_config; VideoCaptureConfig video_config; // Use default display resolution of 1080P. video_config.resolutions.emplace_back(DisplayResolution{}); - current_session_->Negotiate({audio_config}, {video_config}); + + OSP_VLOG << "Starting session negotiation."; + const Error negotiation_error = + current_session_->Negotiate({audio_config}, {video_config}); + if (!negotiation_error.ok()) { + OSP_LOG_ERROR << "Failed to negotiate a session: " << negotiation_error; + } } void LoopingFileCastAgent::StopCurrentSession() { current_session_.reset(); environment_.reset(); file_sender_.reset(); - message_port_.SetSocket(nullptr); + router_->CloseSocket(message_port_->GetSocketId()); + message_port_->SetSocket(nullptr); } } // namespace cast diff --git a/cast/standalone_sender/looping_file_cast_agent.h b/cast/standalone_sender/looping_file_cast_agent.h index 3a7529b7..abe91c96 100644 --- a/cast/standalone_sender/looping_file_cast_agent.h +++ b/cast/standalone_sender/looping_file_cast_agent.h @@ -94,16 +94,16 @@ class LoopingFileCastAgent final void StopCurrentSession(); // Member variables set as part of construction. - std::unique_ptr<Environment> environment_; + VirtualConnectionManager connection_manager_; TaskRunner* const task_runner_; - CastSocketMessagePort message_port_; + SerialDeletePtr<VirtualConnectionRouter> router_; + SerialDeletePtr<CastSocketMessagePort> message_port_; + SerialDeletePtr<SenderSocketFactory> socket_factory_; + SerialDeletePtr<TlsConnectionFactory> connection_factory_; // Member variables set as part of starting up. + std::unique_ptr<Environment> environment_; absl::optional<ConnectionSettings> connection_settings_; - SerialDeletePtr<VirtualConnectionRouter> router_; - SerialDeletePtr<TlsConnectionFactory> connection_factory_; - VirtualConnectionManager connection_manager_; - SerialDeletePtr<SenderSocketFactory> socket_factory_; SerialDeletePtr<ScopedWakeLock> wake_lock_; // Member variables set as part of a sender connection. diff --git a/cast/standalone_sender/looping_file_sender.cc b/cast/standalone_sender/looping_file_sender.cc index 2ca986c0..ff9bdd05 100644 --- a/cast/standalone_sender/looping_file_sender.cc +++ b/cast/standalone_sender/looping_file_sender.cc @@ -137,8 +137,8 @@ void LoopingFileSender::UpdateStatusOnConsole() { // there might sometimes be old status lines not getting erased (i.e., just // partially overwritten). fprintf(stdout, - "\r\x1b[2K\rAt %01" PRId64 - ".%03ds in file (est. network bandwidth: %d kbps). ", + "\r\x1b[2K\rLoopingFileSender: At %01" PRId64 + ".%03ds in file (est. network bandwidth: %d kbps). \n", static_cast<int64_t>(seconds_part.count()), static_cast<int>(millis_part.count()), bandwidth_estimate_ / 1024); fflush(stdout); diff --git a/cast/standalone_sender/main.cc b/cast/standalone_sender/main.cc index b567b9af..02c2b4ee 100644 --- a/cast/standalone_sender/main.cc +++ b/cast/standalone_sender/main.cc @@ -15,6 +15,7 @@ #include <iostream> #include <sstream> +#include "cast/common/certificate/cast_trust_store.h" #include "cast/standalone_sender/constants.h" #include "cast/standalone_sender/looping_file_cast_agent.h" #include "cast/streaming/constants.h" @@ -57,6 +58,11 @@ void LogUsage(const char* argv0) { Default if not set: )" << kDefaultMaxBitrate << R"(. + -s, --server-certificate=path-to-cert + Specifies the path to the server certificate used by the receiver. + If omitted, only connections to receivers using an official + Google-signed cast certificate chain will be permitted. + -a, --android-hack: Use the wrong RTP payload types, for compatibility with older Android TV receivers. @@ -77,6 +83,7 @@ int StandaloneSenderMain(int argc, char* argv[]) { const struct option kArgumentOptions[] = { {"remote", required_argument, nullptr, 'r'}, {"max-bitrate", required_argument, nullptr, 'm'}, + {"server-certificate", required_argument, nullptr, 's'}, {"android-hack", no_argument, nullptr, 'a'}, {"tracing", no_argument, nullptr, 't'}, {"verbose", no_argument, nullptr, 'v'}, @@ -85,12 +92,13 @@ int StandaloneSenderMain(int argc, char* argv[]) { bool is_verbose = false; IPEndpoint remote_endpoint = GetDefaultEndpoint(); + std::string server_certificate_path; [[maybe_unused]] bool use_android_rtp_hack = false; [[maybe_unused]] int max_bitrate = kDefaultMaxBitrate; std::unique_ptr<TextTraceLoggingPlatform> trace_logger; int ch = -1; - while ((ch = getopt_long(argc, argv, "r:atvh", kArgumentOptions, nullptr)) != - -1) { + while ((ch = getopt_long(argc, argv, "r:m:s:atvh", kArgumentOptions, + nullptr)) != -1) { switch (ch) { case 'r': { const ErrorOr<IPEndpoint> parsed_endpoint = IPEndpoint::Parse(optarg); @@ -117,6 +125,9 @@ int StandaloneSenderMain(int argc, char* argv[]) { return 1; } break; + case 's': + server_certificate_path = optarg; + break; case 'a': use_android_rtp_hack = true; break; @@ -145,6 +156,11 @@ int StandaloneSenderMain(int argc, char* argv[]) { return 1; } + if (!server_certificate_path.empty()) { + CastTrustStore::CreateInstanceFromPemFile( + server_certificate_path, TrustStore::Mode::kAllowSelfSigned); + } + auto* const task_runner = new TaskRunnerImpl(&Clock::now); PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}, std::unique_ptr<TaskRunnerImpl>(task_runner)); diff --git a/cast/streaming/constants.h b/cast/streaming/constants.h index 7b32d4d8..65cb2a1a 100644 --- a/cast/streaming/constants.h +++ b/cast/streaming/constants.h @@ -60,11 +60,16 @@ constexpr int kDefaultFrameRate = 30; // The default audio sample rate is 48kHz, slightly higher than standard // consumer audio. -constexpr int kDefaultAudioSampleRate = 480000; +constexpr int kDefaultAudioSampleRate = 48000; // The default audio number of channels is set to stereo. constexpr int kDefaultAudioChannels = 2; +// TODO(jophba): migrate to discovering a randomly generated streaming +// sender id. This will require communicating the ID to the sender so that +// it can send messages appropriately. +constexpr char kDefaultStreamingReceiverSenderId[] = "receiver-12345"; + // Codecs known and understood by cast senders and receivers. Note: receivers // are required to implement the following codecs to be Cast V2 compliant: H264, // VP8, AAC, Opus. Senders have to implement at least one codec for audio and diff --git a/cast/streaming/frame_crypto.h b/cast/streaming/frame_crypto.h index f8d25fcf..a86153e3 100644 --- a/cast/streaming/frame_crypto.h +++ b/cast/streaming/frame_crypto.h @@ -49,7 +49,7 @@ class FrameCrypto { public: // Construct with the given 16-bytes AES key and IV mask. Both arguments // should be randomly-generated for each new streaming session. - // crypto::GenerateRandomBytes() can be used to create them. + // GenerateRandomBytes() can be used to create them. FrameCrypto(const std::array<uint8_t, 16>& aes_key, const std::array<uint8_t, 16>& cast_iv_mask); diff --git a/cast/streaming/frame_crypto_unittest.cc b/cast/streaming/frame_crypto_unittest.cc index 5fa3d7c0..a845ed03 100644 --- a/cast/streaming/frame_crypto_unittest.cc +++ b/cast/streaming/frame_crypto_unittest.cc @@ -29,8 +29,8 @@ TEST(FrameCryptoTest, EncryptsAndDecryptsFrames) { frame1.frame_id = frame0.frame_id + 1; frame1.data = frame0.data; - const std::array<uint8_t, 16> key = crypto::GenerateRandomBytes16(); - const std::array<uint8_t, 16> iv = crypto::GenerateRandomBytes16(); + const std::array<uint8_t, 16> key = GenerateRandomBytes16(); + const std::array<uint8_t, 16> iv = GenerateRandomBytes16(); EXPECT_NE(0, memcmp(key.data(), iv.data(), sizeof(key))); const FrameCrypto crypto(key, iv); diff --git a/cast/streaming/receiver_session.cc b/cast/streaming/receiver_session.cc index 46152131..cea77c50 100644 --- a/cast/streaming/receiver_session.cc +++ b/cast/streaming/receiver_session.cc @@ -102,12 +102,12 @@ ReceiverSession::ReceiverSession(Client* const client, OSP_DCHECK(message_port_); OSP_DCHECK(environment_); - message_port_->SetClient(this); + message_port_->SetClient(this, kDefaultStreamingReceiverSenderId); } ReceiverSession::~ReceiverSession() { ResetReceivers(Client::kEndOfSession); - message_port_->SetClient(nullptr); + message_port_->ResetClient(); } void ReceiverSession::OnMessage(const std::string& sender_id, diff --git a/cast/streaming/rtp_packetizer_unittest.cc b/cast/streaming/rtp_packetizer_unittest.cc index bfea67a9..1c3cd97a 100644 --- a/cast/streaming/rtp_packetizer_unittest.cc +++ b/cast/streaming/rtp_packetizer_unittest.cc @@ -127,8 +127,7 @@ class RtpPacketizerTest : public testing::Test { // The RtpPacketizer instance under test, plus some surrounding dependencies // to generate its input and examine its output. const Ssrc ssrc_{GenerateSsrc(true)}; - const FrameCrypto crypto_{crypto::GenerateRandomBytes16(), - crypto::GenerateRandomBytes16()}; + const FrameCrypto crypto_{GenerateRandomBytes16(), GenerateRandomBytes16()}; RtpPacketizer packetizer_{kPayloadType, ssrc_, kMaxRtpPacketSizeForIpv4UdpOnEthernet}; RtpPacketParser parser_{ssrc_}; diff --git a/cast/streaming/sender_session.cc b/cast/streaming/sender_session.cc index c153d076..897e7560 100644 --- a/cast/streaming/sender_session.cc +++ b/cast/streaming/sender_session.cc @@ -34,19 +34,22 @@ namespace cast { namespace { AudioStream CreateStream(int index, const AudioCaptureConfig& config) { - return AudioStream{Stream{index, - Stream::Type::kAudioSource, - config.channels, - CodecToString(config.codec), - GetPayloadType(config.codec), - GenerateSsrc(true /*high_priority*/), - config.target_playout_delay, - crypto::GenerateRandomBytes16(), - crypto::GenerateRandomBytes16(), - false /* receiver_rtcp_event_log */, - {} /* receiver_rtcp_dscp */, - config.sample_rate}, - config.bit_rate}; + return AudioStream{ + Stream{index, + Stream::Type::kAudioSource, + config.channels, + CodecToString(config.codec), + GetPayloadType(config.codec), + GenerateSsrc(true /*high_priority*/), + config.target_playout_delay, + GenerateRandomBytes16(), + GenerateRandomBytes16(), + false /* receiver_rtcp_event_log */, + {} /* receiver_rtcp_dscp */, + config.sample_rate}, + (config.bit_rate >= capture_recommendations::kDefaultAudioMinBitRate) + ? config.bit_rate + : capture_recommendations::kDefaultAudioMaxBitRate}; } Resolution ToResolution(const DisplayResolution& display_resolution) { @@ -67,17 +70,20 @@ VideoStream CreateStream(int index, const VideoCaptureConfig& config) { GetPayloadType(config.codec), GenerateSsrc(false /*high_priority*/), config.target_playout_delay, - crypto::GenerateRandomBytes16(), - crypto::GenerateRandomBytes16(), + GenerateRandomBytes16(), + GenerateRandomBytes16(), false /* receiver_rtcp_event_log */, {} /* receiver_rtcp_dscp */, kRtpVideoTimebase}, SimpleFraction{config.max_frame_rate.numerator, config.max_frame_rate.denominator}, - config.max_bit_rate, - {}, - {}, - {}, // protection, profile, level + (config.max_bit_rate > + capture_recommendations::kDefaultVideoBitRateLimits.minimum) + ? config.max_bit_rate + : capture_recommendations::kDefaultVideoBitRateLimits.maximum, + {}, // protection + {}, // profile + {}, // protection std::move(resolutions), {} /* error_recovery mode, always "castv2" */ }; @@ -111,7 +117,7 @@ Offer CreateOffer(const std::vector<AudioCaptureConfig>& audio_configs, } bool IsValidAudioCaptureConfig(const AudioCaptureConfig& config) { - return config.channels >= 1 && config.bit_rate > 0; + return config.channels >= 1 && config.bit_rate >= 0; } bool IsValidResolution(const DisplayResolution& resolution) { @@ -121,7 +127,10 @@ bool IsValidResolution(const DisplayResolution& resolution) { bool IsValidVideoCaptureConfig(const VideoCaptureConfig& config) { return config.max_frame_rate.numerator > 0 && - config.max_frame_rate.denominator > 0 && config.max_bit_rate > 0 && + config.max_frame_rate.denominator > 0 && + ((config.max_bit_rate == 0) || + (config.max_bit_rate >= + capture_recommendations::kDefaultVideoBitRateLimits.minimum)) && !config.resolutions.empty() && std::all_of(config.resolutions.begin(), config.resolutions.end(), IsValidResolution); @@ -162,11 +171,11 @@ SenderSession::SenderSession(IPAddress remote_address, OSP_DCHECK(message_port_); OSP_DCHECK(environment_); - message_port_->SetClient(this); + message_port_->SetClient(this, "sender-" + std::to_string(session_id_)); } SenderSession::~SenderSession() { - message_port_->SetClient(nullptr); + message_port_->ResetClient(); } Error SenderSession::Negotiate(std::vector<AudioCaptureConfig> audio_configs, @@ -194,7 +203,10 @@ Error SenderSession::Negotiate(std::vector<AudioCaptureConfig> audio_configs, message_body[kOfferMessageBody] = std::move(json_offer.value()); Message message; - message.sender_id = std::to_string(session_id_); + // Currently we don't have a way to discover the ID of the receiver we + // are connected to, since we have to send the first message. + // TODO(jophba): migrate to discovered receiver ID when available. + message.sender_id = kDefaultStreamingReceiverSenderId; message.message_namespace = kCastWebrtcNamespace; message.body = std::move(message_body); SendMessage(&message); diff --git a/cast/streaming/sender_session_unittest.cc b/cast/streaming/sender_session_unittest.cc index 5a90bee0..7fd428be 100644 --- a/cast/streaming/sender_session_unittest.cc +++ b/cast/streaming/sender_session_unittest.cc @@ -190,6 +190,25 @@ TEST_F(SenderSessionTest, ComplainsIfMissingResolutions) { Error(Error::Code::kParameterInvalid, "Invalid configs provided.")); } +TEST_F(SenderSessionTest, SendsOfferWithZeroBitrateOptions) { + VideoCaptureConfig video_config = kVideoCaptureConfigValid; + video_config.max_bit_rate = 0; + AudioCaptureConfig audio_config = kAudioCaptureConfigValid; + audio_config.bit_rate = 0; + + const Error error = + session_->Negotiate(std::vector<AudioCaptureConfig>{audio_config}, + std::vector<VideoCaptureConfig>{video_config}); + EXPECT_TRUE(error.ok()); + + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + auto message_body = json::Parse(messages[0]); + ASSERT_TRUE(message_body.is_value()); + const Json::Value offer = std::move(message_body.value()); + EXPECT_EQ("OFFER", offer["type"].asString()); +} + TEST_F(SenderSessionTest, SendsOfferWithSimpleVideoOnly) { const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{}, diff --git a/cast/streaming/testing/simple_message_port.h b/cast/streaming/testing/simple_message_port.h index 2d56dd33..7d912918 100644 --- a/cast/streaming/testing/simple_message_port.h +++ b/cast/streaming/testing/simple_message_port.h @@ -20,7 +20,12 @@ namespace cast { class SimpleMessagePort : public MessagePort { public: ~SimpleMessagePort() override {} - void SetClient(MessagePort::Client* client) override { client_ = client; } + void SetClient(MessagePort::Client* client, + std::string client_sender_id) override { + client_ = client; + } + + void ResetClient() override { client_ = nullptr; } void ReceiveMessage(const std::string& message) { ReceiveMessage(kCastWebrtcNamespace, message); diff --git a/cast/test/device_auth_test.cc b/cast/test/device_auth_test.cc index 6bfd9ef7..36a9fc8a 100644 --- a/cast/test/device_auth_test.cc +++ b/cast/test/device_auth_test.cc @@ -4,6 +4,7 @@ #include <stdio.h> +#include "cast/common/certificate/cast_trust_store.h" #include "cast/common/certificate/testing/test_helpers.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/testing/fake_cast_socket.h" @@ -143,61 +144,61 @@ TEST_F(DeviceAuthTest, AuthIntegration) { } TEST_F(DeviceAuthTest, GoodCrl) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest(ReadEntireFileToString(data_path_ + "good_crl.pb"), - fake_crl_trust_store.get()); + &fake_crl_trust_store); } TEST_F(DeviceAuthTest, InvalidCrlTime) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest(ReadEntireFileToString(data_path_ + "invalid_time_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } TEST_F(DeviceAuthTest, IssuerRevoked) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest(ReadEntireFileToString(data_path_ + "issuer_revoked_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } TEST_F(DeviceAuthTest, DeviceRevoked) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest(ReadEntireFileToString(data_path_ + "device_revoked_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } TEST_F(DeviceAuthTest, IssuerSerialRevoked) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest( ReadEntireFileToString(data_path_ + "issuer_serial_revoked_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } TEST_F(DeviceAuthTest, DeviceSerialRevoked) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest( ReadEntireFileToString(data_path_ + "device_serial_revoked_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } TEST_F(DeviceAuthTest, BadCrlSignerCert) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest(ReadEntireFileToString(data_path_ + "bad_signer_cert_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } TEST_F(DeviceAuthTest, BadCrlSignature) { - std::unique_ptr<TrustStore> fake_crl_trust_store = - testing::CreateTrustStoreFromPemFile(data_path_ + "crl_root.pem"); + auto fake_crl_trust_store = + TrustStore::CreateInstanceFromPemFile(data_path_ + "crl_root.pem"); RunAuthTest(ReadEntireFileToString(data_path_ + "bad_signature_crl.pb"), - fake_crl_trust_store.get(), false); + &fake_crl_trust_store, false); } } // namespace diff --git a/cast/test/make_crl_tests.cc b/cast/test/make_crl_tests.cc index 106fe998..fca53176 100644 --- a/cast/test/make_crl_tests.cc +++ b/cast/test/make_crl_tests.cc @@ -11,6 +11,7 @@ #include "platform/test/paths.h" #include "util/crypto/certificate_utils.h" #include "util/crypto/digest_sign.h" +#include "util/crypto/pem_helpers.h" #include "util/crypto/sha2.h" #include "util/osp_logging.h" @@ -88,23 +89,24 @@ void PackCrlIntoFile(const std::string& filename, crl_bundle.SerializeToString(&output); int fd = open(filename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); OSP_DCHECK_GE(fd, 0); - OSP_DCHECK_EQ(write(fd, output.data(), output.size()), (int)output.size()); + OSP_DCHECK_EQ(write(fd, output.data(), output.size()), + static_cast<int>(output.size())); close(fd); } int CastMain() { const std::string data_path = GetTestDataPath() + "cast/receiver/channel/"; bssl::UniquePtr<EVP_PKEY> inter_key = - testing::ReadKeyFromPemFile(data_path + "inter_key.pem"); + ReadKeyFromPemFile(data_path + "inter_key.pem"); bssl::UniquePtr<EVP_PKEY> crl_inter_key = - testing::ReadKeyFromPemFile(data_path + "crl_inter_key.pem"); + ReadKeyFromPemFile(data_path + "crl_inter_key.pem"); OSP_DCHECK(inter_key); OSP_DCHECK(crl_inter_key); std::vector<std::string> chain_der = - testing::ReadCertificatesFromPemFile(data_path + "device_chain.pem"); + ReadCertificatesFromPemFile(data_path + "device_chain.pem"); std::vector<std::string> crl_inter_der = - testing::ReadCertificatesFromPemFile(data_path + "crl_inter.pem"); + ReadCertificatesFromPemFile(data_path + "crl_inter.pem"); OSP_DCHECK_EQ(chain_der.size(), 3u); OSP_DCHECK_EQ(crl_inter_der.size(), 1u); diff --git a/platform/base/error.cc b/platform/base/error.cc index 98d72085..f05c1e79 100644 --- a/platform/base/error.cc +++ b/platform/base/error.cc @@ -158,6 +158,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "ErrCertsDateInvalid"; case Error::Code::kErrCertsVerifyGeneric: return os << "ErrCertsVerifyGeneric"; + case Error::Code::kErrCertsVerifyUntrustedCert: + return os << "kErrCertsVerifyUntrustedCert"; case Error::Code::kErrCrlInvalid: return os << "ErrCrlInvalid"; case Error::Code::kErrCertsRevoked: diff --git a/platform/base/error.h b/platform/base/error.h index 61edd0ec..919c6575 100644 --- a/platform/base/error.h +++ b/platform/base/error.h @@ -120,6 +120,9 @@ class Error { // The certificate failed to chain to a trusted root. kErrCertsVerifyGeneric, + // The certificate was not found in the trust store. + kErrCertsVerifyUntrustedCert, + // The CRL is missing or failed to verify. kErrCrlInvalid, diff --git a/platform/impl/udp_socket_posix.cc b/platform/impl/udp_socket_posix.cc index 424e0d9b..393e2727 100644 --- a/platform/impl/udp_socket_posix.cc +++ b/platform/impl/udp_socket_posix.cc @@ -13,6 +13,7 @@ #include <sys/types.h> #include <unistd.h> +#include <algorithm> #include <cstring> #include <memory> #include <sstream> @@ -29,6 +30,11 @@ namespace openscreen { namespace { +// 64 KB is the maximum possible UDP datagram size. +#if !defined(OS_LINUX) +constexpr int kMaxUdpBufferSize = 64 << 10; +#endif + constexpr bool IsPowerOf2(uint32_t x) { return (x > 0) && ((x & (x - 1)) == 0); } @@ -372,29 +378,53 @@ bool IsPacketInfo<in6_pktinfo>(cmsghdr* cmh) { } template <class SockAddrType, class PktInfoType> -Error ReceiveMessageInternal(int fd, UdpPacket* packet) { +ErrorOr<UdpPacket> ReceiveMessageInternal(int fd) { + int upper_bound_bytes; +#if defined(OS_LINUX) + // This should return the exact size of the next message. + upper_bound_bytes = recv(fd, nullptr, 0, MSG_PEEK | MSG_TRUNC); + if (upper_bound_bytes == -1) { + return ChooseError(errno, Error::Code::kSocketReadFailure); + } +#elif defined(MAC_OSX) + // Can't use MSG_TRUNC in recv(). Use the FIONREAD ioctl() to get an + // upper-bound. + if (ioctl(fd, FIONREAD, &upper_bound_bytes) == -1 || upper_bound_bytes < 0) { + return ChooseError(errno, Error::Code::kSocketReadFailure); + } + upper_bound_bytes = std::min(upper_bound_bytes, kMaxUdpBufferSize); +#else // Other POSIX platforms (neither MSG_TRUNC nor FIONREAD available). + upper_bound_bytes = kMaxUdpBufferSize; +#endif + + UdpPacket packet(upper_bound_bytes); + msghdr msg = {}; SockAddrType sa; - iovec iov = {packet->data(), packet->size()}; - alignas(alignof(cmsghdr)) uint8_t control_buffer[1024]; - msghdr msg; msg.msg_name = &sa; msg.msg_namelen = sizeof(sa); + iovec iov = {packet.data(), packet.size()}; msg.msg_iov = &iov; msg.msg_iovlen = 1; + + // Although we don't do anything with the control buffer, on Linux + // it is required for the message to be properly read. +#if defined(OS_LINUX) + alignas(alignof(cmsghdr)) uint8_t control_buffer[1024]; msg.msg_control = control_buffer; msg.msg_controllen = sizeof(control_buffer); - msg.msg_flags = 0; - - ssize_t bytes_received = recvmsg(fd, &msg, 0); +#endif + const ssize_t bytes_received = recvmsg(fd, &msg, 0); if (bytes_received == -1) { + OSP_DVLOG << "Failed to read from socket."; return ChooseError(errno, Error::Code::kSocketReadFailure); } - - OSP_DCHECK_EQ(static_cast<size_t>(bytes_received), packet->size()); + // We may not populate the entire packet. + OSP_DCHECK_LE(static_cast<size_t>(bytes_received), packet.size()); + packet.resize(bytes_received); IPEndpoint source_endpoint = {.address = GetIPAddressFromSockAddr(sa), .port = GetPortFromFromSockAddr(sa)}; - packet->set_source(std::move(source_endpoint)); + packet.set_source(std::move(source_endpoint)); // For multicast sockets, the packet's original destination address may be // the host address (since we called bind()) but it may also be a @@ -412,11 +442,11 @@ Error ReceiveMessageInternal(int fd, UdpPacket* packet) { IPEndpoint destination_endpoint = { .address = GetIPAddressFromPktInfo(*pktinfo), .port = GetPortFromFromSockAddr(sa)}; - packet->set_destination(std::move(destination_endpoint)); + packet.set_destination(std::move(destination_endpoint)); break; } } - return Error::Code::kNone; + return std::move(packet); } } // namespace @@ -436,32 +466,15 @@ void UdpSocketPosix::ReceiveMessage() { return; } - ssize_t bytes_available = recv(handle_.fd, nullptr, 0, MSG_PEEK | MSG_TRUNC); - if (bytes_available == -1) { - task_runner_->PostTask( - [weak_this = weak_factory_.GetWeakPtr(), - error = - ChooseError(errno, Error::Code::kSocketReadFailure)]() mutable { - if (auto* self = weak_this.get()) { - if (auto* client = self->client_) { - client->OnRead(self, std::move(error)); - } - } - }); - return; - } - UdpPacket packet(bytes_available); - packet.set_socket(this); - Error result = Error::Code::kUnknownError; + ErrorOr<UdpPacket> read_result = Error::Code::kUnknownError; switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { - result = - ReceiveMessageInternal<sockaddr_in, in_pktinfo>(handle_.fd, &packet); + read_result = ReceiveMessageInternal<sockaddr_in, in_pktinfo>(handle_.fd); break; } case UdpSocket::Version::kV6: { - result = ReceiveMessageInternal<sockaddr_in6, in6_pktinfo>(handle_.fd, - &packet); + read_result = + ReceiveMessageInternal<sockaddr_in6, in6_pktinfo>(handle_.fd); break; } default: { @@ -469,17 +482,14 @@ void UdpSocketPosix::ReceiveMessage() { } } - task_runner_->PostTask( - [weak_this = weak_factory_.GetWeakPtr(), - read_result = result.ok() - ? ErrorOr<UdpPacket>(std::move(packet)) - : ErrorOr<UdpPacket>(std::move(result))]() mutable { - if (auto* self = weak_this.get()) { - if (auto* client = self->client_) { - client->OnRead(self, std::move(read_result)); - } - } - }); + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + read_result = std::move(read_result)]() mutable { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnRead(self, std::move(read_result)); + } + } + }); } void UdpSocketPosix::SendMessage(const void* data, diff --git a/util/BUILD.gn b/util/BUILD.gn index d2f756e7..2fdd24c0 100644 --- a/util/BUILD.gn +++ b/util/BUILD.gn @@ -30,6 +30,8 @@ source_set("util") { "crypto/digest_sign.h", "crypto/openssl_util.cc", "crypto/openssl_util.h", + "crypto/pem_helpers.cc", + "crypto/pem_helpers.h", "crypto/random_bytes.cc", "crypto/random_bytes.h", "crypto/rsa_private_key.cc", diff --git a/util/crypto/pem_helpers.cc b/util/crypto/pem_helpers.cc new file mode 100644 index 00000000..471ddc6d --- /dev/null +++ b/util/crypto/pem_helpers.cc @@ -0,0 +1,72 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/crypto/pem_helpers.h" + +#include <openssl/bytestring.h> +#include <openssl/pem.h> +#include <openssl/rsa.h> +#include <stdio.h> +#include <string.h> + +#include "absl/strings/match.h" +#include "util/osp_logging.h" + +namespace openscreen { + +std::vector<std::string> ReadCertificatesFromPemFile( + absl::string_view filename) { + FILE* fp = fopen(filename.data(), "r"); + if (!fp) { + return {}; + } + std::vector<std::string> certs; + char* name; + char* header; + unsigned char* data; + int64_t length; + while (PEM_read(fp, &name, &header, &data, + reinterpret_cast<long*>(&length)) == 1) { + if (absl::StartsWith(name, "CERTIFICATE")) { + certs.emplace_back(reinterpret_cast<char*>(data), length); + } + OPENSSL_free(name); + OPENSSL_free(header); + OPENSSL_free(data); + } + fclose(fp); + return certs; +} + +bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename) { + FILE* fp = fopen(filename.data(), "r"); + if (!fp) { + return nullptr; + } + bssl::UniquePtr<EVP_PKEY> pkey; + char* name; + char* header; + unsigned char* data; + int64_t length; + while (PEM_read(fp, &name, &header, &data, + reinterpret_cast<long*>(&length)) == 1) { + if (absl::StartsWith(name, "RSA PRIVATE KEY")) { + OSP_DCHECK(!pkey); + CBS cbs; + CBS_init(&cbs, data, length); + RSA* rsa = RSA_parse_private_key(&cbs); + if (rsa) { + pkey.reset(EVP_PKEY_new()); + EVP_PKEY_assign_RSA(pkey.get(), rsa); + } + } + OPENSSL_free(name); + OPENSSL_free(header); + OPENSSL_free(data); + } + fclose(fp); + return pkey; +} + +} // namespace openscreen diff --git a/util/crypto/pem_helpers.h b/util/crypto/pem_helpers.h new file mode 100644 index 00000000..6012b069 --- /dev/null +++ b/util/crypto/pem_helpers.h @@ -0,0 +1,24 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_CRYPTO_PEM_HELPERS_H_ +#define UTIL_CRYPTO_PEM_HELPERS_H_ + +#include <openssl/evp.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" + +namespace openscreen { + +std::vector<std::string> ReadCertificatesFromPemFile( + absl::string_view filename); + +bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename); + +} // namespace openscreen + +#endif // UTIL_CRYPTO_PEM_HELPERS_H_ diff --git a/util/crypto/random_bytes.cc b/util/crypto/random_bytes.cc index c090a762..6a4c9dcb 100644 --- a/util/crypto/random_bytes.cc +++ b/util/crypto/random_bytes.cc @@ -8,7 +8,6 @@ #include "util/osp_logging.h" namespace openscreen { -namespace crypto { std::array<uint8_t, 16> GenerateRandomBytes16() { std::array<uint8_t, 16> result; @@ -21,5 +20,4 @@ void GenerateRandomBytes(uint8_t* out, int len) { OSP_CHECK(RAND_bytes(out, len) == 1); } -} // namespace crypto } // namespace openscreen diff --git a/util/crypto/random_bytes.h b/util/crypto/random_bytes.h index be7381f0..3cb2fa8e 100644 --- a/util/crypto/random_bytes.h +++ b/util/crypto/random_bytes.h @@ -8,12 +8,10 @@ #include <array> namespace openscreen { -namespace crypto { std::array<uint8_t, 16> GenerateRandomBytes16(); void GenerateRandomBytes(uint8_t* out, int len); -} // namespace crypto } // namespace openscreen #endif // UTIL_CRYPTO_RANDOM_BYTES_H_ diff --git a/util/crypto/random_bytes_unittest.cc b/util/crypto/random_bytes_unittest.cc index b42e3f08..8129d604 100644 --- a/util/crypto/random_bytes_unittest.cc +++ b/util/crypto/random_bytes_unittest.cc @@ -10,7 +10,6 @@ #include "gtest/gtest.h" namespace openscreen { -namespace crypto { namespace { struct NonZero { @@ -48,5 +47,4 @@ TEST(RandomBytesTest, KeysAreNotIdentical) { std::end(keys)); } -} // namespace crypto } // namespace openscreen |