diff options
author | Cronet Mainline Eng <cronet-mainline-eng+copybara@google.com> | 2024-06-07 18:59:21 +0900 |
---|---|---|
committer | Motomu Utsumi <motomuman@google.com> | 2024-06-07 19:00:37 +0900 |
commit | 93dc77d4cfa4a2996ac5bf4c67b0d4223847eb65 (patch) | |
tree | 92c5aba3655194ff55d27c652a265bd6f87b0470 /net/socket | |
parent | b66ce594f84a102bf71c3e2754d9c0bfdd620b85 (diff) | |
download | cronet-93dc77d4cfa4a2996ac5bf4c67b0d4223847eb65.tar.gz |
Import Cronet version 124.0.6367.42
FolderOrigin-RevId: /tmp/copybara-origin/src
Change-Id: I727d2277512236d7d0db42e102d291b6204b38e5
Diffstat (limited to 'net/socket')
-rw-r--r-- | net/socket/connect_job_factory.cc | 293 | ||||
-rw-r--r-- | net/socket/connect_job_params_factory.cc | 368 | ||||
-rw-r--r-- | net/socket/connect_job_params_factory.h | 55 | ||||
-rw-r--r-- | net/socket/connect_job_params_factory_unittest.cc | 770 | ||||
-rw-r--r-- | net/socket/socket_test_util.cc | 67 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 26 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_impl.cc | 138 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_unittest.cc | 49 | ||||
-rw-r--r-- | net/socket/ssl_connect_job.cc | 61 | ||||
-rw-r--r-- | net/socket/ssl_connect_job.h | 4 | ||||
-rw-r--r-- | net/socket/ssl_connect_job_unittest.cc | 115 | ||||
-rw-r--r-- | net/socket/transport_client_socket_pool.h | 2 | ||||
-rw-r--r-- | net/socket/udp_socket_unittest.cc | 108 | ||||
-rw-r--r-- | net/socket/udp_socket_win.cc | 224 | ||||
-rw-r--r-- | net/socket/udp_socket_win.h | 50 |
15 files changed, 1655 insertions, 675 deletions
diff --git a/net/socket/connect_job_factory.cc b/net/socket/connect_job_factory.cc index 2ab34acfd..fa437fb3e 100644 --- a/net/socket/connect_job_factory.cc +++ b/net/socket/connect_job_factory.cc @@ -10,18 +10,16 @@ #include <vector> #include "base/check.h" -#include "base/containers/flat_set.h" #include "base/memory/scoped_refptr.h" #include "net/base/host_port_pair.h" #include "net/base/network_anonymization_key.h" #include "net/base/privacy_mode.h" #include "net/base/proxy_chain.h" -#include "net/base/proxy_server.h" #include "net/base/request_priority.h" #include "net/dns/public/secure_dns_policy.h" #include "net/http/http_proxy_connect_job.h" #include "net/socket/connect_job.h" -#include "net/socket/next_proto.h" +#include "net/socket/connect_job_params_factory.h" #include "net/socket/socket_tag.h" #include "net/socket/socks_connect_job.h" #include "net/socket/ssl_connect_job.h" @@ -29,7 +27,6 @@ #include "net/ssl/ssl_config.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "third_party/abseil-cpp/absl/types/variant.h" -#include "url/gurl.h" #include "url/scheme_host_port.h" namespace net { @@ -38,116 +35,10 @@ namespace { template <typename T> std::unique_ptr<T> CreateFactoryIfNull(std::unique_ptr<T> in) { - if (in) + if (in) { return in; - return std::make_unique<T>(); -} - -bool UsingSsl(const ConnectJobFactory::Endpoint& endpoint) { - if (absl::holds_alternative<url::SchemeHostPort>(endpoint)) { - return GURL::SchemeIsCryptographic( - base::ToLowerASCII(absl::get<url::SchemeHostPort>(endpoint).scheme())); - } - - DCHECK( - absl::holds_alternative<ConnectJobFactory::SchemelessEndpoint>(endpoint)); - return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint).using_ssl; -} - -HostPortPair ToHostPortPair(const ConnectJobFactory::Endpoint& endpoint) { - if (absl::holds_alternative<url::SchemeHostPort>(endpoint)) { - return HostPortPair::FromSchemeHostPort( - absl::get<url::SchemeHostPort>(endpoint)); - } - - DCHECK( - absl::holds_alternative<ConnectJobFactory::SchemelessEndpoint>(endpoint)); - return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint) - .host_port_pair; -} - -TransportSocketParams::Endpoint ToTransportEndpoint( - const ConnectJobFactory::Endpoint& endpoint) { - if (absl::holds_alternative<url::SchemeHostPort>(endpoint)) - return absl::get<url::SchemeHostPort>(endpoint); - - DCHECK( - absl::holds_alternative<ConnectJobFactory::SchemelessEndpoint>(endpoint)); - return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint) - .host_port_pair; -} - -base::flat_set<std::string> SupportedProtocolsFromSSLConfig( - const SSLConfig& config) { - // We convert because `SSLConfig` uses `NextProto` for ALPN protocols while - // `TransportConnectJob` and DNS logic needs `std::string`. See - // https://crbug.com/1286835. - return base::MakeFlatSet<std::string>(config.alpn_protos, /*comp=*/{}, - NextProtoToString); -} - -// Populates `ssl_config's` ALPN-related fields. Namely, `alpn_protos`, -// `application_settings`, `renego_allowed_default`, and -// `renego_allowed_for_protos`. -// -// In the case of AlpnMode::kDisabled, clears all of the fields. -// -// In the case of AlpnMode::kHttp11Only sets `alpn_protos` to only allow -// HTTP/1.1 negotiation. -// -// In the case of AlpnMode::kHttpAll, copying `alpn_protos` from -// `common_connect_job_params`, and gives HttpServerProperties a chance to force -// use of HTTP/1.1 only. -// -// If `alpn_mode` is not AlpnMode::kDisabled, then `server` must be a -// SchemeHostPort, as it makes no sense to negotiate ALPN when the scheme isn't -// known. -void ConfigureAlpn( - const ConnectJobFactory::Endpoint& endpoint, - ConnectJobFactory::AlpnMode alpn_mode, - const net::NetworkAnonymizationKey& network_anonymization_key, - const CommonConnectJobParams& common_connect_job_params, - SSLConfig& ssl_config, - bool renego_allowed) { - if (alpn_mode == ConnectJobFactory::AlpnMode::kDisabled) { - ssl_config.alpn_protos = {}; - ssl_config.application_settings = {}; - ssl_config.renego_allowed_default = false; - return; - } - - DCHECK(absl::holds_alternative<url::SchemeHostPort>(endpoint)); - - if (alpn_mode == ConnectJobFactory::AlpnMode::kHttp11Only) { - ssl_config.alpn_protos = {kProtoHTTP11}; - ssl_config.application_settings = - *common_connect_job_params.application_settings; - } else { - DCHECK_EQ(alpn_mode, ConnectJobFactory::AlpnMode::kHttpAll); - DCHECK(absl::holds_alternative<url::SchemeHostPort>(endpoint)); - ssl_config.alpn_protos = *common_connect_job_params.alpn_protos; - ssl_config.application_settings = - *common_connect_job_params.application_settings; - if (common_connect_job_params.http_server_properties) { - common_connect_job_params.http_server_properties->MaybeForceHTTP11( - absl::get<url::SchemeHostPort>(endpoint), network_anonymization_key, - &ssl_config); - } - } - - // Prior to HTTP/2 and SPDY, some servers used TLS renegotiation to request - // TLS client authentication after the HTTP request was sent. Allow - // renegotiation for only those connections. - // - // Note that this does NOT implement the provision in - // https://http2.github.io/http2-spec/#rfc.section.9.2.1 which allows the - // server to request a renegotiation immediately before sending the - // connection preface as waiting for the preface would cost the round trip - // that False Start otherwise saves. - ssl_config.renego_allowed_default = renego_allowed; - if (renego_allowed) { - ssl_config.renego_allowed_for_protos = {kProtoHTTP11}; } + return std::make_unique<T>(); } } // namespace @@ -233,177 +124,43 @@ std::unique_ptr<ConnectJob> ConnectJobFactory::CreateConnectJob( bool disable_cert_network_fetches, const CommonConnectJobParams* common_connect_job_params, ConnectJob::Delegate* delegate) const { - scoped_refptr<HttpProxySocketParams> http_proxy_params; - scoped_refptr<SOCKSSocketParams> socks_params; - base::flat_set<std::string> no_alpn_protocols; - - DCHECK(proxy_chain.IsValid()); - if (!proxy_chain.is_direct()) { - // The first iteration of this loop is taken for all types of proxies and - // creates a TransportSocketParams and other socket params based on the - // proxy type. For nested proxies, we then create additional SSLSocketParam - // and HttpProxySocketParam objects for the remaining hops. This is done by - // working backwards through the proxy chain and creating socket params - // such that connect jobs will be created recursively with dependencies in - // the correct order (in other words, the inner-most connect job will - // establish a connection to the first proxy, and then that connection - // will get used to establish a connection to the second proxy). - for (size_t proxy_index = 0; proxy_index < proxy_chain.length(); - ++proxy_index) { - const ProxyServer& proxy_server = proxy_chain.GetProxyServer(proxy_index); - - SSLConfig proxy_server_ssl_config; - if (proxy_server.is_secure_http_like()) { - // Disable cert verification network fetches for secure proxies, since - // those network requests are probably going to need to go through the - // proxy chain too. - // - // Any proxy-specific SSL behavior here should also be configured for - // QUIC proxies. - // - proxy_server_ssl_config.disable_cert_verification_network_fetches = - true; - ConfigureAlpn(url::SchemeHostPort(url::kHttpsScheme, - proxy_server.host_port_pair().host(), - proxy_server.host_port_pair().port()), - // Always enable ALPN for proxies. - ConnectJobFactory::AlpnMode::kHttpAll, - network_anonymization_key, *common_connect_job_params, - proxy_server_ssl_config, /*renego_allowed=*/false); - } - - scoped_refptr<TransportSocketParams> proxy_tcp_params; - if (proxy_index == 0) { - // In the first iteration create the only TransportSocketParams object, - // corresponding to the transport socket we want to create to the first - // proxy. - // TODO(crbug.com/1206799): For an http-like proxy, should this pass a - // `SchemeHostPort`, so proxies can participate in ECH? Note doing so - // with `SCHEME_HTTP` requires handling the HTTPS record upgrade. - proxy_tcp_params = base::MakeRefCounted<TransportSocketParams>( - proxy_server.host_port_pair(), proxy_dns_network_anonymization_key_, - secure_dns_policy, resolution_callback, - proxy_server.is_secure_http_like() - ? SupportedProtocolsFromSSLConfig(proxy_server_ssl_config) - : no_alpn_protocols); - } else { - // TODO(https://crbug.com/1491092): For now we will assume that proxy - // chains with multiple proxies must all use HTTPS. - CHECK(http_proxy_params); - CHECK(http_proxy_params->ssl_params()); - CHECK( - proxy_chain.GetProxyServer(proxy_index - 1).is_secure_http_like()); - } - - if (proxy_server.is_http_like()) { - scoped_refptr<SSLSocketParams> ssl_params; - if (proxy_server.is_secure_http_like()) { - // Set `ssl_params`, and unset `proxy_tcp_params`. - ssl_params = base::MakeRefCounted<SSLSocketParams>( - std::move(proxy_tcp_params), /*socks_proxy_params=*/nullptr, - std::move(http_proxy_params), proxy_server.host_port_pair(), - proxy_server_ssl_config, PRIVACY_MODE_DISABLED, - network_anonymization_key); - proxy_tcp_params = nullptr; - } - - // The endpoint parameter for this HttpProxySocketParams, which is what - // we will CONNECT to, should correspond to either `endpoint` (for - // one-hop proxies) or the proxy server at index 1 (for n-hop proxies). - HostPortPair connect_host_port_pair; - bool should_tunnel; - if (proxy_index + 1 == proxy_chain.length()) { - connect_host_port_pair = ToHostPortPair(endpoint); - should_tunnel = force_tunnel || UsingSsl(endpoint) || - !proxy_chain.is_get_to_proxy_allowed(); - } else { - const auto& next_proxy_server = - proxy_chain.GetProxyServer(proxy_index + 1); - connect_host_port_pair = next_proxy_server.host_port_pair(); - // TODO(https://crbug.com/1491092): For now we will assume that proxy - // chains with multiple proxies must all use HTTPS. - CHECK(next_proxy_server.is_secure_http_like()); - should_tunnel = true; - } - - // TODO(crbug.com/1206799): Pass `endpoint` directly (preserving - // scheme when available)? - http_proxy_params = base::MakeRefCounted<HttpProxySocketParams>( - std::move(proxy_tcp_params), std::move(ssl_params), - connect_host_port_pair, proxy_chain, proxy_index, should_tunnel, - *proxy_annotation_tag, network_anonymization_key, - secure_dns_policy); - } else { - DCHECK(proxy_server.is_socks()); - DCHECK_EQ(1u, proxy_chain.length()); - // TODO(crbug.com/1206799): Pass `endpoint` directly (preserving scheme - // when available)? - socks_params = base::MakeRefCounted<SOCKSSocketParams>( - std::move(proxy_tcp_params), - proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5, - ToHostPortPair(endpoint), network_anonymization_key, - *proxy_annotation_tag); - } - } - } - - // Deal with SSL - which layers on top of any given proxy. - if (UsingSsl(endpoint)) { - scoped_refptr<TransportSocketParams> ssl_tcp_params; - - SSLConfig ssl_config; - ssl_config.allowed_bad_certs = allowed_bad_certs; - - ConfigureAlpn(endpoint, alpn_mode, network_anonymization_key, - *common_connect_job_params, ssl_config, - /*renego_allowed=*/true); - - ssl_config.disable_cert_verification_network_fetches = - disable_cert_network_fetches; - - // TODO(https://crbug.com/964642): Also enable 0-RTT for TLS proxies. - ssl_config.early_data_enabled = - *common_connect_job_params->enable_early_data; + ConnectJobParams connect_job_params = ConstructConnectJobParams( + endpoint, proxy_chain, proxy_annotation_tag, allowed_bad_certs, alpn_mode, + force_tunnel, privacy_mode, resolution_callback, + network_anonymization_key, secure_dns_policy, + disable_cert_network_fetches, common_connect_job_params, + proxy_dns_network_anonymization_key_); - if (proxy_chain.is_direct()) { - ssl_tcp_params = base::MakeRefCounted<TransportSocketParams>( - ToTransportEndpoint(endpoint), network_anonymization_key, - secure_dns_policy, resolution_callback, - SupportedProtocolsFromSSLConfig(ssl_config)); - } - // TODO(crbug.com/1206799): Pass `endpoint` directly (preserving scheme - // when available)? - auto ssl_params = base::MakeRefCounted<SSLSocketParams>( - std::move(ssl_tcp_params), std::move(socks_params), - std::move(http_proxy_params), ToHostPortPair(endpoint), ssl_config, - privacy_mode, network_anonymization_key); + if (holds_alternative<scoped_refptr<SSLSocketParams>>(connect_job_params)) { return ssl_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - std::move(ssl_params), delegate, /*net_log=*/nullptr); + get<scoped_refptr<SSLSocketParams>>(std::move(connect_job_params)), + delegate, /*net_log=*/nullptr); } - // Only SSL/TLS-based endpoints have ALPN protocols. - if (proxy_chain.is_direct()) { - auto tcp_params = base::MakeRefCounted<TransportSocketParams>( - ToTransportEndpoint(endpoint), network_anonymization_key, - secure_dns_policy, resolution_callback, no_alpn_protocols); + if (holds_alternative<scoped_refptr<TransportSocketParams>>( + connect_job_params)) { return transport_connect_job_factory_->Create( - request_priority, socket_tag, common_connect_job_params, tcp_params, + request_priority, socket_tag, common_connect_job_params, + get<scoped_refptr<TransportSocketParams>>( + std::move(connect_job_params)), delegate, /*net_log=*/nullptr); } - const ProxyServer& last_proxy_server = proxy_chain.Last(); - if (http_proxy_params) { - DCHECK(last_proxy_server.is_http_like()); + if (holds_alternative<scoped_refptr<HttpProxySocketParams>>( + connect_job_params)) { return http_proxy_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - std::move(http_proxy_params), delegate, /*net_log=*/nullptr); + get<scoped_refptr<HttpProxySocketParams>>(connect_job_params), delegate, + /*net_log=*/nullptr); } - DCHECK(last_proxy_server.is_socks()); + CHECK( + holds_alternative<scoped_refptr<SOCKSSocketParams>>(connect_job_params)); return socks_connect_job_factory_->Create( request_priority, socket_tag, common_connect_job_params, - std::move(socks_params), delegate, /*net_log=*/nullptr); + get<scoped_refptr<SOCKSSocketParams>>(std::move(connect_job_params)), + delegate, /*net_log=*/nullptr); } } // namespace net diff --git a/net/socket/connect_job_params_factory.cc b/net/socket/connect_job_params_factory.cc new file mode 100644 index 000000000..99a26835f --- /dev/null +++ b/net/socket/connect_job_params_factory.cc @@ -0,0 +1,368 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/connect_job_params_factory.h" + +#include <optional> +#include <vector> + +#include "base/check.h" +#include "base/containers/flat_set.h" +#include "base/memory/scoped_refptr.h" +#include "net/base/host_port_pair.h" +#include "net/base/network_anonymization_key.h" +#include "net/base/privacy_mode.h" +#include "net/base/proxy_chain.h" +#include "net/base/proxy_server.h" +#include "net/base/request_priority.h" +#include "net/dns/public/secure_dns_policy.h" +#include "net/http/http_proxy_connect_job.h" +#include "net/socket/next_proto.h" +#include "net/socket/socket_tag.h" +#include "net/socket/socks_connect_job.h" +#include "net/socket/ssl_connect_job.h" +#include "net/socket/transport_connect_job.h" +#include "net/ssl/ssl_config.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "third_party/abseil-cpp/absl/types/variant.h" +#include "url/gurl.h" +#include "url/scheme_host_port.h" + +namespace net { + +namespace { + +// Populates `ssl_config's` ALPN-related fields. Namely, `alpn_protos`, +// `application_settings`, `renego_allowed_default`, and +// `renego_allowed_for_protos`. +// +// In the case of `AlpnMode::kDisabled`, clears all of the fields. +// +// In the case of `AlpnMode::kHttp11Only`, sets `alpn_protos` to only allow +// HTTP/1.1 negotiation. +// +// In the case of `AlpnMode::kHttpAll`, copies `alpn_protos` from +// `common_connect_job_params`, and gives `HttpServerProperties` a chance to +// force use of HTTP/1.1 only. +// +// If `alpn_mode` is not `AlpnMode::kDisabled`, then `server` must be a +// `SchemeHostPort`, as it makes no sense to negotiate ALPN when the scheme +// isn't known. +void ConfigureAlpn(const ConnectJobFactory::Endpoint& endpoint, + ConnectJobFactory::AlpnMode alpn_mode, + const NetworkAnonymizationKey& network_anonymization_key, + const CommonConnectJobParams& common_connect_job_params, + SSLConfig& ssl_config, + bool renego_allowed) { + if (alpn_mode == ConnectJobFactory::AlpnMode::kDisabled) { + ssl_config.alpn_protos = {}; + ssl_config.application_settings = {}; + ssl_config.renego_allowed_default = false; + return; + } + + DCHECK(absl::holds_alternative<url::SchemeHostPort>(endpoint)); + + if (alpn_mode == ConnectJobFactory::AlpnMode::kHttp11Only) { + ssl_config.alpn_protos = {kProtoHTTP11}; + ssl_config.application_settings = + *common_connect_job_params.application_settings; + } else { + DCHECK_EQ(alpn_mode, ConnectJobFactory::AlpnMode::kHttpAll); + DCHECK(absl::holds_alternative<url::SchemeHostPort>(endpoint)); + ssl_config.alpn_protos = *common_connect_job_params.alpn_protos; + ssl_config.application_settings = + *common_connect_job_params.application_settings; + if (common_connect_job_params.http_server_properties) { + common_connect_job_params.http_server_properties->MaybeForceHTTP11( + absl::get<url::SchemeHostPort>(endpoint), network_anonymization_key, + &ssl_config); + } + } + + // Prior to HTTP/2 and SPDY, some servers used TLS renegotiation to request + // TLS client authentication after the HTTP request was sent. Allow + // renegotiation for only those connections. + // + // Note that this does NOT implement the provision in + // https://http2.github.io/http2-spec/#rfc.section.9.2.1 which allows the + // server to request a renegotiation immediately before sending the + // connection preface as waiting for the preface would cost the round trip + // that False Start otherwise saves. + ssl_config.renego_allowed_default = renego_allowed; + if (renego_allowed) { + ssl_config.renego_allowed_for_protos = {kProtoHTTP11}; + } +} + +base::flat_set<std::string> SupportedProtocolsFromSSLConfig( + const SSLConfig& config) { + // We convert because `SSLConfig` uses `NextProto` for ALPN protocols while + // `TransportConnectJob` and DNS logic needs `std::string`. See + // https://crbug.com/1286835. + return base::MakeFlatSet<std::string>(config.alpn_protos, /*comp=*/{}, + NextProtoToString); +} + +HostPortPair ToHostPortPair(const ConnectJobFactory::Endpoint& endpoint) { + if (absl::holds_alternative<url::SchemeHostPort>(endpoint)) { + return HostPortPair::FromSchemeHostPort( + absl::get<url::SchemeHostPort>(endpoint)); + } + + DCHECK( + absl::holds_alternative<ConnectJobFactory::SchemelessEndpoint>(endpoint)); + return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint) + .host_port_pair; +} + +TransportSocketParams::Endpoint ToTransportEndpoint( + const ConnectJobFactory::Endpoint& endpoint) { + if (absl::holds_alternative<url::SchemeHostPort>(endpoint)) { + return absl::get<url::SchemeHostPort>(endpoint); + } + + DCHECK( + absl::holds_alternative<ConnectJobFactory::SchemelessEndpoint>(endpoint)); + return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint) + .host_port_pair; +} + +bool UsingSsl(const ConnectJobFactory::Endpoint& endpoint) { + if (absl::holds_alternative<url::SchemeHostPort>(endpoint)) { + return GURL::SchemeIsCryptographic( + base::ToLowerASCII(absl::get<url::SchemeHostPort>(endpoint).scheme())); + } + + DCHECK( + absl::holds_alternative<ConnectJobFactory::SchemelessEndpoint>(endpoint)); + return absl::get<ConnectJobFactory::SchemelessEndpoint>(endpoint).using_ssl; +} + +scoped_refptr<HttpProxySocketParams> MaybeHttpProxySocketParams( + const ConnectJobParams& params) { + if (auto p = get_if<scoped_refptr<HttpProxySocketParams>>(¶ms)) { + return *p; + } + return nullptr; +} + +scoped_refptr<SOCKSSocketParams> MaybeSOCKSSocketParams( + const ConnectJobParams& params) { + if (auto p = get_if<scoped_refptr<SOCKSSocketParams>>(¶ms)) { + return *p; + } + return nullptr; +} + +scoped_refptr<TransportSocketParams> MaybeTransportSocketParams( + const ConnectJobParams& params) { + if (auto p = get_if<scoped_refptr<TransportSocketParams>>(¶ms)) { + return *p; + } + return nullptr; +} + +scoped_refptr<SSLSocketParams> MaybeSSLSocketParams( + const ConnectJobParams& params) { + if (auto p = get_if<scoped_refptr<SSLSocketParams>>(¶ms)) { + return *p; + } + return nullptr; +} + +ConnectJobParams MakeSSLSocketParams( + ConnectJobParams params, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const NetworkAnonymizationKey& network_anonymization_key) { + scoped_refptr<TransportSocketParams> transport_socket_params = + MaybeTransportSocketParams(params); + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params = + MaybeHttpProxySocketParams(params); + scoped_refptr<SOCKSSocketParams> socks_socket_params = + MaybeSOCKSSocketParams(params); + return ConnectJobParams(base::MakeRefCounted<SSLSocketParams>( + std::move(transport_socket_params), std::move(socks_socket_params), + std::move(http_proxy_socket_params), host_and_port, ssl_config, + network_anonymization_key)); +} + +// Recursively generate the params for a proxy at `host_port_pair` and the given +// index in the proxy chain. This proceeds from the end of the proxy chain back +// to the first proxy server. +ConnectJobParams CreateProxyParams( + HostPortPair host_port_pair, + bool should_tunnel, + const ConnectJobFactory::Endpoint& endpoint, + const ProxyChain& proxy_chain, + size_t proxy_chain_index, + const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, + const OnHostResolutionCallback& resolution_callback, + const NetworkAnonymizationKey& network_anonymization_key, + SecureDnsPolicy secure_dns_policy, + const CommonConnectJobParams* common_connect_job_params, + const NetworkAnonymizationKey& proxy_dns_network_anonymization_key) { + const ProxyServer& proxy_server = + proxy_chain.GetProxyServer(proxy_chain_index); + + // Set up the SSLConfig if using SSL to the proxy. + SSLConfig proxy_server_ssl_config; + if (proxy_server.is_secure_http_like()) { + // Disable cert verification network fetches for secure proxies, since + // those network requests are probably going to need to go through the + // proxy chain too. + // + // Any proxy-specific SSL behavior here should also be configured for + // QUIC proxies. + proxy_server_ssl_config.disable_cert_verification_network_fetches = true; + ConfigureAlpn(url::SchemeHostPort(url::kHttpsScheme, + proxy_server.host_port_pair().host(), + proxy_server.host_port_pair().port()), + // Always enable ALPN for proxies. + ConnectJobFactory::AlpnMode::kHttpAll, + network_anonymization_key, *common_connect_job_params, + proxy_server_ssl_config, /*renego_allowed=*/false); + } + + // Create the nested parameters over which the connection to the proxy + // will be made. + ConnectJobParams params; + if (proxy_chain_index == 0) { + // At the beginning of the chain, create the only TransportSocketParams + // object, corresponding to the transport socket we want to create to the + // first proxy. + // TODO(crbug.com/1206799): For an http-like proxy, should this pass a + // `SchemeHostPort`, so proxies can participate in ECH? Note doing so + // with `SCHEME_HTTP` requires handling the HTTPS record upgrade. + params = ConnectJobParams(base::MakeRefCounted<TransportSocketParams>( + proxy_server.host_port_pair(), proxy_dns_network_anonymization_key, + secure_dns_policy, resolution_callback, + SupportedProtocolsFromSSLConfig(proxy_server_ssl_config))); + } else { + // TODO(https://crbug.com/1491092): For now we will assume that proxy + // chains with multiple proxies must all use HTTPS. + CHECK(proxy_chain.GetProxyServer(proxy_chain_index - 1) + .is_secure_http_like()); + params = CreateProxyParams( + proxy_server.host_port_pair(), true, endpoint, proxy_chain, + proxy_chain_index - 1, proxy_annotation_tag, resolution_callback, + network_anonymization_key, secure_dns_policy, common_connect_job_params, + proxy_dns_network_anonymization_key); + } + + // For secure connections, wrap the underlying connection params in SSL + // params. + if (proxy_server.is_secure_http_like()) { + params = + MakeSSLSocketParams(std::move(params), proxy_server.host_port_pair(), + proxy_server_ssl_config, network_anonymization_key); + } + + // Further wrap the underlying connection params, or the SSL params wrapping + // them, with the proxy params. + if (proxy_server.is_http_like()) { + scoped_refptr<TransportSocketParams> transport_socket_params = + MaybeTransportSocketParams(params); + scoped_refptr<SSLSocketParams> ssl_socket_params = + MaybeSSLSocketParams(params); + std::optional<SSLConfig> quic_ssl_config; + if (proxy_server.is_quic()) { + // For QUIC, we only need the SSL config, not the full SSLSocketParams. + // A subsequent CL will remove the redundant SSLSocketParams creation. + quic_ssl_config = ssl_socket_params->ssl_config(); + ssl_socket_params = nullptr; + } + params = ConnectJobParams(base::MakeRefCounted<HttpProxySocketParams>( + std::move(transport_socket_params), std::move(ssl_socket_params), + std::move(quic_ssl_config), host_port_pair, proxy_chain, + proxy_chain_index, should_tunnel, *proxy_annotation_tag, + network_anonymization_key, secure_dns_policy)); + } else { + DCHECK(proxy_server.is_socks()); + DCHECK_EQ(1u, proxy_chain.length()); + scoped_refptr<TransportSocketParams> transport_socket_params = + MaybeTransportSocketParams(params); + DCHECK(transport_socket_params); + // TODO(crbug.com/1206799): Pass `endpoint` directly (preserving scheme + // when available)? + params = ConnectJobParams(base::MakeRefCounted<SOCKSSocketParams>( + std::move(transport_socket_params), + proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5, + ToHostPortPair(endpoint), network_anonymization_key, + *proxy_annotation_tag)); + } + + return params; +} + +} // namespace + +ConnectJobParams ConstructConnectJobParams( + const ConnectJobFactory::Endpoint& endpoint, + const ProxyChain& proxy_chain, + const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, + const std::vector<SSLConfig::CertAndStatus>& allowed_bad_certs, + ConnectJobFactory::AlpnMode alpn_mode, + bool force_tunnel, + PrivacyMode privacy_mode, + const OnHostResolutionCallback& resolution_callback, + const NetworkAnonymizationKey& network_anonymization_key, + SecureDnsPolicy secure_dns_policy, + bool disable_cert_network_fetches, + const CommonConnectJobParams* common_connect_job_params, + const NetworkAnonymizationKey& proxy_dns_network_anonymization_key) { + DCHECK(proxy_chain.IsValid()); + + // Set up `ssl_config` if using SSL to the endpoint. + SSLConfig ssl_config; + if (UsingSsl(endpoint)) { + ssl_config.allowed_bad_certs = allowed_bad_certs; + ssl_config.privacy_mode = privacy_mode; + + ConfigureAlpn(endpoint, alpn_mode, network_anonymization_key, + *common_connect_job_params, ssl_config, + /*renego_allowed=*/true); + + ssl_config.disable_cert_verification_network_fetches = + disable_cert_network_fetches; + + // TODO(https://crbug.com/964642): Also enable 0-RTT for TLS proxies. + ssl_config.early_data_enabled = + *common_connect_job_params->enable_early_data; + } + + // Create the nested parameters over which the connection to the endpoint + // will be made. + ConnectJobParams params; + if (proxy_chain.is_direct()) { + params = ConnectJobParams(base::MakeRefCounted<TransportSocketParams>( + ToTransportEndpoint(endpoint), network_anonymization_key, + secure_dns_policy, resolution_callback, + SupportedProtocolsFromSSLConfig(ssl_config))); + } else { + bool should_tunnel = force_tunnel || UsingSsl(endpoint) || + !proxy_chain.is_get_to_proxy_allowed(); + // Begin creating params for the last proxy in the chain. This will + // recursively create params "backward" through the chain to the first. + params = CreateProxyParams( + ToHostPortPair(endpoint), should_tunnel, endpoint, proxy_chain, + /*proxy_chain_index=*/proxy_chain.length() - 1, proxy_annotation_tag, + resolution_callback, network_anonymization_key, secure_dns_policy, + common_connect_job_params, proxy_dns_network_anonymization_key); + } + + if (UsingSsl(endpoint)) { + // Wrap the final params (which includes connections through zero or more + // proxies) in SSLSocketParams to handle SSL to to the endpoint. + // TODO(crbug.com/1206799): Pass `endpoint` directly (preserving scheme + // when available)? + params = MakeSSLSocketParams(std::move(params), ToHostPortPair(endpoint), + ssl_config, network_anonymization_key); + } + + return params; +} + +} // namespace net diff --git a/net/socket/connect_job_params_factory.h b/net/socket/connect_job_params_factory.h new file mode 100644 index 000000000..73394f72f --- /dev/null +++ b/net/socket/connect_job_params_factory.h @@ -0,0 +1,55 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CONNECT_JOB_PARAMS_FACTORY_H_ +#define NET_SOCKET_CONNECT_JOB_PARAMS_FACTORY_H_ + +#include <memory> +#include <optional> +#include <vector> + +#include "net/base/host_port_pair.h" +#include "net/base/network_anonymization_key.h" +#include "net/base/privacy_mode.h" +#include "net/base/request_priority.h" +#include "net/dns/public/secure_dns_policy.h" +#include "net/http/http_proxy_connect_job.h" +#include "net/socket/connect_job.h" +#include "net/socket/connect_job_factory.h" +#include "net/socket/socks_connect_job.h" +#include "net/socket/ssl_connect_job.h" +#include "net/socket/transport_connect_job.h" +#include "third_party/abseil-cpp/absl/types/variant.h" + +namespace net { + +class NetworkAnonymizationKey; +struct NetworkTrafficAnnotationTag; +class ProxyChain; +struct SSLConfig; + +// Abstraction over the param types for various connect jobs. +using ConnectJobParams = absl::variant<scoped_refptr<HttpProxySocketParams>, + scoped_refptr<SOCKSSocketParams>, + scoped_refptr<TransportSocketParams>, + scoped_refptr<SSLSocketParams>>; + +NET_EXPORT_PRIVATE ConnectJobParams ConstructConnectJobParams( + const ConnectJobFactory::Endpoint& endpoint, + const ProxyChain& proxy_chain, + const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, + const std::vector<SSLConfig::CertAndStatus>& allowed_bad_certs, + ConnectJobFactory::AlpnMode alpn_mode, + bool force_tunnel, + PrivacyMode privacy_mode, + const OnHostResolutionCallback& resolution_callback, + const NetworkAnonymizationKey& network_anonymization_key, + SecureDnsPolicy secure_dns_policy, + bool disable_cert_network_fetches, + const CommonConnectJobParams* common_connect_job_params, + const NetworkAnonymizationKey& proxy_dns_network_anonymization_key); + +} // namespace net + +#endif // NET_SOCKET_CONNECT_JOB_PARAMS_FACTORY_H_ diff --git a/net/socket/connect_job_params_factory_unittest.cc b/net/socket/connect_job_params_factory_unittest.cc new file mode 100644 index 000000000..7219455c1 --- /dev/null +++ b/net/socket/connect_job_params_factory_unittest.cc @@ -0,0 +1,770 @@ +// Copyright 2024 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/connect_job_params_factory.h" + +#include <ostream> +#include <tuple> + +#include "base/containers/flat_set.h" +#include "base/memory/scoped_refptr.h" +#include "net/base/host_port_pair.h" +#include "net/base/network_anonymization_key.h" +#include "net/base/privacy_mode.h" +#include "net/base/proxy_chain.h" +#include "net/base/proxy_server.h" +#include "net/base/schemeful_site.h" +#include "net/dns/public/secure_dns_policy.h" +#include "net/http/http_proxy_connect_job.h" +#include "net/socket/connect_job_factory.h" +#include "net/socket/next_proto.h" +#include "net/socket/socks_connect_job.h" +#include "net/socket/ssl_connect_job.h" +#include "net/socket/transport_connect_job.h" +#include "net/ssl/ssl_config.h" +#include "net/traffic_annotation/network_traffic_annotation_test_helper.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/abseil-cpp/absl/types/variant.h" +#include "url/gurl.h" +#include "url/scheme_host_port.h" + +namespace net { + +namespace { + +struct TestParams { + using ParamTuple = std::tuple<bool, + PrivacyMode, + SecureDnsPolicy, + ConnectJobFactory::AlpnMode, + bool>; + + explicit TestParams(ParamTuple tup) + : disable_cert_network_fetches(std::get<0>(tup)), + privacy_mode(std::get<1>(tup)), + secure_dns_policy(std::get<2>(tup)), + alpn_mode(std::get<3>(tup)), + enable_early_data(std::get<4>(tup)) {} + + bool disable_cert_network_fetches; + PrivacyMode privacy_mode; + SecureDnsPolicy secure_dns_policy; + ConnectJobFactory::AlpnMode alpn_mode; + bool enable_early_data; +}; + +std::ostream& operator<<(std::ostream& os, const TestParams& test_params) { + os << "TestParams {.disable_cert_network_fetches=" + << test_params.disable_cert_network_fetches; + os << ", .privacy_mode=" << test_params.privacy_mode; + os << ", .secure_dns_policy=" + << (test_params.secure_dns_policy == SecureDnsPolicy::kAllow ? "kAllow" + : "kDisable"); + os << ", .alpn_mode=" + << (test_params.alpn_mode == ConnectJobFactory::AlpnMode::kDisabled + ? "kDisabled" + : test_params.alpn_mode == ConnectJobFactory::AlpnMode::kHttp11Only + ? "kHttp11Only" + : "kHttpAll"); + os << ", .enable_early_data=" << test_params.enable_early_data; + os << "}"; + return os; +} + +// Get a string describing the params variant. +const char* ParamsName(ConnectJobParams& params) { + if (absl::holds_alternative<scoped_refptr<HttpProxySocketParams>>(params)) { + return "HttpProxySocketParams"; + } + if (absl::holds_alternative<scoped_refptr<SOCKSSocketParams>>(params)) { + return "SOCKSSocketParams"; + } + if (absl::holds_alternative<scoped_refptr<SSLSocketParams>>(params)) { + return "SSLSocketParams"; + } + if (absl::holds_alternative<scoped_refptr<TransportSocketParams>>(params)) { + return "TransportSocketParams"; + } + return "Unknown"; +} + +scoped_refptr<HttpProxySocketParams> ExpectHttpProxySocketParams( + ConnectJobParams params) { + EXPECT_TRUE( + absl::holds_alternative<scoped_refptr<HttpProxySocketParams>>(params)) + << "Expected HttpProxySocketParams, got " << ParamsName(params); + return absl::get<scoped_refptr<HttpProxySocketParams>>(params); +} + +void VerifyHttpProxySocketParams( + scoped_refptr<HttpProxySocketParams> params, + const char* description, + const HostPortPair& endpoint, + const ProxyChain& proxy_chain, + size_t proxy_chain_index, + bool tunnel, + const NetworkAnonymizationKey& network_anonymization_key, + const SecureDnsPolicy secure_dns_policy) { + SCOPED_TRACE(testing::Message() << "Verifying " << description); + EXPECT_EQ(params->endpoint(), endpoint); + EXPECT_EQ(params->proxy_chain(), proxy_chain); + EXPECT_EQ(params->proxy_chain_index(), proxy_chain_index); + EXPECT_EQ(params->tunnel(), tunnel); + EXPECT_EQ(params->network_anonymization_key(), network_anonymization_key); + EXPECT_EQ(params->secure_dns_policy(), secure_dns_policy); +} + +scoped_refptr<SOCKSSocketParams> ExpectSOCKSSocketParams( + ConnectJobParams params) { + EXPECT_TRUE(absl::holds_alternative<scoped_refptr<SOCKSSocketParams>>(params)) + << "Expected SOCKSSocketParams, got " << ParamsName(params); + return absl::get<scoped_refptr<SOCKSSocketParams>>(params); +} + +// Verify the properties of SOCKSSocketParams. +void VerifySOCKSSocketParams( + scoped_refptr<SOCKSSocketParams>& params, + const char* description, + bool is_socks_v5, + const HostPortPair& destination, + const NetworkAnonymizationKey& network_anonymization_key) { + SCOPED_TRACE(testing::Message() << "Verifying " << description); + EXPECT_EQ(params->is_socks_v5(), is_socks_v5); + EXPECT_EQ(params->destination(), destination); + EXPECT_EQ(params->network_anonymization_key(), network_anonymization_key); +} + +// Assert that the params are TransportSocketParams and return them. +scoped_refptr<TransportSocketParams> ExpectTransportSocketParams( + ConnectJobParams params) { + EXPECT_TRUE( + absl::holds_alternative<scoped_refptr<TransportSocketParams>>(params)) + << "Expected TransportSocketParams, got " << ParamsName(params); + return absl::get<scoped_refptr<TransportSocketParams>>(params); +} + +// Verify the properties of TransportSocketParams. +void VerifyTransportSocketParams( + scoped_refptr<TransportSocketParams>& params, + const char* description, + const TransportSocketParams::Endpoint destination, + const SecureDnsPolicy secure_dns_policy, + const NetworkAnonymizationKey& network_anonymization_key, + const base::flat_set<std::string>& supported_alpns) { + SCOPED_TRACE(testing::Message() << "Verifying " << description); + EXPECT_EQ(params->destination(), destination); + EXPECT_EQ(params->secure_dns_policy(), secure_dns_policy); + EXPECT_EQ(params->network_anonymization_key(), network_anonymization_key); + EXPECT_EQ(params->supported_alpns(), supported_alpns); +} + +// Assert that the params are SSLSocketParams and return them. +scoped_refptr<SSLSocketParams> ExpectSSLSocketParams(ConnectJobParams params) { + EXPECT_TRUE(absl::holds_alternative<scoped_refptr<SSLSocketParams>>(params)) + << "Expected SSLSocketParams, got " << ParamsName(params); + return absl::get<scoped_refptr<SSLSocketParams>>(params); +} + +// Verify the properties of SSLSocketParams. +void VerifySSLSocketParams( + scoped_refptr<SSLSocketParams>& params, + const char* description, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + PrivacyMode privacy_mode, + const NetworkAnonymizationKey& network_anonymization_key) { + SCOPED_TRACE(testing::Message() << "Verifying " << description); + EXPECT_EQ(params->host_and_port(), host_and_port); + // SSLConfig doesn't implement operator==, so just check the properties the + // factory uses. + EXPECT_EQ(params->ssl_config().disable_cert_verification_network_fetches, + ssl_config.disable_cert_verification_network_fetches); + EXPECT_EQ(params->ssl_config().alpn_protos, ssl_config.alpn_protos); + EXPECT_EQ(params->ssl_config().application_settings, + ssl_config.application_settings); + EXPECT_EQ(params->ssl_config().renego_allowed_default, + ssl_config.renego_allowed_default); + EXPECT_EQ(params->ssl_config().renego_allowed_for_protos, + ssl_config.renego_allowed_for_protos); + EXPECT_EQ(params->ssl_config().privacy_mode, privacy_mode); + EXPECT_EQ(params->network_anonymization_key(), network_anonymization_key); +} + +// Calculate the ALPN protocols for the given ALPN mode. +base::flat_set<std::string> AlpnProtoStringsForMode( + ConnectJobFactory::AlpnMode alpn_mode) { + switch (alpn_mode) { + case ConnectJobFactory::AlpnMode::kDisabled: + return {}; + case ConnectJobFactory::AlpnMode::kHttp11Only: + return {"http/1.1"}; + case ConnectJobFactory::AlpnMode::kHttpAll: + return {"h2", "http/1.1"}; + } +} + +class ConnectJobParamsFactoryTest : public testing::TestWithParam<TestParams> { + public: + ConnectJobParamsFactoryTest() { + early_data_enabled_ = enable_early_data(); + switch (alpn_mode()) { + case ConnectJobFactory::AlpnMode::kDisabled: + alpn_protos_ = {}; + application_settings_ = {}; + break; + case ConnectJobFactory::AlpnMode::kHttp11Only: + alpn_protos_ = {kProtoHTTP11}; + application_settings_ = {}; + break; + case ConnectJobFactory::AlpnMode::kHttpAll: + alpn_protos_ = {kProtoHTTP2, kProtoHTTP11}; + application_settings_ = {{kProtoHTTP2, {}}}; + break; + } + } + + protected: + // Parameter accessors. + bool disable_cert_network_fetches() const { + return GetParam().disable_cert_network_fetches; + } + PrivacyMode privacy_mode() const { return GetParam().privacy_mode; } + SecureDnsPolicy secure_dns_policy() const { + return GetParam().secure_dns_policy; + } + ConnectJobFactory::AlpnMode alpn_mode() const { return GetParam().alpn_mode; } + bool enable_early_data() const { return GetParam().enable_early_data; } + + // Create an SSL config for connection to the endpoint, based on the test + // parameters. + SSLConfig SSLConfigForEndpoint() const { + SSLConfig endpoint_ssl_config; + endpoint_ssl_config.disable_cert_verification_network_fetches = + disable_cert_network_fetches(); + endpoint_ssl_config.early_data_enabled = enable_early_data(); + switch (alpn_mode()) { + case ConnectJobFactory::AlpnMode::kDisabled: + endpoint_ssl_config.alpn_protos = {}; + endpoint_ssl_config.application_settings = {}; + endpoint_ssl_config.renego_allowed_default = false; + endpoint_ssl_config.renego_allowed_for_protos = {}; + break; + case ConnectJobFactory::AlpnMode::kHttp11Only: + endpoint_ssl_config.alpn_protos = {kProtoHTTP11}; + endpoint_ssl_config.application_settings = {}; + endpoint_ssl_config.renego_allowed_default = true; + endpoint_ssl_config.renego_allowed_for_protos = {kProtoHTTP11}; + break; + case ConnectJobFactory::AlpnMode::kHttpAll: + endpoint_ssl_config.alpn_protos = {kProtoHTTP2, kProtoHTTP11}; + endpoint_ssl_config.application_settings = {{kProtoHTTP2, {}}}; + endpoint_ssl_config.renego_allowed_default = true; + endpoint_ssl_config.renego_allowed_for_protos = {kProtoHTTP11}; + break; + } + return endpoint_ssl_config; + } + + // Create an SSL config for connection to an HTTPS proxy, based on the test + // parameters. + SSLConfig SSLConfigForProxy() const { + SSLConfig proxy_ssl_config; + proxy_ssl_config.disable_cert_verification_network_fetches = true; + proxy_ssl_config.early_data_enabled = true; + proxy_ssl_config.renego_allowed_default = false; + proxy_ssl_config.renego_allowed_for_protos = {}; + switch (alpn_mode()) { + case ConnectJobFactory::AlpnMode::kDisabled: + proxy_ssl_config.alpn_protos = {}; + proxy_ssl_config.application_settings = {}; + break; + case ConnectJobFactory::AlpnMode::kHttp11Only: + proxy_ssl_config.alpn_protos = {kProtoHTTP11}; + proxy_ssl_config.application_settings = {}; + break; + case ConnectJobFactory::AlpnMode::kHttpAll: + proxy_ssl_config.alpn_protos = {kProtoHTTP2, kProtoHTTP11}; + proxy_ssl_config.application_settings = {{kProtoHTTP2, {}}}; + break; + } + return proxy_ssl_config; + } + + NextProtoVector alpn_protos_; + SSLConfig::ApplicationSettings application_settings_; + bool early_data_enabled_; + const CommonConnectJobParams common_connect_job_params_{ + /*client_socket_factory=*/nullptr, + /*host_resolver=*/nullptr, + /*http_auth_cache=*/nullptr, + /*http_auth_handler_factory=*/nullptr, + /*spdy_session_pool=*/nullptr, + /*quic_supported_versions=*/nullptr, + /*quic_session_pool=*/nullptr, + /*proxy_delegate=*/nullptr, + /*http_user_agent_settings=*/nullptr, + /*ssl_client_context=*/nullptr, + /*socket_performance_watcher_factory=*/nullptr, + /*network_quality_estimator=*/nullptr, + /*net_log=*/nullptr, + /*websocket_endpoint_lock_manager=*/nullptr, + /*http_server_properties=*/nullptr, + &alpn_protos_, + &application_settings_, + /*ignore_certificate_errors=*/nullptr, + &early_data_enabled_}; + + const NetworkAnonymizationKey kTestNak = + NetworkAnonymizationKey::CreateSameSite( + net::SchemefulSite(GURL("http://example.test"))); + const NetworkAnonymizationKey kProxyDnsNak = + NetworkAnonymizationKey::CreateSameSite( + net::SchemefulSite(GURL("http://example-dns.test"))); +}; + +// A connect to a simple HTTP endpoint produces just transport params. +TEST_P(ConnectJobParamsFactoryTest, HttpEndpoint) { + const url::SchemeHostPort kEndpoint(url::kHttpScheme, "test", 82); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, ProxyChain::Direct(), + /*proxy_annotation_tag=*/std::nullopt, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + ExpectTransportSocketParams(params); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", kEndpoint, + secure_dns_policy(), kTestNak, base::flat_set<std::string>()); +} + +// A connect to a endpoint without SSL, specified as a `SchemelessEndpoint`, +// produces just transport params. +TEST_P(ConnectJobParamsFactoryTest, UnencryptedEndpointWithoutScheme) { + const ConnectJobFactory::SchemelessEndpoint kEndpoint{ + /*using_ssl=*/false, HostPortPair("test", 82)}; + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, ProxyChain::Direct(), + /*proxy_annotation_tag=*/std::nullopt, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + ExpectTransportSocketParams(params); + VerifyTransportSocketParams(transport_socket_params, + "transport_socket_params", + HostPortPair("test", 82), secure_dns_policy(), + kTestNak, base::flat_set<std::string>()); +} + +// A connect to a simple HTTPS endpoint produces SSL params wrapping transport +// params. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpoint) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, ProxyChain::Direct(), TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SSLSocketParams> ssl_socket_params = + ExpectSSLSocketParams(params); + SSLConfig ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(ssl_socket_params, "ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), ssl_config, + privacy_mode(), kTestNak); + scoped_refptr<TransportSocketParams> transport_socket_params = + ssl_socket_params->GetDirectConnectionParams(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", kEndpoint, + secure_dns_policy(), kTestNak, AlpnProtoStringsForMode(alpn_mode())); +} + +// A connect to a endpoint SSL, specified as a `SchemelessEndpoint`, +// produces just transport params. +TEST_P(ConnectJobParamsFactoryTest, EncryptedEndpointWithoutScheme) { + // Encrypted endpoints without scheme are only supported without ALPN. + if (alpn_mode() != ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const ConnectJobFactory::SchemelessEndpoint kEndpoint{ + /*using_ssl=*/true, HostPortPair("test", 4433)}; + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, ProxyChain::Direct(), TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SSLSocketParams> ssl_socket_params = + ExpectSSLSocketParams(params); + SSLConfig ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(ssl_socket_params, "ssl_socket_params", + HostPortPair("test", 4433), ssl_config, privacy_mode(), + kTestNak); + scoped_refptr<TransportSocketParams> transport_socket_params = + ssl_socket_params->GetDirectConnectionParams(); + VerifyTransportSocketParams(transport_socket_params, + "transport_socket_params", + HostPortPair("test", 4433), secure_dns_policy(), + kTestNak, AlpnProtoStringsForMode(alpn_mode())); +} + +// A connection to an HTTP endpoint via an HTTPS proxy, without forcing a +// tunnel, sets up an HttpProxySocketParams, wrapping SSLSocketParams wrapping +// TransportSocketParams, intending to use GET to the proxy. This is not +// tunneled. +TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaHttpsProxy) { + const url::SchemeHostPort kEndpoint(url::kHttpScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::FromSchemeHostAndPort( + ProxyServer::SCHEME_HTTPS, "proxy", 443); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params = + ExpectHttpProxySocketParams(params); + VerifyHttpProxySocketParams( + http_proxy_socket_params, "http_proxy_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/0, + /*tunnel=*/false, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> ssl_socket_params = + http_proxy_socket_params->ssl_params(); + ASSERT_TRUE(ssl_socket_params); + SSLConfig ssl_config = SSLConfigForProxy(); + VerifySSLSocketParams(ssl_socket_params, "ssl_socket_params", + HostPortPair::FromString("proxy:443"), ssl_config, + PrivacyMode::PRIVACY_MODE_DISABLED, kTestNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + ssl_socket_params->GetDirectConnectionParams(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", + HostPortPair("proxy", 443), secure_dns_policy(), kProxyDnsNak, + AlpnProtoStringsForMode(alpn_mode())); +} + +// A connection to an HTTPS endpoint via an HTTPS proxy, +// sets up an SSLSocketParams, wrapping HttpProxySocketParams, wrapping +// SSLSocketParams, wrapping TransportSocketParams. This is always tunneled. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxy) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::FromSchemeHostAndPort( + ProxyServer::SCHEME_HTTPS, "proxy", 443); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SSLSocketParams> endpoint_ssl_socket_params = + ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params = + endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams( + http_proxy_socket_params, "http_proxy_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/0, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params = + http_proxy_socket_params->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params); + SSLConfig proxy_ssl_config = SSLConfigForProxy(); + VerifySSLSocketParams(proxy_ssl_socket_params, "proxy_ssl_socket_params", + HostPortPair::FromString("proxy:443"), proxy_ssl_config, + PrivacyMode::PRIVACY_MODE_DISABLED, kTestNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + proxy_ssl_socket_params->GetDirectConnectionParams(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", + HostPortPair("proxy", 443), secure_dns_policy(), kProxyDnsNak, + AlpnProtoStringsForMode(alpn_mode())); +} + +// A connection to an HTTPS endpoint via an HTTP proxy +// sets up an SSLSocketParams, wrapping HttpProxySocketParams, wrapping +// TransportSocketParams. This is always tunneled. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpProxy) { + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = + ProxyChain::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTP, "proxy", 80); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SSLSocketParams> endpoint_ssl_socket_params = + ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params = + endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams( + http_proxy_socket_params, "http_proxy_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/0, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<TransportSocketParams> transport_socket_params = + http_proxy_socket_params->transport_params(); + ASSERT_TRUE(transport_socket_params); + VerifyTransportSocketParams(transport_socket_params, + "transport_socket_params", + HostPortPair("proxy", 80), secure_dns_policy(), + kProxyDnsNak, base::flat_set<std::string>({})); +} + +// A connection to an HTTP endpoint via a SOCKS proxy, +// sets up an SOCKSSocketParams wrapping TransportSocketParams. +TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaSOCKSProxy) { + const url::SchemeHostPort kEndpoint(url::kHttpScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::FromSchemeHostAndPort( + ProxyServer::SCHEME_SOCKS4, "proxy", 999); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SOCKSSocketParams> socks_socket_params = + ExpectSOCKSSocketParams(params); + VerifySOCKSSocketParams(socks_socket_params, "socks_socket_params", + /*is_socks_v5=*/false, + HostPortPair::FromSchemeHostPort(kEndpoint), + kTestNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + socks_socket_params->transport_params(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", + HostPortPair("proxy", 999), secure_dns_policy(), kProxyDnsNak, {}); +} + +// A connection to an HTTPS endpoint via a SOCKS proxy, +// sets up an SSLSocketParams wrapping SOCKSSocketParams wrapping +// TransportSocketParams. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaSOCKSProxy) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::FromSchemeHostAndPort( + ProxyServer::SCHEME_SOCKS5, "proxy", 999); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SSLSocketParams> endpoint_ssl_socket_params = + ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + scoped_refptr<SOCKSSocketParams> socks_socket_params = + endpoint_ssl_socket_params->GetSocksProxyConnectionParams(); + VerifySOCKSSocketParams(socks_socket_params, "socks_socket_params", + /*is_socks_v5=*/true, + HostPortPair::FromSchemeHostPort(kEndpoint), + kTestNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + socks_socket_params->transport_params(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", + HostPortPair("proxy", 999), secure_dns_policy(), kProxyDnsNak, {}); +} + +// A connection to an HTTP endpoint via a two-proxy HTTPS chain +// sets up the required parameters. +TEST_P(ConnectJobParamsFactoryTest, HttpEndpointViaHttpsProxyViaHttpsProxy) { + const url::SchemeHostPort kEndpoint(url::kHttpScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::ForIpProtection({ + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTPS, "proxya", + 443), + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTPS, "proxyb", + 443), + }); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_b = + ExpectHttpProxySocketParams(params); + VerifyHttpProxySocketParams( + http_proxy_socket_params_b, "http_proxy_socket_params_b", + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/1, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params_b = + http_proxy_socket_params_b->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params_b); + SSLConfig proxy_ssl_config = SSLConfigForProxy(); + VerifySSLSocketParams(proxy_ssl_socket_params_b, "proxy_ssl_socket_params_b", + HostPortPair::FromString("proxyb:443"), + proxy_ssl_config, PrivacyMode::PRIVACY_MODE_DISABLED, + kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_a = + proxy_ssl_socket_params_b->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams(http_proxy_socket_params_a, + "http_proxy_socket_params_a", + HostPortPair("proxyb", 443), proxy_chain, + /*proxy_chain_index=*/0, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params_a = + http_proxy_socket_params_a->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params_a); + VerifySSLSocketParams(proxy_ssl_socket_params_a, "proxy_ssl_socket_params_a", + HostPortPair::FromString("proxya:443"), + proxy_ssl_config, PrivacyMode::PRIVACY_MODE_DISABLED, + kTestNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + proxy_ssl_socket_params_a->GetDirectConnectionParams(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", + HostPortPair("proxya", 443), secure_dns_policy(), kProxyDnsNak, + AlpnProtoStringsForMode(alpn_mode())); +} + +// A connection to an HTTPS endpoint via a two-proxy HTTPS chain +// sets up the required parameters. +TEST_P(ConnectJobParamsFactoryTest, HttpsEndpointViaHttpsProxyViaHttpsProxy) { + // HTTPS endpoints are not supported without ALPN. + if (alpn_mode() == ConnectJobFactory::AlpnMode::kDisabled) { + return; + } + + const url::SchemeHostPort kEndpoint(url::kHttpsScheme, "test", 82); + ProxyChain proxy_chain = ProxyChain::ForIpProtection({ + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTPS, "proxya", + 443), + ProxyServer::FromSchemeHostAndPort(ProxyServer::SCHEME_HTTPS, "proxyb", + 443), + }); + ConnectJobParams params = ConstructConnectJobParams( + kEndpoint, proxy_chain, TRAFFIC_ANNOTATION_FOR_TESTS, + /*allowed_bad_certs=*/{}, alpn_mode(), + /*force_tunnel=*/false, privacy_mode(), OnHostResolutionCallback(), + kTestNak, secure_dns_policy(), disable_cert_network_fetches(), + &common_connect_job_params_, kProxyDnsNak); + + scoped_refptr<SSLSocketParams> endpoint_ssl_socket_params = + ExpectSSLSocketParams(params); + SSLConfig endpoint_ssl_config = SSLConfigForEndpoint(); + VerifySSLSocketParams(endpoint_ssl_socket_params, + "endpoint_ssl_socket_params", + HostPortPair::FromSchemeHostPort(kEndpoint), + endpoint_ssl_config, privacy_mode(), kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_b = + endpoint_ssl_socket_params->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams( + http_proxy_socket_params_b, "http_proxy_socket_params_b", + HostPortPair::FromSchemeHostPort(kEndpoint), proxy_chain, + /*proxy_chain_index=*/1, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params_b = + http_proxy_socket_params_b->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params_b); + SSLConfig proxy_ssl_config = SSLConfigForProxy(); + VerifySSLSocketParams(proxy_ssl_socket_params_b, "proxy_ssl_socket_params_b", + HostPortPair::FromString("proxyb:443"), + proxy_ssl_config, PrivacyMode::PRIVACY_MODE_DISABLED, + kTestNak); + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_a = + proxy_ssl_socket_params_b->GetHttpProxyConnectionParams(); + VerifyHttpProxySocketParams(http_proxy_socket_params_a, + "http_proxy_socket_params_a", + HostPortPair("proxyb", 443), proxy_chain, + /*proxy_chain_index=*/0, + /*tunnel=*/true, kTestNak, secure_dns_policy()); + + scoped_refptr<SSLSocketParams> proxy_ssl_socket_params_a = + http_proxy_socket_params_a->ssl_params(); + ASSERT_TRUE(proxy_ssl_socket_params_a); + VerifySSLSocketParams(proxy_ssl_socket_params_a, "proxy_ssl_socket_params_a", + HostPortPair::FromString("proxya:443"), + proxy_ssl_config, PrivacyMode::PRIVACY_MODE_DISABLED, + kTestNak); + + scoped_refptr<TransportSocketParams> transport_socket_params = + proxy_ssl_socket_params_a->GetDirectConnectionParams(); + VerifyTransportSocketParams( + transport_socket_params, "transport_socket_params", + HostPortPair("proxya", 443), secure_dns_policy(), kProxyDnsNak, + AlpnProtoStringsForMode(alpn_mode())); +} + +INSTANTIATE_TEST_SUITE_P( + All, + ConnectJobParamsFactoryTest, + testing::ConvertGenerator<TestParams::ParamTuple>(testing::Combine( + testing::Values(false, true), + testing::Values(PrivacyMode::PRIVACY_MODE_ENABLED, + PrivacyMode::PRIVACY_MODE_DISABLED), + testing::Values(SecureDnsPolicy::kAllow, SecureDnsPolicy::kDisable), + testing::Values(ConnectJobFactory::AlpnMode::kDisabled, + ConnectJobFactory::AlpnMode::kHttp11Only, + ConnectJobFactory::AlpnMode::kHttpAll), + testing::Values(false, true)))); + +} // namespace + +} // namespace net diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 21366f033..eb3614379 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -10,6 +10,7 @@ #include <stdio.h> #include <memory> +#include <ostream> #include <string> #include <utility> #include <vector> @@ -250,6 +251,60 @@ bool StaticSocketDataHelper::VerifyWriteData(const std::string& data, return expected_data == actual_data; } +void StaticSocketDataHelper::ExpectAllReadDataConsumed( + SocketDataPrinter* printer) const { + if (AllReadDataConsumed()) { + return; + } + + std::ostringstream msg; + if (read_index_ < read_count()) { + msg << "Unconsumed reads:\n"; + for (size_t i = read_index_; i < read_count(); i++) { + msg << (reads_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockRead seq " + << reads_[i].sequence_number << ":\n"; + if (reads_[i].result != OK) { + msg << "Result: " << reads_[i].result << "\n"; + } + if (reads_[i].data) { + std::string data(reads_[i].data, reads_[i].data_len); + if (printer) { + msg << printer->PrintWrite(data); + } + msg << HexDump(data); + } + } + } + EXPECT_TRUE(AllReadDataConsumed()) << msg.str(); +} + +void StaticSocketDataHelper::ExpectAllWriteDataConsumed( + SocketDataPrinter* printer) const { + if (AllWriteDataConsumed()) { + return; + } + + std::ostringstream msg; + if (write_index_ < write_count()) { + msg << "Unconsumed writes:\n"; + for (size_t i = write_index_; i < write_count(); i++) { + msg << (writes_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockWrite seq " + << writes_[i].sequence_number << ":\n"; + if (writes_[i].result != OK) { + msg << "Result: " << writes_[i].result << "\n"; + } + if (writes_[i].data) { + std::string data(writes_[i].data, writes_[i].data_len); + if (printer) { + msg << printer->PrintWrite(data); + } + msg << HexDump(data); + } + } + } + EXPECT_TRUE(AllWriteDataConsumed()) << msg.str(); +} + const MockWrite& StaticSocketDataHelper::PeekRealWrite() const { for (size_t i = write_index_; i < write_count(); i++) { if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING) @@ -548,6 +603,14 @@ bool SequencedSocketData::AllWriteDataConsumed() const { return helper_.AllWriteDataConsumed(); } +void SequencedSocketData::ExpectAllReadDataConsumed() const { + helper_.ExpectAllReadDataConsumed(printer_.get()); +} + +void SequencedSocketData::ExpectAllWriteDataConsumed() const { + helper_.ExpectAllWriteDataConsumed(printer_.get()); +} + bool SequencedSocketData::IsIdle() const { // If |busy_before_sync_reads_| is not set, always considered idle. If // no reads left, or the next operation is a write, also consider it idle. @@ -847,10 +910,6 @@ std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key, ssl_config.network_anonymization_key); } - if (next_ssl_data->expected_disable_sha1_server_signatures) { - EXPECT_EQ(*next_ssl_data->expected_disable_sha1_server_signatures, - ssl_config.disable_sha1_server_signatures); - } if (next_ssl_data->expected_ech_config_list) { EXPECT_EQ(*next_ssl_data->expected_ech_config_list, ssl_config.ech_config_list); diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 13b39cb91..5550b6a18 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -232,6 +232,15 @@ struct MockWriteResult { int result; }; +class SocketDataPrinter { + public: + ~SocketDataPrinter() = default; + + // Prints the write in |data| using some sort of protocol-specific + // format. + virtual std::string PrintWrite(const std::string& data) = 0; +}; + // The SocketDataProvider is an interface used by the MockClientSocket // for getting data about individual reads and writes on the socket. Can be // used with at most one socket at a time. @@ -382,15 +391,6 @@ class AsyncSocket { virtual void OnDataProviderDestroyed() = 0; }; -class SocketDataPrinter { - public: - ~SocketDataPrinter() = default; - - // Prints the write in |data| using some sort of protocol-specific - // format. - virtual std::string PrintWrite(const std::string& data) = 0; -}; - // StaticSocketDataHelper manages a list of reads and writes. class StaticSocketDataHelper { public: @@ -427,6 +427,9 @@ class StaticSocketDataHelper { bool AllReadDataConsumed() const { return read_index() >= read_count(); } bool AllWriteDataConsumed() const { return write_index() >= write_count(); } + void ExpectAllReadDataConsumed(SocketDataPrinter* printer) const; + void ExpectAllWriteDataConsumed(SocketDataPrinter* printer) const; + private: // Returns the next available read or write that is not a pause event. CHECK // fails if no data is available. @@ -532,7 +535,6 @@ struct SSLSocketDataProvider { std::optional<HostPortPair> expected_host_and_port; std::optional<bool> expected_ignore_certificate_errors; std::optional<NetworkAnonymizationKey> expected_network_anonymization_key; - std::optional<bool> expected_disable_sha1_server_signatures; std::optional<std::vector<uint8_t>> expected_ech_config_list; bool is_connect_data_consumed = false; @@ -571,6 +573,10 @@ class SequencedSocketData : public SocketDataProvider { bool IsIdle() const override; void CancelPendingRead() override; + // EXPECTs that all data has been consumed, printing any un-consumed data. + void ExpectAllReadDataConsumed() const; + void ExpectAllWriteDataConsumed() const; + // An ASYNC read event with a return value of ERR_IO_PENDING will cause the // socket data to pause at that event, and advance no further, until Resume is // invoked. At that point, the socket will continue at the next event in the diff --git a/net/socket/ssl_client_socket_impl.cc b/net/socket/ssl_client_socket_impl.cc index 236de0c0b..e036710b7 100644 --- a/net/socket/ssl_client_socket_impl.cc +++ b/net/socket/ssl_client_socket_impl.cc @@ -65,8 +65,6 @@ #include "third_party/boringssl/src/include/openssl/evp.h" #include "third_party/boringssl/src/include/openssl/mem.h" #include "third_party/boringssl/src/include/openssl/ssl.h" -#include "third_party/boringssl/src/pki/parse_certificate.h" -#include "third_party/boringssl/src/pki/parse_values.h" namespace net { @@ -148,96 +146,6 @@ base::Value::Dict NetLogSSLMessageParams(bool is_write, return dict; } -// This enum is used in histograms, so values may not be reused. -enum class RSAKeyUsage { - // The TLS cipher suite was not RSA or ECDHE_RSA. - kNotRSA = 0, - // The Key Usage extension is not present, which is consistent with TLS usage. - kOKNoExtension = 1, - // The Key Usage extension has both the digitalSignature and keyEncipherment - // bits, which is consistent with TLS usage. - kOKHaveBoth = 2, - // The Key Usage extension contains only the digitalSignature bit, which is - // consistent with TLS usage. - kOKHaveDigitalSignature = 3, - // The Key Usage extension contains only the keyEncipherment bit, which is - // consistent with TLS usage. - kOKHaveKeyEncipherment = 4, - // The Key Usage extension is missing the digitalSignature bit. - kMissingDigitalSignature = 5, - // The Key Usage extension is missing the keyEncipherment bit. - kMissingKeyEncipherment = 6, - // There was an error processing the certificate. - kError = 7, - - kLastValue = kError, -}; - -RSAKeyUsage CheckRSAKeyUsage(const X509Certificate* cert, - const SSL_CIPHER* cipher) { - bool need_key_encipherment = false; - switch (SSL_CIPHER_get_kx_nid(cipher)) { - case NID_kx_rsa: - need_key_encipherment = true; - break; - case NID_kx_ecdhe: - if (SSL_CIPHER_get_auth_nid(cipher) != NID_auth_rsa) { - return RSAKeyUsage::kNotRSA; - } - break; - default: - return RSAKeyUsage::kNotRSA; - } - - const CRYPTO_BUFFER* buffer = cert->cert_buffer(); - bssl::der::Input tbs_certificate_tlv; - bssl::der::Input signature_algorithm_tlv; - bssl::der::BitString signature_value; - bssl::ParsedTbsCertificate tbs; - if (!bssl::ParseCertificate(bssl::der::Input(CRYPTO_BUFFER_data(buffer), - CRYPTO_BUFFER_len(buffer)), - &tbs_certificate_tlv, &signature_algorithm_tlv, - &signature_value, nullptr) || - !ParseTbsCertificate(tbs_certificate_tlv, - x509_util::DefaultParseCertificateOptions(), &tbs, - nullptr)) { - return RSAKeyUsage::kError; - } - - if (!tbs.extensions_tlv) { - return RSAKeyUsage::kOKNoExtension; - } - - std::map<bssl::der::Input, bssl::ParsedExtension> extensions; - if (!ParseExtensions(tbs.extensions_tlv.value(), &extensions)) { - return RSAKeyUsage::kError; - } - bssl::ParsedExtension key_usage_ext; - if (!ConsumeExtension(bssl::der::Input(bssl::kKeyUsageOid), &extensions, - &key_usage_ext)) { - return RSAKeyUsage::kOKNoExtension; - } - bssl::der::BitString key_usage; - if (!bssl::ParseKeyUsage(key_usage_ext.value, &key_usage)) { - return RSAKeyUsage::kError; - } - - bool have_digital_signature = - key_usage.AssertsBit(bssl::KEY_USAGE_BIT_DIGITAL_SIGNATURE); - bool have_key_encipherment = - key_usage.AssertsBit(bssl::KEY_USAGE_BIT_KEY_ENCIPHERMENT); - if (have_digital_signature && have_key_encipherment) { - return RSAKeyUsage::kOKHaveBoth; - } - - if (need_key_encipherment) { - return have_key_encipherment ? RSAKeyUsage::kOKHaveKeyEncipherment - : RSAKeyUsage::kMissingKeyEncipherment; - } - return have_digital_signature ? RSAKeyUsage::kOKHaveDigitalSignature - : RSAKeyUsage::kMissingDigitalSignature; -} - bool HostIsIPAddressNoBrackets(base::StringPiece host) { // Note this cannot directly call url::HostIsIPAddress, because that function // expects bracketed IPv6 literals. By the time hosts reach SSLClientSocket, @@ -836,17 +744,18 @@ int SSLClientSocketImpl::Init() { return ERR_UNEXPECTED; } - if (ssl_config_.disable_sha1_server_signatures) { - static const uint16_t kVerifyPrefs[] = { - SSL_SIGN_ECDSA_SECP256R1_SHA256, SSL_SIGN_RSA_PSS_RSAE_SHA256, - SSL_SIGN_RSA_PKCS1_SHA256, SSL_SIGN_ECDSA_SECP384R1_SHA384, - SSL_SIGN_RSA_PSS_RSAE_SHA384, SSL_SIGN_RSA_PKCS1_SHA384, - SSL_SIGN_RSA_PSS_RSAE_SHA512, SSL_SIGN_RSA_PKCS1_SHA512, - }; - if (!SSL_set_verify_algorithm_prefs(ssl_.get(), kVerifyPrefs, - std::size(kVerifyPrefs))) { - return ERR_UNEXPECTED; - } + // Disable SHA-1 server signatures. + // TODO(crbug.com/boringssl/699): Once the default is flipped in BoringSSL, we + // no longer need to override it. + static const uint16_t kVerifyPrefs[] = { + SSL_SIGN_ECDSA_SECP256R1_SHA256, SSL_SIGN_RSA_PSS_RSAE_SHA256, + SSL_SIGN_RSA_PKCS1_SHA256, SSL_SIGN_ECDSA_SECP384R1_SHA384, + SSL_SIGN_RSA_PSS_RSAE_SHA384, SSL_SIGN_RSA_PKCS1_SHA384, + SSL_SIGN_RSA_PSS_RSAE_SHA512, SSL_SIGN_RSA_PKCS1_SHA512, + }; + if (!SSL_set_verify_algorithm_prefs(ssl_.get(), kVerifyPrefs, + std::size(kVerifyPrefs))) { + return ERR_UNEXPECTED; } SSL_set_alps_use_new_codepoint( @@ -1027,17 +936,6 @@ int SSLClientSocketImpl::DoHandshakeComplete(int result) { // in server_cert_. CHECK(ok); - // See how feasible enforcing RSA key usage would be. See - // https://crbug.com/795089. - if (!server_cert_verify_result_.is_issued_by_known_root) { - RSAKeyUsage rsa_key_usage = CheckRSAKeyUsage( - server_cert_.get(), SSL_get_current_cipher(ssl_.get())); - if (rsa_key_usage != RSAKeyUsage::kNotRSA) { - UMA_HISTOGRAM_ENUMERATION("Net.SSLRSAKeyUsage.UnknownRoot", rsa_key_usage, - static_cast<int>(RSAKeyUsage::kLastValue) + 1); - } - } - SSLHandshakeDetails details; if (SSL_version(ssl_.get()) < TLS1_3_VERSION) { if (SSL_session_reused(ssl_.get())) { @@ -1225,17 +1123,6 @@ ssl_verify_result_t SSLClientSocketImpl::HandleVerifyResult() { cert_verifier_request_.reset(); - // Enforce keyUsage extension for RSA leaf certificates chaining up to known - // roots unconditionally. Enforcement for local anchors is, for now, - // conditional on feature flags and external configuration. See - // https://crbug.com/795089. - bool rsa_key_usage_for_local_anchors = - context_->config().rsa_key_usage_for_local_anchors_override.value_or( - base::FeatureList::IsEnabled(features::kRSAKeyUsageForLocalAnchors)); - SSL_set_enforce_rsa_key_usage( - ssl_.get(), rsa_key_usage_for_local_anchors || - server_cert_verify_result_.is_issued_by_known_root); - // If the connection was good, check HPKP and CT status simultaneously, // but prefer to treat the HPKP error as more serious, if there was one. if (result == OK) { @@ -1683,7 +1570,6 @@ SSLClientSessionCache::Key SSLClientSocketImpl::GetSessionCacheKey( key.network_anonymization_key = ssl_config_.network_anonymization_key; } key.privacy_mode = ssl_config_.privacy_mode; - key.disable_legacy_crypto = ssl_config_.disable_sha1_server_signatures; return key; } diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index 5f2dd445d..97493a29d 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -3223,14 +3223,10 @@ TEST_F(SSLClientSocketTest, SHA1) { ASSERT_TRUE( StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config)); + // SHA-1 server signatures are always disabled. int rv; ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); - EXPECT_THAT(rv, IsOk()); - - SSLConfig config; - config.disable_sha1_server_signatures = true; - ASSERT_TRUE(CreateAndConnectSSLClientSocket(config, &rv)); - EXPECT_THAT(rv, IsError(ERR_SSL_PROTOCOL_ERROR)); + EXPECT_THAT(rv, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH)); } TEST_F(SSLClientSocketFalseStartTest, FalseStartEnabled) { @@ -4124,10 +4120,7 @@ struct KeyUsageTest { class SSLClientSocketKeyUsageTest : public SSLClientSocketTest, public ::testing::WithParamInterface< - std::tuple<KeyUsageTest, - bool /*known_root*/, - bool /*rsa_key_usage_for_local_anchors_enabled*/, - bool /*override_feature*/>> {}; + std::tuple<KeyUsageTest, bool /*known_root*/>> {}; const KeyUsageTest kKeyUsageTests[] = { // keyUsage matches cipher suite. @@ -4143,25 +4136,7 @@ const KeyUsageTest kKeyUsageTests[] = { }; TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsage) { - const auto& [test, known_root, rsa_key_usage_for_local_anchors_enabled, - override_feature] = GetParam(); - bool enable_feature; - if (override_feature) { - // Configure the feature in the opposite way that we intend, to test that - // the configuration overrides it. - enable_feature = !rsa_key_usage_for_local_anchors_enabled; - } else { - enable_feature = rsa_key_usage_for_local_anchors_enabled; - } - base::test::ScopedFeatureList scoped_feature_list; - if (enable_feature) { - scoped_feature_list.InitAndEnableFeature( - features::kRSAKeyUsageForLocalAnchors); - } else { - scoped_feature_list.InitAndDisableFeature( - features::kRSAKeyUsageForLocalAnchors); - } - + const auto& [test, known_root] = GetParam(); SSLServerConfig server_config; server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2; server_config.cipher_suite_for_testing = test.cipher_suite; @@ -4169,13 +4144,6 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsage) { scoped_refptr<X509Certificate> server_cert = embedded_test_server()->GetCertificate(); - SSLContextConfig context_config; - if (override_feature) { - context_config.rsa_key_usage_for_local_anchors_override = - rsa_key_usage_for_local_anchors_enabled; - } - ssl_config_service_->UpdateSSLConfigAndNotify(context_config); - // Certificate is trusted. CertVerifyResult verify_result; verify_result.is_issued_by_known_root = known_root; @@ -4190,7 +4158,7 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsage) { SSLInfo ssl_info; ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); - if (test.match || (!known_root && !rsa_key_usage_for_local_anchors_enabled)) { + if (test.match) { EXPECT_THAT(rv, IsOk()); EXPECT_TRUE(sock_->IsConnected()); } else { @@ -4199,10 +4167,9 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsage) { } } -INSTANTIATE_TEST_SUITE_P( - RSAKeyUsageInstantiation, - SSLClientSocketKeyUsageTest, - Combine(ValuesIn(kKeyUsageTests), Bool(), Bool(), Bool())); +INSTANTIATE_TEST_SUITE_P(RSAKeyUsageInstantiation, + SSLClientSocketKeyUsageTest, + Combine(ValuesIn(kKeyUsageTests), Bool())); // Test that when CT is required (in this case, by the delegate), the // absence of CT information is a socket error. diff --git a/net/socket/ssl_connect_job.cc b/net/socket/ssl_connect_job.cc index f80d60d32..c1878e2a8 100644 --- a/net/socket/ssl_connect_job.cc +++ b/net/socket/ssl_connect_job.cc @@ -33,7 +33,6 @@ #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" -#include "net/ssl/ssl_legacy_crypto_fallback.h" #include "third_party/boringssl/src/include/openssl/pool.h" #include "third_party/boringssl/src/include/openssl/ssl.h" @@ -52,14 +51,12 @@ SSLSocketParams::SSLSocketParams( scoped_refptr<HttpProxySocketParams> http_proxy_params, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - PrivacyMode privacy_mode, NetworkAnonymizationKey network_anonymization_key) : direct_params_(std::move(direct_params)), socks_proxy_params_(std::move(socks_proxy_params)), http_proxy_params_(std::move(http_proxy_params)), host_and_port_(host_and_port), ssl_config_(ssl_config), - privacy_mode_(privacy_mode), network_anonymization_key_(network_anonymization_key) { // Only one set of lower level ConnectJob params should be non-NULL. DCHECK((direct_params_ && !socks_proxy_params_ && !http_proxy_params_) || @@ -389,13 +386,6 @@ int SSLConnectJob::DoSSLConnect() { ssl_config.ignore_certificate_errors = *common_connect_job_params()->ignore_certificate_errors; ssl_config.network_anonymization_key = params_->network_anonymization_key(); - ssl_config.privacy_mode = params_->privacy_mode(); - // We do the fallback in both cases here to ensure we separate the effect of - // disabling sha1 from the effect of having a single automatic retry - // on a potentially unreliably network connection. - ssl_config.disable_sha1_server_signatures = - disable_legacy_crypto_with_fallback_ || - !ssl_client_context()->config().InsecureHashesInTLSHandshakesEnabled(); if (ssl_client_context()->config().ech_enabled) { if (ech_retry_configs_) { @@ -425,15 +415,16 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { server_address_ = IPEndPoint(); } - // Many servers which negotiate SHA-1 server signatures in TLS 1.2 actually - // support SHA-2 but preferentially sign SHA-1 if available. + // Historically, many servers which negotiated SHA-1 server signatures in + // TLS 1.2 actually support SHA-2 but preferentially sign SHA-1 if available. + // In order to get accurate metrics while deprecating SHA-1, we initially + // connected with SHA-1 disabled and then retried with enabled. // - // To get more accurate metrics, initially connect with SHA-1 disabled. If - // this fails, retry with them enabled. This keeps the legacy algorithms - // working for now, but they will only appear in metrics and DevTools if the - // site relies on them. + // SHA-1 is now always disabled, but we retained the fallback to separate the + // effect of disabling SHA-1 from the effect of having a single automatic + // retry on a potentially unreliably network connection. // - // See https://crbug.com/658905. + // TODO(https://crbug.com/658905): Remove this now redundant retry. if (disable_legacy_crypto_with_fallback_ && (result == ERR_CONNECTION_CLOSED || result == ERR_CONNECTION_RESET || result == ERR_SSL_PROTOCOL_ERROR || @@ -535,42 +526,6 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { base::UmaHistogramSparse("Net.SSL_KeyExchange.ECDHE", ssl_info.key_exchange_group); } - - // Classify whether the connection required the legacy crypto fallback. - SSLLegacyCryptoFallback fallback = SSLLegacyCryptoFallback::kNoFallback; - if (!disable_legacy_crypto_with_fallback_) { - // Some servers, though they do not negotiate SHA-1, still fail the - // connection when SHA-1 is not offered. We believe these are servers - // which match the sent certificates against the ClientHello and then - // are configured with a SHA-1 certificate. - // - // SHA-1 certificate chains are no longer accepted, however servers may - // send extra unused certificates, most commonly a copy of the trust - // anchor. We only need to check for RSASSA-PKCS1-v1_5 signatures, because - // other SHA-1 signature types have already been removed from the - // ClientHello. - bool sent_sha1_cert = ssl_info.unverified_cert && - x509_util::HasRsaPkcs1Sha1Signature( - ssl_info.unverified_cert->cert_buffer()); - if (!sent_sha1_cert && ssl_info.unverified_cert) { - for (const auto& cert : - ssl_info.unverified_cert->intermediate_buffers()) { - if (x509_util::HasRsaPkcs1Sha1Signature(cert.get())) { - sent_sha1_cert = true; - break; - } - } - } - if (ssl_info.peer_signature_algorithm == SSL_SIGN_RSA_PKCS1_SHA1) { - fallback = sent_sha1_cert - ? SSLLegacyCryptoFallback::kSentSHA1CertAndUsedSHA1 - : SSLLegacyCryptoFallback::kUsedSHA1; - } else { - fallback = sent_sha1_cert ? SSLLegacyCryptoFallback::kSentSHA1Cert - : SSLLegacyCryptoFallback::kUnknownReason; - } - } - UMA_HISTOGRAM_ENUMERATION("Net.SSLLegacyCryptoFallback2", fallback); } base::UmaHistogramSparse("Net.SSL_Connection_Error", std::abs(result)); diff --git a/net/socket/ssl_connect_job.h b/net/socket/ssl_connect_job.h index fe308b3db..3df1d8719 100644 --- a/net/socket/ssl_connect_job.h +++ b/net/socket/ssl_connect_job.h @@ -19,7 +19,6 @@ #include "net/base/completion_repeating_callback.h" #include "net/base/net_export.h" #include "net/base/network_anonymization_key.h" -#include "net/base/privacy_mode.h" #include "net/dns/public/host_resolver_results.h" #include "net/dns/public/resolve_error_info.h" #include "net/socket/connect_job.h" @@ -48,7 +47,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams scoped_refptr<HttpProxySocketParams> http_proxy_params, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - PrivacyMode privacy_mode, NetworkAnonymizationKey network_anonymization_key); SSLSocketParams(const SSLSocketParams&) = delete; @@ -69,7 +67,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams const HostPortPair& host_and_port() const { return host_and_port_; } const SSLConfig& ssl_config() const { return ssl_config_; } - PrivacyMode privacy_mode() const { return privacy_mode_; } const NetworkAnonymizationKey& network_anonymization_key() const { return network_anonymization_key_; } @@ -83,7 +80,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams const scoped_refptr<HttpProxySocketParams> http_proxy_params_; const HostPortPair host_and_port_; const SSLConfig ssl_config_; - const PrivacyMode privacy_mode_; const NetworkAnonymizationKey network_anonymization_key_; }; diff --git a/net/socket/ssl_connect_job_unittest.cc b/net/socket/ssl_connect_job_unittest.cc index 056f0d95f..71b7b2b5e 100644 --- a/net/socket/ssl_connect_job_unittest.cc +++ b/net/socket/ssl_connect_job_unittest.cc @@ -48,7 +48,6 @@ #include "net/socket/transport_connect_job.h" #include "net/ssl/ssl_config_service_defaults.h" #include "net/ssl/ssl_connection_status_flags.h" -#include "net/ssl/ssl_legacy_crypto_fallback.h" #include "net/ssl/test_ssl_config_service.h" #include "net/test/cert_test_util.h" #include "net/test/gtest_util.h" @@ -148,7 +147,8 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test { SecureDnsPolicy secure_dns_policy) { return base::MakeRefCounted<HttpProxySocketParams>( CreateProxyTransportSocketParams(secure_dns_policy), - /*ssl_params=*/nullptr, kHostHttp, kHttpProxyChain, + /*ssl_params=*/nullptr, /*quic_ssl_config=*/std::nullopt, kHostHttp, + kHttpProxyChain, /*proxy_server_index=*/0, /*tunnel=*/true, TRAFFIC_ANNOTATION_FOR_TESTS, NetworkAnonymizationKey(), secure_dns_policy); @@ -181,7 +181,7 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test { ? CreateHttpProxySocketParams(secure_dns_policy) : nullptr, HostPortPair::FromSchemeHostPort(kHostHttps), SSLConfig(), - PRIVACY_MODE_DISABLED, NetworkAnonymizationKey()); + NetworkAnonymizationKey()); } void AddAuthToCache() { @@ -551,109 +551,6 @@ TEST_F(SSLConnectJobTest, DirectSSLError) { test::IsError(ERR_BAD_SSL_CLIENT_AUTH_CERT)); } -TEST_F(SSLConnectJobTest, LegacyCryptoFallbackHistograms) { - base::FilePath certs_dir = GetTestCertsDirectory(); - - scoped_refptr<X509Certificate> sha1_leaf = - ImportCertFromFile(certs_dir, "sha1_leaf.pem"); - ASSERT_TRUE(sha1_leaf); - - scoped_refptr<X509Certificate> ok_cert = - ImportCertFromFile(certs_dir, "ok_cert.pem"); - ASSERT_TRUE(ok_cert); - - // Make a copy of |ok_cert| with an unused |sha1_leaf| in the intermediate - // list. - std::vector<bssl::UniquePtr<CRYPTO_BUFFER>> intermediates; - for (const auto& cert : ok_cert->intermediate_buffers()) { - intermediates.push_back(bssl::UpRef(cert)); - } - intermediates.push_back(bssl::UpRef(sha1_leaf->cert_buffer())); - scoped_refptr<X509Certificate> ok_with_unused_sha1 = - X509Certificate::CreateFromBuffer(bssl::UpRef(ok_cert->cert_buffer()), - std::move(intermediates)); - ASSERT_TRUE(ok_with_unused_sha1); - - // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - const uint16_t kModernCipher = 0xc02f; - - struct HistogramTest { - SSLLegacyCryptoFallback expected; - Error first_attempt; - uint16_t cipher_suite; - uint16_t peer_signature_algorithm; - scoped_refptr<X509Certificate> unverified_cert; - }; - - const HistogramTest kHistogramTests[] = { - // Connections not using the fallback map to kNoFallback. - {SSLLegacyCryptoFallback::kNoFallback, OK, kModernCipher, - SSL_SIGN_RSA_PSS_RSAE_SHA256, ok_cert}, - {SSLLegacyCryptoFallback::kNoFallback, OK, kModernCipher, - SSL_SIGN_RSA_PSS_RSAE_SHA256, sha1_leaf}, - {SSLLegacyCryptoFallback::kNoFallback, OK, kModernCipher, - SSL_SIGN_RSA_PSS_RSAE_SHA256, ok_with_unused_sha1}, - - // Connections using SHA-1 map to kUsedSHA1 or kSentSHA1CertAndUsedSHA1. - {SSLLegacyCryptoFallback::kUsedSHA1, ERR_SSL_PROTOCOL_ERROR, - kModernCipher, SSL_SIGN_RSA_PKCS1_SHA1, ok_cert}, - {SSLLegacyCryptoFallback::kSentSHA1CertAndUsedSHA1, - ERR_SSL_PROTOCOL_ERROR, kModernCipher, SSL_SIGN_RSA_PKCS1_SHA1, - sha1_leaf}, - {SSLLegacyCryptoFallback::kSentSHA1CertAndUsedSHA1, - ERR_SSL_PROTOCOL_ERROR, kModernCipher, SSL_SIGN_RSA_PKCS1_SHA1, - ok_with_unused_sha1}, - - // Connections using neither map to kUnknownReason or kSentSHA1Cert. - {SSLLegacyCryptoFallback::kUnknownReason, ERR_SSL_PROTOCOL_ERROR, - kModernCipher, SSL_SIGN_RSA_PSS_RSAE_SHA256, ok_cert}, - {SSLLegacyCryptoFallback::kSentSHA1Cert, ERR_SSL_PROTOCOL_ERROR, - kModernCipher, SSL_SIGN_RSA_PSS_RSAE_SHA256, sha1_leaf}, - {SSLLegacyCryptoFallback::kSentSHA1Cert, ERR_SSL_PROTOCOL_ERROR, - kModernCipher, SSL_SIGN_RSA_PSS_RSAE_SHA256, ok_with_unused_sha1}, - }; - for (size_t i = 0; i < std::size(kHistogramTests); i++) { - SCOPED_TRACE(i); - const auto& test = kHistogramTests[i]; - - base::HistogramTester tester; - - SSLInfo ssl_info; - SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2, - &ssl_info.connection_status); - SSLConnectionStatusSetCipherSuite(test.cipher_suite, - &ssl_info.connection_status); - ssl_info.peer_signature_algorithm = test.peer_signature_algorithm; - ssl_info.unverified_cert = test.unverified_cert; - - StaticSocketDataProvider data; - socket_factory_.AddSocketDataProvider(&data); - SSLSocketDataProvider ssl(ASYNC, test.first_attempt); - socket_factory_.AddSSLSocketDataProvider(&ssl); - ssl.expected_disable_sha1_server_signatures = true; - - StaticSocketDataProvider data2; - SSLSocketDataProvider ssl2(ASYNC, OK); - if (test.first_attempt != OK) { - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - ssl2.ssl_info = ssl_info; - ssl2.expected_disable_sha1_server_signatures = true; - } else { - ssl.ssl_info = ssl_info; - } - - TestConnectJobDelegate test_delegate; - std::unique_ptr<ConnectJob> ssl_connect_job = - CreateConnectJob(&test_delegate); - - test_delegate.StartJobExpectingResult(ssl_connect_job.get(), OK, - /*expect_sync_result=*/false); - - tester.ExpectUniqueSample("Net.SSLLegacyCryptoFallback2", test.expected, 1); - } -} - TEST_F(SSLConnectJobTest, DirectWithNPN) { StaticSocketDataProvider data; socket_factory_.AddSocketDataProvider(&data); @@ -1549,7 +1446,6 @@ TEST_F(SSLConnectJobTest, ECHRecoveryThenLegacyCrypto) { // The handshake will then fail, and provide retry configs. SSLSocketDataProvider ssl2(ASYNC, ERR_ECH_NOT_NEGOTIATED); ssl2.expected_ech_config_list = ech_config_list2; - ssl2.expected_disable_sha1_server_signatures = true; ssl2.ech_retry_configs = ech_config_list3; socket_factory_.AddSSLSocketDataProvider(&ssl2); // The third connection attempt should skip `endpoint1` and retry with only @@ -1562,7 +1458,6 @@ TEST_F(SSLConnectJobTest, ECHRecoveryThenLegacyCrypto) { // further but trigger the legacy crypto fallback. SSLSocketDataProvider ssl3(ASYNC, ERR_SSL_PROTOCOL_ERROR); ssl3.expected_ech_config_list = ech_config_list3; - ssl3.expected_disable_sha1_server_signatures = true; socket_factory_.AddSSLSocketDataProvider(&ssl3); // The third connection attempt should still skip `endpoint1` and retry with // only `endpoint2`. @@ -1574,7 +1469,6 @@ TEST_F(SSLConnectJobTest, ECHRecoveryThenLegacyCrypto) { // connection enables legacy crypto and succeeds. SSLSocketDataProvider ssl4(ASYNC, OK); ssl4.expected_ech_config_list = ech_config_list3; - ssl4.expected_disable_sha1_server_signatures = true; socket_factory_.AddSSLSocketDataProvider(&ssl4); // The connection should ultimately succeed. @@ -1624,7 +1518,6 @@ TEST_F(SSLConnectJobTest, LegacyCryptoThenECHRecovery) { // The handshake will then fail, and trigger the legacy cryptography fallback. SSLSocketDataProvider ssl2(ASYNC, ERR_SSL_PROTOCOL_ERROR); ssl2.expected_ech_config_list = ech_config_list2; - ssl2.expected_disable_sha1_server_signatures = true; socket_factory_.AddSSLSocketDataProvider(&ssl2); // The third and fourth connection attempts proceed as before, but with legacy // cryptography enabled. @@ -1639,7 +1532,6 @@ TEST_F(SSLConnectJobTest, LegacyCryptoThenECHRecovery) { // The handshake enables legacy crypto. Now ECH fails with retry configs. SSLSocketDataProvider ssl4(ASYNC, ERR_ECH_NOT_NEGOTIATED); ssl4.expected_ech_config_list = ech_config_list2; - ssl4.expected_disable_sha1_server_signatures = true; ssl4.ech_retry_configs = ech_config_list3; socket_factory_.AddSSLSocketDataProvider(&ssl4); // The fourth connection attempt should still skip `endpoint1` and retry with @@ -1652,7 +1544,6 @@ TEST_F(SSLConnectJobTest, LegacyCryptoThenECHRecovery) { // cryptography. SSLSocketDataProvider ssl5(ASYNC, OK); ssl5.expected_ech_config_list = ech_config_list3; - ssl5.expected_disable_sha1_server_signatures = true; socket_factory_.AddSSLSocketDataProvider(&ssl5); // The connection should ultimately succeed. diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h index c9896e507..48dcfe4e8 100644 --- a/net/socket/transport_client_socket_pool.h +++ b/net/socket/transport_client_socket_pool.h @@ -540,7 +540,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool // pointer of each element of |jobs_| stored either in // |unassigned_jobs_|, or as the associated |job_| of an // element of |unbound_requests_|. - std::list<ConnectJob*> unassigned_jobs_; + std::list<raw_ptr<ConnectJob, CtnExperimental>> unassigned_jobs_; RequestQueue unbound_requests_; int active_socket_count_ = 0; // number of active sockets used by clients // A timer for when to start the backup job. diff --git a/net/socket/udp_socket_unittest.cc b/net/socket/udp_socket_unittest.cc index b594662ec..db2d7cd87 100644 --- a/net/socket/udp_socket_unittest.cc +++ b/net/socket/udp_socket_unittest.cc @@ -132,7 +132,13 @@ class UDPSocketTest : public PlatformTest, public WithTaskEnvironment { rv = callback.GetResult(rv); if (rv < 0) return std::string(); +#if BUILDFLAG(IS_WIN) + // The DSCP value is not populated on Windows, in order to avoid incurring + // an extra system call. + EXPECT_EQ(socket->GetLastTos().dscp, DSCP_DEFAULT); +#else EXPECT_EQ(socket->GetLastTos().dscp, dscp); +#endif EXPECT_EQ(socket->GetLastTos().ecn, ecn); return std::string(buffer_->data(), rv); } @@ -875,45 +881,109 @@ TEST_F(UDPSocketTest, SetDSCP) { client.Close(); } -// Send DSCP + ECN marked packets from server to client and verity the TOS +// Send DSCP + ECN marked packets from server to client and verify the TOS // bytes that arrive. -#if !BUILDFLAG(IS_WIN) // TODO(crbug.com/1521435): No windows support yet. TEST_F(UDPSocketTest, VerifyDscpAndEcnExchange) { IPEndPoint server_address(IPAddress::IPv4Localhost(), 0); UDPServerSocket server(nullptr, NetLogSource()); + server.AllowAddressReuse(); + ASSERT_THAT(server.Listen(server_address), IsOk()); + // Get bound port. + ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk()); UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource()); + client.Connect(server_address); + EXPECT_EQ(client.SetRecvTos(), 0); + IPEndPoint client_address; + client.GetLocalAddress(&client_address); + + EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_ECT1), 0); + std::string first_message = "foobar"; + EXPECT_EQ(SendToSocket(&server, first_message, client_address), + static_cast<int>(first_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_AF41, ECN_ECT1), first_message.data()); + + std::string second_message = "foo"; + EXPECT_EQ(server.SetTos(DSCP_CS2, ECN_ECT0), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_CS2, ECN_ECT0), second_message.data()); + +#if BUILDFLAG(IS_WIN) + // The Windows sendmsg API does not allow setting ECN_CE as the outgoing mark. + EcnCodePoint final_ecn = ECN_ECT1; +#else + EcnCodePoint final_ecn = ECN_CE; +#endif + + EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, final_ecn), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_CS2, final_ecn), second_message.data()); + + EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_NO_CHANGE), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data()); + + EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data()); + + server.Close(); + client.Close(); +} + +// For windows, test with Nonblocking sockets. For other platforms, this test +// is identical to VerifyDscpAndEcnExchange, above. +TEST_F(UDPSocketTest, VerifyDscpAndEcnExchangeNonBlocking) { + IPEndPoint server_address(IPAddress::IPv4Localhost(), 0); + UDPServerSocket server(nullptr, NetLogSource()); + server.UseNonBlockingIO(); server.AllowAddressReuse(); ASSERT_THAT(server.Listen(server_address), IsOk()); + // Get bound port. ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk()); + UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource()); + client.UseNonBlockingIO(); client.Connect(server_address); - client.SetRecvTos(); + EXPECT_EQ(client.SetRecvTos(), 0); IPEndPoint client_address; client.GetLocalAddress(&client_address); - server.SetTos(DSCP_AF41, ECN_ECT1); - SendToSocket(&server, "foo", client_address); - ReadSocket(&client, DSCP_AF41, ECN_ECT1); + EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_ECT1), 0); + std::string first_message = "foobar"; + EXPECT_EQ(SendToSocket(&server, first_message, client_address), + static_cast<int>(first_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_AF41, ECN_ECT1), first_message.data()); + + std::string second_message = "foo"; + EXPECT_EQ(server.SetTos(DSCP_CS2, ECN_ECT0), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_CS2, ECN_ECT0), second_message.data()); - server.SetTos(DSCP_CS2, ECN_ECT0); - SendToSocket(&server, "foo", client_address); - ReadSocket(&client, DSCP_CS2, ECN_ECT0); + // The Windows sendmsg API does not allow setting ECN_CE as the outgoing mark. + EcnCodePoint final_ecn = ECN_ECT1; - server.SetTos(DSCP_NO_CHANGE, ECN_CE); - SendToSocket(&server, "foo", client_address); - ReadSocket(&client, DSCP_CS2, ECN_CE); + EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, final_ecn), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_CS2, final_ecn), second_message.data()); - server.SetTos(DSCP_AF41, ECN_NO_CHANGE); - SendToSocket(&server, "foo", client_address); - ReadSocket(&client, DSCP_AF41, ECN_CE); + EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_NO_CHANGE), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data()); - server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE); - SendToSocket(&server, "foo", client_address); - ReadSocket(&client, DSCP_AF41, ECN_CE); + EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE), 0); + EXPECT_EQ(SendToSocket(&server, second_message, client_address), + static_cast<int>(second_message.length())); + EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data()); server.Close(); client.Close(); } -#endif TEST_F(UDPSocketTest, ConnectUsingNetwork) { // The specific value of this address doesn't really matter, and no diff --git a/net/socket/udp_socket_win.cc b/net/socket/udp_socket_win.cc index 20d094fec..2e28570f9 100644 --- a/net/socket/udp_socket_win.cc +++ b/net/socket/udp_socket_win.cc @@ -66,6 +66,10 @@ class UDPSocketWin::Core : public base::RefCounted<Core> { // The buffers used in Read() and Write(). scoped_refptr<IOBuffer> read_iobuffer_; scoped_refptr<IOBuffer> write_iobuffer_; + // The struct for packet metadata passed to WSARecvMsg(). + std::unique_ptr<WSAMSG> read_message_ = nullptr; + // Big enough for IP_ECN or IPV6_ECN, nothing more. + char read_control_buffer_[WSA_CMSG_SPACE(sizeof(int))]; // The address storage passed to WSARecvFrom(). SockaddrStorage recv_addr_storage_; @@ -589,21 +593,49 @@ int UDPSocketWin::SetDoNotFragment() { return rv == 0 ? OK : MapSystemError(WSAGetLastError()); } +LPFN_WSARECVMSG UDPSocketWin::GetRecvMsgPointer() { + LPFN_WSARECVMSG rv; + GUID message_code = WSAID_WSARECVMSG; + DWORD size; + if (WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, &message_code, + sizeof(message_code), &rv, sizeof(rv), &size, NULL, + NULL) == SOCKET_ERROR) { + return nullptr; + } + return rv; +} + +LPFN_WSASENDMSG UDPSocketWin::GetSendMsgPointer() { + LPFN_WSASENDMSG rv; + GUID message_code = WSAID_WSASENDMSG; + DWORD size; + if (WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, &message_code, + sizeof(message_code), &rv, sizeof(rv), &size, NULL, + NULL) == SOCKET_ERROR) { + return nullptr; + } + return rv; +} + int UDPSocketWin::SetRecvTos() { DCHECK_NE(socket_, INVALID_SOCKET); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); - - int rv; - unsigned int ecn = 1; - if (addr_family_ == AF_INET6) { - rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_RECVTCLASS, - reinterpret_cast<const char*>(&ecn), sizeof(ecn)); - } else { - DCHECK_EQ(addr_family_, AF_INET); - rv = setsockopt(socket_, IPPROTO_IP, IP_RECVTOS, - reinterpret_cast<const char*>(&ecn), sizeof(ecn)); + int rv = WSASetRecvIPEcn(socket_, TRUE); + if (rv != 0) { + int os_error = WSAGetLastError(); + int result = MapSystemError(os_error); + LogRead(result, nullptr, nullptr); + return result; } - return rv == 0 ? OK : MapSystemError(WSAGetLastError()); + wsa_recv_msg_ = GetRecvMsgPointer(); + if (wsa_recv_msg_ == nullptr) { + int os_error = WSAGetLastError(); + int result = MapSystemError(os_error); + LogRead(result, nullptr, nullptr); + return result; + } + report_ecn_ = true; + return 0; } void UDPSocketWin::SetMsgConfirm(bool confirm) {} @@ -671,9 +703,13 @@ void UDPSocketWin::DidCompleteRead() { } else { result = ERR_ADDRESS_INVALID; } + if (core_->read_message_ != nullptr) { + SetLastTosFromWSAMSG(*core_->read_message_); + } } LogRead(result, core_->read_iobuffer_->data(), address_to_log); core_->read_iobuffer_ = nullptr; + core_->read_message_ = nullptr; recv_from_address_ = nullptr; DoReadCallback(result); } @@ -797,10 +833,50 @@ void UDPSocketWin::LogWrite(int result, } } +void UDPSocketWin::PopulateWSAMSG(WSAMSG& message, + SockaddrStorage& storage, + WSABUF* data_buffer, + WSABUF& control_buffer, + bool send) { + bool is_ipv6 = addr_family_ == AF_INET6; + message.name = storage.addr; + message.namelen = storage.addr_len; + message.lpBuffers = data_buffer; + message.dwBufferCount = 1; + message.Control.buf = control_buffer.buf; + message.dwFlags = 0; + if (send) { + message.Control.len = 0; + WSACMSGHDR* cmsg; + message.Control.len += WSA_CMSG_SPACE(sizeof(int)); + cmsg = WSA_CMSG_FIRSTHDR(&message); + cmsg->cmsg_len = WSA_CMSG_LEN(sizeof(int)); + cmsg->cmsg_level = is_ipv6 ? IPPROTO_IPV6 : IPPROTO_IP; + cmsg->cmsg_type = is_ipv6 ? IPV6_ECN : IP_ECN; + *(int*)WSA_CMSG_DATA(cmsg) = static_cast<int>(send_ecn_); + } else { + message.Control.len = control_buffer.len; + } +} + +void UDPSocketWin::SetLastTosFromWSAMSG(WSAMSG& message) { + int ecn = 0; + for (WSACMSGHDR* cmsg = WSA_CMSG_FIRSTHDR(&message); cmsg != NULL; + cmsg = WSA_CMSG_NXTHDR(&message, cmsg)) { + if ((cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_ECN) || + (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_ECN)) { + ecn = *(int*)WSA_CMSG_DATA(cmsg); + break; + } + } + last_tos_.ecn = static_cast<EcnCodePoint>(ecn); +} + int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf, int buf_len, IPEndPoint* address) { DCHECK(!core_->read_iobuffer_.get()); + DCHECK(!core_->read_message_.get()); SockaddrStorage& storage = core_->recv_addr_storage_; storage.addr_len = sizeof(storage.addr_storage); @@ -811,8 +887,26 @@ int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf, DWORD flags = 0; DWORD num; CHECK_NE(INVALID_SOCKET, socket_); - int rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr, - &storage.addr_len, &core_->read_overlapped_, nullptr); + int rv; + std::unique_ptr<WSAMSG> message; + if (report_ecn_) { + WSABUF control_buffer; + control_buffer.buf = core_->read_control_buffer_; + control_buffer.len = sizeof(core_->read_control_buffer_); + message = std::make_unique<WSAMSG>(); + if (message == nullptr) { + return WSA_NOT_ENOUGH_MEMORY; + } + PopulateWSAMSG(*message, storage, &read_buffer, control_buffer, false); + rv = wsa_recv_msg_(socket_, message.get(), &num, &core_->read_overlapped_, + nullptr); + if (rv == 0) { + SetLastTosFromWSAMSG(*message); + } + } else { + rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr, + &storage.addr_len, &core_->read_overlapped_, nullptr); + } if (rv == 0) { if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) { int result = num; @@ -842,6 +936,7 @@ int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf, } core_->WatchForRead(); core_->read_iobuffer_ = buf; + core_->read_message_ = std::move(message); return ERR_IO_PENDING; } @@ -869,8 +964,20 @@ int UDPSocketWin::InternalSendToOverlapped(IOBuffer* buf, DWORD flags = 0; DWORD num; - int rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr, - storage.addr_len, &core_->write_overlapped_, nullptr); + int rv; + if (send_ecn_ != ECN_NOT_ECT) { + WSABUF control_buffer; + char raw_control_buffer[WSA_CMSG_SPACE(sizeof(int))]; + control_buffer.buf = raw_control_buffer; + control_buffer.len = sizeof(raw_control_buffer); + WSAMSG message; + PopulateWSAMSG(message, storage, &write_buffer, control_buffer, true); + rv = wsa_send_msg_(socket_, &message, flags, &num, + &core_->write_overlapped_, nullptr); + } else { + rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr, + storage.addr_len, &core_->write_overlapped_, nullptr); + } if (rv == 0) { if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) { int result = num; @@ -899,8 +1006,29 @@ int UDPSocketWin::InternalRecvFromNonBlocking(IOBuffer* buf, storage.addr_len = sizeof(storage.addr_storage); CHECK_NE(INVALID_SOCKET, socket_); - int rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr, - &storage.addr_len); + + int rv; + if (report_ecn_) { + WSABUF read_buffer; + read_buffer.buf = buf->data(); + read_buffer.len = buf_len; + WSABUF control_buffer; + char raw_control_buffer[WSA_CMSG_SPACE(sizeof(INT))]; + control_buffer.buf = raw_control_buffer; + control_buffer.len = sizeof(raw_control_buffer); + WSAMSG message; + DWORD bytes_read; + PopulateWSAMSG(message, storage, &read_buffer, control_buffer, false); + rv = wsa_recv_msg_(socket_, &message, &bytes_read, nullptr, nullptr); + SetLastTosFromWSAMSG(message); + if (rv == 0) { + rv = bytes_read; // WSARecvMsg() returns zero on delivery, but recvfrom + // returns the number of bytes received. + } + } else { + rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr, + &storage.addr_len); + } if (rv == SOCKET_ERROR) { int os_error = WSAGetLastError(); if (os_error == WSAEWOULDBLOCK) { @@ -946,7 +1074,25 @@ int UDPSocketWin::InternalSendToNonBlocking(IOBuffer* buf, storage.addr_len = 0; } - int rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len); + int rv; + if (send_ecn_ != ECN_NOT_ECT) { + char raw_control_buffer[WSA_CMSG_SPACE(sizeof(INT))]; + WSABUF write_buffer; + write_buffer.buf = buf->data(); + write_buffer.len = buf_len; + WSABUF control_buffer; + control_buffer.buf = raw_control_buffer; + control_buffer.len = sizeof(raw_control_buffer); + WSAMSG message; + DWORD bytes_read; + PopulateWSAMSG(message, storage, &write_buffer, control_buffer, true); + rv = wsa_send_msg_(socket_, &message, 0, &bytes_read, nullptr, nullptr); + if (rv == 0) { + rv = bytes_read; + } + } else { + rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len); + } if (rv == SOCKET_ERROR) { int os_error = WSAGetLastError(); if (os_error == WSAEWOULDBLOCK) { @@ -1191,32 +1337,42 @@ QOS_TRAFFIC_TYPE DscpToTrafficType(DiffServCodePoint dscp) { } int UDPSocketWin::SetDiffServCodePoint(DiffServCodePoint dscp) { - if (dscp == DSCP_NO_CHANGE) - return OK; + return SetTos(dscp, ECN_NO_CHANGE); +} +int UDPSocketWin::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) { if (!is_connected()) return ERR_SOCKET_NOT_CONNECTED; - QwaveApi* api = GetQwaveApi(); - - if (!api->qwave_supported()) - return ERR_NOT_IMPLEMENTED; + if (dscp != DSCP_NO_CHANGE) { + QwaveApi* api = GetQwaveApi(); - if (!dscp_manager_) - dscp_manager_ = std::make_unique<DscpManager>(api, socket_); + if (!api->qwave_supported()) { + return ERR_NOT_IMPLEMENTED; + } - dscp_manager_->Set(dscp); - if (remote_address_) - return dscp_manager_->PrepareForSend(*remote_address_.get()); + if (!dscp_manager_) { + dscp_manager_ = std::make_unique<DscpManager>(api, socket_); + } + dscp_manager_->Set(dscp); + if (remote_address_) { + int rv = dscp_manager_->PrepareForSend(*remote_address_.get()); + if (rv != OK) { + return rv; + } + } + } + if (ecn == ECN_NO_CHANGE) { + return OK; + } + if (wsa_send_msg_ == nullptr) { + wsa_send_msg_ = GetSendMsgPointer(); + } + send_ecn_ = ecn; return OK; } -// TODO(crbug.com/1521435): a stub for future ECN support in Windows. -int UDPSocketWin::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) { - return SetDiffServCodePoint(dscp); -} - int UDPSocketWin::SetIPv6Only(bool ipv6_only) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); if (is_connected()) { diff --git a/net/socket/udp_socket_win.h b/net/socket/udp_socket_win.h index 111a8013c..b4008db4c 100644 --- a/net/socket/udp_socket_win.h +++ b/net/socket/udp_socket_win.h @@ -9,6 +9,9 @@ #include <stdint.h> #include <winsock2.h> +// Must be after winsock2.h: +#include <MSWSock.h> + #include <atomic> #include <memory> #include <set> @@ -26,6 +29,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_export.h" #include "net/base/network_handle.h" +#include "net/base/sockaddr_storage.h" #include "net/log/net_log_with_source.h" #include "net/socket/datagram_socket.h" #include "net/socket/diff_serv_code_point.h" @@ -380,9 +384,14 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate { } bool get_use_non_blocking_io_for_testing() { return use_non_blocking_io_; } - // For now, no support for Windows TOS reporting to Quiche - // TODO(crbug.com/1521435): Add windows support for ECN. - DscpAndEcn GetLastTos() const { return {DSCP_DEFAULT, ECN_DEFAULT}; } + // Because the windows API separates out DSCP and ECN better than Posix, this + // function does not actually return the correct DSCP value, instead always + // returning DSCP_DEFAULT rather than the last incoming value. + // If a use case arises for reading the incoming DSCP value, it would only + // then worth be executing the system call. + // However, the ECN member of the return value is correct if SetRecvTos() + // was called previously on the socket. + DscpAndEcn GetLastTos() const { return last_tos_; } private: enum SocketOptions { @@ -419,6 +428,27 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate { int InternalConnect(const IPEndPoint& address); + // Returns a function pointer to the platform's instantiation of WSARecvMsg() + // or WSASendMsg(). + LPFN_WSARECVMSG GetRecvMsgPointer(); + LPFN_WSASENDMSG GetSendMsgPointer(); + + // Populates |message| with |storage|, |data_buffer|, and |control_buffer| to + // use ECN before calls to either WSASendMsg() (if |send| is true) or + // WSARecvMsg(). + // |data_buffer| is the datagram. |control_buffer| is the storage + // space for cmsgs. If |send| is false for an overlapped socket, the caller + // must retain a reference to |msghdr|, |storage|, and the buf members of + // |data_buffer| and |control_buffer|, in case WSARecvMsg() returns IO_PENDING + // and the result is delivered asynchronously. + void PopulateWSAMSG(WSAMSG& message, + SockaddrStorage& storage, + WSABUF* data_buffer, + WSABUF& control_buffer, + bool send); + // Sets last_tos_ to the last ECN codepoint contained in |message|. + void SetLastTosFromWSAMSG(WSAMSG& message); + // Version for using overlapped IO. int InternalRecvFromOverlapped(IOBuffer* buf, int buf_len, @@ -510,6 +540,20 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate { // UDPSocket is destroyed. OwnedUDPSocketCount owned_socket_count_; + DscpAndEcn last_tos_ = {DSCP_DEFAULT, ECN_DEFAULT}; + + // If true, the socket has been configured to report ECN on incoming + // datagrams. + bool report_ecn_ = false; + + // Function pointers to the platform implementations of WSARecvMsg() and + // WSASendMsg(). + LPFN_WSARECVMSG wsa_recv_msg_ = nullptr; + LPFN_WSASENDMSG wsa_send_msg_ = nullptr; + + // The ECN codepoint to send on outgoing packets. + EcnCodePoint send_ecn_ = ECN_NOT_ECT; + THREAD_CHECKER(thread_checker_); // Used to prevent null dereferences in OnObjectSignaled, when passing an |