diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index 4a8974f3bd..c6aefeed23 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -8,7 +8,10 @@ package source import ( "bytes" "context" + "go/ast" "go/format" + "go/parser" + "go/token" "golang.org/x/tools/internal/imports" "golang.org/x/tools/internal/lsp/diff" @@ -84,64 +87,6 @@ func formatSource(ctx context.Context, s Snapshot, f File) ([]byte, error) { return format.Source(data) } -// Imports formats a file using the goimports tool. -func Imports(ctx context.Context, view View, f File) ([]protocol.TextEdit, error) { - ctx, done := trace.StartSpan(ctx, "source.Imports") - defer done() - - _, cphs, err := view.CheckPackageHandles(ctx, f) - if err != nil { - return nil, err - } - cph, err := NarrowestCheckPackageHandle(cphs) - if err != nil { - return nil, err - } - pkg, err := cph.Check(ctx) - if err != nil { - return nil, err - } - if hasListErrors(pkg) { - return nil, errors.Errorf("%s has list errors, not running goimports", f.URI()) - } - ph, err := pkg.File(f.URI()) - if err != nil { - return nil, err - } - // Be extra careful that the file's ParseMode is correct, - // otherwise we might replace the user's code with a trimmed AST. - if ph.Mode() != ParseFull { - return nil, errors.Errorf("%s was parsed in the incorrect mode", ph.File().Identity().URI) - } - options := &imports.Options{ - // Defaults. - AllErrors: true, - Comments: true, - Fragment: true, - FormatOnly: false, - TabIndent: true, - TabWidth: 8, - } - var formatted []byte - importFn := func(opts *imports.Options) error { - data, _, err := ph.File().Read(ctx) - if err != nil { - return err - } - formatted, err = imports.Process(ph.File().Identity().URI.Filename(), data, opts) - return err - } - err = view.RunProcessEnvFunc(ctx, importFn, options) - if err != nil { - return nil, err - } - _, m, _, err := ph.Parse(ctx) - if err != nil { - return nil, err - } - return computeTextEdits(ctx, view, ph.File(), m, string(formatted)) -} - type ImportFix struct { Fix *imports.ImportFix Edits []protocol.TextEdit @@ -151,7 +96,7 @@ type ImportFix struct { // In addition to returning the result of applying all edits, // it returns a list of fixes that could be applied to the file, with the // corresponding TextEdits that would be needed to apply that fix. -func AllImportsFixes(ctx context.Context, view View, f File) (edits []protocol.TextEdit, editsPerFix []*ImportFix, err error) { +func AllImportsFixes(ctx context.Context, view View, f File) (allFixEdits []protocol.TextEdit, editsPerFix []*ImportFix, err error) { ctx, done := trace.StartSpan(ctx, "source.AllImportsFixes") defer done() @@ -170,6 +115,16 @@ func AllImportsFixes(ctx context.Context, view View, f File) (edits []protocol.T if hasListErrors(pkg) { return nil, nil, errors.Errorf("%s has list errors, not running goimports", f.URI()) } + var ph ParseGoHandle + for _, h := range pkg.Files() { + if h.File().Identity().URI == f.URI() { + ph = h + } + } + if ph == nil { + return nil, nil, errors.Errorf("no ParseGoHandle for %s", f.URI()) + } + options := &imports.Options{ // Defaults. AllErrors: true, @@ -179,65 +134,150 @@ func AllImportsFixes(ctx context.Context, view View, f File) (edits []protocol.T TabIndent: true, TabWidth: 8, } - importFn := func(opts *imports.Options) error { - var ph ParseGoHandle - for _, h := range pkg.Files() { - if h.File().Identity().URI == f.URI() { - ph = h - } - } - if ph == nil { - return errors.Errorf("no ParseGoHandle for %s", f.URI()) - } - data, _, err := ph.File().Read(ctx) - if err != nil { - return err - } - fixes, err := imports.FixImports(f.URI().Filename(), data, opts) - if err != nil { - return err - } - // Do not change the file if there are no import fixes. - if len(fixes) == 0 { - return nil - } - // Apply all of the import fixes to the file. - formatted, err := imports.ApplyFixes(fixes, f.URI().Filename(), data, options) - if err != nil { - return err - } - _, m, _, err := ph.Parse(ctx) - if err != nil { - return err - } - edits, err = computeTextEdits(ctx, view, ph.File(), m, string(formatted)) - if err != nil { - return err - } - // Add the edits for each fix to the result. - editsPerFix = make([]*ImportFix, len(fixes)) - for i, fix := range fixes { - formatted, err := imports.ApplyFixes([]*imports.ImportFix{fix}, f.URI().Filename(), data, options) - if err != nil { - return err - } - edits, err := computeTextEdits(ctx, view, ph.File(), m, string(formatted)) - if err != nil { - return err - } - editsPerFix[i] = &ImportFix{ - Fix: fix, - Edits: edits, - } - } - return nil - } - err = view.RunProcessEnvFunc(ctx, importFn, options) + err = view.RunProcessEnvFunc(ctx, func(opts *imports.Options) error { + allFixEdits, editsPerFix, err = computeImportEdits(ctx, view, ph, opts) + return err + }, options) if err != nil { return nil, nil, err } - return edits, editsPerFix, nil + return allFixEdits, editsPerFix, nil +} + +// computeImportEdits computes a set of edits that perform one or all of the +// necessary import fixes. +func computeImportEdits(ctx context.Context, view View, ph ParseGoHandle, options *imports.Options) (allFixEdits []protocol.TextEdit, editsPerFix []*ImportFix, err error) { + filename := ph.File().Identity().URI.Filename() + + // Build up basic information about the original file. + origData, _, err := ph.File().Read(ctx) + if err != nil { + return nil, nil, err + } + origAST, origMapper, _, err := ph.Parse(ctx) + if err != nil { + return nil, nil, err + } + origImports, origImportOffset := trimToImports(view.Session().Cache().FileSet(), origAST, origData) + + computeFixEdits := func(fixes []*imports.ImportFix) ([]protocol.TextEdit, error) { + // Apply the fixes and re-parse the file so that we can locate the + // new imports. + fixedData, err := imports.ApplyFixes(fixes, filename, origData, options) + if err != nil { + return nil, err + } + fixedFset := token.NewFileSet() + fixedAST, err := parser.ParseFile(fixedFset, filename, fixedData, parser.ImportsOnly) + if err != nil { + return nil, err + } + fixedImports, fixedImportsOffset := trimToImports(fixedFset, fixedAST, fixedData) + + // Prepare the diff. If both sides had import statements, we can diff + // just those sections against each other, then shift the resulting + // edits to the right lines in the original file. + left, right := origImports, fixedImports + converter := span.NewContentConverter(filename, origImports) + offset := origImportOffset + + // If one side or the other has no imports, we won't know where to + // anchor the diffs. Instead, use the beginning of the file, up to its + // first non-imports decl. We know the imports code will insert + // somewhere before that. + if origImportOffset == 0 || fixedImportsOffset == 0 { + left = trimToFirstNonImport(view.Session().Cache().FileSet(), origAST, origData) + // We need the whole AST here, not just the ImportsOnly AST we parsed above. + fixedAST, err = parser.ParseFile(fixedFset, filename, fixedData, 0) + if err != nil { + return nil, err + } + right = trimToFirstNonImport(fixedFset, fixedAST, fixedData) + // We're now working with a prefix of the original file, so we can + // use the original converter, and there is no offset on the edits. + converter = origMapper.Converter + offset = 0 + } + + // Perform the diff and adjust the results for the trimming, if any. + edits := view.Options().ComputeEdits(ph.File().Identity().URI, string(left), string(right)) + for i := range edits { + s, err := edits[i].Span.WithPosition(converter) + if err != nil { + return nil, err + } + start := span.NewPoint(s.Start().Line()+offset, s.Start().Column(), -1) + end := span.NewPoint(s.End().Line()+offset, s.End().Column(), -1) + edits[i].Span = span.New(s.URI(), start, end) + } + return ToProtocolEdits(origMapper, edits) + } + + allFixes, err := imports.FixImports(filename, origData, options) + if err != nil { + return nil, nil, err + } + + allFixEdits, err = computeFixEdits(allFixes) + if err != nil { + return nil, nil, err + } + + // Apply all of the import fixes to the file. + // Add the edits for each fix to the result. + for _, fix := range allFixes { + edits, err := computeFixEdits([]*imports.ImportFix{fix}) + if err != nil { + return nil, nil, err + } + editsPerFix = append(editsPerFix, &ImportFix{ + Fix: fix, + Edits: edits, + }) + } + return allFixEdits, editsPerFix, nil +} + +// trimToImports returns a section of the source file that covers all of the +// import declarations, and the line offset into the file that section starts at. +func trimToImports(fset *token.FileSet, f *ast.File, src []byte) ([]byte, int) { + var firstImport, lastImport ast.Decl + for _, decl := range f.Decls { + if gen, ok := decl.(*ast.GenDecl); ok && gen.Tok == token.IMPORT { + if firstImport == nil { + firstImport = decl + } + lastImport = decl + } + } + + if firstImport == nil { + return nil, 0 + } + start := firstImport.Pos() + end := fset.File(f.Pos()).LineStart(fset.Position(lastImport.End()).Line + 1) + startLineOffset := fset.Position(start).Line - 1 // lines are 1-indexed. + return src[fset.Position(firstImport.Pos()).Offset:fset.Position(end).Offset], startLineOffset +} + +// trimToFirstNonImport returns src from the beginning to the first non-import +// declaration, or the end of the file if there is no such decl. +func trimToFirstNonImport(fset *token.FileSet, f *ast.File, src []byte) []byte { + var firstDecl ast.Decl + for _, decl := range f.Decls { + if gen, ok := decl.(*ast.GenDecl); ok && gen.Tok == token.IMPORT { + continue + } + firstDecl = decl + break + } + + end := f.End() + if firstDecl != nil { + end = fset.File(f.Pos()).LineStart(fset.Position(firstDecl.Pos()).Line - 1) + } + return src[fset.Position(f.Pos()).Offset:fset.Position(end).Offset] } // CandidateImports returns every import that could be added to filename. diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index d3590e532e..65a7371d59 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -450,22 +450,14 @@ func (r *runner) Import(t *testing.T, spn span.Span) { ctx := r.ctx uri := spn.URI() filename := uri.Filename() - goimported := string(r.data.Golden("goimports", filename, func() ([]byte, error) { - cmd := exec.Command("goimports", filename) - out, _ := cmd.Output() // ignore error, sometimes we have intentionally ungofmt-able files - return out, nil - })) f, err := r.view.GetFile(ctx, uri) if err != nil { t.Fatalf("failed for %v: %v", spn, err) } fh := r.view.Snapshot().Handle(r.ctx, f) - edits, err := source.Imports(ctx, r.view, f) + edits, _, err := source.AllImportsFixes(ctx, r.view, f) if err != nil { - if goimported != "" { - t.Error(err) - } - return + t.Error(err) } data, _, err := fh.Read(ctx) if err != nil { @@ -480,8 +472,11 @@ func (r *runner) Import(t *testing.T, spn span.Span) { t.Error(err) } got := diff.ApplyEdits(string(data), diffEdits) - if goimported != got { - t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, goimported, got) + want := string(r.data.Golden("goimports", filename, func() ([]byte, error) { + return []byte(got), nil + })) + if want != got { + t.Errorf("import failed for %s, expected:\n%v\ngot:\n%v", filename, want, got) } } diff --git a/internal/lsp/testdata/imports/add_import.go.golden b/internal/lsp/testdata/imports/add_import.go.golden new file mode 100644 index 0000000000..16af110a07 --- /dev/null +++ b/internal/lsp/testdata/imports/add_import.go.golden @@ -0,0 +1,13 @@ +-- goimports -- +package imports //@import("package") + +import ( + "bytes" + "fmt" +) + +func _() { + fmt.Println("") + bytes.NewBuffer(nil) +} + diff --git a/internal/lsp/testdata/imports/add_import.go.in b/internal/lsp/testdata/imports/add_import.go.in new file mode 100644 index 0000000000..7928e6f710 --- /dev/null +++ b/internal/lsp/testdata/imports/add_import.go.in @@ -0,0 +1,10 @@ +package imports //@import("package") + +import ( + "fmt" +) + +func _() { + fmt.Println("") + bytes.NewBuffer(nil) +} diff --git a/internal/lsp/testdata/imports/good_imports.go.golden b/internal/lsp/testdata/imports/good_imports.go.golden index d37a6c7511..2abdae4d72 100644 --- a/internal/lsp/testdata/imports/good_imports.go.golden +++ b/internal/lsp/testdata/imports/good_imports.go.golden @@ -4,6 +4,6 @@ package imports //@import("package") import "fmt" func _() { - fmt.Println("") +fmt.Println("") } diff --git a/internal/lsp/testdata/imports/good_imports.go.in b/internal/lsp/testdata/imports/good_imports.go.in new file mode 100644 index 0000000000..a03c06c6d9 --- /dev/null +++ b/internal/lsp/testdata/imports/good_imports.go.in @@ -0,0 +1,7 @@ +package imports //@import("package") + +import "fmt" + +func _() { +fmt.Println("") +} diff --git a/internal/lsp/testdata/imports/multiple_blocks.go.golden b/internal/lsp/testdata/imports/multiple_blocks.go.golden new file mode 100644 index 0000000000..d37a6c7511 --- /dev/null +++ b/internal/lsp/testdata/imports/multiple_blocks.go.golden @@ -0,0 +1,9 @@ +-- goimports -- +package imports //@import("package") + +import "fmt" + +func _() { + fmt.Println("") +} + diff --git a/internal/lsp/testdata/imports/good_imports.go b/internal/lsp/testdata/imports/multiple_blocks.go.in similarity index 83% rename from internal/lsp/testdata/imports/good_imports.go rename to internal/lsp/testdata/imports/multiple_blocks.go.in index 40283fa15d..3f2fb99ea2 100644 --- a/internal/lsp/testdata/imports/good_imports.go +++ b/internal/lsp/testdata/imports/multiple_blocks.go.in @@ -2,6 +2,8 @@ package imports //@import("package") import "fmt" +import "bytes" + func _() { fmt.Println("") } diff --git a/internal/lsp/testdata/imports/needs_imports.go b/internal/lsp/testdata/imports/needs_imports.go.in similarity index 100% rename from internal/lsp/testdata/imports/needs_imports.go rename to internal/lsp/testdata/imports/needs_imports.go.in diff --git a/internal/lsp/testdata/imports/remove_import.go.golden b/internal/lsp/testdata/imports/remove_import.go.golden new file mode 100644 index 0000000000..3df80882ca --- /dev/null +++ b/internal/lsp/testdata/imports/remove_import.go.golden @@ -0,0 +1,11 @@ +-- goimports -- +package imports //@import("package") + +import ( + "fmt" +) + +func _() { + fmt.Println("") +} + diff --git a/internal/lsp/testdata/imports/remove_import.go.in b/internal/lsp/testdata/imports/remove_import.go.in new file mode 100644 index 0000000000..09060bada4 --- /dev/null +++ b/internal/lsp/testdata/imports/remove_import.go.in @@ -0,0 +1,10 @@ +package imports //@import("package") + +import ( + "bytes" + "fmt" +) + +func _() { + fmt.Println("") +} diff --git a/internal/lsp/testdata/imports/remove_imports.go.golden b/internal/lsp/testdata/imports/remove_imports.go.golden new file mode 100644 index 0000000000..530c8c09fe --- /dev/null +++ b/internal/lsp/testdata/imports/remove_imports.go.golden @@ -0,0 +1,6 @@ +-- goimports -- +package imports //@import("package") + +func _() { +} + diff --git a/internal/lsp/testdata/imports/remove_imports.go.in b/internal/lsp/testdata/imports/remove_imports.go.in new file mode 100644 index 0000000000..44d065f258 --- /dev/null +++ b/internal/lsp/testdata/imports/remove_imports.go.in @@ -0,0 +1,9 @@ +package imports //@import("package") + +import ( + "bytes" + "fmt" +) + +func _() { +} diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index 0dd57347ff..4acd1f3adc 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -9,7 +9,7 @@ CaseSensitiveCompletionsCount = 4 DiagnosticsCount = 22 FoldingRangesCount = 2 FormatCount = 6 -ImportCount = 2 +ImportCount = 6 SuggestedFixCount = 1 DefinitionsCount = 38 TypeDefinitionsCount = 2