aboutsummaryrefslogtreecommitdiff
path: root/src/net
diff options
context:
space:
mode:
Diffstat (limited to 'src/net')
-rw-r--r--src/net/cgo_stub.go2
-rw-r--r--src/net/cgo_unix.go17
-rw-r--r--src/net/cgo_unix_cgo.go2
-rw-r--r--src/net/cgo_unix_syscall.go5
-rw-r--r--src/net/conf.go36
-rw-r--r--src/net/conn_test.go78
-rw-r--r--src/net/dial.go32
-rw-r--r--src/net/dial_test.go61
-rw-r--r--src/net/dnsclient.go6
-rw-r--r--src/net/dnsclient_unix.go67
-rw-r--r--src/net/dnsclient_unix_test.go97
-rw-r--r--src/net/dnsconfig_unix.go2
-rw-r--r--src/net/dnsname_test.go2
-rw-r--r--src/net/error_plan9_test.go1
-rw-r--r--src/net/error_posix.go2
-rw-r--r--src/net/error_test.go272
-rw-r--r--src/net/error_unix_test.go1
-rw-r--r--src/net/error_windows_test.go1
-rw-r--r--src/net/external_test.go2
-rw-r--r--src/net/fd_fake.go169
-rw-r--r--src/net/fd_js.go28
-rw-r--r--src/net/fd_wasip1.go159
-rw-r--r--src/net/fd_windows.go48
-rw-r--r--src/net/file_stub.go2
-rw-r--r--src/net/file_test.go10
-rw-r--r--src/net/hook.go9
-rw-r--r--src/net/hook_plan9.go6
-rw-r--r--src/net/hook_unix.go5
-rw-r--r--src/net/hook_windows.go4
-rw-r--r--src/net/hosts.go2
-rw-r--r--src/net/hosts_test.go26
-rw-r--r--src/net/http/cgi/cgi_main.go145
-rw-r--r--src/net/http/cgi/child.go6
-rw-r--r--src/net/http/cgi/host.go14
-rw-r--r--src/net/http/cgi/host_test.go243
-rw-r--r--src/net/http/cgi/integration_test.go67
-rw-r--r--src/net/http/cgi/plan9_test.go17
-rw-r--r--src/net/http/cgi/posix_test.go20
-rwxr-xr-xsrc/net/http/cgi/testdata/test.cgi95
-rw-r--r--src/net/http/client.go120
-rw-r--r--src/net/http/client_test.go51
-rw-r--r--src/net/http/clientserver_test.go6
-rw-r--r--src/net/http/cookie.go4
-rw-r--r--src/net/http/cookiejar/jar.go14
-rw-r--r--src/net/http/cookiejar/jar_test.go10
-rw-r--r--src/net/http/doc.go16
-rw-r--r--src/net/http/export_test.go18
-rw-r--r--src/net/http/fcgi/child.go2
-rw-r--r--src/net/http/filetransport.go25
-rw-r--r--src/net/http/filetransport_test.go42
-rw-r--r--src/net/http/fs.go117
-rw-r--r--src/net/http/fs_test.go113
-rw-r--r--src/net/http/h2_bundle.go205
-rw-r--r--src/net/http/h2_error.go1
-rw-r--r--src/net/http/h2_error_test.go1
-rw-r--r--src/net/http/header.go14
-rw-r--r--src/net/http/http.go6
-rw-r--r--src/net/http/http_test.go7
-rw-r--r--src/net/http/httptest/httptest.go2
-rw-r--r--src/net/http/httptest/recorder.go18
-rw-r--r--src/net/http/httptest/server.go12
-rw-r--r--src/net/http/httptrace/trace.go4
-rw-r--r--src/net/http/httputil/dump.go10
-rw-r--r--src/net/http/httputil/httputil.go2
-rw-r--r--src/net/http/httputil/persist.go32
-rw-r--r--src/net/http/httputil/reverseproxy.go15
-rw-r--r--src/net/http/internal/ascii/print.go2
-rw-r--r--src/net/http/internal/chunked.go45
-rw-r--r--src/net/http/internal/chunked_test.go60
-rw-r--r--src/net/http/mapping.go78
-rw-r--r--src/net/http/mapping_test.go154
-rw-r--r--src/net/http/pattern.go529
-rw-r--r--src/net/http/pattern_test.go494
-rw-r--r--src/net/http/pprof/pprof.go4
-rw-r--r--src/net/http/pprof/pprof_test.go63
-rw-r--r--src/net/http/pprof/testdata/delta_mutex.go43
-rw-r--r--src/net/http/request.go152
-rw-r--r--src/net/http/request_test.go186
-rw-r--r--src/net/http/response.go10
-rw-r--r--src/net/http/response_test.go5
-rw-r--r--src/net/http/responsecontroller.go12
-rw-r--r--src/net/http/responsecontroller_test.go4
-rw-r--r--src/net/http/roundtrip.go4
-rw-r--r--src/net/http/roundtrip_js.go14
-rw-r--r--src/net/http/routing_index.go124
-rw-r--r--src/net/http/routing_index_test.go179
-rw-r--r--src/net/http/routing_tree.go240
-rw-r--r--src/net/http/routing_tree_test.go295
-rw-r--r--src/net/http/serve_test.go664
-rw-r--r--src/net/http/servemux121.go211
-rw-r--r--src/net/http/server.go718
-rw-r--r--src/net/http/server_test.go192
-rw-r--r--src/net/http/transfer.go49
-rw-r--r--src/net/http/transfer_test.go16
-rw-r--r--src/net/http/transport.go67
-rw-r--r--src/net/http/transport_test.go325
-rw-r--r--src/net/http/triv.go2
-rw-r--r--src/net/interface.go4
-rw-r--r--src/net/interface_stub.go2
-rw-r--r--src/net/interface_test.go2
-rw-r--r--src/net/internal/socktest/main_test.go2
-rw-r--r--src/net/internal/socktest/main_windows_test.go22
-rw-r--r--src/net/internal/socktest/switch.go2
-rw-r--r--src/net/internal/socktest/sys_unix.go2
-rw-r--r--src/net/internal/socktest/sys_windows.go44
-rw-r--r--src/net/ip.go14
-rw-r--r--src/net/ip_test.go2
-rw-r--r--src/net/iprawsock.go18
-rw-r--r--src/net/iprawsock_posix.go2
-rw-r--r--src/net/iprawsock_test.go2
-rw-r--r--src/net/ipsock.go8
-rw-r--r--src/net/ipsock_plan9.go1
-rw-r--r--src/net/ipsock_posix.go18
-rw-r--r--src/net/listen_test.go2
-rw-r--r--src/net/lookup.go88
-rw-r--r--src/net/lookup_fake.go58
-rw-r--r--src/net/lookup_plan9.go105
-rw-r--r--src/net/lookup_test.go188
-rw-r--r--src/net/lookup_unix.go30
-rw-r--r--src/net/lookup_windows.go94
-rw-r--r--src/net/mail/message.go30
-rw-r--r--src/net/mail/message_test.go40
-rw-r--r--src/net/main_conf_test.go2
-rw-r--r--src/net/main_noconf_test.go2
-rw-r--r--src/net/main_posix_test.go2
-rw-r--r--src/net/main_test.go17
-rw-r--r--src/net/main_wasm_test.go13
-rw-r--r--src/net/main_windows_test.go3
-rw-r--r--src/net/mockserver_test.go4
-rw-r--r--src/net/net.go72
-rw-r--r--src/net/net_fake.go1280
-rw-r--r--src/net/net_fake_js.go36
-rw-r--r--src/net/net_fake_test.go246
-rw-r--r--src/net/net_test.go24
-rw-r--r--src/net/netip/export_test.go2
-rw-r--r--src/net/netip/leaf_alts.go9
-rw-r--r--src/net/netip/netip.go177
-rw-r--r--src/net/netip/netip_test.go125
-rw-r--r--src/net/packetconn_test.go6
-rw-r--r--src/net/parse.go36
-rw-r--r--src/net/pipe.go2
-rw-r--r--src/net/platform_test.go2
-rw-r--r--src/net/port_unix.go2
-rw-r--r--src/net/protoconn_test.go14
-rw-r--r--src/net/rawconn.go21
-rw-r--r--src/net/rawconn_stub_test.go2
-rw-r--r--src/net/rawconn_test.go6
-rw-r--r--src/net/resolverdialfunc_test.go2
-rw-r--r--src/net/rlimit_js.go13
-rw-r--r--src/net/rlimit_unix.go33
-rw-r--r--src/net/rpc/client.go16
-rw-r--r--src/net/rpc/jsonrpc/client.go4
-rw-r--r--src/net/rpc/jsonrpc/server.go2
-rw-r--r--src/net/rpc/server.go67
-rw-r--r--src/net/sendfile_linux_test.go25
-rw-r--r--src/net/sendfile_stub.go2
-rw-r--r--src/net/sendfile_test.go90
-rw-r--r--src/net/sendfile_unix_alt.go10
-rw-r--r--src/net/server_test.go207
-rw-r--r--src/net/smtp/auth.go4
-rw-r--r--src/net/smtp/smtp.go14
-rw-r--r--src/net/sock_posix.go41
-rw-r--r--src/net/sock_stub.go2
-rw-r--r--src/net/sock_windows.go20
-rw-r--r--src/net/sockaddr_posix.go26
-rw-r--r--src/net/sockopt_fake.go (renamed from src/net/sockopt_stub.go)11
-rw-r--r--src/net/sockopt_posix.go29
-rw-r--r--src/net/sockoptip_stub.go2
-rw-r--r--src/net/sockoptip_windows.go3
-rw-r--r--src/net/splice_linux.go40
-rw-r--r--src/net/splice_stub.go6
-rw-r--r--src/net/splice_test.go21
-rw-r--r--src/net/tcpsock.go49
-rw-r--r--src/net/tcpsock_plan9.go4
-rw-r--r--src/net/tcpsock_posix.go11
-rw-r--r--src/net/tcpsock_test.go54
-rw-r--r--src/net/tcpsock_unix_test.go2
-rw-r--r--src/net/tcpsockopt_stub.go2
-rw-r--r--src/net/textproto/header.go4
-rw-r--r--src/net/textproto/reader.go68
-rw-r--r--src/net/textproto/reader_test.go12
-rw-r--r--src/net/textproto/textproto.go20
-rw-r--r--src/net/textproto/writer.go4
-rw-r--r--src/net/timeout_test.go202
-rw-r--r--src/net/udpsock.go2
-rw-r--r--src/net/udpsock_posix.go2
-rw-r--r--src/net/udpsock_test.go67
-rw-r--r--src/net/unixsock.go39
-rw-r--r--src/net/unixsock_posix.go2
-rw-r--r--src/net/unixsock_readmsg_other.go2
-rw-r--r--src/net/unixsock_test.go11
-rw-r--r--src/net/url/url.go55
-rw-r--r--src/net/url/url_test.go1
-rw-r--r--src/net/writev_test.go10
194 files changed, 8413 insertions, 3354 deletions
diff --git a/src/net/cgo_stub.go b/src/net/cgo_stub.go
index b26b11af8b..a4f6b4b0e8 100644
--- a/src/net/cgo_stub.go
+++ b/src/net/cgo_stub.go
@@ -9,7 +9,7 @@
// (Darwin always provides the cgo functions, in cgo_unix_syscall.go)
// - on wasip1, where cgo is never available
-//go:build (netgo && unix) || (unix && !cgo && !darwin) || wasip1
+//go:build (netgo && unix) || (unix && !cgo && !darwin) || js || wasip1
package net
diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go
index f10f3ea60b..7ed5daad73 100644
--- a/src/net/cgo_unix.go
+++ b/src/net/cgo_unix.go
@@ -80,7 +80,7 @@ func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error)
func cgoLookupPort(ctx context.Context, network, service string) (port int, err error) {
var hints _C_struct_addrinfo
switch network {
- case "": // no hints
+ case "ip": // no hints
case "tcp", "tcp4", "tcp6":
*_C_ai_socktype(&hints) = _C_SOCK_STREAM
*_C_ai_protocol(&hints) = _C_IPPROTO_TCP
@@ -120,6 +120,8 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p
if err == nil { // see golang.org/issue/6232
err = syscall.EMFILE
}
+ case _C_EAI_SERVICE, _C_EAI_NONAME: // Darwin returns EAI_NONAME.
+ return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true}
default:
err = addrinfoErrno(gerrno)
isTemporary = addrinfoErrno(gerrno).Temporary()
@@ -140,7 +142,7 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p
return int(p[0])<<8 | int(p[1]), nil
}
}
- return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}
+ return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true}
}
func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
@@ -317,8 +319,15 @@ func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, err
acquireThread()
defer releaseThread()
- state := (*_C_struct___res_state)(_C_malloc(unsafe.Sizeof(_C_struct___res_state{})))
- defer _C_free(unsafe.Pointer(state))
+ resStateSize := unsafe.Sizeof(_C_struct___res_state{})
+ var state *_C_struct___res_state
+ if resStateSize > 0 {
+ mem := _C_malloc(resStateSize)
+ defer _C_free(mem)
+ memSlice := unsafe.Slice((*byte)(mem), resStateSize)
+ clear(memSlice)
+ state = (*_C_struct___res_state)(unsafe.Pointer(&memSlice[0]))
+ }
if err := _C_res_ninit(state); err != nil {
return nil, errors.New("res_ninit failure: " + err.Error())
}
diff --git a/src/net/cgo_unix_cgo.go b/src/net/cgo_unix_cgo.go
index d11f3e301a..7c609eddbf 100644
--- a/src/net/cgo_unix_cgo.go
+++ b/src/net/cgo_unix_cgo.go
@@ -37,6 +37,7 @@ const (
_C_EAI_AGAIN = C.EAI_AGAIN
_C_EAI_NODATA = C.EAI_NODATA
_C_EAI_NONAME = C.EAI_NONAME
+ _C_EAI_SERVICE = C.EAI_SERVICE
_C_EAI_OVERFLOW = C.EAI_OVERFLOW
_C_EAI_SYSTEM = C.EAI_SYSTEM
_C_IPPROTO_TCP = C.IPPROTO_TCP
@@ -55,7 +56,6 @@ type (
_C_struct_sockaddr = C.struct_sockaddr
)
-func _C_GoString(p *_C_char) string { return C.GoString(p) }
func _C_malloc(n uintptr) unsafe.Pointer { return C.malloc(C.size_t(n)) }
func _C_free(p unsafe.Pointer) { C.free(p) }
diff --git a/src/net/cgo_unix_syscall.go b/src/net/cgo_unix_syscall.go
index 2eb8df1da6..ac9aaa78fe 100644
--- a/src/net/cgo_unix_syscall.go
+++ b/src/net/cgo_unix_syscall.go
@@ -19,6 +19,7 @@ const (
_C_AF_UNSPEC = syscall.AF_UNSPEC
_C_EAI_AGAIN = unix.EAI_AGAIN
_C_EAI_NONAME = unix.EAI_NONAME
+ _C_EAI_SERVICE = unix.EAI_SERVICE
_C_EAI_NODATA = unix.EAI_NODATA
_C_EAI_OVERFLOW = unix.EAI_OVERFLOW
_C_EAI_SYSTEM = unix.EAI_SYSTEM
@@ -39,10 +40,6 @@ type (
_C_struct_sockaddr = syscall.RawSockaddr
)
-func _C_GoString(p *_C_char) string {
- return unix.GoString(p)
-}
-
func _C_free(p unsafe.Pointer) { runtime.KeepAlive(p) }
func _C_malloc(n uintptr) unsafe.Pointer {
diff --git a/src/net/conf.go b/src/net/conf.go
index 77cc635592..15d73cf6ce 100644
--- a/src/net/conf.go
+++ b/src/net/conf.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js
-
package net
import (
@@ -153,7 +151,7 @@ func initConfVal() {
}
}
-// goosPreferCgo reports whether the GOOS value passed in prefers
+// goosPrefersCgo reports whether the GOOS value passed in prefers
// the cgo resolver.
func goosPrefersCgo() bool {
switch runtime.GOOS {
@@ -185,7 +183,24 @@ func goosPrefersCgo() bool {
// required to use the go resolver. The provided Resolver is optional.
// This will report true if the cgo resolver is not available.
func (c *conf) mustUseGoResolver(r *Resolver) bool {
- return c.netGo || r.preferGo() || !cgoAvailable
+ if !cgoAvailable {
+ return true
+ }
+
+ if runtime.GOOS == "plan9" {
+ // TODO(bradfitz): for now we only permit use of the PreferGo
+ // implementation when there's a non-nil Resolver with a
+ // non-nil Dialer. This is a sign that they the code is trying
+ // to use their DNS-speaking net.Conn (such as an in-memory
+ // DNS cache) and they don't want to actually hit the network.
+ // Once we add support for looking the default DNS servers
+ // from plan9, though, then we can relax this.
+ if r == nil || r.Dial == nil {
+ return false
+ }
+ }
+
+ return c.netGo || r.preferGo()
}
// addrLookupOrder determines which strategy to use to resolve addresses.
@@ -221,16 +236,7 @@ func (c *conf) lookupOrder(r *Resolver, hostname string) (ret hostLookupOrder, d
// Go resolver was explicitly requested
// or cgo resolver is not available.
// Figure out the order below.
- switch c.goos {
- case "windows":
- // TODO(bradfitz): implement files-based
- // lookup on Windows too? I guess /etc/hosts
- // kinda exists on Windows. But for now, only
- // do DNS.
- fallbackOrder = hostLookupDNS
- default:
- fallbackOrder = hostLookupFilesDNS
- }
+ fallbackOrder = hostLookupFilesDNS
canUseCgo = false
} else if c.netCgo {
// Cgo resolver was explicitly requested.
@@ -516,7 +522,7 @@ func isGateway(h string) bool {
return stringsEqualFold(h, "_gateway")
}
-// isOutbound reports whether h should be considered a "outbound"
+// isOutbound reports whether h should be considered an "outbound"
// name for the myhostname NSS module.
func isOutbound(h string) bool {
return stringsEqualFold(h, "_outbound")
diff --git a/src/net/conn_test.go b/src/net/conn_test.go
index 4f391b0675..d1e1e7bf1c 100644
--- a/src/net/conn_test.go
+++ b/src/net/conn_test.go
@@ -2,10 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// This file implements API tests across platforms and will never have a build
-// tag.
-
-//go:build !js && !wasip1
+// This file implements API tests across platforms and should never have a build
+// constraint.
package net
@@ -21,44 +19,46 @@ const someTimeout = 1 * time.Hour
func TestConnAndListener(t *testing.T) {
for i, network := range []string{"tcp", "unix", "unixpacket"} {
- if !testableNetwork(network) {
- t.Logf("skipping %s test", network)
- continue
- }
+ i, network := i, network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("skipping %s test", network)
+ }
- ls := newLocalServer(t, network)
- defer ls.teardown()
- ch := make(chan error, 1)
- handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
- if err := ls.buildup(handler); err != nil {
- t.Fatal(err)
- }
- if ls.Listener.Addr().Network() != network {
- t.Fatalf("got %s; want %s", ls.Listener.Addr().Network(), network)
- }
+ ls := newLocalServer(t, network)
+ defer ls.teardown()
+ ch := make(chan error, 1)
+ handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ if ls.Listener.Addr().Network() != network {
+ t.Fatalf("got %s; want %s", ls.Listener.Addr().Network(), network)
+ }
- c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
- t.Fatalf("got %s->%s; want %s->%s", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
- }
- c.SetDeadline(time.Now().Add(someTimeout))
- c.SetReadDeadline(time.Now().Add(someTimeout))
- c.SetWriteDeadline(time.Now().Add(someTimeout))
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
+ t.Fatalf("got %s->%s; want %s->%s", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
+ }
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
- if _, err := c.Write([]byte("CONN AND LISTENER TEST")); err != nil {
- t.Fatal(err)
- }
- rb := make([]byte, 128)
- if _, err := c.Read(rb); err != nil {
- t.Fatal(err)
- }
+ if _, err := c.Write([]byte("CONN AND LISTENER TEST")); err != nil {
+ t.Fatal(err)
+ }
+ rb := make([]byte, 128)
+ if _, err := c.Read(rb); err != nil {
+ t.Fatal(err)
+ }
- for err := range ch {
- t.Errorf("#%d: %v", i, err)
- }
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ })
}
}
diff --git a/src/net/dial.go b/src/net/dial.go
index 79bc4958bb..a6565c3ce5 100644
--- a/src/net/dial.go
+++ b/src/net/dial.go
@@ -6,6 +6,7 @@ package net
import (
"context"
+ "internal/bytealg"
"internal/godebug"
"internal/nettrace"
"syscall"
@@ -64,7 +65,7 @@ func (m *mptcpStatus) set(use bool) {
//
// The zero value for each field is equivalent to dialing
// without that option. Dialing with the zero value of Dialer
-// is therefore equivalent to just calling the Dial function.
+// is therefore equivalent to just calling the [Dial] function.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
@@ -226,7 +227,7 @@ func (d *Dialer) fallbackDelay() time.Duration {
}
func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
- i := last(network, ':')
+ i := bytealg.LastIndexByteString(network, ':')
if i < 0 { // no colon
switch network {
case "tcp", "tcp4", "tcp6":
@@ -337,7 +338,7 @@ func (d *Dialer) MultipathTCP() bool {
return d.mptcpStatus.get()
}
-// SetMultipathTCP directs the Dial methods to use, or not use, MPTCP,
+// SetMultipathTCP directs the [Dial] methods to use, or not use, MPTCP,
// if supported by the operating system. This method overrides the
// system default and the GODEBUG=multipathtcp=... setting if any.
//
@@ -362,7 +363,7 @@ func (d *Dialer) SetMultipathTCP(use bool) {
// brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80".
// The zone specifies the scope of the literal IPv6 address as defined
// in RFC 4007.
-// The functions JoinHostPort and SplitHostPort manipulate a pair of
+// The functions [JoinHostPort] and [SplitHostPort] manipulate a pair of
// host and port in this form.
// When using TCP, and the host resolves to multiple IP addresses,
// Dial will try each IP address in order until one succeeds.
@@ -400,7 +401,7 @@ func Dial(network, address string) (Conn, error) {
return d.Dial(network, address)
}
-// DialTimeout acts like Dial but takes a timeout.
+// DialTimeout acts like [Dial] but takes a timeout.
//
// The timeout includes name resolution, if required.
// When using TCP, and the host in the address parameter resolves to
@@ -427,8 +428,8 @@ type sysDialer struct {
// See func Dial for a description of the network and address
// parameters.
//
-// Dial uses context.Background internally; to specify the context, use
-// DialContext.
+// Dial uses [context.Background] internally; to specify the context, use
+// [Dialer.DialContext].
func (d *Dialer) Dial(network, address string) (Conn, error) {
return d.DialContext(context.Background(), network, address)
}
@@ -449,7 +450,7 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
// the connect to each single address will be given 15 seconds to complete
// before trying the next one.
//
-// See func Dial for a description of the network and address
+// See func [Dial] for a description of the network and address
// parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
if ctx == nil {
@@ -457,6 +458,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
}
deadline := d.deadline(ctx, time.Now())
if !deadline.IsZero() {
+ testHookStepTime()
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
subCtx, cancel := context.WithDeadline(ctx, deadline)
defer cancel()
@@ -698,7 +700,7 @@ func (lc *ListenConfig) MultipathTCP() bool {
return lc.mptcpStatus.get()
}
-// SetMultipathTCP directs the Listen method to use, or not use, MPTCP,
+// SetMultipathTCP directs the [Listen] method to use, or not use, MPTCP,
// if supported by the operating system. This method overrides the
// system default and the GODEBUG=multipathtcp=... setting if any.
//
@@ -793,14 +795,14 @@ type sysListener struct {
// addresses.
// If the port in the address parameter is empty or "0", as in
// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen.
-// The Addr method of Listener can be used to discover the chosen
+// The [Addr] method of [Listener] can be used to discover the chosen
// port.
//
-// See func Dial for a description of the network and address
+// See func [Dial] for a description of the network and address
// parameters.
//
// Listen uses context.Background internally; to specify the context, use
-// ListenConfig.Listen.
+// [ListenConfig.Listen].
func Listen(network, address string) (Listener, error) {
var lc ListenConfig
return lc.Listen(context.Background(), network, address)
@@ -823,14 +825,14 @@ func Listen(network, address string) (Listener, error) {
// addresses.
// If the port in the address parameter is empty or "0", as in
// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen.
-// The LocalAddr method of PacketConn can be used to discover the
+// The LocalAddr method of [PacketConn] can be used to discover the
// chosen port.
//
-// See func Dial for a description of the network and address
+// See func [Dial] for a description of the network and address
// parameters.
//
// ListenPacket uses context.Background internally; to specify the context, use
-// ListenConfig.ListenPacket.
+// [ListenConfig.ListenPacket].
func ListenPacket(network, address string) (PacketConn, error) {
var lc ListenConfig
return lc.ListenPacket(context.Background(), network, address)
diff --git a/src/net/dial_test.go b/src/net/dial_test.go
index ca9f0da3d3..1d0832e46e 100644
--- a/src/net/dial_test.go
+++ b/src/net/dial_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -784,6 +782,7 @@ func TestDialCancel(t *testing.T) {
"connection refused",
"unreachable",
"no route to host",
+ "invalid argument",
}
e := err.Error()
for _, ignore := range ignorable {
@@ -982,6 +981,8 @@ func TestDialerControl(t *testing.T) {
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
+ case "js", "wasip1":
+ t.Skipf("skipping: fake net does not support Dialer.Control")
}
t.Run("StreamDial", func(t *testing.T) {
@@ -1025,40 +1026,52 @@ func TestDialerControlContext(t *testing.T) {
switch runtime.GOOS {
case "plan9":
t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ case "js", "wasip1":
+ t.Skipf("skipping: fake net does not support Dialer.ControlContext")
}
t.Run("StreamDial", func(t *testing.T) {
for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
- if !testableNetwork(network) {
- continue
- }
- ln := newLocalListener(t, network)
- defer ln.Close()
- var id int
- d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
- id = ctx.Value("id").(int)
- return controlOnConnSetup(network, address, c)
- }}
- c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
- if err != nil {
- t.Error(err)
- continue
- }
- if id != i+1 {
- t.Errorf("got id %d, want %d", id, i+1)
- }
- c.Close()
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("skipping: %s not available", network)
+ }
+
+ ln := newLocalListener(t, network)
+ defer ln.Close()
+ var id int
+ d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+ id = ctx.Value("id").(int)
+ return controlOnConnSetup(network, address, c)
+ }}
+ c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if id != i+1 {
+ t.Errorf("got id %d, want %d", id, i+1)
+ }
+ c.Close()
+ })
}
})
}
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
-// except that it won't skip testing on non-mobile builders.
+// except on non-Linux, non-mobile builders it permits the test to
+// run in -short mode.
func mustHaveExternalNetwork(t *testing.T) {
t.Helper()
+ definitelyHasLongtestBuilder := runtime.GOOS == "linux"
mobile := runtime.GOOS == "android" || runtime.GOOS == "ios"
- if testenv.Builder() == "" || mobile {
- testenv.MustHaveExternalNetwork(t)
+ fake := runtime.GOOS == "js" || runtime.GOOS == "wasip1"
+ if testenv.Builder() != "" && !definitelyHasLongtestBuilder && !mobile && !fake {
+ // On a non-Linux, non-mobile builder (e.g., freebsd-amd64-13_0).
+ //
+ // Don't skip testing because otherwise the test may never run on
+ // any builder if this port doesn't also have a -longtest builder.
+ return
}
+ testenv.MustHaveExternalNetwork(t)
}
type contextWithNonZeroDeadline struct {
diff --git a/src/net/dnsclient.go b/src/net/dnsclient.go
index b609dbd468..204620b2ed 100644
--- a/src/net/dnsclient.go
+++ b/src/net/dnsclient.go
@@ -8,15 +8,17 @@ import (
"internal/bytealg"
"internal/itoa"
"sort"
+ _ "unsafe" // for go:linkname
"golang.org/x/net/dns/dnsmessage"
)
// provided by runtime
-func fastrandu() uint
+//go:linkname runtime_rand runtime.rand
+func runtime_rand() uint64
func randInt() int {
- return int(fastrandu() >> 1) // clear sign bit
+ return int(uint(runtime_rand()) >> 1) // clear sign bit
}
func randIntn(n int) int {
diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go
index dab5144e5d..c291d5eb4f 100644
--- a/src/net/dnsclient_unix.go
+++ b/src/net/dnsclient_unix.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js
-
// DNS client: see RFC 1035.
// Has to be linked into package net for Dial.
@@ -17,6 +15,7 @@ package net
import (
"context"
"errors"
+ "internal/bytealg"
"internal/itoa"
"io"
"os"
@@ -205,7 +204,9 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
// checkHeader performs basic sanity checks on the header.
func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
- if h.RCode == dnsmessage.RCodeNameError {
+ rcode := extractExtendedRCode(*p, h)
+
+ if rcode == dnsmessage.RCodeNameError {
return errNoSuchHost
}
@@ -216,17 +217,17 @@ func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
- if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ if rcode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
return errLameReferral
}
- if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ if rcode != dnsmessage.RCodeSuccess && rcode != dnsmessage.RCodeNameError {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
- if h.RCode == dnsmessage.RCodeServerFailure {
+ if rcode == dnsmessage.RCodeServerFailure {
return errServerTemporarilyMisbehaving
}
return errServerMisbehaving
@@ -253,6 +254,23 @@ func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
}
}
+// extractExtendedRCode extracts the extended RCode from the OPT resource (EDNS(0))
+// If an OPT record is not found, the RCode from the hdr is returned.
+func extractExtendedRCode(p dnsmessage.Parser, hdr dnsmessage.Header) dnsmessage.RCode {
+ p.SkipAllAnswers()
+ p.SkipAllAuthorities()
+ for {
+ ahdr, err := p.AdditionalHeader()
+ if err != nil {
+ return hdr.RCode
+ }
+ if ahdr.Type == dnsmessage.TypeOPT {
+ return ahdr.ExtendedRCode(hdr.RCode)
+ }
+ p.SkipAdditional()
+ }
+}
+
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
@@ -479,10 +497,6 @@ func avoidDNS(name string) bool {
// nameList returns a list of names for sequential DNS queries.
func (conf *dnsConfig) nameList(name string) []string {
- if avoidDNS(name) {
- return nil
- }
-
// Check name length (see isDomainName).
l := len(name)
rooted := l > 0 && name[l-1] == '.'
@@ -492,27 +506,31 @@ func (conf *dnsConfig) nameList(name string) []string {
// If name is rooted (trailing dot), try only that name.
if rooted {
+ if avoidDNS(name) {
+ return nil
+ }
return []string{name}
}
- hasNdots := count(name, '.') >= conf.ndots
+ hasNdots := bytealg.CountString(name, '.') >= conf.ndots
name += "."
l++
// Build list of search choices.
names := make([]string, 0, 1+len(conf.search))
// If name has enough dots, try unsuffixed first.
- if hasNdots {
+ if hasNdots && !avoidDNS(name) {
names = append(names, name)
}
// Try suffixes that are not too long (see isDomainName).
for _, suffix := range conf.search {
- if l+len(suffix) <= 254 {
- names = append(names, name+suffix)
+ fqdn := name + suffix
+ if !avoidDNS(fqdn) && len(fqdn) <= 254 {
+ names = append(names, fqdn)
}
}
// Try unsuffixed, if not tried first above.
- if !hasNdots {
+ if !hasNdots && !avoidDNS(name) {
names = append(names, name)
}
return names
@@ -586,8 +604,7 @@ func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) {
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
-func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
- order, conf := systemConf().hostLookupOrder(r, host)
+func (r *Resolver) goLookupIP(ctx context.Context, network, host string, order hostLookupOrder, conf *dnsConfig) (addrs []IPAddr, err error) {
addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
return
}
@@ -699,7 +716,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
h, err := result.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result.server,
}
@@ -712,7 +729,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
a, err := result.p.AResource()
if err != nil {
lastErr = &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result.server,
}
@@ -727,7 +744,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
aaaa, err := result.p.AAAAResource()
if err != nil {
lastErr = &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result.server,
}
@@ -742,7 +759,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
c, err := result.p.CNAMEResource()
if err != nil {
lastErr = &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result.server,
}
@@ -755,7 +772,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
default:
if err := result.p.SkipAnswer(); err != nil {
lastErr = &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result.server,
}
@@ -847,7 +864,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
}
if err != nil {
return nil, &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: addr,
Server: server,
}
@@ -856,7 +873,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
err := p.SkipAnswer()
if err != nil {
return nil, &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: addr,
Server: server,
}
@@ -866,7 +883,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
ptr, err := p.PTRResource()
if err != nil {
return nil, &DNSError{
- Err: "cannot marshal DNS message",
+ Err: errCannotUnmarshalDNSMessage.Error(),
Name: addr,
Server: server,
}
diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go
index 8d435a557f..0da36303cc 100644
--- a/src/net/dnsclient_unix_test.go
+++ b/src/net/dnsclient_unix_test.go
@@ -10,12 +10,12 @@ import (
"context"
"errors"
"fmt"
- "internal/testenv"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
+ "slices"
"strings"
"sync"
"sync/atomic"
@@ -191,6 +191,19 @@ func TestAvoidDNSName(t *testing.T) {
}
}
+func TestNameListAvoidDNS(t *testing.T) {
+ c := &dnsConfig{search: []string{"go.dev.", "onion."}}
+ got := c.nameList("www")
+ if !slices.Equal(got, []string{"www.", "www.go.dev."}) {
+ t.Fatalf(`nameList("www") = %v, want "www.", "www.go.dev."`, got)
+ }
+
+ got = c.nameList("www.onion")
+ if !slices.Equal(got, []string{"www.onion.go.dev."}) {
+ t.Fatalf(`nameList("www.onion") = %v, want "www.onion.go.dev."`, got)
+ }
+}
+
var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
@@ -221,7 +234,7 @@ var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.M
func TestLookupTorOnion(t *testing.T) {
defer dnsWaitGroup.Wait()
r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
- addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
+ addrs, err := r.LookupIPAddr(context.Background(), "foo.onion.")
if err != nil {
t.Fatalf("lookup = %v; want nil", err)
}
@@ -606,8 +619,8 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
t.Fatal(err)
}
// Redirect host file lookups.
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
- testHookHostsPath = "testdata/hosts"
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
+ hostsFilePath = "testdata/hosts"
for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
name := fmt.Sprintf("order %v", order)
@@ -1953,8 +1966,8 @@ func TestCVE202133195(t *testing.T) {
DefaultResolver = &r
defer func() { DefaultResolver = originalDefault }()
// Redirect host file lookups.
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
- testHookHostsPath = "testdata/hosts"
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
+ hostsFilePath = "testdata/hosts"
tests := []struct {
name string
@@ -2173,8 +2186,8 @@ func TestRootNS(t *testing.T) {
}
func TestGoLookupIPCNAMEOrderHostsAliasesFilesOnlyMode(t *testing.T) {
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
- testHookHostsPath = "testdata/aliases"
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
+ hostsFilePath = "testdata/aliases"
mode := hostLookupFiles
for _, v := range lookupStaticHostAliasesTest {
@@ -2183,8 +2196,8 @@ func TestGoLookupIPCNAMEOrderHostsAliasesFilesOnlyMode(t *testing.T) {
}
func TestGoLookupIPCNAMEOrderHostsAliasesFilesDNSMode(t *testing.T) {
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
- testHookHostsPath = "testdata/aliases"
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
+ hostsFilePath = "testdata/aliases"
mode := hostLookupFilesDNS
for _, v := range lookupStaticHostAliasesTest {
@@ -2200,11 +2213,8 @@ var goLookupIPCNAMEOrderDNSFilesModeTests = []struct {
}
func TestGoLookupIPCNAMEOrderHostsAliasesDNSFilesMode(t *testing.T) {
- if testenv.Builder() == "" {
- t.Skip("Makes assumptions about local networks and (re)naming that aren't always true")
- }
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
- testHookHostsPath = "testdata/aliases"
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
+ hostsFilePath = "testdata/aliases"
mode := hostLookupDNSFiles
for _, v := range goLookupIPCNAMEOrderDNSFilesModeTests {
@@ -2213,9 +2223,29 @@ func TestGoLookupIPCNAMEOrderHostsAliasesDNSFilesMode(t *testing.T) {
}
func testGoLookupIPCNAMEOrderHostsAliases(t *testing.T, mode hostLookupOrder, lookup, lookupRes string) {
+ fake := fakeDNSServer{
+ rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ var answers []dnsmessage.Resource
+
+ if mode != hostLookupDNSFiles {
+ t.Fatal("received unexpected DNS query")
+ }
+
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ },
+ Questions: []dnsmessage.Question{q.Questions[0]},
+ Answers: answers,
+ }, nil
+ },
+ }
+
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
ins := []string{lookup, absDomainName(lookup), strings.ToLower(lookup), strings.ToUpper(lookup)}
for _, in := range ins {
- _, res, err := goResolver.goLookupIPCNAMEOrder(context.Background(), "ip", in, mode, nil)
+ _, res, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", in, mode, nil)
if err != nil {
t.Errorf("expected err == nil, but got error: %v", err)
}
@@ -2511,7 +2541,7 @@ func TestDNSConfigNoReload(t *testing.T) {
}
func TestLookupOrderFilesNoSuchHost(t *testing.T) {
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
if runtime.GOOS != "openbsd" {
defer setSystemNSS(getSystemNSS(), 0)
setSystemNSS(nssStr(t, "hosts: files"), time.Hour)
@@ -2538,7 +2568,7 @@ func TestLookupOrderFilesNoSuchHost(t *testing.T) {
if err := os.WriteFile(tmpFile, []byte{}, 0660); err != nil {
t.Fatal(err)
}
- testHookHostsPath = tmpFile
+ hostsFilePath = tmpFile
const testName = "test.invalid"
@@ -2598,3 +2628,34 @@ func TestLookupOrderFilesNoSuchHost(t *testing.T) {
}
}
}
+
+func TestExtendedRCode(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ fraudSuccessCode := dnsmessage.RCodeSuccess | 1<<10
+
+ var edns0Hdr dnsmessage.ResourceHeader
+ edns0Hdr.SetEDNS0(maxDNSPacketSize, fraudSuccessCode, false)
+
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: fraudSuccessCode,
+ },
+ Questions: []dnsmessage.Question{q.Questions[0]},
+ Additionals: []dnsmessage.Resource{{
+ Header: edns0Hdr,
+ Body: &dnsmessage.OPTResource{},
+ }},
+ }, nil
+ },
+ }
+
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+ _, _, err := r.tryOneName(context.Background(), getSystemDNSConfig(), "go.dev.", dnsmessage.TypeA)
+ var dnsErr *DNSError
+ if !(errors.As(err, &dnsErr) && dnsErr.Err == errServerMisbehaving.Error()) {
+ t.Fatalf("r.tryOneName(): unexpected error: %v", err)
+ }
+}
diff --git a/src/net/dnsconfig_unix.go b/src/net/dnsconfig_unix.go
index 69b300410a..b0a318279b 100644
--- a/src/net/dnsconfig_unix.go
+++ b/src/net/dnsconfig_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !windows
+//go:build !windows
// Read system DNS config from /etc/resolv.conf
diff --git a/src/net/dnsname_test.go b/src/net/dnsname_test.go
index 4a5f01a04a..601a33af9f 100644
--- a/src/net/dnsname_test.go
+++ b/src/net/dnsname_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
diff --git a/src/net/error_plan9_test.go b/src/net/error_plan9_test.go
index 1270af19e5..f86c96c0d2 100644
--- a/src/net/error_plan9_test.go
+++ b/src/net/error_plan9_test.go
@@ -7,7 +7,6 @@ package net
import "syscall"
var (
- errTimedout = syscall.ETIMEDOUT
errOpNotSupported = syscall.EPLAN9
abortedConnRequestErrors []error
diff --git a/src/net/error_posix.go b/src/net/error_posix.go
index c8dc069db4..84f8044045 100644
--- a/src/net/error_posix.go
+++ b/src/net/error_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
diff --git a/src/net/error_test.go b/src/net/error_test.go
index 4538765d48..f82e863346 100644
--- a/src/net/error_test.go
+++ b/src/net/error_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -157,32 +155,33 @@ func TestDialError(t *testing.T) {
d := Dialer{Timeout: someTimeout}
for i, tt := range dialErrorTests {
- c, err := d.Dial(tt.network, tt.address)
- if err == nil {
- t.Errorf("#%d: should fail; %s:%s->%s", i, c.LocalAddr().Network(), c.LocalAddr(), c.RemoteAddr())
- c.Close()
- continue
- }
- if tt.network == "tcp" || tt.network == "udp" {
- nerr := err
- if op, ok := nerr.(*OpError); ok {
- nerr = op.Err
+ i, tt := i, tt
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ c, err := d.Dial(tt.network, tt.address)
+ if err == nil {
+ t.Errorf("should fail; %s:%s->%s", c.LocalAddr().Network(), c.LocalAddr(), c.RemoteAddr())
+ c.Close()
+ return
}
- if sys, ok := nerr.(*os.SyscallError); ok {
- nerr = sys.Err
+ if tt.network == "tcp" || tt.network == "udp" {
+ nerr := err
+ if op, ok := nerr.(*OpError); ok {
+ nerr = op.Err
+ }
+ if sys, ok := nerr.(*os.SyscallError); ok {
+ nerr = sys.Err
+ }
+ if nerr == errOpNotSupported {
+ t.Fatalf("should fail without %v; %s:%s->", nerr, tt.network, tt.address)
+ }
}
- if nerr == errOpNotSupported {
- t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
- continue
+ if c != nil {
+ t.Errorf("Dial returned non-nil interface %T(%v) with err != nil", c, c)
}
- }
- if c != nil {
- t.Errorf("Dial returned non-nil interface %T(%v) with err != nil", c, c)
- }
- if err = parseDialError(err); err != nil {
- t.Errorf("#%d: %v", i, err)
- continue
- }
+ if err = parseDialError(err); err != nil {
+ t.Error(err)
+ }
+ })
}
}
@@ -208,10 +207,11 @@ func TestProtocolDialError(t *testing.T) {
t.Errorf("%s: should fail", network)
continue
}
- if err = parseDialError(err); err != nil {
+ if err := parseDialError(err); err != nil {
t.Errorf("%s: %v", network, err)
continue
}
+ t.Logf("%s: error as expected: %v", network, err)
}
}
@@ -220,6 +220,7 @@ func TestDialAddrError(t *testing.T) {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}
+
if !supportsIPv4() || !supportsIPv6() {
t.Skip("both IPv4 and IPv6 are required")
}
@@ -236,38 +237,42 @@ func TestDialAddrError(t *testing.T) {
// control name resolution.
{"tcp6", "", &TCPAddr{IP: IP{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}}},
} {
- var err error
- var c Conn
- var op string
- if tt.lit != "" {
- c, err = Dial(tt.network, JoinHostPort(tt.lit, "0"))
- op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0"))
- } else {
- c, err = DialTCP(tt.network, nil, tt.addr)
- op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr)
- }
- if err == nil {
- c.Close()
- t.Errorf("%s succeeded, want error", op)
- continue
- }
- if perr := parseDialError(err); perr != nil {
- t.Errorf("%s: %v", op, perr)
- continue
- }
- operr := err.(*OpError).Err
- aerr, ok := operr.(*AddrError)
- if !ok {
- t.Errorf("%s: %v is %T, want *AddrError", op, err, operr)
- continue
- }
- want := tt.lit
- if tt.lit == "" {
- want = tt.addr.IP.String()
- }
- if aerr.Addr != want {
- t.Errorf("%s: %v, error Addr=%q, want %q", op, err, aerr.Addr, want)
- }
+ desc := tt.lit
+ if desc == "" {
+ desc = tt.addr.String()
+ }
+ t.Run(fmt.Sprintf("%s/%s", tt.network, desc), func(t *testing.T) {
+ var err error
+ var c Conn
+ var op string
+ if tt.lit != "" {
+ c, err = Dial(tt.network, JoinHostPort(tt.lit, "0"))
+ op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0"))
+ } else {
+ c, err = DialTCP(tt.network, nil, tt.addr)
+ op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr)
+ }
+ t.Logf("%s: %v", op, err)
+ if err == nil {
+ c.Close()
+ t.Fatalf("%s succeeded, want error", op)
+ }
+ if perr := parseDialError(err); perr != nil {
+ t.Fatal(perr)
+ }
+ operr := err.(*OpError).Err
+ aerr, ok := operr.(*AddrError)
+ if !ok {
+ t.Fatalf("OpError.Err is %T, want *AddrError", operr)
+ }
+ want := tt.lit
+ if tt.lit == "" {
+ want = tt.addr.IP.String()
+ }
+ if aerr.Addr != want {
+ t.Errorf("error Addr=%q, want %q", aerr.Addr, want)
+ }
+ })
}
}
@@ -305,32 +310,32 @@ func TestListenError(t *testing.T) {
defer sw.Set(socktest.FilterListen, nil)
for i, tt := range listenErrorTests {
- ln, err := Listen(tt.network, tt.address)
- if err == nil {
- t.Errorf("#%d: should fail; %s:%s->", i, ln.Addr().Network(), ln.Addr())
- ln.Close()
- continue
- }
- if tt.network == "tcp" {
- nerr := err
- if op, ok := nerr.(*OpError); ok {
- nerr = op.Err
+ t.Run(fmt.Sprintf("%s_%s", tt.network, tt.address), func(t *testing.T) {
+ ln, err := Listen(tt.network, tt.address)
+ if err == nil {
+ t.Errorf("#%d: should fail; %s:%s->", i, ln.Addr().Network(), ln.Addr())
+ ln.Close()
+ return
}
- if sys, ok := nerr.(*os.SyscallError); ok {
- nerr = sys.Err
+ if tt.network == "tcp" {
+ nerr := err
+ if op, ok := nerr.(*OpError); ok {
+ nerr = op.Err
+ }
+ if sys, ok := nerr.(*os.SyscallError); ok {
+ nerr = sys.Err
+ }
+ if nerr == errOpNotSupported {
+ t.Fatalf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
+ }
}
- if nerr == errOpNotSupported {
- t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
- continue
+ if ln != nil {
+ t.Errorf("Listen returned non-nil interface %T(%v) with err != nil", ln, ln)
}
- }
- if ln != nil {
- t.Errorf("Listen returned non-nil interface %T(%v) with err != nil", ln, ln)
- }
- if err = parseDialError(err); err != nil {
- t.Errorf("#%d: %v", i, err)
- continue
- }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("#%d: %v", i, err)
+ }
+ })
}
}
@@ -361,19 +366,20 @@ func TestListenPacketError(t *testing.T) {
}
for i, tt := range listenPacketErrorTests {
- c, err := ListenPacket(tt.network, tt.address)
- if err == nil {
- t.Errorf("#%d: should fail; %s:%s->", i, c.LocalAddr().Network(), c.LocalAddr())
- c.Close()
- continue
- }
- if c != nil {
- t.Errorf("ListenPacket returned non-nil interface %T(%v) with err != nil", c, c)
- }
- if err = parseDialError(err); err != nil {
- t.Errorf("#%d: %v", i, err)
- continue
- }
+ t.Run(fmt.Sprintf("%s_%s", tt.network, tt.address), func(t *testing.T) {
+ c, err := ListenPacket(tt.network, tt.address)
+ if err == nil {
+ t.Errorf("#%d: should fail; %s:%s->", i, c.LocalAddr().Network(), c.LocalAddr())
+ c.Close()
+ return
+ }
+ if c != nil {
+ t.Errorf("ListenPacket returned non-nil interface %T(%v) with err != nil", c, c)
+ }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("#%d: %v", i, err)
+ }
+ })
}
}
@@ -557,49 +563,57 @@ third:
}
func TestCloseError(t *testing.T) {
- ln := newLocalListener(t, "tcp")
- defer ln.Close()
- c, err := Dial(ln.Addr().Network(), ln.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
+ t.Run("tcp", func(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
- for i := 0; i < 3; i++ {
- err = c.(*TCPConn).CloseRead()
- if perr := parseCloseError(err, true); perr != nil {
- t.Errorf("#%d: %v", i, perr)
+ for i := 0; i < 3; i++ {
+ err = c.(*TCPConn).CloseRead()
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
}
- }
- for i := 0; i < 3; i++ {
- err = c.(*TCPConn).CloseWrite()
- if perr := parseCloseError(err, true); perr != nil {
- t.Errorf("#%d: %v", i, perr)
+ for i := 0; i < 3; i++ {
+ err = c.(*TCPConn).CloseWrite()
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
}
- }
- for i := 0; i < 3; i++ {
- err = c.Close()
- if perr := parseCloseError(err, false); perr != nil {
- t.Errorf("#%d: %v", i, perr)
+ for i := 0; i < 3; i++ {
+ err = c.Close()
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ err = ln.Close()
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
}
- err = ln.Close()
- if perr := parseCloseError(err, false); perr != nil {
- t.Errorf("#%d: %v", i, perr)
+ })
+
+ t.Run("udp", func(t *testing.T) {
+ if !testableNetwork("udp") {
+ t.Skipf("skipping: udp not available")
}
- }
- pc, err := ListenPacket("udp", "127.0.0.1:0")
- if err != nil {
- t.Fatal(err)
- }
- defer pc.Close()
+ pc, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer pc.Close()
- for i := 0; i < 3; i++ {
- err = pc.Close()
- if perr := parseCloseError(err, false); perr != nil {
- t.Errorf("#%d: %v", i, perr)
+ for i := 0; i < 3; i++ {
+ err = pc.Close()
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
}
- }
+ })
}
// parseAcceptError parses nestedErr and reports whether it is a valid
diff --git a/src/net/error_unix_test.go b/src/net/error_unix_test.go
index 291a7234f2..963ba21f1a 100644
--- a/src/net/error_unix_test.go
+++ b/src/net/error_unix_test.go
@@ -13,7 +13,6 @@ import (
)
var (
- errTimedout = syscall.ETIMEDOUT
errOpNotSupported = syscall.EOPNOTSUPP
abortedConnRequestErrors = []error{syscall.ECONNABORTED} // see accept in fd_unix.go
diff --git a/src/net/error_windows_test.go b/src/net/error_windows_test.go
index 25825f96f8..7847af0551 100644
--- a/src/net/error_windows_test.go
+++ b/src/net/error_windows_test.go
@@ -10,7 +10,6 @@ import (
)
var (
- errTimedout = syscall.ETIMEDOUT
errOpNotSupported = syscall.EOPNOTSUPP
abortedConnRequestErrors = []error{syscall.ERROR_NETNAME_DELETED, syscall.WSAECONNRESET} // see accept in fd_windows.go
diff --git a/src/net/external_test.go b/src/net/external_test.go
index 0709b9d6f5..38788efc3d 100644
--- a/src/net/external_test.go
+++ b/src/net/external_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
diff --git a/src/net/fd_fake.go b/src/net/fd_fake.go
new file mode 100644
index 0000000000..ae567acc69
--- /dev/null
+++ b/src/net/fd_fake.go
@@ -0,0 +1,169 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js || wasip1
+
+package net
+
+import (
+ "internal/poll"
+ "runtime"
+ "time"
+)
+
+const (
+ readSyscallName = "fd_read"
+ writeSyscallName = "fd_write"
+)
+
+// Network file descriptor.
+type netFD struct {
+ pfd poll.FD
+
+ // immutable until Close
+ family int
+ sotype int
+ isConnected bool // handshake completed or use of association with peer
+ net string
+ laddr Addr
+ raddr Addr
+
+ // The only networking available in WASI preview 1 is the ability to
+ // sock_accept on a pre-opened socket, and then fd_read, fd_write,
+ // fd_close, and sock_shutdown on the resulting connection. We
+ // intercept applicable netFD calls on this instance, and then pass
+ // the remainder of the netFD calls to fakeNetFD.
+ *fakeNetFD
+}
+
+func newFD(net string, sysfd int) *netFD {
+ return newPollFD(net, poll.FD{
+ Sysfd: sysfd,
+ IsStream: true,
+ ZeroReadIsEOF: true,
+ })
+}
+
+func newPollFD(net string, pfd poll.FD) *netFD {
+ var laddr Addr
+ var raddr Addr
+ // WASI preview 1 does not have functions like getsockname/getpeername,
+ // so we cannot get access to the underlying IP address used by connections.
+ //
+ // However, listeners created by FileListener are of type *TCPListener,
+ // which can be asserted by a Go program. The (*TCPListener).Addr method
+ // documents that the returned value will be of type *TCPAddr, we satisfy
+ // the documented behavior by creating addresses of the expected type here.
+ switch net {
+ case "tcp":
+ laddr = new(TCPAddr)
+ raddr = new(TCPAddr)
+ case "udp":
+ laddr = new(UDPAddr)
+ raddr = new(UDPAddr)
+ default:
+ laddr = unknownAddr{}
+ raddr = unknownAddr{}
+ }
+ return &netFD{
+ pfd: pfd,
+ net: net,
+ laddr: laddr,
+ raddr: raddr,
+ }
+}
+
+func (fd *netFD) init() error {
+ return fd.pfd.Init(fd.net, true)
+}
+
+func (fd *netFD) name() string {
+ return "unknown"
+}
+
+func (fd *netFD) accept() (netfd *netFD, err error) {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.accept(fd.laddr)
+ }
+ d, _, errcall, err := fd.pfd.Accept()
+ if err != nil {
+ if errcall != "" {
+ err = wrapSyscallError(errcall, err)
+ }
+ return nil, err
+ }
+ netfd = newFD("tcp", d)
+ if err = netfd.init(); err != nil {
+ netfd.Close()
+ return nil, err
+ }
+ return netfd, nil
+}
+
+func (fd *netFD) setAddr(laddr, raddr Addr) {
+ fd.laddr = laddr
+ fd.raddr = raddr
+ runtime.SetFinalizer(fd, (*netFD).Close)
+}
+
+func (fd *netFD) Close() error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.Close()
+ }
+ runtime.SetFinalizer(fd, nil)
+ return fd.pfd.Close()
+}
+
+func (fd *netFD) shutdown(how int) error {
+ if fd.fakeNetFD != nil {
+ return nil
+ }
+ err := fd.pfd.Shutdown(how)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("shutdown", err)
+}
+
+func (fd *netFD) Read(p []byte) (n int, err error) {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.Read(p)
+ }
+ n, err = fd.pfd.Read(p)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readSyscallName, err)
+}
+
+func (fd *netFD) Write(p []byte) (nn int, err error) {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.Write(p)
+ }
+ nn, err = fd.pfd.Write(p)
+ runtime.KeepAlive(fd)
+ return nn, wrapSyscallError(writeSyscallName, err)
+}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.SetDeadline(t)
+ }
+ return fd.pfd.SetDeadline(t)
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.SetReadDeadline(t)
+ }
+ return fd.pfd.SetReadDeadline(t)
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.SetWriteDeadline(t)
+ }
+ return fd.pfd.SetWriteDeadline(t)
+}
+
+type unknownAddr struct{}
+
+func (unknownAddr) Network() string { return "unknown" }
+func (unknownAddr) String() string { return "unknown" }
diff --git a/src/net/fd_js.go b/src/net/fd_js.go
new file mode 100644
index 0000000000..0fce036ef1
--- /dev/null
+++ b/src/net/fd_js.go
@@ -0,0 +1,28 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Fake networking for js/wasm. It is intended to allow tests of other package to pass.
+
+//go:build js
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func (fd *netFD) closeRead() error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.closeRead()
+ }
+ return os.NewSyscallError("closeRead", syscall.ENOTSUP)
+}
+
+func (fd *netFD) closeWrite() error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.closeWrite()
+ }
+ return os.NewSyscallError("closeRead", syscall.ENOTSUP)
+}
diff --git a/src/net/fd_wasip1.go b/src/net/fd_wasip1.go
index 74d0b0b2e8..d50effc05d 100644
--- a/src/net/fd_wasip1.go
+++ b/src/net/fd_wasip1.go
@@ -7,124 +7,9 @@
package net
import (
- "internal/poll"
- "runtime"
"syscall"
- "time"
)
-const (
- readSyscallName = "fd_read"
- writeSyscallName = "fd_write"
-)
-
-// Network file descriptor.
-type netFD struct {
- pfd poll.FD
-
- // immutable until Close
- family int
- sotype int
- isConnected bool // handshake completed or use of association with peer
- net string
- laddr Addr
- raddr Addr
-
- // The only networking available in WASI preview 1 is the ability to
- // sock_accept on an pre-opened socket, and then fd_read, fd_write,
- // fd_close, and sock_shutdown on the resulting connection. We
- // intercept applicable netFD calls on this instance, and then pass
- // the remainder of the netFD calls to fakeNetFD.
- *fakeNetFD
-}
-
-func newFD(net string, sysfd int) *netFD {
- return newPollFD(net, poll.FD{
- Sysfd: sysfd,
- IsStream: true,
- ZeroReadIsEOF: true,
- })
-}
-
-func newPollFD(net string, pfd poll.FD) *netFD {
- var laddr Addr
- var raddr Addr
- // WASI preview 1 does not have functions like getsockname/getpeername,
- // so we cannot get access to the underlying IP address used by connections.
- //
- // However, listeners created by FileListener are of type *TCPListener,
- // which can be asserted by a Go program. The (*TCPListener).Addr method
- // documents that the returned value will be of type *TCPAddr, we satisfy
- // the documented behavior by creating addresses of the expected type here.
- switch net {
- case "tcp":
- laddr = new(TCPAddr)
- raddr = new(TCPAddr)
- case "udp":
- laddr = new(UDPAddr)
- raddr = new(UDPAddr)
- default:
- laddr = unknownAddr{}
- raddr = unknownAddr{}
- }
- return &netFD{
- pfd: pfd,
- net: net,
- laddr: laddr,
- raddr: raddr,
- }
-}
-
-func (fd *netFD) init() error {
- return fd.pfd.Init(fd.net, true)
-}
-
-func (fd *netFD) name() string {
- return "unknown"
-}
-
-func (fd *netFD) accept() (netfd *netFD, err error) {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.accept()
- }
- d, _, errcall, err := fd.pfd.Accept()
- if err != nil {
- if errcall != "" {
- err = wrapSyscallError(errcall, err)
- }
- return nil, err
- }
- netfd = newFD("tcp", d)
- if err = netfd.init(); err != nil {
- netfd.Close()
- return nil, err
- }
- return netfd, nil
-}
-
-func (fd *netFD) setAddr(laddr, raddr Addr) {
- fd.laddr = laddr
- fd.raddr = raddr
- runtime.SetFinalizer(fd, (*netFD).Close)
-}
-
-func (fd *netFD) Close() error {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.Close()
- }
- runtime.SetFinalizer(fd, nil)
- return fd.pfd.Close()
-}
-
-func (fd *netFD) shutdown(how int) error {
- if fd.fakeNetFD != nil {
- return nil
- }
- err := fd.pfd.Shutdown(how)
- runtime.KeepAlive(fd)
- return wrapSyscallError("shutdown", err)
-}
-
func (fd *netFD) closeRead() error {
if fd.fakeNetFD != nil {
return fd.fakeNetFD.closeRead()
@@ -138,47 +23,3 @@ func (fd *netFD) closeWrite() error {
}
return fd.shutdown(syscall.SHUT_WR)
}
-
-func (fd *netFD) Read(p []byte) (n int, err error) {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.Read(p)
- }
- n, err = fd.pfd.Read(p)
- runtime.KeepAlive(fd)
- return n, wrapSyscallError(readSyscallName, err)
-}
-
-func (fd *netFD) Write(p []byte) (nn int, err error) {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.Write(p)
- }
- nn, err = fd.pfd.Write(p)
- runtime.KeepAlive(fd)
- return nn, wrapSyscallError(writeSyscallName, err)
-}
-
-func (fd *netFD) SetDeadline(t time.Time) error {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.SetDeadline(t)
- }
- return fd.pfd.SetDeadline(t)
-}
-
-func (fd *netFD) SetReadDeadline(t time.Time) error {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.SetReadDeadline(t)
- }
- return fd.pfd.SetReadDeadline(t)
-}
-
-func (fd *netFD) SetWriteDeadline(t time.Time) error {
- if fd.fakeNetFD != nil {
- return fd.fakeNetFD.SetWriteDeadline(t)
- }
- return fd.pfd.SetWriteDeadline(t)
-}
-
-type unknownAddr struct{}
-
-func (unknownAddr) Network() string { return "unknown" }
-func (unknownAddr) String() string { return "unknown" }
diff --git a/src/net/fd_windows.go b/src/net/fd_windows.go
index eeb994dfd9..45a10cf1eb 100644
--- a/src/net/fd_windows.go
+++ b/src/net/fd_windows.go
@@ -64,10 +64,38 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.
if err := fd.init(); err != nil {
return nil, err
}
- if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
- fd.pfd.SetWriteDeadline(deadline)
+
+ if ctx.Done() != nil {
+ // Propagate the Context's deadline and cancellation.
+ // If the context is already done, or if it has a nonzero deadline,
+ // ensure that that is applied before the call to ConnectEx begins
+ // so that we don't return spurious connections.
defer fd.pfd.SetWriteDeadline(noDeadline)
+
+ if ctx.Err() != nil {
+ fd.pfd.SetWriteDeadline(aLongTimeAgo)
+ } else {
+ if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
+ fd.pfd.SetWriteDeadline(deadline)
+ }
+
+ done := make(chan struct{})
+ stop := context.AfterFunc(ctx, func() {
+ // Force the runtime's poller to immediately give
+ // up waiting for writability.
+ fd.pfd.SetWriteDeadline(aLongTimeAgo)
+ close(done)
+ })
+ defer func() {
+ if !stop() {
+ // Wait for the call to SetWriteDeadline to complete so that we can
+ // reset the deadline if everything else succeeded.
+ <-done
+ }
+ }()
+ }
}
+
if !canUseConnectEx(fd.net) {
err := connectFunc(fd.pfd.Sysfd, ra)
return nil, os.NewSyscallError("connect", err)
@@ -113,22 +141,6 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.
_ = fd.pfd.WSAIoctl(windows.SIO_TCP_INITIAL_RTO, (*byte)(unsafe.Pointer(&params)), uint32(unsafe.Sizeof(params)), nil, 0, &out, nil, 0)
}
- // Wait for the goroutine converting context.Done into a write timeout
- // to exist, otherwise our caller might cancel the context and
- // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
- done := make(chan bool) // must be unbuffered
- defer func() { done <- true }()
- go func() {
- select {
- case <-ctx.Done():
- // Force the runtime's poller to immediately give
- // up waiting for writability.
- fd.pfd.SetWriteDeadline(aLongTimeAgo)
- <-done
- case <-done:
- }
- }()
-
// Call ConnectEx API.
if err := fd.pfd.ConnectEx(ra); err != nil {
select {
diff --git a/src/net/file_stub.go b/src/net/file_stub.go
index 91df926a57..6fd3eec48d 100644
--- a/src/net/file_stub.go
+++ b/src/net/file_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build js && wasm
+//go:build js
package net
diff --git a/src/net/file_test.go b/src/net/file_test.go
index 53cd3c1074..c517af50c5 100644
--- a/src/net/file_test.go
+++ b/src/net/file_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -31,7 +29,7 @@ var fileConnTests = []struct {
func TestFileConn(t *testing.T) {
switch runtime.GOOS {
- case "plan9", "windows":
+ case "plan9", "windows", "js", "wasip1":
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -132,7 +130,7 @@ var fileListenerTests = []struct {
func TestFileListener(t *testing.T) {
switch runtime.GOOS {
- case "plan9", "windows":
+ case "plan9", "windows", "js", "wasip1":
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -224,7 +222,7 @@ var filePacketConnTests = []struct {
func TestFilePacketConn(t *testing.T) {
switch runtime.GOOS {
- case "plan9", "windows":
+ case "plan9", "windows", "js", "wasip1":
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -291,7 +289,7 @@ func TestFilePacketConn(t *testing.T) {
// Issue 24483.
func TestFileCloseRace(t *testing.T) {
switch runtime.GOOS {
- case "plan9", "windows":
+ case "plan9", "windows", "js", "wasip1":
t.Skipf("not supported on %s", runtime.GOOS)
}
if !testableNetwork("tcp") {
diff --git a/src/net/hook.go b/src/net/hook.go
index ea71803e22..eded34d48a 100644
--- a/src/net/hook.go
+++ b/src/net/hook.go
@@ -13,8 +13,7 @@ var (
// if non-nil, overrides dialTCP.
testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
- testHookHostsPath = "/etc/hosts"
- testHookLookupIP = func(
+ testHookLookupIP = func(
ctx context.Context,
fn func(context.Context, string, string) ([]IPAddr, error),
network string,
@@ -23,4 +22,10 @@ var (
return fn(ctx, network, host)
}
testHookSetKeepAlive = func(time.Duration) {}
+
+ // testHookStepTime sleeps until time has moved forward by a nonzero amount.
+ // This helps to avoid flakes in timeout tests by ensuring that an implausibly
+ // short deadline (such as 1ns in the future) is always expired by the time
+ // a relevant system call occurs.
+ testHookStepTime = func() {}
)
diff --git a/src/net/hook_plan9.go b/src/net/hook_plan9.go
index e053348505..6020d32924 100644
--- a/src/net/hook_plan9.go
+++ b/src/net/hook_plan9.go
@@ -4,6 +4,6 @@
package net
-import "time"
-
-var testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
+var (
+ hostsFilePath = "/etc/hosts"
+)
diff --git a/src/net/hook_unix.go b/src/net/hook_unix.go
index 4e20f59218..69b375598d 100644
--- a/src/net/hook_unix.go
+++ b/src/net/hook_unix.go
@@ -2,16 +2,17 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1
+//go:build unix || js || wasip1
package net
import "syscall"
var (
- testHookDialChannel = func() {} // for golang.org/issue/5349
testHookCanceledDial = func() {} // for golang.org/issue/16523
+ hostsFilePath = "/etc/hosts"
+
// Placeholders for socket system calls.
socketFunc func(int, int, int) (int, error) = syscall.Socket
connectFunc func(int, syscall.Sockaddr) error = syscall.Connect
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
index ab8656cbbf..f7c5b5af90 100644
--- a/src/net/hook_windows.go
+++ b/src/net/hook_windows.go
@@ -7,14 +7,12 @@ package net
import (
"internal/syscall/windows"
"syscall"
- "time"
)
var (
- testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
+ hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
- socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
diff --git a/src/net/hosts.go b/src/net/hosts.go
index 56e6674144..73e6fcc7a4 100644
--- a/src/net/hosts.go
+++ b/src/net/hosts.go
@@ -51,7 +51,7 @@ var hosts struct {
func readHosts() {
now := time.Now()
- hp := testHookHostsPath
+ hp := hostsFilePath
if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 {
return
diff --git a/src/net/hosts_test.go b/src/net/hosts_test.go
index b3f189e641..5f22920765 100644
--- a/src/net/hosts_test.go
+++ b/src/net/hosts_test.go
@@ -59,10 +59,10 @@ var lookupStaticHostTests = []struct {
}
func TestLookupStaticHost(t *testing.T) {
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
for _, tt := range lookupStaticHostTests {
- testHookHostsPath = tt.name
+ hostsFilePath = tt.name
for _, ent := range tt.ents {
testStaticHost(t, tt.name, ent)
}
@@ -128,10 +128,10 @@ var lookupStaticAddrTests = []struct {
}
func TestLookupStaticAddr(t *testing.T) {
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
for _, tt := range lookupStaticAddrTests {
- testHookHostsPath = tt.name
+ hostsFilePath = tt.name
for _, ent := range tt.ents {
testStaticAddr(t, tt.name, ent)
}
@@ -151,27 +151,27 @@ func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) {
func TestHostCacheModification(t *testing.T) {
// Ensure that programs can't modify the internals of the host cache.
// See https://golang.org/issues/14212.
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
- testHookHostsPath = "testdata/ipv4-hosts"
+ hostsFilePath = "testdata/ipv4-hosts"
ent := staticHostEntry{"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}}
- testStaticHost(t, testHookHostsPath, ent)
+ testStaticHost(t, hostsFilePath, ent)
// Modify the addresses return by lookupStaticHost.
addrs, _ := lookupStaticHost(ent.in)
for i := range addrs {
addrs[i] += "junk"
}
- testStaticHost(t, testHookHostsPath, ent)
+ testStaticHost(t, hostsFilePath, ent)
- testHookHostsPath = "testdata/ipv6-hosts"
+ hostsFilePath = "testdata/ipv6-hosts"
ent = staticHostEntry{"::1", []string{"localhost"}}
- testStaticAddr(t, testHookHostsPath, ent)
+ testStaticAddr(t, hostsFilePath, ent)
// Modify the hosts return by lookupStaticAddr.
hosts := lookupStaticAddr(ent.in)
for i := range hosts {
hosts[i] += "junk"
}
- testStaticAddr(t, testHookHostsPath, ent)
+ testStaticAddr(t, hostsFilePath, ent)
}
var lookupStaticHostAliasesTest = []struct {
@@ -195,9 +195,9 @@ var lookupStaticHostAliasesTest = []struct {
}
func TestLookupStaticHostAliases(t *testing.T) {
- defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ defer func(orig string) { hostsFilePath = orig }(hostsFilePath)
- testHookHostsPath = "testdata/aliases"
+ hostsFilePath = "testdata/aliases"
for _, ent := range lookupStaticHostAliasesTest {
testLookupStaticHostAliases(t, ent.lookup, absDomainName(ent.res))
}
diff --git a/src/net/http/cgi/cgi_main.go b/src/net/http/cgi/cgi_main.go
new file mode 100644
index 0000000000..8997d66a11
--- /dev/null
+++ b/src/net/http/cgi/cgi_main.go
@@ -0,0 +1,145 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cgi
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path"
+ "sort"
+ "strings"
+ "time"
+)
+
+func cgiMain() {
+ switch path.Join(os.Getenv("SCRIPT_NAME"), os.Getenv("PATH_INFO")) {
+ case "/bar", "/test.cgi", "/myscript/bar", "/test.cgi/extrapath":
+ testCGI()
+ return
+ }
+ childCGIProcess()
+}
+
+// testCGI is a CGI program translated from a Perl program to complete host_test.
+// test cases in host_test should be provided by testCGI.
+func testCGI() {
+ req, err := Request()
+ if err != nil {
+ panic(err)
+ }
+
+ err = req.ParseForm()
+ if err != nil {
+ panic(err)
+ }
+
+ params := req.Form
+ if params.Get("loc") != "" {
+ fmt.Printf("Location: %s\r\n\r\n", params.Get("loc"))
+ return
+ }
+
+ fmt.Printf("Content-Type: text/html\r\n")
+ fmt.Printf("X-CGI-Pid: %d\r\n", os.Getpid())
+ fmt.Printf("X-Test-Header: X-Test-Value\r\n")
+ fmt.Printf("\r\n")
+
+ if params.Get("writestderr") != "" {
+ fmt.Fprintf(os.Stderr, "Hello, stderr!\n")
+ }
+
+ if params.Get("bigresponse") != "" {
+ // 17 MB, for OS X: golang.org/issue/4958
+ line := strings.Repeat("A", 1024)
+ for i := 0; i < 17*1024; i++ {
+ fmt.Printf("%s\r\n", line)
+ }
+ return
+ }
+
+ fmt.Printf("test=Hello CGI\r\n")
+
+ keys := make([]string, 0, len(params))
+ for k := range params {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for _, key := range keys {
+ fmt.Printf("param-%s=%s\r\n", key, params.Get(key))
+ }
+
+ envs := envMap(os.Environ())
+ keys = make([]string, 0, len(envs))
+ for k := range envs {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for _, key := range keys {
+ fmt.Printf("env-%s=%s\r\n", key, envs[key])
+ }
+
+ cwd, _ := os.Getwd()
+ fmt.Printf("cwd=%s\r\n", cwd)
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// childCGIProcess is used by integration_test to complete unit tests.
+func childCGIProcess() {
+ if os.Getenv("REQUEST_METHOD") == "" {
+ // Not in a CGI environment; skipping test.
+ return
+ }
+ switch os.Getenv("REQUEST_URI") {
+ case "/immediate-disconnect":
+ os.Exit(0)
+ case "/no-content-type":
+ fmt.Printf("Content-Length: 6\n\nHello\n")
+ os.Exit(0)
+ case "/empty-headers":
+ fmt.Printf("\nHello")
+ os.Exit(0)
+ }
+ Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ if req.FormValue("nil-request-body") == "1" {
+ fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil)
+ return
+ }
+ rw.Header().Set("X-Test-Header", "X-Test-Value")
+ req.ParseForm()
+ if req.FormValue("no-body") == "1" {
+ return
+ }
+ if eb, ok := req.Form["exact-body"]; ok {
+ io.WriteString(rw, eb[0])
+ return
+ }
+ if req.FormValue("write-forever") == "1" {
+ io.Copy(rw, neverEnding('a'))
+ for {
+ time.Sleep(5 * time.Second) // hang forever, until killed
+ }
+ }
+ fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
+ for k, vv := range req.Form {
+ for _, v := range vv {
+ fmt.Fprintf(rw, "param-%s=%s\n", k, v)
+ }
+ }
+ for _, kv := range os.Environ() {
+ fmt.Fprintf(rw, "env-%s\n", kv)
+ }
+ }))
+ os.Exit(0)
+}
diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go
index 1411f0b8e8..e29fe20d7d 100644
--- a/src/net/http/cgi/child.go
+++ b/src/net/http/cgi/child.go
@@ -46,7 +46,7 @@ func envMap(env []string) map[string]string {
return m
}
-// RequestFromMap creates an http.Request from CGI variables.
+// RequestFromMap creates an [http.Request] from CGI variables.
// The returned Request's Body field is not populated.
func RequestFromMap(params map[string]string) (*http.Request, error) {
r := new(http.Request)
@@ -138,10 +138,10 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
return r, nil
}
-// Serve executes the provided Handler on the currently active CGI
+// Serve executes the provided [Handler] on the currently active CGI
// request, if any. If there's no current CGI environment
// an error is returned. The provided handler may be nil to use
-// http.DefaultServeMux.
+// [http.DefaultServeMux].
func Serve(handler http.Handler) error {
req, err := Request()
if err != nil {
diff --git a/src/net/http/cgi/host.go b/src/net/http/cgi/host.go
index 073952a7bd..ef222ab73a 100644
--- a/src/net/http/cgi/host.go
+++ b/src/net/http/cgi/host.go
@@ -115,23 +115,19 @@ func removeLeadingDuplicates(env []string) (ret []string) {
}
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
- root := h.Root
- if root == "" {
- root = "/"
- }
-
if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte("Chunked request bodies are not supported by CGI."))
return
}
- pathInfo := req.URL.Path
- if root != "/" && strings.HasPrefix(pathInfo, root) {
- pathInfo = pathInfo[len(root):]
- }
+ root := strings.TrimRight(h.Root, "/")
+ pathInfo := strings.TrimPrefix(req.URL.Path, root)
port := "80"
+ if req.TLS != nil {
+ port = "443"
+ }
if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
port = matches[1]
}
diff --git a/src/net/http/cgi/host_test.go b/src/net/http/cgi/host_test.go
index 860e9b3e8f..f29395fe84 100644
--- a/src/net/http/cgi/host_test.go
+++ b/src/net/http/cgi/host_test.go
@@ -9,21 +9,33 @@ package cgi
import (
"bufio"
"fmt"
+ "internal/testenv"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
- "os/exec"
"path/filepath"
"reflect"
+ "regexp"
"runtime"
- "strconv"
"strings"
"testing"
"time"
)
+// TestMain executes the test binary as the cgi server if
+// SERVER_SOFTWARE is set, and runs the tests otherwise.
+func TestMain(m *testing.M) {
+ // SERVER_SOFTWARE swap variable is set when starting the cgi server.
+ if os.Getenv("SERVER_SOFTWARE") != "" {
+ cgiMain()
+ os.Exit(0)
+ }
+
+ os.Exit(m.Run())
+}
+
func newRequest(httpreq string) *http.Request {
buf := bufio.NewReader(strings.NewReader(httpreq))
req, err := http.ReadRequest(buf)
@@ -88,24 +100,10 @@ readlines:
}
}
-var cgiTested, cgiWorks bool
-
-func check(t *testing.T) {
- if !cgiTested {
- cgiTested = true
- cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
- }
- if !cgiWorks {
- // No Perl on Windows, needed by test.cgi
- // TODO: make the child process be Go, not Perl.
- t.Skip("Skipping test: test.cgi failed.")
- }
-}
-
func TestCGIBasicGet(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
expectedMap := map[string]string{
@@ -121,7 +119,7 @@ func TestCGIBasicGet(t *testing.T) {
"env-REMOTE_PORT": "1234",
"env-REQUEST_METHOD": "GET",
"env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_FILENAME": os.Args[0],
"env-SCRIPT_NAME": "/test.cgi",
"env-SERVER_NAME": "example.com",
"env-SERVER_PORT": "80",
@@ -138,9 +136,9 @@ func TestCGIBasicGet(t *testing.T) {
}
func TestCGIEnvIPv6(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
expectedMap := map[string]string{
@@ -156,7 +154,7 @@ func TestCGIEnvIPv6(t *testing.T) {
"env-REMOTE_PORT": "12345",
"env-REQUEST_METHOD": "GET",
"env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_FILENAME": os.Args[0],
"env-SCRIPT_NAME": "/test.cgi",
"env-SERVER_NAME": "example.com",
"env-SERVER_PORT": "80",
@@ -171,27 +169,27 @@ func TestCGIEnvIPv6(t *testing.T) {
}
func TestCGIBasicGetAbsPath(t *testing.T) {
- check(t)
- pwd, err := os.Getwd()
+ absPath, err := filepath.Abs(os.Args[0])
if err != nil {
- t.Fatalf("getwd error: %v", err)
+ t.Fatal(err)
}
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: pwd + "/testdata/test.cgi",
+ Path: absPath,
Root: "/test.cgi",
}
expectedMap := map[string]string{
"env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
- "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi",
+ "env-SCRIPT_FILENAME": absPath,
"env-SCRIPT_NAME": "/test.cgi",
}
runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestPathInfo(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
expectedMap := map[string]string{
@@ -199,36 +197,36 @@ func TestPathInfo(t *testing.T) {
"env-PATH_INFO": "/extrapath",
"env-QUERY_STRING": "a=b",
"env-REQUEST_URI": "/test.cgi/extrapath?a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_FILENAME": os.Args[0],
"env-SCRIPT_NAME": "/test.cgi",
}
runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestPathInfoDirRoot(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
- Root: "/myscript/",
+ Path: os.Args[0],
+ Root: "/myscript//",
}
expectedMap := map[string]string{
- "env-PATH_INFO": "bar",
+ "env-PATH_INFO": "/bar",
"env-QUERY_STRING": "a=b",
"env-REQUEST_URI": "/myscript/bar?a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
- "env-SCRIPT_NAME": "/myscript/",
+ "env-SCRIPT_FILENAME": os.Args[0],
+ "env-SCRIPT_NAME": "/myscript",
}
runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestDupHeaders(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
}
expectedMap := map[string]string{
"env-REQUEST_URI": "/myscript/bar?a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_FILENAME": os.Args[0],
"env-HTTP_COOKIE": "nom=NOM; yum=YUM",
"env-HTTP_X_FOO": "val1, val2",
}
@@ -245,13 +243,13 @@ func TestDupHeaders(t *testing.T) {
// Verify we don't set the HTTP_PROXY environment variable.
// Hope nobody was depending on it. It's not a known header, though.
func TestDropProxyHeader(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
}
expectedMap := map[string]string{
"env-REQUEST_URI": "/myscript/bar?a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_FILENAME": os.Args[0],
"env-HTTP_X_FOO": "a",
}
runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
@@ -267,23 +265,23 @@ func TestDropProxyHeader(t *testing.T) {
}
func TestPathInfoNoRoot(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "",
}
expectedMap := map[string]string{
"env-PATH_INFO": "/bar",
"env-QUERY_STRING": "a=b",
"env-REQUEST_URI": "/bar?a=b",
- "env-SCRIPT_FILENAME": "testdata/test.cgi",
- "env-SCRIPT_NAME": "/",
+ "env-SCRIPT_FILENAME": os.Args[0],
+ "env-SCRIPT_NAME": "",
}
runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
func TestCGIBasicPost(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
postReq := `POST /test.cgi?a=b HTTP/1.0
Host: example.com
Content-Type: application/x-www-form-urlencoded
@@ -291,7 +289,7 @@ Content-Length: 15
postfoo=postbar`
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
expectedMap := map[string]string{
@@ -310,7 +308,7 @@ func chunk(s string) string {
// The CGI spec doesn't allow chunked requests.
func TestCGIPostChunked(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
postReq := `POST /test.cgi?a=b HTTP/1.1
Host: example.com
Content-Type: application/x-www-form-urlencoded
@@ -319,7 +317,7 @@ Transfer-Encoding: chunked
` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("")
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
expectedMap := map[string]string{}
@@ -331,9 +329,9 @@ Transfer-Encoding: chunked
}
func TestRedirect(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil)
@@ -346,13 +344,13 @@ func TestRedirect(t *testing.T) {
}
func TestInternalRedirect(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
})
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
PathLocationHandler: baseHandler,
}
@@ -365,13 +363,14 @@ func TestInternalRedirect(t *testing.T) {
// TestCopyError tests that we kill the process if there's an error copying
// its output. (for example, from the client having gone away)
+//
+// If we fail to do so, the test will time out (and dump its goroutines) with a
+// call to [Handler.ServeHTTP] blocked on a deferred call to [exec.Cmd.Wait].
func TestCopyError(t *testing.T) {
- check(t)
- if runtime.GOOS == "windows" {
- t.Skipf("skipping test on %q", runtime.GOOS)
- }
+ testenv.MustHaveExec(t)
+
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
}
ts := httptest.NewServer(h)
@@ -392,118 +391,63 @@ func TestCopyError(t *testing.T) {
t.Fatalf("ReadResponse: %v", err)
}
- pidstr := res.Header.Get("X-CGI-Pid")
- if pidstr == "" {
- t.Fatalf("expected an X-CGI-Pid header in response")
- }
- pid, err := strconv.Atoi(pidstr)
- if err != nil {
- t.Fatalf("invalid X-CGI-Pid value")
- }
-
var buf [5000]byte
n, err := io.ReadFull(res.Body, buf[:])
if err != nil {
t.Fatalf("ReadFull: %d bytes, %v", n, err)
}
- childRunning := func() bool {
- return isProcessRunning(pid)
- }
-
- if !childRunning() {
- t.Fatalf("pre-conn.Close, expected child to be running")
+ if !handlerRunning() {
+ t.Fatalf("pre-conn.Close, expected handler to still be running")
}
conn.Close()
+ closed := time.Now()
- tries := 0
- for tries < 25 && childRunning() {
- time.Sleep(50 * time.Millisecond * time.Duration(tries))
- tries++
- }
- if childRunning() {
- t.Fatalf("post-conn.Close, expected child to be gone")
- }
-}
-
-func TestDirUnix(t *testing.T) {
- check(t)
- if runtime.GOOS == "windows" {
- t.Skipf("skipping test on %q", runtime.GOOS)
- }
- cwd, _ := os.Getwd()
- h := &Handler{
- Path: "testdata/test.cgi",
- Root: "/test.cgi",
- Dir: cwd,
- }
- expectedMap := map[string]string{
- "cwd": cwd,
- }
- runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
-
- cwd, _ = os.Getwd()
- cwd = filepath.Join(cwd, "testdata")
- h = &Handler{
- Path: "testdata/test.cgi",
- Root: "/test.cgi",
- }
- expectedMap = map[string]string{
- "cwd": cwd,
+ nextSleep := 1 * time.Millisecond
+ for {
+ time.Sleep(nextSleep)
+ nextSleep *= 2
+ if !handlerRunning() {
+ break
+ }
+ t.Logf("handler still running %v after conn.Close", time.Since(closed))
}
- runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
-func findPerl(t *testing.T) string {
- t.Helper()
- perl, err := exec.LookPath("perl")
- if err != nil {
- t.Skip("Skipping test: perl not found.")
- }
- perl, _ = filepath.Abs(perl)
-
- cmd := exec.Command(perl, "-e", "print 123")
- cmd.Env = []string{"PATH=/garbage"}
- out, err := cmd.Output()
- if err != nil || string(out) != "123" {
- t.Skipf("Skipping test: %s is not functional", perl)
+// handlerRunning reports whether any goroutine is currently running
+// [Handler.ServeHTTP].
+func handlerRunning() bool {
+ r := regexp.MustCompile(`net/http/cgi\.\(\*Handler\)\.ServeHTTP`)
+ buf := make([]byte, 64<<10)
+ for {
+ n := runtime.Stack(buf, true)
+ if n < len(buf) {
+ return r.Match(buf[:n])
+ }
+ // Buffer wasn't large enough for a full goroutine dump.
+ // Resize it and try again.
+ buf = make([]byte, 2*len(buf))
}
- return perl
}
-func TestDirWindows(t *testing.T) {
- if runtime.GOOS != "windows" {
- t.Skip("Skipping windows specific test.")
- }
-
- cgifile, _ := filepath.Abs("testdata/test.cgi")
-
- perl := findPerl(t)
-
+func TestDir(t *testing.T) {
+ testenv.MustHaveExec(t)
cwd, _ := os.Getwd()
h := &Handler{
- Path: perl,
+ Path: os.Args[0],
Root: "/test.cgi",
Dir: cwd,
- Args: []string{cgifile},
- Env: []string{"SCRIPT_FILENAME=" + cgifile},
}
expectedMap := map[string]string{
"cwd": cwd,
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
- // If not specify Dir on windows, working directory should be
- // base directory of perl.
- cwd, _ = filepath.Split(perl)
- if cwd != "" && cwd[len(cwd)-1] == filepath.Separator {
- cwd = cwd[:len(cwd)-1]
- }
+ cwd, _ = os.Getwd()
+ cwd, _ = filepath.Split(os.Args[0])
h = &Handler{
- Path: perl,
+ Path: os.Args[0],
Root: "/test.cgi",
- Args: []string{cgifile},
- Env: []string{"SCRIPT_FILENAME=" + cgifile},
}
expectedMap = map[string]string{
"cwd": cwd,
@@ -512,17 +456,14 @@ func TestDirWindows(t *testing.T) {
}
func TestEnvOverride(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
cgifile, _ := filepath.Abs("testdata/test.cgi")
- perl := findPerl(t)
-
cwd, _ := os.Getwd()
h := &Handler{
- Path: perl,
+ Path: os.Args[0],
Root: "/test.cgi",
Dir: cwd,
- Args: []string{cgifile},
Env: []string{
"SCRIPT_FILENAME=" + cgifile,
"REQUEST_URI=/foo/bar",
@@ -538,10 +479,10 @@ func TestEnvOverride(t *testing.T) {
}
func TestHandlerStderr(t *testing.T) {
- check(t)
+ testenv.MustHaveExec(t)
var stderr strings.Builder
h := &Handler{
- Path: "testdata/test.cgi",
+ Path: os.Args[0],
Root: "/test.cgi",
Stderr: &stderr,
}
diff --git a/src/net/http/cgi/integration_test.go b/src/net/http/cgi/integration_test.go
index ef2eaf748b..68f908e2b2 100644
--- a/src/net/http/cgi/integration_test.go
+++ b/src/net/http/cgi/integration_test.go
@@ -20,7 +20,6 @@ import (
"os"
"strings"
"testing"
- "time"
)
// This test is a CGI host (testing host.go) that runs its own binary
@@ -31,7 +30,6 @@ func TestHostingOurselves(t *testing.T) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
- Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"test": "Hello CGI-in-CGI",
@@ -98,9 +96,8 @@ func TestKillChildAfterCopyError(t *testing.T) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
- Args: []string{"-test.run=TestBeChildCGIProcess"},
}
- req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil)
+ req, _ := http.NewRequest("GET", "http://example.com/test.go?write-forever=1", nil)
rec := httptest.NewRecorder()
var out bytes.Buffer
const writeLen = 50 << 10
@@ -120,7 +117,6 @@ func TestChildOnlyHeaders(t *testing.T) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
- Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"_body": "",
@@ -139,7 +135,6 @@ func TestNilRequestBody(t *testing.T) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
- Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"nil-request-body": "false",
@@ -154,7 +149,6 @@ func TestChildContentType(t *testing.T) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
- Args: []string{"-test.run=TestBeChildCGIProcess"},
}
var tests = []struct {
name string
@@ -202,7 +196,6 @@ func want500Test(t *testing.T, path string) {
h := &Handler{
Path: os.Args[0],
Root: "/test.go",
- Args: []string{"-test.run=TestBeChildCGIProcess"},
}
expectedMap := map[string]string{
"_body": "",
@@ -212,61 +205,3 @@ func want500Test(t *testing.T, path string) {
t.Errorf("Got code %d; want 500", replay.Code)
}
}
-
-type neverEnding byte
-
-func (b neverEnding) Read(p []byte) (n int, err error) {
- for i := range p {
- p[i] = byte(b)
- }
- return len(p), nil
-}
-
-// Note: not actually a test.
-func TestBeChildCGIProcess(t *testing.T) {
- if os.Getenv("REQUEST_METHOD") == "" {
- // Not in a CGI environment; skipping test.
- return
- }
- switch os.Getenv("REQUEST_URI") {
- case "/immediate-disconnect":
- os.Exit(0)
- case "/no-content-type":
- fmt.Printf("Content-Length: 6\n\nHello\n")
- os.Exit(0)
- case "/empty-headers":
- fmt.Printf("\nHello")
- os.Exit(0)
- }
- Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
- if req.FormValue("nil-request-body") == "1" {
- fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil)
- return
- }
- rw.Header().Set("X-Test-Header", "X-Test-Value")
- req.ParseForm()
- if req.FormValue("no-body") == "1" {
- return
- }
- if eb, ok := req.Form["exact-body"]; ok {
- io.WriteString(rw, eb[0])
- return
- }
- if req.FormValue("write-forever") == "1" {
- io.Copy(rw, neverEnding('a'))
- for {
- time.Sleep(5 * time.Second) // hang forever, until killed
- }
- }
- fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
- for k, vv := range req.Form {
- for _, v := range vv {
- fmt.Fprintf(rw, "param-%s=%s\n", k, v)
- }
- }
- for _, kv := range os.Environ() {
- fmt.Fprintf(rw, "env-%s\n", kv)
- }
- }))
- os.Exit(0)
-}
diff --git a/src/net/http/cgi/plan9_test.go b/src/net/http/cgi/plan9_test.go
deleted file mode 100644
index b7ace3f81c..0000000000
--- a/src/net/http/cgi/plan9_test.go
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2013 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build plan9
-
-package cgi
-
-import (
- "os"
- "strconv"
-)
-
-func isProcessRunning(pid int) bool {
- _, err := os.Stat("/proc/" + strconv.Itoa(pid))
- return err == nil
-}
diff --git a/src/net/http/cgi/posix_test.go b/src/net/http/cgi/posix_test.go
deleted file mode 100644
index 49b9470d4a..0000000000
--- a/src/net/http/cgi/posix_test.go
+++ /dev/null
@@ -1,20 +0,0 @@
-// Copyright 2013 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !plan9
-
-package cgi
-
-import (
- "os"
- "syscall"
-)
-
-func isProcessRunning(pid int) bool {
- p, err := os.FindProcess(pid)
- if err != nil {
- return false
- }
- return p.Signal(syscall.Signal(0)) == nil
-}
diff --git a/src/net/http/cgi/testdata/test.cgi b/src/net/http/cgi/testdata/test.cgi
deleted file mode 100755
index 667fce217e..0000000000
--- a/src/net/http/cgi/testdata/test.cgi
+++ /dev/null
@@ -1,95 +0,0 @@
-#!/usr/bin/perl
-# Copyright 2011 The Go Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style
-# license that can be found in the LICENSE file.
-#
-# Test script run as a child process under cgi_test.go
-
-use strict;
-use Cwd;
-
-binmode STDOUT;
-
-my $q = MiniCGI->new;
-my $params = $q->Vars;
-
-if ($params->{"loc"}) {
- print "Location: $params->{loc}\r\n\r\n";
- exit(0);
-}
-
-print "Content-Type: text/html\r\n";
-print "X-CGI-Pid: $$\r\n";
-print "X-Test-Header: X-Test-Value\r\n";
-print "\r\n";
-
-if ($params->{"writestderr"}) {
- print STDERR "Hello, stderr!\n";
-}
-
-if ($params->{"bigresponse"}) {
- # 17 MB, for OS X: golang.org/issue/4958
- for (1..(17 * 1024)) {
- print "A" x 1024, "\r\n";
- }
- exit 0;
-}
-
-print "test=Hello CGI\r\n";
-
-foreach my $k (sort keys %$params) {
- print "param-$k=$params->{$k}\r\n";
-}
-
-foreach my $k (sort keys %ENV) {
- my $clean_env = $ENV{$k};
- $clean_env =~ s/[\n\r]//g;
- print "env-$k=$clean_env\r\n";
-}
-
-# NOTE: msys perl returns /c/go/src/... not C:\go\....
-my $dir = getcwd();
-if ($^O eq 'MSWin32' || $^O eq 'msys' || $^O eq 'cygwin') {
- if ($dir =~ /^.:/) {
- $dir =~ s!/!\\!g;
- } else {
- my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
- $cmd =~ s!\\!/!g;
- $dir = `$cmd /c cd`;
- chomp $dir;
- }
-}
-print "cwd=$dir\r\n";
-
-# A minimal version of CGI.pm, for people without the perl-modules
-# package installed. (CGI.pm used to be part of the Perl core, but
-# some distros now bundle perl-base and perl-modules separately...)
-package MiniCGI;
-
-sub new {
- my $class = shift;
- return bless {}, $class;
-}
-
-sub Vars {
- my $self = shift;
- my $pairs;
- if ($ENV{CONTENT_LENGTH}) {
- $pairs = do { local $/; <STDIN> };
- } else {
- $pairs = $ENV{QUERY_STRING};
- }
- my $vars = {};
- foreach my $kv (split(/&/, $pairs)) {
- my ($k, $v) = split(/=/, $kv, 2);
- $vars->{_urldecode($k)} = _urldecode($v);
- }
- return $vars;
-}
-
-sub _urldecode {
- my $v = shift;
- $v =~ tr/+/ /;
- $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg;
- return $v;
-}
diff --git a/src/net/http/client.go b/src/net/http/client.go
index 2cab53a585..8fc348fe5d 100644
--- a/src/net/http/client.go
+++ b/src/net/http/client.go
@@ -27,34 +27,33 @@ import (
"time"
)
-// A Client is an HTTP client. Its zero value (DefaultClient) is a
-// usable client that uses DefaultTransport.
+// A Client is an HTTP client. Its zero value ([DefaultClient]) is a
+// usable client that uses [DefaultTransport].
//
-// The Client's Transport typically has internal state (cached TCP
+// The [Client.Transport] typically has internal state (cached TCP
// connections), so Clients should be reused instead of created as
// needed. Clients are safe for concurrent use by multiple goroutines.
//
-// A Client is higher-level than a RoundTripper (such as Transport)
+// A Client is higher-level than a [RoundTripper] (such as [Transport])
// and additionally handles HTTP details such as cookies and
// redirects.
//
// When following redirects, the Client will forward all headers set on the
-// initial Request except:
+// initial [Request] except:
//
-// • when forwarding sensitive headers like "Authorization",
-// "WWW-Authenticate", and "Cookie" to untrusted targets.
-// These headers will be ignored when following a redirect to a domain
-// that is not a subdomain match or exact match of the initial domain.
-// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com"
-// will forward the sensitive headers, but a redirect to "bar.com" will not.
-//
-// • when forwarding the "Cookie" header with a non-nil cookie Jar.
-// Since each redirect may mutate the state of the cookie jar,
-// a redirect may possibly alter a cookie set in the initial request.
-// When forwarding the "Cookie" header, any mutated cookies will be omitted,
-// with the expectation that the Jar will insert those mutated cookies
-// with the updated values (assuming the origin matches).
-// If Jar is nil, the initial cookies are forwarded without change.
+// - when forwarding sensitive headers like "Authorization",
+// "WWW-Authenticate", and "Cookie" to untrusted targets.
+// These headers will be ignored when following a redirect to a domain
+// that is not a subdomain match or exact match of the initial domain.
+// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com"
+// will forward the sensitive headers, but a redirect to "bar.com" will not.
+// - when forwarding the "Cookie" header with a non-nil cookie Jar.
+// Since each redirect may mutate the state of the cookie jar,
+// a redirect may possibly alter a cookie set in the initial request.
+// When forwarding the "Cookie" header, any mutated cookies will be omitted,
+// with the expectation that the Jar will insert those mutated cookies
+// with the updated values (assuming the origin matches).
+// If Jar is nil, the initial cookies are forwarded without change.
type Client struct {
// Transport specifies the mechanism by which individual
// HTTP requests are made.
@@ -106,11 +105,11 @@ type Client struct {
Timeout time.Duration
}
-// DefaultClient is the default Client and is used by Get, Head, and Post.
+// DefaultClient is the default [Client] and is used by [Get], [Head], and [Post].
var DefaultClient = &Client{}
// RoundTripper is an interface representing the ability to execute a
-// single HTTP transaction, obtaining the Response for a given Request.
+// single HTTP transaction, obtaining the [Response] for a given [Request].
//
// A RoundTripper must be safe for concurrent use by multiple
// goroutines.
@@ -440,7 +439,7 @@ func basicAuth(username, password string) string {
//
// An error is returned if there were too many redirects or if there
// was an HTTP protocol error. A non-2xx response doesn't cause an
-// error. Any returned error will be of type *url.Error. The url.Error
+// error. Any returned error will be of type [*url.Error]. The url.Error
// value's Timeout method will report true if the request timed out.
//
// When err is nil, resp always contains a non-nil resp.Body.
@@ -448,10 +447,10 @@ func basicAuth(username, password string) string {
//
// Get is a wrapper around DefaultClient.Get.
//
-// To make a request with custom headers, use NewRequest and
+// To make a request with custom headers, use [NewRequest] and
// DefaultClient.Do.
//
-// To make a request with a specified context.Context, use NewRequestWithContext
+// To make a request with a specified context.Context, use [NewRequestWithContext]
// and DefaultClient.Do.
func Get(url string) (resp *Response, err error) {
return DefaultClient.Get(url)
@@ -459,7 +458,7 @@ func Get(url string) (resp *Response, err error) {
// Get issues a GET to the specified URL. If the response is one of the
// following redirect codes, Get follows the redirect after calling the
-// Client's CheckRedirect function:
+// [Client.CheckRedirect] function:
//
// 301 (Moved Permanently)
// 302 (Found)
@@ -467,18 +466,18 @@ func Get(url string) (resp *Response, err error) {
// 307 (Temporary Redirect)
// 308 (Permanent Redirect)
//
-// An error is returned if the Client's CheckRedirect function fails
+// An error is returned if the [Client.CheckRedirect] function fails
// or if there was an HTTP protocol error. A non-2xx response doesn't
-// cause an error. Any returned error will be of type *url.Error. The
+// cause an error. Any returned error will be of type [*url.Error]. The
// url.Error value's Timeout method will report true if the request
// timed out.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
//
-// To make a request with custom headers, use NewRequest and Client.Do.
+// To make a request with custom headers, use [NewRequest] and [Client.Do].
//
-// To make a request with a specified context.Context, use NewRequestWithContext
+// To make a request with a specified context.Context, use [NewRequestWithContext]
// and Client.Do.
func (c *Client) Get(url string) (resp *Response, err error) {
req, err := NewRequest("GET", url, nil)
@@ -559,20 +558,21 @@ func urlErrorOp(method string) string {
// connectivity problem). A non-2xx status code doesn't cause an
// error.
//
-// If the returned error is nil, the Response will contain a non-nil
+// If the returned error is nil, the [Response] will contain a non-nil
// Body which the user is expected to close. If the Body is not both
-// read to EOF and closed, the Client's underlying RoundTripper
-// (typically Transport) may not be able to re-use a persistent TCP
+// read to EOF and closed, the [Client]'s underlying [RoundTripper]
+// (typically [Transport]) may not be able to re-use a persistent TCP
// connection to the server for a subsequent "keep-alive" request.
//
// The request Body, if non-nil, will be closed by the underlying
-// Transport, even on errors.
+// Transport, even on errors. The Body may be closed asynchronously after
+// Do returns.
//
// On error, any Response can be ignored. A non-nil Response with a
// non-nil error only occurs when CheckRedirect fails, and even then
-// the returned Response.Body is already closed.
+// the returned [Response.Body] is already closed.
//
-// Generally Get, Post, or PostForm will be used instead of Do.
+// Generally [Get], [Post], or [PostForm] will be used instead of Do.
//
// If the server replies with a redirect, the Client first uses the
// CheckRedirect function to determine whether the redirect should be
@@ -580,11 +580,11 @@ func urlErrorOp(method string) string {
// subsequent requests to use HTTP method GET
// (or HEAD if the original request was HEAD), with no body.
// A 307 or 308 redirect preserves the original HTTP method and body,
-// provided that the Request.GetBody function is defined.
-// The NewRequest function automatically sets GetBody for common
+// provided that the [Request.GetBody] function is defined.
+// The [NewRequest] function automatically sets GetBody for common
// standard library body types.
//
-// Any returned error will be of type *url.Error. The url.Error
+// Any returned error will be of type [*url.Error]. The url.Error
// value's Timeout method will report true if the request timed out.
func (c *Client) Do(req *Request) (*Response, error) {
return c.do(req)
@@ -818,17 +818,17 @@ func defaultCheckRedirect(req *Request, via []*Request) error {
//
// Caller should close resp.Body when done reading from it.
//
-// If the provided body is an io.Closer, it is closed after the
+// If the provided body is an [io.Closer], it is closed after the
// request.
//
// Post is a wrapper around DefaultClient.Post.
//
-// To set custom headers, use NewRequest and DefaultClient.Do.
+// To set custom headers, use [NewRequest] and DefaultClient.Do.
//
-// See the Client.Do method documentation for details on how redirects
+// See the [Client.Do] method documentation for details on how redirects
// are handled.
//
-// To make a request with a specified context.Context, use NewRequestWithContext
+// To make a request with a specified context.Context, use [NewRequestWithContext]
// and DefaultClient.Do.
func Post(url, contentType string, body io.Reader) (resp *Response, err error) {
return DefaultClient.Post(url, contentType, body)
@@ -838,13 +838,13 @@ func Post(url, contentType string, body io.Reader) (resp *Response, err error) {
//
// Caller should close resp.Body when done reading from it.
//
-// If the provided body is an io.Closer, it is closed after the
+// If the provided body is an [io.Closer], it is closed after the
// request.
//
-// To set custom headers, use NewRequest and Client.Do.
+// To set custom headers, use [NewRequest] and [Client.Do].
//
-// To make a request with a specified context.Context, use NewRequestWithContext
-// and Client.Do.
+// To make a request with a specified context.Context, use [NewRequestWithContext]
+// and [Client.Do].
//
// See the Client.Do method documentation for details on how redirects
// are handled.
@@ -861,17 +861,17 @@ func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response,
// values URL-encoded as the request body.
//
// The Content-Type header is set to application/x-www-form-urlencoded.
-// To set other headers, use NewRequest and DefaultClient.Do.
+// To set other headers, use [NewRequest] and DefaultClient.Do.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
//
// PostForm is a wrapper around DefaultClient.PostForm.
//
-// See the Client.Do method documentation for details on how redirects
+// See the [Client.Do] method documentation for details on how redirects
// are handled.
//
-// To make a request with a specified context.Context, use NewRequestWithContext
+// To make a request with a specified [context.Context], use [NewRequestWithContext]
// and DefaultClient.Do.
func PostForm(url string, data url.Values) (resp *Response, err error) {
return DefaultClient.PostForm(url, data)
@@ -881,7 +881,7 @@ func PostForm(url string, data url.Values) (resp *Response, err error) {
// with data's keys and values URL-encoded as the request body.
//
// The Content-Type header is set to application/x-www-form-urlencoded.
-// To set other headers, use NewRequest and Client.Do.
+// To set other headers, use [NewRequest] and [Client.Do].
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
@@ -889,7 +889,7 @@ func PostForm(url string, data url.Values) (resp *Response, err error) {
// See the Client.Do method documentation for details on how redirects
// are handled.
//
-// To make a request with a specified context.Context, use NewRequestWithContext
+// To make a request with a specified context.Context, use [NewRequestWithContext]
// and Client.Do.
func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
@@ -907,7 +907,7 @@ func (c *Client) PostForm(url string, data url.Values) (resp *Response, err erro
//
// Head is a wrapper around DefaultClient.Head.
//
-// To make a request with a specified context.Context, use NewRequestWithContext
+// To make a request with a specified [context.Context], use [NewRequestWithContext]
// and DefaultClient.Do.
func Head(url string) (resp *Response, err error) {
return DefaultClient.Head(url)
@@ -915,7 +915,7 @@ func Head(url string) (resp *Response, err error) {
// Head issues a HEAD to the specified URL. If the response is one of the
// following redirect codes, Head follows the redirect after calling the
-// Client's CheckRedirect function:
+// [Client.CheckRedirect] function:
//
// 301 (Moved Permanently)
// 302 (Found)
@@ -923,8 +923,8 @@ func Head(url string) (resp *Response, err error) {
// 307 (Temporary Redirect)
// 308 (Permanent Redirect)
//
-// To make a request with a specified context.Context, use NewRequestWithContext
-// and Client.Do.
+// To make a request with a specified [context.Context], use [NewRequestWithContext]
+// and [Client.Do].
func (c *Client) Head(url string) (resp *Response, err error) {
req, err := NewRequest("HEAD", url, nil)
if err != nil {
@@ -933,12 +933,12 @@ func (c *Client) Head(url string) (resp *Response, err error) {
return c.Do(req)
}
-// CloseIdleConnections closes any connections on its Transport which
+// CloseIdleConnections closes any connections on its [Transport] which
// were previously connected from previous requests but are now
// sitting idle in a "keep-alive" state. It does not interrupt any
// connections currently in use.
//
-// If the Client's Transport does not have a CloseIdleConnections method
+// If [Client.Transport] does not have a [Client.CloseIdleConnections] method
// then this method does nothing.
func (c *Client) CloseIdleConnections() {
type closeIdler interface {
@@ -1014,6 +1014,12 @@ func isDomainOrSubdomain(sub, parent string) bool {
if sub == parent {
return true
}
+ // If sub contains a :, it's probably an IPv6 address (and is definitely not a hostname).
+ // Don't check the suffix in this case, to avoid matching the contents of a IPv6 zone.
+ // For example, "::1%.www.example.com" is not a subdomain of "www.example.com".
+ if strings.ContainsAny(sub, ":%") {
+ return false
+ }
// If sub is "foo.example.com" and parent is "example.com",
// that means sub must end in "."+parent.
// Do it without allocating.
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
index 0fe555af38..e2a1cbbdea 100644
--- a/src/net/http/client_test.go
+++ b/src/net/http/client_test.go
@@ -60,13 +60,6 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
}
}
-type chanWriter chan string
-
-func (w chanWriter) Write(p []byte) (n int, err error) {
- w <- string(p)
- return len(p), nil
-}
-
func TestClient(t *testing.T) { run(t, testClient) }
func testClient(t *testing.T, mode testMode) {
ts := newClientServerTest(t, mode, robotsTxtHandler).ts
@@ -827,12 +820,12 @@ func TestClientInsecureTransport(t *testing.T) {
run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode})
}
func testClientInsecureTransport(t *testing.T, mode testMode) {
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
- })).ts
- errc := make(chanWriter, 10) // but only expecting 1
- ts.Config.ErrorLog = log.New(errc, "", 0)
- defer ts.Close()
+ }))
+ ts := cst.ts
+ errLog := new(strings.Builder)
+ ts.Config.ErrorLog = log.New(errLog, "", 0)
// TODO(bradfitz): add tests for skipping hostname checks too?
// would require a new cert for testing, and probably
@@ -851,15 +844,10 @@ func testClientInsecureTransport(t *testing.T, mode testMode) {
}
}
- select {
- case v := <-errc:
- if !strings.Contains(v, "TLS handshake error") {
- t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
- }
- case <-time.After(5 * time.Second):
- t.Errorf("timeout waiting for logged error")
+ cst.close()
+ if !strings.Contains(errLog.String(), "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog)
}
-
}
func TestClientErrorWithRequestURI(t *testing.T) {
@@ -897,9 +885,10 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode})
}
func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) {
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
- errc := make(chanWriter, 10) // but only expecting 1
- ts.Config.ErrorLog = log.New(errc, "", 0)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ ts := cst.ts
+ errLog := new(strings.Builder)
+ ts.Config.ErrorLog = log.New(errLog, "", 0)
c := ts.Client()
c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver"
@@ -910,13 +899,10 @@ func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) {
if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
}
- select {
- case v := <-errc:
- if !strings.Contains(v, "TLS handshake error") {
- t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
- }
- case <-time.After(5 * time.Second):
- t.Errorf("timeout waiting for logged error")
+
+ cst.close()
+ if !strings.Contains(errLog.String(), "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog)
}
}
@@ -960,7 +946,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) {
c := ts.Client()
tr := c.Transport.(*Transport)
- tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
+ tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA}
tr.TLSClientConfig.MaxVersion = tls.VersionTLS12 // to get to pick the cipher suite
tr.Dial = func(netw, addr string) (net.Conn, error) {
return net.Dial(netw, ts.Listener.Addr().String())
@@ -973,7 +959,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) {
if res.TLS == nil {
t.Fatal("Response didn't set TLS Connection State.")
}
- if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
+ if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
}
}
@@ -1725,6 +1711,7 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) {
{"authorization", "http://foo.com/", "https://foo.com/", true},
{"authorization", "http://foo.com:1234/", "http://foo.com:4321/", true},
{"www-authenticate", "http://foo.com/", "http://bar.com/", false},
+ {"authorization", "http://foo.com/", "http://[::1%25.foo.com]/", false},
// But subdomains should work:
{"www-authenticate", "http://foo.com/", "http://foo.com/", true},
diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go
index 58321532ea..32948f3aed 100644
--- a/src/net/http/clientserver_test.go
+++ b/src/net/http/clientserver_test.go
@@ -1172,16 +1172,12 @@ func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
t.Fatal(err)
}
})()
- timeout := time.NewTimer(5 * time.Second)
- defer timeout.Stop()
for {
select {
case <-didGC:
return
- case <-time.After(100 * time.Millisecond):
+ case <-time.After(1 * time.Millisecond):
runtime.GC()
- case <-timeout.C:
- t.Fatal("never saw GC of request")
}
}
}
diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go
index 912fde6b95..c22897f3f9 100644
--- a/src/net/http/cookie.go
+++ b/src/net/http/cookie.go
@@ -163,7 +163,7 @@ func readSetCookies(h Header) []*Cookie {
return cookies
}
-// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers.
+// SetCookie adds a Set-Cookie header to the provided [ResponseWriter]'s headers.
// The provided cookie must have a valid Name. Invalid cookies may be
// silently dropped.
func SetCookie(w ResponseWriter, cookie *Cookie) {
@@ -172,7 +172,7 @@ func SetCookie(w ResponseWriter, cookie *Cookie) {
}
}
-// String returns the serialization of the cookie for use in a Cookie
+// String returns the serialization of the cookie for use in a [Cookie]
// header (if only Name and Value are set) or a Set-Cookie response
// header (if other fields are set).
// If c is nil or c.Name is invalid, the empty string is returned.
diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go
index 273b54c84c..e7f5ddd4d0 100644
--- a/src/net/http/cookiejar/jar.go
+++ b/src/net/http/cookiejar/jar.go
@@ -73,7 +73,7 @@ type Jar struct {
nextSeqNum uint64
}
-// New returns a new cookie jar. A nil *Options is equivalent to a zero
+// New returns a new cookie jar. A nil [*Options] is equivalent to a zero
// Options.
func New(o *Options) (*Jar, error) {
jar := &Jar{
@@ -151,7 +151,7 @@ func hasDotSuffix(s, suffix string) bool {
return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
}
-// Cookies implements the Cookies method of the http.CookieJar interface.
+// Cookies implements the Cookies method of the [http.CookieJar] interface.
//
// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
@@ -226,7 +226,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
return cookies
}
-// SetCookies implements the SetCookies method of the http.CookieJar interface.
+// SetCookies implements the SetCookies method of the [http.CookieJar] interface.
//
// It does nothing if the URL's scheme is not HTTP or HTTPS.
func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
@@ -362,6 +362,13 @@ func jarKey(host string, psl PublicSuffixList) string {
// isIP reports whether host is an IP address.
func isIP(host string) bool {
+ if strings.ContainsAny(host, ":%") {
+ // Probable IPv6 address.
+ // Hostnames can't contain : or %, so this is definitely not a valid host.
+ // Treating it as an IP is the more conservative option, and avoids the risk
+ // of interpeting ::1%.www.example.com as a subtomain of www.example.com.
+ return true
+ }
return net.ParseIP(host) != nil
}
@@ -440,7 +447,6 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e
var (
errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
- errNoHostname = errors.New("cookiejar: no host name available (IP only)")
)
// endOfTime is the time when session (non-persistent) cookies expire.
diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go
index 56d0695a66..251f7c1617 100644
--- a/src/net/http/cookiejar/jar_test.go
+++ b/src/net/http/cookiejar/jar_test.go
@@ -252,6 +252,7 @@ var isIPTests = map[string]bool{
"127.0.0.1": true,
"1.2.3.4": true,
"2001:4860:0:2001::68": true,
+ "::1%zone": true,
"example.com": false,
"1.1.1.300": false,
"www.foo.bar.net": false,
@@ -629,6 +630,15 @@ var basicsTests = [...]jarTest{
{"http://www.host.test:1234/", "a=1"},
},
},
+ {
+ "IPv6 zone is not treated as a host.",
+ "https://example.com/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"https://[::1%25.example.com]:80/", ""},
+ },
+ },
}
func TestBasics(t *testing.T) {
diff --git a/src/net/http/doc.go b/src/net/http/doc.go
index d9e6aafb4e..f7ad3ae762 100644
--- a/src/net/http/doc.go
+++ b/src/net/http/doc.go
@@ -5,7 +5,7 @@
/*
Package http provides HTTP client and server implementations.
-Get, Head, Post, and PostForm make HTTP (or HTTPS) requests:
+[Get], [Head], [Post], and [PostForm] make HTTP (or HTTPS) requests:
resp, err := http.Get("http://example.com/")
...
@@ -27,7 +27,7 @@ The caller must close the response body when finished with it:
# Clients and Transports
For control over HTTP client headers, redirect policy, and other
-settings, create a Client:
+settings, create a [Client]:
client := &http.Client{
CheckRedirect: redirectPolicyFunc,
@@ -43,7 +43,7 @@ settings, create a Client:
// ...
For control over proxies, TLS configuration, keep-alives,
-compression, and other settings, create a Transport:
+compression, and other settings, create a [Transport]:
tr := &http.Transport{
MaxIdleConns: 10,
@@ -59,8 +59,8 @@ goroutines and for efficiency should only be created once and re-used.
# Servers
ListenAndServe starts an HTTP server with a given address and handler.
-The handler is usually nil, which means to use DefaultServeMux.
-Handle and HandleFunc add handlers to DefaultServeMux:
+The handler is usually nil, which means to use [DefaultServeMux].
+[Handle] and [HandleFunc] add handlers to [DefaultServeMux]:
http.Handle("/foo", fooHandler)
@@ -86,8 +86,8 @@ custom Server:
Starting with Go 1.6, the http package has transparent support for the
HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2
-can do so by setting Transport.TLSNextProto (for clients) or
-Server.TLSNextProto (for servers) to a non-nil, empty
+can do so by setting [Transport.TLSNextProto] (for clients) or
+[Server.TLSNextProto] (for servers) to a non-nil, empty
map. Alternatively, the following GODEBUG settings are
currently supported:
@@ -98,7 +98,7 @@ currently supported:
Please report any issues before disabling HTTP/2 support: https://golang.org/s/http2bug
-The http package's Transport and Server both automatically enable
+The http package's [Transport] and [Server] both automatically enable
HTTP/2 support for simple configurations. To enable HTTP/2 for more
complex configurations, to use lower-level HTTP/2 features, or to use
a newer version of Go's http2 package, import "golang.org/x/net/http2"
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 5d198f3f89..7e6d3d8e30 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -315,3 +315,21 @@ func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) {
}
return nil, false
}
+
+func init() {
+ // Set the default rstAvoidanceDelay to the minimum possible value to shake
+ // out tests that unexpectedly depend on it. Such tests should use
+ // runTimeSensitiveTest and SetRSTAvoidanceDelay to explicitly raise the delay
+ // if needed.
+ rstAvoidanceDelay = 1 * time.Nanosecond
+}
+
+// SetRSTAvoidanceDelay sets how long we are willing to wait between calling
+// CloseWrite on a connection and fully closing the connection.
+func SetRSTAvoidanceDelay(t *testing.T, d time.Duration) {
+ prevDelay := rstAvoidanceDelay
+ t.Cleanup(func() {
+ rstAvoidanceDelay = prevDelay
+ })
+ rstAvoidanceDelay = d
+}
diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go
index dc82bf7c3a..7665e7d252 100644
--- a/src/net/http/fcgi/child.go
+++ b/src/net/http/fcgi/child.go
@@ -335,7 +335,7 @@ func (c *child) cleanUp() {
// goroutine for each. The goroutine reads requests and then calls handler
// to reply to them.
// If l is nil, Serve accepts connections from os.Stdin.
-// If handler is nil, http.DefaultServeMux is used.
+// If handler is nil, [http.DefaultServeMux] is used.
func Serve(l net.Listener, handler http.Handler) error {
if l == nil {
var err error
diff --git a/src/net/http/filetransport.go b/src/net/http/filetransport.go
index 94684b07a1..7384b22fbe 100644
--- a/src/net/http/filetransport.go
+++ b/src/net/http/filetransport.go
@@ -7,6 +7,7 @@ package http
import (
"fmt"
"io"
+ "io/fs"
)
// fileTransport implements RoundTripper for the 'file' protocol.
@@ -14,13 +15,13 @@ type fileTransport struct {
fh fileHandler
}
-// NewFileTransport returns a new RoundTripper, serving the provided
-// FileSystem. The returned RoundTripper ignores the URL host in its
+// NewFileTransport returns a new [RoundTripper], serving the provided
+// [FileSystem]. The returned RoundTripper ignores the URL host in its
// incoming requests, as well as most other properties of the
// request.
//
// The typical use case for NewFileTransport is to register the "file"
-// protocol with a Transport, as in:
+// protocol with a [Transport], as in:
//
// t := &http.Transport{}
// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
@@ -31,6 +32,24 @@ func NewFileTransport(fs FileSystem) RoundTripper {
return fileTransport{fileHandler{fs}}
}
+// NewFileTransportFS returns a new [RoundTripper], serving the provided
+// file system fsys. The returned RoundTripper ignores the URL host in its
+// incoming requests, as well as most other properties of the
+// request.
+//
+// The typical use case for NewFileTransportFS is to register the "file"
+// protocol with a [Transport], as in:
+//
+// fsys := os.DirFS("/")
+// t := &http.Transport{}
+// t.RegisterProtocol("file", http.NewFileTransportFS(fsys))
+// c := &http.Client{Transport: t}
+// res, err := c.Get("file:///etc/passwd")
+// ...
+func NewFileTransportFS(fsys fs.FS) RoundTripper {
+ return NewFileTransport(FS(fsys))
+}
+
func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) {
// We start ServeHTTP in a goroutine, which may take a long
// time if the file is large. The newPopulateResponseWriter
diff --git a/src/net/http/filetransport_test.go b/src/net/http/filetransport_test.go
index 77fc8eeccf..b3e3301e10 100644
--- a/src/net/http/filetransport_test.go
+++ b/src/net/http/filetransport_test.go
@@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"testing"
+ "testing/fstest"
)
func checker(t *testing.T) func(string, error) {
@@ -62,3 +63,44 @@ func TestFileTransport(t *testing.T) {
}
res.Body.Close()
}
+
+func TestFileTransportFS(t *testing.T) {
+ check := checker(t)
+
+ fsys := fstest.MapFS{
+ "index.html": {Data: []byte("index.html says hello")},
+ }
+
+ tr := &Transport{}
+ tr.RegisterProtocol("file", NewFileTransportFS(fsys))
+ c := &Client{Transport: tr}
+
+ for fname, mfile := range fsys {
+ urlstr := "file:///" + fname
+ res, err := c.Get(urlstr)
+ check("Get "+urlstr, err)
+ if res.StatusCode != 200 {
+ t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode)
+ }
+ if res.ContentLength != -1 {
+ t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength)
+ }
+ if res.Body == nil {
+ t.Fatalf("for %s, nil Body", urlstr)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ check("ReadAll "+urlstr, err)
+ if string(slurp) != string(mfile.Data) {
+ t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar")
+ }
+ }
+
+ const badURL = "file://../no-exist.txt"
+ res, err := c.Get(badURL)
+ check("Get "+badURL, err)
+ if res.StatusCode != 404 {
+ t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
+ }
+ res.Body.Close()
+}
diff --git a/src/net/http/fs.go b/src/net/http/fs.go
index 41e0b43ac8..af7511a7a4 100644
--- a/src/net/http/fs.go
+++ b/src/net/http/fs.go
@@ -25,12 +25,12 @@ import (
"time"
)
-// A Dir implements FileSystem using the native file system restricted to a
+// A Dir implements [FileSystem] using the native file system restricted to a
// specific directory tree.
//
-// While the FileSystem.Open method takes '/'-separated paths, a Dir's string
+// While the [FileSystem.Open] method takes '/'-separated paths, a Dir's string
// value is a filename on the native file system, not a URL, so it is separated
-// by filepath.Separator, which isn't necessarily '/'.
+// by [filepath.Separator], which isn't necessarily '/'.
//
// Note that Dir could expose sensitive files and directories. Dir will follow
// symlinks pointing out of the directory tree, which can be especially dangerous
@@ -67,7 +67,7 @@ func mapOpenError(originalErr error, name string, sep rune, stat func(string) (f
return originalErr
}
-// Open implements FileSystem using os.Open, opening files for reading rooted
+// Open implements [FileSystem] using [os.Open], opening files for reading rooted
// and relative to the directory d.
func (d Dir) Open(name string) (File, error) {
path, err := safefilepath.FromFS(path.Clean("/" + name))
@@ -89,18 +89,18 @@ func (d Dir) Open(name string) (File, error) {
// A FileSystem implements access to a collection of named files.
// The elements in a file path are separated by slash ('/', U+002F)
// characters, regardless of host operating system convention.
-// See the FileServer function to convert a FileSystem to a Handler.
+// See the [FileServer] function to convert a FileSystem to a [Handler].
//
-// This interface predates the fs.FS interface, which can be used instead:
-// the FS adapter function converts an fs.FS to a FileSystem.
+// This interface predates the [fs.FS] interface, which can be used instead:
+// the [FS] adapter function converts an fs.FS to a FileSystem.
type FileSystem interface {
Open(name string) (File, error)
}
-// A File is returned by a FileSystem's Open method and can be
-// served by the FileServer implementation.
+// A File is returned by a [FileSystem]'s Open method and can be
+// served by the [FileServer] implementation.
//
-// The methods should behave the same as those on an *os.File.
+// The methods should behave the same as those on an [*os.File].
type File interface {
io.Closer
io.Reader
@@ -167,7 +167,7 @@ func dirList(w ResponseWriter, r *Request, f File) {
}
// ServeContent replies to the request using the content in the
-// provided ReadSeeker. The main benefit of ServeContent over io.Copy
+// provided ReadSeeker. The main benefit of ServeContent over [io.Copy]
// is that it handles Range requests properly, sets the MIME type, and
// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since,
// and If-Range requests.
@@ -175,7 +175,7 @@ func dirList(w ResponseWriter, r *Request, f File) {
// If the response's Content-Type header is not set, ServeContent
// first tries to deduce the type from name's file extension and,
// if that fails, falls back to reading the first block of the content
-// and passing it to DetectContentType.
+// and passing it to [DetectContentType].
// The name is otherwise unused; in particular it can be empty and is
// never sent in the response.
//
@@ -190,7 +190,7 @@ func dirList(w ResponseWriter, r *Request, f File) {
// If the caller has set w's ETag header formatted per RFC 7232, section 2.3,
// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range.
//
-// Note that *os.File implements the io.ReadSeeker interface.
+// Note that [*os.File] implements the [io.ReadSeeker] interface.
func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
sizeFunc := func() (int64, error) {
size, err := content.Seek(0, io.SeekEnd)
@@ -343,10 +343,35 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
w.Header().Set("Accept-Ranges", "bytes")
- if w.Header().Get("Content-Encoding") == "" {
+
+ // We should be able to unconditionally set the Content-Length here.
+ //
+ // However, there is a pattern observed in the wild that this breaks:
+ // The user wraps the ResponseWriter in one which gzips data written to it,
+ // and sets "Content-Encoding: gzip".
+ //
+ // The user shouldn't be doing this; the serveContent path here depends
+ // on serving seekable data with a known length. If you want to compress
+ // on the fly, then you shouldn't be using ServeFile/ServeContent, or
+ // you should compress the entire file up-front and provide a seekable
+ // view of the compressed data.
+ //
+ // However, since we've observed this pattern in the wild, and since
+ // setting Content-Length here breaks code that mostly-works today,
+ // skip setting Content-Length if the user set Content-Encoding.
+ //
+ // If this is a range request, always set Content-Length.
+ // If the user isn't changing the bytes sent in the ResponseWrite,
+ // the Content-Length will be correct.
+ // If the user is changing the bytes sent, then the range request wasn't
+ // going to work properly anyway and we aren't worse off.
+ //
+ // A possible future improvement on this might be to look at the type
+ // of the ResponseWriter, and always set Content-Length if it's one
+ // that we recognize.
+ if len(ranges) > 0 || w.Header().Get("Content-Encoding") == "" {
w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10))
}
-
w.WriteHeader(code)
if r.Method != "HEAD" {
@@ -716,13 +741,13 @@ func localRedirect(w ResponseWriter, r *Request, newPath string) {
//
// As a precaution, ServeFile will reject requests where r.URL.Path
// contains a ".." path element; this protects against callers who
-// might unsafely use filepath.Join on r.URL.Path without sanitizing
+// might unsafely use [filepath.Join] on r.URL.Path without sanitizing
// it and then use that filepath.Join result as the name argument.
//
// As another special case, ServeFile redirects any request where r.URL.Path
// ends in "/index.html" to the same path, without the final
// "index.html". To avoid such redirects either modify the path or
-// use ServeContent.
+// use [ServeContent].
//
// Outside of those two special cases, ServeFile does not use
// r.URL.Path for selecting the file or directory to serve; only the
@@ -741,6 +766,40 @@ func ServeFile(w ResponseWriter, r *Request, name string) {
serveFile(w, r, Dir(dir), file, false)
}
+// ServeFileFS replies to the request with the contents
+// of the named file or directory from the file system fsys.
+//
+// If the provided file or directory name is a relative path, it is
+// interpreted relative to the current directory and may ascend to
+// parent directories. If the provided name is constructed from user
+// input, it should be sanitized before calling [ServeFile].
+//
+// As a precaution, ServeFile will reject requests where r.URL.Path
+// contains a ".." path element; this protects against callers who
+// might unsafely use [filepath.Join] on r.URL.Path without sanitizing
+// it and then use that filepath.Join result as the name argument.
+//
+// As another special case, ServeFile redirects any request where r.URL.Path
+// ends in "/index.html" to the same path, without the final
+// "index.html". To avoid such redirects either modify the path or
+// use ServeContent.
+//
+// Outside of those two special cases, ServeFile does not use
+// r.URL.Path for selecting the file or directory to serve; only the
+// file or directory provided in the name argument is used.
+func ServeFileFS(w ResponseWriter, r *Request, fsys fs.FS, name string) {
+ if containsDotDot(r.URL.Path) {
+ // Too many programs use r.URL.Path to construct the argument to
+ // serveFile. Reject the request under the assumption that happened
+ // here and ".." may not be wanted.
+ // Note that name might not contain "..", for example if code (still
+ // incorrectly) used filepath.Join(myDir, r.URL.Path).
+ Error(w, "invalid URL path", StatusBadRequest)
+ return
+ }
+ serveFile(w, r, FS(fsys), name, false)
+}
+
func containsDotDot(v string) bool {
if !strings.Contains(v, "..") {
return false
@@ -831,9 +890,9 @@ func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) {
return list, nil
}
-// FS converts fsys to a FileSystem implementation,
-// for use with FileServer and NewFileTransport.
-// The files provided by fsys must implement io.Seeker.
+// FS converts fsys to a [FileSystem] implementation,
+// for use with [FileServer] and [NewFileTransport].
+// The files provided by fsys must implement [io.Seeker].
func FS(fsys fs.FS) FileSystem {
return ioFS{fsys}
}
@@ -846,17 +905,27 @@ func FS(fsys fs.FS) FileSystem {
// "index.html".
//
// To use the operating system's file system implementation,
-// use http.Dir:
+// use [http.Dir]:
//
// http.Handle("/", http.FileServer(http.Dir("/tmp")))
//
-// To use an fs.FS implementation, use http.FS to convert it:
-//
-// http.Handle("/", http.FileServer(http.FS(fsys)))
+// To use an [fs.FS] implementation, use [http.FileServerFS] instead.
func FileServer(root FileSystem) Handler {
return &fileHandler{root}
}
+// FileServerFS returns a handler that serves HTTP requests
+// with the contents of the file system fsys.
+//
+// As a special case, the returned file server redirects any request
+// ending in "/index.html" to the same path, without the final
+// "index.html".
+//
+// http.Handle("/", http.FileServerFS(fsys))
+func FileServerFS(root fs.FS) Handler {
+ return FileServer(FS(root))
+}
+
func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
upath := r.URL.Path
if !strings.HasPrefix(upath, "/") {
diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go
index 3fb9e01235..861e70caf2 100644
--- a/src/net/http/fs_test.go
+++ b/src/net/http/fs_test.go
@@ -7,13 +7,16 @@ package http_test
import (
"bufio"
"bytes"
+ "compress/gzip"
"errors"
"fmt"
+ "internal/testenv"
"io"
"io/fs"
"mime"
"mime/multipart"
"net"
+ "net/http"
. "net/http"
"net/http/httptest"
"net/url"
@@ -26,6 +29,7 @@ import (
"runtime"
"strings"
"testing"
+ "testing/fstest"
"time"
)
@@ -1265,7 +1269,7 @@ func TestLinuxSendfile(t *testing.T) {
defer ln.Close()
// Attempt to run strace, and skip on failure - this test requires SYS_PTRACE.
- if err := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=^$").Run(); err != nil {
+ if err := testenv.Command(t, "strace", "-f", "-q", os.Args[0], "-test.run=^$").Run(); err != nil {
t.Skipf("skipping; failed to run strace: %v", err)
}
@@ -1278,7 +1282,7 @@ func TestLinuxSendfile(t *testing.T) {
defer os.Remove(filepath)
var buf strings.Builder
- child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild")
+ child := testenv.Command(t, "strace", "-f", "-q", os.Args[0], "-test.run=^TestLinuxSendfileChild$")
child.ExtraFiles = append(child.ExtraFiles, lnf)
child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...)
child.Stdout = &buf
@@ -1559,3 +1563,108 @@ func testFileServerMethods(t *testing.T, mode testMode) {
}
}
}
+
+func TestFileServerFS(t *testing.T) {
+ filename := "index.html"
+ contents := []byte("index.html says hello")
+ fsys := fstest.MapFS{
+ filename: {Data: contents},
+ }
+ ts := newClientServerTest(t, http1Mode, FileServerFS(fsys)).ts
+ defer ts.Close()
+
+ res, err := ts.Client().Get(ts.URL + "/" + filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != string(contents) {
+ t.Errorf("for path %q got %q, want %q", filename, s, contents)
+ }
+ res.Body.Close()
+}
+
+func TestServeFileFS(t *testing.T) {
+ filename := "index.html"
+ contents := []byte("index.html says hello")
+ fsys := fstest.MapFS{
+ filename: {Data: contents},
+ }
+ ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFileFS(w, r, fsys, filename)
+ })).ts
+ defer ts.Close()
+
+ res, err := ts.Client().Get(ts.URL + "/" + filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != string(contents) {
+ t.Errorf("for path %q got %q, want %q", filename, s, contents)
+ }
+ res.Body.Close()
+}
+
+func TestServeFileZippingResponseWriter(t *testing.T) {
+ // This test exercises a pattern which is incorrect,
+ // but has been observed enough in the world that we don't want to break it.
+ //
+ // The server is setting "Content-Encoding: gzip",
+ // wrapping the ResponseWriter in an implementation which gzips data written to it,
+ // and passing this ResponseWriter to ServeFile.
+ //
+ // This means ServeFile cannot properly set a Content-Length header, because it
+ // doesn't know what content it is going to send--the ResponseWriter is modifying
+ // the bytes sent.
+ //
+ // Range requests are always going to be broken in this scenario,
+ // but verify that we can serve non-range requests correctly.
+ filename := "index.html"
+ contents := []byte("contents will be sent with Content-Encoding: gzip")
+ fsys := fstest.MapFS{
+ filename: {Data: contents},
+ }
+ ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ gzw := gzip.NewWriter(w)
+ defer gzw.Close()
+ ServeFileFS(gzipResponseWriter{w: gzw, ResponseWriter: w}, r, fsys, filename)
+ })).ts
+ defer ts.Close()
+
+ res, err := ts.Client().Get(ts.URL + "/" + filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != string(contents) {
+ t.Errorf("for path %q got %q, want %q", filename, s, contents)
+ }
+ res.Body.Close()
+}
+
+type gzipResponseWriter struct {
+ ResponseWriter
+ w *gzip.Writer
+}
+
+func (grw gzipResponseWriter) Write(b []byte) (int, error) {
+ return grw.w.Write(b)
+}
+
+func (grw gzipResponseWriter) Flush() {
+ grw.w.Flush()
+ if fw, ok := grw.ResponseWriter.(http.Flusher); ok {
+ fw.Flush()
+ }
+}
diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go
index dd59e1f4f2..ac41144d5b 100644
--- a/src/net/http/h2_bundle.go
+++ b/src/net/http/h2_bundle.go
@@ -1,5 +1,4 @@
//go:build !nethttpomithttp2
-// +build !nethttpomithttp2
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
// $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2
@@ -33,6 +32,7 @@ import (
"io/fs"
"log"
"math"
+ "math/bits"
mathrand "math/rand"
"net"
"net/http/httptrace"
@@ -1041,41 +1041,44 @@ func http2shouldRetryDial(call *http2dialCall, req *Request) bool {
// TODO: Benchmark to determine if the pools are necessary. The GC may have
// improved enough that we can instead allocate chunks like this:
// make([]byte, max(16<<10, expectedBytesRemaining))
-var (
- http2dataChunkSizeClasses = []int{
- 1 << 10,
- 2 << 10,
- 4 << 10,
- 8 << 10,
- 16 << 10,
- }
- http2dataChunkPools = [...]sync.Pool{
- {New: func() interface{} { return make([]byte, 1<<10) }},
- {New: func() interface{} { return make([]byte, 2<<10) }},
- {New: func() interface{} { return make([]byte, 4<<10) }},
- {New: func() interface{} { return make([]byte, 8<<10) }},
- {New: func() interface{} { return make([]byte, 16<<10) }},
- }
-)
+var http2dataChunkPools = [...]sync.Pool{
+ {New: func() interface{} { return new([1 << 10]byte) }},
+ {New: func() interface{} { return new([2 << 10]byte) }},
+ {New: func() interface{} { return new([4 << 10]byte) }},
+ {New: func() interface{} { return new([8 << 10]byte) }},
+ {New: func() interface{} { return new([16 << 10]byte) }},
+}
func http2getDataBufferChunk(size int64) []byte {
- i := 0
- for ; i < len(http2dataChunkSizeClasses)-1; i++ {
- if size <= int64(http2dataChunkSizeClasses[i]) {
- break
- }
+ switch {
+ case size <= 1<<10:
+ return http2dataChunkPools[0].Get().(*[1 << 10]byte)[:]
+ case size <= 2<<10:
+ return http2dataChunkPools[1].Get().(*[2 << 10]byte)[:]
+ case size <= 4<<10:
+ return http2dataChunkPools[2].Get().(*[4 << 10]byte)[:]
+ case size <= 8<<10:
+ return http2dataChunkPools[3].Get().(*[8 << 10]byte)[:]
+ default:
+ return http2dataChunkPools[4].Get().(*[16 << 10]byte)[:]
}
- return http2dataChunkPools[i].Get().([]byte)
}
func http2putDataBufferChunk(p []byte) {
- for i, n := range http2dataChunkSizeClasses {
- if len(p) == n {
- http2dataChunkPools[i].Put(p)
- return
- }
+ switch len(p) {
+ case 1 << 10:
+ http2dataChunkPools[0].Put((*[1 << 10]byte)(p))
+ case 2 << 10:
+ http2dataChunkPools[1].Put((*[2 << 10]byte)(p))
+ case 4 << 10:
+ http2dataChunkPools[2].Put((*[4 << 10]byte)(p))
+ case 8 << 10:
+ http2dataChunkPools[3].Put((*[8 << 10]byte)(p))
+ case 16 << 10:
+ http2dataChunkPools[4].Put((*[16 << 10]byte)(p))
+ default:
+ panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
}
- panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
}
// dataBuffer is an io.ReadWriter backed by a list of data chunks.
@@ -3058,41 +3061,6 @@ func http2summarizeFrame(f http2Frame) string {
return buf.String()
}
-func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
- return trace != nil && trace.WroteHeaderField != nil
-}
-
-func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
- if trace != nil && trace.WroteHeaderField != nil {
- trace.WroteHeaderField(k, []string{v})
- }
-}
-
-func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
- if trace != nil {
- return trace.Got1xxResponse
- }
- return nil
-}
-
-// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
-// connection.
-func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
- dialer := &tls.Dialer{
- Config: cfg,
- }
- cn, err := dialer.DialContext(ctx, network, addr)
- if err != nil {
- return nil, err
- }
- tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
- return tlsCn, nil
-}
-
-func http2tlsUnderlyingConn(tc *tls.Conn) net.Conn {
- return tc.NetConn()
-}
-
var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
type http2goroutineLock uint64
@@ -4831,14 +4799,6 @@ func (sc *http2serverConn) serve() {
}
}
-func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) {
- select {
- case <-sc.doneServing:
- case <-sharedCh:
- close(privateCh)
- }
-}
-
type http2serverMessage int
// Message values sent to serveMsgCh.
@@ -5722,9 +5682,11 @@ func (st *http2stream) copyTrailersToHandlerRequest() {
// onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired.
func (st *http2stream) onReadTimeout() {
- // Wrap the ErrDeadlineExceeded to avoid callers depending on us
- // returning the bare error.
- st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
+ if st.body != nil {
+ // Wrap the ErrDeadlineExceeded to avoid callers depending on us
+ // returning the bare error.
+ st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
+ }
}
// onWriteTimeout is run on its own goroutine (from time.AfterFunc)
@@ -5842,9 +5804,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{})
- if st.body != nil {
- st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
- }
+ st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
return sc.scheduleHandler(id, rw, req, handler)
@@ -6374,7 +6334,6 @@ type http2responseWriterState struct {
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished
- dirty bool // a Write failed; don't reuse this responseWriterState
sentContentLen int64 // non-zero if handler set a Content-Length header
wroteBytes int64
@@ -6494,7 +6453,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
date: date,
})
if err != nil {
- rws.dirty = true
return 0, err
}
if endStream {
@@ -6515,7 +6473,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
- rws.dirty = true
return 0, err
}
}
@@ -6527,9 +6484,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
trailers: rws.trailers,
endStream: true,
})
- if err != nil {
- rws.dirty = true
- }
return len(p), err
}
return len(p), nil
@@ -6745,14 +6699,12 @@ func (rws *http2responseWriterState) writeHeader(code int) {
h.Del("Transfer-Encoding")
}
- if rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
+ rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
streamID: rws.stream.id,
httpResCode: code,
h: h,
endStream: rws.handlerDone && !rws.hasTrailers(),
- }) != nil {
- rws.dirty = true
- }
+ })
return
}
@@ -6817,19 +6769,10 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n
func (w *http2responseWriter) handlerDone() {
rws := w.rws
- dirty := rws.dirty
rws.handlerDone = true
w.Flush()
w.rws = nil
- if !dirty {
- // Only recycle the pool if all prior Write calls to
- // the serverConn goroutine completed successfully. If
- // they returned earlier due to resets from the peer
- // there might still be write goroutines outstanding
- // from the serverConn referencing the rws memory. See
- // issue 20704.
- http2responseWriterStatePool.Put(rws)
- }
+ http2responseWriterStatePool.Put(rws)
}
// Push errors.
@@ -7374,8 +7317,7 @@ func (t *http2Transport) initConnPool() {
// HTTP/2 server.
type http2ClientConn struct {
t *http2Transport
- tconn net.Conn // usually *tls.Conn, except specialized impls
- tconnClosed bool
+ tconn net.Conn // usually *tls.Conn, except specialized impls
tlsState *tls.ConnectionState // nil only for specialized impls
reused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
@@ -8103,7 +8045,7 @@ func (cc *http2ClientConn) forceCloseConn() {
if !ok {
return
}
- if nc := http2tlsUnderlyingConn(tc); nc != nil {
+ if nc := tc.NetConn(); nc != nil {
nc.Close()
}
}
@@ -8765,7 +8707,28 @@ func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int {
return int(n) // doesn't truncate; max is 512K
}
-var http2bufPool sync.Pool // of *[]byte
+// Seven bufPools manage different frame sizes. This helps to avoid scenarios where long-running
+// streaming requests using small frame sizes occupy large buffers initially allocated for prior
+// requests needing big buffers. The size ranges are as follows:
+// {0 KB, 16 KB], {16 KB, 32 KB], {32 KB, 64 KB], {64 KB, 128 KB], {128 KB, 256 KB],
+// {256 KB, 512 KB], {512 KB, infinity}
+// In practice, the maximum scratch buffer size should not exceed 512 KB due to
+// frameScratchBufferLen(maxFrameSize), thus the "infinity pool" should never be used.
+// It exists mainly as a safety measure, for potential future increases in max buffer size.
+var http2bufPools [7]sync.Pool // of *[]byte
+
+func http2bufPoolIndex(size int) int {
+ if size <= 16384 {
+ return 0
+ }
+ size -= 1
+ bits := bits.Len(uint(size))
+ index := bits - 14
+ if index >= len(http2bufPools) {
+ return len(http2bufPools) - 1
+ }
+ return index
+}
func (cs *http2clientStream) writeRequestBody(req *Request) (err error) {
cc := cs.cc
@@ -8783,12 +8746,13 @@ func (cs *http2clientStream) writeRequestBody(req *Request) (err error) {
// Scratch buffer for reading into & writing from.
scratchLen := cs.frameScratchBufferLen(maxFrameSize)
var buf []byte
- if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen {
- defer http2bufPool.Put(bp)
+ index := http2bufPoolIndex(scratchLen)
+ if bp, ok := http2bufPools[index].Get().(*[]byte); ok && len(*bp) >= scratchLen {
+ defer http2bufPools[index].Put(bp)
buf = *bp
} else {
buf = make([]byte, scratchLen)
- defer http2bufPool.Put(&buf)
+ defer http2bufPools[index].Put(&buf)
}
var sawEOF bool
@@ -10269,6 +10233,37 @@ func http2traceFirstResponseByte(trace *httptrace.ClientTrace) {
}
}
+func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
+ return trace != nil && trace.WroteHeaderField != nil
+}
+
+func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(k, []string{v})
+ }
+}
+
+func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
+ if trace != nil {
+ return trace.Got1xxResponse
+ }
+ return nil
+}
+
+// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
+// connection.
+func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
+ dialer := &tls.Dialer{
+ Config: cfg,
+ }
+ cn, err := dialer.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
+ return tlsCn, nil
+}
+
// writeFramer is implemented by any type that is used to write frames.
type http2writeFramer interface {
writeFrame(http2writeContext) error
diff --git a/src/net/http/h2_error.go b/src/net/http/h2_error.go
index 0391d31e5b..2c0b21ec07 100644
--- a/src/net/http/h2_error.go
+++ b/src/net/http/h2_error.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !nethttpomithttp2
-// +build !nethttpomithttp2
package http
diff --git a/src/net/http/h2_error_test.go b/src/net/http/h2_error_test.go
index 0d85e2f36c..5e400683b4 100644
--- a/src/net/http/h2_error_test.go
+++ b/src/net/http/h2_error_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !nethttpomithttp2
-// +build !nethttpomithttp2
package http
diff --git a/src/net/http/header.go b/src/net/http/header.go
index e0b342c63c..9d0f3a125d 100644
--- a/src/net/http/header.go
+++ b/src/net/http/header.go
@@ -20,13 +20,13 @@ import (
// A Header represents the key-value pairs in an HTTP header.
//
// The keys should be in canonical form, as returned by
-// CanonicalHeaderKey.
+// [CanonicalHeaderKey].
type Header map[string][]string
// Add adds the key, value pair to the header.
// It appends to any existing values associated with key.
// The key is case insensitive; it is canonicalized by
-// CanonicalHeaderKey.
+// [CanonicalHeaderKey].
func (h Header) Add(key, value string) {
textproto.MIMEHeader(h).Add(key, value)
}
@@ -34,7 +34,7 @@ func (h Header) Add(key, value string) {
// Set sets the header entries associated with key to the
// single element value. It replaces any existing values
// associated with key. The key is case insensitive; it is
-// canonicalized by textproto.CanonicalMIMEHeaderKey.
+// canonicalized by [textproto.CanonicalMIMEHeaderKey].
// To use non-canonical keys, assign to the map directly.
func (h Header) Set(key, value string) {
textproto.MIMEHeader(h).Set(key, value)
@@ -42,7 +42,7 @@ func (h Header) Set(key, value string) {
// Get gets the first value associated with the given key. If
// there are no values associated with the key, Get returns "".
-// It is case insensitive; textproto.CanonicalMIMEHeaderKey is
+// It is case insensitive; [textproto.CanonicalMIMEHeaderKey] is
// used to canonicalize the provided key. Get assumes that all
// keys are stored in canonical form. To use non-canonical keys,
// access the map directly.
@@ -51,7 +51,7 @@ func (h Header) Get(key string) string {
}
// Values returns all values associated with the given key.
-// It is case insensitive; textproto.CanonicalMIMEHeaderKey is
+// It is case insensitive; [textproto.CanonicalMIMEHeaderKey] is
// used to canonicalize the provided key. To use non-canonical
// keys, access the map directly.
// The returned slice is not a copy.
@@ -76,7 +76,7 @@ func (h Header) has(key string) bool {
// Del deletes the values associated with key.
// The key is case insensitive; it is canonicalized by
-// CanonicalHeaderKey.
+// [CanonicalHeaderKey].
func (h Header) Del(key string) {
textproto.MIMEHeader(h).Del(key)
}
@@ -125,7 +125,7 @@ var timeFormats = []string{
// ParseTime parses a time header (such as the Date: header),
// trying each of the three formats allowed by HTTP/1.1:
-// TimeFormat, time.RFC850, and time.ANSIC.
+// [TimeFormat], [time.RFC850], and [time.ANSIC].
func ParseTime(text string) (t time.Time, err error) {
for _, layout := range timeFormats {
t, err = time.Parse(layout, text)
diff --git a/src/net/http/http.go b/src/net/http/http.go
index 9b81654fcc..6e2259adbf 100644
--- a/src/net/http/http.go
+++ b/src/net/http/http.go
@@ -103,10 +103,10 @@ func hexEscapeNonASCII(s string) string {
return string(b)
}
-// NoBody is an io.ReadCloser with no bytes. Read always returns EOF
+// NoBody is an [io.ReadCloser] with no bytes. Read always returns EOF
// and Close always returns nil. It can be used in an outgoing client
// request to explicitly signal that a request has zero bytes.
-// An alternative, however, is to simply set Request.Body to nil.
+// An alternative, however, is to simply set [Request.Body] to nil.
var NoBody = noBody{}
type noBody struct{}
@@ -121,7 +121,7 @@ var (
_ io.ReadCloser = NoBody
)
-// PushOptions describes options for Pusher.Push.
+// PushOptions describes options for [Pusher.Push].
type PushOptions struct {
// Method specifies the HTTP method for the promised request.
// If set, it must be "GET" or "HEAD". Empty means "GET".
diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go
index 91bb1b2620..2e7e024e20 100644
--- a/src/net/http/http_test.go
+++ b/src/net/http/http_test.go
@@ -12,7 +12,6 @@ import (
"io/fs"
"net/url"
"os"
- "os/exec"
"reflect"
"regexp"
"strings"
@@ -55,7 +54,7 @@ func TestForeachHeaderElement(t *testing.T) {
func TestCmdGoNoHTTPServer(t *testing.T) {
t.Parallel()
goBin := testenv.GoToolPath(t)
- out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput()
+ out, err := testenv.Command(t, goBin, "tool", "nm", goBin).CombinedOutput()
if err != nil {
t.Fatalf("go tool nm: %v: %s", err, out)
}
@@ -89,7 +88,7 @@ func TestOmitHTTP2(t *testing.T) {
}
t.Parallel()
goTool := testenv.GoToolPath(t)
- out, err := exec.Command(goTool, "test", "-short", "-tags=nethttpomithttp2", "net/http").CombinedOutput()
+ out, err := testenv.Command(t, goTool, "test", "-short", "-tags=nethttpomithttp2", "net/http").CombinedOutput()
if err != nil {
t.Fatalf("go test -short failed: %v, %s", err, out)
}
@@ -101,7 +100,7 @@ func TestOmitHTTP2(t *testing.T) {
func TestOmitHTTP2Vet(t *testing.T) {
t.Parallel()
goTool := testenv.GoToolPath(t)
- out, err := exec.Command(goTool, "vet", "-tags=nethttpomithttp2", "net/http").CombinedOutput()
+ out, err := testenv.Command(t, goTool, "vet", "-tags=nethttpomithttp2", "net/http").CombinedOutput()
if err != nil {
t.Fatalf("go vet failed: %v, %s", err, out)
}
diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go
index 9bedefd2bc..f0ca64362d 100644
--- a/src/net/http/httptest/httptest.go
+++ b/src/net/http/httptest/httptest.go
@@ -15,7 +15,7 @@ import (
)
// NewRequest returns a new incoming server Request, suitable
-// for passing to an http.Handler for testing.
+// for passing to an [http.Handler] for testing.
//
// The target is the RFC 7230 "request-target": it may be either a
// path or an absolute URL. If target is an absolute URL, the host name
diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go
index 1c1d880155..dd51901b0d 100644
--- a/src/net/http/httptest/recorder.go
+++ b/src/net/http/httptest/recorder.go
@@ -16,7 +16,7 @@ import (
"golang.org/x/net/http/httpguts"
)
-// ResponseRecorder is an implementation of http.ResponseWriter that
+// ResponseRecorder is an implementation of [http.ResponseWriter] that
// records its mutations for later inspection in tests.
type ResponseRecorder struct {
// Code is the HTTP response code set by WriteHeader.
@@ -47,7 +47,7 @@ type ResponseRecorder struct {
wroteHeader bool
}
-// NewRecorder returns an initialized ResponseRecorder.
+// NewRecorder returns an initialized [ResponseRecorder].
func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
HeaderMap: make(http.Header),
@@ -57,12 +57,12 @@ func NewRecorder() *ResponseRecorder {
}
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
-// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+// an explicit DefaultRemoteAddr isn't set on [ResponseRecorder].
const DefaultRemoteAddr = "1.2.3.4"
-// Header implements http.ResponseWriter. It returns the response
+// Header implements [http.ResponseWriter]. It returns the response
// headers to mutate within a handler. To test the headers that were
-// written after a handler completes, use the Result method and see
+// written after a handler completes, use the [ResponseRecorder.Result] method and see
// the returned Response value's Header.
func (rw *ResponseRecorder) Header() http.Header {
m := rw.HeaderMap
@@ -112,7 +112,7 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
return len(buf), nil
}
-// WriteString implements io.StringWriter. The data in str is written
+// WriteString implements [io.StringWriter]. The data in str is written
// to rw.Body, if not nil.
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
rw.writeHeader(nil, str)
@@ -139,7 +139,7 @@ func checkWriteHeaderCode(code int) {
}
}
-// WriteHeader implements http.ResponseWriter.
+// WriteHeader implements [http.ResponseWriter].
func (rw *ResponseRecorder) WriteHeader(code int) {
if rw.wroteHeader {
return
@@ -154,7 +154,7 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
rw.snapHeader = rw.HeaderMap.Clone()
}
-// Flush implements http.Flusher. To test whether Flush was
+// Flush implements [http.Flusher]. To test whether Flush was
// called, see rw.Flushed.
func (rw *ResponseRecorder) Flush() {
if !rw.wroteHeader {
@@ -175,7 +175,7 @@ func (rw *ResponseRecorder) Flush() {
// did a write.
//
// The Response.Body is guaranteed to be non-nil and Body.Read call is
-// guaranteed to not return any error other than io.EOF.
+// guaranteed to not return any error other than [io.EOF].
//
// Result must only be called after the handler has finished running.
func (rw *ResponseRecorder) Result() *http.Response {
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
index f254a494d1..5095b438ec 100644
--- a/src/net/http/httptest/server.go
+++ b/src/net/http/httptest/server.go
@@ -77,7 +77,7 @@ func newLocalListener() net.Listener {
// When debugging a particular http server-based test,
// this flag lets you run
//
-// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
+// go test -run='^BrokenTest$' -httptest.serve=127.0.0.1:8000
//
// to start the broken server so you can interact with it manually.
// We only register this flag if it looks like the caller knows about it
@@ -100,7 +100,7 @@ func strSliceContainsPrefix(v []string, pre string) bool {
return false
}
-// NewServer starts and returns a new Server.
+// NewServer starts and returns a new [Server].
// The caller should call Close when finished, to shut it down.
func NewServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
@@ -108,7 +108,7 @@ func NewServer(handler http.Handler) *Server {
return ts
}
-// NewUnstartedServer returns a new Server but doesn't start it.
+// NewUnstartedServer returns a new [Server] but doesn't start it.
//
// After changing its configuration, the caller should call Start or
// StartTLS.
@@ -144,7 +144,7 @@ func (s *Server) StartTLS() {
panic("Server already started")
}
if s.client == nil {
- s.client = &http.Client{Transport: &http.Transport{}}
+ s.client = &http.Client{}
}
cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
if err != nil {
@@ -185,7 +185,7 @@ func (s *Server) StartTLS() {
s.goServe()
}
-// NewTLSServer starts and returns a new Server using TLS.
+// NewTLSServer starts and returns a new [Server] using TLS.
// The caller should call Close when finished, to shut it down.
func NewTLSServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
@@ -298,7 +298,7 @@ func (s *Server) Certificate() *x509.Certificate {
// Client returns an HTTP client configured for making requests to the server.
// It is configured to trust the server's TLS test certificate and will
-// close its idle connections on Server.Close.
+// close its idle connections on [Server.Close].
func (s *Server) Client() *http.Client {
return s.client
}
diff --git a/src/net/http/httptrace/trace.go b/src/net/http/httptrace/trace.go
index 6af30f78d1..706a432957 100644
--- a/src/net/http/httptrace/trace.go
+++ b/src/net/http/httptrace/trace.go
@@ -19,7 +19,7 @@ import (
// unique type to prevent assignment.
type clientEventContextKey struct{}
-// ContextClientTrace returns the ClientTrace associated with the
+// ContextClientTrace returns the [ClientTrace] associated with the
// provided context. If none, it returns nil.
func ContextClientTrace(ctx context.Context) *ClientTrace {
trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace)
@@ -233,7 +233,7 @@ func (t *ClientTrace) hasNetHooks() bool {
return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil
}
-// GotConnInfo is the argument to the ClientTrace.GotConn function and
+// GotConnInfo is the argument to the [ClientTrace.GotConn] function and
// contains information about the obtained connection.
type GotConnInfo struct {
// Conn is the connection that was obtained. It is owned by
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go
index 7affe5e61a..2edb9bc98d 100644
--- a/src/net/http/httputil/dump.go
+++ b/src/net/http/httputil/dump.go
@@ -71,8 +71,8 @@ func outgoingLength(req *http.Request) int64 {
return -1
}
-// DumpRequestOut is like DumpRequest but for outgoing client requests. It
-// includes any headers that the standard http.Transport adds, such as
+// DumpRequestOut is like [DumpRequest] but for outgoing client requests. It
+// includes any headers that the standard [http.Transport] adds, such as
// User-Agent.
func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
save := req.Body
@@ -203,17 +203,17 @@ var reqWriteExcludeHeaderDump = map[string]bool{
// representation. It should only be used by servers to debug client
// requests. The returned representation is an approximation only;
// some details of the initial request are lost while parsing it into
-// an http.Request. In particular, the order and case of header field
+// an [http.Request]. In particular, the order and case of header field
// names are lost. The order of values in multi-valued headers is kept
// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their
// original binary representations.
//
// If body is true, DumpRequest also returns the body. To do so, it
-// consumes req.Body and then replaces it with a new io.ReadCloser
+// consumes req.Body and then replaces it with a new [io.ReadCloser]
// that yields the same bytes. If DumpRequest returns an error,
// the state of req is undefined.
//
-// The documentation for http.Request.Write details which fields
+// The documentation for [http.Request.Write] details which fields
// of req are included in the dump.
func DumpRequest(req *http.Request, body bool) ([]byte, error) {
var err error
diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go
index 09ea74d6d1..431930ea65 100644
--- a/src/net/http/httputil/httputil.go
+++ b/src/net/http/httputil/httputil.go
@@ -13,7 +13,7 @@ import (
// NewChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
-// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+// The chunkedReader returns [io.EOF] when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go
index 84b116df8c..0cbe3ebf10 100644
--- a/src/net/http/httputil/persist.go
+++ b/src/net/http/httputil/persist.go
@@ -33,7 +33,7 @@ var errClosed = errors.New("i/o operation on closed connection")
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
-// Deprecated: Use the Server in package net/http instead.
+// Deprecated: Use the Server in package [net/http] instead.
type ServerConn struct {
mu sync.Mutex // read-write protects the following fields
c net.Conn
@@ -50,7 +50,7 @@ type ServerConn struct {
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
-// Deprecated: Use the Server in package net/http instead.
+// Deprecated: Use the Server in package [net/http] instead.
func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
if r == nil {
r = bufio.NewReader(c)
@@ -58,10 +58,10 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)}
}
-// Hijack detaches the ServerConn and returns the underlying connection as well
+// Hijack detaches the [ServerConn] and returns the underlying connection as well
// as the read-side bufio which may have some left over data. Hijack may be
// called before Read has signaled the end of the keep-alive logic. The user
-// should not call Hijack while Read or Write is in progress.
+// should not call Hijack while [ServerConn.Read] or [ServerConn.Write] is in progress.
func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) {
sc.mu.Lock()
defer sc.mu.Unlock()
@@ -72,7 +72,7 @@ func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) {
return c, r
}
-// Close calls Hijack and then also closes the underlying connection.
+// Close calls [ServerConn.Hijack] and then also closes the underlying connection.
func (sc *ServerConn) Close() error {
c, _ := sc.Hijack()
if c != nil {
@@ -81,7 +81,7 @@ func (sc *ServerConn) Close() error {
return nil
}
-// Read returns the next request on the wire. An ErrPersistEOF is returned if
+// Read returns the next request on the wire. An [ErrPersistEOF] is returned if
// it is gracefully determined that there are no more requests (e.g. after the
// first request on an HTTP/1.0 connection, or after a Connection:close on a
// HTTP/1.1 connection).
@@ -171,7 +171,7 @@ func (sc *ServerConn) Pending() int {
// Write writes resp in response to req. To close the connection gracefully, set the
// Response.Close field to true. Write should be considered operational until
-// it returns an error, regardless of any errors returned on the Read side.
+// it returns an error, regardless of any errors returned on the [ServerConn.Read] side.
func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
// Retrieve the pipeline ID of this request/response pair
@@ -226,7 +226,7 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
-// Deprecated: Use Client or Transport in package net/http instead.
+// Deprecated: Use Client or Transport in package [net/http] instead.
type ClientConn struct {
mu sync.Mutex // read-write protects the following fields
c net.Conn
@@ -244,7 +244,7 @@ type ClientConn struct {
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
-// Deprecated: Use the Client or Transport in package net/http instead.
+// Deprecated: Use the Client or Transport in package [net/http] instead.
func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
if r == nil {
r = bufio.NewReader(c)
@@ -261,17 +261,17 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
-// Deprecated: Use the Client or Transport in package net/http instead.
+// Deprecated: Use the Client or Transport in package [net/http] instead.
func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
cc := NewClientConn(c, r)
cc.writeReq = (*http.Request).WriteProxy
return cc
}
-// Hijack detaches the ClientConn and returns the underlying connection as well
+// Hijack detaches the [ClientConn] and returns the underlying connection as well
// as the read-side bufio which may have some left over data. Hijack may be
// called before the user or Read have signaled the end of the keep-alive
-// logic. The user should not call Hijack while Read or Write is in progress.
+// logic. The user should not call Hijack while [ClientConn.Read] or ClientConn.Write is in progress.
func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -282,7 +282,7 @@ func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
return
}
-// Close calls Hijack and then also closes the underlying connection.
+// Close calls [ClientConn.Hijack] and then also closes the underlying connection.
func (cc *ClientConn) Close() error {
c, _ := cc.Hijack()
if c != nil {
@@ -291,7 +291,7 @@ func (cc *ClientConn) Close() error {
return nil
}
-// Write writes a request. An ErrPersistEOF error is returned if the connection
+// Write writes a request. An [ErrPersistEOF] error is returned if the connection
// has been closed in an HTTP keep-alive sense. If req.Close equals true, the
// keep-alive connection is logically closed after this request and the opposing
// server is informed. An ErrUnexpectedEOF indicates the remote closed the
@@ -357,9 +357,9 @@ func (cc *ClientConn) Pending() int {
}
// Read reads the next response from the wire. A valid response might be
-// returned together with an ErrPersistEOF, which means that the remote
+// returned together with an [ErrPersistEOF], which means that the remote
// requested that this be the last request serviced. Read can be called
-// concurrently with Write, but not with another Read.
+// concurrently with [ClientConn.Write], but not with another Read.
func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
// Retrieve the pipeline ID of this request/response pair
cc.mu.Lock()
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go
index 2a76b0b8dc..5c70f0d27b 100644
--- a/src/net/http/httputil/reverseproxy.go
+++ b/src/net/http/httputil/reverseproxy.go
@@ -26,7 +26,7 @@ import (
"golang.org/x/net/http/httpguts"
)
-// A ProxyRequest contains a request to be rewritten by a ReverseProxy.
+// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
type ProxyRequest struct {
// In is the request received by the proxy.
// The Rewrite function must not modify In.
@@ -45,7 +45,7 @@ type ProxyRequest struct {
//
// SetURL rewrites the outbound Host header to match the target's host.
// To preserve the inbound request's Host header (the default behavior
-// of NewSingleHostReverseProxy):
+// of [NewSingleHostReverseProxy]):
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.SetURL(url)
@@ -68,7 +68,7 @@ func (r *ProxyRequest) SetURL(target *url.URL) {
// If the outbound request contains an existing X-Forwarded-For header,
// SetXForwarded appends the client IP address to it. To append to the
// inbound request's X-Forwarded-For header (the default behavior of
-// ReverseProxy when using a Director function), copy the header
+// [ReverseProxy] when using a Director function), copy the header
// from the inbound request before calling SetXForwarded:
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
@@ -200,7 +200,7 @@ type ReverseProxy struct {
}
// A BufferPool is an interface for getting and returning temporary
-// byte slices for use by io.CopyBuffer.
+// byte slices for use by [io.CopyBuffer].
type BufferPool interface {
Get() []byte
Put([]byte)
@@ -239,7 +239,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
return a.Path + b.Path, apath + bpath
}
-// NewSingleHostReverseProxy returns a new ReverseProxy that routes
+// NewSingleHostReverseProxy returns a new [ReverseProxy] that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
@@ -461,10 +461,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
- for k := range h {
- delete(h, k)
- }
-
+ clear(h)
return nil
},
}
diff --git a/src/net/http/internal/ascii/print.go b/src/net/http/internal/ascii/print.go
index 585e5baba4..98dbf4e3d2 100644
--- a/src/net/http/internal/ascii/print.go
+++ b/src/net/http/internal/ascii/print.go
@@ -9,7 +9,7 @@ import (
"unicode"
)
-// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// EqualFold is [strings.EqualFold], ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func EqualFold(s, t string) bool {
if len(s) != len(t) {
diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go
index 5a174415dc..196b5d8925 100644
--- a/src/net/http/internal/chunked.go
+++ b/src/net/http/internal/chunked.go
@@ -22,7 +22,7 @@ var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
-// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+// The chunkedReader returns [io.EOF] when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
@@ -39,7 +39,8 @@ type chunkedReader struct {
n uint64 // unread bytes in chunk
err error
buf [2]byte
- checkEnd bool // whether need to check for \r\n chunk footer
+ checkEnd bool // whether need to check for \r\n chunk footer
+ excess int64 // "excessive" chunk overhead, for malicious sender detection
}
func (cr *chunkedReader) beginChunk() {
@@ -49,10 +50,36 @@ func (cr *chunkedReader) beginChunk() {
if cr.err != nil {
return
}
+ cr.excess += int64(len(line)) + 2 // header, plus \r\n after the chunk data
+ line = trimTrailingWhitespace(line)
+ line, cr.err = removeChunkExtension(line)
+ if cr.err != nil {
+ return
+ }
cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
+ // A sender who sends one byte per chunk will send 5 bytes of overhead
+ // for every byte of data. ("1\r\nX\r\n" to send "X".)
+ // We want to allow this, since streaming a byte at a time can be legitimate.
+ //
+ // A sender can use chunk extensions to add arbitrary amounts of additional
+ // data per byte read. ("1;very long extension\r\nX\r\n" to send "X".)
+ // We don't want to disallow extensions (although we discard them),
+ // but we also don't want to allow a sender to reduce the signal/noise ratio
+ // arbitrarily.
+ //
+ // We track the amount of excess overhead read,
+ // and produce an error if it grows too large.
+ //
+ // Currently, we say that we're willing to accept 16 bytes of overhead per chunk,
+ // plus twice the amount of real data in the chunk.
+ cr.excess -= 16 + (2 * int64(cr.n))
+ cr.excess = max(cr.excess, 0)
+ if cr.excess > 16*1024 {
+ cr.err = errors.New("chunked encoding contains too much non-data")
+ }
if cr.n == 0 {
cr.err = io.EOF
}
@@ -140,11 +167,6 @@ func readChunkLine(b *bufio.Reader) ([]byte, error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
- p = trimTrailingWhitespace(p)
- p, err = removeChunkExtension(p)
- if err != nil {
- return nil, err
- }
return p, nil
}
@@ -199,7 +221,7 @@ type chunkedWriter struct {
// Write the contents of data as one chunk to Wire.
// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
-// a bug since it does not check for success of io.WriteString
+// a bug since it does not check for success of [io.WriteString]
func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
// Don't send 0-length data. It looks like EOF for chunked encoding.
@@ -231,9 +253,9 @@ func (cw *chunkedWriter) Close() error {
return err
}
-// FlushAfterChunkWriter signals from the caller of NewChunkedWriter
+// FlushAfterChunkWriter signals from the caller of [NewChunkedWriter]
// that each chunk should be followed by a flush. It is used by the
-// http.Transport code to keep the buffering behavior for headers and
+// [net/http.Transport] code to keep the buffering behavior for headers and
// trailers, but flush out chunks aggressively in the middle for
// request bodies which may be generated slowly. See Issue 6574.
type FlushAfterChunkWriter struct {
@@ -241,6 +263,9 @@ type FlushAfterChunkWriter struct {
}
func parseHexUint(v []byte) (n uint64, err error) {
+ if len(v) == 0 {
+ return 0, errors.New("empty hex number for chunk length")
+ }
for i, b := range v {
switch {
case '0' <= b && b <= '9':
diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go
index 5e29a786dd..af79711781 100644
--- a/src/net/http/internal/chunked_test.go
+++ b/src/net/http/internal/chunked_test.go
@@ -153,6 +153,7 @@ func TestParseHexUint(t *testing.T) {
{"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted
{"10000000000000000", 0, "http chunk length too large"},
{"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted
+ {"", 0, "empty hex number for chunk length"},
}
for i := uint64(0); i <= 1234; i++ {
tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i})
@@ -239,3 +240,62 @@ func TestChunkEndReadError(t *testing.T) {
t.Errorf("expected %v, got %v", readErr, err)
}
}
+
+func TestChunkReaderTooMuchOverhead(t *testing.T) {
+ // If the sender is sending 100x as many chunk header bytes as chunk data,
+ // we should reject the stream at some point.
+ chunk := []byte("1;")
+ for i := 0; i < 100; i++ {
+ chunk = append(chunk, 'a') // chunk extension
+ }
+ chunk = append(chunk, "\r\nX\r\n"...)
+ const bodylen = 1 << 20
+ r := NewChunkedReader(&funcReader{f: func(i int) ([]byte, error) {
+ if i < bodylen {
+ return chunk, nil
+ }
+ return []byte("0\r\n"), nil
+ }})
+ _, err := io.ReadAll(r)
+ if err == nil {
+ t.Fatalf("successfully read body with excessive overhead; want error")
+ }
+}
+
+func TestChunkReaderByteAtATime(t *testing.T) {
+ // Sending one byte per chunk should not trip the excess-overhead detection.
+ const bodylen = 1 << 20
+ r := NewChunkedReader(&funcReader{f: func(i int) ([]byte, error) {
+ if i < bodylen {
+ return []byte("1\r\nX\r\n"), nil
+ }
+ return []byte("0\r\n"), nil
+ }})
+ got, err := io.ReadAll(r)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if len(got) != bodylen {
+ t.Errorf("read %v bytes, want %v", len(got), bodylen)
+ }
+}
+
+type funcReader struct {
+ f func(iteration int) ([]byte, error)
+ i int
+ b []byte
+ err error
+}
+
+func (r *funcReader) Read(p []byte) (n int, err error) {
+ if len(r.b) == 0 && r.err == nil {
+ r.b, r.err = r.f(r.i)
+ r.i++
+ }
+ n = copy(p, r.b)
+ r.b = r.b[n:]
+ if len(r.b) > 0 {
+ return n, nil
+ }
+ return n, r.err
+}
diff --git a/src/net/http/mapping.go b/src/net/http/mapping.go
new file mode 100644
index 0000000000..87e6d5ae5d
--- /dev/null
+++ b/src/net/http/mapping.go
@@ -0,0 +1,78 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// A mapping is a collection of key-value pairs where the keys are unique.
+// A zero mapping is empty and ready to use.
+// A mapping tries to pick a representation that makes [mapping.find] most efficient.
+type mapping[K comparable, V any] struct {
+ s []entry[K, V] // for few pairs
+ m map[K]V // for many pairs
+}
+
+type entry[K comparable, V any] struct {
+ key K
+ value V
+}
+
+// maxSlice is the maximum number of pairs for which a slice is used.
+// It is a variable for benchmarking.
+var maxSlice int = 8
+
+// add adds a key-value pair to the mapping.
+func (h *mapping[K, V]) add(k K, v V) {
+ if h.m == nil && len(h.s) < maxSlice {
+ h.s = append(h.s, entry[K, V]{k, v})
+ } else {
+ if h.m == nil {
+ h.m = map[K]V{}
+ for _, e := range h.s {
+ h.m[e.key] = e.value
+ }
+ h.s = nil
+ }
+ h.m[k] = v
+ }
+}
+
+// find returns the value corresponding to the given key.
+// The second return value is false if there is no value
+// with that key.
+func (h *mapping[K, V]) find(k K) (v V, found bool) {
+ if h == nil {
+ return v, false
+ }
+ if h.m != nil {
+ v, found = h.m[k]
+ return v, found
+ }
+ for _, e := range h.s {
+ if e.key == k {
+ return e.value, true
+ }
+ }
+ return v, false
+}
+
+// eachPair calls f for each pair in the mapping.
+// If f returns false, pairs returns immediately.
+func (h *mapping[K, V]) eachPair(f func(k K, v V) bool) {
+ if h == nil {
+ return
+ }
+ if h.m != nil {
+ for k, v := range h.m {
+ if !f(k, v) {
+ return
+ }
+ }
+ } else {
+ for _, e := range h.s {
+ if !f(e.key, e.value) {
+ return
+ }
+ }
+ }
+}
diff --git a/src/net/http/mapping_test.go b/src/net/http/mapping_test.go
new file mode 100644
index 0000000000..0aed9d9e31
--- /dev/null
+++ b/src/net/http/mapping_test.go
@@ -0,0 +1,154 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "cmp"
+ "fmt"
+ "slices"
+ "strconv"
+ "testing"
+)
+
+func TestMapping(t *testing.T) {
+ var m mapping[int, string]
+ for i := 0; i < maxSlice; i++ {
+ m.add(i, strconv.Itoa(i))
+ }
+ if m.m != nil {
+ t.Fatal("m.m != nil")
+ }
+ for i := 0; i < maxSlice; i++ {
+ g, _ := m.find(i)
+ w := strconv.Itoa(i)
+ if g != w {
+ t.Fatalf("%d: got %s, want %s", i, g, w)
+ }
+ }
+ m.add(4, "4")
+ if m.s != nil {
+ t.Fatal("m.s != nil")
+ }
+ if m.m == nil {
+ t.Fatal("m.m == nil")
+ }
+ g, _ := m.find(4)
+ if w := "4"; g != w {
+ t.Fatalf("got %s, want %s", g, w)
+ }
+}
+
+func TestMappingEachPair(t *testing.T) {
+ var m mapping[int, string]
+ var want []entry[int, string]
+ for i := 0; i < maxSlice*2; i++ {
+ v := strconv.Itoa(i)
+ m.add(i, v)
+ want = append(want, entry[int, string]{i, v})
+
+ }
+
+ var got []entry[int, string]
+ m.eachPair(func(k int, v string) bool {
+ got = append(got, entry[int, string]{k, v})
+ return true
+ })
+ slices.SortFunc(got, func(e1, e2 entry[int, string]) int {
+ return cmp.Compare(e1.key, e2.key)
+ })
+ if !slices.Equal(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+}
+
+func BenchmarkFindChild(b *testing.B) {
+ key := "articles"
+ children := []string{
+ "*",
+ "cmd.html",
+ "code.html",
+ "contrib.html",
+ "contribute.html",
+ "debugging_with_gdb.html",
+ "docs.html",
+ "effective_go.html",
+ "files.log",
+ "gccgo_contribute.html",
+ "gccgo_install.html",
+ "go-logo-black.png",
+ "go-logo-blue.png",
+ "go-logo-white.png",
+ "go1.1.html",
+ "go1.2.html",
+ "go1.html",
+ "go1compat.html",
+ "go_faq.html",
+ "go_mem.html",
+ "go_spec.html",
+ "help.html",
+ "ie.css",
+ "install-source.html",
+ "install.html",
+ "logo-153x55.png",
+ "Makefile",
+ "root.html",
+ "share.png",
+ "sieve.gif",
+ "tos.html",
+ "articles",
+ }
+ if len(children) != 32 {
+ panic("bad len")
+ }
+ for _, n := range []int{2, 4, 8, 16, 32} {
+ list := children[:n]
+ b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
+
+ b.Run("rep=linear", func(b *testing.B) {
+ var entries []entry[string, any]
+ for _, c := range list {
+ entries = append(entries, entry[string, any]{c, nil})
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ findChildLinear(key, entries)
+ }
+ })
+ b.Run("rep=map", func(b *testing.B) {
+ m := map[string]any{}
+ for _, c := range list {
+ m[c] = nil
+ }
+ var x any
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x = m[key]
+ }
+ _ = x
+ })
+ b.Run(fmt.Sprintf("rep=hybrid%d", maxSlice), func(b *testing.B) {
+ var h mapping[string, any]
+ for _, c := range list {
+ h.add(c, nil)
+ }
+ var x any
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x, _ = h.find(key)
+ }
+ _ = x
+ })
+ })
+ }
+}
+
+func findChildLinear(key string, entries []entry[string, any]) any {
+ for _, e := range entries {
+ if key == e.key {
+ return e.value
+ }
+ }
+ return nil
+}
diff --git a/src/net/http/pattern.go b/src/net/http/pattern.go
new file mode 100644
index 0000000000..f6af19b0f4
--- /dev/null
+++ b/src/net/http/pattern.go
@@ -0,0 +1,529 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Patterns for ServeMux routing.
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "net/url"
+ "strings"
+ "unicode"
+)
+
+// A pattern is something that can be matched against an HTTP request.
+// It has an optional method, an optional host, and a path.
+type pattern struct {
+ str string // original string
+ method string
+ host string
+ // The representation of a path differs from the surface syntax, which
+ // simplifies most algorithms.
+ //
+ // Paths ending in '/' are represented with an anonymous "..." wildcard.
+ // For example, the path "a/" is represented as a literal segment "a" followed
+ // by a segment with multi==true.
+ //
+ // Paths ending in "{$}" are represented with the literal segment "/".
+ // For example, the path "a/{$}" is represented as a literal segment "a" followed
+ // by a literal segment "/".
+ segments []segment
+ loc string // source location of registering call, for helpful messages
+}
+
+func (p *pattern) String() string { return p.str }
+
+func (p *pattern) lastSegment() segment {
+ return p.segments[len(p.segments)-1]
+}
+
+// A segment is a pattern piece that matches one or more path segments, or
+// a trailing slash.
+//
+// If wild is false, it matches a literal segment, or, if s == "/", a trailing slash.
+// Examples:
+//
+// "a" => segment{s: "a"}
+// "/{$}" => segment{s: "/"}
+//
+// If wild is true and multi is false, it matches a single path segment.
+// Example:
+//
+// "{x}" => segment{s: "x", wild: true}
+//
+// If both wild and multi are true, it matches all remaining path segments.
+// Example:
+//
+// "{rest...}" => segment{s: "rest", wild: true, multi: true}
+type segment struct {
+ s string // literal or wildcard name or "/" for "/{$}".
+ wild bool
+ multi bool // "..." wildcard
+}
+
+// parsePattern parses a string into a Pattern.
+// The string's syntax is
+//
+// [METHOD] [HOST]/[PATH]
+//
+// where:
+// - METHOD is an HTTP method
+// - HOST is a hostname
+// - PATH consists of slash-separated segments, where each segment is either
+// a literal or a wildcard of the form "{name}", "{name...}", or "{$}".
+//
+// METHOD, HOST and PATH are all optional; that is, the string can be "/".
+// If METHOD is present, it must be followed by a single space.
+// Wildcard names must be valid Go identifiers.
+// The "{$}" and "{name...}" wildcard must occur at the end of PATH.
+// PATH may end with a '/'.
+// Wildcard names in a path must be distinct.
+func parsePattern(s string) (_ *pattern, err error) {
+ if len(s) == 0 {
+ return nil, errors.New("empty pattern")
+ }
+ off := 0 // offset into string
+ defer func() {
+ if err != nil {
+ err = fmt.Errorf("at offset %d: %w", off, err)
+ }
+ }()
+
+ method, rest, found := strings.Cut(s, " ")
+ if !found {
+ rest = method
+ method = ""
+ }
+ if method != "" && !validMethod(method) {
+ return nil, fmt.Errorf("invalid method %q", method)
+ }
+ p := &pattern{str: s, method: method}
+
+ if found {
+ off = len(method) + 1
+ }
+ i := strings.IndexByte(rest, '/')
+ if i < 0 {
+ return nil, errors.New("host/path missing /")
+ }
+ p.host = rest[:i]
+ rest = rest[i:]
+ if j := strings.IndexByte(p.host, '{'); j >= 0 {
+ off += j
+ return nil, errors.New("host contains '{' (missing initial '/'?)")
+ }
+ // At this point, rest is the path.
+ off += i
+
+ // An unclean path with a method that is not CONNECT can never match,
+ // because paths are cleaned before matching.
+ if method != "" && method != "CONNECT" && rest != cleanPath(rest) {
+ return nil, errors.New("non-CONNECT pattern with unclean path can never match")
+ }
+
+ seenNames := map[string]bool{} // remember wildcard names to catch dups
+ for len(rest) > 0 {
+ // Invariant: rest[0] == '/'.
+ rest = rest[1:]
+ off = len(s) - len(rest)
+ if len(rest) == 0 {
+ // Trailing slash.
+ p.segments = append(p.segments, segment{wild: true, multi: true})
+ break
+ }
+ i := strings.IndexByte(rest, '/')
+ if i < 0 {
+ i = len(rest)
+ }
+ var seg string
+ seg, rest = rest[:i], rest[i:]
+ if i := strings.IndexByte(seg, '{'); i < 0 {
+ // Literal.
+ seg = pathUnescape(seg)
+ p.segments = append(p.segments, segment{s: seg})
+ } else {
+ // Wildcard.
+ if i != 0 {
+ return nil, errors.New("bad wildcard segment (must start with '{')")
+ }
+ if seg[len(seg)-1] != '}' {
+ return nil, errors.New("bad wildcard segment (must end with '}')")
+ }
+ name := seg[1 : len(seg)-1]
+ if name == "$" {
+ if len(rest) != 0 {
+ return nil, errors.New("{$} not at end")
+ }
+ p.segments = append(p.segments, segment{s: "/"})
+ break
+ }
+ name, multi := strings.CutSuffix(name, "...")
+ if multi && len(rest) != 0 {
+ return nil, errors.New("{...} wildcard not at end")
+ }
+ if name == "" {
+ return nil, errors.New("empty wildcard")
+ }
+ if !isValidWildcardName(name) {
+ return nil, fmt.Errorf("bad wildcard name %q", name)
+ }
+ if seenNames[name] {
+ return nil, fmt.Errorf("duplicate wildcard name %q", name)
+ }
+ seenNames[name] = true
+ p.segments = append(p.segments, segment{s: name, wild: true, multi: multi})
+ }
+ }
+ return p, nil
+}
+
+func isValidWildcardName(s string) bool {
+ if s == "" {
+ return false
+ }
+ // Valid Go identifier.
+ for i, c := range s {
+ if !unicode.IsLetter(c) && c != '_' && (i == 0 || !unicode.IsDigit(c)) {
+ return false
+ }
+ }
+ return true
+}
+
+func pathUnescape(path string) string {
+ u, err := url.PathUnescape(path)
+ if err != nil {
+ // Invalidly escaped path; use the original
+ return path
+ }
+ return u
+}
+
+// relationship is a relationship between two patterns, p1 and p2.
+type relationship string
+
+const (
+ equivalent relationship = "equivalent" // both match the same requests
+ moreGeneral relationship = "moreGeneral" // p1 matches everything p2 does & more
+ moreSpecific relationship = "moreSpecific" // p2 matches everything p1 does & more
+ disjoint relationship = "disjoint" // there is no request that both match
+ overlaps relationship = "overlaps" // there is a request that both match, but neither is more specific
+)
+
+// conflictsWith reports whether p1 conflicts with p2, that is, whether
+// there is a request that both match but where neither is higher precedence
+// than the other.
+//
+// Precedence is defined by two rules:
+// 1. Patterns with a host win over patterns without a host.
+// 2. Patterns whose method and path is more specific win. One pattern is more
+// specific than another if the second matches all the (method, path) pairs
+// of the first and more.
+//
+// If rule 1 doesn't apply, then two patterns conflict if their relationship
+// is either equivalence (they match the same set of requests) or overlap
+// (they both match some requests, but neither is more specific than the other).
+func (p1 *pattern) conflictsWith(p2 *pattern) bool {
+ if p1.host != p2.host {
+ // Either one host is empty and the other isn't, in which case the
+ // one with the host wins by rule 1, or neither host is empty
+ // and they differ, so they won't match the same paths.
+ return false
+ }
+ rel := p1.comparePathsAndMethods(p2)
+ return rel == equivalent || rel == overlaps
+}
+
+func (p1 *pattern) comparePathsAndMethods(p2 *pattern) relationship {
+ mrel := p1.compareMethods(p2)
+ // Optimization: avoid a call to comparePaths.
+ if mrel == disjoint {
+ return disjoint
+ }
+ prel := p1.comparePaths(p2)
+ return combineRelationships(mrel, prel)
+}
+
+// compareMethods determines the relationship between the method
+// part of patterns p1 and p2.
+//
+// A method can either be empty, "GET", or something else.
+// The empty string matches any method, so it is the most general.
+// "GET" matches both GET and HEAD.
+// Anything else matches only itself.
+func (p1 *pattern) compareMethods(p2 *pattern) relationship {
+ if p1.method == p2.method {
+ return equivalent
+ }
+ if p1.method == "" {
+ // p1 matches any method, but p2 does not, so p1 is more general.
+ return moreGeneral
+ }
+ if p2.method == "" {
+ return moreSpecific
+ }
+ if p1.method == "GET" && p2.method == "HEAD" {
+ // p1 matches GET and HEAD; p2 matches only HEAD.
+ return moreGeneral
+ }
+ if p2.method == "GET" && p1.method == "HEAD" {
+ return moreSpecific
+ }
+ return disjoint
+}
+
+// comparePaths determines the relationship between the path
+// part of two patterns.
+func (p1 *pattern) comparePaths(p2 *pattern) relationship {
+ // Optimization: if a path pattern doesn't end in a multi ("...") wildcard, then it
+ // can only match paths with the same number of segments.
+ if len(p1.segments) != len(p2.segments) && !p1.lastSegment().multi && !p2.lastSegment().multi {
+ return disjoint
+ }
+
+ // Consider corresponding segments in the two path patterns.
+ var segs1, segs2 []segment
+ rel := equivalent
+ for segs1, segs2 = p1.segments, p2.segments; len(segs1) > 0 && len(segs2) > 0; segs1, segs2 = segs1[1:], segs2[1:] {
+ rel = combineRelationships(rel, compareSegments(segs1[0], segs2[0]))
+ if rel == disjoint {
+ return rel
+ }
+ }
+ // We've reached the end of the corresponding segments of the patterns.
+ // If they have the same number of segments, then we've already determined
+ // their relationship.
+ if len(segs1) == 0 && len(segs2) == 0 {
+ return rel
+ }
+ // Otherwise, the only way they could fail to be disjoint is if the shorter
+ // pattern ends in a multi. In that case, that multi is more general
+ // than the remainder of the longer pattern, so combine those two relationships.
+ if len(segs1) < len(segs2) && p1.lastSegment().multi {
+ return combineRelationships(rel, moreGeneral)
+ }
+ if len(segs2) < len(segs1) && p2.lastSegment().multi {
+ return combineRelationships(rel, moreSpecific)
+ }
+ return disjoint
+}
+
+// compareSegments determines the relationship between two segments.
+func compareSegments(s1, s2 segment) relationship {
+ if s1.multi && s2.multi {
+ return equivalent
+ }
+ if s1.multi {
+ return moreGeneral
+ }
+ if s2.multi {
+ return moreSpecific
+ }
+ if s1.wild && s2.wild {
+ return equivalent
+ }
+ if s1.wild {
+ if s2.s == "/" {
+ // A single wildcard doesn't match a trailing slash.
+ return disjoint
+ }
+ return moreGeneral
+ }
+ if s2.wild {
+ if s1.s == "/" {
+ return disjoint
+ }
+ return moreSpecific
+ }
+ // Both literals.
+ if s1.s == s2.s {
+ return equivalent
+ }
+ return disjoint
+}
+
+// combineRelationships determines the overall relationship of two patterns
+// given the relationships of a partition of the patterns into two parts.
+//
+// For example, if p1 is more general than p2 in one way but equivalent
+// in the other, then it is more general overall.
+//
+// Or if p1 is more general in one way and more specific in the other, then
+// they overlap.
+func combineRelationships(r1, r2 relationship) relationship {
+ switch r1 {
+ case equivalent:
+ return r2
+ case disjoint:
+ return disjoint
+ case overlaps:
+ if r2 == disjoint {
+ return disjoint
+ }
+ return overlaps
+ case moreGeneral, moreSpecific:
+ switch r2 {
+ case equivalent:
+ return r1
+ case inverseRelationship(r1):
+ return overlaps
+ default:
+ return r2
+ }
+ default:
+ panic(fmt.Sprintf("unknown relationship %q", r1))
+ }
+}
+
+// If p1 has relationship `r` to p2, then
+// p2 has inverseRelationship(r) to p1.
+func inverseRelationship(r relationship) relationship {
+ switch r {
+ case moreSpecific:
+ return moreGeneral
+ case moreGeneral:
+ return moreSpecific
+ default:
+ return r
+ }
+}
+
+// isLitOrSingle reports whether the segment is a non-dollar literal or a single wildcard.
+func isLitOrSingle(seg segment) bool {
+ if seg.wild {
+ return !seg.multi
+ }
+ return seg.s != "/"
+}
+
+// describeConflict returns an explanation of why two patterns conflict.
+func describeConflict(p1, p2 *pattern) string {
+ mrel := p1.compareMethods(p2)
+ prel := p1.comparePaths(p2)
+ rel := combineRelationships(mrel, prel)
+ if rel == equivalent {
+ return fmt.Sprintf("%s matches the same requests as %s", p1, p2)
+ }
+ if rel != overlaps {
+ panic("describeConflict called with non-conflicting patterns")
+ }
+ if prel == overlaps {
+ return fmt.Sprintf(`%[1]s and %[2]s both match some paths, like %[3]q.
+But neither is more specific than the other.
+%[1]s matches %[4]q, but %[2]s doesn't.
+%[2]s matches %[5]q, but %[1]s doesn't.`,
+ p1, p2, commonPath(p1, p2), differencePath(p1, p2), differencePath(p2, p1))
+ }
+ if mrel == moreGeneral && prel == moreSpecific {
+ return fmt.Sprintf("%s matches more methods than %s, but has a more specific path pattern", p1, p2)
+ }
+ if mrel == moreSpecific && prel == moreGeneral {
+ return fmt.Sprintf("%s matches fewer methods than %s, but has a more general path pattern", p1, p2)
+ }
+ return fmt.Sprintf("bug: unexpected way for two patterns %s and %s to conflict: methods %s, paths %s", p1, p2, mrel, prel)
+}
+
+// writeMatchingPath writes to b a path that matches the segments.
+func writeMatchingPath(b *strings.Builder, segs []segment) {
+ for _, s := range segs {
+ writeSegment(b, s)
+ }
+}
+
+func writeSegment(b *strings.Builder, s segment) {
+ b.WriteByte('/')
+ if !s.multi && s.s != "/" {
+ b.WriteString(s.s)
+ }
+}
+
+// commonPath returns a path that both p1 and p2 match.
+// It assumes there is such a path.
+func commonPath(p1, p2 *pattern) string {
+ var b strings.Builder
+ var segs1, segs2 []segment
+ for segs1, segs2 = p1.segments, p2.segments; len(segs1) > 0 && len(segs2) > 0; segs1, segs2 = segs1[1:], segs2[1:] {
+ if s1 := segs1[0]; s1.wild {
+ writeSegment(&b, segs2[0])
+ } else {
+ writeSegment(&b, s1)
+ }
+ }
+ if len(segs1) > 0 {
+ writeMatchingPath(&b, segs1)
+ } else if len(segs2) > 0 {
+ writeMatchingPath(&b, segs2)
+ }
+ return b.String()
+}
+
+// differencePath returns a path that p1 matches and p2 doesn't.
+// It assumes there is such a path.
+func differencePath(p1, p2 *pattern) string {
+ var b strings.Builder
+
+ var segs1, segs2 []segment
+ for segs1, segs2 = p1.segments, p2.segments; len(segs1) > 0 && len(segs2) > 0; segs1, segs2 = segs1[1:], segs2[1:] {
+ s1 := segs1[0]
+ s2 := segs2[0]
+ if s1.multi && s2.multi {
+ // From here the patterns match the same paths, so we must have found a difference earlier.
+ b.WriteByte('/')
+ return b.String()
+
+ }
+ if s1.multi && !s2.multi {
+ // s1 ends in a "..." wildcard but s2 does not.
+ // A trailing slash will distinguish them, unless s2 ends in "{$}",
+ // in which case any segment will do; prefer the wildcard name if
+ // it has one.
+ b.WriteByte('/')
+ if s2.s == "/" {
+ if s1.s != "" {
+ b.WriteString(s1.s)
+ } else {
+ b.WriteString("x")
+ }
+ }
+ return b.String()
+ }
+ if !s1.multi && s2.multi {
+ writeSegment(&b, s1)
+ } else if s1.wild && s2.wild {
+ // Both patterns will match whatever we put here; use
+ // the first wildcard name.
+ writeSegment(&b, s1)
+ } else if s1.wild && !s2.wild {
+ // s1 is a wildcard, s2 is a literal.
+ // Any segment other than s2.s will work.
+ // Prefer the wildcard name, but if it's the same as the literal,
+ // tweak the literal.
+ if s1.s != s2.s {
+ writeSegment(&b, s1)
+ } else {
+ b.WriteByte('/')
+ b.WriteString(s2.s + "x")
+ }
+ } else if !s1.wild && s2.wild {
+ writeSegment(&b, s1)
+ } else {
+ // Both are literals. A precondition of this function is that the
+ // patterns overlap, so they must be the same literal. Use it.
+ if s1.s != s2.s {
+ panic(fmt.Sprintf("literals differ: %q and %q", s1.s, s2.s))
+ }
+ writeSegment(&b, s1)
+ }
+ }
+ if len(segs1) > 0 {
+ // p1 is longer than p2, and p2 does not end in a multi.
+ // Anything that matches the rest of p1 will do.
+ writeMatchingPath(&b, segs1)
+ } else if len(segs2) > 0 {
+ writeMatchingPath(&b, segs2)
+ }
+ return b.String()
+}
diff --git a/src/net/http/pattern_test.go b/src/net/http/pattern_test.go
new file mode 100644
index 0000000000..f0c84d243e
--- /dev/null
+++ b/src/net/http/pattern_test.go
@@ -0,0 +1,494 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "slices"
+ "strings"
+ "testing"
+)
+
+func TestParsePattern(t *testing.T) {
+ lit := func(name string) segment {
+ return segment{s: name}
+ }
+
+ wild := func(name string) segment {
+ return segment{s: name, wild: true}
+ }
+
+ multi := func(name string) segment {
+ s := wild(name)
+ s.multi = true
+ return s
+ }
+
+ for _, test := range []struct {
+ in string
+ want pattern
+ }{
+ {"/", pattern{segments: []segment{multi("")}}},
+ {"/a", pattern{segments: []segment{lit("a")}}},
+ {
+ "/a/",
+ pattern{segments: []segment{lit("a"), multi("")}},
+ },
+ {"/path/to/something", pattern{segments: []segment{
+ lit("path"), lit("to"), lit("something"),
+ }}},
+ {
+ "/{w1}/lit/{w2}",
+ pattern{
+ segments: []segment{wild("w1"), lit("lit"), wild("w2")},
+ },
+ },
+ {
+ "/{w1}/lit/{w2}/",
+ pattern{
+ segments: []segment{wild("w1"), lit("lit"), wild("w2"), multi("")},
+ },
+ },
+ {
+ "example.com/",
+ pattern{host: "example.com", segments: []segment{multi("")}},
+ },
+ {
+ "GET /",
+ pattern{method: "GET", segments: []segment{multi("")}},
+ },
+ {
+ "POST example.com/foo/{w}",
+ pattern{
+ method: "POST",
+ host: "example.com",
+ segments: []segment{lit("foo"), wild("w")},
+ },
+ },
+ {
+ "/{$}",
+ pattern{segments: []segment{lit("/")}},
+ },
+ {
+ "DELETE example.com/a/{foo12}/{$}",
+ pattern{method: "DELETE", host: "example.com", segments: []segment{lit("a"), wild("foo12"), lit("/")}},
+ },
+ {
+ "/foo/{$}",
+ pattern{segments: []segment{lit("foo"), lit("/")}},
+ },
+ {
+ "/{a}/foo/{rest...}",
+ pattern{segments: []segment{wild("a"), lit("foo"), multi("rest")}},
+ },
+ {
+ "//",
+ pattern{segments: []segment{lit(""), multi("")}},
+ },
+ {
+ "/foo///./../bar",
+ pattern{segments: []segment{lit("foo"), lit(""), lit(""), lit("."), lit(".."), lit("bar")}},
+ },
+ {
+ "a.com/foo//",
+ pattern{host: "a.com", segments: []segment{lit("foo"), lit(""), multi("")}},
+ },
+ {
+ "/%61%62/%7b/%",
+ pattern{segments: []segment{lit("ab"), lit("{"), lit("%")}},
+ },
+ } {
+ got := mustParsePattern(t, test.in)
+ if !got.equal(&test.want) {
+ t.Errorf("%q:\ngot %#v\nwant %#v", test.in, got, &test.want)
+ }
+ }
+}
+
+func TestParsePatternError(t *testing.T) {
+ for _, test := range []struct {
+ in string
+ contains string
+ }{
+ {"", "empty pattern"},
+ {"A=B /", "at offset 0: invalid method"},
+ {" ", "at offset 1: host/path missing /"},
+ {"/{w}x", "at offset 1: bad wildcard segment"},
+ {"/x{w}", "at offset 1: bad wildcard segment"},
+ {"/{wx", "at offset 1: bad wildcard segment"},
+ {"/a/{/}/c", "at offset 3: bad wildcard segment"},
+ {"/a/{%61}/c", "at offset 3: bad wildcard name"}, // wildcard names aren't unescaped
+ {"/{a$}", "at offset 1: bad wildcard name"},
+ {"/{}", "at offset 1: empty wildcard"},
+ {"POST a.com/x/{}/y", "at offset 13: empty wildcard"},
+ {"/{...}", "at offset 1: empty wildcard"},
+ {"/{$...}", "at offset 1: bad wildcard"},
+ {"/{$}/", "at offset 1: {$} not at end"},
+ {"/{$}/x", "at offset 1: {$} not at end"},
+ {"/abc/{$}/x", "at offset 5: {$} not at end"},
+ {"/{a...}/", "at offset 1: {...} wildcard not at end"},
+ {"/{a...}/x", "at offset 1: {...} wildcard not at end"},
+ {"{a}/b", "at offset 0: host contains '{' (missing initial '/'?)"},
+ {"/a/{x}/b/{x...}", "at offset 9: duplicate wildcard name"},
+ {"GET //", "at offset 4: non-CONNECT pattern with unclean path"},
+ } {
+ _, err := parsePattern(test.in)
+ if err == nil || !strings.Contains(err.Error(), test.contains) {
+ t.Errorf("%q:\ngot %v, want error containing %q", test.in, err, test.contains)
+ }
+ }
+}
+
+func (p1 *pattern) equal(p2 *pattern) bool {
+ return p1.method == p2.method && p1.host == p2.host &&
+ slices.Equal(p1.segments, p2.segments)
+}
+
+func mustParsePattern(tb testing.TB, s string) *pattern {
+ tb.Helper()
+ p, err := parsePattern(s)
+ if err != nil {
+ tb.Fatal(err)
+ }
+ return p
+}
+
+func TestCompareMethods(t *testing.T) {
+ for _, test := range []struct {
+ p1, p2 string
+ want relationship
+ }{
+ {"/", "/", equivalent},
+ {"GET /", "GET /", equivalent},
+ {"HEAD /", "HEAD /", equivalent},
+ {"POST /", "POST /", equivalent},
+ {"GET /", "POST /", disjoint},
+ {"GET /", "/", moreSpecific},
+ {"HEAD /", "/", moreSpecific},
+ {"GET /", "HEAD /", moreGeneral},
+ } {
+ pat1 := mustParsePattern(t, test.p1)
+ pat2 := mustParsePattern(t, test.p2)
+ got := pat1.compareMethods(pat2)
+ if got != test.want {
+ t.Errorf("%s vs %s: got %s, want %s", test.p1, test.p2, got, test.want)
+ }
+ got2 := pat2.compareMethods(pat1)
+ want2 := inverseRelationship(test.want)
+ if got2 != want2 {
+ t.Errorf("%s vs %s: got %s, want %s", test.p2, test.p1, got2, want2)
+ }
+ }
+}
+
+func TestComparePaths(t *testing.T) {
+ for _, test := range []struct {
+ p1, p2 string
+ want relationship
+ }{
+ // A non-final pattern segment can have one of two values: literal or
+ // single wildcard. A final pattern segment can have one of 5: empty
+ // (trailing slash), literal, dollar, single wildcard, or multi
+ // wildcard. Trailing slash and multi wildcard are the same.
+
+ // A literal should be more specific than anything it overlaps, except itself.
+ {"/a", "/a", equivalent},
+ {"/a", "/b", disjoint},
+ {"/a", "/", moreSpecific},
+ {"/a", "/{$}", disjoint},
+ {"/a", "/{x}", moreSpecific},
+ {"/a", "/{x...}", moreSpecific},
+
+ // Adding a segment doesn't change that.
+ {"/b/a", "/b/a", equivalent},
+ {"/b/a", "/b/b", disjoint},
+ {"/b/a", "/b/", moreSpecific},
+ {"/b/a", "/b/{$}", disjoint},
+ {"/b/a", "/b/{x}", moreSpecific},
+ {"/b/a", "/b/{x...}", moreSpecific},
+ {"/{z}/a", "/{z}/a", equivalent},
+ {"/{z}/a", "/{z}/b", disjoint},
+ {"/{z}/a", "/{z}/", moreSpecific},
+ {"/{z}/a", "/{z}/{$}", disjoint},
+ {"/{z}/a", "/{z}/{x}", moreSpecific},
+ {"/{z}/a", "/{z}/{x...}", moreSpecific},
+
+ // Single wildcard on left.
+ {"/{z}", "/a", moreGeneral},
+ {"/{z}", "/a/b", disjoint},
+ {"/{z}", "/{$}", disjoint},
+ {"/{z}", "/{x}", equivalent},
+ {"/{z}", "/", moreSpecific},
+ {"/{z}", "/{x...}", moreSpecific},
+ {"/b/{z}", "/b/a", moreGeneral},
+ {"/b/{z}", "/b/a/b", disjoint},
+ {"/b/{z}", "/b/{$}", disjoint},
+ {"/b/{z}", "/b/{x}", equivalent},
+ {"/b/{z}", "/b/", moreSpecific},
+ {"/b/{z}", "/b/{x...}", moreSpecific},
+
+ // Trailing slash on left.
+ {"/", "/a", moreGeneral},
+ {"/", "/a/b", moreGeneral},
+ {"/", "/{$}", moreGeneral},
+ {"/", "/{x}", moreGeneral},
+ {"/", "/", equivalent},
+ {"/", "/{x...}", equivalent},
+
+ {"/b/", "/b/a", moreGeneral},
+ {"/b/", "/b/a/b", moreGeneral},
+ {"/b/", "/b/{$}", moreGeneral},
+ {"/b/", "/b/{x}", moreGeneral},
+ {"/b/", "/b/", equivalent},
+ {"/b/", "/b/{x...}", equivalent},
+
+ {"/{z}/", "/{z}/a", moreGeneral},
+ {"/{z}/", "/{z}/a/b", moreGeneral},
+ {"/{z}/", "/{z}/{$}", moreGeneral},
+ {"/{z}/", "/{z}/{x}", moreGeneral},
+ {"/{z}/", "/{z}/", equivalent},
+ {"/{z}/", "/a/", moreGeneral},
+ {"/{z}/", "/{z}/{x...}", equivalent},
+ {"/{z}/", "/a/{x...}", moreGeneral},
+ {"/a/{z}/", "/{z}/a/", overlaps},
+ {"/a/{z}/b/", "/{x}/c/{y...}", overlaps},
+
+ // Multi wildcard on left.
+ {"/{m...}", "/a", moreGeneral},
+ {"/{m...}", "/a/b", moreGeneral},
+ {"/{m...}", "/{$}", moreGeneral},
+ {"/{m...}", "/{x}", moreGeneral},
+ {"/{m...}", "/", equivalent},
+ {"/{m...}", "/{x...}", equivalent},
+
+ {"/b/{m...}", "/b/a", moreGeneral},
+ {"/b/{m...}", "/b/a/b", moreGeneral},
+ {"/b/{m...}", "/b/{$}", moreGeneral},
+ {"/b/{m...}", "/b/{x}", moreGeneral},
+ {"/b/{m...}", "/b/", equivalent},
+ {"/b/{m...}", "/b/{x...}", equivalent},
+ {"/b/{m...}", "/a/{x...}", disjoint},
+
+ {"/{z}/{m...}", "/{z}/a", moreGeneral},
+ {"/{z}/{m...}", "/{z}/a/b", moreGeneral},
+ {"/{z}/{m...}", "/{z}/{$}", moreGeneral},
+ {"/{z}/{m...}", "/{z}/{x}", moreGeneral},
+ {"/{z}/{m...}", "/{w}/", equivalent},
+ {"/{z}/{m...}", "/a/", moreGeneral},
+ {"/{z}/{m...}", "/{z}/{x...}", equivalent},
+ {"/{z}/{m...}", "/a/{x...}", moreGeneral},
+ {"/a/{m...}", "/a/b/{y...}", moreGeneral},
+ {"/a/{m...}", "/a/{x}/{y...}", moreGeneral},
+ {"/a/{z}/{m...}", "/a/b/{y...}", moreGeneral},
+ {"/a/{z}/{m...}", "/{z}/a/", overlaps},
+ {"/a/{z}/{m...}", "/{z}/b/{y...}", overlaps},
+ {"/a/{z}/b/{m...}", "/{x}/c/{y...}", overlaps},
+ {"/a/{z}/a/{m...}", "/{x}/b", disjoint},
+
+ // Dollar on left.
+ {"/{$}", "/a", disjoint},
+ {"/{$}", "/a/b", disjoint},
+ {"/{$}", "/{$}", equivalent},
+ {"/{$}", "/{x}", disjoint},
+ {"/{$}", "/", moreSpecific},
+ {"/{$}", "/{x...}", moreSpecific},
+
+ {"/b/{$}", "/b", disjoint},
+ {"/b/{$}", "/b/a", disjoint},
+ {"/b/{$}", "/b/a/b", disjoint},
+ {"/b/{$}", "/b/{$}", equivalent},
+ {"/b/{$}", "/b/{x}", disjoint},
+ {"/b/{$}", "/b/", moreSpecific},
+ {"/b/{$}", "/b/{x...}", moreSpecific},
+ {"/b/{$}", "/b/c/{x...}", disjoint},
+ {"/b/{x}/a/{$}", "/{x}/c/{y...}", overlaps},
+ {"/{x}/b/{$}", "/a/{x}/{y}", disjoint},
+ {"/{x}/b/{$}", "/a/{x}/c", disjoint},
+
+ {"/{z}/{$}", "/{z}/a", disjoint},
+ {"/{z}/{$}", "/{z}/a/b", disjoint},
+ {"/{z}/{$}", "/{z}/{$}", equivalent},
+ {"/{z}/{$}", "/{z}/{x}", disjoint},
+ {"/{z}/{$}", "/{z}/", moreSpecific},
+ {"/{z}/{$}", "/a/", overlaps},
+ {"/{z}/{$}", "/a/{x...}", overlaps},
+ {"/{z}/{$}", "/{z}/{x...}", moreSpecific},
+ {"/a/{z}/{$}", "/{z}/a/", overlaps},
+ } {
+ pat1 := mustParsePattern(t, test.p1)
+ pat2 := mustParsePattern(t, test.p2)
+ if g := pat1.comparePaths(pat1); g != equivalent {
+ t.Errorf("%s does not match itself; got %s", pat1, g)
+ }
+ if g := pat2.comparePaths(pat2); g != equivalent {
+ t.Errorf("%s does not match itself; got %s", pat2, g)
+ }
+ got := pat1.comparePaths(pat2)
+ if got != test.want {
+ t.Errorf("%s vs %s: got %s, want %s", test.p1, test.p2, got, test.want)
+ t.Logf("pat1: %+v\n", pat1.segments)
+ t.Logf("pat2: %+v\n", pat2.segments)
+ }
+ want2 := inverseRelationship(test.want)
+ got2 := pat2.comparePaths(pat1)
+ if got2 != want2 {
+ t.Errorf("%s vs %s: got %s, want %s", test.p2, test.p1, got2, want2)
+ }
+ }
+}
+
+func TestConflictsWith(t *testing.T) {
+ for _, test := range []struct {
+ p1, p2 string
+ want bool
+ }{
+ {"/a", "/a", true},
+ {"/a", "/ab", false},
+ {"/a/b/cd", "/a/b/cd", true},
+ {"/a/b/cd", "/a/b/c", false},
+ {"/a/b/c", "/a/c/c", false},
+ {"/{x}", "/{y}", true},
+ {"/{x}", "/a", false}, // more specific
+ {"/{x}/{y}", "/{x}/a", false},
+ {"/{x}/{y}", "/{x}/a/b", false},
+ {"/{x}", "/a/{y}", false},
+ {"/{x}/{y}", "/{x}/a/", false},
+ {"/{x}", "/a/{y...}", false}, // more specific
+ {"/{x}/a/{y}", "/{x}/a/{y...}", false}, // more specific
+ {"/{x}/{y}", "/{x}/a/{$}", false}, // more specific
+ {"/{x}/{y}/{$}", "/{x}/a/{$}", false},
+ {"/a/{x}", "/{x}/b", true},
+ {"/", "GET /", false},
+ {"/", "GET /foo", false},
+ {"GET /", "GET /foo", false},
+ {"GET /", "/foo", true},
+ {"GET /foo", "HEAD /", true},
+ } {
+ pat1 := mustParsePattern(t, test.p1)
+ pat2 := mustParsePattern(t, test.p2)
+ got := pat1.conflictsWith(pat2)
+ if got != test.want {
+ t.Errorf("%q.ConflictsWith(%q) = %t, want %t",
+ test.p1, test.p2, got, test.want)
+ }
+ // conflictsWith should be commutative.
+ got = pat2.conflictsWith(pat1)
+ if got != test.want {
+ t.Errorf("%q.ConflictsWith(%q) = %t, want %t",
+ test.p2, test.p1, got, test.want)
+ }
+ }
+}
+
+func TestRegisterConflict(t *testing.T) {
+ mux := NewServeMux()
+ pat1 := "/a/{x}/"
+ if err := mux.registerErr(pat1, NotFoundHandler()); err != nil {
+ t.Fatal(err)
+ }
+ pat2 := "/a/{y}/{z...}"
+ err := mux.registerErr(pat2, NotFoundHandler())
+ var got string
+ if err == nil {
+ got = "<nil>"
+ } else {
+ got = err.Error()
+ }
+ want := "matches the same requests as"
+ if !strings.Contains(got, want) {
+ t.Errorf("got\n%s\nwant\n%s", got, want)
+ }
+}
+
+func TestDescribeConflict(t *testing.T) {
+ for _, test := range []struct {
+ p1, p2 string
+ want string
+ }{
+ {"/a/{x}", "/a/{y}", "the same requests"},
+ {"/", "/{m...}", "the same requests"},
+ {"/a/{x}", "/{y}/b", "both match some paths"},
+ {"/a", "GET /{x}", "matches more methods than GET /{x}, but has a more specific path pattern"},
+ {"GET /a", "HEAD /", "matches more methods than HEAD /, but has a more specific path pattern"},
+ {"POST /", "/a", "matches fewer methods than /a, but has a more general path pattern"},
+ } {
+ got := describeConflict(mustParsePattern(t, test.p1), mustParsePattern(t, test.p2))
+ if !strings.Contains(got, test.want) {
+ t.Errorf("%s vs. %s:\ngot:\n%s\nwhich does not contain %q",
+ test.p1, test.p2, got, test.want)
+ }
+ }
+}
+
+func TestCommonPath(t *testing.T) {
+ for _, test := range []struct {
+ p1, p2 string
+ want string
+ }{
+ {"/a/{x}", "/{x}/a", "/a/a"},
+ {"/a/{z}/", "/{z}/a/", "/a/a/"},
+ {"/a/{z}/{m...}", "/{z}/a/", "/a/a/"},
+ {"/{z}/{$}", "/a/", "/a/"},
+ {"/{z}/{$}", "/a/{x...}", "/a/"},
+ {"/a/{z}/{$}", "/{z}/a/", "/a/a/"},
+ {"/a/{x}/b/{y...}", "/{x}/c/{y...}", "/a/c/b/"},
+ {"/a/{x}/b/", "/{x}/c/{y...}", "/a/c/b/"},
+ {"/a/{x}/b/{$}", "/{x}/c/{y...}", "/a/c/b/"},
+ {"/a/{z}/{x...}", "/{z}/b/{y...}", "/a/b/"},
+ } {
+ pat1 := mustParsePattern(t, test.p1)
+ pat2 := mustParsePattern(t, test.p2)
+ if pat1.comparePaths(pat2) != overlaps {
+ t.Fatalf("%s does not overlap %s", test.p1, test.p2)
+ }
+ got := commonPath(pat1, pat2)
+ if got != test.want {
+ t.Errorf("%s vs. %s: got %q, want %q", test.p1, test.p2, got, test.want)
+ }
+ }
+}
+
+func TestDifferencePath(t *testing.T) {
+ for _, test := range []struct {
+ p1, p2 string
+ want string
+ }{
+ {"/a/{x}", "/{x}/a", "/a/x"},
+ {"/{x}/a", "/a/{x}", "/x/a"},
+ {"/a/{z}/", "/{z}/a/", "/a/z/"},
+ {"/{z}/a/", "/a/{z}/", "/z/a/"},
+ {"/{a}/a/", "/a/{z}/", "/ax/a/"},
+ {"/a/{z}/{x...}", "/{z}/b/{y...}", "/a/z/"},
+ {"/{z}/b/{y...}", "/a/{z}/{x...}", "/z/b/"},
+ {"/a/b/", "/a/b/c", "/a/b/"},
+ {"/a/b/{x...}", "/a/b/c", "/a/b/"},
+ {"/a/b/{x...}", "/a/b/c/d", "/a/b/"},
+ {"/a/b/{x...}", "/a/b/c/d/", "/a/b/"},
+ {"/a/{z}/{m...}", "/{z}/a/", "/a/z/"},
+ {"/{z}/a/", "/a/{z}/{m...}", "/z/a/"},
+ {"/{z}/{$}", "/a/", "/z/"},
+ {"/a/", "/{z}/{$}", "/a/x"},
+ {"/{z}/{$}", "/a/{x...}", "/z/"},
+ {"/a/{foo...}", "/{z}/{$}", "/a/foo"},
+ {"/a/{z}/{$}", "/{z}/a/", "/a/z/"},
+ {"/{z}/a/", "/a/{z}/{$}", "/z/a/x"},
+ {"/a/{x}/b/{y...}", "/{x}/c/{y...}", "/a/x/b/"},
+ {"/{x}/c/{y...}", "/a/{x}/b/{y...}", "/x/c/"},
+ {"/a/{c}/b/", "/{x}/c/{y...}", "/a/cx/b/"},
+ {"/{x}/c/{y...}", "/a/{c}/b/", "/x/c/"},
+ {"/a/{x}/b/{$}", "/{x}/c/{y...}", "/a/x/b/"},
+ {"/{x}/c/{y...}", "/a/{x}/b/{$}", "/x/c/"},
+ } {
+ pat1 := mustParsePattern(t, test.p1)
+ pat2 := mustParsePattern(t, test.p2)
+ rel := pat1.comparePaths(pat2)
+ if rel != overlaps && rel != moreGeneral {
+ t.Fatalf("%s vs. %s are %s, need overlaps or moreGeneral", pat1, pat2, rel)
+ }
+ got := differencePath(pat1, pat2)
+ if got != test.want {
+ t.Errorf("%s vs. %s: got %q, want %q", test.p1, test.p2, got, test.want)
+ }
+ }
+}
diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go
index bc3225daca..bc48f11834 100644
--- a/src/net/http/pprof/pprof.go
+++ b/src/net/http/pprof/pprof.go
@@ -47,12 +47,12 @@
// go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
//
// Or to look at the goroutine blocking profile, after calling
-// runtime.SetBlockProfileRate in your program:
+// [runtime.SetBlockProfileRate] in your program:
//
// go tool pprof http://localhost:6060/debug/pprof/block
//
// Or to look at the holders of contended mutexes, after calling
-// runtime.SetMutexProfileFraction in your program:
+// [runtime.SetMutexProfileFraction] in your program:
//
// go tool pprof http://localhost:6060/debug/pprof/mutex
//
diff --git a/src/net/http/pprof/pprof_test.go b/src/net/http/pprof/pprof_test.go
index f82ad45bf6..24ad59ab39 100644
--- a/src/net/http/pprof/pprof_test.go
+++ b/src/net/http/pprof/pprof_test.go
@@ -6,12 +6,14 @@ package pprof
import (
"bytes"
+ "encoding/base64"
"fmt"
"internal/profile"
"internal/testenv"
"io"
"net/http"
"net/http/httptest"
+ "path/filepath"
"runtime"
"runtime/pprof"
"strings"
@@ -261,3 +263,64 @@ func seen(p *profile.Profile, fname string) bool {
}
return false
}
+
+// TestDeltaProfileEmptyBase validates that we still receive a valid delta
+// profile even if the base contains no samples.
+//
+// Regression test for https://go.dev/issue/64566.
+func TestDeltaProfileEmptyBase(t *testing.T) {
+ if testing.Short() {
+ // Delta profile collection has a 1s minimum.
+ t.Skip("skipping in -short mode")
+ }
+
+ testenv.MustHaveGoRun(t)
+
+ gotool, err := testenv.GoTool()
+ if err != nil {
+ t.Fatalf("error finding go tool: %v", err)
+ }
+
+ out, err := testenv.Command(t, gotool, "run", filepath.Join("testdata", "delta_mutex.go")).CombinedOutput()
+ if err != nil {
+ t.Fatalf("error running profile collection: %v\noutput: %s", err, out)
+ }
+
+ // Log the binary output for debugging failures.
+ b64 := make([]byte, base64.StdEncoding.EncodedLen(len(out)))
+ base64.StdEncoding.Encode(b64, out)
+ t.Logf("Output in base64.StdEncoding: %s", b64)
+
+ p, err := profile.Parse(bytes.NewReader(out))
+ if err != nil {
+ t.Fatalf("Parse got err %v want nil", err)
+ }
+
+ t.Logf("Output as parsed Profile: %s", p)
+
+ if len(p.SampleType) != 2 {
+ t.Errorf("len(p.SampleType) got %d want 2", len(p.SampleType))
+ }
+ if p.SampleType[0].Type != "contentions" {
+ t.Errorf(`p.SampleType[0].Type got %q want "contentions"`, p.SampleType[0].Type)
+ }
+ if p.SampleType[0].Unit != "count" {
+ t.Errorf(`p.SampleType[0].Unit got %q want "count"`, p.SampleType[0].Unit)
+ }
+ if p.SampleType[1].Type != "delay" {
+ t.Errorf(`p.SampleType[1].Type got %q want "delay"`, p.SampleType[1].Type)
+ }
+ if p.SampleType[1].Unit != "nanoseconds" {
+ t.Errorf(`p.SampleType[1].Unit got %q want "nanoseconds"`, p.SampleType[1].Unit)
+ }
+
+ if p.PeriodType == nil {
+ t.Fatal("p.PeriodType got nil want not nil")
+ }
+ if p.PeriodType.Type != "contentions" {
+ t.Errorf(`p.PeriodType.Type got %q want "contentions"`, p.PeriodType.Type)
+ }
+ if p.PeriodType.Unit != "count" {
+ t.Errorf(`p.PeriodType.Unit got %q want "count"`, p.PeriodType.Unit)
+ }
+}
diff --git a/src/net/http/pprof/testdata/delta_mutex.go b/src/net/http/pprof/testdata/delta_mutex.go
new file mode 100644
index 0000000000..634069c8a0
--- /dev/null
+++ b/src/net/http/pprof/testdata/delta_mutex.go
@@ -0,0 +1,43 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This binary collects a 1s delta mutex profile and dumps it to os.Stdout.
+//
+// This is in a subprocess because we want the base mutex profile to be empty
+// (as a regression test for https://go.dev/issue/64566) and the only way to
+// force reset the profile is to create a new subprocess.
+//
+// This manually collects the HTTP response and dumps to stdout in order to
+// avoid any flakiness around port selection for a real HTTP server.
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/pprof"
+ "net/http/httptest"
+ "runtime"
+)
+
+func main() {
+ // Disable the mutex profiler. This is the default, but that default is
+ // load-bearing for this test, which needs the base profile to be empty.
+ runtime.SetMutexProfileFraction(0)
+
+ h := pprof.Handler("mutex")
+
+ req := httptest.NewRequest("GET", "/debug/pprof/mutex?seconds=1", nil)
+ rec := httptest.NewRecorder()
+ rec.Body = new(bytes.Buffer)
+
+ h.ServeHTTP(rec, req)
+ resp := rec.Result()
+ if resp.StatusCode != http.StatusOK {
+ log.Fatalf("Request failed: %s\n%s", resp.Status, rec.Body)
+ }
+
+ fmt.Print(rec.Body)
+}
diff --git a/src/net/http/request.go b/src/net/http/request.go
index 81f79566a5..99fdebcf9b 100644
--- a/src/net/http/request.go
+++ b/src/net/http/request.go
@@ -107,14 +107,10 @@ var reqWriteExcludeHeader = map[string]bool{
//
// The field semantics differ slightly between client and server
// usage. In addition to the notes on the fields below, see the
-// documentation for Request.Write and RoundTripper.
+// documentation for [Request.Write] and [RoundTripper].
type Request struct {
// Method specifies the HTTP method (GET, POST, PUT, etc.).
// For client requests, an empty string means GET.
- //
- // Go's HTTP client does not support sending a request with
- // the CONNECT method. See the documentation on Transport for
- // details.
Method string
// URL specifies either the URI being requested (for server
@@ -329,10 +325,15 @@ type Request struct {
// It is unexported to prevent people from using Context wrong
// and mutating the contexts held by callers of the same request.
ctx context.Context
+
+ // The following fields are for requests matched by ServeMux.
+ pat *pattern // the pattern that matched
+ matches []string // values for the matching wildcards in pat
+ otherValues map[string]string // for calls to SetPathValue that don't match a wildcard
}
// Context returns the request's context. To change the context, use
-// Clone or WithContext.
+// [Request.Clone] or [Request.WithContext].
//
// The returned context is always non-nil; it defaults to the
// background context.
@@ -356,8 +357,8 @@ func (r *Request) Context() context.Context {
// lifetime of a request and its response: obtaining a connection,
// sending the request, and reading the response headers and body.
//
-// To create a new request with a context, use NewRequestWithContext.
-// To make a deep copy of a request with a new context, use Request.Clone.
+// To create a new request with a context, use [NewRequestWithContext].
+// To make a deep copy of a request with a new context, use [Request.Clone].
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
@@ -396,6 +397,20 @@ func (r *Request) Clone(ctx context.Context) *Request {
r2.Form = cloneURLValues(r.Form)
r2.PostForm = cloneURLValues(r.PostForm)
r2.MultipartForm = cloneMultipartForm(r.MultipartForm)
+
+ // Copy matches and otherValues. See issue 61410.
+ if s := r.matches; s != nil {
+ s2 := make([]string, len(s))
+ copy(s2, s)
+ r2.matches = s2
+ }
+ if s := r.otherValues; s != nil {
+ s2 := make(map[string]string, len(s))
+ for k, v := range s {
+ s2[k] = v
+ }
+ r2.otherValues = s2
+ }
return r2
}
@@ -420,7 +435,7 @@ func (r *Request) Cookies() []*Cookie {
var ErrNoCookie = errors.New("http: named cookie not present")
// Cookie returns the named cookie provided in the request or
-// ErrNoCookie if not found.
+// [ErrNoCookie] if not found.
// If multiple cookies match the given name, only one cookie will
// be returned.
func (r *Request) Cookie(name string) (*Cookie, error) {
@@ -434,7 +449,7 @@ func (r *Request) Cookie(name string) (*Cookie, error) {
}
// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4,
-// AddCookie does not attach more than one Cookie header field. That
+// AddCookie does not attach more than one [Cookie] header field. That
// means all cookies, if any, are written into the same line,
// separated by semicolon.
// AddCookie only sanitizes c's name and value, and does not sanitize
@@ -452,7 +467,7 @@ func (r *Request) AddCookie(c *Cookie) {
//
// Referer is misspelled as in the request itself, a mistake from the
// earliest days of HTTP. This value can also be fetched from the
-// Header map as Header["Referer"]; the benefit of making it available
+// [Header] map as Header["Referer"]; the benefit of making it available
// as a method is that the compiler can diagnose programs that use the
// alternate (correct English) spelling req.Referrer() but cannot
// diagnose programs that use Header["Referrer"].
@@ -470,7 +485,7 @@ var multipartByReader = &multipart.Form{
// MultipartReader returns a MIME multipart reader if this is a
// multipart/form-data or a multipart/mixed POST request, else returns nil and an error.
-// Use this function instead of ParseMultipartForm to
+// Use this function instead of [Request.ParseMultipartForm] to
// process the request body as a stream.
func (r *Request) MultipartReader() (*multipart.Reader, error) {
if r.MultipartForm == multipartByReader {
@@ -533,15 +548,15 @@ const defaultUserAgent = "Go-http-client/1.1"
// TransferEncoding
// Body
//
-// If Body is present, Content-Length is <= 0 and TransferEncoding
+// If Body is present, Content-Length is <= 0 and [Request.TransferEncoding]
// hasn't been set to "identity", Write adds "Transfer-Encoding:
// chunked" to the header. Body is closed after it is sent.
func (r *Request) Write(w io.Writer) error {
return r.write(w, false, nil, nil)
}
-// WriteProxy is like Write but writes the request in the form
-// expected by an HTTP proxy. In particular, WriteProxy writes the
+// WriteProxy is like [Request.Write] but writes the request in the form
+// expected by an HTTP proxy. In particular, [Request.WriteProxy] writes the
// initial Request-URI line of the request with an absolute URI, per
// section 5.3 of RFC 7230, including the scheme and host.
// In either case, WriteProxy also writes a Host header, using
@@ -669,6 +684,8 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
userAgent = r.Header.Get("User-Agent")
}
if userAgent != "" {
+ userAgent = headerNewlineToSpace.Replace(userAgent)
+ userAgent = textproto.TrimString(userAgent)
_, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
if err != nil {
return err
@@ -834,32 +851,33 @@ func validMethod(method string) bool {
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
-// NewRequest wraps NewRequestWithContext using context.Background.
+// NewRequest wraps [NewRequestWithContext] using [context.Background].
func NewRequest(method, url string, body io.Reader) (*Request, error) {
return NewRequestWithContext(context.Background(), method, url, body)
}
-// NewRequestWithContext returns a new Request given a method, URL, and
+// NewRequestWithContext returns a new [Request] given a method, URL, and
// optional body.
//
-// If the provided body is also an io.Closer, the returned
-// Request.Body is set to body and will be closed by the Client
-// methods Do, Post, and PostForm, and Transport.RoundTrip.
+// If the provided body is also an [io.Closer], the returned
+// [Request.Body] is set to body and will be closed (possibly
+// asynchronously) by the Client methods Do, Post, and PostForm,
+// and [Transport.RoundTrip].
//
// NewRequestWithContext returns a Request suitable for use with
-// Client.Do or Transport.RoundTrip. To create a request for use with
-// testing a Server Handler, either use the NewRequest function in the
-// net/http/httptest package, use ReadRequest, or manually update the
+// [Client.Do] or [Transport.RoundTrip]. To create a request for use with
+// testing a Server Handler, either use the [NewRequest] function in the
+// net/http/httptest package, use [ReadRequest], or manually update the
// Request fields. For an outgoing client request, the context
// controls the entire lifetime of a request and its response:
// obtaining a connection, sending the request, and reading the
// response headers and body. See the Request type's documentation for
// the difference between inbound and outbound request fields.
//
-// If body is of type *bytes.Buffer, *bytes.Reader, or
-// *strings.Reader, the returned request's ContentLength is set to its
+// If body is of type [*bytes.Buffer], [*bytes.Reader], or
+// [*strings.Reader], the returned request's ContentLength is set to its
// exact value (instead of -1), GetBody is populated (so 307 and 308
-// redirects can replay the body), and Body is set to NoBody if the
+// redirects can replay the body), and Body is set to [NoBody] if the
// ContentLength is 0.
func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) {
if method == "" {
@@ -983,7 +1001,7 @@ func parseBasicAuth(auth string) (username, password string, ok bool) {
// The username may not contain a colon. Some protocols may impose
// additional requirements on pre-escaping the username and
// password. For instance, when used with OAuth2, both arguments must
-// be URL encoded first with url.QueryEscape.
+// be URL encoded first with [url.QueryEscape].
func (r *Request) SetBasicAuth(username, password string) {
r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
}
@@ -1017,8 +1035,8 @@ func putTextprotoReader(r *textproto.Reader) {
// ReadRequest reads and parses an incoming request from b.
//
// ReadRequest is a low-level function and should only be used for
-// specialized applications; most code should use the Server to read
-// requests and handle them via the Handler interface. ReadRequest
+// specialized applications; most code should use the [Server] to read
+// requests and handle them via the [Handler] interface. ReadRequest
// only supports HTTP/1.x requests. For HTTP/2, use golang.org/x/net/http2.
func ReadRequest(b *bufio.Reader) (*Request, error) {
req, err := readRequest(b)
@@ -1127,15 +1145,15 @@ func readRequest(b *bufio.Reader) (req *Request, err error) {
return req, nil
}
-// MaxBytesReader is similar to io.LimitReader but is intended for
+// MaxBytesReader is similar to [io.LimitReader] but is intended for
// limiting the size of incoming request bodies. In contrast to
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
-// non-nil error of type *MaxBytesError for a Read beyond the limit,
+// non-nil error of type [*MaxBytesError] for a Read beyond the limit,
// and closes the underlying reader when its Close method is called.
//
// MaxBytesReader prevents clients from accidentally or maliciously
// sending a large request and wasting server resources. If possible,
-// it tells the ResponseWriter to close the connection after the limit
+// it tells the [ResponseWriter] to close the connection after the limit
// has been reached.
func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
if n < 0 { // Treat negative limits as equivalent to 0.
@@ -1144,7 +1162,7 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, i: n, n: n}
}
-// MaxBytesError is returned by MaxBytesReader when its read limit is exceeded.
+// MaxBytesError is returned by [MaxBytesReader] when its read limit is exceeded.
type MaxBytesError struct {
Limit int64
}
@@ -1269,14 +1287,14 @@ func parsePostForm(r *Request) (vs url.Values, err error) {
// as a form and puts the results into both r.PostForm and r.Form. Request body
// parameters take precedence over URL query string values in r.Form.
//
-// If the request Body's size has not already been limited by MaxBytesReader,
+// If the request Body's size has not already been limited by [MaxBytesReader],
// the size is capped at 10MB.
//
// For other HTTP methods, or when the Content-Type is not
// application/x-www-form-urlencoded, the request Body is not read, and
// r.PostForm is initialized to a non-nil, empty value.
//
-// ParseMultipartForm calls ParseForm automatically.
+// [Request.ParseMultipartForm] calls ParseForm automatically.
// ParseForm is idempotent.
func (r *Request) ParseForm() error {
var err error
@@ -1317,7 +1335,7 @@ func (r *Request) ParseForm() error {
// The whole request body is parsed and up to a total of maxMemory bytes of
// its file parts are stored in memory, with the remainder stored on
// disk in temporary files.
-// ParseMultipartForm calls ParseForm if necessary.
+// ParseMultipartForm calls [Request.ParseForm] if necessary.
// If ParseForm returns an error, ParseMultipartForm returns it but also
// continues parsing the request body.
// After one call to ParseMultipartForm, subsequent calls have no effect.
@@ -1360,12 +1378,16 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error {
}
// FormValue returns the first value for the named component of the query.
-// POST and PUT body parameters take precedence over URL query string values.
-// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores
-// any errors returned by these functions.
+// The precedence order:
+// 1. application/x-www-form-urlencoded form body (POST, PUT, PATCH only)
+// 2. query parameters (always)
+// 3. multipart/form-data form body (always)
+//
+// FormValue calls [Request.ParseMultipartForm] and [Request.ParseForm]
+// if necessary and ignores any errors returned by these functions.
// If key is not present, FormValue returns the empty string.
// To access multiple values of the same key, call ParseForm and
-// then inspect Request.Form directly.
+// then inspect [Request.Form] directly.
func (r *Request) FormValue(key string) string {
if r.Form == nil {
r.ParseMultipartForm(defaultMaxMemory)
@@ -1377,8 +1399,8 @@ func (r *Request) FormValue(key string) string {
}
// PostFormValue returns the first value for the named component of the POST,
-// PATCH, or PUT request body. URL query parameters are ignored.
-// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores
+// PUT, or PATCH request body. URL query parameters are ignored.
+// PostFormValue calls [Request.ParseMultipartForm] and [Request.ParseForm] if necessary and ignores
// any errors returned by these functions.
// If key is not present, PostFormValue returns the empty string.
func (r *Request) PostFormValue(key string) string {
@@ -1392,7 +1414,7 @@ func (r *Request) PostFormValue(key string) string {
}
// FormFile returns the first file for the provided form key.
-// FormFile calls ParseMultipartForm and ParseForm if necessary.
+// FormFile calls [Request.ParseMultipartForm] and [Request.ParseForm] if necessary.
func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
if r.MultipartForm == multipartByReader {
return nil, nil, errors.New("http: multipart handled by MultipartReader")
@@ -1412,6 +1434,50 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e
return nil, nil, ErrMissingFile
}
+// PathValue returns the value for the named path wildcard in the [ServeMux] pattern
+// that matched the request.
+// It returns the empty string if the request was not matched against a pattern
+// or there is no such wildcard in the pattern.
+func (r *Request) PathValue(name string) string {
+ if i := r.patIndex(name); i >= 0 {
+ return r.matches[i]
+ }
+ return r.otherValues[name]
+}
+
+// SetPathValue sets name to value, so that subsequent calls to r.PathValue(name)
+// return value.
+func (r *Request) SetPathValue(name, value string) {
+ if i := r.patIndex(name); i >= 0 {
+ r.matches[i] = value
+ } else {
+ if r.otherValues == nil {
+ r.otherValues = map[string]string{}
+ }
+ r.otherValues[name] = value
+ }
+}
+
+// patIndex returns the index of name in the list of named wildcards of the
+// request's pattern, or -1 if there is no such name.
+func (r *Request) patIndex(name string) int {
+ // The linear search seems expensive compared to a map, but just creating the map
+ // takes a lot of time, and most patterns will just have a couple of wildcards.
+ if r.pat == nil {
+ return -1
+ }
+ i := 0
+ for _, seg := range r.pat.segments {
+ if seg.wild && seg.s != "" {
+ if name == seg.s {
+ return i
+ }
+ i++
+ }
+ }
+ return -1
+}
+
func (r *Request) expectsContinue() bool {
return hasToken(r.Header.get("Expect"), "100-continue")
}
diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go
index a32b583c11..6ce32332e7 100644
--- a/src/net/http/request_test.go
+++ b/src/net/http/request_test.go
@@ -15,7 +15,9 @@ import (
"io"
"math"
"mime/multipart"
+ "net/http"
. "net/http"
+ "net/http/httptest"
"net/url"
"os"
"reflect"
@@ -787,6 +789,25 @@ func TestRequestBadHostHeader(t *testing.T) {
}
}
+func TestRequestBadUserAgent(t *testing.T) {
+ got := []string{}
+ req, err := NewRequest("GET", "http://foo/after", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("User-Agent", "evil\r\nX-Evil: evil")
+ req.Write(logWrites{t, &got})
+ want := []string{
+ "GET /after HTTP/1.1\r\n",
+ "Host: foo\r\n",
+ "User-Agent: evil X-Evil: evil\r\n",
+ "\r\n",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Writes = %q\n Want = %q", got, want)
+ }
+}
+
func TestStarRequest(t *testing.T) {
req, err := ReadRequest(bufio.NewReader(strings.NewReader("M-SEARCH * HTTP/1.1\r\n\r\n")))
if err != nil {
@@ -1032,6 +1053,33 @@ func TestRequestCloneTransferEncoding(t *testing.T) {
}
}
+// Ensure that Request.Clone works correctly with PathValue.
+// See issue 64911.
+func TestRequestClonePathValue(t *testing.T) {
+ req, _ := http.NewRequest("GET", "https://example.org/", nil)
+ req.SetPathValue("p1", "orig")
+
+ clonedReq := req.Clone(context.Background())
+ clonedReq.SetPathValue("p2", "copy")
+
+ // Ensure that any modifications to the cloned
+ // request do not pollute the original request.
+ if g, w := req.PathValue("p2"), ""; g != w {
+ t.Fatalf("p2 mismatch got %q, want %q", g, w)
+ }
+ if g, w := req.PathValue("p1"), "orig"; g != w {
+ t.Fatalf("p1 mismatch got %q, want %q", g, w)
+ }
+
+ // Assert on the changes to the cloned request.
+ if g, w := clonedReq.PathValue("p1"), "orig"; g != w {
+ t.Fatalf("p1 mismatch got %q, want %q", g, w)
+ }
+ if g, w := clonedReq.PathValue("p2"), "copy"; g != w {
+ t.Fatalf("p2 mismatch got %q, want %q", g, w)
+ }
+}
+
// Issue 34878: verify we don't panic when including basic auth (Go 1.13 regression)
func TestNoPanicOnRoundTripWithBasicAuth(t *testing.T) { run(t, testNoPanicWithBasicAuth) }
func testNoPanicWithBasicAuth(t *testing.T, mode testMode) {
@@ -1395,3 +1443,141 @@ func TestErrNotSupported(t *testing.T) {
t.Error("errors.Is(ErrNotSupported, errors.ErrUnsupported) failed")
}
}
+
+func TestPathValueNoMatch(t *testing.T) {
+ // Check that PathValue and SetPathValue work on a Request that was never matched.
+ var r Request
+ if g, w := r.PathValue("x"), ""; g != w {
+ t.Errorf("got %q, want %q", g, w)
+ }
+ r.SetPathValue("x", "a")
+ if g, w := r.PathValue("x"), "a"; g != w {
+ t.Errorf("got %q, want %q", g, w)
+ }
+}
+
+func TestPathValue(t *testing.T) {
+ for _, test := range []struct {
+ pattern string
+ url string
+ want map[string]string
+ }{
+ {
+ "/{a}/is/{b}/{c...}",
+ "/now/is/the/time/for/all",
+ map[string]string{
+ "a": "now",
+ "b": "the",
+ "c": "time/for/all",
+ "d": "",
+ },
+ },
+ {
+ "/names/{name}/{other...}",
+ "/names/%2fjohn/address",
+ map[string]string{
+ "name": "/john",
+ "other": "address",
+ },
+ },
+ {
+ "/names/{name}/{other...}",
+ "/names/john%2Fdoe/there/is%2F/more",
+ map[string]string{
+ "name": "john/doe",
+ "other": "there/is//more",
+ },
+ },
+ } {
+ mux := NewServeMux()
+ mux.HandleFunc(test.pattern, func(w ResponseWriter, r *Request) {
+ for name, want := range test.want {
+ got := r.PathValue(name)
+ if got != want {
+ t.Errorf("%q, %q: got %q, want %q", test.pattern, name, got, want)
+ }
+ }
+ })
+ server := httptest.NewServer(mux)
+ defer server.Close()
+ res, err := Get(server.URL + test.url)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ }
+}
+
+func TestSetPathValue(t *testing.T) {
+ mux := NewServeMux()
+ mux.HandleFunc("/a/{b}/c/{d...}", func(_ ResponseWriter, r *Request) {
+ kvs := map[string]string{
+ "b": "X",
+ "d": "Y",
+ "a": "Z",
+ }
+ for k, v := range kvs {
+ r.SetPathValue(k, v)
+ }
+ for k, w := range kvs {
+ if g := r.PathValue(k); g != w {
+ t.Errorf("got %q, want %q", g, w)
+ }
+ }
+ })
+ server := httptest.NewServer(mux)
+ defer server.Close()
+ res, err := Get(server.URL + "/a/b/c/d/e")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+func TestStatus(t *testing.T) {
+ // The main purpose of this test is to check 405 responses and the Allow header.
+ h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
+ mux := NewServeMux()
+ mux.Handle("GET /g", h)
+ mux.Handle("POST /p", h)
+ mux.Handle("PATCH /p", h)
+ mux.Handle("PUT /r", h)
+ mux.Handle("GET /r/", h)
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ for _, test := range []struct {
+ method, path string
+ wantStatus int
+ wantAllow string
+ }{
+ {"GET", "/g", 200, ""},
+ {"HEAD", "/g", 200, ""},
+ {"POST", "/g", 405, "GET, HEAD"},
+ {"GET", "/x", 404, ""},
+ {"GET", "/p", 405, "PATCH, POST"},
+ {"GET", "/./p", 405, "PATCH, POST"},
+ {"GET", "/r/", 200, ""},
+ {"GET", "/r", 200, ""}, // redirected
+ {"HEAD", "/r/", 200, ""},
+ {"HEAD", "/r", 200, ""}, // redirected
+ {"PUT", "/r/", 405, "GET, HEAD"},
+ {"PUT", "/r", 200, ""},
+ } {
+ req, err := http.NewRequest(test.method, server.URL+test.path, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if g, w := res.StatusCode, test.wantStatus; g != w {
+ t.Errorf("%s %s: got %d, want %d", test.method, test.path, g, w)
+ }
+ if g, w := res.Header.Get("Allow"), test.wantAllow; g != w {
+ t.Errorf("%s %s, Allow: got %q, want %q", test.method, test.path, g, w)
+ }
+ }
+}
diff --git a/src/net/http/response.go b/src/net/http/response.go
index 755c696557..0c3d7f6d85 100644
--- a/src/net/http/response.go
+++ b/src/net/http/response.go
@@ -29,7 +29,7 @@ var respExcludeHeader = map[string]bool{
// Response represents the response from an HTTP request.
//
-// The Client and Transport return Responses from servers once
+// The [Client] and [Transport] return Responses from servers once
// the response headers have been received. The response body
// is streamed on demand as the Body field is read.
type Response struct {
@@ -126,13 +126,13 @@ func (r *Response) Cookies() []*Cookie {
return readSetCookies(r.Header)
}
-// ErrNoLocation is returned by Response's Location method
+// ErrNoLocation is returned by the [Response.Location] method
// when no Location header is present.
var ErrNoLocation = errors.New("http: no Location header in response")
// Location returns the URL of the response's "Location" header,
// if present. Relative redirects are resolved relative to
-// the Response's Request. ErrNoLocation is returned if no
+// [Response.Request]. [ErrNoLocation] is returned if no
// Location header is present.
func (r *Response) Location() (*url.URL, error) {
lv := r.Header.Get("Location")
@@ -146,8 +146,8 @@ func (r *Response) Location() (*url.URL, error) {
}
// ReadResponse reads and returns an HTTP response from r.
-// The req parameter optionally specifies the Request that corresponds
-// to this Response. If nil, a GET request is assumed.
+// The req parameter optionally specifies the [Request] that corresponds
+// to this [Response]. If nil, a GET request is assumed.
// Clients must call resp.Body.Close when finished reading resp.Body.
// After that call, clients can inspect resp.Trailer to find key/value
// pairs included in the response trailer.
diff --git a/src/net/http/response_test.go b/src/net/http/response_test.go
index 19fb48f23c..f3425c3c20 100644
--- a/src/net/http/response_test.go
+++ b/src/net/http/response_test.go
@@ -849,7 +849,7 @@ func TestReadResponseErrors(t *testing.T) {
type testCase struct {
name string // optional, defaults to in
in string
- wantErr any // nil, err value, or string substring
+ wantErr any // nil, err value, bool value, or string substring
}
status := func(s string, wantErr any) testCase {
@@ -883,6 +883,7 @@ func TestReadResponseErrors(t *testing.T) {
}
errMultiCL := "message cannot contain multiple Content-Length headers"
+ errEmptyCL := "invalid empty Content-Length"
tests := []testCase{
{"", "", io.ErrUnexpectedEOF},
@@ -918,7 +919,7 @@ func TestReadResponseErrors(t *testing.T) {
contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil),
contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL),
contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil),
- contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil),
+ contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", errEmptyCL),
contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL),
// multiple content-length headers for 204 and 304 should still be checked
diff --git a/src/net/http/responsecontroller.go b/src/net/http/responsecontroller.go
index 92276ffaf2..f3f24c1273 100644
--- a/src/net/http/responsecontroller.go
+++ b/src/net/http/responsecontroller.go
@@ -13,14 +13,14 @@ import (
// A ResponseController is used by an HTTP handler to control the response.
//
-// A ResponseController may not be used after the Handler.ServeHTTP method has returned.
+// A ResponseController may not be used after the [Handler.ServeHTTP] method has returned.
type ResponseController struct {
rw ResponseWriter
}
-// NewResponseController creates a ResponseController for a request.
+// NewResponseController creates a [ResponseController] for a request.
//
-// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method,
+// The ResponseWriter should be the original value passed to the [Handler.ServeHTTP] method,
// or have an Unwrap method returning the original ResponseWriter.
//
// If the ResponseWriter implements any of the following methods, the ResponseController
@@ -34,7 +34,7 @@ type ResponseController struct {
// EnableFullDuplex() error
//
// If the ResponseWriter does not support a method, ResponseController returns
-// an error matching ErrNotSupported.
+// an error matching [ErrNotSupported].
func NewResponseController(rw ResponseWriter) *ResponseController {
return &ResponseController{rw}
}
@@ -116,8 +116,8 @@ func (c *ResponseController) SetWriteDeadline(deadline time.Time) error {
}
}
-// EnableFullDuplex indicates that the request handler will interleave reads from Request.Body
-// with writes to the ResponseWriter.
+// EnableFullDuplex indicates that the request handler will interleave reads from [Request.Body]
+// with writes to the [ResponseWriter].
//
// For HTTP/1 requests, the Go HTTP server by default consumes any unread portion of
// the request body before beginning to write the response, preventing handlers from
diff --git a/src/net/http/responsecontroller_test.go b/src/net/http/responsecontroller_test.go
index 5828f3795a..f1dcc79ef8 100644
--- a/src/net/http/responsecontroller_test.go
+++ b/src/net/http/responsecontroller_test.go
@@ -1,3 +1,7 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
package http_test
import (
diff --git a/src/net/http/roundtrip.go b/src/net/http/roundtrip.go
index 49ea1a71ed..08c270179a 100644
--- a/src/net/http/roundtrip.go
+++ b/src/net/http/roundtrip.go
@@ -6,10 +6,10 @@
package http
-// RoundTrip implements the RoundTripper interface.
+// RoundTrip implements the [RoundTripper] interface.
//
// For higher-level HTTP client support (such as handling of cookies
-// and redirects), see Get, Post, and the Client type.
+// and redirects), see [Get], [Post], and the [Client] type.
//
// Like the RoundTripper interface, the error types returned
// by RoundTrip are unspecified.
diff --git a/src/net/http/roundtrip_js.go b/src/net/http/roundtrip_js.go
index 9f9f0cb67d..04c241eb4c 100644
--- a/src/net/http/roundtrip_js.go
+++ b/src/net/http/roundtrip_js.go
@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
+ "net/http/internal/ascii"
"strconv"
"strings"
"syscall/js"
@@ -55,7 +56,7 @@ var jsFetchMissing = js.Global().Get("fetch").IsUndefined()
var jsFetchDisabled = js.Global().Get("process").Type() == js.TypeObject &&
strings.HasPrefix(js.Global().Get("process").Get("argv0").String(), "node")
-// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API.
+// RoundTrip implements the [RoundTripper] interface using the WHATWG Fetch API.
func (t *Transport) RoundTrip(req *Request) (*Response, error) {
// The Transport has a documented contract that states that if the DialContext or
// DialTLSContext functions are set, they will be used to set up the connections.
@@ -184,11 +185,22 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
}
code := result.Get("status").Int()
+
+ uncompressed := false
+ if ascii.EqualFold(header.Get("Content-Encoding"), "gzip") {
+ // The fetch api will decode the gzip, but Content-Encoding not be deleted.
+ header.Del("Content-Encoding")
+ header.Del("Content-Length")
+ contentLength = -1
+ uncompressed = true
+ }
+
respCh <- &Response{
Status: fmt.Sprintf("%d %s", code, StatusText(code)),
StatusCode: code,
Header: header,
ContentLength: contentLength,
+ Uncompressed: uncompressed,
Body: body,
Request: req,
}
diff --git a/src/net/http/routing_index.go b/src/net/http/routing_index.go
new file mode 100644
index 0000000000..9ac42c997d
--- /dev/null
+++ b/src/net/http/routing_index.go
@@ -0,0 +1,124 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import "math"
+
+// A routingIndex optimizes conflict detection by indexing patterns.
+//
+// The basic idea is to rule out patterns that cannot conflict with a given
+// pattern because they have a different literal in a corresponding segment.
+// See the comments in [routingIndex.possiblyConflictingPatterns] for more details.
+type routingIndex struct {
+ // map from a particular segment position and value to all registered patterns
+ // with that value in that position.
+ // For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c"
+ // but not "/a", "b/a", "/a/c" or "/a/{x}".
+ segments map[routingIndexKey][]*pattern
+ // All patterns that end in a multi wildcard (including trailing slash).
+ // We do not try to be clever about indexing multi patterns, because there
+ // are unlikely to be many of them.
+ multis []*pattern
+}
+
+type routingIndexKey struct {
+ pos int // 0-based segment position
+ s string // literal, or empty for wildcard
+}
+
+func (idx *routingIndex) addPattern(pat *pattern) {
+ if pat.lastSegment().multi {
+ idx.multis = append(idx.multis, pat)
+ } else {
+ if idx.segments == nil {
+ idx.segments = map[routingIndexKey][]*pattern{}
+ }
+ for pos, seg := range pat.segments {
+ key := routingIndexKey{pos: pos, s: ""}
+ if !seg.wild {
+ key.s = seg.s
+ }
+ idx.segments[key] = append(idx.segments[key], pat)
+ }
+ }
+}
+
+// possiblyConflictingPatterns calls f on all patterns that might conflict with
+// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately
+// with that error.
+//
+// To be correct, possiblyConflictingPatterns must include all patterns that
+// might conflict. But it may also include patterns that cannot conflict.
+// For instance, an implementation that returns all registered patterns is correct.
+// We use this fact throughout, simplifying the implementation by returning more
+// patterns that we might need to.
+func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) {
+ // Terminology:
+ // dollar pattern: one ending in "{$}"
+ // multi pattern: one ending in a trailing slash or "{x...}" wildcard
+ // ordinary pattern: neither of the above
+
+ // apply f to all the pats, stopping on error.
+ apply := func(pats []*pattern) error {
+ if err != nil {
+ return err
+ }
+ for _, p := range pats {
+ err = f(p)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Our simple indexing scheme doesn't try to prune multi patterns; assume
+ // any of them can match the argument.
+ if err := apply(idx.multis); err != nil {
+ return err
+ }
+ if pat.lastSegment().s == "/" {
+ // All paths that a dollar pattern matches end in a slash; no paths that
+ // an ordinary pattern matches do. So only other dollar or multi
+ // patterns can conflict with a dollar pattern. Furthermore, conflicting
+ // dollar patterns must have the {$} in the same position.
+ return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}])
+ }
+ // For ordinary and multi patterns, the only conflicts can be with a multi,
+ // or a pattern that has the same literal or a wildcard at some literal
+ // position.
+ // We could intersect all the possible matches at each position, but we
+ // do something simpler: we find the position with the fewest patterns.
+ var lmin, wmin []*pattern
+ min := math.MaxInt
+ hasLit := false
+ for i, seg := range pat.segments {
+ if seg.multi {
+ break
+ }
+ if !seg.wild {
+ hasLit = true
+ lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}]
+ wpats := idx.segments[routingIndexKey{s: "", pos: i}]
+ if sum := len(lpats) + len(wpats); sum < min {
+ lmin = lpats
+ wmin = wpats
+ min = sum
+ }
+ }
+ }
+ if hasLit {
+ apply(lmin)
+ apply(wmin)
+ return err
+ }
+
+ // This pattern is all wildcards.
+ // Check it against everything.
+ for _, pats := range idx.segments {
+ apply(pats)
+ }
+ return err
+}
diff --git a/src/net/http/routing_index_test.go b/src/net/http/routing_index_test.go
new file mode 100644
index 0000000000..1ffb9272c6
--- /dev/null
+++ b/src/net/http/routing_index_test.go
@@ -0,0 +1,179 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "slices"
+ "sort"
+ "strings"
+ "testing"
+)
+
+func TestIndex(t *testing.T) {
+ // Generate every kind of pattern up to some number of segments,
+ // and compare conflicts found during indexing with those found
+ // by exhaustive comparison.
+ patterns := generatePatterns()
+ var idx routingIndex
+ for i, pat := range patterns {
+ got := indexConflicts(pat, &idx)
+ want := trueConflicts(pat, patterns[:i])
+ if !slices.Equal(got, want) {
+ t.Fatalf("%q:\ngot %q\nwant %q", pat, got, want)
+ }
+ idx.addPattern(pat)
+ }
+}
+
+func trueConflicts(pat *pattern, pats []*pattern) []string {
+ var s []string
+ for _, p := range pats {
+ if pat.conflictsWith(p) {
+ s = append(s, p.String())
+ }
+ }
+ sort.Strings(s)
+ return s
+}
+
+func indexConflicts(pat *pattern, idx *routingIndex) []string {
+ var s []string
+ idx.possiblyConflictingPatterns(pat, func(p *pattern) error {
+ if pat.conflictsWith(p) {
+ s = append(s, p.String())
+ }
+ return nil
+ })
+ sort.Strings(s)
+ return slices.Compact(s)
+}
+
+// generatePatterns generates all possible patterns using a representative
+// sample of parts.
+func generatePatterns() []*pattern {
+ var pats []*pattern
+
+ collect := func(s string) {
+ // Replace duplicate wildcards with unique ones.
+ var b strings.Builder
+ wc := 0
+ for {
+ i := strings.Index(s, "{x}")
+ if i < 0 {
+ b.WriteString(s)
+ break
+ }
+ b.WriteString(s[:i])
+ fmt.Fprintf(&b, "{x%d}", wc)
+ wc++
+ s = s[i+3:]
+ }
+ pat, err := parsePattern(b.String())
+ if err != nil {
+ panic(err)
+ }
+ pats = append(pats, pat)
+ }
+
+ var (
+ methods = []string{"", "GET ", "HEAD ", "POST "}
+ hosts = []string{"", "h1", "h2"}
+ segs = []string{"/a", "/b", "/{x}"}
+ finalSegs = []string{"/a", "/b", "/{f}", "/{m...}", "/{$}"}
+ )
+
+ g := genConcat(
+ genChoice(methods),
+ genChoice(hosts),
+ genStar(3, genChoice(segs)),
+ genChoice(finalSegs))
+ g(collect)
+ return pats
+}
+
+// A generator is a function that calls its argument with the strings that it
+// generates.
+type generator func(collect func(string))
+
+// genConst generates a single constant string.
+func genConst(s string) generator {
+ return func(collect func(string)) {
+ collect(s)
+ }
+}
+
+// genChoice generates all the strings in its argument.
+func genChoice(choices []string) generator {
+ return func(collect func(string)) {
+ for _, c := range choices {
+ collect(c)
+ }
+ }
+}
+
+// genConcat2 generates the cross product of the strings of g1 concatenated
+// with those of g2.
+func genConcat2(g1, g2 generator) generator {
+ return func(collect func(string)) {
+ g1(func(s1 string) {
+ g2(func(s2 string) {
+ collect(s1 + s2)
+ })
+ })
+ }
+}
+
+// genConcat generalizes genConcat2 to any number of generators.
+func genConcat(gs ...generator) generator {
+ if len(gs) == 0 {
+ return genConst("")
+ }
+ return genConcat2(gs[0], genConcat(gs[1:]...))
+}
+
+// genRepeat generates strings of exactly n copies of g's strings.
+func genRepeat(n int, g generator) generator {
+ if n == 0 {
+ return genConst("")
+ }
+ return genConcat(g, genRepeat(n-1, g))
+}
+
+// genStar (named after the Kleene star) generates 0, 1, 2, ..., max
+// copies of the strings of g.
+func genStar(max int, g generator) generator {
+ return func(collect func(string)) {
+ for i := 0; i <= max; i++ {
+ genRepeat(i, g)(collect)
+ }
+ }
+}
+
+func BenchmarkMultiConflicts(b *testing.B) {
+ // How fast is indexing if the corpus is all multis?
+ const nMultis = 1000
+ var pats []*pattern
+ for i := 0; i < nMultis; i++ {
+ pats = append(pats, mustParsePattern(b, fmt.Sprintf("/a/b/{x}/d%d/", i)))
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var idx routingIndex
+ for _, p := range pats {
+ got := indexConflicts(p, &idx)
+ if len(got) != 0 {
+ b.Fatalf("got %d conflicts, want 0", len(got))
+ }
+ idx.addPattern(p)
+ }
+ if i == 0 {
+ // Confirm that all the multis ended up where they belong.
+ if g, w := len(idx.multis), nMultis; g != w {
+ b.Fatalf("got %d multis, want %d", g, w)
+ }
+ }
+ }
+}
diff --git a/src/net/http/routing_tree.go b/src/net/http/routing_tree.go
new file mode 100644
index 0000000000..8812ed04e2
--- /dev/null
+++ b/src/net/http/routing_tree.go
@@ -0,0 +1,240 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements a decision tree for fast matching of requests to
+// patterns.
+//
+// The root of the tree branches on the host of the request.
+// The next level branches on the method.
+// The remaining levels branch on consecutive segments of the path.
+//
+// The "more specific wins" precedence rule can result in backtracking.
+// For example, given the patterns
+// /a/b/z
+// /a/{x}/c
+// we will first try to match the path "/a/b/c" with /a/b/z, and
+// when that fails we will try against /a/{x}/c.
+
+package http
+
+import (
+ "strings"
+)
+
+// A routingNode is a node in the decision tree.
+// The same struct is used for leaf and interior nodes.
+type routingNode struct {
+ // A leaf node holds a single pattern and the Handler it was registered
+ // with.
+ pattern *pattern
+ handler Handler
+
+ // An interior node maps parts of the incoming request to child nodes.
+ // special children keys:
+ // "/" trailing slash (resulting from {$})
+ // "" single wildcard
+ // "*" multi wildcard
+ children mapping[string, *routingNode]
+ emptyChild *routingNode // optimization: child with key ""
+}
+
+// addPattern adds a pattern and its associated Handler to the tree
+// at root.
+func (root *routingNode) addPattern(p *pattern, h Handler) {
+ // First level of tree is host.
+ n := root.addChild(p.host)
+ // Second level of tree is method.
+ n = n.addChild(p.method)
+ // Remaining levels are path.
+ n.addSegments(p.segments, p, h)
+}
+
+// addSegments adds the given segments to the tree rooted at n.
+// If there are no segments, then n is a leaf node that holds
+// the given pattern and handler.
+func (n *routingNode) addSegments(segs []segment, p *pattern, h Handler) {
+ if len(segs) == 0 {
+ n.set(p, h)
+ return
+ }
+ seg := segs[0]
+ if seg.multi {
+ if len(segs) != 1 {
+ panic("multi wildcard not last")
+ }
+ n.addChild("*").set(p, h)
+ } else if seg.wild {
+ n.addChild("").addSegments(segs[1:], p, h)
+ } else {
+ n.addChild(seg.s).addSegments(segs[1:], p, h)
+ }
+}
+
+// set sets the pattern and handler for n, which
+// must be a leaf node.
+func (n *routingNode) set(p *pattern, h Handler) {
+ if n.pattern != nil || n.handler != nil {
+ panic("non-nil leaf fields")
+ }
+ n.pattern = p
+ n.handler = h
+}
+
+// addChild adds a child node with the given key to n
+// if one does not exist, and returns the child.
+func (n *routingNode) addChild(key string) *routingNode {
+ if key == "" {
+ if n.emptyChild == nil {
+ n.emptyChild = &routingNode{}
+ }
+ return n.emptyChild
+ }
+ if c := n.findChild(key); c != nil {
+ return c
+ }
+ c := &routingNode{}
+ n.children.add(key, c)
+ return c
+}
+
+// findChild returns the child of n with the given key, or nil
+// if there is no child with that key.
+func (n *routingNode) findChild(key string) *routingNode {
+ if key == "" {
+ return n.emptyChild
+ }
+ r, _ := n.children.find(key)
+ return r
+}
+
+// match returns the leaf node under root that matches the arguments, and a list
+// of values for pattern wildcards in the order that the wildcards appear.
+// For example, if the request path is "/a/b/c" and the pattern is "/{x}/b/{y}",
+// then the second return value will be []string{"a", "c"}.
+func (root *routingNode) match(host, method, path string) (*routingNode, []string) {
+ if host != "" {
+ // There is a host. If there is a pattern that specifies that host and it
+ // matches, we are done. If the pattern doesn't match, fall through to
+ // try patterns with no host.
+ if l, m := root.findChild(host).matchMethodAndPath(method, path); l != nil {
+ return l, m
+ }
+ }
+ return root.emptyChild.matchMethodAndPath(method, path)
+}
+
+// matchMethodAndPath matches the method and path.
+// Its return values are the same as [routingNode.match].
+// The receiver should be a child of the root.
+func (n *routingNode) matchMethodAndPath(method, path string) (*routingNode, []string) {
+ if n == nil {
+ return nil, nil
+ }
+ if l, m := n.findChild(method).matchPath(path, nil); l != nil {
+ // Exact match of method name.
+ return l, m
+ }
+ if method == "HEAD" {
+ // GET matches HEAD too.
+ if l, m := n.findChild("GET").matchPath(path, nil); l != nil {
+ return l, m
+ }
+ }
+ // No exact match; try patterns with no method.
+ return n.emptyChild.matchPath(path, nil)
+}
+
+// matchPath matches a path.
+// Its return values are the same as [routingNode.match].
+// matchPath calls itself recursively. The matches argument holds the wildcard matches
+// found so far.
+func (n *routingNode) matchPath(path string, matches []string) (*routingNode, []string) {
+ if n == nil {
+ return nil, nil
+ }
+ // If path is empty, then we are done.
+ // If n is a leaf node, we found a match; return it.
+ // If n is an interior node (which means it has a nil pattern),
+ // then we failed to match.
+ if path == "" {
+ if n.pattern == nil {
+ return nil, nil
+ }
+ return n, matches
+ }
+ // Get the first segment of path.
+ seg, rest := firstSegment(path)
+ // First try matching against patterns that have a literal for this position.
+ // We know by construction that such patterns are more specific than those
+ // with a wildcard at this position (they are either more specific, equivalent,
+ // or overlap, and we ruled out the first two when the patterns were registered).
+ if n, m := n.findChild(seg).matchPath(rest, matches); n != nil {
+ return n, m
+ }
+ // If matching a literal fails, try again with patterns that have a single
+ // wildcard (represented by an empty string in the child mapping).
+ // Again, by construction, patterns with a single wildcard must be more specific than
+ // those with a multi wildcard.
+ // We skip this step if the segment is a trailing slash, because single wildcards
+ // don't match trailing slashes.
+ if seg != "/" {
+ if n, m := n.emptyChild.matchPath(rest, append(matches, seg)); n != nil {
+ return n, m
+ }
+ }
+ // Lastly, match the pattern (there can be at most one) that has a multi
+ // wildcard in this position to the rest of the path.
+ if c := n.findChild("*"); c != nil {
+ // Don't record a match for a nameless wildcard (which arises from a
+ // trailing slash in the pattern).
+ if c.pattern.lastSegment().s != "" {
+ matches = append(matches, pathUnescape(path[1:])) // remove initial slash
+ }
+ return c, matches
+ }
+ return nil, nil
+}
+
+// firstSegment splits path into its first segment, and the rest.
+// The path must begin with "/".
+// If path consists of only a slash, firstSegment returns ("/", "").
+// The segment is returned unescaped, if possible.
+func firstSegment(path string) (seg, rest string) {
+ if path == "/" {
+ return "/", ""
+ }
+ path = path[1:] // drop initial slash
+ i := strings.IndexByte(path, '/')
+ if i < 0 {
+ i = len(path)
+ }
+ return pathUnescape(path[:i]), path[i:]
+}
+
+// matchingMethods adds to methodSet all the methods that would result in a
+// match if passed to routingNode.match with the given host and path.
+func (root *routingNode) matchingMethods(host, path string, methodSet map[string]bool) {
+ if host != "" {
+ root.findChild(host).matchingMethodsPath(path, methodSet)
+ }
+ root.emptyChild.matchingMethodsPath(path, methodSet)
+ if methodSet["GET"] {
+ methodSet["HEAD"] = true
+ }
+}
+
+func (n *routingNode) matchingMethodsPath(path string, set map[string]bool) {
+ if n == nil {
+ return
+ }
+ n.children.eachPair(func(method string, c *routingNode) bool {
+ if p, _ := c.matchPath(path, nil); p != nil {
+ set[method] = true
+ }
+ return true
+ })
+ // Don't look at the empty child. If there were an empty
+ // child, it would match on any method, but we only
+ // call this when we fail to match on a method.
+}
diff --git a/src/net/http/routing_tree_test.go b/src/net/http/routing_tree_test.go
new file mode 100644
index 0000000000..2aac8b6cdf
--- /dev/null
+++ b/src/net/http/routing_tree_test.go
@@ -0,0 +1,295 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "io"
+ "sort"
+ "strings"
+ "testing"
+
+ "slices"
+)
+
+func TestRoutingFirstSegment(t *testing.T) {
+ for _, test := range []struct {
+ in string
+ want []string
+ }{
+ {"/a/b/c", []string{"a", "b", "c"}},
+ {"/a/b/", []string{"a", "b", "/"}},
+ {"/", []string{"/"}},
+ {"/a/%62/c", []string{"a", "b", "c"}},
+ {"/a%2Fb%2fc", []string{"a/b/c"}},
+ } {
+ var got []string
+ rest := test.in
+ for len(rest) > 0 {
+ var seg string
+ seg, rest = firstSegment(rest)
+ got = append(got, seg)
+ }
+ if !slices.Equal(got, test.want) {
+ t.Errorf("%q: got %v, want %v", test.in, got, test.want)
+ }
+ }
+}
+
+// TODO: test host and method
+var testTree *routingNode
+
+func getTestTree() *routingNode {
+ if testTree == nil {
+ testTree = buildTree("/a", "/a/b", "/a/{x}",
+ "/g/h/i", "/g/{x}/j",
+ "/a/b/{x...}", "/a/b/{y}", "/a/b/{$}")
+ }
+ return testTree
+}
+
+func buildTree(pats ...string) *routingNode {
+ root := &routingNode{}
+ for _, p := range pats {
+ pat, err := parsePattern(p)
+ if err != nil {
+ panic(err)
+ }
+ root.addPattern(pat, nil)
+ }
+ return root
+}
+
+func TestRoutingAddPattern(t *testing.T) {
+ want := `"":
+ "":
+ "a":
+ "/a"
+ "":
+ "/a/{x}"
+ "b":
+ "/a/b"
+ "":
+ "/a/b/{y}"
+ "*":
+ "/a/b/{x...}"
+ "/":
+ "/a/b/{$}"
+ "g":
+ "":
+ "j":
+ "/g/{x}/j"
+ "h":
+ "i":
+ "/g/h/i"
+`
+
+ var b strings.Builder
+ getTestTree().print(&b, 0)
+ got := b.String()
+ if got != want {
+ t.Errorf("got\n%s\nwant\n%s", got, want)
+ }
+}
+
+type testCase struct {
+ method, host, path string
+ wantPat string // "" for nil (no match)
+ wantMatches []string
+}
+
+func TestRoutingNodeMatch(t *testing.T) {
+
+ test := func(tree *routingNode, tests []testCase) {
+ t.Helper()
+ for _, test := range tests {
+ gotNode, gotMatches := tree.match(test.host, test.method, test.path)
+ got := ""
+ if gotNode != nil {
+ got = gotNode.pattern.String()
+ }
+ if got != test.wantPat {
+ t.Errorf("%s, %s, %s: got %q, want %q", test.host, test.method, test.path, got, test.wantPat)
+ }
+ if !slices.Equal(gotMatches, test.wantMatches) {
+ t.Errorf("%s, %s, %s: got matches %v, want %v", test.host, test.method, test.path, gotMatches, test.wantMatches)
+ }
+ }
+ }
+
+ test(getTestTree(), []testCase{
+ {"GET", "", "/a", "/a", nil},
+ {"Get", "", "/b", "", nil},
+ {"Get", "", "/a/b", "/a/b", nil},
+ {"Get", "", "/a/c", "/a/{x}", []string{"c"}},
+ {"Get", "", "/a/b/", "/a/b/{$}", nil},
+ {"Get", "", "/a/b/c", "/a/b/{y}", []string{"c"}},
+ {"Get", "", "/a/b/c/d", "/a/b/{x...}", []string{"c/d"}},
+ {"Get", "", "/g/h/i", "/g/h/i", nil},
+ {"Get", "", "/g/h/j", "/g/{x}/j", []string{"h"}},
+ })
+
+ tree := buildTree(
+ "/item/",
+ "POST /item/{user}",
+ "GET /item/{user}",
+ "/item/{user}",
+ "/item/{user}/{id}",
+ "/item/{user}/new",
+ "/item/{$}",
+ "POST alt.com/item/{user}",
+ "GET /headwins",
+ "HEAD /headwins",
+ "/path/{p...}")
+
+ test(tree, []testCase{
+ {"GET", "", "/item/jba",
+ "GET /item/{user}", []string{"jba"}},
+ {"POST", "", "/item/jba",
+ "POST /item/{user}", []string{"jba"}},
+ {"HEAD", "", "/item/jba",
+ "GET /item/{user}", []string{"jba"}},
+ {"get", "", "/item/jba",
+ "/item/{user}", []string{"jba"}}, // method matches are case-sensitive
+ {"POST", "", "/item/jba/17",
+ "/item/{user}/{id}", []string{"jba", "17"}},
+ {"GET", "", "/item/jba/new",
+ "/item/{user}/new", []string{"jba"}},
+ {"GET", "", "/item/",
+ "/item/{$}", []string{}},
+ {"GET", "", "/item/jba/17/line2",
+ "/item/", nil},
+ {"POST", "alt.com", "/item/jba",
+ "POST alt.com/item/{user}", []string{"jba"}},
+ {"GET", "alt.com", "/item/jba",
+ "GET /item/{user}", []string{"jba"}},
+ {"GET", "", "/item",
+ "", nil}, // does not match
+ {"GET", "", "/headwins",
+ "GET /headwins", nil},
+ {"HEAD", "", "/headwins", // HEAD is more specific than GET
+ "HEAD /headwins", nil},
+ {"GET", "", "/path/to/file",
+ "/path/{p...}", []string{"to/file"}},
+ })
+
+ // A pattern ending in {$} should only match URLS with a trailing slash.
+ pat1 := "/a/b/{$}"
+ test(buildTree(pat1), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", pat1, nil},
+ {"GET", "", "/a/b/c", "", nil},
+ {"GET", "", "/a/b/c/d", "", nil},
+ })
+
+ // A pattern ending in a single wildcard should not match a trailing slash URL.
+ pat2 := "/a/b/{w}"
+ test(buildTree(pat2), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", "", nil},
+ {"GET", "", "/a/b/c", pat2, []string{"c"}},
+ {"GET", "", "/a/b/c/d", "", nil},
+ })
+
+ // A pattern ending in a multi wildcard should match both URLs.
+ pat3 := "/a/b/{w...}"
+ test(buildTree(pat3), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", pat3, []string{""}},
+ {"GET", "", "/a/b/c", pat3, []string{"c"}},
+ {"GET", "", "/a/b/c/d", pat3, []string{"c/d"}},
+ })
+
+ // All three of the above should work together.
+ test(buildTree(pat1, pat2, pat3), []testCase{
+ {"GET", "", "/a/b", "", nil},
+ {"GET", "", "/a/b/", pat1, nil},
+ {"GET", "", "/a/b/c", pat2, []string{"c"}},
+ {"GET", "", "/a/b/c/d", pat3, []string{"c/d"}},
+ })
+}
+
+func TestMatchingMethods(t *testing.T) {
+ hostTree := buildTree("GET a.com/", "PUT b.com/", "POST /foo/{x}")
+ for _, test := range []struct {
+ name string
+ tree *routingNode
+ host, path string
+ want string
+ }{
+ {
+ "post",
+ buildTree("POST /"), "", "/foo",
+ "POST",
+ },
+ {
+ "get",
+ buildTree("GET /"), "", "/foo",
+ "GET,HEAD",
+ },
+ {
+ "host",
+ hostTree, "", "/foo",
+ "",
+ },
+ {
+ "host",
+ hostTree, "", "/foo/bar",
+ "POST",
+ },
+ {
+ "host2",
+ hostTree, "a.com", "/foo/bar",
+ "GET,HEAD,POST",
+ },
+ {
+ "host3",
+ hostTree, "b.com", "/bar",
+ "PUT",
+ },
+ {
+ // This case shouldn't come up because we only call matchingMethods
+ // when there was no match, but we include it for completeness.
+ "empty",
+ buildTree("/"), "", "/",
+ "",
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ ms := map[string]bool{}
+ test.tree.matchingMethods(test.host, test.path, ms)
+ keys := mapKeys(ms)
+ sort.Strings(keys)
+ got := strings.Join(keys, ",")
+ if got != test.want {
+ t.Errorf("got %s, want %s", got, test.want)
+ }
+ })
+ }
+}
+
+func (n *routingNode) print(w io.Writer, level int) {
+ indent := strings.Repeat(" ", level)
+ if n.pattern != nil {
+ fmt.Fprintf(w, "%s%q\n", indent, n.pattern)
+ }
+ if n.emptyChild != nil {
+ fmt.Fprintf(w, "%s%q:\n", indent, "")
+ n.emptyChild.print(w, level+1)
+ }
+
+ var keys []string
+ n.children.eachPair(func(k string, _ *routingNode) bool {
+ keys = append(keys, k)
+ return true
+ })
+ sort.Strings(keys)
+
+ for _, k := range keys {
+ fmt.Fprintf(w, "%s%q:\n", indent, k)
+ n, _ := n.children.find(k)
+ n.print(w, level+1)
+ }
+}
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index bb380cf4a5..0c76f1bcc4 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -30,7 +30,6 @@ import (
"net/http/internal/testcert"
"net/url"
"os"
- "os/exec"
"path/filepath"
"reflect"
"regexp"
@@ -108,10 +107,14 @@ type testConn struct {
readMu sync.Mutex // for TestHandlerBodyClose
readBuf bytes.Buffer
writeBuf bytes.Buffer
- closec chan bool // if non-nil, send value to it on close
+ closec chan bool // 1-buffered; receives true when Close is called
noopConn
}
+func newTestConn() *testConn {
+ return &testConn{closec: make(chan bool, 1)}
+}
+
func (c *testConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
@@ -647,30 +650,27 @@ func benchmarkServeMux(b *testing.B, runHandler bool) {
func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
func testServerTimeouts(t *testing.T, mode testMode) {
- // Try three times, with increasing timeouts.
- tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
- for i, timeout := range tries {
- err := testServerTimeoutsWithTimeout(t, timeout, mode)
- if err == nil {
- return
- }
- t.Logf("failed at %v: %v", timeout, err)
- if i != len(tries)-1 {
- t.Logf("retrying at %v ...", tries[i+1])
- }
- }
- t.Fatal("all attempts failed")
+ runTimeSensitiveTest(t, []time.Duration{
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 100 * time.Millisecond,
+ 500 * time.Millisecond,
+ 1 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ return testServerTimeoutsWithTimeout(t, timeout, mode)
+ })
}
func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
- reqNum := 0
- ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
- reqNum++
- fmt.Fprintf(res, "req=%d", reqNum)
+ var reqNum atomic.Int32
+ cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ fmt.Fprintf(res, "req=%d", reqNum.Add(1))
}), func(ts *httptest.Server) {
ts.Config.ReadTimeout = timeout
ts.Config.WriteTimeout = timeout
- }).ts
+ })
+ defer cst.close()
+ ts := cst.ts
// Hit the HTTP server successfully.
c := ts.Client()
@@ -866,16 +866,20 @@ func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
}
func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
- reqNum := 0
- ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
- reqNum++
- if reqNum == 1 {
- return // first request succeeds
+ firstRequest := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ select {
+ case firstRequest <- true:
+ // first request succeeds
+ default:
+ // second request times out
+ time.Sleep(timeout)
}
- time.Sleep(timeout) // second request times out
}), func(ts *httptest.Server) {
ts.Config.WriteTimeout = timeout / 2
- }).ts
+ })
+ defer cst.close()
+ ts := cst.ts
c := ts.Client()
@@ -922,14 +926,18 @@ func TestNoWriteDeadline(t *testing.T) {
}
func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
- reqNum := 0
- ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
- reqNum++
- if reqNum == 1 {
- return // first request succeeds
+ firstRequest := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ select {
+ case firstRequest <- true:
+ // first request succeeds
+ default:
+ // second request times out
+ time.Sleep(timeout)
}
- time.Sleep(timeout) // second request timesout
- })).ts
+ }))
+ defer cst.close()
+ ts := cst.ts
c := ts.Client()
@@ -1392,27 +1400,28 @@ func TestTLSHandshakeTimeout(t *testing.T) {
run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
}
func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
- errc := make(chanWriter, 10) // but only expecting 1
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
+ errLog := new(strings.Builder)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
func(ts *httptest.Server) {
ts.Config.ReadTimeout = 250 * time.Millisecond
- ts.Config.ErrorLog = log.New(errc, "", 0)
+ ts.Config.ErrorLog = log.New(errLog, "", 0)
},
- ).ts
+ )
+ ts := cst.ts
+
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
- defer conn.Close()
-
var buf [1]byte
n, err := conn.Read(buf[:])
if err == nil || n != 0 {
t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
}
+ conn.Close()
- v := <-errc
- if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
+ cst.close()
+ if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
t.Errorf("expected a TLS handshake timeout error; got %q", v)
}
}
@@ -2968,15 +2977,36 @@ func (b neverEnding) Read(p []byte) (n int, err error) {
return len(p), nil
}
-type countReader struct {
- r io.Reader
- n *int64
+type bodyLimitReader struct {
+ mu sync.Mutex
+ count int
+ limit int
+ closed chan struct{}
}
-func (cr countReader) Read(p []byte) (n int, err error) {
- n, err = cr.r.Read(p)
- atomic.AddInt64(cr.n, int64(n))
- return
+func (r *bodyLimitReader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ select {
+ case <-r.closed:
+ return 0, errors.New("closed")
+ default:
+ }
+ if r.count > r.limit {
+ return 0, errors.New("at limit")
+ }
+ r.count += len(p)
+ for i := range p {
+ p[i] = 'a'
+ }
+ return len(p), nil
+}
+
+func (r *bodyLimitReader) Close() error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ close(r.closed)
+ return nil
}
func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
@@ -3000,8 +3030,11 @@ func testRequestBodyLimit(t *testing.T, mode testMode) {
}
}))
- nWritten := new(int64)
- req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
+ body := &bodyLimitReader{
+ closed: make(chan struct{}),
+ limit: limit * 200,
+ }
+ req, _ := NewRequest("POST", cst.ts.URL, body)
// Send the POST, but don't care it succeeds or not. The
// remote side is going to reply and then close the TCP
@@ -3016,10 +3049,13 @@ func testRequestBodyLimit(t *testing.T, mode testMode) {
if err == nil {
resp.Body.Close()
}
+ // Wait for the Transport to finish writing the request body.
+ // It will close the body when done.
+ <-body.closed
- if atomic.LoadInt64(nWritten) > limit*100 {
+ if body.count > limit*100 {
t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
- limit, nWritten)
+ limit, body.count)
}
}
@@ -3075,47 +3111,71 @@ func TestServerBufferedChunking(t *testing.T) {
// closing the TCP connection, causing the client to get a RST.
// See https://golang.org/issue/3595
func TestServerGracefulClose(t *testing.T) {
- run(t, testServerGracefulClose, []testMode{http1Mode})
+ // Not parallel: modifies the global rstAvoidanceDelay.
+ run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
}
func testServerGracefulClose(t *testing.T, mode testMode) {
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- Error(w, "bye", StatusUnauthorized)
- })).ts
+ runTimeSensitiveTest(t, []time.Duration{
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 100 * time.Millisecond,
+ 500 * time.Millisecond,
+ time.Second,
+ 5 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ SetRSTAvoidanceDelay(t, timeout)
+ t.Logf("set RST avoidance delay to %v", timeout)
- conn, err := net.Dial("tcp", ts.Listener.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer conn.Close()
- const bodySize = 5 << 20
- req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
- for i := 0; i < bodySize; i++ {
- req = append(req, 'x')
- }
- writeErr := make(chan error)
- go func() {
- _, err := conn.Write(req)
- writeErr <- err
- }()
- br := bufio.NewReader(conn)
- lineNum := 0
- for {
- line, err := br.ReadString('\n')
- if err == io.EOF {
- break
+ const bodySize = 5 << 20
+ req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
+ for i := 0; i < bodySize; i++ {
+ req = append(req, 'x')
}
+
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, "bye", StatusUnauthorized)
+ }))
+ // We need to close cst explicitly here so that in-flight server
+ // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+ defer cst.close()
+ ts := cst.ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
- t.Fatalf("ReadLine: %v", err)
+ return err
}
- lineNum++
- if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
- t.Errorf("Response line = %q; want a 401", line)
+ writeErr := make(chan error)
+ go func() {
+ _, err := conn.Write(req)
+ writeErr <- err
+ }()
+ defer func() {
+ conn.Close()
+ // Wait for write to finish. This is a broken pipe on both
+ // Darwin and Linux, but checking this isn't the point of
+ // the test.
+ <-writeErr
+ }()
+
+ br := bufio.NewReader(conn)
+ lineNum := 0
+ for {
+ line, err := br.ReadString('\n')
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("ReadLine: %v", err)
+ }
+ lineNum++
+ if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
+ t.Errorf("Response line = %q; want a 401", line)
+ }
}
- }
- // Wait for write to finish. This is a broken pipe on both
- // Darwin and Linux, but checking this isn't the point of
- // the test.
- <-writeErr
+ return nil
+ })
}
func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
@@ -3897,91 +3957,93 @@ func TestContentTypeOkayOn204(t *testing.T) {
// and the http client), and both think they can close it on failure.
// Therefore, all incoming server requests Bodies need to be thread-safe.
func TestTransportAndServerSharedBodyRace(t *testing.T) {
- run(t, testTransportAndServerSharedBodyRace)
+ run(t, testTransportAndServerSharedBodyRace, testNotParallel)
}
func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
- const bodySize = 1 << 20
-
- // errorf is like t.Errorf, but also writes to println. When
- // this test fails, it hangs. This helps debugging and I've
- // added this enough times "temporarily". It now gets added
- // full time.
- errorf := func(format string, args ...any) {
- v := fmt.Sprintf(format, args...)
- println(v)
- t.Error(v)
- }
-
- unblockBackend := make(chan bool)
- backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
- gone := rw.(CloseNotifier).CloseNotify()
- didCopy := make(chan any)
- go func() {
+ // The proxy server in the middle of the stack for this test potentially
+ // from its handler after only reading half of the body.
+ // That can trigger https://go.dev/issue/3595, which is otherwise
+ // irrelevant to this test.
+ runTimeSensitiveTest(t, []time.Duration{
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 100 * time.Millisecond,
+ 500 * time.Millisecond,
+ time.Second,
+ 5 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ SetRSTAvoidanceDelay(t, timeout)
+ t.Logf("set RST avoidance delay to %v", timeout)
+
+ const bodySize = 1 << 20
+
+ var wg sync.WaitGroup
+ backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ // Work around https://go.dev/issue/38370: clientServerTest uses
+ // an httptest.Server under the hood, and in HTTP/2 mode it does not always
+ // “[block] until all outstanding requests on this server have completed”,
+ // causing the call to Logf below to race with the end of the test.
+ //
+ // Since the client doesn't cancel the request until we have copied half
+ // the body, this call to add happens before the test is cleaned up,
+ // preventing the race.
+ wg.Add(1)
+ defer wg.Done()
+
n, err := io.CopyN(rw, req.Body, bodySize)
- didCopy <- []any{n, err}
+ t.Logf("backend CopyN: %v, %v", n, err)
+ <-req.Context().Done()
+ }))
+ // We need to close explicitly here so that in-flight server
+ // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+ defer func() {
+ wg.Wait()
+ backend.close()
}()
- isGone := false
- Loop:
- for {
- select {
- case <-didCopy:
- break Loop
- case <-gone:
- isGone = true
- case <-time.After(time.Second):
- println("1 second passes in backend, proxygone=", isGone)
- }
- }
- <-unblockBackend
- }))
- defer backend.close()
- backendRespc := make(chan *Response, 1)
- var proxy *clientServerTest
- proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
- req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
- req2.ContentLength = bodySize
- cancel := make(chan struct{})
- req2.Cancel = cancel
+ var proxy *clientServerTest
+ proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
+ req2.ContentLength = bodySize
+ cancel := make(chan struct{})
+ req2.Cancel = cancel
- bresp, err := proxy.c.Do(req2)
- if err != nil {
- errorf("Proxy outbound request: %v", err)
- return
- }
- _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
- if err != nil {
- errorf("Proxy copy error: %v", err)
- return
- }
- backendRespc <- bresp // to close later
+ bresp, err := proxy.c.Do(req2)
+ if err != nil {
+ t.Errorf("Proxy outbound request: %v", err)
+ return
+ }
+ _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
+ if err != nil {
+ t.Errorf("Proxy copy error: %v", err)
+ return
+ }
+ t.Cleanup(func() { bresp.Body.Close() })
+
+ // Try to cause a race. Canceling the client request will cause the client
+ // transport to close req2.Body. Returning from the server handler will
+ // cause the server to close req.Body. Since they are the same underlying
+ // ReadCloser, that will result in concurrent calls to Close (and possibly a
+ // Read concurrent with a Close).
+ if mode == http2Mode {
+ close(cancel)
+ } else {
+ proxy.c.Transport.(*Transport).CancelRequest(req2)
+ }
+ rw.Write([]byte("OK"))
+ }))
+ defer proxy.close()
- // Try to cause a race: Both the Transport and the proxy handler's Server
- // will try to read/close req.Body (aka req2.Body)
- if mode == http2Mode {
- close(cancel)
- } else {
- proxy.c.Transport.(*Transport).CancelRequest(req2)
+ req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
+ res, err := proxy.c.Do(req)
+ if err != nil {
+ return fmt.Errorf("original request: %v", err)
}
- rw.Write([]byte("OK"))
- }))
- defer proxy.close()
-
- defer close(unblockBackend)
- req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
- res, err := proxy.c.Do(req)
- if err != nil {
- t.Fatalf("Original request: %v", err)
- }
-
- // Cleanup, so we don't leak goroutines.
- res.Body.Close()
- select {
- case res := <-backendRespc:
res.Body.Close()
- default:
- // We failed earlier. (e.g. on proxy.c.Do(req2))
- }
+ return nil
+ })
}
// Test that a hanging Request.Body.Read from another goroutine can't
@@ -4316,7 +4378,8 @@ func (c *closeWriteTestConn) CloseWrite() error {
}
func TestCloseWrite(t *testing.T) {
- setParallel(t)
+ SetRSTAvoidanceDelay(t, 1*time.Millisecond)
+
var srv Server
var testConn closeWriteTestConn
c := ExportServerNewConn(&srv, &testConn)
@@ -4552,10 +4615,10 @@ Host: foo
}
// If a Handler finishes and there's an unread request body,
-// verify the server try to do implicit read on it before replying.
+// verify the server implicitly tries to do a read on it before replying.
func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
setParallel(t)
- conn := &testConn{closec: make(chan bool)}
+ conn := newTestConn()
conn.readBuf.Write([]byte(fmt.Sprintf(
"POST / HTTP/1.1\r\n" +
"Host: test\r\n" +
@@ -4645,7 +4708,7 @@ func TestServerValidatesHostHeader(t *testing.T) {
{"GET / HTTP/3.0", "", 505},
}
for _, tt := range tests {
- conn := &testConn{closec: make(chan bool, 1)}
+ conn := newTestConn()
methodTarget := "GET / "
if !strings.HasPrefix(tt.proto, "HTTP/") {
methodTarget = ""
@@ -4743,7 +4806,7 @@ func TestServerValidatesHeaders(t *testing.T) {
{"foo: foo\xfffoo\r\n", 200}, // non-ASCII high octets in value are fine
}
for _, tt := range tests {
- conn := &testConn{closec: make(chan bool, 1)}
+ conn := newTestConn()
io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
ln := &oneConnListener{conn}
@@ -4966,7 +5029,7 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode)
// For use like:
//
// $ go test -c
-// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof
+// $ ./http.test -test.run='^$' -test.bench='^BenchmarkServer$' -test.benchtime=15s -test.cpuprofile=http.prof
// $ go tool pprof http.test http.prof
// (pprof) web
func BenchmarkServer(b *testing.B) {
@@ -5005,7 +5068,7 @@ func BenchmarkServer(b *testing.B) {
defer ts.Close()
b.StartTimer()
- cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer$")
+ cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
cmd.Env = append([]string{
fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
@@ -5060,7 +5123,7 @@ func BenchmarkClient(b *testing.B) {
// Start server process.
ctx, cancel := context.WithCancel(context.Background())
- cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$")
+ cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
cmd.Stderr = os.Stderr
stdout, err := cmd.StdoutPipe()
@@ -5129,11 +5192,7 @@ Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
`)
res := []byte("Hello world!\n")
- conn := &testConn{
- // testConn.Close will not push into the channel
- // if it's full.
- closec: make(chan bool, 1),
- }
+ conn := newTestConn()
handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Write(res)
@@ -5356,49 +5415,75 @@ func testServerIdleTimeout(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- io.Copy(io.Discard, r.Body)
- io.WriteString(w, r.RemoteAddr)
- }), func(ts *httptest.Server) {
- ts.Config.ReadHeaderTimeout = 1 * time.Second
- ts.Config.IdleTimeout = 2 * time.Second
- }).ts
- c := ts.Client()
+ runTimeSensitiveTest(t, []time.Duration{
+ 10 * time.Millisecond,
+ 100 * time.Millisecond,
+ 1 * time.Second,
+ 10 * time.Second,
+ }, func(t *testing.T, readHeaderTimeout time.Duration) error {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(io.Discard, r.Body)
+ io.WriteString(w, r.RemoteAddr)
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadHeaderTimeout = readHeaderTimeout
+ ts.Config.IdleTimeout = 2 * readHeaderTimeout
+ })
+ defer cst.close()
+ ts := cst.ts
+ t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
+ t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
+ c := ts.Client()
- get := func() string {
- res, err := c.Get(ts.URL)
+ get := func() (string, error) {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ return "", err
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ // If we're at this point the headers have definitely already been
+ // read and the server is not idle, so neither timeout applies:
+ // this should never fail.
+ t.Fatal(err)
+ }
+ return string(slurp), nil
+ }
+
+ a1, err := get()
if err != nil {
- t.Fatal(err)
+ return err
}
- defer res.Body.Close()
- slurp, err := io.ReadAll(res.Body)
+ a2, err := get()
if err != nil {
- t.Fatal(err)
+ return err
+ }
+ if a1 != a2 {
+ return fmt.Errorf("did requests on different connections")
+ }
+ time.Sleep(ts.Config.IdleTimeout * 3 / 2)
+ a3, err := get()
+ if err != nil {
+ return err
+ }
+ if a2 == a3 {
+ return fmt.Errorf("request three unexpectedly on same connection")
}
- return string(slurp)
- }
- a1, a2 := get(), get()
- if a1 != a2 {
- t.Fatalf("did requests on different connections")
- }
- time.Sleep(3 * time.Second)
- a3 := get()
- if a2 == a3 {
- t.Fatal("request three unexpectedly on same connection")
- }
+ // And test that ReadHeaderTimeout still works:
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+ conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
+ time.Sleep(ts.Config.ReadHeaderTimeout * 2)
+ if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
+ return fmt.Errorf("copy byte succeeded; want err")
+ }
- // And test that ReadHeaderTimeout still works:
- conn, err := net.Dial("tcp", ts.Listener.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer conn.Close()
- conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
- time.Sleep(2 * time.Second)
- if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
- t.Fatal("copy byte succeeded; want err")
- }
+ return nil
+ })
}
func get(t *testing.T, c *Client, url string) string {
@@ -5658,7 +5743,7 @@ func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
time.Second,
2 * time.Second,
}, func(t *testing.T, timeout time.Duration) error {
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
select {
case <-time.After(2 * timeout):
fmt.Fprint(w, "ok")
@@ -5667,7 +5752,9 @@ func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
}
}), func(ts *httptest.Server) {
ts.Config.ReadTimeout = timeout
- }).ts
+ })
+ defer cst.close()
+ ts := cst.ts
c := ts.Client()
@@ -5701,10 +5788,12 @@ func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
time.Second,
2 * time.Second,
}, func(t *testing.T, timeout time.Duration) error {
- ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
+ cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
ts.Config.ReadHeaderTimeout = timeout
ts.Config.IdleTimeout = 0 // disable idle timeout
- }).ts
+ })
+ defer cst.close()
+ ts := cst.ts
// rather than using an http.Client, create a single connection, so that
// we can ensure this connection is not closed.
@@ -5747,9 +5836,10 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *
if err == nil {
return
}
- if i == len(durations)-1 {
+ if i == len(durations)-1 || t.Failed() {
t.Fatalf("failed with duration %v: %v", d, err)
}
+ t.Logf("retrying after error with duration %v: %v", d, err)
}
}
@@ -5929,7 +6019,7 @@ func TestServerValidatesMethod(t *testing.T) {
{"GE(T", 400},
}
for _, tt := range tests {
- conn := &testConn{closec: make(chan bool, 1)}
+ conn := newTestConn()
io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
ln := &oneConnListener{conn}
@@ -6594,7 +6684,7 @@ func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string,
}
func TestMaxBytesHandler(t *testing.T) {
- setParallel(t)
+ // Not parallel: modifies the global rstAvoidanceDelay.
defer afterTest(t)
for _, maxSize := range []int64{100, 1_000, 1_000_000} {
@@ -6603,77 +6693,99 @@ func TestMaxBytesHandler(t *testing.T) {
func(t *testing.T) {
run(t, func(t *testing.T, mode testMode) {
testMaxBytesHandler(t, mode, maxSize, requestSize)
- })
+ }, testNotParallel)
})
}
}
}
func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
- var (
- handlerN int64
- handlerErr error
- )
- echo := HandlerFunc(func(w ResponseWriter, r *Request) {
- var buf bytes.Buffer
- handlerN, handlerErr = io.Copy(&buf, r.Body)
- io.Copy(w, &buf)
- })
-
- ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts
- defer ts.Close()
+ runTimeSensitiveTest(t, []time.Duration{
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 100 * time.Millisecond,
+ 500 * time.Millisecond,
+ time.Second,
+ 5 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ SetRSTAvoidanceDelay(t, timeout)
+ t.Logf("set RST avoidance delay to %v", timeout)
+
+ var (
+ handlerN int64
+ handlerErr error
+ )
+ echo := HandlerFunc(func(w ResponseWriter, r *Request) {
+ var buf bytes.Buffer
+ handlerN, handlerErr = io.Copy(&buf, r.Body)
+ io.Copy(w, &buf)
+ })
- c := ts.Client()
+ cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
+ // We need to close cst explicitly here so that in-flight server
+ // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+ defer cst.close()
+ ts := cst.ts
+ c := ts.Client()
- body := strings.Repeat("a", int(requestSize))
- var wg sync.WaitGroup
- defer wg.Wait()
- getBody := func() (io.ReadCloser, error) {
- wg.Add(1)
- body := &wgReadCloser{
- Reader: strings.NewReader(body),
- wg: &wg,
+ body := strings.Repeat("a", int(requestSize))
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ getBody := func() (io.ReadCloser, error) {
+ wg.Add(1)
+ body := &wgReadCloser{
+ Reader: strings.NewReader(body),
+ wg: &wg,
+ }
+ return body, nil
}
- return body, nil
- }
- reqBody, _ := getBody()
- req, err := NewRequest("POST", ts.URL, reqBody)
- if err != nil {
- reqBody.Close()
- t.Fatal(err)
- }
- req.ContentLength = int64(len(body))
- req.GetBody = getBody
- req.Header.Set("Content-Type", "text/plain")
+ reqBody, _ := getBody()
+ req, err := NewRequest("POST", ts.URL, reqBody)
+ if err != nil {
+ reqBody.Close()
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body))
+ req.GetBody = getBody
+ req.Header.Set("Content-Type", "text/plain")
- var buf strings.Builder
- res, err := c.Do(req)
- if err != nil {
- t.Errorf("unexpected connection error: %v", err)
- } else {
- _, err = io.Copy(&buf, res.Body)
- res.Body.Close()
+ var buf strings.Builder
+ res, err := c.Do(req)
if err != nil {
- t.Errorf("unexpected read error: %v", err)
+ return fmt.Errorf("unexpected connection error: %v", err)
+ } else {
+ _, err = io.Copy(&buf, res.Body)
+ res.Body.Close()
+ if err != nil {
+ return fmt.Errorf("unexpected read error: %v", err)
+ }
}
- }
- if handlerN > maxSize {
- t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
- }
- if requestSize > maxSize && handlerErr == nil {
- t.Error("expected error on handler side; got nil")
- }
- if requestSize <= maxSize {
- if handlerErr != nil {
- t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+ // We don't expect any of the errors after this point to occur due
+ // to rstAvoidanceDelay being too short, so we use t.Errorf for those
+ // instead of returning a (retriable) error.
+
+ if handlerN > maxSize {
+ t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
}
- if handlerN != requestSize {
- t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+ if requestSize > maxSize && handlerErr == nil {
+ t.Error("expected error on handler side; got nil")
}
- }
- if buf.Len() != int(handlerN) {
- t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
- }
+ if requestSize <= maxSize {
+ if handlerErr != nil {
+ t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+ }
+ if handlerN != requestSize {
+ t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+ }
+ }
+ if buf.Len() != int(handlerN) {
+ t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
+ }
+
+ return nil
+ })
}
func TestEarlyHints(t *testing.T) {
diff --git a/src/net/http/servemux121.go b/src/net/http/servemux121.go
new file mode 100644
index 0000000000..c0a4b77010
--- /dev/null
+++ b/src/net/http/servemux121.go
@@ -0,0 +1,211 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// This file implements ServeMux behavior as in Go 1.21.
+// The behavior is controlled by a GODEBUG setting.
+// Most of this code is derived from commit 08e35cc334.
+// Changes are minimal: aside from the different receiver type,
+// they mostly involve renaming functions, usually by unexporting them.
+
+import (
+ "internal/godebug"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+)
+
+var httpmuxgo121 = godebug.New("httpmuxgo121")
+
+var use121 bool
+
+// Read httpmuxgo121 once at startup, since dealing with changes to it during
+// program execution is too complex and error-prone.
+func init() {
+ if httpmuxgo121.Value() == "1" {
+ use121 = true
+ httpmuxgo121.IncNonDefault()
+ }
+}
+
+// serveMux121 holds the state of a ServeMux needed for Go 1.21 behavior.
+type serveMux121 struct {
+ mu sync.RWMutex
+ m map[string]muxEntry
+ es []muxEntry // slice of entries sorted from longest to shortest.
+ hosts bool // whether any patterns contain hostnames
+}
+
+type muxEntry struct {
+ h Handler
+ pattern string
+}
+
+// Formerly ServeMux.Handle.
+func (mux *serveMux121) handle(pattern string, handler Handler) {
+ mux.mu.Lock()
+ defer mux.mu.Unlock()
+
+ if pattern == "" {
+ panic("http: invalid pattern")
+ }
+ if handler == nil {
+ panic("http: nil handler")
+ }
+ if _, exist := mux.m[pattern]; exist {
+ panic("http: multiple registrations for " + pattern)
+ }
+
+ if mux.m == nil {
+ mux.m = make(map[string]muxEntry)
+ }
+ e := muxEntry{h: handler, pattern: pattern}
+ mux.m[pattern] = e
+ if pattern[len(pattern)-1] == '/' {
+ mux.es = appendSorted(mux.es, e)
+ }
+
+ if pattern[0] != '/' {
+ mux.hosts = true
+ }
+}
+
+func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
+ n := len(es)
+ i := sort.Search(n, func(i int) bool {
+ return len(es[i].pattern) < len(e.pattern)
+ })
+ if i == n {
+ return append(es, e)
+ }
+ // we now know that i points at where we want to insert
+ es = append(es, muxEntry{}) // try to grow the slice in place, any entry works.
+ copy(es[i+1:], es[i:]) // Move shorter entries down
+ es[i] = e
+ return es
+}
+
+// Formerly ServeMux.HandleFunc.
+func (mux *serveMux121) handleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ if handler == nil {
+ panic("http: nil handler")
+ }
+ mux.handle(pattern, HandlerFunc(handler))
+}
+
+// Formerly ServeMux.Handler.
+func (mux *serveMux121) findHandler(r *Request) (h Handler, pattern string) {
+
+ // CONNECT requests are not canonicalized.
+ if r.Method == "CONNECT" {
+ // If r.URL.Path is /tree and its handler is not registered,
+ // the /tree -> /tree/ redirect applies to CONNECT requests
+ // but the path canonicalization does not.
+ if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok {
+ return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
+ }
+
+ return mux.handler(r.Host, r.URL.Path)
+ }
+
+ // All other requests have any port stripped and path cleaned
+ // before passing to mux.handler.
+ host := stripHostPort(r.Host)
+ path := cleanPath(r.URL.Path)
+
+ // If the given path is /tree and its handler is not registered,
+ // redirect for /tree/.
+ if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok {
+ return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
+ }
+
+ if path != r.URL.Path {
+ _, pattern = mux.handler(host, path)
+ u := &url.URL{Path: path, RawQuery: r.URL.RawQuery}
+ return RedirectHandler(u.String(), StatusMovedPermanently), pattern
+ }
+
+ return mux.handler(host, r.URL.Path)
+}
+
+// handler is the main implementation of findHandler.
+// The path is known to be in canonical form, except for CONNECT methods.
+func (mux *serveMux121) handler(host, path string) (h Handler, pattern string) {
+ mux.mu.RLock()
+ defer mux.mu.RUnlock()
+
+ // Host-specific pattern takes precedence over generic ones
+ if mux.hosts {
+ h, pattern = mux.match(host + path)
+ }
+ if h == nil {
+ h, pattern = mux.match(path)
+ }
+ if h == nil {
+ h, pattern = NotFoundHandler(), ""
+ }
+ return
+}
+
+// Find a handler on a handler map given a path string.
+// Most-specific (longest) pattern wins.
+func (mux *serveMux121) match(path string) (h Handler, pattern string) {
+ // Check for exact match first.
+ v, ok := mux.m[path]
+ if ok {
+ return v.h, v.pattern
+ }
+
+ // Check for longest valid match. mux.es contains all patterns
+ // that end in / sorted from longest to shortest.
+ for _, e := range mux.es {
+ if strings.HasPrefix(path, e.pattern) {
+ return e.h, e.pattern
+ }
+ }
+ return nil, ""
+}
+
+// redirectToPathSlash determines if the given path needs appending "/" to it.
+// This occurs when a handler for path + "/" was already registered, but
+// not for path itself. If the path needs appending to, it creates a new
+// URL, setting the path to u.Path + "/" and returning true to indicate so.
+func (mux *serveMux121) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) {
+ mux.mu.RLock()
+ shouldRedirect := mux.shouldRedirectRLocked(host, path)
+ mux.mu.RUnlock()
+ if !shouldRedirect {
+ return u, false
+ }
+ path = path + "/"
+ u = &url.URL{Path: path, RawQuery: u.RawQuery}
+ return u, true
+}
+
+// shouldRedirectRLocked reports whether the given path and host should be redirected to
+// path+"/". This should happen if a handler is registered for path+"/" but
+// not path -- see comments at ServeMux.
+func (mux *serveMux121) shouldRedirectRLocked(host, path string) bool {
+ p := []string{path, host + path}
+
+ for _, c := range p {
+ if _, exist := mux.m[c]; exist {
+ return false
+ }
+ }
+
+ n := len(path)
+ if n == 0 {
+ return false
+ }
+ for _, c := range p {
+ if _, exist := mux.m[c+"/"]; exist {
+ return path[n-1] != '/'
+ }
+ }
+
+ return false
+}
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 8f63a90299..acac78bcd0 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -61,16 +61,16 @@ var (
// A Handler responds to an HTTP request.
//
-// ServeHTTP should write reply headers and data to the ResponseWriter
+// [Handler.ServeHTTP] should write reply headers and data to the [ResponseWriter]
// and then return. Returning signals that the request is finished; it
-// is not valid to use the ResponseWriter or read from the
-// Request.Body after or concurrently with the completion of the
+// is not valid to use the [ResponseWriter] or read from the
+// [Request.Body] after or concurrently with the completion of the
// ServeHTTP call.
//
// Depending on the HTTP client software, HTTP protocol version, and
// any intermediaries between the client and the Go server, it may not
-// be possible to read from the Request.Body after writing to the
-// ResponseWriter. Cautious handlers should read the Request.Body
+// be possible to read from the [Request.Body] after writing to the
+// [ResponseWriter]. Cautious handlers should read the [Request.Body]
// first, and then reply.
//
// Except for reading the body, handlers should not modify the
@@ -82,7 +82,7 @@ var (
// and either closes the network connection or sends an HTTP/2
// RST_STREAM, depending on the HTTP protocol. To abort a handler so
// the client sees an interrupted response but the server doesn't log
-// an error, panic with the value ErrAbortHandler.
+// an error, panic with the value [ErrAbortHandler].
type Handler interface {
ServeHTTP(ResponseWriter, *Request)
}
@@ -90,15 +90,14 @@ type Handler interface {
// A ResponseWriter interface is used by an HTTP handler to
// construct an HTTP response.
//
-// A ResponseWriter may not be used after the Handler.ServeHTTP method
-// has returned.
+// A ResponseWriter may not be used after [Handler.ServeHTTP] has returned.
type ResponseWriter interface {
// Header returns the header map that will be sent by
- // WriteHeader. The Header map also is the mechanism with which
- // Handlers can set HTTP trailers.
+ // [ResponseWriter.WriteHeader]. The [Header] map also is the mechanism with which
+ // [Handler] implementations can set HTTP trailers.
//
- // Changing the header map after a call to WriteHeader (or
- // Write) has no effect unless the HTTP status code was of the
+ // Changing the header map after a call to [ResponseWriter.WriteHeader] (or
+ // [ResponseWriter.Write]) has no effect unless the HTTP status code was of the
// 1xx class or the modified headers are trailers.
//
// There are two ways to set Trailers. The preferred way is to
@@ -107,9 +106,9 @@ type ResponseWriter interface {
// trailer keys which will come later. In this case, those
// keys of the Header map are treated as if they were
// trailers. See the example. The second way, for trailer
- // keys not known to the Handler until after the first Write,
- // is to prefix the Header map keys with the TrailerPrefix
- // constant value. See TrailerPrefix.
+ // keys not known to the [Handler] until after the first [ResponseWriter.Write],
+ // is to prefix the [Header] map keys with the [TrailerPrefix]
+ // constant value.
//
// To suppress automatic response headers (such as "Date"), set
// their value to nil.
@@ -117,11 +116,11 @@ type ResponseWriter interface {
// Write writes the data to the connection as part of an HTTP reply.
//
- // If WriteHeader has not yet been called, Write calls
+ // If [ResponseWriter.WriteHeader] has not yet been called, Write calls
// WriteHeader(http.StatusOK) before writing the data. If the Header
// does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to
- // DetectContentType. Additionally, if the total size of all written
+ // [DetectContentType]. Additionally, if the total size of all written
// data is under a few KB and there are no Flush calls, the
// Content-Length header is added automatically.
//
@@ -162,8 +161,8 @@ type ResponseWriter interface {
// The Flusher interface is implemented by ResponseWriters that allow
// an HTTP handler to flush buffered data to the client.
//
-// The default HTTP/1.x and HTTP/2 ResponseWriter implementations
-// support Flusher, but ResponseWriter wrappers may not. Handlers
+// The default HTTP/1.x and HTTP/2 [ResponseWriter] implementations
+// support [Flusher], but ResponseWriter wrappers may not. Handlers
// should always test for this ability at runtime.
//
// Note that even for ResponseWriters that support Flush,
@@ -178,7 +177,7 @@ type Flusher interface {
// The Hijacker interface is implemented by ResponseWriters that allow
// an HTTP handler to take over the connection.
//
-// The default ResponseWriter for HTTP/1.x connections supports
+// The default [ResponseWriter] for HTTP/1.x connections supports
// Hijacker, but HTTP/2 connections intentionally do not.
// ResponseWriter wrappers may also not support Hijacker. Handlers
// should always test for this ability at runtime.
@@ -212,7 +211,7 @@ type Hijacker interface {
// if the client has disconnected before the response is ready.
//
// Deprecated: the CloseNotifier interface predates Go's context package.
-// New code should use Request.Context instead.
+// New code should use [Request.Context] instead.
type CloseNotifier interface {
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone
@@ -506,7 +505,7 @@ func (c *response) EnableFullDuplex() error {
return nil
}
-// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
+// TrailerPrefix is a magic prefix for [ResponseWriter.Header] map keys
// that, if present, signals that the map entry is actually for
// the response trailers, and not the response headers. The prefix
// is stripped after the ServeHTTP call finishes and the values are
@@ -572,13 +571,12 @@ type writerOnly struct {
io.Writer
}
-// ReadFrom is here to optimize copying from an *os.File regular file
-// to a *net.TCPConn with sendfile, or from a supported src type such
+// ReadFrom is here to optimize copying from an [*os.File] regular file
+// to a [*net.TCPConn] with sendfile, or from a supported src type such
// as a *net.TCPConn on Linux with splice.
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
- bufp := copyBufPool.Get().(*[]byte)
- buf := *bufp
- defer copyBufPool.Put(bufp)
+ buf := getCopyBuf()
+ defer putCopyBuf(buf)
// Our underlying w.conn.rwc is usually a *TCPConn (with its
// own ReadFrom method). If not, just fall back to the normal
@@ -808,11 +806,18 @@ var (
bufioWriter4kPool sync.Pool
)
-var copyBufPool = sync.Pool{
- New: func() any {
- b := make([]byte, 32*1024)
- return &b
- },
+const copyBufPoolSize = 32 * 1024
+
+var copyBufPool = sync.Pool{New: func() any { return new([copyBufPoolSize]byte) }}
+
+func getCopyBuf() []byte {
+ return copyBufPool.Get().(*[copyBufPoolSize]byte)[:]
+}
+func putCopyBuf(b []byte) {
+ if len(b) != copyBufPoolSize {
+ panic("trying to put back buffer of the wrong size in the copyBufPool")
+ }
+ copyBufPool.Put((*[copyBufPoolSize]byte)(b))
}
func bufioWriterPool(size int) *sync.Pool {
@@ -862,7 +867,7 @@ func putBufioWriter(bw *bufio.Writer) {
// DefaultMaxHeaderBytes is the maximum permitted size of the headers
// in an HTTP request.
-// This can be overridden by setting Server.MaxHeaderBytes.
+// This can be overridden by setting [Server.MaxHeaderBytes].
const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
func (srv *Server) maxHeaderBytes() int {
@@ -935,11 +940,11 @@ func (ecr *expectContinueReader) Close() error {
}
// TimeFormat is the time format to use when generating times in HTTP
-// headers. It is like time.RFC1123 but hard-codes GMT as the time
+// headers. It is like [time.RFC1123] but hard-codes GMT as the time
// zone. The time being formatted must be in UTC for Format to
// generate the correct format.
//
-// For parsing this time format, see ParseTime.
+// For parsing this time format, see [ParseTime].
const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat))
@@ -1585,13 +1590,13 @@ func (w *response) bodyAllowed() bool {
// The Writers are wired together like:
//
// 1. *response (the ResponseWriter) ->
-// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes ->
+// 2. (*response).w, a [*bufio.Writer] of bufferBeforeChunkingSize bytes ->
// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
// and which writes the chunk headers, if needed ->
// 4. conn.bufw, a *bufio.Writer of default (4kB) bytes, writing to ->
// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
// and populates c.werr with it if so, but otherwise writes to ->
-// 6. the rwc, the net.Conn.
+// 6. the rwc, the [net.Conn].
//
// TODO(bradfitz): short-circuit some of the buffering when the
// initial header contains both a Content-Type and Content-Length.
@@ -1752,8 +1757,12 @@ func (c *conn) close() {
// and processes its final data before they process the subsequent RST
// from closing a connection with known unread data.
// This RST seems to occur mostly on BSD systems. (And Windows?)
-// This timeout is somewhat arbitrary (~latency around the planet).
-const rstAvoidanceDelay = 500 * time.Millisecond
+// This timeout is somewhat arbitrary (~latency around the planet),
+// and may be modified by tests.
+//
+// TODO(bcmills): This should arguably be a server configuration parameter,
+// not a hard-coded value.
+var rstAvoidanceDelay = 500 * time.Millisecond
type closeWriter interface {
CloseWrite() error
@@ -1772,6 +1781,27 @@ func (c *conn) closeWriteAndWait() {
if tcp, ok := c.rwc.(closeWriter); ok {
tcp.CloseWrite()
}
+
+ // When we return from closeWriteAndWait, the caller will fully close the
+ // connection. If client is still writing to the connection, this will cause
+ // the write to fail with ECONNRESET or similar. Unfortunately, many TCP
+ // implementations will also drop unread packets from the client's read buffer
+ // when a write fails, causing our final response to be truncated away too.
+ //
+ // As a result, https://www.rfc-editor.org/rfc/rfc7230#section-6.6 recommends
+ // that “[t]he server … continues to read from the connection until it
+ // receives a corresponding close by the client, or until the server is
+ // reasonably certain that its own TCP stack has received the client's
+ // acknowledgement of the packet(s) containing the server's last response.”
+ //
+ // Unfortunately, we have no straightforward way to be “reasonably certain”
+ // that we have received the client's ACK, and at any rate we don't want to
+ // allow a misbehaving client to soak up server connections indefinitely by
+ // withholding an ACK, nor do we want to go through the complexity or overhead
+ // of using low-level APIs to figure out when a TCP round-trip has completed.
+ //
+ // Instead, we declare that we are “reasonably certain” that we received the
+ // ACK if maxRSTAvoidanceDelay has elapsed.
time.Sleep(rstAvoidanceDelay)
}
@@ -1971,7 +2001,7 @@ func (c *conn) serve(ctx context.Context) {
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text)
return
}
- publicErr := "400 Bad Request"
+ const publicErr = "400 Bad Request"
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
return
}
@@ -2067,8 +2097,8 @@ func (w *response) sendExpectationFailed() {
w.finishRequest()
}
-// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
-// and a Hijacker.
+// Hijack implements the [Hijacker.Hijack] method. Our response is both a [ResponseWriter]
+// and a [Hijacker].
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if w.handlerDone.Load() {
panic("net/http: Hijack called after ServeHTTP finished")
@@ -2128,7 +2158,7 @@ func requestBodyRemains(rc io.ReadCloser) bool {
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as HTTP handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
-// Handler that calls f.
+// [Handler] that calls f.
type HandlerFunc func(ResponseWriter, *Request)
// ServeHTTP calls f(w, r).
@@ -2187,9 +2217,9 @@ func StripPrefix(prefix string, h Handler) Handler {
// which may be a path relative to the request path.
//
// The provided code should be in the 3xx range and is usually
-// StatusMovedPermanently, StatusFound or StatusSeeOther.
+// [StatusMovedPermanently], [StatusFound] or [StatusSeeOther].
//
-// If the Content-Type header has not been set, Redirect sets it
+// If the Content-Type header has not been set, [Redirect] sets it
// to "text/html; charset=utf-8" and writes a small HTML body.
// Setting the Content-Type header to any value, including nil,
// disables that behavior.
@@ -2277,7 +2307,7 @@ func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) {
// status code.
//
// The provided code should be in the 3xx range and is usually
-// StatusMovedPermanently, StatusFound or StatusSeeOther.
+// [StatusMovedPermanently], [StatusFound] or [StatusSeeOther].
func RedirectHandler(url string, code int) Handler {
return &redirectHandler{url, code}
}
@@ -2287,52 +2317,132 @@ func RedirectHandler(url string, code int) Handler {
// patterns and calls the handler for the pattern that
// most closely matches the URL.
//
-// Patterns name fixed, rooted paths, like "/favicon.ico",
-// or rooted subtrees, like "/images/" (note the trailing slash).
-// Longer patterns take precedence over shorter ones, so that
-// if there are handlers registered for both "/images/"
-// and "/images/thumbnails/", the latter handler will be
-// called for paths beginning with "/images/thumbnails/" and the
-// former will receive requests for any other paths in the
-// "/images/" subtree.
-//
-// Note that since a pattern ending in a slash names a rooted subtree,
-// the pattern "/" matches all paths not matched by other registered
-// patterns, not just the URL with Path == "/".
-//
-// If a subtree has been registered and a request is received naming the
-// subtree root without its trailing slash, ServeMux redirects that
-// request to the subtree root (adding the trailing slash). This behavior can
-// be overridden with a separate registration for the path without
-// the trailing slash. For example, registering "/images/" causes ServeMux
+// # Patterns
+//
+// Patterns can match the method, host and path of a request.
+// Some examples:
+//
+// - "/index.html" matches the path "/index.html" for any host and method.
+// - "GET /static/" matches a GET request whose path begins with "/static/".
+// - "example.com/" matches any request to the host "example.com".
+// - "example.com/{$}" matches requests with host "example.com" and path "/".
+// - "/b/{bucket}/o/{objectname...}" matches paths whose first segment is "b"
+// and whose third segment is "o". The name "bucket" denotes the second
+// segment and "objectname" denotes the remainder of the path.
+//
+// In general, a pattern looks like
+//
+// [METHOD ][HOST]/[PATH]
+//
+// All three parts are optional; "/" is a valid pattern.
+// If METHOD is present, it must be followed by a single space.
+//
+// Literal (that is, non-wildcard) parts of a pattern match
+// the corresponding parts of a request case-sensitively.
+//
+// A pattern with no method matches every method. A pattern
+// with the method GET matches both GET and HEAD requests.
+// Otherwise, the method must match exactly.
+//
+// A pattern with no host matches every host.
+// A pattern with a host matches URLs on that host only.
+//
+// A path can include wildcard segments of the form {NAME} or {NAME...}.
+// For example, "/b/{bucket}/o/{objectname...}".
+// The wildcard name must be a valid Go identifier.
+// Wildcards must be full path segments: they must be preceded by a slash and followed by
+// either a slash or the end of the string.
+// For example, "/b_{bucket}" is not a valid pattern.
+//
+// Normally a wildcard matches only a single path segment,
+// ending at the next literal slash (not %2F) in the request URL.
+// But if the "..." is present, then the wildcard matches the remainder of the URL path, including slashes.
+// (Therefore it is invalid for a "..." wildcard to appear anywhere but at the end of a pattern.)
+// The match for a wildcard can be obtained by calling [Request.PathValue] with the wildcard's name.
+// A trailing slash in a path acts as an anonymous "..." wildcard.
+//
+// The special wildcard {$} matches only the end of the URL.
+// For example, the pattern "/{$}" matches only the path "/",
+// whereas the pattern "/" matches every path.
+//
+// For matching, both pattern paths and incoming request paths are unescaped segment by segment.
+// So, for example, the path "/a%2Fb/100%25" is treated as having two segments, "a/b" and "100%".
+// The pattern "/a%2fb/" matches it, but the pattern "/a/b/" does not.
+//
+// # Precedence
+//
+// If two or more patterns match a request, then the most specific pattern takes precedence.
+// A pattern P1 is more specific than P2 if P1 matches a strict subset of P2’s requests;
+// that is, if P2 matches all the requests of P1 and more.
+// If neither is more specific, then the patterns conflict.
+// There is one exception to this rule, for backwards compatibility:
+// if two patterns would otherwise conflict and one has a host while the other does not,
+// then the pattern with the host takes precedence.
+// If a pattern passed [ServeMux.Handle] or [ServeMux.HandleFunc] conflicts with
+// another pattern that is already registered, those functions panic.
+//
+// As an example of the general rule, "/images/thumbnails/" is more specific than "/images/",
+// so both can be registered.
+// The former matches paths beginning with "/images/thumbnails/"
+// and the latter will match any other path in the "/images/" subtree.
+//
+// As another example, consider the patterns "GET /" and "/index.html":
+// both match a GET request for "/index.html", but the former pattern
+// matches all other GET and HEAD requests, while the latter matches any
+// request for "/index.html" that uses a different method.
+// The patterns conflict.
+//
+// # Trailing-slash redirection
+//
+// Consider a [ServeMux] with a handler for a subtree, registered using a trailing slash or "..." wildcard.
+// If the ServeMux receives a request for the subtree root without a trailing slash,
+// it redirects the request by adding the trailing slash.
+// This behavior can be overridden with a separate registration for the path without
+// the trailing slash or "..." wildcard. For example, registering "/images/" causes ServeMux
// to redirect a request for "/images" to "/images/", unless "/images" has
// been registered separately.
//
-// Patterns may optionally begin with a host name, restricting matches to
-// URLs on that host only. Host-specific patterns take precedence over
-// general patterns, so that a handler might register for the two patterns
-// "/codesearch" and "codesearch.google.com/" without also taking over
-// requests for "http://www.google.com/".
+// # Request sanitizing
//
// ServeMux also takes care of sanitizing the URL request path and the Host
// header, stripping the port number and redirecting any request containing . or
-// .. elements or repeated slashes to an equivalent, cleaner URL.
+// .. segments or repeated slashes to an equivalent, cleaner URL.
+//
+// # Compatibility
+//
+// The pattern syntax and matching behavior of ServeMux changed significantly
+// in Go 1.22. To restore the old behavior, set the GODEBUG environment variable
+// to "httpmuxgo121=1". This setting is read once, at program startup; changes
+// during execution will be ignored.
+//
+// The backwards-incompatible changes include:
+// - Wildcards are just ordinary literal path segments in 1.21.
+// For example, the pattern "/{x}" will match only that path in 1.21,
+// but will match any one-segment path in 1.22.
+// - In 1.21, no pattern was rejected, unless it was empty or conflicted with an existing pattern.
+// In 1.22, syntactically invalid patterns will cause [ServeMux.Handle] and [ServeMux.HandleFunc] to panic.
+// For example, in 1.21, the patterns "/{" and "/a{x}" match themselves,
+// but in 1.22 they are invalid and will cause a panic when registered.
+// - In 1.22, each segment of a pattern is unescaped; this was not done in 1.21.
+// For example, in 1.22 the pattern "/%61" matches the path "/a" ("%61" being the URL escape sequence for "a"),
+// but in 1.21 it would match only the path "/%2561" (where "%25" is the escape for the percent sign).
+// - When matching patterns to paths, in 1.22 each segment of the path is unescaped; in 1.21, the entire path is unescaped.
+// This change mostly affects how paths with %2F escapes adjacent to slashes are treated.
+// See https://go.dev/issue/21955 for details.
type ServeMux struct {
- mu sync.RWMutex
- m map[string]muxEntry
- es []muxEntry // slice of entries sorted from longest to shortest.
- hosts bool // whether any patterns contain hostnames
+ mu sync.RWMutex
+ tree routingNode
+ index routingIndex
+ patterns []*pattern // TODO(jba): remove if possible
+ mux121 serveMux121 // used only when GODEBUG=httpmuxgo121=1
}
-type muxEntry struct {
- h Handler
- pattern string
+// NewServeMux allocates and returns a new [ServeMux].
+func NewServeMux() *ServeMux {
+ return &ServeMux{}
}
-// NewServeMux allocates and returns a new ServeMux.
-func NewServeMux() *ServeMux { return new(ServeMux) }
-
-// DefaultServeMux is the default ServeMux used by Serve.
+// DefaultServeMux is the default [ServeMux] used by [Serve].
var DefaultServeMux = &defaultServeMux
var defaultServeMux ServeMux
@@ -2372,66 +2482,6 @@ func stripHostPort(h string) string {
return host
}
-// Find a handler on a handler map given a path string.
-// Most-specific (longest) pattern wins.
-func (mux *ServeMux) match(path string) (h Handler, pattern string) {
- // Check for exact match first.
- v, ok := mux.m[path]
- if ok {
- return v.h, v.pattern
- }
-
- // Check for longest valid match. mux.es contains all patterns
- // that end in / sorted from longest to shortest.
- for _, e := range mux.es {
- if strings.HasPrefix(path, e.pattern) {
- return e.h, e.pattern
- }
- }
- return nil, ""
-}
-
-// redirectToPathSlash determines if the given path needs appending "/" to it.
-// This occurs when a handler for path + "/" was already registered, but
-// not for path itself. If the path needs appending to, it creates a new
-// URL, setting the path to u.Path + "/" and returning true to indicate so.
-func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) {
- mux.mu.RLock()
- shouldRedirect := mux.shouldRedirectRLocked(host, path)
- mux.mu.RUnlock()
- if !shouldRedirect {
- return u, false
- }
- path = path + "/"
- u = &url.URL{Path: path, RawQuery: u.RawQuery}
- return u, true
-}
-
-// shouldRedirectRLocked reports whether the given path and host should be redirected to
-// path+"/". This should happen if a handler is registered for path+"/" but
-// not path -- see comments at ServeMux.
-func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
- p := []string{path, host + path}
-
- for _, c := range p {
- if _, exist := mux.m[c]; exist {
- return false
- }
- }
-
- n := len(path)
- if n == 0 {
- return false
- }
- for _, c := range p {
- if _, exist := mux.m[c+"/"]; exist {
- return path[n-1] != '/'
- }
- }
-
- return false
-}
-
// Handler returns the handler to use for the given request,
// consulting r.Method, r.Host, and r.URL.Path. It always returns
// a non-nil handler. If the path is not in its canonical form, the
@@ -2443,61 +2493,175 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
//
// Handler also returns the registered pattern that matches the
// request or, in the case of internally-generated redirects,
-// the pattern that will match after following the redirect.
+// the path that will match after following the redirect.
//
// If there is no registered handler that applies to the request,
// Handler returns a “page not found” handler and an empty pattern.
func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
-
+ if use121 {
+ return mux.mux121.findHandler(r)
+ }
+ h, p, _, _ := mux.findHandler(r)
+ return h, p
+}
+
+// findHandler finds a handler for a request.
+// If there is a matching handler, it returns it and the pattern that matched.
+// Otherwise it returns a Redirect or NotFound handler with the path that would match
+// after the redirect.
+func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *pattern, matches []string) {
+ var n *routingNode
+ host := r.URL.Host
+ escapedPath := r.URL.EscapedPath()
+ path := escapedPath
// CONNECT requests are not canonicalized.
if r.Method == "CONNECT" {
// If r.URL.Path is /tree and its handler is not registered,
// the /tree -> /tree/ redirect applies to CONNECT requests
// but the path canonicalization does not.
- if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok {
- return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
+ _, _, u := mux.matchOrRedirect(host, r.Method, path, r.URL)
+ if u != nil {
+ return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil
+ }
+ // Redo the match, this time with r.Host instead of r.URL.Host.
+ // Pass a nil URL to skip the trailing-slash redirect logic.
+ n, matches, _ = mux.matchOrRedirect(r.Host, r.Method, path, nil)
+ } else {
+ // All other requests have any port stripped and path cleaned
+ // before passing to mux.handler.
+ host = stripHostPort(r.Host)
+ path = cleanPath(path)
+
+ // If the given path is /tree and its handler is not registered,
+ // redirect for /tree/.
+ var u *url.URL
+ n, matches, u = mux.matchOrRedirect(host, r.Method, path, r.URL)
+ if u != nil {
+ return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil
+ }
+ if path != escapedPath {
+ // Redirect to cleaned path.
+ patStr := ""
+ if n != nil {
+ patStr = n.pattern.String()
+ }
+ u := &url.URL{Path: path, RawQuery: r.URL.RawQuery}
+ return RedirectHandler(u.String(), StatusMovedPermanently), patStr, nil, nil
}
-
- return mux.handler(r.Host, r.URL.Path)
}
+ if n == nil {
+ // We didn't find a match with the request method. To distinguish between
+ // Not Found and Method Not Allowed, see if there is another pattern that
+ // matches except for the method.
+ allowedMethods := mux.matchingMethods(host, path)
+ if len(allowedMethods) > 0 {
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Allow", strings.Join(allowedMethods, ", "))
+ Error(w, StatusText(StatusMethodNotAllowed), StatusMethodNotAllowed)
+ }), "", nil, nil
+ }
+ return NotFoundHandler(), "", nil, nil
+ }
+ return n.handler, n.pattern.String(), n.pattern, matches
+}
- // All other requests have any port stripped and path cleaned
- // before passing to mux.handler.
- host := stripHostPort(r.Host)
- path := cleanPath(r.URL.Path)
+// matchOrRedirect looks up a node in the tree that matches the host, method and path.
+//
+// If the url argument is non-nil, handler also deals with trailing-slash
+// redirection: when a path doesn't match exactly, the match is tried again
+// after appending "/" to the path. If that second match succeeds, the last
+// return value is the URL to redirect to.
+func (mux *ServeMux) matchOrRedirect(host, method, path string, u *url.URL) (_ *routingNode, matches []string, redirectTo *url.URL) {
+ mux.mu.RLock()
+ defer mux.mu.RUnlock()
- // If the given path is /tree and its handler is not registered,
- // redirect for /tree/.
- if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok {
- return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
+ n, matches := mux.tree.match(host, method, path)
+ // If we have an exact match, or we were asked not to try trailing-slash redirection,
+ // then we're done.
+ if !exactMatch(n, path) && u != nil {
+ // If there is an exact match with a trailing slash, then redirect.
+ path += "/"
+ n2, _ := mux.tree.match(host, method, path)
+ if exactMatch(n2, path) {
+ return nil, nil, &url.URL{Path: cleanPath(u.Path) + "/", RawQuery: u.RawQuery}
+ }
}
+ return n, matches, nil
+}
- if path != r.URL.Path {
- _, pattern = mux.handler(host, path)
- u := &url.URL{Path: path, RawQuery: r.URL.RawQuery}
- return RedirectHandler(u.String(), StatusMovedPermanently), pattern
+// exactMatch reports whether the node's pattern exactly matches the path.
+// As a special case, if the node is nil, exactMatch return false.
+//
+// Before wildcards were introduced, it was clear that an exact match meant
+// that the pattern and path were the same string. The only other possibility
+// was that a trailing-slash pattern, like "/", matched a path longer than
+// it, like "/a".
+//
+// With wildcards, we define an inexact match as any one where a multi wildcard
+// matches a non-empty string. All other matches are exact.
+// For example, these are all exact matches:
+//
+// pattern path
+// /a /a
+// /{x} /a
+// /a/{$} /a/
+// /a/ /a/
+//
+// The last case has a multi wildcard (implicitly), but the match is exact because
+// the wildcard matches the empty string.
+//
+// Examples of matches that are not exact:
+//
+// pattern path
+// / /a
+// /a/{x...} /a/b
+func exactMatch(n *routingNode, path string) bool {
+ if n == nil {
+ return false
}
+ // We can't directly implement the definition (empty match for multi
+ // wildcard) because we don't record a match for anonymous multis.
- return mux.handler(host, r.URL.Path)
+ // If there is no multi, the match is exact.
+ if !n.pattern.lastSegment().multi {
+ return true
+ }
+
+ // If the path doesn't end in a trailing slash, then the multi match
+ // is non-empty.
+ if len(path) > 0 && path[len(path)-1] != '/' {
+ return false
+ }
+ // Only patterns ending in {$} or a multi wildcard can
+ // match a path with a trailing slash.
+ // For the match to be exact, the number of pattern
+ // segments should be the same as the number of slashes in the path.
+ // E.g. "/a/b/{$}" and "/a/b/{...}" exactly match "/a/b/", but "/a/" does not.
+ return len(n.pattern.segments) == strings.Count(path, "/")
}
-// handler is the main implementation of Handler.
-// The path is known to be in canonical form, except for CONNECT methods.
-func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
+// matchingMethods return a sorted list of all methods that would match with the given host and path.
+func (mux *ServeMux) matchingMethods(host, path string) []string {
+ // Hold the read lock for the entire method so that the two matches are done
+ // on the same set of registered patterns.
mux.mu.RLock()
defer mux.mu.RUnlock()
+ ms := map[string]bool{}
+ mux.tree.matchingMethods(host, path, ms)
+ // matchOrRedirect will try appending a trailing slash if there is no match.
+ mux.tree.matchingMethods(host, path+"/", ms)
+ methods := mapKeys(ms)
+ sort.Strings(methods)
+ return methods
+}
- // Host-specific pattern takes precedence over generic ones
- if mux.hosts {
- h, pattern = mux.match(host + path)
- }
- if h == nil {
- h, pattern = mux.match(path)
- }
- if h == nil {
- h, pattern = NotFoundHandler(), ""
+// TODO(jba): replace with maps.Keys when it is defined.
+func mapKeys[K comparable, V any](m map[K]V) []K {
+ var ks []K
+ for k := range m {
+ ks = append(ks, k)
}
- return
+ return ks
}
// ServeHTTP dispatches the request to the handler whose
@@ -2510,82 +2674,117 @@ func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
w.WriteHeader(StatusBadRequest)
return
}
- h, _ := mux.Handler(r)
+ var h Handler
+ if use121 {
+ h, _ = mux.mux121.findHandler(r)
+ } else {
+ h, _, r.pat, r.matches = mux.findHandler(r)
+ }
h.ServeHTTP(w, r)
}
+// The four functions below all call ServeMux.register so that callerLocation
+// always refers to user code.
+
// Handle registers the handler for the given pattern.
-// If a handler already exists for pattern, Handle panics.
+// If the given pattern conflicts, with one that is already registered, Handle
+// panics.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
- mux.mu.Lock()
- defer mux.mu.Unlock()
-
- if pattern == "" {
- panic("http: invalid pattern")
- }
- if handler == nil {
- panic("http: nil handler")
- }
- if _, exist := mux.m[pattern]; exist {
- panic("http: multiple registrations for " + pattern)
+ if use121 {
+ mux.mux121.handle(pattern, handler)
+ } else {
+ mux.register(pattern, handler)
}
+}
- if mux.m == nil {
- mux.m = make(map[string]muxEntry)
+// HandleFunc registers the handler function for the given pattern.
+// If the given pattern conflicts, with one that is already registered, HandleFunc
+// panics.
+func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ if use121 {
+ mux.mux121.handleFunc(pattern, handler)
+ } else {
+ mux.register(pattern, HandlerFunc(handler))
}
- e := muxEntry{h: handler, pattern: pattern}
- mux.m[pattern] = e
- if pattern[len(pattern)-1] == '/' {
- mux.es = appendSorted(mux.es, e)
+}
+
+// Handle registers the handler for the given pattern in [DefaultServeMux].
+// The documentation for [ServeMux] explains how patterns are matched.
+func Handle(pattern string, handler Handler) {
+ if use121 {
+ DefaultServeMux.mux121.handle(pattern, handler)
+ } else {
+ DefaultServeMux.register(pattern, handler)
}
+}
- if pattern[0] != '/' {
- mux.hosts = true
+// HandleFunc registers the handler function for the given pattern in [DefaultServeMux].
+// The documentation for [ServeMux] explains how patterns are matched.
+func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ if use121 {
+ DefaultServeMux.mux121.handleFunc(pattern, handler)
+ } else {
+ DefaultServeMux.register(pattern, HandlerFunc(handler))
}
}
-func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
- n := len(es)
- i := sort.Search(n, func(i int) bool {
- return len(es[i].pattern) < len(e.pattern)
- })
- if i == n {
- return append(es, e)
+func (mux *ServeMux) register(pattern string, handler Handler) {
+ if err := mux.registerErr(pattern, handler); err != nil {
+ panic(err)
}
- // we now know that i points at where we want to insert
- es = append(es, muxEntry{}) // try to grow the slice in place, any entry works.
- copy(es[i+1:], es[i:]) // Move shorter entries down
- es[i] = e
- return es
}
-// HandleFunc registers the handler function for the given pattern.
-func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+func (mux *ServeMux) registerErr(patstr string, handler Handler) error {
+ if patstr == "" {
+ return errors.New("http: invalid pattern")
+ }
if handler == nil {
- panic("http: nil handler")
+ return errors.New("http: nil handler")
+ }
+ if f, ok := handler.(HandlerFunc); ok && f == nil {
+ return errors.New("http: nil handler")
}
- mux.Handle(pattern, HandlerFunc(handler))
-}
-// Handle registers the handler for the given pattern
-// in the DefaultServeMux.
-// The documentation for ServeMux explains how patterns are matched.
-func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
+ pat, err := parsePattern(patstr)
+ if err != nil {
+ return fmt.Errorf("parsing %q: %w", patstr, err)
+ }
-// HandleFunc registers the handler function for the given pattern
-// in the DefaultServeMux.
-// The documentation for ServeMux explains how patterns are matched.
-func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
- DefaultServeMux.HandleFunc(pattern, handler)
+ // Get the caller's location, for better conflict error messages.
+ // Skip register and whatever calls it.
+ _, file, line, ok := runtime.Caller(3)
+ if !ok {
+ pat.loc = "unknown location"
+ } else {
+ pat.loc = fmt.Sprintf("%s:%d", file, line)
+ }
+
+ mux.mu.Lock()
+ defer mux.mu.Unlock()
+ // Check for conflict.
+ if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error {
+ if pat.conflictsWith(pat2) {
+ d := describeConflict(pat, pat2)
+ return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s):\n%s",
+ pat, pat.loc, pat2, pat2.loc, d)
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+ mux.tree.addPattern(pat, handler)
+ mux.index.addPattern(pat)
+ mux.patterns = append(mux.patterns, pat)
+ return nil
}
// Serve accepts incoming HTTP connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
//
-// The handler is typically nil, in which case the DefaultServeMux is used.
+// The handler is typically nil, in which case [DefaultServeMux] is used.
//
-// HTTP/2 support is only enabled if the Listener returns *tls.Conn
+// HTTP/2 support is only enabled if the Listener returns [*tls.Conn]
// connections and they were configured with "h2" in the TLS
// Config.NextProtos.
//
@@ -2599,7 +2798,7 @@ func Serve(l net.Listener, handler Handler) error {
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
//
-// The handler is typically nil, in which case the DefaultServeMux is used.
+// The handler is typically nil, in which case [DefaultServeMux] is used.
//
// Additionally, files containing a certificate and matching private key
// for the server must be provided. If the certificate is signed by a
@@ -2725,13 +2924,13 @@ type Server struct {
}
// Close immediately closes all active net.Listeners and any
-// connections in state StateNew, StateActive, or StateIdle. For a
-// graceful shutdown, use Shutdown.
+// connections in state [StateNew], [StateActive], or [StateIdle]. For a
+// graceful shutdown, use [Server.Shutdown].
//
// Close does not attempt to close (and does not even know about)
// any hijacked connections, such as WebSockets.
//
-// Close returns any error returned from closing the Server's
+// Close returns any error returned from closing the [Server]'s
// underlying Listener(s).
func (srv *Server) Close() error {
srv.inShutdown.Store(true)
@@ -2769,16 +2968,16 @@ const shutdownPollIntervalMax = 500 * time.Millisecond
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
-// error returned from closing the Server's underlying Listener(s).
+// error returned from closing the [Server]'s underlying Listener(s).
//
-// When Shutdown is called, Serve, ListenAndServe, and
-// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
+// When Shutdown is called, [Serve], [ListenAndServe], and
+// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Shutdown does not attempt to close nor wait for hijacked
// connections such as WebSockets. The caller of Shutdown should
// separately notify such long-lived connections of shutdown and wait
-// for them to close, if desired. See RegisterOnShutdown for a way to
+// for them to close, if desired. See [Server.RegisterOnShutdown] for a way to
// register shutdown notification functions.
//
// Once Shutdown has been called on a server, it may not be reused;
@@ -2821,7 +3020,7 @@ func (srv *Server) Shutdown(ctx context.Context) error {
}
}
-// RegisterOnShutdown registers a function to call on Shutdown.
+// RegisterOnShutdown registers a function to call on [Server.Shutdown].
// This can be used to gracefully shutdown connections that have
// undergone ALPN protocol upgrade or that have been hijacked.
// This function should start protocol-specific graceful shutdown,
@@ -2869,7 +3068,7 @@ func (s *Server) closeListenersLocked() error {
}
// A ConnState represents the state of a client connection to a server.
-// It's used by the optional Server.ConnState hook.
+// It's used by the optional [Server.ConnState] hook.
type ConnState int
const (
@@ -2946,7 +3145,7 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
// behavior doesn't match that of many proxies, and the mismatch can lead to
// security issues.
//
-// AllowQuerySemicolons should be invoked before Request.ParseForm is called.
+// AllowQuerySemicolons should be invoked before [Request.ParseForm] is called.
func AllowQuerySemicolons(h Handler) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
if strings.Contains(r.URL.RawQuery, ";") {
@@ -2963,13 +3162,13 @@ func AllowQuerySemicolons(h Handler) Handler {
}
// ListenAndServe listens on the TCP network address srv.Addr and then
-// calls Serve to handle requests on incoming connections.
+// calls [Serve] to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// If srv.Addr is blank, ":http" is used.
//
-// ListenAndServe always returns a non-nil error. After Shutdown or Close,
-// the returned error is ErrServerClosed.
+// ListenAndServe always returns a non-nil error. After [Server.Shutdown] or [Server.Close],
+// the returned error is [ErrServerClosed].
func (srv *Server) ListenAndServe() error {
if srv.shuttingDown() {
return ErrServerClosed
@@ -3009,20 +3208,20 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
}
-// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
-// and ListenAndServeTLS methods after a call to Shutdown or Close.
+// ErrServerClosed is returned by the [Server.Serve], [ServeTLS], [ListenAndServe],
+// and [ListenAndServeTLS] methods after a call to [Server.Shutdown] or [Server.Close].
var ErrServerClosed = errors.New("http: Server closed")
// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
//
-// HTTP/2 support is only enabled if the Listener returns *tls.Conn
+// HTTP/2 support is only enabled if the Listener returns [*tls.Conn]
// connections and they were configured with "h2" in the TLS
// Config.NextProtos.
//
// Serve always returns a non-nil error and closes l.
-// After Shutdown or Close, the returned error is ErrServerClosed.
+// After [Server.Shutdown] or [Server.Close], the returned error is [ErrServerClosed].
func (srv *Server) Serve(l net.Listener) error {
if fn := testHookServerServe; fn != nil {
fn(srv, l) // call hook with unwrapped listener
@@ -3092,14 +3291,14 @@ func (srv *Server) Serve(l net.Listener) error {
// setup and then read requests, calling srv.Handler to reply to them.
//
// Files containing a certificate and matching private key for the
-// server must be provided if neither the Server's
+// server must be provided if neither the [Server]'s
// TLSConfig.Certificates nor TLSConfig.GetCertificate are populated.
// If the certificate is signed by a certificate authority, the
// certFile should be the concatenation of the server's certificate,
// any intermediates, and the CA's certificate.
//
-// ServeTLS always returns a non-nil error. After Shutdown or Close, the
-// returned error is ErrServerClosed.
+// ServeTLS always returns a non-nil error. After [Server.Shutdown] or [Server.Close], the
+// returned error is [ErrServerClosed].
func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
// Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig
// before we clone it and create the TLS Listener.
@@ -3228,10 +3427,10 @@ func logf(r *Request, format string, args ...any) {
}
// ListenAndServe listens on the TCP network address addr and then calls
-// Serve with handler to handle requests on incoming connections.
+// [Serve] with handler to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
//
-// The handler is typically nil, in which case the DefaultServeMux is used.
+// The handler is typically nil, in which case [DefaultServeMux] is used.
//
// ListenAndServe always returns a non-nil error.
func ListenAndServe(addr string, handler Handler) error {
@@ -3239,7 +3438,7 @@ func ListenAndServe(addr string, handler Handler) error {
return server.ListenAndServe()
}
-// ListenAndServeTLS acts identically to ListenAndServe, except that it
+// ListenAndServeTLS acts identically to [ListenAndServe], except that it
// expects HTTPS connections. Additionally, files containing a certificate and
// matching private key for the server must be provided. If the certificate
// is signed by a certificate authority, the certFile should be the concatenation
@@ -3250,11 +3449,11 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and
-// then calls ServeTLS to handle requests on incoming TLS connections.
+// then calls [ServeTLS] to handle requests on incoming TLS connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// Filenames containing a certificate and matching private key for the
-// server must be provided if neither the Server's TLSConfig.Certificates
+// server must be provided if neither the [Server]'s TLSConfig.Certificates
// nor TLSConfig.GetCertificate are populated. If the certificate is
// signed by a certificate authority, the certFile should be the
// concatenation of the server's certificate, any intermediates, and
@@ -3262,8 +3461,8 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
//
// If srv.Addr is blank, ":https" is used.
//
-// ListenAndServeTLS always returns a non-nil error. After Shutdown or
-// Close, the returned error is ErrServerClosed.
+// ListenAndServeTLS always returns a non-nil error. After [Server.Shutdown] or
+// [Server.Close], the returned error is [ErrServerClosed].
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
if srv.shuttingDown() {
return ErrServerClosed
@@ -3333,17 +3532,17 @@ func (srv *Server) onceSetNextProtoDefaults() {
}
}
-// TimeoutHandler returns a Handler that runs h with the given time limit.
+// TimeoutHandler returns a [Handler] that runs h with the given time limit.
//
// The new Handler calls h.ServeHTTP to handle each request, but if a
// call runs for longer than its time limit, the handler responds with
// a 503 Service Unavailable error and the given message in its body.
// (If msg is empty, a suitable default message will be sent.)
-// After such a timeout, writes by h to its ResponseWriter will return
-// ErrHandlerTimeout.
+// After such a timeout, writes by h to its [ResponseWriter] will return
+// [ErrHandlerTimeout].
//
-// TimeoutHandler supports the Pusher interface but does not support
-// the Hijacker or Flusher interfaces.
+// TimeoutHandler supports the [Pusher] interface but does not support
+// the [Hijacker] or [Flusher] interfaces.
func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
return &timeoutHandler{
handler: h,
@@ -3352,7 +3551,7 @@ func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
}
}
-// ErrHandlerTimeout is returned on ResponseWriter Write calls
+// ErrHandlerTimeout is returned on [ResponseWriter] Write calls
// in handlers which have timed out.
var ErrHandlerTimeout = errors.New("http: Handler timeout")
@@ -3441,7 +3640,7 @@ type timeoutWriter struct {
var _ Pusher = (*timeoutWriter)(nil)
-// Push implements the Pusher interface.
+// Push implements the [Pusher] interface.
func (tw *timeoutWriter) Push(target string, opts *PushOptions) error {
if pusher, ok := tw.w.(Pusher); ok {
return pusher.Push(target, opts)
@@ -3526,7 +3725,7 @@ type initALPNRequest struct {
h serverHandler
}
-// BaseContext is an exported but unadvertised http.Handler method
+// BaseContext is an exported but unadvertised [http.Handler] method
// recognized by x/net/http2 to pass down a context; the TLSNextProto
// API predates context support so we shoehorn through the only
// interface we have available.
@@ -3613,7 +3812,6 @@ func numLeadingCRorLF(v []byte) (n int) {
break
}
return
-
}
func strSliceContains(ss []string, s string) bool {
@@ -3635,7 +3833,7 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
return false
}
-// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
+// MaxBytesHandler returns a [Handler] that runs h with its [ResponseWriter] and [Request.Body] wrapped by a MaxBytesReader.
func MaxBytesHandler(h Handler, n int64) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
r2 := *r
diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go
index d17c5c1e7e..e81e3bb6b0 100644
--- a/src/net/http/server_test.go
+++ b/src/net/http/server_test.go
@@ -8,6 +8,8 @@ package http
import (
"fmt"
+ "net/url"
+ "regexp"
"testing"
"time"
)
@@ -64,6 +66,190 @@ func TestServerTLSHandshakeTimeout(t *testing.T) {
}
}
+type handler struct{ i int }
+
+func (handler) ServeHTTP(ResponseWriter, *Request) {}
+
+func TestFindHandler(t *testing.T) {
+ mux := NewServeMux()
+ for _, ph := range []struct {
+ pat string
+ h Handler
+ }{
+ {"/", &handler{1}},
+ {"/foo/", &handler{2}},
+ {"/foo", &handler{3}},
+ {"/bar/", &handler{4}},
+ {"//foo", &handler{5}},
+ } {
+ mux.Handle(ph.pat, ph.h)
+ }
+
+ for _, test := range []struct {
+ method string
+ path string
+ wantHandler string
+ }{
+ {"GET", "/", "&http.handler{i:1}"},
+ {"GET", "//", `&http.redirectHandler{url:"/", code:301}`},
+ {"GET", "/foo/../bar/./..//baz", `&http.redirectHandler{url:"/baz", code:301}`},
+ {"GET", "/foo", "&http.handler{i:3}"},
+ {"GET", "/foo/x", "&http.handler{i:2}"},
+ {"GET", "/bar/x", "&http.handler{i:4}"},
+ {"GET", "/bar", `&http.redirectHandler{url:"/bar/", code:301}`},
+ {"CONNECT", "/", "&http.handler{i:1}"},
+ {"CONNECT", "//", "&http.handler{i:1}"},
+ {"CONNECT", "//foo", "&http.handler{i:5}"},
+ {"CONNECT", "/foo/../bar/./..//baz", "&http.handler{i:2}"},
+ {"CONNECT", "/foo", "&http.handler{i:3}"},
+ {"CONNECT", "/foo/x", "&http.handler{i:2}"},
+ {"CONNECT", "/bar/x", "&http.handler{i:4}"},
+ {"CONNECT", "/bar", `&http.redirectHandler{url:"/bar/", code:301}`},
+ } {
+ var r Request
+ r.Method = test.method
+ r.Host = "example.com"
+ r.URL = &url.URL{Path: test.path}
+ gotH, _, _, _ := mux.findHandler(&r)
+ got := fmt.Sprintf("%#v", gotH)
+ if got != test.wantHandler {
+ t.Errorf("%s %q: got %q, want %q", test.method, test.path, got, test.wantHandler)
+ }
+ }
+}
+
+func TestEmptyServeMux(t *testing.T) {
+ // Verify that a ServeMux with nothing registered
+ // doesn't panic.
+ mux := NewServeMux()
+ var r Request
+ r.Method = "GET"
+ r.Host = "example.com"
+ r.URL = &url.URL{Path: "/"}
+ _, p := mux.Handler(&r)
+ if p != "" {
+ t.Errorf(`got %q, want ""`, p)
+ }
+}
+
+func TestRegisterErr(t *testing.T) {
+ mux := NewServeMux()
+ h := &handler{}
+ mux.Handle("/a", h)
+
+ for _, test := range []struct {
+ pattern string
+ handler Handler
+ wantRegexp string
+ }{
+ {"", h, "invalid pattern"},
+ {"/", nil, "nil handler"},
+ {"/", HandlerFunc(nil), "nil handler"},
+ {"/{x", h, `parsing "/\{x": at offset 1: bad wildcard segment`},
+ {"/a", h, `conflicts with pattern.* \(registered at .*/server_test.go:\d+`},
+ } {
+ t.Run(fmt.Sprintf("%s:%#v", test.pattern, test.handler), func(t *testing.T) {
+ err := mux.registerErr(test.pattern, test.handler)
+ if err == nil {
+ t.Fatal("got nil error")
+ }
+ re := regexp.MustCompile(test.wantRegexp)
+ if g := err.Error(); !re.MatchString(g) {
+ t.Errorf("\ngot %q\nwant string matching %q", g, test.wantRegexp)
+ }
+ })
+ }
+}
+
+func TestExactMatch(t *testing.T) {
+ for _, test := range []struct {
+ pattern string
+ path string
+ want bool
+ }{
+ {"", "/a", false},
+ {"/", "/a", false},
+ {"/a", "/a", true},
+ {"/a/{x...}", "/a/b", false},
+ {"/a/{x}", "/a/b", true},
+ {"/a/b/", "/a/b/", true},
+ {"/a/b/{$}", "/a/b/", true},
+ {"/a/", "/a/b/", false},
+ } {
+ var n *routingNode
+ if test.pattern != "" {
+ pat := mustParsePattern(t, test.pattern)
+ n = &routingNode{pattern: pat}
+ }
+ got := exactMatch(n, test.path)
+ if got != test.want {
+ t.Errorf("%q, %s: got %t, want %t", test.pattern, test.path, got, test.want)
+ }
+ }
+}
+
+func TestEscapedPathsAndPatterns(t *testing.T) {
+ matches := []struct {
+ pattern string
+ paths []string // paths that match the pattern
+ paths121 []string // paths that matched the pattern in Go 1.21.
+ }{
+ {
+ "/a", // this pattern matches a path that unescapes to "/a"
+ []string{"/a", "/%61"},
+ []string{"/a", "/%61"},
+ },
+ {
+ "/%62", // patterns are unescaped by segment; matches paths that unescape to "/b"
+ []string{"/b", "/%62"},
+ []string{"/%2562"}, // In 1.21, patterns were not unescaped but paths were.
+ },
+ {
+ "/%7B/%7D", // the only way to write a pattern that matches '{' or '}'
+ []string{"/{/}", "/%7b/}", "/{/%7d", "/%7B/%7D"},
+ []string{"/%257B/%257D"}, // In 1.21, patterns were not unescaped.
+ },
+ {
+ "/%x", // patterns that do not unescape are left unchanged
+ []string{"/%25x"},
+ []string{"/%25x"},
+ },
+ }
+
+ run := func(t *testing.T, test121 bool) {
+ defer func(u bool) { use121 = u }(use121)
+ use121 = test121
+
+ mux := NewServeMux()
+ for _, m := range matches {
+ mux.HandleFunc(m.pattern, func(w ResponseWriter, r *Request) {})
+ }
+
+ for _, m := range matches {
+ paths := m.paths
+ if use121 {
+ paths = m.paths121
+ }
+ for _, p := range paths {
+ u, err := url.ParseRequestURI(p)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &Request{
+ URL: u,
+ }
+ _, gotPattern := mux.Handler(req)
+ if g, w := gotPattern, m.pattern; g != w {
+ t.Errorf("%s: pattern: got %q, want %q", p, g, w)
+ }
+ }
+ }
+ }
+
+ t.Run("latest", func(t *testing.T) { run(t, false) })
+ t.Run("1.21", func(t *testing.T) { run(t, true) })
+}
+
func BenchmarkServerMatch(b *testing.B) {
fn := func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "OK")
@@ -90,7 +276,11 @@ func BenchmarkServerMatch(b *testing.B) {
"/products/", "/products/3/image.jpg"}
b.StartTimer()
for i := 0; i < b.N; i++ {
- if h, p := mux.match(paths[i%len(paths)]); h != nil && p == "" {
+ r, err := NewRequest("GET", "http://example.com/"+paths[i%len(paths)], nil)
+ if err != nil {
+ b.Fatal(err)
+ }
+ if h, p, _, _ := mux.findHandler(r); h != nil && p == "" {
b.Error("impossible")
}
}
diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go
index d6f26a709c..315c6e2723 100644
--- a/src/net/http/transfer.go
+++ b/src/net/http/transfer.go
@@ -9,6 +9,7 @@ import (
"bytes"
"errors"
"fmt"
+ "internal/godebug"
"io"
"net/http/httptrace"
"net/http/internal"
@@ -409,7 +410,10 @@ func (t *transferWriter) writeBody(w io.Writer) (err error) {
//
// This function is only intended for use in writeBody.
func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) {
- n, err = io.Copy(dst, src)
+ buf := getCopyBuf()
+ defer putCopyBuf(buf)
+
+ n, err = io.CopyBuffer(dst, src, buf)
if err != nil && err != io.EOF {
t.bodyReadError = err
}
@@ -527,7 +531,7 @@ func readTransfer(msg any, r *bufio.Reader) (err error) {
return err
}
if isResponse && t.RequestMethod == "HEAD" {
- if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+ if n, err := parseContentLength(t.Header["Content-Length"]); err != nil {
return err
} else {
t.ContentLength = n
@@ -707,18 +711,15 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
return -1, nil
}
- // Logic based on Content-Length
- var cl string
- if len(contentLens) == 1 {
- cl = textproto.TrimString(contentLens[0])
- }
- if cl != "" {
- n, err := parseContentLength(cl)
+ if len(contentLens) > 0 {
+ // Logic based on Content-Length
+ n, err := parseContentLength(contentLens)
if err != nil {
return -1, err
}
return n, nil
}
+
header.Del("Content-Length")
if isRequest {
@@ -816,10 +817,10 @@ type body struct {
onHitEOF func() // if non-nil, func to call when EOF is Read
}
-// ErrBodyReadAfterClose is returned when reading a Request or Response
+// ErrBodyReadAfterClose is returned when reading a [Request] or [Response]
// Body after the body has been closed. This typically happens when the body is
-// read after an HTTP Handler calls WriteHeader or Write on its
-// ResponseWriter.
+// read after an HTTP [Handler] calls WriteHeader or Write on its
+// [ResponseWriter].
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
func (b *body) Read(p []byte) (n int, err error) {
@@ -1038,19 +1039,31 @@ func (bl bodyLocked) Read(p []byte) (n int, err error) {
return bl.b.readLocked(p)
}
-// parseContentLength trims whitespace from s and returns -1 if no value
-// is set, or the value if it's >= 0.
-func parseContentLength(cl string) (int64, error) {
- cl = textproto.TrimString(cl)
- if cl == "" {
+var laxContentLength = godebug.New("httplaxcontentlength")
+
+// parseContentLength checks that the header is valid and then trims
+// whitespace. It returns -1 if no value is set otherwise the value
+// if it's >= 0.
+func parseContentLength(clHeaders []string) (int64, error) {
+ if len(clHeaders) == 0 {
return -1, nil
}
+ cl := textproto.TrimString(clHeaders[0])
+
+ // The Content-Length must be a valid numeric value.
+ // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13
+ if cl == "" {
+ if laxContentLength.Value() == "1" {
+ laxContentLength.IncNonDefault()
+ return -1, nil
+ }
+ return 0, badStringError("invalid empty Content-Length", cl)
+ }
n, err := strconv.ParseUint(cl, 10, 63)
if err != nil {
return 0, badStringError("bad Content-Length", cl)
}
return int64(n), nil
-
}
// finishAsyncByteRead finishes reading the 1-byte sniff
diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go
index 5e0df896d8..b1a5a93103 100644
--- a/src/net/http/transfer_test.go
+++ b/src/net/http/transfer_test.go
@@ -112,8 +112,8 @@ func (w *mockTransferWriter) Write(p []byte) (int, error) {
}
func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
- fileType := reflect.TypeOf(&os.File{})
- bufferType := reflect.TypeOf(&bytes.Buffer{})
+ fileType := reflect.TypeFor[*os.File]()
+ bufferType := reflect.TypeFor[*bytes.Buffer]()
nBytes := int64(1 << 10)
newFileFunc := func() (r io.Reader, done func(), err error) {
@@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
actualReader = reflect.TypeOf(lr.R)
} else {
actualReader = reflect.TypeOf(mw.CalledReader)
+ // We have to handle this special case for genericWriteTo in os,
+ // this struct is introduced to support a zero-copy optimization,
+ // check out https://go.dev/issue/58808 for details.
+ if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
+ actualReader = actualReader.Field(1).Type
+ }
}
if tc.expectedReader != actualReader {
@@ -333,6 +339,10 @@ func TestParseContentLength(t *testing.T) {
wantErr error
}{
{
+ cl: "",
+ wantErr: badStringError("invalid empty Content-Length", ""),
+ },
+ {
cl: "3",
wantErr: nil,
},
@@ -356,7 +366,7 @@ func TestParseContentLength(t *testing.T) {
}
for _, tt := range tests {
- if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) {
+ if _, gotErr := parseContentLength([]string{tt.cl}); !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr)
}
}
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index c07352b018..17067ac07c 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -35,8 +35,8 @@ import (
"golang.org/x/net/http/httpproxy"
)
-// DefaultTransport is the default implementation of Transport and is
-// used by DefaultClient. It establishes network connections as needed
+// DefaultTransport is the default implementation of [Transport] and is
+// used by [DefaultClient]. It establishes network connections as needed
// and caches them for reuse by subsequent calls. It uses HTTP proxies
// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY
// and NO_PROXY (or the lowercase versions thereof).
@@ -53,42 +53,42 @@ var DefaultTransport RoundTripper = &Transport{
ExpectContinueTimeout: 1 * time.Second,
}
-// DefaultMaxIdleConnsPerHost is the default value of Transport's
+// DefaultMaxIdleConnsPerHost is the default value of [Transport]'s
// MaxIdleConnsPerHost.
const DefaultMaxIdleConnsPerHost = 2
-// Transport is an implementation of RoundTripper that supports HTTP,
+// Transport is an implementation of [RoundTripper] that supports HTTP,
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
//
// By default, Transport caches connections for future re-use.
// This may leave many open connections when accessing many hosts.
-// This behavior can be managed using Transport's CloseIdleConnections method
-// and the MaxIdleConnsPerHost and DisableKeepAlives fields.
+// This behavior can be managed using [Transport.CloseIdleConnections] method
+// and the [Transport.MaxIdleConnsPerHost] and [Transport.DisableKeepAlives] fields.
//
// Transports should be reused instead of created as needed.
// Transports are safe for concurrent use by multiple goroutines.
//
// A Transport is a low-level primitive for making HTTP and HTTPS requests.
-// For high-level functionality, such as cookies and redirects, see Client.
+// For high-level functionality, such as cookies and redirects, see [Client].
//
// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2
// for HTTPS URLs, depending on whether the server supports HTTP/2,
-// and how the Transport is configured. The DefaultTransport supports HTTP/2.
+// and how the Transport is configured. The [DefaultTransport] supports HTTP/2.
// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2
// and call ConfigureTransport. See the package docs for more about HTTP/2.
//
// Responses with status codes in the 1xx range are either handled
// automatically (100 expect-continue) or ignored. The one
// exception is HTTP status code 101 (Switching Protocols), which is
-// considered a terminal status and returned by RoundTrip. To see the
+// considered a terminal status and returned by [Transport.RoundTrip]. To see the
// ignored 1xx responses, use the httptrace trace package's
// ClientTrace.Got1xxResponse.
//
// Transport only retries a request upon encountering a network error
// if the connection has been already been used successfully and if the
-// request is idempotent and either has no body or has its Request.GetBody
+// request is idempotent and either has no body or has its [Request.GetBody]
// defined. HTTP requests are considered idempotent if they have HTTP methods
-// GET, HEAD, OPTIONS, or TRACE; or if their Header map contains an
+// GET, HEAD, OPTIONS, or TRACE; or if their [Header] map contains an
// "Idempotency-Key" or "X-Idempotency-Key" entry. If the idempotency key
// value is a zero-length slice, the request is treated as idempotent but the
// header is not sent on the wire.
@@ -117,6 +117,10 @@ type Transport struct {
// "https", and "socks5" are supported. If the scheme is empty,
// "http" is assumed.
//
+ // If the proxy URL contains a userinfo subcomponent,
+ // the proxy request will pass the username and password
+ // in a Proxy-Authorization header.
+ //
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*Request) (*url.URL, error)
@@ -233,7 +237,7 @@ type Transport struct {
// TLSNextProto specifies how the Transport switches to an
// alternate protocol (such as HTTP/2) after a TLS ALPN
- // protocol negotiation. If Transport dials an TLS connection
+ // protocol negotiation. If Transport dials a TLS connection
// with a non-empty protocol name and TLSNextProto contains a
// map entry for that key (such as "h2"), then the func is
// called with the request's authority (such as "example.com"
@@ -449,7 +453,7 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) {
return envProxyFunc()(req.URL)
}
-// ProxyURL returns a proxy function (for use in a Transport)
+// ProxyURL returns a proxy function (for use in a [Transport])
// that always returns the same URL.
func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
return func(*Request) (*url.URL, error) {
@@ -748,14 +752,14 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool {
var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol")
// RegisterProtocol registers a new protocol with scheme.
-// The Transport will pass requests using the given scheme to rt.
+// The [Transport] will pass requests using the given scheme to rt.
// It is rt's responsibility to simulate HTTP request semantics.
//
// RegisterProtocol can be used by other packages to provide
// implementations of protocol schemes like "ftp" or "file".
//
-// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will
-// handle the RoundTrip itself for that one request, as if the
+// If rt.RoundTrip returns [ErrSkipAltProtocol], the Transport will
+// handle the [Transport.RoundTrip] itself for that one request, as if the
// protocol were not registered.
func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
t.altMu.Lock()
@@ -795,9 +799,9 @@ func (t *Transport) CloseIdleConnections() {
}
// CancelRequest cancels an in-flight request by closing its connection.
-// CancelRequest should only be called after RoundTrip has returned.
+// CancelRequest should only be called after [Transport.RoundTrip] has returned.
//
-// Deprecated: Use Request.WithContext to create a request with a
+// Deprecated: Use [Request.WithContext] to create a request with a
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
func (t *Transport) CancelRequest(req *Request) {
@@ -1205,7 +1209,6 @@ func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, e
type wantConn struct {
cm connectMethod
key connectMethodKey // cm.key()
- ctx context.Context // context for dial
ready chan struct{} // closed when pc, err pair is delivered
// hooks for testing to know when dials are done
@@ -1214,7 +1217,8 @@ type wantConn struct {
beforeDial func()
afterDial func()
- mu sync.Mutex // protects pc, err, close(ready)
+ mu sync.Mutex // protects ctx, pc, err, close(ready)
+ ctx context.Context // context for dial, cleared after delivered or canceled
pc *persistConn
err error
}
@@ -1229,6 +1233,13 @@ func (w *wantConn) waiting() bool {
}
}
+// getCtxForDial returns context for dial or nil if connection was delivered or canceled.
+func (w *wantConn) getCtxForDial() context.Context {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.ctx
+}
+
// tryDeliver attempts to deliver pc, err to w and reports whether it succeeded.
func (w *wantConn) tryDeliver(pc *persistConn, err error) bool {
w.mu.Lock()
@@ -1238,6 +1249,7 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool {
return false
}
+ w.ctx = nil
w.pc = pc
w.err = err
if w.pc == nil && w.err == nil {
@@ -1255,6 +1267,7 @@ func (w *wantConn) cancel(t *Transport, err error) {
close(w.ready) // catch misbehavior in future delivery
}
pc := w.pc
+ w.ctx = nil
w.pc = nil
w.err = err
w.mu.Unlock()
@@ -1463,8 +1476,13 @@ func (t *Transport) queueForDial(w *wantConn) {
// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
func (t *Transport) dialConnFor(w *wantConn) {
defer w.afterDial()
+ ctx := w.getCtxForDial()
+ if ctx == nil {
+ t.decConnsPerHost(w.key)
+ return
+ }
- pc, err := t.dialConn(w.ctx, w.cm)
+ pc, err := t.dialConn(ctx, w.cm)
delivered := w.tryDeliver(pc, err)
if err == nil && (!delivered || pc.alt != nil) {
// pconn was not passed to w,
@@ -1560,6 +1578,11 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr
}()
if err := <-errc; err != nil {
plainConn.Close()
+ if err == (tlsHandshakeTimeoutError{}) {
+ // Now that we have closed the connection,
+ // wait for the call to HandshakeContext to return.
+ <-errc
+ }
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
@@ -2248,7 +2271,7 @@ func (pc *persistConn) readLoop() {
}
case <-rc.req.Cancel:
alive = false
- pc.t.CancelRequest(rc.req)
+ pc.t.cancelRequest(rc.cancelKey, errRequestCanceled)
case <-rc.req.Context().Done():
alive = false
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 028fecc961..3fb5624664 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -730,6 +730,56 @@ func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
}
}
+func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
+ run(t, testTransportMaxConnsPerHostDialCancellation,
+ testNotParallel, // because test uses SetPendingDialHooks
+ []testMode{http1Mode, https1Mode, http2Mode},
+ )
+}
+
+func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
+ CondSkipHTTP2(t)
+
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ cst := newClientServerTest(t, mode, h)
+ defer cst.close()
+ ts := cst.ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+
+ // This request is cancelled when dial is queued, which preempts dialing.
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ SetPendingDialHooks(cancel, nil)
+ defer SetPendingDialHooks(nil, nil)
+
+ req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
+ _, err := c.Do(req)
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("expected error %v, got %v", context.Canceled, err)
+ }
+
+ // This request should succeed.
+ SetPendingDialHooks(nil, nil)
+ req, _ = NewRequest("GET", ts.URL, nil)
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("request failed: %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("read body failed: %v", err)
+ }
+}
+
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
}
@@ -2099,25 +2149,50 @@ func testIssue3644(t *testing.T, mode testMode) {
// Test that a client receives a server's reply, even if the server doesn't read
// the entire request body.
-func TestIssue3595(t *testing.T) { run(t, testIssue3595) }
+func TestIssue3595(t *testing.T) {
+ // Not parallel: modifies the global rstAvoidanceDelay.
+ run(t, testIssue3595, testNotParallel)
+}
func testIssue3595(t *testing.T, mode testMode) {
- const deniedMsg = "sorry, denied."
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- Error(w, deniedMsg, StatusUnauthorized)
- })).ts
- c := ts.Client()
- res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
- if err != nil {
- t.Errorf("Post: %v", err)
- return
- }
- got, err := io.ReadAll(res.Body)
- if err != nil {
- t.Fatalf("Body ReadAll: %v", err)
- }
- if !strings.Contains(string(got), deniedMsg) {
- t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
- }
+ runTimeSensitiveTest(t, []time.Duration{
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 100 * time.Millisecond,
+ 500 * time.Millisecond,
+ time.Second,
+ 5 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ SetRSTAvoidanceDelay(t, timeout)
+ t.Logf("set RST avoidance delay to %v", timeout)
+
+ const deniedMsg = "sorry, denied."
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ }))
+ // We need to close cst explicitly here so that in-flight server
+ // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+ defer cst.close()
+ ts := cst.ts
+ c := ts.Client()
+
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ return fmt.Errorf("Post: %v", err)
+ }
+ got, err := io.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("Body ReadAll: %v", err)
+ }
+ t.Logf("server response:\n%s", got)
+ if !strings.Contains(string(got), deniedMsg) {
+ // If we got an RST packet too early, we should have seen an error
+ // from io.ReadAll, not a silently-truncated body.
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+ return nil
+ })
}
// From https://golang.org/issue/4454 ,
@@ -2440,6 +2515,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
if d > 0 {
t.Logf("pending requests = %d after %v (want 0)", n, d)
}
+ return false
}
return true
})
@@ -2599,6 +2675,65 @@ func testCancelRequestWithChannel(t *testing.T, mode testMode) {
if d > 0 {
t.Logf("pending requests = %d after %v (want 0)", n, d)
}
+ return false
+ }
+ return true
+ })
+}
+
+// Issue 51354
+func TestCancelRequestWithBodyWithChannel(t *testing.T) {
+ run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
+}
+func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+
+ const msg = "Hello"
+ unblockc := make(chan struct{})
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, msg)
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := make([]byte, len(msg))
+ n, _ := io.ReadFull(res.Body, body)
+ if n != len(body) || !bytes.Equal(body, []byte(msg)) {
+ t.Errorf("Body = %q; want %q", body[:n], msg)
+ }
+ close(cancel)
+
+ tail, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != ExportErrRequestCanceled {
+ t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
+ } else if len(tail) > 0 {
+ t.Errorf("Spurious bytes from Body.Read: %q", tail)
+ }
+
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ n := tr.NumPendingRequestsForTesting()
+ if n > 0 {
+ if d > 0 {
+ t.Logf("pending requests = %d after %v (want 0)", n, d)
+ }
+ return false
}
return true
})
@@ -3414,6 +3549,7 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
c net.Conn
}
var getOkay bool
+ var copying sync.WaitGroup
closeConn := func() {
sconn.Lock()
defer sconn.Unlock()
@@ -3425,7 +3561,10 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
}
}
}
- defer closeConn()
+ defer func() {
+ closeConn()
+ copying.Wait()
+ }()
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method == "GET" {
@@ -3437,7 +3576,12 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
sconn.c = conn
sconn.Unlock()
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
- go io.Copy(io.Discard, conn)
+
+ copying.Add(1)
+ go func() {
+ io.Copy(io.Discard, conn)
+ copying.Done()
+ }()
})).ts
c := ts.Client()
@@ -4267,68 +4411,78 @@ func (c *wgReadCloser) Close() error {
// Issue 11745.
func TestTransportPrefersResponseOverWriteError(t *testing.T) {
- run(t, testTransportPrefersResponseOverWriteError)
+ // Not parallel: modifies the global rstAvoidanceDelay.
+ run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
}
func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- const contentLengthLimit = 1024 * 1024 // 1MB
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- if r.ContentLength >= contentLengthLimit {
- w.WriteHeader(StatusBadRequest)
- r.Body.Close()
- return
- }
- w.WriteHeader(StatusOK)
- })).ts
- c := ts.Client()
- fail := 0
- count := 100
+ runTimeSensitiveTest(t, []time.Duration{
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 100 * time.Millisecond,
+ 500 * time.Millisecond,
+ time.Second,
+ 5 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ SetRSTAvoidanceDelay(t, timeout)
+ t.Logf("set RST avoidance delay to %v", timeout)
+
+ const contentLengthLimit = 1024 * 1024 // 1MB
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.ContentLength >= contentLengthLimit {
+ w.WriteHeader(StatusBadRequest)
+ r.Body.Close()
+ return
+ }
+ w.WriteHeader(StatusOK)
+ }))
+ // We need to close cst explicitly here so that in-flight server
+ // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+ defer cst.close()
+ ts := cst.ts
+ c := ts.Client()
- bigBody := strings.Repeat("a", contentLengthLimit*2)
- var wg sync.WaitGroup
- defer wg.Wait()
- getBody := func() (io.ReadCloser, error) {
- wg.Add(1)
- body := &wgReadCloser{
- Reader: strings.NewReader(bigBody),
- wg: &wg,
- }
- return body, nil
- }
+ count := 100
- for i := 0; i < count; i++ {
- reqBody, _ := getBody()
- req, err := NewRequest("PUT", ts.URL, reqBody)
- if err != nil {
- reqBody.Close()
- t.Fatal(err)
+ bigBody := strings.Repeat("a", contentLengthLimit*2)
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ getBody := func() (io.ReadCloser, error) {
+ wg.Add(1)
+ body := &wgReadCloser{
+ Reader: strings.NewReader(bigBody),
+ wg: &wg,
+ }
+ return body, nil
}
- req.ContentLength = int64(len(bigBody))
- req.GetBody = getBody
- resp, err := c.Do(req)
- if err != nil {
- fail++
- t.Logf("%d = %#v", i, err)
- if ue, ok := err.(*url.Error); ok {
- t.Logf("urlErr = %#v", ue.Err)
- if ne, ok := ue.Err.(*net.OpError); ok {
- t.Logf("netOpError = %#v", ne.Err)
- }
+ for i := 0; i < count; i++ {
+ reqBody, _ := getBody()
+ req, err := NewRequest("PUT", ts.URL, reqBody)
+ if err != nil {
+ reqBody.Close()
+ t.Fatal(err)
}
- } else {
- resp.Body.Close()
- if resp.StatusCode != 400 {
- t.Errorf("Expected status code 400, got %v", resp.Status)
+ req.ContentLength = int64(len(bigBody))
+ req.GetBody = getBody
+
+ resp, err := c.Do(req)
+ if err != nil {
+ return fmt.Errorf("Do %d: %v", i, err)
+ } else {
+ resp.Body.Close()
+ if resp.StatusCode != 400 {
+ t.Errorf("Expected status code 400, got %v", resp.Status)
+ }
}
}
- }
- if fail > 0 {
- t.Errorf("Failed %v out of %v\n", fail, count)
- }
+ return nil
+ })
}
func TestTransportAutomaticHTTP2(t *testing.T) {
@@ -6750,3 +6904,36 @@ func testRequestSanitization(t *testing.T, mode testMode) {
resp.Body.Close()
}
}
+
+func TestProxyAuthHeader(t *testing.T) {
+ // Not parallel: Sets an environment variable.
+ run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
+}
+func testProxyAuthHeader(t *testing.T, mode testMode) {
+ const username = "u"
+ const password = "@/?!"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ // Copy the Proxy-Authorization header to a new Request,
+ // since Request.BasicAuth only parses the Authorization header.
+ var r2 Request
+ r2.Header = Header{
+ "Authorization": req.Header["Proxy-Authorization"],
+ }
+ gotuser, gotpass, ok := r2.BasicAuth()
+ if !ok || gotuser != username || gotpass != password {
+ t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
+ }
+ }))
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ u.User = url.UserPassword(username, password)
+ t.Setenv("HTTP_PROXY", u.String())
+ cst.tr.Proxy = ProxyURL(u)
+ resp, err := cst.c.Get("http://_/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+}
diff --git a/src/net/http/triv.go b/src/net/http/triv.go
index f614922c24..c1696425cd 100644
--- a/src/net/http/triv.go
+++ b/src/net/http/triv.go
@@ -34,7 +34,7 @@ type Counter struct {
n int
}
-// This makes Counter satisfy the expvar.Var interface, so we can export
+// This makes Counter satisfy the [expvar.Var] interface, so we can export
// it directly.
func (ctr *Counter) String() string {
ctr.mu.Lock()
diff --git a/src/net/interface.go b/src/net/interface.go
index e1c9a2e2ff..20ac07d31a 100644
--- a/src/net/interface.go
+++ b/src/net/interface.go
@@ -114,7 +114,7 @@ func Interfaces() ([]Interface, error) {
// addresses.
//
// The returned list does not identify the associated interface; use
-// Interfaces and Interface.Addrs for more detail.
+// Interfaces and [Interface.Addrs] for more detail.
func InterfaceAddrs() ([]Addr, error) {
ifat, err := interfaceAddrTable(nil)
if err != nil {
@@ -127,7 +127,7 @@ func InterfaceAddrs() ([]Addr, error) {
//
// On Solaris, it returns one of the logical network interfaces
// sharing the logical data link; for more precision use
-// InterfaceByName.
+// [InterfaceByName].
func InterfaceByIndex(index int) (*Interface, error) {
if index <= 0 {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex}
diff --git a/src/net/interface_stub.go b/src/net/interface_stub.go
index 829dbc6938..4c280c6ff2 100644
--- a/src/net/interface_stub.go
+++ b/src/net/interface_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || wasip1
+//go:build js || wasip1
package net
diff --git a/src/net/interface_test.go b/src/net/interface_test.go
index 5590b06262..a97d675e7e 100644
--- a/src/net/interface_test.go
+++ b/src/net/interface_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
index 0197feb3f1..967ce6795a 100644
--- a/src/net/internal/socktest/main_test.go
+++ b/src/net/internal/socktest/main_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1
+//go:build !js && !plan9 && !wasip1 && !windows
package socktest_test
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
deleted file mode 100644
index df1cb97784..0000000000
--- a/src/net/internal/socktest/main_windows_test.go
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2015 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package socktest_test
-
-import "syscall"
-
-var (
- socketFunc func(int, int, int) (syscall.Handle, error)
- closeFunc func(syscall.Handle) error
-)
-
-func installTestHooks() {
- socketFunc = sw.Socket
- closeFunc = sw.Closesocket
-}
-
-func uninstallTestHooks() {
- socketFunc = syscall.Socket
- closeFunc = syscall.Closesocket
-}
diff --git a/src/net/internal/socktest/switch.go b/src/net/internal/socktest/switch.go
index 3c37b6ff80..dea6d9288c 100644
--- a/src/net/internal/socktest/switch.go
+++ b/src/net/internal/socktest/switch.go
@@ -133,7 +133,7 @@ const (
// If the filter returns a non-nil error, the execution of system call
// will be canceled and the system call function returns the non-nil
// error.
-// It can return a non-nil AfterFilter for filtering after the
+// It can return a non-nil [AfterFilter] for filtering after the
// execution of the system call.
type Filter func(*Status) (AfterFilter, error)
diff --git a/src/net/internal/socktest/sys_unix.go b/src/net/internal/socktest/sys_unix.go
index 712462abf4..3eef26c70b 100644
--- a/src/net/internal/socktest/sys_unix.go
+++ b/src/net/internal/socktest/sys_unix.go
@@ -8,7 +8,7 @@ package socktest
import "syscall"
-// Socket wraps syscall.Socket.
+// Socket wraps [syscall.Socket].
func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) {
sw.once.Do(sw.init)
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
index 8c1c862f33..2f02446075 100644
--- a/src/net/internal/socktest/sys_windows.go
+++ b/src/net/internal/socktest/sys_windows.go
@@ -9,39 +9,7 @@ import (
"syscall"
)
-// Socket wraps syscall.Socket.
-func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
- sw.once.Do(sw.init)
-
- so := &Status{Cookie: cookie(family, sotype, proto)}
- sw.fmu.RLock()
- f, _ := sw.fltab[FilterSocket]
- sw.fmu.RUnlock()
-
- af, err := f.apply(so)
- if err != nil {
- return syscall.InvalidHandle, err
- }
- s, so.Err = syscall.Socket(family, sotype, proto)
- if err = af.apply(so); err != nil {
- if so.Err == nil {
- syscall.Closesocket(s)
- }
- return syscall.InvalidHandle, err
- }
-
- sw.smu.Lock()
- defer sw.smu.Unlock()
- if so.Err != nil {
- sw.stats.getLocked(so.Cookie).OpenFailed++
- return syscall.InvalidHandle, so.Err
- }
- nso := sw.addLocked(s, family, sotype, proto)
- sw.stats.getLocked(nso.Cookie).Opened++
- return s, nil
-}
-
-// WSASocket wraps syscall.WSASocket.
+// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
@@ -73,7 +41,7 @@ func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAPr
return s, nil
}
-// Closesocket wraps syscall.Closesocket.
+// Closesocket wraps [syscall.Closesocket].
func (sw *Switch) Closesocket(s syscall.Handle) (err error) {
so := sw.sockso(s)
if so == nil {
@@ -103,7 +71,7 @@ func (sw *Switch) Closesocket(s syscall.Handle) (err error) {
return nil
}
-// Connect wraps syscall.Connect.
+// Connect wraps [syscall.Connect].
func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) {
so := sw.sockso(s)
if so == nil {
@@ -132,7 +100,7 @@ func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) {
return nil
}
-// ConnectEx wraps syscall.ConnectEx.
+// ConnectEx wraps [syscall.ConnectEx].
func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n uint32, nwr *uint32, o *syscall.Overlapped) (err error) {
so := sw.sockso(s)
if so == nil {
@@ -161,7 +129,7 @@ func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n ui
return nil
}
-// Listen wraps syscall.Listen.
+// Listen wraps [syscall.Listen].
func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) {
so := sw.sockso(s)
if so == nil {
@@ -190,7 +158,7 @@ func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) {
return nil
}
-// AcceptEx wraps syscall.AcceptEx.
+// AcceptEx wraps [syscall.AcceptEx].
func (sw *Switch) AcceptEx(ls syscall.Handle, as syscall.Handle, b *byte, rxdatalen uint32, laddrlen uint32, raddrlen uint32, rcvd *uint32, overlapped *syscall.Overlapped) error {
so := sw.sockso(ls)
if so == nil {
diff --git a/src/net/ip.go b/src/net/ip.go
index d51ba10eec..6083dd8bf9 100644
--- a/src/net/ip.go
+++ b/src/net/ip.go
@@ -38,7 +38,7 @@ type IP []byte
// An IPMask is a bitmask that can be used to manipulate
// IP addresses for IP addressing and routing.
//
-// See type IPNet and func ParseCIDR for details.
+// See type [IPNet] and func [ParseCIDR] for details.
type IPMask []byte
// An IPNet represents an IP network.
@@ -72,9 +72,9 @@ func IPv4Mask(a, b, c, d byte) IPMask {
return p
}
-// CIDRMask returns an IPMask consisting of 'ones' 1 bits
+// CIDRMask returns an [IPMask] consisting of 'ones' 1 bits
// followed by 0s up to a total length of 'bits' bits.
-// For a mask of this form, CIDRMask is the inverse of IPMask.Size.
+// For a mask of this form, CIDRMask is the inverse of [IPMask.Size].
func CIDRMask(ones, bits int) IPMask {
if bits != 8*IPv4len && bits != 8*IPv6len {
return nil
@@ -324,8 +324,8 @@ func ipEmptyString(ip IP) string {
return ip.String()
}
-// MarshalText implements the encoding.TextMarshaler interface.
-// The encoding is the same as returned by String, with one exception:
+// MarshalText implements the [encoding.TextMarshaler] interface.
+// The encoding is the same as returned by [IP.String], with one exception:
// When len(ip) is zero, it returns an empty slice.
func (ip IP) MarshalText() ([]byte, error) {
if len(ip) == 0 {
@@ -337,8 +337,8 @@ func (ip IP) MarshalText() ([]byte, error) {
return []byte(ip.String()), nil
}
-// UnmarshalText implements the encoding.TextUnmarshaler interface.
-// The IP address is expected in a form accepted by ParseIP.
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+// The IP address is expected in a form accepted by [ParseIP].
func (ip *IP) UnmarshalText(text []byte) error {
if len(text) == 0 {
*ip = nil
diff --git a/src/net/ip_test.go b/src/net/ip_test.go
index 1373059abe..acc2310be1 100644
--- a/src/net/ip_test.go
+++ b/src/net/ip_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go
index f18331a1fd..4c06b1b5ac 100644
--- a/src/net/iprawsock.go
+++ b/src/net/iprawsock.go
@@ -72,7 +72,7 @@ func (a *IPAddr) opAddr() Addr {
// recommended, because it will return at most one of the host name's
// IP addresses.
//
-// See func Dial for a description of the network and address
+// See func [Dial] for a description of the network and address
// parameters.
func ResolveIPAddr(network, address string) (*IPAddr, error) {
if network == "" { // a hint wildcard for Go 1.0 undocumented behavior
@@ -94,19 +94,19 @@ func ResolveIPAddr(network, address string) (*IPAddr, error) {
return addrs.forResolve(network, address).(*IPAddr), nil
}
-// IPConn is the implementation of the Conn and PacketConn interfaces
+// IPConn is the implementation of the [Conn] and [PacketConn] interfaces
// for IP network connections.
type IPConn struct {
conn
}
// SyscallConn returns a raw network connection.
-// This implements the syscall.Conn interface.
+// This implements the [syscall.Conn] interface.
func (c *IPConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
- return newRawConn(c.fd)
+ return newRawConn(c.fd), nil
}
// ReadFromIP acts like ReadFrom but returns an IPAddr.
@@ -121,7 +121,7 @@ func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) {
return n, addr, err
}
-// ReadFrom implements the PacketConn ReadFrom method.
+// ReadFrom implements the [PacketConn] ReadFrom method.
func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
@@ -154,7 +154,7 @@ func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err
return
}
-// WriteToIP acts like WriteTo but takes an IPAddr.
+// WriteToIP acts like [IPConn.WriteTo] but takes an [IPAddr].
func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
@@ -166,7 +166,7 @@ func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
return n, err
}
-// WriteTo implements the PacketConn WriteTo method.
+// WriteTo implements the [PacketConn] WriteTo method.
func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
@@ -201,7 +201,7 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error
func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
-// DialIP acts like Dial for IP networks.
+// DialIP acts like [Dial] for IP networks.
//
// The network must be an IP network name; see func Dial for details.
//
@@ -220,7 +220,7 @@ func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
return c, nil
}
-// ListenIP acts like ListenPacket for IP networks.
+// ListenIP acts like [ListenPacket] for IP networks.
//
// The network must be an IP network name; see func Dial for details.
//
diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go
index 59967eb923..73b41ab522 100644
--- a/src/net/iprawsock_posix.go
+++ b/src/net/iprawsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
diff --git a/src/net/iprawsock_test.go b/src/net/iprawsock_test.go
index 14c03a1f4d..7f1fc139ab 100644
--- a/src/net/iprawsock_test.go
+++ b/src/net/iprawsock_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
diff --git a/src/net/ipsock.go b/src/net/ipsock.go
index 0f5da2577c..176dbc748e 100644
--- a/src/net/ipsock.go
+++ b/src/net/ipsock.go
@@ -83,10 +83,10 @@ func (addrs addrList) forResolve(network, addr string) Addr {
switch network {
case "ip":
// IPv6 literal (addr does NOT contain a port)
- want6 = count(addr, ':') > 0
+ want6 = bytealg.CountString(addr, ':') > 0
case "tcp", "udp":
// IPv6 literal. (addr contains a port, so look for '[')
- want6 = count(addr, '[') > 0
+ want6 = bytealg.CountString(addr, '[') > 0
}
if want6 {
return addrs.first(isNotIPv4)
@@ -172,7 +172,7 @@ func SplitHostPort(hostport string) (host, port string, err error) {
j, k := 0, 0
// The port starts after the last colon.
- i := last(hostport, ':')
+ i := bytealg.LastIndexByteString(hostport, ':')
if i < 0 {
return addrErr(hostport, missingPort)
}
@@ -219,7 +219,7 @@ func SplitHostPort(hostport string) (host, port string, err error) {
func splitHostZone(s string) (host, zone string) {
// The IPv6 scoped addressing zone identifier starts after the
// last percent sign.
- if i := last(s, '%'); i > 0 {
+ if i := bytealg.LastIndexByteString(s, '%'); i > 0 {
host, zone = s[:i], s[i+1:]
} else {
host = s
diff --git a/src/net/ipsock_plan9.go b/src/net/ipsock_plan9.go
index 43287431c8..c8d0180436 100644
--- a/src/net/ipsock_plan9.go
+++ b/src/net/ipsock_plan9.go
@@ -181,7 +181,6 @@ func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, e
}
resc := make(chan res)
go func() {
- testHookDialChannel()
fd, err := dialPlan9Blocking(ctx, net, laddr, raddr)
select {
case resc <- res{fd, err}:
diff --git a/src/net/ipsock_posix.go b/src/net/ipsock_posix.go
index b0a00a6296..67ce1479c6 100644
--- a/src/net/ipsock_posix.go
+++ b/src/net/ipsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
@@ -25,6 +25,15 @@ import (
// general. Unfortunately, we need to run on kernels built without
// IPv6 support too. So probe the kernel to figure it out.
func (p *ipStackCapabilities) probe() {
+ switch runtime.GOOS {
+ case "js", "wasip1":
+ // Both ipv4 and ipv6 are faked; see net_fake.go.
+ p.ipv4Enabled = true
+ p.ipv6Enabled = true
+ p.ipv4MappedIPv6Enabled = true
+ return
+ }
+
s, err := sysSocket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
switch err {
case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT:
@@ -135,8 +144,11 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
}
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
- if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
- raddr = raddr.toLocal(net)
+ switch runtime.GOOS {
+ case "aix", "windows", "openbsd", "js", "wasip1":
+ if mode == "dial" && raddr.isWildcard() {
+ raddr = raddr.toLocal(net)
+ }
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
diff --git a/src/net/listen_test.go b/src/net/listen_test.go
index f0a8861370..9100b3d9f7 100644
--- a/src/net/listen_test.go
+++ b/src/net/listen_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1
+//go:build !plan9
package net
diff --git a/src/net/lookup.go b/src/net/lookup.go
index a7133b53ac..3ec2660786 100644
--- a/src/net/lookup.go
+++ b/src/net/lookup.go
@@ -41,19 +41,20 @@ var services = map[string]map[string]int{
"domain": 53,
},
"tcp": {
- "ftp": 21,
- "ftps": 990,
- "gopher": 70, // ʕ◔ϖ◔ʔ
- "http": 80,
- "https": 443,
- "imap2": 143,
- "imap3": 220,
- "imaps": 993,
- "pop3": 110,
- "pop3s": 995,
- "smtp": 25,
- "ssh": 22,
- "telnet": 23,
+ "ftp": 21,
+ "ftps": 990,
+ "gopher": 70, // ʕ◔ϖ◔ʔ
+ "http": 80,
+ "https": 443,
+ "imap2": 143,
+ "imap3": 220,
+ "imaps": 993,
+ "pop3": 110,
+ "pop3s": 995,
+ "smtp": 25,
+ "submissions": 465,
+ "ssh": 22,
+ "telnet": 23,
},
}
@@ -83,12 +84,20 @@ const maxPortBufSize = len("mobility-header") + 10
func lookupPortMap(network, service string) (port int, error error) {
switch network {
- case "tcp4", "tcp6":
- network = "tcp"
- case "udp4", "udp6":
- network = "udp"
+ case "ip": // no hints
+ if p, err := lookupPortMapWithNetwork("tcp", "ip", service); err == nil {
+ return p, nil
+ }
+ return lookupPortMapWithNetwork("udp", "ip", service)
+ case "tcp", "tcp4", "tcp6":
+ return lookupPortMapWithNetwork("tcp", "tcp", service)
+ case "udp", "udp4", "udp6":
+ return lookupPortMapWithNetwork("udp", "udp", service)
}
+ return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}
+}
+func lookupPortMapWithNetwork(network, errNetwork, service string) (port int, error error) {
if m, ok := services[network]; ok {
var lowerService [maxPortBufSize]byte
n := copy(lowerService[:], service)
@@ -96,8 +105,9 @@ func lookupPortMap(network, service string) (port int, error error) {
if port, ok := m[string(lowerService[:n])]; ok && n == len(service) {
return port, nil
}
+ return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true}
}
- return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
+ return 0, &DNSError{Err: "unknown network", Name: errNetwork + "/" + service}
}
// ipVersion returns the provided network's IP version: '4', '6' or 0
@@ -171,8 +181,8 @@ func (r *Resolver) getLookupGroup() *singleflight.Group {
// LookupHost looks up the given host using the local resolver.
// It returns a slice of that host's addresses.
//
-// LookupHost uses context.Background internally; to specify the context, use
-// Resolver.LookupHost.
+// LookupHost uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupHost].
func LookupHost(host string) (addrs []string, err error) {
return DefaultResolver.LookupHost(context.Background(), host)
}
@@ -407,18 +417,20 @@ func ipAddrsEface(addrs []IPAddr) []any {
// LookupPort looks up the port for the given network and service.
//
-// LookupPort uses context.Background internally; to specify the context, use
-// Resolver.LookupPort.
+// LookupPort uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupPort].
func LookupPort(network, service string) (port int, err error) {
return DefaultResolver.LookupPort(context.Background(), network, service)
}
// LookupPort looks up the port for the given network and service.
+//
+// The network must be one of "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6" or "ip".
func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
port, needsLookup := parsePort(service)
if needsLookup {
switch network {
- case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+ case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip":
case "": // a hint wildcard for Go 1.0 undocumented behavior
network = "ip"
default:
@@ -437,7 +449,7 @@ func (r *Resolver) LookupPort(ctx context.Context, network, service string) (por
// LookupCNAME returns the canonical name for the given host.
// Callers that do not care about the canonical name can call
-// LookupHost or LookupIP directly; both take care of resolving
+// [LookupHost] or [LookupIP] directly; both take care of resolving
// the canonical name as part of the lookup.
//
// A canonical name is the final name after following zero
@@ -449,15 +461,15 @@ func (r *Resolver) LookupPort(ctx context.Context, network, service string) (por
// The returned canonical name is validated to be a properly
// formatted presentation-format domain name.
//
-// LookupCNAME uses context.Background internally; to specify the context, use
-// Resolver.LookupCNAME.
+// LookupCNAME uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupCNAME].
func LookupCNAME(host string) (cname string, err error) {
return DefaultResolver.LookupCNAME(context.Background(), host)
}
// LookupCNAME returns the canonical name for the given host.
// Callers that do not care about the canonical name can call
-// LookupHost or LookupIP directly; both take care of resolving
+// [LookupHost] or [LookupIP] directly; both take care of resolving
// the canonical name as part of the lookup.
//
// A canonical name is the final name after following zero
@@ -479,7 +491,7 @@ func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error)
return cname, nil
}
-// LookupSRV tries to resolve an SRV query of the given service,
+// LookupSRV tries to resolve an [SRV] query of the given service,
// protocol, and domain name. The proto is "tcp" or "udp".
// The returned records are sorted by priority and randomized
// by weight within a priority.
@@ -497,7 +509,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
}
-// LookupSRV tries to resolve an SRV query of the given service,
+// LookupSRV tries to resolve an [SRV] query of the given service,
// protocol, and domain name. The proto is "tcp" or "udp".
// The returned records are sorted by priority and randomized
// by weight within a priority.
@@ -542,8 +554,8 @@ func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
//
-// LookupMX uses context.Background internally; to specify the context, use
-// Resolver.LookupMX.
+// LookupMX uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupMX].
func LookupMX(name string) ([]*MX, error) {
return DefaultResolver.LookupMX(context.Background(), name)
}
@@ -582,8 +594,8 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
//
-// LookupNS uses context.Background internally; to specify the context, use
-// Resolver.LookupNS.
+// LookupNS uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupNS].
func LookupNS(name string) ([]*NS, error) {
return DefaultResolver.LookupNS(context.Background(), name)
}
@@ -617,8 +629,8 @@ func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
// LookupTXT returns the DNS TXT records for the given domain name.
//
-// LookupTXT uses context.Background internally; to specify the context, use
-// Resolver.LookupTXT.
+// LookupTXT uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupTXT].
func LookupTXT(name string) ([]string, error) {
return DefaultResolver.lookupTXT(context.Background(), name)
}
@@ -636,10 +648,10 @@ func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error)
// out and an error will be returned alongside the remaining results, if any.
//
// When using the host C library resolver, at most one result will be
-// returned. To bypass the host resolver, use a custom Resolver.
+// returned. To bypass the host resolver, use a custom [Resolver].
//
-// LookupAddr uses context.Background internally; to specify the context, use
-// Resolver.LookupAddr.
+// LookupAddr uses [context.Background] internally; to specify the context, use
+// [Resolver.LookupAddr].
func LookupAddr(addr string) (names []string, err error) {
return DefaultResolver.LookupAddr(context.Background(), addr)
}
diff --git a/src/net/lookup_fake.go b/src/net/lookup_fake.go
deleted file mode 100644
index c27eae4ba5..0000000000
--- a/src/net/lookup_fake.go
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build js && wasm
-
-package net
-
-import (
- "context"
- "syscall"
-)
-
-func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
- return lookupProtocolMap(name)
-}
-
-func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
- return goLookupPort(network, service)
-}
-
-func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
- return "", syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) {
- return "", nil, syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupMX(ctx context.Context, name string) (mxs []*MX, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupNS(ctx context.Context, name string) (nss []*NS, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupTXT(ctx context.Context, name string) (txts []string, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func (*Resolver) lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-// concurrentThreadsLimit returns the number of threads we permit to
-// run concurrently doing DNS lookups.
-func concurrentThreadsLimit() int {
- return 500
-}
diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go
index 5404b996e4..8cfc4f6bb3 100644
--- a/src/net/lookup_plan9.go
+++ b/src/net/lookup_plan9.go
@@ -106,6 +106,22 @@ func queryDNS(ctx context.Context, addr string, typ string) (res []string, err e
return query(ctx, netdir+"/dns", addr+" "+typ, 1024)
}
+func handlePlan9DNSError(err error, name string) error {
+ if stringsHasSuffix(err.Error(), "dns: name does not exist") ||
+ stringsHasSuffix(err.Error(), "dns: resource does not exist; negrcode 0") ||
+ stringsHasSuffix(err.Error(), "dns: resource does not exist; negrcode") {
+ return &DNSError{
+ Err: errNoSuchHost.Error(),
+ Name: name,
+ IsNotFound: true,
+ }
+ }
+ return &DNSError{
+ Err: err.Error(),
+ Name: name,
+ }
+}
+
// toLower returns a lower-case version of in. Restricting us to
// ASCII is sufficient to handle the IP protocol names and allow
// us to not depend on the strings and unicode packages.
@@ -153,12 +169,10 @@ func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, e
// host names in local network (e.g. from /lib/ndb/local)
lines, err := queryCS(ctx, "net", host, "1")
if err != nil {
- dnsError := &DNSError{Err: err.Error(), Name: host}
if stringsHasSuffix(err.Error(), "dns failure") {
- dnsError.Err = errNoSuchHost.Error()
- dnsError.IsNotFound = true
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
}
- return nil, dnsError
+ return nil, handlePlan9DNSError(err, host)
}
loop:
for _, line := range lines {
@@ -184,31 +198,11 @@ loop:
return
}
-// preferGoOverPlan9 reports whether the resolver should use the
-// "PreferGo" implementation rather than asking plan9 services
-// for the answers.
-func (r *Resolver) preferGoOverPlan9() bool {
- _, _, res := r.preferGoOverPlan9WithOrderAndConf()
- return res
-}
-
-func (r *Resolver) preferGoOverPlan9WithOrderAndConf() (hostLookupOrder, *dnsConfig, bool) {
- order, conf := systemConf().hostLookupOrder(r, "") // name is unused
-
- // TODO(bradfitz): for now we only permit use of the PreferGo
- // implementation when there's a non-nil Resolver with a
- // non-nil Dialer. This is a sign that they the code is trying
- // to use their DNS-speaking net.Conn (such as an in-memory
- // DNS cache) and they don't want to actually hit the network.
- // Once we add support for looking the default DNS servers
- // from plan9, though, then we can relax this.
- return order, conf, order != hostLookupCgo && r != nil && r.Dial != nil
-}
-
func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
- if r.preferGoOverPlan9() {
- return r.goLookupIP(ctx, network, host)
+ if order, conf := systemConf().hostLookupOrder(r, host); order != hostLookupCgo {
+ return r.goLookupIP(ctx, network, host, order, conf)
}
+
lits, err := r.lookupHost(ctx, host)
if err != nil {
return
@@ -223,24 +217,36 @@ func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []
return
}
-func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
+func (r *Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
switch network {
- case "tcp4", "tcp6":
- network = "tcp"
- case "udp4", "udp6":
- network = "udp"
+ case "ip": // no hints
+ if p, err := r.lookupPortWithNetwork(ctx, "tcp", "ip", service); err == nil {
+ return p, nil
+ }
+ return r.lookupPortWithNetwork(ctx, "udp", "ip", service)
+ case "tcp", "tcp4", "tcp6":
+ return r.lookupPortWithNetwork(ctx, "tcp", "tcp", service)
+ case "udp", "udp4", "udp6":
+ return r.lookupPortWithNetwork(ctx, "udp", "udp", service)
+ default:
+ return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}
}
+}
+
+func (*Resolver) lookupPortWithNetwork(ctx context.Context, network, errNetwork, service string) (port int, err error) {
lines, err := queryCS(ctx, network, "127.0.0.1", toLower(service))
if err != nil {
+ if stringsHasSuffix(err.Error(), "can't translate service") {
+ return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true}
+ }
return
}
- unknownPortError := &AddrError{Err: "unknown port", Addr: network + "/" + service}
if len(lines) == 0 {
- return 0, unknownPortError
+ return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true}
}
f := getFields(lines[0])
if len(f) < 2 {
- return 0, unknownPortError
+ return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true}
}
s := f[1]
if i := bytealg.IndexByteString(s, '!'); i >= 0 {
@@ -249,21 +255,20 @@ func (*Resolver) lookupPort(ctx context.Context, network, service string) (port
if n, _, ok := dtoi(s); ok {
return n, nil
}
- return 0, unknownPortError
+ return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true}
}
func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
- if order, conf, preferGo := r.preferGoOverPlan9WithOrderAndConf(); preferGo {
+ if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo {
return r.goLookupCNAME(ctx, name, order, conf)
}
lines, err := queryDNS(ctx, name, "cname")
if err != nil {
if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") {
- cname = name + "."
- err = nil
+ return absDomainName(name), nil
}
- return
+ return "", handlePlan9DNSError(err, cname)
}
if len(lines) > 0 {
if f := getFields(lines[0]); len(f) >= 3 {
@@ -274,7 +279,7 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string,
}
func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
- if r.preferGoOverPlan9() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupSRV(ctx, service, proto, name)
}
var target string
@@ -285,7 +290,7 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
}
lines, err := queryDNS(ctx, target, "srv")
if err != nil {
- return
+ return "", nil, handlePlan9DNSError(err, name)
}
for _, line := range lines {
f := getFields(line)
@@ -306,12 +311,12 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
}
func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
- if r.preferGoOverPlan9() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupMX(ctx, name)
}
lines, err := queryDNS(ctx, name, "mx")
if err != nil {
- return
+ return nil, handlePlan9DNSError(err, name)
}
for _, line := range lines {
f := getFields(line)
@@ -327,12 +332,12 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err err
}
func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
- if r.preferGoOverPlan9() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupNS(ctx, name)
}
lines, err := queryDNS(ctx, name, "ns")
if err != nil {
- return
+ return nil, handlePlan9DNSError(err, name)
}
for _, line := range lines {
f := getFields(line)
@@ -345,12 +350,12 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err err
}
func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
- if r.preferGoOverPlan9() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupTXT(ctx, name)
}
lines, err := queryDNS(ctx, name, "txt")
if err != nil {
- return
+ return nil, handlePlan9DNSError(err, name)
}
for _, line := range lines {
if i := bytealg.IndexByteString(line, '\t'); i >= 0 {
@@ -361,7 +366,7 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, er
}
func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
- if order, conf, preferGo := r.preferGoOverPlan9WithOrderAndConf(); preferGo {
+ if order, conf := systemConf().addrLookupOrder(r, addr); order != hostLookupCgo {
return r.goLookupPTR(ctx, addr, order, conf)
}
arpa, err := reverseaddr(addr)
@@ -370,7 +375,7 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string,
}
lines, err := queryDNS(ctx, arpa, "ptr")
if err != nil {
- return
+ return nil, handlePlan9DNSError(err, addr)
}
for _, line := range lines {
f := getFields(line)
diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go
index 0689c19c3c..57ac9a933a 100644
--- a/src/net/lookup_test.go
+++ b/src/net/lookup_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -1462,3 +1460,189 @@ func testLookupNoData(t *testing.T, prefix string) {
return
}
}
+
+func TestLookupPortNotFound(t *testing.T) {
+ allResolvers(t, func(t *testing.T) {
+ _, err := LookupPort("udp", "_-unknown-service-")
+ var dnsErr *DNSError
+ if !errors.As(err, &dnsErr) || !dnsErr.IsNotFound {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ })
+}
+
+// submissions service is only available through a tcp network, see:
+// https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml?search=submissions
+var tcpOnlyService = func() string {
+ // plan9 does not have submissions service defined in the service database.
+ if runtime.GOOS == "plan9" {
+ return "https"
+ }
+ return "submissions"
+}()
+
+func TestLookupPortDifferentNetwork(t *testing.T) {
+ allResolvers(t, func(t *testing.T) {
+ _, err := LookupPort("udp", tcpOnlyService)
+ var dnsErr *DNSError
+ if !errors.As(err, &dnsErr) || !dnsErr.IsNotFound {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ })
+}
+
+func TestLookupPortEmptyNetworkString(t *testing.T) {
+ allResolvers(t, func(t *testing.T) {
+ _, err := LookupPort("", tcpOnlyService)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ })
+}
+
+func TestLookupPortIPNetworkString(t *testing.T) {
+ allResolvers(t, func(t *testing.T) {
+ _, err := LookupPort("ip", tcpOnlyService)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ })
+}
+
+func allResolvers(t *testing.T, f func(t *testing.T)) {
+ t.Run("default resolver", f)
+ t.Run("forced go resolver", func(t *testing.T) {
+ if fixup := forceGoDNS(); fixup != nil {
+ defer fixup()
+ f(t)
+ }
+ })
+ t.Run("forced cgo resolver", func(t *testing.T) {
+ if fixup := forceCgoDNS(); fixup != nil {
+ defer fixup()
+ f(t)
+ }
+ })
+}
+
+func TestLookupNoSuchHost(t *testing.T) {
+ mustHaveExternalNetwork(t)
+
+ const testNXDOMAIN = "invalid.invalid."
+ const testNODATA = "_ldap._tcp.google.com."
+
+ tests := []struct {
+ name string
+ query func() error
+ }{
+ {
+ name: "LookupCNAME NXDOMAIN",
+ query: func() error {
+ _, err := LookupCNAME(testNXDOMAIN)
+ return err
+ },
+ },
+ {
+ name: "LookupHost NXDOMAIN",
+ query: func() error {
+ _, err := LookupHost(testNXDOMAIN)
+ return err
+ },
+ },
+ {
+ name: "LookupHost NODATA",
+ query: func() error {
+ _, err := LookupHost(testNODATA)
+ return err
+ },
+ },
+ {
+ name: "LookupMX NXDOMAIN",
+ query: func() error {
+ _, err := LookupMX(testNXDOMAIN)
+ return err
+ },
+ },
+ {
+ name: "LookupMX NODATA",
+ query: func() error {
+ _, err := LookupMX(testNODATA)
+ return err
+ },
+ },
+ {
+ name: "LookupNS NXDOMAIN",
+ query: func() error {
+ _, err := LookupNS(testNXDOMAIN)
+ return err
+ },
+ },
+ {
+ name: "LookupNS NODATA",
+ query: func() error {
+ _, err := LookupNS(testNODATA)
+ return err
+ },
+ },
+ {
+ name: "LookupSRV NXDOMAIN",
+ query: func() error {
+ _, _, err := LookupSRV("unknown", "tcp", testNXDOMAIN)
+ return err
+ },
+ },
+ {
+ name: "LookupTXT NXDOMAIN",
+ query: func() error {
+ _, err := LookupTXT(testNXDOMAIN)
+ return err
+ },
+ },
+ {
+ name: "LookupTXT NODATA",
+ query: func() error {
+ _, err := LookupTXT(testNODATA)
+ return err
+ },
+ },
+ }
+
+ for _, v := range tests {
+ t.Run(v.name, func(t *testing.T) {
+ allResolvers(t, func(t *testing.T) {
+ attempts := 0
+ for {
+ err := v.query()
+ if err == nil {
+ t.Errorf("unexpected success")
+ return
+ }
+ if dnsErr, ok := err.(*DNSError); ok {
+ succeeded := true
+ if !dnsErr.IsNotFound {
+ succeeded = false
+ t.Log("IsNotFound is set to false")
+ }
+ if dnsErr.Err != errNoSuchHost.Error() {
+ succeeded = false
+ t.Logf("error message is not equal to: %v", errNoSuchHost.Error())
+ }
+ if succeeded {
+ return
+ }
+ }
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ continue
+ }
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+ })
+ })
+ }
+}
diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go
index 56ae11e961..382a2d44bb 100644
--- a/src/net/lookup_unix.go
+++ b/src/net/lookup_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1
+//go:build unix || js || wasip1
package net
@@ -10,7 +10,6 @@ import (
"context"
"internal/bytealg"
"sync"
- "syscall"
)
var onceReadProtocols sync.Once
@@ -62,9 +61,6 @@ func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string,
}
func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
- if r.preferGo() {
- return r.goLookupIP(ctx, network, host)
- }
order, conf := systemConf().hostLookupOrder(r, host)
if order == hostLookupCgo {
return cgoLookupIP(ctx, network, host)
@@ -123,27 +119,3 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error
}
return r.goLookupPTR(ctx, addr, order, conf)
}
-
-// concurrentThreadsLimit returns the number of threads we permit to
-// run concurrently doing DNS lookups via cgo. A DNS lookup may use a
-// file descriptor so we limit this to less than the number of
-// permitted open files. On some systems, notably Darwin, if
-// getaddrinfo is unable to open a file descriptor it simply returns
-// EAI_NONAME rather than a useful error. Limiting the number of
-// concurrent getaddrinfo calls to less than the permitted number of
-// file descriptors makes that error less likely. We don't bother to
-// apply the same limit to DNS lookups run directly from Go, because
-// there we will return a meaningful "too many open files" error.
-func concurrentThreadsLimit() int {
- var rlim syscall.Rlimit
- if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil {
- return 500
- }
- r := rlim.Cur
- if r > 500 {
- r = 500
- } else if r > 30 {
- r -= 30
- }
- return int(r)
-}
diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go
index 33d5ac5fb4..3048f3269b 100644
--- a/src/net/lookup_windows.go
+++ b/src/net/lookup_windows.go
@@ -20,13 +20,17 @@ import (
const cgoAvailable = true
const (
+ _DNS_ERROR_RCODE_NAME_ERROR = syscall.Errno(9003)
+ _DNS_INFO_NO_RECORDS = syscall.Errno(9501)
+
_WSAHOST_NOT_FOUND = syscall.Errno(11001)
_WSATRY_AGAIN = syscall.Errno(11002)
+ _WSATYPE_NOT_FOUND = syscall.Errno(10109)
)
func winError(call string, err error) error {
switch err {
- case _WSAHOST_NOT_FOUND:
+ case _WSAHOST_NOT_FOUND, _DNS_ERROR_RCODE_NAME_ERROR, _DNS_INFO_NO_RECORDS:
return errNoSuchHost
}
return os.NewSyscallError(call, err)
@@ -91,19 +95,11 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error
return addrs, nil
}
-// preferGoOverWindows reports whether the resolver should use the
-// pure Go implementation rather than making win32 calls to ask the
-// kernel for its answer.
-func (r *Resolver) preferGoOverWindows() bool {
- conf := systemConf()
- order, _ := conf.hostLookupOrder(r, "") // name is unused
- return order != hostLookupCgo
-}
-
func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
- if r.preferGoOverWindows() {
- return r.goLookupIP(ctx, network, name)
+ if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo {
+ return r.goLookupIP(ctx, network, name, order, conf)
}
+
// TODO(bradfitz,brainman): use ctx more. See TODO below.
var family int32 = syscall.AF_UNSPEC
@@ -200,37 +196,51 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
}
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
- if r.preferGoOverWindows() {
+ if systemConf().mustUseGoResolver(r) {
return lookupPortMap(network, service)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
- var stype int32
+
+ var hints syscall.AddrinfoW
+
switch network {
- case "tcp4", "tcp6":
- stype = syscall.SOCK_STREAM
- case "udp4", "udp6":
- stype = syscall.SOCK_DGRAM
+ case "ip": // no hints
+ case "tcp", "tcp4", "tcp6":
+ hints.Socktype = syscall.SOCK_STREAM
+ hints.Protocol = syscall.IPPROTO_TCP
+ case "udp", "udp4", "udp6":
+ hints.Socktype = syscall.SOCK_DGRAM
+ hints.Protocol = syscall.IPPROTO_UDP
+ default:
+ return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}
}
- hints := syscall.AddrinfoW{
- Family: syscall.AF_UNSPEC,
- Socktype: stype,
- Protocol: syscall.IPPROTO_IP,
+
+ switch ipVersion(network) {
+ case '4':
+ hints.Family = syscall.AF_INET
+ case '6':
+ hints.Family = syscall.AF_INET6
}
+
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
if e != nil {
if port, err := lookupPortMap(network, service); err == nil {
return port, nil
}
- err := winError("getaddrinfow", e)
- dnsError := &DNSError{Err: err.Error(), Name: network + "/" + service}
- if err == errNoSuchHost {
- dnsError.IsNotFound = true
+
+ // The _WSATYPE_NOT_FOUND error is returned by GetAddrInfoW
+ // when the service name is unknown. We are also checking
+ // for _WSAHOST_NOT_FOUND here to match the cgo (unix) version
+ // cgo_unix.go (cgoLookupServicePort).
+ if e == _WSATYPE_NOT_FOUND || e == _WSAHOST_NOT_FOUND {
+ return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true}
}
- return 0, dnsError
+ err := os.NewSyscallError("getaddrinfow", e)
+ return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}
}
defer syscall.FreeAddrInfoW(result)
if result == nil {
@@ -249,7 +259,7 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
}
func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
- if order, conf := systemConf().hostLookupOrder(r, ""); order != hostLookupCgo {
+ if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo {
return r.goLookupCNAME(ctx, name, order, conf)
}
@@ -264,7 +274,8 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error)
return absDomainName(name), nil
}
if e != nil {
- return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ err := winError("dnsquery", e)
+ return "", &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
}
defer syscall.DnsRecordListFree(rec, 1)
@@ -274,7 +285,7 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error)
}
func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
- if r.preferGoOverWindows() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupSRV(ctx, service, proto, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
@@ -289,7 +300,8 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
var rec *syscall.DNSRecord
e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil)
if e != nil {
- return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
+ err := winError("dnsquery", e)
+ return "", nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
}
defer syscall.DnsRecordListFree(rec, 1)
@@ -303,7 +315,7 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
}
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
- if r.preferGoOverWindows() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupMX(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
@@ -312,7 +324,8 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
if e != nil {
- return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ err := winError("dnsquery", e)
+ return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
}
defer syscall.DnsRecordListFree(rec, 1)
@@ -326,7 +339,7 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
}
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
- if r.preferGoOverWindows() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupNS(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
@@ -335,7 +348,8 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
if e != nil {
- return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ err := winError("dnsquery", e)
+ return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
}
defer syscall.DnsRecordListFree(rec, 1)
@@ -348,7 +362,7 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
}
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
- if r.preferGoOverWindows() {
+ if systemConf().mustUseGoResolver(r) {
return r.goLookupTXT(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
@@ -357,7 +371,8 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error)
var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
if e != nil {
- return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ err := winError("dnsquery", e)
+ return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
}
defer syscall.DnsRecordListFree(rec, 1)
@@ -374,7 +389,7 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error)
}
func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
- if order, conf := systemConf().hostLookupOrder(r, ""); order != hostLookupCgo {
+ if order, conf := systemConf().addrLookupOrder(r, addr); order != hostLookupCgo {
return r.goLookupPTR(ctx, addr, order, conf)
}
@@ -388,7 +403,8 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error
var rec *syscall.DNSRecord
e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil)
if e != nil {
- return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
+ err := winError("dnsquery", e)
+ return nil, &DNSError{Err: err.Error(), Name: addr, IsNotFound: err == errNoSuchHost}
}
defer syscall.DnsRecordListFree(rec, 1)
diff --git a/src/net/mail/message.go b/src/net/mail/message.go
index af516fc30f..fc2a9e46f8 100644
--- a/src/net/mail/message.go
+++ b/src/net/mail/message.go
@@ -280,7 +280,7 @@ func (a *Address) String() string {
// Add quotes if needed
quoteLocal := false
for i, r := range local {
- if isAtext(r, false, false) {
+ if isAtext(r, false) {
continue
}
if r == '.' {
@@ -444,7 +444,7 @@ func (p *addrParser) parseAddress(handleGroup bool) ([]*Address, error) {
if !p.consume('<') {
atext := true
for _, r := range displayName {
- if !isAtext(r, true, false) {
+ if !isAtext(r, true) {
atext = false
break
}
@@ -479,7 +479,9 @@ func (p *addrParser) consumeGroupList() ([]*Address, error) {
// handle empty group.
p.skipSpace()
if p.consume(';') {
- p.skipCFWS()
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
return group, nil
}
@@ -496,7 +498,9 @@ func (p *addrParser) consumeGroupList() ([]*Address, error) {
return nil, errors.New("mail: misformatted parenthetical comment")
}
if p.consume(';') {
- p.skipCFWS()
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
break
}
if !p.consume(',') {
@@ -566,6 +570,12 @@ func (p *addrParser) consumePhrase() (phrase string, err error) {
var words []string
var isPrevEncoded bool
for {
+ // obs-phrase allows CFWS after one word
+ if len(words) > 0 {
+ if !p.skipCFWS() {
+ return "", errors.New("mail: misformatted parenthetical comment")
+ }
+ }
// word = atom / quoted-string
var word string
p.skipSpace()
@@ -661,7 +671,6 @@ Loop:
// If dot is true, consumeAtom parses an RFC 5322 dot-atom instead.
// If permissive is true, consumeAtom will not fail on:
// - leading/trailing/double dots in the atom (see golang.org/issue/4938)
-// - special characters (RFC 5322 3.2.3) except '<', '>', ':' and '"' (see golang.org/issue/21018)
func (p *addrParser) consumeAtom(dot bool, permissive bool) (atom string, err error) {
i := 0
@@ -672,7 +681,7 @@ Loop:
case size == 1 && r == utf8.RuneError:
return "", fmt.Errorf("mail: invalid utf-8 in address: %q", p.s)
- case size == 0 || !isAtext(r, dot, permissive):
+ case size == 0 || !isAtext(r, dot):
break Loop
default:
@@ -850,18 +859,13 @@ func (e charsetError) Error() string {
// isAtext reports whether r is an RFC 5322 atext character.
// If dot is true, period is included.
-// If permissive is true, RFC 5322 3.2.3 specials is included,
-// except '<', '>', ':' and '"'.
-func isAtext(r rune, dot, permissive bool) bool {
+func isAtext(r rune, dot bool) bool {
switch r {
case '.':
return dot
// RFC 5322 3.2.3. specials
- case '(', ')', '[', ']', ';', '@', '\\', ',':
- return permissive
-
- case '<', '>', '"', ':':
+ case '(', ')', '<', '>', '[', ']', ':', ';', '@', '\\', ',', '"': // RFC 5322 3.2.3. specials
return false
}
return isVchar(r)
diff --git a/src/net/mail/message_test.go b/src/net/mail/message_test.go
index 1e1bb4092f..1f2f62afbf 100644
--- a/src/net/mail/message_test.go
+++ b/src/net/mail/message_test.go
@@ -385,8 +385,11 @@ func TestAddressParsingError(t *testing.T) {
13: {"group not closed: null@example.com", "expected comma"},
14: {"group: first@example.com, second@example.com;", "group with multiple addresses"},
15: {"john.doe", "missing '@' or angle-addr"},
- 16: {"john.doe@", "no angle-addr"},
+ 16: {"john.doe@", "missing '@' or angle-addr"},
17: {"John Doe@foo.bar", "no angle-addr"},
+ 18: {" group: null@example.com; (asd", "misformatted parenthetical comment"},
+ 19: {" group: ; (asd", "misformatted parenthetical comment"},
+ 20: {`(John) Doe <jdoe@machine.example>`, "missing word in phrase:"},
}
for i, tc := range mustErrTestCases {
@@ -436,24 +439,19 @@ func TestAddressParsing(t *testing.T) {
Address: "john.q.public@example.com",
}},
},
- {
- `"John (middle) Doe" <jdoe@machine.example>`,
- []*Address{{
- Name: "John (middle) Doe",
- Address: "jdoe@machine.example",
- }},
- },
+ // Comment in display name
{
`John (middle) Doe <jdoe@machine.example>`,
[]*Address{{
- Name: "John (middle) Doe",
+ Name: "John Doe",
Address: "jdoe@machine.example",
}},
},
+ // Display name is quoted string, so comment is not a comment
{
- `John !@M@! Doe <jdoe@machine.example>`,
+ `"John (middle) Doe" <jdoe@machine.example>`,
[]*Address{{
- Name: "John !@M@! Doe",
+ Name: "John (middle) Doe",
Address: "jdoe@machine.example",
}},
},
@@ -788,6 +786,26 @@ func TestAddressParsing(t *testing.T) {
},
},
},
+ // Comment in group display name
+ {
+ `group (comment:): a@example.com, b@example.com;`,
+ []*Address{
+ {
+ Address: "a@example.com",
+ },
+ {
+ Address: "b@example.com",
+ },
+ },
+ },
+ {
+ `x(:"):"@a.example;("@b.example;`,
+ []*Address{
+ {
+ Address: `@a.example;(@b.example`,
+ },
+ },
+ },
}
for _, test := range tests {
if len(test.exp) == 1 {
diff --git a/src/net/main_conf_test.go b/src/net/main_conf_test.go
index 28a1cb8351..307ff5dd8c 100644
--- a/src/net/main_conf_test.go
+++ b/src/net/main_conf_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1
+//go:build !plan9
package net
diff --git a/src/net/main_noconf_test.go b/src/net/main_noconf_test.go
index 077a36e5d6..cdd7c54805 100644
--- a/src/net/main_noconf_test.go
+++ b/src/net/main_noconf_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9 || wasip1
+//go:build plan9
package net
diff --git a/src/net/main_posix_test.go b/src/net/main_posix_test.go
index a7942ee327..24a2a55660 100644
--- a/src/net/main_posix_test.go
+++ b/src/net/main_posix_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1
+//go:build !plan9
package net
diff --git a/src/net/main_test.go b/src/net/main_test.go
index 9fd5c88543..7dc1e3ee0d 100644
--- a/src/net/main_test.go
+++ b/src/net/main_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -16,6 +14,7 @@ import (
"strings"
"sync"
"testing"
+ "time"
)
var (
@@ -61,6 +60,20 @@ func TestMain(m *testing.M) {
os.Exit(st)
}
+// mustSetDeadline calls the bound method m to set a deadline on a Conn.
+// If the call fails, mustSetDeadline skips t if the current GOOS is believed
+// not to support deadlines, or fails the test otherwise.
+func mustSetDeadline(t testing.TB, m func(time.Time) error, d time.Duration) {
+ err := m(time.Now().Add(d))
+ if err != nil {
+ t.Helper()
+ if runtime.GOOS == "plan9" {
+ t.Skipf("skipping: %s does not support deadlines", runtime.GOOS)
+ }
+ t.Fatal(err)
+ }
+}
+
type ipv6LinkLocalUnicastTest struct {
network, address string
nameLookup bool
diff --git a/src/net/main_wasm_test.go b/src/net/main_wasm_test.go
new file mode 100644
index 0000000000..b8196bb283
--- /dev/null
+++ b/src/net/main_wasm_test.go
@@ -0,0 +1,13 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build wasip1 || js
+
+package net
+
+func installTestHooks() {}
+
+func uninstallTestHooks() {}
+
+func forceCloseSockets() {}
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
index 07f21b72eb..bc024c0bbd 100644
--- a/src/net/main_windows_test.go
+++ b/src/net/main_windows_test.go
@@ -8,7 +8,6 @@ import "internal/poll"
var (
// Placeholders for saving original socket system calls.
- origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -18,7 +17,6 @@ var (
)
func installTestHooks() {
- socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -28,7 +26,6 @@ func installTestHooks() {
}
func uninstallTestHooks() {
- socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
diff --git a/src/net/mockserver_test.go b/src/net/mockserver_test.go
index f86dd66a2d..46b2a57321 100644
--- a/src/net/mockserver_test.go
+++ b/src/net/mockserver_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -339,6 +337,7 @@ func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig
return c
}
+ t.Helper()
switch network {
case "udp":
if supportsIPv4() {
@@ -359,7 +358,6 @@ func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig
return listenPacket(network, testUnixAddr(t))
}
- t.Helper()
t.Fatalf("%s is not supported", network)
return nil
}
diff --git a/src/net/net.go b/src/net/net.go
index 5cfc25ffca..c434c96bf8 100644
--- a/src/net/net.go
+++ b/src/net/net.go
@@ -8,8 +8,8 @@ TCP/IP, UDP, domain name resolution, and Unix domain sockets.
Although the package provides access to low-level networking
primitives, most clients will need only the basic interface provided
-by the Dial, Listen, and Accept functions and the associated
-Conn and Listener interfaces. The crypto/tls package uses
+by the [Dial], [Listen], and Accept functions and the associated
+[Conn] and [Listener] interfaces. The crypto/tls package uses
the same interfaces and similar Dial and Listen functions.
The Dial function connects to a server:
@@ -39,7 +39,7 @@ The Listen function creates servers:
# Name Resolution
The method for resolving domain names, whether indirectly with functions like Dial
-or directly with functions like LookupHost and LookupAddr, varies by operating system.
+or directly with functions like [LookupHost] and [LookupAddr], varies by operating system.
On Unix systems, the resolver has two options for resolving names.
It can use a pure Go resolver that sends DNS requests directly to the servers
@@ -95,8 +95,8 @@ import (
// Addr represents a network end point address.
//
-// The two methods Network and String conventionally return strings
-// that can be passed as the arguments to Dial, but the exact form
+// The two methods [Addr.Network] and [Addr.String] conventionally return strings
+// that can be passed as the arguments to [Dial], but the exact form
// and meaning of the strings is up to the implementation.
type Addr interface {
Network() string // name of the network (for example, "tcp", "udp")
@@ -284,7 +284,7 @@ func (c *conn) SetWriteBuffer(bytes int) error {
return nil
}
-// File returns a copy of the underlying os.File.
+// File returns a copy of the underlying [os.File].
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
//
@@ -624,7 +624,11 @@ type DNSError struct {
Server string // server used
IsTimeout bool // if true, timed out; not all timeouts set this
IsTemporary bool // if true, error is temporary; not all errors set this
- IsNotFound bool // if true, host could not be found
+
+ // IsNotFound is set to true when the requested name does not
+ // contain any records of the requested type (data not found),
+ // or the name itself was not found (NXDOMAIN).
+ IsNotFound bool
}
func (e *DNSError) Error() string {
@@ -641,12 +645,12 @@ func (e *DNSError) Error() string {
// Timeout reports whether the DNS lookup is known to have timed out.
// This is not always known; a DNS lookup may fail due to a timeout
-// and return a DNSError for which Timeout returns false.
+// and return a [DNSError] for which Timeout returns false.
func (e *DNSError) Timeout() bool { return e.IsTimeout }
// Temporary reports whether the DNS error is known to be temporary.
// This is not always known; a DNS lookup may fail due to a temporary
-// error and return a DNSError for which Temporary returns false.
+// error and return a [DNSError] for which Temporary returns false.
func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary }
// errClosed exists just so that the docs for ErrClosed don't mention
@@ -660,15 +664,53 @@ var errClosed = poll.ErrNetClosing
// errors.Is(err, net.ErrClosed).
var ErrClosed error = errClosed
-type writerOnly struct {
- io.Writer
+// noReadFrom can be embedded alongside another type to
+// hide the ReadFrom method of that other type.
+type noReadFrom struct{}
+
+// ReadFrom hides another ReadFrom method.
+// It should never be called.
+func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
+ panic("can't happen")
+}
+
+// tcpConnWithoutReadFrom implements all the methods of *TCPConn other
+// than ReadFrom. This is used to permit ReadFrom to call io.Copy
+// without leading to a recursive call to ReadFrom.
+type tcpConnWithoutReadFrom struct {
+ noReadFrom
+ *TCPConn
}
// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
// applicable.
-func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
+func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
- return io.Copy(writerOnly{w}, r)
+ return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r)
+}
+
+// noWriteTo can be embedded alongside another type to
+// hide the WriteTo method of that other type.
+type noWriteTo struct{}
+
+// WriteTo hides another WriteTo method.
+// It should never be called.
+func (noWriteTo) WriteTo(io.Writer) (int64, error) {
+ panic("can't happen")
+}
+
+// tcpConnWithoutWriteTo implements all the methods of *TCPConn other
+// than WriteTo. This is used to permit WriteTo to call io.Copy
+// without leading to a recursive call to WriteTo.
+type tcpConnWithoutWriteTo struct {
+ noWriteTo
+ *TCPConn
+}
+
+// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable.
+func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) {
+ // Use wrapper to hide existing w.WriteTo from io.Copy.
+ return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c})
}
// Limit the number of concurrent cgo-using goroutines, because
@@ -714,7 +756,7 @@ var (
// WriteTo writes contents of the buffers to w.
//
-// WriteTo implements io.WriterTo for Buffers.
+// WriteTo implements [io.WriterTo] for [Buffers].
//
// WriteTo modifies the slice v as well as v[i] for 0 <= i < len(v),
// but does not modify v[i][j] for any i, j.
@@ -736,7 +778,7 @@ func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
// Read from the buffers.
//
-// Read implements io.Reader for Buffers.
+// Read implements [io.Reader] for [Buffers].
//
// Read modifies the slice v as well as v[i] for 0 <= i < len(v),
// but does not modify v[i][j] for any i, j.
diff --git a/src/net/net_fake.go b/src/net/net_fake.go
index 908767a1f6..525ff32296 100644
--- a/src/net/net_fake.go
+++ b/src/net/net_fake.go
@@ -2,405 +2,1169 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Fake networking for js/wasm and wasip1/wasm. It is intended to allow tests of other package to pass.
+// Fake networking for js/wasm and wasip1/wasm.
+// It is intended to allow tests of other package to pass.
-//go:build (js && wasm) || wasip1
+//go:build js || wasip1
package net
import (
"context"
+ "errors"
"io"
"os"
+ "runtime"
"sync"
+ "sync/atomic"
"syscall"
"time"
)
-var listenersMu sync.Mutex
-var listeners = make(map[fakeNetAddr]*netFD)
-
-var portCounterMu sync.Mutex
-var portCounter = 0
+var (
+ sockets sync.Map // fakeSockAddr → *netFD
+ fakeSocketIDs sync.Map // fakeNetFD.id → *netFD
+ fakePorts sync.Map // int (port #) → *netFD
+ nextPortCounter atomic.Int32
+)
-func nextPort() int {
- portCounterMu.Lock()
- defer portCounterMu.Unlock()
- portCounter++
- return portCounter
-}
+const defaultBuffer = 65535
-type fakeNetAddr struct {
- network string
+type fakeSockAddr struct {
+ family int
address string
}
-type fakeNetFD struct {
- listener fakeNetAddr
- r *bufferedPipe
- w *bufferedPipe
- incoming chan *netFD
- closedMu sync.Mutex
- closed bool
+func fakeAddr(sa sockaddr) fakeSockAddr {
+ return fakeSockAddr{
+ family: sa.family(),
+ address: sa.String(),
+ }
}
// socket returns a network file descriptor that is ready for
-// asynchronous I/O using the network poller.
+// I/O using the fake network.
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
- fd := &netFD{family: family, sotype: sotype, net: net}
- if laddr != nil && raddr == nil {
- return fakelistener(fd, laddr)
+ if raddr != nil && ctrlCtxFn != nil {
+ return nil, os.NewSyscallError("socket", syscall.ENOTSUP)
}
- fd2 := &netFD{family: family, sotype: sotype, net: net}
- return fakeconn(fd, fd2, laddr, raddr)
-}
-
-func fakeIPAndPort(ip IP, port int) (IP, int) {
- if ip == nil {
- ip = IPv4(127, 0, 0, 1)
+ switch sotype {
+ case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET, syscall.SOCK_DGRAM:
+ default:
+ return nil, os.NewSyscallError("socket", syscall.ENOTSUP)
}
- if port == 0 {
- port = nextPort()
+
+ fd := &netFD{
+ family: family,
+ sotype: sotype,
+ net: net,
}
- return ip, port
-}
+ fd.fakeNetFD = newFakeNetFD(fd)
-func fakeTCPAddr(addr *TCPAddr) *TCPAddr {
- var ip IP
- var port int
- var zone string
- if addr != nil {
- ip, port, zone = addr.IP, addr.Port, addr.Zone
+ if raddr == nil {
+ if err := fakeListen(fd, laddr); err != nil {
+ fd.Close()
+ return nil, err
+ }
+ return fd, nil
}
- ip, port = fakeIPAndPort(ip, port)
- return &TCPAddr{IP: ip, Port: port, Zone: zone}
-}
-func fakeUDPAddr(addr *UDPAddr) *UDPAddr {
- var ip IP
- var port int
- var zone string
- if addr != nil {
- ip, port, zone = addr.IP, addr.Port, addr.Zone
+ if err := fakeConnect(ctx, fd, laddr, raddr); err != nil {
+ fd.Close()
+ return nil, err
}
- ip, port = fakeIPAndPort(ip, port)
- return &UDPAddr{IP: ip, Port: port, Zone: zone}
+ return fd, nil
}
-func fakeUnixAddr(sotype int, addr *UnixAddr) *UnixAddr {
- var net, name string
- if addr != nil {
- name = addr.Name
+func validateResolvedAddr(net string, family int, sa sockaddr) error {
+ validateIP := func(ip IP) error {
+ switch family {
+ case syscall.AF_INET:
+ if len(ip) != 4 {
+ return &AddrError{
+ Err: "non-IPv4 address",
+ Addr: ip.String(),
+ }
+ }
+ case syscall.AF_INET6:
+ if len(ip) != 16 {
+ return &AddrError{
+ Err: "non-IPv6 address",
+ Addr: ip.String(),
+ }
+ }
+ default:
+ panic("net: unexpected address family in validateResolvedAddr")
+ }
+ return nil
}
- switch sotype {
- case syscall.SOCK_DGRAM:
- net = "unixgram"
- case syscall.SOCK_SEQPACKET:
- net = "unixpacket"
+
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ sa, ok := sa.(*TCPAddr)
+ if !ok {
+ return &AddrError{
+ Err: "non-TCP address for " + net + " network",
+ Addr: sa.String(),
+ }
+ }
+ if err := validateIP(sa.IP); err != nil {
+ return err
+ }
+ if sa.Port <= 0 || sa.Port >= 1<<16 {
+ return &AddrError{
+ Err: "port out of range",
+ Addr: sa.String(),
+ }
+ }
+ return nil
+
+ case "udp", "udp4", "udp6":
+ sa, ok := sa.(*UDPAddr)
+ if !ok {
+ return &AddrError{
+ Err: "non-UDP address for " + net + " network",
+ Addr: sa.String(),
+ }
+ }
+ if err := validateIP(sa.IP); err != nil {
+ return err
+ }
+ if sa.Port <= 0 || sa.Port >= 1<<16 {
+ return &AddrError{
+ Err: "port out of range",
+ Addr: sa.String(),
+ }
+ }
+ return nil
+
+ case "unix", "unixgram", "unixpacket":
+ sa, ok := sa.(*UnixAddr)
+ if !ok {
+ return &AddrError{
+ Err: "non-Unix address for " + net + " network",
+ Addr: sa.String(),
+ }
+ }
+ if sa.Name != "" {
+ i := len(sa.Name) - 1
+ for i > 0 && !os.IsPathSeparator(sa.Name[i]) {
+ i--
+ }
+ for i > 0 && os.IsPathSeparator(sa.Name[i]) {
+ i--
+ }
+ if i <= 0 {
+ return &AddrError{
+ Err: "unix socket name missing path component",
+ Addr: sa.Name,
+ }
+ }
+ if _, err := os.Stat(sa.Name[:i+1]); err != nil {
+ return &AddrError{
+ Err: err.Error(),
+ Addr: sa.Name,
+ }
+ }
+ }
+ return nil
+
default:
- net = "unix"
+ return &AddrError{
+ Err: syscall.EAFNOSUPPORT.Error(),
+ Addr: sa.String(),
+ }
}
- return &UnixAddr{Net: net, Name: name}
}
-func fakelistener(fd *netFD, laddr sockaddr) (*netFD, error) {
- switch l := laddr.(type) {
+func matchIPFamily(family int, addr sockaddr) sockaddr {
+ convertIP := func(ip IP) IP {
+ switch family {
+ case syscall.AF_INET:
+ return ip.To4()
+ case syscall.AF_INET6:
+ return ip.To16()
+ default:
+ return ip
+ }
+ }
+
+ switch addr := addr.(type) {
case *TCPAddr:
- laddr = fakeTCPAddr(l)
+ ip := convertIP(addr.IP)
+ if ip == nil || len(ip) == len(addr.IP) {
+ return addr
+ }
+ return &TCPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone}
case *UDPAddr:
- laddr = fakeUDPAddr(l)
- case *UnixAddr:
- if l.Name == "" {
- return nil, syscall.ENOENT
+ ip := convertIP(addr.IP)
+ if ip == nil || len(ip) == len(addr.IP) {
+ return addr
}
- laddr = fakeUnixAddr(fd.sotype, l)
+ return &UDPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone}
default:
- return nil, syscall.EOPNOTSUPP
+ return addr
}
+}
- listener := fakeNetAddr{
- network: laddr.Network(),
- address: laddr.String(),
- }
+type fakeNetFD struct {
+ fd *netFD
+ assignedPort int // 0 if no port has been assigned for this socket
+
+ queue *packetQueue // incoming packets
+ peer *netFD // connected peer (for outgoing packets); nil for listeners and PacketConns
+ readDeadline atomic.Pointer[deadlineTimer]
+ writeDeadline atomic.Pointer[deadlineTimer]
+
+ fakeAddr fakeSockAddr // cached fakeSockAddr equivalent of fd.laddr
- fd.fakeNetFD = &fakeNetFD{
- listener: listener,
- incoming: make(chan *netFD, 1024),
+ // The incoming channels hold incoming connections that have not yet been accepted.
+ // All of these channels are 1-buffered.
+ incoming chan []*netFD // holds the queue when it has >0 but <SOMAXCONN pending connections; closed when the Listener is closed
+ incomingFull chan []*netFD // holds the queue when it has SOMAXCONN pending connections
+ incomingEmpty chan bool // holds true when the incoming queue is empty
+}
+
+func newFakeNetFD(fd *netFD) *fakeNetFD {
+ ffd := &fakeNetFD{fd: fd}
+ ffd.readDeadline.Store(newDeadlineTimer(noDeadline))
+ ffd.writeDeadline.Store(newDeadlineTimer(noDeadline))
+ return ffd
+}
+
+func (ffd *fakeNetFD) Read(p []byte) (n int, err error) {
+ n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, false, nil)
+ return n, err
+}
+
+func (ffd *fakeNetFD) Write(p []byte) (nn int, err error) {
+ peer := ffd.peer
+ if peer == nil {
+ if ffd.fd.raddr == nil {
+ return 0, os.NewSyscallError("write", syscall.ENOTCONN)
+ }
+ peeri, _ := sockets.Load(fakeAddr(ffd.fd.raddr.(sockaddr)))
+ if peeri == nil {
+ return 0, os.NewSyscallError("write", syscall.ECONNRESET)
+ }
+ peer = peeri.(*netFD)
+ if peer.queue == nil {
+ return 0, os.NewSyscallError("write", syscall.ECONNRESET)
+ }
}
- fd.laddr = laddr
- listenersMu.Lock()
- defer listenersMu.Unlock()
- if _, exists := listeners[listener]; exists {
- return nil, syscall.EADDRINUSE
+ if peer.fakeNetFD == nil {
+ return 0, os.NewSyscallError("write", syscall.EINVAL)
}
- listeners[listener] = fd
- return fd, nil
+ return peer.queue.write(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr))
}
-func fakeconn(fd *netFD, fd2 *netFD, laddr, raddr sockaddr) (*netFD, error) {
- switch r := raddr.(type) {
- case *TCPAddr:
- r = fakeTCPAddr(r)
- raddr = r
- laddr = fakeTCPAddr(laddr.(*TCPAddr))
- case *UDPAddr:
- r = fakeUDPAddr(r)
- raddr = r
- laddr = fakeUDPAddr(laddr.(*UDPAddr))
- case *UnixAddr:
- r = fakeUnixAddr(fd.sotype, r)
- raddr = r
- laddr = &UnixAddr{Net: r.Net, Name: r.Name}
- default:
- return nil, syscall.EAFNOSUPPORT
+func (ffd *fakeNetFD) Close() (err error) {
+ if ffd.fakeAddr != (fakeSockAddr{}) {
+ sockets.CompareAndDelete(ffd.fakeAddr, ffd.fd)
}
- fd.laddr = laddr
- fd.raddr = raddr
- fd.fakeNetFD = &fakeNetFD{
- r: newBufferedPipe(65536),
- w: newBufferedPipe(65536),
+ if ffd.queue != nil {
+ if closeErr := ffd.queue.closeRead(); err == nil {
+ err = closeErr
+ }
}
- fd2.fakeNetFD = &fakeNetFD{
- r: fd.fakeNetFD.w,
- w: fd.fakeNetFD.r,
+ if ffd.peer != nil {
+ if closeErr := ffd.peer.queue.closeWrite(); err == nil {
+ err = closeErr
+ }
}
+ ffd.readDeadline.Load().Reset(noDeadline)
+ ffd.writeDeadline.Load().Reset(noDeadline)
- fd2.laddr = fd.raddr
- fd2.raddr = fd.laddr
+ if ffd.incoming != nil {
+ var (
+ incoming []*netFD
+ ok bool
+ )
+ select {
+ case _, ok = <-ffd.incomingEmpty:
+ case incoming, ok = <-ffd.incoming:
+ case incoming, ok = <-ffd.incomingFull:
+ }
+ if ok {
+ // Sends on ffd.incoming require a receive first.
+ // Since we successfully received, no other goroutine may
+ // send on it at this point, and we may safely close it.
+ close(ffd.incoming)
- listener := fakeNetAddr{
- network: fd.raddr.Network(),
- address: fd.raddr.String(),
+ for _, c := range incoming {
+ c.Close()
+ }
+ }
}
- listenersMu.Lock()
- defer listenersMu.Unlock()
- l, ok := listeners[listener]
- if !ok {
- return nil, syscall.ECONNREFUSED
+
+ if ffd.assignedPort != 0 {
+ fakePorts.CompareAndDelete(ffd.assignedPort, ffd.fd)
}
- l.incoming <- fd2
- return fd, nil
+
+ return err
}
-func (fd *fakeNetFD) Read(p []byte) (n int, err error) {
- return fd.r.Read(p)
+func (ffd *fakeNetFD) closeRead() error {
+ return ffd.queue.closeRead()
}
-func (fd *fakeNetFD) Write(p []byte) (nn int, err error) {
- return fd.w.Write(p)
+func (ffd *fakeNetFD) closeWrite() error {
+ if ffd.peer == nil {
+ return os.NewSyscallError("closeWrite", syscall.ENOTCONN)
+ }
+ return ffd.peer.queue.closeWrite()
}
-func (fd *fakeNetFD) Close() error {
- fd.closedMu.Lock()
- if fd.closed {
- fd.closedMu.Unlock()
- return nil
+func (ffd *fakeNetFD) accept(laddr Addr) (*netFD, error) {
+ if ffd.incoming == nil {
+ return nil, os.NewSyscallError("accept", syscall.EINVAL)
}
- fd.closed = true
- fd.closedMu.Unlock()
- if fd.listener != (fakeNetAddr{}) {
- listenersMu.Lock()
- delete(listeners, fd.listener)
- close(fd.incoming)
- fd.listener = fakeNetAddr{}
- listenersMu.Unlock()
- return nil
+ var (
+ incoming []*netFD
+ ok bool
+ )
+ select {
+ case <-ffd.readDeadline.Load().expired:
+ return nil, os.ErrDeadlineExceeded
+ case incoming, ok = <-ffd.incoming:
+ if !ok {
+ return nil, ErrClosed
+ }
+ case incoming, ok = <-ffd.incomingFull:
}
- fd.r.Close()
- fd.w.Close()
- return nil
+ peer := incoming[0]
+ incoming = incoming[1:]
+ if len(incoming) == 0 {
+ ffd.incomingEmpty <- true
+ } else {
+ ffd.incoming <- incoming
+ }
+ return peer, nil
}
-func (fd *fakeNetFD) closeRead() error {
- fd.r.Close()
+func (ffd *fakeNetFD) SetDeadline(t time.Time) error {
+ err1 := ffd.SetReadDeadline(t)
+ err2 := ffd.SetWriteDeadline(t)
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (ffd *fakeNetFD) SetReadDeadline(t time.Time) error {
+ dt := ffd.readDeadline.Load()
+ if !dt.Reset(t) {
+ ffd.readDeadline.Store(newDeadlineTimer(t))
+ }
return nil
}
-func (fd *fakeNetFD) closeWrite() error {
- fd.w.Close()
+func (ffd *fakeNetFD) SetWriteDeadline(t time.Time) error {
+ dt := ffd.writeDeadline.Load()
+ if !dt.Reset(t) {
+ ffd.writeDeadline.Store(newDeadlineTimer(t))
+ }
return nil
}
-func (fd *fakeNetFD) accept() (*netFD, error) {
- c, ok := <-fd.incoming
- if !ok {
- return nil, syscall.EINVAL
+const maxPacketSize = 65535
+
+type packet struct {
+ buf []byte
+ bufOffset int
+ next *packet
+ from sockaddr
+}
+
+func (p *packet) clear() {
+ p.buf = p.buf[:0]
+ p.bufOffset = 0
+ p.next = nil
+ p.from = nil
+}
+
+var packetPool = sync.Pool{
+ New: func() any { return new(packet) },
+}
+
+type packetQueueState struct {
+ head, tail *packet // unqueued packets
+ nBytes int // number of bytes enqueued in the packet buffers starting from head
+ readBufferBytes int // soft limit on nbytes; no more packets may be enqueued when the limit is exceeded
+ readClosed bool // true if the reader of the queue has stopped reading
+ writeClosed bool // true if the writer of the queue has stopped writing; the reader sees either io.EOF or syscall.ECONNRESET when they have read all buffered packets
+ noLinger bool // if true, the reader sees ECONNRESET instead of EOF
+}
+
+// A packetQueue is a set of 1-buffered channels implementing a FIFO queue
+// of packets.
+type packetQueue struct {
+ empty chan packetQueueState // contains configuration parameters when the queue is empty and not closed
+ ready chan packetQueueState // contains the packets when non-empty or closed
+ full chan packetQueueState // contains the packets when buffer is full and not closed
+}
+
+func newPacketQueue(readBufferBytes int) *packetQueue {
+ pq := &packetQueue{
+ empty: make(chan packetQueueState, 1),
+ ready: make(chan packetQueueState, 1),
+ full: make(chan packetQueueState, 1),
+ }
+ pq.put(packetQueueState{
+ readBufferBytes: readBufferBytes,
+ })
+ return pq
+}
+
+func (pq *packetQueue) get() packetQueueState {
+ var q packetQueueState
+ select {
+ case q = <-pq.empty:
+ case q = <-pq.ready:
+ case q = <-pq.full:
+ }
+ return q
+}
+
+func (pq *packetQueue) put(q packetQueueState) {
+ switch {
+ case q.readClosed || q.writeClosed:
+ pq.ready <- q
+ case q.nBytes >= q.readBufferBytes:
+ pq.full <- q
+ case q.head == nil:
+ if q.nBytes > 0 {
+ defer panic("net: put with nil packet list and nonzero nBytes")
+ }
+ pq.empty <- q
+ default:
+ pq.ready <- q
}
- return c, nil
}
-func (fd *fakeNetFD) SetDeadline(t time.Time) error {
- fd.r.SetReadDeadline(t)
- fd.w.SetWriteDeadline(t)
+func (pq *packetQueue) closeRead() error {
+ q := pq.get()
+
+ // Discard any unread packets.
+ for q.head != nil {
+ p := q.head
+ q.head = p.next
+ p.clear()
+ packetPool.Put(p)
+ }
+ q.nBytes = 0
+
+ q.readClosed = true
+ pq.put(q)
return nil
}
-func (fd *fakeNetFD) SetReadDeadline(t time.Time) error {
- fd.r.SetReadDeadline(t)
+func (pq *packetQueue) closeWrite() error {
+ q := pq.get()
+ q.writeClosed = true
+ pq.put(q)
return nil
}
-func (fd *fakeNetFD) SetWriteDeadline(t time.Time) error {
- fd.w.SetWriteDeadline(t)
+func (pq *packetQueue) setLinger(linger bool) error {
+ q := pq.get()
+ defer func() { pq.put(q) }()
+
+ if q.writeClosed {
+ return ErrClosed
+ }
+ q.noLinger = !linger
return nil
}
-func newBufferedPipe(softLimit int) *bufferedPipe {
- p := &bufferedPipe{softLimit: softLimit}
- p.rCond.L = &p.mu
- p.wCond.L = &p.mu
- return p
+func (pq *packetQueue) write(dt *deadlineTimer, b []byte, from sockaddr) (n int, err error) {
+ for {
+ dn := len(b)
+ if dn > maxPacketSize {
+ dn = maxPacketSize
+ }
+
+ dn, err = pq.send(dt, b[:dn], from, true)
+ n += dn
+ if err != nil {
+ return n, err
+ }
+
+ b = b[dn:]
+ if len(b) == 0 {
+ return n, nil
+ }
+ }
}
-type bufferedPipe struct {
- softLimit int
- mu sync.Mutex
- buf []byte
- closed bool
- rCond sync.Cond
- wCond sync.Cond
- rDeadline time.Time
- wDeadline time.Time
+func (pq *packetQueue) send(dt *deadlineTimer, b []byte, from sockaddr, block bool) (n int, err error) {
+ if from == nil {
+ return 0, os.NewSyscallError("send", syscall.EINVAL)
+ }
+ if len(b) > maxPacketSize {
+ return 0, os.NewSyscallError("send", syscall.EMSGSIZE)
+ }
+
+ var q packetQueueState
+ var full chan packetQueueState
+ if !block {
+ full = pq.full
+ }
+
+ // Before we check dt.expired, yield to other goroutines.
+ // This may help to prevent starvation of the goroutine that runs the
+ // deadlineTimer's time.After callback.
+ //
+ // TODO(#65178): Remove this when the runtime scheduler no longer starves
+ // runnable goroutines.
+ runtime.Gosched()
+
+ select {
+ case <-dt.expired:
+ return 0, os.ErrDeadlineExceeded
+
+ case q = <-full:
+ pq.put(q)
+ return 0, os.NewSyscallError("send", syscall.ENOBUFS)
+
+ case q = <-pq.empty:
+ case q = <-pq.ready:
+ }
+ defer func() { pq.put(q) }()
+
+ // Don't allow a packet to be sent if the deadline has expired,
+ // even if the select above chose a different branch.
+ select {
+ case <-dt.expired:
+ return 0, os.ErrDeadlineExceeded
+ default:
+ }
+ if q.writeClosed {
+ return 0, ErrClosed
+ } else if q.readClosed {
+ return 0, os.NewSyscallError("send", syscall.ECONNRESET)
+ }
+
+ p := packetPool.Get().(*packet)
+ p.buf = append(p.buf[:0], b...)
+ p.from = from
+
+ if q.head == nil {
+ q.head = p
+ } else {
+ q.tail.next = p
+ }
+ q.tail = p
+ q.nBytes += len(p.buf)
+
+ return len(b), nil
}
-func (p *bufferedPipe) Read(b []byte) (int, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
+func (pq *packetQueue) recvfrom(dt *deadlineTimer, b []byte, wholePacket bool, checkFrom func(sockaddr) error) (n int, from sockaddr, err error) {
+ var q packetQueueState
+ var empty chan packetQueueState
+ if len(b) == 0 {
+ // For consistency with the implementation on Unix platforms,
+ // allow a zero-length Read to proceed if the queue is empty.
+ // (Without this, TestZeroByteRead deadlocks.)
+ empty = pq.empty
+ }
- for {
- if p.closed && len(p.buf) == 0 {
- return 0, io.EOF
- }
- if !p.rDeadline.IsZero() {
- d := time.Until(p.rDeadline)
- if d <= 0 {
- return 0, os.ErrDeadlineExceeded
+ // Before we check dt.expired, yield to other goroutines.
+ // This may help to prevent starvation of the goroutine that runs the
+ // deadlineTimer's time.After callback.
+ //
+ // TODO(#65178): Remove this when the runtime scheduler no longer starves
+ // runnable goroutines.
+ runtime.Gosched()
+
+ select {
+ case <-dt.expired:
+ return 0, nil, os.ErrDeadlineExceeded
+ case q = <-empty:
+ case q = <-pq.ready:
+ case q = <-pq.full:
+ }
+ defer func() { pq.put(q) }()
+
+ p := q.head
+ if p == nil {
+ switch {
+ case q.readClosed:
+ return 0, nil, ErrClosed
+ case q.writeClosed:
+ if q.noLinger {
+ return 0, nil, os.NewSyscallError("recvfrom", syscall.ECONNRESET)
}
- time.AfterFunc(d, p.rCond.Broadcast)
+ return 0, nil, io.EOF
+ case len(b) == 0:
+ return 0, nil, nil
+ default:
+ // This should be impossible: pq.full should only contain a non-empty list,
+ // pq.ready should either contain a non-empty list or indicate that the
+ // connection is closed, and we should only receive from pq.empty if
+ // len(b) == 0.
+ panic("net: nil packet list from non-closed packetQueue")
}
- if len(p.buf) > 0 {
- break
+ }
+
+ select {
+ case <-dt.expired:
+ return 0, nil, os.ErrDeadlineExceeded
+ default:
+ }
+
+ if checkFrom != nil {
+ if err := checkFrom(p.from); err != nil {
+ return 0, nil, err
}
- p.rCond.Wait()
}
- n := copy(b, p.buf)
- p.buf = p.buf[n:]
- p.wCond.Broadcast()
- return n, nil
+ n = copy(b, p.buf[p.bufOffset:])
+ from = p.from
+ if wholePacket || p.bufOffset+n == len(p.buf) {
+ q.head = p.next
+ q.nBytes -= len(p.buf)
+ p.clear()
+ packetPool.Put(p)
+ } else {
+ p.bufOffset += n
+ }
+
+ return n, from, nil
+}
+
+// setReadBuffer sets a soft limit on the number of bytes available to read
+// from the pipe.
+func (pq *packetQueue) setReadBuffer(bytes int) error {
+ if bytes <= 0 {
+ return os.NewSyscallError("setReadBuffer", syscall.EINVAL)
+ }
+ q := pq.get() // Use the queue as a lock.
+ q.readBufferBytes = bytes
+ pq.put(q)
+ return nil
+}
+
+type deadlineTimer struct {
+ timer chan *time.Timer
+ expired chan struct{}
}
-func (p *bufferedPipe) Write(b []byte) (int, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
+func newDeadlineTimer(deadline time.Time) *deadlineTimer {
+ dt := &deadlineTimer{
+ timer: make(chan *time.Timer, 1),
+ expired: make(chan struct{}),
+ }
+ dt.timer <- nil
+ dt.Reset(deadline)
+ return dt
+}
- for {
- if p.closed {
- return 0, syscall.ENOTCONN
+// Reset attempts to reset the timer.
+// If the timer has already expired, Reset returns false.
+func (dt *deadlineTimer) Reset(deadline time.Time) bool {
+ timer := <-dt.timer
+ defer func() { dt.timer <- timer }()
+
+ if deadline.Equal(noDeadline) {
+ if timer != nil && timer.Stop() {
+ timer = nil
}
- if !p.wDeadline.IsZero() {
- d := time.Until(p.wDeadline)
- if d <= 0 {
- return 0, os.ErrDeadlineExceeded
+ return timer == nil
+ }
+
+ d := time.Until(deadline)
+ if d < 0 {
+ // Ensure that a deadline in the past takes effect immediately.
+ defer func() { <-dt.expired }()
+ }
+
+ if timer == nil {
+ timer = time.AfterFunc(d, func() { close(dt.expired) })
+ return true
+ }
+ if !timer.Stop() {
+ return false
+ }
+ timer.Reset(d)
+ return true
+}
+
+func sysSocket(family, sotype, proto int) (int, error) {
+ return 0, os.NewSyscallError("sysSocket", syscall.ENOSYS)
+}
+
+func fakeListen(fd *netFD, laddr sockaddr) (err error) {
+ wrapErr := func(err error) error {
+ if errno, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("listen", errno)
+ }
+ if errors.Is(err, syscall.EADDRINUSE) {
+ return err
+ }
+ if laddr != nil {
+ if _, ok := err.(*AddrError); !ok {
+ err = &AddrError{
+ Err: err.Error(),
+ Addr: laddr.String(),
+ }
}
- time.AfterFunc(d, p.wCond.Broadcast)
}
- if len(p.buf) <= p.softLimit {
- break
+ return err
+ }
+
+ ffd := newFakeNetFD(fd)
+ defer func() {
+ if fd.fakeNetFD != ffd {
+ // Failed to register listener; clean up.
+ ffd.Close()
}
- p.wCond.Wait()
+ }()
+
+ if err := ffd.assignFakeAddr(matchIPFamily(fd.family, laddr)); err != nil {
+ return wrapErr(err)
}
- p.buf = append(p.buf, b...)
- p.rCond.Broadcast()
- return len(b), nil
-}
+ ffd.fakeAddr = fakeAddr(fd.laddr.(sockaddr))
+ switch fd.sotype {
+ case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
+ ffd.incoming = make(chan []*netFD, 1)
+ ffd.incomingFull = make(chan []*netFD, 1)
+ ffd.incomingEmpty = make(chan bool, 1)
+ ffd.incomingEmpty <- true
+ case syscall.SOCK_DGRAM:
+ ffd.queue = newPacketQueue(defaultBuffer)
+ default:
+ return wrapErr(syscall.EINVAL)
+ }
-func (p *bufferedPipe) Close() {
- p.mu.Lock()
- defer p.mu.Unlock()
+ fd.fakeNetFD = ffd
+ if _, dup := sockets.LoadOrStore(ffd.fakeAddr, fd); dup {
+ fd.fakeNetFD = nil
+ return wrapErr(syscall.EADDRINUSE)
+ }
- p.closed = true
- p.rCond.Broadcast()
- p.wCond.Broadcast()
+ return nil
}
-func (p *bufferedPipe) SetReadDeadline(t time.Time) {
- p.mu.Lock()
- defer p.mu.Unlock()
+func fakeConnect(ctx context.Context, fd *netFD, laddr, raddr sockaddr) error {
+ wrapErr := func(err error) error {
+ if errno, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("connect", errno)
+ }
+ if errors.Is(err, syscall.EADDRINUSE) {
+ return err
+ }
+ if terr, ok := err.(interface{ Timeout() bool }); !ok || !terr.Timeout() {
+ // For consistency with the net implementation on other platforms,
+ // if we don't need to preserve the Timeout-ness of err we should
+ // wrap it in an AddrError. (Unfortunately we can't wrap errors
+ // that convey structured information, because AddrError reduces
+ // the wrapped Err to a flat string.)
+ if _, ok := err.(*AddrError); !ok {
+ err = &AddrError{
+ Err: err.Error(),
+ Addr: raddr.String(),
+ }
+ }
+ }
+ return err
+ }
+
+ if fd.isConnected {
+ return wrapErr(syscall.EISCONN)
+ }
+ if ctx.Err() != nil {
+ return wrapErr(syscall.ETIMEDOUT)
+ }
+
+ fd.raddr = matchIPFamily(fd.family, raddr)
+ if err := validateResolvedAddr(fd.net, fd.family, fd.raddr.(sockaddr)); err != nil {
+ return wrapErr(err)
+ }
+
+ if err := fd.fakeNetFD.assignFakeAddr(laddr); err != nil {
+ return wrapErr(err)
+ }
+ fd.fakeNetFD.queue = newPacketQueue(defaultBuffer)
+
+ switch fd.sotype {
+ case syscall.SOCK_DGRAM:
+ if ua, ok := fd.laddr.(*UnixAddr); !ok || ua.Name != "" {
+ fd.fakeNetFD.fakeAddr = fakeAddr(fd.laddr.(sockaddr))
+ if _, dup := sockets.LoadOrStore(fd.fakeNetFD.fakeAddr, fd); dup {
+ return wrapErr(syscall.EADDRINUSE)
+ }
+ }
+ fd.isConnected = true
+ return nil
+
+ case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
+ default:
+ return wrapErr(syscall.EINVAL)
+ }
+
+ fa := fakeAddr(raddr)
+ lni, ok := sockets.Load(fa)
+ if !ok {
+ return wrapErr(syscall.ECONNREFUSED)
+ }
+ ln := lni.(*netFD)
+ if ln.sotype != fd.sotype {
+ return wrapErr(syscall.EPROTOTYPE)
+ }
+ if ln.incoming == nil {
+ return wrapErr(syscall.ECONNREFUSED)
+ }
+
+ peer := &netFD{
+ family: ln.family,
+ sotype: ln.sotype,
+ net: ln.net,
+ laddr: ln.laddr,
+ raddr: fd.laddr,
+ isConnected: true,
+ }
+ peer.fakeNetFD = newFakeNetFD(fd)
+ peer.fakeNetFD.queue = newPacketQueue(defaultBuffer)
+ defer func() {
+ if fd.peer != peer {
+ // Failed to connect; clean up.
+ peer.Close()
+ }
+ }()
+
+ var incoming []*netFD
+ select {
+ case <-ctx.Done():
+ return wrapErr(syscall.ETIMEDOUT)
+ case ok = <-ln.incomingEmpty:
+ case incoming, ok = <-ln.incoming:
+ }
+ if !ok {
+ return wrapErr(syscall.ECONNREFUSED)
+ }
+
+ fd.isConnected = true
+ fd.peer = peer
+ peer.peer = fd
- p.rDeadline = t
- p.rCond.Broadcast()
+ incoming = append(incoming, peer)
+ if len(incoming) >= listenerBacklog() {
+ ln.incomingFull <- incoming
+ } else {
+ ln.incoming <- incoming
+ }
+ return nil
}
-func (p *bufferedPipe) SetWriteDeadline(t time.Time) {
- p.mu.Lock()
- defer p.mu.Unlock()
+func (ffd *fakeNetFD) assignFakeAddr(addr sockaddr) error {
+ validate := func(sa sockaddr) error {
+ if err := validateResolvedAddr(ffd.fd.net, ffd.fd.family, sa); err != nil {
+ return err
+ }
+ ffd.fd.laddr = sa
+ return nil
+ }
+
+ assignIP := func(addr sockaddr) error {
+ var (
+ ip IP
+ port int
+ zone string
+ )
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ if addr != nil {
+ ip = addr.IP
+ port = addr.Port
+ zone = addr.Zone
+ }
+ case *UDPAddr:
+ if addr != nil {
+ ip = addr.IP
+ port = addr.Port
+ zone = addr.Zone
+ }
+ default:
+ return validate(addr)
+ }
+
+ if ip == nil {
+ ip = IPv4(127, 0, 0, 1)
+ }
+ switch ffd.fd.family {
+ case syscall.AF_INET:
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ case syscall.AF_INET6:
+ if ip16 := ip.To16(); ip16 != nil {
+ ip = ip16
+ }
+ }
+ if ip == nil {
+ return syscall.EINVAL
+ }
+
+ if port == 0 {
+ var prevPort int32
+ portWrapped := false
+ nextPort := func() (int, bool) {
+ for {
+ port := nextPortCounter.Add(1)
+ if port <= 0 || port >= 1<<16 {
+ // nextPortCounter ran off the end of the port space.
+ // Bump it back into range.
+ for {
+ if nextPortCounter.CompareAndSwap(port, 0) {
+ break
+ }
+ if port = nextPortCounter.Load(); port >= 0 && port < 1<<16 {
+ break
+ }
+ }
+ if portWrapped {
+ // This is the second wraparound, so we've scanned the whole port space
+ // at least once already and it's time to give up.
+ return 0, false
+ }
+ portWrapped = true
+ prevPort = 0
+ continue
+ }
+
+ if port <= prevPort {
+ // nextPortCounter has wrapped around since the last time we read it.
+ if portWrapped {
+ // This is the second wraparound, so we've scanned the whole port space
+ // at least once already and it's time to give up.
+ return 0, false
+ } else {
+ portWrapped = true
+ }
+ }
+
+ prevPort = port
+ return int(port), true
+ }
+ }
+
+ for {
+ var ok bool
+ port, ok = nextPort()
+ if !ok {
+ ffd.assignedPort = 0
+ return syscall.EADDRINUSE
+ }
+
+ ffd.assignedPort = int(port)
+ if _, dup := fakePorts.LoadOrStore(ffd.assignedPort, ffd.fd); !dup {
+ break
+ }
+ }
+ }
+
+ switch addr.(type) {
+ case *TCPAddr:
+ return validate(&TCPAddr{IP: ip, Port: port, Zone: zone})
+ case *UDPAddr:
+ return validate(&UDPAddr{IP: ip, Port: port, Zone: zone})
+ default:
+ panic("unreachable")
+ }
+ }
+
+ switch ffd.fd.net {
+ case "tcp", "tcp4", "tcp6":
+ if addr == nil {
+ return assignIP(new(TCPAddr))
+ }
+ return assignIP(addr)
+
+ case "udp", "udp4", "udp6":
+ if addr == nil {
+ return assignIP(new(UDPAddr))
+ }
+ return assignIP(addr)
+
+ case "unix", "unixgram", "unixpacket":
+ uaddr, ok := addr.(*UnixAddr)
+ if !ok && addr != nil {
+ return &AddrError{
+ Err: "non-Unix address for " + ffd.fd.net + " network",
+ Addr: addr.String(),
+ }
+ }
+ if uaddr == nil {
+ return validate(&UnixAddr{Net: ffd.fd.net})
+ }
+ return validate(&UnixAddr{Net: ffd.fd.net, Name: uaddr.Name})
- p.wDeadline = t
- p.wCond.Broadcast()
+ default:
+ return &AddrError{
+ Err: syscall.EAFNOSUPPORT.Error(),
+ Addr: addr.String(),
+ }
+ }
}
-func sysSocket(family, sotype, proto int) (int, error) {
- return 0, syscall.ENOSYS
+func (ffd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
+ if ffd.queue == nil {
+ return 0, nil, os.NewSyscallError("readFrom", syscall.EINVAL)
+ }
+
+ n, from, err := ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, nil)
+
+ if from != nil {
+ // Convert the net.sockaddr to a syscall.Sockaddr type.
+ var saErr error
+ sa, saErr = from.sockaddr(ffd.fd.family)
+ if err == nil {
+ err = saErr
+ }
+ }
+
+ return n, sa, err
}
-func (fd *fakeNetFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) {
- return nil, syscall.ENOSYS
+func (ffd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error {
+ fromSA, err := from.sockaddr(syscall.AF_INET)
+ if err != nil {
+ return err
+ }
+ if fromSA == nil {
+ return os.NewSyscallError("readFromInet4", syscall.EINVAL)
+ }
+ *sa = *(fromSA.(*syscall.SockaddrInet4))
+ return nil
+ })
+ return n, err
}
-func (fd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
- return 0, nil, syscall.ENOSYS
+func (ffd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error {
+ fromSA, err := from.sockaddr(syscall.AF_INET6)
+ if err != nil {
+ return err
+ }
+ if fromSA == nil {
+ return os.NewSyscallError("readFromInet6", syscall.EINVAL)
+ }
+ *sa = *(fromSA.(*syscall.SockaddrInet6))
+ return nil
+ })
+ return n, err
+}
+func (ffd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
+ if flags != 0 {
+ return 0, 0, 0, nil, os.NewSyscallError("readMsg", syscall.ENOTSUP)
+ }
+ n, sa, err = ffd.readFrom(p)
+ return n, 0, 0, sa, err
}
-func (fd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
- return 0, syscall.ENOSYS
+
+func (ffd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
+ if flags != 0 {
+ return 0, 0, 0, os.NewSyscallError("readMsgInet4", syscall.ENOTSUP)
+ }
+ n, err = ffd.readFromInet4(p, sa)
+ return n, 0, 0, err
}
-func (fd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
- return 0, syscall.ENOSYS
+func (ffd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
+ if flags != 0 {
+ return 0, 0, 0, os.NewSyscallError("readMsgInet6", syscall.ENOTSUP)
+ }
+ n, err = ffd.readFromInet6(p, sa)
+ return n, 0, 0, err
}
-func (fd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
- return 0, 0, 0, nil, syscall.ENOSYS
+func (ffd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
+ if len(oob) > 0 {
+ return 0, 0, os.NewSyscallError("writeMsg", syscall.ENOTSUP)
+ }
+ n, err = ffd.writeTo(p, sa)
+ return n, 0, err
}
-func (fd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
- return 0, 0, 0, syscall.ENOSYS
+func (ffd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
+ return ffd.writeMsg(p, oob, sa)
}
-func (fd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
- return 0, 0, 0, syscall.ENOSYS
+func (ffd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
+ return ffd.writeMsg(p, oob, sa)
}
-func (fd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
- return 0, 0, syscall.ENOSYS
+func (ffd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
+ raddr := ffd.fd.raddr
+ if sa != nil {
+ if ffd.fd.isConnected {
+ return 0, os.NewSyscallError("writeTo", syscall.EISCONN)
+ }
+ raddr = ffd.fd.addrFunc()(sa)
+ }
+ if raddr == nil {
+ return 0, os.NewSyscallError("writeTo", syscall.EINVAL)
+ }
+
+ peeri, _ := sockets.Load(fakeAddr(raddr.(sockaddr)))
+ if peeri == nil {
+ if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
+ return len(p), nil
+ }
+ return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET)
+ }
+ peer := peeri.(*netFD)
+ if peer.queue == nil {
+ if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
+ return len(p), nil
+ }
+ return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET)
+ }
+
+ block := true
+ if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
+ block = false
+ }
+ return peer.queue.send(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr), block)
}
-func (fd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
- return 0, 0, syscall.ENOSYS
+func (ffd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ return ffd.writeTo(p, sa)
}
-func (fd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
- return 0, syscall.ENOSYS
+func (ffd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ return ffd.writeTo(p, sa)
}
-func (fd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
- return 0, syscall.ENOSYS
+func (ffd *fakeNetFD) dup() (f *os.File, err error) {
+ return nil, os.NewSyscallError("dup", syscall.ENOSYS)
}
-func (fd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
- return 0, syscall.ENOSYS
+func (ffd *fakeNetFD) setReadBuffer(bytes int) error {
+ if ffd.queue == nil {
+ return os.NewSyscallError("setReadBuffer", syscall.EINVAL)
+ }
+ ffd.queue.setReadBuffer(bytes)
+ return nil
}
-func (fd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
- return 0, 0, syscall.ENOSYS
+func (ffd *fakeNetFD) setWriteBuffer(bytes int) error {
+ return os.NewSyscallError("setWriteBuffer", syscall.ENOTSUP)
}
-func (fd *fakeNetFD) dup() (f *os.File, err error) {
- return nil, syscall.ENOSYS
+func (ffd *fakeNetFD) setLinger(sec int) error {
+ if sec < 0 || ffd.peer == nil {
+ return os.NewSyscallError("setLinger", syscall.EINVAL)
+ }
+ ffd.peer.queue.setLinger(sec > 0)
+ return nil
}
diff --git a/src/net/net_fake_js.go b/src/net/net_fake_js.go
deleted file mode 100644
index 7ba108b664..0000000000
--- a/src/net/net_fake_js.go
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Fake networking for js/wasm. It is intended to allow tests of other package to pass.
-
-//go:build js && wasm
-
-package net
-
-import (
- "context"
- "internal/poll"
-
- "golang.org/x/net/dns/dnsmessage"
-)
-
-// Network file descriptor.
-type netFD struct {
- *fakeNetFD
-
- // immutable until Close
- family int
- sotype int
- net string
- laddr Addr
- raddr Addr
-
- // unused
- pfd poll.FD
- isConnected bool // handshake completed or use of association with peer
-}
-
-func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) {
- panic("unreachable")
-}
diff --git a/src/net/net_fake_test.go b/src/net/net_fake_test.go
index 783304d531..4542228fbc 100644
--- a/src/net/net_fake_test.go
+++ b/src/net/net_fake_test.go
@@ -13,191 +13,95 @@ package net
// The tests in this files are intended to validate the behavior of the fake
// network stack on these platforms.
-import "testing"
-
-func TestFakeConn(t *testing.T) {
- tests := []struct {
- name string
- listen func() (Listener, error)
- dial func(Addr) (Conn, error)
- addr func(*testing.T, Addr)
- }{
- {
- name: "Listener:tcp",
- listen: func() (Listener, error) {
- return Listen("tcp", ":0")
- },
- dial: func(addr Addr) (Conn, error) {
- return Dial(addr.Network(), addr.String())
- },
- addr: testFakeTCPAddr,
- },
-
- {
- name: "ListenTCP:tcp",
- listen: func() (Listener, error) {
- // Creating a listening TCP connection with a nil address must
- // select an IP address on localhost with a random port.
- // This test verifies that the fake network facility does that.
- return ListenTCP("tcp", nil)
- },
- dial: func(addr Addr) (Conn, error) {
- // Connecting a listening TCP connection will select a local
- // address on the local network and connects to the destination
- // address.
- return DialTCP("tcp", nil, addr.(*TCPAddr))
- },
- addr: testFakeTCPAddr,
- },
-
- {
- name: "ListenUnix:unix",
- listen: func() (Listener, error) {
- return ListenUnix("unix", &UnixAddr{Name: "test"})
- },
- dial: func(addr Addr) (Conn, error) {
- return DialUnix("unix", nil, addr.(*UnixAddr))
- },
- addr: testFakeUnixAddr("unix", "test"),
- },
-
- {
- name: "ListenUnix:unixpacket",
- listen: func() (Listener, error) {
- return ListenUnix("unixpacket", &UnixAddr{Name: "test"})
- },
- dial: func(addr Addr) (Conn, error) {
- return DialUnix("unixpacket", nil, addr.(*UnixAddr))
- },
- addr: testFakeUnixAddr("unixpacket", "test"),
- },
+import (
+ "errors"
+ "syscall"
+ "testing"
+)
+
+func TestFakePortExhaustion(t *testing.T) {
+ if testing.Short() {
+ t.Skipf("skipping test that opens 1<<16 connections")
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- l, err := test.listen()
- if err != nil {
- t.Fatal(err)
+ ln := newLocalListener(t, "tcp")
+ done := make(chan struct{})
+ go func() {
+ var accepted []Conn
+ defer func() {
+ for _, c := range accepted {
+ c.Close()
}
- defer l.Close()
- test.addr(t, l.Addr())
+ close(done)
+ }()
- c, err := test.dial(l.Addr())
+ for {
+ c, err := ln.Accept()
if err != nil {
- t.Fatal(err)
+ return
}
- defer c.Close()
- test.addr(t, c.LocalAddr())
- test.addr(t, c.RemoteAddr())
- })
- }
-}
-
-func TestFakePacketConn(t *testing.T) {
- tests := []struct {
- name string
- listen func() (PacketConn, error)
- dial func(Addr) (Conn, error)
- addr func(*testing.T, Addr)
- }{
- {
- name: "ListenPacket:udp",
- listen: func() (PacketConn, error) {
- return ListenPacket("udp", ":0")
- },
- dial: func(addr Addr) (Conn, error) {
- return Dial(addr.Network(), addr.String())
- },
- addr: testFakeUDPAddr,
- },
-
- {
- name: "ListenUDP:udp",
- listen: func() (PacketConn, error) {
- // Creating a listening UDP connection with a nil address must
- // select an IP address on localhost with a random port.
- // This test verifies that the fake network facility does that.
- return ListenUDP("udp", nil)
- },
- dial: func(addr Addr) (Conn, error) {
- // Connecting a listening UDP connection will select a local
- // address on the local network and connects to the destination
- // address.
- return DialUDP("udp", nil, addr.(*UDPAddr))
- },
- addr: testFakeUDPAddr,
- },
+ accepted = append(accepted, c)
+ }
+ }()
- {
- name: "ListenUnixgram:unixgram",
- listen: func() (PacketConn, error) {
- return ListenUnixgram("unixgram", &UnixAddr{Name: "test"})
- },
- dial: func(addr Addr) (Conn, error) {
- return DialUnix("unixgram", nil, addr.(*UnixAddr))
- },
- addr: testFakeUnixAddr("unixgram", "test"),
- },
+ var dialed []Conn
+ defer func() {
+ ln.Close()
+ for _, c := range dialed {
+ c.Close()
+ }
+ <-done
+ }()
+
+ // Since this test is not running in parallel, we expect to be able to open
+ // all 65535 valid (fake) ports. The listener is already using one, so
+ // we should be able to Dial the remaining 65534.
+ for len(dialed) < (1<<16)-2 {
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatalf("unexpected error from Dial with %v connections: %v", len(dialed), err)
+ }
+ dialed = append(dialed, c)
+ if testing.Verbose() && len(dialed)%(1<<12) == 0 {
+ t.Logf("dialed %d connections", len(dialed))
+ }
}
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- l, err := test.listen()
- if err != nil {
- t.Fatal(err)
- }
- defer l.Close()
- test.addr(t, l.LocalAddr())
-
- c, err := test.dial(l.LocalAddr())
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- test.addr(t, c.LocalAddr())
- test.addr(t, c.RemoteAddr())
- })
+ t.Logf("dialed %d connections", len(dialed))
+
+ // Now that all of the ports are in use, dialing another should fail due
+ // to port exhaustion, which (for POSIX-like socket APIs) should return
+ // an EADDRINUSE error.
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err == nil {
+ c.Close()
}
-}
-
-func testFakeTCPAddr(t *testing.T, addr Addr) {
- t.Helper()
- if a, ok := addr.(*TCPAddr); !ok {
- t.Errorf("Addr is not *TCPAddr: %T", addr)
+ if errors.Is(err, syscall.EADDRINUSE) {
+ t.Logf("Dial returned expected error: %v", err)
} else {
- testFakeNetAddr(t, a.IP, a.Port)
+ t.Errorf("unexpected error from Dial: %v\nwant: %v", err, syscall.EADDRINUSE)
}
-}
-func testFakeUDPAddr(t *testing.T, addr Addr) {
- t.Helper()
- if a, ok := addr.(*UDPAddr); !ok {
- t.Errorf("Addr is not *UDPAddr: %T", addr)
+ // Opening a Listener should fail at this point too.
+ ln2, err := Listen("tcp", "localhost:0")
+ if err == nil {
+ ln2.Close()
+ }
+ if errors.Is(err, syscall.EADDRINUSE) {
+ t.Logf("Listen returned expected error: %v", err)
} else {
- testFakeNetAddr(t, a.IP, a.Port)
+ t.Errorf("unexpected error from Listen: %v\nwant: %v", err, syscall.EADDRINUSE)
}
-}
-func testFakeNetAddr(t *testing.T, ip IP, port int) {
- t.Helper()
- if port == 0 {
- t.Error("network address is missing port")
- } else if len(ip) == 0 {
- t.Error("network address is missing IP")
- } else if !ip.Equal(IPv4(127, 0, 0, 1)) {
- t.Errorf("network address has wrong IP: %s", ip)
- }
-}
-
-func testFakeUnixAddr(net, name string) func(*testing.T, Addr) {
- return func(t *testing.T, addr Addr) {
- t.Helper()
- if a, ok := addr.(*UnixAddr); !ok {
- t.Errorf("Addr is not *UnixAddr: %T", addr)
- } else if a.Net != net {
- t.Errorf("unix address has wrong net: want=%q got=%q", net, a.Net)
- } else if a.Name != name {
- t.Errorf("unix address has wrong name: want=%q got=%q", name, a.Name)
- }
+ // When we close an arbitrary connection, we should be able to reuse its port
+ // even if the server hasn't yet seen the ECONNRESET for the connection.
+ dialed[0].Close()
+ dialed = dialed[1:]
+ t.Logf("closed one connection")
+ c, err = Dial(ln.Addr().Network(), ln.Addr().String())
+ if err == nil {
+ c.Close()
+ t.Logf("Dial succeeded")
+ } else {
+ t.Errorf("unexpected error from Dial: %v", err)
}
}
diff --git a/src/net/net_test.go b/src/net/net_test.go
index a0ac85f406..b448a79cce 100644
--- a/src/net/net_test.go
+++ b/src/net/net_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -383,8 +381,16 @@ func TestZeroByteRead(t *testing.T) {
ln := newLocalListener(t, network)
connc := make(chan Conn, 1)
+ defer func() {
+ ln.Close()
+ for c := range connc {
+ if c != nil {
+ c.Close()
+ }
+ }
+ }()
go func() {
- defer ln.Close()
+ defer close(connc)
c, err := ln.Accept()
if err != nil {
t.Error(err)
@@ -440,8 +446,9 @@ func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
errc <- err
return
}
- defer c1.Close()
- errc <- peer1(c1.(*TCPConn))
+ err = peer1(c1.(*TCPConn))
+ c1.Close()
+ errc <- err
}()
go func() {
c2, err := Dial("tcp", ln.Addr().String())
@@ -449,12 +456,13 @@ func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
errc <- err
return
}
- defer c2.Close()
- errc <- peer2(c2.(*TCPConn))
+ err = peer2(c2.(*TCPConn))
+ c2.Close()
+ errc <- err
}()
for i := 0; i < 2; i++ {
if err := <-errc; err != nil {
- t.Fatal(err)
+ t.Error(err)
}
}
}
diff --git a/src/net/netip/export_test.go b/src/net/netip/export_test.go
index 59971fa2e4..72347ee01b 100644
--- a/src/net/netip/export_test.go
+++ b/src/net/netip/export_test.go
@@ -28,3 +28,5 @@ var TestAppendToMarshal = testAppendToMarshal
func (a Addr) IsZero() bool { return a.isZero() }
func (p Prefix) IsZero() bool { return p.isZero() }
+
+func (p Prefix) Compare(p2 Prefix) int { return p.compare(p2) }
diff --git a/src/net/netip/leaf_alts.go b/src/net/netip/leaf_alts.go
index 70513abfd9..d887bed627 100644
--- a/src/net/netip/leaf_alts.go
+++ b/src/net/netip/leaf_alts.go
@@ -7,15 +7,6 @@
package netip
-func stringsLastIndexByte(s string, b byte) int {
- for i := len(s) - 1; i >= 0; i-- {
- if s[i] == b {
- return i
- }
- }
- return -1
-}
-
func beUint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
diff --git a/src/net/netip/netip.go b/src/net/netip/netip.go
index a44b094955..7a189e8e16 100644
--- a/src/net/netip/netip.go
+++ b/src/net/netip/netip.go
@@ -12,6 +12,7 @@
package netip
import (
+ "cmp"
"errors"
"math"
"strconv"
@@ -127,7 +128,7 @@ func ParseAddr(s string) (Addr, error) {
return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"}
}
-// MustParseAddr calls ParseAddr(s) and panics on error.
+// MustParseAddr calls [ParseAddr](s) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustParseAddr(s string) Addr {
ip, err := ParseAddr(s)
@@ -334,8 +335,8 @@ func parseIPv6(in string) (Addr, error) {
}
// AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address.
-// Note that a net.IP can be passed directly as the []byte argument.
-// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false.
+// Note that a [net.IP] can be passed directly as the []byte argument.
+// If slice's length is not 4 or 16, AddrFromSlice returns [Addr]{}, false.
func AddrFromSlice(slice []byte) (ip Addr, ok bool) {
switch len(slice) {
case 4:
@@ -375,13 +376,13 @@ func (ip Addr) isZero() bool {
return ip.z == z0
}
-// IsValid reports whether the Addr is an initialized address (not the zero Addr).
+// IsValid reports whether the [Addr] is an initialized address (not the zero Addr).
//
// Note that "0.0.0.0" and "::" are both valid values.
func (ip Addr) IsValid() bool { return ip.z != z0 }
// BitLen returns the number of bits in the IP address:
-// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr.
+// 128 for IPv6, 32 for IPv4, and 0 for the zero [Addr].
//
// Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses
// and therefore have bit length 128.
@@ -406,7 +407,7 @@ func (ip Addr) Zone() string {
// Compare returns an integer comparing two IPs.
// The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2.
-// The definition of "less than" is the same as the Less method.
+// The definition of "less than" is the same as the [Addr.Less] method.
func (ip Addr) Compare(ip2 Addr) int {
f1, f2 := ip.BitLen(), ip2.BitLen()
if f1 < f2 {
@@ -448,7 +449,7 @@ func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 }
// Is4 reports whether ip is an IPv4 address.
//
-// It returns false for IPv4-mapped IPv6 addresses. See Addr.Unmap.
+// It returns false for IPv4-mapped IPv6 addresses. See [Addr.Unmap].
func (ip Addr) Is4() bool {
return ip.z == z4
}
@@ -582,7 +583,7 @@ func (ip Addr) IsLinkLocalMulticast() bool {
// IANA-allocated 2000::/3 global unicast space, with the exception of the
// link-local address space. It also returns true even if ip is in the IPv4
// private address space or IPv6 unique local address space.
-// It returns false for the zero Addr.
+// It returns false for the zero [Addr].
//
// For reference, see RFC 1122, RFC 4291, and RFC 4632.
func (ip Addr) IsGlobalUnicast() bool {
@@ -606,7 +607,7 @@ func (ip Addr) IsGlobalUnicast() bool {
// IsPrivate reports whether ip is a private address, according to RFC 1918
// (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether
// ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the
-// same as net.IP.IsPrivate.
+// same as [net.IP.IsPrivate].
func (ip Addr) IsPrivate() bool {
// Match the stdlib's IsPrivate logic.
if ip.Is4() {
@@ -629,14 +630,14 @@ func (ip Addr) IsPrivate() bool {
// IsUnspecified reports whether ip is an unspecified address, either the IPv4
// address "0.0.0.0" or the IPv6 address "::".
//
-// Note that the zero Addr is not an unspecified address.
+// Note that the zero [Addr] is not an unspecified address.
func (ip Addr) IsUnspecified() bool {
return ip == IPv4Unspecified() || ip == IPv6Unspecified()
}
// Prefix keeps only the top b bits of IP, producing a Prefix
// of the specified length.
-// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error.
+// If ip is a zero [Addr], Prefix always returns a zero Prefix and a nil error.
// Otherwise, if bits is less than zero or greater than ip.BitLen(),
// Prefix returns an error.
func (ip Addr) Prefix(b int) (Prefix, error) {
@@ -661,15 +662,10 @@ func (ip Addr) Prefix(b int) (Prefix, error) {
return PrefixFrom(ip, b), nil
}
-const (
- netIPv4len = 4
- netIPv6len = 16
-)
-
// As16 returns the IP address in its 16-byte representation.
// IPv4 addresses are returned as IPv4-mapped IPv6 addresses.
// IPv6 addresses with zones are returned without their zone (use the
-// Zone method to get it).
+// [Addr.Zone] method to get it).
// The ip zero value returns all zeroes.
func (ip Addr) As16() (a16 [16]byte) {
bePutUint64(a16[:8], ip.addr.hi)
@@ -678,7 +674,7 @@ func (ip Addr) As16() (a16 [16]byte) {
}
// As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation.
-// If ip is the zero Addr or an IPv6 address, As4 panics.
+// If ip is the zero [Addr] or an IPv6 address, As4 panics.
// Note that 0.0.0.0 is not the zero Addr.
func (ip Addr) As4() (a4 [4]byte) {
if ip.z == z4 || ip.Is4In6() {
@@ -709,7 +705,7 @@ func (ip Addr) AsSlice() []byte {
}
// Next returns the address following ip.
-// If there is none, it returns the zero Addr.
+// If there is none, it returns the zero [Addr].
func (ip Addr) Next() Addr {
ip.addr = ip.addr.addOne()
if ip.Is4() {
@@ -743,10 +739,10 @@ func (ip Addr) Prev() Addr {
// String returns the string form of the IP address ip.
// It returns one of 5 forms:
//
-// - "invalid IP", if ip is the zero Addr
+// - "invalid IP", if ip is the zero [Addr]
// - IPv4 dotted decimal ("192.0.2.1")
// - IPv6 ("2001:db8::1")
-// - "::ffff:1.2.3.4" (if Is4In6)
+// - "::ffff:1.2.3.4" (if [Addr.Is4In6])
// - IPv6 with zone ("fe80:db8::1%eth0")
//
// Note that unlike package net's IP.String method,
@@ -771,7 +767,7 @@ func (ip Addr) String() string {
}
// AppendTo appends a text encoding of ip,
-// as generated by MarshalText,
+// as generated by [Addr.MarshalText],
// to b and returns the extended buffer.
func (ip Addr) AppendTo(b []byte) []byte {
switch ip.z {
@@ -903,7 +899,7 @@ func (ip Addr) appendTo6(ret []byte) []byte {
return ret
}
-// StringExpanded is like String but IPv6 addresses are expanded with leading
+// StringExpanded is like [Addr.String] but IPv6 addresses are expanded with leading
// zeroes and no "::" compression. For example, "2001:db8::1" becomes
// "2001:0db8:0000:0000:0000:0000:0000:0001".
func (ip Addr) StringExpanded() string {
@@ -931,9 +927,9 @@ func (ip Addr) StringExpanded() string {
return string(ret)
}
-// MarshalText implements the encoding.TextMarshaler interface,
-// The encoding is the same as returned by String, with one exception:
-// If ip is the zero Addr, the encoding is the empty string.
+// MarshalText implements the [encoding.TextMarshaler] interface,
+// The encoding is the same as returned by [Addr.String], with one exception:
+// If ip is the zero [Addr], the encoding is the empty string.
func (ip Addr) MarshalText() ([]byte, error) {
switch ip.z {
case z0:
@@ -960,9 +956,9 @@ func (ip Addr) MarshalText() ([]byte, error) {
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
-// The IP address is expected in a form accepted by ParseAddr.
+// The IP address is expected in a form accepted by [ParseAddr].
//
-// If text is empty, UnmarshalText sets *ip to the zero Addr and
+// If text is empty, UnmarshalText sets *ip to the zero [Addr] and
// returns no error.
func (ip *Addr) UnmarshalText(text []byte) error {
if len(text) == 0 {
@@ -992,15 +988,15 @@ func (ip Addr) marshalBinaryWithTrailingBytes(trailingBytes int) []byte {
return b
}
-// MarshalBinary implements the encoding.BinaryMarshaler interface.
-// It returns a zero-length slice for the zero Addr,
+// MarshalBinary implements the [encoding.BinaryMarshaler] interface.
+// It returns a zero-length slice for the zero [Addr],
// the 4-byte form for an IPv4 address,
// and the 16-byte form with zone appended for an IPv6 address.
func (ip Addr) MarshalBinary() ([]byte, error) {
return ip.marshalBinaryWithTrailingBytes(0), nil
}
-// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface.
// It expects data in the form generated by MarshalBinary.
func (ip *Addr) UnmarshalBinary(b []byte) error {
n := len(b)
@@ -1027,7 +1023,7 @@ type AddrPort struct {
port uint16
}
-// AddrPortFrom returns an AddrPort with the provided IP and port.
+// AddrPortFrom returns an [AddrPort] with the provided IP and port.
// It does not allocate.
func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} }
@@ -1043,7 +1039,7 @@ func (p AddrPort) Port() uint16 { return p.port }
// ip string should parse as an IPv6 address or an IPv4 address, in
// order for s to be a valid ip:port string.
func splitAddrPort(s string) (ip, port string, v6 bool, err error) {
- i := stringsLastIndexByte(s, ':')
+ i := bytealg.LastIndexByteString(s, ':')
if i == -1 {
return "", "", false, errors.New("not an ip:port")
}
@@ -1066,7 +1062,7 @@ func splitAddrPort(s string) (ip, port string, v6 bool, err error) {
return ip, port, v6, nil
}
-// ParseAddrPort parses s as an AddrPort.
+// ParseAddrPort parses s as an [AddrPort].
//
// It doesn't do any name resolution: both the address and the port
// must be numeric.
@@ -1093,7 +1089,7 @@ func ParseAddrPort(s string) (AddrPort, error) {
return ipp, nil
}
-// MustParseAddrPort calls ParseAddrPort(s) and panics on error.
+// MustParseAddrPort calls [ParseAddrPort](s) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustParseAddrPort(s string) AddrPort {
ip, err := ParseAddrPort(s)
@@ -1107,36 +1103,35 @@ func MustParseAddrPort(s string) AddrPort {
// All ports are valid, including zero.
func (p AddrPort) IsValid() bool { return p.ip.IsValid() }
+// Compare returns an integer comparing two AddrPorts.
+// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
+// AddrPorts sort first by IP address, then port.
+func (p AddrPort) Compare(p2 AddrPort) int {
+ if c := p.Addr().Compare(p2.Addr()); c != 0 {
+ return c
+ }
+ return cmp.Compare(p.Port(), p2.Port())
+}
+
func (p AddrPort) String() string {
switch p.ip.z {
case z0:
return "invalid AddrPort"
case z4:
- a := p.ip.As4()
- buf := make([]byte, 0, 21)
- for i := range a {
- buf = strconv.AppendUint(buf, uint64(a[i]), 10)
- buf = append(buf, "...:"[i])
- }
+ const max = len("255.255.255.255:65535")
+ buf := make([]byte, 0, max)
+ buf = p.ip.appendTo4(buf)
+ buf = append(buf, ':')
buf = strconv.AppendUint(buf, uint64(p.port), 10)
return string(buf)
default:
// TODO: this could be more efficient allocation-wise:
- return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port)))
+ return "[" + p.ip.String() + "]:" + itoa.Uitoa(uint(p.port))
}
}
-func joinHostPort(host, port string) string {
- // We assume that host is a literal IPv6 address if host has
- // colons.
- if bytealg.IndexByteString(host, ':') >= 0 {
- return "[" + host + "]:" + port
- }
- return host + ":" + port
-}
-
// AppendTo appends a text encoding of p,
-// as generated by MarshalText,
+// as generated by [AddrPort.MarshalText],
// to b and returns the extended buffer.
func (p AddrPort) AppendTo(b []byte) []byte {
switch p.ip.z {
@@ -1163,9 +1158,9 @@ func (p AddrPort) AppendTo(b []byte) []byte {
return b
}
-// MarshalText implements the encoding.TextMarshaler interface. The
-// encoding is the same as returned by String, with one exception: if
-// p.Addr() is the zero Addr, the encoding is the empty string.
+// MarshalText implements the [encoding.TextMarshaler] interface. The
+// encoding is the same as returned by [AddrPort.String], with one exception: if
+// p.Addr() is the zero [Addr], the encoding is the empty string.
func (p AddrPort) MarshalText() ([]byte, error) {
var max int
switch p.ip.z {
@@ -1181,8 +1176,8 @@ func (p AddrPort) MarshalText() ([]byte, error) {
}
// UnmarshalText implements the encoding.TextUnmarshaler
-// interface. The AddrPort is expected in a form
-// generated by MarshalText or accepted by ParseAddrPort.
+// interface. The [AddrPort] is expected in a form
+// generated by [AddrPort.MarshalText] or accepted by [ParseAddrPort].
func (p *AddrPort) UnmarshalText(text []byte) error {
if len(text) == 0 {
*p = AddrPort{}
@@ -1193,8 +1188,8 @@ func (p *AddrPort) UnmarshalText(text []byte) error {
return err
}
-// MarshalBinary implements the encoding.BinaryMarshaler interface.
-// It returns Addr.MarshalBinary with an additional two bytes appended
+// MarshalBinary implements the [encoding.BinaryMarshaler] interface.
+// It returns [Addr.MarshalBinary] with an additional two bytes appended
// containing the port in little-endian.
func (p AddrPort) MarshalBinary() ([]byte, error) {
b := p.Addr().marshalBinaryWithTrailingBytes(2)
@@ -1202,8 +1197,8 @@ func (p AddrPort) MarshalBinary() ([]byte, error) {
return b, nil
}
-// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
-// It expects data in the form generated by MarshalBinary.
+// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface.
+// It expects data in the form generated by [AddrPort.MarshalBinary].
func (p *AddrPort) UnmarshalBinary(b []byte) error {
if len(b) < 2 {
return errors.New("unexpected slice size")
@@ -1219,7 +1214,7 @@ func (p *AddrPort) UnmarshalBinary(b []byte) error {
// Prefix is an IP address prefix (CIDR) representing an IP network.
//
-// The first Bits() of Addr() are specified. The remaining bits match any address.
+// The first [Prefix.Bits]() of [Addr]() are specified. The remaining bits match any address.
// The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6.
type Prefix struct {
ip Addr
@@ -1229,13 +1224,13 @@ type Prefix struct {
bitsPlusOne uint8
}
-// PrefixFrom returns a Prefix with the provided IP address and bit
+// PrefixFrom returns a [Prefix] with the provided IP address and bit
// prefix length.
//
-// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask
+// It does not allocate. Unlike [Addr.Prefix], [PrefixFrom] does not mask
// off the host bits of ip.
//
-// If bits is less than zero or greater than ip.BitLen, Prefix.Bits
+// If bits is less than zero or greater than ip.BitLen, [Prefix.Bits]
// will return an invalid value -1.
func PrefixFrom(ip Addr, bits int) Prefix {
var bitsPlusOne uint8
@@ -1257,8 +1252,8 @@ func (p Prefix) Addr() Addr { return p.ip }
func (p Prefix) Bits() int { return int(p.bitsPlusOne) - 1 }
// IsValid reports whether p.Bits() has a valid range for p.Addr().
-// If p.Addr() is the zero Addr, IsValid returns false.
-// Note that if p is the zero Prefix, then p.IsValid() == false.
+// If p.Addr() is the zero [Addr], IsValid returns false.
+// Note that if p is the zero [Prefix], then p.IsValid() == false.
func (p Prefix) IsValid() bool { return p.bitsPlusOne > 0 }
func (p Prefix) isZero() bool { return p == Prefix{} }
@@ -1266,6 +1261,24 @@ func (p Prefix) isZero() bool { return p == Prefix{} }
// IsSingleIP reports whether p contains exactly one IP.
func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() }
+// compare returns an integer comparing two prefixes.
+// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
+// Prefixes sort first by validity (invalid before valid), then
+// address family (IPv4 before IPv6), then prefix length, then
+// address.
+//
+// Unexported for Go 1.22 because we may want to compare by p.Addr first.
+// See post-acceptance discussion on go.dev/issue/61642.
+func (p Prefix) compare(p2 Prefix) int {
+ if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
+ return c
+ }
+ if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
+ return c
+ }
+ return p.Addr().Compare(p2.Addr())
+}
+
// ParsePrefix parses s as an IP address prefix.
// The string can be in the form "192.168.1.0/24" or "2001:db8::/32",
// the CIDR notation defined in RFC 4632 and RFC 4291.
@@ -1274,7 +1287,7 @@ func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLe
//
// Note that masked address bits are not zeroed. Use Masked for that.
func ParsePrefix(s string) (Prefix, error) {
- i := stringsLastIndexByte(s, '/')
+ i := bytealg.LastIndexByteString(s, '/')
if i < 0 {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'")
}
@@ -1288,6 +1301,12 @@ func ParsePrefix(s string) (Prefix, error) {
}
bitsStr := s[i+1:]
+
+ // strconv.Atoi accepts a leading sign and leading zeroes, but we don't want that.
+ if len(bitsStr) > 1 && (bitsStr[0] < '1' || bitsStr[0] > '9') {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): bad bits after slash: " + strconv.Quote(bitsStr))
+ }
+
bits, err := strconv.Atoi(bitsStr)
if err != nil {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): bad bits after slash: " + strconv.Quote(bitsStr))
@@ -1302,7 +1321,7 @@ func ParsePrefix(s string) (Prefix, error) {
return PrefixFrom(ip, bits), nil
}
-// MustParsePrefix calls ParsePrefix(s) and panics on error.
+// MustParsePrefix calls [ParsePrefix](s) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustParsePrefix(s string) Prefix {
ip, err := ParsePrefix(s)
@@ -1315,7 +1334,7 @@ func MustParsePrefix(s string) Prefix {
// Masked returns p in its canonical form, with all but the high
// p.Bits() bits of p.Addr() masked off.
//
-// If p is zero or otherwise invalid, Masked returns the zero Prefix.
+// If p is zero or otherwise invalid, Masked returns the zero [Prefix].
func (p Prefix) Masked() Prefix {
m, _ := p.ip.Prefix(p.Bits())
return m
@@ -1392,7 +1411,7 @@ func (p Prefix) Overlaps(o Prefix) bool {
}
// AppendTo appends a text encoding of p,
-// as generated by MarshalText,
+// as generated by [Prefix.MarshalText],
// to b and returns the extended buffer.
func (p Prefix) AppendTo(b []byte) []byte {
if p.isZero() {
@@ -1419,8 +1438,8 @@ func (p Prefix) AppendTo(b []byte) []byte {
return b
}
-// MarshalText implements the encoding.TextMarshaler interface,
-// The encoding is the same as returned by String, with one exception:
+// MarshalText implements the [encoding.TextMarshaler] interface,
+// The encoding is the same as returned by [Prefix.String], with one exception:
// If p is the zero value, the encoding is the empty string.
func (p Prefix) MarshalText() ([]byte, error) {
var max int
@@ -1437,8 +1456,8 @@ func (p Prefix) MarshalText() ([]byte, error) {
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
-// The IP address is expected in a form accepted by ParsePrefix
-// or generated by MarshalText.
+// The IP address is expected in a form accepted by [ParsePrefix]
+// or generated by [Prefix.MarshalText].
func (p *Prefix) UnmarshalText(text []byte) error {
if len(text) == 0 {
*p = Prefix{}
@@ -1449,8 +1468,8 @@ func (p *Prefix) UnmarshalText(text []byte) error {
return err
}
-// MarshalBinary implements the encoding.BinaryMarshaler interface.
-// It returns Addr.MarshalBinary with an additional byte appended
+// MarshalBinary implements the [encoding.BinaryMarshaler] interface.
+// It returns [Addr.MarshalBinary] with an additional byte appended
// containing the prefix bits.
func (p Prefix) MarshalBinary() ([]byte, error) {
b := p.Addr().withoutZone().marshalBinaryWithTrailingBytes(1)
@@ -1458,8 +1477,8 @@ func (p Prefix) MarshalBinary() ([]byte, error) {
return b, nil
}
-// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
-// It expects data in the form generated by MarshalBinary.
+// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface.
+// It expects data in the form generated by [Prefix.MarshalBinary].
func (p *Prefix) UnmarshalBinary(b []byte) error {
if len(b) < 1 {
return errors.New("unexpected slice size")
diff --git a/src/net/netip/netip_test.go b/src/net/netip/netip_test.go
index 0f80bb0ab0..a748ac34f1 100644
--- a/src/net/netip/netip_test.go
+++ b/src/net/netip/netip_test.go
@@ -14,6 +14,7 @@ import (
"net"
. "net/netip"
"reflect"
+ "slices"
"sort"
"strings"
"testing"
@@ -389,6 +390,7 @@ func TestAddrPortMarshalTextString(t *testing.T) {
want string
}{
{mustIPPort("1.2.3.4:80"), "1.2.3.4:80"},
+ {mustIPPort("[::]:80"), "[::]:80"},
{mustIPPort("[1::CAFE]:80"), "[1::cafe]:80"},
{mustIPPort("[1::CAFE%en0]:80"), "[1::cafe%en0]:80"},
{mustIPPort("[::FFFF:192.168.140.255]:80"), "[::ffff:192.168.140.255]:80"},
@@ -812,7 +814,7 @@ func TestAddrWellKnown(t *testing.T) {
}
}
-func TestLessCompare(t *testing.T) {
+func TestAddrLessCompare(t *testing.T) {
tests := []struct {
a, b Addr
want bool
@@ -882,6 +884,109 @@ func TestLessCompare(t *testing.T) {
}
}
+func TestAddrPortCompare(t *testing.T) {
+ tests := []struct {
+ a, b AddrPort
+ want int
+ }{
+ {AddrPort{}, AddrPort{}, 0},
+ {AddrPort{}, mustIPPort("1.2.3.4:80"), -1},
+
+ {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0},
+ {mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0},
+
+ {mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1},
+ {mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1},
+
+ {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1},
+ {mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1},
+
+ {mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1},
+ }
+ for _, tt := range tests {
+ got := tt.a.Compare(tt.b)
+ if got != tt.want {
+ t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
+ }
+
+ // Also check inverse.
+ if got == tt.want {
+ got2 := tt.b.Compare(tt.a)
+ if want2 := -1 * tt.want; got2 != want2 {
+ t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
+ }
+ }
+ }
+
+ // And just sort.
+ values := []AddrPort{
+ mustIPPort("[::1]:80"),
+ mustIPPort("[::2]:80"),
+ AddrPort{},
+ mustIPPort("1.2.3.4:443"),
+ mustIPPort("8.8.8.8:8080"),
+ mustIPPort("[::1%foo]:1024"),
+ }
+ slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) })
+ got := fmt.Sprintf("%s", values)
+ want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]`
+ if got != want {
+ t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
+ }
+}
+
+func TestPrefixCompare(t *testing.T) {
+ tests := []struct {
+ a, b Prefix
+ want int
+ }{
+ {Prefix{}, Prefix{}, 0},
+ {Prefix{}, mustPrefix("1.2.3.0/24"), -1},
+
+ {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0},
+ {mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0},
+
+ {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1},
+ {mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1},
+
+ {mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1},
+ {mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1},
+
+ {mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1},
+ }
+ for _, tt := range tests {
+ got := tt.a.Compare(tt.b)
+ if got != tt.want {
+ t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
+ }
+
+ // Also check inverse.
+ if got == tt.want {
+ got2 := tt.b.Compare(tt.a)
+ if want2 := -1 * tt.want; got2 != want2 {
+ t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
+ }
+ }
+ }
+
+ // And just sort.
+ values := []Prefix{
+ mustPrefix("1.2.3.0/24"),
+ mustPrefix("fe90::/64"),
+ mustPrefix("fe80::/64"),
+ mustPrefix("1.2.0.0/16"),
+ Prefix{},
+ mustPrefix("fe80::/48"),
+ mustPrefix("1.2.0.0/24"),
+ }
+ slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) })
+ got := fmt.Sprintf("%s", values)
+ want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]`
+ if got != want {
+ t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
+ }
+}
+
func TestIPStringExpanded(t *testing.T) {
tests := []struct {
ip Addr
@@ -1352,7 +1457,7 @@ func TestParsePrefixError(t *testing.T) {
},
{
prefix: "1.1.1.0/-1",
- errstr: "out of range",
+ errstr: "bad bits",
},
{
prefix: "1.1.1.0/33",
@@ -1371,6 +1476,22 @@ func TestParsePrefixError(t *testing.T) {
prefix: "2001:db8::%a/32",
errstr: "zones cannot be present",
},
+ {
+ prefix: "1.1.1.0/+32",
+ errstr: "bad bits",
+ },
+ {
+ prefix: "1.1.1.0/-32",
+ errstr: "bad bits",
+ },
+ {
+ prefix: "1.1.1.0/032",
+ errstr: "bad bits",
+ },
+ {
+ prefix: "1.1.1.0/0032",
+ errstr: "bad bits",
+ },
}
for _, test := range tests {
t.Run(test.prefix, func(t *testing.T) {
diff --git a/src/net/packetconn_test.go b/src/net/packetconn_test.go
index dc0c14b93d..e39e7de5d7 100644
--- a/src/net/packetconn_test.go
+++ b/src/net/packetconn_test.go
@@ -2,10 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// This file implements API tests across platforms and will never have a build
-// tag.
-
-//go:build !js && !wasip1
+// This file implements API tests across platforms and should never have a build
+// constraint.
package net
diff --git a/src/net/parse.go b/src/net/parse.go
index fbc50144c2..29dffad43c 100644
--- a/src/net/parse.go
+++ b/src/net/parse.go
@@ -180,42 +180,6 @@ func xtoi2(s string, e byte) (byte, bool) {
return byte(n), ok && ei == 2
}
-// Convert i to a hexadecimal string. Leading zeros are not printed.
-func appendHex(dst []byte, i uint32) []byte {
- if i == 0 {
- return append(dst, '0')
- }
- for j := 7; j >= 0; j-- {
- v := i >> uint(j*4)
- if v > 0 {
- dst = append(dst, hexDigit[v&0xf])
- }
- }
- return dst
-}
-
-// Number of occurrences of b in s.
-func count(s string, b byte) int {
- n := 0
- for i := 0; i < len(s); i++ {
- if s[i] == b {
- n++
- }
- }
- return n
-}
-
-// Index of rightmost occurrence of b in s.
-func last(s string, b byte) int {
- i := len(s)
- for i--; i >= 0; i-- {
- if s[i] == b {
- break
- }
- }
- return i
-}
-
// hasUpperCase tells whether the given string contains at least one upper-case.
func hasUpperCase(s string) bool {
for i := range s {
diff --git a/src/net/pipe.go b/src/net/pipe.go
index f1741938b0..69955e4617 100644
--- a/src/net/pipe.go
+++ b/src/net/pipe.go
@@ -106,7 +106,7 @@ type pipe struct {
}
// Pipe creates a synchronous, in-memory, full duplex
-// network connection; both ends implement the Conn interface.
+// network connection; both ends implement the [Conn] interface.
// Reads on one end are matched with writes on the other,
// copying data directly between the two; there is no internal
// buffering.
diff --git a/src/net/platform_test.go b/src/net/platform_test.go
index 71e90821ce..709d4a3eb7 100644
--- a/src/net/platform_test.go
+++ b/src/net/platform_test.go
@@ -165,7 +165,7 @@ func condFatalf(t *testing.T, network string, format string, args ...any) {
// A few APIs like File and Read/WriteMsg{UDP,IP} are not
// fully implemented yet on Plan 9 and Windows.
switch runtime.GOOS {
- case "windows":
+ case "windows", "js", "wasip1":
if network == "file+net" {
t.Logf(format, args...)
return
diff --git a/src/net/port_unix.go b/src/net/port_unix.go
index 0b2ea3ec5d..df73dbabb3 100644
--- a/src/net/port_unix.go
+++ b/src/net/port_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1
+//go:build unix || js || wasip1
// Read system port mappings from /etc/services
diff --git a/src/net/protoconn_test.go b/src/net/protoconn_test.go
index c5668079a9..a617470580 100644
--- a/src/net/protoconn_test.go
+++ b/src/net/protoconn_test.go
@@ -5,8 +5,6 @@
// This file implements API tests across platforms and will never have a build
// tag.
-//go:build !js && !wasip1
-
package net
import (
@@ -39,7 +37,7 @@ func TestTCPListenerSpecificMethods(t *testing.T) {
}
defer ln.Close()
ln.Addr()
- ln.SetDeadline(time.Now().Add(30 * time.Nanosecond))
+ mustSetDeadline(t, ln.SetDeadline, 30*time.Nanosecond)
if c, err := ln.Accept(); err != nil {
if !err.(Error).Timeout() {
@@ -162,6 +160,10 @@ func TestUDPConnSpecificMethods(t *testing.T) {
}
func TestIPConnSpecificMethods(t *testing.T) {
+ if !testableNetwork("ip4") {
+ t.Skip("skipping: ip4 not supported")
+ }
+
la, err := ResolveIPAddr("ip4", "127.0.0.1")
if err != nil {
t.Fatal(err)
@@ -217,7 +219,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) {
defer ln.Close()
defer os.Remove(addr)
ln.Addr()
- ln.SetDeadline(time.Now().Add(30 * time.Nanosecond))
+ mustSetDeadline(t, ln.SetDeadline, 30*time.Nanosecond)
if c, err := ln.Accept(); err != nil {
if !err.(Error).Timeout() {
@@ -235,7 +237,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) {
}
if f, err := ln.File(); err != nil {
- t.Fatal(err)
+ condFatalf(t, "file+net", "%v", err)
} else {
f.Close()
}
@@ -332,7 +334,7 @@ func TestUnixConnSpecificMethods(t *testing.T) {
}
if f, err := c1.File(); err != nil {
- t.Fatal(err)
+ condFatalf(t, "file+net", "%v", err)
} else {
f.Close()
}
diff --git a/src/net/rawconn.go b/src/net/rawconn.go
index 974320c25f..19228e94ed 100644
--- a/src/net/rawconn.go
+++ b/src/net/rawconn.go
@@ -63,7 +63,7 @@ func (c *rawConn) Write(f func(uintptr) bool) error {
// PollFD returns the poll.FD of the underlying connection.
//
-// Other packages in std that also import internal/poll (such as os)
+// Other packages in std that also import [internal/poll] (such as os)
// can use a type assertion to access this extension method so that
// they can pass the *poll.FD to functions like poll.Splice.
//
@@ -75,8 +75,19 @@ func (c *rawConn) PollFD() *poll.FD {
return &c.fd.pfd
}
-func newRawConn(fd *netFD) (*rawConn, error) {
- return &rawConn{fd: fd}, nil
+func newRawConn(fd *netFD) *rawConn {
+ return &rawConn{fd: fd}
+}
+
+// Network returns the network type of the underlying connection.
+//
+// Other packages in std that import internal/poll and are unable to
+// import net (such as os) can use a type assertion to access this
+// extension method so that they can distinguish different socket types.
+//
+// Network is not intended for use outside the standard library.
+func (c *rawConn) Network() poll.String {
+ return poll.String(c.fd.net)
}
type rawListener struct {
@@ -91,6 +102,6 @@ func (l *rawListener) Write(func(uintptr) bool) error {
return syscall.EINVAL
}
-func newRawListener(fd *netFD) (*rawListener, error) {
- return &rawListener{rawConn{fd: fd}}, nil
+func newRawListener(fd *netFD) *rawListener {
+ return &rawListener{rawConn{fd: fd}}
}
diff --git a/src/net/rawconn_stub_test.go b/src/net/rawconn_stub_test.go
index c8ad80cc84..6d54f2df55 100644
--- a/src/net/rawconn_stub_test.go
+++ b/src/net/rawconn_stub_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9 || wasip1
+//go:build js || plan9 || wasip1
package net
diff --git a/src/net/rawconn_test.go b/src/net/rawconn_test.go
index 06d5856a9a..70b16c4115 100644
--- a/src/net/rawconn_test.go
+++ b/src/net/rawconn_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -15,7 +13,7 @@ import (
func TestRawConnReadWrite(t *testing.T) {
switch runtime.GOOS {
- case "plan9":
+ case "plan9", "js", "wasip1":
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -169,7 +167,7 @@ func TestRawConnReadWrite(t *testing.T) {
func TestRawConnControl(t *testing.T) {
switch runtime.GOOS {
- case "plan9":
+ case "plan9", "js", "wasip1":
t.Skipf("not supported on %s", runtime.GOOS)
}
diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go
index 1de0402389..1af4199269 100644
--- a/src/net/resolverdialfunc_test.go
+++ b/src/net/resolverdialfunc_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
// Test that Resolver.Dial can be a func returning an in-memory net.Conn
// speaking DNS.
diff --git a/src/net/rlimit_js.go b/src/net/rlimit_js.go
new file mode 100644
index 0000000000..9ee5748b21
--- /dev/null
+++ b/src/net/rlimit_js.go
@@ -0,0 +1,13 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js
+
+package net
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/src/net/rlimit_unix.go b/src/net/rlimit_unix.go
new file mode 100644
index 0000000000..0094756e3a
--- /dev/null
+++ b/src/net/rlimit_unix.go
@@ -0,0 +1,33 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || wasip1
+
+package net
+
+import "syscall"
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups via cgo. A DNS lookup may use a
+// file descriptor so we limit this to less than the number of
+// permitted open files. On some systems, notably Darwin, if
+// getaddrinfo is unable to open a file descriptor it simply returns
+// EAI_NONAME rather than a useful error. Limiting the number of
+// concurrent getaddrinfo calls to less than the permitted number of
+// file descriptors makes that error less likely. We don't bother to
+// apply the same limit to DNS lookups run directly from Go, because
+// there we will return a meaningful "too many open files" error.
+func concurrentThreadsLimit() int {
+ var rlim syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil {
+ return 500
+ }
+ r := rlim.Cur
+ if r > 500 {
+ r = 500
+ } else if r > 30 {
+ r -= 30
+ }
+ return int(r)
+}
diff --git a/src/net/rpc/client.go b/src/net/rpc/client.go
index 42d13519b1..ffdc435965 100644
--- a/src/net/rpc/client.go
+++ b/src/net/rpc/client.go
@@ -53,13 +53,13 @@ type Client struct {
// A ClientCodec implements writing of RPC requests and
// reading of RPC responses for the client side of an RPC session.
-// The client calls WriteRequest to write a request to the connection
-// and calls ReadResponseHeader and ReadResponseBody in pairs
-// to read responses. The client calls Close when finished with the
+// The client calls [ClientCodec.WriteRequest] to write a request to the connection
+// and calls [ClientCodec.ReadResponseHeader] and [ClientCodec.ReadResponseBody] in pairs
+// to read responses. The client calls [ClientCodec.Close] when finished with the
// connection. ReadResponseBody may be called with a nil
// argument to force the body of the response to be read and then
// discarded.
-// See NewClient's comment for information about concurrent access.
+// See [NewClient]'s comment for information about concurrent access.
type ClientCodec interface {
WriteRequest(*Request, any) error
ReadResponseHeader(*Response) error
@@ -181,7 +181,7 @@ func (call *Call) done() {
}
}
-// NewClient returns a new Client to handle requests to the
+// NewClient returns a new [Client] to handle requests to the
// set of services at the other end of the connection.
// It adds a buffer to the write side of the connection so
// the header and payload are sent as a unit.
@@ -196,7 +196,7 @@ func NewClient(conn io.ReadWriteCloser) *Client {
return NewClientWithCodec(client)
}
-// NewClientWithCodec is like NewClient but uses the specified
+// NewClientWithCodec is like [NewClient] but uses the specified
// codec to encode requests and decode responses.
func NewClientWithCodec(codec ClientCodec) *Client {
client := &Client{
@@ -279,7 +279,7 @@ func Dial(network, address string) (*Client, error) {
}
// Close calls the underlying codec's Close method. If the connection is already
-// shutting down, ErrShutdown is returned.
+// shutting down, [ErrShutdown] is returned.
func (client *Client) Close() error {
client.mutex.Lock()
if client.closing {
@@ -291,7 +291,7 @@ func (client *Client) Close() error {
return client.codec.Close()
}
-// Go invokes the function asynchronously. It returns the Call structure representing
+// Go invokes the function asynchronously. It returns the [Call] structure representing
// the invocation. The done channel will signal when the call is complete by returning
// the same Call object. If done is nil, Go will allocate a new channel.
// If non-nil, done must be buffered or Go will deliberately crash.
diff --git a/src/net/rpc/jsonrpc/client.go b/src/net/rpc/jsonrpc/client.go
index c473017d26..1beba0f364 100644
--- a/src/net/rpc/jsonrpc/client.go
+++ b/src/net/rpc/jsonrpc/client.go
@@ -33,7 +33,7 @@ type clientCodec struct {
pending map[uint64]string // map request id to method name
}
-// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
+// NewClientCodec returns a new [rpc.ClientCodec] using JSON-RPC on conn.
func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
return &clientCodec{
dec: json.NewDecoder(conn),
@@ -108,7 +108,7 @@ func (c *clientCodec) Close() error {
return c.c.Close()
}
-// NewClient returns a new rpc.Client to handle requests to the
+// NewClient returns a new [rpc.Client] to handle requests to the
// set of services at the other end of the connection.
func NewClient(conn io.ReadWriteCloser) *rpc.Client {
return rpc.NewClientWithCodec(NewClientCodec(conn))
diff --git a/src/net/rpc/jsonrpc/server.go b/src/net/rpc/jsonrpc/server.go
index 3ee4ddfef2..57a4de1d0f 100644
--- a/src/net/rpc/jsonrpc/server.go
+++ b/src/net/rpc/jsonrpc/server.go
@@ -33,7 +33,7 @@ type serverCodec struct {
pending map[uint64]*json.RawMessage
}
-// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
+// NewServerCodec returns a new [rpc.ServerCodec] using JSON-RPC on conn.
func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
return &serverCodec{
dec: json.NewDecoder(conn),
diff --git a/src/net/rpc/server.go b/src/net/rpc/server.go
index 5cea2cc507..1771726a93 100644
--- a/src/net/rpc/server.go
+++ b/src/net/rpc/server.go
@@ -30,17 +30,17 @@ These requirements apply even if a different codec is used.
The method's first argument represents the arguments provided by the caller; the
second argument represents the result parameters to be returned to the caller.
The method's return value, if non-nil, is passed back as a string that the client
-sees as if created by errors.New. If an error is returned, the reply parameter
+sees as if created by [errors.New]. If an error is returned, the reply parameter
will not be sent back to the client.
-The server may handle requests on a single connection by calling ServeConn. More
-typically it will create a network listener and call Accept or, for an HTTP
-listener, HandleHTTP and http.Serve.
+The server may handle requests on a single connection by calling [ServeConn]. More
+typically it will create a network listener and call [Accept] or, for an HTTP
+listener, [HandleHTTP] and [http.Serve].
A client wishing to use the service establishes a connection and then invokes
-NewClient on the connection. The convenience function Dial (DialHTTP) performs
+[NewClient] on the connection. The convenience function [Dial] ([DialHTTP]) performs
both steps for a raw network connection (an HTTP connection). The resulting
-Client object has two methods, Call and Go, that specify the service and method to
+[Client] object has two methods, [Call] and Go, that specify the service and method to
call, a pointer containing the arguments, and a pointer to receive the result
parameters.
@@ -48,7 +48,7 @@ The Call method waits for the remote call to complete while the Go method
launches the call asynchronously and signals completion using the Call
structure's Done channel.
-Unless an explicit codec is set up, package encoding/gob is used to
+Unless an explicit codec is set up, package [encoding/gob] is used to
transport the data.
Here is a simple example. A server wishes to export an object of type Arith:
@@ -146,9 +146,8 @@ const (
DefaultDebugPath = "/debug/rpc"
)
-// Precompute the reflect type for error. Can't use error directly
-// because Typeof takes an empty interface value. This is annoying.
-var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
+// Precompute the reflect type for error.
+var typeOfError = reflect.TypeFor[error]()
type methodType struct {
sync.Mutex // protects counters
@@ -193,12 +192,12 @@ type Server struct {
freeResp *Response
}
-// NewServer returns a new Server.
+// NewServer returns a new [Server].
func NewServer() *Server {
return &Server{}
}
-// DefaultServer is the default instance of *Server.
+// DefaultServer is the default instance of [*Server].
var DefaultServer = NewServer()
// Is this type exported or a builtin?
@@ -226,7 +225,7 @@ func (server *Server) Register(rcvr any) error {
return server.register(rcvr, "", false)
}
-// RegisterName is like Register but uses the provided name for the type
+// RegisterName is like [Register] but uses the provided name for the type
// instead of the receiver's concrete type.
func (server *Server) RegisterName(name string, rcvr any) error {
return server.register(rcvr, name, true)
@@ -441,8 +440,8 @@ func (c *gobServerCodec) Close() error {
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
-// connection. To use an alternate codec, use ServeCodec.
-// See NewClient's comment for information about concurrent access.
+// connection. To use an alternate codec, use [ServeCodec].
+// See [NewClient]'s comment for information about concurrent access.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
buf := bufio.NewWriter(conn)
srv := &gobServerCodec{
@@ -454,7 +453,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
server.ServeCodec(srv)
}
-// ServeCodec is like ServeConn but uses the specified codec to
+// ServeCodec is like [ServeConn] but uses the specified codec to
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
@@ -484,7 +483,7 @@ func (server *Server) ServeCodec(codec ServerCodec) {
codec.Close()
}
-// ServeRequest is like ServeCodec but synchronously serves a single request.
+// ServeRequest is like [ServeCodec] but synchronously serves a single request.
// It does not close the codec upon completion.
func (server *Server) ServeRequest(codec ServerCodec) error {
sending := new(sync.Mutex)
@@ -636,10 +635,10 @@ func (server *Server) Accept(lis net.Listener) {
}
}
-// Register publishes the receiver's methods in the DefaultServer.
+// Register publishes the receiver's methods in the [DefaultServer].
func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
-// RegisterName is like Register but uses the provided name for the type
+// RegisterName is like [Register] but uses the provided name for the type
// instead of the receiver's concrete type.
func RegisterName(name string, rcvr any) error {
return DefaultServer.RegisterName(name, rcvr)
@@ -647,12 +646,12 @@ func RegisterName(name string, rcvr any) error {
// A ServerCodec implements reading of RPC requests and writing of
// RPC responses for the server side of an RPC session.
-// The server calls ReadRequestHeader and ReadRequestBody in pairs
-// to read requests from the connection, and it calls WriteResponse to
-// write a response back. The server calls Close when finished with the
+// The server calls [ServerCodec.ReadRequestHeader] and [ServerCodec.ReadRequestBody] in pairs
+// to read requests from the connection, and it calls [ServerCodec.WriteResponse] to
+// write a response back. The server calls [ServerCodec.Close] when finished with the
// connection. ReadRequestBody may be called with a nil
// argument to force the body of the request to be read and discarded.
-// See NewClient's comment for information about concurrent access.
+// See [NewClient]'s comment for information about concurrent access.
type ServerCodec interface {
ReadRequestHeader(*Request) error
ReadRequestBody(any) error
@@ -662,37 +661,37 @@ type ServerCodec interface {
Close() error
}
-// ServeConn runs the DefaultServer on a single connection.
+// ServeConn runs the [DefaultServer] on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
-// connection. To use an alternate codec, use ServeCodec.
-// See NewClient's comment for information about concurrent access.
+// connection. To use an alternate codec, use [ServeCodec].
+// See [NewClient]'s comment for information about concurrent access.
func ServeConn(conn io.ReadWriteCloser) {
DefaultServer.ServeConn(conn)
}
-// ServeCodec is like ServeConn but uses the specified codec to
+// ServeCodec is like [ServeConn] but uses the specified codec to
// decode requests and encode responses.
func ServeCodec(codec ServerCodec) {
DefaultServer.ServeCodec(codec)
}
-// ServeRequest is like ServeCodec but synchronously serves a single request.
+// ServeRequest is like [ServeCodec] but synchronously serves a single request.
// It does not close the codec upon completion.
func ServeRequest(codec ServerCodec) error {
return DefaultServer.ServeRequest(codec)
}
// Accept accepts connections on the listener and serves requests
-// to DefaultServer for each incoming connection.
+// to [DefaultServer] for each incoming connection.
// Accept blocks; the caller typically invokes it in a go statement.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Can connect to RPC service using HTTP CONNECT to rpcPath.
var connected = "200 Connected to Go RPC"
-// ServeHTTP implements an http.Handler that answers RPC requests.
+// ServeHTTP implements an [http.Handler] that answers RPC requests.
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
@@ -711,15 +710,15 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
// and a debugging handler on debugPath.
-// It is still necessary to invoke http.Serve(), typically in a go statement.
+// It is still necessary to invoke [http.Serve](), typically in a go statement.
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
http.Handle(rpcPath, server)
http.Handle(debugPath, debugHTTP{server})
}
-// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
-// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
-// It is still necessary to invoke http.Serve(), typically in a go statement.
+// HandleHTTP registers an HTTP handler for RPC messages to [DefaultServer]
+// on [DefaultRPCPath] and a debugging handler on [DefaultDebugPath].
+// It is still necessary to invoke [http.Serve](), typically in a go statement.
func HandleHTTP() {
DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}
diff --git a/src/net/sendfile_linux_test.go b/src/net/sendfile_linux_test.go
index 8cd6acca17..7a66d3645f 100644
--- a/src/net/sendfile_linux_test.go
+++ b/src/net/sendfile_linux_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build linux
-// +build linux
package net
@@ -15,29 +14,36 @@ import (
)
func BenchmarkSendFile(b *testing.B) {
+ b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
+ b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
+}
+
+func benchmarkSendFile(b *testing.B, proto string) {
for i := 0; i <= 10; i++ {
size := 1 << (i + 10)
- bench := sendFileBench{chunkSize: size}
+ bench := sendFileBench{
+ proto: proto,
+ chunkSize: size,
+ }
b.Run(strconv.Itoa(size), bench.benchSendFile)
}
}
type sendFileBench struct {
+ proto string
chunkSize int
}
func (bench sendFileBench) benchSendFile(b *testing.B) {
fileSize := b.N * bench.chunkSize
f := createTempFile(b, fileSize)
- fileName := f.Name()
- defer os.Remove(fileName)
- defer f.Close()
- client, server := spliceTestSocketPair(b, "tcp")
+ client, server := spliceTestSocketPair(b, bench.proto)
defer server.Close()
cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize)
if err != nil {
+ client.Close()
b.Fatal(err)
}
defer cleanUp()
@@ -52,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) {
b.Fatalf("failed to copy data with sendfile, error: %v", err)
}
if sent != int64(fileSize) {
- b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent)
+ b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
}
}
func createTempFile(b *testing.B, size int) *os.File {
- f, err := os.CreateTemp("", "linux-sendfile-test")
+ f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench")
if err != nil {
b.Fatalf("failed to create temporary file: %v", err)
}
+ b.Cleanup(func() {
+ f.Close()
+ })
data := make([]byte, size)
if _, err := f.Write(data); err != nil {
diff --git a/src/net/sendfile_stub.go b/src/net/sendfile_stub.go
index c7a2e6a1e4..a4fdd99ffe 100644
--- a/src/net/sendfile_stub.go
+++ b/src/net/sendfile_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build aix || (js && wasm) || netbsd || openbsd || ios || wasip1
+//go:build aix || js || netbsd || openbsd || ios || wasip1
package net
diff --git a/src/net/sendfile_test.go b/src/net/sendfile_test.go
index 44a87a1d20..4cba1ed2b1 100644
--- a/src/net/sendfile_test.go
+++ b/src/net/sendfile_test.go
@@ -2,12 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
"bytes"
+ "context"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -209,7 +208,7 @@ func TestSendfileSeeked(t *testing.T) {
// Test that sendfile doesn't put a pipe into blocking mode.
func TestSendfilePipe(t *testing.T) {
switch runtime.GOOS {
- case "plan9", "windows":
+ case "plan9", "windows", "js", "wasip1":
// These systems don't support deadlines on pipes.
t.Skipf("skipping on %s", runtime.GOOS)
}
@@ -362,3 +361,88 @@ func TestSendfileOnWriteTimeoutExceeded(t *testing.T) {
t.Fatal(err)
}
}
+
+func BenchmarkSendfileZeroBytes(b *testing.B) {
+ var (
+ wg sync.WaitGroup
+ ctx, cancel = context.WithCancel(context.Background())
+ )
+
+ defer wg.Wait()
+
+ ln := newLocalListener(b, "tcp")
+ defer ln.Close()
+
+ tempFile, err := os.CreateTemp(b.TempDir(), "test.txt")
+ if err != nil {
+ b.Fatalf("failed to create temp file: %v", err)
+ }
+ defer tempFile.Close()
+
+ fileName := tempFile.Name()
+
+ dataSize := b.N
+ wg.Add(1)
+ go func(f *os.File) {
+ defer wg.Done()
+
+ for i := 0; i < dataSize; i++ {
+ if _, err := f.Write([]byte{1}); err != nil {
+ b.Errorf("failed to write: %v", err)
+ return
+ }
+ if i%1000 == 0 {
+ f.Sync()
+ }
+ }
+ }(tempFile)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ wg.Add(1)
+ go func(ln Listener, fileName string) {
+ defer wg.Done()
+
+ conn, err := ln.Accept()
+ if err != nil {
+ b.Errorf("failed to accept: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ f, err := os.OpenFile(fileName, os.O_RDONLY, 0660)
+ if err != nil {
+ b.Errorf("failed to open file: %v", err)
+ return
+ }
+ defer f.Close()
+
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ if _, err := io.Copy(conn, f); err != nil {
+ b.Errorf("failed to copy: %v", err)
+ return
+ }
+ }
+ }(ln, fileName)
+
+ conn, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ b.Fatalf("failed to dial: %v", err)
+ }
+ defer conn.Close()
+
+ n, err := io.CopyN(io.Discard, conn, int64(dataSize))
+ if err != nil {
+ b.Fatalf("failed to copy: %v", err)
+ }
+ if n != int64(dataSize) {
+ b.Fatalf("expected %d copied bytes, but got %d", dataSize, n)
+ }
+
+ cancel()
+}
diff --git a/src/net/sendfile_unix_alt.go b/src/net/sendfile_unix_alt.go
index b86771721e..5cb65ee767 100644
--- a/src/net/sendfile_unix_alt.go
+++ b/src/net/sendfile_unix_alt.go
@@ -15,8 +15,8 @@ import (
// sendFile copies the contents of r to c using the sendfile
// system call to minimize copies.
//
-// if handled == true, sendFile returns the number of bytes copied and any
-// non-EOF error.
+// if handled == true, sendFile returns the number (potentially zero) of bytes
+// copied and any non-EOF error.
//
// if handled == false, sendFile performed no work.
func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
@@ -65,7 +65,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var werr error
err = sc.Read(func(fd uintptr) bool {
- written, werr = poll.SendFile(&c.pfd, int(fd), pos, remain)
+ written, werr, handled = poll.SendFile(&c.pfd, int(fd), pos, remain)
return true
})
if err == nil {
@@ -78,8 +78,8 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
_, err1 := f.Seek(written, io.SeekCurrent)
if err1 != nil && err == nil {
- return written, err1, written > 0
+ return written, err1, handled
}
- return written, wrapSyscallError("sendfile", err), written > 0
+ return written, wrapSyscallError("sendfile", err), handled
}
diff --git a/src/net/server_test.go b/src/net/server_test.go
index 2ff0689067..eb6b111f1f 100644
--- a/src/net/server_test.go
+++ b/src/net/server_test.go
@@ -2,11 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
+ "fmt"
"os"
"testing"
)
@@ -251,65 +250,80 @@ var udpServerTests = []struct {
func TestUDPServer(t *testing.T) {
for i, tt := range udpServerTests {
- if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) {
- t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"<-"+tt.taddr)
- continue
- }
-
- c1, err := ListenPacket(tt.snet, tt.saddr)
- if err != nil {
- if perr := parseDialError(err); perr != nil {
- t.Error(perr)
+ i, tt := i, tt
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) {
+ t.Skipf("skipping %s %s<-%s test", tt.snet, tt.saddr, tt.taddr)
}
- t.Fatal(err)
- }
+ t.Logf("%s %s<-%s", tt.snet, tt.saddr, tt.taddr)
- ls := (&packetListener{PacketConn: c1}).newLocalServer()
- defer ls.teardown()
- tpch := make(chan error, 1)
- handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
- if err := ls.buildup(handler); err != nil {
- t.Fatal(err)
- }
-
- trch := make(chan error, 1)
- _, port, err := SplitHostPort(ls.PacketConn.LocalAddr().String())
- if err != nil {
- t.Fatal(err)
- }
- if tt.dial {
- d := Dialer{Timeout: someTimeout}
- c2, err := d.Dial(tt.tnet, JoinHostPort(tt.taddr, port))
+ c1, err := ListenPacket(tt.snet, tt.saddr)
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
- defer c2.Close()
- go transceiver(c2, []byte("UDP SERVER TEST"), trch)
- } else {
- c2, err := ListenPacket(tt.tnet, JoinHostPort(tt.taddr, "0"))
- if err != nil {
- if perr := parseDialError(err); perr != nil {
- t.Error(perr)
- }
+
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
+ defer ls.teardown()
+ tpch := make(chan error, 1)
+ handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
+ if err := ls.buildup(handler); err != nil {
t.Fatal(err)
}
- defer c2.Close()
- dst, err := ResolveUDPAddr(tt.tnet, JoinHostPort(tt.taddr, port))
+
+ trch := make(chan error, 1)
+ _, port, err := SplitHostPort(ls.PacketConn.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
- go packetTransceiver(c2, []byte("UDP SERVER TEST"), dst, trch)
- }
+ if tt.dial {
+ d := Dialer{Timeout: someTimeout}
+ c2, err := d.Dial(tt.tnet, JoinHostPort(tt.taddr, port))
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ go transceiver(c2, []byte("UDP SERVER TEST"), trch)
+ } else {
+ c2, err := ListenPacket(tt.tnet, JoinHostPort(tt.taddr, "0"))
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ dst, err := ResolveUDPAddr(tt.tnet, JoinHostPort(tt.taddr, port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ go packetTransceiver(c2, []byte("UDP SERVER TEST"), dst, trch)
+ }
- for err := range trch {
- t.Errorf("#%d: %v", i, err)
- }
- for err := range tpch {
- t.Errorf("#%d: %v", i, err)
- }
+ for trch != nil || tpch != nil {
+ select {
+ case err, ok := <-trch:
+ if !ok {
+ trch = nil
+ }
+ if err != nil {
+ t.Errorf("client: %v", err)
+ }
+ case err, ok := <-tpch:
+ if !ok {
+ tpch = nil
+ }
+ if err != nil {
+ t.Errorf("server: %v", err)
+ }
+ }
+ }
+ })
}
}
@@ -326,58 +340,73 @@ func TestUnixgramServer(t *testing.T) {
}
for i, tt := range unixgramServerTests {
- if !testableListenArgs("unixgram", tt.saddr, "") {
- t.Logf("skipping %s test", "unixgram "+tt.saddr+"<-"+tt.caddr)
- continue
- }
-
- c1, err := ListenPacket("unixgram", tt.saddr)
- if err != nil {
- if perr := parseDialError(err); perr != nil {
- t.Error(perr)
+ i, tt := i, tt
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ if !testableListenArgs("unixgram", tt.saddr, "") {
+ t.Skipf("skipping unixgram %s<-%s test", tt.saddr, tt.caddr)
}
- t.Fatal(err)
- }
-
- ls := (&packetListener{PacketConn: c1}).newLocalServer()
- defer ls.teardown()
- tpch := make(chan error, 1)
- handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
- if err := ls.buildup(handler); err != nil {
- t.Fatal(err)
- }
+ t.Logf("unixgram %s<-%s", tt.saddr, tt.caddr)
- trch := make(chan error, 1)
- if tt.dial {
- d := Dialer{Timeout: someTimeout, LocalAddr: &UnixAddr{Net: "unixgram", Name: tt.caddr}}
- c2, err := d.Dial("unixgram", ls.PacketConn.LocalAddr().String())
+ c1, err := ListenPacket("unixgram", tt.saddr)
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
- defer os.Remove(c2.LocalAddr().String())
- defer c2.Close()
- go transceiver(c2, []byte(c2.LocalAddr().String()), trch)
- } else {
- c2, err := ListenPacket("unixgram", tt.caddr)
- if err != nil {
- if perr := parseDialError(err); perr != nil {
- t.Error(perr)
- }
+
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
+ defer ls.teardown()
+ tpch := make(chan error, 1)
+ handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
+ if err := ls.buildup(handler); err != nil {
t.Fatal(err)
}
- defer os.Remove(c2.LocalAddr().String())
- defer c2.Close()
- go packetTransceiver(c2, []byte("UNIXGRAM SERVER TEST"), ls.PacketConn.LocalAddr(), trch)
- }
- for err := range trch {
- t.Errorf("#%d: %v", i, err)
- }
- for err := range tpch {
- t.Errorf("#%d: %v", i, err)
- }
+ trch := make(chan error, 1)
+ if tt.dial {
+ d := Dialer{Timeout: someTimeout, LocalAddr: &UnixAddr{Net: "unixgram", Name: tt.caddr}}
+ c2, err := d.Dial("unixgram", ls.PacketConn.LocalAddr().String())
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer os.Remove(c2.LocalAddr().String())
+ defer c2.Close()
+ go transceiver(c2, []byte(c2.LocalAddr().String()), trch)
+ } else {
+ c2, err := ListenPacket("unixgram", tt.caddr)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer os.Remove(c2.LocalAddr().String())
+ defer c2.Close()
+ go packetTransceiver(c2, []byte("UNIXGRAM SERVER TEST"), ls.PacketConn.LocalAddr(), trch)
+ }
+
+ for trch != nil || tpch != nil {
+ select {
+ case err, ok := <-trch:
+ if !ok {
+ trch = nil
+ }
+ if err != nil {
+ t.Errorf("client: %v", err)
+ }
+ case err, ok := <-tpch:
+ if !ok {
+ tpch = nil
+ }
+ if err != nil {
+ t.Errorf("server: %v", err)
+ }
+ }
+ }
+ })
}
}
diff --git a/src/net/smtp/auth.go b/src/net/smtp/auth.go
index 72eb16671f..6d461acc48 100644
--- a/src/net/smtp/auth.go
+++ b/src/net/smtp/auth.go
@@ -42,7 +42,7 @@ type plainAuth struct {
host string
}
-// PlainAuth returns an Auth that implements the PLAIN authentication
+// PlainAuth returns an [Auth] that implements the PLAIN authentication
// mechanism as defined in RFC 4616. The returned Auth uses the given
// username and password to authenticate to host and act as identity.
// Usually identity should be the empty string, to act as username.
@@ -86,7 +86,7 @@ type cramMD5Auth struct {
username, secret string
}
-// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication
+// CRAMMD5Auth returns an [Auth] that implements the CRAM-MD5 authentication
// mechanism as defined in RFC 2195.
// The returned Auth uses the given username and secret to authenticate
// to the server using the challenge-response mechanism.
diff --git a/src/net/smtp/smtp.go b/src/net/smtp/smtp.go
index b5a025ef2a..b7877936da 100644
--- a/src/net/smtp/smtp.go
+++ b/src/net/smtp/smtp.go
@@ -48,7 +48,7 @@ type Client struct {
helloError error // the error from the hello
}
-// Dial returns a new Client connected to an SMTP server at addr.
+// Dial returns a new [Client] connected to an SMTP server at addr.
// The addr must include a port, as in "mail.example.com:smtp".
func Dial(addr string) (*Client, error) {
conn, err := net.Dial("tcp", addr)
@@ -59,7 +59,7 @@ func Dial(addr string) (*Client, error) {
return NewClient(conn, host)
}
-// NewClient returns a new Client using an existing connection and host as a
+// NewClient returns a new [Client] using an existing connection and host as a
// server name to be used when authenticating.
func NewClient(conn net.Conn, host string) (*Client, error) {
text := textproto.NewConn(conn)
@@ -166,7 +166,7 @@ func (c *Client) StartTLS(config *tls.Config) error {
}
// TLSConnectionState returns the client's TLS connection state.
-// The return values are their zero values if StartTLS did
+// The return values are their zero values if [Client.StartTLS] did
// not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
tc, ok := c.conn.(*tls.Conn)
@@ -241,7 +241,7 @@ func (c *Client) Auth(a Auth) error {
// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME
// parameter. If the server supports the SMTPUTF8 extension, Mail adds the
// SMTPUTF8 parameter.
-// This initiates a mail transaction and is followed by one or more Rcpt calls.
+// This initiates a mail transaction and is followed by one or more [Client.Rcpt] calls.
func (c *Client) Mail(from string) error {
if err := validateLine(from); err != nil {
return err
@@ -263,8 +263,8 @@ func (c *Client) Mail(from string) error {
}
// Rcpt issues a RCPT command to the server using the provided email address.
-// A call to Rcpt must be preceded by a call to Mail and may be followed by
-// a Data call or another Rcpt call.
+// A call to Rcpt must be preceded by a call to [Client.Mail] and may be followed by
+// a [Client.Data] call or another Rcpt call.
func (c *Client) Rcpt(to string) error {
if err := validateLine(to); err != nil {
return err
@@ -287,7 +287,7 @@ func (d *dataCloser) Close() error {
// Data issues a DATA command to the server and returns a writer that
// can be used to write the mail headers and body. The caller should
// close the writer before calling any more methods on c. A call to
-// Data must be preceded by one or more calls to Rcpt.
+// Data must be preceded by one or more calls to [Client.Rcpt].
func (c *Client) Data() (io.WriteCloser, error) {
_, _, err := c.cmd(354, "DATA")
if err != nil {
diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go
index b3e1806ba9..d04c26e7ef 100644
--- a/src/net/sock_posix.go
+++ b/src/net/sock_posix.go
@@ -89,38 +89,10 @@ func (fd *netFD) ctrlNetwork() string {
return fd.net + "6"
}
-func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
- switch fd.family {
- case syscall.AF_INET, syscall.AF_INET6:
- switch fd.sotype {
- case syscall.SOCK_STREAM:
- return sockaddrToTCP
- case syscall.SOCK_DGRAM:
- return sockaddrToUDP
- case syscall.SOCK_RAW:
- return sockaddrToIP
- }
- case syscall.AF_UNIX:
- switch fd.sotype {
- case syscall.SOCK_STREAM:
- return sockaddrToUnix
- case syscall.SOCK_DGRAM:
- return sockaddrToUnixgram
- case syscall.SOCK_SEQPACKET:
- return sockaddrToUnixpacket
- }
- }
- return func(syscall.Sockaddr) Addr { return nil }
-}
-
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
var c *rawConn
- var err error
if ctrlCtxFn != nil {
- c, err = newRawConn(fd)
- if err != nil {
- return err
- }
+ c = newRawConn(fd)
var ctrlAddr string
if raddr != nil {
ctrlAddr = raddr.String()
@@ -133,6 +105,7 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func
}
var lsa syscall.Sockaddr
+ var err error
if laddr != nil {
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
@@ -185,10 +158,7 @@ func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int,
}
if ctrlCtxFn != nil {
- c, err := newRawConn(fd)
- if err != nil {
- return err
- }
+ c := newRawConn(fd)
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
@@ -239,10 +209,7 @@ func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn f
}
if ctrlCtxFn != nil {
- c, err := newRawConn(fd)
- if err != nil {
- return err
- }
+ c := newRawConn(fd)
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
diff --git a/src/net/sock_stub.go b/src/net/sock_stub.go
index e163755568..fd86fa92dc 100644
--- a/src/net/sock_stub.go
+++ b/src/net/sock_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build aix || (js && wasm) || solaris || wasip1
+//go:build aix || js || solaris || wasip1
package net
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
index fa11c7af2e..a519909bb0 100644
--- a/src/net/sock_windows.go
+++ b/src/net/sock_windows.go
@@ -11,29 +11,15 @@ import (
)
func maxListenerBacklog() int {
- // TODO: Implement this
- // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030.
+ // When the socket backlog is SOMAXCONN, Windows will set the backlog to
+ // "a reasonable maximum value".
+ // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen
return syscall.SOMAXCONN
}
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
- if err == nil {
- return s, nil
- }
- // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
- // old versions of Windows, see
- // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
- // for details. Just use syscall.Socket, if windows.WSASocket failed.
-
- // See ../syscall/exec_unix.go for description of ForkLock.
- syscall.ForkLock.RLock()
- s, err = socketFunc(family, sotype, proto)
- if err == nil {
- syscall.CloseOnExec(s)
- }
- syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
diff --git a/src/net/sockaddr_posix.go b/src/net/sockaddr_posix.go
index e44fc76f4b..c5604fca35 100644
--- a/src/net/sockaddr_posix.go
+++ b/src/net/sockaddr_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
@@ -32,3 +32,27 @@ type sockaddr interface {
// toLocal maps the zero address to a local system address (127.0.0.1 or ::1)
toLocal(net string) sockaddr
}
+
+func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
+ switch fd.family {
+ case syscall.AF_INET, syscall.AF_INET6:
+ switch fd.sotype {
+ case syscall.SOCK_STREAM:
+ return sockaddrToTCP
+ case syscall.SOCK_DGRAM:
+ return sockaddrToUDP
+ case syscall.SOCK_RAW:
+ return sockaddrToIP
+ }
+ case syscall.AF_UNIX:
+ switch fd.sotype {
+ case syscall.SOCK_STREAM:
+ return sockaddrToUnix
+ case syscall.SOCK_DGRAM:
+ return sockaddrToUnixgram
+ case syscall.SOCK_SEQPACKET:
+ return sockaddrToUnixpacket
+ }
+ }
+ return func(syscall.Sockaddr) Addr { return nil }
+}
diff --git a/src/net/sockopt_stub.go b/src/net/sockopt_fake.go
index 186d8912cb..9d9f7ea951 100644
--- a/src/net/sockopt_stub.go
+++ b/src/net/sockopt_fake.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || wasip1
+//go:build js || wasip1
package net
@@ -21,10 +21,16 @@ func setDefaultMulticastSockopts(s int) error {
}
func setReadBuffer(fd *netFD, bytes int) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.setReadBuffer(bytes)
+ }
return syscall.ENOPROTOOPT
}
func setWriteBuffer(fd *netFD, bytes int) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.setWriteBuffer(bytes)
+ }
return syscall.ENOPROTOOPT
}
@@ -33,5 +39,8 @@ func setKeepAlive(fd *netFD, keepalive bool) error {
}
func setLinger(fd *netFD, sec int) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.setLinger(sec)
+ }
return syscall.ENOPROTOOPT
}
diff --git a/src/net/sockopt_posix.go b/src/net/sockopt_posix.go
index 32e8fcd505..a380c7719b 100644
--- a/src/net/sockopt_posix.go
+++ b/src/net/sockopt_posix.go
@@ -20,35 +20,6 @@ func boolint(b bool) int {
return 0
}
-func ipv4AddrToInterface(ip IP) (*Interface, error) {
- ift, err := Interfaces()
- if err != nil {
- return nil, err
- }
- for _, ifi := range ift {
- ifat, err := ifi.Addrs()
- if err != nil {
- return nil, err
- }
- for _, ifa := range ifat {
- switch v := ifa.(type) {
- case *IPAddr:
- if ip.Equal(v.IP) {
- return &ifi, nil
- }
- case *IPNet:
- if ip.Equal(v.IP) {
- return &ifi, nil
- }
- }
- }
- }
- if ip.Equal(IPv4zero) {
- return nil, nil
- }
- return nil, errNoSuchInterface
-}
-
func interfaceToIPv4Addr(ifi *Interface) (IP, error) {
if ifi == nil {
return IPv4zero, nil
diff --git a/src/net/sockoptip_stub.go b/src/net/sockoptip_stub.go
index a37c31223d..23891a865f 100644
--- a/src/net/sockoptip_stub.go
+++ b/src/net/sockoptip_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || wasip1
+//go:build js || wasip1
package net
diff --git a/src/net/sockoptip_windows.go b/src/net/sockoptip_windows.go
index 62676039a3..9dfa37c51e 100644
--- a/src/net/sockoptip_windows.go
+++ b/src/net/sockoptip_windows.go
@@ -8,7 +8,6 @@ import (
"os"
"runtime"
"syscall"
- "unsafe"
)
func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
@@ -18,7 +17,7 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
}
var a [4]byte
copy(a[:], ip.To4())
- err = fd.pfd.Setsockopt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, (*byte)(unsafe.Pointer(&a[0])), 4)
+ err = fd.pfd.SetsockoptInet4Addr(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
diff --git a/src/net/splice_linux.go b/src/net/splice_linux.go
index ab2ab70b28..bdafcb59ab 100644
--- a/src/net/splice_linux.go
+++ b/src/net/splice_linux.go
@@ -9,12 +9,12 @@ import (
"io"
)
-// splice transfers data from r to c using the splice system call to minimize
-// copies from and to userspace. c must be a TCP connection. Currently, splice
-// is only enabled if r is a TCP or a stream-oriented Unix connection.
+// spliceFrom transfers data from r to c using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection.
+// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection.
//
-// If splice returns handled == false, it has performed no work.
-func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+// If spliceFrom returns handled == false, it has performed no work.
+func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var remain int64 = 1<<63 - 1 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
if ok {
@@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
}
var s *netFD
- if tc, ok := r.(*TCPConn); ok {
- s = tc.fd
- } else if uc, ok := r.(*UnixConn); ok {
- if uc.fd.net != "unix" {
+ switch v := r.(type) {
+ case *TCPConn:
+ s = v.fd
+ case tcpConnWithoutWriteTo:
+ s = v.fd
+ case *UnixConn:
+ if v.fd.net != "unix" {
return 0, nil, false
}
- s = uc.fd
- } else {
+ s = v.fd
+ default:
return 0, nil, false
}
@@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
}
return written, wrapSyscallError(sc, err), handled
}
+
+// spliceTo transfers data from c to w using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection.
+// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection.
+//
+// If spliceTo returns handled == false, it has performed no work.
+func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) {
+ uc, ok := w.(*UnixConn)
+ if !ok || uc.fd.net != "unix" {
+ return
+ }
+
+ written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1)
+ return written, wrapSyscallError(sc, err), handled
+}
diff --git a/src/net/splice_stub.go b/src/net/splice_stub.go
index 3cdadb11c5..239227ff88 100644
--- a/src/net/splice_stub.go
+++ b/src/net/splice_stub.go
@@ -8,6 +8,10 @@ package net
import "io"
-func splice(c *netFD, r io.Reader) (int64, error, bool) {
+func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) {
+ return 0, nil, false
+}
+
+func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) {
return 0, nil, false
}
diff --git a/src/net/splice_test.go b/src/net/splice_test.go
index 75a8f274ff..227ddebff4 100644
--- a/src/net/splice_test.go
+++ b/src/net/splice_test.go
@@ -23,6 +23,7 @@ func TestSplice(t *testing.T) {
t.Skip("skipping unix-to-tcp tests")
}
t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
+ t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
t.Run("no-unixpacket", testSpliceNoUnixpacket)
@@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) {
}
func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
+ // UnixConn doesn't implement io.ReaderFrom, which will fail
+ // the following test in asserting a UnixConn to be an io.ReaderFrom,
+ // so skip this test.
+ if upNet == "unix" || downNet == "unix" {
+ t.Skip("skipping test on unix socket")
+ }
+
clientUp, serverUp := spliceTestSocketPair(t, upNet)
defer clientUp.Close()
clientDown, serverDown := spliceTestSocketPair(t, downNet)
@@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
serverUp.Close()
- // We'd like to call net.splice here and check the handled return
+ // We'd like to call net.spliceFrom here and check the handled return
// value, but we disable splice on old Linux kernels.
//
- // In that case, poll.Splice and net.splice return a non-nil error
+ // In that case, poll.Splice and net.spliceFrom return a non-nil error
// and handled == false. We'd ideally like to see handled == true
// because the source reader is at EOF, but if we're running on an old
- // kernel, and splice is disabled, we won't see EOF from net.splice,
+ // kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
// because we won't touch the reader at all.
//
- // Trying to untangle the errors from net.splice and match them
+ // Trying to untangle the errors from net.spliceFrom and match them
// against the errors created by the poll package would be brittle,
// so this is a higher level test.
//
@@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) {
//
// What we want is err == nil and handled == false, i.e. we never
// called poll.Splice, because we know the unix socket's network.
- _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+ _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
if err != nil || handled != false {
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
}
@@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) {
defer clientDown.Close()
defer serverDown.Close()
// Analogous to testSpliceNoUnixpacket.
- _, err, handled := splice(serverDown.(*TCPConn).fd, up)
+ _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
if err != nil || handled != false {
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
}
@@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) {
b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
+ b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
}
func benchSplice(b *testing.B, upNet, downNet string) {
diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go
index 358e48723b..590516bff1 100644
--- a/src/net/tcpsock.go
+++ b/src/net/tcpsock.go
@@ -24,7 +24,7 @@ type TCPAddr struct {
Zone string // IPv6 scoped addressing zone
}
-// AddrPort returns the TCPAddr a as a netip.AddrPort.
+// AddrPort returns the [TCPAddr] a as a [netip.AddrPort].
//
// If a.Port does not fit in a uint16, it's silently truncated.
//
@@ -79,7 +79,7 @@ func (a *TCPAddr) opAddr() Addr {
// recommended, because it will return at most one of the host name's
// IP addresses.
//
-// See func Dial for a description of the network and address
+// See func [Dial] for a description of the network and address
// parameters.
func ResolveTCPAddr(network, address string) (*TCPAddr, error) {
switch network {
@@ -96,7 +96,7 @@ func ResolveTCPAddr(network, address string) (*TCPAddr, error) {
return addrs.forResolve(network, address).(*TCPAddr), nil
}
-// TCPAddrFromAddrPort returns addr as a TCPAddr. If addr.IsValid() is false,
+// TCPAddrFromAddrPort returns addr as a [TCPAddr]. If addr.IsValid() is false,
// then the returned TCPAddr will contain a nil IP field, indicating an
// address family-agnostic unspecified address.
func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr {
@@ -107,22 +107,22 @@ func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr {
}
}
-// TCPConn is an implementation of the Conn interface for TCP network
+// TCPConn is an implementation of the [Conn] interface for TCP network
// connections.
type TCPConn struct {
conn
}
// SyscallConn returns a raw network connection.
-// This implements the syscall.Conn interface.
+// This implements the [syscall.Conn] interface.
func (c *TCPConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
- return newRawConn(c.fd)
+ return newRawConn(c.fd), nil
}
-// ReadFrom implements the io.ReaderFrom ReadFrom method.
+// ReadFrom implements the [io.ReaderFrom] ReadFrom method.
func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
if !c.ok() {
return 0, syscall.EINVAL
@@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
+// WriteTo implements the io.WriterTo WriteTo method.
+func (c *TCPConn) WriteTo(w io.Writer) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeTo(w)
+ if err != nil && err != io.EOF {
+ err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, err
+}
+
// CloseRead shuts down the reading side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseRead() error {
@@ -250,7 +262,7 @@ func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Dura
return &TCPConn{conn{fd}}
}
-// DialTCP acts like Dial for TCP networks.
+// DialTCP acts like [Dial] for TCP networks.
//
// The network must be a TCP network name; see func Dial for details.
//
@@ -275,14 +287,14 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
}
// TCPListener is a TCP network listener. Clients should typically
-// use variables of type Listener instead of assuming TCP.
+// use variables of type [Listener] instead of assuming TCP.
type TCPListener struct {
fd *netFD
lc ListenConfig
}
// SyscallConn returns a raw network connection.
-// This implements the syscall.Conn interface.
+// This implements the [syscall.Conn] interface.
//
// The returned RawConn only supports calling Control. Read and
// Write return an error.
@@ -290,7 +302,7 @@ func (l *TCPListener) SyscallConn() (syscall.RawConn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
- return newRawListener(l.fd)
+ return newRawListener(l.fd), nil
}
// AcceptTCP accepts the next incoming call and returns the new
@@ -306,8 +318,8 @@ func (l *TCPListener) AcceptTCP() (*TCPConn, error) {
return c, nil
}
-// Accept implements the Accept method in the Listener interface; it
-// waits for the next call and returns a generic Conn.
+// Accept implements the Accept method in the [Listener] interface; it
+// waits for the next call and returns a generic [Conn].
func (l *TCPListener) Accept() (Conn, error) {
if !l.ok() {
return nil, syscall.EINVAL
@@ -331,7 +343,7 @@ func (l *TCPListener) Close() error {
return nil
}
-// Addr returns the listener's network address, a *TCPAddr.
+// Addr returns the listener's network address, a [*TCPAddr].
// The Addr returned is shared by all invocations of Addr, so
// do not modify it.
func (l *TCPListener) Addr() Addr { return l.fd.laddr }
@@ -342,13 +354,10 @@ func (l *TCPListener) SetDeadline(t time.Time) error {
if !l.ok() {
return syscall.EINVAL
}
- if err := l.fd.pfd.SetDeadline(t); err != nil {
- return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
- }
- return nil
+ return l.fd.SetDeadline(t)
}
-// File returns a copy of the underlying os.File.
+// File returns a copy of the underlying [os.File].
// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
//
@@ -366,7 +375,7 @@ func (l *TCPListener) File() (f *os.File, err error) {
return
}
-// ListenTCP acts like Listen for TCP networks.
+// ListenTCP acts like [Listen] for TCP networks.
//
// The network must be a TCP network name; see func Dial for details.
//
diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go
index d55948f69e..463dedcf44 100644
--- a/src/net/tcpsock_plan9.go
+++ b/src/net/tcpsock_plan9.go
@@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
+func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
+ return genericWriteTo(c, w)
+}
+
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go
index e6f425b1cd..01b5ec9ed0 100644
--- a/src/net/tcpsock_posix.go
+++ b/src/net/tcpsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
@@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr {
}
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
- if n, err, handled := splice(c.fd, r); handled {
+ if n, err, handled := spliceFrom(c.fd, r); handled {
return n, err
}
if n, err, handled := sendFile(c.fd, r); handled {
@@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
+func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
+ if n, err, handled := spliceTo(w, c.fd); handled {
+ return n, err
+ }
+ return genericWriteTo(c, w)
+}
+
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
diff --git a/src/net/tcpsock_test.go b/src/net/tcpsock_test.go
index f720a22519..b37e936ff8 100644
--- a/src/net/tcpsock_test.go
+++ b/src/net/tcpsock_test.go
@@ -2,11 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
+ "context"
+ "errors"
"fmt"
"internal/testenv"
"io"
@@ -670,6 +670,11 @@ func TestTCPBig(t *testing.T) {
}
func TestCopyPipeIntoTCP(t *testing.T) {
+ switch runtime.GOOS {
+ case "js", "wasip1":
+ t.Skipf("skipping: os.Pipe not supported on %s", runtime.GOOS)
+ }
+
ln := newLocalListener(t, "tcp")
defer ln.Close()
@@ -783,3 +788,48 @@ func TestDialTCPDefaultKeepAlive(t *testing.T) {
t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive)
}
}
+
+func TestTCPListenAfterClose(t *testing.T) {
+ // Regression test for https://go.dev/issue/50216:
+ // after calling Close on a Listener, the fake net implementation would
+ // erroneously Accept a connection dialed before the call to Close.
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ var wg sync.WaitGroup
+ ctx, cancel := context.WithCancel(context.Background())
+
+ d := &Dialer{}
+ for n := 2; n > 0; n-- {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ c, err := d.DialContext(ctx, ln.Addr().Network(), ln.Addr().String())
+ if err == nil {
+ <-ctx.Done()
+ c.Close()
+ }
+ }()
+ }
+
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ } else {
+ t.Error(err)
+ }
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ wg.Wait()
+ ln.Close()
+
+ c, err = ln.Accept()
+ if !errors.Is(err, ErrClosed) {
+ if err == nil {
+ c.Close()
+ }
+ t.Errorf("after l.Close(), l.Accept() = _, %v\nwant %v", err, ErrClosed)
+ }
+}
diff --git a/src/net/tcpsock_unix_test.go b/src/net/tcpsock_unix_test.go
index 35fd937e07..df810a21d8 100644
--- a/src/net/tcpsock_unix_test.go
+++ b/src/net/tcpsock_unix_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !plan9 && !windows
package net
diff --git a/src/net/tcpsockopt_stub.go b/src/net/tcpsockopt_stub.go
index f778143d3b..cef07cd648 100644
--- a/src/net/tcpsockopt_stub.go
+++ b/src/net/tcpsockopt_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || wasip1
+//go:build js || wasip1
package net
diff --git a/src/net/textproto/header.go b/src/net/textproto/header.go
index a58df7aebc..689a6827b9 100644
--- a/src/net/textproto/header.go
+++ b/src/net/textproto/header.go
@@ -23,7 +23,7 @@ func (h MIMEHeader) Set(key, value string) {
}
// Get gets the first value associated with the given key.
-// It is case insensitive; CanonicalMIMEHeaderKey is used
+// It is case insensitive; [CanonicalMIMEHeaderKey] is used
// to canonicalize the provided key.
// If there are no values associated with the key, Get returns "".
// To use non-canonical keys, access the map directly.
@@ -39,7 +39,7 @@ func (h MIMEHeader) Get(key string) string {
}
// Values returns all values associated with the given key.
-// It is case insensitive; CanonicalMIMEHeaderKey is
+// It is case insensitive; [CanonicalMIMEHeaderKey] is
// used to canonicalize the provided key. To use non-canonical
// keys, access the map directly.
// The returned slice is not a copy.
diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go
index fc2590b1cd..793021101b 100644
--- a/src/net/textproto/reader.go
+++ b/src/net/textproto/reader.go
@@ -16,6 +16,10 @@ import (
"sync"
)
+// TODO: This should be a distinguishable error (ErrMessageTooLarge)
+// to allow mime/multipart to detect it.
+var errMessageTooLarge = errors.New("message too large")
+
// A Reader implements convenience methods for reading requests
// or responses from a text protocol network connection.
type Reader struct {
@@ -24,10 +28,10 @@ type Reader struct {
buf []byte // a re-usable buffer for readContinuedLineSlice
}
-// NewReader returns a new Reader reading from r.
+// NewReader returns a new [Reader] reading from r.
//
-// To avoid denial of service attacks, the provided bufio.Reader
-// should be reading from an io.LimitReader or similar Reader to bound
+// To avoid denial of service attacks, the provided [bufio.Reader]
+// should be reading from an [io.LimitReader] or similar Reader to bound
// the size of responses.
func NewReader(r *bufio.Reader) *Reader {
return &Reader{R: r}
@@ -36,20 +40,23 @@ func NewReader(r *bufio.Reader) *Reader {
// ReadLine reads a single line from r,
// eliding the final \n or \r\n from the returned string.
func (r *Reader) ReadLine() (string, error) {
- line, err := r.readLineSlice()
+ line, err := r.readLineSlice(-1)
return string(line), err
}
-// ReadLineBytes is like ReadLine but returns a []byte instead of a string.
+// ReadLineBytes is like [Reader.ReadLine] but returns a []byte instead of a string.
func (r *Reader) ReadLineBytes() ([]byte, error) {
- line, err := r.readLineSlice()
+ line, err := r.readLineSlice(-1)
if line != nil {
line = bytes.Clone(line)
}
return line, err
}
-func (r *Reader) readLineSlice() ([]byte, error) {
+// readLineSlice reads a single line from r,
+// up to lim bytes long (or unlimited if lim is less than 0),
+// eliding the final \r or \r\n from the returned string.
+func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
r.closeDot()
var line []byte
for {
@@ -57,6 +64,9 @@ func (r *Reader) readLineSlice() ([]byte, error) {
if err != nil {
return nil, err
}
+ if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
+ return nil, errMessageTooLarge
+ }
// Avoid the copy if the first call produced a full line.
if line == nil && !more {
return l, nil
@@ -88,7 +98,7 @@ func (r *Reader) readLineSlice() ([]byte, error) {
//
// Empty lines are never continued.
func (r *Reader) ReadContinuedLine() (string, error) {
- line, err := r.readContinuedLineSlice(noValidation)
+ line, err := r.readContinuedLineSlice(-1, noValidation)
return string(line), err
}
@@ -106,10 +116,10 @@ func trim(s []byte) []byte {
return s[i:n]
}
-// ReadContinuedLineBytes is like ReadContinuedLine but
+// ReadContinuedLineBytes is like [Reader.ReadContinuedLine] but
// returns a []byte instead of a string.
func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
- line, err := r.readContinuedLineSlice(noValidation)
+ line, err := r.readContinuedLineSlice(-1, noValidation)
if line != nil {
line = bytes.Clone(line)
}
@@ -120,13 +130,14 @@ func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
// returning a byte slice with all lines. The validateFirstLine function
// is run on the first read line, and if it returns an error then this
// error is returned from readContinuedLineSlice.
-func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) {
+// It reads up to lim bytes of data (or unlimited if lim is less than 0).
+func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
if validateFirstLine == nil {
return nil, fmt.Errorf("missing validateFirstLine func")
}
// Read the first line.
- line, err := r.readLineSlice()
+ line, err := r.readLineSlice(lim)
if err != nil {
return nil, err
}
@@ -154,13 +165,21 @@ func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([
// copy the slice into buf.
r.buf = append(r.buf[:0], trim(line)...)
+ if lim < 0 {
+ lim = math.MaxInt64
+ }
+ lim -= int64(len(r.buf))
+
// Read continuation lines.
for r.skipSpace() > 0 {
- line, err := r.readLineSlice()
+ r.buf = append(r.buf, ' ')
+ if int64(len(r.buf)) >= lim {
+ return nil, errMessageTooLarge
+ }
+ line, err := r.readLineSlice(lim - int64(len(r.buf)))
if err != nil {
break
}
- r.buf = append(r.buf, ' ')
r.buf = append(r.buf, trim(line)...)
}
return r.buf, nil
@@ -289,7 +308,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err err
return
}
-// DotReader returns a new Reader that satisfies Reads using the
+// DotReader returns a new [Reader] that satisfies Reads using the
// decoded text of a dot-encoded block read from r.
// The returned Reader is only valid until the next call
// to a method on r.
@@ -303,7 +322,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err err
//
// The decoded form returned by the Reader's Read method
// rewrites the "\r\n" line endings into the simpler "\n",
-// removes leading dot escapes if present, and stops with error io.EOF
+// removes leading dot escapes if present, and stops with error [io.EOF]
// after consuming (and discarding) the end-of-sequence line.
func (r *Reader) DotReader() io.Reader {
r.closeDot()
@@ -420,7 +439,7 @@ func (r *Reader) closeDot() {
// ReadDotBytes reads a dot-encoding and returns the decoded data.
//
-// See the documentation for the DotReader method for details about dot-encoding.
+// See the documentation for the [Reader.DotReader] method for details about dot-encoding.
func (r *Reader) ReadDotBytes() ([]byte, error) {
return io.ReadAll(r.DotReader())
}
@@ -428,7 +447,7 @@ func (r *Reader) ReadDotBytes() ([]byte, error) {
// ReadDotLines reads a dot-encoding and returns a slice
// containing the decoded lines, with the final \r\n or \n elided from each.
//
-// See the documentation for the DotReader method for details about dot-encoding.
+// See the documentation for the [Reader.DotReader] method for details about dot-encoding.
func (r *Reader) ReadDotLines() ([]string, error) {
// We could use ReadDotBytes and then Split it,
// but reading a line at a time avoids needing a
@@ -462,7 +481,7 @@ var colon = []byte(":")
// ReadMIMEHeader reads a MIME-style header from r.
// The header is a sequence of possibly continued Key: Value lines
// ending in a blank line.
-// The returned map m maps CanonicalMIMEHeaderKey(key) to a
+// The returned map m maps [CanonicalMIMEHeaderKey](key) to a
// sequence of values in the same order encountered in the input.
//
// For example, consider this input:
@@ -507,7 +526,8 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error)
// The first line cannot start with a leading space.
if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
- line, err := r.readLineSlice()
+ const errorLimit = 80 // arbitrary limit on how much of the line we'll quote
+ line, err := r.readLineSlice(errorLimit)
if err != nil {
return m, err
}
@@ -515,7 +535,7 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error)
}
for {
- kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon)
+ kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
if len(kv) == 0 {
return m, err
}
@@ -544,7 +564,7 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error)
maxHeaders--
if maxHeaders < 0 {
- return nil, errors.New("message too large")
+ return nil, errMessageTooLarge
}
// Skip initial spaces in value.
@@ -557,9 +577,7 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error)
}
maxMemory -= int64(len(value))
if maxMemory < 0 {
- // TODO: This should be a distinguishable error (ErrMessageTooLarge)
- // to allow mime/multipart to detect it.
- return m, errors.New("message too large")
+ return m, errMessageTooLarge
}
if vv == nil && len(strs) > 0 {
// More than likely this will be a single-element key.
diff --git a/src/net/textproto/reader_test.go b/src/net/textproto/reader_test.go
index 696ae406f3..26ff617470 100644
--- a/src/net/textproto/reader_test.go
+++ b/src/net/textproto/reader_test.go
@@ -36,6 +36,18 @@ func TestReadLine(t *testing.T) {
}
}
+func TestReadLineLongLine(t *testing.T) {
+ line := strings.Repeat("12345", 10000)
+ r := reader(line + "\r\n")
+ s, err := r.ReadLine()
+ if err != nil {
+ t.Fatalf("Line 1: %v", err)
+ }
+ if s != line {
+ t.Fatalf("%v-byte line does not match expected %v-byte line", len(s), len(line))
+ }
+}
+
func TestReadContinuedLine(t *testing.T) {
r := reader("line1\nline\n 2\nline3\n")
s, err := r.ReadContinuedLine()
diff --git a/src/net/textproto/textproto.go b/src/net/textproto/textproto.go
index 70038d5888..4ae3ecff74 100644
--- a/src/net/textproto/textproto.go
+++ b/src/net/textproto/textproto.go
@@ -7,20 +7,20 @@
//
// The package provides:
//
-// Error, which represents a numeric error response from
+// [Error], which represents a numeric error response from
// a server.
//
-// Pipeline, to manage pipelined requests and responses
+// [Pipeline], to manage pipelined requests and responses
// in a client.
//
-// Reader, to read numeric response code lines,
+// [Reader], to read numeric response code lines,
// key: value headers, lines wrapped with leading spaces
// on continuation lines, and whole text blocks ending
// with a dot on a line by itself.
//
-// Writer, to write dot-encoded text blocks.
+// [Writer], to write dot-encoded text blocks.
//
-// Conn, a convenient packaging of Reader, Writer, and Pipeline for use
+// [Conn], a convenient packaging of [Reader], [Writer], and [Pipeline] for use
// with a single network connection.
package textproto
@@ -50,8 +50,8 @@ func (p ProtocolError) Error() string {
}
// A Conn represents a textual network protocol connection.
-// It consists of a Reader and Writer to manage I/O
-// and a Pipeline to sequence concurrent requests on the connection.
+// It consists of a [Reader] and [Writer] to manage I/O
+// and a [Pipeline] to sequence concurrent requests on the connection.
// These embedded types carry methods with them;
// see the documentation of those types for details.
type Conn struct {
@@ -61,7 +61,7 @@ type Conn struct {
conn io.ReadWriteCloser
}
-// NewConn returns a new Conn using conn for I/O.
+// NewConn returns a new [Conn] using conn for I/O.
func NewConn(conn io.ReadWriteCloser) *Conn {
return &Conn{
Reader: Reader{R: bufio.NewReader(conn)},
@@ -75,8 +75,8 @@ func (c *Conn) Close() error {
return c.conn.Close()
}
-// Dial connects to the given address on the given network using net.Dial
-// and then returns a new Conn for the connection.
+// Dial connects to the given address on the given network using [net.Dial]
+// and then returns a new [Conn] for the connection.
func Dial(network, addr string) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
diff --git a/src/net/textproto/writer.go b/src/net/textproto/writer.go
index 2ece3f511b..662515fb2c 100644
--- a/src/net/textproto/writer.go
+++ b/src/net/textproto/writer.go
@@ -17,7 +17,7 @@ type Writer struct {
dot *dotWriter
}
-// NewWriter returns a new Writer writing to w.
+// NewWriter returns a new [Writer] writing to w.
func NewWriter(w *bufio.Writer) *Writer {
return &Writer{W: w}
}
@@ -39,7 +39,7 @@ func (w *Writer) PrintfLine(format string, args ...any) error {
// when the DotWriter is closed. The caller should close the
// DotWriter before the next call to a method on w.
//
-// See the documentation for Reader's DotReader method for details about dot-encoding.
+// See the documentation for the [Reader.DotReader] method for details about dot-encoding.
func (w *Writer) DotWriter() io.WriteCloser {
w.closeDot()
w.dot = &dotWriter{w: w}
diff --git a/src/net/timeout_test.go b/src/net/timeout_test.go
index c0bce57b94..ca86f31ef2 100644
--- a/src/net/timeout_test.go
+++ b/src/net/timeout_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -11,7 +9,6 @@ import (
"fmt"
"internal/testenv"
"io"
- "net/internal/socktest"
"os"
"runtime"
"sync"
@@ -19,65 +16,121 @@ import (
"time"
)
-var dialTimeoutTests = []struct {
- timeout time.Duration
- delta time.Duration // for deadline
+func init() {
+ // Install a hook to ensure that a 1ns timeout will always
+ // be exceeded by the time Dial gets to the relevant system call.
+ //
+ // Without this, systems with a very large timer granularity — such as
+ // Windows — may be able to accept connections without measurably exceeding
+ // even an implausibly short deadline.
+ testHookStepTime = func() {
+ now := time.Now()
+ for time.Since(now) == 0 {
+ time.Sleep(1 * time.Nanosecond)
+ }
+ }
+}
- guard time.Duration
+var dialTimeoutTests = []struct {
+ initialTimeout time.Duration
+ initialDelta time.Duration // for deadline
}{
// Tests that dial timeouts, deadlines in the past work.
- {-5 * time.Second, 0, -5 * time.Second},
- {0, -5 * time.Second, -5 * time.Second},
- {-5 * time.Second, 5 * time.Second, -5 * time.Second}, // timeout over deadline
- {-1 << 63, 0, time.Second},
- {0, -1 << 63, time.Second},
-
- {50 * time.Millisecond, 0, 100 * time.Millisecond},
- {0, 50 * time.Millisecond, 100 * time.Millisecond},
- {50 * time.Millisecond, 5 * time.Second, 100 * time.Millisecond}, // timeout over deadline
+ {-5 * time.Second, 0},
+ {0, -5 * time.Second},
+ {-5 * time.Second, 5 * time.Second}, // timeout over deadline
+ {-1 << 63, 0},
+ {0, -1 << 63},
+
+ {1 * time.Millisecond, 0},
+ {0, 1 * time.Millisecond},
+ {1 * time.Millisecond, 5 * time.Second}, // timeout over deadline
}
func TestDialTimeout(t *testing.T) {
- // Cannot use t.Parallel - modifies global hooks.
- origTestHookDialChannel := testHookDialChannel
- defer func() { testHookDialChannel = origTestHookDialChannel }()
- defer sw.Set(socktest.FilterConnect, nil)
-
- for i, tt := range dialTimeoutTests {
- switch runtime.GOOS {
- case "plan9", "windows":
- testHookDialChannel = func() { time.Sleep(tt.guard) }
- if runtime.GOOS == "plan9" {
- break
- }
- fallthrough
- default:
- sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
- time.Sleep(tt.guard)
- return nil, errTimedout
- })
- }
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
- d := Dialer{Timeout: tt.timeout}
- if tt.delta != 0 {
- d.Deadline = time.Now().Add(tt.delta)
- }
+ t.Parallel()
- // This dial never starts to send any TCP SYN
- // segment because of above socket filter and
- // test hook.
- c, err := d.Dial("tcp", "127.0.0.1:0")
- if err == nil {
- err = fmt.Errorf("unexpectedly established: tcp:%s->%s", c.LocalAddr(), c.RemoteAddr())
- c.Close()
+ ln := newLocalListener(t, "tcp")
+ defer func() {
+ if err := ln.Close(); err != nil {
+ t.Error(err)
}
+ }()
- if perr := parseDialError(err); perr != nil {
- t.Errorf("#%d: %v", i, perr)
- }
- if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
- t.Fatalf("#%d: %v", i, err)
- }
+ for _, tt := range dialTimeoutTests {
+ t.Run(fmt.Sprintf("%v/%v", tt.initialTimeout, tt.initialDelta), func(t *testing.T) {
+ // We don't run these subtests in parallel because we don't know how big
+ // the kernel's accept queue is, and we don't want to accidentally saturate
+ // it with concurrent calls. (That could cause the Dial to fail with
+ // ECONNREFUSED or ECONNRESET instead of a timeout error.)
+ d := Dialer{Timeout: tt.initialTimeout}
+ delta := tt.initialDelta
+
+ var (
+ beforeDial time.Time
+ afterDial time.Time
+ err error
+ )
+ for {
+ if delta != 0 {
+ d.Deadline = time.Now().Add(delta)
+ }
+
+ beforeDial = time.Now()
+
+ var c Conn
+ c, err = d.Dial(ln.Addr().Network(), ln.Addr().String())
+ afterDial = time.Now()
+
+ if err != nil {
+ break
+ }
+
+ // Even though we're not calling Accept on the Listener, the kernel may
+ // spuriously accept connections on its behalf. If that happens, we will
+ // close the connection (to try to get it out of the kernel's accept
+ // queue) and try a shorter timeout.
+ //
+ // We assume that we will reach a point where the call actually does
+ // time out, although in theory (since this socket is on a loopback
+ // address) a sufficiently clever kernel could notice that no Accept
+ // call is pending and bypass both the queue and the timeout to return
+ // another error immediately.
+ t.Logf("closing spurious connection from Dial")
+ c.Close()
+
+ if delta <= 1 && d.Timeout <= 1 {
+ t.Fatalf("can't reduce Timeout or Deadline")
+ }
+ if delta > 1 {
+ delta /= 2
+ t.Logf("reducing Deadline delta to %v", delta)
+ }
+ if d.Timeout > 1 {
+ d.Timeout /= 2
+ t.Logf("reducing Timeout to %v", d.Timeout)
+ }
+ }
+
+ if d.Deadline.IsZero() || afterDial.Before(d.Deadline) {
+ delay := afterDial.Sub(beforeDial)
+ if delay < d.Timeout {
+ t.Errorf("Dial returned after %v; want ≥%v", delay, d.Timeout)
+ }
+ }
+
+ if perr := parseDialError(err); perr != nil {
+ t.Errorf("unexpected error from Dial: %v", perr)
+ }
+ if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
+ t.Errorf("Dial: %v, want timeout", err)
+ }
+ })
}
}
@@ -189,35 +242,22 @@ func TestAcceptTimeoutMustReturn(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
- max := time.NewTimer(time.Second)
- defer max.Stop()
- ch := make(chan error)
- go func() {
- if err := ln.(*TCPListener).SetDeadline(noDeadline); err != nil {
- t.Error(err)
- }
- if err := ln.(*TCPListener).SetDeadline(time.Now().Add(10 * time.Millisecond)); err != nil {
- t.Error(err)
- }
- c, err := ln.Accept()
- if err == nil {
- c.Close()
- }
- ch <- err
- }()
+ if err := ln.(*TCPListener).SetDeadline(noDeadline); err != nil {
+ t.Error(err)
+ }
+ if err := ln.(*TCPListener).SetDeadline(time.Now().Add(10 * time.Millisecond)); err != nil {
+ t.Error(err)
+ }
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ }
- select {
- case <-max.C:
- ln.Close()
- <-ch // wait for tester goroutine to stop
- t.Fatal("Accept didn't return in an expected time")
- case err := <-ch:
- if perr := parseAcceptError(err); perr != nil {
- t.Error(perr)
- }
- if !isDeadlineExceeded(err) {
- t.Fatal(err)
- }
+ if perr := parseAcceptError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Fatal(err)
}
}
@@ -529,7 +569,7 @@ func TestWriteTimeoutMustNotReturn(t *testing.T) {
t.Error(err)
}
maxch <- time.NewTimer(100 * time.Millisecond)
- var b [1]byte
+ var b [1024]byte
for {
if _, err := c.Write(b[:]); err != nil {
ch <- err
diff --git a/src/net/udpsock.go b/src/net/udpsock.go
index e30624dea5..4f8acb7fc8 100644
--- a/src/net/udpsock.go
+++ b/src/net/udpsock.go
@@ -129,7 +129,7 @@ func (c *UDPConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
- return newRawConn(c.fd)
+ return newRawConn(c.fd), nil
}
// ReadFromUDP acts like ReadFrom but returns a UDPAddr.
diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go
index f3dbcfec00..5035059831 100644
--- a/src/net/udpsock_posix.go
+++ b/src/net/udpsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
diff --git a/src/net/udpsock_test.go b/src/net/udpsock_test.go
index 2afd4ac2ae..8a21aa7370 100644
--- a/src/net/udpsock_test.go
+++ b/src/net/udpsock_test.go
@@ -2,12 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
"errors"
+ "fmt"
"internal/testenv"
"net/netip"
"os"
@@ -116,6 +115,10 @@ func TestWriteToUDP(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
+ if !testableNetwork("udp") {
+ t.Skipf("skipping: udp not supported")
+ }
+
c, err := ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
@@ -221,19 +224,29 @@ func TestUDPConnLocalName(t *testing.T) {
testenv.MustHaveExternalNetwork(t)
for _, tt := range udpConnLocalNameTests {
- c, err := ListenUDP(tt.net, tt.laddr)
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- la := c.LocalAddr()
- if a, ok := la.(*UDPAddr); !ok || a.Port == 0 {
- t.Fatalf("got %v; expected a proper address with non-zero port number", la)
- }
+ t.Run(fmt.Sprint(tt.laddr), func(t *testing.T) {
+ if !testableNetwork(tt.net) {
+ t.Skipf("skipping: %s not available", tt.net)
+ }
+
+ c, err := ListenUDP(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ la := c.LocalAddr()
+ if a, ok := la.(*UDPAddr); !ok || a.Port == 0 {
+ t.Fatalf("got %v; expected a proper address with non-zero port number", la)
+ }
+ })
}
}
func TestUDPConnLocalAndRemoteNames(t *testing.T) {
+ if !testableNetwork("udp") {
+ t.Skipf("skipping: udp not available")
+ }
+
for _, laddr := range []string{"", "127.0.0.1:0"} {
c1, err := ListenPacket("udp", "127.0.0.1:0")
if err != nil {
@@ -330,6 +343,9 @@ func TestUDPZeroBytePayload(t *testing.T) {
case "darwin", "ios":
testenv.SkipFlaky(t, 29225)
}
+ if !testableNetwork("udp") {
+ t.Skipf("skipping: udp not available")
+ }
c := newLocalPacketListener(t, "udp")
defer c.Close()
@@ -363,6 +379,9 @@ func TestUDPZeroByteBuffer(t *testing.T) {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}
+ if !testableNetwork("udp") {
+ t.Skipf("skipping: udp not available")
+ }
c := newLocalPacketListener(t, "udp")
defer c.Close()
@@ -397,6 +416,9 @@ func TestUDPReadSizeError(t *testing.T) {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}
+ if !testableNetwork("udp") {
+ t.Skipf("skipping: udp not available")
+ }
c1 := newLocalPacketListener(t, "udp")
defer c1.Close()
@@ -434,6 +456,10 @@ func TestUDPReadSizeError(t *testing.T) {
// TestUDPReadTimeout verifies that ReadFromUDP with timeout returns an error
// without data or an address.
func TestUDPReadTimeout(t *testing.T) {
+ if !testableNetwork("udp4") {
+ t.Skipf("skipping: udp4 not available")
+ }
+
la, err := ResolveUDPAddr("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
@@ -460,10 +486,14 @@ func TestUDPReadTimeout(t *testing.T) {
func TestAllocs(t *testing.T) {
switch runtime.GOOS {
- case "plan9":
- // Plan9 wasn't optimized.
+ case "plan9", "js", "wasip1":
+ // These implementations have not been optimized.
t.Skipf("skipping on %v", runtime.GOOS)
}
+ if !testableNetwork("udp4") {
+ t.Skipf("skipping: udp4 not available")
+ }
+
// Optimizations are required to remove the allocs.
testenv.SkipIfOptimizationOff(t)
@@ -590,6 +620,10 @@ func TestUDPIPVersionReadMsg(t *testing.T) {
case "plan9":
t.Skipf("skipping on %v", runtime.GOOS)
}
+ if !testableNetwork("udp4") {
+ t.Skipf("skipping: udp4 not available")
+ }
+
conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
if err != nil {
t.Fatal(err)
@@ -625,8 +659,11 @@ func TestUDPIPVersionReadMsg(t *testing.T) {
// WriteMsgUDPAddrPort accepts IPv4, IPv4-mapped IPv6, and IPv6 target addresses
// on a UDPConn listening on "::".
func TestIPv6WriteMsgUDPAddrPortTargetAddrIPVersion(t *testing.T) {
- if !supportsIPv6() {
- t.Skip("IPv6 is not supported")
+ if !testableNetwork("udp4") {
+ t.Skipf("skipping: udp4 not available")
+ }
+ if !testableNetwork("udp6") {
+ t.Skipf("skipping: udp6 not available")
}
switch runtime.GOOS {
diff --git a/src/net/unixsock.go b/src/net/unixsock.go
index 14fbac0932..821be7bf74 100644
--- a/src/net/unixsock.go
+++ b/src/net/unixsock.go
@@ -52,7 +52,7 @@ func (a *UnixAddr) opAddr() Addr {
//
// The network must be a Unix network name.
//
-// See func Dial for a description of the network and address
+// See func [Dial] for a description of the network and address
// parameters.
func ResolveUnixAddr(network, address string) (*UnixAddr, error) {
switch network {
@@ -63,19 +63,19 @@ func ResolveUnixAddr(network, address string) (*UnixAddr, error) {
}
}
-// UnixConn is an implementation of the Conn interface for connections
+// UnixConn is an implementation of the [Conn] interface for connections
// to Unix domain sockets.
type UnixConn struct {
conn
}
// SyscallConn returns a raw network connection.
-// This implements the syscall.Conn interface.
+// This implements the [syscall.Conn] interface.
func (c *UnixConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
- return newRawConn(c.fd)
+ return newRawConn(c.fd), nil
}
// CloseRead shuts down the reading side of the Unix domain connection.
@@ -102,7 +102,7 @@ func (c *UnixConn) CloseWrite() error {
return nil
}
-// ReadFromUnix acts like ReadFrom but returns a UnixAddr.
+// ReadFromUnix acts like [UnixConn.ReadFrom] but returns a [UnixAddr].
func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
@@ -114,7 +114,7 @@ func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) {
return n, addr, err
}
-// ReadFrom implements the PacketConn ReadFrom method.
+// ReadFrom implements the [PacketConn] ReadFrom method.
func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
@@ -147,7 +147,7 @@ func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAdd
return
}
-// WriteToUnix acts like WriteTo but takes a UnixAddr.
+// WriteToUnix acts like [UnixConn.WriteTo] but takes a [UnixAddr].
func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
@@ -159,7 +159,7 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) {
return n, err
}
-// WriteTo implements the PacketConn WriteTo method.
+// WriteTo implements the [PacketConn] WriteTo method.
func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
@@ -194,7 +194,7 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err
func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
-// DialUnix acts like Dial for Unix networks.
+// DialUnix acts like [Dial] for Unix networks.
//
// The network must be a Unix network name; see func Dial for details.
//
@@ -215,7 +215,7 @@ func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
}
// UnixListener is a Unix domain socket listener. Clients should
-// typically use variables of type Listener instead of assuming Unix
+// typically use variables of type [Listener] instead of assuming Unix
// domain sockets.
type UnixListener struct {
fd *netFD
@@ -227,7 +227,7 @@ type UnixListener struct {
func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil }
// SyscallConn returns a raw network connection.
-// This implements the syscall.Conn interface.
+// This implements the [syscall.Conn] interface.
//
// The returned RawConn only supports calling Control. Read and
// Write return an error.
@@ -235,7 +235,7 @@ func (l *UnixListener) SyscallConn() (syscall.RawConn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
- return newRawListener(l.fd)
+ return newRawListener(l.fd), nil
}
// AcceptUnix accepts the next incoming call and returns the new
@@ -251,8 +251,8 @@ func (l *UnixListener) AcceptUnix() (*UnixConn, error) {
return c, nil
}
-// Accept implements the Accept method in the Listener interface.
-// Returned connections will be of type *UnixConn.
+// Accept implements the Accept method in the [Listener] interface.
+// Returned connections will be of type [*UnixConn].
func (l *UnixListener) Accept() (Conn, error) {
if !l.ok() {
return nil, syscall.EINVAL
@@ -287,13 +287,10 @@ func (l *UnixListener) SetDeadline(t time.Time) error {
if !l.ok() {
return syscall.EINVAL
}
- if err := l.fd.pfd.SetDeadline(t); err != nil {
- return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
- }
- return nil
+ return l.fd.SetDeadline(t)
}
-// File returns a copy of the underlying os.File.
+// File returns a copy of the underlying [os.File].
// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
//
@@ -311,7 +308,7 @@ func (l *UnixListener) File() (f *os.File, err error) {
return
}
-// ListenUnix acts like Listen for Unix networks.
+// ListenUnix acts like [Listen] for Unix networks.
//
// The network must be "unix" or "unixpacket".
func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
@@ -331,7 +328,7 @@ func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
return ln, nil
}
-// ListenUnixgram acts like ListenPacket for Unix networks.
+// ListenUnixgram acts like [ListenPacket] for Unix networks.
//
// The network must be "unixgram".
func ListenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go
index c501b499ed..f6c8e8f0b0 100644
--- a/src/net/unixsock_posix.go
+++ b/src/net/unixsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || (js && wasm) || wasip1 || windows
+//go:build unix || js || wasip1 || windows
package net
diff --git a/src/net/unixsock_readmsg_other.go b/src/net/unixsock_readmsg_other.go
index 0899a6d3d3..4bef3ee71d 100644
--- a/src/net/unixsock_readmsg_other.go
+++ b/src/net/unixsock_readmsg_other.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || wasip1 || windows
+//go:build js || wasip1 || windows
package net
diff --git a/src/net/unixsock_test.go b/src/net/unixsock_test.go
index 8402519a0d..6906ecc046 100644
--- a/src/net/unixsock_test.go
+++ b/src/net/unixsock_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !plan9 && !windows
package net
@@ -21,6 +21,10 @@ func TestReadUnixgramWithUnnamedSocket(t *testing.T) {
if !testableNetwork("unixgram") {
t.Skip("unixgram test")
}
+ switch runtime.GOOS {
+ case "js", "wasip1":
+ t.Skipf("skipping: syscall.Socket not implemented on %s", runtime.GOOS)
+ }
if runtime.GOOS == "openbsd" {
testenv.SkipFlaky(t, 15157)
}
@@ -359,6 +363,11 @@ func TestUnixUnlink(t *testing.T) {
if !testableNetwork("unix") {
t.Skip("unix test")
}
+ switch runtime.GOOS {
+ case "js", "wasip1":
+ t.Skipf("skipping: %s does not support Unlink", runtime.GOOS)
+ }
+
name := testUnixAddr(t)
listen := func(t *testing.T) *UnixListener {
diff --git a/src/net/url/url.go b/src/net/url/url.go
index 501b263e87..f362958edd 100644
--- a/src/net/url/url.go
+++ b/src/net/url/url.go
@@ -175,7 +175,7 @@ func shouldEscape(c byte, mode encoding) bool {
return true
}
-// QueryUnescape does the inverse transformation of QueryEscape,
+// QueryUnescape does the inverse transformation of [QueryEscape],
// converting each 3-byte encoded substring of the form "%AB" into the
// hex-decoded byte 0xAB.
// It returns an error if any % is not followed by two hexadecimal
@@ -184,12 +184,12 @@ func QueryUnescape(s string) (string, error) {
return unescape(s, encodeQueryComponent)
}
-// PathUnescape does the inverse transformation of PathEscape,
+// PathUnescape does the inverse transformation of [PathEscape],
// converting each 3-byte encoded substring of the form "%AB" into the
// hex-decoded byte 0xAB. It returns an error if any % is not followed
// by two hexadecimal digits.
//
-// PathUnescape is identical to QueryUnescape except that it does not
+// PathUnescape is identical to [QueryUnescape] except that it does not
// unescape '+' to ' ' (space).
func PathUnescape(s string) (string, error) {
return unescape(s, encodePathSegment)
@@ -271,12 +271,12 @@ func unescape(s string, mode encoding) (string, error) {
}
// QueryEscape escapes the string so it can be safely placed
-// inside a URL query.
+// inside a [URL] query.
func QueryEscape(s string) string {
return escape(s, encodeQueryComponent)
}
-// PathEscape escapes the string so it can be safely placed inside a URL path segment,
+// PathEscape escapes the string so it can be safely placed inside a [URL] path segment,
// replacing special characters (including /) with %XX sequences as needed.
func PathEscape(s string) string {
return escape(s, encodePathSegment)
@@ -348,10 +348,17 @@ func escape(s string, mode encoding) string {
//
// scheme:opaque[?query][#fragment]
//
+// The Host field contains the host and port subcomponents of the URL.
+// When the port is present, it is separated from the host with a colon.
+// When the host is an IPv6 address, it must be enclosed in square brackets:
+// "[fe80::1]:80". The [net.JoinHostPort] function combines a host and port
+// into a string suitable for the Host field, adding square brackets to
+// the host when necessary.
+//
// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
// A consequence is that it is impossible to tell which slashes in the Path were
// slashes in the raw URL and which were %2f. This distinction is rarely important,
-// but when it is, the code should use the EscapedPath method, which preserves
+// but when it is, the code should use the [URL.EscapedPath] method, which preserves
// the original encoding of Path.
//
// The RawPath field is an optional field which is only set when the default
@@ -363,7 +370,7 @@ type URL struct {
Scheme string
Opaque string // encoded opaque data
User *Userinfo // username and password information
- Host string // host or host:port
+ Host string // host or host:port (see Hostname and Port methods)
Path string // path (relative paths may omit leading slash)
RawPath string // encoded path hint (see EscapedPath method)
OmitHost bool // do not emit empty host (authority)
@@ -373,13 +380,13 @@ type URL struct {
RawFragment string // encoded fragment hint (see EscapedFragment method)
}
-// User returns a Userinfo containing the provided username
+// User returns a [Userinfo] containing the provided username
// and no password set.
func User(username string) *Userinfo {
return &Userinfo{username, "", false}
}
-// UserPassword returns a Userinfo containing the provided username
+// UserPassword returns a [Userinfo] containing the provided username
// and password.
//
// This functionality should only be used with legacy web sites.
@@ -392,7 +399,7 @@ func UserPassword(username, password string) *Userinfo {
}
// The Userinfo type is an immutable encapsulation of username and
-// password details for a URL. An existing Userinfo value is guaranteed
+// password details for a [URL]. An existing Userinfo value is guaranteed
// to have a username set (potentially empty, as allowed by RFC 2396),
// and optionally a password.
type Userinfo struct {
@@ -457,7 +464,7 @@ func getScheme(rawURL string) (scheme, path string, err error) {
return "", rawURL, nil
}
-// Parse parses a raw url into a URL structure.
+// Parse parses a raw url into a [URL] structure.
//
// The url may be relative (a path, without a host) or absolute
// (starting with a scheme). Trying to parse a hostname and path
@@ -479,7 +486,7 @@ func Parse(rawURL string) (*URL, error) {
return url, nil
}
-// ParseRequestURI parses a raw url into a URL structure. It assumes that
+// ParseRequestURI parses a raw url into a [URL] structure. It assumes that
// url was received in an HTTP request, so the url is interpreted
// only as an absolute URI or an absolute path.
// The string url is assumed not to have a #fragment suffix.
@@ -690,7 +697,7 @@ func (u *URL) setPath(p string) error {
// EscapedPath returns u.RawPath when it is a valid escaping of u.Path.
// Otherwise EscapedPath ignores u.RawPath and computes an escaped
// form on its own.
-// The String and RequestURI methods use EscapedPath to construct
+// The [URL.String] and [URL.RequestURI] methods use EscapedPath to construct
// their results.
// In general, code should call EscapedPath instead of
// reading u.RawPath directly.
@@ -754,7 +761,7 @@ func (u *URL) setFragment(f string) error {
// EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment.
// Otherwise EscapedFragment ignores u.RawFragment and computes an escaped
// form on its own.
-// The String method uses EscapedFragment to construct its result.
+// The [URL.String] method uses EscapedFragment to construct its result.
// In general, code should call EscapedFragment instead of
// reading u.RawFragment directly.
func (u *URL) EscapedFragment() string {
@@ -784,7 +791,7 @@ func validOptionalPort(port string) bool {
return true
}
-// String reassembles the URL into a valid URL string.
+// String reassembles the [URL] into a valid URL string.
// The general form of the result is one of:
//
// scheme:opaque?query#fragment
@@ -858,7 +865,7 @@ func (u *URL) String() string {
return buf.String()
}
-// Redacted is like String but replaces any password with "xxxxx".
+// Redacted is like [URL.String] but replaces any password with "xxxxx".
// Only the password in u.User is redacted.
func (u *URL) Redacted() string {
if u == nil {
@@ -963,7 +970,7 @@ func parseQuery(m Values, query string) (err error) {
// Encode encodes the values into “URL encoded” form
// ("bar=baz&foo=quux") sorted by key.
func (v Values) Encode() string {
- if v == nil {
+ if len(v) == 0 {
return ""
}
var buf strings.Builder
@@ -1053,15 +1060,15 @@ func resolvePath(base, ref string) string {
return r
}
-// IsAbs reports whether the URL is absolute.
+// IsAbs reports whether the [URL] is absolute.
// Absolute means that it has a non-empty scheme.
func (u *URL) IsAbs() bool {
return u.Scheme != ""
}
-// Parse parses a URL in the context of the receiver. The provided URL
+// Parse parses a [URL] in the context of the receiver. The provided URL
// may be relative or absolute. Parse returns nil, err on parse
-// failure, otherwise its return value is the same as ResolveReference.
+// failure, otherwise its return value is the same as [URL.ResolveReference].
func (u *URL) Parse(ref string) (*URL, error) {
refURL, err := Parse(ref)
if err != nil {
@@ -1073,7 +1080,7 @@ func (u *URL) Parse(ref string) (*URL, error) {
// ResolveReference resolves a URI reference to an absolute URI from
// an absolute base URI u, per RFC 3986 Section 5.2. The URI reference
// may be relative or absolute. ResolveReference always returns a new
-// URL instance, even if the returned URL is identical to either the
+// [URL] instance, even if the returned URL is identical to either the
// base or reference. If ref is an absolute URL, then ResolveReference
// ignores base and returns a copy of ref.
func (u *URL) ResolveReference(ref *URL) *URL {
@@ -1110,7 +1117,7 @@ func (u *URL) ResolveReference(ref *URL) *URL {
// Query parses RawQuery and returns the corresponding values.
// It silently discards malformed value pairs.
-// To check errors use ParseQuery.
+// To check errors use [ParseQuery].
func (u *URL) Query() Values {
v, _ := ParseQuery(u.RawQuery)
return v
@@ -1187,7 +1194,7 @@ func (u *URL) UnmarshalBinary(text []byte) error {
return nil
}
-// JoinPath returns a new URL with the provided path elements joined to
+// JoinPath returns a new [URL] with the provided path elements joined to
// any existing path and the resulting path cleaned of any ./ or ../ elements.
// Any sequences of multiple / characters will be reduced to a single /.
func (u *URL) JoinPath(elem ...string) *URL {
@@ -1253,7 +1260,7 @@ func stringContainsCTLByte(s string) bool {
return false
}
-// JoinPath returns a URL string with the provided path elements joined to
+// JoinPath returns a [URL] string with the provided path elements joined to
// the existing path of base and the resulting path cleaned of any ./ or ../ elements.
func JoinPath(base string, elem ...string) (result string, err error) {
url, err := Parse(base)
diff --git a/src/net/url/url_test.go b/src/net/url/url_test.go
index 23c5c581c5..4aa20bb95f 100644
--- a/src/net/url/url_test.go
+++ b/src/net/url/url_test.go
@@ -1072,6 +1072,7 @@ type EncodeQueryTest struct {
var encodeQueryTests = []EncodeQueryTest{
{nil, ""},
+ {Values{}, ""},
{Values{"q": {"puppies"}, "oe": {"utf8"}}, "oe=utf8&q=puppies"},
{Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7"},
{Values{
diff --git a/src/net/writev_test.go b/src/net/writev_test.go
index 8722c0f920..e4e88c4fac 100644
--- a/src/net/writev_test.go
+++ b/src/net/writev_test.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !wasip1
-
package net
import (
@@ -187,9 +185,15 @@ func TestWritevError(t *testing.T) {
}
ln := newLocalListener(t, "tcp")
- defer ln.Close()
ch := make(chan Conn, 1)
+ defer func() {
+ ln.Close()
+ for c := range ch {
+ c.Close()
+ }
+ }()
+
go func() {
defer close(ch)
c, err := ln.Accept()