diff options
Diffstat (limited to 'net/socket')
22 files changed, 656 insertions, 187 deletions
diff --git a/net/socket/client_socket_pool_manager.cc b/net/socket/client_socket_pool_manager.cc index 9e0107dc0..34f87868c 100644 --- a/net/socket/client_socket_pool_manager.cc +++ b/net/socket/client_socket_pool_manager.cc @@ -284,9 +284,6 @@ int PreconnectSocketsForHttpRequest( const NetLogWithSource& net_log, int num_preconnect_streams, CompletionOnceCallback callback) { - // QUIC proxies are currently not supported through this method. - DCHECK(proxy_info.is_direct() || !proxy_info.proxy_chain().Last().is_quic()); - // Expect websocket schemes (ws and wss) to be converted to the http(s) // equivalent. DCHECK(endpoint.scheme() == url::kHttpScheme || diff --git a/net/socket/connect_job_factory.cc b/net/socket/connect_job_factory.cc index fa437fb3e..860da9f1a 100644 --- a/net/socket/connect_job_factory.cc +++ b/net/socket/connect_job_factory.cc @@ -131,36 +131,29 @@ std::unique_ptr<ConnectJob> ConnectJobFactory::CreateConnectJob( disable_cert_network_fetches, common_connect_job_params, proxy_dns_network_anonymization_key_); - if (holds_alternative<scoped_refptr<SSLSocketParams>>(connect_job_params)) { + if (connect_job_params.is_ssl()) { return ssl_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - get<scoped_refptr<SSLSocketParams>>(std::move(connect_job_params)), - delegate, /*net_log=*/nullptr); + connect_job_params.take_ssl(), delegate, /*net_log=*/nullptr); } - if (holds_alternative<scoped_refptr<TransportSocketParams>>( - connect_job_params)) { + if (connect_job_params.is_transport()) { return transport_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - get<scoped_refptr<TransportSocketParams>>( - std::move(connect_job_params)), - delegate, /*net_log=*/nullptr); + connect_job_params.take_transport(), delegate, /*net_log=*/nullptr); } - if (holds_alternative<scoped_refptr<HttpProxySocketParams>>( - connect_job_params)) { + if (connect_job_params.is_http_proxy()) { return http_proxy_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - get<scoped_refptr<HttpProxySocketParams>>(connect_job_params), delegate, + connect_job_params.take_http_proxy(), delegate, /*net_log=*/nullptr); } - CHECK( - holds_alternative<scoped_refptr<SOCKSSocketParams>>(connect_job_params)); + CHECK(connect_job_params.is_socks()); return socks_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - get<scoped_refptr<SOCKSSocketParams>>(std::move(connect_job_params)), - delegate, /*net_log=*/nullptr); + connect_job_params.take_socks(), delegate, /*net_log=*/nullptr); } } // namespace net diff --git a/net/socket/connect_job_params.cc b/net/socket/connect_job_params.cc new file mode 100644 index 000000000..9efdd40ce --- /dev/null +++ b/net/socket/connect_job_params.cc @@ -0,0 +1,31 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/connect_job_params.h" + +#include "net/http/http_proxy_connect_job.h" +#include "net/socket/socks_connect_job.h" +#include "net/socket/ssl_connect_job.h" +#include "net/socket/transport_connect_job.h" + +namespace net { + +ConnectJobParams::ConnectJobParams() = default; +ConnectJobParams::ConnectJobParams(scoped_refptr<HttpProxySocketParams> params) + : params_(params) {} +ConnectJobParams::ConnectJobParams(scoped_refptr<SOCKSSocketParams> params) + : params_(params) {} +ConnectJobParams::ConnectJobParams(scoped_refptr<TransportSocketParams> params) + : params_(params) {} +ConnectJobParams::ConnectJobParams(scoped_refptr<SSLSocketParams> params) + : params_(params) {} + +ConnectJobParams::~ConnectJobParams() = default; + +ConnectJobParams::ConnectJobParams(ConnectJobParams&) = default; +ConnectJobParams& ConnectJobParams::operator=(ConnectJobParams&) = default; +ConnectJobParams::ConnectJobParams(ConnectJobParams&&) = default; +ConnectJobParams& ConnectJobParams::operator=(ConnectJobParams&&) = default; + +} // namespace net diff --git a/net/socket/connect_job_params.h b/net/socket/connect_job_params.h new file mode 100644 index 000000000..d89655cba --- /dev/null +++ b/net/socket/connect_job_params.h @@ -0,0 +1,90 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CONNECT_JOB_PARAMS_H_ +#define NET_SOCKET_CONNECT_JOB_PARAMS_H_ + +#include "base/memory/scoped_refptr.h" +#include "net/base/net_export.h" +#include "third_party/abseil-cpp/absl/types/variant.h" + +namespace net { + +class HttpProxySocketParams; +class SOCKSSocketParams; +class TransportSocketParams; +class SSLSocketParams; + +// Abstraction over the param types for various connect jobs. +class NET_EXPORT_PRIVATE ConnectJobParams { + public: + ConnectJobParams(); + explicit ConnectJobParams(scoped_refptr<HttpProxySocketParams> params); + explicit ConnectJobParams(scoped_refptr<SOCKSSocketParams> params); + explicit ConnectJobParams(scoped_refptr<TransportSocketParams> params); + explicit ConnectJobParams(scoped_refptr<SSLSocketParams> params); + ~ConnectJobParams(); + + ConnectJobParams(ConnectJobParams&); + ConnectJobParams& operator=(ConnectJobParams&); + ConnectJobParams(ConnectJobParams&&); + ConnectJobParams& operator=(ConnectJobParams&&); + + bool is_http_proxy() const { + return absl::holds_alternative<scoped_refptr<HttpProxySocketParams>>( + params_); + } + + bool is_socks() const { + return absl::holds_alternative<scoped_refptr<SOCKSSocketParams>>(params_); + } + + bool is_transport() const { + return absl::holds_alternative<scoped_refptr<TransportSocketParams>>( + params_); + } + + bool is_ssl() const { + return absl::holds_alternative<scoped_refptr<SSLSocketParams>>(params_); + } + + // Get lvalue references to the contained params. + const scoped_refptr<HttpProxySocketParams>& http_proxy() const { + return get<scoped_refptr<HttpProxySocketParams>>(params_); + } + const scoped_refptr<SOCKSSocketParams>& socks() const { + return get<scoped_refptr<SOCKSSocketParams>>(params_); + } + const scoped_refptr<TransportSocketParams>& transport() const { + return get<scoped_refptr<TransportSocketParams>>(params_); + } + const scoped_refptr<SSLSocketParams>& ssl() const { + return get<scoped_refptr<SSLSocketParams>>(params_); + } + + // Take params out of this value. + scoped_refptr<HttpProxySocketParams>&& take_http_proxy() { + return get<scoped_refptr<HttpProxySocketParams>>(std::move(params_)); + } + scoped_refptr<SOCKSSocketParams>&& take_socks() { + return get<scoped_refptr<SOCKSSocketParams>>(std::move(params_)); + } + scoped_refptr<TransportSocketParams>&& take_transport() { + return get<scoped_refptr<TransportSocketParams>>(std::move(params_)); + } + scoped_refptr<SSLSocketParams>&& take_ssl() { + return get<scoped_refptr<SSLSocketParams>>(std::move(params_)); + } + + private: + absl::variant<scoped_refptr<HttpProxySocketParams>, + scoped_refptr<SOCKSSocketParams>, + scoped_refptr<TransportSocketParams>, + scoped_refptr<SSLSocketParams>> + params_; +}; + +} // namespace net + +#endif // NET_SOCKET_CONNECT_JOB_PARAMS_H_ diff --git a/net/socket/connect_job_params_factory.cc b/net/socket/connect_job_params_factory.cc index 99a26835f..53cd4a807 100644 --- a/net/socket/connect_job_params_factory.cc +++ b/net/socket/connect_job_params_factory.cc @@ -18,6 +18,7 @@ #include "net/base/request_priority.h" #include "net/dns/public/secure_dns_policy.h" #include "net/http/http_proxy_connect_job.h" +#include "net/socket/connect_job_params.h" #include "net/socket/next_proto.h" #include "net/socket/socket_tag.h" #include "net/socket/socks_connect_job.h" @@ -140,53 +141,13 @@ bool UsingSsl(const ConnectJobFactory::Endpoint& endpoint) { return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint).using_ssl; } -scoped_refptr<HttpProxySocketParams> MaybeHttpProxySocketParams( - const ConnectJobParams& params) { - if (auto p = get_if<scoped_refptr<HttpProxySocketParams>>(¶ms)) { - return *p; - } - return nullptr; -} - -scoped_refptr<SOCKSSocketParams> MaybeSOCKSSocketParams( - const ConnectJobParams& params) { - if (auto p = get_if<scoped_refptr<SOCKSSocketParams>>(¶ms)) { - return *p; - } - return nullptr; -} - -scoped_refptr<TransportSocketParams> MaybeTransportSocketParams( - const ConnectJobParams& params) { - if (auto p = get_if<scoped_refptr<TransportSocketParams>>(¶ms)) { - return *p; - } - return nullptr; -} - -scoped_refptr<SSLSocketParams> MaybeSSLSocketParams( - const ConnectJobParams& params) { - if (auto p = get_if<scoped_refptr<SSLSocketParams>>(¶ms)) { - return *p; - } - return nullptr; -} - ConnectJobParams MakeSSLSocketParams( ConnectJobParams params, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const NetworkAnonymizationKey& network_anonymization_key) { - scoped_refptr<TransportSocketParams> transport_socket_params = - MaybeTransportSocketParams(params); - scoped_refptr<HttpProxySocketParams> http_proxy_socket_params = - MaybeHttpProxySocketParams(params); - scoped_refptr<SOCKSSocketParams> socks_socket_params = - MaybeSOCKSSocketParams(params); return ConnectJobParams(base::MakeRefCounted<SSLSocketParams>( - std::move(transport_socket_params), std::move(socks_socket_params), - std::move(http_proxy_socket_params), host_and_port, ssl_config, - network_anonymization_key)); + std::move(params), host_and_port, ssl_config, network_anonymization_key)); } // Recursively generate the params for a proxy at `host_port_pair` and the given @@ -229,7 +190,21 @@ ConnectJobParams CreateProxyParams( // Create the nested parameters over which the connection to the proxy // will be made. ConnectJobParams params; - if (proxy_chain_index == 0) { + + if (proxy_server.is_quic()) { + // If this and all proxies earlier in the chain are QUIC, then we can hand + // off the remainder of the proxy connecting work to the QuicSocketPool, so + // no further recursion is required. If any proxies earlier in the chain are + // not QUIC, then the chain is unsupported. Such ProxyChains cannot be + // constructed, so this is just a double-check. + for (size_t i = 0; i < proxy_chain_index; i++) { + CHECK(proxy_chain.GetProxyServer(i).is_quic()); + } + return ConnectJobParams(base::MakeRefCounted<HttpProxySocketParams>( + std::move(proxy_server_ssl_config), host_port_pair, proxy_chain, + proxy_chain_index, should_tunnel, *proxy_annotation_tag, + network_anonymization_key, secure_dns_policy)); + } else if (proxy_chain_index == 0) { // At the beginning of the chain, create the only TransportSocketParams // object, corresponding to the transport socket we want to create to the // first proxy. @@ -241,10 +216,6 @@ ConnectJobParams CreateProxyParams( secure_dns_policy, resolution_callback, SupportedProtocolsFromSSLConfig(proxy_server_ssl_config))); } else { - // TODO(https://crbug.com/1491092): For now we will assume that proxy - // chains with multiple proxies must all use HTTPS. - CHECK(proxy_chain.GetProxyServer(proxy_chain_index - 1) - .is_secure_http_like()); params = CreateProxyParams( proxy_server.host_port_pair(), true, endpoint, proxy_chain, proxy_chain_index - 1, proxy_annotation_tag, resolution_callback, @@ -263,33 +234,18 @@ ConnectJobParams CreateProxyParams( // Further wrap the underlying connection params, or the SSL params wrapping // them, with the proxy params. if (proxy_server.is_http_like()) { - scoped_refptr<TransportSocketParams> transport_socket_params = - MaybeTransportSocketParams(params); - scoped_refptr<SSLSocketParams> ssl_socket_params = - MaybeSSLSocketParams(params); - std::optional<SSLConfig> quic_ssl_config; - if (proxy_server.is_quic()) { - // For QUIC, we only need the SSL config, not the full SSLSocketParams. - // A subsequent CL will remove the redundant SSLSocketParams creation. - quic_ssl_config = ssl_socket_params->ssl_config(); - ssl_socket_params = nullptr; - } + CHECK(!proxy_server.is_quic()); params = ConnectJobParams(base::MakeRefCounted<HttpProxySocketParams>( - std::move(transport_socket_params), std::move(ssl_socket_params), - std::move(quic_ssl_config), host_port_pair, proxy_chain, - proxy_chain_index, should_tunnel, *proxy_annotation_tag, - network_anonymization_key, secure_dns_policy)); + std::move(params), host_port_pair, proxy_chain, proxy_chain_index, + should_tunnel, *proxy_annotation_tag, network_anonymization_key, + secure_dns_policy)); } else { DCHECK(proxy_server.is_socks()); DCHECK_EQ(1u, proxy_chain.length()); - scoped_refptr<TransportSocketParams> transport_socket_params = - MaybeTransportSocketParams(params); - DCHECK(transport_socket_params); // TODO(crbug.com/1206799): Pass `endpoint` directly (preserving scheme // when available)? params = ConnectJobParams(base::MakeRefCounted<SOCKSSocketParams>( - std::move(transport_socket_params), - proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5, + std::move(params), proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5, ToHostPortPair(endpoint), network_anonymization_key, *proxy_annotation_tag)); } diff --git a/net/socket/connect_job_params_factory.h b/net/socket/connect_job_params_factory.h index 73394f72f..796f68dd7 100644 --- a/net/socket/connect_job_params_factory.h +++ b/net/socket/connect_job_params_factory.h @@ -17,6 +17,7 @@ #include "net/http/http_proxy_connect_job.h" #include "net/socket/connect_job.h" #include "net/socket/connect_job_factory.h" +#include "net/socket/connect_job_params.h" #include "net/socket/socks_connect_job.h" #include "net/socket/ssl_connect_job.h" #include "net/socket/transport_connect_job.h" @@ -29,12 +30,6 @@ struct NetworkTrafficAnnotationTag; class ProxyChain; struct SSLConfig; -// Abstraction over the param types for various connect jobs. -using ConnectJobParams = absl::variant<scoped_refptr<HttpProxySocketParams>, - scoped_refptr<SOCKSSocketParams>, - scoped_refptr<TransportSocketParams>, - scoped_refptr<SSLSocketParams>>; - NET_EXPORT_PRIVATE ConnectJobParams ConstructConnectJobParams( const ConnectJobFactory::Endpoint& endpoint, const ProxyChain& proxy_chain, diff --git a/net/socket/connect_job_params_factory_unittest.cc b/net/socket/connect_job_params_factory_unittest.cc index 7219455c1..715b536f1 100644 --- a/net/socket/connect_job_params_factory_unittest.cc +++ b/net/socket/connect_job_params_factory_unittest.cc @@ -74,16 +74,16 @@ std::ostream& operator<<(std::ostream& os, const TestParams& test_params) { // Get a string describing the params variant. const char* ParamsName(ConnectJobParams& params) { - if (absl::holds_alternative<scoped_refptr<HttpProxySocketParams>>(params)) { + if (params.is_http_proxy()) { return "HttpProxySocketParams"; } - if (absl::holds_alternative<scoped_refptr<SOCKSSocketParams>>(params)) { + if (params.is_socks()) { return "SOCKSSocketParams"; } - if (absl::holds_alternative<scoped_refptr<SSLSocketParams>>(params)) { + if (params.is_ssl()) { return "SSLSocketParams"; } - if (absl::holds_alternative<scoped_refptr<TransportSocketParams>>(params)) { + if (params.is_transport()) { return "TransportSocketParams"; } return "Unknown"; @@ -91,15 +91,16 @@ const char* ParamsName(ConnectJobParams& params) { scoped_refptr<HttpProxySocketParams> ExpectHttpProxySocketParams( ConnectJobParams params) { - EXPECT_TRUE( - absl::holds_alternative<scoped_refptr<HttpProxySocketParams>>(params)) + EXPECT_TRUE(params.is_http_proxy()) << "Expected HttpProxySocketParams, got " << ParamsName(params); - return absl::get<scoped_refptr<HttpProxySocketParams>>(params); + return params.take_http_proxy(); } void VerifyHttpProxySocketParams( scoped_refptr<HttpProxySocketParams> params, const char* description, + // Only QUIC proxies have a quic_ssl_config. + std::optional<SSLConfig> quic_ssl_config, const HostPortPair& endpoint, const ProxyChain& proxy_chain, size_t proxy_chain_index, @@ -107,6 +108,16 @@ void VerifyHttpProxySocketParams( const NetworkAnonymizationKey& network_anonymization_key, const SecureDnsPolicy secure_dns_policy) { SCOPED_TRACE(testing::Message() << "Verifying " << description); + if (quic_ssl_config) { + // Only examine the values used for QUIC connections. + ASSERT_TRUE(params->quic_ssl_config().has_value()); + EXPECT_EQ(params->quic_ssl_config()->privacy_mode, + quic_ssl_config->privacy_mode); + EXPECT_EQ(params->quic_ssl_config()->GetCertVerifyFlags(), + quic_ssl_config->GetCertVerifyFlags()); + } else { + EXPECT_FALSE(params->quic_ssl_config().has_value()); + } EXPECT_EQ(params->endpoint(), endpoint); EXPECT_EQ(params->proxy_chain(), proxy_chain); EXPECT_EQ(params->proxy_chain_index(), proxy_chain_index); @@ -117,9 +128,9 @@ void VerifyHttpProxySocketParams( scoped_refptr<SOCKSSocketParams> ExpectSOCKSSocketParams( ConnectJobParams params) { - EXPECT_TRUE(absl::holds_alternative<scoped_refptr<SOCKSSocketParams>>(params)) + EXPECT_TRUE(params.is_socks()) << "Expected SOCKSSocketParams, got " << ParamsName(params); - return absl::get<scoped_refptr<SOCKSSocketParams>>(params); + return params.take_socks(); } // Verify the properties of SOCKSSocketParams. @@ -138,10 +149,9 @@ void VerifySOCKSSocketParams( // Assert that the params are TransportSocketParams and return them. scoped_refptr<TransportSocketParams> ExpectTransportSocketParams( ConnectJobParams params) { - EXPECT_TRUE( - absl::holds_alternative<scoped_refptr<TransportSocketParams>>(params)) + EXPECT_TRUE(params.is_transport()) << "Expected TransportSocketParams, got " << ParamsName(params); - return absl::get<scoped_refptr<TransportSocketParams>>(params); + return params.take_transport(); } // Verify the properties of TransportSocketParams. @@ -161,9 +171,9 @@ void VerifyTransportSocketParams( // Assert that the params are SSLSocketParams and return them. scoped_refptr<SSLSocketParams> ExpectSSLSocketParams(ConnectJobParams params) { - EXPECT_TRUE(absl::holds_alternative<scoped_refptr<SSLSocketParams>>(params)) + EXPECT_TRUE(params.is_ssl()) << "Expected SSLSocketParams, got " << ParamsName(params); - return absl::get<scoped_refptr<SSLSocketParams>>(params); + return params.take_ssl(); } // Verify the properties of SSLSocketParams. @@ -441,6 +451,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaHttpsProxy) { ExpectHttpProxySocketParams(params); VerifyHttpProxySocketParams( http_proxy_socket_params, "http_proxy_socket_params", + /*quic_ssl_config=*/std::nullopt, HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, /*proxy_chain_index=*/0, /*tunnel=*/false, kTestNak, secure_dns_policy()); @@ -461,6 +472,32 @@ TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaHttpsProxy) { AlpnProtoStringsForMode(alpn_mode())); } +// A connection to an HTTP endpoint via an QUIC proxy sets up an +// HttpProxySocketParams, wrapping almost-unused SSLSocketParams, intending to +// use GET to the proxy. This is not tunneled. +TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaQuicProxy) { + const url::SchemeHostPort kEndpoint(url::kHttpScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::ForIpProtection({ + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_QUIC, "proxy", + 443), + }); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + auto http_proxy_socket_params = ExpectHttpProxySocketParams(params); + SSLConfig quic_ssl_config = SSLConfigForProxy(); + // Traffic always tunnels over QUIC proxies. + const bool tunnel = true; + VerifyHttpProxySocketParams( + http_proxy_socket_params, "http_proxy_socket_params", quic_ssl_config, + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/0, tunnel, kTestNak, secure_dns_policy()); +} + // A connection to an HTTPS endpoint via an HTTPS proxy, // sets up an SSLSocketParams, wrapping HttpProxySocketParams, wrapping // SSLSocketParams, wrapping TransportSocketParams. This is always tunneled. @@ -492,6 +529,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxy) { endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); VerifyHttpProxySocketParams( http_proxy_socket_params, "http_proxy_socket_params", + /*quic_ssl_config=*/std::nullopt, HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, /*proxy_chain_index=*/0, /*tunnel=*/true, kTestNak, secure_dns_policy()); @@ -512,6 +550,44 @@ TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxy) { AlpnProtoStringsForMode(alpn_mode())); } +// A connection to an HTTPS endpoint via a QUIC proxy, +// sets up an SSLSocketParams, wrapping HttpProxySocketParams, wrapping +// SSLSocketParams. This is always tunneled. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaQuicProxy) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::ForIpProtection({ + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_QUIC, "proxy", + 443), + }); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + auto endpoint_ssl_socket_params = ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + auto http_proxy_socket_params = + endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); + SSLConfig quic_ssl_config = SSLConfigForProxy(); + VerifyHttpProxySocketParams( + http_proxy_socket_params, "http_proxy_socket_params", quic_ssl_config, + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/0, + /*tunnel=*/true, kTestNak, secure_dns_policy()); +} + // A connection to an HTTPS endpoint via an HTTP proxy // sets up an SSLSocketParams, wrapping HttpProxySocketParams, wrapping // TransportSocketParams. This is always tunneled. @@ -542,6 +618,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpProxy) { endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); VerifyHttpProxySocketParams( http_proxy_socket_params, "http_proxy_socket_params", + /*quic_ssl_config=*/std::nullopt, HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, /*proxy_chain_index=*/0, /*tunnel=*/true, kTestNak, secure_dns_policy()); @@ -644,6 +721,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaHttpsProxyViaHttpsProxy) { ExpectHttpProxySocketParams(params); VerifyHttpProxySocketParams( http_proxy_socket_params_b, "http_proxy_socket_params_b", + /*quic_ssl_config=*/std::nullopt, HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, /*proxy_chain_index=*/1, /*tunnel=*/true, kTestNak, secure_dns_policy()); @@ -661,6 +739,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaHttpsProxyViaHttpsProxy) { proxy_ssl_socket_params_b->GetHttpProxyConnectionParams(); VerifyHttpProxySocketParams(http_proxy_socket_params_a, "http_proxy_socket_params_a", + /*quic_ssl_config=*/std::nullopt, HostPortPair("proxyb", 443), proxy_chain, /*proxy_chain_index=*/0, /*tunnel=*/true, kTestNak, secure_dns_policy()); @@ -715,6 +794,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxyViaHttpsProxy) { endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); VerifyHttpProxySocketParams( http_proxy_socket_params_b, "http_proxy_socket_params_b", + /*quic_ssl_config=*/std::nullopt, HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, /*proxy_chain_index=*/1, /*tunnel=*/true, kTestNak, secure_dns_policy()); @@ -732,6 +812,7 @@ TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxyViaHttpsProxy) { proxy_ssl_socket_params_b->GetHttpProxyConnectionParams(); VerifyHttpProxySocketParams(http_proxy_socket_params_a, "http_proxy_socket_params_a", + /*quic_ssl_config=*/std::nullopt, HostPortPair("proxyb", 443), proxy_chain, /*proxy_chain_index=*/0, /*tunnel=*/true, kTestNak, secure_dns_policy()); @@ -752,6 +833,124 @@ TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxyViaHttpsProxy) { AlpnProtoStringsForMode(alpn_mode())); } +// A connection to an HTTPS endpoint via a two-proxy QUIC chain +// sets up the required parameters. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaQuicProxyViaQuicProxy) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::ForIpProtection({ + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_QUIC, "proxya", + 443), + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_QUIC, "proxyb", + 443), + }); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + auto endpoint_ssl_socket_params = ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + auto http_proxy_socket_params_b = + endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); + SSLConfig quic_ssl_config_b = SSLConfigForProxy(); + VerifyHttpProxySocketParams(http_proxy_socket_params_b, + "http_proxy_socket_params_b", quic_ssl_config_b, + HostPortPair::FromSchemeHostPort(kEndpoint), + proxy_chain, + /*proxy_chain_index=*/1, + /*tunnel=*/true, kTestNak, secure_dns_policy()); +} + +// A connection to an HTTPS endpoint via a proxy chain with two HTTPS proxies +// and two QUIC proxies. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaMixedProxyChain) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::ForIpProtection({ + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_QUIC, "proxya", + 443), + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_QUIC, "proxyb", + 443), + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTPS, "proxyc", + 443), + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTPS, "proxyd", + 443), + }); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + auto endpoint_ssl_socket_params = ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_d = + endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams( + http_proxy_socket_params_d, "http_proxy_socket_params_d", + /*quic_ssl_config=*/std::nullopt, + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/3, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params_d = + http_proxy_socket_params_d->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params_d); + SSLConfig proxy_ssl_config = SSLConfigForProxy(); + VerifySSLSocketParams(proxy_ssl_socket_params_d, "proxy_ssl_socket_params_d", + HostPortPair::FromString("proxyd:443"), + proxy_ssl_config, PrivacyMode::PRIVACY_MODE_DISABLED, + kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_c = + proxy_ssl_socket_params_d->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams(http_proxy_socket_params_c, + "http_proxy_socket_params_c", + /*quic_ssl_config=*/std::nullopt, + HostPortPair("proxyd", 443), proxy_chain, + /*proxy_chain_index=*/2, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params_c = + http_proxy_socket_params_c->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params_c); + VerifySSLSocketParams(proxy_ssl_socket_params_c, "proxy_ssl_socket_params_c", + HostPortPair::FromString("proxyc:443"), + proxy_ssl_config, PrivacyMode::PRIVACY_MODE_DISABLED, + kTestNak); + + auto http_proxy_socket_params_b = + proxy_ssl_socket_params_c->GetHttpProxyConnectionParams(); + SSLConfig quic_ssl_config_b = SSLConfigForProxy(); + VerifyHttpProxySocketParams(http_proxy_socket_params_b, + "http_proxy_socket_params_b", quic_ssl_config_b, + HostPortPair("proxyc", 443), proxy_chain, + /*proxy_chain_index=*/1, + /*tunnel=*/true, kTestNak, secure_dns_policy()); +} + INSTANTIATE_TEST_SUITE_P( All, ConnectJobParamsFactoryTest, diff --git a/net/socket/socket_descriptor.cc b/net/socket/socket_descriptor.cc index f6db3d82e..3e252d07d 100644 --- a/net/socket/socket_descriptor.cc +++ b/net/socket/socket_descriptor.cc @@ -8,6 +8,7 @@ #if BUILDFLAG(IS_WIN) #include <ws2tcpip.h> + #include "net/base/winsock_init.h" #elif BUILDFLAG(IS_POSIX) || BUILDFLAG(IS_FUCHSIA) #include <sys/socket.h> diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 40e3d2ea5..e394cd42a 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -1470,7 +1470,6 @@ MockSSLClientSocket::GetPeerApplicationSettings() const { } bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) { - requested_ssl_info->Reset(); *requested_ssl_info = data_->ssl_info; return true; } diff --git a/net/socket/socks_connect_job.cc b/net/socket/socks_connect_job.cc index b920c28af..d8ed2e825 100644 --- a/net/socket/socks_connect_job.cc +++ b/net/socket/socks_connect_job.cc @@ -13,6 +13,7 @@ #include "net/log/net_log_with_source.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" +#include "net/socket/connect_job_params.h" #include "net/socket/socks5_client_socket.h" #include "net/socket/socks_client_socket.h" #include "net/socket/transport_connect_job.h" @@ -23,12 +24,12 @@ namespace net { static constexpr base::TimeDelta kSOCKSConnectJobTimeout = base::Seconds(30); SOCKSSocketParams::SOCKSSocketParams( - scoped_refptr<TransportSocketParams> proxy_server_params, + ConnectJobParams nested_params, bool socks_v5, const HostPortPair& host_port_pair, const NetworkAnonymizationKey& network_anonymization_key, const NetworkTrafficAnnotationTag& traffic_annotation) - : transport_params_(std::move(proxy_server_params)), + : transport_params_(nested_params.take_transport()), destination_(host_port_pair), socks_v5_(socks_v5), network_anonymization_key_(network_anonymization_key), diff --git a/net/socket/socks_connect_job.h b/net/socket/socks_connect_job.h index 5e9de9c7d..3fc7dc02d 100644 --- a/net/socket/socks_connect_job.h +++ b/net/socket/socks_connect_job.h @@ -17,6 +17,7 @@ #include "net/base/request_priority.h" #include "net/dns/public/resolve_error_info.h" #include "net/socket/connect_job.h" +#include "net/socket/connect_job_params.h" #include "net/socket/socks_client_socket.h" #include "net/traffic_annotation/network_traffic_annotation.h" @@ -29,7 +30,7 @@ class TransportSocketParams; class NET_EXPORT_PRIVATE SOCKSSocketParams : public base::RefCounted<SOCKSSocketParams> { public: - SOCKSSocketParams(scoped_refptr<TransportSocketParams> proxy_server_params, + SOCKSSocketParams(ConnectJobParams nested_params, bool socks_v5, const HostPortPair& host_port_pair, const NetworkAnonymizationKey& network_anonymization_key, diff --git a/net/socket/socks_connect_job_unittest.cc b/net/socket/socks_connect_job_unittest.cc index 830015e49..6821b8254 100644 --- a/net/socket/socks_connect_job_unittest.cc +++ b/net/socket/socks_connect_job_unittest.cc @@ -77,10 +77,10 @@ class SOCKSConnectJobTest : public testing::Test, public WithTaskEnvironment { SOCKSVersion socks_version, SecureDnsPolicy secure_dns_policy = SecureDnsPolicy::kAllow) { return base::MakeRefCounted<SOCKSSocketParams>( - base::MakeRefCounted<TransportSocketParams>( + ConnectJobParams(base::MakeRefCounted<TransportSocketParams>( HostPortPair(kProxyHostName, kProxyPort), NetworkAnonymizationKey(), secure_dns_policy, OnHostResolutionCallback(), - /*supported_alpns=*/base::flat_set<std::string>()), + /*supported_alpns=*/base::flat_set<std::string>())), socks_version == SOCKSVersion::V5, socks_version == SOCKSVersion::V4 ? HostPortPair(kSOCKS4TestHost, kSOCKS4TestPort) @@ -126,11 +126,11 @@ TEST_F(SOCKSConnectJobTest, HostResolutionFailureSOCKS4Endpoint) { scoped_refptr<SOCKSSocketParams> socket_params = base::MakeRefCounted<SOCKSSocketParams>( - base::MakeRefCounted<TransportSocketParams>( + ConnectJobParams(base::MakeRefCounted<TransportSocketParams>( HostPortPair(kProxyHostName, kProxyPort), NetworkAnonymizationKey(), SecureDnsPolicy::kAllow, OnHostResolutionCallback(), - /*supported_alpns=*/base::flat_set<std::string>()), + /*supported_alpns=*/base::flat_set<std::string>())), false /* socks_v5 */, HostPortPair(hostname, kSOCKS4TestPort), NetworkAnonymizationKey(), TRAFFIC_ANNOTATION_FOR_TESTS); diff --git a/net/socket/ssl_client_socket.cc b/net/socket/ssl_client_socket.cc index c207b3aef..f5ac37f5d 100644 --- a/net/socket/ssl_client_socket.cc +++ b/net/socket/ssl_client_socket.cc @@ -25,10 +25,13 @@ namespace { // Returns true if |first_cert| and |second_cert| represent the same certificate // (with the same chain), or if they're both NULL. bool AreCertificatesEqual(const scoped_refptr<X509Certificate>& first_cert, - const scoped_refptr<X509Certificate>& second_cert) { + const scoped_refptr<X509Certificate>& second_cert, + bool include_chain = true) { return (!first_cert && !second_cert) || (first_cert && second_cert && - first_cert->EqualsIncludingChain(second_cert.get())); + (include_chain + ? first_cert->EqualsIncludingChain(second_cert.get()) + : first_cert->EqualsExcludingChain(second_cert.get()))); } // Returns a base::Value::Dict value NetLog parameter with the expected format @@ -44,6 +47,22 @@ base::Value::Dict NetLogClearCachedClientCertParams( .Set("is_cleared", is_cleared); } +// Returns a base::Value::Dict value NetLog parameter with the expected format +// for events of type CLEAR_MATCHING_CACHED_CLIENT_CERT. +base::Value::Dict NetLogClearMatchingCachedClientCertParams( + const base::flat_set<net::HostPortPair>& hosts, + const scoped_refptr<net::X509Certificate>& cert) { + base::Value::List hosts_values; + for (const auto& host : hosts) { + hosts_values.Append(host.ToString()); + } + + return base::Value::Dict() + .Set("hosts", base::Value(std::move(hosts_values))) + .Set("certificates", cert ? net::NetLogX509CertificateList(cert.get()) + : base::Value(base::Value::List())); +} + } // namespace SSLClientSocket::SSLClientSocket() = default; @@ -216,6 +235,43 @@ void SSLClientContext::ClearClientCertificateIfNeeded( NotifySSLConfigForServersChanged({host}); } +void SSLClientContext::ClearMatchingClientCertificate( + const scoped_refptr<net::X509Certificate>& certificate) { + CHECK(certificate); + + base::flat_set<HostPortPair> cleared_servers; + for (const auto& server : ssl_client_auth_cache_.GetCachedServers()) { + scoped_refptr<X509Certificate> cached_certificate; + scoped_refptr<SSLPrivateKey> cached_private_key; + if (ssl_client_auth_cache_.Lookup(server, &cached_certificate, + &cached_private_key) && + AreCertificatesEqual(cached_certificate, certificate, + /*include_chain=*/false)) { + cleared_servers.insert(cleared_servers.end(), server); + } + } + + net::NetLog::Get()->AddGlobalEntry( + NetLogEventType::CLEAR_MATCHING_CACHED_CLIENT_CERT, [&]() { + return NetLogClearMatchingCachedClientCertParams(cleared_servers, + certificate); + }); + + if (cleared_servers.empty()) { + return; + } + + for (const auto& server_to_clear : cleared_servers) { + ssl_client_auth_cache_.Remove(server_to_clear); + } + + if (ssl_client_session_cache_) { + ssl_client_session_cache_->FlushForServers(cleared_servers); + } + + NotifySSLConfigForServersChanged(cleared_servers); +} + void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) { for (Observer& observer : observers_) { observer.OnSSLConfigChanged(change_type); diff --git a/net/socket/ssl_client_socket.h b/net/socket/ssl_client_socket.h index 6b5a991f9..b63c471a9 100644 --- a/net/socket/ssl_client_socket.h +++ b/net/socket/ssl_client_socket.h @@ -188,6 +188,14 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer, const net::HostPortPair& host, const scoped_refptr<net::X509Certificate>& certificate); + // Clears a client certificate preference, set by SetClientCertificate(), + // for all hosts whose cached certificate matches |certificate|. + // + // Note this method will synchronously call OnSSLConfigForServersChanged() on + // observers. + void ClearMatchingClientCertificate( + const scoped_refptr<net::X509Certificate>& certificate); + base::flat_set<HostPortPair> GetClientCertificateCachedServersForTesting() const { return ssl_client_auth_cache_.GetCachedServers(); diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index d2e5ee80d..879ad2e78 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -3964,6 +3964,163 @@ TEST_F(SSLClientSocketTest, DontClearClientCertificatesWithNullCerts) { EXPECT_FALSE(GetBooleanValueFromParams(entries[0], "is_cleared")); } +TEST_F(SSLClientSocketTest, ClearMatchingCertDontClearEmptyClientCertCache) { + SSLServerConfig server_config; + server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT; + ASSERT_TRUE( + StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config)); + + // No cached client certs and no open session. + ASSERT_TRUE(context_->GetClientCertificateCachedServersForTesting().empty()); + ASSERT_EQ(context_->ssl_client_session_cache()->size(), 0U); + + base::FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<net::X509Certificate> certificate1 = + ImportCertFromFile(certs_dir, "client_1.pem"); + context_->ClearMatchingClientCertificate(certificate1); + base::RunLoop().RunUntilIdle(); + + EXPECT_TRUE(context_->GetClientCertificateCachedServersForTesting().empty()); + EXPECT_EQ(context_->ssl_client_session_cache()->size(), 0U); + + auto entries = log_observer_.GetEntriesWithType( + NetLogEventType::CLEAR_MATCHING_CACHED_CLIENT_CERT); + ASSERT_EQ(1u, entries.size()); + + const auto& log_entry = entries[0]; + ASSERT_FALSE(log_entry.params.empty()); + + const base::Value::List* hosts_values = + log_entry.params.FindListByDottedPath("hosts"); + ASSERT_TRUE(hosts_values); + ASSERT_TRUE(hosts_values->empty()); + + const base::Value::List* certificates_values = + log_entry.params.FindListByDottedPath("certificates"); + ASSERT_TRUE(certificates_values); + EXPECT_FALSE(certificates_values->empty()); +} + +TEST_F(SSLClientSocketTest, ClearMatchingCertSingleNotMatching) { + SSLServerConfig server_config; + // TLS 1.3 reports client certificate errors after the handshake, so test at + // TLS 1.2 for simplicity. + server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2; + server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT; + ASSERT_TRUE( + StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config)); + + // Add a client cert decision to the cache. + base::FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<net::X509Certificate> certificate1 = + ImportCertFromFile(certs_dir, "client_1.pem"); + scoped_refptr<net::SSLPrivateKey> private_key1 = + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key")); + context_->SetClientCertificate(host_port_pair(), certificate1, private_key1); + ASSERT_EQ(context_->GetClientCertificateCachedServersForTesting().size(), 1U); + + // Create a connection to `host_port_pair()`. + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); + EXPECT_THAT(rv, IsOk()); + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_EQ(context_->ssl_client_session_cache()->size(), 1U); + + scoped_refptr<net::X509Certificate> certificate2 = + ImportCertFromFile(certs_dir, "client_2.pem"); + context_->ClearMatchingClientCertificate(certificate2); + base::RunLoop().RunUntilIdle(); + + // Verify that calling with an unused certificate should not invalidate the + // cache, but will still log an event with no hosts. + EXPECT_EQ(context_->GetClientCertificateCachedServersForTesting().size(), 1U); + EXPECT_EQ(context_->ssl_client_session_cache()->size(), 1U); + + auto entries = log_observer_.GetEntriesWithType( + NetLogEventType::CLEAR_MATCHING_CACHED_CLIENT_CERT); + ASSERT_EQ(1u, entries.size()); + + const auto& log_entry = entries[0]; + ASSERT_FALSE(log_entry.params.empty()); + + const base::Value::List* hosts_values = + log_entry.params.FindListByDottedPath("hosts"); + ASSERT_TRUE(hosts_values); + ASSERT_TRUE(hosts_values->empty()); + + const base::Value::List* certificates_values = + log_entry.params.FindListByDottedPath("certificates"); + ASSERT_TRUE(certificates_values); + EXPECT_FALSE(certificates_values->empty()); +} + +TEST_F(SSLClientSocketTest, ClearMatchingCertSingleMatching) { + SSLServerConfig server_config; + // TLS 1.3 reports client certificate errors after the handshake, so test at + // TLS 1.2 for simplicity. + server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2; + server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT; + ASSERT_TRUE( + StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config)); + + // Add a couple of client cert decision to the cache. + base::FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<net::X509Certificate> certificate1 = + ImportCertFromFile(certs_dir, "client_1.pem"); + scoped_refptr<net::SSLPrivateKey> private_key1 = + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key")); + context_->SetClientCertificate(host_port_pair(), certificate1, private_key1); + + HostPortPair host_port_pair2("example.com", 42); + scoped_refptr<net::X509Certificate> certificate2 = + ImportCertFromFile(certs_dir, "client_2.pem"); + scoped_refptr<net::SSLPrivateKey> private_key2 = + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_2.key")); + context_->SetClientCertificate(host_port_pair2, certificate2, private_key2); + ASSERT_EQ(context_->GetClientCertificateCachedServersForTesting().size(), 2U); + + // Create a connection to `host_port_pair()`. + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); + EXPECT_THAT(rv, IsOk()); + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_EQ(context_->ssl_client_session_cache()->size(), 1U); + + testing::StrictMock<MockSSLClientContextObserver> observer; + EXPECT_CALL(observer, OnSSLConfigForServersChanged( + base::flat_set<HostPortPair>({host_port_pair()}))); + context_->AddObserver(&observer); + + context_->ClearMatchingClientCertificate(certificate1); + base::RunLoop().RunUntilIdle(); + + context_->RemoveObserver(&observer); + auto cached_servers_with_decision = + context_->GetClientCertificateCachedServersForTesting(); + EXPECT_EQ(cached_servers_with_decision.size(), 1U); + EXPECT_TRUE(cached_servers_with_decision.contains(host_port_pair2)); + + EXPECT_EQ(context_->ssl_client_session_cache()->size(), 0U); + + auto entries = log_observer_.GetEntriesWithType( + NetLogEventType::CLEAR_MATCHING_CACHED_CLIENT_CERT); + ASSERT_EQ(1u, entries.size()); + + const auto& log_entry = entries[0]; + ASSERT_FALSE(log_entry.params.empty()); + + const base::Value::List* hosts_values = + log_entry.params.FindListByDottedPath("hosts"); + ASSERT_TRUE(hosts_values); + ASSERT_EQ(hosts_values->size(), 1U); + EXPECT_EQ(hosts_values->front().GetString(), host_port_pair().ToString()); + + const base::Value::List* certificates_values = + log_entry.params.FindListByDottedPath("certificates"); + ASSERT_TRUE(certificates_values); + EXPECT_FALSE(certificates_values->empty()); +} + TEST_F(SSLClientSocketTest, DontClearSessionCacheOnServerCertDatabaseChange) { SSLServerConfig server_config; // TLS 1.3 reports client certificate errors after the handshake, so test at diff --git a/net/socket/ssl_connect_job.cc b/net/socket/ssl_connect_job.cc index c1878e2a8..442891723 100644 --- a/net/socket/ssl_connect_job.cc +++ b/net/socket/ssl_connect_job.cc @@ -33,6 +33,7 @@ #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" +#include "third_party/abseil-cpp/absl/types/variant.h" #include "third_party/boringssl/src/include/openssl/pool.h" #include "third_party/boringssl/src/include/openssl/ssl.h" @@ -46,58 +47,27 @@ constexpr base::TimeDelta kSSLHandshakeTimeout(base::Seconds(30)); } // namespace SSLSocketParams::SSLSocketParams( - scoped_refptr<TransportSocketParams> direct_params, - scoped_refptr<SOCKSSocketParams> socks_proxy_params, - scoped_refptr<HttpProxySocketParams> http_proxy_params, + ConnectJobParams nested_params, const HostPortPair& host_and_port, const SSLConfig& ssl_config, NetworkAnonymizationKey network_anonymization_key) - : direct_params_(std::move(direct_params)), - socks_proxy_params_(std::move(socks_proxy_params)), - http_proxy_params_(std::move(http_proxy_params)), + : nested_params_(nested_params), host_and_port_(host_and_port), ssl_config_(ssl_config), network_anonymization_key_(network_anonymization_key) { - // Only one set of lower level ConnectJob params should be non-NULL. - DCHECK((direct_params_ && !socks_proxy_params_ && !http_proxy_params_) || - (!direct_params_ && socks_proxy_params_ && !http_proxy_params_) || - (!direct_params_ && !socks_proxy_params_ && http_proxy_params_)); + CHECK(!nested_params_.is_ssl()); } SSLSocketParams::~SSLSocketParams() = default; SSLSocketParams::ConnectionType SSLSocketParams::GetConnectionType() const { - if (direct_params_.get()) { - DCHECK(!socks_proxy_params_.get()); - DCHECK(!http_proxy_params_.get()); - return DIRECT; - } - - if (socks_proxy_params_.get()) { - DCHECK(!http_proxy_params_.get()); + if (nested_params_.is_socks()) { return SOCKS_PROXY; } - - DCHECK(http_proxy_params_.get()); - return HTTP_PROXY; -} - -const scoped_refptr<TransportSocketParams>& -SSLSocketParams::GetDirectConnectionParams() const { - DCHECK_EQ(GetConnectionType(), DIRECT); - return direct_params_; -} - -const scoped_refptr<SOCKSSocketParams>& -SSLSocketParams::GetSocksProxyConnectionParams() const { - DCHECK_EQ(GetConnectionType(), SOCKS_PROXY); - return socks_proxy_params_; -} - -const scoped_refptr<HttpProxySocketParams>& -SSLSocketParams::GetHttpProxyConnectionParams() const { - DCHECK_EQ(GetConnectionType(), HTTP_PROXY); - return http_proxy_params_; + if (nested_params_.is_http_proxy()) { + return HTTP_PROXY; + } + return DIRECT; } std::unique_ptr<SSLConnectJob> SSLConnectJob::Factory::Create( @@ -149,8 +119,9 @@ LoadState SSLConnectJob::GetLoadState() const { case STATE_SOCKS_CONNECT_COMPLETE: return nested_connect_job_->GetLoadState(); case STATE_TUNNEL_CONNECT_COMPLETE: - if (nested_socket_) + if (nested_socket_) { return LOAD_STATE_ESTABLISHING_PROXY_TUNNEL; + } return nested_connect_job_->GetLoadState(); case STATE_SSL_CONNECT: case STATE_SSL_CONNECT_COMPLETE: @@ -163,8 +134,9 @@ LoadState SSLConnectJob::GetLoadState() const { bool SSLConnectJob::HasEstablishedConnection() const { // If waiting on a nested ConnectJob, defer to that ConnectJob's state. - if (nested_connect_job_) + if (nested_connect_job_) { return nested_connect_job_->HasEstablishedConnection(); + } // Otherwise, return true if a socket has been created. return nested_socket_ || ssl_socket_; } @@ -213,8 +185,9 @@ base::TimeDelta SSLConnectJob::HandshakeTimeoutForTesting() { void SSLConnectJob::OnIOComplete(int result) { int rv = DoLoop(result); - if (rv != ERR_IO_PENDING) + if (rv != ERR_IO_PENDING) { NotifyDelegateOfCompletion(rv); // Deletes |this|. + } } int SSLConnectJob::DoLoop(int result) { @@ -574,8 +547,9 @@ void SSLConnectJob::ResetStateForRestart() { } void SSLConnectJob::ChangePriorityInternal(RequestPriority priority) { - if (nested_connect_job_) + if (nested_connect_job_) { nested_connect_job_->ChangePriority(priority); + } } } // namespace net diff --git a/net/socket/ssl_connect_job.h b/net/socket/ssl_connect_job.h index 3df1d8719..842193487 100644 --- a/net/socket/ssl_connect_job.h +++ b/net/socket/ssl_connect_job.h @@ -22,6 +22,7 @@ #include "net/dns/public/host_resolver_results.h" #include "net/dns/public/resolve_error_info.h" #include "net/socket/connect_job.h" +#include "net/socket/connect_job_params.h" #include "net/socket/connection_attempts.h" #include "net/socket/ssl_client_socket.h" #include "net/ssl/ssl_cert_request_info.h" @@ -42,9 +43,7 @@ class NET_EXPORT_PRIVATE SSLSocketParams // Exactly one of |direct_params|, |socks_proxy_params|, and // |http_proxy_params| must be non-NULL. - SSLSocketParams(scoped_refptr<TransportSocketParams> direct_params, - scoped_refptr<SOCKSSocketParams> socks_proxy_params, - scoped_refptr<HttpProxySocketParams> http_proxy_params, + SSLSocketParams(ConnectJobParams params, const HostPortPair& host_and_port, const SSLConfig& ssl_config, NetworkAnonymizationKey network_anonymization_key); @@ -56,14 +55,22 @@ class NET_EXPORT_PRIVATE SSLSocketParams ConnectionType GetConnectionType() const; // Must be called only when GetConnectionType() returns DIRECT. - const scoped_refptr<TransportSocketParams>& GetDirectConnectionParams() const; + const scoped_refptr<TransportSocketParams>& GetDirectConnectionParams() + const { + return nested_params_.transport(); + } // Must be called only when GetConnectionType() returns SOCKS_PROXY. - const scoped_refptr<SOCKSSocketParams>& GetSocksProxyConnectionParams() const; + const scoped_refptr<SOCKSSocketParams>& GetSocksProxyConnectionParams() + const { + return nested_params_.socks(); + } // Must be called only when GetConnectionType() returns HTTP_PROXY. const scoped_refptr<HttpProxySocketParams>& GetHttpProxyConnectionParams() - const; + const { + return nested_params_.http_proxy(); + } const HostPortPair& host_and_port() const { return host_and_port_; } const SSLConfig& ssl_config() const { return ssl_config_; } @@ -75,9 +82,7 @@ class NET_EXPORT_PRIVATE SSLSocketParams friend class base::RefCounted<SSLSocketParams>; ~SSLSocketParams(); - const scoped_refptr<TransportSocketParams> direct_params_; - const scoped_refptr<SOCKSSocketParams> socks_proxy_params_; - const scoped_refptr<HttpProxySocketParams> http_proxy_params_; + const ConnectJobParams nested_params_; const HostPortPair host_and_port_; const SSLConfig ssl_config_; const NetworkAnonymizationKey network_anonymization_key_; diff --git a/net/socket/ssl_connect_job_unittest.cc b/net/socket/ssl_connect_job_unittest.cc index 71b7b2b5e..52cd6bfc4 100644 --- a/net/socket/ssl_connect_job_unittest.cc +++ b/net/socket/ssl_connect_job_unittest.cc @@ -137,7 +137,7 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test { scoped_refptr<SOCKSSocketParams> CreateSOCKSSocketParams( SecureDnsPolicy secure_dns_policy) { return base::MakeRefCounted<SOCKSSocketParams>( - CreateProxyTransportSocketParams(secure_dns_policy), + ConnectJobParams(CreateProxyTransportSocketParams(secure_dns_policy)), kSocksProxyServer.scheme() == ProxyServer::SCHEME_SOCKS5, kSocksProxyServer.host_port_pair(), NetworkAnonymizationKey(), TRAFFIC_ANNOTATION_FOR_TESTS); @@ -146,9 +146,8 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test { scoped_refptr<HttpProxySocketParams> CreateHttpProxySocketParams( SecureDnsPolicy secure_dns_policy) { return base::MakeRefCounted<HttpProxySocketParams>( - CreateProxyTransportSocketParams(secure_dns_policy), - /*ssl_params=*/nullptr, /*quic_ssl_config=*/std::nullopt, kHostHttp, - kHttpProxyChain, + ConnectJobParams(CreateProxyTransportSocketParams(secure_dns_policy)), + kHostHttp, kHttpProxyChain, /*proxy_server_index=*/0, /*tunnel=*/true, TRAFFIC_ANNOTATION_FOR_TESTS, NetworkAnonymizationKey(), secure_dns_policy); @@ -170,16 +169,15 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test { SecureDnsPolicy secure_dns_policy) { return base::MakeRefCounted<SSLSocketParams>( proxy_chain == ProxyChain::Direct() - ? CreateDirectTransportSocketParams(secure_dns_policy) - : nullptr, - proxy_chain.is_single_proxy() && + ? ConnectJobParams( + CreateDirectTransportSocketParams(secure_dns_policy)) + : proxy_chain.is_single_proxy() && proxy_chain.First().scheme() == ProxyServer::SCHEME_SOCKS5 - ? CreateSOCKSSocketParams(secure_dns_policy) - : nullptr, - proxy_chain.is_single_proxy() && + ? ConnectJobParams(CreateSOCKSSocketParams(secure_dns_policy)) + : proxy_chain.is_single_proxy() && proxy_chain.First().scheme() == ProxyServer::SCHEME_HTTP - ? CreateHttpProxySocketParams(secure_dns_policy) - : nullptr, + ? ConnectJobParams(CreateHttpProxySocketParams(secure_dns_policy)) + : ConnectJobParams(), HostPortPair::FromSchemeHostPort(kHostHttps), SSLConfig(), NetworkAnonymizationKey()); } @@ -442,8 +440,9 @@ TEST_F(SSLConnectJobTest, RequestPriority) { for (int new_priority = MINIMUM_PRIORITY; new_priority <= MAXIMUM_PRIORITY; ++new_priority) { SCOPED_TRACE(new_priority); - if (initial_priority == new_priority) + if (initial_priority == new_priority) { continue; + } TestConnectJobDelegate test_delegate; std::unique_ptr<ConnectJob> ssl_connect_job = CreateConnectJob(&test_delegate, ProxyChain::Direct(), @@ -724,8 +723,9 @@ TEST_F(SSLConnectJobTest, SOCKSRequestPriority) { for (int new_priority = MINIMUM_PRIORITY; new_priority <= MAXIMUM_PRIORITY; ++new_priority) { SCOPED_TRACE(new_priority); - if (initial_priority == new_priority) + if (initial_priority == new_priority) { continue; + } TestConnectJobDelegate test_delegate; std::unique_ptr<ConnectJob> ssl_connect_job = CreateConnectJob( &test_delegate, PacResultElementToProxyChain("SOCKS5 foo:333"), @@ -872,8 +872,9 @@ TEST_F(SSLConnectJobTest, HttpProxyRequestPriority) { for (int new_priority = MINIMUM_PRIORITY; new_priority <= MAXIMUM_PRIORITY; ++new_priority) { SCOPED_TRACE(new_priority); - if (initial_priority == new_priority) + if (initial_priority == new_priority) { continue; + } TestConnectJobDelegate test_delegate; std::unique_ptr<ConnectJob> ssl_connect_job = CreateConnectJob( &test_delegate, PacResultElementToProxyChain("PROXY foo:444"), diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index 95ab2c63e..4eaf05f34 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -57,6 +57,7 @@ #include "net/socket/socket_test_util.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/ssl/openssl_private_key.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_cipher_suite_names.h" #include "net/ssl/ssl_client_session_cache.h" @@ -65,7 +66,6 @@ #include "net/ssl/ssl_private_key.h" #include "net/ssl/ssl_server_config.h" #include "net/ssl/test_ssl_config_service.h" -#include "net/ssl/test_ssl_private_key.h" #include "net/test/cert_test_util.h" #include "net/test/gtest_util.h" #include "net/test/test_data_directory.h" @@ -502,11 +502,13 @@ class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment { std::unique_ptr<FakeDataChannel> channel_1_; std::unique_ptr<FakeDataChannel> channel_2_; - SSLConfig client_ssl_config_; - SSLServerConfig server_ssl_config_; std::unique_ptr<TestSSLConfigService> ssl_config_service_; std::unique_ptr<MockCertVerifier> cert_verifier_; std::unique_ptr<MockClientCertVerifier> client_cert_verifier_; + SSLConfig client_ssl_config_; + // Note that this has a pointer to the `cert_verifier_`, so must be destroyed + // before that is. + SSLServerConfig server_ssl_config_; std::unique_ptr<TransportSecurityState> transport_security_state_; std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_; std::unique_ptr<SSLClientContext> client_context_; diff --git a/net/socket/tcp_socket_win.h b/net/socket/tcp_socket_win.h index f6cc34e91..5dd11fad5 100644 --- a/net/socket/tcp_socket_win.h +++ b/net/socket/tcp_socket_win.h @@ -5,9 +5,10 @@ #ifndef NET_SOCKET_TCP_SOCKET_WIN_H_ #define NET_SOCKET_TCP_SOCKET_WIN_H_ -#include <stdint.h> #include <winsock2.h> +#include <stdint.h> + #include <memory> #include "base/memory/raw_ptr.h" diff --git a/net/socket/udp_socket_win.cc b/net/socket/udp_socket_win.cc index 2e28570f9..b172fc368 100644 --- a/net/socket/udp_socket_win.cc +++ b/net/socket/udp_socket_win.cc @@ -4,9 +4,10 @@ #include "net/socket/udp_socket_win.h" -#include <mstcpip.h> #include <winsock2.h> +#include <mstcpip.h> + #include <memory> #include "base/check_op.h" diff --git a/net/socket/udp_socket_win.h b/net/socket/udp_socket_win.h index b4008db4c..97dd80630 100644 --- a/net/socket/udp_socket_win.h +++ b/net/socket/udp_socket_win.h @@ -5,9 +5,10 @@ #ifndef NET_SOCKET_UDP_SOCKET_WIN_H_ #define NET_SOCKET_UDP_SOCKET_WIN_H_ +#include <winsock2.h> + #include <qos2.h> #include <stdint.h> -#include <winsock2.h> // Must be after winsock2.h: #include <MSWSock.h> |