internal/lsp: cancel early

This change allows us to hanel cancel messages as they go into the queue, and
cancel messages that are ahead of them in the queue but not being processed yet.
This should reduce the amount of redundant work that we do when we are handling
a cancel storm.

Change-Id: Id1a58991407d75b68d65bacf96350a4dd69d4d2b
Reviewed-on: https://go-review.googlesource.com/c/tools/+/200766
Run-TryBot: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Ian Cottrell 2019-10-11 16:08:39 -04:00
parent 7178990c25
commit 774d2ec196
8 changed files with 91 additions and 79 deletions

View File

@ -38,9 +38,9 @@ type Handler interface {
// response // response
// Request is called near the start of processing any request. // Request is called near the start of processing any request.
Request(ctx context.Context, direction Direction, r *WireRequest) context.Context Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context
// Response is called near the start of processing any response. // Response is called near the start of processing any response.
Response(ctx context.Context, direction Direction, r *WireResponse) context.Context Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context
// Done is called when any request is fully processed. // Done is called when any request is fully processed.
// For calls, this means the response has also been processed, for notifies // For calls, this means the response has also been processed, for notifies
// this is as soon as the message has been written to the stream. // this is as soon as the message has been written to the stream.
@ -90,11 +90,11 @@ func (EmptyHandler) Cancel(ctx context.Context, conn *Conn, id ID, cancelled boo
return false return false
} }
func (EmptyHandler) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context { func (EmptyHandler) Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context {
return ctx return ctx
} }
func (EmptyHandler) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context { func (EmptyHandler) Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context {
return ctx return ctx
} }

View File

@ -110,7 +110,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e
return fmt.Errorf("marshalling notify request: %v", err) return fmt.Errorf("marshalling notify request: %v", err)
} }
for _, h := range c.handlers { for _, h := range c.handlers {
ctx = h.Request(ctx, Send, request) ctx = h.Request(ctx, c, Send, request)
} }
defer func() { defer func() {
for _, h := range c.handlers { for _, h := range c.handlers {
@ -145,7 +145,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
return fmt.Errorf("marshalling call request: %v", err) return fmt.Errorf("marshalling call request: %v", err)
} }
for _, h := range c.handlers { for _, h := range c.handlers {
ctx = h.Request(ctx, Send, request) ctx = h.Request(ctx, c, Send, request)
} }
// we have to add ourselves to the pending map before we send, otherwise we // we have to add ourselves to the pending map before we send, otherwise we
// are racing the response // are racing the response
@ -175,7 +175,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
select { select {
case response := <-rchan: case response := <-rchan:
for _, h := range c.handlers { for _, h := range c.handlers {
ctx = h.Response(ctx, Receive, response) ctx = h.Response(ctx, c, Receive, response)
} }
// is it an error response? // is it an error response?
if response.Error != nil { if response.Error != nil {
@ -262,7 +262,7 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro
return err return err
} }
for _, h := range r.conn.handlers { for _, h := range r.conn.handlers {
ctx = h.Response(ctx, Send, response) ctx = h.Response(ctx, r.conn, Send, response)
} }
n, err := r.conn.stream.Write(ctx, data) n, err := r.conn.stream.Write(ctx, data)
for _, h := range r.conn.handlers { for _, h := range r.conn.handlers {
@ -347,7 +347,7 @@ func (c *Conn) Run(runCtx context.Context) error {
}, },
} }
for _, h := range c.handlers { for _, h := range c.handlers {
reqCtx = h.Request(reqCtx, Receive, &req.WireRequest) reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest)
reqCtx = h.Read(reqCtx, n) reqCtx = h.Read(reqCtx, n)
} }
c.setHandling(req, true) c.setHandling(req, true)

View File

@ -164,7 +164,7 @@ func (h *handle) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID
return false return false
} }
func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { func (h *handle) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
if h.log { if h.log {
if r.ID != nil { if r.ID != nil {
log.Printf("%v call [%v] %s %v", direction, r.ID, r.Method, r.Params) log.Printf("%v call [%v] %s %v", direction, r.ID, r.Method, r.Params)
@ -177,7 +177,7 @@ func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *j
return ctx return ctx
} }
func (h *handle) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { func (h *handle) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
if h.log { if h.log {
method := ctx.Value("method") method := ctx.Value("method")
elapsed := time.Since(ctx.Value("start").(time.Time)) elapsed := time.Since(ctx.Value("start").(time.Time))

View File

@ -149,7 +149,7 @@ func (h *handler) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.I
return false return false
} }
func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { func (h *handler) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
if r.Method == "" { if r.Method == "" {
panic("no method in rpc stats") panic("no method in rpc stats")
} }
@ -174,7 +174,7 @@ func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *
return ctx return ctx
} }
func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { func (h *handler) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
return ctx return ctx
} }

View File

@ -6,6 +6,7 @@ package protocol
import ( import (
"context" "context"
"encoding/json"
"golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/jsonrpc2"
"golang.org/x/tools/internal/telemetry/log" "golang.org/x/tools/internal/telemetry/log"
@ -13,6 +14,11 @@ import (
"golang.org/x/tools/internal/xcontext" "golang.org/x/tools/internal/xcontext"
) )
const (
// RequestCancelledError should be used when a request is cancelled early.
RequestCancelledError = -32800
)
type DocumentUri = string type DocumentUri = string
type canceller struct{ jsonrpc2.EmptyHandler } type canceller struct{ jsonrpc2.EmptyHandler }
@ -27,6 +33,18 @@ type serverHandler struct {
server Server server Server
} }
func (canceller) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
if direction == jsonrpc2.Receive && r.Method == "$/cancelRequest" {
var params CancelParams
if err := json.Unmarshal(*r.Params, &params); err != nil {
log.Error(ctx, "", err)
} else {
conn.Cancel(params.ID)
}
}
return ctx
}
func (canceller) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool { func (canceller) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool {
if cancelled { if cancelled {
return false return false

View File

@ -8,6 +8,7 @@ import (
"golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/jsonrpc2"
"golang.org/x/tools/internal/telemetry/log" "golang.org/x/tools/internal/telemetry/log"
"golang.org/x/tools/internal/xcontext"
) )
type Client interface { type Client interface {
@ -27,15 +28,12 @@ func (h clientHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver
if delivered { if delivered {
return false return false
} }
switch r.Method { if ctx.Err() != nil {
case "$/cancelRequest": ctx := xcontext.Detach(ctx)
var params CancelParams r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, ""))
if err := json.Unmarshal(*r.Params, &params); err != nil {
sendParseError(ctx, r, err)
return true
}
r.Conn().Cancel(params.ID)
return true return true
}
switch r.Method {
case "window/showMessage": // notif case "window/showMessage": // notif
var params ShowMessageParams var params ShowMessageParams
if err := json.Unmarshal(*r.Params, &params); err != nil { if err := json.Unmarshal(*r.Params, &params); err != nil {

View File

@ -8,6 +8,7 @@ import (
"golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/jsonrpc2"
"golang.org/x/tools/internal/telemetry/log" "golang.org/x/tools/internal/telemetry/log"
"golang.org/x/tools/internal/xcontext"
) )
type Server interface { type Server interface {
@ -46,13 +47,13 @@ type Server interface {
Symbol(context.Context, *WorkspaceSymbolParams) ([]SymbolInformation, error) Symbol(context.Context, *WorkspaceSymbolParams) ([]SymbolInformation, error)
CodeLens(context.Context, *CodeLensParams) ([]CodeLens, error) CodeLens(context.Context, *CodeLensParams) ([]CodeLens, error)
ResolveCodeLens(context.Context, *CodeLens) (*CodeLens, error) ResolveCodeLens(context.Context, *CodeLens) (*CodeLens, error)
DocumentLink(context.Context, *DocumentLinkParams) ([]DocumentLink, error)
ResolveDocumentLink(context.Context, *DocumentLink) (*DocumentLink, error)
Formatting(context.Context, *DocumentFormattingParams) ([]TextEdit, error) Formatting(context.Context, *DocumentFormattingParams) ([]TextEdit, error)
RangeFormatting(context.Context, *DocumentRangeFormattingParams) ([]TextEdit, error) RangeFormatting(context.Context, *DocumentRangeFormattingParams) ([]TextEdit, error)
OnTypeFormatting(context.Context, *DocumentOnTypeFormattingParams) ([]TextEdit, error) OnTypeFormatting(context.Context, *DocumentOnTypeFormattingParams) ([]TextEdit, error)
Rename(context.Context, *RenameParams) (*WorkspaceEdit, error) Rename(context.Context, *RenameParams) (*WorkspaceEdit, error)
PrepareRename(context.Context, *PrepareRenameParams) (*Range, error) PrepareRename(context.Context, *PrepareRenameParams) (*Range, error)
DocumentLink(context.Context, *DocumentLinkParams) ([]DocumentLink, error)
ResolveDocumentLink(context.Context, *DocumentLink) (*DocumentLink, error)
ExecuteCommand(context.Context, *ExecuteCommandParams) (interface{}, error) ExecuteCommand(context.Context, *ExecuteCommandParams) (interface{}, error)
} }
@ -60,15 +61,12 @@ func (h serverHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver
if delivered { if delivered {
return false return false
} }
switch r.Method { if ctx.Err() != nil {
case "$/cancelRequest": ctx := xcontext.Detach(ctx)
var params CancelParams r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, ""))
if err := json.Unmarshal(*r.Params, &params); err != nil {
sendParseError(ctx, r, err)
return true
}
r.Conn().Cancel(params.ID)
return true return true
}
switch r.Method {
case "workspace/didChangeWorkspaceFolders": // notif case "workspace/didChangeWorkspaceFolders": // notif
var params DidChangeWorkspaceFoldersParams var params DidChangeWorkspaceFoldersParams
if err := json.Unmarshal(*r.Params, &params); err != nil { if err := json.Unmarshal(*r.Params, &params); err != nil {
@ -435,6 +433,28 @@ func (h serverHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver
log.Error(ctx, "", err) log.Error(ctx, "", err)
} }
return true return true
case "textDocument/documentLink": // req
var params DocumentLinkParams
if err := json.Unmarshal(*r.Params, &params); err != nil {
sendParseError(ctx, r, err)
return true
}
resp, err := h.server.DocumentLink(ctx, &params)
if err := r.Reply(ctx, resp, err); err != nil {
log.Error(ctx, "", err)
}
return true
case "documentLink/resolve": // req
var params DocumentLink
if err := json.Unmarshal(*r.Params, &params); err != nil {
sendParseError(ctx, r, err)
return true
}
resp, err := h.server.ResolveDocumentLink(ctx, &params)
if err := r.Reply(ctx, resp, err); err != nil {
log.Error(ctx, "", err)
}
return true
case "textDocument/formatting": // req case "textDocument/formatting": // req
var params DocumentFormattingParams var params DocumentFormattingParams
if err := json.Unmarshal(*r.Params, &params); err != nil { if err := json.Unmarshal(*r.Params, &params); err != nil {
@ -490,28 +510,6 @@ func (h serverHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver
log.Error(ctx, "", err) log.Error(ctx, "", err)
} }
return true return true
case "textDocument/documentLink": // req
var params DocumentLinkParams
if err := json.Unmarshal(*r.Params, &params); err != nil {
sendParseError(ctx, r, err)
return true
}
resp, err := h.server.DocumentLink(ctx, &params)
if err := r.Reply(ctx, resp, err); err != nil {
log.Error(ctx, "", err)
}
return true
case "documentLink/resolve": // req
var params DocumentLink
if err := json.Unmarshal(*r.Params, &params); err != nil {
sendParseError(ctx, r, err)
return true
}
resp, err := h.server.ResolveDocumentLink(ctx, &params)
if err := r.Reply(ctx, resp, err); err != nil {
log.Error(ctx, "", err)
}
return true
case "workspace/executeCommand": // req case "workspace/executeCommand": // req
var params ExecuteCommandParams var params ExecuteCommandParams
if err := json.Unmarshal(*r.Params, &params); err != nil { if err := json.Unmarshal(*r.Params, &params); err != nil {
@ -756,6 +754,22 @@ func (s *serverDispatcher) ResolveCodeLens(ctx context.Context, params *CodeLens
return &result, nil return &result, nil
} }
func (s *serverDispatcher) DocumentLink(ctx context.Context, params *DocumentLinkParams) ([]DocumentLink, error) {
var result []DocumentLink
if err := s.Conn.Call(ctx, "textDocument/documentLink", params, &result); err != nil {
return nil, err
}
return result, nil
}
func (s *serverDispatcher) ResolveDocumentLink(ctx context.Context, params *DocumentLink) (*DocumentLink, error) {
var result DocumentLink
if err := s.Conn.Call(ctx, "documentLink/resolve", params, &result); err != nil {
return nil, err
}
return &result, nil
}
func (s *serverDispatcher) Formatting(ctx context.Context, params *DocumentFormattingParams) ([]TextEdit, error) { func (s *serverDispatcher) Formatting(ctx context.Context, params *DocumentFormattingParams) ([]TextEdit, error) {
var result []TextEdit var result []TextEdit
if err := s.Conn.Call(ctx, "textDocument/formatting", params, &result); err != nil { if err := s.Conn.Call(ctx, "textDocument/formatting", params, &result); err != nil {
@ -796,22 +810,6 @@ func (s *serverDispatcher) PrepareRename(ctx context.Context, params *PrepareRen
return &result, nil return &result, nil
} }
func (s *serverDispatcher) DocumentLink(ctx context.Context, params *DocumentLinkParams) ([]DocumentLink, error) {
var result []DocumentLink
if err := s.Conn.Call(ctx, "textDocument/documentLink", params, &result); err != nil {
return nil, err
}
return result, nil
}
func (s *serverDispatcher) ResolveDocumentLink(ctx context.Context, params *DocumentLink) (*DocumentLink, error) {
var result DocumentLink
if err := s.Conn.Call(ctx, "documentLink/resolve", params, &result); err != nil {
return nil, err
}
return &result, nil
}
func (s *serverDispatcher) ExecuteCommand(ctx context.Context, params *ExecuteCommandParams) (interface{}, error) { func (s *serverDispatcher) ExecuteCommand(ctx context.Context, params *ExecuteCommandParams) (interface{}, error) {
var result interface{} var result interface{}
if err := s.Conn.Call(ctx, "workspace/executeCommand", params, &result); err != nil { if err := s.Conn.Call(ctx, "workspace/executeCommand", params, &result); err != nil {

View File

@ -224,7 +224,8 @@ function output(side: side) {
"golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/jsonrpc2"
"golang.org/x/tools/internal/telemetry/log" "golang.org/x/tools/internal/telemetry/log"
) "golang.org/x/tools/internal/xcontext"
)
`); `);
const a = side.name[0].toUpperCase() + side.name.substring(1) const a = side.name[0].toUpperCase() + side.name.substring(1)
f(`type ${a} interface {`); f(`type ${a} interface {`);
@ -235,15 +236,12 @@ function output(side: side) {
if delivered { if delivered {
return false return false
} }
switch r.Method { if ctx.Err() != nil {
case "$/cancelRequest": ctx := xcontext.Detach(ctx)
var params CancelParams r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, ""))
if err := json.Unmarshal(*r.Params, &params); err != nil { return true
sendParseError(ctx, r, err) }
return true switch r.Method {`);
}
r.Conn().Cancel(params.ID)
return true`);
side.cases.forEach((v) => {f(v)}); side.cases.forEach((v) => {f(v)});
f(` f(`
default: default: