diff --git a/internal/lsp/cmd/format.go b/internal/lsp/cmd/format.go index bbe9f7a212..c78d2fab54 100644 --- a/internal/lsp/cmd/format.go +++ b/internal/lsp/cmd/format.go @@ -10,9 +10,9 @@ import ( "fmt" "io/ioutil" - "golang.org/x/tools/internal/lsp" "golang.org/x/tools/internal/lsp/diff" "golang.org/x/tools/internal/lsp/protocol" + "golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/span" errors "golang.org/x/xerrors" ) @@ -76,7 +76,7 @@ func (f *format) Run(ctx context.Context, args ...string) error { if err != nil { return errors.Errorf("%v: %v", spn, err) } - sedits, err := lsp.FromProtocolEdits(file.mapper, edits) + sedits, err := source.FromProtocolEdits(file.mapper, edits) if err != nil { return errors.Errorf("%v: %v", spn, err) } diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go index 9b6fee47b8..e7fb25bc90 100644 --- a/internal/lsp/code_action.go +++ b/internal/lsp/code_action.go @@ -44,10 +44,6 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara if err != nil { return nil, err } - m, err := getMapper(ctx, f) - if err != nil { - return nil, err - } // Determine the supported actions for this file kind. fileKind := f.Handle(ctx).Kind() @@ -71,14 +67,9 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara return nil, errors.Errorf("no supported code action to execute for %s, wanted %v", uri, params.Context.Only) } - spn, err := m.RangeSpan(params.Range) - if err != nil { - return nil, err - } - var codeActions []protocol.CodeAction - edits, editsPerFix, err := organizeImports(ctx, view, spn) + edits, editsPerFix, err := source.AllImportsFixes(ctx, view, f) if err != nil { return nil, err } @@ -105,13 +96,13 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara // each action is the addition, removal, or renaming of one import. for _, importFix := range editsPerFix { // Get the diagnostics this fix would affect. - if fixDiagnostics := importDiagnostics(importFix.fix, params.Context.Diagnostics); len(fixDiagnostics) > 0 { + if fixDiagnostics := importDiagnostics(importFix.Fix, params.Context.Diagnostics); len(fixDiagnostics) > 0 { codeActions = append(codeActions, protocol.CodeAction{ - Title: importFixTitle(importFix.fix), + Title: importFixTitle(importFix.Fix), Kind: protocol.QuickFix, Edit: &protocol.WorkspaceEdit{ Changes: &map[string][]protocol.TextEdit{ - string(uri): importFix.edits, + string(uri): importFix.Edits, }, }, Diagnostics: fixDiagnostics, @@ -128,7 +119,7 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara Kind: protocol.SourceOrganizeImports, Edit: &protocol.WorkspaceEdit{ Changes: &map[string][]protocol.TextEdit{ - string(spn.URI()): edits, + string(uri): edits, }, }, }) @@ -142,35 +133,6 @@ type protocolImportFix struct { edits []protocol.TextEdit } -func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, []*protocolImportFix, error) { - f, m, rng, err := spanToRange(ctx, view, s) - if err != nil { - return nil, nil, err - } - edits, editsPerFix, err := source.AllImportsFixes(ctx, view, f, rng) - if err != nil { - return nil, nil, err - } - // Convert all source edits to protocol edits. - pEdits, err := source.ToProtocolEdits(m, edits) - if err != nil { - return nil, nil, err - } - - pEditsPerFix := make([]*protocolImportFix, len(editsPerFix)) - for i, fix := range editsPerFix { - pEdits, err := source.ToProtocolEdits(m, fix.Edits) - if err != nil { - return nil, nil, err - } - pEditsPerFix[i] = &protocolImportFix{ - fix: fix.Fix, - edits: pEdits, - } - } - return pEdits, pEditsPerFix, nil -} - // findImports determines if a given diagnostic represents an error that could // be fixed by organizing imports. // TODO(rstambler): We need a better way to check this than string matching. diff --git a/internal/lsp/format.go b/internal/lsp/format.go index d819afd248..33e9b8d633 100644 --- a/internal/lsp/format.go +++ b/internal/lsp/format.go @@ -7,7 +7,6 @@ package lsp import ( "context" - "golang.org/x/tools/internal/lsp/diff" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/span" @@ -16,56 +15,9 @@ import ( func (s *Server) formatting(ctx context.Context, params *protocol.DocumentFormattingParams) ([]protocol.TextEdit, error) { uri := span.NewURI(params.TextDocument.URI) view := s.session.ViewOf(uri) - spn := span.New(uri, span.Point{}, span.Point{}) - f, m, rng, err := spanToRange(ctx, view, spn) + f, err := view.GetFile(ctx, uri) if err != nil { return nil, err } - edits, err := source.Format(ctx, f, rng) - if err != nil { - return nil, err - } - return source.ToProtocolEdits(m, edits) -} - -func spanToRange(ctx context.Context, view source.View, spn span.Span) (source.GoFile, *protocol.ColumnMapper, span.Range, error) { - f, err := getGoFile(ctx, view, spn.URI()) - if err != nil { - return nil, nil, span.Range{}, err - } - m, err := getMapper(ctx, f) - if err != nil { - return nil, nil, span.Range{}, err - } - rng, err := spn.Range(m.Converter) - if err != nil { - return nil, nil, span.Range{}, err - } - if rng.Start == rng.End { - // If we have a single point, assume we want the whole file. - tok, err := f.GetToken(ctx) - if err != nil { - return nil, nil, span.Range{}, err - } - rng.End = tok.Pos(tok.Size()) - } - return f, m, rng, nil -} - -func FromProtocolEdits(m *protocol.ColumnMapper, edits []protocol.TextEdit) ([]diff.TextEdit, error) { - if edits == nil { - return nil, nil - } - result := make([]diff.TextEdit, len(edits)) - for i, edit := range edits { - spn, err := m.RangeSpan(edit.Range) - if err != nil { - return nil, err - } - result[i] = diff.TextEdit{ - Span: spn, - NewText: edit.NewText, - } - } - return result, nil + return source.Format(ctx, view, f) } diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index ba1c0266cb..cb13886a1f 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -447,7 +447,7 @@ func (r *runner) Format(t *testing.T, data tests.Formats) { if err != nil { t.Fatal(err) } - sedits, err := FromProtocolEdits(m, edits) + sedits, err := source.FromProtocolEdits(m, edits) if err != nil { t.Error(err) } @@ -493,7 +493,7 @@ func (r *runner) Import(t *testing.T, data tests.Imports) { edits = (*a.Edit.Changes)[string(uri)] } } - sedits, err := FromProtocolEdits(m, edits) + sedits, err := source.FromProtocolEdits(m, edits) if err != nil { t.Error(err) } @@ -679,7 +679,7 @@ func (r *runner) Rename(t *testing.T, data tests.Renames) { t.Fatal(err) } - sedits, err := FromProtocolEdits(m, edits) + sedits, err := source.FromProtocolEdits(m, edits) if err != nil { t.Error(err) } diff --git a/internal/lsp/signature_help.go b/internal/lsp/signature_help.go index f846f94b30..d33774b64b 100644 --- a/internal/lsp/signature_help.go +++ b/internal/lsp/signature_help.go @@ -21,21 +21,9 @@ func (s *Server) signatureHelp(ctx context.Context, params *protocol.TextDocumen if err != nil { return nil, err } - m, err := getMapper(ctx, f) - if err != nil { - return nil, err - } - spn, err := m.PointSpan(params.Position) - if err != nil { - return nil, err - } - rng, err := spn.Range(m.Converter) - if err != nil { - return nil, err - } info, err := source.SignatureHelp(ctx, view, f, params.Position) if err != nil { - log.Print(ctx, "no signature help", tag.Of("At", rng), tag.Of("Failure", err)) + log.Print(ctx, "no signature help", tag.Of("At", params.Position), tag.Of("Failure", err)) return nil, nil } return toProtocolSignatureHelp(info), nil diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index ae49558c03..ec3c6caf5b 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -9,28 +9,31 @@ import ( "bytes" "context" "go/format" + "go/token" - "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" "golang.org/x/tools/internal/imports" "golang.org/x/tools/internal/lsp/diff" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/span" - "golang.org/x/tools/internal/telemetry/log" "golang.org/x/tools/internal/telemetry/trace" errors "golang.org/x/xerrors" ) // Format formats a file with a given range. -func Format(ctx context.Context, f GoFile, rng span.Range) ([]diff.TextEdit, error) { +func Format(ctx context.Context, view View, f File) ([]protocol.TextEdit, error) { ctx, done := trace.StartSpan(ctx, "source.Format") defer done() - file, err := f.GetAST(ctx, ParseFull) + gof, ok := f.(GoFile) + if !ok { + return nil, errors.Errorf("formatting is not supported for non-Go files") + } + file, err := gof.GetAST(ctx, ParseFull) if file == nil { return nil, err } - pkg, err := f.GetPackage(ctx) + pkg, err := gof.GetPackage(ctx) if err != nil { return nil, err } @@ -43,13 +46,8 @@ func Format(ctx context.Context, f GoFile, rng span.Range) ([]diff.TextEdit, err if err != nil { return nil, err } - return computeTextEdits(ctx, f, string(formatted)), nil + return computeTextEdits(ctx, view.Session().Cache().FileSet(), f, string(formatted)) } - path, exact := astutil.PathEnclosingInterval(file, rng.Start, rng.End) - if !exact || len(path) == 0 { - return nil, errors.Errorf("no exact AST node matching the specified range") - } - node := path[0] fset := f.FileSet() buf := &bytes.Buffer{} @@ -58,10 +56,10 @@ func Format(ctx context.Context, f GoFile, rng span.Range) ([]diff.TextEdit, err // of Go used to build the LSP server will determine how it formats code. // This should be acceptable for all users, who likely be prompted to rebuild // the LSP server on each Go release. - if err := format.Node(buf, fset, node); err != nil { + if err := format.Node(buf, fset, file); err != nil { return nil, err } - return computeTextEdits(ctx, f, buf.String()), nil + return computeTextEdits(ctx, view.Session().Cache().FileSet(), f, buf.String()) } func formatSource(ctx context.Context, file File) ([]byte, error) { @@ -75,7 +73,7 @@ func formatSource(ctx context.Context, file File) ([]byte, error) { } // Imports formats a file using the goimports tool. -func Imports(ctx context.Context, view View, f GoFile, rng span.Range) ([]diff.TextEdit, error) { +func Imports(ctx context.Context, view View, f GoFile, rng span.Range) ([]protocol.TextEdit, error) { ctx, done := trace.StartSpan(ctx, "source.Imports") defer done() data, _, err := f.Handle(ctx).Read(ctx) @@ -108,26 +106,32 @@ func Imports(ctx context.Context, view View, f GoFile, rng span.Range) ([]diff.T if err != nil { return nil, err } - return computeTextEdits(ctx, f, string(formatted)), nil + return computeTextEdits(ctx, view.Session().Cache().FileSet(), f, string(formatted)) } type ImportFix struct { Fix *imports.ImportFix - Edits []diff.TextEdit + Edits []protocol.TextEdit } // AllImportsFixes formats f for each possible fix to the imports. // In addition to returning the result of applying all edits, // it returns a list of fixes that could be applied to the file, with the // corresponding TextEdits that would be needed to apply that fix. -func AllImportsFixes(ctx context.Context, view View, f GoFile, rng span.Range) (edits []diff.TextEdit, editsPerFix []*ImportFix, err error) { +func AllImportsFixes(ctx context.Context, view View, f File) (edits []protocol.TextEdit, editsPerFix []*ImportFix, err error) { ctx, done := trace.StartSpan(ctx, "source.AllImportsFixes") defer done() + + gof, ok := f.(GoFile) + if !ok { + return nil, nil, errors.Errorf("no imports fixes for non-Go files: %v", err) + } + data, _, err := f.Handle(ctx).Read(ctx) if err != nil { return nil, nil, err } - pkg, err := f.GetPackage(ctx) + pkg, err := gof.GetPackage(ctx) if err != nil { return nil, nil, err } @@ -153,7 +157,10 @@ func AllImportsFixes(ctx context.Context, view View, f GoFile, rng span.Range) ( if err != nil { return err } - edits = computeTextEdits(ctx, f, string(formatted)) + edits, err = computeTextEdits(ctx, view.Session().Cache().FileSet(), f, string(formatted)) + if err != nil { + return err + } // Add the edits for each fix to the result. editsPerFix = make([]*ImportFix, len(fixes)) for i, fix := range fixes { @@ -161,12 +168,16 @@ func AllImportsFixes(ctx context.Context, view View, f GoFile, rng span.Range) ( if err != nil { return err } + edits, err := computeTextEdits(ctx, view.Session().Cache().FileSet(), f, string(formatted)) + if err != nil { + return err + } editsPerFix[i] = &ImportFix{ Fix: fix, - Edits: computeTextEdits(ctx, f, string(formatted)), + Edits: edits, } } - return err + return nil } err = view.RunProcessEnvFunc(ctx, importFn, options) if err != nil { @@ -225,15 +236,19 @@ func hasListErrors(errors []packages.Error) bool { return false } -func computeTextEdits(ctx context.Context, file File, formatted string) (edits []diff.TextEdit) { +func computeTextEdits(ctx context.Context, fset *token.FileSet, f File, formatted string) ([]protocol.TextEdit, error) { ctx, done := trace.StartSpan(ctx, "source.computeTextEdits") defer done() - data, _, err := file.Handle(ctx).Read(ctx) + + data, _, err := f.Handle(ctx).Read(ctx) if err != nil { - log.Error(ctx, "Cannot compute text edits", err) - return nil + return nil, err } - return diff.ComputeEdits(file.URI(), string(data), formatted) + edits := diff.ComputeEdits(f.URI(), string(data), formatted) + m := protocol.NewColumnMapper(f.URI(), f.URI().Filename(), fset, nil, data) + + return ToProtocolEdits(m, edits) + } func ToProtocolEdits(m *protocol.ColumnMapper, edits []diff.TextEdit) ([]protocol.TextEdit, error) { @@ -253,3 +268,21 @@ func ToProtocolEdits(m *protocol.ColumnMapper, edits []diff.TextEdit) ([]protoco } return result, nil } + +func FromProtocolEdits(m *protocol.ColumnMapper, edits []protocol.TextEdit) ([]diff.TextEdit, error) { + if edits == nil { + return nil, nil + } + result := make([]diff.TextEdit, len(edits)) + for i, edit := range edits { + spn, err := m.RangeSpan(edit.Range) + if err != nil { + return nil, err + } + result[i] = diff.TextEdit{ + Span: spn, + NewText: edit.NewText, + } + } + return result, nil +} diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index 9865a0e8e4..5881cd529a 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -415,15 +415,7 @@ func (r *runner) Format(t *testing.T, data tests.Formats) { if err != nil { t.Fatalf("failed for %v: %v", spn, err) } - tok, err := f.(source.GoFile).GetToken(ctx) - if err != nil { - t.Fatalf("failed to get token for %s: %v", spn.URI(), err) - } - rng, err := spn.Range(span.NewTokenConverter(f.FileSet(), tok)) - if err != nil { - t.Fatalf("failed for %v: %v", spn, err) - } - edits, err := source.Format(ctx, f.(source.GoFile), rng) + edits, err := source.Format(ctx, r.view, f) if err != nil { if gofmted != "" { t.Error(err) @@ -435,7 +427,12 @@ func (r *runner) Format(t *testing.T, data tests.Formats) { t.Error(err) continue } - got := diff.ApplyEdits(string(data), edits) + m := protocol.NewColumnMapper(uri, filename, r.view.Session().Cache().FileSet(), nil, data) + diffEdits, err := source.FromProtocolEdits(m, edits) + if err != nil { + t.Error(err) + } + got := diff.ApplyEdits(string(data), diffEdits) if gofmted != got { t.Errorf("format failed for %s, expected:\n%v\ngot:\n%v", filename, gofmted, got) } @@ -476,7 +473,12 @@ func (r *runner) Import(t *testing.T, data tests.Imports) { t.Error(err) continue } - got := diff.ApplyEdits(string(data), edits) + m := protocol.NewColumnMapper(uri, filename, r.view.Session().Cache().FileSet(), nil, data) + diffEdits, err := source.FromProtocolEdits(m, edits) + if err != nil { + t.Error(err) + } + got := diff.ApplyEdits(string(data), diffEdits) if goimported != got { t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, goimported, got) }