summaryrefslogtreecommitdiff
path: root/net/socket
diff options
context:
space:
mode:
authorCronet Mainline Eng <cronet-mainline-eng+copybara@google.com>2023-08-14 17:15:38 +0000
committerMohannad Farrag <aymanm@google.com>2023-08-14 17:22:36 +0000
commitec3a8e8db24bb3ce4b078106b358ca1c4389c14f (patch)
tree823f64849ad509483bfebb2252199a5fe79b8e43 /net/socket
parentd12afe756882b2521faa0b33cbd4813fcea04c22 (diff)
downloadcronet-ec3a8e8db24bb3ce4b078106b358ca1c4389c14f.tar.gz
Import Cronet version 117.0.5938.0
Project import generated by Copybara. FolderOrigin-RevId: /tmp/copybara-origin/src Change-Id: Ib7683d0ed240e11ed9068152600c8092afba4571
Diffstat (limited to 'net/socket')
-rw-r--r--net/socket/client_socket_handle.cc1
-rw-r--r--net/socket/client_socket_handle.h7
-rw-r--r--net/socket/client_socket_pool_base_unittest.cc26
-rw-r--r--net/socket/connect_job.h2
-rw-r--r--net/socket/connect_job_factory_unittest.cc3
-rw-r--r--net/socket/socket_bio_adapter.cc116
-rw-r--r--net/socket/socket_bio_adapter.h4
-rw-r--r--net/socket/socket_posix.cc11
-rw-r--r--net/socket/socket_test_util.cc9
-rw-r--r--net/socket/socket_test_util.h12
-rw-r--r--net/socket/socks_client_socket.h2
-rw-r--r--net/socket/ssl_client_socket.cc35
-rw-r--r--net/socket/ssl_client_socket.h26
-rw-r--r--net/socket/ssl_client_socket_impl.cc62
-rw-r--r--net/socket/ssl_client_socket_impl.h1
-rw-r--r--net/socket/ssl_client_socket_unittest.cc232
-rw-r--r--net/socket/ssl_server_socket_impl.cc14
-rw-r--r--net/socket/ssl_server_socket_unittest.cc76
-rw-r--r--net/socket/tcp_client_socket.cc3
-rw-r--r--net/socket/tcp_client_socket.h5
-rw-r--r--net/socket/tcp_socket_win.cc6
-rw-r--r--net/socket/transport_client_socket_pool.cc9
-rw-r--r--net/socket/transport_client_socket_pool.h3
-rw-r--r--net/socket/udp_client_socket.cc35
-rw-r--r--net/socket/udp_client_socket.h31
-rw-r--r--net/socket/udp_socket_posix.cc16
-rw-r--r--net/socket/udp_socket_posix.h11
-rw-r--r--net/socket/udp_socket_win.cc10
-rw-r--r--net/socket/udp_socket_win.h16
-rw-r--r--net/socket/websocket_transport_client_socket_pool.cc29
-rw-r--r--net/socket/websocket_transport_client_socket_pool.h1
-rw-r--r--net/socket/websocket_transport_client_socket_pool_unittest.cc2
32 files changed, 587 insertions, 229 deletions
diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc
index 653b469fc..35d15c24b 100644
--- a/net/socket/client_socket_handle.cc
+++ b/net/socket/client_socket_handle.cc
@@ -24,6 +24,7 @@ ClientSocketHandle::ClientSocketHandle()
: resolve_error_info_(ResolveErrorInfo(OK)) {}
ClientSocketHandle::~ClientSocketHandle() {
+ weak_factory_.InvalidateWeakPtrs();
Reset();
}
diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h
index b6b71c284..2c7b2b662 100644
--- a/net/socket/client_socket_handle.h
+++ b/net/socket/client_socket_handle.h
@@ -12,6 +12,7 @@
#include "base/functional/bind.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
+#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "net/base/ip_endpoint.h"
#include "net/base/load_states.h"
@@ -209,6 +210,10 @@ class NET_EXPORT ClientSocketHandle {
connect_timing_ = connect_timing;
}
+ base::WeakPtr<ClientSocketHandle> GetWeakPtr() {
+ return weak_factory_.GetWeakPtr();
+ }
+
private:
// Called on asynchronous completion of an Init() request.
void OnIOComplete(int result);
@@ -247,6 +252,8 @@ class NET_EXPORT ClientSocketHandle {
// Timing information is set when a connection is successfully established.
LoadTimingInfo::ConnectTiming connect_timing_;
+
+ base::WeakPtrFactory<ClientSocketHandle> weak_factory_{this};
};
} // namespace net
diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc
index 212b055c7..eeba63b47 100644
--- a/net/socket/client_socket_pool_base_unittest.cc
+++ b/net/socket/client_socket_pool_base_unittest.cc
@@ -5708,7 +5708,7 @@ enum class RefreshType {
};
// Common base class to test RefreshGroup() when called from either
-// OnSSLConfigForServerChanged() matching a specific group or the pool's proxy.
+// OnSSLConfigForServersChanged() matching a specific group or the pool's proxy.
//
// Tests which test behavior specific to one or the other case should use
// ClientSocketPoolBaseTest directly. In particular, there is no "other group"
@@ -5750,13 +5750,13 @@ class ClientSocketPoolBaseRefreshTest
kNetworkAnonymizationKey);
}
- void OnSSLConfigForServerChanged() {
+ void OnSSLConfigForServersChanged() {
switch (GetParam()) {
case RefreshType::kServer:
- pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)});
break;
case RefreshType::kProxy:
- pool_->OnSSLConfigForServerChanged(HostPortPair("myproxy", 70));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("myproxy", 70)});
break;
}
}
@@ -5787,7 +5787,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, RefreshGroupCreatesNewConnectJobs) {
// success.
connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
- OnSSLConfigForServerChanged();
+ OnSSLConfigForServersChanged();
EXPECT_EQ(OK, callback.WaitForResult());
ASSERT_TRUE(handle.socket());
EXPECT_EQ(0, pool_->IdleSocketCount());
@@ -5819,7 +5819,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, RefreshGroupClosesIdleConnectJobs) {
EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kGroupId));
EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kGroupIdInPartition));
- OnSSLConfigForServerChanged();
+ OnSSLConfigForServersChanged();
EXPECT_EQ(0, pool_->IdleSocketCount());
EXPECT_FALSE(pool_->HasGroupForTesting(kGroupId));
EXPECT_FALSE(pool_->HasGroupForTesting(kGroupIdInPartition));
@@ -5840,7 +5840,7 @@ TEST_F(ClientSocketPoolBaseTest,
EXPECT_EQ(2, pool_->IdleSocketCount());
EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kOtherGroupId));
- pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)});
ASSERT_TRUE(pool_->HasGroupForTesting(kOtherGroupId));
EXPECT_EQ(2, pool_->IdleSocketCount());
EXPECT_EQ(2u, pool_->IdleSocketCountInGroup(kOtherGroupId));
@@ -5861,7 +5861,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest, RefreshGroupPreventsSocketReuse) {
ASSERT_TRUE(pool_->HasGroupForTesting(kGroupId));
EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kGroupId));
- OnSSLConfigForServerChanged();
+ OnSSLConfigForServersChanged();
handle.Reset();
EXPECT_EQ(0, pool_->IdleSocketCount());
@@ -5887,7 +5887,7 @@ TEST_F(ClientSocketPoolBaseTest,
ASSERT_TRUE(pool_->HasGroupForTesting(kOtherGroupId));
EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kOtherGroupId));
- pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)});
handle.Reset();
EXPECT_EQ(1, pool_->IdleSocketCount());
@@ -5910,7 +5910,7 @@ TEST_P(ClientSocketPoolBaseRefreshTest,
// This should update the generation, but not cancel the old ConnectJob - it's
// not safe to do anything while waiting on the original ConnectJob.
- OnSSLConfigForServerChanged();
+ OnSSLConfigForServersChanged();
// Providing auth credentials and restarting the request with them will cause
// the ConnectJob to complete successfully, but the result will be discarded
@@ -5980,7 +5980,7 @@ TEST_F(ClientSocketPoolBaseTest, RefreshProxyRefreshesAllGroups) {
// Changes to some other proxy do not affect the pool. The idle socket remains
// alive and closing |handle2| makes the socket available for the pool.
- pool_->OnSSLConfigForServerChanged(HostPortPair("someotherproxy", 70));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("someotherproxy", 70)});
ASSERT_TRUE(pool_->HasGroupForTesting(kGroupId1));
EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kGroupId1));
@@ -5994,7 +5994,7 @@ TEST_F(ClientSocketPoolBaseTest, RefreshProxyRefreshesAllGroups) {
EXPECT_EQ(1u, pool_->IdleSocketCountInGroup(kGroupId2));
// Changes to the matching proxy refreshes all groups.
- pool_->OnSSLConfigForServerChanged(HostPortPair("myproxy", 70));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("myproxy", 70)});
// Idle sockets are closed.
EXPECT_EQ(0, pool_->IdleSocketCount());
@@ -6049,7 +6049,7 @@ TEST_F(ClientSocketPoolBaseTest, RefreshBothPrivacyAndNormalSockets) {
ASSERT_TRUE(pool_->HasGroupForTesting(kOtherGroupId));
EXPECT_EQ(1, pool_->NumActiveSocketsInGroupForTesting(kOtherGroupId));
- pool_->OnSSLConfigForServerChanged(HostPortPair("a", 443));
+ pool_->OnSSLConfigForServersChanged({HostPortPair("a", 443)});
// Active sockets continue to be active.
ASSERT_TRUE(pool_->HasGroupForTesting(kGroupId));
diff --git a/net/socket/connect_job.h b/net/socket/connect_job.h
index 35c0f4992..2f810362a 100644
--- a/net/socket/connect_job.h
+++ b/net/socket/connect_job.h
@@ -110,7 +110,7 @@ enum class OnHostResolutionCallbackResult {
// ConnectJob synchronously, but may signal the ConnectJob may be destroyed
// asynchronously. See OnHostResolutionCallbackResult above.
//
-// |address_list| is the list of addresses the host being connected to was
+// `endpoint_results` is the list of endpoints the host being connected to was
// resolved to, with the port fields populated to the port being connected to.
using OnHostResolutionCallback =
base::RepeatingCallback<OnHostResolutionCallbackResult(
diff --git a/net/socket/connect_job_factory_unittest.cc b/net/socket/connect_job_factory_unittest.cc
index 45d599223..d3b0268af 100644
--- a/net/socket/connect_job_factory_unittest.cc
+++ b/net/socket/connect_job_factory_unittest.cc
@@ -185,12 +185,11 @@ class ConnectJobFactoryTest : public TestWithTaskEnvironment {
/*websocket_endpoint_lock_manager=*/nullptr};
TestConnectJobDelegate delegate_;
+ std::unique_ptr<ConnectJobFactory> factory_;
raw_ptr<TestHttpProxyConnectJobFactory> http_proxy_job_factory_;
raw_ptr<TestSocksConnectJobFactory> socks_job_factory_;
raw_ptr<TestSslConnectJobFactory> ssl_job_factory_;
raw_ptr<TestTransportConnectJobFactory> transport_job_factory_;
-
- std::unique_ptr<ConnectJobFactory> factory_;
};
TEST_F(ConnectJobFactoryTest, CreateConnectJob) {
diff --git a/net/socket/socket_bio_adapter.cc b/net/socket/socket_bio_adapter.cc
index 8017b593b..4126ea309 100644
--- a/net/socket/socket_bio_adapter.cc
+++ b/net/socket/socket_bio_adapter.cc
@@ -4,11 +4,13 @@
#include "net/socket/socket_bio_adapter.h"
+#include <stdio.h>
#include <string.h>
#include <algorithm>
#include "base/check_op.h"
+#include "base/debug/alias.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/notreached.h"
@@ -64,9 +66,9 @@ SocketBIOAdapter::SocketBIOAdapter(StreamSocket* socket,
read_buffer_capacity_(read_buffer_capacity),
write_buffer_capacity_(write_buffer_capacity),
delegate_(delegate) {
- bio_.reset(BIO_new(&kBIOMethod));
- bio_->ptr = this;
- bio_->init = 1;
+ bio_.reset(BIO_new(BIOMethod()));
+ BIO_set_data(bio_.get(), this);
+ BIO_set_init(bio_.get(), 1);
read_callback_ = base::BindRepeating(&SocketBIOAdapter::OnSocketReadComplete,
weak_factory_.GetWeakPtr());
@@ -75,16 +77,19 @@ SocketBIOAdapter::SocketBIOAdapter(StreamSocket* socket,
}
SocketBIOAdapter::~SocketBIOAdapter() {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// BIOs are reference-counted and may outlive the adapter. Clear the pointer
// so future operations fail.
- bio_->ptr = nullptr;
+ BIO_set_data(bio_.get(), nullptr);
}
bool SocketBIOAdapter::HasPendingReadData() {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return read_result_ > 0;
}
size_t SocketBIOAdapter::GetAllocationSize() const {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
size_t buffer_size = 0;
if (read_buffer_)
buffer_size += read_buffer_capacity_;
@@ -95,6 +100,7 @@ size_t SocketBIOAdapter::GetAllocationSize() const {
}
int SocketBIOAdapter::BIORead(char* out, int len) {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (len <= 0)
return len;
@@ -115,9 +121,10 @@ int SocketBIOAdapter::BIORead(char* out, int len) {
// layer reads the record header and body in separate reads to avoid
// overreading, but issuing one is more efficient. SSL sockets are not
// reused after shutdown for non-SSL traffic, so overreading is fine.
- DCHECK(!read_buffer_);
- DCHECK_EQ(0, read_offset_);
+ CHECK(!read_buffer_);
+ CHECK_EQ(0, read_offset_);
read_buffer_ = base::MakeRefCounted<IOBuffer>(read_buffer_capacity_);
+ read_result_ = ERR_IO_PENDING;
int result = socket_->ReadIfReady(
read_buffer_.get(), read_buffer_capacity_,
base::BindOnce(&SocketBIOAdapter::OnSocketReadIfReadyComplete,
@@ -128,9 +135,8 @@ int SocketBIOAdapter::BIORead(char* out, int len) {
result = socket_->Read(read_buffer_.get(), read_buffer_capacity_,
read_callback_);
}
- if (result == ERR_IO_PENDING) {
- read_result_ = ERR_IO_PENDING;
- } else {
+ if (result != ERR_IO_PENDING) {
+ // `HandleSocketReadResult` will update `read_result_` based on `result`.
HandleSocketReadResult(result);
}
}
@@ -164,7 +170,9 @@ int SocketBIOAdapter::BIORead(char* out, int len) {
}
void SocketBIOAdapter::HandleSocketReadResult(int result) {
- DCHECK_NE(ERR_IO_PENDING, result);
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+ CHECK_NE(ERR_IO_PENDING, result);
+ CHECK_EQ(ERR_IO_PENDING, read_result_);
// If an EOF, canonicalize to ERR_CONNECTION_CLOSED here, so that higher
// levels don't report success.
@@ -179,15 +187,17 @@ void SocketBIOAdapter::HandleSocketReadResult(int result) {
}
void SocketBIOAdapter::OnSocketReadComplete(int result) {
- DCHECK_EQ(ERR_IO_PENDING, read_result_);
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+ CHECK_EQ(ERR_IO_PENDING, read_result_);
HandleSocketReadResult(result);
delegate_->OnReadReady();
}
void SocketBIOAdapter::OnSocketReadIfReadyComplete(int result) {
- DCHECK_EQ(ERR_IO_PENDING, read_result_);
- DCHECK_GE(OK, result);
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+ CHECK_EQ(ERR_IO_PENDING, read_result_);
+ CHECK_GE(OK, result);
// Do not use HandleSocketReadResult() because result == OK doesn't mean EOF.
read_result_ = result;
@@ -196,12 +206,13 @@ void SocketBIOAdapter::OnSocketReadIfReadyComplete(int result) {
}
int SocketBIOAdapter::BIOWrite(const char* in, int len) {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (len <= 0)
return len;
// If the write buffer is not empty, there must be a pending Write() to flush
// it.
- DCHECK(write_buffer_used_ == 0 || write_error_ == ERR_IO_PENDING);
+ CHECK(write_buffer_used_ == 0 || write_error_ == ERR_IO_PENDING);
// If a previous Write() failed, report the error.
if (write_error_ != OK && write_error_ != ERR_IO_PENDING) {
@@ -211,7 +222,7 @@ int SocketBIOAdapter::BIOWrite(const char* in, int len) {
// Instantiate the write buffer if needed.
if (!write_buffer_) {
- DCHECK_EQ(0, write_buffer_used_);
+ CHECK_EQ(0, write_buffer_used_);
write_buffer_ = base::MakeRefCounted<GrowableIOBuffer>();
write_buffer_->SetCapacity(write_buffer_capacity_);
}
@@ -250,7 +261,7 @@ int SocketBIOAdapter::BIOWrite(const char* in, int len) {
}
// Either the buffer is now full or there is no more input.
- DCHECK(len == 0 || write_buffer_used_ == write_buffer_->capacity());
+ CHECK(len == 0 || write_buffer_used_ == write_buffer_->capacity());
// Schedule a socket Write() if necessary. (The ring buffer may previously
// have been empty.)
@@ -270,22 +281,45 @@ int SocketBIOAdapter::BIOWrite(const char* in, int len) {
}
void SocketBIOAdapter::SocketWrite() {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
while (write_error_ == OK && write_buffer_used_ > 0) {
+ int write_buffer_used_old = write_buffer_used_;
int write_size =
std::min(write_buffer_used_, write_buffer_->RemainingCapacity());
+
+ // TODO(crbug.com/1440692): Remove this once the crash is resolved.
+ char debug[128];
+ snprintf(debug, sizeof(debug),
+ "offset=%d;remaining=%d;used=%d;write_size=%d",
+ write_buffer_->offset(), write_buffer_->RemainingCapacity(),
+ write_buffer_used_, write_size);
+ base::debug::Alias(debug);
+
+ write_error_ = ERR_IO_PENDING;
int result = socket_->Write(write_buffer_.get(), write_size,
write_callback_, kTrafficAnnotation);
- if (result == ERR_IO_PENDING) {
- write_error_ = ERR_IO_PENDING;
- return;
- }
- HandleSocketWriteResult(result);
+ // TODO(crbug.com/1440692): Remove this once the crash is resolved.
+ char debug2[32];
+ snprintf(debug2, sizeof(debug2), "result=%d", result);
+ base::debug::Alias(debug2);
+
+ // If `write_buffer_used_` changed across a call to the underlying socket,
+ // something went very wrong.
+ //
+ // TODO(crbug.com/1440692): Remove this once the crash is resolved.
+ CHECK_EQ(write_buffer_used_old, write_buffer_used_);
+ if (result != ERR_IO_PENDING) {
+ // `HandleSocketWriteResult` will update `write_error_` based on `result.
+ HandleSocketWriteResult(result);
+ }
}
}
void SocketBIOAdapter::HandleSocketWriteResult(int result) {
- DCHECK_NE(ERR_IO_PENDING, result);
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+ CHECK_NE(ERR_IO_PENDING, result);
+ CHECK_EQ(ERR_IO_PENDING, write_error_);
if (result < 0) {
write_error_ = result;
@@ -297,6 +331,8 @@ void SocketBIOAdapter::HandleSocketWriteResult(int result) {
}
// Advance the ring buffer.
+ CHECK_LE(result, write_buffer_used_);
+ CHECK_LE(result, write_buffer_->RemainingCapacity());
write_buffer_->set_offset(write_buffer_->offset() + result);
write_buffer_used_ -= result;
if (write_buffer_->RemainingCapacity() == 0)
@@ -309,7 +345,8 @@ void SocketBIOAdapter::HandleSocketWriteResult(int result) {
}
void SocketBIOAdapter::OnSocketWriteComplete(int result) {
- DCHECK_EQ(ERR_IO_PENDING, write_error_);
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+ CHECK_EQ(ERR_IO_PENDING, write_error_);
bool was_full = write_buffer_used_ == write_buffer_->capacity();
@@ -333,15 +370,17 @@ void SocketBIOAdapter::OnSocketWriteComplete(int result) {
}
void SocketBIOAdapter::CallOnReadReady() {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (read_result_ == ERR_IO_PENDING)
delegate_->OnReadReady();
}
SocketBIOAdapter* SocketBIOAdapter::GetAdapter(BIO* bio) {
- DCHECK_EQ(&kBIOMethod, bio->method);
- SocketBIOAdapter* adapter = reinterpret_cast<SocketBIOAdapter*>(bio->ptr);
- if (adapter)
- DCHECK_EQ(bio, adapter->bio());
+ SocketBIOAdapter* adapter =
+ reinterpret_cast<SocketBIOAdapter*>(BIO_get_data(bio));
+ if (adapter) {
+ CHECK_EQ(bio, adapter->bio());
+ }
return adapter;
}
@@ -383,17 +422,16 @@ long SocketBIOAdapter::BIOCtrlWrapper(BIO* bio,
return 0;
}
-const BIO_METHOD SocketBIOAdapter::kBIOMethod = {
- 0, // type (unused)
- nullptr, // name (unused)
- SocketBIOAdapter::BIOWriteWrapper,
- SocketBIOAdapter::BIOReadWrapper,
- nullptr, // puts
- nullptr, // gets
- SocketBIOAdapter::BIOCtrlWrapper,
- nullptr, // create
- nullptr, // destroy
- nullptr, // callback_ctrl
-};
+const BIO_METHOD* SocketBIOAdapter::BIOMethod() {
+ static const BIO_METHOD* kMethod = []() {
+ BIO_METHOD* method = BIO_meth_new(0, nullptr);
+ CHECK(method);
+ CHECK(BIO_meth_set_write(method, SocketBIOAdapter::BIOWriteWrapper));
+ CHECK(BIO_meth_set_read(method, SocketBIOAdapter::BIOReadWrapper));
+ CHECK(BIO_meth_set_ctrl(method, SocketBIOAdapter::BIOCtrlWrapper));
+ return method;
+ }();
+ return kMethod;
+}
} // namespace net
diff --git a/net/socket/socket_bio_adapter.h b/net/socket/socket_bio_adapter.h
index e22ab6992..06212a7a1 100644
--- a/net/socket/socket_bio_adapter.h
+++ b/net/socket/socket_bio_adapter.h
@@ -8,6 +8,7 @@
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
+#include "base/sequence_checker.h"
#include "net/base/completion_repeating_callback.h"
#include "net/base/net_errors.h"
#include "net/base/net_export.h"
@@ -104,7 +105,7 @@ class NET_EXPORT_PRIVATE SocketBIOAdapter {
static int BIOWriteWrapper(BIO* bio, const char* in, int len);
static long BIOCtrlWrapper(BIO* bio, int cmd, long larg, void* parg);
- static const BIO_METHOD kBIOMethod;
+ static const BIO_METHOD* BIOMethod();
bssl::UniquePtr<BIO> bio_;
@@ -142,6 +143,7 @@ class NET_EXPORT_PRIVATE SocketBIOAdapter {
raw_ptr<Delegate> delegate_;
+ SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<SocketBIOAdapter> weak_factory_{this};
};
diff --git a/net/socket/socket_posix.cc b/net/socket/socket_posix.cc
index ea3d794d9..14f09938c 100644
--- a/net/socket/socket_posix.cc
+++ b/net/socket/socket_posix.cc
@@ -325,12 +325,12 @@ int SocketPosix::Write(
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
DCHECK(thread_checker_.CalledOnValidThread());
- DCHECK_NE(kInvalidSocket, socket_fd_);
- DCHECK(!waiting_connect_);
+ CHECK_NE(kInvalidSocket, socket_fd_);
+ CHECK(!waiting_connect_);
CHECK(write_callback_.is_null());
// Synchronous operation not supported
- DCHECK(!callback.is_null());
- DCHECK_LT(0, buf_len);
+ CHECK(!callback.is_null());
+ CHECK_LT(0, buf_len);
int rv = DoWrite(buf, buf_len);
if (rv == ERR_IO_PENDING)
@@ -525,6 +525,9 @@ int SocketPosix::DoWrite(IOBuffer* buf, int buf_len) {
#else
int rv = HANDLE_EINTR(write(socket_fd_, buf->data(), buf_len));
#endif
+ if (rv >= 0) {
+ CHECK_LE(rv, buf_len);
+ }
return rv >= 0 ? rv : MapSystemError(errno);
}
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc
index 441af898c..acaf5868c 100644
--- a/net/socket/socket_test_util.cc
+++ b/net/socket/socket_test_util.cc
@@ -47,6 +47,7 @@
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gtest/include/gtest/gtest.h"
+#include "third_party/abseil-cpp/absl/strings/ascii.h"
#if BUILDFLAG(IS_ANDROID)
#include "base/android/build_info.h"
@@ -68,9 +69,7 @@ inline char AsciifyLow(char x) {
}
inline char Asciify(char x) {
- if ((x < 0) || !isprint(x))
- return '.';
- return x;
+ return absl::ascii_isprint(static_cast<unsigned char>(x)) ? x : '.';
}
void DumpData(const char* data, int data_len) {
@@ -1443,8 +1442,8 @@ void MockSSLClientSocket::GetSSLCertRequestInfo(
cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
cert_request_info->cert_authorities =
data_->cert_request_info->cert_authorities;
- cert_request_info->cert_key_types =
- data_->cert_request_info->cert_key_types;
+ cert_request_info->signature_algorithms =
+ data_->cert_request_info->signature_algorithms;
} else {
cert_request_info->Reset();
}
diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h
index cf8115ec1..5119e6619 100644
--- a/net/socket/socket_test_util.h
+++ b/net/socket/socket_test_util.h
@@ -495,9 +495,7 @@ struct SSLSocketDataProvider {
SSLInfo ssl_info;
// Result for GetSSLCertRequestInfo().
- // This field is not a raw_ptr<> because it was filtered by the rewriter for:
- // #union
- RAW_PTR_EXCLUSION SSLCertRequestInfo* cert_request_info = nullptr;
+ scoped_refptr<SSLCertRequestInfo> cert_request_info;
// Result for GetECHRetryConfigs().
std::vector<uint8_t> ech_retry_configs;
@@ -949,7 +947,7 @@ class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket {
bool in_confirm_handshake_ = false;
NetLogWithSource net_log_;
std::unique_ptr<StreamSocket> stream_socket_;
- raw_ptr<SSLSocketDataProvider> data_;
+ raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_;
// Address of the "remote" peer we're connected to.
IPEndPoint peer_addr_;
@@ -1355,8 +1353,10 @@ class MockTaggingClientSocketFactory : public MockClientSocketFactory {
MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; }
private:
- raw_ptr<MockTaggingStreamSocket> tcp_socket_ = nullptr;
- raw_ptr<MockUDPClientSocket> udp_socket_ = nullptr;
+ raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ =
+ nullptr;
+ raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ =
+ nullptr;
};
// Host / port used for SOCKS4 test strings.
diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h
index e606c591c..e12475f4e 100644
--- a/net/socket/socks_client_socket.h
+++ b/net/socket/socks_client_socket.h
@@ -143,7 +143,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
bool was_ever_used_ = false;
// Used to resolve the hostname to which the SOCKS proxy will connect.
- raw_ptr<HostResolver> host_resolver_;
+ raw_ptr<HostResolver, DanglingUntriaged> host_resolver_;
SecureDnsPolicy secure_dns_policy_;
std::unique_ptr<HostResolver::ResolveHostRequest> resolve_host_request_;
const HostPortPair destination_;
diff --git a/net/socket/ssl_client_socket.cc b/net/socket/ssl_client_socket.cc
index 4dd73b5c3..8290b9592 100644
--- a/net/socket/ssl_client_socket.cc
+++ b/net/socket/ssl_client_socket.cc
@@ -6,6 +6,7 @@
#include <string>
+#include "base/containers/flat_tree.h"
#include "base/logging.h"
#include "base/observer_list.h"
#include "net/socket/ssl_client_socket_impl.h"
@@ -103,9 +104,9 @@ void SSLClientContext::SetClientCertificate(
if (ssl_client_session_cache_) {
// Session resumption bypasses client certificate negotiation, so flush all
// associated sessions when preferences change.
- ssl_client_session_cache_->FlushForServer(server);
+ ssl_client_session_cache_->FlushForServers({server});
}
- NotifySSLConfigForServerChanged(server);
+ NotifySSLConfigForServersChanged({server});
}
bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) {
@@ -116,9 +117,9 @@ bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) {
if (ssl_client_session_cache_) {
// Session resumption bypasses client certificate negotiation, so flush all
// associated sessions when preferences change.
- ssl_client_session_cache_->FlushForServer(server);
+ ssl_client_session_cache_->FlushForServers({server});
}
- NotifySSLConfigForServerChanged(server);
+ NotifySSLConfigForServersChanged({server});
return true;
}
@@ -131,11 +132,10 @@ void SSLClientContext::RemoveObserver(Observer* observer) {
}
void SSLClientContext::OnSSLContextConfigChanged() {
- // TODO(davidben): Should we flush |ssl_client_session_cache_| here? We flush
- // the socket pools, but not the session cache. While BoringSSL-based servers
- // never change version or cipher negotiation based on client-offered
- // sessions, other servers do.
config_ = ssl_config_service_->GetSSLContextConfig();
+ if (ssl_client_session_cache_) {
+ ssl_client_session_cache_->Flush();
+ }
NotifySSLConfigChanged(SSLConfigChangeType::kSSLConfigChanged);
}
@@ -143,13 +143,18 @@ void SSLClientContext::OnCertVerifierChanged() {
NotifySSLConfigChanged(SSLConfigChangeType::kCertVerifierChanged);
}
-void SSLClientContext::OnCertDBChanged() {
- // Both the trust store and client certificate store may have changed.
+void SSLClientContext::OnTrustStoreChanged() {
+ NotifySSLConfigChanged(SSLConfigChangeType::kCertDatabaseChanged);
+}
+
+void SSLClientContext::OnClientCertStoreChanged() {
+ base::flat_set<HostPortPair> servers =
+ ssl_client_auth_cache_.GetCachedServers();
ssl_client_auth_cache_.Clear();
if (ssl_client_session_cache_) {
- ssl_client_session_cache_->Flush();
+ ssl_client_session_cache_->FlushForServers(servers);
}
- NotifySSLConfigChanged(SSLConfigChangeType::kCertDatabaseChanged);
+ NotifySSLConfigForServersChanged(servers);
}
void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) {
@@ -158,10 +163,10 @@ void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) {
}
}
-void SSLClientContext::NotifySSLConfigForServerChanged(
- const HostPortPair& server) {
+void SSLClientContext::NotifySSLConfigForServersChanged(
+ const base::flat_set<HostPortPair>& servers) {
for (Observer& observer : observers_) {
- observer.OnSSLConfigForServerChanged(server);
+ observer.OnSSLConfigForServersChanged(servers);
}
}
diff --git a/net/socket/ssl_client_socket.h b/net/socket/ssl_client_socket.h
index ddbbf708c..f34320c57 100644
--- a/net/socket/ssl_client_socket.h
+++ b/net/socket/ssl_client_socket.h
@@ -10,6 +10,7 @@
#include <memory>
#include <vector>
+#include "base/containers/flat_set.h"
#include "base/gtest_prod_util.h"
#include "base/memory/raw_ptr.h"
#include "base/observer_list.h"
@@ -101,11 +102,13 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer,
// Called when SSL configuration for all hosts changed. Newly-created
// SSLClientSockets will pick up the new configuration. Note that changes
// which only apply to one server will result in a call to
- // OnSSLConfigForServerChanged() instead.
+ // OnSSLConfigForServersChanged() instead.
virtual void OnSSLConfigChanged(SSLConfigChangeType change_type) = 0;
- // Called when SSL configuration for |server| changed. Newly-created
- // SSLClientSockets to |server| will pick up the new configuration.
- virtual void OnSSLConfigForServerChanged(const HostPortPair& server) = 0;
+ // Called when SSL configuration for |servers| changed. Newly-created
+ // SSLClientSockets to any server in |servers| will pick up the new
+ // configuration.
+ virtual void OnSSLConfigForServersChanged(
+ const base::flat_set<HostPortPair>& servers) = 0;
};
// Creates a new SSLClientContext with the specified parameters. The
@@ -164,7 +167,7 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer,
// |private_key| may be null to indicate that no client certificate should be
// sent to |server|.
//
- // Note this method will synchronously call OnSSLConfigForServerChanged() on
+ // Note this method will synchronously call OnSSLConfigForServersChanged() on
// observers.
void SetClientCertificate(const HostPortPair& server,
scoped_refptr<X509Certificate> client_cert,
@@ -174,10 +177,15 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer,
// SetClientCertificate(). Returns true if one was removed and false
// otherwise.
//
- // Note this method will synchronously call OnSSLConfigForServerChanged() on
+ // Note this method will synchronously call OnSSLConfigForServersChanged() on
// observers.
bool ClearClientCertificate(const HostPortPair& server);
+ base::flat_set<HostPortPair> GetClientCertificateCachedServersForTesting()
+ const {
+ return ssl_client_auth_cache_.GetCachedServers();
+ }
+
// Add an observer to be notified when configuration has changed.
// RemoveObserver() must be called before |observer| is destroyed.
void AddObserver(Observer* observer);
@@ -192,11 +200,13 @@ class NET_EXPORT SSLClientContext : public SSLConfigService::Observer,
void OnCertVerifierChanged() override;
// CertDatabase::Observer:
- void OnCertDBChanged() override;
+ void OnTrustStoreChanged() override;
+ void OnClientCertStoreChanged() override;
private:
void NotifySSLConfigChanged(SSLConfigChangeType change_type);
- void NotifySSLConfigForServerChanged(const HostPortPair& server);
+ void NotifySSLConfigForServersChanged(
+ const base::flat_set<HostPortPair>& servers);
SSLContextConfig config_;
diff --git a/net/socket/ssl_client_socket_impl.cc b/net/socket/ssl_client_socket_impl.cc
index 31256e791..9d7134c1c 100644
--- a/net/socket/ssl_client_socket_impl.cc
+++ b/net/socket/ssl_client_socket_impl.cc
@@ -624,22 +624,11 @@ void SSLClientSocketImpl::GetSSLCertRequestInfo(
CRYPTO_BUFFER_len(ca_name));
}
- cert_request_info->cert_key_types.clear();
- const uint8_t* client_cert_types;
- size_t num_client_cert_types =
- SSL_get0_certificate_types(ssl_.get(), &client_cert_types);
- for (size_t i = 0; i < num_client_cert_types; i++) {
- switch (client_cert_types[i]) {
- case static_cast<uint8_t>(SSLClientCertType::kRsaSign):
- case static_cast<uint8_t>(SSLClientCertType::kEcdsaSign):
- cert_request_info->cert_key_types.push_back(
- static_cast<SSLClientCertType>(client_cert_types[i]));
- break;
- default:
- // Unknown client certificate types are ignored.
- break;
- }
- }
+ const uint16_t* algorithms;
+ size_t num_algorithms =
+ SSL_get0_peer_verify_algorithms(ssl_.get(), &algorithms);
+ cert_request_info->signature_algorithms.assign(algorithms,
+ algorithms + num_algorithms);
}
void SSLClientSocketImpl::ApplySocketTag(const SocketTag& tag) {
@@ -702,8 +691,10 @@ int SSLClientSocketImpl::Write(
if (rv == ERR_IO_PENDING) {
user_write_callback_ = std::move(callback);
} else {
- if (rv > 0)
+ if (rv > 0) {
+ CHECK_LE(rv, buf_len);
was_ever_used_ = true;
+ }
user_write_buf_ = nullptr;
user_write_buf_len_ = 0;
}
@@ -754,9 +745,8 @@ int SSLClientSocketImpl::Init() {
return ERR_UNEXPECTED;
}
- if (context_->config().post_quantum_enabled &&
- base::FeatureList::IsEnabled(features::kPostQuantumKyber)) {
- static const int kCurves[] = {NID_X25519Kyber768, NID_X25519,
+ if (context_->config().PostQuantumKeyAgreementEnabled()) {
+ static const int kCurves[] = {NID_X25519Kyber768Draft00, NID_X25519,
NID_X9_62_prime256v1, NID_secp384r1};
if (!SSL_set1_curves(ssl_.get(), kCurves, std::size(kCurves))) {
return ERR_UNEXPECTED;
@@ -1247,10 +1237,15 @@ ssl_verify_result_t SSLClientSocketImpl::HandleVerifyResult() {
}
// Enforce keyUsage extension for RSA leaf certificates chaining up to known
- // roots.
- // TODO(crbug.com/795089): Enforce this unconditionally.
+ // roots unconditionally. Enforcement for local anchors is, for now,
+ // conditional on feature flags and external configuration. See
+ // https://crbug.com/795089.
+ bool rsa_key_usage_for_local_anchors =
+ context_->config().rsa_key_usage_for_local_anchors_override.value_or(
+ base::FeatureList::IsEnabled(features::kRSAKeyUsageForLocalAnchors));
SSL_set_enforce_rsa_key_usage(
- ssl_.get(), server_cert_verify_result_.is_issued_by_known_root);
+ ssl_.get(), rsa_key_usage_for_local_anchors ||
+ server_cert_verify_result_.is_issued_by_known_root);
// If the connection was good, check HPKP and CT status simultaneously,
// but prefer to treat the HPKP error as more serious, if there was one.
@@ -1505,6 +1500,7 @@ int SSLClientSocketImpl::DoPayloadWrite() {
int rv = SSL_write(ssl_.get(), user_write_buf_->data(), user_write_buf_len_);
if (rv >= 0) {
+ CHECK_LE(rv, user_write_buf_len_);
net_log_.AddByteTransferEvent(NetLogEventType::SSL_SOCKET_BYTES_SENT, rv,
user_write_buf_->data());
if (first_post_handshake_write_ && SSL_is_init_finished(ssl_.get())) {
@@ -1644,21 +1640,9 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) {
// Clear any currently configured certificates.
SSL_certs_clear(ssl_.get());
-#if BUILDFLAG(IS_IOS)
- // TODO(droger): Support client auth on iOS. See http://crbug.com/145954).
- //
- // Historically this was disabled because client auth required
- // platform-specific code deep in //net. Nowadays, this is abstracted away and
- // we could enable the interfaces on iOS for platform-independence. However,
- // merely enabling them changes our behavior from automatically proceeding
- // with no client certificate to raising
- // `URLRequest::Delegate::OnCertificateRequested`. Callers would need to be
- // updated to apply that behavior manually.
- //
- // If fixing this, re-enable the tests in ssl_client_socket_unittest.cc and
- // ssl_server_socket_unittest.cc which are disabled on iOS.
+#if !BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
LOG(WARNING) << "Client auth is not supported";
-#else // !BUILDFLAG(IS_IOS)
+#else // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
if (!send_client_cert_) {
// First pass: we know that a client certificate is needed, but we do not
// have one at hand. Suspend the handshake. SSL_get_error will return
@@ -1693,7 +1677,7 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) {
client_cert_->intermediate_buffers().size()));
return 1;
}
-#endif // BUILDFLAG(IS_IOS)
+#endif // !BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
// Send no client certificate.
net_log_.AddEventWithIntParams(NetLogEventType::SSL_CLIENT_CERT_PROVIDED,
@@ -1842,7 +1826,7 @@ void SSLClientSocketImpl::MessageCallback(int is_write,
break;
case SSL3_RT_CLIENT_HELLO_INNER:
DCHECK(is_write);
- net_log_.AddEvent(NetLogEventType::SSL_ENCYPTED_CLIENT_HELLO,
+ net_log_.AddEvent(NetLogEventType::SSL_ENCRYPTED_CLIENT_HELLO,
[&](NetLogCaptureMode capture_mode) {
return NetLogSSLMessageParams(!!is_write, buf, len,
capture_mode);
diff --git a/net/socket/ssl_client_socket_impl.h b/net/socket/ssl_client_socket_impl.h
index e7636ba29..fc7487077 100644
--- a/net/socket/ssl_client_socket_impl.h
+++ b/net/socket/ssl_client_socket_impl.h
@@ -29,7 +29,6 @@
#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
#include "net/ssl/openssl_ssl_util.h"
-#include "net/ssl/ssl_client_cert_type.h"
#include "net/ssl/ssl_client_session_cache.h"
#include "net/ssl/ssl_config.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc
index 8fd32cd7d..6c30471f2 100644
--- a/net/socket/ssl_client_socket_unittest.cc
+++ b/net/socket/ssl_client_socket_unittest.cc
@@ -34,6 +34,7 @@
#include "net/base/address_list.h"
#include "net/base/completion_once_callback.h"
#include "net/base/features.h"
+#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
@@ -43,6 +44,7 @@
#include "net/base/test_completion_callback.h"
#include "net/cert/asn1_util.h"
#include "net/cert/cert_and_ct_verifier.h"
+#include "net/cert/cert_database.h"
#include "net/cert/ct_policy_enforcer.h"
#include "net/cert/ct_policy_status.h"
#include "net/cert/ct_verifier.h"
@@ -583,7 +585,7 @@ class DeleteSocketCallback : public TestCompletionCallbackBase {
SetResult(result);
}
- raw_ptr<StreamSocket> socket_;
+ raw_ptr<StreamSocket, DanglingUntriaged> socket_;
};
// A mock CTVerifier that records every call to Verify but doesn't verify
@@ -828,7 +830,7 @@ class SSLClientSocketTest : public PlatformTest, public WithTaskEnvironment {
}
RecordingNetLogObserver log_observer_;
- raw_ptr<ClientSocketFactory> socket_factory_;
+ raw_ptr<ClientSocketFactory, DanglingUntriaged> socket_factory_;
std::unique_ptr<TestSSLConfigService> ssl_config_service_;
std::unique_ptr<MockCertVerifier> cert_verifier_;
std::unique_ptr<TransportSecurityState> transport_security_state_;
@@ -1390,6 +1392,13 @@ class HangingCertVerifier : public CertVerifier {
int num_active_requests_ = 0;
};
+class MockSSLClientContextObserver : public SSLClientContext::Observer {
+ public:
+ MOCK_METHOD1(OnSSLConfigChanged, void(SSLClientContext::SSLConfigChangeType));
+ MOCK_METHOD1(OnSSLConfigForServersChanged,
+ void(const base::flat_set<HostPortPair>&));
+};
+
} // namespace
INSTANTIATE_TEST_SUITE_P(TLSVersion,
@@ -1547,7 +1556,7 @@ TEST_P(SSLClientSocketVersionTest, ConnectBadValidityIgnoreCertErrors) {
}
// Client certificates are disabled on iOS.
-#if !BUILDFLAG(IS_IOS)
+#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
// Attempt to connect to a page which requests a client certificate. It should
// return an error code on connect.
TEST_P(SSLClientSocketVersionTest, ConnectClientAuthCertRequested) {
@@ -1590,7 +1599,7 @@ TEST_P(SSLClientSocketVersionTest, ConnectClientAuthSendNullCert) {
sock_->Disconnect();
EXPECT_FALSE(sock_->IsConnected());
}
-#endif // !IS_IOS
+#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
// TODO(wtc): Add unit tests for IsConnectedAndIdle:
// - Server closes an SSL connection (with a close_notify alert message).
@@ -2659,7 +2668,7 @@ TEST_P(SSLClientSocketVersionTest, VerifyReturnChainProperlyOrdered) {
}
// Client certificates are disabled on iOS.
-#if !BUILDFLAG(IS_IOS)
+#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
INSTANTIATE_TEST_SUITE_P(TLSVersion,
SSLClientSocketCertRequestInfoTest,
ValuesIn(GetTLSVersions()));
@@ -2724,20 +2733,14 @@ TEST_P(SSLClientSocketCertRequestInfoTest, CertKeyTypes) {
config.client_cert_type = SSLServerConfig::OPTIONAL_CLIENT_CERT;
ASSERT_TRUE(StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, config));
scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest();
- ASSERT_TRUE(request_info.get());
- if (version() >= SSL_PROTOCOL_VERSION_TLS1_3) {
- // TLS 1.3 does not use cert_key_types, only signature algorithms. This
- // should be migrated to a more modern mechanism. See
- // https://crbug.com/1270530.
- EXPECT_EQ(0u, request_info->cert_key_types.size());
- } else {
- // BoringSSL always sends rsa_sign and ecdsa_sign.
- ASSERT_EQ(2u, request_info->cert_key_types.size());
- EXPECT_EQ(SSLClientCertType::kRsaSign, request_info->cert_key_types[0]);
- EXPECT_EQ(SSLClientCertType::kEcdsaSign, request_info->cert_key_types[1]);
- }
+ ASSERT_TRUE(request_info);
+ // Look for some values we expect BoringSSL to always send.
+ EXPECT_THAT(request_info->signature_algorithms,
+ testing::Contains(SSL_SIGN_ECDSA_SECP256R1_SHA256));
+ EXPECT_THAT(request_info->signature_algorithms,
+ testing::Contains(SSL_SIGN_RSA_PSS_RSAE_SHA256));
}
-#endif // !IS_IOS
+#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
// Tests that the Certificate Transparency (RFC 6962) TLS extension is
// supported.
@@ -3002,6 +3005,22 @@ TEST_P(SSLClientSocketVersionTest, SessionResumption) {
ASSERT_THAT(rv, IsOk());
ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info));
EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type);
+
+ // Pick up the ticket again and confirm resumption works.
+ EXPECT_THAT(MakeHTTPRequest(sock_.get()), IsOk());
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv));
+ ASSERT_THAT(rv, IsOk());
+ ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info));
+ EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type);
+ sock_.reset();
+
+ // Updating the context-wide configuration should flush the session cache.
+ SSLContextConfig config;
+ config.disabled_cipher_suites = {1234};
+ ssl_config_service_->UpdateSSLConfigAndNotify(config);
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv));
+ ASSERT_THAT(rv, IsOk());
+ ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info));
}
namespace {
@@ -3590,7 +3609,7 @@ TEST_F(SSLClientSocketTest, AlpnClientDisabled) {
}
// Client certificates are disabled on iOS.
-#if !BUILDFLAG(IS_IOS)
+#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
// Connect to a server requesting client authentication, do not send
// any client certificates. It should refuse the connection.
TEST_P(SSLClientSocketVersionTest, NoCert) {
@@ -3733,7 +3752,109 @@ TEST_F(SSLClientSocketTest, ClearSessionCacheOnClientCertChange) {
ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv));
EXPECT_THAT(rv, IsError(ERR_BAD_SSL_CLIENT_AUTH_CERT));
}
-#endif // !IS_IOS
+
+TEST_F(SSLClientSocketTest, ClearSessionCacheOnClientCertDatabaseChange) {
+ SSLServerConfig server_config;
+ // TLS 1.3 reports client certificate errors after the handshake, so test at
+ // TLS 1.2 for simplicity.
+ server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
+ server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT;
+ ASSERT_TRUE(
+ StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config));
+
+ HostPortPair host_port_pair2("example.com", 42);
+ testing::StrictMock<MockSSLClientContextObserver> observer;
+ EXPECT_CALL(observer, OnSSLConfigForServersChanged(
+ base::flat_set<HostPortPair>({host_port_pair()})));
+ EXPECT_CALL(observer, OnSSLConfigForServersChanged(
+ base::flat_set<HostPortPair>({host_port_pair2})));
+ EXPECT_CALL(observer,
+ OnSSLConfigForServersChanged(base::flat_set<HostPortPair>(
+ {host_port_pair(), host_port_pair2})));
+
+ context_->AddObserver(&observer);
+
+ base::FilePath certs_dir = GetTestCertsDirectory();
+ context_->SetClientCertificate(
+ host_port_pair(), ImportCertFromFile(certs_dir, "client_1.pem"),
+ key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key")));
+
+ context_->SetClientCertificate(
+ host_port_pair2, ImportCertFromFile(certs_dir, "client_2.pem"),
+ key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_2.key")));
+
+ EXPECT_EQ(2U, context_->GetClientCertificateCachedServersForTesting().size());
+
+ // Connect to `host_port_pair()` using the client cert.
+ int rv;
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv));
+ EXPECT_THAT(rv, IsOk());
+ EXPECT_TRUE(sock_->IsConnected());
+
+ EXPECT_EQ(1U, context_->ssl_client_session_cache()->size());
+
+ CertDatabase::GetInstance()->NotifyObserversClientCertStoreChanged();
+ base::RunLoop().RunUntilIdle();
+
+ EXPECT_EQ(0U, context_->GetClientCertificateCachedServersForTesting().size());
+ EXPECT_EQ(0U, context_->ssl_client_session_cache()->size());
+
+ context_->RemoveObserver(&observer);
+}
+
+TEST_F(SSLClientSocketTest, DontClearSessionCacheOnServerCertDatabaseChange) {
+ SSLServerConfig server_config;
+ // TLS 1.3 reports client certificate errors after the handshake, so test at
+ // TLS 1.2 for simplicity.
+ server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
+ server_config.client_cert_type = SSLServerConfig::REQUIRE_CLIENT_CERT;
+ ASSERT_TRUE(
+ StartEmbeddedTestServer(EmbeddedTestServer::CERT_OK, server_config));
+
+ HostPortPair host_port_pair2("example.com", 42);
+ testing::StrictMock<MockSSLClientContextObserver> observer;
+ EXPECT_CALL(observer, OnSSLConfigForServersChanged(
+ base::flat_set<HostPortPair>({host_port_pair()})));
+ EXPECT_CALL(observer, OnSSLConfigForServersChanged(
+ base::flat_set<HostPortPair>({host_port_pair2})));
+ EXPECT_CALL(observer,
+ OnSSLConfigChanged(
+ SSLClientContext::SSLConfigChangeType::kCertDatabaseChanged));
+
+ context_->AddObserver(&observer);
+
+ base::FilePath certs_dir = GetTestCertsDirectory();
+ context_->SetClientCertificate(
+ host_port_pair(), ImportCertFromFile(certs_dir, "client_1.pem"),
+ key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key")));
+
+ context_->SetClientCertificate(
+ host_port_pair2, ImportCertFromFile(certs_dir, "client_2.pem"),
+ key_util::LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_2.key")));
+
+ EXPECT_EQ(2U, context_->GetClientCertificateCachedServersForTesting().size());
+
+ // Connect to `host_port_pair()` using the client cert.
+ int rv;
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv));
+ EXPECT_THAT(rv, IsOk());
+ EXPECT_TRUE(sock_->IsConnected());
+
+ EXPECT_EQ(1U, context_->ssl_client_session_cache()->size());
+
+ CertDatabase::GetInstance()->NotifyObserversTrustStoreChanged();
+ base::RunLoop().RunUntilIdle();
+
+ // The `OnSSLConfigChanged` observer call should be verified by the
+ // mock observer, but the client auth and client session cache should be
+ // untouched.
+
+ EXPECT_EQ(2U, context_->GetClientCertificateCachedServersForTesting().size());
+ EXPECT_EQ(1U, context_->ssl_client_session_cache()->size());
+
+ context_->RemoveObserver(&observer);
+}
+#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
HashValueVector MakeHashValueVector(uint8_t value) {
HashValueVector out;
@@ -3831,37 +3952,50 @@ const uint16_t kSigningCipher = kModernTLS12Cipher;
struct KeyUsageTest {
EmbeddedTestServer::ServerCertificate server_cert;
uint16_t cipher_suite;
- bool known_root;
- bool success;
+ bool match;
};
class SSLClientSocketKeyUsageTest
: public SSLClientSocketTest,
- public ::testing::WithParamInterface<struct KeyUsageTest> {};
+ public ::testing::WithParamInterface<
+ std::tuple<KeyUsageTest,
+ bool /*known_root*/,
+ bool /*rsa_key_usage_for_local_anchors_enabled*/,
+ bool /*override_feature*/>> {};
-const struct KeyUsageTest kKeyUsageTests[] = {
- // Known Root: Success iff keyUsage allows the key exchange method
- {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kSigningCipher, true,
- false},
+const KeyUsageTest kKeyUsageTests[] = {
+ // keyUsage matches cipher suite.
{EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE, kSigningCipher,
- true, true},
- {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kEncryptingCipher,
- true, true},
- {EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE,
- kEncryptingCipher, true, false},
- // Unknown Root: Always succeeds
- {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kSigningCipher, false,
true},
- {EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE, kSigningCipher,
- false, true},
{EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kEncryptingCipher,
- false, true},
+ true},
+ // keyUsage does not match cipher suite.
+ {EmbeddedTestServer::CERT_KEY_USAGE_RSA_ENCIPHERMENT, kSigningCipher,
+ false},
{EmbeddedTestServer::CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE,
- kEncryptingCipher, false, true},
+ kEncryptingCipher, false},
};
-TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) {
- const KeyUsageTest test = GetParam();
+TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsage) {
+ const auto& [test, known_root, rsa_key_usage_for_local_anchors_enabled,
+ override_feature] = GetParam();
+ bool enable_feature;
+ if (override_feature) {
+ // Configure the feature in the opposite way that we intend, to test that
+ // the configuration overrides it.
+ enable_feature = !rsa_key_usage_for_local_anchors_enabled;
+ } else {
+ enable_feature = rsa_key_usage_for_local_anchors_enabled;
+ }
+ base::test::ScopedFeatureList scoped_feature_list;
+ if (enable_feature) {
+ scoped_feature_list.InitAndEnableFeature(
+ features::kRSAKeyUsageForLocalAnchors);
+ } else {
+ scoped_feature_list.InitAndDisableFeature(
+ features::kRSAKeyUsageForLocalAnchors);
+ }
+
SSLServerConfig server_config;
server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
server_config.cipher_suite_for_testing = test.cipher_suite;
@@ -3869,9 +4003,16 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) {
scoped_refptr<X509Certificate> server_cert =
embedded_test_server()->GetCertificate();
+ SSLContextConfig context_config;
+ if (override_feature) {
+ context_config.rsa_key_usage_for_local_anchors_override =
+ rsa_key_usage_for_local_anchors_enabled;
+ }
+ ssl_config_service_->UpdateSSLConfigAndNotify(context_config);
+
// Certificate is trusted.
CertVerifyResult verify_result;
- verify_result.is_issued_by_known_root = test.known_root;
+ verify_result.is_issued_by_known_root = known_root;
verify_result.verified_cert = server_cert;
verify_result.public_key_hashes =
MakeHashValueVector(kGoodHashValueVectorInput);
@@ -3883,7 +4024,7 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) {
SSLInfo ssl_info;
ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info));
- if (test.success) {
+ if (test.match || (!known_root && !rsa_key_usage_for_local_anchors_enabled)) {
EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(sock_->IsConnected());
} else {
@@ -3892,9 +4033,10 @@ TEST_P(SSLClientSocketKeyUsageTest, RSAKeyUsageEnforcedForKnownRoot) {
}
}
-INSTANTIATE_TEST_SUITE_P(RSAKeyUsageInstantiation,
- SSLClientSocketKeyUsageTest,
- ValuesIn(kKeyUsageTests));
+INSTANTIATE_TEST_SUITE_P(
+ RSAKeyUsageInstantiation,
+ SSLClientSocketKeyUsageTest,
+ Combine(ValuesIn(kKeyUsageTests), Bool(), Bool(), Bool()));
// Test that when CT is required (in this case, by the delegate), the
// absence of CT information is a socket error.
diff --git a/net/socket/ssl_server_socket_impl.cc b/net/socket/ssl_server_socket_impl.cc
index f5f028847..9f22c20c2 100644
--- a/net/socket/ssl_server_socket_impl.cc
+++ b/net/socket/ssl_server_socket_impl.cc
@@ -155,19 +155,19 @@ class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
void OnHandshakeIOComplete(int result);
- int DoPayloadRead(IOBuffer* buf, int buf_len);
- int DoPayloadWrite();
+ [[nodiscard]] int DoPayloadRead(IOBuffer* buf, int buf_len);
+ [[nodiscard]] int DoPayloadWrite();
- int DoHandshakeLoop(int last_io_result);
- int DoHandshake();
+ [[nodiscard]] int DoHandshakeLoop(int last_io_result);
+ [[nodiscard]] int DoHandshake();
void DoHandshakeCallback(int result);
void DoReadCallback(int result);
void DoWriteCallback(int result);
- int Init();
+ [[nodiscard]] int Init();
void ExtractClientCert();
- raw_ptr<SSLServerContextImpl> context_;
+ raw_ptr<SSLServerContextImpl, DanglingUntriaged> context_;
NetLogWithSource net_log_;
@@ -312,7 +312,7 @@ void SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete(
signature_result_ = error;
if (signature_result_ == OK)
signature_ = signature;
- DoHandshakeLoop(ERR_IO_PENDING);
+ OnHandshakeIOComplete(ERR_IO_PENDING);
}
// static
diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc
index eed11bd33..2cfb6563e 100644
--- a/net/socket/ssl_server_socket_unittest.cc
+++ b/net/socket/ssl_server_socket_unittest.cc
@@ -34,6 +34,7 @@
#include "base/notreached.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_runner.h"
+#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "build/build_config.h"
#include "crypto/rsa_private_key.h"
@@ -85,12 +86,12 @@ namespace net {
namespace {
// Client certificates are disabled on iOS.
-#if !BUILDFLAG(IS_IOS)
+#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
const char kClientCertFileName[] = "client_1.pem";
const char kClientPrivateKeyFileName[] = "client_1.pk8";
const char kWrongClientCertFileName[] = "client_2.pem";
const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
-#endif // !IS_IOS
+#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
const uint16_t kEcdheCiphers[] = {
0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA
@@ -442,7 +443,7 @@ class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
}
// Client certificates are disabled on iOS.
-#if !BUILDFLAG(IS_IOS)
+#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
void ConfigureClientCertsForClient(const char* cert_file_name,
const char* private_key_file_name) {
scoped_refptr<X509Certificate> client_cert =
@@ -477,7 +478,7 @@ class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
}
-#endif // !IS_IOS
+#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(base::StringPiece name) {
base::FilePath certs_dir(GetTestCertsDirectory());
@@ -708,7 +709,7 @@ TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
}
// Client certificates are disabled on iOS.
-#if !BUILDFLAG(IS_IOS)
+#if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
// This test executes Connect() on SSLClientSocket and Handshake() on
// SSLServerSocket to make sure handshaking between the two sockets is
// completed successfully, using client certificate.
@@ -1010,7 +1011,7 @@ TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
client_ret = read_callback.GetResult(client_ret);
EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
}
-#endif // !IS_IOS
+#endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
TEST_P(SSLServerSocketReadTest, DataTransfer) {
ASSERT_NO_FATAL_FAILURE(CreateContext());
@@ -1267,6 +1268,69 @@ TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
EXPECT_TRUE(is_aead);
}
+namespace {
+
+// Helper that wraps an underlying SSLPrivateKey to allow the test to
+// do some work immediately before a `Sign()` operation is performed.
+class SSLPrivateKeyHook : public SSLPrivateKey {
+ public:
+ SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,
+ base::RepeatingClosure on_sign)
+ : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {}
+
+ // SSLPrivateKey implementation.
+ std::string GetProviderName() override {
+ return private_key_->GetProviderName();
+ }
+ std::vector<uint16_t> GetAlgorithmPreferences() override {
+ return private_key_->GetAlgorithmPreferences();
+ }
+ void Sign(uint16_t algorithm,
+ base::span<const uint8_t> input,
+ SignCallback callback) override {
+ on_sign_.Run();
+ private_key_->Sign(algorithm, input, std::move(callback));
+ }
+
+ private:
+ ~SSLPrivateKeyHook() override = default;
+
+ const scoped_refptr<SSLPrivateKey> private_key_;
+ const base::RepeatingClosure on_sign_;
+};
+
+} // namespace
+
+// Verifies that if the client disconnects while during private key signing then
+// the disconnection is correctly reported to the `Handshake()` completion
+// callback, with `ERR_CONNECTION_CLOSED`.
+// This is a regression test for crbug.com/1449461.
+TEST_F(SSLServerSocketTest,
+ HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) {
+ auto on_sign = base::BindLambdaForTesting([&]() {
+ client_socket_->Disconnect();
+ ASSERT_FALSE(client_socket_->IsConnected());
+ });
+ server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>(
+ std::move(server_ssl_private_key_), on_sign);
+ ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
+ ASSERT_NO_FATAL_FAILURE(CreateSockets());
+
+ TestCompletionCallback handshake_callback;
+ int server_ret = server_socket_->Handshake(handshake_callback.callback());
+ ASSERT_EQ(server_ret, net::ERR_IO_PENDING);
+
+ TestCompletionCallback connect_callback;
+ client_socket_->Connect(connect_callback.callback());
+
+ // If resuming the handshake after private-key signing is not handled
+ // correctly as per crbug.com/1449461 then the test will hang and timeout
+ // at this point, due to the server-side completion callback not being
+ // correctly invoked.
+ server_ret = handshake_callback.GetResult(server_ret);
+ EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED);
+}
+
// Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
// server key.
TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
diff --git a/net/socket/tcp_client_socket.cc b/net/socket/tcp_client_socket.cc
index a5fc05194..47a6b0dfa 100644
--- a/net/socket/tcp_client_socket.cc
+++ b/net/socket/tcp_client_socket.cc
@@ -61,11 +61,12 @@ TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,
TCPClientSocket::TCPClientSocket(
std::unique_ptr<TCPSocket> unconnected_socket,
const AddressList& addresses,
+ std::unique_ptr<IPEndPoint> bound_address,
NetworkQualityEstimator* network_quality_estimator)
: TCPClientSocket(std::move(unconnected_socket),
addresses,
-1 /* current_address_index */,
- nullptr /* bind_address */,
+ std::move(bound_address),
network_quality_estimator,
handles::kInvalidNetworkHandle) {}
diff --git a/net/socket/tcp_client_socket.h b/net/socket/tcp_client_socket.h
index b761849ae..200246171 100644
--- a/net/socket/tcp_client_socket.h
+++ b/net/socket/tcp_client_socket.h
@@ -69,10 +69,11 @@ class NET_EXPORT TCPClientSocket : public TransportClientSocket,
TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,
const IPEndPoint& peer_address);
- // Adopts an unconnected TCPSocket. This function is used by
- // TCPClientSocketBrokered.
+ // Adopts an unconnected TCPSocket. TCPSocket may be bound or unbound. This
+ // function is used by TCPClientSocketBrokered.
TCPClientSocket(std::unique_ptr<TCPSocket> unconnected_socket,
const AddressList& addresses,
+ std::unique_ptr<IPEndPoint> bound_address,
NetworkQualityEstimator* network_quality_estimator);
// Creates a TCPClientSocket from a bound-but-not-connected socket.
diff --git a/net/socket/tcp_socket_win.cc b/net/socket/tcp_socket_win.cc
index de99ac85b..355aa7dae 100644
--- a/net/socket/tcp_socket_win.cc
+++ b/net/socket/tcp_socket_win.cc
@@ -949,13 +949,15 @@ void TCPSocketWin::DidCompleteConnect() {
int rv = WSAEnumNetworkEvents(socket_, core_->read_event_, &events);
int os_error = WSAGetLastError();
if (rv == SOCKET_ERROR) {
- NOTREACHED();
+ DLOG(FATAL)
+ << "WSAEnumNetworkEvents() failed with SOCKET_ERROR, os_error = "
+ << os_error;
result = MapSystemError(os_error);
} else if (events.lNetworkEvents & FD_CONNECT) {
os_error = events.iErrorCode[FD_CONNECT_BIT];
result = MapConnectError(os_error);
} else {
- NOTREACHED();
+ DLOG(FATAL) << "WSAEnumNetworkEvents() failed, rv = " << rv;
result = ERR_UNEXPECTED;
}
diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc
index 03180c6c8..8587f4b9d 100644
--- a/net/socket/transport_client_socket_pool.cc
+++ b/net/socket/transport_client_socket_pool.cc
@@ -830,8 +830,8 @@ void TransportClientSocketPool::OnSSLConfigChanged(
}
// TODO(crbug.com/1206799): Get `server` as SchemeHostPort?
-void TransportClientSocketPool::OnSSLConfigForServerChanged(
- const HostPortPair& server) {
+void TransportClientSocketPool::OnSSLConfigForServersChanged(
+ const base::flat_set<HostPortPair>& servers) {
// Current time value. Retrieving it once at the function start rather than
// inside the inner loop, since it shouldn't change by any meaningful amount.
//
@@ -844,12 +844,13 @@ void TransportClientSocketPool::OnSSLConfigForServerChanged(
// every group.
bool proxy_matches = proxy_server_.is_http_like() &&
!proxy_server_.is_http() &&
- proxy_server_.host_port_pair() == server;
+ servers.contains(proxy_server_.host_port_pair());
bool refreshed_any = false;
for (auto it = group_map_.begin(); it != group_map_.end();) {
if (proxy_matches ||
(GURL::SchemeIsCryptographic(it->first.destination().scheme()) &&
- HostPortPair::FromSchemeHostPort(it->first.destination()) == server)) {
+ servers.contains(
+ HostPortPair::FromSchemeHostPort(it->first.destination())))) {
refreshed_any = true;
// Note this call may destroy the group and invalidate |to_refresh|.
it = RefreshGroup(it, now, kSslConfigChanged);
diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h
index 52ff53207..470c878df 100644
--- a/net/socket/transport_client_socket_pool.h
+++ b/net/socket/transport_client_socket_pool.h
@@ -265,7 +265,8 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool
// SSLClientContext::Observer methods.
void OnSSLConfigChanged(
SSLClientContext::SSLConfigChangeType change_type) override;
- void OnSSLConfigForServerChanged(const HostPortPair& server) override;
+ void OnSSLConfigForServersChanged(
+ const base::flat_set<HostPortPair>& servers) override;
private:
// Entry for a persistent socket which became idle at time |start_time|.
diff --git a/net/socket/udp_client_socket.cc b/net/socket/udp_client_socket.cc
index 8645f1fc5..ca4a277e7 100644
--- a/net/socket/udp_client_socket.cc
+++ b/net/socket/udp_client_socket.cc
@@ -18,6 +18,11 @@ UDPClientSocket::UDPClientSocket(DatagramSocket::BindType bind_type,
handles::NetworkHandle network)
: socket_(bind_type, net_log, source), connect_using_network_(network) {}
+UDPClientSocket::UDPClientSocket(DatagramSocket::BindType bind_type,
+ NetLogWithSource source_net_log,
+ handles::NetworkHandle network)
+ : socket_(bind_type, source_net_log), connect_using_network_(network) {}
+
UDPClientSocket::~UDPClientSocket() = default;
int UDPClientSocket::Connect(const IPEndPoint& address) {
@@ -26,7 +31,10 @@ int UDPClientSocket::Connect(const IPEndPoint& address) {
return ConnectUsingNetwork(connect_using_network_, address);
connect_called_ = true;
- int rv = socket_.Open(address.GetFamily());
+ int rv = OK;
+ if (!adopted_opened_socket_) {
+ rv = socket_.Open(address.GetFamily());
+ }
if (rv != OK)
return rv;
return socket_.Connect(address);
@@ -38,9 +46,13 @@ int UDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network,
connect_called_ = true;
if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
return ERR_NOT_IMPLEMENTED;
- int rv = socket_.Open(address.GetFamily());
- if (rv != OK)
+ int rv = OK;
+ if (!adopted_opened_socket_) {
+ rv = socket_.Open(address.GetFamily());
+ }
+ if (rv != OK) {
return rv;
+ }
rv = socket_.BindToNetwork(network);
if (rv != OK)
return rv;
@@ -53,8 +65,10 @@ int UDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
connect_called_ = true;
if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
return ERR_NOT_IMPLEMENTED;
- int rv;
- rv = socket_.Open(address.GetFamily());
+ int rv = OK;
+ if (!adopted_opened_socket_) {
+ rv = socket_.Open(address.GetFamily());
+ }
if (rv != OK)
return rv;
// Calling connect() will bind a socket to the default network, however there
@@ -126,6 +140,7 @@ int UDPClientSocket::Write(
void UDPClientSocket::Close() {
socket_.Close();
+ adopted_opened_socket_ = false;
}
int UDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
@@ -178,9 +193,13 @@ void UDPClientSocket::SetIOSNetworkServiceType(int ios_network_service_type) {
#endif
}
-void UDPClientSocket::AdoptOpenedSocket(AddressFamily address_family,
- SocketDescriptor socket) {
- socket_.AdoptOpenedSocket(address_family, socket);
+int UDPClientSocket::AdoptOpenedSocket(AddressFamily address_family,
+ SocketDescriptor socket) {
+ int rv = socket_.AdoptOpenedSocket(address_family, socket);
+ if (rv == OK) {
+ adopted_opened_socket_ = true;
+ }
+ return rv;
}
} // namespace net
diff --git a/net/socket/udp_client_socket.h b/net/socket/udp_client_socket.h
index fc695b7cd..6e3875b4f 100644
--- a/net/socket/udp_client_socket.h
+++ b/net/socket/udp_client_socket.h
@@ -30,6 +30,11 @@ class NET_EXPORT_PRIVATE UDPClientSocket : public DatagramClientSocket {
const net::NetLogSource& source,
handles::NetworkHandle network = handles::kInvalidNetworkHandle);
+ UDPClientSocket(
+ DatagramSocket::BindType bind_type,
+ NetLogWithSource source_net_log,
+ handles::NetworkHandle network = handles::kInvalidNetworkHandle);
+
UDPClientSocket(const UDPClientSocket&) = delete;
UDPClientSocket& operator=(const UDPClientSocket&) = delete;
@@ -74,14 +79,34 @@ class NET_EXPORT_PRIVATE UDPClientSocket : public DatagramClientSocket {
int SetMulticastInterface(uint32_t interface_index) override;
void SetIOSNetworkServiceType(int ios_network_service_type) override;
- // Takes ownership of an opened but unconnected and unbound `socket`.
- void AdoptOpenedSocket(AddressFamily address_family, SocketDescriptor socket);
+ // Takes ownership of an opened but unconnected and unbound `socket`. This
+ // method must be called after UseNonBlockingIO, otherwise the adopted socket
+ // will not have the non-blocking IO flag set.
+ int AdoptOpenedSocket(AddressFamily address_family, SocketDescriptor socket);
+
+ uint32_t get_multicast_interface_for_testing() {
+ return socket_.get_multicast_interface_for_testing();
+ }
+#if !BUILDFLAG(IS_WIN)
+ bool get_msg_confirm_for_testing() {
+ return socket_.get_msg_confirm_for_testing();
+ }
+ bool get_recv_optimization_for_testing() {
+ return socket_.get_experimental_recv_optimization_enabled_for_testing();
+ }
+#endif
+#if BUILDFLAG(IS_WIN)
+ bool get_use_non_blocking_io_for_testing() {
+ return socket_.get_use_non_blocking_io_for_testing();
+ }
+#endif
private:
UDPSocket socket_;
+ bool adopted_opened_socket_ = false;
bool connect_called_ = false;
// The network the socket is currently bound to.
- handles::NetworkHandle network_;
+ handles::NetworkHandle network_ = handles::kInvalidNetworkHandle;
handles::NetworkHandle connect_using_network_;
};
diff --git a/net/socket/udp_socket_posix.cc b/net/socket/udp_socket_posix.cc
index d7fe60ca6..120ce583c 100644
--- a/net/socket/udp_socket_posix.cc
+++ b/net/socket/udp_socket_posix.cc
@@ -139,6 +139,22 @@ UDPSocketPosix::UDPSocketPosix(DatagramSocket::BindType bind_type,
net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
}
+UDPSocketPosix::UDPSocketPosix(DatagramSocket::BindType bind_type,
+ NetLogWithSource source_net_log)
+ : socket_(kInvalidSocket),
+ bind_type_(bind_type),
+ read_socket_watcher_(FROM_HERE),
+ write_socket_watcher_(FROM_HERE),
+ read_watcher_(this),
+ write_watcher_(this),
+ net_log_(source_net_log),
+ bound_network_(handles::kInvalidNetworkHandle),
+ always_update_bytes_received_(base::FeatureList::IsEnabled(
+ features::kUdpSocketPosixAlwaysUpdateBytesReceived)) {
+ net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE,
+ net_log_.source());
+}
+
UDPSocketPosix::~UDPSocketPosix() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
Close();
diff --git a/net/socket/udp_socket_posix.h b/net/socket/udp_socket_posix.h
index ed9c7549c..a0cb1e2f5 100644
--- a/net/socket/udp_socket_posix.h
+++ b/net/socket/udp_socket_posix.h
@@ -72,6 +72,9 @@ class NET_EXPORT UDPSocketPosix {
net::NetLog* net_log,
const net::NetLogSource& source);
+ UDPSocketPosix(DatagramSocket::BindType bind_type,
+ NetLogWithSource source_net_log);
+
UDPSocketPosix(const UDPSocketPosix&) = delete;
UDPSocketPosix& operator=(const UDPSocketPosix&) = delete;
@@ -280,6 +283,14 @@ class NET_EXPORT UDPSocketPosix {
// not bound or connected to an address.
int AdoptOpenedSocket(AddressFamily address_family, int socket);
+ uint32_t get_multicast_interface_for_testing() {
+ return multicast_interface_;
+ }
+ bool get_msg_confirm_for_testing() { return sendto_flags_; }
+ bool get_experimental_recv_optimization_enabled_for_testing() {
+ return experimental_recv_optimization_enabled_;
+ }
+
private:
enum SocketOptions {
SOCKET_OPTION_MULTICAST_LOOP = 1 << 0
diff --git a/net/socket/udp_socket_win.cc b/net/socket/udp_socket_win.cc
index dfc234f5d..32926a00c 100644
--- a/net/socket/udp_socket_win.cc
+++ b/net/socket/udp_socket_win.cc
@@ -249,6 +249,16 @@ UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
}
+UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
+ NetLogWithSource source_net_log)
+ : socket_(INVALID_SOCKET),
+ socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
+ net_log_(source_net_log) {
+ EnsureWinsockInit();
+ net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE,
+ net_log_.source());
+}
+
UDPSocketWin::~UDPSocketWin() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
Close();
diff --git a/net/socket/udp_socket_win.h b/net/socket/udp_socket_win.h
index 857613564..dd27752c6 100644
--- a/net/socket/udp_socket_win.h
+++ b/net/socket/udp_socket_win.h
@@ -165,6 +165,9 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate {
net::NetLog* net_log,
const net::NetLogSource& source);
+ UDPSocketWin(DatagramSocket::BindType bind_type,
+ NetLogWithSource source_net_log);
+
UDPSocketWin(const UDPSocketWin&) = delete;
UDPSocketWin& operator=(const UDPSocketWin&) = delete;
@@ -346,8 +349,8 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate {
// Resets the thread to be used for thread-safety checks.
void DetachFromThread();
- // This class by default uses overlapped IO. Call this method before Open()
- // to switch to non-blocking IO.
+ // This class by default uses overlapped IO. Call this method before Open() or
+ // AdoptOpenedSocket() to switch to non-blocking IO.
void UseNonBlockingIO();
// Apply |tag| to this socket.
@@ -355,9 +358,16 @@ class NET_EXPORT UDPSocketWin : public base::win::ObjectWatcher::Delegate {
// Takes ownership of `socket`, which should be a socket descriptor opened
// with the specified address family. The socket should only be created but
- // not bound or connected to an address.
+ // not bound or connected to an address. This method must be called after
+ // UseNonBlockingIO, otherwise the adopted socket will not have the
+ // non-blocking IO flag set.
int AdoptOpenedSocket(AddressFamily address_family, SOCKET socket);
+ uint32_t get_multicast_interface_for_testing() {
+ return multicast_interface_;
+ }
+ bool get_use_non_blocking_io_for_testing() { return use_non_blocking_io_; }
+
private:
enum SocketOptions {
SOCKET_OPTION_MULTICAST_LOOP = 1 << 0
diff --git a/net/socket/websocket_transport_client_socket_pool.cc b/net/socket/websocket_transport_client_socket_pool.cc
index b8f15ee7d..286652d99 100644
--- a/net/socket/websocket_transport_client_socket_pool.cc
+++ b/net/socket/websocket_transport_client_socket_pool.cc
@@ -44,10 +44,10 @@ WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool(
WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() {
// Clean up any pending connect jobs.
FlushWithError(ERR_ABORTED, "");
- DCHECK(pending_connects_.empty());
- DCHECK_EQ(0, handed_out_socket_count_);
- DCHECK(stalled_request_queue_.empty());
- DCHECK(stalled_request_map_.empty());
+ CHECK(pending_connects_.empty());
+ CHECK_EQ(0, handed_out_socket_count_);
+ CHECK(stalled_request_queue_.empty());
+ CHECK(stalled_request_map_.empty());
}
// static
@@ -160,8 +160,12 @@ void WebSocketTransportClientSocketPool::CancelRequest(
if (socket)
ReleaseSocket(handle->group_id(), std::move(socket),
handle->group_generation());
- if (!DeleteJob(handle))
+ if (DeleteJob(handle)) {
+ CHECK(!base::Contains(pending_callbacks_,
+ reinterpret_cast<ClientSocketHandleID>(handle)));
+ } else {
pending_callbacks_.erase(reinterpret_cast<ClientSocketHandleID>(handle));
+ }
ActivateStalledRequest();
}
@@ -331,7 +335,7 @@ void WebSocketTransportClientSocketPool::OnConnectJobComplete(
ClientSocketHandle* const handle = connect_job_delegate->socket_handle();
bool delete_succeeded = DeleteJob(handle);
- DCHECK(delete_succeeded);
+ CHECK(delete_succeeded);
connect_job_delegate = nullptr;
@@ -346,21 +350,24 @@ void WebSocketTransportClientSocketPool::InvokeUserCallbackLater(
CompletionOnceCallback callback,
int rv) {
const auto handle_id = reinterpret_cast<ClientSocketHandleID>(handle);
- DCHECK(!pending_callbacks_.count(handle_id));
+ CHECK(!pending_callbacks_.count(handle_id));
pending_callbacks_.insert(handle_id);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&WebSocketTransportClientSocketPool::InvokeUserCallback,
- weak_factory_.GetWeakPtr(), handle_id, std::move(callback),
- rv));
+ weak_factory_.GetWeakPtr(), handle_id,
+ handle->GetWeakPtr(), std::move(callback), rv));
}
void WebSocketTransportClientSocketPool::InvokeUserCallback(
ClientSocketHandleID handle_id,
+ base::WeakPtr<ClientSocketHandle> weak_handle,
CompletionOnceCallback callback,
int rv) {
- if (pending_callbacks_.erase(handle_id))
+ if (pending_callbacks_.erase(handle_id)) {
+ CHECK(weak_handle);
std::move(callback).Run(rv);
+ }
}
bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const {
@@ -396,7 +403,7 @@ void WebSocketTransportClientSocketPool::AddJob(
pending_connects_
.insert(PendingConnectsMap::value_type(handle, std::move(delegate)))
.second;
- DCHECK(inserted);
+ CHECK(inserted);
}
bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) {
diff --git a/net/socket/websocket_transport_client_socket_pool.h b/net/socket/websocket_transport_client_socket_pool.h
index fea8b4445..2750816af 100644
--- a/net/socket/websocket_transport_client_socket_pool.h
+++ b/net/socket/websocket_transport_client_socket_pool.h
@@ -191,6 +191,7 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool
CompletionOnceCallback callback,
int rv);
void InvokeUserCallback(ClientSocketHandleID handle_id,
+ base::WeakPtr<ClientSocketHandle> weak_handle,
CompletionOnceCallback callback,
int rv);
bool ReachedMaxSocketsLimit() const;
diff --git a/net/socket/websocket_transport_client_socket_pool_unittest.cc b/net/socket/websocket_transport_client_socket_pool_unittest.cc
index e29dc56f5..787d08226 100644
--- a/net/socket/websocket_transport_client_socket_pool_unittest.cc
+++ b/net/socket/websocket_transport_client_socket_pool_unittest.cc
@@ -4,11 +4,11 @@
#include "net/socket/websocket_transport_client_socket_pool.h"
+#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
-#include "base/cxx17_backports.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"