diff options
Diffstat (limited to 'internal/jsonrpc2_v2')
-rw-r--r-- | internal/jsonrpc2_v2/conn.go | 954 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/frame.go | 22 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/jsonrpc2.go | 12 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/jsonrpc2_test.go | 55 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/messages.go | 12 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/net.go | 35 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/serve.go | 367 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/serve_go116.go | 19 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/serve_pre116.go | 30 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/serve_test.go | 284 | ||||
-rw-r--r-- | internal/jsonrpc2_v2/wire.go | 12 |
11 files changed, 1224 insertions, 578 deletions
diff --git a/internal/jsonrpc2_v2/conn.go b/internal/jsonrpc2_v2/conn.go index 018175e88..04d1445cc 100644 --- a/internal/jsonrpc2_v2/conn.go +++ b/internal/jsonrpc2_v2/conn.go @@ -7,14 +7,17 @@ package jsonrpc2 import ( "context" "encoding/json" + "errors" "fmt" "io" + "sync" "sync/atomic" + "time" "golang.org/x/tools/internal/event" + "golang.org/x/tools/internal/event/keys" "golang.org/x/tools/internal/event/label" - "golang.org/x/tools/internal/lsp/debug/tag" - errors "golang.org/x/xerrors" + "golang.org/x/tools/internal/event/tag" ) // Binder builds a connection configuration. @@ -24,10 +27,21 @@ import ( type Binder interface { // Bind returns the ConnectionOptions to use when establishing the passed-in // Connection. - // The connection is not ready to use when Bind is called. - Bind(context.Context, *Connection) (ConnectionOptions, error) + // + // The connection is not ready to use when Bind is called, + // but Bind may close it without reading or writing to it. + Bind(context.Context, *Connection) ConnectionOptions } +// A BinderFunc implements the Binder interface for a standalone Bind function. +type BinderFunc func(context.Context, *Connection) ConnectionOptions + +func (f BinderFunc) Bind(ctx context.Context, c *Connection) ConnectionOptions { + return f(ctx, c) +} + +var _ Binder = BinderFunc(nil) + // ConnectionOptions holds the options for new connections. type ConnectionOptions struct { // Framer allows control over the message framing and encoding. @@ -39,6 +53,10 @@ type ConnectionOptions struct { // Handler is used as the queued message handler for inbound messages. // If nil, all responses will be ErrNotHandled. Handler Handler + // OnInternalError, if non-nil, is called with any internal errors that occur + // while serving the connection, such as protocol errors or invariant + // violations. (If nil, internal errors result in panics.) + OnInternalError func(error) } // Connection manages the jsonrpc2 protocol, connecting responses back to their @@ -46,102 +64,244 @@ type ConnectionOptions struct { // Connection is bidirectional; it does not have a designated server or client // end. type Connection struct { - seq int64 // must only be accessed using atomic operations - closer io.Closer - writerBox chan Writer - outgoingBox chan map[ID]chan<- *Response - incomingBox chan map[ID]*incoming - async *async + seq int64 // must only be accessed using atomic operations + + stateMu sync.Mutex + state inFlightState // accessed only in updateInFlight + done chan struct{} // closed (under stateMu) when state.closed is true and all goroutines have completed + + writer chan Writer // 1-buffered; stores the writer when not in use + + handler Handler + + onInternalError func(error) + onDone func() } -type AsyncCall struct { - id ID - response chan *Response // the channel a response will be delivered on - resultBox chan asyncResult - endSpan func() // close the tracing span when all processing for the message is complete +// inFlightState records the state of the incoming and outgoing calls on a +// Connection. +type inFlightState struct { + connClosing bool // true when the Connection's Close method has been called + reading bool // true while the readIncoming goroutine is running + readErr error // non-nil when the readIncoming goroutine exits (typically io.EOF) + writeErr error // non-nil if a call to the Writer has failed with a non-canceled Context + + // closer shuts down and cleans up the Reader and Writer state, ideally + // interrupting any Read or Write call that is currently blocked. It is closed + // when the state is idle and one of: connClosing is true, readErr is non-nil, + // or writeErr is non-nil. + // + // After the closer has been invoked, the closer field is set to nil + // and the closeErr field is simultaneously set to its result. + closer io.Closer + closeErr error // error returned from closer.Close + + outgoingCalls map[ID]*AsyncCall // calls only + outgoingNotifications int // # of notifications awaiting "write" + + // incoming stores the total number of incoming calls and notifications + // that have not yet written or processed a result. + incoming int + + incomingByID map[ID]*incomingRequest // calls only + + // handlerQueue stores the backlog of calls and notifications that were not + // already handled by a preempter. + // The queue does not include the request currently being handled (if any). + handlerQueue []*incomingRequest + handlerRunning bool +} + +// updateInFlight locks the state of the connection's in-flight requests, allows +// f to mutate that state, and closes the connection if it is idle and either +// is closing or has a read or write error. +func (c *Connection) updateInFlight(f func(*inFlightState)) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + + s := &c.state + + f(s) + + select { + case <-c.done: + // The connection was already completely done at the start of this call to + // updateInFlight, so it must remain so. (The call to f should have noticed + // that and avoided making any updates that would cause the state to be + // non-idle.) + if !s.idle() { + panic("jsonrpc2_v2: updateInFlight transitioned to non-idle when already done") + } + return + default: + } + + if s.idle() && s.shuttingDown(ErrUnknown) != nil { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil // prevent duplicate Close calls + } + if s.reading { + // The readIncoming goroutine is still running. Our call to Close should + // cause it to exit soon, at which point it will make another call to + // updateInFlight, set s.reading to false, and mark the Connection done. + } else { + // The readIncoming goroutine has exited, or never started to begin with. + // Since everything else is idle, we're completely done. + if c.onDone != nil { + c.onDone() + } + close(c.done) + } + } } -type asyncResult struct { - result []byte - err error +// idle reports whether the connection is in a state with no pending calls or +// notifications. +// +// If idle returns true, the readIncoming goroutine may still be running, +// but no other goroutines are doing work on behalf of the connection. +func (s *inFlightState) idle() bool { + return len(s.outgoingCalls) == 0 && s.outgoingNotifications == 0 && s.incoming == 0 && !s.handlerRunning } -// incoming is used to track an incoming request as it is being handled -type incoming struct { - request *Request // the request being processed - baseCtx context.Context // a base context for the message processing - done func() // a function called when all processing for the message is complete - handleCtx context.Context // the context for handling the message, child of baseCtx - cancel func() // a function that cancels the handling context +// shuttingDown reports whether the connection is in a state that should +// disallow new (incoming and outgoing) calls. It returns either nil or +// an error that is or wraps the provided errClosing. +func (s *inFlightState) shuttingDown(errClosing error) error { + if s.connClosing { + // If Close has been called explicitly, it doesn't matter what state the + // Reader and Writer are in: we shouldn't be starting new work because the + // caller told us not to start new work. + return errClosing + } + if s.readErr != nil { + // If the read side of the connection is broken, we cannot read new call + // requests, and cannot read responses to our outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.readErr) + } + if s.writeErr != nil { + // If the write side of the connection is broken, we cannot write responses + // for incoming calls, and cannot write requests for outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.writeErr) + } + return nil +} + +// incomingRequest is used to track an incoming request as it is being handled +type incomingRequest struct { + *Request // the request being processed + ctx context.Context + cancel context.CancelFunc + endSpan func() // called (and set to nil) when the response is sent } // Bind returns the options unmodified. -func (o ConnectionOptions) Bind(context.Context, *Connection) (ConnectionOptions, error) { - return o, nil +func (o ConnectionOptions) Bind(context.Context, *Connection) ConnectionOptions { + return o } // newConnection creates a new connection and runs it. +// // This is used by the Dial and Serve functions to build the actual connection. -func newConnection(ctx context.Context, rwc io.ReadWriteCloser, binder Binder) (*Connection, error) { +// +// The connection is closed automatically (and its resources cleaned up) when +// the last request has completed after the underlying ReadWriteCloser breaks, +// but it may be stopped earlier by calling Close (for a clean shutdown). +func newConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Binder, onDone func()) *Connection { + // TODO: Should we create a new event span here? + // This will propagate cancellation from ctx; should it? + ctx := notDone{bindCtx} + c := &Connection{ - closer: rwc, - writerBox: make(chan Writer, 1), - outgoingBox: make(chan map[ID]chan<- *Response, 1), - incomingBox: make(chan map[ID]*incoming, 1), - async: newAsync(), + state: inFlightState{closer: rwc}, + done: make(chan struct{}), + writer: make(chan Writer, 1), + onDone: onDone, } + // It's tempting to set a finalizer on c to verify that the state has gone + // idle when the connection becomes unreachable. Unfortunately, the Binder + // interface makes that unsafe: it allows the Handler to close over the + // Connection, which could create a reference cycle that would cause the + // Connection to become uncollectable. - options, err := binder.Bind(ctx, c) - if err != nil { - return nil, err - } - if options.Framer == nil { - options.Framer = HeaderFramer() - } - if options.Preempter == nil { - options.Preempter = defaultHandler{} - } - if options.Handler == nil { - options.Handler = defaultHandler{} - } - c.outgoingBox <- make(map[ID]chan<- *Response) - c.incomingBox <- make(map[ID]*incoming) - // the goroutines started here will continue until the underlying stream is closed - reader := options.Framer.Reader(rwc) - readToQueue := make(chan *incoming) - queueToDeliver := make(chan *incoming) - go c.readIncoming(ctx, reader, readToQueue) - go c.manageQueue(ctx, options.Preempter, readToQueue, queueToDeliver) - go c.deliverMessages(ctx, options.Handler, queueToDeliver) - - // releaseing the writer must be the last thing we do in case any requests - // are blocked waiting for the connection to be ready - c.writerBox <- options.Framer.Writer(rwc) - return c, nil + options := binder.Bind(bindCtx, c) + framer := options.Framer + if framer == nil { + framer = HeaderFramer() + } + c.handler = options.Handler + if c.handler == nil { + c.handler = defaultHandler{} + } + c.onInternalError = options.OnInternalError + + c.writer <- framer.Writer(rwc) + reader := framer.Reader(rwc) + + c.updateInFlight(func(s *inFlightState) { + select { + case <-c.done: + // Bind already closed the connection; don't start a goroutine to read it. + return + default: + } + + // The goroutine started here will continue until the underlying stream is closed. + // + // (If the Binder closed the Connection already, this should error out and + // return almost immediately.) + s.reading = true + go c.readIncoming(ctx, reader, options.Preempter) + }) + return c } // Notify invokes the target method but does not wait for a response. // The params will be marshaled to JSON before sending over the wire, and will // be handed to the method invoked. -func (c *Connection) Notify(ctx context.Context, method string, params interface{}) error { - notify, err := NewNotification(method, params) - if err != nil { - return errors.Errorf("marshaling notify parameters: %v", err) - } +func (c *Connection) Notify(ctx context.Context, method string, params interface{}) (err error) { ctx, done := event.Start(ctx, method, tag.Method.Of(method), tag.RPCDirection.Of(tag.Outbound), ) - event.Metric(ctx, tag.Started.Of(1)) - err = c.write(ctx, notify) - switch { - case err != nil: - event.Label(ctx, tag.StatusCode.Of("ERROR")) - default: - event.Label(ctx, tag.StatusCode.Of("OK")) + attempted := false + + defer func() { + labelStatus(ctx, err) + done() + if attempted { + c.updateInFlight(func(s *inFlightState) { + s.outgoingNotifications-- + }) + } + }() + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, allow outgoing notifications only if + // there is at least one call still in flight. The number of calls in flight + // cannot increase once shutdown begins, and allowing outgoing notifications + // may permit notifications that will cancel in-flight calls. + if len(s.outgoingCalls) == 0 && len(s.incomingByID) == 0 { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + } + s.outgoingNotifications++ + attempted = true + }) + if err != nil { + return err } - done() - return err + + notify, err := NewNotification(method, params) + if err != nil { + return fmt.Errorf("marshaling notify parameters: %v", err) + } + + event.Metric(ctx, tag.Started.Of(1)) + return c.write(ctx, notify) } // Call invokes the target method and returns an object that can be used to await the response. @@ -150,339 +310,503 @@ func (c *Connection) Notify(ctx context.Context, method string, params interface // You do not have to wait for the response, it can just be ignored if not needed. // If sending the call failed, the response will be ready and have the error in it. func (c *Connection) Call(ctx context.Context, method string, params interface{}) *AsyncCall { - result := &AsyncCall{ - id: Int64ID(atomic.AddInt64(&c.seq, 1)), - resultBox: make(chan asyncResult, 1), - } - // generate a new request identifier - call, err := NewCall(result.id, method, params) - if err != nil { - //set the result to failed - result.resultBox <- asyncResult{err: errors.Errorf("marshaling call parameters: %w", err)} - return result - } + // Generate a new request identifier. + id := Int64ID(atomic.AddInt64(&c.seq, 1)) ctx, endSpan := event.Start(ctx, method, tag.Method.Of(method), tag.RPCDirection.Of(tag.Outbound), - tag.RPCID.Of(fmt.Sprintf("%q", result.id)), + tag.RPCID.Of(fmt.Sprintf("%q", id)), ) - result.endSpan = endSpan + + ac := &AsyncCall{ + id: id, + ready: make(chan struct{}), + ctx: ctx, + endSpan: endSpan, + } + // When this method returns, either ac is retired, or the request has been + // written successfully and the call is awaiting a response (to be provided by + // the readIncoming goroutine). + + call, err := NewCall(ac.id, method, params) + if err != nil { + ac.retire(&Response{ID: id, Error: fmt.Errorf("marshaling call parameters: %w", err)}) + return ac + } + + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + if s.outgoingCalls == nil { + s.outgoingCalls = make(map[ID]*AsyncCall) + } + s.outgoingCalls[ac.id] = ac + }) + if err != nil { + ac.retire(&Response{ID: id, Error: err}) + return ac + } + event.Metric(ctx, tag.Started.Of(1)) - // We have to add ourselves to the pending map before we send, otherwise we - // are racing the response. - // rchan is buffered in case the response arrives without a listener. - result.response = make(chan *Response, 1) - pending := <-c.outgoingBox - pending[result.id] = result.response - c.outgoingBox <- pending - // now we are ready to send if err := c.write(ctx, call); err != nil { - // sending failed, we will never get a response, so deliver a fake one - r, _ := NewResponse(result.id, nil, err) - c.incomingResponse(r) + // Sending failed. We will never get a response, so deliver a fake one if it + // wasn't already retired by the connection breaking. + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: id, Error: err}) + } else { + // ac was already retired by the readIncoming goroutine: + // perhaps our write raced with the Read side of the connection breaking. + } + }) } - return result + return ac +} + +type AsyncCall struct { + id ID + ready chan struct{} // closed after response has been set and span has been ended + response *Response + ctx context.Context // for event logging only + endSpan func() // close the tracing span when all processing for the message is complete } // ID used for this call. // This can be used to cancel the call if needed. -func (a *AsyncCall) ID() ID { return a.id } +func (ac *AsyncCall) ID() ID { return ac.id } // IsReady can be used to check if the result is already prepared. // This is guaranteed to return true on a result for which Await has already // returned, or a call that failed to send in the first place. -func (a *AsyncCall) IsReady() bool { +func (ac *AsyncCall) IsReady() bool { select { - case r := <-a.resultBox: - a.resultBox <- r + case <-ac.ready: return true default: return false } } -// Await the results of a Call. +// retire processes the response to the call. +func (ac *AsyncCall) retire(response *Response) { + select { + case <-ac.ready: + panic(fmt.Sprintf("jsonrpc2: retire called twice for ID %v", ac.id)) + default: + } + + ac.response = response + labelStatus(ac.ctx, response.Error) + ac.endSpan() + // Allow the trace context, which may retain a lot of reachable values, + // to be garbage-collected. + ac.ctx, ac.endSpan = nil, nil + + close(ac.ready) +} + +// Await waits for (and decodes) the results of a Call. // The response will be unmarshaled from JSON into the result. -func (a *AsyncCall) Await(ctx context.Context, result interface{}) error { - defer a.endSpan() - var r asyncResult +func (ac *AsyncCall) Await(ctx context.Context, result interface{}) error { select { - case response := <-a.response: - // response just arrived, prepare the result - switch { - case response.Error != nil: - r.err = response.Error - event.Label(ctx, tag.StatusCode.Of("ERROR")) - default: - r.result = response.Result - event.Label(ctx, tag.StatusCode.Of("OK")) - } - case r = <-a.resultBox: - // result already available case <-ctx.Done(): - event.Label(ctx, tag.StatusCode.Of("CANCELLED")) return ctx.Err() + case <-ac.ready: } - // refill the box for the next caller - a.resultBox <- r - // and unpack the result - if r.err != nil { - return r.err + if ac.response.Error != nil { + return ac.response.Error } - if result == nil || len(r.result) == 0 { + if result == nil { return nil } - return json.Unmarshal(r.result, result) + return json.Unmarshal(ac.response.Result, result) } // Respond delivers a response to an incoming Call. // // Respond must be called exactly once for any message for which a handler // returns ErrAsyncResponse. It must not be called for any other message. -func (c *Connection) Respond(id ID, result interface{}, rerr error) error { - pending := <-c.incomingBox - defer func() { c.incomingBox <- pending }() - entry, found := pending[id] - if !found { - return nil +func (c *Connection) Respond(id ID, result interface{}, err error) error { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + req = s.incomingByID[id] + }) + if req == nil { + return c.internalErrorf("Request not found for ID %v", id) + } + + if err == ErrAsyncResponse { + // Respond is supposed to supply the asynchronous response, so it would be + // confusing to call Respond with an error that promises to call Respond + // again. + err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method) } - delete(pending, id) - return c.respond(entry, result, rerr) + return c.processResult("Respond", req, result, err) } -// Cancel is used to cancel an inbound message by ID, it does not cancel -// outgoing messages. -// This is only used inside a message handler that is layering a -// cancellation protocol on top of JSON RPC 2. -// It will not complain if the ID is not a currently active message, and it will -// not cause any messages that have not arrived yet with that ID to be +// Cancel cancels the Context passed to the Handle call for the inbound message +// with the given ID. +// +// Cancel will not complain if the ID is not a currently active message, and it +// will not cause any messages that have not arrived yet with that ID to be // cancelled. func (c *Connection) Cancel(id ID) { - pending := <-c.incomingBox - defer func() { c.incomingBox <- pending }() - if entry, found := pending[id]; found && entry.cancel != nil { - entry.cancel() - entry.cancel = nil + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + req = s.incomingByID[id] + }) + if req != nil { + req.cancel() } } // Wait blocks until the connection is fully closed, but does not close it. func (c *Connection) Wait() error { - return c.async.wait() + var err error + <-c.done + c.updateInFlight(func(s *inFlightState) { + err = s.closeErr + }) + return err } -// Close can be used to close the underlying stream, and then wait for the connection to -// fully shut down. -// This does not cancel in flight requests, but waits for them to gracefully complete. +// Close stops accepting new requests, waits for in-flight requests and enqueued +// Handle calls to complete, and then closes the underlying stream. +// +// After the start of a Close, notification requests (that lack IDs and do not +// receive responses) will continue to be passed to the Preempter, but calls +// with IDs will receive immediate responses with ErrServerClosing, and no new +// requests (not even notifications!) will be enqueued to the Handler. func (c *Connection) Close() error { - // close the underlying stream - if err := c.closer.Close(); err != nil && !isClosingError(err) { - return err - } - // and then wait for it to cause the connection to close - if err := c.Wait(); err != nil && !isClosingError(err) { - return err - } - return nil + // Stop handling new requests, and interrupt the reader (by closing the + // connection) as soon as the active requests finish. + c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) + + return c.Wait() } // readIncoming collects inbound messages from the reader and delivers them, either responding // to outgoing calls or feeding requests to the queue. -func (c *Connection) readIncoming(ctx context.Context, reader Reader, toQueue chan<- *incoming) { - defer close(toQueue) +func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter Preempter) { + var err error for { - // get the next message - // no lock is needed, this is the only reader - msg, n, err := reader.Read(ctx) + var ( + msg Message + n int64 + ) + msg, n, err = reader.Read(ctx) if err != nil { - // The stream failed, we cannot continue - c.async.setError(err) - return + break } + switch msg := msg.(type) { case *Request: - entry := &incoming{ - request: msg, - } - // add a span to the context for this request - labels := append(make([]label.Label, 0, 3), // make space for the id if present - tag.Method.Of(msg.Method), - tag.RPCDirection.Of(tag.Inbound), - ) - if msg.IsCall() { - labels = append(labels, tag.RPCID.Of(fmt.Sprintf("%q", msg.ID))) - } - entry.baseCtx, entry.done = event.Start(ctx, msg.Method, labels...) - event.Metric(entry.baseCtx, - tag.Started.Of(1), - tag.ReceivedBytes.Of(n)) - // in theory notifications cannot be cancelled, but we build them a cancel context anyway - entry.handleCtx, entry.cancel = context.WithCancel(entry.baseCtx) - // if the request is a call, add it to the incoming map so it can be - // cancelled by id - if msg.IsCall() { - pending := <-c.incomingBox - pending[msg.ID] = entry - c.incomingBox <- pending - } - // send the message to the incoming queue - toQueue <- entry + c.acceptRequest(ctx, msg, n, preempter) + case *Response: - // If method is not set, this should be a response, in which case we must - // have an id to send the response back to the caller. - c.incomingResponse(msg) + c.updateInFlight(func(s *inFlightState) { + if ac, ok := s.outgoingCalls[msg.ID]; ok { + delete(s.outgoingCalls, msg.ID) + ac.retire(msg) + } else { + // TODO: How should we report unexpected responses? + } + }) + + default: + c.internalErrorf("Read returned an unexpected message of type %T", msg) } } + + c.updateInFlight(func(s *inFlightState) { + s.reading = false + s.readErr = err + + // Retire any outgoing requests that were still in flight: with the Reader no + // longer being processed, they necessarily cannot receive a response. + for id, ac := range s.outgoingCalls { + ac.retire(&Response{ID: id, Error: err}) + } + s.outgoingCalls = nil + }) } -func (c *Connection) incomingResponse(msg *Response) { - pending := <-c.outgoingBox - response, ok := pending[msg.ID] - if ok { - delete(pending, msg.ID) +// acceptRequest either handles msg synchronously or enqueues it to be handled +// asynchronously. +func (c *Connection) acceptRequest(ctx context.Context, msg *Request, msgBytes int64, preempter Preempter) { + // Add a span to the context for this request. + labels := append(make([]label.Label, 0, 3), // Make space for the ID if present. + tag.Method.Of(msg.Method), + tag.RPCDirection.Of(tag.Inbound), + ) + if msg.IsCall() { + labels = append(labels, tag.RPCID.Of(fmt.Sprintf("%q", msg.ID))) } - c.outgoingBox <- pending - if response != nil { - response <- msg + ctx, endSpan := event.Start(ctx, msg.Method, labels...) + event.Metric(ctx, + tag.Started.Of(1), + tag.ReceivedBytes.Of(msgBytes)) + + // In theory notifications cannot be cancelled, but we build them a cancel + // context anyway. + ctx, cancel := context.WithCancel(ctx) + req := &incomingRequest{ + Request: msg, + ctx: ctx, + cancel: cancel, + endSpan: endSpan, } -} -// manageQueue reads incoming requests, attempts to process them with the preempter, or queue them -// up for normal handling. -func (c *Connection) manageQueue(ctx context.Context, preempter Preempter, fromRead <-chan *incoming, toDeliver chan<- *incoming) { - defer close(toDeliver) - q := []*incoming{} - ok := true - for { - var nextReq *incoming - if len(q) == 0 { - // no messages in the queue - // if we were closing, then we are done - if !ok { + // If the request is a call, add it to the incoming map so it can be + // cancelled (or responded) by ID. + var err error + c.updateInFlight(func(s *inFlightState) { + s.incoming++ + + if req.IsCall() { + if s.incomingByID[req.ID] != nil { + err = fmt.Errorf("%w: request ID %v already in use", ErrInvalidRequest, req.ID) + req.ID = ID{} // Don't misattribute this error to the existing request. return } - // not closing, but nothing in the queue, so just block waiting for a read - nextReq, ok = <-fromRead - } else { - // we have a non empty queue, so pick whichever of reading or delivering - // that we can make progress on - select { - case nextReq, ok = <-fromRead: - case toDeliver <- q[0]: - //TODO: this causes a lot of shuffling, should we use a growing ring buffer? compaction? - q = q[1:] + + if s.incomingByID == nil { + s.incomingByID = make(map[ID]*incomingRequest) } + s.incomingByID[req.ID] = req + + // When shutting down, reject all new Call requests, even if they could + // theoretically be handled by the preempter. The preempter could return + // ErrAsyncResponse, which would increase the amount of work in flight + // when we're trying to ensure that it strictly decreases. + err = s.shuttingDown(ErrServerClosing) } - if nextReq != nil { - // TODO: should we allow to limit the queue size? - var result interface{} - rerr := nextReq.handleCtx.Err() - if rerr == nil { - // only preempt if not already cancelled - result, rerr = preempter.Preempt(nextReq.handleCtx, nextReq.request) - } - switch { - case rerr == ErrNotHandled: - // message not handled, add it to the queue for the main handler - q = append(q, nextReq) - case rerr == ErrAsyncResponse: - // message handled but the response will come later - default: - // anything else means the message is fully handled - c.reply(nextReq, result, rerr) - } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + return + } + + if preempter != nil { + result, err := preempter.Preempt(req.ctx, req.Request) + + if req.IsCall() && errors.Is(err, ErrAsyncResponse) { + // This request will remain in flight until Respond is called for it. + return + } + + if !errors.Is(err, ErrNotHandled) { + c.processResult("Preempt", req, result, err) + return + } + } + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, don't enqueue anything to the + // handler — not even notifications. That ensures that if the handler + // continues to make progress, it will eventually become idle and + // close the connection. + err = s.shuttingDown(ErrServerClosing) + if err != nil { + return } + + // We enqueue requests that have not been preempted to an unbounded slice. + // Unfortunately, we cannot in general limit the size of the handler + // queue: we have to read every response that comes in on the wire + // (because it may be responding to a request issued by, say, an + // asynchronous handler), and in order to get to that response we have + // to read all of the requests that came in ahead of it. + s.handlerQueue = append(s.handlerQueue, req) + if !s.handlerRunning { + // We start the handleAsync goroutine when it has work to do, and let it + // exit when the queue empties. + // + // Otherwise, in order to synchronize the handler we would need some other + // goroutine (probably readIncoming?) to explicitly wait for handleAsync + // to finish, and that would complicate error reporting: either the error + // report from the goroutine would be blocked on the handler emptying its + // queue (which was tried, and introduced a deadlock detected by + // TestCloseCallRace), or the error would need to be reported separately + // from synchronizing completion. Allowing the handler goroutine to exit + // when idle seems simpler than trying to implement either of those + // alternatives correctly. + s.handlerRunning = true + go c.handleAsync() + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) } } -func (c *Connection) deliverMessages(ctx context.Context, handler Handler, fromQueue <-chan *incoming) { - defer c.async.done() - for entry := range fromQueue { - // cancel any messages in the queue that we have a pending cancel for - var result interface{} - rerr := entry.handleCtx.Err() - if rerr == nil { - // only deliver if not already cancelled - result, rerr = handler.Handle(entry.handleCtx, entry.request) +// handleAsync invokes the handler on the requests in the handler queue +// sequentially until the queue is empty. +func (c *Connection) handleAsync() { + for { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + if len(s.handlerQueue) > 0 { + req, s.handlerQueue = s.handlerQueue[0], s.handlerQueue[1:] + } else { + s.handlerRunning = false + } + }) + if req == nil { + return } - switch { - case rerr == ErrNotHandled: - // message not handled, report it back to the caller as an error - c.reply(entry, nil, errors.Errorf("%w: %q", ErrMethodNotFound, entry.request.Method)) - case rerr == ErrAsyncResponse: - // message handled but the response will come later - default: - c.reply(entry, result, rerr) + + // Only deliver to the Handler if not already canceled. + if err := req.ctx.Err(); err != nil { + c.updateInFlight(func(s *inFlightState) { + if s.writeErr != nil { + // Assume that req.ctx was canceled due to s.writeErr. + // TODO(#51365): use a Context API to plumb this through req.ctx. + err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr) + } + }) + c.processResult("handleAsync", req, nil, err) + continue } + + result, err := c.handler.Handle(req.ctx, req.Request) + c.processResult(c.handler, req, result, err) } } -// reply is used to reply to an incoming request that has just been handled -func (c *Connection) reply(entry *incoming, result interface{}, rerr error) { - if entry.request.IsCall() { - // we have a call finishing, remove it from the incoming map - pending := <-c.incomingBox - defer func() { c.incomingBox <- pending }() - delete(pending, entry.request.ID) +// processResult processes the result of a request and, if appropriate, sends a response. +func (c *Connection) processResult(from interface{}, req *incomingRequest, result interface{}, err error) error { + switch err { + case ErrAsyncResponse: + if !req.IsCall() { + return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method) + } + return nil // This request is still in flight, so don't record the result yet. + case ErrNotHandled, ErrMethodNotFound: + // Add detail describing the unhandled method. + err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) } - if err := c.respond(entry, result, rerr); err != nil { - // no way to propagate this error - //TODO: should we do more than just log it? - event.Error(entry.baseCtx, "jsonrpc2 message delivery failed", err) + + if req.endSpan == nil { + return c.internalErrorf("%#v produced a duplicate %q Response", from, req.Method) } -} -// respond sends a response. -// This is the code shared between reply and SendResponse. -func (c *Connection) respond(entry *incoming, result interface{}, rerr error) error { - var err error - if entry.request.IsCall() { - // send the response - if result == nil && rerr == nil { - // call with no response, send an error anyway - rerr = errors.Errorf("%w: %q produced no response", ErrInternal, entry.request.Method) + if result != nil && err != nil { + c.internalErrorf("%#v returned a non-nil result with a non-nil error for %s:\n%v\n%#v", from, req.Method, err, result) + result = nil // Discard the spurious result and respond with err. + } + + if req.IsCall() { + if result == nil && err == nil { + err = c.internalErrorf("%#v returned a nil result and nil error for a %q Request that requires a Response", from, req.Method) } - var response *Response - response, err = NewResponse(entry.request.ID, result, rerr) - if err == nil { - // we write the response with the base context, in case the message was cancelled - err = c.write(entry.baseCtx, response) + + response, respErr := NewResponse(req.ID, result, err) + + // The caller could theoretically reuse the request's ID as soon as we've + // sent the response, so ensure that it is removed from the incoming map + // before sending. + c.updateInFlight(func(s *inFlightState) { + delete(s.incomingByID, req.ID) + }) + if respErr == nil { + writeErr := c.write(notDone{req.ctx}, response) + if err == nil { + err = writeErr + } + } else { + err = c.internalErrorf("%#v returned a malformed result for %q: %w", from, req.Method, respErr) } - } else { - switch { - case rerr != nil: - // notification failed - err = errors.Errorf("%w: %q notification failed: %v", ErrInternal, entry.request.Method, rerr) - rerr = nil - case result != nil: - //notification produced a response, which is an error - err = errors.Errorf("%w: %q produced unwanted response", ErrInternal, entry.request.Method) - default: - // normal notification finish + } else { // req is a notification + if result != nil { + err = c.internalErrorf("%#v returned a non-nil result for a %q Request without an ID", from, req.Method) + } else if err != nil { + err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err) + } + if err != nil { + // TODO: can/should we do anything with this error beyond writing it to the event log? + // (Is this the right label to attach to the log?) + event.Label(req.ctx, keys.Err.Of(err)) } } - switch { - case rerr != nil || err != nil: - event.Label(entry.baseCtx, tag.StatusCode.Of("ERROR")) - default: - event.Label(entry.baseCtx, tag.StatusCode.Of("OK")) - } - // and just to be clean, invoke and clear the cancel if needed - if entry.cancel != nil { - entry.cancel() - entry.cancel = nil - } - // mark the entire request processing as done - entry.done() - return err + + labelStatus(req.ctx, err) + + // Cancel the request and finalize the event span to free any associated resources. + req.cancel() + req.endSpan() + req.endSpan = nil + c.updateInFlight(func(s *inFlightState) { + if s.incoming == 0 { + panic("jsonrpc2_v2: processResult called when incoming count is already zero") + } + s.incoming-- + }) + return nil } // write is used by all things that write outgoing messages, including replies. // it makes sure that writes are atomic func (c *Connection) write(ctx context.Context, msg Message) error { - writer := <-c.writerBox - defer func() { c.writerBox <- writer }() + writer := <-c.writer + defer func() { c.writer <- writer }() n, err := writer.Write(ctx, msg) event.Metric(ctx, tag.SentBytes.Of(n)) + + if err != nil && ctx.Err() == nil { + // The call to Write failed, and since ctx.Err() is nil we can't attribute + // the failure (even indirectly) to Context cancellation. The writer appears + // to be broken, and future writes are likely to also fail. + // + // If the read side of the connection is also broken, we might not even be + // able to receive cancellation notifications. Since we can't reliably write + // the results of incoming calls and can't receive explicit cancellations, + // cancel the calls now. + c.updateInFlight(func(s *inFlightState) { + if s.writeErr == nil { + s.writeErr = err + for _, r := range s.incomingByID { + r.cancel() + } + } + }) + } + return err } + +// internalErrorf reports an internal error. By default it panics, but if +// c.onInternalError is non-nil it instead calls that and returns an error +// wrapping ErrInternal. +func (c *Connection) internalErrorf(format string, args ...interface{}) error { + err := fmt.Errorf(format, args...) + if c.onInternalError == nil { + panic("jsonrpc2: " + err.Error()) + } + c.onInternalError(err) + + return fmt.Errorf("%w: %v", ErrInternal, err) +} + +// labelStatus labels the status of the event in ctx based on whether err is nil. +func labelStatus(ctx context.Context, err error) { + if err == nil { + event.Label(ctx, tag.StatusCode.Of("OK")) + } else { + event.Label(ctx, tag.StatusCode.Of("ERROR")) + } +} + +// notDone is a context.Context wrapper that returns a nil Done channel. +type notDone struct{ ctx context.Context } + +func (ic notDone) Value(key interface{}) interface{} { + return ic.ctx.Value(key) +} + +func (notDone) Done() <-chan struct{} { return nil } +func (notDone) Err() error { return nil } +func (notDone) Deadline() (time.Time, bool) { return time.Time{}, false } diff --git a/internal/jsonrpc2_v2/frame.go b/internal/jsonrpc2_v2/frame.go index 634717c73..e42483281 100644 --- a/internal/jsonrpc2_v2/frame.go +++ b/internal/jsonrpc2_v2/frame.go @@ -12,8 +12,6 @@ import ( "io" "strconv" "strings" - - errors "golang.org/x/xerrors" ) // Reader abstracts the transport mechanics from the JSON RPC protocol. @@ -87,7 +85,7 @@ func (w *rawWriter) Write(ctx context.Context, msg Message) (int64, error) { } data, err := EncodeMessage(msg) if err != nil { - return 0, errors.Errorf("marshaling message: %v", err) + return 0, fmt.Errorf("marshaling message: %v", err) } n, err := w.out.Write(data) return int64(n), err @@ -122,7 +120,13 @@ func (r *headerReader) Read(ctx context.Context) (Message, int64, error) { line, err := r.in.ReadString('\n') total += int64(len(line)) if err != nil { - return nil, total, errors.Errorf("failed reading header line: %w", err) + if err == io.EOF { + if total == 0 { + return nil, 0, io.EOF + } + err = io.ErrUnexpectedEOF + } + return nil, total, fmt.Errorf("failed reading header line: %w", err) } line = strings.TrimSpace(line) // check we have a header line @@ -131,23 +135,23 @@ func (r *headerReader) Read(ctx context.Context) (Message, int64, error) { } colon := strings.IndexRune(line, ':') if colon < 0 { - return nil, total, errors.Errorf("invalid header line %q", line) + return nil, total, fmt.Errorf("invalid header line %q", line) } name, value := line[:colon], strings.TrimSpace(line[colon+1:]) switch name { case "Content-Length": if length, err = strconv.ParseInt(value, 10, 32); err != nil { - return nil, total, errors.Errorf("failed parsing Content-Length: %v", value) + return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value) } if length <= 0 { - return nil, total, errors.Errorf("invalid Content-Length: %v", length) + return nil, total, fmt.Errorf("invalid Content-Length: %v", length) } default: // ignoring unknown headers } } if length == 0 { - return nil, total, errors.Errorf("missing Content-Length header") + return nil, total, fmt.Errorf("missing Content-Length header") } data := make([]byte, length) n, err := io.ReadFull(r.in, data) @@ -167,7 +171,7 @@ func (w *headerWriter) Write(ctx context.Context, msg Message) (int64, error) { } data, err := EncodeMessage(msg) if err != nil { - return 0, errors.Errorf("marshaling message: %v", err) + return 0, fmt.Errorf("marshaling message: %v", err) } n, err := fmt.Fprintf(w.out, "Content-Length: %v\r\n\r\n", len(data)) total := int64(n) diff --git a/internal/jsonrpc2_v2/jsonrpc2.go b/internal/jsonrpc2_v2/jsonrpc2.go index e68558442..e9164b0bc 100644 --- a/internal/jsonrpc2_v2/jsonrpc2.go +++ b/internal/jsonrpc2_v2/jsonrpc2.go @@ -47,6 +47,15 @@ type Preempter interface { Preempt(ctx context.Context, req *Request) (result interface{}, err error) } +// A PreempterFunc implements the Preempter interface for a standalone Preempt function. +type PreempterFunc func(ctx context.Context, req *Request) (interface{}, error) + +func (f PreempterFunc) Preempt(ctx context.Context, req *Request) (interface{}, error) { + return f(ctx, req) +} + +var _ Preempter = PreempterFunc(nil) + // Handler handles messages on a connection. type Handler interface { // Handle is invoked sequentially for each incoming request that has not @@ -75,12 +84,15 @@ func (defaultHandler) Handle(context.Context, *Request) (interface{}, error) { return nil, ErrNotHandled } +// A HandlerFunc implements the Handler interface for a standalone Handle function. type HandlerFunc func(ctx context.Context, req *Request) (interface{}, error) func (f HandlerFunc) Handle(ctx context.Context, req *Request) (interface{}, error) { return f(ctx, req) } +var _ Handler = HandlerFunc(nil) + // async is a small helper for operations with an asynchronous result that you // can wait for. type async struct { diff --git a/internal/jsonrpc2_v2/jsonrpc2_test.go b/internal/jsonrpc2_v2/jsonrpc2_test.go index 4f4b7d9b9..dd8d09c88 100644 --- a/internal/jsonrpc2_v2/jsonrpc2_test.go +++ b/internal/jsonrpc2_v2/jsonrpc2_test.go @@ -11,12 +11,10 @@ import ( "path" "reflect" "testing" - "time" "golang.org/x/tools/internal/event/export/eventtest" jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" "golang.org/x/tools/internal/stack/stacktest" - errors "golang.org/x/xerrors" ) var callTests = []invoker{ @@ -78,7 +76,7 @@ type binder struct { type handler struct { conn *jsonrpc2.Connection accumulator int - waitersBox chan map[string]chan struct{} + waiters chan map[string]chan struct{} calls map[string]*jsonrpc2.AsyncCall } @@ -138,10 +136,7 @@ func testConnection(t *testing.T, framer jsonrpc2.Framer) { if err != nil { t.Fatal(err) } - server, err := jsonrpc2.Serve(ctx, listener, binder{framer, nil}) - if err != nil { - t.Fatal(err) - } + server := jsonrpc2.NewServer(ctx, listener, binder{framer, nil}) defer func() { listener.Close() server.Wait() @@ -255,13 +250,13 @@ func verifyResults(t *testing.T, method string, results interface{}, expect inte } } -func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) (jsonrpc2.ConnectionOptions, error) { +func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) jsonrpc2.ConnectionOptions { h := &handler{ - conn: conn, - waitersBox: make(chan map[string]chan struct{}, 1), - calls: make(map[string]*jsonrpc2.AsyncCall), + conn: conn, + waiters: make(chan map[string]chan struct{}, 1), + calls: make(map[string]*jsonrpc2.AsyncCall), } - h.waitersBox <- make(map[string]chan struct{}) + h.waiters <- make(map[string]chan struct{}) if b.runTest != nil { go b.runTest(h) } @@ -269,12 +264,12 @@ func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) (jsonrpc2.C Framer: b.framer, Preempter: h, Handler: h, - }, nil + } } func (h *handler) waiter(name string) chan struct{} { - waiters := <-h.waitersBox - defer func() { h.waitersBox <- waiters }() + waiters := <-h.waiters + defer func() { h.waiters <- waiters }() waiter, found := waiters[name] if !found { waiter = make(chan struct{}) @@ -288,19 +283,19 @@ func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface case "unblock": var name string if err := json.Unmarshal(req.Params, &name); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } close(h.waiter(name)) return nil, nil case "peek": if len(req.Params) > 0 { - return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) + return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) } return h.accumulator, nil case "cancel": var params cancelParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } h.conn.Cancel(jsonrpc2.Int64ID(params.ID)) return nil, nil @@ -313,50 +308,50 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{ switch req.Method { case "no_args": if len(req.Params) > 0 { - return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) + return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) } return true, nil case "one_string": var v string if err := json.Unmarshal(req.Params, &v); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } return "got:" + v, nil case "one_number": var v int if err := json.Unmarshal(req.Params, &v); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } return fmt.Sprintf("got:%d", v), nil case "set": var v int if err := json.Unmarshal(req.Params, &v); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } h.accumulator = v return nil, nil case "add": var v int if err := json.Unmarshal(req.Params, &v); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } h.accumulator += v return nil, nil case "get": if len(req.Params) > 0 { - return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) + return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams) } return h.accumulator, nil case "join": var v []string if err := json.Unmarshal(req.Params, &v); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } return path.Join(v...), nil case "echo": var v []interface{} if err := json.Unmarshal(req.Params, &v); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } var result interface{} err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result) @@ -364,20 +359,18 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{ case "wait": var name string if err := json.Unmarshal(req.Params, &name); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } select { case <-h.waiter(name): return true, nil case <-ctx.Done(): return nil, ctx.Err() - case <-time.After(time.Second): - return nil, errors.Errorf("wait for %q timed out", name) } case "fork": var name string if err := json.Unmarshal(req.Params, &name); err != nil { - return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err) + return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } waitFor := h.waiter(name) go func() { @@ -386,8 +379,6 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{ h.conn.Respond(req.ID, true, nil) case <-ctx.Done(): h.conn.Respond(req.ID, nil, ctx.Err()) - case <-time.After(time.Second): - h.conn.Respond(req.ID, nil, errors.Errorf("wait for %q timed out", name)) } }() return nil, jsonrpc2.ErrAsyncResponse diff --git a/internal/jsonrpc2_v2/messages.go b/internal/jsonrpc2_v2/messages.go index 652ac817a..af145641d 100644 --- a/internal/jsonrpc2_v2/messages.go +++ b/internal/jsonrpc2_v2/messages.go @@ -6,8 +6,8 @@ package jsonrpc2 import ( "encoding/json" - - errors "golang.org/x/xerrors" + "errors" + "fmt" ) // ID is a Request identifier. @@ -120,7 +120,7 @@ func EncodeMessage(msg Message) ([]byte, error) { msg.marshal(&wire) data, err := json.Marshal(&wire) if err != nil { - return data, errors.Errorf("marshaling jsonrpc message: %w", err) + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) } return data, nil } @@ -128,10 +128,10 @@ func EncodeMessage(msg Message) ([]byte, error) { func DecodeMessage(data []byte) (Message, error) { msg := wireCombined{} if err := json.Unmarshal(data, &msg); err != nil { - return nil, errors.Errorf("unmarshaling jsonrpc message: %w", err) + return nil, fmt.Errorf("unmarshaling jsonrpc message: %w", err) } if msg.VersionTag != wireVersion { - return nil, errors.Errorf("invalid message version tag %s expected %s", msg.VersionTag, wireVersion) + return nil, fmt.Errorf("invalid message version tag %s expected %s", msg.VersionTag, wireVersion) } id := ID{} switch v := msg.ID.(type) { @@ -144,7 +144,7 @@ func DecodeMessage(data []byte) (Message, error) { case string: id = StringID(v) default: - return nil, errors.Errorf("invalid message id type <%T>%v", v, v) + return nil, fmt.Errorf("invalid message id type <%T>%v", v, v) } if msg.Method != "" { // has a method, must be a call diff --git a/internal/jsonrpc2_v2/net.go b/internal/jsonrpc2_v2/net.go index 4f2082599..15d0aea3a 100644 --- a/internal/jsonrpc2_v2/net.go +++ b/internal/jsonrpc2_v2/net.go @@ -9,7 +9,6 @@ import ( "io" "net" "os" - "time" ) // This file contains implementations of the transport primitives that use the standard network @@ -36,7 +35,7 @@ type netListener struct { } // Accept blocks waiting for an incoming connection to the listener. -func (l *netListener) Accept(ctx context.Context) (io.ReadWriteCloser, error) { +func (l *netListener) Accept(context.Context) (io.ReadWriteCloser, error) { return l.net.Accept() } @@ -56,9 +55,7 @@ func (l *netListener) Close() error { // Dialer returns a dialer that can be used to connect to the listener. func (l *netListener) Dialer() Dialer { - return NetDialer(l.net.Addr().Network(), l.net.Addr().String(), net.Dialer{ - Timeout: 5 * time.Second, - }) + return NetDialer(l.net.Addr().Network(), l.net.Addr().String(), net.Dialer{}) } // NetDialer returns a Dialer using the supplied standard network dialer. @@ -81,7 +78,7 @@ func (n *netDialer) Dial(ctx context.Context) (io.ReadWriteCloser, error) { } // NetPipeListener returns a new Listener that listens using net.Pipe. -// It is only possibly to connect to it using the Dialier returned by the +// It is only possibly to connect to it using the Dialer returned by the // Dialer method, each call to that method will generate a new pipe the other // side of which will be returned from the Accept call. func NetPipeListener(ctx context.Context) (Listener, error) { @@ -98,15 +95,19 @@ type netPiper struct { } // Accept blocks waiting for an incoming connection to the listener. -func (l *netPiper) Accept(ctx context.Context) (io.ReadWriteCloser, error) { - // block until we have a listener, or are closed or cancelled +func (l *netPiper) Accept(context.Context) (io.ReadWriteCloser, error) { + // Block until the pipe is dialed or the listener is closed, + // preferring the latter if already closed at the start of Accept. + select { + case <-l.done: + return nil, errClosed + default: + } select { case rwc := <-l.dialed: return rwc, nil case <-l.done: - return nil, io.EOF - case <-ctx.Done(): - return nil, ctx.Err() + return nil, errClosed } } @@ -124,6 +125,14 @@ func (l *netPiper) Dialer() Dialer { func (l *netPiper) Dial(ctx context.Context) (io.ReadWriteCloser, error) { client, server := net.Pipe() - l.dialed <- server - return client, nil + + select { + case l.dialed <- server: + return client, nil + + case <-l.done: + client.Close() + server.Close() + return nil, errClosed + } } diff --git a/internal/jsonrpc2_v2/serve.go b/internal/jsonrpc2_v2/serve.go index fb3516635..5e0827354 100644 --- a/internal/jsonrpc2_v2/serve.go +++ b/internal/jsonrpc2_v2/serve.go @@ -6,14 +6,12 @@ package jsonrpc2 import ( "context" + "fmt" "io" "runtime" - "strings" "sync" - "syscall" + "sync/atomic" "time" - - errors "golang.org/x/xerrors" ) // Listener is implemented by protocols to accept new inbound connections. @@ -43,35 +41,43 @@ type Server struct { listener Listener binder Binder async *async + + shutdownOnce sync.Once + closing int32 // atomic: set to nonzero when Shutdown is called } // Dial uses the dialer to make a new connection, wraps the returned // reader and writer using the framer to make a stream, and then builds // a connection on top of that stream using the binder. +// +// The returned Connection will operate independently using the Preempter and/or +// Handler provided by the Binder, and will release its own resources when the +// connection is broken, but the caller may Close it earlier to stop accepting +// (or sending) new requests. func Dial(ctx context.Context, dialer Dialer, binder Binder) (*Connection, error) { // dial a server rwc, err := dialer.Dial(ctx) if err != nil { return nil, err } - return newConnection(ctx, rwc, binder) + return newConnection(ctx, rwc, binder, nil), nil } -// Serve starts a new server listening for incoming connections and returns +// NewServer starts a new server listening for incoming connections and returns // it. // This returns a fully running and connected server, it does not block on // the listener. // You can call Wait to block on the server, or Shutdown to get the sever to // terminate gracefully. // To notice incoming connections, use an intercepting Binder. -func Serve(ctx context.Context, listener Listener, binder Binder) (*Server, error) { +func NewServer(ctx context.Context, listener Listener, binder Binder) *Server { server := &Server{ listener: listener, binder: binder, async: newAsync(), } go server.run(ctx) - return server, nil + return server } // Wait returns only when the server has shut down. @@ -79,173 +85,160 @@ func (s *Server) Wait() error { return s.async.wait() } +// Shutdown informs the server to stop accepting new connections. +func (s *Server) Shutdown() { + s.shutdownOnce.Do(func() { + atomic.StoreInt32(&s.closing, 1) + s.listener.Close() + }) +} + // run accepts incoming connections from the listener, // If IdleTimeout is non-zero, run exits after there are no clients for this // duration, otherwise it exits only on error. func (s *Server) run(ctx context.Context) { defer s.async.done() - var activeConns []*Connection + + var activeConns sync.WaitGroup for { - // we never close the accepted connection, we rely on the other end - // closing or the socket closing itself naturally rwc, err := s.listener.Accept(ctx) if err != nil { - if !isClosingError(err) { + // Only Shutdown closes the listener. If we get an error after Shutdown is + // called, assume that that was the cause and don't report the error; + // otherwise, report the error in case it is unexpected. + if atomic.LoadInt32(&s.closing) == 0 { s.async.setError(err) } - // we are done generating new connections for good + // We are done generating new connections for good. break } - // see if any connections were closed while we were waiting - activeConns = onlyActive(activeConns) - - // a new inbound connection, - conn, err := newConnection(ctx, rwc, s.binder) - if err != nil { - if !isClosingError(err) { - s.async.setError(err) - } - continue - } - activeConns = append(activeConns, conn) - } - - // wait for all active conns to finish - for _, c := range activeConns { - c.Wait() + // A new inbound connection. + activeConns.Add(1) + _ = newConnection(ctx, rwc, s.binder, activeConns.Done) // unregisters itself when done } + activeConns.Wait() } -func onlyActive(conns []*Connection) []*Connection { - i := 0 - for _, c := range conns { - if !c.async.isDone() { - conns[i] = c - i++ - } +// NewIdleListener wraps a listener with an idle timeout. +// +// When there are no active connections for at least the timeout duration, +// calls to Accept will fail with ErrIdleTimeout. +// +// A connection is considered inactive as soon as its Close method is called. +func NewIdleListener(timeout time.Duration, wrap Listener) Listener { + l := &idleListener{ + wrapped: wrap, + timeout: timeout, + active: make(chan int, 1), + timedOut: make(chan struct{}), + idleTimer: make(chan *time.Timer, 1), } - // trim the slice down - return conns[:i] + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + return l } -// isClosingError reports if the error occurs normally during the process of -// closing a network connection. It uses imperfect heuristics that err on the -// side of false negatives, and should not be used for anything critical. -func isClosingError(err error) bool { - if err == nil { - return false - } - // Fully unwrap the error, so the following tests work. - for wrapped := err; wrapped != nil; wrapped = errors.Unwrap(err) { - err = wrapped - } - - // Was it based on an EOF error? - if err == io.EOF { - return true - } +type idleListener struct { + wrapped Listener + timeout time.Duration - // Was it based on a closed pipe? - if err == io.ErrClosedPipe { - return true - } + // Only one of these channels is receivable at any given time. + active chan int // count of active connections; closed when Close is called if not timed out + timedOut chan struct{} // closed when the idle timer expires + idleTimer chan *time.Timer // holds the timer only when idle +} - // Per https://github.com/golang/go/issues/4373, this error string should not - // change. This is not ideal, but since the worst that could happen here is - // some superfluous logging, it is acceptable. - if err.Error() == "use of closed network connection" { - return true - } +// Accept accepts an incoming connection. +// +// If an incoming connection is accepted concurrent to the listener being closed +// due to idleness, the new connection is immediately closed. +func (l *idleListener) Accept(ctx context.Context) (io.ReadWriteCloser, error) { + rwc, err := l.wrapped.Accept(ctx) - if runtime.GOOS == "plan9" { - // Error reading from a closed connection. - if err == syscall.EINVAL { - return true + select { + case n, ok := <-l.active: + if err != nil { + if ok { + l.active <- n + } + return nil, err } - // Error trying to accept a new connection from a closed listener. - if strings.HasSuffix(err.Error(), " listen hungup") { - return true + if ok { + l.active <- n + 1 + } else { + // l.wrapped.Close Close has been called, but Accept returned a + // connection. This race can occur with concurrent Accept and Close calls + // with any net.Listener, and it is benign: since the listener was closed + // explicitly, it can't have also timed out. } - } - return false -} + return l.newConn(rwc), nil -// NewIdleListener wraps a listener with an idle timeout. -// When there are no active connections for at least the timeout duration a -// call to accept will fail with ErrIdleTimeout. -func NewIdleListener(timeout time.Duration, wrap Listener) Listener { - l := &idleListener{ - timeout: timeout, - wrapped: wrap, - newConns: make(chan *idleCloser), - closed: make(chan struct{}), - wasTimeout: make(chan struct{}), - } - go l.run() - return l -} + case <-l.timedOut: + if err == nil { + // Keeping the connection open would leave the listener simultaneously + // active and closed due to idleness, which would be contradictory and + // confusing. Close the connection and pretend that it never happened. + rwc.Close() + } else { + // In theory the timeout could have raced with an unrelated error return + // from Accept. However, ErrIdleTimeout is arguably still valid (since we + // would have closed due to the timeout independent of the error), and the + // harm from returning a spurious ErrIdleTimeout is negligible anyway. + } + return nil, ErrIdleTimeout -type idleListener struct { - wrapped Listener - timeout time.Duration - newConns chan *idleCloser - closed chan struct{} - wasTimeout chan struct{} - closeOnce sync.Once -} + case timer := <-l.idleTimer: + if err != nil { + // The idle timer doesn't run until it receives itself from the idleTimer + // channel, so it can't have called l.wrapped.Close yet and thus err can't + // be ErrIdleTimeout. Leave the idle timer as it was and return whatever + // error we got. + l.idleTimer <- timer + return nil, err + } -type idleCloser struct { - wrapped io.ReadWriteCloser - closed chan struct{} - closeOnce sync.Once -} + if !timer.Stop() { + // Failed to stop the timer — the timer goroutine is in the process of + // firing. Send the timer back to the timer goroutine so that it can + // safely close the timedOut channel, and then wait for the listener to + // actually be closed before we return ErrIdleTimeout. + l.idleTimer <- timer + rwc.Close() + <-l.timedOut + return nil, ErrIdleTimeout + } -func (c *idleCloser) Read(p []byte) (int, error) { - n, err := c.wrapped.Read(p) - if err != nil && isClosingError(err) { - c.closeOnce.Do(func() { close(c.closed) }) + l.active <- 1 + return l.newConn(rwc), nil } - return n, err } -func (c *idleCloser) Write(p []byte) (int, error) { - // we do not close on write failure, we rely on the wrapped writer to do that - // if it is appropriate, which we will detect in the next read. - return c.wrapped.Write(p) -} +func (l *idleListener) Close() error { + select { + case _, ok := <-l.active: + if ok { + close(l.active) + } -func (c *idleCloser) Close() error { - // we rely on closing the wrapped stream to signal to the next read that we - // are closed, rather than triggering the closed signal directly - return c.wrapped.Close() -} + case <-l.timedOut: + // Already closed by the timer; take care not to double-close if the caller + // only explicitly invokes this Close method once, since the io.Closer + // interface explicitly leaves doubled Close calls undefined. + return ErrIdleTimeout -func (l *idleListener) Accept(ctx context.Context) (io.ReadWriteCloser, error) { - rwc, err := l.wrapped.Accept(ctx) - if err != nil { - if isClosingError(err) { - // underlying listener was closed - l.closeOnce.Do(func() { close(l.closed) }) - // was it closed because of the idle timeout? - select { - case <-l.wasTimeout: - err = ErrIdleTimeout - default: - } + case timer := <-l.idleTimer: + if !timer.Stop() { + // Couldn't stop the timer. It shouldn't take long to run, so just wait + // (so that the Listener is guaranteed to be closed before we return) + // and pretend that this call happened afterward. + // That way we won't leak any timers or goroutines when Close returns. + l.idleTimer <- timer + <-l.timedOut + return ErrIdleTimeout } - return nil, err + close(l.active) } - conn := &idleCloser{ - wrapped: rwc, - closed: make(chan struct{}), - } - l.newConns <- conn - return conn, err -} -func (l *idleListener) Close() error { - defer l.closeOnce.Do(func() { close(l.closed) }) return l.wrapped.Close() } @@ -253,31 +246,83 @@ func (l *idleListener) Dialer() Dialer { return l.wrapped.Dialer() } -func (l *idleListener) run() { - var conns []*idleCloser - for { - var firstClosed chan struct{} // left at nil if there are no active conns - var timeout <-chan time.Time // left at nil if there are active conns - if len(conns) > 0 { - firstClosed = conns[0].closed +func (l *idleListener) timerExpired() { + select { + case n, ok := <-l.active: + if ok { + panic(fmt.Sprintf("jsonrpc2: idleListener idle timer fired with %d connections still active", n)) } else { - timeout = time.After(l.timeout) + panic("jsonrpc2: Close finished with idle timer still running") } - select { - case <-l.closed: - // the main listener closed, no need to keep going + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired more than once") + + case <-l.idleTimer: + // The timer for this very call! + } + + // Close the Listener with all channels still blocked to ensure that this call + // to l.wrapped.Close doesn't race with the one in l.Close. + defer close(l.timedOut) + l.wrapped.Close() +} + +func (l *idleListener) connClosed() { + select { + case n, ok := <-l.active: + if !ok { + // l is already closed, so it can't close due to idleness, + // and we don't need to track the number of active connections any more. return - case conn := <-l.newConns: - // a new conn arrived, add it to the list - conns = append(conns, conn) - case <-timeout: - // we timed out, only happens when there are no active conns - // close the underlying listener, and allow the normal closing process to happen - close(l.wasTimeout) - l.wrapped.Close() - case <-firstClosed: - // a conn closed, remove it from the active list - conns = conns[:copy(conns, conns[1:])] } + n-- + if n == 0 { + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + } else { + l.active <- n + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired before last active connection was closed") + + case <-l.idleTimer: + panic("jsonrpc2: idleListener idle timer active before last active connection was closed") } } + +type idleListenerConn struct { + wrapped io.ReadWriteCloser + l *idleListener + closeOnce sync.Once +} + +func (l *idleListener) newConn(rwc io.ReadWriteCloser) *idleListenerConn { + c := &idleListenerConn{ + wrapped: rwc, + l: l, + } + + // A caller that forgets to call Close may disrupt the idleListener's + // accounting, even though the file descriptor for the underlying connection + // may eventually be garbage-collected anyway. + // + // Set a (best-effort) finalizer to verify that a Close call always occurs. + // (We will clear the finalizer explicitly in Close.) + runtime.SetFinalizer(c, func(c *idleListenerConn) { + panic("jsonrpc2: IdleListener connection became unreachable without a call to Close") + }) + + return c +} + +func (c *idleListenerConn) Read(p []byte) (int, error) { return c.wrapped.Read(p) } +func (c *idleListenerConn) Write(p []byte) (int, error) { return c.wrapped.Write(p) } + +func (c *idleListenerConn) Close() error { + defer c.closeOnce.Do(func() { + c.l.connClosed() + runtime.SetFinalizer(c, nil) + }) + return c.wrapped.Close() +} diff --git a/internal/jsonrpc2_v2/serve_go116.go b/internal/jsonrpc2_v2/serve_go116.go new file mode 100644 index 000000000..29549f105 --- /dev/null +++ b/internal/jsonrpc2_v2/serve_go116.go @@ -0,0 +1,19 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.16 +// +build go1.16 + +package jsonrpc2 + +import ( + "errors" + "net" +) + +var errClosed = net.ErrClosed + +func isErrClosed(err error) bool { + return errors.Is(err, errClosed) +} diff --git a/internal/jsonrpc2_v2/serve_pre116.go b/internal/jsonrpc2_v2/serve_pre116.go new file mode 100644 index 000000000..a1801d8a2 --- /dev/null +++ b/internal/jsonrpc2_v2/serve_pre116.go @@ -0,0 +1,30 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.16 +// +build !go1.16 + +package jsonrpc2 + +import ( + "errors" + "strings" +) + +// errClosed is an error with the same string as net.ErrClosed, +// which was added in Go 1.16. +var errClosed = errors.New("use of closed network connection") + +// isErrClosed reports whether err ends in the same string as errClosed. +func isErrClosed(err error) bool { + // As of Go 1.16, this could be 'errors.Is(err, net.ErrClosing)', but + // unfortunately gopls still requires compatibility with + // (otherwise-unsupported) older Go versions. + // + // In the meantime, this error string has not changed on any supported Go + // version, and is not expected to change in the future. + // This is not ideal, but since the worst that could happen here is some + // superfluous logging, it is acceptable. + return strings.HasSuffix(err.Error(), "use of closed network connection") +} diff --git a/internal/jsonrpc2_v2/serve_test.go b/internal/jsonrpc2_v2/serve_test.go index 26cf6a58c..88ac66b7e 100644 --- a/internal/jsonrpc2_v2/serve_test.go +++ b/internal/jsonrpc2_v2/serve_test.go @@ -7,6 +7,8 @@ package jsonrpc2_test import ( "context" "errors" + "fmt" + "runtime/debug" "testing" "time" @@ -16,48 +18,125 @@ import ( func TestIdleTimeout(t *testing.T) { stacktest.NoLeak(t) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - listener, err := jsonrpc2.NetListener(ctx, "tcp", "localhost:0", jsonrpc2.NetListenOptions{}) - if err != nil { - t.Fatal(err) - } - listener = jsonrpc2.NewIdleListener(100*time.Millisecond, listener) - defer listener.Close() - server, err := jsonrpc2.Serve(ctx, listener, jsonrpc2.ConnectionOptions{}) - if err != nil { - t.Fatal(err) - } + // Use a panicking time.AfterFunc instead of context.WithTimeout so that we + // get a goroutine dump on failure. We expect the test to take on the order of + // a few tens of milliseconds at most, so 10s should be several orders of + // magnitude of headroom. + timer := time.AfterFunc(10*time.Second, func() { + debug.SetTraceback("all") + panic("TestIdleTimeout deadlocked") + }) + defer timer.Stop() - connect := func() *jsonrpc2.Connection { - client, err := jsonrpc2.Dial(ctx, - listener.Dialer(), - jsonrpc2.ConnectionOptions{}) + ctx := context.Background() + + try := func(d time.Duration) (longEnough bool) { + listener, err := jsonrpc2.NetListener(ctx, "tcp", "localhost:0", jsonrpc2.NetListenOptions{}) if err != nil { t.Fatal(err) } - return client - } - // Exercise some connection/disconnection patterns, and then assert that when - // our timer fires, the server exits. - conn1 := connect() - conn2 := connect() - if err := conn1.Close(); err != nil { - t.Fatalf("conn1.Close failed with error: %v", err) - } - if err := conn2.Close(); err != nil { - t.Fatalf("conn2.Close failed with error: %v", err) - } - conn3 := connect() - if err := conn3.Close(); err != nil { - t.Fatalf("conn3.Close failed with error: %v", err) - } - serverError := server.Wait() + idleStart := time.Now() + listener = jsonrpc2.NewIdleListener(d, listener) + defer listener.Close() - if !errors.Is(serverError, jsonrpc2.ErrIdleTimeout) { - t.Errorf("run() returned error %v, want %v", serverError, jsonrpc2.ErrIdleTimeout) + server := jsonrpc2.NewServer(ctx, listener, jsonrpc2.ConnectionOptions{}) + + // Exercise some connection/disconnection patterns, and then assert that when + // our timer fires, the server exits. + conn1, err := jsonrpc2.Dial(ctx, listener.Dialer(), jsonrpc2.ConnectionOptions{}) + if err != nil { + if since := time.Since(idleStart); since < d { + t.Fatalf("conn1 failed to connect after %v: %v", since, err) + } + t.Log("jsonrpc2.Dial:", err) + return false // Took to long to dial, so the failure could have been due to the idle timeout. + } + // On the server side, Accept can race with the connection timing out. + // Send a call and wait for the response to ensure that the connection was + // actually fully accepted. + ac := conn1.Call(ctx, "ping", nil) + if err := ac.Await(ctx, nil); !errors.Is(err, jsonrpc2.ErrMethodNotFound) { + if since := time.Since(idleStart); since < d { + t.Fatalf("conn1 broken after %v: %v", since, err) + } + t.Log(`conn1.Call(ctx, "ping", nil):`, err) + conn1.Close() + return false + } + + // Since conn1 was successfully accepted and remains open, the server is + // definitely non-idle. Dialing another simultaneous connection should + // succeed. + conn2, err := jsonrpc2.Dial(ctx, listener.Dialer(), jsonrpc2.ConnectionOptions{}) + if err != nil { + conn1.Close() + t.Fatalf("conn2 failed to connect while non-idle after %v: %v", time.Since(idleStart), err) + return false + } + // Ensure that conn2 is also accepted on the server side before we close + // conn1. Otherwise, the connection can appear idle if the server processes + // the closure of conn1 and the idle timeout before it finally notices conn2 + // in the accept queue. + // (That failure mode may explain the failure noted in + // https://go.dev/issue/49387#issuecomment-1303979877.) + ac = conn2.Call(ctx, "ping", nil) + if err := ac.Await(ctx, nil); !errors.Is(err, jsonrpc2.ErrMethodNotFound) { + t.Fatalf("conn2 broken while non-idle after %v: %v", time.Since(idleStart), err) + } + + if err := conn1.Close(); err != nil { + t.Fatalf("conn1.Close failed with error: %v", err) + } + idleStart = time.Now() + if err := conn2.Close(); err != nil { + t.Fatalf("conn2.Close failed with error: %v", err) + } + + conn3, err := jsonrpc2.Dial(ctx, listener.Dialer(), jsonrpc2.ConnectionOptions{}) + if err != nil { + if since := time.Since(idleStart); since < d { + t.Fatalf("conn3 failed to connect after %v: %v", since, err) + } + t.Log("jsonrpc2.Dial:", err) + return false // Took to long to dial, so the failure could have been due to the idle timeout. + } + + ac = conn3.Call(ctx, "ping", nil) + if err := ac.Await(ctx, nil); !errors.Is(err, jsonrpc2.ErrMethodNotFound) { + if since := time.Since(idleStart); since < d { + t.Fatalf("conn3 broken after %v: %v", since, err) + } + t.Log(`conn3.Call(ctx, "ping", nil):`, err) + conn3.Close() + return false + } + + idleStart = time.Now() + if err := conn3.Close(); err != nil { + t.Fatalf("conn3.Close failed with error: %v", err) + } + + serverError := server.Wait() + + if !errors.Is(serverError, jsonrpc2.ErrIdleTimeout) { + t.Errorf("run() returned error %v, want %v", serverError, jsonrpc2.ErrIdleTimeout) + } + if since := time.Since(idleStart); since < d { + t.Errorf("server shut down after %v idle; want at least %v", since, d) + } + return true + } + + d := 1 * time.Millisecond + for { + t.Logf("testing with idle timout %v", d) + if !try(d) { + d *= 2 + continue + } + break } } @@ -78,8 +157,7 @@ func (fakeHandler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface func TestServe(t *testing.T) { stacktest.NoLeak(t) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + ctx := context.Background() tests := []struct { name string @@ -116,13 +194,9 @@ func TestServe(t *testing.T) { } func newFake(t *testing.T, ctx context.Context, l jsonrpc2.Listener) (*jsonrpc2.Connection, func(), error) { - l = jsonrpc2.NewIdleListener(100*time.Millisecond, l) - server, err := jsonrpc2.Serve(ctx, l, jsonrpc2.ConnectionOptions{ + server := jsonrpc2.NewServer(ctx, l, jsonrpc2.ConnectionOptions{ Handler: fakeHandler{}, }) - if err != nil { - return nil, nil, err - } client, err := jsonrpc2.Dial(ctx, l.Dialer(), @@ -142,3 +216,129 @@ func newFake(t *testing.T, ctx context.Context, l jsonrpc2.Listener) (*jsonrpc2. server.Wait() }, nil } + +// TestIdleListenerAcceptCloseRace checks for the Accept/Close race fixed in CL 388597. +// +// (A bug in the idleListener implementation caused a successful Accept to block +// on sending to a background goroutine that could have already exited.) +func TestIdleListenerAcceptCloseRace(t *testing.T) { + ctx := context.Background() + + n := 10 + + // Each iteration of the loop appears to take around a millisecond, so to + // avoid spurious failures we'll set the watchdog for three orders of + // magnitude longer. When the bug was present, this reproduced the deadlock + // reliably on a Linux workstation when run with -count=100, which should be + // frequent enough to show up on the Go build dashboard if it regresses. + watchdog := time.Duration(n) * 1000 * time.Millisecond + timer := time.AfterFunc(watchdog, func() { + debug.SetTraceback("all") + panic(fmt.Sprintf("%s deadlocked after %v", t.Name(), watchdog)) + }) + defer timer.Stop() + + for ; n > 0; n-- { + listener, err := jsonrpc2.NetPipeListener(ctx) + if err != nil { + t.Fatal(err) + } + listener = jsonrpc2.NewIdleListener(24*time.Hour, listener) + + done := make(chan struct{}) + go func() { + conn, err := jsonrpc2.Dial(ctx, listener.Dialer(), jsonrpc2.ConnectionOptions{}) + listener.Close() + if err == nil { + conn.Close() + } + close(done) + }() + + // Accept may return a non-nil error if Close closes the underlying network + // connection before the wrapped Accept call unblocks. However, it must not + // deadlock! + c, err := listener.Accept(ctx) + if err == nil { + c.Close() + } + <-done + } +} + +// TestCloseCallRace checks for a race resulting in a deadlock when a Call on +// one side of the connection races with a Close (or otherwise broken +// connection) initiated from the other side. +// +// (The Call method was waiting for a result from the Read goroutine to +// determine which error value to return, but the Read goroutine was waiting for +// in-flight calls to complete before reporting that result.) +func TestCloseCallRace(t *testing.T) { + ctx := context.Background() + n := 10 + + watchdog := time.Duration(n) * 1000 * time.Millisecond + timer := time.AfterFunc(watchdog, func() { + debug.SetTraceback("all") + panic(fmt.Sprintf("%s deadlocked after %v", t.Name(), watchdog)) + }) + defer timer.Stop() + + for ; n > 0; n-- { + listener, err := jsonrpc2.NetPipeListener(ctx) + if err != nil { + t.Fatal(err) + } + + pokec := make(chan *jsonrpc2.AsyncCall, 1) + + s := jsonrpc2.NewServer(ctx, listener, jsonrpc2.BinderFunc(func(_ context.Context, srvConn *jsonrpc2.Connection) jsonrpc2.ConnectionOptions { + h := jsonrpc2.HandlerFunc(func(ctx context.Context, _ *jsonrpc2.Request) (interface{}, error) { + // Start a concurrent call from the server to the client. + // The point of this test is to ensure this doesn't deadlock + // if the client shuts down the connection concurrently. + // + // The racing Call may or may not receive a response: it should get a + // response if it is sent before the client closes the connection, and + // it should fail with some kind of "connection closed" error otherwise. + go func() { + pokec <- srvConn.Call(ctx, "poke", nil) + }() + + return &msg{"pong"}, nil + }) + return jsonrpc2.ConnectionOptions{Handler: h} + })) + + dialConn, err := jsonrpc2.Dial(ctx, listener.Dialer(), jsonrpc2.ConnectionOptions{}) + if err != nil { + listener.Close() + s.Wait() + t.Fatal(err) + } + + // Calling any method on the server should provoke it to asynchronously call + // us back. While it is starting that call, we will close the connection. + if err := dialConn.Call(ctx, "ping", nil).Await(ctx, nil); err != nil { + t.Error(err) + } + if err := dialConn.Close(); err != nil { + t.Error(err) + } + + // Ensure that the Call on the server side did not block forever when the + // connection closed. + pokeCall := <-pokec + if err := pokeCall.Await(ctx, nil); err == nil { + t.Errorf("unexpected nil error from server-initited call") + } else if errors.Is(err, jsonrpc2.ErrMethodNotFound) { + // The call completed before the Close reached the handler. + } else { + // The error was something else. + t.Logf("server-initiated call completed with expected error: %v", err) + } + + listener.Close() + s.Wait() + } +} diff --git a/internal/jsonrpc2_v2/wire.go b/internal/jsonrpc2_v2/wire.go index 4da129ae6..c8dc9ebf1 100644 --- a/internal/jsonrpc2_v2/wire.go +++ b/internal/jsonrpc2_v2/wire.go @@ -33,6 +33,10 @@ var ( ErrServerOverloaded = NewError(-32000, "JSON RPC overloaded") // ErrUnknown should be used for all non coded errors. ErrUnknown = NewError(-32001, "JSON RPC unknown error") + // ErrServerClosing is returned for calls that arrive while the server is closing. + ErrServerClosing = NewError(-32002, "JSON RPC server is closing") + // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. + ErrClientClosing = NewError(-32003, "JSON RPC client is closing") ) const wireVersion = "2.0" @@ -72,3 +76,11 @@ func NewError(code int64, message string) error { func (err *wireError) Error() string { return err.Message } + +func (err *wireError) Is(other error) bool { + w, ok := other.(*wireError) + if !ok { + return false + } + return err.Code == w.Code +} |