diff options
Diffstat (limited to 'src/net')
194 files changed, 8413 insertions, 3354 deletions
diff --git a/src/net/cgo_stub.go b/src/net/cgo_stub.go index b26b11af8b..a4f6b4b0e8 100644 --- a/src/net/cgo_stub.go +++ b/src/net/cgo_stub.go @@ -9,7 +9,7 @@ // (Darwin always provides the cgo functions, in cgo_unix_syscall.go) // - on wasip1, where cgo is never available -//go:build (netgo && unix) || (unix && !cgo && !darwin) || wasip1 +//go:build (netgo && unix) || (unix && !cgo && !darwin) || js || wasip1 package net diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index f10f3ea60b..7ed5daad73 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -80,7 +80,7 @@ func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error) func cgoLookupPort(ctx context.Context, network, service string) (port int, err error) { var hints _C_struct_addrinfo switch network { - case "": // no hints + case "ip": // no hints case "tcp", "tcp4", "tcp6": *_C_ai_socktype(&hints) = _C_SOCK_STREAM *_C_ai_protocol(&hints) = _C_IPPROTO_TCP @@ -120,6 +120,8 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p if err == nil { // see golang.org/issue/6232 err = syscall.EMFILE } + case _C_EAI_SERVICE, _C_EAI_NONAME: // Darwin returns EAI_NONAME. + return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true} default: err = addrinfoErrno(gerrno) isTemporary = addrinfoErrno(gerrno).Temporary() @@ -140,7 +142,7 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p return int(p[0])<<8 | int(p[1]), nil } } - return 0, &DNSError{Err: "unknown port", Name: network + "/" + service} + return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true} } func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { @@ -317,8 +319,15 @@ func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, err acquireThread() defer releaseThread() - state := (*_C_struct___res_state)(_C_malloc(unsafe.Sizeof(_C_struct___res_state{}))) - defer _C_free(unsafe.Pointer(state)) + resStateSize := unsafe.Sizeof(_C_struct___res_state{}) + var state *_C_struct___res_state + if resStateSize > 0 { + mem := _C_malloc(resStateSize) + defer _C_free(mem) + memSlice := unsafe.Slice((*byte)(mem), resStateSize) + clear(memSlice) + state = (*_C_struct___res_state)(unsafe.Pointer(&memSlice[0])) + } if err := _C_res_ninit(state); err != nil { return nil, errors.New("res_ninit failure: " + err.Error()) } diff --git a/src/net/cgo_unix_cgo.go b/src/net/cgo_unix_cgo.go index d11f3e301a..7c609eddbf 100644 --- a/src/net/cgo_unix_cgo.go +++ b/src/net/cgo_unix_cgo.go @@ -37,6 +37,7 @@ const ( _C_EAI_AGAIN = C.EAI_AGAIN _C_EAI_NODATA = C.EAI_NODATA _C_EAI_NONAME = C.EAI_NONAME + _C_EAI_SERVICE = C.EAI_SERVICE _C_EAI_OVERFLOW = C.EAI_OVERFLOW _C_EAI_SYSTEM = C.EAI_SYSTEM _C_IPPROTO_TCP = C.IPPROTO_TCP @@ -55,7 +56,6 @@ type ( _C_struct_sockaddr = C.struct_sockaddr ) -func _C_GoString(p *_C_char) string { return C.GoString(p) } func _C_malloc(n uintptr) unsafe.Pointer { return C.malloc(C.size_t(n)) } func _C_free(p unsafe.Pointer) { C.free(p) } diff --git a/src/net/cgo_unix_syscall.go b/src/net/cgo_unix_syscall.go index 2eb8df1da6..ac9aaa78fe 100644 --- a/src/net/cgo_unix_syscall.go +++ b/src/net/cgo_unix_syscall.go @@ -19,6 +19,7 @@ const ( _C_AF_UNSPEC = syscall.AF_UNSPEC _C_EAI_AGAIN = unix.EAI_AGAIN _C_EAI_NONAME = unix.EAI_NONAME + _C_EAI_SERVICE = unix.EAI_SERVICE _C_EAI_NODATA = unix.EAI_NODATA _C_EAI_OVERFLOW = unix.EAI_OVERFLOW _C_EAI_SYSTEM = unix.EAI_SYSTEM @@ -39,10 +40,6 @@ type ( _C_struct_sockaddr = syscall.RawSockaddr ) -func _C_GoString(p *_C_char) string { - return unix.GoString(p) -} - func _C_free(p unsafe.Pointer) { runtime.KeepAlive(p) } func _C_malloc(n uintptr) unsafe.Pointer { diff --git a/src/net/conf.go b/src/net/conf.go index 77cc635592..15d73cf6ce 100644 --- a/src/net/conf.go +++ b/src/net/conf.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js - package net import ( @@ -153,7 +151,7 @@ func initConfVal() { } } -// goosPreferCgo reports whether the GOOS value passed in prefers +// goosPrefersCgo reports whether the GOOS value passed in prefers // the cgo resolver. func goosPrefersCgo() bool { switch runtime.GOOS { @@ -185,7 +183,24 @@ func goosPrefersCgo() bool { // required to use the go resolver. The provided Resolver is optional. // This will report true if the cgo resolver is not available. func (c *conf) mustUseGoResolver(r *Resolver) bool { - return c.netGo || r.preferGo() || !cgoAvailable + if !cgoAvailable { + return true + } + + if runtime.GOOS == "plan9" { + // TODO(bradfitz): for now we only permit use of the PreferGo + // implementation when there's a non-nil Resolver with a + // non-nil Dialer. This is a sign that they the code is trying + // to use their DNS-speaking net.Conn (such as an in-memory + // DNS cache) and they don't want to actually hit the network. + // Once we add support for looking the default DNS servers + // from plan9, though, then we can relax this. + if r == nil || r.Dial == nil { + return false + } + } + + return c.netGo || r.preferGo() } // addrLookupOrder determines which strategy to use to resolve addresses. @@ -221,16 +236,7 @@ func (c *conf) lookupOrder(r *Resolver, hostname string) (ret hostLookupOrder, d // Go resolver was explicitly requested // or cgo resolver is not available. // Figure out the order below. - switch c.goos { - case "windows": - // TODO(bradfitz): implement files-based - // lookup on Windows too? I guess /etc/hosts - // kinda exists on Windows. But for now, only - // do DNS. - fallbackOrder = hostLookupDNS - default: - fallbackOrder = hostLookupFilesDNS - } + fallbackOrder = hostLookupFilesDNS canUseCgo = false } else if c.netCgo { // Cgo resolver was explicitly requested. @@ -516,7 +522,7 @@ func isGateway(h string) bool { return stringsEqualFold(h, "_gateway") } -// isOutbound reports whether h should be considered a "outbound" +// isOutbound reports whether h should be considered an "outbound" // name for the myhostname NSS module. func isOutbound(h string) bool { return stringsEqualFold(h, "_outbound") diff --git a/src/net/conn_test.go b/src/net/conn_test.go index 4f391b0675..d1e1e7bf1c 100644 --- a/src/net/conn_test.go +++ b/src/net/conn_test.go @@ -2,10 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This file implements API tests across platforms and will never have a build -// tag. - -//go:build !js && !wasip1 +// This file implements API tests across platforms and should never have a build +// constraint. package net @@ -21,44 +19,46 @@ const someTimeout = 1 * time.Hour func TestConnAndListener(t *testing.T) { for i, network := range []string{"tcp", "unix", "unixpacket"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } + i, network := i, network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("skipping %s test", network) + } - ls := newLocalServer(t, network) - defer ls.teardown() - ch := make(chan error, 1) - handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } - if err := ls.buildup(handler); err != nil { - t.Fatal(err) - } - if ls.Listener.Addr().Network() != network { - t.Fatalf("got %s; want %s", ls.Listener.Addr().Network(), network) - } + ls := newLocalServer(t, network) + defer ls.teardown() + ch := make(chan error, 1) + handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } + if err := ls.buildup(handler); err != nil { + t.Fatal(err) + } + if ls.Listener.Addr().Network() != network { + t.Fatalf("got %s; want %s", ls.Listener.Addr().Network(), network) + } - c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer c.Close() - if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network { - t.Fatalf("got %s->%s; want %s->%s", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network) - } - c.SetDeadline(time.Now().Add(someTimeout)) - c.SetReadDeadline(time.Now().Add(someTimeout)) - c.SetWriteDeadline(time.Now().Add(someTimeout)) + c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network { + t.Fatalf("got %s->%s; want %s->%s", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network) + } + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) - if _, err := c.Write([]byte("CONN AND LISTENER TEST")); err != nil { - t.Fatal(err) - } - rb := make([]byte, 128) - if _, err := c.Read(rb); err != nil { - t.Fatal(err) - } + if _, err := c.Write([]byte("CONN AND LISTENER TEST")); err != nil { + t.Fatal(err) + } + rb := make([]byte, 128) + if _, err := c.Read(rb); err != nil { + t.Fatal(err) + } - for err := range ch { - t.Errorf("#%d: %v", i, err) - } + for err := range ch { + t.Errorf("#%d: %v", i, err) + } + }) } } diff --git a/src/net/dial.go b/src/net/dial.go index 79bc4958bb..a6565c3ce5 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -6,6 +6,7 @@ package net import ( "context" + "internal/bytealg" "internal/godebug" "internal/nettrace" "syscall" @@ -64,7 +65,7 @@ func (m *mptcpStatus) set(use bool) { // // The zero value for each field is equivalent to dialing // without that option. Dialing with the zero value of Dialer -// is therefore equivalent to just calling the Dial function. +// is therefore equivalent to just calling the [Dial] function. // // It is safe to call Dialer's methods concurrently. type Dialer struct { @@ -226,7 +227,7 @@ func (d *Dialer) fallbackDelay() time.Duration { } func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) { - i := last(network, ':') + i := bytealg.LastIndexByteString(network, ':') if i < 0 { // no colon switch network { case "tcp", "tcp4", "tcp6": @@ -337,7 +338,7 @@ func (d *Dialer) MultipathTCP() bool { return d.mptcpStatus.get() } -// SetMultipathTCP directs the Dial methods to use, or not use, MPTCP, +// SetMultipathTCP directs the [Dial] methods to use, or not use, MPTCP, // if supported by the operating system. This method overrides the // system default and the GODEBUG=multipathtcp=... setting if any. // @@ -362,7 +363,7 @@ func (d *Dialer) SetMultipathTCP(use bool) { // brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80". // The zone specifies the scope of the literal IPv6 address as defined // in RFC 4007. -// The functions JoinHostPort and SplitHostPort manipulate a pair of +// The functions [JoinHostPort] and [SplitHostPort] manipulate a pair of // host and port in this form. // When using TCP, and the host resolves to multiple IP addresses, // Dial will try each IP address in order until one succeeds. @@ -400,7 +401,7 @@ func Dial(network, address string) (Conn, error) { return d.Dial(network, address) } -// DialTimeout acts like Dial but takes a timeout. +// DialTimeout acts like [Dial] but takes a timeout. // // The timeout includes name resolution, if required. // When using TCP, and the host in the address parameter resolves to @@ -427,8 +428,8 @@ type sysDialer struct { // See func Dial for a description of the network and address // parameters. // -// Dial uses context.Background internally; to specify the context, use -// DialContext. +// Dial uses [context.Background] internally; to specify the context, use +// [Dialer.DialContext]. func (d *Dialer) Dial(network, address string) (Conn, error) { return d.DialContext(context.Background(), network, address) } @@ -449,7 +450,7 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { // the connect to each single address will be given 15 seconds to complete // before trying the next one. // -// See func Dial for a description of the network and address +// See func [Dial] for a description of the network and address // parameters. func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { if ctx == nil { @@ -457,6 +458,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn } deadline := d.deadline(ctx, time.Now()) if !deadline.IsZero() { + testHookStepTime() if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { subCtx, cancel := context.WithDeadline(ctx, deadline) defer cancel() @@ -698,7 +700,7 @@ func (lc *ListenConfig) MultipathTCP() bool { return lc.mptcpStatus.get() } -// SetMultipathTCP directs the Listen method to use, or not use, MPTCP, +// SetMultipathTCP directs the [Listen] method to use, or not use, MPTCP, // if supported by the operating system. This method overrides the // system default and the GODEBUG=multipathtcp=... setting if any. // @@ -793,14 +795,14 @@ type sysListener struct { // addresses. // If the port in the address parameter is empty or "0", as in // "127.0.0.1:" or "[::1]:0", a port number is automatically chosen. -// The Addr method of Listener can be used to discover the chosen +// The [Addr] method of [Listener] can be used to discover the chosen // port. // -// See func Dial for a description of the network and address +// See func [Dial] for a description of the network and address // parameters. // // Listen uses context.Background internally; to specify the context, use -// ListenConfig.Listen. +// [ListenConfig.Listen]. func Listen(network, address string) (Listener, error) { var lc ListenConfig return lc.Listen(context.Background(), network, address) @@ -823,14 +825,14 @@ func Listen(network, address string) (Listener, error) { // addresses. // If the port in the address parameter is empty or "0", as in // "127.0.0.1:" or "[::1]:0", a port number is automatically chosen. -// The LocalAddr method of PacketConn can be used to discover the +// The LocalAddr method of [PacketConn] can be used to discover the // chosen port. // -// See func Dial for a description of the network and address +// See func [Dial] for a description of the network and address // parameters. // // ListenPacket uses context.Background internally; to specify the context, use -// ListenConfig.ListenPacket. +// [ListenConfig.ListenPacket]. func ListenPacket(network, address string) (PacketConn, error) { var lc ListenConfig return lc.ListenPacket(context.Background(), network, address) diff --git a/src/net/dial_test.go b/src/net/dial_test.go index ca9f0da3d3..1d0832e46e 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -784,6 +782,7 @@ func TestDialCancel(t *testing.T) { "connection refused", "unreachable", "no route to host", + "invalid argument", } e := err.Error() for _, ignore := range ignorable { @@ -982,6 +981,8 @@ func TestDialerControl(t *testing.T) { switch runtime.GOOS { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) + case "js", "wasip1": + t.Skipf("skipping: fake net does not support Dialer.Control") } t.Run("StreamDial", func(t *testing.T) { @@ -1025,40 +1026,52 @@ func TestDialerControlContext(t *testing.T) { switch runtime.GOOS { case "plan9": t.Skipf("%s does not have full support of socktest", runtime.GOOS) + case "js", "wasip1": + t.Skipf("skipping: fake net does not support Dialer.ControlContext") } t.Run("StreamDial", func(t *testing.T) { for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} { - if !testableNetwork(network) { - continue - } - ln := newLocalListener(t, network) - defer ln.Close() - var id int - d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error { - id = ctx.Value("id").(int) - return controlOnConnSetup(network, address, c) - }} - c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String()) - if err != nil { - t.Error(err) - continue - } - if id != i+1 { - t.Errorf("got id %d, want %d", id, i+1) - } - c.Close() + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("skipping: %s not available", network) + } + + ln := newLocalListener(t, network) + defer ln.Close() + var id int + d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error { + id = ctx.Value("id").(int) + return controlOnConnSetup(network, address, c) + }} + c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + if id != i+1 { + t.Errorf("got id %d, want %d", id, i+1) + } + c.Close() + }) } }) } // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork -// except that it won't skip testing on non-mobile builders. +// except on non-Linux, non-mobile builders it permits the test to +// run in -short mode. func mustHaveExternalNetwork(t *testing.T) { t.Helper() + definitelyHasLongtestBuilder := runtime.GOOS == "linux" mobile := runtime.GOOS == "android" || runtime.GOOS == "ios" - if testenv.Builder() == "" || mobile { - testenv.MustHaveExternalNetwork(t) + fake := runtime.GOOS == "js" || runtime.GOOS == "wasip1" + if testenv.Builder() != "" && !definitelyHasLongtestBuilder && !mobile && !fake { + // On a non-Linux, non-mobile builder (e.g., freebsd-amd64-13_0). + // + // Don't skip testing because otherwise the test may never run on + // any builder if this port doesn't also have a -longtest builder. + return } + testenv.MustHaveExternalNetwork(t) } type contextWithNonZeroDeadline struct { diff --git a/src/net/dnsclient.go b/src/net/dnsclient.go index b609dbd468..204620b2ed 100644 --- a/src/net/dnsclient.go +++ b/src/net/dnsclient.go @@ -8,15 +8,17 @@ import ( "internal/bytealg" "internal/itoa" "sort" + _ "unsafe" // for go:linkname "golang.org/x/net/dns/dnsmessage" ) // provided by runtime -func fastrandu() uint +//go:linkname runtime_rand runtime.rand +func runtime_rand() uint64 func randInt() int { - return int(fastrandu() >> 1) // clear sign bit + return int(uint(runtime_rand()) >> 1) // clear sign bit } func randIntn(n int) int { diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index dab5144e5d..c291d5eb4f 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js - // DNS client: see RFC 1035. // Has to be linked into package net for Dial. @@ -17,6 +15,7 @@ package net import ( "context" "errors" + "internal/bytealg" "internal/itoa" "io" "os" @@ -205,7 +204,9 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que // checkHeader performs basic sanity checks on the header. func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { - if h.RCode == dnsmessage.RCodeNameError { + rcode := extractExtendedRCode(*p, h) + + if rcode == dnsmessage.RCodeNameError { return errNoSuchHost } @@ -216,17 +217,17 @@ func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { // libresolv continues to the next server when it receives // an invalid referral response. See golang.org/issue/15434. - if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + if rcode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { return errLameReferral } - if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + if rcode != dnsmessage.RCodeSuccess && rcode != dnsmessage.RCodeNameError { // None of the error codes make sense // for the query we sent. If we didn't get // a name error and we didn't get success, // the server is behaving incorrectly or // having temporary trouble. - if h.RCode == dnsmessage.RCodeServerFailure { + if rcode == dnsmessage.RCodeServerFailure { return errServerTemporarilyMisbehaving } return errServerMisbehaving @@ -253,6 +254,23 @@ func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { } } +// extractExtendedRCode extracts the extended RCode from the OPT resource (EDNS(0)) +// If an OPT record is not found, the RCode from the hdr is returned. +func extractExtendedRCode(p dnsmessage.Parser, hdr dnsmessage.Header) dnsmessage.RCode { + p.SkipAllAnswers() + p.SkipAllAuthorities() + for { + ahdr, err := p.AdditionalHeader() + if err != nil { + return hdr.RCode + } + if ahdr.Type == dnsmessage.TypeOPT { + return ahdr.ExtendedRCode(hdr.RCode) + } + p.SkipAdditional() + } +} + // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { @@ -479,10 +497,6 @@ func avoidDNS(name string) bool { // nameList returns a list of names for sequential DNS queries. func (conf *dnsConfig) nameList(name string) []string { - if avoidDNS(name) { - return nil - } - // Check name length (see isDomainName). l := len(name) rooted := l > 0 && name[l-1] == '.' @@ -492,27 +506,31 @@ func (conf *dnsConfig) nameList(name string) []string { // If name is rooted (trailing dot), try only that name. if rooted { + if avoidDNS(name) { + return nil + } return []string{name} } - hasNdots := count(name, '.') >= conf.ndots + hasNdots := bytealg.CountString(name, '.') >= conf.ndots name += "." l++ // Build list of search choices. names := make([]string, 0, 1+len(conf.search)) // If name has enough dots, try unsuffixed first. - if hasNdots { + if hasNdots && !avoidDNS(name) { names = append(names, name) } // Try suffixes that are not too long (see isDomainName). for _, suffix := range conf.search { - if l+len(suffix) <= 254 { - names = append(names, name+suffix) + fqdn := name + suffix + if !avoidDNS(fqdn) && len(fqdn) <= 254 { + names = append(names, fqdn) } } // Try unsuffixed, if not tried first above. - if !hasNdots { + if !hasNdots && !avoidDNS(name) { names = append(names, name) } return names @@ -586,8 +604,7 @@ func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) { // goLookupIP is the native Go implementation of LookupIP. // The libc versions are in cgo_*.go. -func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) { - order, conf := systemConf().hostLookupOrder(r, host) +func (r *Resolver) goLookupIP(ctx context.Context, network, host string, order hostLookupOrder, conf *dnsConfig) (addrs []IPAddr, err error) { addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order, conf) return } @@ -699,7 +716,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin h, err := result.p.AnswerHeader() if err != nil && err != dnsmessage.ErrSectionDone { lastErr = &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: name, Server: result.server, } @@ -712,7 +729,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin a, err := result.p.AResource() if err != nil { lastErr = &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: name, Server: result.server, } @@ -727,7 +744,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin aaaa, err := result.p.AAAAResource() if err != nil { lastErr = &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: name, Server: result.server, } @@ -742,7 +759,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin c, err := result.p.CNAMEResource() if err != nil { lastErr = &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: name, Server: result.server, } @@ -755,7 +772,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin default: if err := result.p.SkipAnswer(); err != nil { lastErr = &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: name, Server: result.server, } @@ -847,7 +864,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku } if err != nil { return nil, &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: addr, Server: server, } @@ -856,7 +873,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku err := p.SkipAnswer() if err != nil { return nil, &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: addr, Server: server, } @@ -866,7 +883,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku ptr, err := p.PTRResource() if err != nil { return nil, &DNSError{ - Err: "cannot marshal DNS message", + Err: errCannotUnmarshalDNSMessage.Error(), Name: addr, Server: server, } diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 8d435a557f..0da36303cc 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -10,12 +10,12 @@ import ( "context" "errors" "fmt" - "internal/testenv" "os" "path" "path/filepath" "reflect" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -191,6 +191,19 @@ func TestAvoidDNSName(t *testing.T) { } } +func TestNameListAvoidDNS(t *testing.T) { + c := &dnsConfig{search: []string{"go.dev.", "onion."}} + got := c.nameList("www") + if !slices.Equal(got, []string{"www.", "www.go.dev."}) { + t.Fatalf(`nameList("www") = %v, want "www.", "www.go.dev."`, got) + } + + got = c.nameList("www.onion") + if !slices.Equal(got, []string{"www.onion.go.dev."}) { + t.Fatalf(`nameList("www.onion") = %v, want "www.onion.go.dev."`, got) + } +} + var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ @@ -221,7 +234,7 @@ var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.M func TestLookupTorOnion(t *testing.T) { defer dnsWaitGroup.Wait() r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} - addrs, err := r.LookupIPAddr(context.Background(), "foo.onion") + addrs, err := r.LookupIPAddr(context.Background(), "foo.onion.") if err != nil { t.Fatalf("lookup = %v; want nil", err) } @@ -606,8 +619,8 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { t.Fatal(err) } // Redirect host file lookups. - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) - testHookHostsPath = "testdata/hosts" + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) + hostsFilePath = "testdata/hosts" for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} { name := fmt.Sprintf("order %v", order) @@ -1953,8 +1966,8 @@ func TestCVE202133195(t *testing.T) { DefaultResolver = &r defer func() { DefaultResolver = originalDefault }() // Redirect host file lookups. - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) - testHookHostsPath = "testdata/hosts" + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) + hostsFilePath = "testdata/hosts" tests := []struct { name string @@ -2173,8 +2186,8 @@ func TestRootNS(t *testing.T) { } func TestGoLookupIPCNAMEOrderHostsAliasesFilesOnlyMode(t *testing.T) { - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) - testHookHostsPath = "testdata/aliases" + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) + hostsFilePath = "testdata/aliases" mode := hostLookupFiles for _, v := range lookupStaticHostAliasesTest { @@ -2183,8 +2196,8 @@ func TestGoLookupIPCNAMEOrderHostsAliasesFilesOnlyMode(t *testing.T) { } func TestGoLookupIPCNAMEOrderHostsAliasesFilesDNSMode(t *testing.T) { - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) - testHookHostsPath = "testdata/aliases" + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) + hostsFilePath = "testdata/aliases" mode := hostLookupFilesDNS for _, v := range lookupStaticHostAliasesTest { @@ -2200,11 +2213,8 @@ var goLookupIPCNAMEOrderDNSFilesModeTests = []struct { } func TestGoLookupIPCNAMEOrderHostsAliasesDNSFilesMode(t *testing.T) { - if testenv.Builder() == "" { - t.Skip("Makes assumptions about local networks and (re)naming that aren't always true") - } - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) - testHookHostsPath = "testdata/aliases" + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) + hostsFilePath = "testdata/aliases" mode := hostLookupDNSFiles for _, v := range goLookupIPCNAMEOrderDNSFilesModeTests { @@ -2213,9 +2223,29 @@ func TestGoLookupIPCNAMEOrderHostsAliasesDNSFilesMode(t *testing.T) { } func testGoLookupIPCNAMEOrderHostsAliases(t *testing.T, mode hostLookupOrder, lookup, lookupRes string) { + fake := fakeDNSServer{ + rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + var answers []dnsmessage.Resource + + if mode != hostLookupDNSFiles { + t.Fatal("received unexpected DNS query") + } + + return dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + }, + Questions: []dnsmessage.Question{q.Questions[0]}, + Answers: answers, + }, nil + }, + } + + r := Resolver{PreferGo: true, Dial: fake.DialContext} ins := []string{lookup, absDomainName(lookup), strings.ToLower(lookup), strings.ToUpper(lookup)} for _, in := range ins { - _, res, err := goResolver.goLookupIPCNAMEOrder(context.Background(), "ip", in, mode, nil) + _, res, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", in, mode, nil) if err != nil { t.Errorf("expected err == nil, but got error: %v", err) } @@ -2511,7 +2541,7 @@ func TestDNSConfigNoReload(t *testing.T) { } func TestLookupOrderFilesNoSuchHost(t *testing.T) { - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) if runtime.GOOS != "openbsd" { defer setSystemNSS(getSystemNSS(), 0) setSystemNSS(nssStr(t, "hosts: files"), time.Hour) @@ -2538,7 +2568,7 @@ func TestLookupOrderFilesNoSuchHost(t *testing.T) { if err := os.WriteFile(tmpFile, []byte{}, 0660); err != nil { t.Fatal(err) } - testHookHostsPath = tmpFile + hostsFilePath = tmpFile const testName = "test.invalid" @@ -2598,3 +2628,34 @@ func TestLookupOrderFilesNoSuchHost(t *testing.T) { } } } + +func TestExtendedRCode(t *testing.T) { + fake := fakeDNSServer{ + rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fraudSuccessCode := dnsmessage.RCodeSuccess | 1<<10 + + var edns0Hdr dnsmessage.ResourceHeader + edns0Hdr.SetEDNS0(maxDNSPacketSize, fraudSuccessCode, false) + + return dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: fraudSuccessCode, + }, + Questions: []dnsmessage.Question{q.Questions[0]}, + Additionals: []dnsmessage.Resource{{ + Header: edns0Hdr, + Body: &dnsmessage.OPTResource{}, + }}, + }, nil + }, + } + + r := &Resolver{PreferGo: true, Dial: fake.DialContext} + _, _, err := r.tryOneName(context.Background(), getSystemDNSConfig(), "go.dev.", dnsmessage.TypeA) + var dnsErr *DNSError + if !(errors.As(err, &dnsErr) && dnsErr.Err == errServerMisbehaving.Error()) { + t.Fatalf("r.tryOneName(): unexpected error: %v", err) + } +} diff --git a/src/net/dnsconfig_unix.go b/src/net/dnsconfig_unix.go index 69b300410a..b0a318279b 100644 --- a/src/net/dnsconfig_unix.go +++ b/src/net/dnsconfig_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !windows +//go:build !windows // Read system DNS config from /etc/resolv.conf diff --git a/src/net/dnsname_test.go b/src/net/dnsname_test.go index 4a5f01a04a..601a33af9f 100644 --- a/src/net/dnsname_test.go +++ b/src/net/dnsname_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( diff --git a/src/net/error_plan9_test.go b/src/net/error_plan9_test.go index 1270af19e5..f86c96c0d2 100644 --- a/src/net/error_plan9_test.go +++ b/src/net/error_plan9_test.go @@ -7,7 +7,6 @@ package net import "syscall" var ( - errTimedout = syscall.ETIMEDOUT errOpNotSupported = syscall.EPLAN9 abortedConnRequestErrors []error diff --git a/src/net/error_posix.go b/src/net/error_posix.go index c8dc069db4..84f8044045 100644 --- a/src/net/error_posix.go +++ b/src/net/error_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net diff --git a/src/net/error_test.go b/src/net/error_test.go index 4538765d48..f82e863346 100644 --- a/src/net/error_test.go +++ b/src/net/error_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -157,32 +155,33 @@ func TestDialError(t *testing.T) { d := Dialer{Timeout: someTimeout} for i, tt := range dialErrorTests { - c, err := d.Dial(tt.network, tt.address) - if err == nil { - t.Errorf("#%d: should fail; %s:%s->%s", i, c.LocalAddr().Network(), c.LocalAddr(), c.RemoteAddr()) - c.Close() - continue - } - if tt.network == "tcp" || tt.network == "udp" { - nerr := err - if op, ok := nerr.(*OpError); ok { - nerr = op.Err + i, tt := i, tt + t.Run(fmt.Sprint(i), func(t *testing.T) { + c, err := d.Dial(tt.network, tt.address) + if err == nil { + t.Errorf("should fail; %s:%s->%s", c.LocalAddr().Network(), c.LocalAddr(), c.RemoteAddr()) + c.Close() + return } - if sys, ok := nerr.(*os.SyscallError); ok { - nerr = sys.Err + if tt.network == "tcp" || tt.network == "udp" { + nerr := err + if op, ok := nerr.(*OpError); ok { + nerr = op.Err + } + if sys, ok := nerr.(*os.SyscallError); ok { + nerr = sys.Err + } + if nerr == errOpNotSupported { + t.Fatalf("should fail without %v; %s:%s->", nerr, tt.network, tt.address) + } } - if nerr == errOpNotSupported { - t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address) - continue + if c != nil { + t.Errorf("Dial returned non-nil interface %T(%v) with err != nil", c, c) } - } - if c != nil { - t.Errorf("Dial returned non-nil interface %T(%v) with err != nil", c, c) - } - if err = parseDialError(err); err != nil { - t.Errorf("#%d: %v", i, err) - continue - } + if err = parseDialError(err); err != nil { + t.Error(err) + } + }) } } @@ -208,10 +207,11 @@ func TestProtocolDialError(t *testing.T) { t.Errorf("%s: should fail", network) continue } - if err = parseDialError(err); err != nil { + if err := parseDialError(err); err != nil { t.Errorf("%s: %v", network, err) continue } + t.Logf("%s: error as expected: %v", network, err) } } @@ -220,6 +220,7 @@ func TestDialAddrError(t *testing.T) { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -236,38 +237,42 @@ func TestDialAddrError(t *testing.T) { // control name resolution. {"tcp6", "", &TCPAddr{IP: IP{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}}}, } { - var err error - var c Conn - var op string - if tt.lit != "" { - c, err = Dial(tt.network, JoinHostPort(tt.lit, "0")) - op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0")) - } else { - c, err = DialTCP(tt.network, nil, tt.addr) - op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr) - } - if err == nil { - c.Close() - t.Errorf("%s succeeded, want error", op) - continue - } - if perr := parseDialError(err); perr != nil { - t.Errorf("%s: %v", op, perr) - continue - } - operr := err.(*OpError).Err - aerr, ok := operr.(*AddrError) - if !ok { - t.Errorf("%s: %v is %T, want *AddrError", op, err, operr) - continue - } - want := tt.lit - if tt.lit == "" { - want = tt.addr.IP.String() - } - if aerr.Addr != want { - t.Errorf("%s: %v, error Addr=%q, want %q", op, err, aerr.Addr, want) - } + desc := tt.lit + if desc == "" { + desc = tt.addr.String() + } + t.Run(fmt.Sprintf("%s/%s", tt.network, desc), func(t *testing.T) { + var err error + var c Conn + var op string + if tt.lit != "" { + c, err = Dial(tt.network, JoinHostPort(tt.lit, "0")) + op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0")) + } else { + c, err = DialTCP(tt.network, nil, tt.addr) + op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr) + } + t.Logf("%s: %v", op, err) + if err == nil { + c.Close() + t.Fatalf("%s succeeded, want error", op) + } + if perr := parseDialError(err); perr != nil { + t.Fatal(perr) + } + operr := err.(*OpError).Err + aerr, ok := operr.(*AddrError) + if !ok { + t.Fatalf("OpError.Err is %T, want *AddrError", operr) + } + want := tt.lit + if tt.lit == "" { + want = tt.addr.IP.String() + } + if aerr.Addr != want { + t.Errorf("error Addr=%q, want %q", aerr.Addr, want) + } + }) } } @@ -305,32 +310,32 @@ func TestListenError(t *testing.T) { defer sw.Set(socktest.FilterListen, nil) for i, tt := range listenErrorTests { - ln, err := Listen(tt.network, tt.address) - if err == nil { - t.Errorf("#%d: should fail; %s:%s->", i, ln.Addr().Network(), ln.Addr()) - ln.Close() - continue - } - if tt.network == "tcp" { - nerr := err - if op, ok := nerr.(*OpError); ok { - nerr = op.Err + t.Run(fmt.Sprintf("%s_%s", tt.network, tt.address), func(t *testing.T) { + ln, err := Listen(tt.network, tt.address) + if err == nil { + t.Errorf("#%d: should fail; %s:%s->", i, ln.Addr().Network(), ln.Addr()) + ln.Close() + return } - if sys, ok := nerr.(*os.SyscallError); ok { - nerr = sys.Err + if tt.network == "tcp" { + nerr := err + if op, ok := nerr.(*OpError); ok { + nerr = op.Err + } + if sys, ok := nerr.(*os.SyscallError); ok { + nerr = sys.Err + } + if nerr == errOpNotSupported { + t.Fatalf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address) + } } - if nerr == errOpNotSupported { - t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address) - continue + if ln != nil { + t.Errorf("Listen returned non-nil interface %T(%v) with err != nil", ln, ln) } - } - if ln != nil { - t.Errorf("Listen returned non-nil interface %T(%v) with err != nil", ln, ln) - } - if err = parseDialError(err); err != nil { - t.Errorf("#%d: %v", i, err) - continue - } + if err = parseDialError(err); err != nil { + t.Errorf("#%d: %v", i, err) + } + }) } } @@ -361,19 +366,20 @@ func TestListenPacketError(t *testing.T) { } for i, tt := range listenPacketErrorTests { - c, err := ListenPacket(tt.network, tt.address) - if err == nil { - t.Errorf("#%d: should fail; %s:%s->", i, c.LocalAddr().Network(), c.LocalAddr()) - c.Close() - continue - } - if c != nil { - t.Errorf("ListenPacket returned non-nil interface %T(%v) with err != nil", c, c) - } - if err = parseDialError(err); err != nil { - t.Errorf("#%d: %v", i, err) - continue - } + t.Run(fmt.Sprintf("%s_%s", tt.network, tt.address), func(t *testing.T) { + c, err := ListenPacket(tt.network, tt.address) + if err == nil { + t.Errorf("#%d: should fail; %s:%s->", i, c.LocalAddr().Network(), c.LocalAddr()) + c.Close() + return + } + if c != nil { + t.Errorf("ListenPacket returned non-nil interface %T(%v) with err != nil", c, c) + } + if err = parseDialError(err); err != nil { + t.Errorf("#%d: %v", i, err) + } + }) } } @@ -557,49 +563,57 @@ third: } func TestCloseError(t *testing.T) { - ln := newLocalListener(t, "tcp") - defer ln.Close() - c, err := Dial(ln.Addr().Network(), ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer c.Close() + t.Run("tcp", func(t *testing.T) { + ln := newLocalListener(t, "tcp") + defer ln.Close() + c, err := Dial(ln.Addr().Network(), ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() - for i := 0; i < 3; i++ { - err = c.(*TCPConn).CloseRead() - if perr := parseCloseError(err, true); perr != nil { - t.Errorf("#%d: %v", i, perr) + for i := 0; i < 3; i++ { + err = c.(*TCPConn).CloseRead() + if perr := parseCloseError(err, true); perr != nil { + t.Errorf("#%d: %v", i, perr) + } } - } - for i := 0; i < 3; i++ { - err = c.(*TCPConn).CloseWrite() - if perr := parseCloseError(err, true); perr != nil { - t.Errorf("#%d: %v", i, perr) + for i := 0; i < 3; i++ { + err = c.(*TCPConn).CloseWrite() + if perr := parseCloseError(err, true); perr != nil { + t.Errorf("#%d: %v", i, perr) + } } - } - for i := 0; i < 3; i++ { - err = c.Close() - if perr := parseCloseError(err, false); perr != nil { - t.Errorf("#%d: %v", i, perr) + for i := 0; i < 3; i++ { + err = c.Close() + if perr := parseCloseError(err, false); perr != nil { + t.Errorf("#%d: %v", i, perr) + } + err = ln.Close() + if perr := parseCloseError(err, false); perr != nil { + t.Errorf("#%d: %v", i, perr) + } } - err = ln.Close() - if perr := parseCloseError(err, false); perr != nil { - t.Errorf("#%d: %v", i, perr) + }) + + t.Run("udp", func(t *testing.T) { + if !testableNetwork("udp") { + t.Skipf("skipping: udp not available") } - } - pc, err := ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer pc.Close() + pc, err := ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer pc.Close() - for i := 0; i < 3; i++ { - err = pc.Close() - if perr := parseCloseError(err, false); perr != nil { - t.Errorf("#%d: %v", i, perr) + for i := 0; i < 3; i++ { + err = pc.Close() + if perr := parseCloseError(err, false); perr != nil { + t.Errorf("#%d: %v", i, perr) + } } - } + }) } // parseAcceptError parses nestedErr and reports whether it is a valid diff --git a/src/net/error_unix_test.go b/src/net/error_unix_test.go index 291a7234f2..963ba21f1a 100644 --- a/src/net/error_unix_test.go +++ b/src/net/error_unix_test.go @@ -13,7 +13,6 @@ import ( ) var ( - errTimedout = syscall.ETIMEDOUT errOpNotSupported = syscall.EOPNOTSUPP abortedConnRequestErrors = []error{syscall.ECONNABORTED} // see accept in fd_unix.go diff --git a/src/net/error_windows_test.go b/src/net/error_windows_test.go index 25825f96f8..7847af0551 100644 --- a/src/net/error_windows_test.go +++ b/src/net/error_windows_test.go @@ -10,7 +10,6 @@ import ( ) var ( - errTimedout = syscall.ETIMEDOUT errOpNotSupported = syscall.EOPNOTSUPP abortedConnRequestErrors = []error{syscall.ERROR_NETNAME_DELETED, syscall.WSAECONNRESET} // see accept in fd_windows.go diff --git a/src/net/external_test.go b/src/net/external_test.go index 0709b9d6f5..38788efc3d 100644 --- a/src/net/external_test.go +++ b/src/net/external_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( diff --git a/src/net/fd_fake.go b/src/net/fd_fake.go new file mode 100644 index 0000000000..ae567acc69 --- /dev/null +++ b/src/net/fd_fake.go @@ -0,0 +1,169 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js || wasip1 + +package net + +import ( + "internal/poll" + "runtime" + "time" +) + +const ( + readSyscallName = "fd_read" + writeSyscallName = "fd_write" +) + +// Network file descriptor. +type netFD struct { + pfd poll.FD + + // immutable until Close + family int + sotype int + isConnected bool // handshake completed or use of association with peer + net string + laddr Addr + raddr Addr + + // The only networking available in WASI preview 1 is the ability to + // sock_accept on a pre-opened socket, and then fd_read, fd_write, + // fd_close, and sock_shutdown on the resulting connection. We + // intercept applicable netFD calls on this instance, and then pass + // the remainder of the netFD calls to fakeNetFD. + *fakeNetFD +} + +func newFD(net string, sysfd int) *netFD { + return newPollFD(net, poll.FD{ + Sysfd: sysfd, + IsStream: true, + ZeroReadIsEOF: true, + }) +} + +func newPollFD(net string, pfd poll.FD) *netFD { + var laddr Addr + var raddr Addr + // WASI preview 1 does not have functions like getsockname/getpeername, + // so we cannot get access to the underlying IP address used by connections. + // + // However, listeners created by FileListener are of type *TCPListener, + // which can be asserted by a Go program. The (*TCPListener).Addr method + // documents that the returned value will be of type *TCPAddr, we satisfy + // the documented behavior by creating addresses of the expected type here. + switch net { + case "tcp": + laddr = new(TCPAddr) + raddr = new(TCPAddr) + case "udp": + laddr = new(UDPAddr) + raddr = new(UDPAddr) + default: + laddr = unknownAddr{} + raddr = unknownAddr{} + } + return &netFD{ + pfd: pfd, + net: net, + laddr: laddr, + raddr: raddr, + } +} + +func (fd *netFD) init() error { + return fd.pfd.Init(fd.net, true) +} + +func (fd *netFD) name() string { + return "unknown" +} + +func (fd *netFD) accept() (netfd *netFD, err error) { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.accept(fd.laddr) + } + d, _, errcall, err := fd.pfd.Accept() + if err != nil { + if errcall != "" { + err = wrapSyscallError(errcall, err) + } + return nil, err + } + netfd = newFD("tcp", d) + if err = netfd.init(); err != nil { + netfd.Close() + return nil, err + } + return netfd, nil +} + +func (fd *netFD) setAddr(laddr, raddr Addr) { + fd.laddr = laddr + fd.raddr = raddr + runtime.SetFinalizer(fd, (*netFD).Close) +} + +func (fd *netFD) Close() error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.Close() + } + runtime.SetFinalizer(fd, nil) + return fd.pfd.Close() +} + +func (fd *netFD) shutdown(how int) error { + if fd.fakeNetFD != nil { + return nil + } + err := fd.pfd.Shutdown(how) + runtime.KeepAlive(fd) + return wrapSyscallError("shutdown", err) +} + +func (fd *netFD) Read(p []byte) (n int, err error) { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.Read(p) + } + n, err = fd.pfd.Read(p) + runtime.KeepAlive(fd) + return n, wrapSyscallError(readSyscallName, err) +} + +func (fd *netFD) Write(p []byte) (nn int, err error) { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.Write(p) + } + nn, err = fd.pfd.Write(p) + runtime.KeepAlive(fd) + return nn, wrapSyscallError(writeSyscallName, err) +} + +func (fd *netFD) SetDeadline(t time.Time) error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.SetDeadline(t) + } + return fd.pfd.SetDeadline(t) +} + +func (fd *netFD) SetReadDeadline(t time.Time) error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.SetReadDeadline(t) + } + return fd.pfd.SetReadDeadline(t) +} + +func (fd *netFD) SetWriteDeadline(t time.Time) error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.SetWriteDeadline(t) + } + return fd.pfd.SetWriteDeadline(t) +} + +type unknownAddr struct{} + +func (unknownAddr) Network() string { return "unknown" } +func (unknownAddr) String() string { return "unknown" } diff --git a/src/net/fd_js.go b/src/net/fd_js.go new file mode 100644 index 0000000000..0fce036ef1 --- /dev/null +++ b/src/net/fd_js.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fake networking for js/wasm. It is intended to allow tests of other package to pass. + +//go:build js + +package net + +import ( + "os" + "syscall" +) + +func (fd *netFD) closeRead() error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.closeRead() + } + return os.NewSyscallError("closeRead", syscall.ENOTSUP) +} + +func (fd *netFD) closeWrite() error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.closeWrite() + } + return os.NewSyscallError("closeRead", syscall.ENOTSUP) +} diff --git a/src/net/fd_wasip1.go b/src/net/fd_wasip1.go index 74d0b0b2e8..d50effc05d 100644 --- a/src/net/fd_wasip1.go +++ b/src/net/fd_wasip1.go @@ -7,124 +7,9 @@ package net import ( - "internal/poll" - "runtime" "syscall" - "time" ) -const ( - readSyscallName = "fd_read" - writeSyscallName = "fd_write" -) - -// Network file descriptor. -type netFD struct { - pfd poll.FD - - // immutable until Close - family int - sotype int - isConnected bool // handshake completed or use of association with peer - net string - laddr Addr - raddr Addr - - // The only networking available in WASI preview 1 is the ability to - // sock_accept on an pre-opened socket, and then fd_read, fd_write, - // fd_close, and sock_shutdown on the resulting connection. We - // intercept applicable netFD calls on this instance, and then pass - // the remainder of the netFD calls to fakeNetFD. - *fakeNetFD -} - -func newFD(net string, sysfd int) *netFD { - return newPollFD(net, poll.FD{ - Sysfd: sysfd, - IsStream: true, - ZeroReadIsEOF: true, - }) -} - -func newPollFD(net string, pfd poll.FD) *netFD { - var laddr Addr - var raddr Addr - // WASI preview 1 does not have functions like getsockname/getpeername, - // so we cannot get access to the underlying IP address used by connections. - // - // However, listeners created by FileListener are of type *TCPListener, - // which can be asserted by a Go program. The (*TCPListener).Addr method - // documents that the returned value will be of type *TCPAddr, we satisfy - // the documented behavior by creating addresses of the expected type here. - switch net { - case "tcp": - laddr = new(TCPAddr) - raddr = new(TCPAddr) - case "udp": - laddr = new(UDPAddr) - raddr = new(UDPAddr) - default: - laddr = unknownAddr{} - raddr = unknownAddr{} - } - return &netFD{ - pfd: pfd, - net: net, - laddr: laddr, - raddr: raddr, - } -} - -func (fd *netFD) init() error { - return fd.pfd.Init(fd.net, true) -} - -func (fd *netFD) name() string { - return "unknown" -} - -func (fd *netFD) accept() (netfd *netFD, err error) { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.accept() - } - d, _, errcall, err := fd.pfd.Accept() - if err != nil { - if errcall != "" { - err = wrapSyscallError(errcall, err) - } - return nil, err - } - netfd = newFD("tcp", d) - if err = netfd.init(); err != nil { - netfd.Close() - return nil, err - } - return netfd, nil -} - -func (fd *netFD) setAddr(laddr, raddr Addr) { - fd.laddr = laddr - fd.raddr = raddr - runtime.SetFinalizer(fd, (*netFD).Close) -} - -func (fd *netFD) Close() error { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.Close() - } - runtime.SetFinalizer(fd, nil) - return fd.pfd.Close() -} - -func (fd *netFD) shutdown(how int) error { - if fd.fakeNetFD != nil { - return nil - } - err := fd.pfd.Shutdown(how) - runtime.KeepAlive(fd) - return wrapSyscallError("shutdown", err) -} - func (fd *netFD) closeRead() error { if fd.fakeNetFD != nil { return fd.fakeNetFD.closeRead() @@ -138,47 +23,3 @@ func (fd *netFD) closeWrite() error { } return fd.shutdown(syscall.SHUT_WR) } - -func (fd *netFD) Read(p []byte) (n int, err error) { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.Read(p) - } - n, err = fd.pfd.Read(p) - runtime.KeepAlive(fd) - return n, wrapSyscallError(readSyscallName, err) -} - -func (fd *netFD) Write(p []byte) (nn int, err error) { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.Write(p) - } - nn, err = fd.pfd.Write(p) - runtime.KeepAlive(fd) - return nn, wrapSyscallError(writeSyscallName, err) -} - -func (fd *netFD) SetDeadline(t time.Time) error { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.SetDeadline(t) - } - return fd.pfd.SetDeadline(t) -} - -func (fd *netFD) SetReadDeadline(t time.Time) error { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.SetReadDeadline(t) - } - return fd.pfd.SetReadDeadline(t) -} - -func (fd *netFD) SetWriteDeadline(t time.Time) error { - if fd.fakeNetFD != nil { - return fd.fakeNetFD.SetWriteDeadline(t) - } - return fd.pfd.SetWriteDeadline(t) -} - -type unknownAddr struct{} - -func (unknownAddr) Network() string { return "unknown" } -func (unknownAddr) String() string { return "unknown" } diff --git a/src/net/fd_windows.go b/src/net/fd_windows.go index eeb994dfd9..45a10cf1eb 100644 --- a/src/net/fd_windows.go +++ b/src/net/fd_windows.go @@ -64,10 +64,38 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall. if err := fd.init(); err != nil { return nil, err } - if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { - fd.pfd.SetWriteDeadline(deadline) + + if ctx.Done() != nil { + // Propagate the Context's deadline and cancellation. + // If the context is already done, or if it has a nonzero deadline, + // ensure that that is applied before the call to ConnectEx begins + // so that we don't return spurious connections. defer fd.pfd.SetWriteDeadline(noDeadline) + + if ctx.Err() != nil { + fd.pfd.SetWriteDeadline(aLongTimeAgo) + } else { + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + fd.pfd.SetWriteDeadline(deadline) + } + + done := make(chan struct{}) + stop := context.AfterFunc(ctx, func() { + // Force the runtime's poller to immediately give + // up waiting for writability. + fd.pfd.SetWriteDeadline(aLongTimeAgo) + close(done) + }) + defer func() { + if !stop() { + // Wait for the call to SetWriteDeadline to complete so that we can + // reset the deadline if everything else succeeded. + <-done + } + }() + } } + if !canUseConnectEx(fd.net) { err := connectFunc(fd.pfd.Sysfd, ra) return nil, os.NewSyscallError("connect", err) @@ -113,22 +141,6 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall. _ = fd.pfd.WSAIoctl(windows.SIO_TCP_INITIAL_RTO, (*byte)(unsafe.Pointer(¶ms)), uint32(unsafe.Sizeof(params)), nil, 0, &out, nil, 0) } - // Wait for the goroutine converting context.Done into a write timeout - // to exist, otherwise our caller might cancel the context and - // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial. - done := make(chan bool) // must be unbuffered - defer func() { done <- true }() - go func() { - select { - case <-ctx.Done(): - // Force the runtime's poller to immediately give - // up waiting for writability. - fd.pfd.SetWriteDeadline(aLongTimeAgo) - <-done - case <-done: - } - }() - // Call ConnectEx API. if err := fd.pfd.ConnectEx(ra); err != nil { select { diff --git a/src/net/file_stub.go b/src/net/file_stub.go index 91df926a57..6fd3eec48d 100644 --- a/src/net/file_stub.go +++ b/src/net/file_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build js && wasm +//go:build js package net diff --git a/src/net/file_test.go b/src/net/file_test.go index 53cd3c1074..c517af50c5 100644 --- a/src/net/file_test.go +++ b/src/net/file_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -31,7 +29,7 @@ var fileConnTests = []struct { func TestFileConn(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9", "windows", "js", "wasip1": t.Skipf("not supported on %s", runtime.GOOS) } @@ -132,7 +130,7 @@ var fileListenerTests = []struct { func TestFileListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9", "windows", "js", "wasip1": t.Skipf("not supported on %s", runtime.GOOS) } @@ -224,7 +222,7 @@ var filePacketConnTests = []struct { func TestFilePacketConn(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9", "windows", "js", "wasip1": t.Skipf("not supported on %s", runtime.GOOS) } @@ -291,7 +289,7 @@ func TestFilePacketConn(t *testing.T) { // Issue 24483. func TestFileCloseRace(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9", "windows", "js", "wasip1": t.Skipf("not supported on %s", runtime.GOOS) } if !testableNetwork("tcp") { diff --git a/src/net/hook.go b/src/net/hook.go index ea71803e22..eded34d48a 100644 --- a/src/net/hook.go +++ b/src/net/hook.go @@ -13,8 +13,7 @@ var ( // if non-nil, overrides dialTCP. testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) - testHookHostsPath = "/etc/hosts" - testHookLookupIP = func( + testHookLookupIP = func( ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network string, @@ -23,4 +22,10 @@ var ( return fn(ctx, network, host) } testHookSetKeepAlive = func(time.Duration) {} + + // testHookStepTime sleeps until time has moved forward by a nonzero amount. + // This helps to avoid flakes in timeout tests by ensuring that an implausibly + // short deadline (such as 1ns in the future) is always expired by the time + // a relevant system call occurs. + testHookStepTime = func() {} ) diff --git a/src/net/hook_plan9.go b/src/net/hook_plan9.go index e053348505..6020d32924 100644 --- a/src/net/hook_plan9.go +++ b/src/net/hook_plan9.go @@ -4,6 +4,6 @@ package net -import "time" - -var testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349 +var ( + hostsFilePath = "/etc/hosts" +) diff --git a/src/net/hook_unix.go b/src/net/hook_unix.go index 4e20f59218..69b375598d 100644 --- a/src/net/hook_unix.go +++ b/src/net/hook_unix.go @@ -2,16 +2,17 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 +//go:build unix || js || wasip1 package net import "syscall" var ( - testHookDialChannel = func() {} // for golang.org/issue/5349 testHookCanceledDial = func() {} // for golang.org/issue/16523 + hostsFilePath = "/etc/hosts" + // Placeholders for socket system calls. socketFunc func(int, int, int) (int, error) = syscall.Socket connectFunc func(int, syscall.Sockaddr) error = syscall.Connect diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go index ab8656cbbf..f7c5b5af90 100644 --- a/src/net/hook_windows.go +++ b/src/net/hook_windows.go @@ -7,14 +7,12 @@ package net import ( "internal/syscall/windows" "syscall" - "time" ) var ( - testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349 + hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts" // Placeholders for socket system calls. - socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect listenFunc func(syscall.Handle, int) error = syscall.Listen diff --git a/src/net/hosts.go b/src/net/hosts.go index 56e6674144..73e6fcc7a4 100644 --- a/src/net/hosts.go +++ b/src/net/hosts.go @@ -51,7 +51,7 @@ var hosts struct { func readHosts() { now := time.Now() - hp := testHookHostsPath + hp := hostsFilePath if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 { return diff --git a/src/net/hosts_test.go b/src/net/hosts_test.go index b3f189e641..5f22920765 100644 --- a/src/net/hosts_test.go +++ b/src/net/hosts_test.go @@ -59,10 +59,10 @@ var lookupStaticHostTests = []struct { } func TestLookupStaticHost(t *testing.T) { - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) for _, tt := range lookupStaticHostTests { - testHookHostsPath = tt.name + hostsFilePath = tt.name for _, ent := range tt.ents { testStaticHost(t, tt.name, ent) } @@ -128,10 +128,10 @@ var lookupStaticAddrTests = []struct { } func TestLookupStaticAddr(t *testing.T) { - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) for _, tt := range lookupStaticAddrTests { - testHookHostsPath = tt.name + hostsFilePath = tt.name for _, ent := range tt.ents { testStaticAddr(t, tt.name, ent) } @@ -151,27 +151,27 @@ func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) { func TestHostCacheModification(t *testing.T) { // Ensure that programs can't modify the internals of the host cache. // See https://golang.org/issues/14212. - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) - testHookHostsPath = "testdata/ipv4-hosts" + hostsFilePath = "testdata/ipv4-hosts" ent := staticHostEntry{"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}} - testStaticHost(t, testHookHostsPath, ent) + testStaticHost(t, hostsFilePath, ent) // Modify the addresses return by lookupStaticHost. addrs, _ := lookupStaticHost(ent.in) for i := range addrs { addrs[i] += "junk" } - testStaticHost(t, testHookHostsPath, ent) + testStaticHost(t, hostsFilePath, ent) - testHookHostsPath = "testdata/ipv6-hosts" + hostsFilePath = "testdata/ipv6-hosts" ent = staticHostEntry{"::1", []string{"localhost"}} - testStaticAddr(t, testHookHostsPath, ent) + testStaticAddr(t, hostsFilePath, ent) // Modify the hosts return by lookupStaticAddr. hosts := lookupStaticAddr(ent.in) for i := range hosts { hosts[i] += "junk" } - testStaticAddr(t, testHookHostsPath, ent) + testStaticAddr(t, hostsFilePath, ent) } var lookupStaticHostAliasesTest = []struct { @@ -195,9 +195,9 @@ var lookupStaticHostAliasesTest = []struct { } func TestLookupStaticHostAliases(t *testing.T) { - defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + defer func(orig string) { hostsFilePath = orig }(hostsFilePath) - testHookHostsPath = "testdata/aliases" + hostsFilePath = "testdata/aliases" for _, ent := range lookupStaticHostAliasesTest { testLookupStaticHostAliases(t, ent.lookup, absDomainName(ent.res)) } diff --git a/src/net/http/cgi/cgi_main.go b/src/net/http/cgi/cgi_main.go new file mode 100644 index 0000000000..8997d66a11 --- /dev/null +++ b/src/net/http/cgi/cgi_main.go @@ -0,0 +1,145 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cgi + +import ( + "fmt" + "io" + "net/http" + "os" + "path" + "sort" + "strings" + "time" +) + +func cgiMain() { + switch path.Join(os.Getenv("SCRIPT_NAME"), os.Getenv("PATH_INFO")) { + case "/bar", "/test.cgi", "/myscript/bar", "/test.cgi/extrapath": + testCGI() + return + } + childCGIProcess() +} + +// testCGI is a CGI program translated from a Perl program to complete host_test. +// test cases in host_test should be provided by testCGI. +func testCGI() { + req, err := Request() + if err != nil { + panic(err) + } + + err = req.ParseForm() + if err != nil { + panic(err) + } + + params := req.Form + if params.Get("loc") != "" { + fmt.Printf("Location: %s\r\n\r\n", params.Get("loc")) + return + } + + fmt.Printf("Content-Type: text/html\r\n") + fmt.Printf("X-CGI-Pid: %d\r\n", os.Getpid()) + fmt.Printf("X-Test-Header: X-Test-Value\r\n") + fmt.Printf("\r\n") + + if params.Get("writestderr") != "" { + fmt.Fprintf(os.Stderr, "Hello, stderr!\n") + } + + if params.Get("bigresponse") != "" { + // 17 MB, for OS X: golang.org/issue/4958 + line := strings.Repeat("A", 1024) + for i := 0; i < 17*1024; i++ { + fmt.Printf("%s\r\n", line) + } + return + } + + fmt.Printf("test=Hello CGI\r\n") + + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + fmt.Printf("param-%s=%s\r\n", key, params.Get(key)) + } + + envs := envMap(os.Environ()) + keys = make([]string, 0, len(envs)) + for k := range envs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + fmt.Printf("env-%s=%s\r\n", key, envs[key]) + } + + cwd, _ := os.Getwd() + fmt.Printf("cwd=%s\r\n", cwd) +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// childCGIProcess is used by integration_test to complete unit tests. +func childCGIProcess() { + if os.Getenv("REQUEST_METHOD") == "" { + // Not in a CGI environment; skipping test. + return + } + switch os.Getenv("REQUEST_URI") { + case "/immediate-disconnect": + os.Exit(0) + case "/no-content-type": + fmt.Printf("Content-Length: 6\n\nHello\n") + os.Exit(0) + case "/empty-headers": + fmt.Printf("\nHello") + os.Exit(0) + } + Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.FormValue("nil-request-body") == "1" { + fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil) + return + } + rw.Header().Set("X-Test-Header", "X-Test-Value") + req.ParseForm() + if req.FormValue("no-body") == "1" { + return + } + if eb, ok := req.Form["exact-body"]; ok { + io.WriteString(rw, eb[0]) + return + } + if req.FormValue("write-forever") == "1" { + io.Copy(rw, neverEnding('a')) + for { + time.Sleep(5 * time.Second) // hang forever, until killed + } + } + fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") + for k, vv := range req.Form { + for _, v := range vv { + fmt.Fprintf(rw, "param-%s=%s\n", k, v) + } + } + for _, kv := range os.Environ() { + fmt.Fprintf(rw, "env-%s\n", kv) + } + })) + os.Exit(0) +} diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go index 1411f0b8e8..e29fe20d7d 100644 --- a/src/net/http/cgi/child.go +++ b/src/net/http/cgi/child.go @@ -46,7 +46,7 @@ func envMap(env []string) map[string]string { return m } -// RequestFromMap creates an http.Request from CGI variables. +// RequestFromMap creates an [http.Request] from CGI variables. // The returned Request's Body field is not populated. func RequestFromMap(params map[string]string) (*http.Request, error) { r := new(http.Request) @@ -138,10 +138,10 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { return r, nil } -// Serve executes the provided Handler on the currently active CGI +// Serve executes the provided [Handler] on the currently active CGI // request, if any. If there's no current CGI environment // an error is returned. The provided handler may be nil to use -// http.DefaultServeMux. +// [http.DefaultServeMux]. func Serve(handler http.Handler) error { req, err := Request() if err != nil { diff --git a/src/net/http/cgi/host.go b/src/net/http/cgi/host.go index 073952a7bd..ef222ab73a 100644 --- a/src/net/http/cgi/host.go +++ b/src/net/http/cgi/host.go @@ -115,23 +115,19 @@ func removeLeadingDuplicates(env []string) (ret []string) { } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - root := h.Root - if root == "" { - root = "/" - } - if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" { rw.WriteHeader(http.StatusBadRequest) rw.Write([]byte("Chunked request bodies are not supported by CGI.")) return } - pathInfo := req.URL.Path - if root != "/" && strings.HasPrefix(pathInfo, root) { - pathInfo = pathInfo[len(root):] - } + root := strings.TrimRight(h.Root, "/") + pathInfo := strings.TrimPrefix(req.URL.Path, root) port := "80" + if req.TLS != nil { + port = "443" + } if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 { port = matches[1] } diff --git a/src/net/http/cgi/host_test.go b/src/net/http/cgi/host_test.go index 860e9b3e8f..f29395fe84 100644 --- a/src/net/http/cgi/host_test.go +++ b/src/net/http/cgi/host_test.go @@ -9,21 +9,33 @@ package cgi import ( "bufio" "fmt" + "internal/testenv" "io" "net" "net/http" "net/http/httptest" "os" - "os/exec" "path/filepath" "reflect" + "regexp" "runtime" - "strconv" "strings" "testing" "time" ) +// TestMain executes the test binary as the cgi server if +// SERVER_SOFTWARE is set, and runs the tests otherwise. +func TestMain(m *testing.M) { + // SERVER_SOFTWARE swap variable is set when starting the cgi server. + if os.Getenv("SERVER_SOFTWARE") != "" { + cgiMain() + os.Exit(0) + } + + os.Exit(m.Run()) +} + func newRequest(httpreq string) *http.Request { buf := bufio.NewReader(strings.NewReader(httpreq)) req, err := http.ReadRequest(buf) @@ -88,24 +100,10 @@ readlines: } } -var cgiTested, cgiWorks bool - -func check(t *testing.T) { - if !cgiTested { - cgiTested = true - cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil - } - if !cgiWorks { - // No Perl on Windows, needed by test.cgi - // TODO: make the child process be Go, not Perl. - t.Skip("Skipping test: test.cgi failed.") - } -} - func TestCGIBasicGet(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } expectedMap := map[string]string{ @@ -121,7 +119,7 @@ func TestCGIBasicGet(t *testing.T) { "env-REMOTE_PORT": "1234", "env-REQUEST_METHOD": "GET", "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_FILENAME": os.Args[0], "env-SCRIPT_NAME": "/test.cgi", "env-SERVER_NAME": "example.com", "env-SERVER_PORT": "80", @@ -138,9 +136,9 @@ func TestCGIBasicGet(t *testing.T) { } func TestCGIEnvIPv6(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } expectedMap := map[string]string{ @@ -156,7 +154,7 @@ func TestCGIEnvIPv6(t *testing.T) { "env-REMOTE_PORT": "12345", "env-REQUEST_METHOD": "GET", "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_FILENAME": os.Args[0], "env-SCRIPT_NAME": "/test.cgi", "env-SERVER_NAME": "example.com", "env-SERVER_PORT": "80", @@ -171,27 +169,27 @@ func TestCGIEnvIPv6(t *testing.T) { } func TestCGIBasicGetAbsPath(t *testing.T) { - check(t) - pwd, err := os.Getwd() + absPath, err := filepath.Abs(os.Args[0]) if err != nil { - t.Fatalf("getwd error: %v", err) + t.Fatal(err) } + testenv.MustHaveExec(t) h := &Handler{ - Path: pwd + "/testdata/test.cgi", + Path: absPath, Root: "/test.cgi", } expectedMap := map[string]string{ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", - "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi", + "env-SCRIPT_FILENAME": absPath, "env-SCRIPT_NAME": "/test.cgi", } runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) } func TestPathInfo(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } expectedMap := map[string]string{ @@ -199,36 +197,36 @@ func TestPathInfo(t *testing.T) { "env-PATH_INFO": "/extrapath", "env-QUERY_STRING": "a=b", "env-REQUEST_URI": "/test.cgi/extrapath?a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_FILENAME": os.Args[0], "env-SCRIPT_NAME": "/test.cgi", } runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) } func TestPathInfoDirRoot(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", - Root: "/myscript/", + Path: os.Args[0], + Root: "/myscript//", } expectedMap := map[string]string{ - "env-PATH_INFO": "bar", + "env-PATH_INFO": "/bar", "env-QUERY_STRING": "a=b", "env-REQUEST_URI": "/myscript/bar?a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", - "env-SCRIPT_NAME": "/myscript/", + "env-SCRIPT_FILENAME": os.Args[0], + "env-SCRIPT_NAME": "/myscript", } runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) } func TestDupHeaders(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], } expectedMap := map[string]string{ "env-REQUEST_URI": "/myscript/bar?a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_FILENAME": os.Args[0], "env-HTTP_COOKIE": "nom=NOM; yum=YUM", "env-HTTP_X_FOO": "val1, val2", } @@ -245,13 +243,13 @@ func TestDupHeaders(t *testing.T) { // Verify we don't set the HTTP_PROXY environment variable. // Hope nobody was depending on it. It's not a known header, though. func TestDropProxyHeader(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], } expectedMap := map[string]string{ "env-REQUEST_URI": "/myscript/bar?a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_FILENAME": os.Args[0], "env-HTTP_X_FOO": "a", } runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ @@ -267,23 +265,23 @@ func TestDropProxyHeader(t *testing.T) { } func TestPathInfoNoRoot(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "", } expectedMap := map[string]string{ "env-PATH_INFO": "/bar", "env-QUERY_STRING": "a=b", "env-REQUEST_URI": "/bar?a=b", - "env-SCRIPT_FILENAME": "testdata/test.cgi", - "env-SCRIPT_NAME": "/", + "env-SCRIPT_FILENAME": os.Args[0], + "env-SCRIPT_NAME": "", } runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) } func TestCGIBasicPost(t *testing.T) { - check(t) + testenv.MustHaveExec(t) postReq := `POST /test.cgi?a=b HTTP/1.0 Host: example.com Content-Type: application/x-www-form-urlencoded @@ -291,7 +289,7 @@ Content-Length: 15 postfoo=postbar` h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } expectedMap := map[string]string{ @@ -310,7 +308,7 @@ func chunk(s string) string { // The CGI spec doesn't allow chunked requests. func TestCGIPostChunked(t *testing.T) { - check(t) + testenv.MustHaveExec(t) postReq := `POST /test.cgi?a=b HTTP/1.1 Host: example.com Content-Type: application/x-www-form-urlencoded @@ -319,7 +317,7 @@ Transfer-Encoding: chunked ` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("") h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } expectedMap := map[string]string{} @@ -331,9 +329,9 @@ Transfer-Encoding: chunked } func TestRedirect(t *testing.T) { - check(t) + testenv.MustHaveExec(t) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) @@ -346,13 +344,13 @@ func TestRedirect(t *testing.T) { } func TestInternalRedirect(t *testing.T) { - check(t) + testenv.MustHaveExec(t) baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) }) h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", PathLocationHandler: baseHandler, } @@ -365,13 +363,14 @@ func TestInternalRedirect(t *testing.T) { // TestCopyError tests that we kill the process if there's an error copying // its output. (for example, from the client having gone away) +// +// If we fail to do so, the test will time out (and dump its goroutines) with a +// call to [Handler.ServeHTTP] blocked on a deferred call to [exec.Cmd.Wait]. func TestCopyError(t *testing.T) { - check(t) - if runtime.GOOS == "windows" { - t.Skipf("skipping test on %q", runtime.GOOS) - } + testenv.MustHaveExec(t) + h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", } ts := httptest.NewServer(h) @@ -392,118 +391,63 @@ func TestCopyError(t *testing.T) { t.Fatalf("ReadResponse: %v", err) } - pidstr := res.Header.Get("X-CGI-Pid") - if pidstr == "" { - t.Fatalf("expected an X-CGI-Pid header in response") - } - pid, err := strconv.Atoi(pidstr) - if err != nil { - t.Fatalf("invalid X-CGI-Pid value") - } - var buf [5000]byte n, err := io.ReadFull(res.Body, buf[:]) if err != nil { t.Fatalf("ReadFull: %d bytes, %v", n, err) } - childRunning := func() bool { - return isProcessRunning(pid) - } - - if !childRunning() { - t.Fatalf("pre-conn.Close, expected child to be running") + if !handlerRunning() { + t.Fatalf("pre-conn.Close, expected handler to still be running") } conn.Close() + closed := time.Now() - tries := 0 - for tries < 25 && childRunning() { - time.Sleep(50 * time.Millisecond * time.Duration(tries)) - tries++ - } - if childRunning() { - t.Fatalf("post-conn.Close, expected child to be gone") - } -} - -func TestDirUnix(t *testing.T) { - check(t) - if runtime.GOOS == "windows" { - t.Skipf("skipping test on %q", runtime.GOOS) - } - cwd, _ := os.Getwd() - h := &Handler{ - Path: "testdata/test.cgi", - Root: "/test.cgi", - Dir: cwd, - } - expectedMap := map[string]string{ - "cwd": cwd, - } - runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) - - cwd, _ = os.Getwd() - cwd = filepath.Join(cwd, "testdata") - h = &Handler{ - Path: "testdata/test.cgi", - Root: "/test.cgi", - } - expectedMap = map[string]string{ - "cwd": cwd, + nextSleep := 1 * time.Millisecond + for { + time.Sleep(nextSleep) + nextSleep *= 2 + if !handlerRunning() { + break + } + t.Logf("handler still running %v after conn.Close", time.Since(closed)) } - runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) } -func findPerl(t *testing.T) string { - t.Helper() - perl, err := exec.LookPath("perl") - if err != nil { - t.Skip("Skipping test: perl not found.") - } - perl, _ = filepath.Abs(perl) - - cmd := exec.Command(perl, "-e", "print 123") - cmd.Env = []string{"PATH=/garbage"} - out, err := cmd.Output() - if err != nil || string(out) != "123" { - t.Skipf("Skipping test: %s is not functional", perl) +// handlerRunning reports whether any goroutine is currently running +// [Handler.ServeHTTP]. +func handlerRunning() bool { + r := regexp.MustCompile(`net/http/cgi\.\(\*Handler\)\.ServeHTTP`) + buf := make([]byte, 64<<10) + for { + n := runtime.Stack(buf, true) + if n < len(buf) { + return r.Match(buf[:n]) + } + // Buffer wasn't large enough for a full goroutine dump. + // Resize it and try again. + buf = make([]byte, 2*len(buf)) } - return perl } -func TestDirWindows(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping windows specific test.") - } - - cgifile, _ := filepath.Abs("testdata/test.cgi") - - perl := findPerl(t) - +func TestDir(t *testing.T) { + testenv.MustHaveExec(t) cwd, _ := os.Getwd() h := &Handler{ - Path: perl, + Path: os.Args[0], Root: "/test.cgi", Dir: cwd, - Args: []string{cgifile}, - Env: []string{"SCRIPT_FILENAME=" + cgifile}, } expectedMap := map[string]string{ "cwd": cwd, } runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) - // If not specify Dir on windows, working directory should be - // base directory of perl. - cwd, _ = filepath.Split(perl) - if cwd != "" && cwd[len(cwd)-1] == filepath.Separator { - cwd = cwd[:len(cwd)-1] - } + cwd, _ = os.Getwd() + cwd, _ = filepath.Split(os.Args[0]) h = &Handler{ - Path: perl, + Path: os.Args[0], Root: "/test.cgi", - Args: []string{cgifile}, - Env: []string{"SCRIPT_FILENAME=" + cgifile}, } expectedMap = map[string]string{ "cwd": cwd, @@ -512,17 +456,14 @@ func TestDirWindows(t *testing.T) { } func TestEnvOverride(t *testing.T) { - check(t) + testenv.MustHaveExec(t) cgifile, _ := filepath.Abs("testdata/test.cgi") - perl := findPerl(t) - cwd, _ := os.Getwd() h := &Handler{ - Path: perl, + Path: os.Args[0], Root: "/test.cgi", Dir: cwd, - Args: []string{cgifile}, Env: []string{ "SCRIPT_FILENAME=" + cgifile, "REQUEST_URI=/foo/bar", @@ -538,10 +479,10 @@ func TestEnvOverride(t *testing.T) { } func TestHandlerStderr(t *testing.T) { - check(t) + testenv.MustHaveExec(t) var stderr strings.Builder h := &Handler{ - Path: "testdata/test.cgi", + Path: os.Args[0], Root: "/test.cgi", Stderr: &stderr, } diff --git a/src/net/http/cgi/integration_test.go b/src/net/http/cgi/integration_test.go index ef2eaf748b..68f908e2b2 100644 --- a/src/net/http/cgi/integration_test.go +++ b/src/net/http/cgi/integration_test.go @@ -20,7 +20,6 @@ import ( "os" "strings" "testing" - "time" ) // This test is a CGI host (testing host.go) that runs its own binary @@ -31,7 +30,6 @@ func TestHostingOurselves(t *testing.T) { h := &Handler{ Path: os.Args[0], Root: "/test.go", - Args: []string{"-test.run=TestBeChildCGIProcess"}, } expectedMap := map[string]string{ "test": "Hello CGI-in-CGI", @@ -98,9 +96,8 @@ func TestKillChildAfterCopyError(t *testing.T) { h := &Handler{ Path: os.Args[0], Root: "/test.go", - Args: []string{"-test.run=TestBeChildCGIProcess"}, } - req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil) + req, _ := http.NewRequest("GET", "http://example.com/test.go?write-forever=1", nil) rec := httptest.NewRecorder() var out bytes.Buffer const writeLen = 50 << 10 @@ -120,7 +117,6 @@ func TestChildOnlyHeaders(t *testing.T) { h := &Handler{ Path: os.Args[0], Root: "/test.go", - Args: []string{"-test.run=TestBeChildCGIProcess"}, } expectedMap := map[string]string{ "_body": "", @@ -139,7 +135,6 @@ func TestNilRequestBody(t *testing.T) { h := &Handler{ Path: os.Args[0], Root: "/test.go", - Args: []string{"-test.run=TestBeChildCGIProcess"}, } expectedMap := map[string]string{ "nil-request-body": "false", @@ -154,7 +149,6 @@ func TestChildContentType(t *testing.T) { h := &Handler{ Path: os.Args[0], Root: "/test.go", - Args: []string{"-test.run=TestBeChildCGIProcess"}, } var tests = []struct { name string @@ -202,7 +196,6 @@ func want500Test(t *testing.T, path string) { h := &Handler{ Path: os.Args[0], Root: "/test.go", - Args: []string{"-test.run=TestBeChildCGIProcess"}, } expectedMap := map[string]string{ "_body": "", @@ -212,61 +205,3 @@ func want500Test(t *testing.T, path string) { t.Errorf("Got code %d; want 500", replay.Code) } } - -type neverEnding byte - -func (b neverEnding) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(b) - } - return len(p), nil -} - -// Note: not actually a test. -func TestBeChildCGIProcess(t *testing.T) { - if os.Getenv("REQUEST_METHOD") == "" { - // Not in a CGI environment; skipping test. - return - } - switch os.Getenv("REQUEST_URI") { - case "/immediate-disconnect": - os.Exit(0) - case "/no-content-type": - fmt.Printf("Content-Length: 6\n\nHello\n") - os.Exit(0) - case "/empty-headers": - fmt.Printf("\nHello") - os.Exit(0) - } - Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.FormValue("nil-request-body") == "1" { - fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil) - return - } - rw.Header().Set("X-Test-Header", "X-Test-Value") - req.ParseForm() - if req.FormValue("no-body") == "1" { - return - } - if eb, ok := req.Form["exact-body"]; ok { - io.WriteString(rw, eb[0]) - return - } - if req.FormValue("write-forever") == "1" { - io.Copy(rw, neverEnding('a')) - for { - time.Sleep(5 * time.Second) // hang forever, until killed - } - } - fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") - for k, vv := range req.Form { - for _, v := range vv { - fmt.Fprintf(rw, "param-%s=%s\n", k, v) - } - } - for _, kv := range os.Environ() { - fmt.Fprintf(rw, "env-%s\n", kv) - } - })) - os.Exit(0) -} diff --git a/src/net/http/cgi/plan9_test.go b/src/net/http/cgi/plan9_test.go deleted file mode 100644 index b7ace3f81c..0000000000 --- a/src/net/http/cgi/plan9_test.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build plan9 - -package cgi - -import ( - "os" - "strconv" -) - -func isProcessRunning(pid int) bool { - _, err := os.Stat("/proc/" + strconv.Itoa(pid)) - return err == nil -} diff --git a/src/net/http/cgi/posix_test.go b/src/net/http/cgi/posix_test.go deleted file mode 100644 index 49b9470d4a..0000000000 --- a/src/net/http/cgi/posix_test.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !plan9 - -package cgi - -import ( - "os" - "syscall" -) - -func isProcessRunning(pid int) bool { - p, err := os.FindProcess(pid) - if err != nil { - return false - } - return p.Signal(syscall.Signal(0)) == nil -} diff --git a/src/net/http/cgi/testdata/test.cgi b/src/net/http/cgi/testdata/test.cgi deleted file mode 100755 index 667fce217e..0000000000 --- a/src/net/http/cgi/testdata/test.cgi +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/perl -# Copyright 2011 The Go Authors. All rights reserved. -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file. -# -# Test script run as a child process under cgi_test.go - -use strict; -use Cwd; - -binmode STDOUT; - -my $q = MiniCGI->new; -my $params = $q->Vars; - -if ($params->{"loc"}) { - print "Location: $params->{loc}\r\n\r\n"; - exit(0); -} - -print "Content-Type: text/html\r\n"; -print "X-CGI-Pid: $$\r\n"; -print "X-Test-Header: X-Test-Value\r\n"; -print "\r\n"; - -if ($params->{"writestderr"}) { - print STDERR "Hello, stderr!\n"; -} - -if ($params->{"bigresponse"}) { - # 17 MB, for OS X: golang.org/issue/4958 - for (1..(17 * 1024)) { - print "A" x 1024, "\r\n"; - } - exit 0; -} - -print "test=Hello CGI\r\n"; - -foreach my $k (sort keys %$params) { - print "param-$k=$params->{$k}\r\n"; -} - -foreach my $k (sort keys %ENV) { - my $clean_env = $ENV{$k}; - $clean_env =~ s/[\n\r]//g; - print "env-$k=$clean_env\r\n"; -} - -# NOTE: msys perl returns /c/go/src/... not C:\go\.... -my $dir = getcwd(); -if ($^O eq 'MSWin32' || $^O eq 'msys' || $^O eq 'cygwin') { - if ($dir =~ /^.:/) { - $dir =~ s!/!\\!g; - } else { - my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; - $cmd =~ s!\\!/!g; - $dir = `$cmd /c cd`; - chomp $dir; - } -} -print "cwd=$dir\r\n"; - -# A minimal version of CGI.pm, for people without the perl-modules -# package installed. (CGI.pm used to be part of the Perl core, but -# some distros now bundle perl-base and perl-modules separately...) -package MiniCGI; - -sub new { - my $class = shift; - return bless {}, $class; -} - -sub Vars { - my $self = shift; - my $pairs; - if ($ENV{CONTENT_LENGTH}) { - $pairs = do { local $/; <STDIN> }; - } else { - $pairs = $ENV{QUERY_STRING}; - } - my $vars = {}; - foreach my $kv (split(/&/, $pairs)) { - my ($k, $v) = split(/=/, $kv, 2); - $vars->{_urldecode($k)} = _urldecode($v); - } - return $vars; -} - -sub _urldecode { - my $v = shift; - $v =~ tr/+/ /; - $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg; - return $v; -} diff --git a/src/net/http/client.go b/src/net/http/client.go index 2cab53a585..8fc348fe5d 100644 --- a/src/net/http/client.go +++ b/src/net/http/client.go @@ -27,34 +27,33 @@ import ( "time" ) -// A Client is an HTTP client. Its zero value (DefaultClient) is a -// usable client that uses DefaultTransport. +// A Client is an HTTP client. Its zero value ([DefaultClient]) is a +// usable client that uses [DefaultTransport]. // -// The Client's Transport typically has internal state (cached TCP +// The [Client.Transport] typically has internal state (cached TCP // connections), so Clients should be reused instead of created as // needed. Clients are safe for concurrent use by multiple goroutines. // -// A Client is higher-level than a RoundTripper (such as Transport) +// A Client is higher-level than a [RoundTripper] (such as [Transport]) // and additionally handles HTTP details such as cookies and // redirects. // // When following redirects, the Client will forward all headers set on the -// initial Request except: +// initial [Request] except: // -// • when forwarding sensitive headers like "Authorization", -// "WWW-Authenticate", and "Cookie" to untrusted targets. -// These headers will be ignored when following a redirect to a domain -// that is not a subdomain match or exact match of the initial domain. -// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" -// will forward the sensitive headers, but a redirect to "bar.com" will not. -// -// • when forwarding the "Cookie" header with a non-nil cookie Jar. -// Since each redirect may mutate the state of the cookie jar, -// a redirect may possibly alter a cookie set in the initial request. -// When forwarding the "Cookie" header, any mutated cookies will be omitted, -// with the expectation that the Jar will insert those mutated cookies -// with the updated values (assuming the origin matches). -// If Jar is nil, the initial cookies are forwarded without change. +// - when forwarding sensitive headers like "Authorization", +// "WWW-Authenticate", and "Cookie" to untrusted targets. +// These headers will be ignored when following a redirect to a domain +// that is not a subdomain match or exact match of the initial domain. +// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" +// will forward the sensitive headers, but a redirect to "bar.com" will not. +// - when forwarding the "Cookie" header with a non-nil cookie Jar. +// Since each redirect may mutate the state of the cookie jar, +// a redirect may possibly alter a cookie set in the initial request. +// When forwarding the "Cookie" header, any mutated cookies will be omitted, +// with the expectation that the Jar will insert those mutated cookies +// with the updated values (assuming the origin matches). +// If Jar is nil, the initial cookies are forwarded without change. type Client struct { // Transport specifies the mechanism by which individual // HTTP requests are made. @@ -106,11 +105,11 @@ type Client struct { Timeout time.Duration } -// DefaultClient is the default Client and is used by Get, Head, and Post. +// DefaultClient is the default [Client] and is used by [Get], [Head], and [Post]. var DefaultClient = &Client{} // RoundTripper is an interface representing the ability to execute a -// single HTTP transaction, obtaining the Response for a given Request. +// single HTTP transaction, obtaining the [Response] for a given [Request]. // // A RoundTripper must be safe for concurrent use by multiple // goroutines. @@ -440,7 +439,7 @@ func basicAuth(username, password string) string { // // An error is returned if there were too many redirects or if there // was an HTTP protocol error. A non-2xx response doesn't cause an -// error. Any returned error will be of type *url.Error. The url.Error +// error. Any returned error will be of type [*url.Error]. The url.Error // value's Timeout method will report true if the request timed out. // // When err is nil, resp always contains a non-nil resp.Body. @@ -448,10 +447,10 @@ func basicAuth(username, password string) string { // // Get is a wrapper around DefaultClient.Get. // -// To make a request with custom headers, use NewRequest and +// To make a request with custom headers, use [NewRequest] and // DefaultClient.Do. // -// To make a request with a specified context.Context, use NewRequestWithContext +// To make a request with a specified context.Context, use [NewRequestWithContext] // and DefaultClient.Do. func Get(url string) (resp *Response, err error) { return DefaultClient.Get(url) @@ -459,7 +458,7 @@ func Get(url string) (resp *Response, err error) { // Get issues a GET to the specified URL. If the response is one of the // following redirect codes, Get follows the redirect after calling the -// Client's CheckRedirect function: +// [Client.CheckRedirect] function: // // 301 (Moved Permanently) // 302 (Found) @@ -467,18 +466,18 @@ func Get(url string) (resp *Response, err error) { // 307 (Temporary Redirect) // 308 (Permanent Redirect) // -// An error is returned if the Client's CheckRedirect function fails +// An error is returned if the [Client.CheckRedirect] function fails // or if there was an HTTP protocol error. A non-2xx response doesn't -// cause an error. Any returned error will be of type *url.Error. The +// cause an error. Any returned error will be of type [*url.Error]. The // url.Error value's Timeout method will report true if the request // timed out. // // When err is nil, resp always contains a non-nil resp.Body. // Caller should close resp.Body when done reading from it. // -// To make a request with custom headers, use NewRequest and Client.Do. +// To make a request with custom headers, use [NewRequest] and [Client.Do]. // -// To make a request with a specified context.Context, use NewRequestWithContext +// To make a request with a specified context.Context, use [NewRequestWithContext] // and Client.Do. func (c *Client) Get(url string) (resp *Response, err error) { req, err := NewRequest("GET", url, nil) @@ -559,20 +558,21 @@ func urlErrorOp(method string) string { // connectivity problem). A non-2xx status code doesn't cause an // error. // -// If the returned error is nil, the Response will contain a non-nil +// If the returned error is nil, the [Response] will contain a non-nil // Body which the user is expected to close. If the Body is not both -// read to EOF and closed, the Client's underlying RoundTripper -// (typically Transport) may not be able to re-use a persistent TCP +// read to EOF and closed, the [Client]'s underlying [RoundTripper] +// (typically [Transport]) may not be able to re-use a persistent TCP // connection to the server for a subsequent "keep-alive" request. // // The request Body, if non-nil, will be closed by the underlying -// Transport, even on errors. +// Transport, even on errors. The Body may be closed asynchronously after +// Do returns. // // On error, any Response can be ignored. A non-nil Response with a // non-nil error only occurs when CheckRedirect fails, and even then -// the returned Response.Body is already closed. +// the returned [Response.Body] is already closed. // -// Generally Get, Post, or PostForm will be used instead of Do. +// Generally [Get], [Post], or [PostForm] will be used instead of Do. // // If the server replies with a redirect, the Client first uses the // CheckRedirect function to determine whether the redirect should be @@ -580,11 +580,11 @@ func urlErrorOp(method string) string { // subsequent requests to use HTTP method GET // (or HEAD if the original request was HEAD), with no body. // A 307 or 308 redirect preserves the original HTTP method and body, -// provided that the Request.GetBody function is defined. -// The NewRequest function automatically sets GetBody for common +// provided that the [Request.GetBody] function is defined. +// The [NewRequest] function automatically sets GetBody for common // standard library body types. // -// Any returned error will be of type *url.Error. The url.Error +// Any returned error will be of type [*url.Error]. The url.Error // value's Timeout method will report true if the request timed out. func (c *Client) Do(req *Request) (*Response, error) { return c.do(req) @@ -818,17 +818,17 @@ func defaultCheckRedirect(req *Request, via []*Request) error { // // Caller should close resp.Body when done reading from it. // -// If the provided body is an io.Closer, it is closed after the +// If the provided body is an [io.Closer], it is closed after the // request. // // Post is a wrapper around DefaultClient.Post. // -// To set custom headers, use NewRequest and DefaultClient.Do. +// To set custom headers, use [NewRequest] and DefaultClient.Do. // -// See the Client.Do method documentation for details on how redirects +// See the [Client.Do] method documentation for details on how redirects // are handled. // -// To make a request with a specified context.Context, use NewRequestWithContext +// To make a request with a specified context.Context, use [NewRequestWithContext] // and DefaultClient.Do. func Post(url, contentType string, body io.Reader) (resp *Response, err error) { return DefaultClient.Post(url, contentType, body) @@ -838,13 +838,13 @@ func Post(url, contentType string, body io.Reader) (resp *Response, err error) { // // Caller should close resp.Body when done reading from it. // -// If the provided body is an io.Closer, it is closed after the +// If the provided body is an [io.Closer], it is closed after the // request. // -// To set custom headers, use NewRequest and Client.Do. +// To set custom headers, use [NewRequest] and [Client.Do]. // -// To make a request with a specified context.Context, use NewRequestWithContext -// and Client.Do. +// To make a request with a specified context.Context, use [NewRequestWithContext] +// and [Client.Do]. // // See the Client.Do method documentation for details on how redirects // are handled. @@ -861,17 +861,17 @@ func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, // values URL-encoded as the request body. // // The Content-Type header is set to application/x-www-form-urlencoded. -// To set other headers, use NewRequest and DefaultClient.Do. +// To set other headers, use [NewRequest] and DefaultClient.Do. // // When err is nil, resp always contains a non-nil resp.Body. // Caller should close resp.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm. // -// See the Client.Do method documentation for details on how redirects +// See the [Client.Do] method documentation for details on how redirects // are handled. // -// To make a request with a specified context.Context, use NewRequestWithContext +// To make a request with a specified [context.Context], use [NewRequestWithContext] // and DefaultClient.Do. func PostForm(url string, data url.Values) (resp *Response, err error) { return DefaultClient.PostForm(url, data) @@ -881,7 +881,7 @@ func PostForm(url string, data url.Values) (resp *Response, err error) { // with data's keys and values URL-encoded as the request body. // // The Content-Type header is set to application/x-www-form-urlencoded. -// To set other headers, use NewRequest and Client.Do. +// To set other headers, use [NewRequest] and [Client.Do]. // // When err is nil, resp always contains a non-nil resp.Body. // Caller should close resp.Body when done reading from it. @@ -889,7 +889,7 @@ func PostForm(url string, data url.Values) (resp *Response, err error) { // See the Client.Do method documentation for details on how redirects // are handled. // -// To make a request with a specified context.Context, use NewRequestWithContext +// To make a request with a specified context.Context, use [NewRequestWithContext] // and Client.Do. func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) @@ -907,7 +907,7 @@ func (c *Client) PostForm(url string, data url.Values) (resp *Response, err erro // // Head is a wrapper around DefaultClient.Head. // -// To make a request with a specified context.Context, use NewRequestWithContext +// To make a request with a specified [context.Context], use [NewRequestWithContext] // and DefaultClient.Do. func Head(url string) (resp *Response, err error) { return DefaultClient.Head(url) @@ -915,7 +915,7 @@ func Head(url string) (resp *Response, err error) { // Head issues a HEAD to the specified URL. If the response is one of the // following redirect codes, Head follows the redirect after calling the -// Client's CheckRedirect function: +// [Client.CheckRedirect] function: // // 301 (Moved Permanently) // 302 (Found) @@ -923,8 +923,8 @@ func Head(url string) (resp *Response, err error) { // 307 (Temporary Redirect) // 308 (Permanent Redirect) // -// To make a request with a specified context.Context, use NewRequestWithContext -// and Client.Do. +// To make a request with a specified [context.Context], use [NewRequestWithContext] +// and [Client.Do]. func (c *Client) Head(url string) (resp *Response, err error) { req, err := NewRequest("HEAD", url, nil) if err != nil { @@ -933,12 +933,12 @@ func (c *Client) Head(url string) (resp *Response, err error) { return c.Do(req) } -// CloseIdleConnections closes any connections on its Transport which +// CloseIdleConnections closes any connections on its [Transport] which // were previously connected from previous requests but are now // sitting idle in a "keep-alive" state. It does not interrupt any // connections currently in use. // -// If the Client's Transport does not have a CloseIdleConnections method +// If [Client.Transport] does not have a [Client.CloseIdleConnections] method // then this method does nothing. func (c *Client) CloseIdleConnections() { type closeIdler interface { @@ -1014,6 +1014,12 @@ func isDomainOrSubdomain(sub, parent string) bool { if sub == parent { return true } + // If sub contains a :, it's probably an IPv6 address (and is definitely not a hostname). + // Don't check the suffix in this case, to avoid matching the contents of a IPv6 zone. + // For example, "::1%.www.example.com" is not a subdomain of "www.example.com". + if strings.ContainsAny(sub, ":%") { + return false + } // If sub is "foo.example.com" and parent is "example.com", // that means sub must end in "."+parent. // Do it without allocating. diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 0fe555af38..e2a1cbbdea 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -60,13 +60,6 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } } -type chanWriter chan string - -func (w chanWriter) Write(p []byte) (n int, err error) { - w <- string(p) - return len(p), nil -} - func TestClient(t *testing.T) { run(t, testClient) } func testClient(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, robotsTxtHandler).ts @@ -827,12 +820,12 @@ func TestClientInsecureTransport(t *testing.T) { run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode}) } func testClientInsecureTransport(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })).ts - errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ErrorLog = log.New(errc, "", 0) - defer ts.Close() + })) + ts := cst.ts + errLog := new(strings.Builder) + ts.Config.ErrorLog = log.New(errLog, "", 0) // TODO(bradfitz): add tests for skipping hostname checks too? // would require a new cert for testing, and probably @@ -851,15 +844,10 @@ func testClientInsecureTransport(t *testing.T, mode testMode) { } } - select { - case v := <-errc: - if !strings.Contains(v, "TLS handshake error") { - t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) - } - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for logged error") + cst.close() + if !strings.Contains(errLog.String(), "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog) } - } func TestClientErrorWithRequestURI(t *testing.T) { @@ -897,9 +885,10 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode}) } func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts - errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ErrorLog = log.New(errc, "", 0) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts := cst.ts + errLog := new(strings.Builder) + ts.Config.ErrorLog = log.New(errLog, "", 0) c := ts.Client() c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver" @@ -910,13 +899,10 @@ func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) } - select { - case v := <-errc: - if !strings.Contains(v, "TLS handshake error") { - t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) - } - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for logged error") + + cst.close() + if !strings.Contains(errLog.String(), "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog) } } @@ -960,7 +946,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { c := ts.Client() tr := c.Transport.(*Transport) - tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} + tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA} tr.TLSClientConfig.MaxVersion = tls.VersionTLS12 // to get to pick the cipher suite tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) @@ -973,7 +959,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { if res.TLS == nil { t.Fatal("Response didn't set TLS Connection State.") } - if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want { + if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA; got != want { t.Errorf("TLS Cipher Suite = %d; want %d", got, want) } } @@ -1725,6 +1711,7 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) { {"authorization", "http://foo.com/", "https://foo.com/", true}, {"authorization", "http://foo.com:1234/", "http://foo.com:4321/", true}, {"www-authenticate", "http://foo.com/", "http://bar.com/", false}, + {"authorization", "http://foo.com/", "http://[::1%25.foo.com]/", false}, // But subdomains should work: {"www-authenticate", "http://foo.com/", "http://foo.com/", true}, diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 58321532ea..32948f3aed 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -1172,16 +1172,12 @@ func testTransportGCRequest(t *testing.T, mode testMode, body bool) { t.Fatal(err) } })() - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() for { select { case <-didGC: return - case <-time.After(100 * time.Millisecond): + case <-time.After(1 * time.Millisecond): runtime.GC() - case <-timeout.C: - t.Fatal("never saw GC of request") } } } diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go index 912fde6b95..c22897f3f9 100644 --- a/src/net/http/cookie.go +++ b/src/net/http/cookie.go @@ -163,7 +163,7 @@ func readSetCookies(h Header) []*Cookie { return cookies } -// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. +// SetCookie adds a Set-Cookie header to the provided [ResponseWriter]'s headers. // The provided cookie must have a valid Name. Invalid cookies may be // silently dropped. func SetCookie(w ResponseWriter, cookie *Cookie) { @@ -172,7 +172,7 @@ func SetCookie(w ResponseWriter, cookie *Cookie) { } } -// String returns the serialization of the cookie for use in a Cookie +// String returns the serialization of the cookie for use in a [Cookie] // header (if only Name and Value are set) or a Set-Cookie response // header (if other fields are set). // If c is nil or c.Name is invalid, the empty string is returned. diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go index 273b54c84c..e7f5ddd4d0 100644 --- a/src/net/http/cookiejar/jar.go +++ b/src/net/http/cookiejar/jar.go @@ -73,7 +73,7 @@ type Jar struct { nextSeqNum uint64 } -// New returns a new cookie jar. A nil *Options is equivalent to a zero +// New returns a new cookie jar. A nil [*Options] is equivalent to a zero // Options. func New(o *Options) (*Jar, error) { jar := &Jar{ @@ -151,7 +151,7 @@ func hasDotSuffix(s, suffix string) bool { return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix } -// Cookies implements the Cookies method of the http.CookieJar interface. +// Cookies implements the Cookies method of the [http.CookieJar] interface. // // It returns an empty slice if the URL's scheme is not HTTP or HTTPS. func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { @@ -226,7 +226,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { return cookies } -// SetCookies implements the SetCookies method of the http.CookieJar interface. +// SetCookies implements the SetCookies method of the [http.CookieJar] interface. // // It does nothing if the URL's scheme is not HTTP or HTTPS. func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { @@ -362,6 +362,13 @@ func jarKey(host string, psl PublicSuffixList) string { // isIP reports whether host is an IP address. func isIP(host string) bool { + if strings.ContainsAny(host, ":%") { + // Probable IPv6 address. + // Hostnames can't contain : or %, so this is definitely not a valid host. + // Treating it as an IP is the more conservative option, and avoids the risk + // of interpeting ::1%.www.example.com as a subtomain of www.example.com. + return true + } return net.ParseIP(host) != nil } @@ -440,7 +447,6 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e var ( errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") - errNoHostname = errors.New("cookiejar: no host name available (IP only)") ) // endOfTime is the time when session (non-persistent) cookies expire. diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go index 56d0695a66..251f7c1617 100644 --- a/src/net/http/cookiejar/jar_test.go +++ b/src/net/http/cookiejar/jar_test.go @@ -252,6 +252,7 @@ var isIPTests = map[string]bool{ "127.0.0.1": true, "1.2.3.4": true, "2001:4860:0:2001::68": true, + "::1%zone": true, "example.com": false, "1.1.1.300": false, "www.foo.bar.net": false, @@ -629,6 +630,15 @@ var basicsTests = [...]jarTest{ {"http://www.host.test:1234/", "a=1"}, }, }, + { + "IPv6 zone is not treated as a host.", + "https://example.com/", + []string{"a=1"}, + "a=1", + []query{ + {"https://[::1%25.example.com]:80/", ""}, + }, + }, } func TestBasics(t *testing.T) { diff --git a/src/net/http/doc.go b/src/net/http/doc.go index d9e6aafb4e..f7ad3ae762 100644 --- a/src/net/http/doc.go +++ b/src/net/http/doc.go @@ -5,7 +5,7 @@ /* Package http provides HTTP client and server implementations. -Get, Head, Post, and PostForm make HTTP (or HTTPS) requests: +[Get], [Head], [Post], and [PostForm] make HTTP (or HTTPS) requests: resp, err := http.Get("http://example.com/") ... @@ -27,7 +27,7 @@ The caller must close the response body when finished with it: # Clients and Transports For control over HTTP client headers, redirect policy, and other -settings, create a Client: +settings, create a [Client]: client := &http.Client{ CheckRedirect: redirectPolicyFunc, @@ -43,7 +43,7 @@ settings, create a Client: // ... For control over proxies, TLS configuration, keep-alives, -compression, and other settings, create a Transport: +compression, and other settings, create a [Transport]: tr := &http.Transport{ MaxIdleConns: 10, @@ -59,8 +59,8 @@ goroutines and for efficiency should only be created once and re-used. # Servers ListenAndServe starts an HTTP server with a given address and handler. -The handler is usually nil, which means to use DefaultServeMux. -Handle and HandleFunc add handlers to DefaultServeMux: +The handler is usually nil, which means to use [DefaultServeMux]. +[Handle] and [HandleFunc] add handlers to [DefaultServeMux]: http.Handle("/foo", fooHandler) @@ -86,8 +86,8 @@ custom Server: Starting with Go 1.6, the http package has transparent support for the HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2 -can do so by setting Transport.TLSNextProto (for clients) or -Server.TLSNextProto (for servers) to a non-nil, empty +can do so by setting [Transport.TLSNextProto] (for clients) or +[Server.TLSNextProto] (for servers) to a non-nil, empty map. Alternatively, the following GODEBUG settings are currently supported: @@ -98,7 +98,7 @@ currently supported: Please report any issues before disabling HTTP/2 support: https://golang.org/s/http2bug -The http package's Transport and Server both automatically enable +The http package's [Transport] and [Server] both automatically enable HTTP/2 support for simple configurations. To enable HTTP/2 for more complex configurations, to use lower-level HTTP/2 features, or to use a newer version of Go's http2 package, import "golang.org/x/net/http2" diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 5d198f3f89..7e6d3d8e30 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -315,3 +315,21 @@ func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) { } return nil, false } + +func init() { + // Set the default rstAvoidanceDelay to the minimum possible value to shake + // out tests that unexpectedly depend on it. Such tests should use + // runTimeSensitiveTest and SetRSTAvoidanceDelay to explicitly raise the delay + // if needed. + rstAvoidanceDelay = 1 * time.Nanosecond +} + +// SetRSTAvoidanceDelay sets how long we are willing to wait between calling +// CloseWrite on a connection and fully closing the connection. +func SetRSTAvoidanceDelay(t *testing.T, d time.Duration) { + prevDelay := rstAvoidanceDelay + t.Cleanup(func() { + rstAvoidanceDelay = prevDelay + }) + rstAvoidanceDelay = d +} diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go index dc82bf7c3a..7665e7d252 100644 --- a/src/net/http/fcgi/child.go +++ b/src/net/http/fcgi/child.go @@ -335,7 +335,7 @@ func (c *child) cleanUp() { // goroutine for each. The goroutine reads requests and then calls handler // to reply to them. // If l is nil, Serve accepts connections from os.Stdin. -// If handler is nil, http.DefaultServeMux is used. +// If handler is nil, [http.DefaultServeMux] is used. func Serve(l net.Listener, handler http.Handler) error { if l == nil { var err error diff --git a/src/net/http/filetransport.go b/src/net/http/filetransport.go index 94684b07a1..7384b22fbe 100644 --- a/src/net/http/filetransport.go +++ b/src/net/http/filetransport.go @@ -7,6 +7,7 @@ package http import ( "fmt" "io" + "io/fs" ) // fileTransport implements RoundTripper for the 'file' protocol. @@ -14,13 +15,13 @@ type fileTransport struct { fh fileHandler } -// NewFileTransport returns a new RoundTripper, serving the provided -// FileSystem. The returned RoundTripper ignores the URL host in its +// NewFileTransport returns a new [RoundTripper], serving the provided +// [FileSystem]. The returned RoundTripper ignores the URL host in its // incoming requests, as well as most other properties of the // request. // // The typical use case for NewFileTransport is to register the "file" -// protocol with a Transport, as in: +// protocol with a [Transport], as in: // // t := &http.Transport{} // t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/"))) @@ -31,6 +32,24 @@ func NewFileTransport(fs FileSystem) RoundTripper { return fileTransport{fileHandler{fs}} } +// NewFileTransportFS returns a new [RoundTripper], serving the provided +// file system fsys. The returned RoundTripper ignores the URL host in its +// incoming requests, as well as most other properties of the +// request. +// +// The typical use case for NewFileTransportFS is to register the "file" +// protocol with a [Transport], as in: +// +// fsys := os.DirFS("/") +// t := &http.Transport{} +// t.RegisterProtocol("file", http.NewFileTransportFS(fsys)) +// c := &http.Client{Transport: t} +// res, err := c.Get("file:///etc/passwd") +// ... +func NewFileTransportFS(fsys fs.FS) RoundTripper { + return NewFileTransport(FS(fsys)) +} + func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) { // We start ServeHTTP in a goroutine, which may take a long // time if the file is large. The newPopulateResponseWriter diff --git a/src/net/http/filetransport_test.go b/src/net/http/filetransport_test.go index 77fc8eeccf..b3e3301e10 100644 --- a/src/net/http/filetransport_test.go +++ b/src/net/http/filetransport_test.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "testing" + "testing/fstest" ) func checker(t *testing.T) func(string, error) { @@ -62,3 +63,44 @@ func TestFileTransport(t *testing.T) { } res.Body.Close() } + +func TestFileTransportFS(t *testing.T) { + check := checker(t) + + fsys := fstest.MapFS{ + "index.html": {Data: []byte("index.html says hello")}, + } + + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransportFS(fsys)) + c := &Client{Transport: tr} + + for fname, mfile := range fsys { + urlstr := "file:///" + fname + res, err := c.Get(urlstr) + check("Get "+urlstr, err) + if res.StatusCode != 200 { + t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode) + } + if res.ContentLength != -1 { + t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength) + } + if res.Body == nil { + t.Fatalf("for %s, nil Body", urlstr) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + check("ReadAll "+urlstr, err) + if string(slurp) != string(mfile.Data) { + t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar") + } + } + + const badURL = "file://../no-exist.txt" + res, err := c.Get(badURL) + check("Get "+badURL, err) + if res.StatusCode != 404 { + t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) + } + res.Body.Close() +} diff --git a/src/net/http/fs.go b/src/net/http/fs.go index 41e0b43ac8..af7511a7a4 100644 --- a/src/net/http/fs.go +++ b/src/net/http/fs.go @@ -25,12 +25,12 @@ import ( "time" ) -// A Dir implements FileSystem using the native file system restricted to a +// A Dir implements [FileSystem] using the native file system restricted to a // specific directory tree. // -// While the FileSystem.Open method takes '/'-separated paths, a Dir's string +// While the [FileSystem.Open] method takes '/'-separated paths, a Dir's string // value is a filename on the native file system, not a URL, so it is separated -// by filepath.Separator, which isn't necessarily '/'. +// by [filepath.Separator], which isn't necessarily '/'. // // Note that Dir could expose sensitive files and directories. Dir will follow // symlinks pointing out of the directory tree, which can be especially dangerous @@ -67,7 +67,7 @@ func mapOpenError(originalErr error, name string, sep rune, stat func(string) (f return originalErr } -// Open implements FileSystem using os.Open, opening files for reading rooted +// Open implements [FileSystem] using [os.Open], opening files for reading rooted // and relative to the directory d. func (d Dir) Open(name string) (File, error) { path, err := safefilepath.FromFS(path.Clean("/" + name)) @@ -89,18 +89,18 @@ func (d Dir) Open(name string) (File, error) { // A FileSystem implements access to a collection of named files. // The elements in a file path are separated by slash ('/', U+002F) // characters, regardless of host operating system convention. -// See the FileServer function to convert a FileSystem to a Handler. +// See the [FileServer] function to convert a FileSystem to a [Handler]. // -// This interface predates the fs.FS interface, which can be used instead: -// the FS adapter function converts an fs.FS to a FileSystem. +// This interface predates the [fs.FS] interface, which can be used instead: +// the [FS] adapter function converts an fs.FS to a FileSystem. type FileSystem interface { Open(name string) (File, error) } -// A File is returned by a FileSystem's Open method and can be -// served by the FileServer implementation. +// A File is returned by a [FileSystem]'s Open method and can be +// served by the [FileServer] implementation. // -// The methods should behave the same as those on an *os.File. +// The methods should behave the same as those on an [*os.File]. type File interface { io.Closer io.Reader @@ -167,7 +167,7 @@ func dirList(w ResponseWriter, r *Request, f File) { } // ServeContent replies to the request using the content in the -// provided ReadSeeker. The main benefit of ServeContent over io.Copy +// provided ReadSeeker. The main benefit of ServeContent over [io.Copy] // is that it handles Range requests properly, sets the MIME type, and // handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since, // and If-Range requests. @@ -175,7 +175,7 @@ func dirList(w ResponseWriter, r *Request, f File) { // If the response's Content-Type header is not set, ServeContent // first tries to deduce the type from name's file extension and, // if that fails, falls back to reading the first block of the content -// and passing it to DetectContentType. +// and passing it to [DetectContentType]. // The name is otherwise unused; in particular it can be empty and is // never sent in the response. // @@ -190,7 +190,7 @@ func dirList(w ResponseWriter, r *Request, f File) { // If the caller has set w's ETag header formatted per RFC 7232, section 2.3, // ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range. // -// Note that *os.File implements the io.ReadSeeker interface. +// Note that [*os.File] implements the [io.ReadSeeker] interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { sizeFunc := func() (int64, error) { size, err := content.Seek(0, io.SeekEnd) @@ -343,10 +343,35 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } w.Header().Set("Accept-Ranges", "bytes") - if w.Header().Get("Content-Encoding") == "" { + + // We should be able to unconditionally set the Content-Length here. + // + // However, there is a pattern observed in the wild that this breaks: + // The user wraps the ResponseWriter in one which gzips data written to it, + // and sets "Content-Encoding: gzip". + // + // The user shouldn't be doing this; the serveContent path here depends + // on serving seekable data with a known length. If you want to compress + // on the fly, then you shouldn't be using ServeFile/ServeContent, or + // you should compress the entire file up-front and provide a seekable + // view of the compressed data. + // + // However, since we've observed this pattern in the wild, and since + // setting Content-Length here breaks code that mostly-works today, + // skip setting Content-Length if the user set Content-Encoding. + // + // If this is a range request, always set Content-Length. + // If the user isn't changing the bytes sent in the ResponseWrite, + // the Content-Length will be correct. + // If the user is changing the bytes sent, then the range request wasn't + // going to work properly anyway and we aren't worse off. + // + // A possible future improvement on this might be to look at the type + // of the ResponseWriter, and always set Content-Length if it's one + // that we recognize. + if len(ranges) > 0 || w.Header().Get("Content-Encoding") == "" { w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) } - w.WriteHeader(code) if r.Method != "HEAD" { @@ -716,13 +741,13 @@ func localRedirect(w ResponseWriter, r *Request, newPath string) { // // As a precaution, ServeFile will reject requests where r.URL.Path // contains a ".." path element; this protects against callers who -// might unsafely use filepath.Join on r.URL.Path without sanitizing +// might unsafely use [filepath.Join] on r.URL.Path without sanitizing // it and then use that filepath.Join result as the name argument. // // As another special case, ServeFile redirects any request where r.URL.Path // ends in "/index.html" to the same path, without the final // "index.html". To avoid such redirects either modify the path or -// use ServeContent. +// use [ServeContent]. // // Outside of those two special cases, ServeFile does not use // r.URL.Path for selecting the file or directory to serve; only the @@ -741,6 +766,40 @@ func ServeFile(w ResponseWriter, r *Request, name string) { serveFile(w, r, Dir(dir), file, false) } +// ServeFileFS replies to the request with the contents +// of the named file or directory from the file system fsys. +// +// If the provided file or directory name is a relative path, it is +// interpreted relative to the current directory and may ascend to +// parent directories. If the provided name is constructed from user +// input, it should be sanitized before calling [ServeFile]. +// +// As a precaution, ServeFile will reject requests where r.URL.Path +// contains a ".." path element; this protects against callers who +// might unsafely use [filepath.Join] on r.URL.Path without sanitizing +// it and then use that filepath.Join result as the name argument. +// +// As another special case, ServeFile redirects any request where r.URL.Path +// ends in "/index.html" to the same path, without the final +// "index.html". To avoid such redirects either modify the path or +// use ServeContent. +// +// Outside of those two special cases, ServeFile does not use +// r.URL.Path for selecting the file or directory to serve; only the +// file or directory provided in the name argument is used. +func ServeFileFS(w ResponseWriter, r *Request, fsys fs.FS, name string) { + if containsDotDot(r.URL.Path) { + // Too many programs use r.URL.Path to construct the argument to + // serveFile. Reject the request under the assumption that happened + // here and ".." may not be wanted. + // Note that name might not contain "..", for example if code (still + // incorrectly) used filepath.Join(myDir, r.URL.Path). + Error(w, "invalid URL path", StatusBadRequest) + return + } + serveFile(w, r, FS(fsys), name, false) +} + func containsDotDot(v string) bool { if !strings.Contains(v, "..") { return false @@ -831,9 +890,9 @@ func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) { return list, nil } -// FS converts fsys to a FileSystem implementation, -// for use with FileServer and NewFileTransport. -// The files provided by fsys must implement io.Seeker. +// FS converts fsys to a [FileSystem] implementation, +// for use with [FileServer] and [NewFileTransport]. +// The files provided by fsys must implement [io.Seeker]. func FS(fsys fs.FS) FileSystem { return ioFS{fsys} } @@ -846,17 +905,27 @@ func FS(fsys fs.FS) FileSystem { // "index.html". // // To use the operating system's file system implementation, -// use http.Dir: +// use [http.Dir]: // // http.Handle("/", http.FileServer(http.Dir("/tmp"))) // -// To use an fs.FS implementation, use http.FS to convert it: -// -// http.Handle("/", http.FileServer(http.FS(fsys))) +// To use an [fs.FS] implementation, use [http.FileServerFS] instead. func FileServer(root FileSystem) Handler { return &fileHandler{root} } +// FileServerFS returns a handler that serves HTTP requests +// with the contents of the file system fsys. +// +// As a special case, the returned file server redirects any request +// ending in "/index.html" to the same path, without the final +// "index.html". +// +// http.Handle("/", http.FileServerFS(fsys)) +func FileServerFS(root fs.FS) Handler { + return FileServer(FS(root)) +} + func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { upath := r.URL.Path if !strings.HasPrefix(upath, "/") { diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 3fb9e01235..861e70caf2 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -7,13 +7,16 @@ package http_test import ( "bufio" "bytes" + "compress/gzip" "errors" "fmt" + "internal/testenv" "io" "io/fs" "mime" "mime/multipart" "net" + "net/http" . "net/http" "net/http/httptest" "net/url" @@ -26,6 +29,7 @@ import ( "runtime" "strings" "testing" + "testing/fstest" "time" ) @@ -1265,7 +1269,7 @@ func TestLinuxSendfile(t *testing.T) { defer ln.Close() // Attempt to run strace, and skip on failure - this test requires SYS_PTRACE. - if err := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=^$").Run(); err != nil { + if err := testenv.Command(t, "strace", "-f", "-q", os.Args[0], "-test.run=^$").Run(); err != nil { t.Skipf("skipping; failed to run strace: %v", err) } @@ -1278,7 +1282,7 @@ func TestLinuxSendfile(t *testing.T) { defer os.Remove(filepath) var buf strings.Builder - child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild") + child := testenv.Command(t, "strace", "-f", "-q", os.Args[0], "-test.run=^TestLinuxSendfileChild$") child.ExtraFiles = append(child.ExtraFiles, lnf) child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf @@ -1559,3 +1563,108 @@ func testFileServerMethods(t *testing.T, mode testMode) { } } } + +func TestFileServerFS(t *testing.T) { + filename := "index.html" + contents := []byte("index.html says hello") + fsys := fstest.MapFS{ + filename: {Data: contents}, + } + ts := newClientServerTest(t, http1Mode, FileServerFS(fsys)).ts + defer ts.Close() + + res, err := ts.Client().Get(ts.URL + "/" + filename) + if err != nil { + t.Fatal(err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != string(contents) { + t.Errorf("for path %q got %q, want %q", filename, s, contents) + } + res.Body.Close() +} + +func TestServeFileFS(t *testing.T) { + filename := "index.html" + contents := []byte("index.html says hello") + fsys := fstest.MapFS{ + filename: {Data: contents}, + } + ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFileFS(w, r, fsys, filename) + })).ts + defer ts.Close() + + res, err := ts.Client().Get(ts.URL + "/" + filename) + if err != nil { + t.Fatal(err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != string(contents) { + t.Errorf("for path %q got %q, want %q", filename, s, contents) + } + res.Body.Close() +} + +func TestServeFileZippingResponseWriter(t *testing.T) { + // This test exercises a pattern which is incorrect, + // but has been observed enough in the world that we don't want to break it. + // + // The server is setting "Content-Encoding: gzip", + // wrapping the ResponseWriter in an implementation which gzips data written to it, + // and passing this ResponseWriter to ServeFile. + // + // This means ServeFile cannot properly set a Content-Length header, because it + // doesn't know what content it is going to send--the ResponseWriter is modifying + // the bytes sent. + // + // Range requests are always going to be broken in this scenario, + // but verify that we can serve non-range requests correctly. + filename := "index.html" + contents := []byte("contents will be sent with Content-Encoding: gzip") + fsys := fstest.MapFS{ + filename: {Data: contents}, + } + ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + gzw := gzip.NewWriter(w) + defer gzw.Close() + ServeFileFS(gzipResponseWriter{w: gzw, ResponseWriter: w}, r, fsys, filename) + })).ts + defer ts.Close() + + res, err := ts.Client().Get(ts.URL + "/" + filename) + if err != nil { + t.Fatal(err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != string(contents) { + t.Errorf("for path %q got %q, want %q", filename, s, contents) + } + res.Body.Close() +} + +type gzipResponseWriter struct { + ResponseWriter + w *gzip.Writer +} + +func (grw gzipResponseWriter) Write(b []byte) (int, error) { + return grw.w.Write(b) +} + +func (grw gzipResponseWriter) Flush() { + grw.w.Flush() + if fw, ok := grw.ResponseWriter.(http.Flusher); ok { + fw.Flush() + } +} diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index dd59e1f4f2..ac41144d5b 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -1,5 +1,4 @@ //go:build !nethttpomithttp2 -// +build !nethttpomithttp2 // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. // $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2 @@ -33,6 +32,7 @@ import ( "io/fs" "log" "math" + "math/bits" mathrand "math/rand" "net" "net/http/httptrace" @@ -1041,41 +1041,44 @@ func http2shouldRetryDial(call *http2dialCall, req *Request) bool { // TODO: Benchmark to determine if the pools are necessary. The GC may have // improved enough that we can instead allocate chunks like this: // make([]byte, max(16<<10, expectedBytesRemaining)) -var ( - http2dataChunkSizeClasses = []int{ - 1 << 10, - 2 << 10, - 4 << 10, - 8 << 10, - 16 << 10, - } - http2dataChunkPools = [...]sync.Pool{ - {New: func() interface{} { return make([]byte, 1<<10) }}, - {New: func() interface{} { return make([]byte, 2<<10) }}, - {New: func() interface{} { return make([]byte, 4<<10) }}, - {New: func() interface{} { return make([]byte, 8<<10) }}, - {New: func() interface{} { return make([]byte, 16<<10) }}, - } -) +var http2dataChunkPools = [...]sync.Pool{ + {New: func() interface{} { return new([1 << 10]byte) }}, + {New: func() interface{} { return new([2 << 10]byte) }}, + {New: func() interface{} { return new([4 << 10]byte) }}, + {New: func() interface{} { return new([8 << 10]byte) }}, + {New: func() interface{} { return new([16 << 10]byte) }}, +} func http2getDataBufferChunk(size int64) []byte { - i := 0 - for ; i < len(http2dataChunkSizeClasses)-1; i++ { - if size <= int64(http2dataChunkSizeClasses[i]) { - break - } + switch { + case size <= 1<<10: + return http2dataChunkPools[0].Get().(*[1 << 10]byte)[:] + case size <= 2<<10: + return http2dataChunkPools[1].Get().(*[2 << 10]byte)[:] + case size <= 4<<10: + return http2dataChunkPools[2].Get().(*[4 << 10]byte)[:] + case size <= 8<<10: + return http2dataChunkPools[3].Get().(*[8 << 10]byte)[:] + default: + return http2dataChunkPools[4].Get().(*[16 << 10]byte)[:] } - return http2dataChunkPools[i].Get().([]byte) } func http2putDataBufferChunk(p []byte) { - for i, n := range http2dataChunkSizeClasses { - if len(p) == n { - http2dataChunkPools[i].Put(p) - return - } + switch len(p) { + case 1 << 10: + http2dataChunkPools[0].Put((*[1 << 10]byte)(p)) + case 2 << 10: + http2dataChunkPools[1].Put((*[2 << 10]byte)(p)) + case 4 << 10: + http2dataChunkPools[2].Put((*[4 << 10]byte)(p)) + case 8 << 10: + http2dataChunkPools[3].Put((*[8 << 10]byte)(p)) + case 16 << 10: + http2dataChunkPools[4].Put((*[16 << 10]byte)(p)) + default: + panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) } - panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) } // dataBuffer is an io.ReadWriter backed by a list of data chunks. @@ -3058,41 +3061,6 @@ func http2summarizeFrame(f http2Frame) string { return buf.String() } -func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { - return trace != nil && trace.WroteHeaderField != nil -} - -func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField(k, []string{v}) - } -} - -func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { - if trace != nil { - return trace.Got1xxResponse - } - return nil -} - -// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS -// connection. -func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) { - dialer := &tls.Dialer{ - Config: cfg, - } - cn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed - return tlsCn, nil -} - -func http2tlsUnderlyingConn(tc *tls.Conn) net.Conn { - return tc.NetConn() -} - var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -4831,14 +4799,6 @@ func (sc *http2serverConn) serve() { } } -func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) { - select { - case <-sc.doneServing: - case <-sharedCh: - close(privateCh) - } -} - type http2serverMessage int // Message values sent to serveMsgCh. @@ -5722,9 +5682,11 @@ func (st *http2stream) copyTrailersToHandlerRequest() { // onReadTimeout is run on its own goroutine (from time.AfterFunc) // when the stream's ReadTimeout has fired. func (st *http2stream) onReadTimeout() { - // Wrap the ErrDeadlineExceeded to avoid callers depending on us - // returning the bare error. - st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) + if st.body != nil { + // Wrap the ErrDeadlineExceeded to avoid callers depending on us + // returning the bare error. + st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) + } } // onWriteTimeout is run on its own goroutine (from time.AfterFunc) @@ -5842,9 +5804,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout != 0 { sc.conn.SetReadDeadline(time.Time{}) - if st.body != nil { - st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) - } + st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } return sc.scheduleHandler(id, rw, req, handler) @@ -6374,7 +6334,6 @@ type http2responseWriterState struct { wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. sentHeader bool // have we sent the header frame? handlerDone bool // handler has finished - dirty bool // a Write failed; don't reuse this responseWriterState sentContentLen int64 // non-zero if handler set a Content-Length header wroteBytes int64 @@ -6494,7 +6453,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { date: date, }) if err != nil { - rws.dirty = true return 0, err } if endStream { @@ -6515,7 +6473,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { if len(p) > 0 || endStream { // only send a 0 byte DATA frame if we're ending the stream. if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { - rws.dirty = true return 0, err } } @@ -6527,9 +6484,6 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { trailers: rws.trailers, endStream: true, }) - if err != nil { - rws.dirty = true - } return len(p), err } return len(p), nil @@ -6745,14 +6699,12 @@ func (rws *http2responseWriterState) writeHeader(code int) { h.Del("Transfer-Encoding") } - if rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ streamID: rws.stream.id, httpResCode: code, h: h, endStream: rws.handlerDone && !rws.hasTrailers(), - }) != nil { - rws.dirty = true - } + }) return } @@ -6817,19 +6769,10 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n func (w *http2responseWriter) handlerDone() { rws := w.rws - dirty := rws.dirty rws.handlerDone = true w.Flush() w.rws = nil - if !dirty { - // Only recycle the pool if all prior Write calls to - // the serverConn goroutine completed successfully. If - // they returned earlier due to resets from the peer - // there might still be write goroutines outstanding - // from the serverConn referencing the rws memory. See - // issue 20704. - http2responseWriterStatePool.Put(rws) - } + http2responseWriterStatePool.Put(rws) } // Push errors. @@ -7374,8 +7317,7 @@ func (t *http2Transport) initConnPool() { // HTTP/2 server. type http2ClientConn struct { t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls - tconnClosed bool + tconn net.Conn // usually *tls.Conn, except specialized impls tlsState *tls.ConnectionState // nil only for specialized impls reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request @@ -8103,7 +8045,7 @@ func (cc *http2ClientConn) forceCloseConn() { if !ok { return } - if nc := http2tlsUnderlyingConn(tc); nc != nil { + if nc := tc.NetConn(); nc != nil { nc.Close() } } @@ -8765,7 +8707,28 @@ func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { return int(n) // doesn't truncate; max is 512K } -var http2bufPool sync.Pool // of *[]byte +// Seven bufPools manage different frame sizes. This helps to avoid scenarios where long-running +// streaming requests using small frame sizes occupy large buffers initially allocated for prior +// requests needing big buffers. The size ranges are as follows: +// {0 KB, 16 KB], {16 KB, 32 KB], {32 KB, 64 KB], {64 KB, 128 KB], {128 KB, 256 KB], +// {256 KB, 512 KB], {512 KB, infinity} +// In practice, the maximum scratch buffer size should not exceed 512 KB due to +// frameScratchBufferLen(maxFrameSize), thus the "infinity pool" should never be used. +// It exists mainly as a safety measure, for potential future increases in max buffer size. +var http2bufPools [7]sync.Pool // of *[]byte + +func http2bufPoolIndex(size int) int { + if size <= 16384 { + return 0 + } + size -= 1 + bits := bits.Len(uint(size)) + index := bits - 14 + if index >= len(http2bufPools) { + return len(http2bufPools) - 1 + } + return index +} func (cs *http2clientStream) writeRequestBody(req *Request) (err error) { cc := cs.cc @@ -8783,12 +8746,13 @@ func (cs *http2clientStream) writeRequestBody(req *Request) (err error) { // Scratch buffer for reading into & writing from. scratchLen := cs.frameScratchBufferLen(maxFrameSize) var buf []byte - if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { - defer http2bufPool.Put(bp) + index := http2bufPoolIndex(scratchLen) + if bp, ok := http2bufPools[index].Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer http2bufPools[index].Put(bp) buf = *bp } else { buf = make([]byte, scratchLen) - defer http2bufPool.Put(&buf) + defer http2bufPools[index].Put(&buf) } var sawEOF bool @@ -10269,6 +10233,37 @@ func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { } } +func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { + return trace != nil && trace.WroteHeaderField != nil +} + +func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(k, []string{v}) + } +} + +func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { + if trace != nil { + return trace.Got1xxResponse + } + return nil +} + +// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS +// connection. +func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) { + dialer := &tls.Dialer{ + Config: cfg, + } + cn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed + return tlsCn, nil +} + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error diff --git a/src/net/http/h2_error.go b/src/net/http/h2_error.go index 0391d31e5b..2c0b21ec07 100644 --- a/src/net/http/h2_error.go +++ b/src/net/http/h2_error.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !nethttpomithttp2 -// +build !nethttpomithttp2 package http diff --git a/src/net/http/h2_error_test.go b/src/net/http/h2_error_test.go index 0d85e2f36c..5e400683b4 100644 --- a/src/net/http/h2_error_test.go +++ b/src/net/http/h2_error_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !nethttpomithttp2 -// +build !nethttpomithttp2 package http diff --git a/src/net/http/header.go b/src/net/http/header.go index e0b342c63c..9d0f3a125d 100644 --- a/src/net/http/header.go +++ b/src/net/http/header.go @@ -20,13 +20,13 @@ import ( // A Header represents the key-value pairs in an HTTP header. // // The keys should be in canonical form, as returned by -// CanonicalHeaderKey. +// [CanonicalHeaderKey]. type Header map[string][]string // Add adds the key, value pair to the header. // It appends to any existing values associated with key. // The key is case insensitive; it is canonicalized by -// CanonicalHeaderKey. +// [CanonicalHeaderKey]. func (h Header) Add(key, value string) { textproto.MIMEHeader(h).Add(key, value) } @@ -34,7 +34,7 @@ func (h Header) Add(key, value string) { // Set sets the header entries associated with key to the // single element value. It replaces any existing values // associated with key. The key is case insensitive; it is -// canonicalized by textproto.CanonicalMIMEHeaderKey. +// canonicalized by [textproto.CanonicalMIMEHeaderKey]. // To use non-canonical keys, assign to the map directly. func (h Header) Set(key, value string) { textproto.MIMEHeader(h).Set(key, value) @@ -42,7 +42,7 @@ func (h Header) Set(key, value string) { // Get gets the first value associated with the given key. If // there are no values associated with the key, Get returns "". -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// It is case insensitive; [textproto.CanonicalMIMEHeaderKey] is // used to canonicalize the provided key. Get assumes that all // keys are stored in canonical form. To use non-canonical keys, // access the map directly. @@ -51,7 +51,7 @@ func (h Header) Get(key string) string { } // Values returns all values associated with the given key. -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// It is case insensitive; [textproto.CanonicalMIMEHeaderKey] is // used to canonicalize the provided key. To use non-canonical // keys, access the map directly. // The returned slice is not a copy. @@ -76,7 +76,7 @@ func (h Header) has(key string) bool { // Del deletes the values associated with key. // The key is case insensitive; it is canonicalized by -// CanonicalHeaderKey. +// [CanonicalHeaderKey]. func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } @@ -125,7 +125,7 @@ var timeFormats = []string{ // ParseTime parses a time header (such as the Date: header), // trying each of the three formats allowed by HTTP/1.1: -// TimeFormat, time.RFC850, and time.ANSIC. +// [TimeFormat], [time.RFC850], and [time.ANSIC]. func ParseTime(text string) (t time.Time, err error) { for _, layout := range timeFormats { t, err = time.Parse(layout, text) diff --git a/src/net/http/http.go b/src/net/http/http.go index 9b81654fcc..6e2259adbf 100644 --- a/src/net/http/http.go +++ b/src/net/http/http.go @@ -103,10 +103,10 @@ func hexEscapeNonASCII(s string) string { return string(b) } -// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// NoBody is an [io.ReadCloser] with no bytes. Read always returns EOF // and Close always returns nil. It can be used in an outgoing client // request to explicitly signal that a request has zero bytes. -// An alternative, however, is to simply set Request.Body to nil. +// An alternative, however, is to simply set [Request.Body] to nil. var NoBody = noBody{} type noBody struct{} @@ -121,7 +121,7 @@ var ( _ io.ReadCloser = NoBody ) -// PushOptions describes options for Pusher.Push. +// PushOptions describes options for [Pusher.Push]. type PushOptions struct { // Method specifies the HTTP method for the promised request. // If set, it must be "GET" or "HEAD". Empty means "GET". diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go index 91bb1b2620..2e7e024e20 100644 --- a/src/net/http/http_test.go +++ b/src/net/http/http_test.go @@ -12,7 +12,6 @@ import ( "io/fs" "net/url" "os" - "os/exec" "reflect" "regexp" "strings" @@ -55,7 +54,7 @@ func TestForeachHeaderElement(t *testing.T) { func TestCmdGoNoHTTPServer(t *testing.T) { t.Parallel() goBin := testenv.GoToolPath(t) - out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput() + out, err := testenv.Command(t, goBin, "tool", "nm", goBin).CombinedOutput() if err != nil { t.Fatalf("go tool nm: %v: %s", err, out) } @@ -89,7 +88,7 @@ func TestOmitHTTP2(t *testing.T) { } t.Parallel() goTool := testenv.GoToolPath(t) - out, err := exec.Command(goTool, "test", "-short", "-tags=nethttpomithttp2", "net/http").CombinedOutput() + out, err := testenv.Command(t, goTool, "test", "-short", "-tags=nethttpomithttp2", "net/http").CombinedOutput() if err != nil { t.Fatalf("go test -short failed: %v, %s", err, out) } @@ -101,7 +100,7 @@ func TestOmitHTTP2(t *testing.T) { func TestOmitHTTP2Vet(t *testing.T) { t.Parallel() goTool := testenv.GoToolPath(t) - out, err := exec.Command(goTool, "vet", "-tags=nethttpomithttp2", "net/http").CombinedOutput() + out, err := testenv.Command(t, goTool, "vet", "-tags=nethttpomithttp2", "net/http").CombinedOutput() if err != nil { t.Fatalf("go vet failed: %v, %s", err, out) } diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go index 9bedefd2bc..f0ca64362d 100644 --- a/src/net/http/httptest/httptest.go +++ b/src/net/http/httptest/httptest.go @@ -15,7 +15,7 @@ import ( ) // NewRequest returns a new incoming server Request, suitable -// for passing to an http.Handler for testing. +// for passing to an [http.Handler] for testing. // // The target is the RFC 7230 "request-target": it may be either a // path or an absolute URL. If target is an absolute URL, the host name diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 1c1d880155..dd51901b0d 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -16,7 +16,7 @@ import ( "golang.org/x/net/http/httpguts" ) -// ResponseRecorder is an implementation of http.ResponseWriter that +// ResponseRecorder is an implementation of [http.ResponseWriter] that // records its mutations for later inspection in tests. type ResponseRecorder struct { // Code is the HTTP response code set by WriteHeader. @@ -47,7 +47,7 @@ type ResponseRecorder struct { wroteHeader bool } -// NewRecorder returns an initialized ResponseRecorder. +// NewRecorder returns an initialized [ResponseRecorder]. func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), @@ -57,12 +57,12 @@ func NewRecorder() *ResponseRecorder { } // DefaultRemoteAddr is the default remote address to return in RemoteAddr if -// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +// an explicit DefaultRemoteAddr isn't set on [ResponseRecorder]. const DefaultRemoteAddr = "1.2.3.4" -// Header implements http.ResponseWriter. It returns the response +// Header implements [http.ResponseWriter]. It returns the response // headers to mutate within a handler. To test the headers that were -// written after a handler completes, use the Result method and see +// written after a handler completes, use the [ResponseRecorder.Result] method and see // the returned Response value's Header. func (rw *ResponseRecorder) Header() http.Header { m := rw.HeaderMap @@ -112,7 +112,7 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) { return len(buf), nil } -// WriteString implements io.StringWriter. The data in str is written +// WriteString implements [io.StringWriter]. The data in str is written // to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { rw.writeHeader(nil, str) @@ -139,7 +139,7 @@ func checkWriteHeaderCode(code int) { } } -// WriteHeader implements http.ResponseWriter. +// WriteHeader implements [http.ResponseWriter]. func (rw *ResponseRecorder) WriteHeader(code int) { if rw.wroteHeader { return @@ -154,7 +154,7 @@ func (rw *ResponseRecorder) WriteHeader(code int) { rw.snapHeader = rw.HeaderMap.Clone() } -// Flush implements http.Flusher. To test whether Flush was +// Flush implements [http.Flusher]. To test whether Flush was // called, see rw.Flushed. func (rw *ResponseRecorder) Flush() { if !rw.wroteHeader { @@ -175,7 +175,7 @@ func (rw *ResponseRecorder) Flush() { // did a write. // // The Response.Body is guaranteed to be non-nil and Body.Read call is -// guaranteed to not return any error other than io.EOF. +// guaranteed to not return any error other than [io.EOF]. // // Result must only be called after the handler has finished running. func (rw *ResponseRecorder) Result() *http.Response { diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index f254a494d1..5095b438ec 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -77,7 +77,7 @@ func newLocalListener() net.Listener { // When debugging a particular http server-based test, // this flag lets you run // -// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// go test -run='^BrokenTest$' -httptest.serve=127.0.0.1:8000 // // to start the broken server so you can interact with it manually. // We only register this flag if it looks like the caller knows about it @@ -100,7 +100,7 @@ func strSliceContainsPrefix(v []string, pre string) bool { return false } -// NewServer starts and returns a new Server. +// NewServer starts and returns a new [Server]. // The caller should call Close when finished, to shut it down. func NewServer(handler http.Handler) *Server { ts := NewUnstartedServer(handler) @@ -108,7 +108,7 @@ func NewServer(handler http.Handler) *Server { return ts } -// NewUnstartedServer returns a new Server but doesn't start it. +// NewUnstartedServer returns a new [Server] but doesn't start it. // // After changing its configuration, the caller should call Start or // StartTLS. @@ -144,7 +144,7 @@ func (s *Server) StartTLS() { panic("Server already started") } if s.client == nil { - s.client = &http.Client{Transport: &http.Transport{}} + s.client = &http.Client{} } cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) if err != nil { @@ -185,7 +185,7 @@ func (s *Server) StartTLS() { s.goServe() } -// NewTLSServer starts and returns a new Server using TLS. +// NewTLSServer starts and returns a new [Server] using TLS. // The caller should call Close when finished, to shut it down. func NewTLSServer(handler http.Handler) *Server { ts := NewUnstartedServer(handler) @@ -298,7 +298,7 @@ func (s *Server) Certificate() *x509.Certificate { // Client returns an HTTP client configured for making requests to the server. // It is configured to trust the server's TLS test certificate and will -// close its idle connections on Server.Close. +// close its idle connections on [Server.Close]. func (s *Server) Client() *http.Client { return s.client } diff --git a/src/net/http/httptrace/trace.go b/src/net/http/httptrace/trace.go index 6af30f78d1..706a432957 100644 --- a/src/net/http/httptrace/trace.go +++ b/src/net/http/httptrace/trace.go @@ -19,7 +19,7 @@ import ( // unique type to prevent assignment. type clientEventContextKey struct{} -// ContextClientTrace returns the ClientTrace associated with the +// ContextClientTrace returns the [ClientTrace] associated with the // provided context. If none, it returns nil. func ContextClientTrace(ctx context.Context) *ClientTrace { trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace) @@ -233,7 +233,7 @@ func (t *ClientTrace) hasNetHooks() bool { return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil } -// GotConnInfo is the argument to the ClientTrace.GotConn function and +// GotConnInfo is the argument to the [ClientTrace.GotConn] function and // contains information about the obtained connection. type GotConnInfo struct { // Conn is the connection that was obtained. It is owned by diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go index 7affe5e61a..2edb9bc98d 100644 --- a/src/net/http/httputil/dump.go +++ b/src/net/http/httputil/dump.go @@ -71,8 +71,8 @@ func outgoingLength(req *http.Request) int64 { return -1 } -// DumpRequestOut is like DumpRequest but for outgoing client requests. It -// includes any headers that the standard http.Transport adds, such as +// DumpRequestOut is like [DumpRequest] but for outgoing client requests. It +// includes any headers that the standard [http.Transport] adds, such as // User-Agent. func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { save := req.Body @@ -203,17 +203,17 @@ var reqWriteExcludeHeaderDump = map[string]bool{ // representation. It should only be used by servers to debug client // requests. The returned representation is an approximation only; // some details of the initial request are lost while parsing it into -// an http.Request. In particular, the order and case of header field +// an [http.Request]. In particular, the order and case of header field // names are lost. The order of values in multi-valued headers is kept // intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their // original binary representations. // // If body is true, DumpRequest also returns the body. To do so, it -// consumes req.Body and then replaces it with a new io.ReadCloser +// consumes req.Body and then replaces it with a new [io.ReadCloser] // that yields the same bytes. If DumpRequest returns an error, // the state of req is undefined. // -// The documentation for http.Request.Write details which fields +// The documentation for [http.Request.Write] details which fields // of req are included in the dump. func DumpRequest(req *http.Request, body bool) ([]byte, error) { var err error diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go index 09ea74d6d1..431930ea65 100644 --- a/src/net/http/httputil/httputil.go +++ b/src/net/http/httputil/httputil.go @@ -13,7 +13,7 @@ import ( // NewChunkedReader returns a new chunkedReader that translates the data read from r // out of HTTP "chunked" format before returning it. -// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// The chunkedReader returns [io.EOF] when the final 0-length chunk is read. // // NewChunkedReader is not needed by normal applications. The http package // automatically decodes chunking when reading response bodies. diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go index 84b116df8c..0cbe3ebf10 100644 --- a/src/net/http/httputil/persist.go +++ b/src/net/http/httputil/persist.go @@ -33,7 +33,7 @@ var errClosed = errors.New("i/o operation on closed connection") // It is low-level, old, and unused by Go's current HTTP stack. // We should have deleted it before Go 1. // -// Deprecated: Use the Server in package net/http instead. +// Deprecated: Use the Server in package [net/http] instead. type ServerConn struct { mu sync.Mutex // read-write protects the following fields c net.Conn @@ -50,7 +50,7 @@ type ServerConn struct { // It is low-level, old, and unused by Go's current HTTP stack. // We should have deleted it before Go 1. // -// Deprecated: Use the Server in package net/http instead. +// Deprecated: Use the Server in package [net/http] instead. func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { if r == nil { r = bufio.NewReader(c) @@ -58,10 +58,10 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)} } -// Hijack detaches the ServerConn and returns the underlying connection as well +// Hijack detaches the [ServerConn] and returns the underlying connection as well // as the read-side bufio which may have some left over data. Hijack may be // called before Read has signaled the end of the keep-alive logic. The user -// should not call Hijack while Read or Write is in progress. +// should not call Hijack while [ServerConn.Read] or [ServerConn.Write] is in progress. func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) { sc.mu.Lock() defer sc.mu.Unlock() @@ -72,7 +72,7 @@ func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) { return c, r } -// Close calls Hijack and then also closes the underlying connection. +// Close calls [ServerConn.Hijack] and then also closes the underlying connection. func (sc *ServerConn) Close() error { c, _ := sc.Hijack() if c != nil { @@ -81,7 +81,7 @@ func (sc *ServerConn) Close() error { return nil } -// Read returns the next request on the wire. An ErrPersistEOF is returned if +// Read returns the next request on the wire. An [ErrPersistEOF] is returned if // it is gracefully determined that there are no more requests (e.g. after the // first request on an HTTP/1.0 connection, or after a Connection:close on a // HTTP/1.1 connection). @@ -171,7 +171,7 @@ func (sc *ServerConn) Pending() int { // Write writes resp in response to req. To close the connection gracefully, set the // Response.Close field to true. Write should be considered operational until -// it returns an error, regardless of any errors returned on the Read side. +// it returns an error, regardless of any errors returned on the [ServerConn.Read] side. func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { // Retrieve the pipeline ID of this request/response pair @@ -226,7 +226,7 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { // It is low-level, old, and unused by Go's current HTTP stack. // We should have deleted it before Go 1. // -// Deprecated: Use Client or Transport in package net/http instead. +// Deprecated: Use Client or Transport in package [net/http] instead. type ClientConn struct { mu sync.Mutex // read-write protects the following fields c net.Conn @@ -244,7 +244,7 @@ type ClientConn struct { // It is low-level, old, and unused by Go's current HTTP stack. // We should have deleted it before Go 1. // -// Deprecated: Use the Client or Transport in package net/http instead. +// Deprecated: Use the Client or Transport in package [net/http] instead. func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { if r == nil { r = bufio.NewReader(c) @@ -261,17 +261,17 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { // It is low-level, old, and unused by Go's current HTTP stack. // We should have deleted it before Go 1. // -// Deprecated: Use the Client or Transport in package net/http instead. +// Deprecated: Use the Client or Transport in package [net/http] instead. func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { cc := NewClientConn(c, r) cc.writeReq = (*http.Request).WriteProxy return cc } -// Hijack detaches the ClientConn and returns the underlying connection as well +// Hijack detaches the [ClientConn] and returns the underlying connection as well // as the read-side bufio which may have some left over data. Hijack may be // called before the user or Read have signaled the end of the keep-alive -// logic. The user should not call Hijack while Read or Write is in progress. +// logic. The user should not call Hijack while [ClientConn.Read] or ClientConn.Write is in progress. func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { cc.mu.Lock() defer cc.mu.Unlock() @@ -282,7 +282,7 @@ func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { return } -// Close calls Hijack and then also closes the underlying connection. +// Close calls [ClientConn.Hijack] and then also closes the underlying connection. func (cc *ClientConn) Close() error { c, _ := cc.Hijack() if c != nil { @@ -291,7 +291,7 @@ func (cc *ClientConn) Close() error { return nil } -// Write writes a request. An ErrPersistEOF error is returned if the connection +// Write writes a request. An [ErrPersistEOF] error is returned if the connection // has been closed in an HTTP keep-alive sense. If req.Close equals true, the // keep-alive connection is logically closed after this request and the opposing // server is informed. An ErrUnexpectedEOF indicates the remote closed the @@ -357,9 +357,9 @@ func (cc *ClientConn) Pending() int { } // Read reads the next response from the wire. A valid response might be -// returned together with an ErrPersistEOF, which means that the remote +// returned together with an [ErrPersistEOF], which means that the remote // requested that this be the last request serviced. Read can be called -// concurrently with Write, but not with another Read. +// concurrently with [ClientConn.Write], but not with another Read. func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { // Retrieve the pipeline ID of this request/response pair cc.mu.Lock() diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 2a76b0b8dc..5c70f0d27b 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -26,7 +26,7 @@ import ( "golang.org/x/net/http/httpguts" ) -// A ProxyRequest contains a request to be rewritten by a ReverseProxy. +// A ProxyRequest contains a request to be rewritten by a [ReverseProxy]. type ProxyRequest struct { // In is the request received by the proxy. // The Rewrite function must not modify In. @@ -45,7 +45,7 @@ type ProxyRequest struct { // // SetURL rewrites the outbound Host header to match the target's host. // To preserve the inbound request's Host header (the default behavior -// of NewSingleHostReverseProxy): +// of [NewSingleHostReverseProxy]): // // rewriteFunc := func(r *httputil.ProxyRequest) { // r.SetURL(url) @@ -68,7 +68,7 @@ func (r *ProxyRequest) SetURL(target *url.URL) { // If the outbound request contains an existing X-Forwarded-For header, // SetXForwarded appends the client IP address to it. To append to the // inbound request's X-Forwarded-For header (the default behavior of -// ReverseProxy when using a Director function), copy the header +// [ReverseProxy] when using a Director function), copy the header // from the inbound request before calling SetXForwarded: // // rewriteFunc := func(r *httputil.ProxyRequest) { @@ -200,7 +200,7 @@ type ReverseProxy struct { } // A BufferPool is an interface for getting and returning temporary -// byte slices for use by io.CopyBuffer. +// byte slices for use by [io.CopyBuffer]. type BufferPool interface { Get() []byte Put([]byte) @@ -239,7 +239,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { return a.Path + b.Path, apath + bpath } -// NewSingleHostReverseProxy returns a new ReverseProxy that routes +// NewSingleHostReverseProxy returns a new [ReverseProxy] that routes // URLs to the scheme, host, and base path provided in target. If the // target's path is "/base" and the incoming request was for "/dir", // the target request will be for /base/dir. @@ -461,10 +461,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(code) // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses - for k := range h { - delete(h, k) - } - + clear(h) return nil }, } diff --git a/src/net/http/internal/ascii/print.go b/src/net/http/internal/ascii/print.go index 585e5baba4..98dbf4e3d2 100644 --- a/src/net/http/internal/ascii/print.go +++ b/src/net/http/internal/ascii/print.go @@ -9,7 +9,7 @@ import ( "unicode" ) -// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// EqualFold is [strings.EqualFold], ASCII only. It reports whether s and t // are equal, ASCII-case-insensitively. func EqualFold(s, t string) bool { if len(s) != len(t) { diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go index 5a174415dc..196b5d8925 100644 --- a/src/net/http/internal/chunked.go +++ b/src/net/http/internal/chunked.go @@ -22,7 +22,7 @@ var ErrLineTooLong = errors.New("header line too long") // NewChunkedReader returns a new chunkedReader that translates the data read from r // out of HTTP "chunked" format before returning it. -// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// The chunkedReader returns [io.EOF] when the final 0-length chunk is read. // // NewChunkedReader is not needed by normal applications. The http package // automatically decodes chunking when reading response bodies. @@ -39,7 +39,8 @@ type chunkedReader struct { n uint64 // unread bytes in chunk err error buf [2]byte - checkEnd bool // whether need to check for \r\n chunk footer + checkEnd bool // whether need to check for \r\n chunk footer + excess int64 // "excessive" chunk overhead, for malicious sender detection } func (cr *chunkedReader) beginChunk() { @@ -49,10 +50,36 @@ func (cr *chunkedReader) beginChunk() { if cr.err != nil { return } + cr.excess += int64(len(line)) + 2 // header, plus \r\n after the chunk data + line = trimTrailingWhitespace(line) + line, cr.err = removeChunkExtension(line) + if cr.err != nil { + return + } cr.n, cr.err = parseHexUint(line) if cr.err != nil { return } + // A sender who sends one byte per chunk will send 5 bytes of overhead + // for every byte of data. ("1\r\nX\r\n" to send "X".) + // We want to allow this, since streaming a byte at a time can be legitimate. + // + // A sender can use chunk extensions to add arbitrary amounts of additional + // data per byte read. ("1;very long extension\r\nX\r\n" to send "X".) + // We don't want to disallow extensions (although we discard them), + // but we also don't want to allow a sender to reduce the signal/noise ratio + // arbitrarily. + // + // We track the amount of excess overhead read, + // and produce an error if it grows too large. + // + // Currently, we say that we're willing to accept 16 bytes of overhead per chunk, + // plus twice the amount of real data in the chunk. + cr.excess -= 16 + (2 * int64(cr.n)) + cr.excess = max(cr.excess, 0) + if cr.excess > 16*1024 { + cr.err = errors.New("chunked encoding contains too much non-data") + } if cr.n == 0 { cr.err = io.EOF } @@ -140,11 +167,6 @@ func readChunkLine(b *bufio.Reader) ([]byte, error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - p = trimTrailingWhitespace(p) - p, err = removeChunkExtension(p) - if err != nil { - return nil, err - } return p, nil } @@ -199,7 +221,7 @@ type chunkedWriter struct { // Write the contents of data as one chunk to Wire. // NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has -// a bug since it does not check for success of io.WriteString +// a bug since it does not check for success of [io.WriteString] func (cw *chunkedWriter) Write(data []byte) (n int, err error) { // Don't send 0-length data. It looks like EOF for chunked encoding. @@ -231,9 +253,9 @@ func (cw *chunkedWriter) Close() error { return err } -// FlushAfterChunkWriter signals from the caller of NewChunkedWriter +// FlushAfterChunkWriter signals from the caller of [NewChunkedWriter] // that each chunk should be followed by a flush. It is used by the -// http.Transport code to keep the buffering behavior for headers and +// [net/http.Transport] code to keep the buffering behavior for headers and // trailers, but flush out chunks aggressively in the middle for // request bodies which may be generated slowly. See Issue 6574. type FlushAfterChunkWriter struct { @@ -241,6 +263,9 @@ type FlushAfterChunkWriter struct { } func parseHexUint(v []byte) (n uint64, err error) { + if len(v) == 0 { + return 0, errors.New("empty hex number for chunk length") + } for i, b := range v { switch { case '0' <= b && b <= '9': diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go index 5e29a786dd..af79711781 100644 --- a/src/net/http/internal/chunked_test.go +++ b/src/net/http/internal/chunked_test.go @@ -153,6 +153,7 @@ func TestParseHexUint(t *testing.T) { {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted {"10000000000000000", 0, "http chunk length too large"}, {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted + {"", 0, "empty hex number for chunk length"}, } for i := uint64(0); i <= 1234; i++ { tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i}) @@ -239,3 +240,62 @@ func TestChunkEndReadError(t *testing.T) { t.Errorf("expected %v, got %v", readErr, err) } } + +func TestChunkReaderTooMuchOverhead(t *testing.T) { + // If the sender is sending 100x as many chunk header bytes as chunk data, + // we should reject the stream at some point. + chunk := []byte("1;") + for i := 0; i < 100; i++ { + chunk = append(chunk, 'a') // chunk extension + } + chunk = append(chunk, "\r\nX\r\n"...) + const bodylen = 1 << 20 + r := NewChunkedReader(&funcReader{f: func(i int) ([]byte, error) { + if i < bodylen { + return chunk, nil + } + return []byte("0\r\n"), nil + }}) + _, err := io.ReadAll(r) + if err == nil { + t.Fatalf("successfully read body with excessive overhead; want error") + } +} + +func TestChunkReaderByteAtATime(t *testing.T) { + // Sending one byte per chunk should not trip the excess-overhead detection. + const bodylen = 1 << 20 + r := NewChunkedReader(&funcReader{f: func(i int) ([]byte, error) { + if i < bodylen { + return []byte("1\r\nX\r\n"), nil + } + return []byte("0\r\n"), nil + }}) + got, err := io.ReadAll(r) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(got) != bodylen { + t.Errorf("read %v bytes, want %v", len(got), bodylen) + } +} + +type funcReader struct { + f func(iteration int) ([]byte, error) + i int + b []byte + err error +} + +func (r *funcReader) Read(p []byte) (n int, err error) { + if len(r.b) == 0 && r.err == nil { + r.b, r.err = r.f(r.i) + r.i++ + } + n = copy(p, r.b) + r.b = r.b[n:] + if len(r.b) > 0 { + return n, nil + } + return n, r.err +} diff --git a/src/net/http/mapping.go b/src/net/http/mapping.go new file mode 100644 index 0000000000..87e6d5ae5d --- /dev/null +++ b/src/net/http/mapping.go @@ -0,0 +1,78 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// A mapping is a collection of key-value pairs where the keys are unique. +// A zero mapping is empty and ready to use. +// A mapping tries to pick a representation that makes [mapping.find] most efficient. +type mapping[K comparable, V any] struct { + s []entry[K, V] // for few pairs + m map[K]V // for many pairs +} + +type entry[K comparable, V any] struct { + key K + value V +} + +// maxSlice is the maximum number of pairs for which a slice is used. +// It is a variable for benchmarking. +var maxSlice int = 8 + +// add adds a key-value pair to the mapping. +func (h *mapping[K, V]) add(k K, v V) { + if h.m == nil && len(h.s) < maxSlice { + h.s = append(h.s, entry[K, V]{k, v}) + } else { + if h.m == nil { + h.m = map[K]V{} + for _, e := range h.s { + h.m[e.key] = e.value + } + h.s = nil + } + h.m[k] = v + } +} + +// find returns the value corresponding to the given key. +// The second return value is false if there is no value +// with that key. +func (h *mapping[K, V]) find(k K) (v V, found bool) { + if h == nil { + return v, false + } + if h.m != nil { + v, found = h.m[k] + return v, found + } + for _, e := range h.s { + if e.key == k { + return e.value, true + } + } + return v, false +} + +// eachPair calls f for each pair in the mapping. +// If f returns false, pairs returns immediately. +func (h *mapping[K, V]) eachPair(f func(k K, v V) bool) { + if h == nil { + return + } + if h.m != nil { + for k, v := range h.m { + if !f(k, v) { + return + } + } + } else { + for _, e := range h.s { + if !f(e.key, e.value) { + return + } + } + } +} diff --git a/src/net/http/mapping_test.go b/src/net/http/mapping_test.go new file mode 100644 index 0000000000..0aed9d9e31 --- /dev/null +++ b/src/net/http/mapping_test.go @@ -0,0 +1,154 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "cmp" + "fmt" + "slices" + "strconv" + "testing" +) + +func TestMapping(t *testing.T) { + var m mapping[int, string] + for i := 0; i < maxSlice; i++ { + m.add(i, strconv.Itoa(i)) + } + if m.m != nil { + t.Fatal("m.m != nil") + } + for i := 0; i < maxSlice; i++ { + g, _ := m.find(i) + w := strconv.Itoa(i) + if g != w { + t.Fatalf("%d: got %s, want %s", i, g, w) + } + } + m.add(4, "4") + if m.s != nil { + t.Fatal("m.s != nil") + } + if m.m == nil { + t.Fatal("m.m == nil") + } + g, _ := m.find(4) + if w := "4"; g != w { + t.Fatalf("got %s, want %s", g, w) + } +} + +func TestMappingEachPair(t *testing.T) { + var m mapping[int, string] + var want []entry[int, string] + for i := 0; i < maxSlice*2; i++ { + v := strconv.Itoa(i) + m.add(i, v) + want = append(want, entry[int, string]{i, v}) + + } + + var got []entry[int, string] + m.eachPair(func(k int, v string) bool { + got = append(got, entry[int, string]{k, v}) + return true + }) + slices.SortFunc(got, func(e1, e2 entry[int, string]) int { + return cmp.Compare(e1.key, e2.key) + }) + if !slices.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func BenchmarkFindChild(b *testing.B) { + key := "articles" + children := []string{ + "*", + "cmd.html", + "code.html", + "contrib.html", + "contribute.html", + "debugging_with_gdb.html", + "docs.html", + "effective_go.html", + "files.log", + "gccgo_contribute.html", + "gccgo_install.html", + "go-logo-black.png", + "go-logo-blue.png", + "go-logo-white.png", + "go1.1.html", + "go1.2.html", + "go1.html", + "go1compat.html", + "go_faq.html", + "go_mem.html", + "go_spec.html", + "help.html", + "ie.css", + "install-source.html", + "install.html", + "logo-153x55.png", + "Makefile", + "root.html", + "share.png", + "sieve.gif", + "tos.html", + "articles", + } + if len(children) != 32 { + panic("bad len") + } + for _, n := range []int{2, 4, 8, 16, 32} { + list := children[:n] + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + + b.Run("rep=linear", func(b *testing.B) { + var entries []entry[string, any] + for _, c := range list { + entries = append(entries, entry[string, any]{c, nil}) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + findChildLinear(key, entries) + } + }) + b.Run("rep=map", func(b *testing.B) { + m := map[string]any{} + for _, c := range list { + m[c] = nil + } + var x any + b.ResetTimer() + for i := 0; i < b.N; i++ { + x = m[key] + } + _ = x + }) + b.Run(fmt.Sprintf("rep=hybrid%d", maxSlice), func(b *testing.B) { + var h mapping[string, any] + for _, c := range list { + h.add(c, nil) + } + var x any + b.ResetTimer() + for i := 0; i < b.N; i++ { + x, _ = h.find(key) + } + _ = x + }) + }) + } +} + +func findChildLinear(key string, entries []entry[string, any]) any { + for _, e := range entries { + if key == e.key { + return e.value + } + } + return nil +} diff --git a/src/net/http/pattern.go b/src/net/http/pattern.go new file mode 100644 index 0000000000..f6af19b0f4 --- /dev/null +++ b/src/net/http/pattern.go @@ -0,0 +1,529 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Patterns for ServeMux routing. + +package http + +import ( + "errors" + "fmt" + "net/url" + "strings" + "unicode" +) + +// A pattern is something that can be matched against an HTTP request. +// It has an optional method, an optional host, and a path. +type pattern struct { + str string // original string + method string + host string + // The representation of a path differs from the surface syntax, which + // simplifies most algorithms. + // + // Paths ending in '/' are represented with an anonymous "..." wildcard. + // For example, the path "a/" is represented as a literal segment "a" followed + // by a segment with multi==true. + // + // Paths ending in "{$}" are represented with the literal segment "/". + // For example, the path "a/{$}" is represented as a literal segment "a" followed + // by a literal segment "/". + segments []segment + loc string // source location of registering call, for helpful messages +} + +func (p *pattern) String() string { return p.str } + +func (p *pattern) lastSegment() segment { + return p.segments[len(p.segments)-1] +} + +// A segment is a pattern piece that matches one or more path segments, or +// a trailing slash. +// +// If wild is false, it matches a literal segment, or, if s == "/", a trailing slash. +// Examples: +// +// "a" => segment{s: "a"} +// "/{$}" => segment{s: "/"} +// +// If wild is true and multi is false, it matches a single path segment. +// Example: +// +// "{x}" => segment{s: "x", wild: true} +// +// If both wild and multi are true, it matches all remaining path segments. +// Example: +// +// "{rest...}" => segment{s: "rest", wild: true, multi: true} +type segment struct { + s string // literal or wildcard name or "/" for "/{$}". + wild bool + multi bool // "..." wildcard +} + +// parsePattern parses a string into a Pattern. +// The string's syntax is +// +// [METHOD] [HOST]/[PATH] +// +// where: +// - METHOD is an HTTP method +// - HOST is a hostname +// - PATH consists of slash-separated segments, where each segment is either +// a literal or a wildcard of the form "{name}", "{name...}", or "{$}". +// +// METHOD, HOST and PATH are all optional; that is, the string can be "/". +// If METHOD is present, it must be followed by a single space. +// Wildcard names must be valid Go identifiers. +// The "{$}" and "{name...}" wildcard must occur at the end of PATH. +// PATH may end with a '/'. +// Wildcard names in a path must be distinct. +func parsePattern(s string) (_ *pattern, err error) { + if len(s) == 0 { + return nil, errors.New("empty pattern") + } + off := 0 // offset into string + defer func() { + if err != nil { + err = fmt.Errorf("at offset %d: %w", off, err) + } + }() + + method, rest, found := strings.Cut(s, " ") + if !found { + rest = method + method = "" + } + if method != "" && !validMethod(method) { + return nil, fmt.Errorf("invalid method %q", method) + } + p := &pattern{str: s, method: method} + + if found { + off = len(method) + 1 + } + i := strings.IndexByte(rest, '/') + if i < 0 { + return nil, errors.New("host/path missing /") + } + p.host = rest[:i] + rest = rest[i:] + if j := strings.IndexByte(p.host, '{'); j >= 0 { + off += j + return nil, errors.New("host contains '{' (missing initial '/'?)") + } + // At this point, rest is the path. + off += i + + // An unclean path with a method that is not CONNECT can never match, + // because paths are cleaned before matching. + if method != "" && method != "CONNECT" && rest != cleanPath(rest) { + return nil, errors.New("non-CONNECT pattern with unclean path can never match") + } + + seenNames := map[string]bool{} // remember wildcard names to catch dups + for len(rest) > 0 { + // Invariant: rest[0] == '/'. + rest = rest[1:] + off = len(s) - len(rest) + if len(rest) == 0 { + // Trailing slash. + p.segments = append(p.segments, segment{wild: true, multi: true}) + break + } + i := strings.IndexByte(rest, '/') + if i < 0 { + i = len(rest) + } + var seg string + seg, rest = rest[:i], rest[i:] + if i := strings.IndexByte(seg, '{'); i < 0 { + // Literal. + seg = pathUnescape(seg) + p.segments = append(p.segments, segment{s: seg}) + } else { + // Wildcard. + if i != 0 { + return nil, errors.New("bad wildcard segment (must start with '{')") + } + if seg[len(seg)-1] != '}' { + return nil, errors.New("bad wildcard segment (must end with '}')") + } + name := seg[1 : len(seg)-1] + if name == "$" { + if len(rest) != 0 { + return nil, errors.New("{$} not at end") + } + p.segments = append(p.segments, segment{s: "/"}) + break + } + name, multi := strings.CutSuffix(name, "...") + if multi && len(rest) != 0 { + return nil, errors.New("{...} wildcard not at end") + } + if name == "" { + return nil, errors.New("empty wildcard") + } + if !isValidWildcardName(name) { + return nil, fmt.Errorf("bad wildcard name %q", name) + } + if seenNames[name] { + return nil, fmt.Errorf("duplicate wildcard name %q", name) + } + seenNames[name] = true + p.segments = append(p.segments, segment{s: name, wild: true, multi: multi}) + } + } + return p, nil +} + +func isValidWildcardName(s string) bool { + if s == "" { + return false + } + // Valid Go identifier. + for i, c := range s { + if !unicode.IsLetter(c) && c != '_' && (i == 0 || !unicode.IsDigit(c)) { + return false + } + } + return true +} + +func pathUnescape(path string) string { + u, err := url.PathUnescape(path) + if err != nil { + // Invalidly escaped path; use the original + return path + } + return u +} + +// relationship is a relationship between two patterns, p1 and p2. +type relationship string + +const ( + equivalent relationship = "equivalent" // both match the same requests + moreGeneral relationship = "moreGeneral" // p1 matches everything p2 does & more + moreSpecific relationship = "moreSpecific" // p2 matches everything p1 does & more + disjoint relationship = "disjoint" // there is no request that both match + overlaps relationship = "overlaps" // there is a request that both match, but neither is more specific +) + +// conflictsWith reports whether p1 conflicts with p2, that is, whether +// there is a request that both match but where neither is higher precedence +// than the other. +// +// Precedence is defined by two rules: +// 1. Patterns with a host win over patterns without a host. +// 2. Patterns whose method and path is more specific win. One pattern is more +// specific than another if the second matches all the (method, path) pairs +// of the first and more. +// +// If rule 1 doesn't apply, then two patterns conflict if their relationship +// is either equivalence (they match the same set of requests) or overlap +// (they both match some requests, but neither is more specific than the other). +func (p1 *pattern) conflictsWith(p2 *pattern) bool { + if p1.host != p2.host { + // Either one host is empty and the other isn't, in which case the + // one with the host wins by rule 1, or neither host is empty + // and they differ, so they won't match the same paths. + return false + } + rel := p1.comparePathsAndMethods(p2) + return rel == equivalent || rel == overlaps +} + +func (p1 *pattern) comparePathsAndMethods(p2 *pattern) relationship { + mrel := p1.compareMethods(p2) + // Optimization: avoid a call to comparePaths. + if mrel == disjoint { + return disjoint + } + prel := p1.comparePaths(p2) + return combineRelationships(mrel, prel) +} + +// compareMethods determines the relationship between the method +// part of patterns p1 and p2. +// +// A method can either be empty, "GET", or something else. +// The empty string matches any method, so it is the most general. +// "GET" matches both GET and HEAD. +// Anything else matches only itself. +func (p1 *pattern) compareMethods(p2 *pattern) relationship { + if p1.method == p2.method { + return equivalent + } + if p1.method == "" { + // p1 matches any method, but p2 does not, so p1 is more general. + return moreGeneral + } + if p2.method == "" { + return moreSpecific + } + if p1.method == "GET" && p2.method == "HEAD" { + // p1 matches GET and HEAD; p2 matches only HEAD. + return moreGeneral + } + if p2.method == "GET" && p1.method == "HEAD" { + return moreSpecific + } + return disjoint +} + +// comparePaths determines the relationship between the path +// part of two patterns. +func (p1 *pattern) comparePaths(p2 *pattern) relationship { + // Optimization: if a path pattern doesn't end in a multi ("...") wildcard, then it + // can only match paths with the same number of segments. + if len(p1.segments) != len(p2.segments) && !p1.lastSegment().multi && !p2.lastSegment().multi { + return disjoint + } + + // Consider corresponding segments in the two path patterns. + var segs1, segs2 []segment + rel := equivalent + for segs1, segs2 = p1.segments, p2.segments; len(segs1) > 0 && len(segs2) > 0; segs1, segs2 = segs1[1:], segs2[1:] { + rel = combineRelationships(rel, compareSegments(segs1[0], segs2[0])) + if rel == disjoint { + return rel + } + } + // We've reached the end of the corresponding segments of the patterns. + // If they have the same number of segments, then we've already determined + // their relationship. + if len(segs1) == 0 && len(segs2) == 0 { + return rel + } + // Otherwise, the only way they could fail to be disjoint is if the shorter + // pattern ends in a multi. In that case, that multi is more general + // than the remainder of the longer pattern, so combine those two relationships. + if len(segs1) < len(segs2) && p1.lastSegment().multi { + return combineRelationships(rel, moreGeneral) + } + if len(segs2) < len(segs1) && p2.lastSegment().multi { + return combineRelationships(rel, moreSpecific) + } + return disjoint +} + +// compareSegments determines the relationship between two segments. +func compareSegments(s1, s2 segment) relationship { + if s1.multi && s2.multi { + return equivalent + } + if s1.multi { + return moreGeneral + } + if s2.multi { + return moreSpecific + } + if s1.wild && s2.wild { + return equivalent + } + if s1.wild { + if s2.s == "/" { + // A single wildcard doesn't match a trailing slash. + return disjoint + } + return moreGeneral + } + if s2.wild { + if s1.s == "/" { + return disjoint + } + return moreSpecific + } + // Both literals. + if s1.s == s2.s { + return equivalent + } + return disjoint +} + +// combineRelationships determines the overall relationship of two patterns +// given the relationships of a partition of the patterns into two parts. +// +// For example, if p1 is more general than p2 in one way but equivalent +// in the other, then it is more general overall. +// +// Or if p1 is more general in one way and more specific in the other, then +// they overlap. +func combineRelationships(r1, r2 relationship) relationship { + switch r1 { + case equivalent: + return r2 + case disjoint: + return disjoint + case overlaps: + if r2 == disjoint { + return disjoint + } + return overlaps + case moreGeneral, moreSpecific: + switch r2 { + case equivalent: + return r1 + case inverseRelationship(r1): + return overlaps + default: + return r2 + } + default: + panic(fmt.Sprintf("unknown relationship %q", r1)) + } +} + +// If p1 has relationship `r` to p2, then +// p2 has inverseRelationship(r) to p1. +func inverseRelationship(r relationship) relationship { + switch r { + case moreSpecific: + return moreGeneral + case moreGeneral: + return moreSpecific + default: + return r + } +} + +// isLitOrSingle reports whether the segment is a non-dollar literal or a single wildcard. +func isLitOrSingle(seg segment) bool { + if seg.wild { + return !seg.multi + } + return seg.s != "/" +} + +// describeConflict returns an explanation of why two patterns conflict. +func describeConflict(p1, p2 *pattern) string { + mrel := p1.compareMethods(p2) + prel := p1.comparePaths(p2) + rel := combineRelationships(mrel, prel) + if rel == equivalent { + return fmt.Sprintf("%s matches the same requests as %s", p1, p2) + } + if rel != overlaps { + panic("describeConflict called with non-conflicting patterns") + } + if prel == overlaps { + return fmt.Sprintf(`%[1]s and %[2]s both match some paths, like %[3]q. +But neither is more specific than the other. +%[1]s matches %[4]q, but %[2]s doesn't. +%[2]s matches %[5]q, but %[1]s doesn't.`, + p1, p2, commonPath(p1, p2), differencePath(p1, p2), differencePath(p2, p1)) + } + if mrel == moreGeneral && prel == moreSpecific { + return fmt.Sprintf("%s matches more methods than %s, but has a more specific path pattern", p1, p2) + } + if mrel == moreSpecific && prel == moreGeneral { + return fmt.Sprintf("%s matches fewer methods than %s, but has a more general path pattern", p1, p2) + } + return fmt.Sprintf("bug: unexpected way for two patterns %s and %s to conflict: methods %s, paths %s", p1, p2, mrel, prel) +} + +// writeMatchingPath writes to b a path that matches the segments. +func writeMatchingPath(b *strings.Builder, segs []segment) { + for _, s := range segs { + writeSegment(b, s) + } +} + +func writeSegment(b *strings.Builder, s segment) { + b.WriteByte('/') + if !s.multi && s.s != "/" { + b.WriteString(s.s) + } +} + +// commonPath returns a path that both p1 and p2 match. +// It assumes there is such a path. +func commonPath(p1, p2 *pattern) string { + var b strings.Builder + var segs1, segs2 []segment + for segs1, segs2 = p1.segments, p2.segments; len(segs1) > 0 && len(segs2) > 0; segs1, segs2 = segs1[1:], segs2[1:] { + if s1 := segs1[0]; s1.wild { + writeSegment(&b, segs2[0]) + } else { + writeSegment(&b, s1) + } + } + if len(segs1) > 0 { + writeMatchingPath(&b, segs1) + } else if len(segs2) > 0 { + writeMatchingPath(&b, segs2) + } + return b.String() +} + +// differencePath returns a path that p1 matches and p2 doesn't. +// It assumes there is such a path. +func differencePath(p1, p2 *pattern) string { + var b strings.Builder + + var segs1, segs2 []segment + for segs1, segs2 = p1.segments, p2.segments; len(segs1) > 0 && len(segs2) > 0; segs1, segs2 = segs1[1:], segs2[1:] { + s1 := segs1[0] + s2 := segs2[0] + if s1.multi && s2.multi { + // From here the patterns match the same paths, so we must have found a difference earlier. + b.WriteByte('/') + return b.String() + + } + if s1.multi && !s2.multi { + // s1 ends in a "..." wildcard but s2 does not. + // A trailing slash will distinguish them, unless s2 ends in "{$}", + // in which case any segment will do; prefer the wildcard name if + // it has one. + b.WriteByte('/') + if s2.s == "/" { + if s1.s != "" { + b.WriteString(s1.s) + } else { + b.WriteString("x") + } + } + return b.String() + } + if !s1.multi && s2.multi { + writeSegment(&b, s1) + } else if s1.wild && s2.wild { + // Both patterns will match whatever we put here; use + // the first wildcard name. + writeSegment(&b, s1) + } else if s1.wild && !s2.wild { + // s1 is a wildcard, s2 is a literal. + // Any segment other than s2.s will work. + // Prefer the wildcard name, but if it's the same as the literal, + // tweak the literal. + if s1.s != s2.s { + writeSegment(&b, s1) + } else { + b.WriteByte('/') + b.WriteString(s2.s + "x") + } + } else if !s1.wild && s2.wild { + writeSegment(&b, s1) + } else { + // Both are literals. A precondition of this function is that the + // patterns overlap, so they must be the same literal. Use it. + if s1.s != s2.s { + panic(fmt.Sprintf("literals differ: %q and %q", s1.s, s2.s)) + } + writeSegment(&b, s1) + } + } + if len(segs1) > 0 { + // p1 is longer than p2, and p2 does not end in a multi. + // Anything that matches the rest of p1 will do. + writeMatchingPath(&b, segs1) + } else if len(segs2) > 0 { + writeMatchingPath(&b, segs2) + } + return b.String() +} diff --git a/src/net/http/pattern_test.go b/src/net/http/pattern_test.go new file mode 100644 index 0000000000..f0c84d243e --- /dev/null +++ b/src/net/http/pattern_test.go @@ -0,0 +1,494 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "slices" + "strings" + "testing" +) + +func TestParsePattern(t *testing.T) { + lit := func(name string) segment { + return segment{s: name} + } + + wild := func(name string) segment { + return segment{s: name, wild: true} + } + + multi := func(name string) segment { + s := wild(name) + s.multi = true + return s + } + + for _, test := range []struct { + in string + want pattern + }{ + {"/", pattern{segments: []segment{multi("")}}}, + {"/a", pattern{segments: []segment{lit("a")}}}, + { + "/a/", + pattern{segments: []segment{lit("a"), multi("")}}, + }, + {"/path/to/something", pattern{segments: []segment{ + lit("path"), lit("to"), lit("something"), + }}}, + { + "/{w1}/lit/{w2}", + pattern{ + segments: []segment{wild("w1"), lit("lit"), wild("w2")}, + }, + }, + { + "/{w1}/lit/{w2}/", + pattern{ + segments: []segment{wild("w1"), lit("lit"), wild("w2"), multi("")}, + }, + }, + { + "example.com/", + pattern{host: "example.com", segments: []segment{multi("")}}, + }, + { + "GET /", + pattern{method: "GET", segments: []segment{multi("")}}, + }, + { + "POST example.com/foo/{w}", + pattern{ + method: "POST", + host: "example.com", + segments: []segment{lit("foo"), wild("w")}, + }, + }, + { + "/{$}", + pattern{segments: []segment{lit("/")}}, + }, + { + "DELETE example.com/a/{foo12}/{$}", + pattern{method: "DELETE", host: "example.com", segments: []segment{lit("a"), wild("foo12"), lit("/")}}, + }, + { + "/foo/{$}", + pattern{segments: []segment{lit("foo"), lit("/")}}, + }, + { + "/{a}/foo/{rest...}", + pattern{segments: []segment{wild("a"), lit("foo"), multi("rest")}}, + }, + { + "//", + pattern{segments: []segment{lit(""), multi("")}}, + }, + { + "/foo///./../bar", + pattern{segments: []segment{lit("foo"), lit(""), lit(""), lit("."), lit(".."), lit("bar")}}, + }, + { + "a.com/foo//", + pattern{host: "a.com", segments: []segment{lit("foo"), lit(""), multi("")}}, + }, + { + "/%61%62/%7b/%", + pattern{segments: []segment{lit("ab"), lit("{"), lit("%")}}, + }, + } { + got := mustParsePattern(t, test.in) + if !got.equal(&test.want) { + t.Errorf("%q:\ngot %#v\nwant %#v", test.in, got, &test.want) + } + } +} + +func TestParsePatternError(t *testing.T) { + for _, test := range []struct { + in string + contains string + }{ + {"", "empty pattern"}, + {"A=B /", "at offset 0: invalid method"}, + {" ", "at offset 1: host/path missing /"}, + {"/{w}x", "at offset 1: bad wildcard segment"}, + {"/x{w}", "at offset 1: bad wildcard segment"}, + {"/{wx", "at offset 1: bad wildcard segment"}, + {"/a/{/}/c", "at offset 3: bad wildcard segment"}, + {"/a/{%61}/c", "at offset 3: bad wildcard name"}, // wildcard names aren't unescaped + {"/{a$}", "at offset 1: bad wildcard name"}, + {"/{}", "at offset 1: empty wildcard"}, + {"POST a.com/x/{}/y", "at offset 13: empty wildcard"}, + {"/{...}", "at offset 1: empty wildcard"}, + {"/{$...}", "at offset 1: bad wildcard"}, + {"/{$}/", "at offset 1: {$} not at end"}, + {"/{$}/x", "at offset 1: {$} not at end"}, + {"/abc/{$}/x", "at offset 5: {$} not at end"}, + {"/{a...}/", "at offset 1: {...} wildcard not at end"}, + {"/{a...}/x", "at offset 1: {...} wildcard not at end"}, + {"{a}/b", "at offset 0: host contains '{' (missing initial '/'?)"}, + {"/a/{x}/b/{x...}", "at offset 9: duplicate wildcard name"}, + {"GET //", "at offset 4: non-CONNECT pattern with unclean path"}, + } { + _, err := parsePattern(test.in) + if err == nil || !strings.Contains(err.Error(), test.contains) { + t.Errorf("%q:\ngot %v, want error containing %q", test.in, err, test.contains) + } + } +} + +func (p1 *pattern) equal(p2 *pattern) bool { + return p1.method == p2.method && p1.host == p2.host && + slices.Equal(p1.segments, p2.segments) +} + +func mustParsePattern(tb testing.TB, s string) *pattern { + tb.Helper() + p, err := parsePattern(s) + if err != nil { + tb.Fatal(err) + } + return p +} + +func TestCompareMethods(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want relationship + }{ + {"/", "/", equivalent}, + {"GET /", "GET /", equivalent}, + {"HEAD /", "HEAD /", equivalent}, + {"POST /", "POST /", equivalent}, + {"GET /", "POST /", disjoint}, + {"GET /", "/", moreSpecific}, + {"HEAD /", "/", moreSpecific}, + {"GET /", "HEAD /", moreGeneral}, + } { + pat1 := mustParsePattern(t, test.p1) + pat2 := mustParsePattern(t, test.p2) + got := pat1.compareMethods(pat2) + if got != test.want { + t.Errorf("%s vs %s: got %s, want %s", test.p1, test.p2, got, test.want) + } + got2 := pat2.compareMethods(pat1) + want2 := inverseRelationship(test.want) + if got2 != want2 { + t.Errorf("%s vs %s: got %s, want %s", test.p2, test.p1, got2, want2) + } + } +} + +func TestComparePaths(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want relationship + }{ + // A non-final pattern segment can have one of two values: literal or + // single wildcard. A final pattern segment can have one of 5: empty + // (trailing slash), literal, dollar, single wildcard, or multi + // wildcard. Trailing slash and multi wildcard are the same. + + // A literal should be more specific than anything it overlaps, except itself. + {"/a", "/a", equivalent}, + {"/a", "/b", disjoint}, + {"/a", "/", moreSpecific}, + {"/a", "/{$}", disjoint}, + {"/a", "/{x}", moreSpecific}, + {"/a", "/{x...}", moreSpecific}, + + // Adding a segment doesn't change that. + {"/b/a", "/b/a", equivalent}, + {"/b/a", "/b/b", disjoint}, + {"/b/a", "/b/", moreSpecific}, + {"/b/a", "/b/{$}", disjoint}, + {"/b/a", "/b/{x}", moreSpecific}, + {"/b/a", "/b/{x...}", moreSpecific}, + {"/{z}/a", "/{z}/a", equivalent}, + {"/{z}/a", "/{z}/b", disjoint}, + {"/{z}/a", "/{z}/", moreSpecific}, + {"/{z}/a", "/{z}/{$}", disjoint}, + {"/{z}/a", "/{z}/{x}", moreSpecific}, + {"/{z}/a", "/{z}/{x...}", moreSpecific}, + + // Single wildcard on left. + {"/{z}", "/a", moreGeneral}, + {"/{z}", "/a/b", disjoint}, + {"/{z}", "/{$}", disjoint}, + {"/{z}", "/{x}", equivalent}, + {"/{z}", "/", moreSpecific}, + {"/{z}", "/{x...}", moreSpecific}, + {"/b/{z}", "/b/a", moreGeneral}, + {"/b/{z}", "/b/a/b", disjoint}, + {"/b/{z}", "/b/{$}", disjoint}, + {"/b/{z}", "/b/{x}", equivalent}, + {"/b/{z}", "/b/", moreSpecific}, + {"/b/{z}", "/b/{x...}", moreSpecific}, + + // Trailing slash on left. + {"/", "/a", moreGeneral}, + {"/", "/a/b", moreGeneral}, + {"/", "/{$}", moreGeneral}, + {"/", "/{x}", moreGeneral}, + {"/", "/", equivalent}, + {"/", "/{x...}", equivalent}, + + {"/b/", "/b/a", moreGeneral}, + {"/b/", "/b/a/b", moreGeneral}, + {"/b/", "/b/{$}", moreGeneral}, + {"/b/", "/b/{x}", moreGeneral}, + {"/b/", "/b/", equivalent}, + {"/b/", "/b/{x...}", equivalent}, + + {"/{z}/", "/{z}/a", moreGeneral}, + {"/{z}/", "/{z}/a/b", moreGeneral}, + {"/{z}/", "/{z}/{$}", moreGeneral}, + {"/{z}/", "/{z}/{x}", moreGeneral}, + {"/{z}/", "/{z}/", equivalent}, + {"/{z}/", "/a/", moreGeneral}, + {"/{z}/", "/{z}/{x...}", equivalent}, + {"/{z}/", "/a/{x...}", moreGeneral}, + {"/a/{z}/", "/{z}/a/", overlaps}, + {"/a/{z}/b/", "/{x}/c/{y...}", overlaps}, + + // Multi wildcard on left. + {"/{m...}", "/a", moreGeneral}, + {"/{m...}", "/a/b", moreGeneral}, + {"/{m...}", "/{$}", moreGeneral}, + {"/{m...}", "/{x}", moreGeneral}, + {"/{m...}", "/", equivalent}, + {"/{m...}", "/{x...}", equivalent}, + + {"/b/{m...}", "/b/a", moreGeneral}, + {"/b/{m...}", "/b/a/b", moreGeneral}, + {"/b/{m...}", "/b/{$}", moreGeneral}, + {"/b/{m...}", "/b/{x}", moreGeneral}, + {"/b/{m...}", "/b/", equivalent}, + {"/b/{m...}", "/b/{x...}", equivalent}, + {"/b/{m...}", "/a/{x...}", disjoint}, + + {"/{z}/{m...}", "/{z}/a", moreGeneral}, + {"/{z}/{m...}", "/{z}/a/b", moreGeneral}, + {"/{z}/{m...}", "/{z}/{$}", moreGeneral}, + {"/{z}/{m...}", "/{z}/{x}", moreGeneral}, + {"/{z}/{m...}", "/{w}/", equivalent}, + {"/{z}/{m...}", "/a/", moreGeneral}, + {"/{z}/{m...}", "/{z}/{x...}", equivalent}, + {"/{z}/{m...}", "/a/{x...}", moreGeneral}, + {"/a/{m...}", "/a/b/{y...}", moreGeneral}, + {"/a/{m...}", "/a/{x}/{y...}", moreGeneral}, + {"/a/{z}/{m...}", "/a/b/{y...}", moreGeneral}, + {"/a/{z}/{m...}", "/{z}/a/", overlaps}, + {"/a/{z}/{m...}", "/{z}/b/{y...}", overlaps}, + {"/a/{z}/b/{m...}", "/{x}/c/{y...}", overlaps}, + {"/a/{z}/a/{m...}", "/{x}/b", disjoint}, + + // Dollar on left. + {"/{$}", "/a", disjoint}, + {"/{$}", "/a/b", disjoint}, + {"/{$}", "/{$}", equivalent}, + {"/{$}", "/{x}", disjoint}, + {"/{$}", "/", moreSpecific}, + {"/{$}", "/{x...}", moreSpecific}, + + {"/b/{$}", "/b", disjoint}, + {"/b/{$}", "/b/a", disjoint}, + {"/b/{$}", "/b/a/b", disjoint}, + {"/b/{$}", "/b/{$}", equivalent}, + {"/b/{$}", "/b/{x}", disjoint}, + {"/b/{$}", "/b/", moreSpecific}, + {"/b/{$}", "/b/{x...}", moreSpecific}, + {"/b/{$}", "/b/c/{x...}", disjoint}, + {"/b/{x}/a/{$}", "/{x}/c/{y...}", overlaps}, + {"/{x}/b/{$}", "/a/{x}/{y}", disjoint}, + {"/{x}/b/{$}", "/a/{x}/c", disjoint}, + + {"/{z}/{$}", "/{z}/a", disjoint}, + {"/{z}/{$}", "/{z}/a/b", disjoint}, + {"/{z}/{$}", "/{z}/{$}", equivalent}, + {"/{z}/{$}", "/{z}/{x}", disjoint}, + {"/{z}/{$}", "/{z}/", moreSpecific}, + {"/{z}/{$}", "/a/", overlaps}, + {"/{z}/{$}", "/a/{x...}", overlaps}, + {"/{z}/{$}", "/{z}/{x...}", moreSpecific}, + {"/a/{z}/{$}", "/{z}/a/", overlaps}, + } { + pat1 := mustParsePattern(t, test.p1) + pat2 := mustParsePattern(t, test.p2) + if g := pat1.comparePaths(pat1); g != equivalent { + t.Errorf("%s does not match itself; got %s", pat1, g) + } + if g := pat2.comparePaths(pat2); g != equivalent { + t.Errorf("%s does not match itself; got %s", pat2, g) + } + got := pat1.comparePaths(pat2) + if got != test.want { + t.Errorf("%s vs %s: got %s, want %s", test.p1, test.p2, got, test.want) + t.Logf("pat1: %+v\n", pat1.segments) + t.Logf("pat2: %+v\n", pat2.segments) + } + want2 := inverseRelationship(test.want) + got2 := pat2.comparePaths(pat1) + if got2 != want2 { + t.Errorf("%s vs %s: got %s, want %s", test.p2, test.p1, got2, want2) + } + } +} + +func TestConflictsWith(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want bool + }{ + {"/a", "/a", true}, + {"/a", "/ab", false}, + {"/a/b/cd", "/a/b/cd", true}, + {"/a/b/cd", "/a/b/c", false}, + {"/a/b/c", "/a/c/c", false}, + {"/{x}", "/{y}", true}, + {"/{x}", "/a", false}, // more specific + {"/{x}/{y}", "/{x}/a", false}, + {"/{x}/{y}", "/{x}/a/b", false}, + {"/{x}", "/a/{y}", false}, + {"/{x}/{y}", "/{x}/a/", false}, + {"/{x}", "/a/{y...}", false}, // more specific + {"/{x}/a/{y}", "/{x}/a/{y...}", false}, // more specific + {"/{x}/{y}", "/{x}/a/{$}", false}, // more specific + {"/{x}/{y}/{$}", "/{x}/a/{$}", false}, + {"/a/{x}", "/{x}/b", true}, + {"/", "GET /", false}, + {"/", "GET /foo", false}, + {"GET /", "GET /foo", false}, + {"GET /", "/foo", true}, + {"GET /foo", "HEAD /", true}, + } { + pat1 := mustParsePattern(t, test.p1) + pat2 := mustParsePattern(t, test.p2) + got := pat1.conflictsWith(pat2) + if got != test.want { + t.Errorf("%q.ConflictsWith(%q) = %t, want %t", + test.p1, test.p2, got, test.want) + } + // conflictsWith should be commutative. + got = pat2.conflictsWith(pat1) + if got != test.want { + t.Errorf("%q.ConflictsWith(%q) = %t, want %t", + test.p2, test.p1, got, test.want) + } + } +} + +func TestRegisterConflict(t *testing.T) { + mux := NewServeMux() + pat1 := "/a/{x}/" + if err := mux.registerErr(pat1, NotFoundHandler()); err != nil { + t.Fatal(err) + } + pat2 := "/a/{y}/{z...}" + err := mux.registerErr(pat2, NotFoundHandler()) + var got string + if err == nil { + got = "<nil>" + } else { + got = err.Error() + } + want := "matches the same requests as" + if !strings.Contains(got, want) { + t.Errorf("got\n%s\nwant\n%s", got, want) + } +} + +func TestDescribeConflict(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want string + }{ + {"/a/{x}", "/a/{y}", "the same requests"}, + {"/", "/{m...}", "the same requests"}, + {"/a/{x}", "/{y}/b", "both match some paths"}, + {"/a", "GET /{x}", "matches more methods than GET /{x}, but has a more specific path pattern"}, + {"GET /a", "HEAD /", "matches more methods than HEAD /, but has a more specific path pattern"}, + {"POST /", "/a", "matches fewer methods than /a, but has a more general path pattern"}, + } { + got := describeConflict(mustParsePattern(t, test.p1), mustParsePattern(t, test.p2)) + if !strings.Contains(got, test.want) { + t.Errorf("%s vs. %s:\ngot:\n%s\nwhich does not contain %q", + test.p1, test.p2, got, test.want) + } + } +} + +func TestCommonPath(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want string + }{ + {"/a/{x}", "/{x}/a", "/a/a"}, + {"/a/{z}/", "/{z}/a/", "/a/a/"}, + {"/a/{z}/{m...}", "/{z}/a/", "/a/a/"}, + {"/{z}/{$}", "/a/", "/a/"}, + {"/{z}/{$}", "/a/{x...}", "/a/"}, + {"/a/{z}/{$}", "/{z}/a/", "/a/a/"}, + {"/a/{x}/b/{y...}", "/{x}/c/{y...}", "/a/c/b/"}, + {"/a/{x}/b/", "/{x}/c/{y...}", "/a/c/b/"}, + {"/a/{x}/b/{$}", "/{x}/c/{y...}", "/a/c/b/"}, + {"/a/{z}/{x...}", "/{z}/b/{y...}", "/a/b/"}, + } { + pat1 := mustParsePattern(t, test.p1) + pat2 := mustParsePattern(t, test.p2) + if pat1.comparePaths(pat2) != overlaps { + t.Fatalf("%s does not overlap %s", test.p1, test.p2) + } + got := commonPath(pat1, pat2) + if got != test.want { + t.Errorf("%s vs. %s: got %q, want %q", test.p1, test.p2, got, test.want) + } + } +} + +func TestDifferencePath(t *testing.T) { + for _, test := range []struct { + p1, p2 string + want string + }{ + {"/a/{x}", "/{x}/a", "/a/x"}, + {"/{x}/a", "/a/{x}", "/x/a"}, + {"/a/{z}/", "/{z}/a/", "/a/z/"}, + {"/{z}/a/", "/a/{z}/", "/z/a/"}, + {"/{a}/a/", "/a/{z}/", "/ax/a/"}, + {"/a/{z}/{x...}", "/{z}/b/{y...}", "/a/z/"}, + {"/{z}/b/{y...}", "/a/{z}/{x...}", "/z/b/"}, + {"/a/b/", "/a/b/c", "/a/b/"}, + {"/a/b/{x...}", "/a/b/c", "/a/b/"}, + {"/a/b/{x...}", "/a/b/c/d", "/a/b/"}, + {"/a/b/{x...}", "/a/b/c/d/", "/a/b/"}, + {"/a/{z}/{m...}", "/{z}/a/", "/a/z/"}, + {"/{z}/a/", "/a/{z}/{m...}", "/z/a/"}, + {"/{z}/{$}", "/a/", "/z/"}, + {"/a/", "/{z}/{$}", "/a/x"}, + {"/{z}/{$}", "/a/{x...}", "/z/"}, + {"/a/{foo...}", "/{z}/{$}", "/a/foo"}, + {"/a/{z}/{$}", "/{z}/a/", "/a/z/"}, + {"/{z}/a/", "/a/{z}/{$}", "/z/a/x"}, + {"/a/{x}/b/{y...}", "/{x}/c/{y...}", "/a/x/b/"}, + {"/{x}/c/{y...}", "/a/{x}/b/{y...}", "/x/c/"}, + {"/a/{c}/b/", "/{x}/c/{y...}", "/a/cx/b/"}, + {"/{x}/c/{y...}", "/a/{c}/b/", "/x/c/"}, + {"/a/{x}/b/{$}", "/{x}/c/{y...}", "/a/x/b/"}, + {"/{x}/c/{y...}", "/a/{x}/b/{$}", "/x/c/"}, + } { + pat1 := mustParsePattern(t, test.p1) + pat2 := mustParsePattern(t, test.p2) + rel := pat1.comparePaths(pat2) + if rel != overlaps && rel != moreGeneral { + t.Fatalf("%s vs. %s are %s, need overlaps or moreGeneral", pat1, pat2, rel) + } + got := differencePath(pat1, pat2) + if got != test.want { + t.Errorf("%s vs. %s: got %q, want %q", test.p1, test.p2, got, test.want) + } + } +} diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go index bc3225daca..bc48f11834 100644 --- a/src/net/http/pprof/pprof.go +++ b/src/net/http/pprof/pprof.go @@ -47,12 +47,12 @@ // go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30 // // Or to look at the goroutine blocking profile, after calling -// runtime.SetBlockProfileRate in your program: +// [runtime.SetBlockProfileRate] in your program: // // go tool pprof http://localhost:6060/debug/pprof/block // // Or to look at the holders of contended mutexes, after calling -// runtime.SetMutexProfileFraction in your program: +// [runtime.SetMutexProfileFraction] in your program: // // go tool pprof http://localhost:6060/debug/pprof/mutex // diff --git a/src/net/http/pprof/pprof_test.go b/src/net/http/pprof/pprof_test.go index f82ad45bf6..24ad59ab39 100644 --- a/src/net/http/pprof/pprof_test.go +++ b/src/net/http/pprof/pprof_test.go @@ -6,12 +6,14 @@ package pprof import ( "bytes" + "encoding/base64" "fmt" "internal/profile" "internal/testenv" "io" "net/http" "net/http/httptest" + "path/filepath" "runtime" "runtime/pprof" "strings" @@ -261,3 +263,64 @@ func seen(p *profile.Profile, fname string) bool { } return false } + +// TestDeltaProfileEmptyBase validates that we still receive a valid delta +// profile even if the base contains no samples. +// +// Regression test for https://go.dev/issue/64566. +func TestDeltaProfileEmptyBase(t *testing.T) { + if testing.Short() { + // Delta profile collection has a 1s minimum. + t.Skip("skipping in -short mode") + } + + testenv.MustHaveGoRun(t) + + gotool, err := testenv.GoTool() + if err != nil { + t.Fatalf("error finding go tool: %v", err) + } + + out, err := testenv.Command(t, gotool, "run", filepath.Join("testdata", "delta_mutex.go")).CombinedOutput() + if err != nil { + t.Fatalf("error running profile collection: %v\noutput: %s", err, out) + } + + // Log the binary output for debugging failures. + b64 := make([]byte, base64.StdEncoding.EncodedLen(len(out))) + base64.StdEncoding.Encode(b64, out) + t.Logf("Output in base64.StdEncoding: %s", b64) + + p, err := profile.Parse(bytes.NewReader(out)) + if err != nil { + t.Fatalf("Parse got err %v want nil", err) + } + + t.Logf("Output as parsed Profile: %s", p) + + if len(p.SampleType) != 2 { + t.Errorf("len(p.SampleType) got %d want 2", len(p.SampleType)) + } + if p.SampleType[0].Type != "contentions" { + t.Errorf(`p.SampleType[0].Type got %q want "contentions"`, p.SampleType[0].Type) + } + if p.SampleType[0].Unit != "count" { + t.Errorf(`p.SampleType[0].Unit got %q want "count"`, p.SampleType[0].Unit) + } + if p.SampleType[1].Type != "delay" { + t.Errorf(`p.SampleType[1].Type got %q want "delay"`, p.SampleType[1].Type) + } + if p.SampleType[1].Unit != "nanoseconds" { + t.Errorf(`p.SampleType[1].Unit got %q want "nanoseconds"`, p.SampleType[1].Unit) + } + + if p.PeriodType == nil { + t.Fatal("p.PeriodType got nil want not nil") + } + if p.PeriodType.Type != "contentions" { + t.Errorf(`p.PeriodType.Type got %q want "contentions"`, p.PeriodType.Type) + } + if p.PeriodType.Unit != "count" { + t.Errorf(`p.PeriodType.Unit got %q want "count"`, p.PeriodType.Unit) + } +} diff --git a/src/net/http/pprof/testdata/delta_mutex.go b/src/net/http/pprof/testdata/delta_mutex.go new file mode 100644 index 0000000000..634069c8a0 --- /dev/null +++ b/src/net/http/pprof/testdata/delta_mutex.go @@ -0,0 +1,43 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This binary collects a 1s delta mutex profile and dumps it to os.Stdout. +// +// This is in a subprocess because we want the base mutex profile to be empty +// (as a regression test for https://go.dev/issue/64566) and the only way to +// force reset the profile is to create a new subprocess. +// +// This manually collects the HTTP response and dumps to stdout in order to +// avoid any flakiness around port selection for a real HTTP server. +package main + +import ( + "bytes" + "fmt" + "log" + "net/http" + "net/http/pprof" + "net/http/httptest" + "runtime" +) + +func main() { + // Disable the mutex profiler. This is the default, but that default is + // load-bearing for this test, which needs the base profile to be empty. + runtime.SetMutexProfileFraction(0) + + h := pprof.Handler("mutex") + + req := httptest.NewRequest("GET", "/debug/pprof/mutex?seconds=1", nil) + rec := httptest.NewRecorder() + rec.Body = new(bytes.Buffer) + + h.ServeHTTP(rec, req) + resp := rec.Result() + if resp.StatusCode != http.StatusOK { + log.Fatalf("Request failed: %s\n%s", resp.Status, rec.Body) + } + + fmt.Print(rec.Body) +} diff --git a/src/net/http/request.go b/src/net/http/request.go index 81f79566a5..99fdebcf9b 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -107,14 +107,10 @@ var reqWriteExcludeHeader = map[string]bool{ // // The field semantics differ slightly between client and server // usage. In addition to the notes on the fields below, see the -// documentation for Request.Write and RoundTripper. +// documentation for [Request.Write] and [RoundTripper]. type Request struct { // Method specifies the HTTP method (GET, POST, PUT, etc.). // For client requests, an empty string means GET. - // - // Go's HTTP client does not support sending a request with - // the CONNECT method. See the documentation on Transport for - // details. Method string // URL specifies either the URI being requested (for server @@ -329,10 +325,15 @@ type Request struct { // It is unexported to prevent people from using Context wrong // and mutating the contexts held by callers of the same request. ctx context.Context + + // The following fields are for requests matched by ServeMux. + pat *pattern // the pattern that matched + matches []string // values for the matching wildcards in pat + otherValues map[string]string // for calls to SetPathValue that don't match a wildcard } // Context returns the request's context. To change the context, use -// Clone or WithContext. +// [Request.Clone] or [Request.WithContext]. // // The returned context is always non-nil; it defaults to the // background context. @@ -356,8 +357,8 @@ func (r *Request) Context() context.Context { // lifetime of a request and its response: obtaining a connection, // sending the request, and reading the response headers and body. // -// To create a new request with a context, use NewRequestWithContext. -// To make a deep copy of a request with a new context, use Request.Clone. +// To create a new request with a context, use [NewRequestWithContext]. +// To make a deep copy of a request with a new context, use [Request.Clone]. func (r *Request) WithContext(ctx context.Context) *Request { if ctx == nil { panic("nil context") @@ -396,6 +397,20 @@ func (r *Request) Clone(ctx context.Context) *Request { r2.Form = cloneURLValues(r.Form) r2.PostForm = cloneURLValues(r.PostForm) r2.MultipartForm = cloneMultipartForm(r.MultipartForm) + + // Copy matches and otherValues. See issue 61410. + if s := r.matches; s != nil { + s2 := make([]string, len(s)) + copy(s2, s) + r2.matches = s2 + } + if s := r.otherValues; s != nil { + s2 := make(map[string]string, len(s)) + for k, v := range s { + s2[k] = v + } + r2.otherValues = s2 + } return r2 } @@ -420,7 +435,7 @@ func (r *Request) Cookies() []*Cookie { var ErrNoCookie = errors.New("http: named cookie not present") // Cookie returns the named cookie provided in the request or -// ErrNoCookie if not found. +// [ErrNoCookie] if not found. // If multiple cookies match the given name, only one cookie will // be returned. func (r *Request) Cookie(name string) (*Cookie, error) { @@ -434,7 +449,7 @@ func (r *Request) Cookie(name string) (*Cookie, error) { } // AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, -// AddCookie does not attach more than one Cookie header field. That +// AddCookie does not attach more than one [Cookie] header field. That // means all cookies, if any, are written into the same line, // separated by semicolon. // AddCookie only sanitizes c's name and value, and does not sanitize @@ -452,7 +467,7 @@ func (r *Request) AddCookie(c *Cookie) { // // Referer is misspelled as in the request itself, a mistake from the // earliest days of HTTP. This value can also be fetched from the -// Header map as Header["Referer"]; the benefit of making it available +// [Header] map as Header["Referer"]; the benefit of making it available // as a method is that the compiler can diagnose programs that use the // alternate (correct English) spelling req.Referrer() but cannot // diagnose programs that use Header["Referrer"]. @@ -470,7 +485,7 @@ var multipartByReader = &multipart.Form{ // MultipartReader returns a MIME multipart reader if this is a // multipart/form-data or a multipart/mixed POST request, else returns nil and an error. -// Use this function instead of ParseMultipartForm to +// Use this function instead of [Request.ParseMultipartForm] to // process the request body as a stream. func (r *Request) MultipartReader() (*multipart.Reader, error) { if r.MultipartForm == multipartByReader { @@ -533,15 +548,15 @@ const defaultUserAgent = "Go-http-client/1.1" // TransferEncoding // Body // -// If Body is present, Content-Length is <= 0 and TransferEncoding +// If Body is present, Content-Length is <= 0 and [Request.TransferEncoding] // hasn't been set to "identity", Write adds "Transfer-Encoding: // chunked" to the header. Body is closed after it is sent. func (r *Request) Write(w io.Writer) error { return r.write(w, false, nil, nil) } -// WriteProxy is like Write but writes the request in the form -// expected by an HTTP proxy. In particular, WriteProxy writes the +// WriteProxy is like [Request.Write] but writes the request in the form +// expected by an HTTP proxy. In particular, [Request.WriteProxy] writes the // initial Request-URI line of the request with an absolute URI, per // section 5.3 of RFC 7230, including the scheme and host. // In either case, WriteProxy also writes a Host header, using @@ -669,6 +684,8 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF userAgent = r.Header.Get("User-Agent") } if userAgent != "" { + userAgent = headerNewlineToSpace.Replace(userAgent) + userAgent = textproto.TrimString(userAgent) _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) if err != nil { return err @@ -834,32 +851,33 @@ func validMethod(method string) bool { return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 } -// NewRequest wraps NewRequestWithContext using context.Background. +// NewRequest wraps [NewRequestWithContext] using [context.Background]. func NewRequest(method, url string, body io.Reader) (*Request, error) { return NewRequestWithContext(context.Background(), method, url, body) } -// NewRequestWithContext returns a new Request given a method, URL, and +// NewRequestWithContext returns a new [Request] given a method, URL, and // optional body. // -// If the provided body is also an io.Closer, the returned -// Request.Body is set to body and will be closed by the Client -// methods Do, Post, and PostForm, and Transport.RoundTrip. +// If the provided body is also an [io.Closer], the returned +// [Request.Body] is set to body and will be closed (possibly +// asynchronously) by the Client methods Do, Post, and PostForm, +// and [Transport.RoundTrip]. // // NewRequestWithContext returns a Request suitable for use with -// Client.Do or Transport.RoundTrip. To create a request for use with -// testing a Server Handler, either use the NewRequest function in the -// net/http/httptest package, use ReadRequest, or manually update the +// [Client.Do] or [Transport.RoundTrip]. To create a request for use with +// testing a Server Handler, either use the [NewRequest] function in the +// net/http/httptest package, use [ReadRequest], or manually update the // Request fields. For an outgoing client request, the context // controls the entire lifetime of a request and its response: // obtaining a connection, sending the request, and reading the // response headers and body. See the Request type's documentation for // the difference between inbound and outbound request fields. // -// If body is of type *bytes.Buffer, *bytes.Reader, or -// *strings.Reader, the returned request's ContentLength is set to its +// If body is of type [*bytes.Buffer], [*bytes.Reader], or +// [*strings.Reader], the returned request's ContentLength is set to its // exact value (instead of -1), GetBody is populated (so 307 and 308 -// redirects can replay the body), and Body is set to NoBody if the +// redirects can replay the body), and Body is set to [NoBody] if the // ContentLength is 0. func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) { if method == "" { @@ -983,7 +1001,7 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { // The username may not contain a colon. Some protocols may impose // additional requirements on pre-escaping the username and // password. For instance, when used with OAuth2, both arguments must -// be URL encoded first with url.QueryEscape. +// be URL encoded first with [url.QueryEscape]. func (r *Request) SetBasicAuth(username, password string) { r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } @@ -1017,8 +1035,8 @@ func putTextprotoReader(r *textproto.Reader) { // ReadRequest reads and parses an incoming request from b. // // ReadRequest is a low-level function and should only be used for -// specialized applications; most code should use the Server to read -// requests and handle them via the Handler interface. ReadRequest +// specialized applications; most code should use the [Server] to read +// requests and handle them via the [Handler] interface. ReadRequest // only supports HTTP/1.x requests. For HTTP/2, use golang.org/x/net/http2. func ReadRequest(b *bufio.Reader) (*Request, error) { req, err := readRequest(b) @@ -1127,15 +1145,15 @@ func readRequest(b *bufio.Reader) (req *Request, err error) { return req, nil } -// MaxBytesReader is similar to io.LimitReader but is intended for +// MaxBytesReader is similar to [io.LimitReader] but is intended for // limiting the size of incoming request bodies. In contrast to // io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a -// non-nil error of type *MaxBytesError for a Read beyond the limit, +// non-nil error of type [*MaxBytesError] for a Read beyond the limit, // and closes the underlying reader when its Close method is called. // // MaxBytesReader prevents clients from accidentally or maliciously // sending a large request and wasting server resources. If possible, -// it tells the ResponseWriter to close the connection after the limit +// it tells the [ResponseWriter] to close the connection after the limit // has been reached. func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { if n < 0 { // Treat negative limits as equivalent to 0. @@ -1144,7 +1162,7 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { return &maxBytesReader{w: w, r: r, i: n, n: n} } -// MaxBytesError is returned by MaxBytesReader when its read limit is exceeded. +// MaxBytesError is returned by [MaxBytesReader] when its read limit is exceeded. type MaxBytesError struct { Limit int64 } @@ -1269,14 +1287,14 @@ func parsePostForm(r *Request) (vs url.Values, err error) { // as a form and puts the results into both r.PostForm and r.Form. Request body // parameters take precedence over URL query string values in r.Form. // -// If the request Body's size has not already been limited by MaxBytesReader, +// If the request Body's size has not already been limited by [MaxBytesReader], // the size is capped at 10MB. // // For other HTTP methods, or when the Content-Type is not // application/x-www-form-urlencoded, the request Body is not read, and // r.PostForm is initialized to a non-nil, empty value. // -// ParseMultipartForm calls ParseForm automatically. +// [Request.ParseMultipartForm] calls ParseForm automatically. // ParseForm is idempotent. func (r *Request) ParseForm() error { var err error @@ -1317,7 +1335,7 @@ func (r *Request) ParseForm() error { // The whole request body is parsed and up to a total of maxMemory bytes of // its file parts are stored in memory, with the remainder stored on // disk in temporary files. -// ParseMultipartForm calls ParseForm if necessary. +// ParseMultipartForm calls [Request.ParseForm] if necessary. // If ParseForm returns an error, ParseMultipartForm returns it but also // continues parsing the request body. // After one call to ParseMultipartForm, subsequent calls have no effect. @@ -1360,12 +1378,16 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { } // FormValue returns the first value for the named component of the query. -// POST and PUT body parameters take precedence over URL query string values. -// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores -// any errors returned by these functions. +// The precedence order: +// 1. application/x-www-form-urlencoded form body (POST, PUT, PATCH only) +// 2. query parameters (always) +// 3. multipart/form-data form body (always) +// +// FormValue calls [Request.ParseMultipartForm] and [Request.ParseForm] +// if necessary and ignores any errors returned by these functions. // If key is not present, FormValue returns the empty string. // To access multiple values of the same key, call ParseForm and -// then inspect Request.Form directly. +// then inspect [Request.Form] directly. func (r *Request) FormValue(key string) string { if r.Form == nil { r.ParseMultipartForm(defaultMaxMemory) @@ -1377,8 +1399,8 @@ func (r *Request) FormValue(key string) string { } // PostFormValue returns the first value for the named component of the POST, -// PATCH, or PUT request body. URL query parameters are ignored. -// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores +// PUT, or PATCH request body. URL query parameters are ignored. +// PostFormValue calls [Request.ParseMultipartForm] and [Request.ParseForm] if necessary and ignores // any errors returned by these functions. // If key is not present, PostFormValue returns the empty string. func (r *Request) PostFormValue(key string) string { @@ -1392,7 +1414,7 @@ func (r *Request) PostFormValue(key string) string { } // FormFile returns the first file for the provided form key. -// FormFile calls ParseMultipartForm and ParseForm if necessary. +// FormFile calls [Request.ParseMultipartForm] and [Request.ParseForm] if necessary. func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { if r.MultipartForm == multipartByReader { return nil, nil, errors.New("http: multipart handled by MultipartReader") @@ -1412,6 +1434,50 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e return nil, nil, ErrMissingFile } +// PathValue returns the value for the named path wildcard in the [ServeMux] pattern +// that matched the request. +// It returns the empty string if the request was not matched against a pattern +// or there is no such wildcard in the pattern. +func (r *Request) PathValue(name string) string { + if i := r.patIndex(name); i >= 0 { + return r.matches[i] + } + return r.otherValues[name] +} + +// SetPathValue sets name to value, so that subsequent calls to r.PathValue(name) +// return value. +func (r *Request) SetPathValue(name, value string) { + if i := r.patIndex(name); i >= 0 { + r.matches[i] = value + } else { + if r.otherValues == nil { + r.otherValues = map[string]string{} + } + r.otherValues[name] = value + } +} + +// patIndex returns the index of name in the list of named wildcards of the +// request's pattern, or -1 if there is no such name. +func (r *Request) patIndex(name string) int { + // The linear search seems expensive compared to a map, but just creating the map + // takes a lot of time, and most patterns will just have a couple of wildcards. + if r.pat == nil { + return -1 + } + i := 0 + for _, seg := range r.pat.segments { + if seg.wild && seg.s != "" { + if name == seg.s { + return i + } + i++ + } + } + return -1 +} + func (r *Request) expectsContinue() bool { return hasToken(r.Header.get("Expect"), "100-continue") } diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index a32b583c11..6ce32332e7 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -15,7 +15,9 @@ import ( "io" "math" "mime/multipart" + "net/http" . "net/http" + "net/http/httptest" "net/url" "os" "reflect" @@ -787,6 +789,25 @@ func TestRequestBadHostHeader(t *testing.T) { } } +func TestRequestBadUserAgent(t *testing.T) { + got := []string{} + req, err := NewRequest("GET", "http://foo/after", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "evil\r\nX-Evil: evil") + req.Write(logWrites{t, &got}) + want := []string{ + "GET /after HTTP/1.1\r\n", + "Host: foo\r\n", + "User-Agent: evil X-Evil: evil\r\n", + "\r\n", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Writes = %q\n Want = %q", got, want) + } +} + func TestStarRequest(t *testing.T) { req, err := ReadRequest(bufio.NewReader(strings.NewReader("M-SEARCH * HTTP/1.1\r\n\r\n"))) if err != nil { @@ -1032,6 +1053,33 @@ func TestRequestCloneTransferEncoding(t *testing.T) { } } +// Ensure that Request.Clone works correctly with PathValue. +// See issue 64911. +func TestRequestClonePathValue(t *testing.T) { + req, _ := http.NewRequest("GET", "https://example.org/", nil) + req.SetPathValue("p1", "orig") + + clonedReq := req.Clone(context.Background()) + clonedReq.SetPathValue("p2", "copy") + + // Ensure that any modifications to the cloned + // request do not pollute the original request. + if g, w := req.PathValue("p2"), ""; g != w { + t.Fatalf("p2 mismatch got %q, want %q", g, w) + } + if g, w := req.PathValue("p1"), "orig"; g != w { + t.Fatalf("p1 mismatch got %q, want %q", g, w) + } + + // Assert on the changes to the cloned request. + if g, w := clonedReq.PathValue("p1"), "orig"; g != w { + t.Fatalf("p1 mismatch got %q, want %q", g, w) + } + if g, w := clonedReq.PathValue("p2"), "copy"; g != w { + t.Fatalf("p2 mismatch got %q, want %q", g, w) + } +} + // Issue 34878: verify we don't panic when including basic auth (Go 1.13 regression) func TestNoPanicOnRoundTripWithBasicAuth(t *testing.T) { run(t, testNoPanicWithBasicAuth) } func testNoPanicWithBasicAuth(t *testing.T, mode testMode) { @@ -1395,3 +1443,141 @@ func TestErrNotSupported(t *testing.T) { t.Error("errors.Is(ErrNotSupported, errors.ErrUnsupported) failed") } } + +func TestPathValueNoMatch(t *testing.T) { + // Check that PathValue and SetPathValue work on a Request that was never matched. + var r Request + if g, w := r.PathValue("x"), ""; g != w { + t.Errorf("got %q, want %q", g, w) + } + r.SetPathValue("x", "a") + if g, w := r.PathValue("x"), "a"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} + +func TestPathValue(t *testing.T) { + for _, test := range []struct { + pattern string + url string + want map[string]string + }{ + { + "/{a}/is/{b}/{c...}", + "/now/is/the/time/for/all", + map[string]string{ + "a": "now", + "b": "the", + "c": "time/for/all", + "d": "", + }, + }, + { + "/names/{name}/{other...}", + "/names/%2fjohn/address", + map[string]string{ + "name": "/john", + "other": "address", + }, + }, + { + "/names/{name}/{other...}", + "/names/john%2Fdoe/there/is%2F/more", + map[string]string{ + "name": "john/doe", + "other": "there/is//more", + }, + }, + } { + mux := NewServeMux() + mux.HandleFunc(test.pattern, func(w ResponseWriter, r *Request) { + for name, want := range test.want { + got := r.PathValue(name) + if got != want { + t.Errorf("%q, %q: got %q, want %q", test.pattern, name, got, want) + } + } + }) + server := httptest.NewServer(mux) + defer server.Close() + res, err := Get(server.URL + test.url) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + } +} + +func TestSetPathValue(t *testing.T) { + mux := NewServeMux() + mux.HandleFunc("/a/{b}/c/{d...}", func(_ ResponseWriter, r *Request) { + kvs := map[string]string{ + "b": "X", + "d": "Y", + "a": "Z", + } + for k, v := range kvs { + r.SetPathValue(k, v) + } + for k, w := range kvs { + if g := r.PathValue(k); g != w { + t.Errorf("got %q, want %q", g, w) + } + } + }) + server := httptest.NewServer(mux) + defer server.Close() + res, err := Get(server.URL + "/a/b/c/d/e") + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +func TestStatus(t *testing.T) { + // The main purpose of this test is to check 405 responses and the Allow header. + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + mux := NewServeMux() + mux.Handle("GET /g", h) + mux.Handle("POST /p", h) + mux.Handle("PATCH /p", h) + mux.Handle("PUT /r", h) + mux.Handle("GET /r/", h) + server := httptest.NewServer(mux) + defer server.Close() + + for _, test := range []struct { + method, path string + wantStatus int + wantAllow string + }{ + {"GET", "/g", 200, ""}, + {"HEAD", "/g", 200, ""}, + {"POST", "/g", 405, "GET, HEAD"}, + {"GET", "/x", 404, ""}, + {"GET", "/p", 405, "PATCH, POST"}, + {"GET", "/./p", 405, "PATCH, POST"}, + {"GET", "/r/", 200, ""}, + {"GET", "/r", 200, ""}, // redirected + {"HEAD", "/r/", 200, ""}, + {"HEAD", "/r", 200, ""}, // redirected + {"PUT", "/r/", 405, "GET, HEAD"}, + {"PUT", "/r", 200, ""}, + } { + req, err := http.NewRequest(test.method, server.URL+test.path, nil) + if err != nil { + t.Fatal(err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, w := res.StatusCode, test.wantStatus; g != w { + t.Errorf("%s %s: got %d, want %d", test.method, test.path, g, w) + } + if g, w := res.Header.Get("Allow"), test.wantAllow; g != w { + t.Errorf("%s %s, Allow: got %q, want %q", test.method, test.path, g, w) + } + } +} diff --git a/src/net/http/response.go b/src/net/http/response.go index 755c696557..0c3d7f6d85 100644 --- a/src/net/http/response.go +++ b/src/net/http/response.go @@ -29,7 +29,7 @@ var respExcludeHeader = map[string]bool{ // Response represents the response from an HTTP request. // -// The Client and Transport return Responses from servers once +// The [Client] and [Transport] return Responses from servers once // the response headers have been received. The response body // is streamed on demand as the Body field is read. type Response struct { @@ -126,13 +126,13 @@ func (r *Response) Cookies() []*Cookie { return readSetCookies(r.Header) } -// ErrNoLocation is returned by Response's Location method +// ErrNoLocation is returned by the [Response.Location] method // when no Location header is present. var ErrNoLocation = errors.New("http: no Location header in response") // Location returns the URL of the response's "Location" header, // if present. Relative redirects are resolved relative to -// the Response's Request. ErrNoLocation is returned if no +// [Response.Request]. [ErrNoLocation] is returned if no // Location header is present. func (r *Response) Location() (*url.URL, error) { lv := r.Header.Get("Location") @@ -146,8 +146,8 @@ func (r *Response) Location() (*url.URL, error) { } // ReadResponse reads and returns an HTTP response from r. -// The req parameter optionally specifies the Request that corresponds -// to this Response. If nil, a GET request is assumed. +// The req parameter optionally specifies the [Request] that corresponds +// to this [Response]. If nil, a GET request is assumed. // Clients must call resp.Body.Close when finished reading resp.Body. // After that call, clients can inspect resp.Trailer to find key/value // pairs included in the response trailer. diff --git a/src/net/http/response_test.go b/src/net/http/response_test.go index 19fb48f23c..f3425c3c20 100644 --- a/src/net/http/response_test.go +++ b/src/net/http/response_test.go @@ -849,7 +849,7 @@ func TestReadResponseErrors(t *testing.T) { type testCase struct { name string // optional, defaults to in in string - wantErr any // nil, err value, or string substring + wantErr any // nil, err value, bool value, or string substring } status := func(s string, wantErr any) testCase { @@ -883,6 +883,7 @@ func TestReadResponseErrors(t *testing.T) { } errMultiCL := "message cannot contain multiple Content-Length headers" + errEmptyCL := "invalid empty Content-Length" tests := []testCase{ {"", "", io.ErrUnexpectedEOF}, @@ -918,7 +919,7 @@ func TestReadResponseErrors(t *testing.T) { contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil), contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL), contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil), - contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil), + contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", errEmptyCL), contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL), // multiple content-length headers for 204 and 304 should still be checked diff --git a/src/net/http/responsecontroller.go b/src/net/http/responsecontroller.go index 92276ffaf2..f3f24c1273 100644 --- a/src/net/http/responsecontroller.go +++ b/src/net/http/responsecontroller.go @@ -13,14 +13,14 @@ import ( // A ResponseController is used by an HTTP handler to control the response. // -// A ResponseController may not be used after the Handler.ServeHTTP method has returned. +// A ResponseController may not be used after the [Handler.ServeHTTP] method has returned. type ResponseController struct { rw ResponseWriter } -// NewResponseController creates a ResponseController for a request. +// NewResponseController creates a [ResponseController] for a request. // -// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method, +// The ResponseWriter should be the original value passed to the [Handler.ServeHTTP] method, // or have an Unwrap method returning the original ResponseWriter. // // If the ResponseWriter implements any of the following methods, the ResponseController @@ -34,7 +34,7 @@ type ResponseController struct { // EnableFullDuplex() error // // If the ResponseWriter does not support a method, ResponseController returns -// an error matching ErrNotSupported. +// an error matching [ErrNotSupported]. func NewResponseController(rw ResponseWriter) *ResponseController { return &ResponseController{rw} } @@ -116,8 +116,8 @@ func (c *ResponseController) SetWriteDeadline(deadline time.Time) error { } } -// EnableFullDuplex indicates that the request handler will interleave reads from Request.Body -// with writes to the ResponseWriter. +// EnableFullDuplex indicates that the request handler will interleave reads from [Request.Body] +// with writes to the [ResponseWriter]. // // For HTTP/1 requests, the Go HTTP server by default consumes any unread portion of // the request body before beginning to write the response, preventing handlers from diff --git a/src/net/http/responsecontroller_test.go b/src/net/http/responsecontroller_test.go index 5828f3795a..f1dcc79ef8 100644 --- a/src/net/http/responsecontroller_test.go +++ b/src/net/http/responsecontroller_test.go @@ -1,3 +1,7 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package http_test import ( diff --git a/src/net/http/roundtrip.go b/src/net/http/roundtrip.go index 49ea1a71ed..08c270179a 100644 --- a/src/net/http/roundtrip.go +++ b/src/net/http/roundtrip.go @@ -6,10 +6,10 @@ package http -// RoundTrip implements the RoundTripper interface. +// RoundTrip implements the [RoundTripper] interface. // // For higher-level HTTP client support (such as handling of cookies -// and redirects), see Get, Post, and the Client type. +// and redirects), see [Get], [Post], and the [Client] type. // // Like the RoundTripper interface, the error types returned // by RoundTrip are unspecified. diff --git a/src/net/http/roundtrip_js.go b/src/net/http/roundtrip_js.go index 9f9f0cb67d..04c241eb4c 100644 --- a/src/net/http/roundtrip_js.go +++ b/src/net/http/roundtrip_js.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "net/http/internal/ascii" "strconv" "strings" "syscall/js" @@ -55,7 +56,7 @@ var jsFetchMissing = js.Global().Get("fetch").IsUndefined() var jsFetchDisabled = js.Global().Get("process").Type() == js.TypeObject && strings.HasPrefix(js.Global().Get("process").Get("argv0").String(), "node") -// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. +// RoundTrip implements the [RoundTripper] interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *Request) (*Response, error) { // The Transport has a documented contract that states that if the DialContext or // DialTLSContext functions are set, they will be used to set up the connections. @@ -184,11 +185,22 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } code := result.Get("status").Int() + + uncompressed := false + if ascii.EqualFold(header.Get("Content-Encoding"), "gzip") { + // The fetch api will decode the gzip, but Content-Encoding not be deleted. + header.Del("Content-Encoding") + header.Del("Content-Length") + contentLength = -1 + uncompressed = true + } + respCh <- &Response{ Status: fmt.Sprintf("%d %s", code, StatusText(code)), StatusCode: code, Header: header, ContentLength: contentLength, + Uncompressed: uncompressed, Body: body, Request: req, } diff --git a/src/net/http/routing_index.go b/src/net/http/routing_index.go new file mode 100644 index 0000000000..9ac42c997d --- /dev/null +++ b/src/net/http/routing_index.go @@ -0,0 +1,124 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import "math" + +// A routingIndex optimizes conflict detection by indexing patterns. +// +// The basic idea is to rule out patterns that cannot conflict with a given +// pattern because they have a different literal in a corresponding segment. +// See the comments in [routingIndex.possiblyConflictingPatterns] for more details. +type routingIndex struct { + // map from a particular segment position and value to all registered patterns + // with that value in that position. + // For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c" + // but not "/a", "b/a", "/a/c" or "/a/{x}". + segments map[routingIndexKey][]*pattern + // All patterns that end in a multi wildcard (including trailing slash). + // We do not try to be clever about indexing multi patterns, because there + // are unlikely to be many of them. + multis []*pattern +} + +type routingIndexKey struct { + pos int // 0-based segment position + s string // literal, or empty for wildcard +} + +func (idx *routingIndex) addPattern(pat *pattern) { + if pat.lastSegment().multi { + idx.multis = append(idx.multis, pat) + } else { + if idx.segments == nil { + idx.segments = map[routingIndexKey][]*pattern{} + } + for pos, seg := range pat.segments { + key := routingIndexKey{pos: pos, s: ""} + if !seg.wild { + key.s = seg.s + } + idx.segments[key] = append(idx.segments[key], pat) + } + } +} + +// possiblyConflictingPatterns calls f on all patterns that might conflict with +// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately +// with that error. +// +// To be correct, possiblyConflictingPatterns must include all patterns that +// might conflict. But it may also include patterns that cannot conflict. +// For instance, an implementation that returns all registered patterns is correct. +// We use this fact throughout, simplifying the implementation by returning more +// patterns that we might need to. +func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) { + // Terminology: + // dollar pattern: one ending in "{$}" + // multi pattern: one ending in a trailing slash or "{x...}" wildcard + // ordinary pattern: neither of the above + + // apply f to all the pats, stopping on error. + apply := func(pats []*pattern) error { + if err != nil { + return err + } + for _, p := range pats { + err = f(p) + if err != nil { + return err + } + } + return nil + } + + // Our simple indexing scheme doesn't try to prune multi patterns; assume + // any of them can match the argument. + if err := apply(idx.multis); err != nil { + return err + } + if pat.lastSegment().s == "/" { + // All paths that a dollar pattern matches end in a slash; no paths that + // an ordinary pattern matches do. So only other dollar or multi + // patterns can conflict with a dollar pattern. Furthermore, conflicting + // dollar patterns must have the {$} in the same position. + return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}]) + } + // For ordinary and multi patterns, the only conflicts can be with a multi, + // or a pattern that has the same literal or a wildcard at some literal + // position. + // We could intersect all the possible matches at each position, but we + // do something simpler: we find the position with the fewest patterns. + var lmin, wmin []*pattern + min := math.MaxInt + hasLit := false + for i, seg := range pat.segments { + if seg.multi { + break + } + if !seg.wild { + hasLit = true + lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}] + wpats := idx.segments[routingIndexKey{s: "", pos: i}] + if sum := len(lpats) + len(wpats); sum < min { + lmin = lpats + wmin = wpats + min = sum + } + } + } + if hasLit { + apply(lmin) + apply(wmin) + return err + } + + // This pattern is all wildcards. + // Check it against everything. + for _, pats := range idx.segments { + apply(pats) + } + return err +} diff --git a/src/net/http/routing_index_test.go b/src/net/http/routing_index_test.go new file mode 100644 index 0000000000..1ffb9272c6 --- /dev/null +++ b/src/net/http/routing_index_test.go @@ -0,0 +1,179 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "slices" + "sort" + "strings" + "testing" +) + +func TestIndex(t *testing.T) { + // Generate every kind of pattern up to some number of segments, + // and compare conflicts found during indexing with those found + // by exhaustive comparison. + patterns := generatePatterns() + var idx routingIndex + for i, pat := range patterns { + got := indexConflicts(pat, &idx) + want := trueConflicts(pat, patterns[:i]) + if !slices.Equal(got, want) { + t.Fatalf("%q:\ngot %q\nwant %q", pat, got, want) + } + idx.addPattern(pat) + } +} + +func trueConflicts(pat *pattern, pats []*pattern) []string { + var s []string + for _, p := range pats { + if pat.conflictsWith(p) { + s = append(s, p.String()) + } + } + sort.Strings(s) + return s +} + +func indexConflicts(pat *pattern, idx *routingIndex) []string { + var s []string + idx.possiblyConflictingPatterns(pat, func(p *pattern) error { + if pat.conflictsWith(p) { + s = append(s, p.String()) + } + return nil + }) + sort.Strings(s) + return slices.Compact(s) +} + +// generatePatterns generates all possible patterns using a representative +// sample of parts. +func generatePatterns() []*pattern { + var pats []*pattern + + collect := func(s string) { + // Replace duplicate wildcards with unique ones. + var b strings.Builder + wc := 0 + for { + i := strings.Index(s, "{x}") + if i < 0 { + b.WriteString(s) + break + } + b.WriteString(s[:i]) + fmt.Fprintf(&b, "{x%d}", wc) + wc++ + s = s[i+3:] + } + pat, err := parsePattern(b.String()) + if err != nil { + panic(err) + } + pats = append(pats, pat) + } + + var ( + methods = []string{"", "GET ", "HEAD ", "POST "} + hosts = []string{"", "h1", "h2"} + segs = []string{"/a", "/b", "/{x}"} + finalSegs = []string{"/a", "/b", "/{f}", "/{m...}", "/{$}"} + ) + + g := genConcat( + genChoice(methods), + genChoice(hosts), + genStar(3, genChoice(segs)), + genChoice(finalSegs)) + g(collect) + return pats +} + +// A generator is a function that calls its argument with the strings that it +// generates. +type generator func(collect func(string)) + +// genConst generates a single constant string. +func genConst(s string) generator { + return func(collect func(string)) { + collect(s) + } +} + +// genChoice generates all the strings in its argument. +func genChoice(choices []string) generator { + return func(collect func(string)) { + for _, c := range choices { + collect(c) + } + } +} + +// genConcat2 generates the cross product of the strings of g1 concatenated +// with those of g2. +func genConcat2(g1, g2 generator) generator { + return func(collect func(string)) { + g1(func(s1 string) { + g2(func(s2 string) { + collect(s1 + s2) + }) + }) + } +} + +// genConcat generalizes genConcat2 to any number of generators. +func genConcat(gs ...generator) generator { + if len(gs) == 0 { + return genConst("") + } + return genConcat2(gs[0], genConcat(gs[1:]...)) +} + +// genRepeat generates strings of exactly n copies of g's strings. +func genRepeat(n int, g generator) generator { + if n == 0 { + return genConst("") + } + return genConcat(g, genRepeat(n-1, g)) +} + +// genStar (named after the Kleene star) generates 0, 1, 2, ..., max +// copies of the strings of g. +func genStar(max int, g generator) generator { + return func(collect func(string)) { + for i := 0; i <= max; i++ { + genRepeat(i, g)(collect) + } + } +} + +func BenchmarkMultiConflicts(b *testing.B) { + // How fast is indexing if the corpus is all multis? + const nMultis = 1000 + var pats []*pattern + for i := 0; i < nMultis; i++ { + pats = append(pats, mustParsePattern(b, fmt.Sprintf("/a/b/{x}/d%d/", i))) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var idx routingIndex + for _, p := range pats { + got := indexConflicts(p, &idx) + if len(got) != 0 { + b.Fatalf("got %d conflicts, want 0", len(got)) + } + idx.addPattern(p) + } + if i == 0 { + // Confirm that all the multis ended up where they belong. + if g, w := len(idx.multis), nMultis; g != w { + b.Fatalf("got %d multis, want %d", g, w) + } + } + } +} diff --git a/src/net/http/routing_tree.go b/src/net/http/routing_tree.go new file mode 100644 index 0000000000..8812ed04e2 --- /dev/null +++ b/src/net/http/routing_tree.go @@ -0,0 +1,240 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements a decision tree for fast matching of requests to +// patterns. +// +// The root of the tree branches on the host of the request. +// The next level branches on the method. +// The remaining levels branch on consecutive segments of the path. +// +// The "more specific wins" precedence rule can result in backtracking. +// For example, given the patterns +// /a/b/z +// /a/{x}/c +// we will first try to match the path "/a/b/c" with /a/b/z, and +// when that fails we will try against /a/{x}/c. + +package http + +import ( + "strings" +) + +// A routingNode is a node in the decision tree. +// The same struct is used for leaf and interior nodes. +type routingNode struct { + // A leaf node holds a single pattern and the Handler it was registered + // with. + pattern *pattern + handler Handler + + // An interior node maps parts of the incoming request to child nodes. + // special children keys: + // "/" trailing slash (resulting from {$}) + // "" single wildcard + // "*" multi wildcard + children mapping[string, *routingNode] + emptyChild *routingNode // optimization: child with key "" +} + +// addPattern adds a pattern and its associated Handler to the tree +// at root. +func (root *routingNode) addPattern(p *pattern, h Handler) { + // First level of tree is host. + n := root.addChild(p.host) + // Second level of tree is method. + n = n.addChild(p.method) + // Remaining levels are path. + n.addSegments(p.segments, p, h) +} + +// addSegments adds the given segments to the tree rooted at n. +// If there are no segments, then n is a leaf node that holds +// the given pattern and handler. +func (n *routingNode) addSegments(segs []segment, p *pattern, h Handler) { + if len(segs) == 0 { + n.set(p, h) + return + } + seg := segs[0] + if seg.multi { + if len(segs) != 1 { + panic("multi wildcard not last") + } + n.addChild("*").set(p, h) + } else if seg.wild { + n.addChild("").addSegments(segs[1:], p, h) + } else { + n.addChild(seg.s).addSegments(segs[1:], p, h) + } +} + +// set sets the pattern and handler for n, which +// must be a leaf node. +func (n *routingNode) set(p *pattern, h Handler) { + if n.pattern != nil || n.handler != nil { + panic("non-nil leaf fields") + } + n.pattern = p + n.handler = h +} + +// addChild adds a child node with the given key to n +// if one does not exist, and returns the child. +func (n *routingNode) addChild(key string) *routingNode { + if key == "" { + if n.emptyChild == nil { + n.emptyChild = &routingNode{} + } + return n.emptyChild + } + if c := n.findChild(key); c != nil { + return c + } + c := &routingNode{} + n.children.add(key, c) + return c +} + +// findChild returns the child of n with the given key, or nil +// if there is no child with that key. +func (n *routingNode) findChild(key string) *routingNode { + if key == "" { + return n.emptyChild + } + r, _ := n.children.find(key) + return r +} + +// match returns the leaf node under root that matches the arguments, and a list +// of values for pattern wildcards in the order that the wildcards appear. +// For example, if the request path is "/a/b/c" and the pattern is "/{x}/b/{y}", +// then the second return value will be []string{"a", "c"}. +func (root *routingNode) match(host, method, path string) (*routingNode, []string) { + if host != "" { + // There is a host. If there is a pattern that specifies that host and it + // matches, we are done. If the pattern doesn't match, fall through to + // try patterns with no host. + if l, m := root.findChild(host).matchMethodAndPath(method, path); l != nil { + return l, m + } + } + return root.emptyChild.matchMethodAndPath(method, path) +} + +// matchMethodAndPath matches the method and path. +// Its return values are the same as [routingNode.match]. +// The receiver should be a child of the root. +func (n *routingNode) matchMethodAndPath(method, path string) (*routingNode, []string) { + if n == nil { + return nil, nil + } + if l, m := n.findChild(method).matchPath(path, nil); l != nil { + // Exact match of method name. + return l, m + } + if method == "HEAD" { + // GET matches HEAD too. + if l, m := n.findChild("GET").matchPath(path, nil); l != nil { + return l, m + } + } + // No exact match; try patterns with no method. + return n.emptyChild.matchPath(path, nil) +} + +// matchPath matches a path. +// Its return values are the same as [routingNode.match]. +// matchPath calls itself recursively. The matches argument holds the wildcard matches +// found so far. +func (n *routingNode) matchPath(path string, matches []string) (*routingNode, []string) { + if n == nil { + return nil, nil + } + // If path is empty, then we are done. + // If n is a leaf node, we found a match; return it. + // If n is an interior node (which means it has a nil pattern), + // then we failed to match. + if path == "" { + if n.pattern == nil { + return nil, nil + } + return n, matches + } + // Get the first segment of path. + seg, rest := firstSegment(path) + // First try matching against patterns that have a literal for this position. + // We know by construction that such patterns are more specific than those + // with a wildcard at this position (they are either more specific, equivalent, + // or overlap, and we ruled out the first two when the patterns were registered). + if n, m := n.findChild(seg).matchPath(rest, matches); n != nil { + return n, m + } + // If matching a literal fails, try again with patterns that have a single + // wildcard (represented by an empty string in the child mapping). + // Again, by construction, patterns with a single wildcard must be more specific than + // those with a multi wildcard. + // We skip this step if the segment is a trailing slash, because single wildcards + // don't match trailing slashes. + if seg != "/" { + if n, m := n.emptyChild.matchPath(rest, append(matches, seg)); n != nil { + return n, m + } + } + // Lastly, match the pattern (there can be at most one) that has a multi + // wildcard in this position to the rest of the path. + if c := n.findChild("*"); c != nil { + // Don't record a match for a nameless wildcard (which arises from a + // trailing slash in the pattern). + if c.pattern.lastSegment().s != "" { + matches = append(matches, pathUnescape(path[1:])) // remove initial slash + } + return c, matches + } + return nil, nil +} + +// firstSegment splits path into its first segment, and the rest. +// The path must begin with "/". +// If path consists of only a slash, firstSegment returns ("/", ""). +// The segment is returned unescaped, if possible. +func firstSegment(path string) (seg, rest string) { + if path == "/" { + return "/", "" + } + path = path[1:] // drop initial slash + i := strings.IndexByte(path, '/') + if i < 0 { + i = len(path) + } + return pathUnescape(path[:i]), path[i:] +} + +// matchingMethods adds to methodSet all the methods that would result in a +// match if passed to routingNode.match with the given host and path. +func (root *routingNode) matchingMethods(host, path string, methodSet map[string]bool) { + if host != "" { + root.findChild(host).matchingMethodsPath(path, methodSet) + } + root.emptyChild.matchingMethodsPath(path, methodSet) + if methodSet["GET"] { + methodSet["HEAD"] = true + } +} + +func (n *routingNode) matchingMethodsPath(path string, set map[string]bool) { + if n == nil { + return + } + n.children.eachPair(func(method string, c *routingNode) bool { + if p, _ := c.matchPath(path, nil); p != nil { + set[method] = true + } + return true + }) + // Don't look at the empty child. If there were an empty + // child, it would match on any method, but we only + // call this when we fail to match on a method. +} diff --git a/src/net/http/routing_tree_test.go b/src/net/http/routing_tree_test.go new file mode 100644 index 0000000000..2aac8b6cdf --- /dev/null +++ b/src/net/http/routing_tree_test.go @@ -0,0 +1,295 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "io" + "sort" + "strings" + "testing" + + "slices" +) + +func TestRoutingFirstSegment(t *testing.T) { + for _, test := range []struct { + in string + want []string + }{ + {"/a/b/c", []string{"a", "b", "c"}}, + {"/a/b/", []string{"a", "b", "/"}}, + {"/", []string{"/"}}, + {"/a/%62/c", []string{"a", "b", "c"}}, + {"/a%2Fb%2fc", []string{"a/b/c"}}, + } { + var got []string + rest := test.in + for len(rest) > 0 { + var seg string + seg, rest = firstSegment(rest) + got = append(got, seg) + } + if !slices.Equal(got, test.want) { + t.Errorf("%q: got %v, want %v", test.in, got, test.want) + } + } +} + +// TODO: test host and method +var testTree *routingNode + +func getTestTree() *routingNode { + if testTree == nil { + testTree = buildTree("/a", "/a/b", "/a/{x}", + "/g/h/i", "/g/{x}/j", + "/a/b/{x...}", "/a/b/{y}", "/a/b/{$}") + } + return testTree +} + +func buildTree(pats ...string) *routingNode { + root := &routingNode{} + for _, p := range pats { + pat, err := parsePattern(p) + if err != nil { + panic(err) + } + root.addPattern(pat, nil) + } + return root +} + +func TestRoutingAddPattern(t *testing.T) { + want := `"": + "": + "a": + "/a" + "": + "/a/{x}" + "b": + "/a/b" + "": + "/a/b/{y}" + "*": + "/a/b/{x...}" + "/": + "/a/b/{$}" + "g": + "": + "j": + "/g/{x}/j" + "h": + "i": + "/g/h/i" +` + + var b strings.Builder + getTestTree().print(&b, 0) + got := b.String() + if got != want { + t.Errorf("got\n%s\nwant\n%s", got, want) + } +} + +type testCase struct { + method, host, path string + wantPat string // "" for nil (no match) + wantMatches []string +} + +func TestRoutingNodeMatch(t *testing.T) { + + test := func(tree *routingNode, tests []testCase) { + t.Helper() + for _, test := range tests { + gotNode, gotMatches := tree.match(test.host, test.method, test.path) + got := "" + if gotNode != nil { + got = gotNode.pattern.String() + } + if got != test.wantPat { + t.Errorf("%s, %s, %s: got %q, want %q", test.host, test.method, test.path, got, test.wantPat) + } + if !slices.Equal(gotMatches, test.wantMatches) { + t.Errorf("%s, %s, %s: got matches %v, want %v", test.host, test.method, test.path, gotMatches, test.wantMatches) + } + } + } + + test(getTestTree(), []testCase{ + {"GET", "", "/a", "/a", nil}, + {"Get", "", "/b", "", nil}, + {"Get", "", "/a/b", "/a/b", nil}, + {"Get", "", "/a/c", "/a/{x}", []string{"c"}}, + {"Get", "", "/a/b/", "/a/b/{$}", nil}, + {"Get", "", "/a/b/c", "/a/b/{y}", []string{"c"}}, + {"Get", "", "/a/b/c/d", "/a/b/{x...}", []string{"c/d"}}, + {"Get", "", "/g/h/i", "/g/h/i", nil}, + {"Get", "", "/g/h/j", "/g/{x}/j", []string{"h"}}, + }) + + tree := buildTree( + "/item/", + "POST /item/{user}", + "GET /item/{user}", + "/item/{user}", + "/item/{user}/{id}", + "/item/{user}/new", + "/item/{$}", + "POST alt.com/item/{user}", + "GET /headwins", + "HEAD /headwins", + "/path/{p...}") + + test(tree, []testCase{ + {"GET", "", "/item/jba", + "GET /item/{user}", []string{"jba"}}, + {"POST", "", "/item/jba", + "POST /item/{user}", []string{"jba"}}, + {"HEAD", "", "/item/jba", + "GET /item/{user}", []string{"jba"}}, + {"get", "", "/item/jba", + "/item/{user}", []string{"jba"}}, // method matches are case-sensitive + {"POST", "", "/item/jba/17", + "/item/{user}/{id}", []string{"jba", "17"}}, + {"GET", "", "/item/jba/new", + "/item/{user}/new", []string{"jba"}}, + {"GET", "", "/item/", + "/item/{$}", []string{}}, + {"GET", "", "/item/jba/17/line2", + "/item/", nil}, + {"POST", "alt.com", "/item/jba", + "POST alt.com/item/{user}", []string{"jba"}}, + {"GET", "alt.com", "/item/jba", + "GET /item/{user}", []string{"jba"}}, + {"GET", "", "/item", + "", nil}, // does not match + {"GET", "", "/headwins", + "GET /headwins", nil}, + {"HEAD", "", "/headwins", // HEAD is more specific than GET + "HEAD /headwins", nil}, + {"GET", "", "/path/to/file", + "/path/{p...}", []string{"to/file"}}, + }) + + // A pattern ending in {$} should only match URLS with a trailing slash. + pat1 := "/a/b/{$}" + test(buildTree(pat1), []testCase{ + {"GET", "", "/a/b", "", nil}, + {"GET", "", "/a/b/", pat1, nil}, + {"GET", "", "/a/b/c", "", nil}, + {"GET", "", "/a/b/c/d", "", nil}, + }) + + // A pattern ending in a single wildcard should not match a trailing slash URL. + pat2 := "/a/b/{w}" + test(buildTree(pat2), []testCase{ + {"GET", "", "/a/b", "", nil}, + {"GET", "", "/a/b/", "", nil}, + {"GET", "", "/a/b/c", pat2, []string{"c"}}, + {"GET", "", "/a/b/c/d", "", nil}, + }) + + // A pattern ending in a multi wildcard should match both URLs. + pat3 := "/a/b/{w...}" + test(buildTree(pat3), []testCase{ + {"GET", "", "/a/b", "", nil}, + {"GET", "", "/a/b/", pat3, []string{""}}, + {"GET", "", "/a/b/c", pat3, []string{"c"}}, + {"GET", "", "/a/b/c/d", pat3, []string{"c/d"}}, + }) + + // All three of the above should work together. + test(buildTree(pat1, pat2, pat3), []testCase{ + {"GET", "", "/a/b", "", nil}, + {"GET", "", "/a/b/", pat1, nil}, + {"GET", "", "/a/b/c", pat2, []string{"c"}}, + {"GET", "", "/a/b/c/d", pat3, []string{"c/d"}}, + }) +} + +func TestMatchingMethods(t *testing.T) { + hostTree := buildTree("GET a.com/", "PUT b.com/", "POST /foo/{x}") + for _, test := range []struct { + name string + tree *routingNode + host, path string + want string + }{ + { + "post", + buildTree("POST /"), "", "/foo", + "POST", + }, + { + "get", + buildTree("GET /"), "", "/foo", + "GET,HEAD", + }, + { + "host", + hostTree, "", "/foo", + "", + }, + { + "host", + hostTree, "", "/foo/bar", + "POST", + }, + { + "host2", + hostTree, "a.com", "/foo/bar", + "GET,HEAD,POST", + }, + { + "host3", + hostTree, "b.com", "/bar", + "PUT", + }, + { + // This case shouldn't come up because we only call matchingMethods + // when there was no match, but we include it for completeness. + "empty", + buildTree("/"), "", "/", + "", + }, + } { + t.Run(test.name, func(t *testing.T) { + ms := map[string]bool{} + test.tree.matchingMethods(test.host, test.path, ms) + keys := mapKeys(ms) + sort.Strings(keys) + got := strings.Join(keys, ",") + if got != test.want { + t.Errorf("got %s, want %s", got, test.want) + } + }) + } +} + +func (n *routingNode) print(w io.Writer, level int) { + indent := strings.Repeat(" ", level) + if n.pattern != nil { + fmt.Fprintf(w, "%s%q\n", indent, n.pattern) + } + if n.emptyChild != nil { + fmt.Fprintf(w, "%s%q:\n", indent, "") + n.emptyChild.print(w, level+1) + } + + var keys []string + n.children.eachPair(func(k string, _ *routingNode) bool { + keys = append(keys, k) + return true + }) + sort.Strings(keys) + + for _, k := range keys { + fmt.Fprintf(w, "%s%q:\n", indent, k) + n, _ := n.children.find(k) + n.print(w, level+1) + } +} diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index bb380cf4a5..0c76f1bcc4 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -30,7 +30,6 @@ import ( "net/http/internal/testcert" "net/url" "os" - "os/exec" "path/filepath" "reflect" "regexp" @@ -108,10 +107,14 @@ type testConn struct { readMu sync.Mutex // for TestHandlerBodyClose readBuf bytes.Buffer writeBuf bytes.Buffer - closec chan bool // if non-nil, send value to it on close + closec chan bool // 1-buffered; receives true when Close is called noopConn } +func newTestConn() *testConn { + return &testConn{closec: make(chan bool, 1)} +} + func (c *testConn) Read(b []byte) (int, error) { c.readMu.Lock() defer c.readMu.Unlock() @@ -647,30 +650,27 @@ func benchmarkServeMux(b *testing.B, runHandler bool) { func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) } func testServerTimeouts(t *testing.T, mode testMode) { - // Try three times, with increasing timeouts. - tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} - for i, timeout := range tries { - err := testServerTimeoutsWithTimeout(t, timeout, mode) - if err == nil { - return - } - t.Logf("failed at %v: %v", timeout, err) - if i != len(tries)-1 { - t.Logf("retrying at %v ...", tries[i+1]) - } - } - t.Fatal("all attempts failed") + runTimeSensitiveTest(t, []time.Duration{ + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 500 * time.Millisecond, + 1 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + return testServerTimeoutsWithTimeout(t, timeout, mode) + }) } func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error { - reqNum := 0 - ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { - reqNum++ - fmt.Fprintf(res, "req=%d", reqNum) + var reqNum atomic.Int32 + cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { + fmt.Fprintf(res, "req=%d", reqNum.Add(1)) }), func(ts *httptest.Server) { ts.Config.ReadTimeout = timeout ts.Config.WriteTimeout = timeout - }).ts + }) + defer cst.close() + ts := cst.ts // Hit the HTTP server successfully. c := ts.Client() @@ -866,16 +866,20 @@ func TestWriteDeadlineEnforcedPerStream(t *testing.T) { } func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error { - reqNum := 0 - ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { - reqNum++ - if reqNum == 1 { - return // first request succeeds + firstRequest := make(chan bool, 1) + cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { + select { + case firstRequest <- true: + // first request succeeds + default: + // second request times out + time.Sleep(timeout) } - time.Sleep(timeout) // second request times out }), func(ts *httptest.Server) { ts.Config.WriteTimeout = timeout / 2 - }).ts + }) + defer cst.close() + ts := cst.ts c := ts.Client() @@ -922,14 +926,18 @@ func TestNoWriteDeadline(t *testing.T) { } func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error { - reqNum := 0 - ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { - reqNum++ - if reqNum == 1 { - return // first request succeeds + firstRequest := make(chan bool, 1) + cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { + select { + case firstRequest <- true: + // first request succeeds + default: + // second request times out + time.Sleep(timeout) } - time.Sleep(timeout) // second request timesout - })).ts + })) + defer cst.close() + ts := cst.ts c := ts.Client() @@ -1392,27 +1400,28 @@ func TestTLSHandshakeTimeout(t *testing.T) { run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) } func testTLSHandshakeTimeout(t *testing.T, mode testMode) { - errc := make(chanWriter, 10) // but only expecting 1 - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + errLog := new(strings.Builder) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), func(ts *httptest.Server) { ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) + ts.Config.ErrorLog = log.New(errLog, "", 0) }, - ).ts + ) + ts := cst.ts + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } - defer conn.Close() - var buf [1]byte n, err := conn.Read(buf[:]) if err == nil || n != 0 { t.Errorf("Read = %d, %v; want an error and no bytes", n, err) } + conn.Close() - v := <-errc - if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { + cst.close() + if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { t.Errorf("expected a TLS handshake timeout error; got %q", v) } } @@ -2968,15 +2977,36 @@ func (b neverEnding) Read(p []byte) (n int, err error) { return len(p), nil } -type countReader struct { - r io.Reader - n *int64 +type bodyLimitReader struct { + mu sync.Mutex + count int + limit int + closed chan struct{} } -func (cr countReader) Read(p []byte) (n int, err error) { - n, err = cr.r.Read(p) - atomic.AddInt64(cr.n, int64(n)) - return +func (r *bodyLimitReader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + select { + case <-r.closed: + return 0, errors.New("closed") + default: + } + if r.count > r.limit { + return 0, errors.New("at limit") + } + r.count += len(p) + for i := range p { + p[i] = 'a' + } + return len(p), nil +} + +func (r *bodyLimitReader) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + close(r.closed) + return nil } func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } @@ -3000,8 +3030,11 @@ func testRequestBodyLimit(t *testing.T, mode testMode) { } })) - nWritten := new(int64) - req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) + body := &bodyLimitReader{ + closed: make(chan struct{}), + limit: limit * 200, + } + req, _ := NewRequest("POST", cst.ts.URL, body) // Send the POST, but don't care it succeeds or not. The // remote side is going to reply and then close the TCP @@ -3016,10 +3049,13 @@ func testRequestBodyLimit(t *testing.T, mode testMode) { if err == nil { resp.Body.Close() } + // Wait for the Transport to finish writing the request body. + // It will close the body when done. + <-body.closed - if atomic.LoadInt64(nWritten) > limit*100 { + if body.count > limit*100 { t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", - limit, nWritten) + limit, body.count) } } @@ -3075,47 +3111,71 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - run(t, testServerGracefulClose, []testMode{http1Mode}) + // Not parallel: modifies the global rstAvoidanceDelay. + run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel) } func testServerGracefulClose(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - Error(w, "bye", StatusUnauthorized) - })).ts + runTimeSensitiveTest(t, []time.Duration{ + 1 * time.Millisecond, + 5 * time.Millisecond, + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 5 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + SetRSTAvoidanceDelay(t, timeout) + t.Logf("set RST avoidance delay to %v", timeout) - conn, err := net.Dial("tcp", ts.Listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - const bodySize = 5 << 20 - req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize)) - for i := 0; i < bodySize; i++ { - req = append(req, 'x') - } - writeErr := make(chan error) - go func() { - _, err := conn.Write(req) - writeErr <- err - }() - br := bufio.NewReader(conn) - lineNum := 0 - for { - line, err := br.ReadString('\n') - if err == io.EOF { - break + const bodySize = 5 << 20 + req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize)) + for i := 0; i < bodySize; i++ { + req = append(req, 'x') } + + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, "bye", StatusUnauthorized) + })) + // We need to close cst explicitly here so that in-flight server + // requests don't race with the call to SetRSTAvoidanceDelay for a retry. + defer cst.close() + ts := cst.ts + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { - t.Fatalf("ReadLine: %v", err) + return err } - lineNum++ - if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") { - t.Errorf("Response line = %q; want a 401", line) + writeErr := make(chan error) + go func() { + _, err := conn.Write(req) + writeErr <- err + }() + defer func() { + conn.Close() + // Wait for write to finish. This is a broken pipe on both + // Darwin and Linux, but checking this isn't the point of + // the test. + <-writeErr + }() + + br := bufio.NewReader(conn) + lineNum := 0 + for { + line, err := br.ReadString('\n') + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("ReadLine: %v", err) + } + lineNum++ + if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") { + t.Errorf("Response line = %q; want a 401", line) + } } - } - // Wait for write to finish. This is a broken pipe on both - // Darwin and Linux, but checking this isn't the point of - // the test. - <-writeErr + return nil + }) } func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) } @@ -3897,91 +3957,93 @@ func TestContentTypeOkayOn204(t *testing.T) { // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. func TestTransportAndServerSharedBodyRace(t *testing.T) { - run(t, testTransportAndServerSharedBodyRace) + run(t, testTransportAndServerSharedBodyRace, testNotParallel) } func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { - const bodySize = 1 << 20 - - // errorf is like t.Errorf, but also writes to println. When - // this test fails, it hangs. This helps debugging and I've - // added this enough times "temporarily". It now gets added - // full time. - errorf := func(format string, args ...any) { - v := fmt.Sprintf(format, args...) - println(v) - t.Error(v) - } - - unblockBackend := make(chan bool) - backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { - gone := rw.(CloseNotifier).CloseNotify() - didCopy := make(chan any) - go func() { + // The proxy server in the middle of the stack for this test potentially + // from its handler after only reading half of the body. + // That can trigger https://go.dev/issue/3595, which is otherwise + // irrelevant to this test. + runTimeSensitiveTest(t, []time.Duration{ + 1 * time.Millisecond, + 5 * time.Millisecond, + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 5 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + SetRSTAvoidanceDelay(t, timeout) + t.Logf("set RST avoidance delay to %v", timeout) + + const bodySize = 1 << 20 + + var wg sync.WaitGroup + backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + // Work around https://go.dev/issue/38370: clientServerTest uses + // an httptest.Server under the hood, and in HTTP/2 mode it does not always + // “[block] until all outstanding requests on this server have completed”, + // causing the call to Logf below to race with the end of the test. + // + // Since the client doesn't cancel the request until we have copied half + // the body, this call to add happens before the test is cleaned up, + // preventing the race. + wg.Add(1) + defer wg.Done() + n, err := io.CopyN(rw, req.Body, bodySize) - didCopy <- []any{n, err} + t.Logf("backend CopyN: %v, %v", n, err) + <-req.Context().Done() + })) + // We need to close explicitly here so that in-flight server + // requests don't race with the call to SetRSTAvoidanceDelay for a retry. + defer func() { + wg.Wait() + backend.close() }() - isGone := false - Loop: - for { - select { - case <-didCopy: - break Loop - case <-gone: - isGone = true - case <-time.After(time.Second): - println("1 second passes in backend, proxygone=", isGone) - } - } - <-unblockBackend - })) - defer backend.close() - backendRespc := make(chan *Response, 1) - var proxy *clientServerTest - proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { - req2, _ := NewRequest("POST", backend.ts.URL, req.Body) - req2.ContentLength = bodySize - cancel := make(chan struct{}) - req2.Cancel = cancel + var proxy *clientServerTest + proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + req2, _ := NewRequest("POST", backend.ts.URL, req.Body) + req2.ContentLength = bodySize + cancel := make(chan struct{}) + req2.Cancel = cancel - bresp, err := proxy.c.Do(req2) - if err != nil { - errorf("Proxy outbound request: %v", err) - return - } - _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2) - if err != nil { - errorf("Proxy copy error: %v", err) - return - } - backendRespc <- bresp // to close later + bresp, err := proxy.c.Do(req2) + if err != nil { + t.Errorf("Proxy outbound request: %v", err) + return + } + _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2) + if err != nil { + t.Errorf("Proxy copy error: %v", err) + return + } + t.Cleanup(func() { bresp.Body.Close() }) + + // Try to cause a race. Canceling the client request will cause the client + // transport to close req2.Body. Returning from the server handler will + // cause the server to close req.Body. Since they are the same underlying + // ReadCloser, that will result in concurrent calls to Close (and possibly a + // Read concurrent with a Close). + if mode == http2Mode { + close(cancel) + } else { + proxy.c.Transport.(*Transport).CancelRequest(req2) + } + rw.Write([]byte("OK")) + })) + defer proxy.close() - // Try to cause a race: Both the Transport and the proxy handler's Server - // will try to read/close req.Body (aka req2.Body) - if mode == http2Mode { - close(cancel) - } else { - proxy.c.Transport.(*Transport).CancelRequest(req2) + req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize)) + res, err := proxy.c.Do(req) + if err != nil { + return fmt.Errorf("original request: %v", err) } - rw.Write([]byte("OK")) - })) - defer proxy.close() - - defer close(unblockBackend) - req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize)) - res, err := proxy.c.Do(req) - if err != nil { - t.Fatalf("Original request: %v", err) - } - - // Cleanup, so we don't leak goroutines. - res.Body.Close() - select { - case res := <-backendRespc: res.Body.Close() - default: - // We failed earlier. (e.g. on proxy.c.Do(req2)) - } + return nil + }) } // Test that a hanging Request.Body.Read from another goroutine can't @@ -4316,7 +4378,8 @@ func (c *closeWriteTestConn) CloseWrite() error { } func TestCloseWrite(t *testing.T) { - setParallel(t) + SetRSTAvoidanceDelay(t, 1*time.Millisecond) + var srv Server var testConn closeWriteTestConn c := ExportServerNewConn(&srv, &testConn) @@ -4552,10 +4615,10 @@ Host: foo } // If a Handler finishes and there's an unread request body, -// verify the server try to do implicit read on it before replying. +// verify the server implicitly tries to do a read on it before replying. func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { setParallel(t) - conn := &testConn{closec: make(chan bool)} + conn := newTestConn() conn.readBuf.Write([]byte(fmt.Sprintf( "POST / HTTP/1.1\r\n" + "Host: test\r\n" + @@ -4645,7 +4708,7 @@ func TestServerValidatesHostHeader(t *testing.T) { {"GET / HTTP/3.0", "", 505}, } for _, tt := range tests { - conn := &testConn{closec: make(chan bool, 1)} + conn := newTestConn() methodTarget := "GET / " if !strings.HasPrefix(tt.proto, "HTTP/") { methodTarget = "" @@ -4743,7 +4806,7 @@ func TestServerValidatesHeaders(t *testing.T) { {"foo: foo\xfffoo\r\n", 200}, // non-ASCII high octets in value are fine } for _, tt := range tests { - conn := &testConn{closec: make(chan bool, 1)} + conn := newTestConn() io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n") ln := &oneConnListener{conn} @@ -4966,7 +5029,7 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) // For use like: // // $ go test -c -// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof +// $ ./http.test -test.run='^$' -test.bench='^BenchmarkServer$' -test.benchtime=15s -test.cpuprofile=http.prof // $ go tool pprof http.test http.prof // (pprof) web func BenchmarkServer(b *testing.B) { @@ -5005,7 +5068,7 @@ func BenchmarkServer(b *testing.B) { defer ts.Close() b.StartTimer() - cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer$") + cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$") cmd.Env = append([]string{ fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N), fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL), @@ -5060,7 +5123,7 @@ func BenchmarkClient(b *testing.B) { // Start server process. ctx, cancel := context.WithCancel(context.Background()) - cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$") + cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$") cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes") cmd.Stderr = os.Stderr stdout, err := cmd.StdoutPipe() @@ -5129,11 +5192,7 @@ Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 `) res := []byte("Hello world!\n") - conn := &testConn{ - // testConn.Close will not push into the channel - // if it's full. - closec: make(chan bool, 1), - } + conn := newTestConn() handler := HandlerFunc(func(rw ResponseWriter, r *Request) { rw.Header().Set("Content-Type", "text/html; charset=utf-8") rw.Write(res) @@ -5356,49 +5415,75 @@ func testServerIdleTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(io.Discard, r.Body) - io.WriteString(w, r.RemoteAddr) - }), func(ts *httptest.Server) { - ts.Config.ReadHeaderTimeout = 1 * time.Second - ts.Config.IdleTimeout = 2 * time.Second - }).ts - c := ts.Client() + runTimeSensitiveTest(t, []time.Duration{ + 10 * time.Millisecond, + 100 * time.Millisecond, + 1 * time.Second, + 10 * time.Second, + }, func(t *testing.T, readHeaderTimeout time.Duration) error { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(io.Discard, r.Body) + io.WriteString(w, r.RemoteAddr) + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = readHeaderTimeout + ts.Config.IdleTimeout = 2 * readHeaderTimeout + }) + defer cst.close() + ts := cst.ts + t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout) + t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout) + c := ts.Client() - get := func() string { - res, err := c.Get(ts.URL) + get := func() (string, error) { + res, err := c.Get(ts.URL) + if err != nil { + return "", err + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + // If we're at this point the headers have definitely already been + // read and the server is not idle, so neither timeout applies: + // this should never fail. + t.Fatal(err) + } + return string(slurp), nil + } + + a1, err := get() if err != nil { - t.Fatal(err) + return err } - defer res.Body.Close() - slurp, err := io.ReadAll(res.Body) + a2, err := get() if err != nil { - t.Fatal(err) + return err + } + if a1 != a2 { + return fmt.Errorf("did requests on different connections") + } + time.Sleep(ts.Config.IdleTimeout * 3 / 2) + a3, err := get() + if err != nil { + return err + } + if a2 == a3 { + return fmt.Errorf("request three unexpectedly on same connection") } - return string(slurp) - } - a1, a2 := get(), get() - if a1 != a2 { - t.Fatalf("did requests on different connections") - } - time.Sleep(3 * time.Second) - a3 := get() - if a2 == a3 { - t.Fatal("request three unexpectedly on same connection") - } + // And test that ReadHeaderTimeout still works: + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + return err + } + defer conn.Close() + conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n")) + time.Sleep(ts.Config.ReadHeaderTimeout * 2) + if _, err := io.CopyN(io.Discard, conn, 1); err == nil { + return fmt.Errorf("copy byte succeeded; want err") + } - // And test that ReadHeaderTimeout still works: - conn, err := net.Dial("tcp", ts.Listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n")) - time.Sleep(2 * time.Second) - if _, err := io.CopyN(io.Discard, conn, 1); err == nil { - t.Fatal("copy byte succeeded; want err") - } + return nil + }) } func get(t *testing.T, c *Client, url string) string { @@ -5658,7 +5743,7 @@ func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { select { case <-time.After(2 * timeout): fmt.Fprint(w, "ok") @@ -5667,7 +5752,9 @@ func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { } }), func(ts *httptest.Server) { ts.Config.ReadTimeout = timeout - }).ts + }) + defer cst.close() + ts := cst.ts c := ts.Client() @@ -5701,10 +5788,12 @@ func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { + cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { ts.Config.ReadHeaderTimeout = timeout ts.Config.IdleTimeout = 0 // disable idle timeout - }).ts + }) + defer cst.close() + ts := cst.ts // rather than using an http.Client, create a single connection, so that // we can ensure this connection is not closed. @@ -5747,9 +5836,10 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t * if err == nil { return } - if i == len(durations)-1 { + if i == len(durations)-1 || t.Failed() { t.Fatalf("failed with duration %v: %v", d, err) } + t.Logf("retrying after error with duration %v: %v", d, err) } } @@ -5929,7 +6019,7 @@ func TestServerValidatesMethod(t *testing.T) { {"GE(T", 400}, } for _, tt := range tests { - conn := &testConn{closec: make(chan bool, 1)} + conn := newTestConn() io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n") ln := &oneConnListener{conn} @@ -6594,7 +6684,7 @@ func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, } func TestMaxBytesHandler(t *testing.T) { - setParallel(t) + // Not parallel: modifies the global rstAvoidanceDelay. defer afterTest(t) for _, maxSize := range []int64{100, 1_000, 1_000_000} { @@ -6603,77 +6693,99 @@ func TestMaxBytesHandler(t *testing.T) { func(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testMaxBytesHandler(t, mode, maxSize, requestSize) - }) + }, testNotParallel) }) } } } func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) { - var ( - handlerN int64 - handlerErr error - ) - echo := HandlerFunc(func(w ResponseWriter, r *Request) { - var buf bytes.Buffer - handlerN, handlerErr = io.Copy(&buf, r.Body) - io.Copy(w, &buf) - }) - - ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts - defer ts.Close() + runTimeSensitiveTest(t, []time.Duration{ + 1 * time.Millisecond, + 5 * time.Millisecond, + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 5 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + SetRSTAvoidanceDelay(t, timeout) + t.Logf("set RST avoidance delay to %v", timeout) + + var ( + handlerN int64 + handlerErr error + ) + echo := HandlerFunc(func(w ResponseWriter, r *Request) { + var buf bytes.Buffer + handlerN, handlerErr = io.Copy(&buf, r.Body) + io.Copy(w, &buf) + }) - c := ts.Client() + cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)) + // We need to close cst explicitly here so that in-flight server + // requests don't race with the call to SetRSTAvoidanceDelay for a retry. + defer cst.close() + ts := cst.ts + c := ts.Client() - body := strings.Repeat("a", int(requestSize)) - var wg sync.WaitGroup - defer wg.Wait() - getBody := func() (io.ReadCloser, error) { - wg.Add(1) - body := &wgReadCloser{ - Reader: strings.NewReader(body), - wg: &wg, + body := strings.Repeat("a", int(requestSize)) + var wg sync.WaitGroup + defer wg.Wait() + getBody := func() (io.ReadCloser, error) { + wg.Add(1) + body := &wgReadCloser{ + Reader: strings.NewReader(body), + wg: &wg, + } + return body, nil } - return body, nil - } - reqBody, _ := getBody() - req, err := NewRequest("POST", ts.URL, reqBody) - if err != nil { - reqBody.Close() - t.Fatal(err) - } - req.ContentLength = int64(len(body)) - req.GetBody = getBody - req.Header.Set("Content-Type", "text/plain") + reqBody, _ := getBody() + req, err := NewRequest("POST", ts.URL, reqBody) + if err != nil { + reqBody.Close() + t.Fatal(err) + } + req.ContentLength = int64(len(body)) + req.GetBody = getBody + req.Header.Set("Content-Type", "text/plain") - var buf strings.Builder - res, err := c.Do(req) - if err != nil { - t.Errorf("unexpected connection error: %v", err) - } else { - _, err = io.Copy(&buf, res.Body) - res.Body.Close() + var buf strings.Builder + res, err := c.Do(req) if err != nil { - t.Errorf("unexpected read error: %v", err) + return fmt.Errorf("unexpected connection error: %v", err) + } else { + _, err = io.Copy(&buf, res.Body) + res.Body.Close() + if err != nil { + return fmt.Errorf("unexpected read error: %v", err) + } } - } - if handlerN > maxSize { - t.Errorf("expected max request body %d; got %d", maxSize, handlerN) - } - if requestSize > maxSize && handlerErr == nil { - t.Error("expected error on handler side; got nil") - } - if requestSize <= maxSize { - if handlerErr != nil { - t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr) + // We don't expect any of the errors after this point to occur due + // to rstAvoidanceDelay being too short, so we use t.Errorf for those + // instead of returning a (retriable) error. + + if handlerN > maxSize { + t.Errorf("expected max request body %d; got %d", maxSize, handlerN) } - if handlerN != requestSize { - t.Errorf("expected request of size %d; got %d", requestSize, handlerN) + if requestSize > maxSize && handlerErr == nil { + t.Error("expected error on handler side; got nil") } - } - if buf.Len() != int(handlerN) { - t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) - } + if requestSize <= maxSize { + if handlerErr != nil { + t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr) + } + if handlerN != requestSize { + t.Errorf("expected request of size %d; got %d", requestSize, handlerN) + } + } + if buf.Len() != int(handlerN) { + t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) + } + + return nil + }) } func TestEarlyHints(t *testing.T) { diff --git a/src/net/http/servemux121.go b/src/net/http/servemux121.go new file mode 100644 index 0000000000..c0a4b77010 --- /dev/null +++ b/src/net/http/servemux121.go @@ -0,0 +1,211 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// This file implements ServeMux behavior as in Go 1.21. +// The behavior is controlled by a GODEBUG setting. +// Most of this code is derived from commit 08e35cc334. +// Changes are minimal: aside from the different receiver type, +// they mostly involve renaming functions, usually by unexporting them. + +import ( + "internal/godebug" + "net/url" + "sort" + "strings" + "sync" +) + +var httpmuxgo121 = godebug.New("httpmuxgo121") + +var use121 bool + +// Read httpmuxgo121 once at startup, since dealing with changes to it during +// program execution is too complex and error-prone. +func init() { + if httpmuxgo121.Value() == "1" { + use121 = true + httpmuxgo121.IncNonDefault() + } +} + +// serveMux121 holds the state of a ServeMux needed for Go 1.21 behavior. +type serveMux121 struct { + mu sync.RWMutex + m map[string]muxEntry + es []muxEntry // slice of entries sorted from longest to shortest. + hosts bool // whether any patterns contain hostnames +} + +type muxEntry struct { + h Handler + pattern string +} + +// Formerly ServeMux.Handle. +func (mux *serveMux121) handle(pattern string, handler Handler) { + mux.mu.Lock() + defer mux.mu.Unlock() + + if pattern == "" { + panic("http: invalid pattern") + } + if handler == nil { + panic("http: nil handler") + } + if _, exist := mux.m[pattern]; exist { + panic("http: multiple registrations for " + pattern) + } + + if mux.m == nil { + mux.m = make(map[string]muxEntry) + } + e := muxEntry{h: handler, pattern: pattern} + mux.m[pattern] = e + if pattern[len(pattern)-1] == '/' { + mux.es = appendSorted(mux.es, e) + } + + if pattern[0] != '/' { + mux.hosts = true + } +} + +func appendSorted(es []muxEntry, e muxEntry) []muxEntry { + n := len(es) + i := sort.Search(n, func(i int) bool { + return len(es[i].pattern) < len(e.pattern) + }) + if i == n { + return append(es, e) + } + // we now know that i points at where we want to insert + es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. + copy(es[i+1:], es[i:]) // Move shorter entries down + es[i] = e + return es +} + +// Formerly ServeMux.HandleFunc. +func (mux *serveMux121) handleFunc(pattern string, handler func(ResponseWriter, *Request)) { + if handler == nil { + panic("http: nil handler") + } + mux.handle(pattern, HandlerFunc(handler)) +} + +// Formerly ServeMux.Handler. +func (mux *serveMux121) findHandler(r *Request) (h Handler, pattern string) { + + // CONNECT requests are not canonicalized. + if r.Method == "CONNECT" { + // If r.URL.Path is /tree and its handler is not registered, + // the /tree -> /tree/ redirect applies to CONNECT requests + // but the path canonicalization does not. + if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + + return mux.handler(r.Host, r.URL.Path) + } + + // All other requests have any port stripped and path cleaned + // before passing to mux.handler. + host := stripHostPort(r.Host) + path := cleanPath(r.URL.Path) + + // If the given path is /tree and its handler is not registered, + // redirect for /tree/. + if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + + if path != r.URL.Path { + _, pattern = mux.handler(host, path) + u := &url.URL{Path: path, RawQuery: r.URL.RawQuery} + return RedirectHandler(u.String(), StatusMovedPermanently), pattern + } + + return mux.handler(host, r.URL.Path) +} + +// handler is the main implementation of findHandler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *serveMux121) handler(host, path string) (h Handler, pattern string) { + mux.mu.RLock() + defer mux.mu.RUnlock() + + // Host-specific pattern takes precedence over generic ones + if mux.hosts { + h, pattern = mux.match(host + path) + } + if h == nil { + h, pattern = mux.match(path) + } + if h == nil { + h, pattern = NotFoundHandler(), "" + } + return +} + +// Find a handler on a handler map given a path string. +// Most-specific (longest) pattern wins. +func (mux *serveMux121) match(path string) (h Handler, pattern string) { + // Check for exact match first. + v, ok := mux.m[path] + if ok { + return v.h, v.pattern + } + + // Check for longest valid match. mux.es contains all patterns + // that end in / sorted from longest to shortest. + for _, e := range mux.es { + if strings.HasPrefix(path, e.pattern) { + return e.h, e.pattern + } + } + return nil, "" +} + +// redirectToPathSlash determines if the given path needs appending "/" to it. +// This occurs when a handler for path + "/" was already registered, but +// not for path itself. If the path needs appending to, it creates a new +// URL, setting the path to u.Path + "/" and returning true to indicate so. +func (mux *serveMux121) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) { + mux.mu.RLock() + shouldRedirect := mux.shouldRedirectRLocked(host, path) + mux.mu.RUnlock() + if !shouldRedirect { + return u, false + } + path = path + "/" + u = &url.URL{Path: path, RawQuery: u.RawQuery} + return u, true +} + +// shouldRedirectRLocked reports whether the given path and host should be redirected to +// path+"/". This should happen if a handler is registered for path+"/" but +// not path -- see comments at ServeMux. +func (mux *serveMux121) shouldRedirectRLocked(host, path string) bool { + p := []string{path, host + path} + + for _, c := range p { + if _, exist := mux.m[c]; exist { + return false + } + } + + n := len(path) + if n == 0 { + return false + } + for _, c := range p { + if _, exist := mux.m[c+"/"]; exist { + return path[n-1] != '/' + } + } + + return false +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 8f63a90299..acac78bcd0 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -61,16 +61,16 @@ var ( // A Handler responds to an HTTP request. // -// ServeHTTP should write reply headers and data to the ResponseWriter +// [Handler.ServeHTTP] should write reply headers and data to the [ResponseWriter] // and then return. Returning signals that the request is finished; it -// is not valid to use the ResponseWriter or read from the -// Request.Body after or concurrently with the completion of the +// is not valid to use the [ResponseWriter] or read from the +// [Request.Body] after or concurrently with the completion of the // ServeHTTP call. // // Depending on the HTTP client software, HTTP protocol version, and // any intermediaries between the client and the Go server, it may not -// be possible to read from the Request.Body after writing to the -// ResponseWriter. Cautious handlers should read the Request.Body +// be possible to read from the [Request.Body] after writing to the +// [ResponseWriter]. Cautious handlers should read the [Request.Body] // first, and then reply. // // Except for reading the body, handlers should not modify the @@ -82,7 +82,7 @@ var ( // and either closes the network connection or sends an HTTP/2 // RST_STREAM, depending on the HTTP protocol. To abort a handler so // the client sees an interrupted response but the server doesn't log -// an error, panic with the value ErrAbortHandler. +// an error, panic with the value [ErrAbortHandler]. type Handler interface { ServeHTTP(ResponseWriter, *Request) } @@ -90,15 +90,14 @@ type Handler interface { // A ResponseWriter interface is used by an HTTP handler to // construct an HTTP response. // -// A ResponseWriter may not be used after the Handler.ServeHTTP method -// has returned. +// A ResponseWriter may not be used after [Handler.ServeHTTP] has returned. type ResponseWriter interface { // Header returns the header map that will be sent by - // WriteHeader. The Header map also is the mechanism with which - // Handlers can set HTTP trailers. + // [ResponseWriter.WriteHeader]. The [Header] map also is the mechanism with which + // [Handler] implementations can set HTTP trailers. // - // Changing the header map after a call to WriteHeader (or - // Write) has no effect unless the HTTP status code was of the + // Changing the header map after a call to [ResponseWriter.WriteHeader] (or + // [ResponseWriter.Write]) has no effect unless the HTTP status code was of the // 1xx class or the modified headers are trailers. // // There are two ways to set Trailers. The preferred way is to @@ -107,9 +106,9 @@ type ResponseWriter interface { // trailer keys which will come later. In this case, those // keys of the Header map are treated as if they were // trailers. See the example. The second way, for trailer - // keys not known to the Handler until after the first Write, - // is to prefix the Header map keys with the TrailerPrefix - // constant value. See TrailerPrefix. + // keys not known to the [Handler] until after the first [ResponseWriter.Write], + // is to prefix the [Header] map keys with the [TrailerPrefix] + // constant value. // // To suppress automatic response headers (such as "Date"), set // their value to nil. @@ -117,11 +116,11 @@ type ResponseWriter interface { // Write writes the data to the connection as part of an HTTP reply. // - // If WriteHeader has not yet been called, Write calls + // If [ResponseWriter.WriteHeader] has not yet been called, Write calls // WriteHeader(http.StatusOK) before writing the data. If the Header // does not contain a Content-Type line, Write adds a Content-Type set // to the result of passing the initial 512 bytes of written data to - // DetectContentType. Additionally, if the total size of all written + // [DetectContentType]. Additionally, if the total size of all written // data is under a few KB and there are no Flush calls, the // Content-Length header is added automatically. // @@ -162,8 +161,8 @@ type ResponseWriter interface { // The Flusher interface is implemented by ResponseWriters that allow // an HTTP handler to flush buffered data to the client. // -// The default HTTP/1.x and HTTP/2 ResponseWriter implementations -// support Flusher, but ResponseWriter wrappers may not. Handlers +// The default HTTP/1.x and HTTP/2 [ResponseWriter] implementations +// support [Flusher], but ResponseWriter wrappers may not. Handlers // should always test for this ability at runtime. // // Note that even for ResponseWriters that support Flush, @@ -178,7 +177,7 @@ type Flusher interface { // The Hijacker interface is implemented by ResponseWriters that allow // an HTTP handler to take over the connection. // -// The default ResponseWriter for HTTP/1.x connections supports +// The default [ResponseWriter] for HTTP/1.x connections supports // Hijacker, but HTTP/2 connections intentionally do not. // ResponseWriter wrappers may also not support Hijacker. Handlers // should always test for this ability at runtime. @@ -212,7 +211,7 @@ type Hijacker interface { // if the client has disconnected before the response is ready. // // Deprecated: the CloseNotifier interface predates Go's context package. -// New code should use Request.Context instead. +// New code should use [Request.Context] instead. type CloseNotifier interface { // CloseNotify returns a channel that receives at most a // single value (true) when the client connection has gone @@ -506,7 +505,7 @@ func (c *response) EnableFullDuplex() error { return nil } -// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// TrailerPrefix is a magic prefix for [ResponseWriter.Header] map keys // that, if present, signals that the map entry is actually for // the response trailers, and not the response headers. The prefix // is stripped after the ServeHTTP call finishes and the values are @@ -572,13 +571,12 @@ type writerOnly struct { io.Writer } -// ReadFrom is here to optimize copying from an *os.File regular file -// to a *net.TCPConn with sendfile, or from a supported src type such +// ReadFrom is here to optimize copying from an [*os.File] regular file +// to a [*net.TCPConn] with sendfile, or from a supported src type such // as a *net.TCPConn on Linux with splice. func (w *response) ReadFrom(src io.Reader) (n int64, err error) { - bufp := copyBufPool.Get().(*[]byte) - buf := *bufp - defer copyBufPool.Put(bufp) + buf := getCopyBuf() + defer putCopyBuf(buf) // Our underlying w.conn.rwc is usually a *TCPConn (with its // own ReadFrom method). If not, just fall back to the normal @@ -808,11 +806,18 @@ var ( bufioWriter4kPool sync.Pool ) -var copyBufPool = sync.Pool{ - New: func() any { - b := make([]byte, 32*1024) - return &b - }, +const copyBufPoolSize = 32 * 1024 + +var copyBufPool = sync.Pool{New: func() any { return new([copyBufPoolSize]byte) }} + +func getCopyBuf() []byte { + return copyBufPool.Get().(*[copyBufPoolSize]byte)[:] +} +func putCopyBuf(b []byte) { + if len(b) != copyBufPoolSize { + panic("trying to put back buffer of the wrong size in the copyBufPool") + } + copyBufPool.Put((*[copyBufPoolSize]byte)(b)) } func bufioWriterPool(size int) *sync.Pool { @@ -862,7 +867,7 @@ func putBufioWriter(bw *bufio.Writer) { // DefaultMaxHeaderBytes is the maximum permitted size of the headers // in an HTTP request. -// This can be overridden by setting Server.MaxHeaderBytes. +// This can be overridden by setting [Server.MaxHeaderBytes]. const DefaultMaxHeaderBytes = 1 << 20 // 1 MB func (srv *Server) maxHeaderBytes() int { @@ -935,11 +940,11 @@ func (ecr *expectContinueReader) Close() error { } // TimeFormat is the time format to use when generating times in HTTP -// headers. It is like time.RFC1123 but hard-codes GMT as the time +// headers. It is like [time.RFC1123] but hard-codes GMT as the time // zone. The time being formatted must be in UTC for Format to // generate the correct format. // -// For parsing this time format, see ParseTime. +// For parsing this time format, see [ParseTime]. const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" // appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat)) @@ -1585,13 +1590,13 @@ func (w *response) bodyAllowed() bool { // The Writers are wired together like: // // 1. *response (the ResponseWriter) -> -// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes -> +// 2. (*response).w, a [*bufio.Writer] of bufferBeforeChunkingSize bytes -> // 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) // and which writes the chunk headers, if needed -> // 4. conn.bufw, a *bufio.Writer of default (4kB) bytes, writing to -> // 5. checkConnErrorWriter{c}, which notes any non-nil error on Write // and populates c.werr with it if so, but otherwise writes to -> -// 6. the rwc, the net.Conn. +// 6. the rwc, the [net.Conn]. // // TODO(bradfitz): short-circuit some of the buffering when the // initial header contains both a Content-Type and Content-Length. @@ -1752,8 +1757,12 @@ func (c *conn) close() { // and processes its final data before they process the subsequent RST // from closing a connection with known unread data. // This RST seems to occur mostly on BSD systems. (And Windows?) -// This timeout is somewhat arbitrary (~latency around the planet). -const rstAvoidanceDelay = 500 * time.Millisecond +// This timeout is somewhat arbitrary (~latency around the planet), +// and may be modified by tests. +// +// TODO(bcmills): This should arguably be a server configuration parameter, +// not a hard-coded value. +var rstAvoidanceDelay = 500 * time.Millisecond type closeWriter interface { CloseWrite() error @@ -1772,6 +1781,27 @@ func (c *conn) closeWriteAndWait() { if tcp, ok := c.rwc.(closeWriter); ok { tcp.CloseWrite() } + + // When we return from closeWriteAndWait, the caller will fully close the + // connection. If client is still writing to the connection, this will cause + // the write to fail with ECONNRESET or similar. Unfortunately, many TCP + // implementations will also drop unread packets from the client's read buffer + // when a write fails, causing our final response to be truncated away too. + // + // As a result, https://www.rfc-editor.org/rfc/rfc7230#section-6.6 recommends + // that “[t]he server … continues to read from the connection until it + // receives a corresponding close by the client, or until the server is + // reasonably certain that its own TCP stack has received the client's + // acknowledgement of the packet(s) containing the server's last response.” + // + // Unfortunately, we have no straightforward way to be “reasonably certain” + // that we have received the client's ACK, and at any rate we don't want to + // allow a misbehaving client to soak up server connections indefinitely by + // withholding an ACK, nor do we want to go through the complexity or overhead + // of using low-level APIs to figure out when a TCP round-trip has completed. + // + // Instead, we declare that we are “reasonably certain” that we received the + // ACK if maxRSTAvoidanceDelay has elapsed. time.Sleep(rstAvoidanceDelay) } @@ -1971,7 +2001,7 @@ func (c *conn) serve(ctx context.Context) { fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text) return } - publicErr := "400 Bad Request" + const publicErr = "400 Bad Request" fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } @@ -2067,8 +2097,8 @@ func (w *response) sendExpectationFailed() { w.finishRequest() } -// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter -// and a Hijacker. +// Hijack implements the [Hijacker.Hijack] method. Our response is both a [ResponseWriter] +// and a [Hijacker]. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.handlerDone.Load() { panic("net/http: Hijack called after ServeHTTP finished") @@ -2128,7 +2158,7 @@ func requestBodyRemains(rc io.ReadCloser) bool { // The HandlerFunc type is an adapter to allow the use of // ordinary functions as HTTP handlers. If f is a function // with the appropriate signature, HandlerFunc(f) is a -// Handler that calls f. +// [Handler] that calls f. type HandlerFunc func(ResponseWriter, *Request) // ServeHTTP calls f(w, r). @@ -2187,9 +2217,9 @@ func StripPrefix(prefix string, h Handler) Handler { // which may be a path relative to the request path. // // The provided code should be in the 3xx range and is usually -// StatusMovedPermanently, StatusFound or StatusSeeOther. +// [StatusMovedPermanently], [StatusFound] or [StatusSeeOther]. // -// If the Content-Type header has not been set, Redirect sets it +// If the Content-Type header has not been set, [Redirect] sets it // to "text/html; charset=utf-8" and writes a small HTML body. // Setting the Content-Type header to any value, including nil, // disables that behavior. @@ -2277,7 +2307,7 @@ func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) { // status code. // // The provided code should be in the 3xx range and is usually -// StatusMovedPermanently, StatusFound or StatusSeeOther. +// [StatusMovedPermanently], [StatusFound] or [StatusSeeOther]. func RedirectHandler(url string, code int) Handler { return &redirectHandler{url, code} } @@ -2287,52 +2317,132 @@ func RedirectHandler(url string, code int) Handler { // patterns and calls the handler for the pattern that // most closely matches the URL. // -// Patterns name fixed, rooted paths, like "/favicon.ico", -// or rooted subtrees, like "/images/" (note the trailing slash). -// Longer patterns take precedence over shorter ones, so that -// if there are handlers registered for both "/images/" -// and "/images/thumbnails/", the latter handler will be -// called for paths beginning with "/images/thumbnails/" and the -// former will receive requests for any other paths in the -// "/images/" subtree. -// -// Note that since a pattern ending in a slash names a rooted subtree, -// the pattern "/" matches all paths not matched by other registered -// patterns, not just the URL with Path == "/". -// -// If a subtree has been registered and a request is received naming the -// subtree root without its trailing slash, ServeMux redirects that -// request to the subtree root (adding the trailing slash). This behavior can -// be overridden with a separate registration for the path without -// the trailing slash. For example, registering "/images/" causes ServeMux +// # Patterns +// +// Patterns can match the method, host and path of a request. +// Some examples: +// +// - "/index.html" matches the path "/index.html" for any host and method. +// - "GET /static/" matches a GET request whose path begins with "/static/". +// - "example.com/" matches any request to the host "example.com". +// - "example.com/{$}" matches requests with host "example.com" and path "/". +// - "/b/{bucket}/o/{objectname...}" matches paths whose first segment is "b" +// and whose third segment is "o". The name "bucket" denotes the second +// segment and "objectname" denotes the remainder of the path. +// +// In general, a pattern looks like +// +// [METHOD ][HOST]/[PATH] +// +// All three parts are optional; "/" is a valid pattern. +// If METHOD is present, it must be followed by a single space. +// +// Literal (that is, non-wildcard) parts of a pattern match +// the corresponding parts of a request case-sensitively. +// +// A pattern with no method matches every method. A pattern +// with the method GET matches both GET and HEAD requests. +// Otherwise, the method must match exactly. +// +// A pattern with no host matches every host. +// A pattern with a host matches URLs on that host only. +// +// A path can include wildcard segments of the form {NAME} or {NAME...}. +// For example, "/b/{bucket}/o/{objectname...}". +// The wildcard name must be a valid Go identifier. +// Wildcards must be full path segments: they must be preceded by a slash and followed by +// either a slash or the end of the string. +// For example, "/b_{bucket}" is not a valid pattern. +// +// Normally a wildcard matches only a single path segment, +// ending at the next literal slash (not %2F) in the request URL. +// But if the "..." is present, then the wildcard matches the remainder of the URL path, including slashes. +// (Therefore it is invalid for a "..." wildcard to appear anywhere but at the end of a pattern.) +// The match for a wildcard can be obtained by calling [Request.PathValue] with the wildcard's name. +// A trailing slash in a path acts as an anonymous "..." wildcard. +// +// The special wildcard {$} matches only the end of the URL. +// For example, the pattern "/{$}" matches only the path "/", +// whereas the pattern "/" matches every path. +// +// For matching, both pattern paths and incoming request paths are unescaped segment by segment. +// So, for example, the path "/a%2Fb/100%25" is treated as having two segments, "a/b" and "100%". +// The pattern "/a%2fb/" matches it, but the pattern "/a/b/" does not. +// +// # Precedence +// +// If two or more patterns match a request, then the most specific pattern takes precedence. +// A pattern P1 is more specific than P2 if P1 matches a strict subset of P2’s requests; +// that is, if P2 matches all the requests of P1 and more. +// If neither is more specific, then the patterns conflict. +// There is one exception to this rule, for backwards compatibility: +// if two patterns would otherwise conflict and one has a host while the other does not, +// then the pattern with the host takes precedence. +// If a pattern passed [ServeMux.Handle] or [ServeMux.HandleFunc] conflicts with +// another pattern that is already registered, those functions panic. +// +// As an example of the general rule, "/images/thumbnails/" is more specific than "/images/", +// so both can be registered. +// The former matches paths beginning with "/images/thumbnails/" +// and the latter will match any other path in the "/images/" subtree. +// +// As another example, consider the patterns "GET /" and "/index.html": +// both match a GET request for "/index.html", but the former pattern +// matches all other GET and HEAD requests, while the latter matches any +// request for "/index.html" that uses a different method. +// The patterns conflict. +// +// # Trailing-slash redirection +// +// Consider a [ServeMux] with a handler for a subtree, registered using a trailing slash or "..." wildcard. +// If the ServeMux receives a request for the subtree root without a trailing slash, +// it redirects the request by adding the trailing slash. +// This behavior can be overridden with a separate registration for the path without +// the trailing slash or "..." wildcard. For example, registering "/images/" causes ServeMux // to redirect a request for "/images" to "/images/", unless "/images" has // been registered separately. // -// Patterns may optionally begin with a host name, restricting matches to -// URLs on that host only. Host-specific patterns take precedence over -// general patterns, so that a handler might register for the two patterns -// "/codesearch" and "codesearch.google.com/" without also taking over -// requests for "http://www.google.com/". +// # Request sanitizing // // ServeMux also takes care of sanitizing the URL request path and the Host // header, stripping the port number and redirecting any request containing . or -// .. elements or repeated slashes to an equivalent, cleaner URL. +// .. segments or repeated slashes to an equivalent, cleaner URL. +// +// # Compatibility +// +// The pattern syntax and matching behavior of ServeMux changed significantly +// in Go 1.22. To restore the old behavior, set the GODEBUG environment variable +// to "httpmuxgo121=1". This setting is read once, at program startup; changes +// during execution will be ignored. +// +// The backwards-incompatible changes include: +// - Wildcards are just ordinary literal path segments in 1.21. +// For example, the pattern "/{x}" will match only that path in 1.21, +// but will match any one-segment path in 1.22. +// - In 1.21, no pattern was rejected, unless it was empty or conflicted with an existing pattern. +// In 1.22, syntactically invalid patterns will cause [ServeMux.Handle] and [ServeMux.HandleFunc] to panic. +// For example, in 1.21, the patterns "/{" and "/a{x}" match themselves, +// but in 1.22 they are invalid and will cause a panic when registered. +// - In 1.22, each segment of a pattern is unescaped; this was not done in 1.21. +// For example, in 1.22 the pattern "/%61" matches the path "/a" ("%61" being the URL escape sequence for "a"), +// but in 1.21 it would match only the path "/%2561" (where "%25" is the escape for the percent sign). +// - When matching patterns to paths, in 1.22 each segment of the path is unescaped; in 1.21, the entire path is unescaped. +// This change mostly affects how paths with %2F escapes adjacent to slashes are treated. +// See https://go.dev/issue/21955 for details. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry - es []muxEntry // slice of entries sorted from longest to shortest. - hosts bool // whether any patterns contain hostnames + mu sync.RWMutex + tree routingNode + index routingIndex + patterns []*pattern // TODO(jba): remove if possible + mux121 serveMux121 // used only when GODEBUG=httpmuxgo121=1 } -type muxEntry struct { - h Handler - pattern string +// NewServeMux allocates and returns a new [ServeMux]. +func NewServeMux() *ServeMux { + return &ServeMux{} } -// NewServeMux allocates and returns a new ServeMux. -func NewServeMux() *ServeMux { return new(ServeMux) } - -// DefaultServeMux is the default ServeMux used by Serve. +// DefaultServeMux is the default [ServeMux] used by [Serve]. var DefaultServeMux = &defaultServeMux var defaultServeMux ServeMux @@ -2372,66 +2482,6 @@ func stripHostPort(h string) string { return host } -// Find a handler on a handler map given a path string. -// Most-specific (longest) pattern wins. -func (mux *ServeMux) match(path string) (h Handler, pattern string) { - // Check for exact match first. - v, ok := mux.m[path] - if ok { - return v.h, v.pattern - } - - // Check for longest valid match. mux.es contains all patterns - // that end in / sorted from longest to shortest. - for _, e := range mux.es { - if strings.HasPrefix(path, e.pattern) { - return e.h, e.pattern - } - } - return nil, "" -} - -// redirectToPathSlash determines if the given path needs appending "/" to it. -// This occurs when a handler for path + "/" was already registered, but -// not for path itself. If the path needs appending to, it creates a new -// URL, setting the path to u.Path + "/" and returning true to indicate so. -func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) { - mux.mu.RLock() - shouldRedirect := mux.shouldRedirectRLocked(host, path) - mux.mu.RUnlock() - if !shouldRedirect { - return u, false - } - path = path + "/" - u = &url.URL{Path: path, RawQuery: u.RawQuery} - return u, true -} - -// shouldRedirectRLocked reports whether the given path and host should be redirected to -// path+"/". This should happen if a handler is registered for path+"/" but -// not path -- see comments at ServeMux. -func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { - p := []string{path, host + path} - - for _, c := range p { - if _, exist := mux.m[c]; exist { - return false - } - } - - n := len(path) - if n == 0 { - return false - } - for _, c := range p { - if _, exist := mux.m[c+"/"]; exist { - return path[n-1] != '/' - } - } - - return false -} - // Handler returns the handler to use for the given request, // consulting r.Method, r.Host, and r.URL.Path. It always returns // a non-nil handler. If the path is not in its canonical form, the @@ -2443,61 +2493,175 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { // // Handler also returns the registered pattern that matches the // request or, in the case of internally-generated redirects, -// the pattern that will match after following the redirect. +// the path that will match after following the redirect. // // If there is no registered handler that applies to the request, // Handler returns a “page not found” handler and an empty pattern. func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { - + if use121 { + return mux.mux121.findHandler(r) + } + h, p, _, _ := mux.findHandler(r) + return h, p +} + +// findHandler finds a handler for a request. +// If there is a matching handler, it returns it and the pattern that matched. +// Otherwise it returns a Redirect or NotFound handler with the path that would match +// after the redirect. +func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *pattern, matches []string) { + var n *routingNode + host := r.URL.Host + escapedPath := r.URL.EscapedPath() + path := escapedPath // CONNECT requests are not canonicalized. if r.Method == "CONNECT" { // If r.URL.Path is /tree and its handler is not registered, // the /tree -> /tree/ redirect applies to CONNECT requests // but the path canonicalization does not. - if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok { - return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + _, _, u := mux.matchOrRedirect(host, r.Method, path, r.URL) + if u != nil { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil + } + // Redo the match, this time with r.Host instead of r.URL.Host. + // Pass a nil URL to skip the trailing-slash redirect logic. + n, matches, _ = mux.matchOrRedirect(r.Host, r.Method, path, nil) + } else { + // All other requests have any port stripped and path cleaned + // before passing to mux.handler. + host = stripHostPort(r.Host) + path = cleanPath(path) + + // If the given path is /tree and its handler is not registered, + // redirect for /tree/. + var u *url.URL + n, matches, u = mux.matchOrRedirect(host, r.Method, path, r.URL) + if u != nil { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil + } + if path != escapedPath { + // Redirect to cleaned path. + patStr := "" + if n != nil { + patStr = n.pattern.String() + } + u := &url.URL{Path: path, RawQuery: r.URL.RawQuery} + return RedirectHandler(u.String(), StatusMovedPermanently), patStr, nil, nil } - - return mux.handler(r.Host, r.URL.Path) } + if n == nil { + // We didn't find a match with the request method. To distinguish between + // Not Found and Method Not Allowed, see if there is another pattern that + // matches except for the method. + allowedMethods := mux.matchingMethods(host, path) + if len(allowedMethods) > 0 { + return HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + Error(w, StatusText(StatusMethodNotAllowed), StatusMethodNotAllowed) + }), "", nil, nil + } + return NotFoundHandler(), "", nil, nil + } + return n.handler, n.pattern.String(), n.pattern, matches +} - // All other requests have any port stripped and path cleaned - // before passing to mux.handler. - host := stripHostPort(r.Host) - path := cleanPath(r.URL.Path) +// matchOrRedirect looks up a node in the tree that matches the host, method and path. +// +// If the url argument is non-nil, handler also deals with trailing-slash +// redirection: when a path doesn't match exactly, the match is tried again +// after appending "/" to the path. If that second match succeeds, the last +// return value is the URL to redirect to. +func (mux *ServeMux) matchOrRedirect(host, method, path string, u *url.URL) (_ *routingNode, matches []string, redirectTo *url.URL) { + mux.mu.RLock() + defer mux.mu.RUnlock() - // If the given path is /tree and its handler is not registered, - // redirect for /tree/. - if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok { - return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + n, matches := mux.tree.match(host, method, path) + // If we have an exact match, or we were asked not to try trailing-slash redirection, + // then we're done. + if !exactMatch(n, path) && u != nil { + // If there is an exact match with a trailing slash, then redirect. + path += "/" + n2, _ := mux.tree.match(host, method, path) + if exactMatch(n2, path) { + return nil, nil, &url.URL{Path: cleanPath(u.Path) + "/", RawQuery: u.RawQuery} + } } + return n, matches, nil +} - if path != r.URL.Path { - _, pattern = mux.handler(host, path) - u := &url.URL{Path: path, RawQuery: r.URL.RawQuery} - return RedirectHandler(u.String(), StatusMovedPermanently), pattern +// exactMatch reports whether the node's pattern exactly matches the path. +// As a special case, if the node is nil, exactMatch return false. +// +// Before wildcards were introduced, it was clear that an exact match meant +// that the pattern and path were the same string. The only other possibility +// was that a trailing-slash pattern, like "/", matched a path longer than +// it, like "/a". +// +// With wildcards, we define an inexact match as any one where a multi wildcard +// matches a non-empty string. All other matches are exact. +// For example, these are all exact matches: +// +// pattern path +// /a /a +// /{x} /a +// /a/{$} /a/ +// /a/ /a/ +// +// The last case has a multi wildcard (implicitly), but the match is exact because +// the wildcard matches the empty string. +// +// Examples of matches that are not exact: +// +// pattern path +// / /a +// /a/{x...} /a/b +func exactMatch(n *routingNode, path string) bool { + if n == nil { + return false } + // We can't directly implement the definition (empty match for multi + // wildcard) because we don't record a match for anonymous multis. - return mux.handler(host, r.URL.Path) + // If there is no multi, the match is exact. + if !n.pattern.lastSegment().multi { + return true + } + + // If the path doesn't end in a trailing slash, then the multi match + // is non-empty. + if len(path) > 0 && path[len(path)-1] != '/' { + return false + } + // Only patterns ending in {$} or a multi wildcard can + // match a path with a trailing slash. + // For the match to be exact, the number of pattern + // segments should be the same as the number of slashes in the path. + // E.g. "/a/b/{$}" and "/a/b/{...}" exactly match "/a/b/", but "/a/" does not. + return len(n.pattern.segments) == strings.Count(path, "/") } -// handler is the main implementation of Handler. -// The path is known to be in canonical form, except for CONNECT methods. -func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { +// matchingMethods return a sorted list of all methods that would match with the given host and path. +func (mux *ServeMux) matchingMethods(host, path string) []string { + // Hold the read lock for the entire method so that the two matches are done + // on the same set of registered patterns. mux.mu.RLock() defer mux.mu.RUnlock() + ms := map[string]bool{} + mux.tree.matchingMethods(host, path, ms) + // matchOrRedirect will try appending a trailing slash if there is no match. + mux.tree.matchingMethods(host, path+"/", ms) + methods := mapKeys(ms) + sort.Strings(methods) + return methods +} - // Host-specific pattern takes precedence over generic ones - if mux.hosts { - h, pattern = mux.match(host + path) - } - if h == nil { - h, pattern = mux.match(path) - } - if h == nil { - h, pattern = NotFoundHandler(), "" +// TODO(jba): replace with maps.Keys when it is defined. +func mapKeys[K comparable, V any](m map[K]V) []K { + var ks []K + for k := range m { + ks = append(ks, k) } - return + return ks } // ServeHTTP dispatches the request to the handler whose @@ -2510,82 +2674,117 @@ func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { w.WriteHeader(StatusBadRequest) return } - h, _ := mux.Handler(r) + var h Handler + if use121 { + h, _ = mux.mux121.findHandler(r) + } else { + h, _, r.pat, r.matches = mux.findHandler(r) + } h.ServeHTTP(w, r) } +// The four functions below all call ServeMux.register so that callerLocation +// always refers to user code. + // Handle registers the handler for the given pattern. -// If a handler already exists for pattern, Handle panics. +// If the given pattern conflicts, with one that is already registered, Handle +// panics. func (mux *ServeMux) Handle(pattern string, handler Handler) { - mux.mu.Lock() - defer mux.mu.Unlock() - - if pattern == "" { - panic("http: invalid pattern") - } - if handler == nil { - panic("http: nil handler") - } - if _, exist := mux.m[pattern]; exist { - panic("http: multiple registrations for " + pattern) + if use121 { + mux.mux121.handle(pattern, handler) + } else { + mux.register(pattern, handler) } +} - if mux.m == nil { - mux.m = make(map[string]muxEntry) +// HandleFunc registers the handler function for the given pattern. +// If the given pattern conflicts, with one that is already registered, HandleFunc +// panics. +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + if use121 { + mux.mux121.handleFunc(pattern, handler) + } else { + mux.register(pattern, HandlerFunc(handler)) } - e := muxEntry{h: handler, pattern: pattern} - mux.m[pattern] = e - if pattern[len(pattern)-1] == '/' { - mux.es = appendSorted(mux.es, e) +} + +// Handle registers the handler for the given pattern in [DefaultServeMux]. +// The documentation for [ServeMux] explains how patterns are matched. +func Handle(pattern string, handler Handler) { + if use121 { + DefaultServeMux.mux121.handle(pattern, handler) + } else { + DefaultServeMux.register(pattern, handler) } +} - if pattern[0] != '/' { - mux.hosts = true +// HandleFunc registers the handler function for the given pattern in [DefaultServeMux]. +// The documentation for [ServeMux] explains how patterns are matched. +func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + if use121 { + DefaultServeMux.mux121.handleFunc(pattern, handler) + } else { + DefaultServeMux.register(pattern, HandlerFunc(handler)) } } -func appendSorted(es []muxEntry, e muxEntry) []muxEntry { - n := len(es) - i := sort.Search(n, func(i int) bool { - return len(es[i].pattern) < len(e.pattern) - }) - if i == n { - return append(es, e) +func (mux *ServeMux) register(pattern string, handler Handler) { + if err := mux.registerErr(pattern, handler); err != nil { + panic(err) } - // we now know that i points at where we want to insert - es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. - copy(es[i+1:], es[i:]) // Move shorter entries down - es[i] = e - return es } -// HandleFunc registers the handler function for the given pattern. -func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { +func (mux *ServeMux) registerErr(patstr string, handler Handler) error { + if patstr == "" { + return errors.New("http: invalid pattern") + } if handler == nil { - panic("http: nil handler") + return errors.New("http: nil handler") + } + if f, ok := handler.(HandlerFunc); ok && f == nil { + return errors.New("http: nil handler") } - mux.Handle(pattern, HandlerFunc(handler)) -} -// Handle registers the handler for the given pattern -// in the DefaultServeMux. -// The documentation for ServeMux explains how patterns are matched. -func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } + pat, err := parsePattern(patstr) + if err != nil { + return fmt.Errorf("parsing %q: %w", patstr, err) + } -// HandleFunc registers the handler function for the given pattern -// in the DefaultServeMux. -// The documentation for ServeMux explains how patterns are matched. -func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { - DefaultServeMux.HandleFunc(pattern, handler) + // Get the caller's location, for better conflict error messages. + // Skip register and whatever calls it. + _, file, line, ok := runtime.Caller(3) + if !ok { + pat.loc = "unknown location" + } else { + pat.loc = fmt.Sprintf("%s:%d", file, line) + } + + mux.mu.Lock() + defer mux.mu.Unlock() + // Check for conflict. + if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error { + if pat.conflictsWith(pat2) { + d := describeConflict(pat, pat2) + return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s):\n%s", + pat, pat.loc, pat2, pat2.loc, d) + } + return nil + }); err != nil { + return err + } + mux.tree.addPattern(pat, handler) + mux.index.addPattern(pat) + mux.patterns = append(mux.patterns, pat) + return nil } // Serve accepts incoming HTTP connections on the listener l, // creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // -// The handler is typically nil, in which case the DefaultServeMux is used. +// The handler is typically nil, in which case [DefaultServeMux] is used. // -// HTTP/2 support is only enabled if the Listener returns *tls.Conn +// HTTP/2 support is only enabled if the Listener returns [*tls.Conn] // connections and they were configured with "h2" in the TLS // Config.NextProtos. // @@ -2599,7 +2798,7 @@ func Serve(l net.Listener, handler Handler) error { // creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // -// The handler is typically nil, in which case the DefaultServeMux is used. +// The handler is typically nil, in which case [DefaultServeMux] is used. // // Additionally, files containing a certificate and matching private key // for the server must be provided. If the certificate is signed by a @@ -2725,13 +2924,13 @@ type Server struct { } // Close immediately closes all active net.Listeners and any -// connections in state StateNew, StateActive, or StateIdle. For a -// graceful shutdown, use Shutdown. +// connections in state [StateNew], [StateActive], or [StateIdle]. For a +// graceful shutdown, use [Server.Shutdown]. // // Close does not attempt to close (and does not even know about) // any hijacked connections, such as WebSockets. // -// Close returns any error returned from closing the Server's +// Close returns any error returned from closing the [Server]'s // underlying Listener(s). func (srv *Server) Close() error { srv.inShutdown.Store(true) @@ -2769,16 +2968,16 @@ const shutdownPollIntervalMax = 500 * time.Millisecond // indefinitely for connections to return to idle and then shut down. // If the provided context expires before the shutdown is complete, // Shutdown returns the context's error, otherwise it returns any -// error returned from closing the Server's underlying Listener(s). +// error returned from closing the [Server]'s underlying Listener(s). // -// When Shutdown is called, Serve, ListenAndServe, and -// ListenAndServeTLS immediately return ErrServerClosed. Make sure the +// When Shutdown is called, [Serve], [ListenAndServe], and +// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the // program doesn't exit and waits instead for Shutdown to return. // // Shutdown does not attempt to close nor wait for hijacked // connections such as WebSockets. The caller of Shutdown should // separately notify such long-lived connections of shutdown and wait -// for them to close, if desired. See RegisterOnShutdown for a way to +// for them to close, if desired. See [Server.RegisterOnShutdown] for a way to // register shutdown notification functions. // // Once Shutdown has been called on a server, it may not be reused; @@ -2821,7 +3020,7 @@ func (srv *Server) Shutdown(ctx context.Context) error { } } -// RegisterOnShutdown registers a function to call on Shutdown. +// RegisterOnShutdown registers a function to call on [Server.Shutdown]. // This can be used to gracefully shutdown connections that have // undergone ALPN protocol upgrade or that have been hijacked. // This function should start protocol-specific graceful shutdown, @@ -2869,7 +3068,7 @@ func (s *Server) closeListenersLocked() error { } // A ConnState represents the state of a client connection to a server. -// It's used by the optional Server.ConnState hook. +// It's used by the optional [Server.ConnState] hook. type ConnState int const ( @@ -2946,7 +3145,7 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { // behavior doesn't match that of many proxies, and the mismatch can lead to // security issues. // -// AllowQuerySemicolons should be invoked before Request.ParseForm is called. +// AllowQuerySemicolons should be invoked before [Request.ParseForm] is called. func AllowQuerySemicolons(h Handler) Handler { return HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.RawQuery, ";") { @@ -2963,13 +3162,13 @@ func AllowQuerySemicolons(h Handler) Handler { } // ListenAndServe listens on the TCP network address srv.Addr and then -// calls Serve to handle requests on incoming connections. +// calls [Serve] to handle requests on incoming connections. // Accepted connections are configured to enable TCP keep-alives. // // If srv.Addr is blank, ":http" is used. // -// ListenAndServe always returns a non-nil error. After Shutdown or Close, -// the returned error is ErrServerClosed. +// ListenAndServe always returns a non-nil error. After [Server.Shutdown] or [Server.Close], +// the returned error is [ErrServerClosed]. func (srv *Server) ListenAndServe() error { if srv.shuttingDown() { return ErrServerClosed @@ -3009,20 +3208,20 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) } -// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe, -// and ListenAndServeTLS methods after a call to Shutdown or Close. +// ErrServerClosed is returned by the [Server.Serve], [ServeTLS], [ListenAndServe], +// and [ListenAndServeTLS] methods after a call to [Server.Shutdown] or [Server.Close]. var ErrServerClosed = errors.New("http: Server closed") // Serve accepts incoming connections on the Listener l, creating a // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. // -// HTTP/2 support is only enabled if the Listener returns *tls.Conn +// HTTP/2 support is only enabled if the Listener returns [*tls.Conn] // connections and they were configured with "h2" in the TLS // Config.NextProtos. // // Serve always returns a non-nil error and closes l. -// After Shutdown or Close, the returned error is ErrServerClosed. +// After [Server.Shutdown] or [Server.Close], the returned error is [ErrServerClosed]. func (srv *Server) Serve(l net.Listener) error { if fn := testHookServerServe; fn != nil { fn(srv, l) // call hook with unwrapped listener @@ -3092,14 +3291,14 @@ func (srv *Server) Serve(l net.Listener) error { // setup and then read requests, calling srv.Handler to reply to them. // // Files containing a certificate and matching private key for the -// server must be provided if neither the Server's +// server must be provided if neither the [Server]'s // TLSConfig.Certificates nor TLSConfig.GetCertificate are populated. // If the certificate is signed by a certificate authority, the // certFile should be the concatenation of the server's certificate, // any intermediates, and the CA's certificate. // -// ServeTLS always returns a non-nil error. After Shutdown or Close, the -// returned error is ErrServerClosed. +// ServeTLS always returns a non-nil error. After [Server.Shutdown] or [Server.Close], the +// returned error is [ErrServerClosed]. func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error { // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig // before we clone it and create the TLS Listener. @@ -3228,10 +3427,10 @@ func logf(r *Request, format string, args ...any) { } // ListenAndServe listens on the TCP network address addr and then calls -// Serve with handler to handle requests on incoming connections. +// [Serve] with handler to handle requests on incoming connections. // Accepted connections are configured to enable TCP keep-alives. // -// The handler is typically nil, in which case the DefaultServeMux is used. +// The handler is typically nil, in which case [DefaultServeMux] is used. // // ListenAndServe always returns a non-nil error. func ListenAndServe(addr string, handler Handler) error { @@ -3239,7 +3438,7 @@ func ListenAndServe(addr string, handler Handler) error { return server.ListenAndServe() } -// ListenAndServeTLS acts identically to ListenAndServe, except that it +// ListenAndServeTLS acts identically to [ListenAndServe], except that it // expects HTTPS connections. Additionally, files containing a certificate and // matching private key for the server must be provided. If the certificate // is signed by a certificate authority, the certFile should be the concatenation @@ -3250,11 +3449,11 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { } // ListenAndServeTLS listens on the TCP network address srv.Addr and -// then calls ServeTLS to handle requests on incoming TLS connections. +// then calls [ServeTLS] to handle requests on incoming TLS connections. // Accepted connections are configured to enable TCP keep-alives. // // Filenames containing a certificate and matching private key for the -// server must be provided if neither the Server's TLSConfig.Certificates +// server must be provided if neither the [Server]'s TLSConfig.Certificates // nor TLSConfig.GetCertificate are populated. If the certificate is // signed by a certificate authority, the certFile should be the // concatenation of the server's certificate, any intermediates, and @@ -3262,8 +3461,8 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { // // If srv.Addr is blank, ":https" is used. // -// ListenAndServeTLS always returns a non-nil error. After Shutdown or -// Close, the returned error is ErrServerClosed. +// ListenAndServeTLS always returns a non-nil error. After [Server.Shutdown] or +// [Server.Close], the returned error is [ErrServerClosed]. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { if srv.shuttingDown() { return ErrServerClosed @@ -3333,17 +3532,17 @@ func (srv *Server) onceSetNextProtoDefaults() { } } -// TimeoutHandler returns a Handler that runs h with the given time limit. +// TimeoutHandler returns a [Handler] that runs h with the given time limit. // // The new Handler calls h.ServeHTTP to handle each request, but if a // call runs for longer than its time limit, the handler responds with // a 503 Service Unavailable error and the given message in its body. // (If msg is empty, a suitable default message will be sent.) -// After such a timeout, writes by h to its ResponseWriter will return -// ErrHandlerTimeout. +// After such a timeout, writes by h to its [ResponseWriter] will return +// [ErrHandlerTimeout]. // -// TimeoutHandler supports the Pusher interface but does not support -// the Hijacker or Flusher interfaces. +// TimeoutHandler supports the [Pusher] interface but does not support +// the [Hijacker] or [Flusher] interfaces. func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { return &timeoutHandler{ handler: h, @@ -3352,7 +3551,7 @@ func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { } } -// ErrHandlerTimeout is returned on ResponseWriter Write calls +// ErrHandlerTimeout is returned on [ResponseWriter] Write calls // in handlers which have timed out. var ErrHandlerTimeout = errors.New("http: Handler timeout") @@ -3441,7 +3640,7 @@ type timeoutWriter struct { var _ Pusher = (*timeoutWriter)(nil) -// Push implements the Pusher interface. +// Push implements the [Pusher] interface. func (tw *timeoutWriter) Push(target string, opts *PushOptions) error { if pusher, ok := tw.w.(Pusher); ok { return pusher.Push(target, opts) @@ -3526,7 +3725,7 @@ type initALPNRequest struct { h serverHandler } -// BaseContext is an exported but unadvertised http.Handler method +// BaseContext is an exported but unadvertised [http.Handler] method // recognized by x/net/http2 to pass down a context; the TLSNextProto // API predates context support so we shoehorn through the only // interface we have available. @@ -3613,7 +3812,6 @@ func numLeadingCRorLF(v []byte) (n int) { break } return - } func strSliceContains(ss []string, s string) bool { @@ -3635,7 +3833,7 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { return false } -// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader. +// MaxBytesHandler returns a [Handler] that runs h with its [ResponseWriter] and [Request.Body] wrapped by a MaxBytesReader. func MaxBytesHandler(h Handler, n int64) Handler { return HandlerFunc(func(w ResponseWriter, r *Request) { r2 := *r diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go index d17c5c1e7e..e81e3bb6b0 100644 --- a/src/net/http/server_test.go +++ b/src/net/http/server_test.go @@ -8,6 +8,8 @@ package http import ( "fmt" + "net/url" + "regexp" "testing" "time" ) @@ -64,6 +66,190 @@ func TestServerTLSHandshakeTimeout(t *testing.T) { } } +type handler struct{ i int } + +func (handler) ServeHTTP(ResponseWriter, *Request) {} + +func TestFindHandler(t *testing.T) { + mux := NewServeMux() + for _, ph := range []struct { + pat string + h Handler + }{ + {"/", &handler{1}}, + {"/foo/", &handler{2}}, + {"/foo", &handler{3}}, + {"/bar/", &handler{4}}, + {"//foo", &handler{5}}, + } { + mux.Handle(ph.pat, ph.h) + } + + for _, test := range []struct { + method string + path string + wantHandler string + }{ + {"GET", "/", "&http.handler{i:1}"}, + {"GET", "//", `&http.redirectHandler{url:"/", code:301}`}, + {"GET", "/foo/../bar/./..//baz", `&http.redirectHandler{url:"/baz", code:301}`}, + {"GET", "/foo", "&http.handler{i:3}"}, + {"GET", "/foo/x", "&http.handler{i:2}"}, + {"GET", "/bar/x", "&http.handler{i:4}"}, + {"GET", "/bar", `&http.redirectHandler{url:"/bar/", code:301}`}, + {"CONNECT", "/", "&http.handler{i:1}"}, + {"CONNECT", "//", "&http.handler{i:1}"}, + {"CONNECT", "//foo", "&http.handler{i:5}"}, + {"CONNECT", "/foo/../bar/./..//baz", "&http.handler{i:2}"}, + {"CONNECT", "/foo", "&http.handler{i:3}"}, + {"CONNECT", "/foo/x", "&http.handler{i:2}"}, + {"CONNECT", "/bar/x", "&http.handler{i:4}"}, + {"CONNECT", "/bar", `&http.redirectHandler{url:"/bar/", code:301}`}, + } { + var r Request + r.Method = test.method + r.Host = "example.com" + r.URL = &url.URL{Path: test.path} + gotH, _, _, _ := mux.findHandler(&r) + got := fmt.Sprintf("%#v", gotH) + if got != test.wantHandler { + t.Errorf("%s %q: got %q, want %q", test.method, test.path, got, test.wantHandler) + } + } +} + +func TestEmptyServeMux(t *testing.T) { + // Verify that a ServeMux with nothing registered + // doesn't panic. + mux := NewServeMux() + var r Request + r.Method = "GET" + r.Host = "example.com" + r.URL = &url.URL{Path: "/"} + _, p := mux.Handler(&r) + if p != "" { + t.Errorf(`got %q, want ""`, p) + } +} + +func TestRegisterErr(t *testing.T) { + mux := NewServeMux() + h := &handler{} + mux.Handle("/a", h) + + for _, test := range []struct { + pattern string + handler Handler + wantRegexp string + }{ + {"", h, "invalid pattern"}, + {"/", nil, "nil handler"}, + {"/", HandlerFunc(nil), "nil handler"}, + {"/{x", h, `parsing "/\{x": at offset 1: bad wildcard segment`}, + {"/a", h, `conflicts with pattern.* \(registered at .*/server_test.go:\d+`}, + } { + t.Run(fmt.Sprintf("%s:%#v", test.pattern, test.handler), func(t *testing.T) { + err := mux.registerErr(test.pattern, test.handler) + if err == nil { + t.Fatal("got nil error") + } + re := regexp.MustCompile(test.wantRegexp) + if g := err.Error(); !re.MatchString(g) { + t.Errorf("\ngot %q\nwant string matching %q", g, test.wantRegexp) + } + }) + } +} + +func TestExactMatch(t *testing.T) { + for _, test := range []struct { + pattern string + path string + want bool + }{ + {"", "/a", false}, + {"/", "/a", false}, + {"/a", "/a", true}, + {"/a/{x...}", "/a/b", false}, + {"/a/{x}", "/a/b", true}, + {"/a/b/", "/a/b/", true}, + {"/a/b/{$}", "/a/b/", true}, + {"/a/", "/a/b/", false}, + } { + var n *routingNode + if test.pattern != "" { + pat := mustParsePattern(t, test.pattern) + n = &routingNode{pattern: pat} + } + got := exactMatch(n, test.path) + if got != test.want { + t.Errorf("%q, %s: got %t, want %t", test.pattern, test.path, got, test.want) + } + } +} + +func TestEscapedPathsAndPatterns(t *testing.T) { + matches := []struct { + pattern string + paths []string // paths that match the pattern + paths121 []string // paths that matched the pattern in Go 1.21. + }{ + { + "/a", // this pattern matches a path that unescapes to "/a" + []string{"/a", "/%61"}, + []string{"/a", "/%61"}, + }, + { + "/%62", // patterns are unescaped by segment; matches paths that unescape to "/b" + []string{"/b", "/%62"}, + []string{"/%2562"}, // In 1.21, patterns were not unescaped but paths were. + }, + { + "/%7B/%7D", // the only way to write a pattern that matches '{' or '}' + []string{"/{/}", "/%7b/}", "/{/%7d", "/%7B/%7D"}, + []string{"/%257B/%257D"}, // In 1.21, patterns were not unescaped. + }, + { + "/%x", // patterns that do not unescape are left unchanged + []string{"/%25x"}, + []string{"/%25x"}, + }, + } + + run := func(t *testing.T, test121 bool) { + defer func(u bool) { use121 = u }(use121) + use121 = test121 + + mux := NewServeMux() + for _, m := range matches { + mux.HandleFunc(m.pattern, func(w ResponseWriter, r *Request) {}) + } + + for _, m := range matches { + paths := m.paths + if use121 { + paths = m.paths121 + } + for _, p := range paths { + u, err := url.ParseRequestURI(p) + if err != nil { + t.Fatal(err) + } + req := &Request{ + URL: u, + } + _, gotPattern := mux.Handler(req) + if g, w := gotPattern, m.pattern; g != w { + t.Errorf("%s: pattern: got %q, want %q", p, g, w) + } + } + } + } + + t.Run("latest", func(t *testing.T) { run(t, false) }) + t.Run("1.21", func(t *testing.T) { run(t, true) }) +} + func BenchmarkServerMatch(b *testing.B) { fn := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "OK") @@ -90,7 +276,11 @@ func BenchmarkServerMatch(b *testing.B) { "/products/", "/products/3/image.jpg"} b.StartTimer() for i := 0; i < b.N; i++ { - if h, p := mux.match(paths[i%len(paths)]); h != nil && p == "" { + r, err := NewRequest("GET", "http://example.com/"+paths[i%len(paths)], nil) + if err != nil { + b.Fatal(err) + } + if h, p, _, _ := mux.findHandler(r); h != nil && p == "" { b.Error("impossible") } } diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go index d6f26a709c..315c6e2723 100644 --- a/src/net/http/transfer.go +++ b/src/net/http/transfer.go @@ -9,6 +9,7 @@ import ( "bytes" "errors" "fmt" + "internal/godebug" "io" "net/http/httptrace" "net/http/internal" @@ -409,7 +410,10 @@ func (t *transferWriter) writeBody(w io.Writer) (err error) { // // This function is only intended for use in writeBody. func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { - n, err = io.Copy(dst, src) + buf := getCopyBuf() + defer putCopyBuf(buf) + + n, err = io.CopyBuffer(dst, src, buf) if err != nil && err != io.EOF { t.bodyReadError = err } @@ -527,7 +531,7 @@ func readTransfer(msg any, r *bufio.Reader) (err error) { return err } if isResponse && t.RequestMethod == "HEAD" { - if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + if n, err := parseContentLength(t.Header["Content-Length"]); err != nil { return err } else { t.ContentLength = n @@ -707,18 +711,15 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return -1, nil } - // Logic based on Content-Length - var cl string - if len(contentLens) == 1 { - cl = textproto.TrimString(contentLens[0]) - } - if cl != "" { - n, err := parseContentLength(cl) + if len(contentLens) > 0 { + // Logic based on Content-Length + n, err := parseContentLength(contentLens) if err != nil { return -1, err } return n, nil } + header.Del("Content-Length") if isRequest { @@ -816,10 +817,10 @@ type body struct { onHitEOF func() // if non-nil, func to call when EOF is Read } -// ErrBodyReadAfterClose is returned when reading a Request or Response +// ErrBodyReadAfterClose is returned when reading a [Request] or [Response] // Body after the body has been closed. This typically happens when the body is -// read after an HTTP Handler calls WriteHeader or Write on its -// ResponseWriter. +// read after an HTTP [Handler] calls WriteHeader or Write on its +// [ResponseWriter]. var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") func (b *body) Read(p []byte) (n int, err error) { @@ -1038,19 +1039,31 @@ func (bl bodyLocked) Read(p []byte) (n int, err error) { return bl.b.readLocked(p) } -// parseContentLength trims whitespace from s and returns -1 if no value -// is set, or the value if it's >= 0. -func parseContentLength(cl string) (int64, error) { - cl = textproto.TrimString(cl) - if cl == "" { +var laxContentLength = godebug.New("httplaxcontentlength") + +// parseContentLength checks that the header is valid and then trims +// whitespace. It returns -1 if no value is set otherwise the value +// if it's >= 0. +func parseContentLength(clHeaders []string) (int64, error) { + if len(clHeaders) == 0 { return -1, nil } + cl := textproto.TrimString(clHeaders[0]) + + // The Content-Length must be a valid numeric value. + // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 + if cl == "" { + if laxContentLength.Value() == "1" { + laxContentLength.IncNonDefault() + return -1, nil + } + return 0, badStringError("invalid empty Content-Length", cl) + } n, err := strconv.ParseUint(cl, 10, 63) if err != nil { return 0, badStringError("bad Content-Length", cl) } return int64(n), nil - } // finishAsyncByteRead finishes reading the 1-byte sniff diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go index 5e0df896d8..b1a5a93103 100644 --- a/src/net/http/transfer_test.go +++ b/src/net/http/transfer_test.go @@ -112,8 +112,8 @@ func (w *mockTransferWriter) Write(p []byte) (int, error) { } func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { - fileType := reflect.TypeOf(&os.File{}) - bufferType := reflect.TypeOf(&bytes.Buffer{}) + fileType := reflect.TypeFor[*os.File]() + bufferType := reflect.TypeFor[*bytes.Buffer]() nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { @@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { actualReader = reflect.TypeOf(lr.R) } else { actualReader = reflect.TypeOf(mw.CalledReader) + // We have to handle this special case for genericWriteTo in os, + // this struct is introduced to support a zero-copy optimization, + // check out https://go.dev/issue/58808 for details. + if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" { + actualReader = actualReader.Field(1).Type + } } if tc.expectedReader != actualReader { @@ -333,6 +339,10 @@ func TestParseContentLength(t *testing.T) { wantErr error }{ { + cl: "", + wantErr: badStringError("invalid empty Content-Length", ""), + }, + { cl: "3", wantErr: nil, }, @@ -356,7 +366,7 @@ func TestParseContentLength(t *testing.T) { } for _, tt := range tests { - if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) { + if _, gotErr := parseContentLength([]string{tt.cl}); !reflect.DeepEqual(gotErr, tt.wantErr) { t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr) } } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index c07352b018..17067ac07c 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -35,8 +35,8 @@ import ( "golang.org/x/net/http/httpproxy" ) -// DefaultTransport is the default implementation of Transport and is -// used by DefaultClient. It establishes network connections as needed +// DefaultTransport is the default implementation of [Transport] and is +// used by [DefaultClient]. It establishes network connections as needed // and caches them for reuse by subsequent calls. It uses HTTP proxies // as directed by the environment variables HTTP_PROXY, HTTPS_PROXY // and NO_PROXY (or the lowercase versions thereof). @@ -53,42 +53,42 @@ var DefaultTransport RoundTripper = &Transport{ ExpectContinueTimeout: 1 * time.Second, } -// DefaultMaxIdleConnsPerHost is the default value of Transport's +// DefaultMaxIdleConnsPerHost is the default value of [Transport]'s // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -// Transport is an implementation of RoundTripper that supports HTTP, +// Transport is an implementation of [RoundTripper] that supports HTTP, // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). // // By default, Transport caches connections for future re-use. // This may leave many open connections when accessing many hosts. -// This behavior can be managed using Transport's CloseIdleConnections method -// and the MaxIdleConnsPerHost and DisableKeepAlives fields. +// This behavior can be managed using [Transport.CloseIdleConnections] method +// and the [Transport.MaxIdleConnsPerHost] and [Transport.DisableKeepAlives] fields. // // Transports should be reused instead of created as needed. // Transports are safe for concurrent use by multiple goroutines. // // A Transport is a low-level primitive for making HTTP and HTTPS requests. -// For high-level functionality, such as cookies and redirects, see Client. +// For high-level functionality, such as cookies and redirects, see [Client]. // // Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2 // for HTTPS URLs, depending on whether the server supports HTTP/2, -// and how the Transport is configured. The DefaultTransport supports HTTP/2. +// and how the Transport is configured. The [DefaultTransport] supports HTTP/2. // To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2 // and call ConfigureTransport. See the package docs for more about HTTP/2. // // Responses with status codes in the 1xx range are either handled // automatically (100 expect-continue) or ignored. The one // exception is HTTP status code 101 (Switching Protocols), which is -// considered a terminal status and returned by RoundTrip. To see the +// considered a terminal status and returned by [Transport.RoundTrip]. To see the // ignored 1xx responses, use the httptrace trace package's // ClientTrace.Got1xxResponse. // // Transport only retries a request upon encountering a network error // if the connection has been already been used successfully and if the -// request is idempotent and either has no body or has its Request.GetBody +// request is idempotent and either has no body or has its [Request.GetBody] // defined. HTTP requests are considered idempotent if they have HTTP methods -// GET, HEAD, OPTIONS, or TRACE; or if their Header map contains an +// GET, HEAD, OPTIONS, or TRACE; or if their [Header] map contains an // "Idempotency-Key" or "X-Idempotency-Key" entry. If the idempotency key // value is a zero-length slice, the request is treated as idempotent but the // header is not sent on the wire. @@ -117,6 +117,10 @@ type Transport struct { // "https", and "socks5" are supported. If the scheme is empty, // "http" is assumed. // + // If the proxy URL contains a userinfo subcomponent, + // the proxy request will pass the username and password + // in a Proxy-Authorization header. + // // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*Request) (*url.URL, error) @@ -233,7 +237,7 @@ type Transport struct { // TLSNextProto specifies how the Transport switches to an // alternate protocol (such as HTTP/2) after a TLS ALPN - // protocol negotiation. If Transport dials an TLS connection + // protocol negotiation. If Transport dials a TLS connection // with a non-empty protocol name and TLSNextProto contains a // map entry for that key (such as "h2"), then the func is // called with the request's authority (such as "example.com" @@ -449,7 +453,7 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) { return envProxyFunc()(req.URL) } -// ProxyURL returns a proxy function (for use in a Transport) +// ProxyURL returns a proxy function (for use in a [Transport]) // that always returns the same URL. func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { return func(*Request) (*url.URL, error) { @@ -748,14 +752,14 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") // RegisterProtocol registers a new protocol with scheme. -// The Transport will pass requests using the given scheme to rt. +// The [Transport] will pass requests using the given scheme to rt. // It is rt's responsibility to simulate HTTP request semantics. // // RegisterProtocol can be used by other packages to provide // implementations of protocol schemes like "ftp" or "file". // -// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will -// handle the RoundTrip itself for that one request, as if the +// If rt.RoundTrip returns [ErrSkipAltProtocol], the Transport will +// handle the [Transport.RoundTrip] itself for that one request, as if the // protocol were not registered. func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { t.altMu.Lock() @@ -795,9 +799,9 @@ func (t *Transport) CloseIdleConnections() { } // CancelRequest cancels an in-flight request by closing its connection. -// CancelRequest should only be called after RoundTrip has returned. +// CancelRequest should only be called after [Transport.RoundTrip] has returned. // -// Deprecated: Use Request.WithContext to create a request with a +// Deprecated: Use [Request.WithContext] to create a request with a // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *Request) { @@ -1205,7 +1209,6 @@ func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, e type wantConn struct { cm connectMethod key connectMethodKey // cm.key() - ctx context.Context // context for dial ready chan struct{} // closed when pc, err pair is delivered // hooks for testing to know when dials are done @@ -1214,7 +1217,8 @@ type wantConn struct { beforeDial func() afterDial func() - mu sync.Mutex // protects pc, err, close(ready) + mu sync.Mutex // protects ctx, pc, err, close(ready) + ctx context.Context // context for dial, cleared after delivered or canceled pc *persistConn err error } @@ -1229,6 +1233,13 @@ func (w *wantConn) waiting() bool { } } +// getCtxForDial returns context for dial or nil if connection was delivered or canceled. +func (w *wantConn) getCtxForDial() context.Context { + w.mu.Lock() + defer w.mu.Unlock() + return w.ctx +} + // tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { w.mu.Lock() @@ -1238,6 +1249,7 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { return false } + w.ctx = nil w.pc = pc w.err = err if w.pc == nil && w.err == nil { @@ -1255,6 +1267,7 @@ func (w *wantConn) cancel(t *Transport, err error) { close(w.ready) // catch misbehavior in future delivery } pc := w.pc + w.ctx = nil w.pc = nil w.err = err w.mu.Unlock() @@ -1463,8 +1476,13 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { defer w.afterDial() + ctx := w.getCtxForDial() + if ctx == nil { + t.decConnsPerHost(w.key) + return + } - pc, err := t.dialConn(w.ctx, w.cm) + pc, err := t.dialConn(ctx, w.cm) delivered := w.tryDeliver(pc, err) if err == nil && (!delivered || pc.alt != nil) { // pconn was not passed to w, @@ -1560,6 +1578,11 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr }() if err := <-errc; err != nil { plainConn.Close() + if err == (tlsHandshakeTimeoutError{}) { + // Now that we have closed the connection, + // wait for the call to HandshakeContext to return. + <-errc + } if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) } @@ -2248,7 +2271,7 @@ func (pc *persistConn) readLoop() { } case <-rc.req.Cancel: alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.cancelKey, errRequestCanceled) case <-rc.req.Context().Done(): alive = false pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 028fecc961..3fb5624664 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -730,6 +730,56 @@ func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { } } +func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) { + run(t, testTransportMaxConnsPerHostDialCancellation, + testNotParallel, // because test uses SetPendingDialHooks + []testMode{http1Mode, https1Mode, http2Mode}, + ) +} + +func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) { + CondSkipHTTP2(t) + + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + cst := newClientServerTest(t, mode, h) + defer cst.close() + ts := cst.ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + + // This request is cancelled when dial is queued, which preempts dialing. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + SetPendingDialHooks(cancel, nil) + defer SetPendingDialHooks(nil, nil) + + req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) + _, err := c.Do(req) + if !errors.Is(err, context.Canceled) { + t.Errorf("expected error %v, got %v", context.Canceled, err) + } + + // This request should succeed. + SetPendingDialHooks(nil, nil) + req, _ = NewRequest("GET", ts.URL, nil) + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } +} + func TestTransportRemovesDeadIdleConnections(t *testing.T) { run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) } @@ -2099,25 +2149,50 @@ func testIssue3644(t *testing.T, mode testMode) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. -func TestIssue3595(t *testing.T) { run(t, testIssue3595) } +func TestIssue3595(t *testing.T) { + // Not parallel: modifies the global rstAvoidanceDelay. + run(t, testIssue3595, testNotParallel) +} func testIssue3595(t *testing.T, mode testMode) { - const deniedMsg = "sorry, denied." - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - Error(w, deniedMsg, StatusUnauthorized) - })).ts - c := ts.Client() - res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) - if err != nil { - t.Errorf("Post: %v", err) - return - } - got, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("Body ReadAll: %v", err) - } - if !strings.Contains(string(got), deniedMsg) { - t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) - } + runTimeSensitiveTest(t, []time.Duration{ + 1 * time.Millisecond, + 5 * time.Millisecond, + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 5 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + SetRSTAvoidanceDelay(t, timeout) + t.Logf("set RST avoidance delay to %v", timeout) + + const deniedMsg = "sorry, denied." + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })) + // We need to close cst explicitly here so that in-flight server + // requests don't race with the call to SetRSTAvoidanceDelay for a retry. + defer cst.close() + ts := cst.ts + c := ts.Client() + + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + return fmt.Errorf("Post: %v", err) + } + got, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("Body ReadAll: %v", err) + } + t.Logf("server response:\n%s", got) + if !strings.Contains(string(got), deniedMsg) { + // If we got an RST packet too early, we should have seen an error + // from io.ReadAll, not a silently-truncated body. + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } + return nil + }) } // From https://golang.org/issue/4454 , @@ -2440,6 +2515,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } + return false } return true }) @@ -2599,6 +2675,65 @@ func testCancelRequestWithChannel(t *testing.T, mode testMode) { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } + return false + } + return true + }) +} + +// Issue 51354 +func TestCancelRequestWithBodyWithChannel(t *testing.T) { + run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode}) +} +func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping test in -short mode") + } + + const msg = "Hello" + unblockc := make(chan struct{}) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, msg) + w.(Flusher).Flush() // send headers and some body + <-unblockc + })).ts + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody")) + cancel := make(chan struct{}) + req.Cancel = cancel + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + body := make([]byte, len(msg)) + n, _ := io.ReadFull(res.Body, body) + if n != len(body) || !bytes.Equal(body, []byte(msg)) { + t.Errorf("Body = %q; want %q", body[:n], msg) + } + close(cancel) + + tail, err := io.ReadAll(res.Body) + res.Body.Close() + if err != ExportErrRequestCanceled { + t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } else if len(tail) > 0 { + t.Errorf("Spurious bytes from Body.Read: %q", tail) + } + + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { + n := tr.NumPendingRequestsForTesting() + if n > 0 { + if d > 0 { + t.Logf("pending requests = %d after %v (want 0)", n, d) + } + return false } return true }) @@ -3414,6 +3549,7 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { c net.Conn } var getOkay bool + var copying sync.WaitGroup closeConn := func() { sconn.Lock() defer sconn.Unlock() @@ -3425,7 +3561,10 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { } } } - defer closeConn() + defer func() { + closeConn() + copying.Wait() + }() ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { @@ -3437,7 +3576,12 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { sconn.c = conn sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive - go io.Copy(io.Discard, conn) + + copying.Add(1) + go func() { + io.Copy(io.Discard, conn) + copying.Done() + }() })).ts c := ts.Client() @@ -4267,68 +4411,78 @@ func (c *wgReadCloser) Close() error { // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { - run(t, testTransportPrefersResponseOverWriteError) + // Not parallel: modifies the global rstAvoidanceDelay. + run(t, testTransportPrefersResponseOverWriteError, testNotParallel) } func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - const contentLengthLimit = 1024 * 1024 // 1MB - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - if r.ContentLength >= contentLengthLimit { - w.WriteHeader(StatusBadRequest) - r.Body.Close() - return - } - w.WriteHeader(StatusOK) - })).ts - c := ts.Client() - fail := 0 - count := 100 + runTimeSensitiveTest(t, []time.Duration{ + 1 * time.Millisecond, + 5 * time.Millisecond, + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 5 * time.Second, + }, func(t *testing.T, timeout time.Duration) error { + SetRSTAvoidanceDelay(t, timeout) + t.Logf("set RST avoidance delay to %v", timeout) + + const contentLengthLimit = 1024 * 1024 // 1MB + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.ContentLength >= contentLengthLimit { + w.WriteHeader(StatusBadRequest) + r.Body.Close() + return + } + w.WriteHeader(StatusOK) + })) + // We need to close cst explicitly here so that in-flight server + // requests don't race with the call to SetRSTAvoidanceDelay for a retry. + defer cst.close() + ts := cst.ts + c := ts.Client() - bigBody := strings.Repeat("a", contentLengthLimit*2) - var wg sync.WaitGroup - defer wg.Wait() - getBody := func() (io.ReadCloser, error) { - wg.Add(1) - body := &wgReadCloser{ - Reader: strings.NewReader(bigBody), - wg: &wg, - } - return body, nil - } + count := 100 - for i := 0; i < count; i++ { - reqBody, _ := getBody() - req, err := NewRequest("PUT", ts.URL, reqBody) - if err != nil { - reqBody.Close() - t.Fatal(err) + bigBody := strings.Repeat("a", contentLengthLimit*2) + var wg sync.WaitGroup + defer wg.Wait() + getBody := func() (io.ReadCloser, error) { + wg.Add(1) + body := &wgReadCloser{ + Reader: strings.NewReader(bigBody), + wg: &wg, + } + return body, nil } - req.ContentLength = int64(len(bigBody)) - req.GetBody = getBody - resp, err := c.Do(req) - if err != nil { - fail++ - t.Logf("%d = %#v", i, err) - if ue, ok := err.(*url.Error); ok { - t.Logf("urlErr = %#v", ue.Err) - if ne, ok := ue.Err.(*net.OpError); ok { - t.Logf("netOpError = %#v", ne.Err) - } + for i := 0; i < count; i++ { + reqBody, _ := getBody() + req, err := NewRequest("PUT", ts.URL, reqBody) + if err != nil { + reqBody.Close() + t.Fatal(err) } - } else { - resp.Body.Close() - if resp.StatusCode != 400 { - t.Errorf("Expected status code 400, got %v", resp.Status) + req.ContentLength = int64(len(bigBody)) + req.GetBody = getBody + + resp, err := c.Do(req) + if err != nil { + return fmt.Errorf("Do %d: %v", i, err) + } else { + resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("Expected status code 400, got %v", resp.Status) + } } } - } - if fail > 0 { - t.Errorf("Failed %v out of %v\n", fail, count) - } + return nil + }) } func TestTransportAutomaticHTTP2(t *testing.T) { @@ -6750,3 +6904,36 @@ func testRequestSanitization(t *testing.T, mode testMode) { resp.Body.Close() } } + +func TestProxyAuthHeader(t *testing.T) { + // Not parallel: Sets an environment variable. + run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel) +} +func testProxyAuthHeader(t *testing.T, mode testMode) { + const username = "u" + const password = "@/?!" + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + // Copy the Proxy-Authorization header to a new Request, + // since Request.BasicAuth only parses the Authorization header. + var r2 Request + r2.Header = Header{ + "Authorization": req.Header["Proxy-Authorization"], + } + gotuser, gotpass, ok := r2.BasicAuth() + if !ok || gotuser != username || gotpass != password { + t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password) + } + })) + u, err := url.Parse(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + u.User = url.UserPassword(username, password) + t.Setenv("HTTP_PROXY", u.String()) + cst.tr.Proxy = ProxyURL(u) + resp, err := cst.c.Get("http://_/") + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} diff --git a/src/net/http/triv.go b/src/net/http/triv.go index f614922c24..c1696425cd 100644 --- a/src/net/http/triv.go +++ b/src/net/http/triv.go @@ -34,7 +34,7 @@ type Counter struct { n int } -// This makes Counter satisfy the expvar.Var interface, so we can export +// This makes Counter satisfy the [expvar.Var] interface, so we can export // it directly. func (ctr *Counter) String() string { ctr.mu.Lock() diff --git a/src/net/interface.go b/src/net/interface.go index e1c9a2e2ff..20ac07d31a 100644 --- a/src/net/interface.go +++ b/src/net/interface.go @@ -114,7 +114,7 @@ func Interfaces() ([]Interface, error) { // addresses. // // The returned list does not identify the associated interface; use -// Interfaces and Interface.Addrs for more detail. +// Interfaces and [Interface.Addrs] for more detail. func InterfaceAddrs() ([]Addr, error) { ifat, err := interfaceAddrTable(nil) if err != nil { @@ -127,7 +127,7 @@ func InterfaceAddrs() ([]Addr, error) { // // On Solaris, it returns one of the logical network interfaces // sharing the logical data link; for more precision use -// InterfaceByName. +// [InterfaceByName]. func InterfaceByIndex(index int) (*Interface, error) { if index <= 0 { return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex} diff --git a/src/net/interface_stub.go b/src/net/interface_stub.go index 829dbc6938..4c280c6ff2 100644 --- a/src/net/interface_stub.go +++ b/src/net/interface_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || wasip1 +//go:build js || wasip1 package net diff --git a/src/net/interface_test.go b/src/net/interface_test.go index 5590b06262..a97d675e7e 100644 --- a/src/net/interface_test.go +++ b/src/net/interface_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go index 0197feb3f1..967ce6795a 100644 --- a/src/net/internal/socktest/main_test.go +++ b/src/net/internal/socktest/main_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !plan9 && !wasip1 +//go:build !js && !plan9 && !wasip1 && !windows package socktest_test diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go deleted file mode 100644 index df1cb97784..0000000000 --- a/src/net/internal/socktest/main_windows_test.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package socktest_test - -import "syscall" - -var ( - socketFunc func(int, int, int) (syscall.Handle, error) - closeFunc func(syscall.Handle) error -) - -func installTestHooks() { - socketFunc = sw.Socket - closeFunc = sw.Closesocket -} - -func uninstallTestHooks() { - socketFunc = syscall.Socket - closeFunc = syscall.Closesocket -} diff --git a/src/net/internal/socktest/switch.go b/src/net/internal/socktest/switch.go index 3c37b6ff80..dea6d9288c 100644 --- a/src/net/internal/socktest/switch.go +++ b/src/net/internal/socktest/switch.go @@ -133,7 +133,7 @@ const ( // If the filter returns a non-nil error, the execution of system call // will be canceled and the system call function returns the non-nil // error. -// It can return a non-nil AfterFilter for filtering after the +// It can return a non-nil [AfterFilter] for filtering after the // execution of the system call. type Filter func(*Status) (AfterFilter, error) diff --git a/src/net/internal/socktest/sys_unix.go b/src/net/internal/socktest/sys_unix.go index 712462abf4..3eef26c70b 100644 --- a/src/net/internal/socktest/sys_unix.go +++ b/src/net/internal/socktest/sys_unix.go @@ -8,7 +8,7 @@ package socktest import "syscall" -// Socket wraps syscall.Socket. +// Socket wraps [syscall.Socket]. func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) { sw.once.Do(sw.init) diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go index 8c1c862f33..2f02446075 100644 --- a/src/net/internal/socktest/sys_windows.go +++ b/src/net/internal/socktest/sys_windows.go @@ -9,39 +9,7 @@ import ( "syscall" ) -// Socket wraps syscall.Socket. -func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) { - sw.once.Do(sw.init) - - so := &Status{Cookie: cookie(family, sotype, proto)} - sw.fmu.RLock() - f, _ := sw.fltab[FilterSocket] - sw.fmu.RUnlock() - - af, err := f.apply(so) - if err != nil { - return syscall.InvalidHandle, err - } - s, so.Err = syscall.Socket(family, sotype, proto) - if err = af.apply(so); err != nil { - if so.Err == nil { - syscall.Closesocket(s) - } - return syscall.InvalidHandle, err - } - - sw.smu.Lock() - defer sw.smu.Unlock() - if so.Err != nil { - sw.stats.getLocked(so.Cookie).OpenFailed++ - return syscall.InvalidHandle, so.Err - } - nso := sw.addLocked(s, family, sotype, proto) - sw.stats.getLocked(nso.Cookie).Opened++ - return s, nil -} - -// WSASocket wraps syscall.WSASocket. +// WSASocket wraps [syscall.WSASocket]. func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) { sw.once.Do(sw.init) @@ -73,7 +41,7 @@ func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAPr return s, nil } -// Closesocket wraps syscall.Closesocket. +// Closesocket wraps [syscall.Closesocket]. func (sw *Switch) Closesocket(s syscall.Handle) (err error) { so := sw.sockso(s) if so == nil { @@ -103,7 +71,7 @@ func (sw *Switch) Closesocket(s syscall.Handle) (err error) { return nil } -// Connect wraps syscall.Connect. +// Connect wraps [syscall.Connect]. func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) { so := sw.sockso(s) if so == nil { @@ -132,7 +100,7 @@ func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) { return nil } -// ConnectEx wraps syscall.ConnectEx. +// ConnectEx wraps [syscall.ConnectEx]. func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n uint32, nwr *uint32, o *syscall.Overlapped) (err error) { so := sw.sockso(s) if so == nil { @@ -161,7 +129,7 @@ func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n ui return nil } -// Listen wraps syscall.Listen. +// Listen wraps [syscall.Listen]. func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) { so := sw.sockso(s) if so == nil { @@ -190,7 +158,7 @@ func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) { return nil } -// AcceptEx wraps syscall.AcceptEx. +// AcceptEx wraps [syscall.AcceptEx]. func (sw *Switch) AcceptEx(ls syscall.Handle, as syscall.Handle, b *byte, rxdatalen uint32, laddrlen uint32, raddrlen uint32, rcvd *uint32, overlapped *syscall.Overlapped) error { so := sw.sockso(ls) if so == nil { diff --git a/src/net/ip.go b/src/net/ip.go index d51ba10eec..6083dd8bf9 100644 --- a/src/net/ip.go +++ b/src/net/ip.go @@ -38,7 +38,7 @@ type IP []byte // An IPMask is a bitmask that can be used to manipulate // IP addresses for IP addressing and routing. // -// See type IPNet and func ParseCIDR for details. +// See type [IPNet] and func [ParseCIDR] for details. type IPMask []byte // An IPNet represents an IP network. @@ -72,9 +72,9 @@ func IPv4Mask(a, b, c, d byte) IPMask { return p } -// CIDRMask returns an IPMask consisting of 'ones' 1 bits +// CIDRMask returns an [IPMask] consisting of 'ones' 1 bits // followed by 0s up to a total length of 'bits' bits. -// For a mask of this form, CIDRMask is the inverse of IPMask.Size. +// For a mask of this form, CIDRMask is the inverse of [IPMask.Size]. func CIDRMask(ones, bits int) IPMask { if bits != 8*IPv4len && bits != 8*IPv6len { return nil @@ -324,8 +324,8 @@ func ipEmptyString(ip IP) string { return ip.String() } -// MarshalText implements the encoding.TextMarshaler interface. -// The encoding is the same as returned by String, with one exception: +// MarshalText implements the [encoding.TextMarshaler] interface. +// The encoding is the same as returned by [IP.String], with one exception: // When len(ip) is zero, it returns an empty slice. func (ip IP) MarshalText() ([]byte, error) { if len(ip) == 0 { @@ -337,8 +337,8 @@ func (ip IP) MarshalText() ([]byte, error) { return []byte(ip.String()), nil } -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// The IP address is expected in a form accepted by ParseIP. +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +// The IP address is expected in a form accepted by [ParseIP]. func (ip *IP) UnmarshalText(text []byte) error { if len(text) == 0 { *ip = nil diff --git a/src/net/ip_test.go b/src/net/ip_test.go index 1373059abe..acc2310be1 100644 --- a/src/net/ip_test.go +++ b/src/net/ip_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index f18331a1fd..4c06b1b5ac 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -72,7 +72,7 @@ func (a *IPAddr) opAddr() Addr { // recommended, because it will return at most one of the host name's // IP addresses. // -// See func Dial for a description of the network and address +// See func [Dial] for a description of the network and address // parameters. func ResolveIPAddr(network, address string) (*IPAddr, error) { if network == "" { // a hint wildcard for Go 1.0 undocumented behavior @@ -94,19 +94,19 @@ func ResolveIPAddr(network, address string) (*IPAddr, error) { return addrs.forResolve(network, address).(*IPAddr), nil } -// IPConn is the implementation of the Conn and PacketConn interfaces +// IPConn is the implementation of the [Conn] and [PacketConn] interfaces // for IP network connections. type IPConn struct { conn } // SyscallConn returns a raw network connection. -// This implements the syscall.Conn interface. +// This implements the [syscall.Conn] interface. func (c *IPConn) SyscallConn() (syscall.RawConn, error) { if !c.ok() { return nil, syscall.EINVAL } - return newRawConn(c.fd) + return newRawConn(c.fd), nil } // ReadFromIP acts like ReadFrom but returns an IPAddr. @@ -121,7 +121,7 @@ func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) { return n, addr, err } -// ReadFrom implements the PacketConn ReadFrom method. +// ReadFrom implements the [PacketConn] ReadFrom method. func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -154,7 +154,7 @@ func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err return } -// WriteToIP acts like WriteTo but takes an IPAddr. +// WriteToIP acts like [IPConn.WriteTo] but takes an [IPAddr]. func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -166,7 +166,7 @@ func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { return n, err } -// WriteTo implements the PacketConn WriteTo method. +// WriteTo implements the [PacketConn] WriteTo method. func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -201,7 +201,7 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} } -// DialIP acts like Dial for IP networks. +// DialIP acts like [Dial] for IP networks. // // The network must be an IP network name; see func Dial for details. // @@ -220,7 +220,7 @@ func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) { return c, nil } -// ListenIP acts like ListenPacket for IP networks. +// ListenIP acts like [ListenPacket] for IP networks. // // The network must be an IP network name; see func Dial for details. // diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index 59967eb923..73b41ab522 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net diff --git a/src/net/iprawsock_test.go b/src/net/iprawsock_test.go index 14c03a1f4d..7f1fc139ab 100644 --- a/src/net/iprawsock_test.go +++ b/src/net/iprawsock_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( diff --git a/src/net/ipsock.go b/src/net/ipsock.go index 0f5da2577c..176dbc748e 100644 --- a/src/net/ipsock.go +++ b/src/net/ipsock.go @@ -83,10 +83,10 @@ func (addrs addrList) forResolve(network, addr string) Addr { switch network { case "ip": // IPv6 literal (addr does NOT contain a port) - want6 = count(addr, ':') > 0 + want6 = bytealg.CountString(addr, ':') > 0 case "tcp", "udp": // IPv6 literal. (addr contains a port, so look for '[') - want6 = count(addr, '[') > 0 + want6 = bytealg.CountString(addr, '[') > 0 } if want6 { return addrs.first(isNotIPv4) @@ -172,7 +172,7 @@ func SplitHostPort(hostport string) (host, port string, err error) { j, k := 0, 0 // The port starts after the last colon. - i := last(hostport, ':') + i := bytealg.LastIndexByteString(hostport, ':') if i < 0 { return addrErr(hostport, missingPort) } @@ -219,7 +219,7 @@ func SplitHostPort(hostport string) (host, port string, err error) { func splitHostZone(s string) (host, zone string) { // The IPv6 scoped addressing zone identifier starts after the // last percent sign. - if i := last(s, '%'); i > 0 { + if i := bytealg.LastIndexByteString(s, '%'); i > 0 { host, zone = s[:i], s[i+1:] } else { host = s diff --git a/src/net/ipsock_plan9.go b/src/net/ipsock_plan9.go index 43287431c8..c8d0180436 100644 --- a/src/net/ipsock_plan9.go +++ b/src/net/ipsock_plan9.go @@ -181,7 +181,6 @@ func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, e } resc := make(chan res) go func() { - testHookDialChannel() fd, err := dialPlan9Blocking(ctx, net, laddr, raddr) select { case resc <- res{fd, err}: diff --git a/src/net/ipsock_posix.go b/src/net/ipsock_posix.go index b0a00a6296..67ce1479c6 100644 --- a/src/net/ipsock_posix.go +++ b/src/net/ipsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net @@ -25,6 +25,15 @@ import ( // general. Unfortunately, we need to run on kernels built without // IPv6 support too. So probe the kernel to figure it out. func (p *ipStackCapabilities) probe() { + switch runtime.GOOS { + case "js", "wasip1": + // Both ipv4 and ipv6 are faked; see net_fake.go. + p.ipv4Enabled = true + p.ipv6Enabled = true + p.ipv4MappedIPv6Enabled = true + return + } + s, err := sysSocket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) switch err { case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT: @@ -135,8 +144,11 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam } func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) { - if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() { - raddr = raddr.toLocal(net) + switch runtime.GOOS { + case "aix", "windows", "openbsd", "js", "wasip1": + if mode == "dial" && raddr.isWildcard() { + raddr = raddr.toLocal(net) + } } family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn) diff --git a/src/net/listen_test.go b/src/net/listen_test.go index f0a8861370..9100b3d9f7 100644 --- a/src/net/listen_test.go +++ b/src/net/listen_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !plan9 && !wasip1 +//go:build !plan9 package net diff --git a/src/net/lookup.go b/src/net/lookup.go index a7133b53ac..3ec2660786 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -41,19 +41,20 @@ var services = map[string]map[string]int{ "domain": 53, }, "tcp": { - "ftp": 21, - "ftps": 990, - "gopher": 70, // ʕ◔ϖ◔ʔ - "http": 80, - "https": 443, - "imap2": 143, - "imap3": 220, - "imaps": 993, - "pop3": 110, - "pop3s": 995, - "smtp": 25, - "ssh": 22, - "telnet": 23, + "ftp": 21, + "ftps": 990, + "gopher": 70, // ʕ◔ϖ◔ʔ + "http": 80, + "https": 443, + "imap2": 143, + "imap3": 220, + "imaps": 993, + "pop3": 110, + "pop3s": 995, + "smtp": 25, + "submissions": 465, + "ssh": 22, + "telnet": 23, }, } @@ -83,12 +84,20 @@ const maxPortBufSize = len("mobility-header") + 10 func lookupPortMap(network, service string) (port int, error error) { switch network { - case "tcp4", "tcp6": - network = "tcp" - case "udp4", "udp6": - network = "udp" + case "ip": // no hints + if p, err := lookupPortMapWithNetwork("tcp", "ip", service); err == nil { + return p, nil + } + return lookupPortMapWithNetwork("udp", "ip", service) + case "tcp", "tcp4", "tcp6": + return lookupPortMapWithNetwork("tcp", "tcp", service) + case "udp", "udp4", "udp6": + return lookupPortMapWithNetwork("udp", "udp", service) } + return 0, &DNSError{Err: "unknown network", Name: network + "/" + service} +} +func lookupPortMapWithNetwork(network, errNetwork, service string) (port int, error error) { if m, ok := services[network]; ok { var lowerService [maxPortBufSize]byte n := copy(lowerService[:], service) @@ -96,8 +105,9 @@ func lookupPortMap(network, service string) (port int, error error) { if port, ok := m[string(lowerService[:n])]; ok && n == len(service) { return port, nil } + return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true} } - return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service} + return 0, &DNSError{Err: "unknown network", Name: errNetwork + "/" + service} } // ipVersion returns the provided network's IP version: '4', '6' or 0 @@ -171,8 +181,8 @@ func (r *Resolver) getLookupGroup() *singleflight.Group { // LookupHost looks up the given host using the local resolver. // It returns a slice of that host's addresses. // -// LookupHost uses context.Background internally; to specify the context, use -// Resolver.LookupHost. +// LookupHost uses [context.Background] internally; to specify the context, use +// [Resolver.LookupHost]. func LookupHost(host string) (addrs []string, err error) { return DefaultResolver.LookupHost(context.Background(), host) } @@ -407,18 +417,20 @@ func ipAddrsEface(addrs []IPAddr) []any { // LookupPort looks up the port for the given network and service. // -// LookupPort uses context.Background internally; to specify the context, use -// Resolver.LookupPort. +// LookupPort uses [context.Background] internally; to specify the context, use +// [Resolver.LookupPort]. func LookupPort(network, service string) (port int, err error) { return DefaultResolver.LookupPort(context.Background(), network, service) } // LookupPort looks up the port for the given network and service. +// +// The network must be one of "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6" or "ip". func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) { port, needsLookup := parsePort(service) if needsLookup { switch network { - case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip": case "": // a hint wildcard for Go 1.0 undocumented behavior network = "ip" default: @@ -437,7 +449,7 @@ func (r *Resolver) LookupPort(ctx context.Context, network, service string) (por // LookupCNAME returns the canonical name for the given host. // Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving +// [LookupHost] or [LookupIP] directly; both take care of resolving // the canonical name as part of the lookup. // // A canonical name is the final name after following zero @@ -449,15 +461,15 @@ func (r *Resolver) LookupPort(ctx context.Context, network, service string) (por // The returned canonical name is validated to be a properly // formatted presentation-format domain name. // -// LookupCNAME uses context.Background internally; to specify the context, use -// Resolver.LookupCNAME. +// LookupCNAME uses [context.Background] internally; to specify the context, use +// [Resolver.LookupCNAME]. func LookupCNAME(host string) (cname string, err error) { return DefaultResolver.LookupCNAME(context.Background(), host) } // LookupCNAME returns the canonical name for the given host. // Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving +// [LookupHost] or [LookupIP] directly; both take care of resolving // the canonical name as part of the lookup. // // A canonical name is the final name after following zero @@ -479,7 +491,7 @@ func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) return cname, nil } -// LookupSRV tries to resolve an SRV query of the given service, +// LookupSRV tries to resolve an [SRV] query of the given service, // protocol, and domain name. The proto is "tcp" or "udp". // The returned records are sorted by priority and randomized // by weight within a priority. @@ -497,7 +509,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err return DefaultResolver.LookupSRV(context.Background(), service, proto, name) } -// LookupSRV tries to resolve an SRV query of the given service, +// LookupSRV tries to resolve an [SRV] query of the given service, // protocol, and domain name. The proto is "tcp" or "udp". // The returned records are sorted by priority and randomized // by weight within a priority. @@ -542,8 +554,8 @@ func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) ( // invalid names, those records are filtered out and an error // will be returned alongside the remaining results, if any. // -// LookupMX uses context.Background internally; to specify the context, use -// Resolver.LookupMX. +// LookupMX uses [context.Background] internally; to specify the context, use +// [Resolver.LookupMX]. func LookupMX(name string) ([]*MX, error) { return DefaultResolver.LookupMX(context.Background(), name) } @@ -582,8 +594,8 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) { // invalid names, those records are filtered out and an error // will be returned alongside the remaining results, if any. // -// LookupNS uses context.Background internally; to specify the context, use -// Resolver.LookupNS. +// LookupNS uses [context.Background] internally; to specify the context, use +// [Resolver.LookupNS]. func LookupNS(name string) ([]*NS, error) { return DefaultResolver.LookupNS(context.Background(), name) } @@ -617,8 +629,8 @@ func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) { // LookupTXT returns the DNS TXT records for the given domain name. // -// LookupTXT uses context.Background internally; to specify the context, use -// Resolver.LookupTXT. +// LookupTXT uses [context.Background] internally; to specify the context, use +// [Resolver.LookupTXT]. func LookupTXT(name string) ([]string, error) { return DefaultResolver.lookupTXT(context.Background(), name) } @@ -636,10 +648,10 @@ func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) // out and an error will be returned alongside the remaining results, if any. // // When using the host C library resolver, at most one result will be -// returned. To bypass the host resolver, use a custom Resolver. +// returned. To bypass the host resolver, use a custom [Resolver]. // -// LookupAddr uses context.Background internally; to specify the context, use -// Resolver.LookupAddr. +// LookupAddr uses [context.Background] internally; to specify the context, use +// [Resolver.LookupAddr]. func LookupAddr(addr string) (names []string, err error) { return DefaultResolver.LookupAddr(context.Background(), addr) } diff --git a/src/net/lookup_fake.go b/src/net/lookup_fake.go deleted file mode 100644 index c27eae4ba5..0000000000 --- a/src/net/lookup_fake.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build js && wasm - -package net - -import ( - "context" - "syscall" -) - -func lookupProtocol(ctx context.Context, name string) (proto int, err error) { - return lookupProtocolMap(name) -} - -func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { - return nil, syscall.ENOPROTOOPT -} - -func (*Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) { - return nil, syscall.ENOPROTOOPT -} - -func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) { - return goLookupPort(network, service) -} - -func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { - return "", syscall.ENOPROTOOPT -} - -func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) { - return "", nil, syscall.ENOPROTOOPT -} - -func (*Resolver) lookupMX(ctx context.Context, name string) (mxs []*MX, err error) { - return nil, syscall.ENOPROTOOPT -} - -func (*Resolver) lookupNS(ctx context.Context, name string) (nss []*NS, err error) { - return nil, syscall.ENOPROTOOPT -} - -func (*Resolver) lookupTXT(ctx context.Context, name string) (txts []string, err error) { - return nil, syscall.ENOPROTOOPT -} - -func (*Resolver) lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) { - return nil, syscall.ENOPROTOOPT -} - -// concurrentThreadsLimit returns the number of threads we permit to -// run concurrently doing DNS lookups. -func concurrentThreadsLimit() int { - return 500 -} diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go index 5404b996e4..8cfc4f6bb3 100644 --- a/src/net/lookup_plan9.go +++ b/src/net/lookup_plan9.go @@ -106,6 +106,22 @@ func queryDNS(ctx context.Context, addr string, typ string) (res []string, err e return query(ctx, netdir+"/dns", addr+" "+typ, 1024) } +func handlePlan9DNSError(err error, name string) error { + if stringsHasSuffix(err.Error(), "dns: name does not exist") || + stringsHasSuffix(err.Error(), "dns: resource does not exist; negrcode 0") || + stringsHasSuffix(err.Error(), "dns: resource does not exist; negrcode") { + return &DNSError{ + Err: errNoSuchHost.Error(), + Name: name, + IsNotFound: true, + } + } + return &DNSError{ + Err: err.Error(), + Name: name, + } +} + // toLower returns a lower-case version of in. Restricting us to // ASCII is sufficient to handle the IP protocol names and allow // us to not depend on the strings and unicode packages. @@ -153,12 +169,10 @@ func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, e // host names in local network (e.g. from /lib/ndb/local) lines, err := queryCS(ctx, "net", host, "1") if err != nil { - dnsError := &DNSError{Err: err.Error(), Name: host} if stringsHasSuffix(err.Error(), "dns failure") { - dnsError.Err = errNoSuchHost.Error() - dnsError.IsNotFound = true + return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} } - return nil, dnsError + return nil, handlePlan9DNSError(err, host) } loop: for _, line := range lines { @@ -184,31 +198,11 @@ loop: return } -// preferGoOverPlan9 reports whether the resolver should use the -// "PreferGo" implementation rather than asking plan9 services -// for the answers. -func (r *Resolver) preferGoOverPlan9() bool { - _, _, res := r.preferGoOverPlan9WithOrderAndConf() - return res -} - -func (r *Resolver) preferGoOverPlan9WithOrderAndConf() (hostLookupOrder, *dnsConfig, bool) { - order, conf := systemConf().hostLookupOrder(r, "") // name is unused - - // TODO(bradfitz): for now we only permit use of the PreferGo - // implementation when there's a non-nil Resolver with a - // non-nil Dialer. This is a sign that they the code is trying - // to use their DNS-speaking net.Conn (such as an in-memory - // DNS cache) and they don't want to actually hit the network. - // Once we add support for looking the default DNS servers - // from plan9, though, then we can relax this. - return order, conf, order != hostLookupCgo && r != nil && r.Dial != nil -} - func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) { - if r.preferGoOverPlan9() { - return r.goLookupIP(ctx, network, host) + if order, conf := systemConf().hostLookupOrder(r, host); order != hostLookupCgo { + return r.goLookupIP(ctx, network, host, order, conf) } + lits, err := r.lookupHost(ctx, host) if err != nil { return @@ -223,24 +217,36 @@ func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs [] return } -func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) { +func (r *Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) { switch network { - case "tcp4", "tcp6": - network = "tcp" - case "udp4", "udp6": - network = "udp" + case "ip": // no hints + if p, err := r.lookupPortWithNetwork(ctx, "tcp", "ip", service); err == nil { + return p, nil + } + return r.lookupPortWithNetwork(ctx, "udp", "ip", service) + case "tcp", "tcp4", "tcp6": + return r.lookupPortWithNetwork(ctx, "tcp", "tcp", service) + case "udp", "udp4", "udp6": + return r.lookupPortWithNetwork(ctx, "udp", "udp", service) + default: + return 0, &DNSError{Err: "unknown network", Name: network + "/" + service} } +} + +func (*Resolver) lookupPortWithNetwork(ctx context.Context, network, errNetwork, service string) (port int, err error) { lines, err := queryCS(ctx, network, "127.0.0.1", toLower(service)) if err != nil { + if stringsHasSuffix(err.Error(), "can't translate service") { + return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true} + } return } - unknownPortError := &AddrError{Err: "unknown port", Addr: network + "/" + service} if len(lines) == 0 { - return 0, unknownPortError + return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true} } f := getFields(lines[0]) if len(f) < 2 { - return 0, unknownPortError + return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true} } s := f[1] if i := bytealg.IndexByteString(s, '!'); i >= 0 { @@ -249,21 +255,20 @@ func (*Resolver) lookupPort(ctx context.Context, network, service string) (port if n, _, ok := dtoi(s); ok { return n, nil } - return 0, unknownPortError + return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true} } func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { - if order, conf, preferGo := r.preferGoOverPlan9WithOrderAndConf(); preferGo { + if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo { return r.goLookupCNAME(ctx, name, order, conf) } lines, err := queryDNS(ctx, name, "cname") if err != nil { if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") { - cname = name + "." - err = nil + return absDomainName(name), nil } - return + return "", handlePlan9DNSError(err, cname) } if len(lines) > 0 { if f := getFields(lines[0]); len(f) >= 3 { @@ -274,7 +279,7 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, } func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { - if r.preferGoOverPlan9() { + if systemConf().mustUseGoResolver(r) { return r.goLookupSRV(ctx, service, proto, name) } var target string @@ -285,7 +290,7 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( } lines, err := queryDNS(ctx, target, "srv") if err != nil { - return + return "", nil, handlePlan9DNSError(err, name) } for _, line := range lines { f := getFields(line) @@ -306,12 +311,12 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( } func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) { - if r.preferGoOverPlan9() { + if systemConf().mustUseGoResolver(r) { return r.goLookupMX(ctx, name) } lines, err := queryDNS(ctx, name, "mx") if err != nil { - return + return nil, handlePlan9DNSError(err, name) } for _, line := range lines { f := getFields(line) @@ -327,12 +332,12 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err err } func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) { - if r.preferGoOverPlan9() { + if systemConf().mustUseGoResolver(r) { return r.goLookupNS(ctx, name) } lines, err := queryDNS(ctx, name, "ns") if err != nil { - return + return nil, handlePlan9DNSError(err, name) } for _, line := range lines { f := getFields(line) @@ -345,12 +350,12 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err err } func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) { - if r.preferGoOverPlan9() { + if systemConf().mustUseGoResolver(r) { return r.goLookupTXT(ctx, name) } lines, err := queryDNS(ctx, name, "txt") if err != nil { - return + return nil, handlePlan9DNSError(err, name) } for _, line := range lines { if i := bytealg.IndexByteString(line, '\t'); i >= 0 { @@ -361,7 +366,7 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, er } func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) { - if order, conf, preferGo := r.preferGoOverPlan9WithOrderAndConf(); preferGo { + if order, conf := systemConf().addrLookupOrder(r, addr); order != hostLookupCgo { return r.goLookupPTR(ctx, addr, order, conf) } arpa, err := reverseaddr(addr) @@ -370,7 +375,7 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, } lines, err := queryDNS(ctx, arpa, "ptr") if err != nil { - return + return nil, handlePlan9DNSError(err, addr) } for _, line := range lines { f := getFields(line) diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go index 0689c19c3c..57ac9a933a 100644 --- a/src/net/lookup_test.go +++ b/src/net/lookup_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -1462,3 +1460,189 @@ func testLookupNoData(t *testing.T, prefix string) { return } } + +func TestLookupPortNotFound(t *testing.T) { + allResolvers(t, func(t *testing.T) { + _, err := LookupPort("udp", "_-unknown-service-") + var dnsErr *DNSError + if !errors.As(err, &dnsErr) || !dnsErr.IsNotFound { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +// submissions service is only available through a tcp network, see: +// https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml?search=submissions +var tcpOnlyService = func() string { + // plan9 does not have submissions service defined in the service database. + if runtime.GOOS == "plan9" { + return "https" + } + return "submissions" +}() + +func TestLookupPortDifferentNetwork(t *testing.T) { + allResolvers(t, func(t *testing.T) { + _, err := LookupPort("udp", tcpOnlyService) + var dnsErr *DNSError + if !errors.As(err, &dnsErr) || !dnsErr.IsNotFound { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestLookupPortEmptyNetworkString(t *testing.T) { + allResolvers(t, func(t *testing.T) { + _, err := LookupPort("", tcpOnlyService) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestLookupPortIPNetworkString(t *testing.T) { + allResolvers(t, func(t *testing.T) { + _, err := LookupPort("ip", tcpOnlyService) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func allResolvers(t *testing.T, f func(t *testing.T)) { + t.Run("default resolver", f) + t.Run("forced go resolver", func(t *testing.T) { + if fixup := forceGoDNS(); fixup != nil { + defer fixup() + f(t) + } + }) + t.Run("forced cgo resolver", func(t *testing.T) { + if fixup := forceCgoDNS(); fixup != nil { + defer fixup() + f(t) + } + }) +} + +func TestLookupNoSuchHost(t *testing.T) { + mustHaveExternalNetwork(t) + + const testNXDOMAIN = "invalid.invalid." + const testNODATA = "_ldap._tcp.google.com." + + tests := []struct { + name string + query func() error + }{ + { + name: "LookupCNAME NXDOMAIN", + query: func() error { + _, err := LookupCNAME(testNXDOMAIN) + return err + }, + }, + { + name: "LookupHost NXDOMAIN", + query: func() error { + _, err := LookupHost(testNXDOMAIN) + return err + }, + }, + { + name: "LookupHost NODATA", + query: func() error { + _, err := LookupHost(testNODATA) + return err + }, + }, + { + name: "LookupMX NXDOMAIN", + query: func() error { + _, err := LookupMX(testNXDOMAIN) + return err + }, + }, + { + name: "LookupMX NODATA", + query: func() error { + _, err := LookupMX(testNODATA) + return err + }, + }, + { + name: "LookupNS NXDOMAIN", + query: func() error { + _, err := LookupNS(testNXDOMAIN) + return err + }, + }, + { + name: "LookupNS NODATA", + query: func() error { + _, err := LookupNS(testNODATA) + return err + }, + }, + { + name: "LookupSRV NXDOMAIN", + query: func() error { + _, _, err := LookupSRV("unknown", "tcp", testNXDOMAIN) + return err + }, + }, + { + name: "LookupTXT NXDOMAIN", + query: func() error { + _, err := LookupTXT(testNXDOMAIN) + return err + }, + }, + { + name: "LookupTXT NODATA", + query: func() error { + _, err := LookupTXT(testNODATA) + return err + }, + }, + } + + for _, v := range tests { + t.Run(v.name, func(t *testing.T) { + allResolvers(t, func(t *testing.T) { + attempts := 0 + for { + err := v.query() + if err == nil { + t.Errorf("unexpected success") + return + } + if dnsErr, ok := err.(*DNSError); ok { + succeeded := true + if !dnsErr.IsNotFound { + succeeded = false + t.Log("IsNotFound is set to false") + } + if dnsErr.Err != errNoSuchHost.Error() { + succeeded = false + t.Logf("error message is not equal to: %v", errNoSuchHost.Error()) + } + if succeeded { + return + } + } + testenv.SkipFlakyNet(t) + if attempts < len(backoffDuration) { + dur := backoffDuration[attempts] + t.Logf("backoff %v after failure %v\n", dur, err) + time.Sleep(dur) + attempts++ + continue + } + t.Errorf("unexpected error: %v", err) + return + } + }) + }) + } +} diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index 56ae11e961..382a2d44bb 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || wasip1 +//go:build unix || js || wasip1 package net @@ -10,7 +10,6 @@ import ( "context" "internal/bytealg" "sync" - "syscall" ) var onceReadProtocols sync.Once @@ -62,9 +61,6 @@ func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, } func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) { - if r.preferGo() { - return r.goLookupIP(ctx, network, host) - } order, conf := systemConf().hostLookupOrder(r, host) if order == hostLookupCgo { return cgoLookupIP(ctx, network, host) @@ -123,27 +119,3 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error } return r.goLookupPTR(ctx, addr, order, conf) } - -// concurrentThreadsLimit returns the number of threads we permit to -// run concurrently doing DNS lookups via cgo. A DNS lookup may use a -// file descriptor so we limit this to less than the number of -// permitted open files. On some systems, notably Darwin, if -// getaddrinfo is unable to open a file descriptor it simply returns -// EAI_NONAME rather than a useful error. Limiting the number of -// concurrent getaddrinfo calls to less than the permitted number of -// file descriptors makes that error less likely. We don't bother to -// apply the same limit to DNS lookups run directly from Go, because -// there we will return a meaningful "too many open files" error. -func concurrentThreadsLimit() int { - var rlim syscall.Rlimit - if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil { - return 500 - } - r := rlim.Cur - if r > 500 { - r = 500 - } else if r > 30 { - r -= 30 - } - return int(r) -} diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index 33d5ac5fb4..3048f3269b 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -20,13 +20,17 @@ import ( const cgoAvailable = true const ( + _DNS_ERROR_RCODE_NAME_ERROR = syscall.Errno(9003) + _DNS_INFO_NO_RECORDS = syscall.Errno(9501) + _WSAHOST_NOT_FOUND = syscall.Errno(11001) _WSATRY_AGAIN = syscall.Errno(11002) + _WSATYPE_NOT_FOUND = syscall.Errno(10109) ) func winError(call string, err error) error { switch err { - case _WSAHOST_NOT_FOUND: + case _WSAHOST_NOT_FOUND, _DNS_ERROR_RCODE_NAME_ERROR, _DNS_INFO_NO_RECORDS: return errNoSuchHost } return os.NewSyscallError(call, err) @@ -91,19 +95,11 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error return addrs, nil } -// preferGoOverWindows reports whether the resolver should use the -// pure Go implementation rather than making win32 calls to ask the -// kernel for its answer. -func (r *Resolver) preferGoOverWindows() bool { - conf := systemConf() - order, _ := conf.hostLookupOrder(r, "") // name is unused - return order != hostLookupCgo -} - func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) { - if r.preferGoOverWindows() { - return r.goLookupIP(ctx, network, name) + if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo { + return r.goLookupIP(ctx, network, name, order, conf) } + // TODO(bradfitz,brainman): use ctx more. See TODO below. var family int32 = syscall.AF_UNSPEC @@ -200,37 +196,51 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr } func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) { - if r.preferGoOverWindows() { + if systemConf().mustUseGoResolver(r) { return lookupPortMap(network, service) } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() - var stype int32 + + var hints syscall.AddrinfoW + switch network { - case "tcp4", "tcp6": - stype = syscall.SOCK_STREAM - case "udp4", "udp6": - stype = syscall.SOCK_DGRAM + case "ip": // no hints + case "tcp", "tcp4", "tcp6": + hints.Socktype = syscall.SOCK_STREAM + hints.Protocol = syscall.IPPROTO_TCP + case "udp", "udp4", "udp6": + hints.Socktype = syscall.SOCK_DGRAM + hints.Protocol = syscall.IPPROTO_UDP + default: + return 0, &DNSError{Err: "unknown network", Name: network + "/" + service} } - hints := syscall.AddrinfoW{ - Family: syscall.AF_UNSPEC, - Socktype: stype, - Protocol: syscall.IPPROTO_IP, + + switch ipVersion(network) { + case '4': + hints.Family = syscall.AF_INET + case '6': + hints.Family = syscall.AF_INET6 } + var result *syscall.AddrinfoW e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result) if e != nil { if port, err := lookupPortMap(network, service); err == nil { return port, nil } - err := winError("getaddrinfow", e) - dnsError := &DNSError{Err: err.Error(), Name: network + "/" + service} - if err == errNoSuchHost { - dnsError.IsNotFound = true + + // The _WSATYPE_NOT_FOUND error is returned by GetAddrInfoW + // when the service name is unknown. We are also checking + // for _WSAHOST_NOT_FOUND here to match the cgo (unix) version + // cgo_unix.go (cgoLookupServicePort). + if e == _WSATYPE_NOT_FOUND || e == _WSAHOST_NOT_FOUND { + return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true} } - return 0, dnsError + err := os.NewSyscallError("getaddrinfow", e) + return 0, &DNSError{Err: err.Error(), Name: network + "/" + service} } defer syscall.FreeAddrInfoW(result) if result == nil { @@ -249,7 +259,7 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int } func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { - if order, conf := systemConf().hostLookupOrder(r, ""); order != hostLookupCgo { + if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo { return r.goLookupCNAME(ctx, name, order, conf) } @@ -264,7 +274,8 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) return absDomainName(name), nil } if e != nil { - return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name} + err := winError("dnsquery", e) + return "", &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} } defer syscall.DnsRecordListFree(rec, 1) @@ -274,7 +285,7 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) } func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { - if r.preferGoOverWindows() { + if systemConf().mustUseGoResolver(r) { return r.goLookupSRV(ctx, service, proto, name) } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. @@ -289,7 +300,8 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( var rec *syscall.DNSRecord e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil) if e != nil { - return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target} + err := winError("dnsquery", e) + return "", nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} } defer syscall.DnsRecordListFree(rec, 1) @@ -303,7 +315,7 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( } func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { - if r.preferGoOverWindows() { + if systemConf().mustUseGoResolver(r) { return r.goLookupMX(ctx, name) } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. @@ -312,7 +324,8 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil) if e != nil { - return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} + err := winError("dnsquery", e) + return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} } defer syscall.DnsRecordListFree(rec, 1) @@ -326,7 +339,7 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { } func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { - if r.preferGoOverWindows() { + if systemConf().mustUseGoResolver(r) { return r.goLookupNS(ctx, name) } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. @@ -335,7 +348,8 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil) if e != nil { - return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} + err := winError("dnsquery", e) + return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} } defer syscall.DnsRecordListFree(rec, 1) @@ -348,7 +362,7 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { } func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { - if r.preferGoOverWindows() { + if systemConf().mustUseGoResolver(r) { return r.goLookupTXT(ctx, name) } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. @@ -357,7 +371,8 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil) if e != nil { - return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} + err := winError("dnsquery", e) + return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} } defer syscall.DnsRecordListFree(rec, 1) @@ -374,7 +389,7 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) } func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { - if order, conf := systemConf().hostLookupOrder(r, ""); order != hostLookupCgo { + if order, conf := systemConf().addrLookupOrder(r, addr); order != hostLookupCgo { return r.goLookupPTR(ctx, addr, order, conf) } @@ -388,7 +403,8 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error var rec *syscall.DNSRecord e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil) if e != nil { - return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr} + err := winError("dnsquery", e) + return nil, &DNSError{Err: err.Error(), Name: addr, IsNotFound: err == errNoSuchHost} } defer syscall.DnsRecordListFree(rec, 1) diff --git a/src/net/mail/message.go b/src/net/mail/message.go index af516fc30f..fc2a9e46f8 100644 --- a/src/net/mail/message.go +++ b/src/net/mail/message.go @@ -280,7 +280,7 @@ func (a *Address) String() string { // Add quotes if needed quoteLocal := false for i, r := range local { - if isAtext(r, false, false) { + if isAtext(r, false) { continue } if r == '.' { @@ -444,7 +444,7 @@ func (p *addrParser) parseAddress(handleGroup bool) ([]*Address, error) { if !p.consume('<') { atext := true for _, r := range displayName { - if !isAtext(r, true, false) { + if !isAtext(r, true) { atext = false break } @@ -479,7 +479,9 @@ func (p *addrParser) consumeGroupList() ([]*Address, error) { // handle empty group. p.skipSpace() if p.consume(';') { - p.skipCFWS() + if !p.skipCFWS() { + return nil, errors.New("mail: misformatted parenthetical comment") + } return group, nil } @@ -496,7 +498,9 @@ func (p *addrParser) consumeGroupList() ([]*Address, error) { return nil, errors.New("mail: misformatted parenthetical comment") } if p.consume(';') { - p.skipCFWS() + if !p.skipCFWS() { + return nil, errors.New("mail: misformatted parenthetical comment") + } break } if !p.consume(',') { @@ -566,6 +570,12 @@ func (p *addrParser) consumePhrase() (phrase string, err error) { var words []string var isPrevEncoded bool for { + // obs-phrase allows CFWS after one word + if len(words) > 0 { + if !p.skipCFWS() { + return "", errors.New("mail: misformatted parenthetical comment") + } + } // word = atom / quoted-string var word string p.skipSpace() @@ -661,7 +671,6 @@ Loop: // If dot is true, consumeAtom parses an RFC 5322 dot-atom instead. // If permissive is true, consumeAtom will not fail on: // - leading/trailing/double dots in the atom (see golang.org/issue/4938) -// - special characters (RFC 5322 3.2.3) except '<', '>', ':' and '"' (see golang.org/issue/21018) func (p *addrParser) consumeAtom(dot bool, permissive bool) (atom string, err error) { i := 0 @@ -672,7 +681,7 @@ Loop: case size == 1 && r == utf8.RuneError: return "", fmt.Errorf("mail: invalid utf-8 in address: %q", p.s) - case size == 0 || !isAtext(r, dot, permissive): + case size == 0 || !isAtext(r, dot): break Loop default: @@ -850,18 +859,13 @@ func (e charsetError) Error() string { // isAtext reports whether r is an RFC 5322 atext character. // If dot is true, period is included. -// If permissive is true, RFC 5322 3.2.3 specials is included, -// except '<', '>', ':' and '"'. -func isAtext(r rune, dot, permissive bool) bool { +func isAtext(r rune, dot bool) bool { switch r { case '.': return dot // RFC 5322 3.2.3. specials - case '(', ')', '[', ']', ';', '@', '\\', ',': - return permissive - - case '<', '>', '"', ':': + case '(', ')', '<', '>', '[', ']', ':', ';', '@', '\\', ',', '"': // RFC 5322 3.2.3. specials return false } return isVchar(r) diff --git a/src/net/mail/message_test.go b/src/net/mail/message_test.go index 1e1bb4092f..1f2f62afbf 100644 --- a/src/net/mail/message_test.go +++ b/src/net/mail/message_test.go @@ -385,8 +385,11 @@ func TestAddressParsingError(t *testing.T) { 13: {"group not closed: null@example.com", "expected comma"}, 14: {"group: first@example.com, second@example.com;", "group with multiple addresses"}, 15: {"john.doe", "missing '@' or angle-addr"}, - 16: {"john.doe@", "no angle-addr"}, + 16: {"john.doe@", "missing '@' or angle-addr"}, 17: {"John Doe@foo.bar", "no angle-addr"}, + 18: {" group: null@example.com; (asd", "misformatted parenthetical comment"}, + 19: {" group: ; (asd", "misformatted parenthetical comment"}, + 20: {`(John) Doe <jdoe@machine.example>`, "missing word in phrase:"}, } for i, tc := range mustErrTestCases { @@ -436,24 +439,19 @@ func TestAddressParsing(t *testing.T) { Address: "john.q.public@example.com", }}, }, - { - `"John (middle) Doe" <jdoe@machine.example>`, - []*Address{{ - Name: "John (middle) Doe", - Address: "jdoe@machine.example", - }}, - }, + // Comment in display name { `John (middle) Doe <jdoe@machine.example>`, []*Address{{ - Name: "John (middle) Doe", + Name: "John Doe", Address: "jdoe@machine.example", }}, }, + // Display name is quoted string, so comment is not a comment { - `John !@M@! Doe <jdoe@machine.example>`, + `"John (middle) Doe" <jdoe@machine.example>`, []*Address{{ - Name: "John !@M@! Doe", + Name: "John (middle) Doe", Address: "jdoe@machine.example", }}, }, @@ -788,6 +786,26 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // Comment in group display name + { + `group (comment:): a@example.com, b@example.com;`, + []*Address{ + { + Address: "a@example.com", + }, + { + Address: "b@example.com", + }, + }, + }, + { + `x(:"):"@a.example;("@b.example;`, + []*Address{ + { + Address: `@a.example;(@b.example`, + }, + }, + }, } for _, test := range tests { if len(test.exp) == 1 { diff --git a/src/net/main_conf_test.go b/src/net/main_conf_test.go index 28a1cb8351..307ff5dd8c 100644 --- a/src/net/main_conf_test.go +++ b/src/net/main_conf_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !plan9 && !wasip1 +//go:build !plan9 package net diff --git a/src/net/main_noconf_test.go b/src/net/main_noconf_test.go index 077a36e5d6..cdd7c54805 100644 --- a/src/net/main_noconf_test.go +++ b/src/net/main_noconf_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || plan9 || wasip1 +//go:build plan9 package net diff --git a/src/net/main_posix_test.go b/src/net/main_posix_test.go index a7942ee327..24a2a55660 100644 --- a/src/net/main_posix_test.go +++ b/src/net/main_posix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !plan9 && !wasip1 +//go:build !plan9 package net diff --git a/src/net/main_test.go b/src/net/main_test.go index 9fd5c88543..7dc1e3ee0d 100644 --- a/src/net/main_test.go +++ b/src/net/main_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -16,6 +14,7 @@ import ( "strings" "sync" "testing" + "time" ) var ( @@ -61,6 +60,20 @@ func TestMain(m *testing.M) { os.Exit(st) } +// mustSetDeadline calls the bound method m to set a deadline on a Conn. +// If the call fails, mustSetDeadline skips t if the current GOOS is believed +// not to support deadlines, or fails the test otherwise. +func mustSetDeadline(t testing.TB, m func(time.Time) error, d time.Duration) { + err := m(time.Now().Add(d)) + if err != nil { + t.Helper() + if runtime.GOOS == "plan9" { + t.Skipf("skipping: %s does not support deadlines", runtime.GOOS) + } + t.Fatal(err) + } +} + type ipv6LinkLocalUnicastTest struct { network, address string nameLookup bool diff --git a/src/net/main_wasm_test.go b/src/net/main_wasm_test.go new file mode 100644 index 0000000000..b8196bb283 --- /dev/null +++ b/src/net/main_wasm_test.go @@ -0,0 +1,13 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build wasip1 || js + +package net + +func installTestHooks() {} + +func uninstallTestHooks() {} + +func forceCloseSockets() {} diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go index 07f21b72eb..bc024c0bbd 100644 --- a/src/net/main_windows_test.go +++ b/src/net/main_windows_test.go @@ -8,7 +8,6 @@ import "internal/poll" var ( // Placeholders for saving original socket system calls. - origSocket = socketFunc origWSASocket = wsaSocketFunc origClosesocket = poll.CloseFunc origConnect = connectFunc @@ -18,7 +17,6 @@ var ( ) func installTestHooks() { - socketFunc = sw.Socket wsaSocketFunc = sw.WSASocket poll.CloseFunc = sw.Closesocket connectFunc = sw.Connect @@ -28,7 +26,6 @@ func installTestHooks() { } func uninstallTestHooks() { - socketFunc = origSocket wsaSocketFunc = origWSASocket poll.CloseFunc = origClosesocket connectFunc = origConnect diff --git a/src/net/mockserver_test.go b/src/net/mockserver_test.go index f86dd66a2d..46b2a57321 100644 --- a/src/net/mockserver_test.go +++ b/src/net/mockserver_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -339,6 +337,7 @@ func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig return c } + t.Helper() switch network { case "udp": if supportsIPv4() { @@ -359,7 +358,6 @@ func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig return listenPacket(network, testUnixAddr(t)) } - t.Helper() t.Fatalf("%s is not supported", network) return nil } diff --git a/src/net/net.go b/src/net/net.go index 5cfc25ffca..c434c96bf8 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -8,8 +8,8 @@ TCP/IP, UDP, domain name resolution, and Unix domain sockets. Although the package provides access to low-level networking primitives, most clients will need only the basic interface provided -by the Dial, Listen, and Accept functions and the associated -Conn and Listener interfaces. The crypto/tls package uses +by the [Dial], [Listen], and Accept functions and the associated +[Conn] and [Listener] interfaces. The crypto/tls package uses the same interfaces and similar Dial and Listen functions. The Dial function connects to a server: @@ -39,7 +39,7 @@ The Listen function creates servers: # Name Resolution The method for resolving domain names, whether indirectly with functions like Dial -or directly with functions like LookupHost and LookupAddr, varies by operating system. +or directly with functions like [LookupHost] and [LookupAddr], varies by operating system. On Unix systems, the resolver has two options for resolving names. It can use a pure Go resolver that sends DNS requests directly to the servers @@ -95,8 +95,8 @@ import ( // Addr represents a network end point address. // -// The two methods Network and String conventionally return strings -// that can be passed as the arguments to Dial, but the exact form +// The two methods [Addr.Network] and [Addr.String] conventionally return strings +// that can be passed as the arguments to [Dial], but the exact form // and meaning of the strings is up to the implementation. type Addr interface { Network() string // name of the network (for example, "tcp", "udp") @@ -284,7 +284,7 @@ func (c *conn) SetWriteBuffer(bytes int) error { return nil } -// File returns a copy of the underlying os.File. +// File returns a copy of the underlying [os.File]. // It is the caller's responsibility to close f when finished. // Closing c does not affect f, and closing f does not affect c. // @@ -624,7 +624,11 @@ type DNSError struct { Server string // server used IsTimeout bool // if true, timed out; not all timeouts set this IsTemporary bool // if true, error is temporary; not all errors set this - IsNotFound bool // if true, host could not be found + + // IsNotFound is set to true when the requested name does not + // contain any records of the requested type (data not found), + // or the name itself was not found (NXDOMAIN). + IsNotFound bool } func (e *DNSError) Error() string { @@ -641,12 +645,12 @@ func (e *DNSError) Error() string { // Timeout reports whether the DNS lookup is known to have timed out. // This is not always known; a DNS lookup may fail due to a timeout -// and return a DNSError for which Timeout returns false. +// and return a [DNSError] for which Timeout returns false. func (e *DNSError) Timeout() bool { return e.IsTimeout } // Temporary reports whether the DNS error is known to be temporary. // This is not always known; a DNS lookup may fail due to a temporary -// error and return a DNSError for which Temporary returns false. +// error and return a [DNSError] for which Temporary returns false. func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary } // errClosed exists just so that the docs for ErrClosed don't mention @@ -660,15 +664,53 @@ var errClosed = poll.ErrNetClosing // errors.Is(err, net.ErrClosed). var ErrClosed error = errClosed -type writerOnly struct { - io.Writer +// noReadFrom can be embedded alongside another type to +// hide the ReadFrom method of that other type. +type noReadFrom struct{} + +// ReadFrom hides another ReadFrom method. +// It should never be called. +func (noReadFrom) ReadFrom(io.Reader) (int64, error) { + panic("can't happen") +} + +// tcpConnWithoutReadFrom implements all the methods of *TCPConn other +// than ReadFrom. This is used to permit ReadFrom to call io.Copy +// without leading to a recursive call to ReadFrom. +type tcpConnWithoutReadFrom struct { + noReadFrom + *TCPConn } // Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't // applicable. -func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { +func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) { // Use wrapper to hide existing r.ReadFrom from io.Copy. - return io.Copy(writerOnly{w}, r) + return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r) +} + +// noWriteTo can be embedded alongside another type to +// hide the WriteTo method of that other type. +type noWriteTo struct{} + +// WriteTo hides another WriteTo method. +// It should never be called. +func (noWriteTo) WriteTo(io.Writer) (int64, error) { + panic("can't happen") +} + +// tcpConnWithoutWriteTo implements all the methods of *TCPConn other +// than WriteTo. This is used to permit WriteTo to call io.Copy +// without leading to a recursive call to WriteTo. +type tcpConnWithoutWriteTo struct { + noWriteTo + *TCPConn +} + +// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable. +func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) { + // Use wrapper to hide existing w.WriteTo from io.Copy. + return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c}) } // Limit the number of concurrent cgo-using goroutines, because @@ -714,7 +756,7 @@ var ( // WriteTo writes contents of the buffers to w. // -// WriteTo implements io.WriterTo for Buffers. +// WriteTo implements [io.WriterTo] for [Buffers]. // // WriteTo modifies the slice v as well as v[i] for 0 <= i < len(v), // but does not modify v[i][j] for any i, j. @@ -736,7 +778,7 @@ func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) { // Read from the buffers. // -// Read implements io.Reader for Buffers. +// Read implements [io.Reader] for [Buffers]. // // Read modifies the slice v as well as v[i] for 0 <= i < len(v), // but does not modify v[i][j] for any i, j. diff --git a/src/net/net_fake.go b/src/net/net_fake.go index 908767a1f6..525ff32296 100644 --- a/src/net/net_fake.go +++ b/src/net/net_fake.go @@ -2,405 +2,1169 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Fake networking for js/wasm and wasip1/wasm. It is intended to allow tests of other package to pass. +// Fake networking for js/wasm and wasip1/wasm. +// It is intended to allow tests of other package to pass. -//go:build (js && wasm) || wasip1 +//go:build js || wasip1 package net import ( "context" + "errors" "io" "os" + "runtime" "sync" + "sync/atomic" "syscall" "time" ) -var listenersMu sync.Mutex -var listeners = make(map[fakeNetAddr]*netFD) - -var portCounterMu sync.Mutex -var portCounter = 0 +var ( + sockets sync.Map // fakeSockAddr → *netFD + fakeSocketIDs sync.Map // fakeNetFD.id → *netFD + fakePorts sync.Map // int (port #) → *netFD + nextPortCounter atomic.Int32 +) -func nextPort() int { - portCounterMu.Lock() - defer portCounterMu.Unlock() - portCounter++ - return portCounter -} +const defaultBuffer = 65535 -type fakeNetAddr struct { - network string +type fakeSockAddr struct { + family int address string } -type fakeNetFD struct { - listener fakeNetAddr - r *bufferedPipe - w *bufferedPipe - incoming chan *netFD - closedMu sync.Mutex - closed bool +func fakeAddr(sa sockaddr) fakeSockAddr { + return fakeSockAddr{ + family: sa.family(), + address: sa.String(), + } } // socket returns a network file descriptor that is ready for -// asynchronous I/O using the network poller. +// I/O using the fake network. func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) { - fd := &netFD{family: family, sotype: sotype, net: net} - if laddr != nil && raddr == nil { - return fakelistener(fd, laddr) + if raddr != nil && ctrlCtxFn != nil { + return nil, os.NewSyscallError("socket", syscall.ENOTSUP) } - fd2 := &netFD{family: family, sotype: sotype, net: net} - return fakeconn(fd, fd2, laddr, raddr) -} - -func fakeIPAndPort(ip IP, port int) (IP, int) { - if ip == nil { - ip = IPv4(127, 0, 0, 1) + switch sotype { + case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET, syscall.SOCK_DGRAM: + default: + return nil, os.NewSyscallError("socket", syscall.ENOTSUP) } - if port == 0 { - port = nextPort() + + fd := &netFD{ + family: family, + sotype: sotype, + net: net, } - return ip, port -} + fd.fakeNetFD = newFakeNetFD(fd) -func fakeTCPAddr(addr *TCPAddr) *TCPAddr { - var ip IP - var port int - var zone string - if addr != nil { - ip, port, zone = addr.IP, addr.Port, addr.Zone + if raddr == nil { + if err := fakeListen(fd, laddr); err != nil { + fd.Close() + return nil, err + } + return fd, nil } - ip, port = fakeIPAndPort(ip, port) - return &TCPAddr{IP: ip, Port: port, Zone: zone} -} -func fakeUDPAddr(addr *UDPAddr) *UDPAddr { - var ip IP - var port int - var zone string - if addr != nil { - ip, port, zone = addr.IP, addr.Port, addr.Zone + if err := fakeConnect(ctx, fd, laddr, raddr); err != nil { + fd.Close() + return nil, err } - ip, port = fakeIPAndPort(ip, port) - return &UDPAddr{IP: ip, Port: port, Zone: zone} + return fd, nil } -func fakeUnixAddr(sotype int, addr *UnixAddr) *UnixAddr { - var net, name string - if addr != nil { - name = addr.Name +func validateResolvedAddr(net string, family int, sa sockaddr) error { + validateIP := func(ip IP) error { + switch family { + case syscall.AF_INET: + if len(ip) != 4 { + return &AddrError{ + Err: "non-IPv4 address", + Addr: ip.String(), + } + } + case syscall.AF_INET6: + if len(ip) != 16 { + return &AddrError{ + Err: "non-IPv6 address", + Addr: ip.String(), + } + } + default: + panic("net: unexpected address family in validateResolvedAddr") + } + return nil } - switch sotype { - case syscall.SOCK_DGRAM: - net = "unixgram" - case syscall.SOCK_SEQPACKET: - net = "unixpacket" + + switch net { + case "tcp", "tcp4", "tcp6": + sa, ok := sa.(*TCPAddr) + if !ok { + return &AddrError{ + Err: "non-TCP address for " + net + " network", + Addr: sa.String(), + } + } + if err := validateIP(sa.IP); err != nil { + return err + } + if sa.Port <= 0 || sa.Port >= 1<<16 { + return &AddrError{ + Err: "port out of range", + Addr: sa.String(), + } + } + return nil + + case "udp", "udp4", "udp6": + sa, ok := sa.(*UDPAddr) + if !ok { + return &AddrError{ + Err: "non-UDP address for " + net + " network", + Addr: sa.String(), + } + } + if err := validateIP(sa.IP); err != nil { + return err + } + if sa.Port <= 0 || sa.Port >= 1<<16 { + return &AddrError{ + Err: "port out of range", + Addr: sa.String(), + } + } + return nil + + case "unix", "unixgram", "unixpacket": + sa, ok := sa.(*UnixAddr) + if !ok { + return &AddrError{ + Err: "non-Unix address for " + net + " network", + Addr: sa.String(), + } + } + if sa.Name != "" { + i := len(sa.Name) - 1 + for i > 0 && !os.IsPathSeparator(sa.Name[i]) { + i-- + } + for i > 0 && os.IsPathSeparator(sa.Name[i]) { + i-- + } + if i <= 0 { + return &AddrError{ + Err: "unix socket name missing path component", + Addr: sa.Name, + } + } + if _, err := os.Stat(sa.Name[:i+1]); err != nil { + return &AddrError{ + Err: err.Error(), + Addr: sa.Name, + } + } + } + return nil + default: - net = "unix" + return &AddrError{ + Err: syscall.EAFNOSUPPORT.Error(), + Addr: sa.String(), + } } - return &UnixAddr{Net: net, Name: name} } -func fakelistener(fd *netFD, laddr sockaddr) (*netFD, error) { - switch l := laddr.(type) { +func matchIPFamily(family int, addr sockaddr) sockaddr { + convertIP := func(ip IP) IP { + switch family { + case syscall.AF_INET: + return ip.To4() + case syscall.AF_INET6: + return ip.To16() + default: + return ip + } + } + + switch addr := addr.(type) { case *TCPAddr: - laddr = fakeTCPAddr(l) + ip := convertIP(addr.IP) + if ip == nil || len(ip) == len(addr.IP) { + return addr + } + return &TCPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone} case *UDPAddr: - laddr = fakeUDPAddr(l) - case *UnixAddr: - if l.Name == "" { - return nil, syscall.ENOENT + ip := convertIP(addr.IP) + if ip == nil || len(ip) == len(addr.IP) { + return addr } - laddr = fakeUnixAddr(fd.sotype, l) + return &UDPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone} default: - return nil, syscall.EOPNOTSUPP + return addr } +} - listener := fakeNetAddr{ - network: laddr.Network(), - address: laddr.String(), - } +type fakeNetFD struct { + fd *netFD + assignedPort int // 0 if no port has been assigned for this socket + + queue *packetQueue // incoming packets + peer *netFD // connected peer (for outgoing packets); nil for listeners and PacketConns + readDeadline atomic.Pointer[deadlineTimer] + writeDeadline atomic.Pointer[deadlineTimer] + + fakeAddr fakeSockAddr // cached fakeSockAddr equivalent of fd.laddr - fd.fakeNetFD = &fakeNetFD{ - listener: listener, - incoming: make(chan *netFD, 1024), + // The incoming channels hold incoming connections that have not yet been accepted. + // All of these channels are 1-buffered. + incoming chan []*netFD // holds the queue when it has >0 but <SOMAXCONN pending connections; closed when the Listener is closed + incomingFull chan []*netFD // holds the queue when it has SOMAXCONN pending connections + incomingEmpty chan bool // holds true when the incoming queue is empty +} + +func newFakeNetFD(fd *netFD) *fakeNetFD { + ffd := &fakeNetFD{fd: fd} + ffd.readDeadline.Store(newDeadlineTimer(noDeadline)) + ffd.writeDeadline.Store(newDeadlineTimer(noDeadline)) + return ffd +} + +func (ffd *fakeNetFD) Read(p []byte) (n int, err error) { + n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, false, nil) + return n, err +} + +func (ffd *fakeNetFD) Write(p []byte) (nn int, err error) { + peer := ffd.peer + if peer == nil { + if ffd.fd.raddr == nil { + return 0, os.NewSyscallError("write", syscall.ENOTCONN) + } + peeri, _ := sockets.Load(fakeAddr(ffd.fd.raddr.(sockaddr))) + if peeri == nil { + return 0, os.NewSyscallError("write", syscall.ECONNRESET) + } + peer = peeri.(*netFD) + if peer.queue == nil { + return 0, os.NewSyscallError("write", syscall.ECONNRESET) + } } - fd.laddr = laddr - listenersMu.Lock() - defer listenersMu.Unlock() - if _, exists := listeners[listener]; exists { - return nil, syscall.EADDRINUSE + if peer.fakeNetFD == nil { + return 0, os.NewSyscallError("write", syscall.EINVAL) } - listeners[listener] = fd - return fd, nil + return peer.queue.write(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr)) } -func fakeconn(fd *netFD, fd2 *netFD, laddr, raddr sockaddr) (*netFD, error) { - switch r := raddr.(type) { - case *TCPAddr: - r = fakeTCPAddr(r) - raddr = r - laddr = fakeTCPAddr(laddr.(*TCPAddr)) - case *UDPAddr: - r = fakeUDPAddr(r) - raddr = r - laddr = fakeUDPAddr(laddr.(*UDPAddr)) - case *UnixAddr: - r = fakeUnixAddr(fd.sotype, r) - raddr = r - laddr = &UnixAddr{Net: r.Net, Name: r.Name} - default: - return nil, syscall.EAFNOSUPPORT +func (ffd *fakeNetFD) Close() (err error) { + if ffd.fakeAddr != (fakeSockAddr{}) { + sockets.CompareAndDelete(ffd.fakeAddr, ffd.fd) } - fd.laddr = laddr - fd.raddr = raddr - fd.fakeNetFD = &fakeNetFD{ - r: newBufferedPipe(65536), - w: newBufferedPipe(65536), + if ffd.queue != nil { + if closeErr := ffd.queue.closeRead(); err == nil { + err = closeErr + } } - fd2.fakeNetFD = &fakeNetFD{ - r: fd.fakeNetFD.w, - w: fd.fakeNetFD.r, + if ffd.peer != nil { + if closeErr := ffd.peer.queue.closeWrite(); err == nil { + err = closeErr + } } + ffd.readDeadline.Load().Reset(noDeadline) + ffd.writeDeadline.Load().Reset(noDeadline) - fd2.laddr = fd.raddr - fd2.raddr = fd.laddr + if ffd.incoming != nil { + var ( + incoming []*netFD + ok bool + ) + select { + case _, ok = <-ffd.incomingEmpty: + case incoming, ok = <-ffd.incoming: + case incoming, ok = <-ffd.incomingFull: + } + if ok { + // Sends on ffd.incoming require a receive first. + // Since we successfully received, no other goroutine may + // send on it at this point, and we may safely close it. + close(ffd.incoming) - listener := fakeNetAddr{ - network: fd.raddr.Network(), - address: fd.raddr.String(), + for _, c := range incoming { + c.Close() + } + } } - listenersMu.Lock() - defer listenersMu.Unlock() - l, ok := listeners[listener] - if !ok { - return nil, syscall.ECONNREFUSED + + if ffd.assignedPort != 0 { + fakePorts.CompareAndDelete(ffd.assignedPort, ffd.fd) } - l.incoming <- fd2 - return fd, nil + + return err } -func (fd *fakeNetFD) Read(p []byte) (n int, err error) { - return fd.r.Read(p) +func (ffd *fakeNetFD) closeRead() error { + return ffd.queue.closeRead() } -func (fd *fakeNetFD) Write(p []byte) (nn int, err error) { - return fd.w.Write(p) +func (ffd *fakeNetFD) closeWrite() error { + if ffd.peer == nil { + return os.NewSyscallError("closeWrite", syscall.ENOTCONN) + } + return ffd.peer.queue.closeWrite() } -func (fd *fakeNetFD) Close() error { - fd.closedMu.Lock() - if fd.closed { - fd.closedMu.Unlock() - return nil +func (ffd *fakeNetFD) accept(laddr Addr) (*netFD, error) { + if ffd.incoming == nil { + return nil, os.NewSyscallError("accept", syscall.EINVAL) } - fd.closed = true - fd.closedMu.Unlock() - if fd.listener != (fakeNetAddr{}) { - listenersMu.Lock() - delete(listeners, fd.listener) - close(fd.incoming) - fd.listener = fakeNetAddr{} - listenersMu.Unlock() - return nil + var ( + incoming []*netFD + ok bool + ) + select { + case <-ffd.readDeadline.Load().expired: + return nil, os.ErrDeadlineExceeded + case incoming, ok = <-ffd.incoming: + if !ok { + return nil, ErrClosed + } + case incoming, ok = <-ffd.incomingFull: } - fd.r.Close() - fd.w.Close() - return nil + peer := incoming[0] + incoming = incoming[1:] + if len(incoming) == 0 { + ffd.incomingEmpty <- true + } else { + ffd.incoming <- incoming + } + return peer, nil } -func (fd *fakeNetFD) closeRead() error { - fd.r.Close() +func (ffd *fakeNetFD) SetDeadline(t time.Time) error { + err1 := ffd.SetReadDeadline(t) + err2 := ffd.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} + +func (ffd *fakeNetFD) SetReadDeadline(t time.Time) error { + dt := ffd.readDeadline.Load() + if !dt.Reset(t) { + ffd.readDeadline.Store(newDeadlineTimer(t)) + } return nil } -func (fd *fakeNetFD) closeWrite() error { - fd.w.Close() +func (ffd *fakeNetFD) SetWriteDeadline(t time.Time) error { + dt := ffd.writeDeadline.Load() + if !dt.Reset(t) { + ffd.writeDeadline.Store(newDeadlineTimer(t)) + } return nil } -func (fd *fakeNetFD) accept() (*netFD, error) { - c, ok := <-fd.incoming - if !ok { - return nil, syscall.EINVAL +const maxPacketSize = 65535 + +type packet struct { + buf []byte + bufOffset int + next *packet + from sockaddr +} + +func (p *packet) clear() { + p.buf = p.buf[:0] + p.bufOffset = 0 + p.next = nil + p.from = nil +} + +var packetPool = sync.Pool{ + New: func() any { return new(packet) }, +} + +type packetQueueState struct { + head, tail *packet // unqueued packets + nBytes int // number of bytes enqueued in the packet buffers starting from head + readBufferBytes int // soft limit on nbytes; no more packets may be enqueued when the limit is exceeded + readClosed bool // true if the reader of the queue has stopped reading + writeClosed bool // true if the writer of the queue has stopped writing; the reader sees either io.EOF or syscall.ECONNRESET when they have read all buffered packets + noLinger bool // if true, the reader sees ECONNRESET instead of EOF +} + +// A packetQueue is a set of 1-buffered channels implementing a FIFO queue +// of packets. +type packetQueue struct { + empty chan packetQueueState // contains configuration parameters when the queue is empty and not closed + ready chan packetQueueState // contains the packets when non-empty or closed + full chan packetQueueState // contains the packets when buffer is full and not closed +} + +func newPacketQueue(readBufferBytes int) *packetQueue { + pq := &packetQueue{ + empty: make(chan packetQueueState, 1), + ready: make(chan packetQueueState, 1), + full: make(chan packetQueueState, 1), + } + pq.put(packetQueueState{ + readBufferBytes: readBufferBytes, + }) + return pq +} + +func (pq *packetQueue) get() packetQueueState { + var q packetQueueState + select { + case q = <-pq.empty: + case q = <-pq.ready: + case q = <-pq.full: + } + return q +} + +func (pq *packetQueue) put(q packetQueueState) { + switch { + case q.readClosed || q.writeClosed: + pq.ready <- q + case q.nBytes >= q.readBufferBytes: + pq.full <- q + case q.head == nil: + if q.nBytes > 0 { + defer panic("net: put with nil packet list and nonzero nBytes") + } + pq.empty <- q + default: + pq.ready <- q } - return c, nil } -func (fd *fakeNetFD) SetDeadline(t time.Time) error { - fd.r.SetReadDeadline(t) - fd.w.SetWriteDeadline(t) +func (pq *packetQueue) closeRead() error { + q := pq.get() + + // Discard any unread packets. + for q.head != nil { + p := q.head + q.head = p.next + p.clear() + packetPool.Put(p) + } + q.nBytes = 0 + + q.readClosed = true + pq.put(q) return nil } -func (fd *fakeNetFD) SetReadDeadline(t time.Time) error { - fd.r.SetReadDeadline(t) +func (pq *packetQueue) closeWrite() error { + q := pq.get() + q.writeClosed = true + pq.put(q) return nil } -func (fd *fakeNetFD) SetWriteDeadline(t time.Time) error { - fd.w.SetWriteDeadline(t) +func (pq *packetQueue) setLinger(linger bool) error { + q := pq.get() + defer func() { pq.put(q) }() + + if q.writeClosed { + return ErrClosed + } + q.noLinger = !linger return nil } -func newBufferedPipe(softLimit int) *bufferedPipe { - p := &bufferedPipe{softLimit: softLimit} - p.rCond.L = &p.mu - p.wCond.L = &p.mu - return p +func (pq *packetQueue) write(dt *deadlineTimer, b []byte, from sockaddr) (n int, err error) { + for { + dn := len(b) + if dn > maxPacketSize { + dn = maxPacketSize + } + + dn, err = pq.send(dt, b[:dn], from, true) + n += dn + if err != nil { + return n, err + } + + b = b[dn:] + if len(b) == 0 { + return n, nil + } + } } -type bufferedPipe struct { - softLimit int - mu sync.Mutex - buf []byte - closed bool - rCond sync.Cond - wCond sync.Cond - rDeadline time.Time - wDeadline time.Time +func (pq *packetQueue) send(dt *deadlineTimer, b []byte, from sockaddr, block bool) (n int, err error) { + if from == nil { + return 0, os.NewSyscallError("send", syscall.EINVAL) + } + if len(b) > maxPacketSize { + return 0, os.NewSyscallError("send", syscall.EMSGSIZE) + } + + var q packetQueueState + var full chan packetQueueState + if !block { + full = pq.full + } + + // Before we check dt.expired, yield to other goroutines. + // This may help to prevent starvation of the goroutine that runs the + // deadlineTimer's time.After callback. + // + // TODO(#65178): Remove this when the runtime scheduler no longer starves + // runnable goroutines. + runtime.Gosched() + + select { + case <-dt.expired: + return 0, os.ErrDeadlineExceeded + + case q = <-full: + pq.put(q) + return 0, os.NewSyscallError("send", syscall.ENOBUFS) + + case q = <-pq.empty: + case q = <-pq.ready: + } + defer func() { pq.put(q) }() + + // Don't allow a packet to be sent if the deadline has expired, + // even if the select above chose a different branch. + select { + case <-dt.expired: + return 0, os.ErrDeadlineExceeded + default: + } + if q.writeClosed { + return 0, ErrClosed + } else if q.readClosed { + return 0, os.NewSyscallError("send", syscall.ECONNRESET) + } + + p := packetPool.Get().(*packet) + p.buf = append(p.buf[:0], b...) + p.from = from + + if q.head == nil { + q.head = p + } else { + q.tail.next = p + } + q.tail = p + q.nBytes += len(p.buf) + + return len(b), nil } -func (p *bufferedPipe) Read(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() +func (pq *packetQueue) recvfrom(dt *deadlineTimer, b []byte, wholePacket bool, checkFrom func(sockaddr) error) (n int, from sockaddr, err error) { + var q packetQueueState + var empty chan packetQueueState + if len(b) == 0 { + // For consistency with the implementation on Unix platforms, + // allow a zero-length Read to proceed if the queue is empty. + // (Without this, TestZeroByteRead deadlocks.) + empty = pq.empty + } - for { - if p.closed && len(p.buf) == 0 { - return 0, io.EOF - } - if !p.rDeadline.IsZero() { - d := time.Until(p.rDeadline) - if d <= 0 { - return 0, os.ErrDeadlineExceeded + // Before we check dt.expired, yield to other goroutines. + // This may help to prevent starvation of the goroutine that runs the + // deadlineTimer's time.After callback. + // + // TODO(#65178): Remove this when the runtime scheduler no longer starves + // runnable goroutines. + runtime.Gosched() + + select { + case <-dt.expired: + return 0, nil, os.ErrDeadlineExceeded + case q = <-empty: + case q = <-pq.ready: + case q = <-pq.full: + } + defer func() { pq.put(q) }() + + p := q.head + if p == nil { + switch { + case q.readClosed: + return 0, nil, ErrClosed + case q.writeClosed: + if q.noLinger { + return 0, nil, os.NewSyscallError("recvfrom", syscall.ECONNRESET) } - time.AfterFunc(d, p.rCond.Broadcast) + return 0, nil, io.EOF + case len(b) == 0: + return 0, nil, nil + default: + // This should be impossible: pq.full should only contain a non-empty list, + // pq.ready should either contain a non-empty list or indicate that the + // connection is closed, and we should only receive from pq.empty if + // len(b) == 0. + panic("net: nil packet list from non-closed packetQueue") } - if len(p.buf) > 0 { - break + } + + select { + case <-dt.expired: + return 0, nil, os.ErrDeadlineExceeded + default: + } + + if checkFrom != nil { + if err := checkFrom(p.from); err != nil { + return 0, nil, err } - p.rCond.Wait() } - n := copy(b, p.buf) - p.buf = p.buf[n:] - p.wCond.Broadcast() - return n, nil + n = copy(b, p.buf[p.bufOffset:]) + from = p.from + if wholePacket || p.bufOffset+n == len(p.buf) { + q.head = p.next + q.nBytes -= len(p.buf) + p.clear() + packetPool.Put(p) + } else { + p.bufOffset += n + } + + return n, from, nil +} + +// setReadBuffer sets a soft limit on the number of bytes available to read +// from the pipe. +func (pq *packetQueue) setReadBuffer(bytes int) error { + if bytes <= 0 { + return os.NewSyscallError("setReadBuffer", syscall.EINVAL) + } + q := pq.get() // Use the queue as a lock. + q.readBufferBytes = bytes + pq.put(q) + return nil +} + +type deadlineTimer struct { + timer chan *time.Timer + expired chan struct{} } -func (p *bufferedPipe) Write(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() +func newDeadlineTimer(deadline time.Time) *deadlineTimer { + dt := &deadlineTimer{ + timer: make(chan *time.Timer, 1), + expired: make(chan struct{}), + } + dt.timer <- nil + dt.Reset(deadline) + return dt +} - for { - if p.closed { - return 0, syscall.ENOTCONN +// Reset attempts to reset the timer. +// If the timer has already expired, Reset returns false. +func (dt *deadlineTimer) Reset(deadline time.Time) bool { + timer := <-dt.timer + defer func() { dt.timer <- timer }() + + if deadline.Equal(noDeadline) { + if timer != nil && timer.Stop() { + timer = nil } - if !p.wDeadline.IsZero() { - d := time.Until(p.wDeadline) - if d <= 0 { - return 0, os.ErrDeadlineExceeded + return timer == nil + } + + d := time.Until(deadline) + if d < 0 { + // Ensure that a deadline in the past takes effect immediately. + defer func() { <-dt.expired }() + } + + if timer == nil { + timer = time.AfterFunc(d, func() { close(dt.expired) }) + return true + } + if !timer.Stop() { + return false + } + timer.Reset(d) + return true +} + +func sysSocket(family, sotype, proto int) (int, error) { + return 0, os.NewSyscallError("sysSocket", syscall.ENOSYS) +} + +func fakeListen(fd *netFD, laddr sockaddr) (err error) { + wrapErr := func(err error) error { + if errno, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError("listen", errno) + } + if errors.Is(err, syscall.EADDRINUSE) { + return err + } + if laddr != nil { + if _, ok := err.(*AddrError); !ok { + err = &AddrError{ + Err: err.Error(), + Addr: laddr.String(), + } } - time.AfterFunc(d, p.wCond.Broadcast) } - if len(p.buf) <= p.softLimit { - break + return err + } + + ffd := newFakeNetFD(fd) + defer func() { + if fd.fakeNetFD != ffd { + // Failed to register listener; clean up. + ffd.Close() } - p.wCond.Wait() + }() + + if err := ffd.assignFakeAddr(matchIPFamily(fd.family, laddr)); err != nil { + return wrapErr(err) } - p.buf = append(p.buf, b...) - p.rCond.Broadcast() - return len(b), nil -} + ffd.fakeAddr = fakeAddr(fd.laddr.(sockaddr)) + switch fd.sotype { + case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET: + ffd.incoming = make(chan []*netFD, 1) + ffd.incomingFull = make(chan []*netFD, 1) + ffd.incomingEmpty = make(chan bool, 1) + ffd.incomingEmpty <- true + case syscall.SOCK_DGRAM: + ffd.queue = newPacketQueue(defaultBuffer) + default: + return wrapErr(syscall.EINVAL) + } -func (p *bufferedPipe) Close() { - p.mu.Lock() - defer p.mu.Unlock() + fd.fakeNetFD = ffd + if _, dup := sockets.LoadOrStore(ffd.fakeAddr, fd); dup { + fd.fakeNetFD = nil + return wrapErr(syscall.EADDRINUSE) + } - p.closed = true - p.rCond.Broadcast() - p.wCond.Broadcast() + return nil } -func (p *bufferedPipe) SetReadDeadline(t time.Time) { - p.mu.Lock() - defer p.mu.Unlock() +func fakeConnect(ctx context.Context, fd *netFD, laddr, raddr sockaddr) error { + wrapErr := func(err error) error { + if errno, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError("connect", errno) + } + if errors.Is(err, syscall.EADDRINUSE) { + return err + } + if terr, ok := err.(interface{ Timeout() bool }); !ok || !terr.Timeout() { + // For consistency with the net implementation on other platforms, + // if we don't need to preserve the Timeout-ness of err we should + // wrap it in an AddrError. (Unfortunately we can't wrap errors + // that convey structured information, because AddrError reduces + // the wrapped Err to a flat string.) + if _, ok := err.(*AddrError); !ok { + err = &AddrError{ + Err: err.Error(), + Addr: raddr.String(), + } + } + } + return err + } + + if fd.isConnected { + return wrapErr(syscall.EISCONN) + } + if ctx.Err() != nil { + return wrapErr(syscall.ETIMEDOUT) + } + + fd.raddr = matchIPFamily(fd.family, raddr) + if err := validateResolvedAddr(fd.net, fd.family, fd.raddr.(sockaddr)); err != nil { + return wrapErr(err) + } + + if err := fd.fakeNetFD.assignFakeAddr(laddr); err != nil { + return wrapErr(err) + } + fd.fakeNetFD.queue = newPacketQueue(defaultBuffer) + + switch fd.sotype { + case syscall.SOCK_DGRAM: + if ua, ok := fd.laddr.(*UnixAddr); !ok || ua.Name != "" { + fd.fakeNetFD.fakeAddr = fakeAddr(fd.laddr.(sockaddr)) + if _, dup := sockets.LoadOrStore(fd.fakeNetFD.fakeAddr, fd); dup { + return wrapErr(syscall.EADDRINUSE) + } + } + fd.isConnected = true + return nil + + case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET: + default: + return wrapErr(syscall.EINVAL) + } + + fa := fakeAddr(raddr) + lni, ok := sockets.Load(fa) + if !ok { + return wrapErr(syscall.ECONNREFUSED) + } + ln := lni.(*netFD) + if ln.sotype != fd.sotype { + return wrapErr(syscall.EPROTOTYPE) + } + if ln.incoming == nil { + return wrapErr(syscall.ECONNREFUSED) + } + + peer := &netFD{ + family: ln.family, + sotype: ln.sotype, + net: ln.net, + laddr: ln.laddr, + raddr: fd.laddr, + isConnected: true, + } + peer.fakeNetFD = newFakeNetFD(fd) + peer.fakeNetFD.queue = newPacketQueue(defaultBuffer) + defer func() { + if fd.peer != peer { + // Failed to connect; clean up. + peer.Close() + } + }() + + var incoming []*netFD + select { + case <-ctx.Done(): + return wrapErr(syscall.ETIMEDOUT) + case ok = <-ln.incomingEmpty: + case incoming, ok = <-ln.incoming: + } + if !ok { + return wrapErr(syscall.ECONNREFUSED) + } + + fd.isConnected = true + fd.peer = peer + peer.peer = fd - p.rDeadline = t - p.rCond.Broadcast() + incoming = append(incoming, peer) + if len(incoming) >= listenerBacklog() { + ln.incomingFull <- incoming + } else { + ln.incoming <- incoming + } + return nil } -func (p *bufferedPipe) SetWriteDeadline(t time.Time) { - p.mu.Lock() - defer p.mu.Unlock() +func (ffd *fakeNetFD) assignFakeAddr(addr sockaddr) error { + validate := func(sa sockaddr) error { + if err := validateResolvedAddr(ffd.fd.net, ffd.fd.family, sa); err != nil { + return err + } + ffd.fd.laddr = sa + return nil + } + + assignIP := func(addr sockaddr) error { + var ( + ip IP + port int + zone string + ) + switch addr := addr.(type) { + case *TCPAddr: + if addr != nil { + ip = addr.IP + port = addr.Port + zone = addr.Zone + } + case *UDPAddr: + if addr != nil { + ip = addr.IP + port = addr.Port + zone = addr.Zone + } + default: + return validate(addr) + } + + if ip == nil { + ip = IPv4(127, 0, 0, 1) + } + switch ffd.fd.family { + case syscall.AF_INET: + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + case syscall.AF_INET6: + if ip16 := ip.To16(); ip16 != nil { + ip = ip16 + } + } + if ip == nil { + return syscall.EINVAL + } + + if port == 0 { + var prevPort int32 + portWrapped := false + nextPort := func() (int, bool) { + for { + port := nextPortCounter.Add(1) + if port <= 0 || port >= 1<<16 { + // nextPortCounter ran off the end of the port space. + // Bump it back into range. + for { + if nextPortCounter.CompareAndSwap(port, 0) { + break + } + if port = nextPortCounter.Load(); port >= 0 && port < 1<<16 { + break + } + } + if portWrapped { + // This is the second wraparound, so we've scanned the whole port space + // at least once already and it's time to give up. + return 0, false + } + portWrapped = true + prevPort = 0 + continue + } + + if port <= prevPort { + // nextPortCounter has wrapped around since the last time we read it. + if portWrapped { + // This is the second wraparound, so we've scanned the whole port space + // at least once already and it's time to give up. + return 0, false + } else { + portWrapped = true + } + } + + prevPort = port + return int(port), true + } + } + + for { + var ok bool + port, ok = nextPort() + if !ok { + ffd.assignedPort = 0 + return syscall.EADDRINUSE + } + + ffd.assignedPort = int(port) + if _, dup := fakePorts.LoadOrStore(ffd.assignedPort, ffd.fd); !dup { + break + } + } + } + + switch addr.(type) { + case *TCPAddr: + return validate(&TCPAddr{IP: ip, Port: port, Zone: zone}) + case *UDPAddr: + return validate(&UDPAddr{IP: ip, Port: port, Zone: zone}) + default: + panic("unreachable") + } + } + + switch ffd.fd.net { + case "tcp", "tcp4", "tcp6": + if addr == nil { + return assignIP(new(TCPAddr)) + } + return assignIP(addr) + + case "udp", "udp4", "udp6": + if addr == nil { + return assignIP(new(UDPAddr)) + } + return assignIP(addr) + + case "unix", "unixgram", "unixpacket": + uaddr, ok := addr.(*UnixAddr) + if !ok && addr != nil { + return &AddrError{ + Err: "non-Unix address for " + ffd.fd.net + " network", + Addr: addr.String(), + } + } + if uaddr == nil { + return validate(&UnixAddr{Net: ffd.fd.net}) + } + return validate(&UnixAddr{Net: ffd.fd.net, Name: uaddr.Name}) - p.wDeadline = t - p.wCond.Broadcast() + default: + return &AddrError{ + Err: syscall.EAFNOSUPPORT.Error(), + Addr: addr.String(), + } + } } -func sysSocket(family, sotype, proto int) (int, error) { - return 0, syscall.ENOSYS +func (ffd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { + if ffd.queue == nil { + return 0, nil, os.NewSyscallError("readFrom", syscall.EINVAL) + } + + n, from, err := ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, nil) + + if from != nil { + // Convert the net.sockaddr to a syscall.Sockaddr type. + var saErr error + sa, saErr = from.sockaddr(ffd.fd.family) + if err == nil { + err = saErr + } + } + + return n, sa, err } -func (fd *fakeNetFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) { - return nil, syscall.ENOSYS +func (ffd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { + n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error { + fromSA, err := from.sockaddr(syscall.AF_INET) + if err != nil { + return err + } + if fromSA == nil { + return os.NewSyscallError("readFromInet4", syscall.EINVAL) + } + *sa = *(fromSA.(*syscall.SockaddrInet4)) + return nil + }) + return n, err } -func (fd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { - return 0, nil, syscall.ENOSYS +func (ffd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { + n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error { + fromSA, err := from.sockaddr(syscall.AF_INET6) + if err != nil { + return err + } + if fromSA == nil { + return os.NewSyscallError("readFromInet6", syscall.EINVAL) + } + *sa = *(fromSA.(*syscall.SockaddrInet6)) + return nil + }) + return n, err +} +func (ffd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { + if flags != 0 { + return 0, 0, 0, nil, os.NewSyscallError("readMsg", syscall.ENOTSUP) + } + n, sa, err = ffd.readFrom(p) + return n, 0, 0, sa, err } -func (fd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { - return 0, syscall.ENOSYS + +func (ffd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { + if flags != 0 { + return 0, 0, 0, os.NewSyscallError("readMsgInet4", syscall.ENOTSUP) + } + n, err = ffd.readFromInet4(p, sa) + return n, 0, 0, err } -func (fd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { - return 0, syscall.ENOSYS +func (ffd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { + if flags != 0 { + return 0, 0, 0, os.NewSyscallError("readMsgInet6", syscall.ENOTSUP) + } + n, err = ffd.readFromInet6(p, sa) + return n, 0, 0, err } -func (fd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { - return 0, 0, 0, nil, syscall.ENOSYS +func (ffd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { + if len(oob) > 0 { + return 0, 0, os.NewSyscallError("writeMsg", syscall.ENOTSUP) + } + n, err = ffd.writeTo(p, sa) + return n, 0, err } -func (fd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { - return 0, 0, 0, syscall.ENOSYS +func (ffd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { + return ffd.writeMsg(p, oob, sa) } -func (fd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { - return 0, 0, 0, syscall.ENOSYS +func (ffd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { + return ffd.writeMsg(p, oob, sa) } -func (fd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { - return 0, 0, syscall.ENOSYS +func (ffd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { + raddr := ffd.fd.raddr + if sa != nil { + if ffd.fd.isConnected { + return 0, os.NewSyscallError("writeTo", syscall.EISCONN) + } + raddr = ffd.fd.addrFunc()(sa) + } + if raddr == nil { + return 0, os.NewSyscallError("writeTo", syscall.EINVAL) + } + + peeri, _ := sockets.Load(fakeAddr(raddr.(sockaddr))) + if peeri == nil { + if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" { + return len(p), nil + } + return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET) + } + peer := peeri.(*netFD) + if peer.queue == nil { + if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" { + return len(p), nil + } + return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET) + } + + block := true + if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" { + block = false + } + return peer.queue.send(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr), block) } -func (fd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { - return 0, 0, syscall.ENOSYS +func (ffd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { + return ffd.writeTo(p, sa) } -func (fd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { - return 0, syscall.ENOSYS +func (ffd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { + return ffd.writeTo(p, sa) } -func (fd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { - return 0, syscall.ENOSYS +func (ffd *fakeNetFD) dup() (f *os.File, err error) { + return nil, os.NewSyscallError("dup", syscall.ENOSYS) } -func (fd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { - return 0, syscall.ENOSYS +func (ffd *fakeNetFD) setReadBuffer(bytes int) error { + if ffd.queue == nil { + return os.NewSyscallError("setReadBuffer", syscall.EINVAL) + } + ffd.queue.setReadBuffer(bytes) + return nil } -func (fd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { - return 0, 0, syscall.ENOSYS +func (ffd *fakeNetFD) setWriteBuffer(bytes int) error { + return os.NewSyscallError("setWriteBuffer", syscall.ENOTSUP) } -func (fd *fakeNetFD) dup() (f *os.File, err error) { - return nil, syscall.ENOSYS +func (ffd *fakeNetFD) setLinger(sec int) error { + if sec < 0 || ffd.peer == nil { + return os.NewSyscallError("setLinger", syscall.EINVAL) + } + ffd.peer.queue.setLinger(sec > 0) + return nil } diff --git a/src/net/net_fake_js.go b/src/net/net_fake_js.go deleted file mode 100644 index 7ba108b664..0000000000 --- a/src/net/net_fake_js.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Fake networking for js/wasm. It is intended to allow tests of other package to pass. - -//go:build js && wasm - -package net - -import ( - "context" - "internal/poll" - - "golang.org/x/net/dns/dnsmessage" -) - -// Network file descriptor. -type netFD struct { - *fakeNetFD - - // immutable until Close - family int - sotype int - net string - laddr Addr - raddr Addr - - // unused - pfd poll.FD - isConnected bool // handshake completed or use of association with peer -} - -func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) { - panic("unreachable") -} diff --git a/src/net/net_fake_test.go b/src/net/net_fake_test.go index 783304d531..4542228fbc 100644 --- a/src/net/net_fake_test.go +++ b/src/net/net_fake_test.go @@ -13,191 +13,95 @@ package net // The tests in this files are intended to validate the behavior of the fake // network stack on these platforms. -import "testing" - -func TestFakeConn(t *testing.T) { - tests := []struct { - name string - listen func() (Listener, error) - dial func(Addr) (Conn, error) - addr func(*testing.T, Addr) - }{ - { - name: "Listener:tcp", - listen: func() (Listener, error) { - return Listen("tcp", ":0") - }, - dial: func(addr Addr) (Conn, error) { - return Dial(addr.Network(), addr.String()) - }, - addr: testFakeTCPAddr, - }, - - { - name: "ListenTCP:tcp", - listen: func() (Listener, error) { - // Creating a listening TCP connection with a nil address must - // select an IP address on localhost with a random port. - // This test verifies that the fake network facility does that. - return ListenTCP("tcp", nil) - }, - dial: func(addr Addr) (Conn, error) { - // Connecting a listening TCP connection will select a local - // address on the local network and connects to the destination - // address. - return DialTCP("tcp", nil, addr.(*TCPAddr)) - }, - addr: testFakeTCPAddr, - }, - - { - name: "ListenUnix:unix", - listen: func() (Listener, error) { - return ListenUnix("unix", &UnixAddr{Name: "test"}) - }, - dial: func(addr Addr) (Conn, error) { - return DialUnix("unix", nil, addr.(*UnixAddr)) - }, - addr: testFakeUnixAddr("unix", "test"), - }, - - { - name: "ListenUnix:unixpacket", - listen: func() (Listener, error) { - return ListenUnix("unixpacket", &UnixAddr{Name: "test"}) - }, - dial: func(addr Addr) (Conn, error) { - return DialUnix("unixpacket", nil, addr.(*UnixAddr)) - }, - addr: testFakeUnixAddr("unixpacket", "test"), - }, +import ( + "errors" + "syscall" + "testing" +) + +func TestFakePortExhaustion(t *testing.T) { + if testing.Short() { + t.Skipf("skipping test that opens 1<<16 connections") } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - l, err := test.listen() - if err != nil { - t.Fatal(err) + ln := newLocalListener(t, "tcp") + done := make(chan struct{}) + go func() { + var accepted []Conn + defer func() { + for _, c := range accepted { + c.Close() } - defer l.Close() - test.addr(t, l.Addr()) + close(done) + }() - c, err := test.dial(l.Addr()) + for { + c, err := ln.Accept() if err != nil { - t.Fatal(err) + return } - defer c.Close() - test.addr(t, c.LocalAddr()) - test.addr(t, c.RemoteAddr()) - }) - } -} - -func TestFakePacketConn(t *testing.T) { - tests := []struct { - name string - listen func() (PacketConn, error) - dial func(Addr) (Conn, error) - addr func(*testing.T, Addr) - }{ - { - name: "ListenPacket:udp", - listen: func() (PacketConn, error) { - return ListenPacket("udp", ":0") - }, - dial: func(addr Addr) (Conn, error) { - return Dial(addr.Network(), addr.String()) - }, - addr: testFakeUDPAddr, - }, - - { - name: "ListenUDP:udp", - listen: func() (PacketConn, error) { - // Creating a listening UDP connection with a nil address must - // select an IP address on localhost with a random port. - // This test verifies that the fake network facility does that. - return ListenUDP("udp", nil) - }, - dial: func(addr Addr) (Conn, error) { - // Connecting a listening UDP connection will select a local - // address on the local network and connects to the destination - // address. - return DialUDP("udp", nil, addr.(*UDPAddr)) - }, - addr: testFakeUDPAddr, - }, + accepted = append(accepted, c) + } + }() - { - name: "ListenUnixgram:unixgram", - listen: func() (PacketConn, error) { - return ListenUnixgram("unixgram", &UnixAddr{Name: "test"}) - }, - dial: func(addr Addr) (Conn, error) { - return DialUnix("unixgram", nil, addr.(*UnixAddr)) - }, - addr: testFakeUnixAddr("unixgram", "test"), - }, + var dialed []Conn + defer func() { + ln.Close() + for _, c := range dialed { + c.Close() + } + <-done + }() + + // Since this test is not running in parallel, we expect to be able to open + // all 65535 valid (fake) ports. The listener is already using one, so + // we should be able to Dial the remaining 65534. + for len(dialed) < (1<<16)-2 { + c, err := Dial(ln.Addr().Network(), ln.Addr().String()) + if err != nil { + t.Fatalf("unexpected error from Dial with %v connections: %v", len(dialed), err) + } + dialed = append(dialed, c) + if testing.Verbose() && len(dialed)%(1<<12) == 0 { + t.Logf("dialed %d connections", len(dialed)) + } } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - l, err := test.listen() - if err != nil { - t.Fatal(err) - } - defer l.Close() - test.addr(t, l.LocalAddr()) - - c, err := test.dial(l.LocalAddr()) - if err != nil { - t.Fatal(err) - } - defer c.Close() - test.addr(t, c.LocalAddr()) - test.addr(t, c.RemoteAddr()) - }) + t.Logf("dialed %d connections", len(dialed)) + + // Now that all of the ports are in use, dialing another should fail due + // to port exhaustion, which (for POSIX-like socket APIs) should return + // an EADDRINUSE error. + c, err := Dial(ln.Addr().Network(), ln.Addr().String()) + if err == nil { + c.Close() } -} - -func testFakeTCPAddr(t *testing.T, addr Addr) { - t.Helper() - if a, ok := addr.(*TCPAddr); !ok { - t.Errorf("Addr is not *TCPAddr: %T", addr) + if errors.Is(err, syscall.EADDRINUSE) { + t.Logf("Dial returned expected error: %v", err) } else { - testFakeNetAddr(t, a.IP, a.Port) + t.Errorf("unexpected error from Dial: %v\nwant: %v", err, syscall.EADDRINUSE) } -} -func testFakeUDPAddr(t *testing.T, addr Addr) { - t.Helper() - if a, ok := addr.(*UDPAddr); !ok { - t.Errorf("Addr is not *UDPAddr: %T", addr) + // Opening a Listener should fail at this point too. + ln2, err := Listen("tcp", "localhost:0") + if err == nil { + ln2.Close() + } + if errors.Is(err, syscall.EADDRINUSE) { + t.Logf("Listen returned expected error: %v", err) } else { - testFakeNetAddr(t, a.IP, a.Port) + t.Errorf("unexpected error from Listen: %v\nwant: %v", err, syscall.EADDRINUSE) } -} -func testFakeNetAddr(t *testing.T, ip IP, port int) { - t.Helper() - if port == 0 { - t.Error("network address is missing port") - } else if len(ip) == 0 { - t.Error("network address is missing IP") - } else if !ip.Equal(IPv4(127, 0, 0, 1)) { - t.Errorf("network address has wrong IP: %s", ip) - } -} - -func testFakeUnixAddr(net, name string) func(*testing.T, Addr) { - return func(t *testing.T, addr Addr) { - t.Helper() - if a, ok := addr.(*UnixAddr); !ok { - t.Errorf("Addr is not *UnixAddr: %T", addr) - } else if a.Net != net { - t.Errorf("unix address has wrong net: want=%q got=%q", net, a.Net) - } else if a.Name != name { - t.Errorf("unix address has wrong name: want=%q got=%q", name, a.Name) - } + // When we close an arbitrary connection, we should be able to reuse its port + // even if the server hasn't yet seen the ECONNRESET for the connection. + dialed[0].Close() + dialed = dialed[1:] + t.Logf("closed one connection") + c, err = Dial(ln.Addr().Network(), ln.Addr().String()) + if err == nil { + c.Close() + t.Logf("Dial succeeded") + } else { + t.Errorf("unexpected error from Dial: %v", err) } } diff --git a/src/net/net_test.go b/src/net/net_test.go index a0ac85f406..b448a79cce 100644 --- a/src/net/net_test.go +++ b/src/net/net_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -383,8 +381,16 @@ func TestZeroByteRead(t *testing.T) { ln := newLocalListener(t, network) connc := make(chan Conn, 1) + defer func() { + ln.Close() + for c := range connc { + if c != nil { + c.Close() + } + } + }() go func() { - defer ln.Close() + defer close(connc) c, err := ln.Accept() if err != nil { t.Error(err) @@ -440,8 +446,9 @@ func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) { errc <- err return } - defer c1.Close() - errc <- peer1(c1.(*TCPConn)) + err = peer1(c1.(*TCPConn)) + c1.Close() + errc <- err }() go func() { c2, err := Dial("tcp", ln.Addr().String()) @@ -449,12 +456,13 @@ func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) { errc <- err return } - defer c2.Close() - errc <- peer2(c2.(*TCPConn)) + err = peer2(c2.(*TCPConn)) + c2.Close() + errc <- err }() for i := 0; i < 2; i++ { if err := <-errc; err != nil { - t.Fatal(err) + t.Error(err) } } } diff --git a/src/net/netip/export_test.go b/src/net/netip/export_test.go index 59971fa2e4..72347ee01b 100644 --- a/src/net/netip/export_test.go +++ b/src/net/netip/export_test.go @@ -28,3 +28,5 @@ var TestAppendToMarshal = testAppendToMarshal func (a Addr) IsZero() bool { return a.isZero() } func (p Prefix) IsZero() bool { return p.isZero() } + +func (p Prefix) Compare(p2 Prefix) int { return p.compare(p2) } diff --git a/src/net/netip/leaf_alts.go b/src/net/netip/leaf_alts.go index 70513abfd9..d887bed627 100644 --- a/src/net/netip/leaf_alts.go +++ b/src/net/netip/leaf_alts.go @@ -7,15 +7,6 @@ package netip -func stringsLastIndexByte(s string, b byte) int { - for i := len(s) - 1; i >= 0; i-- { - if s[i] == b { - return i - } - } - return -1 -} - func beUint64(b []byte) uint64 { _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | diff --git a/src/net/netip/netip.go b/src/net/netip/netip.go index a44b094955..7a189e8e16 100644 --- a/src/net/netip/netip.go +++ b/src/net/netip/netip.go @@ -12,6 +12,7 @@ package netip import ( + "cmp" "errors" "math" "strconv" @@ -127,7 +128,7 @@ func ParseAddr(s string) (Addr, error) { return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"} } -// MustParseAddr calls ParseAddr(s) and panics on error. +// MustParseAddr calls [ParseAddr](s) and panics on error. // It is intended for use in tests with hard-coded strings. func MustParseAddr(s string) Addr { ip, err := ParseAddr(s) @@ -334,8 +335,8 @@ func parseIPv6(in string) (Addr, error) { } // AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address. -// Note that a net.IP can be passed directly as the []byte argument. -// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false. +// Note that a [net.IP] can be passed directly as the []byte argument. +// If slice's length is not 4 or 16, AddrFromSlice returns [Addr]{}, false. func AddrFromSlice(slice []byte) (ip Addr, ok bool) { switch len(slice) { case 4: @@ -375,13 +376,13 @@ func (ip Addr) isZero() bool { return ip.z == z0 } -// IsValid reports whether the Addr is an initialized address (not the zero Addr). +// IsValid reports whether the [Addr] is an initialized address (not the zero Addr). // // Note that "0.0.0.0" and "::" are both valid values. func (ip Addr) IsValid() bool { return ip.z != z0 } // BitLen returns the number of bits in the IP address: -// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr. +// 128 for IPv6, 32 for IPv4, and 0 for the zero [Addr]. // // Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses // and therefore have bit length 128. @@ -406,7 +407,7 @@ func (ip Addr) Zone() string { // Compare returns an integer comparing two IPs. // The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2. -// The definition of "less than" is the same as the Less method. +// The definition of "less than" is the same as the [Addr.Less] method. func (ip Addr) Compare(ip2 Addr) int { f1, f2 := ip.BitLen(), ip2.BitLen() if f1 < f2 { @@ -448,7 +449,7 @@ func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 } // Is4 reports whether ip is an IPv4 address. // -// It returns false for IPv4-mapped IPv6 addresses. See Addr.Unmap. +// It returns false for IPv4-mapped IPv6 addresses. See [Addr.Unmap]. func (ip Addr) Is4() bool { return ip.z == z4 } @@ -582,7 +583,7 @@ func (ip Addr) IsLinkLocalMulticast() bool { // IANA-allocated 2000::/3 global unicast space, with the exception of the // link-local address space. It also returns true even if ip is in the IPv4 // private address space or IPv6 unique local address space. -// It returns false for the zero Addr. +// It returns false for the zero [Addr]. // // For reference, see RFC 1122, RFC 4291, and RFC 4632. func (ip Addr) IsGlobalUnicast() bool { @@ -606,7 +607,7 @@ func (ip Addr) IsGlobalUnicast() bool { // IsPrivate reports whether ip is a private address, according to RFC 1918 // (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether // ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the -// same as net.IP.IsPrivate. +// same as [net.IP.IsPrivate]. func (ip Addr) IsPrivate() bool { // Match the stdlib's IsPrivate logic. if ip.Is4() { @@ -629,14 +630,14 @@ func (ip Addr) IsPrivate() bool { // IsUnspecified reports whether ip is an unspecified address, either the IPv4 // address "0.0.0.0" or the IPv6 address "::". // -// Note that the zero Addr is not an unspecified address. +// Note that the zero [Addr] is not an unspecified address. func (ip Addr) IsUnspecified() bool { return ip == IPv4Unspecified() || ip == IPv6Unspecified() } // Prefix keeps only the top b bits of IP, producing a Prefix // of the specified length. -// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error. +// If ip is a zero [Addr], Prefix always returns a zero Prefix and a nil error. // Otherwise, if bits is less than zero or greater than ip.BitLen(), // Prefix returns an error. func (ip Addr) Prefix(b int) (Prefix, error) { @@ -661,15 +662,10 @@ func (ip Addr) Prefix(b int) (Prefix, error) { return PrefixFrom(ip, b), nil } -const ( - netIPv4len = 4 - netIPv6len = 16 -) - // As16 returns the IP address in its 16-byte representation. // IPv4 addresses are returned as IPv4-mapped IPv6 addresses. // IPv6 addresses with zones are returned without their zone (use the -// Zone method to get it). +// [Addr.Zone] method to get it). // The ip zero value returns all zeroes. func (ip Addr) As16() (a16 [16]byte) { bePutUint64(a16[:8], ip.addr.hi) @@ -678,7 +674,7 @@ func (ip Addr) As16() (a16 [16]byte) { } // As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation. -// If ip is the zero Addr or an IPv6 address, As4 panics. +// If ip is the zero [Addr] or an IPv6 address, As4 panics. // Note that 0.0.0.0 is not the zero Addr. func (ip Addr) As4() (a4 [4]byte) { if ip.z == z4 || ip.Is4In6() { @@ -709,7 +705,7 @@ func (ip Addr) AsSlice() []byte { } // Next returns the address following ip. -// If there is none, it returns the zero Addr. +// If there is none, it returns the zero [Addr]. func (ip Addr) Next() Addr { ip.addr = ip.addr.addOne() if ip.Is4() { @@ -743,10 +739,10 @@ func (ip Addr) Prev() Addr { // String returns the string form of the IP address ip. // It returns one of 5 forms: // -// - "invalid IP", if ip is the zero Addr +// - "invalid IP", if ip is the zero [Addr] // - IPv4 dotted decimal ("192.0.2.1") // - IPv6 ("2001:db8::1") -// - "::ffff:1.2.3.4" (if Is4In6) +// - "::ffff:1.2.3.4" (if [Addr.Is4In6]) // - IPv6 with zone ("fe80:db8::1%eth0") // // Note that unlike package net's IP.String method, @@ -771,7 +767,7 @@ func (ip Addr) String() string { } // AppendTo appends a text encoding of ip, -// as generated by MarshalText, +// as generated by [Addr.MarshalText], // to b and returns the extended buffer. func (ip Addr) AppendTo(b []byte) []byte { switch ip.z { @@ -903,7 +899,7 @@ func (ip Addr) appendTo6(ret []byte) []byte { return ret } -// StringExpanded is like String but IPv6 addresses are expanded with leading +// StringExpanded is like [Addr.String] but IPv6 addresses are expanded with leading // zeroes and no "::" compression. For example, "2001:db8::1" becomes // "2001:0db8:0000:0000:0000:0000:0000:0001". func (ip Addr) StringExpanded() string { @@ -931,9 +927,9 @@ func (ip Addr) StringExpanded() string { return string(ret) } -// MarshalText implements the encoding.TextMarshaler interface, -// The encoding is the same as returned by String, with one exception: -// If ip is the zero Addr, the encoding is the empty string. +// MarshalText implements the [encoding.TextMarshaler] interface, +// The encoding is the same as returned by [Addr.String], with one exception: +// If ip is the zero [Addr], the encoding is the empty string. func (ip Addr) MarshalText() ([]byte, error) { switch ip.z { case z0: @@ -960,9 +956,9 @@ func (ip Addr) MarshalText() ([]byte, error) { } // UnmarshalText implements the encoding.TextUnmarshaler interface. -// The IP address is expected in a form accepted by ParseAddr. +// The IP address is expected in a form accepted by [ParseAddr]. // -// If text is empty, UnmarshalText sets *ip to the zero Addr and +// If text is empty, UnmarshalText sets *ip to the zero [Addr] and // returns no error. func (ip *Addr) UnmarshalText(text []byte) error { if len(text) == 0 { @@ -992,15 +988,15 @@ func (ip Addr) marshalBinaryWithTrailingBytes(trailingBytes int) []byte { return b } -// MarshalBinary implements the encoding.BinaryMarshaler interface. -// It returns a zero-length slice for the zero Addr, +// MarshalBinary implements the [encoding.BinaryMarshaler] interface. +// It returns a zero-length slice for the zero [Addr], // the 4-byte form for an IPv4 address, // and the 16-byte form with zone appended for an IPv6 address. func (ip Addr) MarshalBinary() ([]byte, error) { return ip.marshalBinaryWithTrailingBytes(0), nil } -// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface. // It expects data in the form generated by MarshalBinary. func (ip *Addr) UnmarshalBinary(b []byte) error { n := len(b) @@ -1027,7 +1023,7 @@ type AddrPort struct { port uint16 } -// AddrPortFrom returns an AddrPort with the provided IP and port. +// AddrPortFrom returns an [AddrPort] with the provided IP and port. // It does not allocate. func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} } @@ -1043,7 +1039,7 @@ func (p AddrPort) Port() uint16 { return p.port } // ip string should parse as an IPv6 address or an IPv4 address, in // order for s to be a valid ip:port string. func splitAddrPort(s string) (ip, port string, v6 bool, err error) { - i := stringsLastIndexByte(s, ':') + i := bytealg.LastIndexByteString(s, ':') if i == -1 { return "", "", false, errors.New("not an ip:port") } @@ -1066,7 +1062,7 @@ func splitAddrPort(s string) (ip, port string, v6 bool, err error) { return ip, port, v6, nil } -// ParseAddrPort parses s as an AddrPort. +// ParseAddrPort parses s as an [AddrPort]. // // It doesn't do any name resolution: both the address and the port // must be numeric. @@ -1093,7 +1089,7 @@ func ParseAddrPort(s string) (AddrPort, error) { return ipp, nil } -// MustParseAddrPort calls ParseAddrPort(s) and panics on error. +// MustParseAddrPort calls [ParseAddrPort](s) and panics on error. // It is intended for use in tests with hard-coded strings. func MustParseAddrPort(s string) AddrPort { ip, err := ParseAddrPort(s) @@ -1107,36 +1103,35 @@ func MustParseAddrPort(s string) AddrPort { // All ports are valid, including zero. func (p AddrPort) IsValid() bool { return p.ip.IsValid() } +// Compare returns an integer comparing two AddrPorts. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// AddrPorts sort first by IP address, then port. +func (p AddrPort) Compare(p2 AddrPort) int { + if c := p.Addr().Compare(p2.Addr()); c != 0 { + return c + } + return cmp.Compare(p.Port(), p2.Port()) +} + func (p AddrPort) String() string { switch p.ip.z { case z0: return "invalid AddrPort" case z4: - a := p.ip.As4() - buf := make([]byte, 0, 21) - for i := range a { - buf = strconv.AppendUint(buf, uint64(a[i]), 10) - buf = append(buf, "...:"[i]) - } + const max = len("255.255.255.255:65535") + buf := make([]byte, 0, max) + buf = p.ip.appendTo4(buf) + buf = append(buf, ':') buf = strconv.AppendUint(buf, uint64(p.port), 10) return string(buf) default: // TODO: this could be more efficient allocation-wise: - return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port))) + return "[" + p.ip.String() + "]:" + itoa.Uitoa(uint(p.port)) } } -func joinHostPort(host, port string) string { - // We assume that host is a literal IPv6 address if host has - // colons. - if bytealg.IndexByteString(host, ':') >= 0 { - return "[" + host + "]:" + port - } - return host + ":" + port -} - // AppendTo appends a text encoding of p, -// as generated by MarshalText, +// as generated by [AddrPort.MarshalText], // to b and returns the extended buffer. func (p AddrPort) AppendTo(b []byte) []byte { switch p.ip.z { @@ -1163,9 +1158,9 @@ func (p AddrPort) AppendTo(b []byte) []byte { return b } -// MarshalText implements the encoding.TextMarshaler interface. The -// encoding is the same as returned by String, with one exception: if -// p.Addr() is the zero Addr, the encoding is the empty string. +// MarshalText implements the [encoding.TextMarshaler] interface. The +// encoding is the same as returned by [AddrPort.String], with one exception: if +// p.Addr() is the zero [Addr], the encoding is the empty string. func (p AddrPort) MarshalText() ([]byte, error) { var max int switch p.ip.z { @@ -1181,8 +1176,8 @@ func (p AddrPort) MarshalText() ([]byte, error) { } // UnmarshalText implements the encoding.TextUnmarshaler -// interface. The AddrPort is expected in a form -// generated by MarshalText or accepted by ParseAddrPort. +// interface. The [AddrPort] is expected in a form +// generated by [AddrPort.MarshalText] or accepted by [ParseAddrPort]. func (p *AddrPort) UnmarshalText(text []byte) error { if len(text) == 0 { *p = AddrPort{} @@ -1193,8 +1188,8 @@ func (p *AddrPort) UnmarshalText(text []byte) error { return err } -// MarshalBinary implements the encoding.BinaryMarshaler interface. -// It returns Addr.MarshalBinary with an additional two bytes appended +// MarshalBinary implements the [encoding.BinaryMarshaler] interface. +// It returns [Addr.MarshalBinary] with an additional two bytes appended // containing the port in little-endian. func (p AddrPort) MarshalBinary() ([]byte, error) { b := p.Addr().marshalBinaryWithTrailingBytes(2) @@ -1202,8 +1197,8 @@ func (p AddrPort) MarshalBinary() ([]byte, error) { return b, nil } -// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. -// It expects data in the form generated by MarshalBinary. +// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface. +// It expects data in the form generated by [AddrPort.MarshalBinary]. func (p *AddrPort) UnmarshalBinary(b []byte) error { if len(b) < 2 { return errors.New("unexpected slice size") @@ -1219,7 +1214,7 @@ func (p *AddrPort) UnmarshalBinary(b []byte) error { // Prefix is an IP address prefix (CIDR) representing an IP network. // -// The first Bits() of Addr() are specified. The remaining bits match any address. +// The first [Prefix.Bits]() of [Addr]() are specified. The remaining bits match any address. // The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6. type Prefix struct { ip Addr @@ -1229,13 +1224,13 @@ type Prefix struct { bitsPlusOne uint8 } -// PrefixFrom returns a Prefix with the provided IP address and bit +// PrefixFrom returns a [Prefix] with the provided IP address and bit // prefix length. // -// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask +// It does not allocate. Unlike [Addr.Prefix], [PrefixFrom] does not mask // off the host bits of ip. // -// If bits is less than zero or greater than ip.BitLen, Prefix.Bits +// If bits is less than zero or greater than ip.BitLen, [Prefix.Bits] // will return an invalid value -1. func PrefixFrom(ip Addr, bits int) Prefix { var bitsPlusOne uint8 @@ -1257,8 +1252,8 @@ func (p Prefix) Addr() Addr { return p.ip } func (p Prefix) Bits() int { return int(p.bitsPlusOne) - 1 } // IsValid reports whether p.Bits() has a valid range for p.Addr(). -// If p.Addr() is the zero Addr, IsValid returns false. -// Note that if p is the zero Prefix, then p.IsValid() == false. +// If p.Addr() is the zero [Addr], IsValid returns false. +// Note that if p is the zero [Prefix], then p.IsValid() == false. func (p Prefix) IsValid() bool { return p.bitsPlusOne > 0 } func (p Prefix) isZero() bool { return p == Prefix{} } @@ -1266,6 +1261,24 @@ func (p Prefix) isZero() bool { return p == Prefix{} } // IsSingleIP reports whether p contains exactly one IP. func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() } +// compare returns an integer comparing two prefixes. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// Prefixes sort first by validity (invalid before valid), then +// address family (IPv4 before IPv6), then prefix length, then +// address. +// +// Unexported for Go 1.22 because we may want to compare by p.Addr first. +// See post-acceptance discussion on go.dev/issue/61642. +func (p Prefix) compare(p2 Prefix) int { + if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 { + return c + } + if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 { + return c + } + return p.Addr().Compare(p2.Addr()) +} + // ParsePrefix parses s as an IP address prefix. // The string can be in the form "192.168.1.0/24" or "2001:db8::/32", // the CIDR notation defined in RFC 4632 and RFC 4291. @@ -1274,7 +1287,7 @@ func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLe // // Note that masked address bits are not zeroed. Use Masked for that. func ParsePrefix(s string) (Prefix, error) { - i := stringsLastIndexByte(s, '/') + i := bytealg.LastIndexByteString(s, '/') if i < 0 { return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'") } @@ -1288,6 +1301,12 @@ func ParsePrefix(s string) (Prefix, error) { } bitsStr := s[i+1:] + + // strconv.Atoi accepts a leading sign and leading zeroes, but we don't want that. + if len(bitsStr) > 1 && (bitsStr[0] < '1' || bitsStr[0] > '9') { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): bad bits after slash: " + strconv.Quote(bitsStr)) + } + bits, err := strconv.Atoi(bitsStr) if err != nil { return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): bad bits after slash: " + strconv.Quote(bitsStr)) @@ -1302,7 +1321,7 @@ func ParsePrefix(s string) (Prefix, error) { return PrefixFrom(ip, bits), nil } -// MustParsePrefix calls ParsePrefix(s) and panics on error. +// MustParsePrefix calls [ParsePrefix](s) and panics on error. // It is intended for use in tests with hard-coded strings. func MustParsePrefix(s string) Prefix { ip, err := ParsePrefix(s) @@ -1315,7 +1334,7 @@ func MustParsePrefix(s string) Prefix { // Masked returns p in its canonical form, with all but the high // p.Bits() bits of p.Addr() masked off. // -// If p is zero or otherwise invalid, Masked returns the zero Prefix. +// If p is zero or otherwise invalid, Masked returns the zero [Prefix]. func (p Prefix) Masked() Prefix { m, _ := p.ip.Prefix(p.Bits()) return m @@ -1392,7 +1411,7 @@ func (p Prefix) Overlaps(o Prefix) bool { } // AppendTo appends a text encoding of p, -// as generated by MarshalText, +// as generated by [Prefix.MarshalText], // to b and returns the extended buffer. func (p Prefix) AppendTo(b []byte) []byte { if p.isZero() { @@ -1419,8 +1438,8 @@ func (p Prefix) AppendTo(b []byte) []byte { return b } -// MarshalText implements the encoding.TextMarshaler interface, -// The encoding is the same as returned by String, with one exception: +// MarshalText implements the [encoding.TextMarshaler] interface, +// The encoding is the same as returned by [Prefix.String], with one exception: // If p is the zero value, the encoding is the empty string. func (p Prefix) MarshalText() ([]byte, error) { var max int @@ -1437,8 +1456,8 @@ func (p Prefix) MarshalText() ([]byte, error) { } // UnmarshalText implements the encoding.TextUnmarshaler interface. -// The IP address is expected in a form accepted by ParsePrefix -// or generated by MarshalText. +// The IP address is expected in a form accepted by [ParsePrefix] +// or generated by [Prefix.MarshalText]. func (p *Prefix) UnmarshalText(text []byte) error { if len(text) == 0 { *p = Prefix{} @@ -1449,8 +1468,8 @@ func (p *Prefix) UnmarshalText(text []byte) error { return err } -// MarshalBinary implements the encoding.BinaryMarshaler interface. -// It returns Addr.MarshalBinary with an additional byte appended +// MarshalBinary implements the [encoding.BinaryMarshaler] interface. +// It returns [Addr.MarshalBinary] with an additional byte appended // containing the prefix bits. func (p Prefix) MarshalBinary() ([]byte, error) { b := p.Addr().withoutZone().marshalBinaryWithTrailingBytes(1) @@ -1458,8 +1477,8 @@ func (p Prefix) MarshalBinary() ([]byte, error) { return b, nil } -// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. -// It expects data in the form generated by MarshalBinary. +// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface. +// It expects data in the form generated by [Prefix.MarshalBinary]. func (p *Prefix) UnmarshalBinary(b []byte) error { if len(b) < 1 { return errors.New("unexpected slice size") diff --git a/src/net/netip/netip_test.go b/src/net/netip/netip_test.go index 0f80bb0ab0..a748ac34f1 100644 --- a/src/net/netip/netip_test.go +++ b/src/net/netip/netip_test.go @@ -14,6 +14,7 @@ import ( "net" . "net/netip" "reflect" + "slices" "sort" "strings" "testing" @@ -389,6 +390,7 @@ func TestAddrPortMarshalTextString(t *testing.T) { want string }{ {mustIPPort("1.2.3.4:80"), "1.2.3.4:80"}, + {mustIPPort("[::]:80"), "[::]:80"}, {mustIPPort("[1::CAFE]:80"), "[1::cafe]:80"}, {mustIPPort("[1::CAFE%en0]:80"), "[1::cafe%en0]:80"}, {mustIPPort("[::FFFF:192.168.140.255]:80"), "[::ffff:192.168.140.255]:80"}, @@ -812,7 +814,7 @@ func TestAddrWellKnown(t *testing.T) { } } -func TestLessCompare(t *testing.T) { +func TestAddrLessCompare(t *testing.T) { tests := []struct { a, b Addr want bool @@ -882,6 +884,109 @@ func TestLessCompare(t *testing.T) { } } +func TestAddrPortCompare(t *testing.T) { + tests := []struct { + a, b AddrPort + want int + }{ + {AddrPort{}, AddrPort{}, 0}, + {AddrPort{}, mustIPPort("1.2.3.4:80"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0}, + {mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1}, + {mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1}, + {mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1}, + } + for _, tt := range tests { + got := tt.a.Compare(tt.b) + if got != tt.want { + t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + + // Also check inverse. + if got == tt.want { + got2 := tt.b.Compare(tt.a) + if want2 := -1 * tt.want; got2 != want2 { + t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2) + } + } + } + + // And just sort. + values := []AddrPort{ + mustIPPort("[::1]:80"), + mustIPPort("[::2]:80"), + AddrPort{}, + mustIPPort("1.2.3.4:443"), + mustIPPort("8.8.8.8:8080"), + mustIPPort("[::1%foo]:1024"), + } + slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) }) + got := fmt.Sprintf("%s", values) + want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + +func TestPrefixCompare(t *testing.T) { + tests := []struct { + a, b Prefix + want int + }{ + {Prefix{}, Prefix{}, 0}, + {Prefix{}, mustPrefix("1.2.3.0/24"), -1}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0}, + {mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1}, + {mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1}, + + {mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1}, + {mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1}, + } + for _, tt := range tests { + got := tt.a.Compare(tt.b) + if got != tt.want { + t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + + // Also check inverse. + if got == tt.want { + got2 := tt.b.Compare(tt.a) + if want2 := -1 * tt.want; got2 != want2 { + t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2) + } + } + } + + // And just sort. + values := []Prefix{ + mustPrefix("1.2.3.0/24"), + mustPrefix("fe90::/64"), + mustPrefix("fe80::/64"), + mustPrefix("1.2.0.0/16"), + Prefix{}, + mustPrefix("fe80::/48"), + mustPrefix("1.2.0.0/24"), + } + slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) }) + got := fmt.Sprintf("%s", values) + want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + func TestIPStringExpanded(t *testing.T) { tests := []struct { ip Addr @@ -1352,7 +1457,7 @@ func TestParsePrefixError(t *testing.T) { }, { prefix: "1.1.1.0/-1", - errstr: "out of range", + errstr: "bad bits", }, { prefix: "1.1.1.0/33", @@ -1371,6 +1476,22 @@ func TestParsePrefixError(t *testing.T) { prefix: "2001:db8::%a/32", errstr: "zones cannot be present", }, + { + prefix: "1.1.1.0/+32", + errstr: "bad bits", + }, + { + prefix: "1.1.1.0/-32", + errstr: "bad bits", + }, + { + prefix: "1.1.1.0/032", + errstr: "bad bits", + }, + { + prefix: "1.1.1.0/0032", + errstr: "bad bits", + }, } for _, test := range tests { t.Run(test.prefix, func(t *testing.T) { diff --git a/src/net/packetconn_test.go b/src/net/packetconn_test.go index dc0c14b93d..e39e7de5d7 100644 --- a/src/net/packetconn_test.go +++ b/src/net/packetconn_test.go @@ -2,10 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This file implements API tests across platforms and will never have a build -// tag. - -//go:build !js && !wasip1 +// This file implements API tests across platforms and should never have a build +// constraint. package net diff --git a/src/net/parse.go b/src/net/parse.go index fbc50144c2..29dffad43c 100644 --- a/src/net/parse.go +++ b/src/net/parse.go @@ -180,42 +180,6 @@ func xtoi2(s string, e byte) (byte, bool) { return byte(n), ok && ei == 2 } -// Convert i to a hexadecimal string. Leading zeros are not printed. -func appendHex(dst []byte, i uint32) []byte { - if i == 0 { - return append(dst, '0') - } - for j := 7; j >= 0; j-- { - v := i >> uint(j*4) - if v > 0 { - dst = append(dst, hexDigit[v&0xf]) - } - } - return dst -} - -// Number of occurrences of b in s. -func count(s string, b byte) int { - n := 0 - for i := 0; i < len(s); i++ { - if s[i] == b { - n++ - } - } - return n -} - -// Index of rightmost occurrence of b in s. -func last(s string, b byte) int { - i := len(s) - for i--; i >= 0; i-- { - if s[i] == b { - break - } - } - return i -} - // hasUpperCase tells whether the given string contains at least one upper-case. func hasUpperCase(s string) bool { for i := range s { diff --git a/src/net/pipe.go b/src/net/pipe.go index f1741938b0..69955e4617 100644 --- a/src/net/pipe.go +++ b/src/net/pipe.go @@ -106,7 +106,7 @@ type pipe struct { } // Pipe creates a synchronous, in-memory, full duplex -// network connection; both ends implement the Conn interface. +// network connection; both ends implement the [Conn] interface. // Reads on one end are matched with writes on the other, // copying data directly between the two; there is no internal // buffering. diff --git a/src/net/platform_test.go b/src/net/platform_test.go index 71e90821ce..709d4a3eb7 100644 --- a/src/net/platform_test.go +++ b/src/net/platform_test.go @@ -165,7 +165,7 @@ func condFatalf(t *testing.T, network string, format string, args ...any) { // A few APIs like File and Read/WriteMsg{UDP,IP} are not // fully implemented yet on Plan 9 and Windows. switch runtime.GOOS { - case "windows": + case "windows", "js", "wasip1": if network == "file+net" { t.Logf(format, args...) return diff --git a/src/net/port_unix.go b/src/net/port_unix.go index 0b2ea3ec5d..df73dbabb3 100644 --- a/src/net/port_unix.go +++ b/src/net/port_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 +//go:build unix || js || wasip1 // Read system port mappings from /etc/services diff --git a/src/net/protoconn_test.go b/src/net/protoconn_test.go index c5668079a9..a617470580 100644 --- a/src/net/protoconn_test.go +++ b/src/net/protoconn_test.go @@ -5,8 +5,6 @@ // This file implements API tests across platforms and will never have a build // tag. -//go:build !js && !wasip1 - package net import ( @@ -39,7 +37,7 @@ func TestTCPListenerSpecificMethods(t *testing.T) { } defer ln.Close() ln.Addr() - ln.SetDeadline(time.Now().Add(30 * time.Nanosecond)) + mustSetDeadline(t, ln.SetDeadline, 30*time.Nanosecond) if c, err := ln.Accept(); err != nil { if !err.(Error).Timeout() { @@ -162,6 +160,10 @@ func TestUDPConnSpecificMethods(t *testing.T) { } func TestIPConnSpecificMethods(t *testing.T) { + if !testableNetwork("ip4") { + t.Skip("skipping: ip4 not supported") + } + la, err := ResolveIPAddr("ip4", "127.0.0.1") if err != nil { t.Fatal(err) @@ -217,7 +219,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) { defer ln.Close() defer os.Remove(addr) ln.Addr() - ln.SetDeadline(time.Now().Add(30 * time.Nanosecond)) + mustSetDeadline(t, ln.SetDeadline, 30*time.Nanosecond) if c, err := ln.Accept(); err != nil { if !err.(Error).Timeout() { @@ -235,7 +237,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) { } if f, err := ln.File(); err != nil { - t.Fatal(err) + condFatalf(t, "file+net", "%v", err) } else { f.Close() } @@ -332,7 +334,7 @@ func TestUnixConnSpecificMethods(t *testing.T) { } if f, err := c1.File(); err != nil { - t.Fatal(err) + condFatalf(t, "file+net", "%v", err) } else { f.Close() } diff --git a/src/net/rawconn.go b/src/net/rawconn.go index 974320c25f..19228e94ed 100644 --- a/src/net/rawconn.go +++ b/src/net/rawconn.go @@ -63,7 +63,7 @@ func (c *rawConn) Write(f func(uintptr) bool) error { // PollFD returns the poll.FD of the underlying connection. // -// Other packages in std that also import internal/poll (such as os) +// Other packages in std that also import [internal/poll] (such as os) // can use a type assertion to access this extension method so that // they can pass the *poll.FD to functions like poll.Splice. // @@ -75,8 +75,19 @@ func (c *rawConn) PollFD() *poll.FD { return &c.fd.pfd } -func newRawConn(fd *netFD) (*rawConn, error) { - return &rawConn{fd: fd}, nil +func newRawConn(fd *netFD) *rawConn { + return &rawConn{fd: fd} +} + +// Network returns the network type of the underlying connection. +// +// Other packages in std that import internal/poll and are unable to +// import net (such as os) can use a type assertion to access this +// extension method so that they can distinguish different socket types. +// +// Network is not intended for use outside the standard library. +func (c *rawConn) Network() poll.String { + return poll.String(c.fd.net) } type rawListener struct { @@ -91,6 +102,6 @@ func (l *rawListener) Write(func(uintptr) bool) error { return syscall.EINVAL } -func newRawListener(fd *netFD) (*rawListener, error) { - return &rawListener{rawConn{fd: fd}}, nil +func newRawListener(fd *netFD) *rawListener { + return &rawListener{rawConn{fd: fd}} } diff --git a/src/net/rawconn_stub_test.go b/src/net/rawconn_stub_test.go index c8ad80cc84..6d54f2df55 100644 --- a/src/net/rawconn_stub_test.go +++ b/src/net/rawconn_stub_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || plan9 || wasip1 +//go:build js || plan9 || wasip1 package net diff --git a/src/net/rawconn_test.go b/src/net/rawconn_test.go index 06d5856a9a..70b16c4115 100644 --- a/src/net/rawconn_test.go +++ b/src/net/rawconn_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -15,7 +13,7 @@ import ( func TestRawConnReadWrite(t *testing.T) { switch runtime.GOOS { - case "plan9": + case "plan9", "js", "wasip1": t.Skipf("not supported on %s", runtime.GOOS) } @@ -169,7 +167,7 @@ func TestRawConnReadWrite(t *testing.T) { func TestRawConnControl(t *testing.T) { switch runtime.GOOS { - case "plan9": + case "plan9", "js", "wasip1": t.Skipf("not supported on %s", runtime.GOOS) } diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go index 1de0402389..1af4199269 100644 --- a/src/net/resolverdialfunc_test.go +++ b/src/net/resolverdialfunc_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - // Test that Resolver.Dial can be a func returning an in-memory net.Conn // speaking DNS. diff --git a/src/net/rlimit_js.go b/src/net/rlimit_js.go new file mode 100644 index 0000000000..9ee5748b21 --- /dev/null +++ b/src/net/rlimit_js.go @@ -0,0 +1,13 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js + +package net + +// concurrentThreadsLimit returns the number of threads we permit to +// run concurrently doing DNS lookups. +func concurrentThreadsLimit() int { + return 500 +} diff --git a/src/net/rlimit_unix.go b/src/net/rlimit_unix.go new file mode 100644 index 0000000000..0094756e3a --- /dev/null +++ b/src/net/rlimit_unix.go @@ -0,0 +1,33 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix || wasip1 + +package net + +import "syscall" + +// concurrentThreadsLimit returns the number of threads we permit to +// run concurrently doing DNS lookups via cgo. A DNS lookup may use a +// file descriptor so we limit this to less than the number of +// permitted open files. On some systems, notably Darwin, if +// getaddrinfo is unable to open a file descriptor it simply returns +// EAI_NONAME rather than a useful error. Limiting the number of +// concurrent getaddrinfo calls to less than the permitted number of +// file descriptors makes that error less likely. We don't bother to +// apply the same limit to DNS lookups run directly from Go, because +// there we will return a meaningful "too many open files" error. +func concurrentThreadsLimit() int { + var rlim syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil { + return 500 + } + r := rlim.Cur + if r > 500 { + r = 500 + } else if r > 30 { + r -= 30 + } + return int(r) +} diff --git a/src/net/rpc/client.go b/src/net/rpc/client.go index 42d13519b1..ffdc435965 100644 --- a/src/net/rpc/client.go +++ b/src/net/rpc/client.go @@ -53,13 +53,13 @@ type Client struct { // A ClientCodec implements writing of RPC requests and // reading of RPC responses for the client side of an RPC session. -// The client calls WriteRequest to write a request to the connection -// and calls ReadResponseHeader and ReadResponseBody in pairs -// to read responses. The client calls Close when finished with the +// The client calls [ClientCodec.WriteRequest] to write a request to the connection +// and calls [ClientCodec.ReadResponseHeader] and [ClientCodec.ReadResponseBody] in pairs +// to read responses. The client calls [ClientCodec.Close] when finished with the // connection. ReadResponseBody may be called with a nil // argument to force the body of the response to be read and then // discarded. -// See NewClient's comment for information about concurrent access. +// See [NewClient]'s comment for information about concurrent access. type ClientCodec interface { WriteRequest(*Request, any) error ReadResponseHeader(*Response) error @@ -181,7 +181,7 @@ func (call *Call) done() { } } -// NewClient returns a new Client to handle requests to the +// NewClient returns a new [Client] to handle requests to the // set of services at the other end of the connection. // It adds a buffer to the write side of the connection so // the header and payload are sent as a unit. @@ -196,7 +196,7 @@ func NewClient(conn io.ReadWriteCloser) *Client { return NewClientWithCodec(client) } -// NewClientWithCodec is like NewClient but uses the specified +// NewClientWithCodec is like [NewClient] but uses the specified // codec to encode requests and decode responses. func NewClientWithCodec(codec ClientCodec) *Client { client := &Client{ @@ -279,7 +279,7 @@ func Dial(network, address string) (*Client, error) { } // Close calls the underlying codec's Close method. If the connection is already -// shutting down, ErrShutdown is returned. +// shutting down, [ErrShutdown] is returned. func (client *Client) Close() error { client.mutex.Lock() if client.closing { @@ -291,7 +291,7 @@ func (client *Client) Close() error { return client.codec.Close() } -// Go invokes the function asynchronously. It returns the Call structure representing +// Go invokes the function asynchronously. It returns the [Call] structure representing // the invocation. The done channel will signal when the call is complete by returning // the same Call object. If done is nil, Go will allocate a new channel. // If non-nil, done must be buffered or Go will deliberately crash. diff --git a/src/net/rpc/jsonrpc/client.go b/src/net/rpc/jsonrpc/client.go index c473017d26..1beba0f364 100644 --- a/src/net/rpc/jsonrpc/client.go +++ b/src/net/rpc/jsonrpc/client.go @@ -33,7 +33,7 @@ type clientCodec struct { pending map[uint64]string // map request id to method name } -// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn. +// NewClientCodec returns a new [rpc.ClientCodec] using JSON-RPC on conn. func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { return &clientCodec{ dec: json.NewDecoder(conn), @@ -108,7 +108,7 @@ func (c *clientCodec) Close() error { return c.c.Close() } -// NewClient returns a new rpc.Client to handle requests to the +// NewClient returns a new [rpc.Client] to handle requests to the // set of services at the other end of the connection. func NewClient(conn io.ReadWriteCloser) *rpc.Client { return rpc.NewClientWithCodec(NewClientCodec(conn)) diff --git a/src/net/rpc/jsonrpc/server.go b/src/net/rpc/jsonrpc/server.go index 3ee4ddfef2..57a4de1d0f 100644 --- a/src/net/rpc/jsonrpc/server.go +++ b/src/net/rpc/jsonrpc/server.go @@ -33,7 +33,7 @@ type serverCodec struct { pending map[uint64]*json.RawMessage } -// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn. +// NewServerCodec returns a new [rpc.ServerCodec] using JSON-RPC on conn. func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { return &serverCodec{ dec: json.NewDecoder(conn), diff --git a/src/net/rpc/server.go b/src/net/rpc/server.go index 5cea2cc507..1771726a93 100644 --- a/src/net/rpc/server.go +++ b/src/net/rpc/server.go @@ -30,17 +30,17 @@ These requirements apply even if a different codec is used. The method's first argument represents the arguments provided by the caller; the second argument represents the result parameters to be returned to the caller. The method's return value, if non-nil, is passed back as a string that the client -sees as if created by errors.New. If an error is returned, the reply parameter +sees as if created by [errors.New]. If an error is returned, the reply parameter will not be sent back to the client. -The server may handle requests on a single connection by calling ServeConn. More -typically it will create a network listener and call Accept or, for an HTTP -listener, HandleHTTP and http.Serve. +The server may handle requests on a single connection by calling [ServeConn]. More +typically it will create a network listener and call [Accept] or, for an HTTP +listener, [HandleHTTP] and [http.Serve]. A client wishing to use the service establishes a connection and then invokes -NewClient on the connection. The convenience function Dial (DialHTTP) performs +[NewClient] on the connection. The convenience function [Dial] ([DialHTTP]) performs both steps for a raw network connection (an HTTP connection). The resulting -Client object has two methods, Call and Go, that specify the service and method to +[Client] object has two methods, [Call] and Go, that specify the service and method to call, a pointer containing the arguments, and a pointer to receive the result parameters. @@ -48,7 +48,7 @@ The Call method waits for the remote call to complete while the Go method launches the call asynchronously and signals completion using the Call structure's Done channel. -Unless an explicit codec is set up, package encoding/gob is used to +Unless an explicit codec is set up, package [encoding/gob] is used to transport the data. Here is a simple example. A server wishes to export an object of type Arith: @@ -146,9 +146,8 @@ const ( DefaultDebugPath = "/debug/rpc" ) -// Precompute the reflect type for error. Can't use error directly -// because Typeof takes an empty interface value. This is annoying. -var typeOfError = reflect.TypeOf((*error)(nil)).Elem() +// Precompute the reflect type for error. +var typeOfError = reflect.TypeFor[error]() type methodType struct { sync.Mutex // protects counters @@ -193,12 +192,12 @@ type Server struct { freeResp *Response } -// NewServer returns a new Server. +// NewServer returns a new [Server]. func NewServer() *Server { return &Server{} } -// DefaultServer is the default instance of *Server. +// DefaultServer is the default instance of [*Server]. var DefaultServer = NewServer() // Is this type exported or a builtin? @@ -226,7 +225,7 @@ func (server *Server) Register(rcvr any) error { return server.register(rcvr, "", false) } -// RegisterName is like Register but uses the provided name for the type +// RegisterName is like [Register] but uses the provided name for the type // instead of the receiver's concrete type. func (server *Server) RegisterName(name string, rcvr any) error { return server.register(rcvr, name, true) @@ -441,8 +440,8 @@ func (c *gobServerCodec) Close() error { // ServeConn blocks, serving the connection until the client hangs up. // The caller typically invokes ServeConn in a go statement. // ServeConn uses the gob wire format (see package gob) on the -// connection. To use an alternate codec, use ServeCodec. -// See NewClient's comment for information about concurrent access. +// connection. To use an alternate codec, use [ServeCodec]. +// See [NewClient]'s comment for information about concurrent access. func (server *Server) ServeConn(conn io.ReadWriteCloser) { buf := bufio.NewWriter(conn) srv := &gobServerCodec{ @@ -454,7 +453,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { server.ServeCodec(srv) } -// ServeCodec is like ServeConn but uses the specified codec to +// ServeCodec is like [ServeConn] but uses the specified codec to // decode requests and encode responses. func (server *Server) ServeCodec(codec ServerCodec) { sending := new(sync.Mutex) @@ -484,7 +483,7 @@ func (server *Server) ServeCodec(codec ServerCodec) { codec.Close() } -// ServeRequest is like ServeCodec but synchronously serves a single request. +// ServeRequest is like [ServeCodec] but synchronously serves a single request. // It does not close the codec upon completion. func (server *Server) ServeRequest(codec ServerCodec) error { sending := new(sync.Mutex) @@ -636,10 +635,10 @@ func (server *Server) Accept(lis net.Listener) { } } -// Register publishes the receiver's methods in the DefaultServer. +// Register publishes the receiver's methods in the [DefaultServer]. func Register(rcvr any) error { return DefaultServer.Register(rcvr) } -// RegisterName is like Register but uses the provided name for the type +// RegisterName is like [Register] but uses the provided name for the type // instead of the receiver's concrete type. func RegisterName(name string, rcvr any) error { return DefaultServer.RegisterName(name, rcvr) @@ -647,12 +646,12 @@ func RegisterName(name string, rcvr any) error { // A ServerCodec implements reading of RPC requests and writing of // RPC responses for the server side of an RPC session. -// The server calls ReadRequestHeader and ReadRequestBody in pairs -// to read requests from the connection, and it calls WriteResponse to -// write a response back. The server calls Close when finished with the +// The server calls [ServerCodec.ReadRequestHeader] and [ServerCodec.ReadRequestBody] in pairs +// to read requests from the connection, and it calls [ServerCodec.WriteResponse] to +// write a response back. The server calls [ServerCodec.Close] when finished with the // connection. ReadRequestBody may be called with a nil // argument to force the body of the request to be read and discarded. -// See NewClient's comment for information about concurrent access. +// See [NewClient]'s comment for information about concurrent access. type ServerCodec interface { ReadRequestHeader(*Request) error ReadRequestBody(any) error @@ -662,37 +661,37 @@ type ServerCodec interface { Close() error } -// ServeConn runs the DefaultServer on a single connection. +// ServeConn runs the [DefaultServer] on a single connection. // ServeConn blocks, serving the connection until the client hangs up. // The caller typically invokes ServeConn in a go statement. // ServeConn uses the gob wire format (see package gob) on the -// connection. To use an alternate codec, use ServeCodec. -// See NewClient's comment for information about concurrent access. +// connection. To use an alternate codec, use [ServeCodec]. +// See [NewClient]'s comment for information about concurrent access. func ServeConn(conn io.ReadWriteCloser) { DefaultServer.ServeConn(conn) } -// ServeCodec is like ServeConn but uses the specified codec to +// ServeCodec is like [ServeConn] but uses the specified codec to // decode requests and encode responses. func ServeCodec(codec ServerCodec) { DefaultServer.ServeCodec(codec) } -// ServeRequest is like ServeCodec but synchronously serves a single request. +// ServeRequest is like [ServeCodec] but synchronously serves a single request. // It does not close the codec upon completion. func ServeRequest(codec ServerCodec) error { return DefaultServer.ServeRequest(codec) } // Accept accepts connections on the listener and serves requests -// to DefaultServer for each incoming connection. +// to [DefaultServer] for each incoming connection. // Accept blocks; the caller typically invokes it in a go statement. func Accept(lis net.Listener) { DefaultServer.Accept(lis) } // Can connect to RPC service using HTTP CONNECT to rpcPath. var connected = "200 Connected to Go RPC" -// ServeHTTP implements an http.Handler that answers RPC requests. +// ServeHTTP implements an [http.Handler] that answers RPC requests. func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "CONNECT" { w.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -711,15 +710,15 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { // HandleHTTP registers an HTTP handler for RPC messages on rpcPath, // and a debugging handler on debugPath. -// It is still necessary to invoke http.Serve(), typically in a go statement. +// It is still necessary to invoke [http.Serve](), typically in a go statement. func (server *Server) HandleHTTP(rpcPath, debugPath string) { http.Handle(rpcPath, server) http.Handle(debugPath, debugHTTP{server}) } -// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer -// on DefaultRPCPath and a debugging handler on DefaultDebugPath. -// It is still necessary to invoke http.Serve(), typically in a go statement. +// HandleHTTP registers an HTTP handler for RPC messages to [DefaultServer] +// on [DefaultRPCPath] and a debugging handler on [DefaultDebugPath]. +// It is still necessary to invoke [http.Serve](), typically in a go statement. func HandleHTTP() { DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath) } diff --git a/src/net/sendfile_linux_test.go b/src/net/sendfile_linux_test.go index 8cd6acca17..7a66d3645f 100644 --- a/src/net/sendfile_linux_test.go +++ b/src/net/sendfile_linux_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build linux -// +build linux package net @@ -15,29 +14,36 @@ import ( ) func BenchmarkSendFile(b *testing.B) { + b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") }) + b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") }) +} + +func benchmarkSendFile(b *testing.B, proto string) { for i := 0; i <= 10; i++ { size := 1 << (i + 10) - bench := sendFileBench{chunkSize: size} + bench := sendFileBench{ + proto: proto, + chunkSize: size, + } b.Run(strconv.Itoa(size), bench.benchSendFile) } } type sendFileBench struct { + proto string chunkSize int } func (bench sendFileBench) benchSendFile(b *testing.B) { fileSize := b.N * bench.chunkSize f := createTempFile(b, fileSize) - fileName := f.Name() - defer os.Remove(fileName) - defer f.Close() - client, server := spliceTestSocketPair(b, "tcp") + client, server := spliceTestSocketPair(b, bench.proto) defer server.Close() cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize) if err != nil { + client.Close() b.Fatal(err) } defer cleanUp() @@ -52,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) { b.Fatalf("failed to copy data with sendfile, error: %v", err) } if sent != int64(fileSize) { - b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent) + b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize) } } func createTempFile(b *testing.B, size int) *os.File { - f, err := os.CreateTemp("", "linux-sendfile-test") + f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench") if err != nil { b.Fatalf("failed to create temporary file: %v", err) } + b.Cleanup(func() { + f.Close() + }) data := make([]byte, size) if _, err := f.Write(data); err != nil { diff --git a/src/net/sendfile_stub.go b/src/net/sendfile_stub.go index c7a2e6a1e4..a4fdd99ffe 100644 --- a/src/net/sendfile_stub.go +++ b/src/net/sendfile_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build aix || (js && wasm) || netbsd || openbsd || ios || wasip1 +//go:build aix || js || netbsd || openbsd || ios || wasip1 package net diff --git a/src/net/sendfile_test.go b/src/net/sendfile_test.go index 44a87a1d20..4cba1ed2b1 100644 --- a/src/net/sendfile_test.go +++ b/src/net/sendfile_test.go @@ -2,12 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "errors" @@ -209,7 +208,7 @@ func TestSendfileSeeked(t *testing.T) { // Test that sendfile doesn't put a pipe into blocking mode. func TestSendfilePipe(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9", "windows", "js", "wasip1": // These systems don't support deadlines on pipes. t.Skipf("skipping on %s", runtime.GOOS) } @@ -362,3 +361,88 @@ func TestSendfileOnWriteTimeoutExceeded(t *testing.T) { t.Fatal(err) } } + +func BenchmarkSendfileZeroBytes(b *testing.B) { + var ( + wg sync.WaitGroup + ctx, cancel = context.WithCancel(context.Background()) + ) + + defer wg.Wait() + + ln := newLocalListener(b, "tcp") + defer ln.Close() + + tempFile, err := os.CreateTemp(b.TempDir(), "test.txt") + if err != nil { + b.Fatalf("failed to create temp file: %v", err) + } + defer tempFile.Close() + + fileName := tempFile.Name() + + dataSize := b.N + wg.Add(1) + go func(f *os.File) { + defer wg.Done() + + for i := 0; i < dataSize; i++ { + if _, err := f.Write([]byte{1}); err != nil { + b.Errorf("failed to write: %v", err) + return + } + if i%1000 == 0 { + f.Sync() + } + } + }(tempFile) + + b.ResetTimer() + b.ReportAllocs() + + wg.Add(1) + go func(ln Listener, fileName string) { + defer wg.Done() + + conn, err := ln.Accept() + if err != nil { + b.Errorf("failed to accept: %v", err) + return + } + defer conn.Close() + + f, err := os.OpenFile(fileName, os.O_RDONLY, 0660) + if err != nil { + b.Errorf("failed to open file: %v", err) + return + } + defer f.Close() + + for { + if ctx.Err() != nil { + return + } + + if _, err := io.Copy(conn, f); err != nil { + b.Errorf("failed to copy: %v", err) + return + } + } + }(ln, fileName) + + conn, err := Dial("tcp", ln.Addr().String()) + if err != nil { + b.Fatalf("failed to dial: %v", err) + } + defer conn.Close() + + n, err := io.CopyN(io.Discard, conn, int64(dataSize)) + if err != nil { + b.Fatalf("failed to copy: %v", err) + } + if n != int64(dataSize) { + b.Fatalf("expected %d copied bytes, but got %d", dataSize, n) + } + + cancel() +} diff --git a/src/net/sendfile_unix_alt.go b/src/net/sendfile_unix_alt.go index b86771721e..5cb65ee767 100644 --- a/src/net/sendfile_unix_alt.go +++ b/src/net/sendfile_unix_alt.go @@ -15,8 +15,8 @@ import ( // sendFile copies the contents of r to c using the sendfile // system call to minimize copies. // -// if handled == true, sendFile returns the number of bytes copied and any -// non-EOF error. +// if handled == true, sendFile returns the number (potentially zero) of bytes +// copied and any non-EOF error. // // if handled == false, sendFile performed no work. func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { @@ -65,7 +65,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { var werr error err = sc.Read(func(fd uintptr) bool { - written, werr = poll.SendFile(&c.pfd, int(fd), pos, remain) + written, werr, handled = poll.SendFile(&c.pfd, int(fd), pos, remain) return true }) if err == nil { @@ -78,8 +78,8 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { _, err1 := f.Seek(written, io.SeekCurrent) if err1 != nil && err == nil { - return written, err1, written > 0 + return written, err1, handled } - return written, wrapSyscallError("sendfile", err), written > 0 + return written, wrapSyscallError("sendfile", err), handled } diff --git a/src/net/server_test.go b/src/net/server_test.go index 2ff0689067..eb6b111f1f 100644 --- a/src/net/server_test.go +++ b/src/net/server_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( + "fmt" "os" "testing" ) @@ -251,65 +250,80 @@ var udpServerTests = []struct { func TestUDPServer(t *testing.T) { for i, tt := range udpServerTests { - if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) { - t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"<-"+tt.taddr) - continue - } - - c1, err := ListenPacket(tt.snet, tt.saddr) - if err != nil { - if perr := parseDialError(err); perr != nil { - t.Error(perr) + i, tt := i, tt + t.Run(fmt.Sprint(i), func(t *testing.T) { + if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) { + t.Skipf("skipping %s %s<-%s test", tt.snet, tt.saddr, tt.taddr) } - t.Fatal(err) - } + t.Logf("%s %s<-%s", tt.snet, tt.saddr, tt.taddr) - ls := (&packetListener{PacketConn: c1}).newLocalServer() - defer ls.teardown() - tpch := make(chan error, 1) - handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) } - if err := ls.buildup(handler); err != nil { - t.Fatal(err) - } - - trch := make(chan error, 1) - _, port, err := SplitHostPort(ls.PacketConn.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if tt.dial { - d := Dialer{Timeout: someTimeout} - c2, err := d.Dial(tt.tnet, JoinHostPort(tt.taddr, port)) + c1, err := ListenPacket(tt.snet, tt.saddr) if err != nil { if perr := parseDialError(err); perr != nil { t.Error(perr) } t.Fatal(err) } - defer c2.Close() - go transceiver(c2, []byte("UDP SERVER TEST"), trch) - } else { - c2, err := ListenPacket(tt.tnet, JoinHostPort(tt.taddr, "0")) - if err != nil { - if perr := parseDialError(err); perr != nil { - t.Error(perr) - } + + ls := (&packetListener{PacketConn: c1}).newLocalServer() + defer ls.teardown() + tpch := make(chan error, 1) + handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) } + if err := ls.buildup(handler); err != nil { t.Fatal(err) } - defer c2.Close() - dst, err := ResolveUDPAddr(tt.tnet, JoinHostPort(tt.taddr, port)) + + trch := make(chan error, 1) + _, port, err := SplitHostPort(ls.PacketConn.LocalAddr().String()) if err != nil { t.Fatal(err) } - go packetTransceiver(c2, []byte("UDP SERVER TEST"), dst, trch) - } + if tt.dial { + d := Dialer{Timeout: someTimeout} + c2, err := d.Dial(tt.tnet, JoinHostPort(tt.taddr, port)) + if err != nil { + if perr := parseDialError(err); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + defer c2.Close() + go transceiver(c2, []byte("UDP SERVER TEST"), trch) + } else { + c2, err := ListenPacket(tt.tnet, JoinHostPort(tt.taddr, "0")) + if err != nil { + if perr := parseDialError(err); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + defer c2.Close() + dst, err := ResolveUDPAddr(tt.tnet, JoinHostPort(tt.taddr, port)) + if err != nil { + t.Fatal(err) + } + go packetTransceiver(c2, []byte("UDP SERVER TEST"), dst, trch) + } - for err := range trch { - t.Errorf("#%d: %v", i, err) - } - for err := range tpch { - t.Errorf("#%d: %v", i, err) - } + for trch != nil || tpch != nil { + select { + case err, ok := <-trch: + if !ok { + trch = nil + } + if err != nil { + t.Errorf("client: %v", err) + } + case err, ok := <-tpch: + if !ok { + tpch = nil + } + if err != nil { + t.Errorf("server: %v", err) + } + } + } + }) } } @@ -326,58 +340,73 @@ func TestUnixgramServer(t *testing.T) { } for i, tt := range unixgramServerTests { - if !testableListenArgs("unixgram", tt.saddr, "") { - t.Logf("skipping %s test", "unixgram "+tt.saddr+"<-"+tt.caddr) - continue - } - - c1, err := ListenPacket("unixgram", tt.saddr) - if err != nil { - if perr := parseDialError(err); perr != nil { - t.Error(perr) + i, tt := i, tt + t.Run(fmt.Sprint(i), func(t *testing.T) { + if !testableListenArgs("unixgram", tt.saddr, "") { + t.Skipf("skipping unixgram %s<-%s test", tt.saddr, tt.caddr) } - t.Fatal(err) - } - - ls := (&packetListener{PacketConn: c1}).newLocalServer() - defer ls.teardown() - tpch := make(chan error, 1) - handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) } - if err := ls.buildup(handler); err != nil { - t.Fatal(err) - } + t.Logf("unixgram %s<-%s", tt.saddr, tt.caddr) - trch := make(chan error, 1) - if tt.dial { - d := Dialer{Timeout: someTimeout, LocalAddr: &UnixAddr{Net: "unixgram", Name: tt.caddr}} - c2, err := d.Dial("unixgram", ls.PacketConn.LocalAddr().String()) + c1, err := ListenPacket("unixgram", tt.saddr) if err != nil { if perr := parseDialError(err); perr != nil { t.Error(perr) } t.Fatal(err) } - defer os.Remove(c2.LocalAddr().String()) - defer c2.Close() - go transceiver(c2, []byte(c2.LocalAddr().String()), trch) - } else { - c2, err := ListenPacket("unixgram", tt.caddr) - if err != nil { - if perr := parseDialError(err); perr != nil { - t.Error(perr) - } + + ls := (&packetListener{PacketConn: c1}).newLocalServer() + defer ls.teardown() + tpch := make(chan error, 1) + handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) } + if err := ls.buildup(handler); err != nil { t.Fatal(err) } - defer os.Remove(c2.LocalAddr().String()) - defer c2.Close() - go packetTransceiver(c2, []byte("UNIXGRAM SERVER TEST"), ls.PacketConn.LocalAddr(), trch) - } - for err := range trch { - t.Errorf("#%d: %v", i, err) - } - for err := range tpch { - t.Errorf("#%d: %v", i, err) - } + trch := make(chan error, 1) + if tt.dial { + d := Dialer{Timeout: someTimeout, LocalAddr: &UnixAddr{Net: "unixgram", Name: tt.caddr}} + c2, err := d.Dial("unixgram", ls.PacketConn.LocalAddr().String()) + if err != nil { + if perr := parseDialError(err); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + defer os.Remove(c2.LocalAddr().String()) + defer c2.Close() + go transceiver(c2, []byte(c2.LocalAddr().String()), trch) + } else { + c2, err := ListenPacket("unixgram", tt.caddr) + if err != nil { + if perr := parseDialError(err); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + defer os.Remove(c2.LocalAddr().String()) + defer c2.Close() + go packetTransceiver(c2, []byte("UNIXGRAM SERVER TEST"), ls.PacketConn.LocalAddr(), trch) + } + + for trch != nil || tpch != nil { + select { + case err, ok := <-trch: + if !ok { + trch = nil + } + if err != nil { + t.Errorf("client: %v", err) + } + case err, ok := <-tpch: + if !ok { + tpch = nil + } + if err != nil { + t.Errorf("server: %v", err) + } + } + } + }) } } diff --git a/src/net/smtp/auth.go b/src/net/smtp/auth.go index 72eb16671f..6d461acc48 100644 --- a/src/net/smtp/auth.go +++ b/src/net/smtp/auth.go @@ -42,7 +42,7 @@ type plainAuth struct { host string } -// PlainAuth returns an Auth that implements the PLAIN authentication +// PlainAuth returns an [Auth] that implements the PLAIN authentication // mechanism as defined in RFC 4616. The returned Auth uses the given // username and password to authenticate to host and act as identity. // Usually identity should be the empty string, to act as username. @@ -86,7 +86,7 @@ type cramMD5Auth struct { username, secret string } -// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication +// CRAMMD5Auth returns an [Auth] that implements the CRAM-MD5 authentication // mechanism as defined in RFC 2195. // The returned Auth uses the given username and secret to authenticate // to the server using the challenge-response mechanism. diff --git a/src/net/smtp/smtp.go b/src/net/smtp/smtp.go index b5a025ef2a..b7877936da 100644 --- a/src/net/smtp/smtp.go +++ b/src/net/smtp/smtp.go @@ -48,7 +48,7 @@ type Client struct { helloError error // the error from the hello } -// Dial returns a new Client connected to an SMTP server at addr. +// Dial returns a new [Client] connected to an SMTP server at addr. // The addr must include a port, as in "mail.example.com:smtp". func Dial(addr string) (*Client, error) { conn, err := net.Dial("tcp", addr) @@ -59,7 +59,7 @@ func Dial(addr string) (*Client, error) { return NewClient(conn, host) } -// NewClient returns a new Client using an existing connection and host as a +// NewClient returns a new [Client] using an existing connection and host as a // server name to be used when authenticating. func NewClient(conn net.Conn, host string) (*Client, error) { text := textproto.NewConn(conn) @@ -166,7 +166,7 @@ func (c *Client) StartTLS(config *tls.Config) error { } // TLSConnectionState returns the client's TLS connection state. -// The return values are their zero values if StartTLS did +// The return values are their zero values if [Client.StartTLS] did // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { tc, ok := c.conn.(*tls.Conn) @@ -241,7 +241,7 @@ func (c *Client) Auth(a Auth) error { // If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME // parameter. If the server supports the SMTPUTF8 extension, Mail adds the // SMTPUTF8 parameter. -// This initiates a mail transaction and is followed by one or more Rcpt calls. +// This initiates a mail transaction and is followed by one or more [Client.Rcpt] calls. func (c *Client) Mail(from string) error { if err := validateLine(from); err != nil { return err @@ -263,8 +263,8 @@ func (c *Client) Mail(from string) error { } // Rcpt issues a RCPT command to the server using the provided email address. -// A call to Rcpt must be preceded by a call to Mail and may be followed by -// a Data call or another Rcpt call. +// A call to Rcpt must be preceded by a call to [Client.Mail] and may be followed by +// a [Client.Data] call or another Rcpt call. func (c *Client) Rcpt(to string) error { if err := validateLine(to); err != nil { return err @@ -287,7 +287,7 @@ func (d *dataCloser) Close() error { // Data issues a DATA command to the server and returns a writer that // can be used to write the mail headers and body. The caller should // close the writer before calling any more methods on c. A call to -// Data must be preceded by one or more calls to Rcpt. +// Data must be preceded by one or more calls to [Client.Rcpt]. func (c *Client) Data() (io.WriteCloser, error) { _, _, err := c.cmd(354, "DATA") if err != nil { diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index b3e1806ba9..d04c26e7ef 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -89,38 +89,10 @@ func (fd *netFD) ctrlNetwork() string { return fd.net + "6" } -func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { - switch fd.family { - case syscall.AF_INET, syscall.AF_INET6: - switch fd.sotype { - case syscall.SOCK_STREAM: - return sockaddrToTCP - case syscall.SOCK_DGRAM: - return sockaddrToUDP - case syscall.SOCK_RAW: - return sockaddrToIP - } - case syscall.AF_UNIX: - switch fd.sotype { - case syscall.SOCK_STREAM: - return sockaddrToUnix - case syscall.SOCK_DGRAM: - return sockaddrToUnixgram - case syscall.SOCK_SEQPACKET: - return sockaddrToUnixpacket - } - } - return func(syscall.Sockaddr) Addr { return nil } -} - func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error { var c *rawConn - var err error if ctrlCtxFn != nil { - c, err = newRawConn(fd) - if err != nil { - return err - } + c = newRawConn(fd) var ctrlAddr string if raddr != nil { ctrlAddr = raddr.String() @@ -133,6 +105,7 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func } var lsa syscall.Sockaddr + var err error if laddr != nil { if lsa, err = laddr.sockaddr(fd.family); err != nil { return err @@ -185,10 +158,7 @@ func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, } if ctrlCtxFn != nil { - c, err := newRawConn(fd) - if err != nil { - return err - } + c := newRawConn(fd) if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil { return err } @@ -239,10 +209,7 @@ func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn f } if ctrlCtxFn != nil { - c, err := newRawConn(fd) - if err != nil { - return err - } + c := newRawConn(fd) if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil { return err } diff --git a/src/net/sock_stub.go b/src/net/sock_stub.go index e163755568..fd86fa92dc 100644 --- a/src/net/sock_stub.go +++ b/src/net/sock_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build aix || (js && wasm) || solaris || wasip1 +//go:build aix || js || solaris || wasip1 package net diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go index fa11c7af2e..a519909bb0 100644 --- a/src/net/sock_windows.go +++ b/src/net/sock_windows.go @@ -11,29 +11,15 @@ import ( ) func maxListenerBacklog() int { - // TODO: Implement this - // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030. + // When the socket backlog is SOMAXCONN, Windows will set the backlog to + // "a reasonable maximum value". + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen return syscall.SOMAXCONN } func sysSocket(family, sotype, proto int) (syscall.Handle, error) { s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto), nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT) - if err == nil { - return s, nil - } - // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some - // old versions of Windows, see - // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx - // for details. Just use syscall.Socket, if windows.WSASocket failed. - - // See ../syscall/exec_unix.go for description of ForkLock. - syscall.ForkLock.RLock() - s, err = socketFunc(family, sotype, proto) - if err == nil { - syscall.CloseOnExec(s) - } - syscall.ForkLock.RUnlock() if err != nil { return syscall.InvalidHandle, os.NewSyscallError("socket", err) } diff --git a/src/net/sockaddr_posix.go b/src/net/sockaddr_posix.go index e44fc76f4b..c5604fca35 100644 --- a/src/net/sockaddr_posix.go +++ b/src/net/sockaddr_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net @@ -32,3 +32,27 @@ type sockaddr interface { // toLocal maps the zero address to a local system address (127.0.0.1 or ::1) toLocal(net string) sockaddr } + +func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { + switch fd.family { + case syscall.AF_INET, syscall.AF_INET6: + switch fd.sotype { + case syscall.SOCK_STREAM: + return sockaddrToTCP + case syscall.SOCK_DGRAM: + return sockaddrToUDP + case syscall.SOCK_RAW: + return sockaddrToIP + } + case syscall.AF_UNIX: + switch fd.sotype { + case syscall.SOCK_STREAM: + return sockaddrToUnix + case syscall.SOCK_DGRAM: + return sockaddrToUnixgram + case syscall.SOCK_SEQPACKET: + return sockaddrToUnixpacket + } + } + return func(syscall.Sockaddr) Addr { return nil } +} diff --git a/src/net/sockopt_stub.go b/src/net/sockopt_fake.go index 186d8912cb..9d9f7ea951 100644 --- a/src/net/sockopt_stub.go +++ b/src/net/sockopt_fake.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || wasip1 +//go:build js || wasip1 package net @@ -21,10 +21,16 @@ func setDefaultMulticastSockopts(s int) error { } func setReadBuffer(fd *netFD, bytes int) error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.setReadBuffer(bytes) + } return syscall.ENOPROTOOPT } func setWriteBuffer(fd *netFD, bytes int) error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.setWriteBuffer(bytes) + } return syscall.ENOPROTOOPT } @@ -33,5 +39,8 @@ func setKeepAlive(fd *netFD, keepalive bool) error { } func setLinger(fd *netFD, sec int) error { + if fd.fakeNetFD != nil { + return fd.fakeNetFD.setLinger(sec) + } return syscall.ENOPROTOOPT } diff --git a/src/net/sockopt_posix.go b/src/net/sockopt_posix.go index 32e8fcd505..a380c7719b 100644 --- a/src/net/sockopt_posix.go +++ b/src/net/sockopt_posix.go @@ -20,35 +20,6 @@ func boolint(b bool) int { return 0 } -func ipv4AddrToInterface(ip IP) (*Interface, error) { - ift, err := Interfaces() - if err != nil { - return nil, err - } - for _, ifi := range ift { - ifat, err := ifi.Addrs() - if err != nil { - return nil, err - } - for _, ifa := range ifat { - switch v := ifa.(type) { - case *IPAddr: - if ip.Equal(v.IP) { - return &ifi, nil - } - case *IPNet: - if ip.Equal(v.IP) { - return &ifi, nil - } - } - } - } - if ip.Equal(IPv4zero) { - return nil, nil - } - return nil, errNoSuchInterface -} - func interfaceToIPv4Addr(ifi *Interface) (IP, error) { if ifi == nil { return IPv4zero, nil diff --git a/src/net/sockoptip_stub.go b/src/net/sockoptip_stub.go index a37c31223d..23891a865f 100644 --- a/src/net/sockoptip_stub.go +++ b/src/net/sockoptip_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || wasip1 +//go:build js || wasip1 package net diff --git a/src/net/sockoptip_windows.go b/src/net/sockoptip_windows.go index 62676039a3..9dfa37c51e 100644 --- a/src/net/sockoptip_windows.go +++ b/src/net/sockoptip_windows.go @@ -8,7 +8,6 @@ import ( "os" "runtime" "syscall" - "unsafe" ) func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { @@ -18,7 +17,7 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { } var a [4]byte copy(a[:], ip.To4()) - err = fd.pfd.Setsockopt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, (*byte)(unsafe.Pointer(&a[0])), 4) + err = fd.pfd.SetsockoptInet4Addr(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a) runtime.KeepAlive(fd) return wrapSyscallError("setsockopt", err) } diff --git a/src/net/splice_linux.go b/src/net/splice_linux.go index ab2ab70b28..bdafcb59ab 100644 --- a/src/net/splice_linux.go +++ b/src/net/splice_linux.go @@ -9,12 +9,12 @@ import ( "io" ) -// splice transfers data from r to c using the splice system call to minimize -// copies from and to userspace. c must be a TCP connection. Currently, splice -// is only enabled if r is a TCP or a stream-oriented Unix connection. +// spliceFrom transfers data from r to c using the splice system call to minimize +// copies from and to userspace. c must be a TCP connection. +// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection. // -// If splice returns handled == false, it has performed no work. -func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) { +// If spliceFrom returns handled == false, it has performed no work. +func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) { var remain int64 = 1<<63 - 1 // by default, copy until EOF lr, ok := r.(*io.LimitedReader) if ok { @@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) { } var s *netFD - if tc, ok := r.(*TCPConn); ok { - s = tc.fd - } else if uc, ok := r.(*UnixConn); ok { - if uc.fd.net != "unix" { + switch v := r.(type) { + case *TCPConn: + s = v.fd + case tcpConnWithoutWriteTo: + s = v.fd + case *UnixConn: + if v.fd.net != "unix" { return 0, nil, false } - s = uc.fd - } else { + s = v.fd + default: return 0, nil, false } @@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) { } return written, wrapSyscallError(sc, err), handled } + +// spliceTo transfers data from c to w using the splice system call to minimize +// copies from and to userspace. c must be a TCP connection. +// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection. +// +// If spliceTo returns handled == false, it has performed no work. +func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) { + uc, ok := w.(*UnixConn) + if !ok || uc.fd.net != "unix" { + return + } + + written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1) + return written, wrapSyscallError(sc, err), handled +} diff --git a/src/net/splice_stub.go b/src/net/splice_stub.go index 3cdadb11c5..239227ff88 100644 --- a/src/net/splice_stub.go +++ b/src/net/splice_stub.go @@ -8,6 +8,10 @@ package net import "io" -func splice(c *netFD, r io.Reader) (int64, error, bool) { +func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) { + return 0, nil, false +} + +func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) { return 0, nil, false } diff --git a/src/net/splice_test.go b/src/net/splice_test.go index 75a8f274ff..227ddebff4 100644 --- a/src/net/splice_test.go +++ b/src/net/splice_test.go @@ -23,6 +23,7 @@ func TestSplice(t *testing.T) { t.Skip("skipping unix-to-tcp tests") } t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) + t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") }) t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") }) t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") }) t.Run("no-unixpacket", testSpliceNoUnixpacket) @@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) { } func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { + // UnixConn doesn't implement io.ReaderFrom, which will fail + // the following test in asserting a UnixConn to be an io.ReaderFrom, + // so skip this test. + if upNet == "unix" || downNet == "unix" { + t.Skip("skipping test on unix socket") + } + clientUp, serverUp := spliceTestSocketPair(t, upNet) defer clientUp.Close() clientDown, serverDown := spliceTestSocketPair(t, downNet) @@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { serverUp.Close() - // We'd like to call net.splice here and check the handled return + // We'd like to call net.spliceFrom here and check the handled return // value, but we disable splice on old Linux kernels. // - // In that case, poll.Splice and net.splice return a non-nil error + // In that case, poll.Splice and net.spliceFrom return a non-nil error // and handled == false. We'd ideally like to see handled == true // because the source reader is at EOF, but if we're running on an old - // kernel, and splice is disabled, we won't see EOF from net.splice, + // kernel, and splice is disabled, we won't see EOF from net.spliceFrom, // because we won't touch the reader at all. // - // Trying to untangle the errors from net.splice and match them + // Trying to untangle the errors from net.spliceFrom and match them // against the errors created by the poll package would be brittle, // so this is a higher level test. // @@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) { // // What we want is err == nil and handled == false, i.e. we never // called poll.Splice, because we know the unix socket's network. - _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp) + _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } @@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) { defer clientDown.Close() defer serverDown.Close() // Analogous to testSpliceNoUnixpacket. - _, err, handled := splice(serverDown.(*TCPConn).fd, up) + _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } @@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) { b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) + b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") }) } func benchSplice(b *testing.B, upNet, downNet string) { diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index 358e48723b..590516bff1 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -24,7 +24,7 @@ type TCPAddr struct { Zone string // IPv6 scoped addressing zone } -// AddrPort returns the TCPAddr a as a netip.AddrPort. +// AddrPort returns the [TCPAddr] a as a [netip.AddrPort]. // // If a.Port does not fit in a uint16, it's silently truncated. // @@ -79,7 +79,7 @@ func (a *TCPAddr) opAddr() Addr { // recommended, because it will return at most one of the host name's // IP addresses. // -// See func Dial for a description of the network and address +// See func [Dial] for a description of the network and address // parameters. func ResolveTCPAddr(network, address string) (*TCPAddr, error) { switch network { @@ -96,7 +96,7 @@ func ResolveTCPAddr(network, address string) (*TCPAddr, error) { return addrs.forResolve(network, address).(*TCPAddr), nil } -// TCPAddrFromAddrPort returns addr as a TCPAddr. If addr.IsValid() is false, +// TCPAddrFromAddrPort returns addr as a [TCPAddr]. If addr.IsValid() is false, // then the returned TCPAddr will contain a nil IP field, indicating an // address family-agnostic unspecified address. func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr { @@ -107,22 +107,22 @@ func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr { } } -// TCPConn is an implementation of the Conn interface for TCP network +// TCPConn is an implementation of the [Conn] interface for TCP network // connections. type TCPConn struct { conn } // SyscallConn returns a raw network connection. -// This implements the syscall.Conn interface. +// This implements the [syscall.Conn] interface. func (c *TCPConn) SyscallConn() (syscall.RawConn, error) { if !c.ok() { return nil, syscall.EINVAL } - return newRawConn(c.fd) + return newRawConn(c.fd), nil } -// ReadFrom implements the io.ReaderFrom ReadFrom method. +// ReadFrom implements the [io.ReaderFrom] ReadFrom method. func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { if !c.ok() { return 0, syscall.EINVAL @@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { return n, err } +// WriteTo implements the io.WriterTo WriteTo method. +func (c *TCPConn) WriteTo(w io.Writer) (int64, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + n, err := c.writeTo(w) + if err != nil && err != io.EOF { + err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return n, err +} + // CloseRead shuts down the reading side of the TCP connection. // Most callers should just use Close. func (c *TCPConn) CloseRead() error { @@ -250,7 +262,7 @@ func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Dura return &TCPConn{conn{fd}} } -// DialTCP acts like Dial for TCP networks. +// DialTCP acts like [Dial] for TCP networks. // // The network must be a TCP network name; see func Dial for details. // @@ -275,14 +287,14 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) { } // TCPListener is a TCP network listener. Clients should typically -// use variables of type Listener instead of assuming TCP. +// use variables of type [Listener] instead of assuming TCP. type TCPListener struct { fd *netFD lc ListenConfig } // SyscallConn returns a raw network connection. -// This implements the syscall.Conn interface. +// This implements the [syscall.Conn] interface. // // The returned RawConn only supports calling Control. Read and // Write return an error. @@ -290,7 +302,7 @@ func (l *TCPListener) SyscallConn() (syscall.RawConn, error) { if !l.ok() { return nil, syscall.EINVAL } - return newRawListener(l.fd) + return newRawListener(l.fd), nil } // AcceptTCP accepts the next incoming call and returns the new @@ -306,8 +318,8 @@ func (l *TCPListener) AcceptTCP() (*TCPConn, error) { return c, nil } -// Accept implements the Accept method in the Listener interface; it -// waits for the next call and returns a generic Conn. +// Accept implements the Accept method in the [Listener] interface; it +// waits for the next call and returns a generic [Conn]. func (l *TCPListener) Accept() (Conn, error) { if !l.ok() { return nil, syscall.EINVAL @@ -331,7 +343,7 @@ func (l *TCPListener) Close() error { return nil } -// Addr returns the listener's network address, a *TCPAddr. +// Addr returns the listener's network address, a [*TCPAddr]. // The Addr returned is shared by all invocations of Addr, so // do not modify it. func (l *TCPListener) Addr() Addr { return l.fd.laddr } @@ -342,13 +354,10 @@ func (l *TCPListener) SetDeadline(t time.Time) error { if !l.ok() { return syscall.EINVAL } - if err := l.fd.pfd.SetDeadline(t); err != nil { - return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err} - } - return nil + return l.fd.SetDeadline(t) } -// File returns a copy of the underlying os.File. +// File returns a copy of the underlying [os.File]. // It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. // @@ -366,7 +375,7 @@ func (l *TCPListener) File() (f *os.File, err error) { return } -// ListenTCP acts like Listen for TCP networks. +// ListenTCP acts like [Listen] for TCP networks. // // The network must be a TCP network name; see func Dial for details. // diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index d55948f69e..463dedcf44 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } +func (c *TCPConn) writeTo(w io.Writer) (int64, error) { + return genericWriteTo(c, w) +} + func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { if h := sd.testHookDialTCP; h != nil { return h(ctx, sd.network, laddr, raddr) diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index e6f425b1cd..01b5ec9ed0 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net @@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr { } func (c *TCPConn) readFrom(r io.Reader) (int64, error) { - if n, err, handled := splice(c.fd, r); handled { + if n, err, handled := spliceFrom(c.fd, r); handled { return n, err } if n, err, handled := sendFile(c.fd, r); handled { @@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } +func (c *TCPConn) writeTo(w io.Writer) (int64, error) { + if n, err, handled := spliceTo(w, c.fd); handled { + return n, err + } + return genericWriteTo(c, w) +} + func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { if h := sd.testHookDialTCP; h != nil { return h(ctx, sd.network, laddr, raddr) diff --git a/src/net/tcpsock_test.go b/src/net/tcpsock_test.go index f720a22519..b37e936ff8 100644 --- a/src/net/tcpsock_test.go +++ b/src/net/tcpsock_test.go @@ -2,11 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( + "context" + "errors" "fmt" "internal/testenv" "io" @@ -670,6 +670,11 @@ func TestTCPBig(t *testing.T) { } func TestCopyPipeIntoTCP(t *testing.T) { + switch runtime.GOOS { + case "js", "wasip1": + t.Skipf("skipping: os.Pipe not supported on %s", runtime.GOOS) + } + ln := newLocalListener(t, "tcp") defer ln.Close() @@ -783,3 +788,48 @@ func TestDialTCPDefaultKeepAlive(t *testing.T) { t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive) } } + +func TestTCPListenAfterClose(t *testing.T) { + // Regression test for https://go.dev/issue/50216: + // after calling Close on a Listener, the fake net implementation would + // erroneously Accept a connection dialed before the call to Close. + + ln := newLocalListener(t, "tcp") + defer ln.Close() + + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + + d := &Dialer{} + for n := 2; n > 0; n-- { + wg.Add(1) + go func() { + defer wg.Done() + + c, err := d.DialContext(ctx, ln.Addr().Network(), ln.Addr().String()) + if err == nil { + <-ctx.Done() + c.Close() + } + }() + } + + c, err := ln.Accept() + if err == nil { + c.Close() + } else { + t.Error(err) + } + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + ln.Close() + + c, err = ln.Accept() + if !errors.Is(err, ErrClosed) { + if err == nil { + c.Close() + } + t.Errorf("after l.Close(), l.Accept() = _, %v\nwant %v", err, ErrClosed) + } +} diff --git a/src/net/tcpsock_unix_test.go b/src/net/tcpsock_unix_test.go index 35fd937e07..df810a21d8 100644 --- a/src/net/tcpsock_unix_test.go +++ b/src/net/tcpsock_unix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !plan9 && !wasip1 && !windows +//go:build !plan9 && !windows package net diff --git a/src/net/tcpsockopt_stub.go b/src/net/tcpsockopt_stub.go index f778143d3b..cef07cd648 100644 --- a/src/net/tcpsockopt_stub.go +++ b/src/net/tcpsockopt_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || wasip1 +//go:build js || wasip1 package net diff --git a/src/net/textproto/header.go b/src/net/textproto/header.go index a58df7aebc..689a6827b9 100644 --- a/src/net/textproto/header.go +++ b/src/net/textproto/header.go @@ -23,7 +23,7 @@ func (h MIMEHeader) Set(key, value string) { } // Get gets the first value associated with the given key. -// It is case insensitive; CanonicalMIMEHeaderKey is used +// It is case insensitive; [CanonicalMIMEHeaderKey] is used // to canonicalize the provided key. // If there are no values associated with the key, Get returns "". // To use non-canonical keys, access the map directly. @@ -39,7 +39,7 @@ func (h MIMEHeader) Get(key string) string { } // Values returns all values associated with the given key. -// It is case insensitive; CanonicalMIMEHeaderKey is +// It is case insensitive; [CanonicalMIMEHeaderKey] is // used to canonicalize the provided key. To use non-canonical // keys, access the map directly. // The returned slice is not a copy. diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go index fc2590b1cd..793021101b 100644 --- a/src/net/textproto/reader.go +++ b/src/net/textproto/reader.go @@ -16,6 +16,10 @@ import ( "sync" ) +// TODO: This should be a distinguishable error (ErrMessageTooLarge) +// to allow mime/multipart to detect it. +var errMessageTooLarge = errors.New("message too large") + // A Reader implements convenience methods for reading requests // or responses from a text protocol network connection. type Reader struct { @@ -24,10 +28,10 @@ type Reader struct { buf []byte // a re-usable buffer for readContinuedLineSlice } -// NewReader returns a new Reader reading from r. +// NewReader returns a new [Reader] reading from r. // -// To avoid denial of service attacks, the provided bufio.Reader -// should be reading from an io.LimitReader or similar Reader to bound +// To avoid denial of service attacks, the provided [bufio.Reader] +// should be reading from an [io.LimitReader] or similar Reader to bound // the size of responses. func NewReader(r *bufio.Reader) *Reader { return &Reader{R: r} @@ -36,20 +40,23 @@ func NewReader(r *bufio.Reader) *Reader { // ReadLine reads a single line from r, // eliding the final \n or \r\n from the returned string. func (r *Reader) ReadLine() (string, error) { - line, err := r.readLineSlice() + line, err := r.readLineSlice(-1) return string(line), err } -// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +// ReadLineBytes is like [Reader.ReadLine] but returns a []byte instead of a string. func (r *Reader) ReadLineBytes() ([]byte, error) { - line, err := r.readLineSlice() + line, err := r.readLineSlice(-1) if line != nil { line = bytes.Clone(line) } return line, err } -func (r *Reader) readLineSlice() ([]byte, error) { +// readLineSlice reads a single line from r, +// up to lim bytes long (or unlimited if lim is less than 0), +// eliding the final \r or \r\n from the returned string. +func (r *Reader) readLineSlice(lim int64) ([]byte, error) { r.closeDot() var line []byte for { @@ -57,6 +64,9 @@ func (r *Reader) readLineSlice() ([]byte, error) { if err != nil { return nil, err } + if lim >= 0 && int64(len(line))+int64(len(l)) > lim { + return nil, errMessageTooLarge + } // Avoid the copy if the first call produced a full line. if line == nil && !more { return l, nil @@ -88,7 +98,7 @@ func (r *Reader) readLineSlice() ([]byte, error) { // // Empty lines are never continued. func (r *Reader) ReadContinuedLine() (string, error) { - line, err := r.readContinuedLineSlice(noValidation) + line, err := r.readContinuedLineSlice(-1, noValidation) return string(line), err } @@ -106,10 +116,10 @@ func trim(s []byte) []byte { return s[i:n] } -// ReadContinuedLineBytes is like ReadContinuedLine but +// ReadContinuedLineBytes is like [Reader.ReadContinuedLine] but // returns a []byte instead of a string. func (r *Reader) ReadContinuedLineBytes() ([]byte, error) { - line, err := r.readContinuedLineSlice(noValidation) + line, err := r.readContinuedLineSlice(-1, noValidation) if line != nil { line = bytes.Clone(line) } @@ -120,13 +130,14 @@ func (r *Reader) ReadContinuedLineBytes() ([]byte, error) { // returning a byte slice with all lines. The validateFirstLine function // is run on the first read line, and if it returns an error then this // error is returned from readContinuedLineSlice. -func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) { +// It reads up to lim bytes of data (or unlimited if lim is less than 0). +func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) { if validateFirstLine == nil { return nil, fmt.Errorf("missing validateFirstLine func") } // Read the first line. - line, err := r.readLineSlice() + line, err := r.readLineSlice(lim) if err != nil { return nil, err } @@ -154,13 +165,21 @@ func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([ // copy the slice into buf. r.buf = append(r.buf[:0], trim(line)...) + if lim < 0 { + lim = math.MaxInt64 + } + lim -= int64(len(r.buf)) + // Read continuation lines. for r.skipSpace() > 0 { - line, err := r.readLineSlice() + r.buf = append(r.buf, ' ') + if int64(len(r.buf)) >= lim { + return nil, errMessageTooLarge + } + line, err := r.readLineSlice(lim - int64(len(r.buf))) if err != nil { break } - r.buf = append(r.buf, ' ') r.buf = append(r.buf, trim(line)...) } return r.buf, nil @@ -289,7 +308,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err err return } -// DotReader returns a new Reader that satisfies Reads using the +// DotReader returns a new [Reader] that satisfies Reads using the // decoded text of a dot-encoded block read from r. // The returned Reader is only valid until the next call // to a method on r. @@ -303,7 +322,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err err // // The decoded form returned by the Reader's Read method // rewrites the "\r\n" line endings into the simpler "\n", -// removes leading dot escapes if present, and stops with error io.EOF +// removes leading dot escapes if present, and stops with error [io.EOF] // after consuming (and discarding) the end-of-sequence line. func (r *Reader) DotReader() io.Reader { r.closeDot() @@ -420,7 +439,7 @@ func (r *Reader) closeDot() { // ReadDotBytes reads a dot-encoding and returns the decoded data. // -// See the documentation for the DotReader method for details about dot-encoding. +// See the documentation for the [Reader.DotReader] method for details about dot-encoding. func (r *Reader) ReadDotBytes() ([]byte, error) { return io.ReadAll(r.DotReader()) } @@ -428,7 +447,7 @@ func (r *Reader) ReadDotBytes() ([]byte, error) { // ReadDotLines reads a dot-encoding and returns a slice // containing the decoded lines, with the final \r\n or \n elided from each. // -// See the documentation for the DotReader method for details about dot-encoding. +// See the documentation for the [Reader.DotReader] method for details about dot-encoding. func (r *Reader) ReadDotLines() ([]string, error) { // We could use ReadDotBytes and then Split it, // but reading a line at a time avoids needing a @@ -462,7 +481,7 @@ var colon = []byte(":") // ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. -// The returned map m maps CanonicalMIMEHeaderKey(key) to a +// The returned map m maps [CanonicalMIMEHeaderKey](key) to a // sequence of values in the same order encountered in the input. // // For example, consider this input: @@ -507,7 +526,8 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) // The first line cannot start with a leading space. if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { - line, err := r.readLineSlice() + const errorLimit = 80 // arbitrary limit on how much of the line we'll quote + line, err := r.readLineSlice(errorLimit) if err != nil { return m, err } @@ -515,7 +535,7 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) } for { - kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon) + kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon) if len(kv) == 0 { return m, err } @@ -544,7 +564,7 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) maxHeaders-- if maxHeaders < 0 { - return nil, errors.New("message too large") + return nil, errMessageTooLarge } // Skip initial spaces in value. @@ -557,9 +577,7 @@ func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) } maxMemory -= int64(len(value)) if maxMemory < 0 { - // TODO: This should be a distinguishable error (ErrMessageTooLarge) - // to allow mime/multipart to detect it. - return m, errors.New("message too large") + return m, errMessageTooLarge } if vv == nil && len(strs) > 0 { // More than likely this will be a single-element key. diff --git a/src/net/textproto/reader_test.go b/src/net/textproto/reader_test.go index 696ae406f3..26ff617470 100644 --- a/src/net/textproto/reader_test.go +++ b/src/net/textproto/reader_test.go @@ -36,6 +36,18 @@ func TestReadLine(t *testing.T) { } } +func TestReadLineLongLine(t *testing.T) { + line := strings.Repeat("12345", 10000) + r := reader(line + "\r\n") + s, err := r.ReadLine() + if err != nil { + t.Fatalf("Line 1: %v", err) + } + if s != line { + t.Fatalf("%v-byte line does not match expected %v-byte line", len(s), len(line)) + } +} + func TestReadContinuedLine(t *testing.T) { r := reader("line1\nline\n 2\nline3\n") s, err := r.ReadContinuedLine() diff --git a/src/net/textproto/textproto.go b/src/net/textproto/textproto.go index 70038d5888..4ae3ecff74 100644 --- a/src/net/textproto/textproto.go +++ b/src/net/textproto/textproto.go @@ -7,20 +7,20 @@ // // The package provides: // -// Error, which represents a numeric error response from +// [Error], which represents a numeric error response from // a server. // -// Pipeline, to manage pipelined requests and responses +// [Pipeline], to manage pipelined requests and responses // in a client. // -// Reader, to read numeric response code lines, +// [Reader], to read numeric response code lines, // key: value headers, lines wrapped with leading spaces // on continuation lines, and whole text blocks ending // with a dot on a line by itself. // -// Writer, to write dot-encoded text blocks. +// [Writer], to write dot-encoded text blocks. // -// Conn, a convenient packaging of Reader, Writer, and Pipeline for use +// [Conn], a convenient packaging of [Reader], [Writer], and [Pipeline] for use // with a single network connection. package textproto @@ -50,8 +50,8 @@ func (p ProtocolError) Error() string { } // A Conn represents a textual network protocol connection. -// It consists of a Reader and Writer to manage I/O -// and a Pipeline to sequence concurrent requests on the connection. +// It consists of a [Reader] and [Writer] to manage I/O +// and a [Pipeline] to sequence concurrent requests on the connection. // These embedded types carry methods with them; // see the documentation of those types for details. type Conn struct { @@ -61,7 +61,7 @@ type Conn struct { conn io.ReadWriteCloser } -// NewConn returns a new Conn using conn for I/O. +// NewConn returns a new [Conn] using conn for I/O. func NewConn(conn io.ReadWriteCloser) *Conn { return &Conn{ Reader: Reader{R: bufio.NewReader(conn)}, @@ -75,8 +75,8 @@ func (c *Conn) Close() error { return c.conn.Close() } -// Dial connects to the given address on the given network using net.Dial -// and then returns a new Conn for the connection. +// Dial connects to the given address on the given network using [net.Dial] +// and then returns a new [Conn] for the connection. func Dial(network, addr string) (*Conn, error) { c, err := net.Dial(network, addr) if err != nil { diff --git a/src/net/textproto/writer.go b/src/net/textproto/writer.go index 2ece3f511b..662515fb2c 100644 --- a/src/net/textproto/writer.go +++ b/src/net/textproto/writer.go @@ -17,7 +17,7 @@ type Writer struct { dot *dotWriter } -// NewWriter returns a new Writer writing to w. +// NewWriter returns a new [Writer] writing to w. func NewWriter(w *bufio.Writer) *Writer { return &Writer{W: w} } @@ -39,7 +39,7 @@ func (w *Writer) PrintfLine(format string, args ...any) error { // when the DotWriter is closed. The caller should close the // DotWriter before the next call to a method on w. // -// See the documentation for Reader's DotReader method for details about dot-encoding. +// See the documentation for the [Reader.DotReader] method for details about dot-encoding. func (w *Writer) DotWriter() io.WriteCloser { w.closeDot() w.dot = &dotWriter{w: w} diff --git a/src/net/timeout_test.go b/src/net/timeout_test.go index c0bce57b94..ca86f31ef2 100644 --- a/src/net/timeout_test.go +++ b/src/net/timeout_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -11,7 +9,6 @@ import ( "fmt" "internal/testenv" "io" - "net/internal/socktest" "os" "runtime" "sync" @@ -19,65 +16,121 @@ import ( "time" ) -var dialTimeoutTests = []struct { - timeout time.Duration - delta time.Duration // for deadline +func init() { + // Install a hook to ensure that a 1ns timeout will always + // be exceeded by the time Dial gets to the relevant system call. + // + // Without this, systems with a very large timer granularity — such as + // Windows — may be able to accept connections without measurably exceeding + // even an implausibly short deadline. + testHookStepTime = func() { + now := time.Now() + for time.Since(now) == 0 { + time.Sleep(1 * time.Nanosecond) + } + } +} - guard time.Duration +var dialTimeoutTests = []struct { + initialTimeout time.Duration + initialDelta time.Duration // for deadline }{ // Tests that dial timeouts, deadlines in the past work. - {-5 * time.Second, 0, -5 * time.Second}, - {0, -5 * time.Second, -5 * time.Second}, - {-5 * time.Second, 5 * time.Second, -5 * time.Second}, // timeout over deadline - {-1 << 63, 0, time.Second}, - {0, -1 << 63, time.Second}, - - {50 * time.Millisecond, 0, 100 * time.Millisecond}, - {0, 50 * time.Millisecond, 100 * time.Millisecond}, - {50 * time.Millisecond, 5 * time.Second, 100 * time.Millisecond}, // timeout over deadline + {-5 * time.Second, 0}, + {0, -5 * time.Second}, + {-5 * time.Second, 5 * time.Second}, // timeout over deadline + {-1 << 63, 0}, + {0, -1 << 63}, + + {1 * time.Millisecond, 0}, + {0, 1 * time.Millisecond}, + {1 * time.Millisecond, 5 * time.Second}, // timeout over deadline } func TestDialTimeout(t *testing.T) { - // Cannot use t.Parallel - modifies global hooks. - origTestHookDialChannel := testHookDialChannel - defer func() { testHookDialChannel = origTestHookDialChannel }() - defer sw.Set(socktest.FilterConnect, nil) - - for i, tt := range dialTimeoutTests { - switch runtime.GOOS { - case "plan9", "windows": - testHookDialChannel = func() { time.Sleep(tt.guard) } - if runtime.GOOS == "plan9" { - break - } - fallthrough - default: - sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) { - time.Sleep(tt.guard) - return nil, errTimedout - }) - } + switch runtime.GOOS { + case "plan9": + t.Skipf("not supported on %s", runtime.GOOS) + } - d := Dialer{Timeout: tt.timeout} - if tt.delta != 0 { - d.Deadline = time.Now().Add(tt.delta) - } + t.Parallel() - // This dial never starts to send any TCP SYN - // segment because of above socket filter and - // test hook. - c, err := d.Dial("tcp", "127.0.0.1:0") - if err == nil { - err = fmt.Errorf("unexpectedly established: tcp:%s->%s", c.LocalAddr(), c.RemoteAddr()) - c.Close() + ln := newLocalListener(t, "tcp") + defer func() { + if err := ln.Close(); err != nil { + t.Error(err) } + }() - if perr := parseDialError(err); perr != nil { - t.Errorf("#%d: %v", i, perr) - } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { - t.Fatalf("#%d: %v", i, err) - } + for _, tt := range dialTimeoutTests { + t.Run(fmt.Sprintf("%v/%v", tt.initialTimeout, tt.initialDelta), func(t *testing.T) { + // We don't run these subtests in parallel because we don't know how big + // the kernel's accept queue is, and we don't want to accidentally saturate + // it with concurrent calls. (That could cause the Dial to fail with + // ECONNREFUSED or ECONNRESET instead of a timeout error.) + d := Dialer{Timeout: tt.initialTimeout} + delta := tt.initialDelta + + var ( + beforeDial time.Time + afterDial time.Time + err error + ) + for { + if delta != 0 { + d.Deadline = time.Now().Add(delta) + } + + beforeDial = time.Now() + + var c Conn + c, err = d.Dial(ln.Addr().Network(), ln.Addr().String()) + afterDial = time.Now() + + if err != nil { + break + } + + // Even though we're not calling Accept on the Listener, the kernel may + // spuriously accept connections on its behalf. If that happens, we will + // close the connection (to try to get it out of the kernel's accept + // queue) and try a shorter timeout. + // + // We assume that we will reach a point where the call actually does + // time out, although in theory (since this socket is on a loopback + // address) a sufficiently clever kernel could notice that no Accept + // call is pending and bypass both the queue and the timeout to return + // another error immediately. + t.Logf("closing spurious connection from Dial") + c.Close() + + if delta <= 1 && d.Timeout <= 1 { + t.Fatalf("can't reduce Timeout or Deadline") + } + if delta > 1 { + delta /= 2 + t.Logf("reducing Deadline delta to %v", delta) + } + if d.Timeout > 1 { + d.Timeout /= 2 + t.Logf("reducing Timeout to %v", d.Timeout) + } + } + + if d.Deadline.IsZero() || afterDial.Before(d.Deadline) { + delay := afterDial.Sub(beforeDial) + if delay < d.Timeout { + t.Errorf("Dial returned after %v; want ≥%v", delay, d.Timeout) + } + } + + if perr := parseDialError(err); perr != nil { + t.Errorf("unexpected error from Dial: %v", perr) + } + if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + t.Errorf("Dial: %v, want timeout", err) + } + }) } } @@ -189,35 +242,22 @@ func TestAcceptTimeoutMustReturn(t *testing.T) { ln := newLocalListener(t, "tcp") defer ln.Close() - max := time.NewTimer(time.Second) - defer max.Stop() - ch := make(chan error) - go func() { - if err := ln.(*TCPListener).SetDeadline(noDeadline); err != nil { - t.Error(err) - } - if err := ln.(*TCPListener).SetDeadline(time.Now().Add(10 * time.Millisecond)); err != nil { - t.Error(err) - } - c, err := ln.Accept() - if err == nil { - c.Close() - } - ch <- err - }() + if err := ln.(*TCPListener).SetDeadline(noDeadline); err != nil { + t.Error(err) + } + if err := ln.(*TCPListener).SetDeadline(time.Now().Add(10 * time.Millisecond)); err != nil { + t.Error(err) + } + c, err := ln.Accept() + if err == nil { + c.Close() + } - select { - case <-max.C: - ln.Close() - <-ch // wait for tester goroutine to stop - t.Fatal("Accept didn't return in an expected time") - case err := <-ch: - if perr := parseAcceptError(err); perr != nil { - t.Error(perr) - } - if !isDeadlineExceeded(err) { - t.Fatal(err) - } + if perr := parseAcceptError(err); perr != nil { + t.Error(perr) + } + if !isDeadlineExceeded(err) { + t.Fatal(err) } } @@ -529,7 +569,7 @@ func TestWriteTimeoutMustNotReturn(t *testing.T) { t.Error(err) } maxch <- time.NewTimer(100 * time.Millisecond) - var b [1]byte + var b [1024]byte for { if _, err := c.Write(b[:]); err != nil { ch <- err diff --git a/src/net/udpsock.go b/src/net/udpsock.go index e30624dea5..4f8acb7fc8 100644 --- a/src/net/udpsock.go +++ b/src/net/udpsock.go @@ -129,7 +129,7 @@ func (c *UDPConn) SyscallConn() (syscall.RawConn, error) { if !c.ok() { return nil, syscall.EINVAL } - return newRawConn(c.fd) + return newRawConn(c.fd), nil } // ReadFromUDP acts like ReadFrom but returns a UDPAddr. diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go index f3dbcfec00..5035059831 100644 --- a/src/net/udpsock_posix.go +++ b/src/net/udpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net diff --git a/src/net/udpsock_test.go b/src/net/udpsock_test.go index 2afd4ac2ae..8a21aa7370 100644 --- a/src/net/udpsock_test.go +++ b/src/net/udpsock_test.go @@ -2,12 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( "errors" + "fmt" "internal/testenv" "net/netip" "os" @@ -116,6 +115,10 @@ func TestWriteToUDP(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } + if !testableNetwork("udp") { + t.Skipf("skipping: udp not supported") + } + c, err := ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -221,19 +224,29 @@ func TestUDPConnLocalName(t *testing.T) { testenv.MustHaveExternalNetwork(t) for _, tt := range udpConnLocalNameTests { - c, err := ListenUDP(tt.net, tt.laddr) - if err != nil { - t.Fatal(err) - } - defer c.Close() - la := c.LocalAddr() - if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Fatalf("got %v; expected a proper address with non-zero port number", la) - } + t.Run(fmt.Sprint(tt.laddr), func(t *testing.T) { + if !testableNetwork(tt.net) { + t.Skipf("skipping: %s not available", tt.net) + } + + c, err := ListenUDP(tt.net, tt.laddr) + if err != nil { + t.Fatal(err) + } + defer c.Close() + la := c.LocalAddr() + if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { + t.Fatalf("got %v; expected a proper address with non-zero port number", la) + } + }) } } func TestUDPConnLocalAndRemoteNames(t *testing.T) { + if !testableNetwork("udp") { + t.Skipf("skipping: udp not available") + } + for _, laddr := range []string{"", "127.0.0.1:0"} { c1, err := ListenPacket("udp", "127.0.0.1:0") if err != nil { @@ -330,6 +343,9 @@ func TestUDPZeroBytePayload(t *testing.T) { case "darwin", "ios": testenv.SkipFlaky(t, 29225) } + if !testableNetwork("udp") { + t.Skipf("skipping: udp not available") + } c := newLocalPacketListener(t, "udp") defer c.Close() @@ -363,6 +379,9 @@ func TestUDPZeroByteBuffer(t *testing.T) { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } + if !testableNetwork("udp") { + t.Skipf("skipping: udp not available") + } c := newLocalPacketListener(t, "udp") defer c.Close() @@ -397,6 +416,9 @@ func TestUDPReadSizeError(t *testing.T) { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } + if !testableNetwork("udp") { + t.Skipf("skipping: udp not available") + } c1 := newLocalPacketListener(t, "udp") defer c1.Close() @@ -434,6 +456,10 @@ func TestUDPReadSizeError(t *testing.T) { // TestUDPReadTimeout verifies that ReadFromUDP with timeout returns an error // without data or an address. func TestUDPReadTimeout(t *testing.T) { + if !testableNetwork("udp4") { + t.Skipf("skipping: udp4 not available") + } + la, err := ResolveUDPAddr("udp4", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -460,10 +486,14 @@ func TestUDPReadTimeout(t *testing.T) { func TestAllocs(t *testing.T) { switch runtime.GOOS { - case "plan9": - // Plan9 wasn't optimized. + case "plan9", "js", "wasip1": + // These implementations have not been optimized. t.Skipf("skipping on %v", runtime.GOOS) } + if !testableNetwork("udp4") { + t.Skipf("skipping: udp4 not available") + } + // Optimizations are required to remove the allocs. testenv.SkipIfOptimizationOff(t) @@ -590,6 +620,10 @@ func TestUDPIPVersionReadMsg(t *testing.T) { case "plan9": t.Skipf("skipping on %v", runtime.GOOS) } + if !testableNetwork("udp4") { + t.Skipf("skipping: udp4 not available") + } + conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) if err != nil { t.Fatal(err) @@ -625,8 +659,11 @@ func TestUDPIPVersionReadMsg(t *testing.T) { // WriteMsgUDPAddrPort accepts IPv4, IPv4-mapped IPv6, and IPv6 target addresses // on a UDPConn listening on "::". func TestIPv6WriteMsgUDPAddrPortTargetAddrIPVersion(t *testing.T) { - if !supportsIPv6() { - t.Skip("IPv6 is not supported") + if !testableNetwork("udp4") { + t.Skipf("skipping: udp4 not available") + } + if !testableNetwork("udp6") { + t.Skipf("skipping: udp6 not available") } switch runtime.GOOS { diff --git a/src/net/unixsock.go b/src/net/unixsock.go index 14fbac0932..821be7bf74 100644 --- a/src/net/unixsock.go +++ b/src/net/unixsock.go @@ -52,7 +52,7 @@ func (a *UnixAddr) opAddr() Addr { // // The network must be a Unix network name. // -// See func Dial for a description of the network and address +// See func [Dial] for a description of the network and address // parameters. func ResolveUnixAddr(network, address string) (*UnixAddr, error) { switch network { @@ -63,19 +63,19 @@ func ResolveUnixAddr(network, address string) (*UnixAddr, error) { } } -// UnixConn is an implementation of the Conn interface for connections +// UnixConn is an implementation of the [Conn] interface for connections // to Unix domain sockets. type UnixConn struct { conn } // SyscallConn returns a raw network connection. -// This implements the syscall.Conn interface. +// This implements the [syscall.Conn] interface. func (c *UnixConn) SyscallConn() (syscall.RawConn, error) { if !c.ok() { return nil, syscall.EINVAL } - return newRawConn(c.fd) + return newRawConn(c.fd), nil } // CloseRead shuts down the reading side of the Unix domain connection. @@ -102,7 +102,7 @@ func (c *UnixConn) CloseWrite() error { return nil } -// ReadFromUnix acts like ReadFrom but returns a UnixAddr. +// ReadFromUnix acts like [UnixConn.ReadFrom] but returns a [UnixAddr]. func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -114,7 +114,7 @@ func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) { return n, addr, err } -// ReadFrom implements the PacketConn ReadFrom method. +// ReadFrom implements the [PacketConn] ReadFrom method. func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -147,7 +147,7 @@ func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAdd return } -// WriteToUnix acts like WriteTo but takes a UnixAddr. +// WriteToUnix acts like [UnixConn.WriteTo] but takes a [UnixAddr]. func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -159,7 +159,7 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) { return n, err } -// WriteTo implements the PacketConn WriteTo method. +// WriteTo implements the [PacketConn] WriteTo method. func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -194,7 +194,7 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} } -// DialUnix acts like Dial for Unix networks. +// DialUnix acts like [Dial] for Unix networks. // // The network must be a Unix network name; see func Dial for details. // @@ -215,7 +215,7 @@ func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) { } // UnixListener is a Unix domain socket listener. Clients should -// typically use variables of type Listener instead of assuming Unix +// typically use variables of type [Listener] instead of assuming Unix // domain sockets. type UnixListener struct { fd *netFD @@ -227,7 +227,7 @@ type UnixListener struct { func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil } // SyscallConn returns a raw network connection. -// This implements the syscall.Conn interface. +// This implements the [syscall.Conn] interface. // // The returned RawConn only supports calling Control. Read and // Write return an error. @@ -235,7 +235,7 @@ func (l *UnixListener) SyscallConn() (syscall.RawConn, error) { if !l.ok() { return nil, syscall.EINVAL } - return newRawListener(l.fd) + return newRawListener(l.fd), nil } // AcceptUnix accepts the next incoming call and returns the new @@ -251,8 +251,8 @@ func (l *UnixListener) AcceptUnix() (*UnixConn, error) { return c, nil } -// Accept implements the Accept method in the Listener interface. -// Returned connections will be of type *UnixConn. +// Accept implements the Accept method in the [Listener] interface. +// Returned connections will be of type [*UnixConn]. func (l *UnixListener) Accept() (Conn, error) { if !l.ok() { return nil, syscall.EINVAL @@ -287,13 +287,10 @@ func (l *UnixListener) SetDeadline(t time.Time) error { if !l.ok() { return syscall.EINVAL } - if err := l.fd.pfd.SetDeadline(t); err != nil { - return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err} - } - return nil + return l.fd.SetDeadline(t) } -// File returns a copy of the underlying os.File. +// File returns a copy of the underlying [os.File]. // It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. // @@ -311,7 +308,7 @@ func (l *UnixListener) File() (f *os.File, err error) { return } -// ListenUnix acts like Listen for Unix networks. +// ListenUnix acts like [Listen] for Unix networks. // // The network must be "unix" or "unixpacket". func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { @@ -331,7 +328,7 @@ func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { return ln, nil } -// ListenUnixgram acts like ListenPacket for Unix networks. +// ListenUnixgram acts like [ListenPacket] for Unix networks. // // The network must be "unixgram". func ListenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) { diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go index c501b499ed..f6c8e8f0b0 100644 --- a/src/net/unixsock_posix.go +++ b/src/net/unixsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix || (js && wasm) || wasip1 || windows +//go:build unix || js || wasip1 || windows package net diff --git a/src/net/unixsock_readmsg_other.go b/src/net/unixsock_readmsg_other.go index 0899a6d3d3..4bef3ee71d 100644 --- a/src/net/unixsock_readmsg_other.go +++ b/src/net/unixsock_readmsg_other.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (js && wasm) || wasip1 || windows +//go:build js || wasip1 || windows package net diff --git a/src/net/unixsock_test.go b/src/net/unixsock_test.go index 8402519a0d..6906ecc046 100644 --- a/src/net/unixsock_test.go +++ b/src/net/unixsock_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !plan9 && !wasip1 && !windows +//go:build !plan9 && !windows package net @@ -21,6 +21,10 @@ func TestReadUnixgramWithUnnamedSocket(t *testing.T) { if !testableNetwork("unixgram") { t.Skip("unixgram test") } + switch runtime.GOOS { + case "js", "wasip1": + t.Skipf("skipping: syscall.Socket not implemented on %s", runtime.GOOS) + } if runtime.GOOS == "openbsd" { testenv.SkipFlaky(t, 15157) } @@ -359,6 +363,11 @@ func TestUnixUnlink(t *testing.T) { if !testableNetwork("unix") { t.Skip("unix test") } + switch runtime.GOOS { + case "js", "wasip1": + t.Skipf("skipping: %s does not support Unlink", runtime.GOOS) + } + name := testUnixAddr(t) listen := func(t *testing.T) *UnixListener { diff --git a/src/net/url/url.go b/src/net/url/url.go index 501b263e87..f362958edd 100644 --- a/src/net/url/url.go +++ b/src/net/url/url.go @@ -175,7 +175,7 @@ func shouldEscape(c byte, mode encoding) bool { return true } -// QueryUnescape does the inverse transformation of QueryEscape, +// QueryUnescape does the inverse transformation of [QueryEscape], // converting each 3-byte encoded substring of the form "%AB" into the // hex-decoded byte 0xAB. // It returns an error if any % is not followed by two hexadecimal @@ -184,12 +184,12 @@ func QueryUnescape(s string) (string, error) { return unescape(s, encodeQueryComponent) } -// PathUnescape does the inverse transformation of PathEscape, +// PathUnescape does the inverse transformation of [PathEscape], // converting each 3-byte encoded substring of the form "%AB" into the // hex-decoded byte 0xAB. It returns an error if any % is not followed // by two hexadecimal digits. // -// PathUnescape is identical to QueryUnescape except that it does not +// PathUnescape is identical to [QueryUnescape] except that it does not // unescape '+' to ' ' (space). func PathUnescape(s string) (string, error) { return unescape(s, encodePathSegment) @@ -271,12 +271,12 @@ func unescape(s string, mode encoding) (string, error) { } // QueryEscape escapes the string so it can be safely placed -// inside a URL query. +// inside a [URL] query. func QueryEscape(s string) string { return escape(s, encodeQueryComponent) } -// PathEscape escapes the string so it can be safely placed inside a URL path segment, +// PathEscape escapes the string so it can be safely placed inside a [URL] path segment, // replacing special characters (including /) with %XX sequences as needed. func PathEscape(s string) string { return escape(s, encodePathSegment) @@ -348,10 +348,17 @@ func escape(s string, mode encoding) string { // // scheme:opaque[?query][#fragment] // +// The Host field contains the host and port subcomponents of the URL. +// When the port is present, it is separated from the host with a colon. +// When the host is an IPv6 address, it must be enclosed in square brackets: +// "[fe80::1]:80". The [net.JoinHostPort] function combines a host and port +// into a string suitable for the Host field, adding square brackets to +// the host when necessary. +// // Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. // A consequence is that it is impossible to tell which slashes in the Path were // slashes in the raw URL and which were %2f. This distinction is rarely important, -// but when it is, the code should use the EscapedPath method, which preserves +// but when it is, the code should use the [URL.EscapedPath] method, which preserves // the original encoding of Path. // // The RawPath field is an optional field which is only set when the default @@ -363,7 +370,7 @@ type URL struct { Scheme string Opaque string // encoded opaque data User *Userinfo // username and password information - Host string // host or host:port + Host string // host or host:port (see Hostname and Port methods) Path string // path (relative paths may omit leading slash) RawPath string // encoded path hint (see EscapedPath method) OmitHost bool // do not emit empty host (authority) @@ -373,13 +380,13 @@ type URL struct { RawFragment string // encoded fragment hint (see EscapedFragment method) } -// User returns a Userinfo containing the provided username +// User returns a [Userinfo] containing the provided username // and no password set. func User(username string) *Userinfo { return &Userinfo{username, "", false} } -// UserPassword returns a Userinfo containing the provided username +// UserPassword returns a [Userinfo] containing the provided username // and password. // // This functionality should only be used with legacy web sites. @@ -392,7 +399,7 @@ func UserPassword(username, password string) *Userinfo { } // The Userinfo type is an immutable encapsulation of username and -// password details for a URL. An existing Userinfo value is guaranteed +// password details for a [URL]. An existing Userinfo value is guaranteed // to have a username set (potentially empty, as allowed by RFC 2396), // and optionally a password. type Userinfo struct { @@ -457,7 +464,7 @@ func getScheme(rawURL string) (scheme, path string, err error) { return "", rawURL, nil } -// Parse parses a raw url into a URL structure. +// Parse parses a raw url into a [URL] structure. // // The url may be relative (a path, without a host) or absolute // (starting with a scheme). Trying to parse a hostname and path @@ -479,7 +486,7 @@ func Parse(rawURL string) (*URL, error) { return url, nil } -// ParseRequestURI parses a raw url into a URL structure. It assumes that +// ParseRequestURI parses a raw url into a [URL] structure. It assumes that // url was received in an HTTP request, so the url is interpreted // only as an absolute URI or an absolute path. // The string url is assumed not to have a #fragment suffix. @@ -690,7 +697,7 @@ func (u *URL) setPath(p string) error { // EscapedPath returns u.RawPath when it is a valid escaping of u.Path. // Otherwise EscapedPath ignores u.RawPath and computes an escaped // form on its own. -// The String and RequestURI methods use EscapedPath to construct +// The [URL.String] and [URL.RequestURI] methods use EscapedPath to construct // their results. // In general, code should call EscapedPath instead of // reading u.RawPath directly. @@ -754,7 +761,7 @@ func (u *URL) setFragment(f string) error { // EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment. // Otherwise EscapedFragment ignores u.RawFragment and computes an escaped // form on its own. -// The String method uses EscapedFragment to construct its result. +// The [URL.String] method uses EscapedFragment to construct its result. // In general, code should call EscapedFragment instead of // reading u.RawFragment directly. func (u *URL) EscapedFragment() string { @@ -784,7 +791,7 @@ func validOptionalPort(port string) bool { return true } -// String reassembles the URL into a valid URL string. +// String reassembles the [URL] into a valid URL string. // The general form of the result is one of: // // scheme:opaque?query#fragment @@ -858,7 +865,7 @@ func (u *URL) String() string { return buf.String() } -// Redacted is like String but replaces any password with "xxxxx". +// Redacted is like [URL.String] but replaces any password with "xxxxx". // Only the password in u.User is redacted. func (u *URL) Redacted() string { if u == nil { @@ -963,7 +970,7 @@ func parseQuery(m Values, query string) (err error) { // Encode encodes the values into “URL encoded” form // ("bar=baz&foo=quux") sorted by key. func (v Values) Encode() string { - if v == nil { + if len(v) == 0 { return "" } var buf strings.Builder @@ -1053,15 +1060,15 @@ func resolvePath(base, ref string) string { return r } -// IsAbs reports whether the URL is absolute. +// IsAbs reports whether the [URL] is absolute. // Absolute means that it has a non-empty scheme. func (u *URL) IsAbs() bool { return u.Scheme != "" } -// Parse parses a URL in the context of the receiver. The provided URL +// Parse parses a [URL] in the context of the receiver. The provided URL // may be relative or absolute. Parse returns nil, err on parse -// failure, otherwise its return value is the same as ResolveReference. +// failure, otherwise its return value is the same as [URL.ResolveReference]. func (u *URL) Parse(ref string) (*URL, error) { refURL, err := Parse(ref) if err != nil { @@ -1073,7 +1080,7 @@ func (u *URL) Parse(ref string) (*URL, error) { // ResolveReference resolves a URI reference to an absolute URI from // an absolute base URI u, per RFC 3986 Section 5.2. The URI reference // may be relative or absolute. ResolveReference always returns a new -// URL instance, even if the returned URL is identical to either the +// [URL] instance, even if the returned URL is identical to either the // base or reference. If ref is an absolute URL, then ResolveReference // ignores base and returns a copy of ref. func (u *URL) ResolveReference(ref *URL) *URL { @@ -1110,7 +1117,7 @@ func (u *URL) ResolveReference(ref *URL) *URL { // Query parses RawQuery and returns the corresponding values. // It silently discards malformed value pairs. -// To check errors use ParseQuery. +// To check errors use [ParseQuery]. func (u *URL) Query() Values { v, _ := ParseQuery(u.RawQuery) return v @@ -1187,7 +1194,7 @@ func (u *URL) UnmarshalBinary(text []byte) error { return nil } -// JoinPath returns a new URL with the provided path elements joined to +// JoinPath returns a new [URL] with the provided path elements joined to // any existing path and the resulting path cleaned of any ./ or ../ elements. // Any sequences of multiple / characters will be reduced to a single /. func (u *URL) JoinPath(elem ...string) *URL { @@ -1253,7 +1260,7 @@ func stringContainsCTLByte(s string) bool { return false } -// JoinPath returns a URL string with the provided path elements joined to +// JoinPath returns a [URL] string with the provided path elements joined to // the existing path of base and the resulting path cleaned of any ./ or ../ elements. func JoinPath(base string, elem ...string) (result string, err error) { url, err := Parse(base) diff --git a/src/net/url/url_test.go b/src/net/url/url_test.go index 23c5c581c5..4aa20bb95f 100644 --- a/src/net/url/url_test.go +++ b/src/net/url/url_test.go @@ -1072,6 +1072,7 @@ type EncodeQueryTest struct { var encodeQueryTests = []EncodeQueryTest{ {nil, ""}, + {Values{}, ""}, {Values{"q": {"puppies"}, "oe": {"utf8"}}, "oe=utf8&q=puppies"}, {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7"}, {Values{ diff --git a/src/net/writev_test.go b/src/net/writev_test.go index 8722c0f920..e4e88c4fac 100644 --- a/src/net/writev_test.go +++ b/src/net/writev_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js && !wasip1 - package net import ( @@ -187,9 +185,15 @@ func TestWritevError(t *testing.T) { } ln := newLocalListener(t, "tcp") - defer ln.Close() ch := make(chan Conn, 1) + defer func() { + ln.Close() + for c := range ch { + c.Close() + } + }() + go func() { defer close(ch) c, err := ln.Accept() |