aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorRob Findley <rfindley@google.com>2020-02-19 10:18:21 -0500
committerRobert Findley <rfindley@google.com>2020-02-24 22:51:04 +0000
commit20f46356b33885092df84870bb0f5da7f7b29d6d (patch)
treee81fb77db003a581e8506cc342130cb164cf253f /internal
parente02f5847d1aae0470827ff54a24022013aa4d0d0 (diff)
downloadgolang-x-tools-20f46356b33885092df84870bb0f5da7f7b29d6d.tar.gz
internal/lsp/lsprpc: add a handshake between forwarder and remote
In the ideal future, users will have one or more gopls instances, each serving potentially many LSP clients. In order to have any hope of navigating this web, clients and servers must know about eachother. To allow for such an exchange of information, this CL adds an additional handler layer to the serving configured in the lsprpc package. For now, forwarders just use this layer to execute a handshake with the LSP server, communicating the location of their logs and debug addresses. Updates golang/go#34111 Change-Id: Ic7432062c01a8bbd52fb4a058a95bbf5dc26baa3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/220081 Run-TryBot: Robert Findley <rfindley@google.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Heschi Kreinick <heschi@google.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/lsp/cache/cache.go2
-rw-r--r--internal/lsp/cache/debug.go12
-rw-r--r--internal/lsp/cache/session.go2
-rw-r--r--internal/lsp/cmd/cmd_test.go6
-rw-r--r--internal/lsp/cmd/serve.go4
-rw-r--r--internal/lsp/debug/serve.go8
-rw-r--r--internal/lsp/lsprpc/lsprpc.go184
-rw-r--r--internal/lsp/lsprpc/lsprpc_test.go65
-rw-r--r--internal/lsp/regtest/env.go11
-rw-r--r--internal/lsp/regtest/reg_test.go14
10 files changed, 266 insertions, 42 deletions
diff --git a/internal/lsp/cache/cache.go b/internal/lsp/cache/cache.go
index be451b3a5..4b71ebc5f 100644
--- a/internal/lsp/cache/cache.go
+++ b/internal/lsp/cache/cache.go
@@ -87,7 +87,7 @@ func (c *Cache) NewSession() *Session {
options: source.DefaultOptions(),
overlays: make(map[span.URI]*overlay),
}
- c.debug.AddSession(debugSession{s})
+ c.debug.AddSession(DebugSession{s})
return s
}
diff --git a/internal/lsp/cache/debug.go b/internal/lsp/cache/debug.go
index 65dcb5ecf..d4ea52841 100644
--- a/internal/lsp/cache/debug.go
+++ b/internal/lsp/cache/debug.go
@@ -14,14 +14,14 @@ import (
type debugView struct{ *view }
func (v debugView) ID() string { return v.id }
-func (v debugView) Session() debug.Session { return debugSession{v.session} }
+func (v debugView) Session() debug.Session { return DebugSession{v.session} }
func (v debugView) Env() []string { return v.Options().Env }
-type debugSession struct{ *Session }
+type DebugSession struct{ *Session }
-func (s debugSession) ID() string { return s.id }
-func (s debugSession) Cache() debug.Cache { return debugCache{s.cache} }
-func (s debugSession) Files() []*debug.File {
+func (s DebugSession) ID() string { return s.id }
+func (s DebugSession) Cache() debug.Cache { return debugCache{s.cache} }
+func (s DebugSession) Files() []*debug.File {
var files []*debug.File
seen := make(map[span.URI]*debug.File)
s.overlayMu.Lock()
@@ -43,7 +43,7 @@ func (s debugSession) Files() []*debug.File {
return files
}
-func (s debugSession) File(hash string) *debug.File {
+func (s DebugSession) File(hash string) *debug.File {
s.overlayMu.Lock()
defer s.overlayMu.Unlock()
for _, overlay := range s.overlays {
diff --git a/internal/lsp/cache/session.go b/internal/lsp/cache/session.go
index 2630f74be..9df76aa34 100644
--- a/internal/lsp/cache/session.go
+++ b/internal/lsp/cache/session.go
@@ -77,7 +77,7 @@ func (s *Session) Shutdown(ctx context.Context) {
}
s.views = nil
s.viewMap = nil
- s.cache.debug.DropSession(debugSession{s})
+ s.cache.debug.DropSession(DebugSession{s})
}
func (s *Session) Cache() source.Cache {
diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go
index 16b6c3258..bc56a3a20 100644
--- a/internal/lsp/cmd/cmd_test.go
+++ b/internal/lsp/cmd/cmd_test.go
@@ -18,6 +18,7 @@ import (
"golang.org/x/tools/internal/lsp/cache"
"golang.org/x/tools/internal/lsp/cmd"
cmdtest "golang.org/x/tools/internal/lsp/cmd/test"
+ "golang.org/x/tools/internal/lsp/debug"
"golang.org/x/tools/internal/lsp/lsprpc"
"golang.org/x/tools/internal/lsp/tests"
"golang.org/x/tools/internal/testenv"
@@ -46,8 +47,9 @@ func testCommandLine(t *testing.T, exporter packagestest.Exporter) {
}
func testServer(ctx context.Context) *servertest.TCPServer {
- cache := cache.New(nil, nil)
- ss := lsprpc.NewStreamServer(cache, false)
+ di := debug.NewInstance("", "")
+ cache := cache.New(nil, di.State)
+ ss := lsprpc.NewStreamServer(cache, false, di)
return servertest.NewTCPServer(ctx, ss)
}
diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go
index 81fbb278b..907407fbb 100644
--- a/internal/lsp/cmd/serve.go
+++ b/internal/lsp/cmd/serve.go
@@ -66,9 +66,9 @@ func (s *Serve) Run(ctx context.Context, args ...string) error {
var ss jsonrpc2.StreamServer
if s.app.Remote != "" {
network, addr := parseAddr(s.app.Remote)
- ss = lsprpc.NewForwarder(network, addr, true)
+ ss = lsprpc.NewForwarder(network, addr, true, s.app.debug)
} else {
- ss = lsprpc.NewStreamServer(cache.New(s.app.options, s.app.debug.State), true)
+ ss = lsprpc.NewStreamServer(cache.New(s.app.options, s.app.debug.State), true, s.app.debug)
}
if s.Address != "" {
diff --git a/internal/lsp/debug/serve.go b/internal/lsp/debug/serve.go
index 221431be8..c6304b538 100644
--- a/internal/lsp/debug/serve.go
+++ b/internal/lsp/debug/serve.go
@@ -154,7 +154,7 @@ func (st *State) Servers() []Server {
type Client interface {
ID() string
Session() Session
- DebugAddr() string
+ DebugAddress() string
Logfile() string
ServerID() string
}
@@ -162,7 +162,7 @@ type Client interface {
// A Server is an outgoing connection to a remote LSP server.
type Server interface {
ID() string
- DebugAddr() string
+ DebugAddress() string
Logfile() string
ClientID() string
}
@@ -644,7 +644,7 @@ var clientTmpl = template.Must(template.Must(baseTemplate.Clone()).Parse(`
{{define "title"}}Client {{.ID}}{{end}}
{{define "body"}}
Using session: <b>{{template "sessionlink" .Session.ID}}</b><br>
-Debug this client at: <a href="http://{{url .DebugAddr}}">{{.DebugAddr}}</a><br>
+Debug this client at: <a href="http://{{url .DebugAddress}}">{{.DebugAddress}}</a><br>
Logfile: {{.Logfile}}<br>
{{end}}
`))
@@ -652,7 +652,7 @@ Logfile: {{.Logfile}}<br>
var serverTmpl = template.Must(template.Must(baseTemplate.Clone()).Parse(`
{{define "title"}}Server {{.ID}}{{end}}
{{define "body"}}
-Debug this server at: <a href="http://{{.DebugAddr}}">{{.DebugAddr}}</a><br>
+Debug this server at: <a href="http://{{.DebugAddress}}">{{.DebugAddress}}</a><br>
Logfile: {{.Logfile}}<br>
{{end}}
`))
diff --git a/internal/lsp/lsprpc/lsprpc.go b/internal/lsp/lsprpc/lsprpc.go
index fb00ad022..2c0ae46f8 100644
--- a/internal/lsp/lsprpc/lsprpc.go
+++ b/internal/lsp/lsprpc/lsprpc.go
@@ -8,54 +8,126 @@ package lsprpc
import (
"context"
+ "encoding/json"
"fmt"
- "log"
"net"
"os"
+ "strconv"
+ "sync/atomic"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/tools/internal/jsonrpc2"
"golang.org/x/tools/internal/lsp"
"golang.org/x/tools/internal/lsp/cache"
+ "golang.org/x/tools/internal/lsp/debug"
"golang.org/x/tools/internal/lsp/protocol"
+ "golang.org/x/tools/internal/telemetry/log"
)
// The StreamServer type is a jsonrpc2.StreamServer that handles incoming
// streams as a new LSP session, using a shared cache.
type StreamServer struct {
withTelemetry bool
+ debug *debug.Instance
cache *cache.Cache
- // If set, serverForTest is used instead of an actual lsp.Server.
+ // serverForTest may be set to a test fake for testing.
serverForTest protocol.Server
}
+var clientIndex, serverIndex int64
+
// NewStreamServer creates a StreamServer using the shared cache. If
// withTelemetry is true, each session is instrumented with telemetry that
// records RPC statistics.
-func NewStreamServer(cache *cache.Cache, withTelemetry bool) *StreamServer {
+func NewStreamServer(cache *cache.Cache, withTelemetry bool, debugInstance *debug.Instance) *StreamServer {
s := &StreamServer{
withTelemetry: withTelemetry,
+ debug: debugInstance,
cache: cache,
}
return s
}
+// debugInstance is the common functionality shared between client and server
+// gopls instances.
+type debugInstance struct {
+ id string
+ debugAddress string
+ logfile string
+}
+
+func (d debugInstance) ID() string {
+ return d.id
+}
+
+func (d debugInstance) DebugAddress() string {
+ return d.debugAddress
+}
+
+func (d debugInstance) Logfile() string {
+ return d.logfile
+}
+
+// A debugServer is held by the client to identity the remove server to which
+// it is connected.
+type debugServer struct {
+ debugInstance
+ // clientID is the id of this client on the server.
+ clientID string
+}
+
+func (s debugServer) ClientID() string {
+ return s.clientID
+}
+
+// A debugClient is held by the server to identify an incoming client
+// connection.
+type debugClient struct {
+ debugInstance
+ // session is the session serving this client.
+ session *cache.Session
+ // serverID is this id of this server on the client.
+ serverID string
+}
+
+func (c debugClient) Session() debug.Session {
+ return cache.DebugSession{Session: c.session}
+}
+
+func (c debugClient) ServerID() string {
+ return c.serverID
+}
+
// ServeStream implements the jsonrpc2.StreamServer interface, by handling
// incoming streams using a new lsp server.
func (s *StreamServer) ServeStream(ctx context.Context, stream jsonrpc2.Stream) error {
+ index := atomic.AddInt64(&clientIndex, 1)
+
conn := jsonrpc2.NewConn(stream)
client := protocol.ClientDispatcher(conn)
+ session := s.cache.NewSession()
+ dc := &debugClient{
+ debugInstance: debugInstance{
+ id: strconv.FormatInt(index, 10),
+ },
+ session: session,
+ }
+ s.debug.State.AddClient(dc)
server := s.serverForTest
if server == nil {
- server = lsp.NewServer(s.cache.NewSession(), client)
+ server = lsp.NewServer(session, client)
}
conn.AddHandler(protocol.ServerHandler(server))
conn.AddHandler(protocol.Canceller{})
if s.withTelemetry {
conn.AddHandler(telemetryHandler{})
}
+ conn.AddHandler(&handshaker{
+ client: dc,
+ debug: s.debug,
+ })
return conn.Run(protocol.WithClient(ctx, client))
}
@@ -73,17 +145,19 @@ type Forwarder struct {
withTelemetry bool
dialTimeout time.Duration
retries int
+ debug *debug.Instance
}
// NewForwarder creates a new Forwarder, ready to forward connections to the
// remote server specified by network and addr.
-func NewForwarder(network, addr string, withTelemetry bool) *Forwarder {
+func NewForwarder(network, addr string, withTelemetry bool, debugInstance *debug.Instance) *Forwarder {
return &Forwarder{
network: network,
addr: addr,
withTelemetry: withTelemetry,
dialTimeout: 1 * time.Second,
retries: 5,
+ debug: debugInstance,
}
}
@@ -106,7 +180,7 @@ func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) err
if err == nil {
break
}
- log.Printf("failed an attempt to connect to remote: %v\n", err)
+ log.Print(ctx, fmt.Sprintf("failed an attempt to connect to remote: %v\n", err))
// In case our failure was a fast-failure, ensure we wait at least
// f.dialTimeout before trying again.
if attempt != f.retries {
@@ -128,7 +202,6 @@ func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) err
if f.withTelemetry {
clientConn.AddHandler(telemetryHandler{})
}
-
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return serverConn.Run(ctx)
@@ -136,6 +209,29 @@ func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) err
g.Go(func() error {
return clientConn.Run(ctx)
})
+
+ // Do a handshake with the server instance to exchange debug information.
+ index := atomic.AddInt64(&serverIndex, 1)
+ serverID := strconv.FormatInt(index, 10)
+ var (
+ hreq = handshakeRequest{
+ ServerID: serverID,
+ Logfile: f.debug.Logfile,
+ DebugAddr: f.debug.DebugAddress,
+ }
+ hresp handshakeResponse
+ )
+ if err := serverConn.Call(ctx, handshakeMethod, hreq, &hresp); err != nil {
+ log.Error(ctx, "gopls handshake failed", err)
+ }
+ f.debug.State.AddServer(debugServer{
+ debugInstance: debugInstance{
+ id: serverID,
+ logfile: hresp.Logfile,
+ debugAddress: hresp.DebugAddr,
+ },
+ clientID: hresp.ClientID,
+ })
return g.Wait()
}
@@ -143,6 +239,26 @@ func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) err
// testing purposes.
var ForwarderExitFunc = os.Exit
+// OverrideExitFuncsForTest can be used from test code to prevent the test
+// process from exiting on server shutdown. The returned func reverts the exit
+// funcs to their previous state.
+func OverrideExitFuncsForTest() func() {
+ // Override functions that would shut down the test process
+ cleanup := func(lspExit, forwarderExit func(code int)) func() {
+ return func() {
+ lsp.ServerExitFunc = lspExit
+ ForwarderExitFunc = forwarderExit
+ }
+ }(lsp.ServerExitFunc, ForwarderExitFunc)
+ // It is an error for a test to shutdown a server process.
+ lsp.ServerExitFunc = func(code int) {
+ panic(fmt.Sprintf("LSP server exited with code %d", code))
+ }
+ // We don't want our forwarders to exit, but it's OK if they would have.
+ ForwarderExitFunc = func(code int) {}
+ return cleanup
+}
+
// forwarderHandler intercepts 'exit' messages to prevent the shared gopls
// instance from exiting. In the future it may also intercept 'shutdown' to
// provide more graceful shutdown of the client connection.
@@ -162,3 +278,57 @@ func (forwarderHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, delive
}
return false
}
+
+type handshaker struct {
+ jsonrpc2.EmptyHandler
+ client *debugClient
+ debug *debug.Instance
+}
+
+type handshakeRequest struct {
+ ServerID string `json:"serverID"`
+ Logfile string `json:"logfile"`
+ DebugAddr string `json:"debugAddr"`
+}
+
+type handshakeResponse struct {
+ ClientID string `json:"clientID"`
+ SessionID string `json:"sessionID"`
+ Logfile string `json:"logfile"`
+ DebugAddr string `json:"debugAddr"`
+}
+
+const handshakeMethod = "gopls/handshake"
+
+func (h *handshaker) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
+ if r.Method == handshakeMethod {
+ var req handshakeRequest
+ if err := json.Unmarshal(*r.Params, &req); err != nil {
+ sendError(ctx, r, err)
+ return true
+ }
+ h.client.debugAddress = req.DebugAddr
+ h.client.logfile = req.Logfile
+ h.client.serverID = req.ServerID
+ resp := handshakeResponse{
+ ClientID: h.client.id,
+ SessionID: cache.DebugSession{Session: h.client.session}.ID(),
+ Logfile: h.debug.Logfile,
+ DebugAddr: h.debug.DebugAddress,
+ }
+ if err := r.Reply(ctx, resp, nil); err != nil {
+ log.Error(ctx, "replying to handshake", err)
+ }
+ return true
+ }
+ return false
+}
+
+func sendError(ctx context.Context, req *jsonrpc2.Request, err error) {
+ if _, ok := err.(*jsonrpc2.Error); !ok {
+ err = jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err)
+ }
+ if err := req.Reply(ctx, nil, err); err != nil {
+ log.Error(ctx, "", err)
+ }
+}
diff --git a/internal/lsp/lsprpc/lsprpc_test.go b/internal/lsp/lsprpc/lsprpc_test.go
index b20b92d94..215d936b5 100644
--- a/internal/lsp/lsprpc/lsprpc_test.go
+++ b/internal/lsp/lsprpc/lsprpc_test.go
@@ -13,6 +13,8 @@ import (
"golang.org/x/tools/internal/jsonrpc2/servertest"
"golang.org/x/tools/internal/lsp/cache"
+ "golang.org/x/tools/internal/lsp/debug"
+ "golang.org/x/tools/internal/lsp/fake"
"golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/tools/internal/telemetry/log"
)
@@ -42,7 +44,8 @@ func TestClientLogging(t *testing.T) {
server := pingServer{}
client := fakeClient{logs: make(chan string, 10)}
- ss := NewStreamServer(cache.New(nil, nil), false)
+ di := debug.NewInstance("", "")
+ ss := NewStreamServer(cache.New(nil, di.State), false, di)
ss.serverForTest = server
ts := servertest.NewPipeServer(ctx, ss)
cc := ts.Connect(ctx)
@@ -92,12 +95,13 @@ func TestRequestCancellation(t *testing.T) {
server := waitableServer{
started: make(chan struct{}),
}
- ss := NewStreamServer(cache.New(nil, nil), false)
+ diserve := debug.NewInstance("", "")
+ ss := NewStreamServer(cache.New(nil, diserve.State), false, diserve)
ss.serverForTest = server
ctx := context.Background()
tsDirect := servertest.NewTCPServer(ctx, ss)
- forwarder := NewForwarder("tcp", tsDirect.Addr, false)
+ forwarder := NewForwarder("tcp", tsDirect.Addr, false, debug.NewInstance("", ""))
tsForwarded := servertest.NewPipeServer(ctx, forwarder)
tests := []struct {
@@ -158,4 +162,59 @@ func main() {
fmt.Println("Hello World.")
}`
+func TestDebugInfoLifecycle(t *testing.T) {
+ resetExitFuncs := OverrideExitFuncsForTest()
+ defer resetExitFuncs()
+
+ clientDebug := debug.NewInstance("", "")
+ serverDebug := debug.NewInstance("", "")
+
+ cache := cache.New(nil, serverDebug.State)
+ ss := NewStreamServer(cache, false, serverDebug)
+ ctx := context.Background()
+ tsBackend := servertest.NewTCPServer(ctx, ss)
+
+ forwarder := NewForwarder("tcp", tsBackend.Addr, false, clientDebug)
+ tsForwarder := servertest.NewPipeServer(ctx, forwarder)
+
+ ws, err := fake.NewWorkspace("gopls-lsprpc-test", []byte(exampleProgram))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ws.Close()
+
+ conn1 := tsForwarder.Connect(ctx)
+ ed1, err := fake.NewConnectedEditor(ctx, ws, conn1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ed1.Shutdown(ctx)
+ conn2 := tsBackend.Connect(ctx)
+ ed2, err := fake.NewConnectedEditor(ctx, ws, conn2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ed2.Shutdown(ctx)
+
+ if got, want := len(serverDebug.State.Clients()), 2; got != want {
+ t.Errorf("len(server:Clients) = %d, want %d", got, want)
+ }
+ if got, want := len(serverDebug.State.Sessions()), 2; got != want {
+ t.Errorf("len(server:Sessions) = %d, want %d", got, want)
+ }
+ if got, want := len(clientDebug.State.Servers()), 1; got != want {
+ t.Errorf("len(client:Servers) = %d, want %d", got, want)
+ }
+ // Close one of the connections to verify that the client and session were
+ // dropped.
+ if err := ed1.Shutdown(ctx); err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(serverDebug.State.Sessions()), 1; got != want {
+ t.Errorf("len(server:Sessions()) = %d, want %d", got, want)
+ }
+ // TODO(rfindley): once disconnection works, assert that len(Clients) == 1
+ // (as of writing, it is still 2)
+}
+
// TODO: add a test for telemetry.
diff --git a/internal/lsp/regtest/env.go b/internal/lsp/regtest/env.go
index b1a042019..036edd9bb 100644
--- a/internal/lsp/regtest/env.go
+++ b/internal/lsp/regtest/env.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/tools/internal/jsonrpc2/servertest"
"golang.org/x/tools/internal/lsp/cache"
+ "golang.org/x/tools/internal/lsp/debug"
"golang.org/x/tools/internal/lsp/fake"
"golang.org/x/tools/internal/lsp/lsprpc"
"golang.org/x/tools/internal/lsp/protocol"
@@ -79,7 +80,8 @@ func (r *Runner) getTestServer() *servertest.TCPServer {
r.mu.Lock()
defer r.mu.Unlock()
if r.ts == nil {
- ss := lsprpc.NewStreamServer(cache.New(nil, nil), false)
+ di := debug.NewInstance("", "")
+ ss := lsprpc.NewStreamServer(cache.New(nil, di.State), false, di)
r.ts = servertest.NewTCPServer(context.Background(), ss)
}
return r.ts
@@ -184,7 +186,8 @@ func (r *Runner) RunInMode(modes EnvMode, t *testing.T, filedata string, test fu
}
func (r *Runner) singletonEnv(ctx context.Context, t *testing.T) (servertest.Connector, func()) {
- ss := lsprpc.NewStreamServer(cache.New(nil, nil), false)
+ di := debug.NewInstance("", "")
+ ss := lsprpc.NewStreamServer(cache.New(nil, di.State), false, di)
ts := servertest.NewPipeServer(ctx, ss)
cleanup := func() {
ts.Close()
@@ -198,7 +201,7 @@ func (r *Runner) sharedEnv(ctx context.Context, t *testing.T) (servertest.Connec
func (r *Runner) forwardedEnv(ctx context.Context, t *testing.T) (servertest.Connector, func()) {
ts := r.getTestServer()
- forwarder := lsprpc.NewForwarder("tcp", ts.Addr, false)
+ forwarder := lsprpc.NewForwarder("tcp", ts.Addr, false, debug.NewInstance("", ""))
ts2 := servertest.NewPipeServer(ctx, forwarder)
cleanup := func() {
ts2.Close()
@@ -208,7 +211,7 @@ func (r *Runner) forwardedEnv(ctx context.Context, t *testing.T) (servertest.Con
func (r *Runner) separateProcessEnv(ctx context.Context, t *testing.T) (servertest.Connector, func()) {
socket := r.getRemoteSocket(t)
- forwarder := lsprpc.NewForwarder("unix", socket, false)
+ forwarder := lsprpc.NewForwarder("unix", socket, false, debug.NewInstance("", ""))
ts2 := servertest.NewPipeServer(ctx, forwarder)
cleanup := func() {
ts2.Close()
diff --git a/internal/lsp/regtest/reg_test.go b/internal/lsp/regtest/reg_test.go
index 542b7fa45..2c57119f8 100644
--- a/internal/lsp/regtest/reg_test.go
+++ b/internal/lsp/regtest/reg_test.go
@@ -13,7 +13,6 @@ import (
"testing"
"time"
- "golang.org/x/tools/internal/lsp"
"golang.org/x/tools/internal/lsp/cmd"
"golang.org/x/tools/internal/lsp/lsprpc"
"golang.org/x/tools/internal/tool"
@@ -32,17 +31,8 @@ func TestMain(m *testing.M) {
tool.Main(context.Background(), cmd.New("gopls", "", nil, nil), os.Args[1:])
os.Exit(0)
}
- // Override functions that would shut down the test process
- defer func(lspExit, forwarderExit func(code int)) {
- lsp.ServerExitFunc = lspExit
- lsprpc.ForwarderExitFunc = forwarderExit
- }(lsp.ServerExitFunc, lsprpc.ForwarderExitFunc)
- // None of these regtests should be able to shut down a server process.
- lsp.ServerExitFunc = func(code int) {
- panic(fmt.Sprintf("LSP server exited with code %d", code))
- }
- // We don't want our forwarders to exit, but it's OK if they would have.
- lsprpc.ForwarderExitFunc = func(code int) {}
+ resetExitFuncs := lsprpc.OverrideExitFuncsForTest()
+ defer resetExitFuncs()
const testTimeout = 60 * time.Second
if *runSubprocessTests {