// Copyright 2018 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_basic_stream_adapters.h" #include #include #include #include #include "base/check.h" #include "base/containers/span.h" #include "base/functional/bind.h" #include "base/functional/callback.h" #include "base/memory/raw_ptr.h" #include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" #include "base/run_loop.h" #include "base/strings/strcat.h" #include "base/strings/string_piece.h" #include "base/task/single_thread_task_runner.h" #include "base/time/default_tick_clock.h" #include "base/time/time.h" #include "net/base/host_port_pair.h" #include "net/base/io_buffer.h" #include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/network_anonymization_key.h" #include "net/base/network_handle.h" #include "net/base/privacy_mode.h" #include "net/base/proxy_chain.h" #include "net/base/request_priority.h" #include "net/base/test_completion_callback.h" #include "net/cert/cert_verify_result.h" #include "net/dns/public/host_resolver_results.h" #include "net/dns/public/secure_dns_policy.h" #include "net/http/http_network_session.h" #include "net/http/transport_security_state.h" #include "net/log/net_log.h" #include "net/log/net_log_with_source.h" #include "net/quic/address_utils.h" #include "net/quic/crypto/proof_verifier_chromium.h" #include "net/quic/mock_crypto_client_stream_factory.h" #include "net/quic/mock_quic_data.h" #include "net/quic/quic_chromium_alarm_factory.h" #include "net/quic/quic_chromium_client_session.h" #include "net/quic/quic_chromium_client_session_peer.h" #include "net/quic/quic_chromium_connection_helper.h" #include "net/quic/quic_chromium_packet_reader.h" #include "net/quic/quic_chromium_packet_writer.h" #include "net/quic/quic_context.h" #include "net/quic/quic_http_utils.h" #include "net/quic/quic_server_info.h" #include "net/quic/quic_session_key.h" #include "net/quic/quic_test_packet_maker.h" #include "net/quic/test_quic_crypto_client_config_handle.h" #include "net/quic/test_task_runner.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool.h" #include "net/socket/next_proto.h" #include "net/socket/socket_tag.h" #include "net/socket/socket_test_util.h" #include "net/socket/stream_socket.h" #include "net/spdy/spdy_session_key.h" #include "net/spdy/spdy_test_util_common.h" #include "net/ssl/ssl_config.h" #include "net/ssl/ssl_config_service_defaults.h" #include "net/ssl/ssl_info.h" #include "net/test/cert_test_util.h" #include "net/test/gtest_util.h" #include "net/test/test_data_directory.h" #include "net/test/test_with_task_environment.h" #include "net/third_party/quiche/src/quiche/common/http/http_header_block.h" #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h" #include "net/third_party/quiche/src/quiche/common/quiche_buffer_allocator.h" #include "net/third_party/quiche/src/quiche/common/simple_buffer_allocator.h" #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h" #include "net/third_party/quiche/src/quiche/quic/core/http/http_encoder.h" #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h" #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h" #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h" #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h" #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h" #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h" #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h" #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h" #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h" #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/websockets/websocket_test_util.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" #include "url/scheme_host_port.h" #include "url/url_constants.h" namespace net { class QuicChromiumClientStream; class SpdySession; class WebSocketEndpointLockManager; class X509Certificate; } // namespace net using testing::_; using testing::AnyNumber; using testing::Invoke; using testing::Return; using testing::StrictMock; using testing::Test; namespace net::test { class WebSocketClientSocketHandleAdapterTest : public TestWithTaskEnvironment { protected: WebSocketClientSocketHandleAdapterTest() : network_session_( SpdySessionDependencies::SpdyCreateSession(&session_deps_)), websocket_endpoint_lock_manager_( network_session_->websocket_endpoint_lock_manager()) {} ~WebSocketClientSocketHandleAdapterTest() override = default; bool InitClientSocketHandle(ClientSocketHandle* connection) { auto ssl_config_for_origin = std::make_unique(); ssl_config_for_origin->alpn_protos = {kProtoHTTP11}; scoped_refptr socks_params = base::MakeRefCounted( std::move(ssl_config_for_origin)); TestCompletionCallback callback; int rv = connection->Init( ClientSocketPool::GroupId( url::SchemeHostPort(url::kHttpsScheme, "www.example.org", 443), PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(), SecureDnsPolicy::kAllow, /*disable_cert_network_fetches=*/false), socks_params, /*proxy_annotation_tag=*/TRAFFIC_ANNOTATION_FOR_TESTS, MEDIUM, SocketTag(), ClientSocketPool::RespectLimits::ENABLED, callback.callback(), ClientSocketPool::ProxyAuthCallback(), network_session_->GetSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL, ProxyChain::Direct()), NetLogWithSource()); rv = callback.GetResult(rv); return rv == OK; } SpdySessionDependencies session_deps_; std::unique_ptr network_session_; raw_ptr websocket_endpoint_lock_manager_; }; TEST_F(WebSocketClientSocketHandleAdapterTest, Uninitialized) { auto connection = std::make_unique(); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_FALSE(adapter.is_initialized()); } TEST_F(WebSocketClientSocketHandleAdapterTest, IsInitialized) { StaticSocketDataProvider data; session_deps_.socket_factory->AddSocketDataProvider(&data); SSLSocketDataProvider ssl_socket_data(ASYNC, OK); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data); auto connection = std::make_unique(); ClientSocketHandle* const connection_ptr = connection.get(); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_FALSE(adapter.is_initialized()); EXPECT_TRUE(InitClientSocketHandle(connection_ptr)); EXPECT_TRUE(adapter.is_initialized()); } TEST_F(WebSocketClientSocketHandleAdapterTest, Disconnect) { StaticSocketDataProvider data; session_deps_.socket_factory->AddSocketDataProvider(&data); SSLSocketDataProvider ssl_socket_data(ASYNC, OK); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data); auto connection = std::make_unique(); EXPECT_TRUE(InitClientSocketHandle(connection.get())); StreamSocket* const socket = connection->socket(); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_TRUE(adapter.is_initialized()); EXPECT_TRUE(socket->IsConnected()); adapter.Disconnect(); EXPECT_FALSE(socket->IsConnected()); } TEST_F(WebSocketClientSocketHandleAdapterTest, Read) { MockRead reads[] = {MockRead(SYNCHRONOUS, "foo"), MockRead("bar")}; StaticSocketDataProvider data(reads, base::span()); session_deps_.socket_factory->AddSocketDataProvider(&data); SSLSocketDataProvider ssl_socket_data(ASYNC, OK); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data); auto connection = std::make_unique(); EXPECT_TRUE(InitClientSocketHandle(connection.get())); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_TRUE(adapter.is_initialized()); // Buffer larger than each MockRead. const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); int rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(3, rv); EXPECT_EQ("foo", base::StringPiece(read_buf->data(), rv)); TestCompletionCallback callback; rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_EQ("bar", base::StringPiece(read_buf->data(), rv)); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketClientSocketHandleAdapterTest, ReadIntoSmallBuffer) { MockRead reads[] = {MockRead(SYNCHRONOUS, "foo"), MockRead("bar")}; StaticSocketDataProvider data(reads, base::span()); session_deps_.socket_factory->AddSocketDataProvider(&data); SSLSocketDataProvider ssl_socket_data(ASYNC, OK); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data); auto connection = std::make_unique(); EXPECT_TRUE(InitClientSocketHandle(connection.get())); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_TRUE(adapter.is_initialized()); // Buffer smaller than each MockRead. const int kReadBufSize = 2; auto read_buf = base::MakeRefCounted(kReadBufSize); int rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(2, rv); EXPECT_EQ("fo", base::StringPiece(read_buf->data(), rv)); rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(1, rv); EXPECT_EQ("o", base::StringPiece(read_buf->data(), rv)); TestCompletionCallback callback1; rv = adapter.Read(read_buf.get(), kReadBufSize, callback1.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback1.WaitForResult(); ASSERT_EQ(2, rv); EXPECT_EQ("ba", base::StringPiece(read_buf->data(), rv)); rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(1, rv); EXPECT_EQ("r", base::StringPiece(read_buf->data(), rv)); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketClientSocketHandleAdapterTest, Write) { MockWrite writes[] = {MockWrite(SYNCHRONOUS, "foo"), MockWrite("bar")}; StaticSocketDataProvider data(base::span(), writes); session_deps_.socket_factory->AddSocketDataProvider(&data); SSLSocketDataProvider ssl_socket_data(ASYNC, OK); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data); auto connection = std::make_unique(); EXPECT_TRUE(InitClientSocketHandle(connection.get())); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_TRUE(adapter.is_initialized()); auto write_buf1 = base::MakeRefCounted("foo"); int rv = adapter.Write(write_buf1.get(), write_buf1->size(), CompletionOnceCallback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_EQ(3, rv); auto write_buf2 = base::MakeRefCounted("bar"); TestCompletionCallback callback; rv = adapter.Write(write_buf2.get(), write_buf2->size(), callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } // Test that if both Read() and Write() returns asynchronously, // the two callbacks are handled correctly. TEST_F(WebSocketClientSocketHandleAdapterTest, AsyncReadAndWrite) { MockRead reads[] = {MockRead("foobar")}; MockWrite writes[] = {MockWrite("baz")}; StaticSocketDataProvider data(reads, writes); session_deps_.socket_factory->AddSocketDataProvider(&data); SSLSocketDataProvider ssl_socket_data(ASYNC, OK); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data); auto connection = std::make_unique(); EXPECT_TRUE(InitClientSocketHandle(connection.get())); WebSocketClientSocketHandleAdapter adapter(std::move(connection)); EXPECT_TRUE(adapter.is_initialized()); const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback read_callback; int rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); auto write_buf = base::MakeRefCounted("baz"); TestCompletionCallback write_callback; rv = adapter.Write(write_buf.get(), write_buf->size(), write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = read_callback.WaitForResult(); ASSERT_EQ(6, rv); EXPECT_EQ("foobar", base::StringPiece(read_buf->data(), rv)); rv = write_callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } class MockDelegate : public WebSocketSpdyStreamAdapter::Delegate { public: ~MockDelegate() override = default; MOCK_METHOD(void, OnHeadersSent, (), (override)); MOCK_METHOD(void, OnHeadersReceived, (const spdy::Http2HeaderBlock&), (override)); MOCK_METHOD(void, OnClose, (int), (override)); }; class WebSocketSpdyStreamAdapterTest : public TestWithTaskEnvironment { protected: WebSocketSpdyStreamAdapterTest() : url_("wss://www.example.org/"), key_(HostPortPair::FromURL(url_), ProxyChain::Direct(), PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse, SocketTag(), NetworkAnonymizationKey(), SecureDnsPolicy::kAllow), session_(SpdySessionDependencies::SpdyCreateSession(&session_deps_)), ssl_(SYNCHRONOUS, OK) {} ~WebSocketSpdyStreamAdapterTest() override = default; static spdy::Http2HeaderBlock RequestHeaders() { return WebSocketHttp2Request("/", "www.example.org:443", "http://www.example.org", {}); } static spdy::Http2HeaderBlock ResponseHeaders() { return WebSocketHttp2Response({}); } void AddSocketData(SocketDataProvider* data) { session_deps_.socket_factory->AddSocketDataProvider(data); } void AddSSLSocketData() { ssl_.ssl_info.cert = ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"); ASSERT_TRUE(ssl_.ssl_info.cert); session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_); } base::WeakPtr CreateSpdySession() { return ::net::CreateSpdySession(session_.get(), key_, NetLogWithSource()); } base::WeakPtr CreateSpdyStream( base::WeakPtr session) { return CreateStreamSynchronously(SPDY_BIDIRECTIONAL_STREAM, session, url_, LOWEST, NetLogWithSource()); } SpdyTestUtil spdy_util_; StrictMock mock_delegate_; private: const GURL url_; const SpdySessionKey key_; SpdySessionDependencies session_deps_; std::unique_ptr session_; SSLSocketDataProvider ssl_; }; TEST_F(WebSocketSpdyStreamAdapterTest, Disconnect) { MockRead reads[] = {MockRead(ASYNC, ERR_IO_PENDING, 0), MockRead(ASYNC, 0, 1)}; SequencedSocketData data(reads, base::span()); AddSocketData(&data); AddSSLSocketData(); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); base::RunLoop().RunUntilIdle(); EXPECT_TRUE(stream); adapter.Disconnect(); EXPECT_FALSE(stream); // Read EOF. EXPECT_TRUE(session); data.Resume(); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, SendRequestHeadersThenDisconnect) { MockRead reads[] = {MockRead(ASYNC, ERR_IO_PENDING, 0), MockRead(ASYNC, 0, 3)}; spdy::SpdySerializedFrame headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); spdy::SpdySerializedFrame rst( spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL)); MockWrite writes[] = {CreateMockWrite(headers, 1), CreateMockWrite(rst, 2)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // First read is a pause and it has lower sequence number than first write. // Therefore writing headers does not complete while |data| is paused. base::RunLoop().RunUntilIdle(); // Reset the stream before writing completes. // OnHeadersSent() will never be called. EXPECT_TRUE(stream); adapter.Disconnect(); EXPECT_FALSE(stream); // Resume |data|, finish writing headers, and read EOF. EXPECT_TRUE(session); data.Resume(); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersSentThenDisconnect) { MockRead reads[] = {MockRead(ASYNC, 0, 2)}; spdy::SpdySerializedFrame headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); spdy::SpdySerializedFrame rst( spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL)); MockWrite writes[] = {CreateMockWrite(headers, 0), CreateMockWrite(rst, 1)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // Finish asynchronous write of headers. This calls OnHeadersSent(). base::RunLoop().RunUntilIdle(); EXPECT_TRUE(stream); adapter.Disconnect(); EXPECT_FALSE(stream); // Read EOF. EXPECT_TRUE(session); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersReceivedThenDisconnect) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); MockRead reads[] = {CreateMockRead(response_headers, 1), MockRead(ASYNC, 0, 3)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); spdy::SpdySerializedFrame rst( spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL)); MockWrite writes[] = {CreateMockWrite(request_headers, 0), CreateMockWrite(rst, 2)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); base::RunLoop().RunUntilIdle(); EXPECT_TRUE(stream); adapter.Disconnect(); EXPECT_FALSE(stream); // Read EOF. EXPECT_TRUE(session); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, ServerClosesConnection) { MockRead reads[] = {MockRead(ASYNC, 0, 0)}; SequencedSocketData data(reads, base::span()); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); EXPECT_TRUE(session); EXPECT_TRUE(stream); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, SendRequestHeadersThenServerClosesConnection) { MockRead reads[] = {MockRead(ASYNC, 0, 1)}; spdy::SpdySerializedFrame headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_TRUE(session); EXPECT_TRUE(stream); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersReceivedThenServerClosesConnection) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); MockRead reads[] = {CreateMockRead(response_headers, 1), MockRead(ASYNC, 0, 2)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_TRUE(session); EXPECT_TRUE(stream); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } // Previously we failed to detect a half-close by the server that indicated the // stream should be closed. This test ensures a half-close is correctly // detected. See https://crbug.com/1151393. TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersReceivedThenStreamEnd) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); spdy::SpdySerializedFrame stream_end( spdy_util_.ConstructSpdyDataFrame(1, "", true)); MockRead reads[] = {CreateMockRead(response_headers, 1), CreateMockRead(stream_end, 2), MockRead(ASYNC, ERR_IO_PENDING, 3), // pause here MockRead(ASYNC, 0, 4)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, /* fin = */ false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); constexpr int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback read_callback; rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_TRUE(session); EXPECT_TRUE(stream); rv = read_callback.WaitForResult(); EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); EXPECT_TRUE(session); EXPECT_FALSE(stream); // Close the session. data.Resume(); base::RunLoop().RunUntilIdle(); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, DetachDelegate) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); MockRead reads[] = {CreateMockRead(response_headers, 1), MockRead(ASYNC, 0, 2)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); // No Delegate methods shall be called after this. adapter.DetachDelegate(); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_TRUE(session); EXPECT_TRUE(stream); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, Read) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); // First read is the same size as the buffer, next is smaller, last is larger. spdy::SpdySerializedFrame data_frame1( spdy_util_.ConstructSpdyDataFrame(1, "foo", false)); spdy::SpdySerializedFrame data_frame2( spdy_util_.ConstructSpdyDataFrame(1, "ba", false)); spdy::SpdySerializedFrame data_frame3( spdy_util_.ConstructSpdyDataFrame(1, "rbaz", true)); MockRead reads[] = {CreateMockRead(response_headers, 1), CreateMockRead(data_frame1, 2), CreateMockRead(data_frame2, 3), CreateMockRead(data_frame3, 4), MockRead(ASYNC, 0, 5)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); const int kReadBufSize = 3; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback callback; rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_EQ("foo", base::StringPiece(read_buf->data(), rv)); // Read EOF to destroy the connection and the stream. // This calls SpdySession::Delegate::OnClose(). EXPECT_TRUE(session); EXPECT_TRUE(stream); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); // Two socket reads are concatenated by WebSocketSpdyStreamAdapter. rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(3, rv); EXPECT_EQ("bar", base::StringPiece(read_buf->data(), rv)); rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(3, rv); EXPECT_EQ("baz", base::StringPiece(read_buf->data(), rv)); // Even though connection and stream are already closed, // WebSocketSpdyStreamAdapter::Delegate::OnClose() is only called after all // buffered data are read. EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED)); base::RunLoop().RunUntilIdle(); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, CallDelegateOnCloseShouldNotCrash) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); spdy::SpdySerializedFrame data_frame1( spdy_util_.ConstructSpdyDataFrame(1, "foo", false)); spdy::SpdySerializedFrame data_frame2( spdy_util_.ConstructSpdyDataFrame(1, "bar", false)); spdy::SpdySerializedFrame rst( spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL)); MockRead reads[] = {CreateMockRead(response_headers, 1), CreateMockRead(data_frame1, 2), CreateMockRead(data_frame2, 3), CreateMockRead(rst, 4), MockRead(ASYNC, 0, 5)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // Buffer larger than each MockRead. const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback callback; rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_EQ("foo", base::StringPiece(read_buf->data(), rv)); // Read RST_STREAM to destroy the stream. // This calls SpdySession::Delegate::OnClose(). EXPECT_TRUE(session); EXPECT_TRUE(stream); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); // Read remaining buffered data. This will PostTask CallDelegateOnClose(). rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(3, rv); EXPECT_EQ("bar", base::StringPiece(read_buf->data(), rv)); adapter.DetachDelegate(); // Run CallDelegateOnClose(), which should not crash // even if |delegate_| is null. base::RunLoop().RunUntilIdle(); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, Write) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); MockRead reads[] = {CreateMockRead(response_headers, 1), MockRead(ASYNC, 0, 3)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); spdy::SpdySerializedFrame data_frame( spdy_util_.ConstructSpdyDataFrame(1, "foo", false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0), CreateMockWrite(data_frame, 2)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, nullptr, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); base::RunLoop().RunUntilIdle(); auto write_buf = base::MakeRefCounted("foo"); TestCompletionCallback callback; rv = adapter.Write(write_buf.get(), write_buf->size(), callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback.WaitForResult(); ASSERT_EQ(3, rv); // Read EOF. base::RunLoop().RunUntilIdle(); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } // Test that if both Read() and Write() returns asynchronously, // the two callbacks are handled correctly. TEST_F(WebSocketSpdyStreamAdapterTest, AsyncReadAndWrite) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); spdy::SpdySerializedFrame read_data_frame( spdy_util_.ConstructSpdyDataFrame(1, "foobar", true)); MockRead reads[] = {CreateMockRead(response_headers, 1), CreateMockRead(read_data_frame, 3), MockRead(ASYNC, 0, 4)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); spdy::SpdySerializedFrame write_data_frame( spdy_util_.ConstructSpdyDataFrame(1, "baz", false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0), CreateMockWrite(write_data_frame, 2)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, nullptr, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); base::RunLoop().RunUntilIdle(); const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback read_callback; rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); auto write_buf = base::MakeRefCounted("baz"); TestCompletionCallback write_callback; rv = adapter.Write(write_buf.get(), write_buf->size(), write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = read_callback.WaitForResult(); ASSERT_EQ(6, rv); EXPECT_EQ("foobar", base::StringPiece(read_buf->data(), rv)); rv = write_callback.WaitForResult(); ASSERT_EQ(3, rv); // Read EOF. base::RunLoop().RunUntilIdle(); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } // A helper class that will delete |adapter| when the callback is invoked. class KillerCallback : public TestCompletionCallbackBase { public: explicit KillerCallback(std::unique_ptr adapter) : adapter_(std::move(adapter)) {} ~KillerCallback() override = default; CompletionOnceCallback callback() { return base::BindOnce(&KillerCallback::OnComplete, base::Unretained(this)); } private: void OnComplete(int result) { adapter_.reset(); SetResult(result); } std::unique_ptr adapter_; }; TEST_F(WebSocketSpdyStreamAdapterTest, ReadCallbackDestroysAdapter) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); MockRead reads[] = {CreateMockRead(response_headers, 1), MockRead(ASYNC, ERR_IO_PENDING, 2), MockRead(ASYNC, 0, 3)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); auto adapter = std::make_unique( stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter->is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // Send headers. base::RunLoop().RunUntilIdle(); WebSocketSpdyStreamAdapter* adapter_raw = adapter.get(); KillerCallback callback(std::move(adapter)); const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); rv = adapter_raw->Read(read_buf.get(), kReadBufSize, callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // Read EOF while read is pending. WebSocketSpdyStreamAdapter::OnClose() // should not crash if read callback destroys |adapter|. data.Resume(); rv = callback.WaitForResult(); EXPECT_THAT(rv, IsError(ERR_CONNECTION_CLOSED)); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, WriteCallbackDestroysAdapter) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); MockRead reads[] = {CreateMockRead(response_headers, 1), MockRead(ASYNC, ERR_IO_PENDING, 2), MockRead(ASYNC, 0, 3)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); auto adapter = std::make_unique( stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter->is_initialized()); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // Send headers. base::RunLoop().RunUntilIdle(); WebSocketSpdyStreamAdapter* adapter_raw = adapter.get(); KillerCallback callback(std::move(adapter)); auto write_buf = base::MakeRefCounted("foo"); rv = adapter_raw->Write(write_buf.get(), write_buf->size(), callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); // Read EOF while write is pending. WebSocketSpdyStreamAdapter::OnClose() // should not crash if write callback destroys |adapter|. data.Resume(); rv = callback.WaitForResult(); EXPECT_THAT(rv, IsError(ERR_CONNECTION_CLOSED)); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(session); EXPECT_FALSE(stream); EXPECT_TRUE(data.AllReadDataConsumed()); EXPECT_TRUE(data.AllWriteDataConsumed()); } TEST_F(WebSocketSpdyStreamAdapterTest, OnCloseOkShouldBeTranslatedToConnectionClose) { spdy::SpdySerializedFrame response_headers( spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false)); spdy::SpdySerializedFrame close( spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_NO_ERROR)); MockRead reads[] = {CreateMockRead(response_headers, 1), CreateMockRead(close, 2), MockRead(ASYNC, 0, 3)}; spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders( 1, RequestHeaders(), DEFAULT_PRIORITY, false)); MockWrite writes[] = {CreateMockWrite(request_headers, 0)}; SequencedSocketData data(reads, writes); AddSocketData(&data); AddSSLSocketData(); EXPECT_CALL(mock_delegate_, OnHeadersSent()); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)); base::WeakPtr session = CreateSpdySession(); base::WeakPtr stream = CreateSpdyStream(session); WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_, NetLogWithSource()); EXPECT_TRUE(adapter.is_initialized()); EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED)); int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback callback; rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback()); EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); rv = callback.WaitForResult(); ASSERT_EQ(ERR_CONNECTION_CLOSED, rv); } class MockQuicDelegate : public WebSocketQuicStreamAdapter::Delegate { public: ~MockQuicDelegate() override = default; MOCK_METHOD(void, OnHeadersSent, (), (override)); MOCK_METHOD(void, OnHeadersReceived, (const spdy::Http2HeaderBlock&), (override)); MOCK_METHOD(void, OnClose, (int), (override)); }; class WebSocketQuicStreamAdapterTest : public TestWithTaskEnvironment, public ::testing::WithParamInterface { protected: static spdy::Http2HeaderBlock RequestHeaders() { return WebSocketHttp2Request("/", "www.example.org:443", "http://www.example.org", {}); } WebSocketQuicStreamAdapterTest() : version_(GetParam()), mock_quic_data_(version_), client_data_stream_id1_(quic::QuicUtils::GetFirstBidirectionalStreamId( version_.transport_version, quic::Perspective::IS_CLIENT)), crypto_config_( quic::test::crypto_test_utils::ProofVerifierForTesting()), connection_id_(quic::test::TestConnectionId(2)), client_maker_(version_, connection_id_, &clock_, "mail.example.org", quic::Perspective::IS_CLIENT), server_maker_(version_, connection_id_, &clock_, "mail.example.org", quic::Perspective::IS_SERVER), peer_addr_(IPAddress(192, 0, 2, 23), 443), destination_endpoint_(url::kHttpsScheme, "mail.example.org", 80) {} ~WebSocketQuicStreamAdapterTest() override = default; void SetUp() override { FLAGS_quic_enable_http3_grease_randomness = false; clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20)); quic::QuicEnableVersion(version_); } void TearDown() override { EXPECT_TRUE(mock_quic_data_.AllReadDataConsumed()); EXPECT_TRUE(mock_quic_data_.AllWriteDataConsumed()); } net::QuicChromiumClientSession::Handle* GetQuicSessionHandle() { return session_handle_.get(); } // Helper functions for constructing packets sent by the client std::unique_ptr ConstructSettingsPacket( uint64_t packet_number) { return client_maker_.MakeInitialSettingsPacket(packet_number); } std::unique_ptr ConstructServerDataPacket( uint64_t packet_number, base::StringPiece data) { quiche::QuicheBuffer buffer = quic::HttpEncoder::SerializeDataFrameHeader( data.size(), quiche::SimpleBufferAllocator::Get()); return server_maker_.MakeDataPacket( packet_number, client_data_stream_id1_, /*fin=*/false, base::StrCat({base::StringPiece(buffer.data(), buffer.size()), data})); } std::unique_ptr ConstructRstPacket( uint64_t packet_number, quic::QuicRstStreamErrorCode error_code) { return client_maker_.MakeRstPacket(packet_number, client_data_stream_id1_, error_code, /*include_stop_sending_if_v99=*/true); } std::unique_ptr ConstructClientAckPacket( uint64_t packet_number, uint64_t largest_received, uint64_t smallest_received) { return client_maker_.MakeAckPacket(packet_number, largest_received, smallest_received); } std::unique_ptr ConstructAckAndRstPacket( uint64_t packet_number, quic::QuicRstStreamErrorCode error_code, uint64_t largest_received, uint64_t smallest_received) { return client_maker_.MakeAckAndRstPacket( packet_number, client_data_stream_id1_, error_code, largest_received, smallest_received, /*include_stop_sending_if_v99=*/true); } void Initialize() { auto socket = std::make_unique( mock_quic_data_.InitializeAndGetSequencedSocketData(), NetLog::Get()); socket->Connect(peer_addr_); runner_ = base::MakeRefCounted(&clock_); helper_ = std::make_unique( &clock_, &random_generator_); alarm_factory_ = std::make_unique(runner_.get(), &clock_); // Ownership of 'writer' is passed to 'QuicConnection'. QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter( socket.get(), base::SingleThreadTaskRunner::GetCurrentDefault().get()); quic::QuicConnection* connection = new quic::QuicConnection( connection_id_, quic::QuicSocketAddress(), net::ToQuicSocketAddress(peer_addr_), helper_.get(), alarm_factory_.get(), writer, true /* owns_writer */, quic::Perspective::IS_CLIENT, quic::test::SupportedVersions(version_), connection_id_generator_); connection->set_visitor(&visitor_); // Load a certificate that is valid for *.example.org scoped_refptr test_cert( ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem")); EXPECT_TRUE(test_cert.get()); verify_details_.cert_verify_result.verified_cert = test_cert; verify_details_.cert_verify_result.is_issued_by_known_root = true; crypto_client_stream_factory_.AddProofVerifyDetails(&verify_details_); base::TimeTicks dns_end = base::TimeTicks::Now(); base::TimeTicks dns_start = dns_end - base::Milliseconds(1); session_ = std::make_unique( connection, std::move(socket), /*stream_factory=*/nullptr, &crypto_client_stream_factory_, &clock_, &transport_security_state_, &ssl_config_service_, /*server_info=*/nullptr, QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED, SocketTag(), NetworkAnonymizationKey(), SecureDnsPolicy::kAllow, /*require_dns_https_alpn=*/false), /*require_confirmation=*/false, /*migrate_session_early_v2=*/false, /*migrate_session_on_network_change_v2=*/false, /*default_network=*/handles::kInvalidNetworkHandle, quic::QuicTime::Delta::FromMilliseconds( kDefaultRetransmittableOnWireTimeout.InMilliseconds()), /*migrate_idle_session=*/true, /*allow_port_migration=*/false, kDefaultIdleSessionMigrationPeriod, /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork, kMaxMigrationsToNonDefaultNetworkOnWriteError, kMaxMigrationsToNonDefaultNetworkOnPathDegrading, kQuicYieldAfterPacketsRead, quic::QuicTime::Delta::FromMilliseconds( kQuicYieldAfterDurationMilliseconds), /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(), std::make_unique(&crypto_config_), dns_start, dns_end, base::DefaultTickClock::GetInstance(), base::SingleThreadTaskRunner::GetCurrentDefault().get(), /*socket_performance_watcher=*/nullptr, HostResolverEndpointResult(), NetLog::Get()); session_->Initialize(); // Blackhole QPACK decoder stream instead of constructing mock writes. session_->qpack_decoder()->set_qpack_stream_sender_delegate( &noop_qpack_stream_sender_delegate_); TestCompletionCallback callback; EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk()); EXPECT_TRUE(session_->OneRttKeysAvailable()); session_handle_ = session_->CreateHandle( url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80)); } const quic::ParsedQuicVersion version_; MockQuicData mock_quic_data_; StrictMock mock_delegate_; const quic::QuicStreamId client_data_stream_id1_; private: quic::QuicCryptoClientConfig crypto_config_; const quic::QuicConnectionId connection_id_; protected: QuicTestPacketMaker client_maker_; QuicTestPacketMaker server_maker_; std::unique_ptr session_; private: quic::MockClock clock_; std::unique_ptr session_handle_; scoped_refptr runner_; ProofVerifyDetailsChromium verify_details_; MockCryptoClientStreamFactory crypto_client_stream_factory_; SSLConfigServiceDefaults ssl_config_service_; quic::test::MockConnectionIdGenerator connection_id_generator_; std::unique_ptr helper_; std::unique_ptr alarm_factory_; testing::StrictMock visitor_; TransportSecurityState transport_security_state_; IPAddress ip_; IPEndPoint peer_addr_; quic::test::MockRandom random_generator_{0}; url::SchemeHostPort destination_endpoint_; quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_; }; // Like net::TestCompletionCallback, but for a callback that takes an unbound // parameter of type WebSocketQuicStreamAdapter. struct WebSocketQuicStreamAdapterIsPendingHelper { bool operator()( const std::unique_ptr& adapter) const { return !adapter; } }; using TestWebSocketQuicStreamAdapterCompletionCallbackBase = net::internal::TestCompletionCallbackTemplate< std::unique_ptr, WebSocketQuicStreamAdapterIsPendingHelper>; class TestWebSocketQuicStreamAdapterCompletionCallback : public TestWebSocketQuicStreamAdapterCompletionCallbackBase { public: base::OnceCallback)> callback(); }; base::OnceCallback)> TestWebSocketQuicStreamAdapterCompletionCallback::callback() { return base::BindOnce( &TestWebSocketQuicStreamAdapterCompletionCallback::SetResult, base::Unretained(this)); } INSTANTIATE_TEST_SUITE_P(QuicVersion, WebSocketQuicStreamAdapterTest, ::testing::ValuesIn(AllSupportedQuicVersions()), ::testing::PrintToStringParamName()); TEST_P(WebSocketQuicStreamAdapterTest, Disconnect) { int packet_number = 1; mock_quic_data_.AddWrite(SYNCHRONOUS, ConstructSettingsPacket(packet_number++)); mock_quic_data_.AddWrite( SYNCHRONOUS, ConstructRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED)); Initialize(); net::QuicChromiumClientSession::Handle* session_handle = GetQuicSessionHandle(); ASSERT_TRUE(session_handle); TestWebSocketQuicStreamAdapterCompletionCallback callback; std::unique_ptr adapter = session_handle->CreateWebSocketQuicStreamAdapter( &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_TRUE(adapter); EXPECT_TRUE(adapter->is_initialized()); adapter->Disconnect(); // TODO(momoka): Add tests to test both destruction orders. } TEST_P(WebSocketQuicStreamAdapterTest, AsyncAdapterCreation) { const size_t kMaxOpenStreams = 50; int packet_number = 1; mock_quic_data_.AddWrite(SYNCHRONOUS, ConstructSettingsPacket(packet_number++)); mock_quic_data_.AddWrite(SYNCHRONOUS, client_maker_.MakeStreamsBlockedPacket( packet_number++, kMaxOpenStreams, /* unidirectional = */ false)); mock_quic_data_.AddRead( ASYNC, server_maker_.MakeMaxStreamsPacket(1, kMaxOpenStreams + 2, /* unidirectional = */ false)); mock_quic_data_.AddRead(ASYNC, ERR_IO_PENDING); mock_quic_data_.AddRead(ASYNC, ERR_CONNECTION_CLOSED); Initialize(); std::vector streams; for (size_t i = 0; i < kMaxOpenStreams; i++) { QuicChromiumClientStream* stream = QuicChromiumClientSessionPeer::CreateOutgoingStream(session_.get()); ASSERT_TRUE(stream); streams.push_back(stream); EXPECT_EQ(i + 1, session_->GetNumActiveStreams()); } net::QuicChromiumClientSession::Handle* session_handle = GetQuicSessionHandle(); ASSERT_TRUE(session_handle); // Creating an adapter should fail because of the stream limit. TestWebSocketQuicStreamAdapterCompletionCallback callback; std::unique_ptr adapter = session_handle->CreateWebSocketQuicStreamAdapter( &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_EQ(adapter, nullptr); EXPECT_FALSE(callback.have_result()); EXPECT_EQ(kMaxOpenStreams, session_->GetNumActiveStreams()); // Read MAX_STREAMS frame that makes it possible to open WebSocket stream. session_->StartReading(); callback.WaitForResult(); EXPECT_EQ(kMaxOpenStreams + 1, session_->GetNumActiveStreams()); // Close connection. mock_quic_data_.Resume(); base::RunLoop().RunUntilIdle(); } TEST_P(WebSocketQuicStreamAdapterTest, SendRequestHeadersThenDisconnect) { int packet_number = 1; mock_quic_data_.AddWrite(SYNCHRONOUS, ConstructSettingsPacket(packet_number++)); SpdyTestUtil spdy_util; spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request( "/", "www.example.org:443", "http://www.example.org", {}); mock_quic_data_.AddWrite( SYNCHRONOUS, client_maker_.MakeRequestHeadersPacket( packet_number++, client_data_stream_id1_, /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST), std::move(request_header_block), nullptr)); mock_quic_data_.AddWrite( SYNCHRONOUS, ConstructRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED)); Initialize(); net::QuicChromiumClientSession::Handle* session_handle = GetQuicSessionHandle(); ASSERT_TRUE(session_handle); TestWebSocketQuicStreamAdapterCompletionCallback callback; std::unique_ptr adapter = session_handle->CreateWebSocketQuicStreamAdapter( &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_TRUE(adapter); EXPECT_TRUE(adapter->is_initialized()); adapter->WriteHeaders(RequestHeaders(), false); adapter->Disconnect(); } TEST_P(WebSocketQuicStreamAdapterTest, OnHeadersReceivedThenDisconnect) { int packet_number = 1; mock_quic_data_.AddWrite(SYNCHRONOUS, ConstructSettingsPacket(packet_number++)); SpdyTestUtil spdy_util; spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request( "/", "www.example.org:443", "http://www.example.org", {}); mock_quic_data_.AddWrite( SYNCHRONOUS, client_maker_.MakeRequestHeadersPacket( packet_number++, client_data_stream_id1_, /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST), std::move(request_header_block), nullptr)); spdy::Http2HeaderBlock response_header_block = WebSocketHttp2Response({}); mock_quic_data_.AddRead( ASYNC, server_maker_.MakeResponseHeadersPacket( /*packet_number=*/1, client_data_stream_id1_, /*fin=*/false, std::move(response_header_block), /*spdy_headers_frame_length=*/nullptr)); mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING); mock_quic_data_.AddWrite( SYNCHRONOUS, ConstructAckAndRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED, 1, 0)); base::RunLoop run_loop; auto quit_closure = run_loop.QuitClosure(); EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)).WillOnce(Invoke([&]() { std::move(quit_closure).Run(); })); Initialize(); net::QuicChromiumClientSession::Handle* session_handle = GetQuicSessionHandle(); ASSERT_TRUE(session_handle); TestWebSocketQuicStreamAdapterCompletionCallback callback; std::unique_ptr adapter = session_handle->CreateWebSocketQuicStreamAdapter( &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_TRUE(adapter); EXPECT_TRUE(adapter->is_initialized()); adapter->WriteHeaders(RequestHeaders(), false); session_->StartReading(); run_loop.Run(); adapter->Disconnect(); } TEST_P(WebSocketQuicStreamAdapterTest, Read) { int packet_number = 1; mock_quic_data_.AddWrite(SYNCHRONOUS, ConstructSettingsPacket(packet_number++)); SpdyTestUtil spdy_util; spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request( "/", "www.example.org:443", "http://www.example.org", {}); mock_quic_data_.AddWrite( SYNCHRONOUS, client_maker_.MakeRequestHeadersPacket( packet_number++, client_data_stream_id1_, /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST), std::move(request_header_block), nullptr)); spdy::Http2HeaderBlock response_header_block = WebSocketHttp2Response({}); mock_quic_data_.AddRead( ASYNC, server_maker_.MakeResponseHeadersPacket( /*packet_number=*/1, client_data_stream_id1_, /*fin=*/false, std::move(response_header_block), /*spdy_headers_frame_length=*/nullptr)); mock_quic_data_.AddRead(ASYNC, ERR_IO_PENDING); mock_quic_data_.AddRead(ASYNC, ConstructServerDataPacket(2, "foo")); mock_quic_data_.AddRead(SYNCHRONOUS, ConstructServerDataPacket(3, "hogehoge")); mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING); mock_quic_data_.AddWrite(ASYNC, ConstructClientAckPacket(packet_number++, 2, 0)); mock_quic_data_.AddWrite( SYNCHRONOUS, ConstructAckAndRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED, 3, 0)); base::RunLoop run_loop; EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)).WillOnce(Invoke([&]() { run_loop.Quit(); })); Initialize(); net::QuicChromiumClientSession::Handle* session_handle = GetQuicSessionHandle(); ASSERT_TRUE(session_handle); TestWebSocketQuicStreamAdapterCompletionCallback callback; std::unique_ptr adapter = session_handle->CreateWebSocketQuicStreamAdapter( &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_TRUE(adapter); EXPECT_TRUE(adapter->is_initialized()); adapter->WriteHeaders(RequestHeaders(), false); session_->StartReading(); run_loop.Run(); // Buffer larger than each MockRead. const int kReadBufSize = 1024; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback read_callback; int rv = adapter->Read(read_buf.get(), kReadBufSize, read_callback.callback()); ASSERT_EQ(ERR_IO_PENDING, rv); mock_quic_data_.GetSequencedSocketData()->Resume(); base::RunLoop().RunUntilIdle(); rv = read_callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_EQ("foo", base::StringPiece(read_buf->data(), rv)); rv = adapter->Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(8, rv); EXPECT_EQ("hogehoge", base::StringPiece(read_buf->data(), rv)); adapter->Disconnect(); EXPECT_TRUE(mock_quic_data_.AllReadDataConsumed()); EXPECT_TRUE(mock_quic_data_.AllWriteDataConsumed()); } TEST_P(WebSocketQuicStreamAdapterTest, ReadIntoSmallBuffer) { int packet_number = 1; mock_quic_data_.AddWrite(SYNCHRONOUS, ConstructSettingsPacket(packet_number++)); SpdyTestUtil spdy_util; spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request( "/", "www.example.org:443", "http://www.example.org", {}); mock_quic_data_.AddWrite( SYNCHRONOUS, client_maker_.MakeRequestHeadersPacket( packet_number++, client_data_stream_id1_, /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST), std::move(request_header_block), nullptr)); spdy::Http2HeaderBlock response_header_block = WebSocketHttp2Response({}); mock_quic_data_.AddRead( ASYNC, server_maker_.MakeResponseHeadersPacket( /*packet_number=*/1, client_data_stream_id1_, /*fin=*/false, std::move(response_header_block), /*spdy_headers_frame_length=*/nullptr)); mock_quic_data_.AddRead(ASYNC, ERR_IO_PENDING); // First read is the same size as the buffer, next is smaller, last is larger. mock_quic_data_.AddRead(ASYNC, ConstructServerDataPacket(2, "abc")); mock_quic_data_.AddRead(SYNCHRONOUS, ConstructServerDataPacket(3, "12")); mock_quic_data_.AddRead(SYNCHRONOUS, ConstructServerDataPacket(4, "ABCD")); mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING); mock_quic_data_.AddWrite(ASYNC, ConstructClientAckPacket(packet_number++, 2, 0)); mock_quic_data_.AddWrite( SYNCHRONOUS, ConstructAckAndRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED, 4, 0)); base::RunLoop run_loop; EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)).WillOnce(Invoke([&]() { run_loop.Quit(); })); Initialize(); net::QuicChromiumClientSession::Handle* session_handle = GetQuicSessionHandle(); ASSERT_TRUE(session_handle); TestWebSocketQuicStreamAdapterCompletionCallback callback; std::unique_ptr adapter = session_handle->CreateWebSocketQuicStreamAdapter( &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS); ASSERT_TRUE(adapter); EXPECT_TRUE(adapter->is_initialized()); adapter->WriteHeaders(RequestHeaders(), false); session_->StartReading(); run_loop.Run(); const int kReadBufSize = 3; auto read_buf = base::MakeRefCounted(kReadBufSize); TestCompletionCallback read_callback; int rv = adapter->Read(read_buf.get(), kReadBufSize, read_callback.callback()); ASSERT_EQ(ERR_IO_PENDING, rv); mock_quic_data_.GetSequencedSocketData()->Resume(); base::RunLoop().RunUntilIdle(); rv = read_callback.WaitForResult(); ASSERT_EQ(3, rv); EXPECT_EQ("abc", base::StringPiece(read_buf->data(), rv)); rv = adapter->Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(3, rv); EXPECT_EQ("12A", base::StringPiece(read_buf->data(), rv)); rv = adapter->Read(read_buf.get(), kReadBufSize, CompletionOnceCallback()); ASSERT_EQ(3, rv); EXPECT_EQ("BCD", base::StringPiece(read_buf->data(), rv)); adapter->Disconnect(); EXPECT_TRUE(mock_quic_data_.AllReadDataConsumed()); EXPECT_TRUE(mock_quic_data_.AllWriteDataConsumed()); } } // namespace net::test