diff options
author | Cronet Mainline Eng <cronet-mainline-eng+copybara@google.com> | 2023-08-14 17:15:38 +0000 |
---|---|---|
committer | Mohannad Farrag <aymanm@google.com> | 2023-08-14 17:22:36 +0000 |
commit | ec3a8e8db24bb3ce4b078106b358ca1c4389c14f (patch) | |
tree | 823f64849ad509483bfebb2252199a5fe79b8e43 /net/socket | |
parent | d12afe756882b2521faa0b33cbd4813fcea04c22 (diff) | |
download | cronet-ec3a8e8db24bb3ce4b078106b358ca1c4389c14f.tar.gz |
Import Cronet version 117.0.5938.0
Project import generated by Copybara.
FolderOrigin-RevId: /tmp/copybara-origin/src
Change-Id: Ib7683d0ed240e11ed9068152600c8092afba4571
Diffstat (limited to 'net/socket')
32 files changed, 587 insertions, 229 deletions
diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc index 653b469fc..35d15c24b 100644 --- a/net/socket/client_socket_handle.cc +++ b/net/socket/client_socket_handle.cc @@ -24,6 +24,7 @@ ClientSocketHandle::ClientSocketHandle() : resolve_error_info_(ResolveErrorInfo(OK)) {} ClientSocketHandle::~ClientSocketHandle() { + weak_factory_.InvalidateWeakPtrs(); Reset(); } diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h index b6b71c284..2c7b2b662 100644 --- a/net/socket/client_socket_handle.h +++ b/net/socket/client_socket_handle.h @@ -12,6 +12,7 @@ #include "base/functional/bind.h" #include "base/memory/raw_ptr.h" #include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" #include "base/time/time.h" #include "net/base/ip_endpoint.h" #include "net/base/load_states.h" @@ -209,6 +210,10 @@ class NET_EXPORT ClientSocketHandle { connect_timing_ = connect_timing; } + base::WeakPtr<ClientSocketHandle> GetWeakPtr() { + return weak_factory_.GetWeakPtr(); + } + private: // Called on asynchronous completion of an Init() request. void OnIOComplete(int result); @@ -247,6 +252,8 @@ class NET_EXPORT ClientSocketHandle { // Timing information is set when a connection is successfully established. LoadTimingInfo::ConnectTiming connect_timing_; + + base::WeakPtrFactory<ClientSocketHandle> weak_factory_{this}; }; } // namespace net diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 212b055c7..eeba63b47 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -5708,7 +5708,7 @@ enum class RefreshType { }; // Common base class to test RefreshGroup() when called from either -// OnSSLConfigForServerChanged() matching a specific group or the pool's proxy. +// OnSSLConfigForServersChanged() matching a specific group or the pool's proxy. // // Tests which test behavior specific to one or the other case should use // ClientSocketPoolBaseTest directly. In particular, there is no "other group" @@ -5750,13 +5750,13 @@ class ClientSocketPoolBaseRefreshTest kNetworkAnonymizationKey); } - void OnSSLConfigForServerChanged() { + void OnSSLConfigForServersChanged() { switch (GetParam()) { case RefreshType::kServer: - pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443)); + pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)}); break; case RefreshType::kProxy: - pool_->OnSSLConfigForServerChanged(HostPortPair("myproxy", 70)); + pool_->OnSSLConfigForServersChanged({HostPortPair("myproxy", 70)}); break; } } @@ -5787,7 +5787,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, RefreshGroupCreatesNewConnectJobs) { // success. connect_job_factory_->set_job_type(TestConnectJob::kMockJob); - OnSSLConfigForServerChanged(); + OnSSLConfigForServersChanged(); EXPECT_EQ(OK, callback.WaitForResult()); ASSERT_TRUE(handle.socket()); EXPECT_EQ(0, pool_->IdleSocketCount()); @@ -5819,7 +5819,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, RefreshGroupClosesIdleConnectJobs) { EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kGroupId)); EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kGroupIdInPartition)); - OnSSLConfigForServerChanged(); + OnSSLConfigForServersChanged(); EXPECT_EQ(0, pool_->IdleSocketCount()); EXPECT_FALSE(pool_->HasGroupForTesting(kGroupId)); EXPECT_FALSE(pool_->HasGroupForTesting(kGroupIdInPartition)); @@ -5840,7 +5840,7 @@ TEST_F(ClientSocketPoolBaseTest, EXPECT_EQ(2, pool_->IdleSocketCount()); EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kOtherGroupId)); - pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443)); + pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)}); ASSERT_TRUE(pool_->HasGroupForTesting(kOtherGroupId)); EXPECT_EQ(2, pool_->IdleSocketCount()); EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kOtherGroupId)); @@ -5861,7 +5861,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, RefreshGroupPreventsSocketReuse) { ASSERT_TRUE(pool_->HasGroupForTesting(kGroupId)); EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kGroupId)); - OnSSLConfigForServerChanged(); + OnSSLConfigForServersChanged(); handle.Reset(); EXPECT_EQ(0, pool_->IdleSocketCount()); @@ -5887,7 +5887,7 @@ TEST_F(ClientSocketPoolBaseTest, ASSERT_TRUE(pool_->HasGroupForTesting(kOtherGroupId)); EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kOtherGroupId)); - pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443)); + pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)}); handle.Reset(); EXPECT_EQ(1, pool_->IdleSocketCount()); @@ -5910,7 +5910,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, // This should update the generation, but not cancel the old ConnectJob - it's // not safe to do anything while waiting on the original ConnectJob. - OnSSLConfigForServerChanged(); + OnSSLConfigForServersChanged(); // Providing auth credentials and restarting the request with them will cause // the ConnectJob to complete successfully, but the result will be discarded @@ -5980,7 +5980,7 @@ TEST_F(ClientSocketPoolBaseTest, RefreshProxyRefreshesAllGroups) { // Changes to some other proxy do not affect the pool. The idle socket remains // alive and closing |handle2| makes the socket available for the pool. - pool_->OnSSLConfigForServerChanged(HostPortPair("someotherproxy", 70)); + pool_->OnSSLConfigForServersChanged({HostPortPair("someotherproxy", 70)}); ASSERT_TRUE(pool_->HasGroupForTesting(kGroupId1)); EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kGroupId1)); @@ -5994,7 +5994,7 @@ TEST_F(ClientSocketPoolBaseTest, RefreshProxyRefreshesAllGroups) { EXPECT_EQ(1u, pool_->IdleSocketCountInGroup(kGroupId2)); // Changes to the matching proxy refreshes all groups. - pool_->OnSSLConfigForServerChanged(HostPortPair("myproxy", 70)); + pool_->OnSSLConfigForServersChanged({HostPortPair("myproxy", 70)}); // Idle sockets are closed. EXPECT_EQ(0, pool_->IdleSocketCount()); @@ -6049,7 +6049,7 @@ TEST_F(ClientSocketPoolBaseTest, RefreshBothPrivacyAndNormalSockets) { ASSERT_TRUE(pool_->HasGroupForTesting(kOtherGroupId)); EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kOtherGroupId)); - pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443)); + pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)}); // Active sockets continue to be active. ASSERT_TRUE(pool_->HasGroupForTesting(kGroupId)); diff --git a/net/socket/connect_job.h b/net/socket/connect_job.h index 35c0f4992..2f810362a 100644 --- a/net/socket/connect_job.h +++ b/net/socket/connect_job.h @@ -110,7 +110,7 @@ enum class OnHostResolutionCallbackResult { // ConnectJob synchronously, but may signal the ConnectJob may be destroyed // asynchronously. See OnHostResolutionCallbackResult above. // -// |address_list| is the list of addresses the host being connected to was +// `endpoint_results` is the list of endpoints the host being connected to was // resolved to, with the port fields populated to the port being connected to. using OnHostResolutionCallback = base::RepeatingCallback<OnHostResolutionCallbackResult( diff --git a/net/socket/connect_job_factory_unittest.cc b/net/socket/connect_job_factory_unittest.cc index 45d599223..d3b0268af 100644 --- a/net/socket/connect_job_factory_unittest.cc +++ b/net/socket/connect_job_factory_unittest.cc @@ -185,12 +185,11 @@ class ConnectJobFactoryTest : public TestWithTaskEnvironment { /*websocket_endpoint_lock_manager=*/nullptr}; TestConnectJobDelegate delegate_; + std::unique_ptr<ConnectJobFactory> factory_; raw_ptr<TestHttpProxyConnectJobFactory> http_proxy_job_factory_; raw_ptr<TestSocksConnectJobFactory> socks_job_factory_; raw_ptr<TestSslConnectJobFactory> ssl_job_factory_; raw_ptr<TestTransportConnectJobFactory> transport_job_factory_; - - std::unique_ptr<ConnectJobFactory> factory_; }; TEST_F(ConnectJobFactoryTest, CreateConnectJob) { diff --git a/net/socket/socket_bio_adapter.cc b/net/socket/socket_bio_adapter.cc index 8017b593b..4126ea309 100644 --- a/net/socket/socket_bio_adapter.cc +++ b/net/socket/socket_bio_adapter.cc @@ -4,11 +4,13 @@ #include "net/socket/socket_bio_adapter.h" +#include <stdio.h> #include <string.h> #include <algorithm> #include "base/check_op.h" +#include "base/debug/alias.h" #include "base/functional/bind.h" #include "base/location.h" #include "base/notreached.h" @@ -64,9 +66,9 @@ SocketBIOAdapter::SocketBIOAdapter(StreamSocket* socket, read_buffer_capacity_(read_buffer_capacity), write_buffer_capacity_(write_buffer_capacity), delegate_(delegate) { - bio_.reset(BIO_new(&kBIOMethod)); - bio_->ptr = this; - bio_->init = 1; + bio_.reset(BIO_new(BIOMethod())); + BIO_set_data(bio_.get(), this); + BIO_set_init(bio_.get(), 1); read_callback_ = base::BindRepeating(&SocketBIOAdapter::OnSocketReadComplete, weak_factory_.GetWeakPtr()); @@ -75,16 +77,19 @@ SocketBIOAdapter::SocketBIOAdapter(StreamSocket* socket, } SocketBIOAdapter::~SocketBIOAdapter() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); // BIOs are reference-counted and may outlive the adapter. Clear the pointer // so future operations fail. - bio_->ptr = nullptr; + BIO_set_data(bio_.get(), nullptr); } bool SocketBIOAdapter::HasPendingReadData() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return read_result_ > 0; } size_t SocketBIOAdapter::GetAllocationSize() const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); size_t buffer_size = 0; if (read_buffer_) buffer_size += read_buffer_capacity_; @@ -95,6 +100,7 @@ size_t SocketBIOAdapter::GetAllocationSize() const { } int SocketBIOAdapter::BIORead(char* out, int len) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (len <= 0) return len; @@ -115,9 +121,10 @@ int SocketBIOAdapter::BIORead(char* out, int len) { // layer reads the record header and body in separate reads to avoid // overreading, but issuing one is more efficient. SSL sockets are not // reused after shutdown for non-SSL traffic, so overreading is fine. - DCHECK(!read_buffer_); - DCHECK_EQ(0, read_offset_); + CHECK(!read_buffer_); + CHECK_EQ(0, read_offset_); read_buffer_ = base::MakeRefCounted<IOBuffer>(read_buffer_capacity_); + read_result_ = ERR_IO_PENDING; int result = socket_->ReadIfReady( read_buffer_.get(), read_buffer_capacity_, base::BindOnce(&SocketBIOAdapter::OnSocketReadIfReadyComplete, @@ -128,9 +135,8 @@ int SocketBIOAdapter::BIORead(char* out, int len) { result = socket_->Read(read_buffer_.get(), read_buffer_capacity_, read_callback_); } - if (result == ERR_IO_PENDING) { - read_result_ = ERR_IO_PENDING; - } else { + if (result != ERR_IO_PENDING) { + // `HandleSocketReadResult` will update `read_result_` based on `result`. HandleSocketReadResult(result); } } @@ -164,7 +170,9 @@ int SocketBIOAdapter::BIORead(char* out, int len) { } void SocketBIOAdapter::HandleSocketReadResult(int result) { - DCHECK_NE(ERR_IO_PENDING, result); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK_NE(ERR_IO_PENDING, result); + CHECK_EQ(ERR_IO_PENDING, read_result_); // If an EOF, canonicalize to ERR_CONNECTION_CLOSED here, so that higher // levels don't report success. @@ -179,15 +187,17 @@ void SocketBIOAdapter::HandleSocketReadResult(int result) { } void SocketBIOAdapter::OnSocketReadComplete(int result) { - DCHECK_EQ(ERR_IO_PENDING, read_result_); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK_EQ(ERR_IO_PENDING, read_result_); HandleSocketReadResult(result); delegate_->OnReadReady(); } void SocketBIOAdapter::OnSocketReadIfReadyComplete(int result) { - DCHECK_EQ(ERR_IO_PENDING, read_result_); - DCHECK_GE(OK, result); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK_EQ(ERR_IO_PENDING, read_result_); + CHECK_GE(OK, result); // Do not use HandleSocketReadResult() because result == OK doesn't mean EOF. read_result_ = result; @@ -196,12 +206,13 @@ void SocketBIOAdapter::OnSocketReadIfReadyComplete(int result) { } int SocketBIOAdapter::BIOWrite(const char* in, int len) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (len <= 0) return len; // If the write buffer is not empty, there must be a pending Write() to flush // it. - DCHECK(write_buffer_used_ == 0 || write_error_ == ERR_IO_PENDING); + CHECK(write_buffer_used_ == 0 || write_error_ == ERR_IO_PENDING); // If a previous Write() failed, report the error. if (write_error_ != OK && write_error_ != ERR_IO_PENDING) { @@ -211,7 +222,7 @@ int SocketBIOAdapter::BIOWrite(const char* in, int len) { // Instantiate the write buffer if needed. if (!write_buffer_) { - DCHECK_EQ(0, write_buffer_used_); + CHECK_EQ(0, write_buffer_used_); write_buffer_ = base::MakeRefCounted<GrowableIOBuffer>(); write_buffer_->SetCapacity(write_buffer_capacity_); } @@ -250,7 +261,7 @@ int SocketBIOAdapter::BIOWrite(const char* in, int len) { } // Either the buffer is now full or there is no more input. - DCHECK(len == 0 || write_buffer_used_ == write_buffer_->capacity()); + CHECK(len == 0 || write_buffer_used_ == write_buffer_->capacity()); // Schedule a socket Write() if necessary. (The ring buffer may previously // have been empty.) @@ -270,22 +281,45 @@ int SocketBIOAdapter::BIOWrite(const char* in, int len) { } void SocketBIOAdapter::SocketWrite() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); while (write_error_ == OK && write_buffer_used_ > 0) { + int write_buffer_used_old = write_buffer_used_; int write_size = std::min(write_buffer_used_, write_buffer_->RemainingCapacity()); + + // TODO(crbug.com/1440692): Remove this once the crash is resolved. + char debug[128]; + snprintf(debug, sizeof(debug), + "offset=%d;remaining=%d;used=%d;write_size=%d", + write_buffer_->offset(), write_buffer_->RemainingCapacity(), + write_buffer_used_, write_size); + base::debug::Alias(debug); + + write_error_ = ERR_IO_PENDING; int result = socket_->Write(write_buffer_.get(), write_size, write_callback_, kTrafficAnnotation); - if (result == ERR_IO_PENDING) { - write_error_ = ERR_IO_PENDING; - return; - } - HandleSocketWriteResult(result); + // TODO(crbug.com/1440692): Remove this once the crash is resolved. + char debug2[32]; + snprintf(debug2, sizeof(debug2), "result=%d", result); + base::debug::Alias(debug2); + + // If `write_buffer_used_` changed across a call to the underlying socket, + // something went very wrong. + // + // TODO(crbug.com/1440692): Remove this once the crash is resolved. + CHECK_EQ(write_buffer_used_old, write_buffer_used_); + if (result != ERR_IO_PENDING) { + // `HandleSocketWriteResult` will update `write_error_` based on `result. + HandleSocketWriteResult(result); + } } } void SocketBIOAdapter::HandleSocketWriteResult(int result) { - DCHECK_NE(ERR_IO_PENDING, result); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK_NE(ERR_IO_PENDING, result); + CHECK_EQ(ERR_IO_PENDING, write_error_); if (result < 0) { write_error_ = result; @@ -297,6 +331,8 @@ void SocketBIOAdapter::HandleSocketWriteResult(int result) { } // Advance the ring buffer. + CHECK_LE(result, write_buffer_used_); + CHECK_LE(result, write_buffer_->RemainingCapacity()); write_buffer_->set_offset(write_buffer_->offset() + result); write_buffer_used_ -= result; if (write_buffer_->RemainingCapacity() == 0) @@ -309,7 +345,8 @@ void SocketBIOAdapter::HandleSocketWriteResult(int result) { } void SocketBIOAdapter::OnSocketWriteComplete(int result) { - DCHECK_EQ(ERR_IO_PENDING, write_error_); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK_EQ(ERR_IO_PENDING, write_error_); bool was_full = write_buffer_used_ == write_buffer_->capacity(); @@ -333,15 +370,17 @@ void SocketBIOAdapter::OnSocketWriteComplete(int result) { } void SocketBIOAdapter::CallOnReadReady() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (read_result_ == ERR_IO_PENDING) delegate_->OnReadReady(); } SocketBIOAdapter* SocketBIOAdapter::GetAdapter(BIO* bio) { - DCHECK_EQ(&kBIOMethod, bio->method); - SocketBIOAdapter* adapter = reinterpret_cast<SocketBIOAdapter*>(bio->ptr); - if (adapter) - DCHECK_EQ(bio, adapter->bio()); + SocketBIOAdapter* adapter = + reinterpret_cast<SocketBIOAdapter*>(BIO_get_data(bio)); + if (adapter) { + CHECK_EQ(bio, adapter->bio()); + } return adapter; } @@ -383,17 +422,16 @@ long SocketBIOAdapter::BIOCtrlWrapper(BIO* bio, return 0; } -const BIO_METHOD SocketBIOAdapter::kBIOMethod = { - 0, // type (unused) - nullptr, // name (unused) - SocketBIOAdapter::BIOWriteWrapper, - SocketBIOAdapter::BIOReadWrapper, - nullptr, // puts - nullptr, // gets - SocketBIOAdapter::BIOCtrlWrapper, - nullptr, // create - nullptr, // destroy - nullptr, // callback_ctrl -}; +const BIO_METHOD* SocketBIOAdapter::BIOMethod() { + static const BIO_METHOD* kMethod = []() { + BIO_METHOD* method = BIO_meth_new(0, nullptr); + CHECK(method); + CHECK(BIO_meth_set_write(method, SocketBIOAdapter::BIOWriteWrapper)); + CHECK(BIO_meth_set_read(method, SocketBIOAdapter::BIOReadWrapper)); + CHECK(BIO_meth_set_ctrl(method, SocketBIOAdapter::BIOCtrlWrapper)); + return method; + }(); + return kMethod; +} } // namespace net diff --git a/net/socket/socket_bio_adapter.h b/net/socket/socket_bio_adapter.h index e22ab6992..06212a7a1 100644 --- a/net/socket/socket_bio_adapter.h +++ b/net/socket/socket_bio_adapter.h @@ -8,6 +8,7 @@ #include "base/memory/raw_ptr.h" #include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" +#include "base/sequence_checker.h" #include "net/base/completion_repeating_callback.h" #include "net/base/net_errors.h" #include "net/base/net_export.h" @@ -104,7 +105,7 @@ class NET_EXPORT_PRIVATE SocketBIOAdapter { static int BIOWriteWrapper(BIO* bio, const char* in, int len); static long BIOCtrlWrapper(BIO* bio, int cmd, long larg, void* parg); - static const BIO_METHOD kBIOMethod; + static const BIO_METHOD* BIOMethod(); bssl::UniquePtr<BIO> bio_; @@ -142,6 +143,7 @@ class NET_EXPORT_PRIVATE SocketBIOAdapter { raw_ptr<Delegate> delegate_; + SEQUENCE_CHECKER(sequence_checker_); base::WeakPtrFactory<SocketBIOAdapter> weak_factory_{this}; }; diff --git a/net/socket/socket_posix.cc b/net/socket/socket_posix.cc index ea3d794d9..14f09938c 100644 --- a/net/socket/socket_posix.cc +++ b/net/socket/socket_posix.cc @@ -325,12 +325,12 @@ int SocketPosix::Write( CompletionOnceCallback callback, const NetworkTrafficAnnotationTag& /* traffic_annotation */) { DCHECK(thread_checker_.CalledOnValidThread()); - DCHECK_NE(kInvalidSocket, socket_fd_); - DCHECK(!waiting_connect_); + CHECK_NE(kInvalidSocket, socket_fd_); + CHECK(!waiting_connect_); CHECK(write_callback_.is_null()); // Synchronous operation not supported - DCHECK(!callback.is_null()); - DCHECK_LT(0, buf_len); + CHECK(!callback.is_null()); + CHECK_LT(0, buf_len); int rv = DoWrite(buf, buf_len); if (rv == ERR_IO_PENDING) @@ -525,6 +525,9 @@ int SocketPosix::DoWrite(IOBuffer* buf, int buf_len) { #else int rv = HANDLE_EINTR(write(socket_fd_, buf->data(), buf_len)); #endif + if (rv >= 0) { + CHECK_LE(rv, buf_len); + } return rv >= 0 ? rv : MapSystemError(errno); } diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 441af898c..acaf5868c 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -47,6 +47,7 @@ #include "net/traffic_annotation/network_traffic_annotation.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "testing/gtest/include/gtest/gtest.h" +#include "third_party/abseil-cpp/absl/strings/ascii.h" #if BUILDFLAG(IS_ANDROID) #include "base/android/build_info.h" @@ -68,9 +69,7 @@ inline char AsciifyLow(char x) { } inline char Asciify(char x) { - if ((x < 0) || !isprint(x)) - return '.'; - return x; + return absl::ascii_isprint(static_cast<unsigned char>(x)) ? x : '.'; } void DumpData(const char* data, int data_len) { @@ -1443,8 +1442,8 @@ void MockSSLClientSocket::GetSSLCertRequestInfo( cert_request_info->is_proxy = data_->cert_request_info->is_proxy; cert_request_info->cert_authorities = data_->cert_request_info->cert_authorities; - cert_request_info->cert_key_types = - data_->cert_request_info->cert_key_types; + cert_request_info->signature_algorithms = + data_->cert_request_info->signature_algorithms; } else { cert_request_info->Reset(); } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index cf8115ec1..5119e6619 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -495,9 +495,7 @@ struct SSLSocketDataProvider { SSLInfo ssl_info; // Result for GetSSLCertRequestInfo(). - // This field is not a raw_ptr<> because it was filtered by the rewriter for: - // #union - RAW_PTR_EXCLUSION SSLCertRequestInfo* cert_request_info = nullptr; + scoped_refptr<SSLCertRequestInfo> cert_request_info; // Result for GetECHRetryConfigs(). std::vector<uint8_t> ech_retry_configs; @@ -949,7 +947,7 @@ class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket { bool in_confirm_handshake_ = false; NetLogWithSource net_log_; std::unique_ptr<StreamSocket> stream_socket_; - raw_ptr<SSLSocketDataProvider> data_; + raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_; // Address of the "remote" peer we're connected to. IPEndPoint peer_addr_; @@ -1355,8 +1353,10 @@ class MockTaggingClientSocketFactory : public MockClientSocketFactory { MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; } private: - raw_ptr<MockTaggingStreamSocket> tcp_socket_ = nullptr; - raw_ptr<MockUDPClientSocket> udp_socket_ = nullptr; + raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ = + nullptr; + raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ = + nullptr; }; // Host / port used for SOCKS4 test strings. diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index e606c591c..e12475f4e 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -143,7 +143,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { bool was_ever_used_ = false; // Used to resolve the hostname to which the SOCKS proxy will connect. - raw_ptr<HostResolver> host_resolver_; + raw_ptr<HostResolver, DanglingUntriaged> host_resolver_; SecureDnsPolicy secure_dns_policy_; std::unique_ptr<HostResolver::ResolveHostRequest> resolve_host_request_; const HostPortPair destination_; diff --git a/net/socket/ssl_client_socket.cc b/net/socket/ssl_client_socket.cc index 4dd73b5c3..8290b9592 100644 --- a/net/socket/ssl_client_socket.cc +++ b/net/socket/ssl_client_socket.cc @@ -6,6 +6,7 @@ #include <string> +#include "base/containers/flat_tree.h" #include "base/logging.h" #include "base/observer_list.h" #include "net/socket/ssl_client_socket_impl.h" @@ -103,9 +104,9 @@ void SSLClientContext::SetClientCertificate( if (ssl_client_session_cache_) { // Session resumption bypasses client certificate negotiation, so flush all // associated sessions when preferences change. - ssl_client_session_cache_->FlushForServer(server); + ssl_client_session_cache_->FlushForServers({server}); } - NotifySSLConfigForServerChanged(server); + NotifySSLConfigForServersChanged({server}); } bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) { @@ -116,9 +117,9 @@ bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) { if (ssl_client_session_cache_) { // Session resumption bypasses client certificate negotiation, so flush all // associated sessions when preferences change. - ssl_client_session_cache_->FlushForServer(server); + ssl_client_session_cache_->FlushForServers({server}); } - NotifySSLConfigForServerChanged(server); + NotifySSLConfigForServersChanged({server}); return true; } @@ -131,11 +132,10 @@ void SSLClientContext::RemoveObserver(Observer* observer) { } void SSLClientContext::OnSSLContextConfigChanged() { - // TODO(davidben): Should we flush |ssl_client_session_cache_| here? We flush - // the socket pools, but not the session cache. While BoringSSL-based servers - // never change version or cipher negotiation based on client-offered - // sessions, other servers do. config_ = ssl_config_service_->GetSSLContextConfig(); + if (ssl_client_session_cache_) { + ssl_client_session_cache_->Flush(); + } NotifySSLConfigChanged(SSLConfigChangeType::kSSLConfigChanged); } @@ -143,13 +143,18 @@ void SSLClientContext::OnCertVerifierChanged() { NotifySSLConfigChanged(SSLConfigChangeType::kCertVerifierChanged); } -void SSLClientContext::OnCertDBChanged() { - // Both the trust store and client certificate store may have changed. +void SSLClientContext::OnTrustStoreChanged() { + NotifySSLConfigChanged(SSLConfigChangeType::kCertDatabaseChanged); +} + +void SSLClientContext::OnClientCertStoreChanged() { + base::flat_set<HostPortPair> servers = + ssl_client_auth_cache_.GetCachedServers(); ssl_client_auth_cache_.Clear(); if (ssl_client_session_cache_) { - ssl_client_session_cache_->Flush(); + ssl_client_session_cache_->FlushForServers(servers); } - NotifySSLConfigChanged(SSLConfigChangeType::kCertDatabaseChanged); + NotifySSLConfigForServersChanged(servers); } void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) { @@ -158,10 +163,10 @@ void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) { } } -void SSLClientContext::NotifySSLConfigForServerChanged( - const HostPortPair& server) { +void SSLClientContext::NotifySSLConfigForServersChanged( + const base::flat_set<HostPortPair>& servers) { for (Observer& observer : observers_) { - observer.OnSSLConfigForServerChanged(server); + observer.OnSSLConfigForServersChanged(servers); } } diff --git a/net/socket/ssl_client_socket.h b/net/socket/ssl_client_socket.h index ddbbf708c..f34320c57 100644 --- a/net/socket/ssl_client_socket.h +++ b/net/socket/ssl_client_socket.h @@ -10,6 +10,7 @@ #include <memory> #include <vector> +#include "base/containers/flat_set.h" #include "base/gtest_prod_util.h" #include "base/memory/raw_ptr.h" #include "base/observer_list.h" @@ -101,11 +102,13 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer, // Called when SSL configuration for all hosts changed. Newly-created // SSLClientSockets will pick up the new configuration. Note that changes // which only apply to one server will result in a call to - // OnSSLConfigForServerChanged() instead. + // OnSSLConfigForServersChanged() instead. virtual void OnSSLConfigChanged(SSLConfigChangeType change_type) = 0; - // Called when SSL configuration for |server| changed. Newly-created - // SSLClientSockets to |server| will pick up the new configuration. - virtual void OnSSLConfigForServerChanged(const HostPortPair& server) = 0; + // Called when SSL configuration for |servers| changed. Newly-created + // SSLClientSockets to any server in |servers| will pick up the new + // configuration. + virtual void OnSSLConfigForServersChanged( + const base::flat_set<HostPortPair>& servers) = 0; }; // Creates a new SSLClientContext with the specified parameters. The @@ -164,7 +167,7 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer, // |private_key| may be null to indicate that no client certificate should be // sent to |server|. // - // Note this method will synchronously call OnSSLConfigForServerChanged() on + // Note this method will synchronously call OnSSLConfigForServersChanged() on // observers. void SetClientCertificate(const HostPortPair& server, scoped_refptr<X509Certificate> client_cert, @@ -174,10 +177,15 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer, // SetClientCertificate(). Returns true if one was removed and false // otherwise. // - // Note this method will synchronously call OnSSLConfigForServerChanged() on + // Note this method will synchronously call OnSSLConfigForServersChanged() on // observers. bool ClearClientCertificate(const HostPortPair& server); + base::flat_set<HostPortPair> GetClientCertificateCachedServersForTesting() + const { + return ssl_client_auth_cache_.GetCachedServers(); + } + // Add an observer to be notified when configuration has changed. // RemoveObserver() must be called before |observer| is destroyed. void AddObserver(Observer* observer); @@ -192,11 +200,13 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer, void OnCertVerifierChanged() override; // CertDatabase::Observer: - void OnCertDBChanged() override; + void OnTrustStoreChanged() override; + void OnClientCertStoreChanged() override; private: void NotifySSLConfigChanged(SSLConfigChangeType change_type); - void NotifySSLConfigForServerChanged(const HostPortPair& server); + void NotifySSLConfigForServersChanged( + const base::flat_set<HostPortPair>& servers); SSLContextConfig config_; diff --git a/net/socket/ssl_client_socket_impl.cc b/net/socket/ssl_client_socket_impl.cc index 31256e791..9d7134c1c 100644 --- a/net/socket/ssl_client_socket_impl.cc +++ b/net/socket/ssl_client_socket_impl.cc @@ -624,22 +624,11 @@ void SSLClientSocketImpl::GetSSLCertRequestInfo( CRYPTO_BUFFER_len(ca_name)); } - cert_request_info->cert_key_types.clear(); - const uint8_t* client_cert_types; - size_t num_client_cert_types = - SSL_get0_certificate_types(ssl_.get(), &client_cert_types); - for (size_t i = 0; i < num_client_cert_types; i++) { - switch (client_cert_types[i]) { - case static_cast<uint8_t>(SSLClientCertType::kRsaSign): - case static_cast<uint8_t>(SSLClientCertType::kEcdsaSign): - cert_request_info->cert_key_types.push_back( - static_cast<SSLClientCertType>(client_cert_types[i])); - break; - default: - // Unknown client certificate types are ignored. - break; - } - } + const uint16_t* algorithms; + size_t num_algorithms = + SSL_get0_peer_verify_algorithms(ssl_.get(), &algorithms); + cert_request_info->signature_algorithms.assign(algorithms, + algorithms + num_algorithms); } void SSLClientSocketImpl::ApplySocketTag(const SocketTag& tag) { @@ -702,8 +691,10 @@ int SSLClientSocketImpl::Write( if (rv == ERR_IO_PENDING) { user_write_callback_ = std::move(callback); } else { - if (rv > 0) + if (rv > 0) { + CHECK_LE(rv, buf_len); was_ever_used_ = true; + } user_write_buf_ = nullptr; user_write_buf_len_ = 0; } @@ -754,9 +745,8 @@ int SSLClientSocketImpl::Init() { return ERR_UNEXPECTED; } - if (context_->config().post_quantum_enabled && - base::FeatureList::IsEnabled(features::kPostQuantumKyber)) { - static const int kCurves[] = {NID_X25519Kyber768, NID_X25519, + if (context_->config().PostQuantumKeyAgreementEnabled()) { + static const int kCurves[] = {NID_X25519Kyber768Draft00, NID_X25519, NID_X9_62_prime256v1, NID_secp384r1}; if (!SSL_set1_curves(ssl_.get(), kCurves, std::size(kCurves))) { return ERR_UNEXPECTED; @@ -1247,10 +1237,15 @@ ssl_verify_result_t SSLClientSocketImpl::HandleVerifyResult() { } // Enforce keyUsage extension for RSA leaf certificates chaining up to known - // roots. - // TODO(crbug.com/795089): Enforce this unconditionally. + // roots unconditionally. Enforcement for local anchors is, for now, + // conditional on feature flags and external configuration. See + // https://crbug.com/795089. + bool rsa_key_usage_for_local_anchors = + context_->config().rsa_key_usage_for_local_anchors_override.value_or( + base::FeatureList::IsEnabled(features::kRSAKeyUsageForLocalAnchors)); SSL_set_enforce_rsa_key_usage( - ssl_.get(), server_cert_verify_result_.is_issued_by_known_root); + ssl_.get(), rsa_key_usage_for_local_anchors || + server_cert_verify_result_.is_issued_by_known_root); // If the connection was good, check HPKP and CT status simultaneously, // but prefer to treat the HPKP error as more serious, if there was one. @@ -1505,6 +1500,7 @@ int SSLClientSocketImpl::DoPayloadWrite() { int rv = SSL_write(ssl_.get(), user_write_buf_->data(), user_write_buf_len_); if (rv >= 0) { + CHECK_LE(rv, user_write_buf_len_); net_log_.AddByteTransferEvent(NetLogEventType::SSL_SOCKET_BYTES_SENT, rv, user_write_buf_->data()); if (first_post_handshake_write_ && SSL_is_init_finished(ssl_.get())) { @@ -1644,21 +1640,9 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) { // Clear any currently configured certificates. SSL_certs_clear(ssl_.get()); -#if BUILDFLAG(IS_IOS) - // TODO(droger): Support client auth on iOS. See http://crbug.com/145954). - // - // Historically this was disabled because client auth required - // platform-specific code deep in //net. Nowadays, this is abstracted away and - // we could enable the interfaces on iOS for platform-independence. However, - // merely enabling them changes our behavior from automatically proceeding - // with no client certificate to raising - // `URLRequest::Delegate::OnCertificateRequested`. Callers would need to be - // updated to apply that behavior manually. - // - // If fixing this, re-enable the tests in ssl_client_socket_unittest.cc and - // ssl_server_socket_unittest.cc which are disabled on iOS. +#if !BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) LOG(WARNING) << "Client auth is not supported"; -#else // !BUILDFLAG(IS_IOS) +#else // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) if (!send_client_cert_) { // First pass: we know that a client certificate is needed, but we do not // have one at hand. Suspend the handshake. SSL_get_error will return @@ -1693,7 +1677,7 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) { client_cert_->intermediate_buffers().size())); return 1; } -#endif // BUILDFLAG(IS_IOS) +#endif // !BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) // Send no client certificate. net_log_.AddEventWithIntParams(NetLogEventType::SSL_CLIENT_CERT_PROVIDED, @@ -1842,7 +1826,7 @@ void SSLClientSocketImpl::MessageCallback(int is_write, break; case SSL3_RT_CLIENT_HELLO_INNER: DCHECK(is_write); - net_log_.AddEvent(NetLogEventType::SSL_ENCYPTED_CLIENT_HELLO, + net_log_.AddEvent(NetLogEventType::SSL_ENCRYPTED_CLIENT_HELLO, [&](NetLogCaptureMode capture_mode) { return NetLogSSLMessageParams(!!is_write, buf, len, capture_mode); diff --git a/net/socket/ssl_client_socket_impl.h b/net/socket/ssl_client_socket_impl.h index e7636ba29..fc7487077 100644 --- a/net/socket/ssl_client_socket_impl.h +++ b/net/socket/ssl_client_socket_impl.h @@ -29,7 +29,6 @@ #include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" #include "net/ssl/openssl_ssl_util.h" -#include "net/ssl/ssl_client_cert_type.h" #include "net/ssl/ssl_client_session_cache.h" #include "net/ssl/ssl_config.h" #include "net/traffic_annotation/network_traffic_annotation.h" diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index 8fd32cd7d..6c30471f2 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -34,6 +34,7 @@ #include "net/base/address_list.h" #include "net/base/completion_once_callback.h" #include "net/base/features.h" +#include "net/base/host_port_pair.h" #include "net/base/io_buffer.h" #include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" @@ -43,6 +44,7 @@ #include "net/base/test_completion_callback.h" #include "net/cert/asn1_util.h" #include "net/cert/cert_and_ct_verifier.h" +#include "net/cert/cert_database.h" #include "net/cert/ct_policy_enforcer.h" #include "net/cert/ct_policy_status.h" #include "net/cert/ct_verifier.h" @@ -583,7 +585,7 @@ class DeleteSocketCallback : public TestCompletionCallbackBase { SetResult(result); } - raw_ptr<StreamSocket> socket_; + raw_ptr<StreamSocket, DanglingUntriaged> socket_; }; // A mock CTVerifier that records every call to Verify but doesn't verify @@ -828,7 +830,7 @@ class SSLClientSocketTest : public PlatformTest, public WithTaskEnvironment { } RecordingNetLogObserver log_observer_; - raw_ptr<ClientSocketFactory> socket_factory_; + raw_ptr<ClientSocketFactory, DanglingUntriaged> socket_factory_; std::unique_ptr<TestSSLConfigService> ssl_config_service_; std::unique_ptr<MockCertVerifier> cert_verifier_; std::unique_ptr<TransportSecurityState> transport_security_state_; @@ -1390,6 +1392,13 @@ class HangingCertVerifier : public CertVerifier { int num_active_requests_ = 0; }; +class MockSSLClientContextObserver : public SSLClientContext::Observer { + public: + MOCK_METHOD1(OnSSLConfigChanged, void(SSLClientContext::SSLConfigChangeType)); + MOCK_METHOD1(OnSSLConfigForServersChanged, + void(const base::flat_set<HostPortPair>&)); +}; + } // namespace INSTANTIATE_TEST_SUITE_P(TLSVersion, @@ -1547,7 +1556,7 @@ TEST_P(SSLClientSocketVersionTest, ConnectBadValidityIgnoreCertErrors) { } // Client certificates are disabled on iOS. -#if !BUILDFLAG(IS_IOS) +#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) // Attempt to connect to a page which requests a client certificate. It should // return an error code on connect. TEST_P(SSLClientSocketVersionTest, ConnectClientAuthCertRequested) { @@ -1590,7 +1599,7 @@ TEST_P(SSLClientSocketVersionTest, ConnectClientAuthSendNullCert) { sock_->Disconnect(); EXPECT_FALSE(sock_->IsConnected()); } -#endif // !IS_IOS +#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) // TODO(wtc): Add unit tests for IsConnectedAndIdle: // - Server closes an SSL connection (with a close_notify alert message). @@ -2659,7 +2668,7 @@ TEST_P(SSLClientSocketVersionTest, VerifyReturnChainProperlyOrdered) { } // Client certificates are disabled on iOS. -#if !BUILDFLAG(IS_IOS) +#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) INSTANTIATE_TEST_SUITE_P(TLSVersion, SSLClientSocketCertRequestInfoTest, ValuesIn(GetTLSVersions())); @@ -2724,20 +2733,14 @@ TEST_P(SSLClientSocketCertRequestInfoTest, CertKeyTypes) { config.client_cert_type = SSLServerConfig::OPTIONAL_CLIENT_CERT; ASSERT_TRUE(StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, config)); scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(); - ASSERT_TRUE(request_info.get()); - if (version() >= SSL_PROTOCOL_VERSION_TLS1_3) { - // TLS 1.3 does not use cert_key_types, only signature algorithms. This - // should be migrated to a more modern mechanism. See - // https://crbug.com/1270530. - EXPECT_EQ(0u, request_info->cert_key_types.size()); - } else { - // BoringSSL always sends rsa_sign and ecdsa_sign. - ASSERT_EQ(2u, request_info->cert_key_types.size()); - EXPECT_EQ(SSLClientCertType::kRsaSign, request_info->cert_key_types[0]); - EXPECT_EQ(SSLClientCertType::kEcdsaSign, request_info->cert_key_types[1]); - } + ASSERT_TRUE(request_info); + // Look for some values we expect BoringSSL to always send. + EXPECT_THAT(request_info->signature_algorithms, + testing::Contains(SSL_SIGN_ECDSA_SECP256R1_SHA256)); + EXPECT_THAT(request_info->signature_algorithms, + testing::Contains(SSL_SIGN_RSA_PSS_RSAE_SHA256)); } -#endif // !IS_IOS +#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) // Tests that the Certificate Transparency (RFC 6962) TLS extension is // supported. @@ -3002,6 +3005,22 @@ TEST_P(SSLClientSocketVersionTest, SessionResumption) { ASSERT_THAT(rv, IsOk()); ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); + + // Pick up the ticket again and confirm resumption works. + EXPECT_THAT(MakeHTTPRequest(sock_.get()), IsOk()); + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + ASSERT_THAT(rv, IsOk()); + ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); + sock_.reset(); + + // Updating the context-wide configuration should flush the session cache. + SSLContextConfig config; + config.disabled_cipher_suites = {1234}; + ssl_config_service_->UpdateSSLConfigAndNotify(config); + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + ASSERT_THAT(rv, IsOk()); + ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); } namespace { @@ -3590,7 +3609,7 @@ TEST_F(SSLClientSocketTest, AlpnClientDisabled) { } // Client certificates are disabled on iOS. -#if !BUILDFLAG(IS_IOS) +#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) // Connect to a server requesting client authentication, do not send // any client certificates. It should refuse the connection. TEST_P(SSLClientSocketVersionTest, NoCert) { @@ -3733,7 +3752,109 @@ TEST_F(SSLClientSocketTest, ClearSessionCacheOnClientCertChange) { ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); EXPECT_THAT(rv, IsError(ERR_BAD_SSL_CLIENT_AUTH_CERT)); } -#endif // !IS_IOS + +TEST_F(SSLClientSocketTest, ClearSessionCacheOnClientCertDatabaseChange) { + SSLServerConfig server_config; + // TLS 1.3 reports client certificate errors after the handshake, so test at + // TLS 1.2 for simplicity. + server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2; + server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT; + ASSERT_TRUE( + StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config)); + + HostPortPair host_port_pair2("example.com", 42); + testing::StrictMock<MockSSLClientContextObserver> observer; + EXPECT_CALL(observer, OnSSLConfigForServersChanged( + base::flat_set<HostPortPair>({host_port_pair()}))); + EXPECT_CALL(observer, OnSSLConfigForServersChanged( + base::flat_set<HostPortPair>({host_port_pair2}))); + EXPECT_CALL(observer, + OnSSLConfigForServersChanged(base::flat_set<HostPortPair>( + {host_port_pair(), host_port_pair2}))); + + context_->AddObserver(&observer); + + base::FilePath certs_dir = GetTestCertsDirectory(); + context_->SetClientCertificate( + host_port_pair(), ImportCertFromFile(certs_dir, "client_1.pem"), + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"))); + + context_->SetClientCertificate( + host_port_pair2, ImportCertFromFile(certs_dir, "client_2.pem"), + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_2.key"))); + + EXPECT_EQ(2U, context_->GetClientCertificateCachedServersForTesting().size()); + + // Connect to `host_port_pair()` using the client cert. + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); + EXPECT_THAT(rv, IsOk()); + EXPECT_TRUE(sock_->IsConnected()); + + EXPECT_EQ(1U, context_->ssl_client_session_cache()->size()); + + CertDatabase::GetInstance()->NotifyObserversClientCertStoreChanged(); + base::RunLoop().RunUntilIdle(); + + EXPECT_EQ(0U, context_->GetClientCertificateCachedServersForTesting().size()); + EXPECT_EQ(0U, context_->ssl_client_session_cache()->size()); + + context_->RemoveObserver(&observer); +} + +TEST_F(SSLClientSocketTest, DontClearSessionCacheOnServerCertDatabaseChange) { + SSLServerConfig server_config; + // TLS 1.3 reports client certificate errors after the handshake, so test at + // TLS 1.2 for simplicity. + server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2; + server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT; + ASSERT_TRUE( + StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config)); + + HostPortPair host_port_pair2("example.com", 42); + testing::StrictMock<MockSSLClientContextObserver> observer; + EXPECT_CALL(observer, OnSSLConfigForServersChanged( + base::flat_set<HostPortPair>({host_port_pair()}))); + EXPECT_CALL(observer, OnSSLConfigForServersChanged( + base::flat_set<HostPortPair>({host_port_pair2}))); + EXPECT_CALL(observer, + OnSSLConfigChanged( + SSLClientContext::SSLConfigChangeType::kCertDatabaseChanged)); + + context_->AddObserver(&observer); + + base::FilePath certs_dir = GetTestCertsDirectory(); + context_->SetClientCertificate( + host_port_pair(), ImportCertFromFile(certs_dir, "client_1.pem"), + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"))); + + context_->SetClientCertificate( + host_port_pair2, ImportCertFromFile(certs_dir, "client_2.pem"), + key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_2.key"))); + + EXPECT_EQ(2U, context_->GetClientCertificateCachedServersForTesting().size()); + + // Connect to `host_port_pair()` using the client cert. + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); + EXPECT_THAT(rv, IsOk()); + EXPECT_TRUE(sock_->IsConnected()); + + EXPECT_EQ(1U, context_->ssl_client_session_cache()->size()); + + CertDatabase::GetInstance()->NotifyObserversTrustStoreChanged(); + base::RunLoop().RunUntilIdle(); + + // The `OnSSLConfigChanged` observer call should be verified by the + // mock observer, but the client auth and client session cache should be + // untouched. + + EXPECT_EQ(2U, context_->GetClientCertificateCachedServersForTesting().size()); + EXPECT_EQ(1U, context_->ssl_client_session_cache()->size()); + + context_->RemoveObserver(&observer); +} +#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) HashValueVector MakeHashValueVector(uint8_t value) { HashValueVector out; @@ -3831,37 +3952,50 @@ const uint16_t kSigningCipher = kModernTLS12Cipher; struct KeyUsageTest { EmbeddedTestServer::ServerCertificate server_cert; uint16_t cipher_suite; - bool known_root; - bool success; + bool match; }; class SSLClientSocketKeyUsageTest : public SSLClientSocketTest, - public ::testing::WithParamInterface<struct KeyUsageTest> {}; + public ::testing::WithParamInterface< + std::tuple<KeyUsageTest, + bool /*known_root*/, + bool /*rsa_key_usage_for_local_anchors_enabled*/, + bool /*override_feature*/>> {}; -const struct KeyUsageTest kKeyUsageTests[] = { - // Known Root: Success iff keyUsage allows the key exchange method - {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kSigningCipher, true, - false}, +const KeyUsageTest kKeyUsageTests[] = { + // keyUsage matches cipher suite. {EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE, kSigningCipher, - true, true}, - {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kEncryptingCipher, - true, true}, - {EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE, - kEncryptingCipher, true, false}, - // Unknown Root: Always succeeds - {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kSigningCipher, false, true}, - {EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE, kSigningCipher, - false, true}, {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kEncryptingCipher, - false, true}, + true}, + // keyUsage does not match cipher suite. + {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kSigningCipher, + false}, {EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE, - kEncryptingCipher, false, true}, + kEncryptingCipher, false}, }; -TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) { - const KeyUsageTest test = GetParam(); +TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsage) { + const auto& [test, known_root, rsa_key_usage_for_local_anchors_enabled, + override_feature] = GetParam(); + bool enable_feature; + if (override_feature) { + // Configure the feature in the opposite way that we intend, to test that + // the configuration overrides it. + enable_feature = !rsa_key_usage_for_local_anchors_enabled; + } else { + enable_feature = rsa_key_usage_for_local_anchors_enabled; + } + base::test::ScopedFeatureList scoped_feature_list; + if (enable_feature) { + scoped_feature_list.InitAndEnableFeature( + features::kRSAKeyUsageForLocalAnchors); + } else { + scoped_feature_list.InitAndDisableFeature( + features::kRSAKeyUsageForLocalAnchors); + } + SSLServerConfig server_config; server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2; server_config.cipher_suite_for_testing = test.cipher_suite; @@ -3869,9 +4003,16 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) { scoped_refptr<X509Certificate> server_cert = embedded_test_server()->GetCertificate(); + SSLContextConfig context_config; + if (override_feature) { + context_config.rsa_key_usage_for_local_anchors_override = + rsa_key_usage_for_local_anchors_enabled; + } + ssl_config_service_->UpdateSSLConfigAndNotify(context_config); + // Certificate is trusted. CertVerifyResult verify_result; - verify_result.is_issued_by_known_root = test.known_root; + verify_result.is_issued_by_known_root = known_root; verify_result.verified_cert = server_cert; verify_result.public_key_hashes = MakeHashValueVector(kGoodHashValueVectorInput); @@ -3883,7 +4024,7 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) { SSLInfo ssl_info; ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); - if (test.success) { + if (test.match || (!known_root && !rsa_key_usage_for_local_anchors_enabled)) { EXPECT_THAT(rv, IsOk()); EXPECT_TRUE(sock_->IsConnected()); } else { @@ -3892,9 +4033,10 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) { } } -INSTANTIATE_TEST_SUITE_P(RSAKeyUsageInstantiation, - SSLClientSocketKeyUsageTest, - ValuesIn(kKeyUsageTests)); +INSTANTIATE_TEST_SUITE_P( + RSAKeyUsageInstantiation, + SSLClientSocketKeyUsageTest, + Combine(ValuesIn(kKeyUsageTests), Bool(), Bool(), Bool())); // Test that when CT is required (in this case, by the delegate), the // absence of CT information is a socket error. diff --git a/net/socket/ssl_server_socket_impl.cc b/net/socket/ssl_server_socket_impl.cc index f5f028847..9f22c20c2 100644 --- a/net/socket/ssl_server_socket_impl.cc +++ b/net/socket/ssl_server_socket_impl.cc @@ -155,19 +155,19 @@ class SSLServerContextImpl::SocketImpl : public SSLServerSocket, void OnHandshakeIOComplete(int result); - int DoPayloadRead(IOBuffer* buf, int buf_len); - int DoPayloadWrite(); + [[nodiscard]] int DoPayloadRead(IOBuffer* buf, int buf_len); + [[nodiscard]] int DoPayloadWrite(); - int DoHandshakeLoop(int last_io_result); - int DoHandshake(); + [[nodiscard]] int DoHandshakeLoop(int last_io_result); + [[nodiscard]] int DoHandshake(); void DoHandshakeCallback(int result); void DoReadCallback(int result); void DoWriteCallback(int result); - int Init(); + [[nodiscard]] int Init(); void ExtractClientCert(); - raw_ptr<SSLServerContextImpl> context_; + raw_ptr<SSLServerContextImpl, DanglingUntriaged> context_; NetLogWithSource net_log_; @@ -312,7 +312,7 @@ void SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete( signature_result_ = error; if (signature_result_ == OK) signature_ = signature; - DoHandshakeLoop(ERR_IO_PENDING); + OnHandshakeIOComplete(ERR_IO_PENDING); } // static diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index eed11bd33..2cfb6563e 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -34,6 +34,7 @@ #include "base/notreached.h" #include "base/run_loop.h" #include "base/task/single_thread_task_runner.h" +#include "base/test/bind.h" #include "base/test/task_environment.h" #include "build/build_config.h" #include "crypto/rsa_private_key.h" @@ -85,12 +86,12 @@ namespace net { namespace { // Client certificates are disabled on iOS. -#if !BUILDFLAG(IS_IOS) +#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) const char kClientCertFileName[] = "client_1.pem"; const char kClientPrivateKeyFileName[] = "client_1.pk8"; const char kWrongClientCertFileName[] = "client_2.pem"; const char kWrongClientPrivateKeyFileName[] = "client_2.pk8"; -#endif // !IS_IOS +#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) const uint16_t kEcdheCiphers[] = { 0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA @@ -442,7 +443,7 @@ class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment { } // Client certificates are disabled on iOS. -#if !BUILDFLAG(IS_IOS) +#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) void ConfigureClientCertsForClient(const char* cert_file_name, const char* private_key_file_name) { scoped_refptr<X509Certificate> client_cert = @@ -477,7 +478,7 @@ class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment { server_ssl_config_.client_cert_verifier = client_cert_verifier_.get(); } -#endif // !IS_IOS +#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(base::StringPiece name) { base::FilePath certs_dir(GetTestCertsDirectory()); @@ -708,7 +709,7 @@ TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) { } // Client certificates are disabled on iOS. -#if !BUILDFLAG(IS_IOS) +#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) // This test executes Connect() on SSLClientSocket and Handshake() on // SSLServerSocket to make sure handshaking between the two sockets is // completed successfully, using client certificate. @@ -1010,7 +1011,7 @@ TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) { client_ret = read_callback.GetResult(client_ret); EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret); } -#endif // !IS_IOS +#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES) TEST_P(SSLServerSocketReadTest, DataTransfer) { ASSERT_NO_FATAL_FAILURE(CreateContext()); @@ -1267,6 +1268,69 @@ TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) { EXPECT_TRUE(is_aead); } +namespace { + +// Helper that wraps an underlying SSLPrivateKey to allow the test to +// do some work immediately before a `Sign()` operation is performed. +class SSLPrivateKeyHook : public SSLPrivateKey { + public: + SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key, + base::RepeatingClosure on_sign) + : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {} + + // SSLPrivateKey implementation. + std::string GetProviderName() override { + return private_key_->GetProviderName(); + } + std::vector<uint16_t> GetAlgorithmPreferences() override { + return private_key_->GetAlgorithmPreferences(); + } + void Sign(uint16_t algorithm, + base::span<const uint8_t> input, + SignCallback callback) override { + on_sign_.Run(); + private_key_->Sign(algorithm, input, std::move(callback)); + } + + private: + ~SSLPrivateKeyHook() override = default; + + const scoped_refptr<SSLPrivateKey> private_key_; + const base::RepeatingClosure on_sign_; +}; + +} // namespace + +// Verifies that if the client disconnects while during private key signing then +// the disconnection is correctly reported to the `Handshake()` completion +// callback, with `ERR_CONNECTION_CLOSED`. +// This is a regression test for crbug.com/1449461. +TEST_F(SSLServerSocketTest, + HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) { + auto on_sign = base::BindLambdaForTesting([&]() { + client_socket_->Disconnect(); + ASSERT_FALSE(client_socket_->IsConnected()); + }); + server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>( + std::move(server_ssl_private_key_), on_sign); + ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + ASSERT_EQ(server_ret, net::ERR_IO_PENDING); + + TestCompletionCallback connect_callback; + client_socket_->Connect(connect_callback.callback()); + + // If resuming the handshake after private-key signing is not handled + // correctly as per crbug.com/1449461 then the test will hang and timeout + // at this point, due to the server-side completion callback not being + // correctly invoked. + server_ret = handshake_callback.GetResult(server_ret); + EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED); +} + // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the // server key. TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) { diff --git a/net/socket/tcp_client_socket.cc b/net/socket/tcp_client_socket.cc index a5fc05194..47a6b0dfa 100644 --- a/net/socket/tcp_client_socket.cc +++ b/net/socket/tcp_client_socket.cc @@ -61,11 +61,12 @@ TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket, TCPClientSocket::TCPClientSocket( std::unique_ptr<TCPSocket> unconnected_socket, const AddressList& addresses, + std::unique_ptr<IPEndPoint> bound_address, NetworkQualityEstimator* network_quality_estimator) : TCPClientSocket(std::move(unconnected_socket), addresses, -1 /* current_address_index */, - nullptr /* bind_address */, + std::move(bound_address), network_quality_estimator, handles::kInvalidNetworkHandle) {} diff --git a/net/socket/tcp_client_socket.h b/net/socket/tcp_client_socket.h index b761849ae..200246171 100644 --- a/net/socket/tcp_client_socket.h +++ b/net/socket/tcp_client_socket.h @@ -69,10 +69,11 @@ class NET_EXPORT TCPClientSocket : public TransportClientSocket, TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket, const IPEndPoint& peer_address); - // Adopts an unconnected TCPSocket. This function is used by - // TCPClientSocketBrokered. + // Adopts an unconnected TCPSocket. TCPSocket may be bound or unbound. This + // function is used by TCPClientSocketBrokered. TCPClientSocket(std::unique_ptr<TCPSocket> unconnected_socket, const AddressList& addresses, + std::unique_ptr<IPEndPoint> bound_address, NetworkQualityEstimator* network_quality_estimator); // Creates a TCPClientSocket from a bound-but-not-connected socket. diff --git a/net/socket/tcp_socket_win.cc b/net/socket/tcp_socket_win.cc index de99ac85b..355aa7dae 100644 --- a/net/socket/tcp_socket_win.cc +++ b/net/socket/tcp_socket_win.cc @@ -949,13 +949,15 @@ void TCPSocketWin::DidCompleteConnect() { int rv = WSAEnumNetworkEvents(socket_, core_->read_event_, &events); int os_error = WSAGetLastError(); if (rv == SOCKET_ERROR) { - NOTREACHED(); + DLOG(FATAL) + << "WSAEnumNetworkEvents() failed with SOCKET_ERROR, os_error = " + << os_error; result = MapSystemError(os_error); } else if (events.lNetworkEvents & FD_CONNECT) { os_error = events.iErrorCode[FD_CONNECT_BIT]; result = MapConnectError(os_error); } else { - NOTREACHED(); + DLOG(FATAL) << "WSAEnumNetworkEvents() failed, rv = " << rv; result = ERR_UNEXPECTED; } diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc index 03180c6c8..8587f4b9d 100644 --- a/net/socket/transport_client_socket_pool.cc +++ b/net/socket/transport_client_socket_pool.cc @@ -830,8 +830,8 @@ void TransportClientSocketPool::OnSSLConfigChanged( } // TODO(crbug.com/1206799): Get `server` as SchemeHostPort? -void TransportClientSocketPool::OnSSLConfigForServerChanged( - const HostPortPair& server) { +void TransportClientSocketPool::OnSSLConfigForServersChanged( + const base::flat_set<HostPortPair>& servers) { // Current time value. Retrieving it once at the function start rather than // inside the inner loop, since it shouldn't change by any meaningful amount. // @@ -844,12 +844,13 @@ void TransportClientSocketPool::OnSSLConfigForServerChanged( // every group. bool proxy_matches = proxy_server_.is_http_like() && !proxy_server_.is_http() && - proxy_server_.host_port_pair() == server; + servers.contains(proxy_server_.host_port_pair()); bool refreshed_any = false; for (auto it = group_map_.begin(); it != group_map_.end();) { if (proxy_matches || (GURL::SchemeIsCryptographic(it->first.destination().scheme()) && - HostPortPair::FromSchemeHostPort(it->first.destination()) == server)) { + servers.contains( + HostPortPair::FromSchemeHostPort(it->first.destination())))) { refreshed_any = true; // Note this call may destroy the group and invalidate |to_refresh|. it = RefreshGroup(it, now, kSslConfigChanged); diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h index 52ff53207..470c878df 100644 --- a/net/socket/transport_client_socket_pool.h +++ b/net/socket/transport_client_socket_pool.h @@ -265,7 +265,8 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool // SSLClientContext::Observer methods. void OnSSLConfigChanged( SSLClientContext::SSLConfigChangeType change_type) override; - void OnSSLConfigForServerChanged(const HostPortPair& server) override; + void OnSSLConfigForServersChanged( + const base::flat_set<HostPortPair>& servers) override; private: // Entry for a persistent socket which became idle at time |start_time|. diff --git a/net/socket/udp_client_socket.cc b/net/socket/udp_client_socket.cc index 8645f1fc5..ca4a277e7 100644 --- a/net/socket/udp_client_socket.cc +++ b/net/socket/udp_client_socket.cc @@ -18,6 +18,11 @@ UDPClientSocket::UDPClientSocket(DatagramSocket::BindType bind_type, handles::NetworkHandle network) : socket_(bind_type, net_log, source), connect_using_network_(network) {} +UDPClientSocket::UDPClientSocket(DatagramSocket::BindType bind_type, + NetLogWithSource source_net_log, + handles::NetworkHandle network) + : socket_(bind_type, source_net_log), connect_using_network_(network) {} + UDPClientSocket::~UDPClientSocket() = default; int UDPClientSocket::Connect(const IPEndPoint& address) { @@ -26,7 +31,10 @@ int UDPClientSocket::Connect(const IPEndPoint& address) { return ConnectUsingNetwork(connect_using_network_, address); connect_called_ = true; - int rv = socket_.Open(address.GetFamily()); + int rv = OK; + if (!adopted_opened_socket_) { + rv = socket_.Open(address.GetFamily()); + } if (rv != OK) return rv; return socket_.Connect(address); @@ -38,9 +46,13 @@ int UDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network, connect_called_ = true; if (!NetworkChangeNotifier::AreNetworkHandlesSupported()) return ERR_NOT_IMPLEMENTED; - int rv = socket_.Open(address.GetFamily()); - if (rv != OK) + int rv = OK; + if (!adopted_opened_socket_) { + rv = socket_.Open(address.GetFamily()); + } + if (rv != OK) { return rv; + } rv = socket_.BindToNetwork(network); if (rv != OK) return rv; @@ -53,8 +65,10 @@ int UDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) { connect_called_ = true; if (!NetworkChangeNotifier::AreNetworkHandlesSupported()) return ERR_NOT_IMPLEMENTED; - int rv; - rv = socket_.Open(address.GetFamily()); + int rv = OK; + if (!adopted_opened_socket_) { + rv = socket_.Open(address.GetFamily()); + } if (rv != OK) return rv; // Calling connect() will bind a socket to the default network, however there @@ -126,6 +140,7 @@ int UDPClientSocket::Write( void UDPClientSocket::Close() { socket_.Close(); + adopted_opened_socket_ = false; } int UDPClientSocket::GetPeerAddress(IPEndPoint* address) const { @@ -178,9 +193,13 @@ void UDPClientSocket::SetIOSNetworkServiceType(int ios_network_service_type) { #endif } -void UDPClientSocket::AdoptOpenedSocket(AddressFamily address_family, - SocketDescriptor socket) { - socket_.AdoptOpenedSocket(address_family, socket); +int UDPClientSocket::AdoptOpenedSocket(AddressFamily address_family, + SocketDescriptor socket) { + int rv = socket_.AdoptOpenedSocket(address_family, socket); + if (rv == OK) { + adopted_opened_socket_ = true; + } + return rv; } } // namespace net diff --git a/net/socket/udp_client_socket.h b/net/socket/udp_client_socket.h index fc695b7cd..6e3875b4f 100644 --- a/net/socket/udp_client_socket.h +++ b/net/socket/udp_client_socket.h @@ -30,6 +30,11 @@ class NET_EXPORT_PRIVATE UDPClientSocket : public DatagramClientSocket { const net::NetLogSource& source, handles::NetworkHandle network = handles::kInvalidNetworkHandle); + UDPClientSocket( + DatagramSocket::BindType bind_type, + NetLogWithSource source_net_log, + handles::NetworkHandle network = handles::kInvalidNetworkHandle); + UDPClientSocket(const UDPClientSocket&) = delete; UDPClientSocket& operator=(const UDPClientSocket&) = delete; @@ -74,14 +79,34 @@ class NET_EXPORT_PRIVATE UDPClientSocket : public DatagramClientSocket { int SetMulticastInterface(uint32_t interface_index) override; void SetIOSNetworkServiceType(int ios_network_service_type) override; - // Takes ownership of an opened but unconnected and unbound `socket`. - void AdoptOpenedSocket(AddressFamily address_family, SocketDescriptor socket); + // Takes ownership of an opened but unconnected and unbound `socket`. This + // method must be called after UseNonBlockingIO, otherwise the adopted socket + // will not have the non-blocking IO flag set. + int AdoptOpenedSocket(AddressFamily address_family, SocketDescriptor socket); + + uint32_t get_multicast_interface_for_testing() { + return socket_.get_multicast_interface_for_testing(); + } +#if !BUILDFLAG(IS_WIN) + bool get_msg_confirm_for_testing() { + return socket_.get_msg_confirm_for_testing(); + } + bool get_recv_optimization_for_testing() { + return socket_.get_experimental_recv_optimization_enabled_for_testing(); + } +#endif +#if BUILDFLAG(IS_WIN) + bool get_use_non_blocking_io_for_testing() { + return socket_.get_use_non_blocking_io_for_testing(); + } +#endif private: UDPSocket socket_; + bool adopted_opened_socket_ = false; bool connect_called_ = false; // The network the socket is currently bound to. - handles::NetworkHandle network_; + handles::NetworkHandle network_ = handles::kInvalidNetworkHandle; handles::NetworkHandle connect_using_network_; }; diff --git a/net/socket/udp_socket_posix.cc b/net/socket/udp_socket_posix.cc index d7fe60ca6..120ce583c 100644 --- a/net/socket/udp_socket_posix.cc +++ b/net/socket/udp_socket_posix.cc @@ -139,6 +139,22 @@ UDPSocketPosix::UDPSocketPosix(DatagramSocket::BindType bind_type, net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source); } +UDPSocketPosix::UDPSocketPosix(DatagramSocket::BindType bind_type, + NetLogWithSource source_net_log) + : socket_(kInvalidSocket), + bind_type_(bind_type), + read_socket_watcher_(FROM_HERE), + write_socket_watcher_(FROM_HERE), + read_watcher_(this), + write_watcher_(this), + net_log_(source_net_log), + bound_network_(handles::kInvalidNetworkHandle), + always_update_bytes_received_(base::FeatureList::IsEnabled( + features::kUdpSocketPosixAlwaysUpdateBytesReceived)) { + net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, + net_log_.source()); +} + UDPSocketPosix::~UDPSocketPosix() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); Close(); diff --git a/net/socket/udp_socket_posix.h b/net/socket/udp_socket_posix.h index ed9c7549c..a0cb1e2f5 100644 --- a/net/socket/udp_socket_posix.h +++ b/net/socket/udp_socket_posix.h @@ -72,6 +72,9 @@ class NET_EXPORT UDPSocketPosix { net::NetLog* net_log, const net::NetLogSource& source); + UDPSocketPosix(DatagramSocket::BindType bind_type, + NetLogWithSource source_net_log); + UDPSocketPosix(const UDPSocketPosix&) = delete; UDPSocketPosix& operator=(const UDPSocketPosix&) = delete; @@ -280,6 +283,14 @@ class NET_EXPORT UDPSocketPosix { // not bound or connected to an address. int AdoptOpenedSocket(AddressFamily address_family, int socket); + uint32_t get_multicast_interface_for_testing() { + return multicast_interface_; + } + bool get_msg_confirm_for_testing() { return sendto_flags_; } + bool get_experimental_recv_optimization_enabled_for_testing() { + return experimental_recv_optimization_enabled_; + } + private: enum SocketOptions { SOCKET_OPTION_MULTICAST_LOOP = 1 << 0 diff --git a/net/socket/udp_socket_win.cc b/net/socket/udp_socket_win.cc index dfc234f5d..32926a00c 100644 --- a/net/socket/udp_socket_win.cc +++ b/net/socket/udp_socket_win.cc @@ -249,6 +249,16 @@ UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type, net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source); } +UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type, + NetLogWithSource source_net_log) + : socket_(INVALID_SOCKET), + socket_options_(SOCKET_OPTION_MULTICAST_LOOP), + net_log_(source_net_log) { + EnsureWinsockInit(); + net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, + net_log_.source()); +} + UDPSocketWin::~UDPSocketWin() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); Close(); diff --git a/net/socket/udp_socket_win.h b/net/socket/udp_socket_win.h index 857613564..dd27752c6 100644 --- a/net/socket/udp_socket_win.h +++ b/net/socket/udp_socket_win.h @@ -165,6 +165,9 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate { net::NetLog* net_log, const net::NetLogSource& source); + UDPSocketWin(DatagramSocket::BindType bind_type, + NetLogWithSource source_net_log); + UDPSocketWin(const UDPSocketWin&) = delete; UDPSocketWin& operator=(const UDPSocketWin&) = delete; @@ -346,8 +349,8 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate { // Resets the thread to be used for thread-safety checks. void DetachFromThread(); - // This class by default uses overlapped IO. Call this method before Open() - // to switch to non-blocking IO. + // This class by default uses overlapped IO. Call this method before Open() or + // AdoptOpenedSocket() to switch to non-blocking IO. void UseNonBlockingIO(); // Apply |tag| to this socket. @@ -355,9 +358,16 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate { // Takes ownership of `socket`, which should be a socket descriptor opened // with the specified address family. The socket should only be created but - // not bound or connected to an address. + // not bound or connected to an address. This method must be called after + // UseNonBlockingIO, otherwise the adopted socket will not have the + // non-blocking IO flag set. int AdoptOpenedSocket(AddressFamily address_family, SOCKET socket); + uint32_t get_multicast_interface_for_testing() { + return multicast_interface_; + } + bool get_use_non_blocking_io_for_testing() { return use_non_blocking_io_; } + private: enum SocketOptions { SOCKET_OPTION_MULTICAST_LOOP = 1 << 0 diff --git a/net/socket/websocket_transport_client_socket_pool.cc b/net/socket/websocket_transport_client_socket_pool.cc index b8f15ee7d..286652d99 100644 --- a/net/socket/websocket_transport_client_socket_pool.cc +++ b/net/socket/websocket_transport_client_socket_pool.cc @@ -44,10 +44,10 @@ WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool( WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() { // Clean up any pending connect jobs. FlushWithError(ERR_ABORTED, ""); - DCHECK(pending_connects_.empty()); - DCHECK_EQ(0, handed_out_socket_count_); - DCHECK(stalled_request_queue_.empty()); - DCHECK(stalled_request_map_.empty()); + CHECK(pending_connects_.empty()); + CHECK_EQ(0, handed_out_socket_count_); + CHECK(stalled_request_queue_.empty()); + CHECK(stalled_request_map_.empty()); } // static @@ -160,8 +160,12 @@ void WebSocketTransportClientSocketPool::CancelRequest( if (socket) ReleaseSocket(handle->group_id(), std::move(socket), handle->group_generation()); - if (!DeleteJob(handle)) + if (DeleteJob(handle)) { + CHECK(!base::Contains(pending_callbacks_, + reinterpret_cast<ClientSocketHandleID>(handle))); + } else { pending_callbacks_.erase(reinterpret_cast<ClientSocketHandleID>(handle)); + } ActivateStalledRequest(); } @@ -331,7 +335,7 @@ void WebSocketTransportClientSocketPool::OnConnectJobComplete( ClientSocketHandle* const handle = connect_job_delegate->socket_handle(); bool delete_succeeded = DeleteJob(handle); - DCHECK(delete_succeeded); + CHECK(delete_succeeded); connect_job_delegate = nullptr; @@ -346,21 +350,24 @@ void WebSocketTransportClientSocketPool::InvokeUserCallbackLater( CompletionOnceCallback callback, int rv) { const auto handle_id = reinterpret_cast<ClientSocketHandleID>(handle); - DCHECK(!pending_callbacks_.count(handle_id)); + CHECK(!pending_callbacks_.count(handle_id)); pending_callbacks_.insert(handle_id); base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&WebSocketTransportClientSocketPool::InvokeUserCallback, - weak_factory_.GetWeakPtr(), handle_id, std::move(callback), - rv)); + weak_factory_.GetWeakPtr(), handle_id, + handle->GetWeakPtr(), std::move(callback), rv)); } void WebSocketTransportClientSocketPool::InvokeUserCallback( ClientSocketHandleID handle_id, + base::WeakPtr<ClientSocketHandle> weak_handle, CompletionOnceCallback callback, int rv) { - if (pending_callbacks_.erase(handle_id)) + if (pending_callbacks_.erase(handle_id)) { + CHECK(weak_handle); std::move(callback).Run(rv); + } } bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const { @@ -396,7 +403,7 @@ void WebSocketTransportClientSocketPool::AddJob( pending_connects_ .insert(PendingConnectsMap::value_type(handle, std::move(delegate))) .second; - DCHECK(inserted); + CHECK(inserted); } bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) { diff --git a/net/socket/websocket_transport_client_socket_pool.h b/net/socket/websocket_transport_client_socket_pool.h index fea8b4445..2750816af 100644 --- a/net/socket/websocket_transport_client_socket_pool.h +++ b/net/socket/websocket_transport_client_socket_pool.h @@ -191,6 +191,7 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool CompletionOnceCallback callback, int rv); void InvokeUserCallback(ClientSocketHandleID handle_id, + base::WeakPtr<ClientSocketHandle> weak_handle, CompletionOnceCallback callback, int rv); bool ReachedMaxSocketsLimit() const; diff --git a/net/socket/websocket_transport_client_socket_pool_unittest.cc b/net/socket/websocket_transport_client_socket_pool_unittest.cc index e29dc56f5..787d08226 100644 --- a/net/socket/websocket_transport_client_socket_pool_unittest.cc +++ b/net/socket/websocket_transport_client_socket_pool_unittest.cc @@ -4,11 +4,11 @@ #include "net/socket/websocket_transport_client_socket_pool.h" +#include <algorithm> #include <memory> #include <utility> #include <vector> -#include "base/cxx17_backports.h" #include "base/functional/bind.h" #include "base/functional/callback.h" #include "base/functional/callback_helpers.h" |