From ae534bcb6ccdd13487d0491c2194d10ebcd30ff3 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 16 Dec 2013 14:43:29 -0800 Subject: [PATCH] astutil: fix a comment corruption case This fixes a case where adding an import when there are is no existing import declaration can corrupt the position of comments attached to types. This was the last known goimports/astutil corruption case. See golang.org/issue/6884 for more details. Unfortunately this requires changing the API to add a *token.FileSet, which we should've had before. I will update goimports (the only user of this API?) immediately after submitting this. This CL also contains a hack (used only in this case of no imports): rather than fix the comment positions by hand (something that only Robert might know how to do), it instead just prints the AST, manipulates the source, and re-parses the AST. We can fix up later. Fixes golang/go#6884 R=golang-dev, gri CC=golang-dev https://golang.org/cl/38270043 --- astutil/imports.go | 56 +++++++++++++++++++++++++++++++++++++---- astutil/imports_test.go | 15 ++++++----- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/astutil/imports.go b/astutil/imports.go index 1675aa62cc..37451ec511 100644 --- a/astutil/imports.go +++ b/astutil/imports.go @@ -6,16 +6,22 @@ package astutil import ( + "bufio" + "bytes" + "fmt" "go/ast" + "go/format" + "go/parser" "go/token" + "log" "path" "strconv" "strings" ) // AddImport adds the import path to the file f, if absent. -func AddImport(f *ast.File, ipath string) (added bool) { - return AddNamedImport(f, "", ipath) +func AddImport(fset *token.FileSet, f *ast.File, ipath string) (added bool) { + return AddNamedImport(fset, f, "", ipath) } // AddNamedImport adds the import path to the file f, if absent. @@ -25,7 +31,7 @@ func AddImport(f *ast.File, ipath string) (added bool) { // AddNamedImport(f, "pathpkg", "path") // adds // import pathpkg "path" -func AddNamedImport(f *ast.File, name, ipath string) (added bool) { +func AddNamedImport(fset *token.FileSet, f *ast.File, name, ipath string) (added bool) { if imports(f, ipath) { return false } @@ -46,10 +52,12 @@ func AddNamedImport(f *ast.File, name, ipath string) (added bool) { lastImport = -1 impDecl *ast.GenDecl impIndex = -1 + hasImports = false ) for i, decl := range f.Decls { gen, ok := decl.(*ast.GenDecl) if ok && gen.Tok == token.IMPORT { + hasImports = true lastImport = i // Do not add to import "C", to avoid disrupting the // association with its doc comment, breaking cgo. @@ -72,6 +80,18 @@ func AddNamedImport(f *ast.File, name, ipath string) (added bool) { // If no import decl found, add one after the last import. if impDecl == nil { + // TODO(bradfitz): remove this hack. See comment below on + // addImportViaSourceModification. + if !hasImports { + f2, err := addImportViaSourceModification(fset, f, name, ipath) + if err == nil { + *f = *f2 + return true + } + log.Printf("addImportViaSourceModification error: %v", err) + } + + // TODO(bradfitz): fix above and resume using this old code: impDecl = &ast.GenDecl{ Tok: token.IMPORT, } @@ -110,7 +130,7 @@ func AddNamedImport(f *ast.File, name, ipath string) (added bool) { } // DeleteImport deletes the import path from the file f, if present. -func DeleteImport(f *ast.File, path string) (deleted bool) { +func DeleteImport(fset *token.FileSet, f *ast.File, path string) (deleted bool) { oldImport := importSpec(f, path) // Find the import node that imports path, if any. @@ -163,7 +183,7 @@ func DeleteImport(f *ast.File, path string) (deleted bool) { } // RewriteImport rewrites any import of path oldPath to path newPath. -func RewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) { +func RewriteImport(fset *token.FileSet, f *ast.File, oldPath, newPath string) (rewrote bool) { for _, imp := range f.Imports { if importPath(imp) == oldPath { rewrote = true @@ -371,3 +391,29 @@ func Imports(fset *token.FileSet, f *ast.File) [][]*ast.ImportSpec { return groups } + +// NOTE(bradfitz): this is a bit of a hack for golang.org/issue/6884 +// because we can't get the comment positions correct. Instead of modifying +// the AST, we print it, modify the text, and re-parse it. Gross. +func addImportViaSourceModification(fset *token.FileSet, f *ast.File, name, ipath string) (*ast.File, error) { + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + return nil, fmt.Errorf("Error formatting ast.File node: %v", err) + } + var out bytes.Buffer + sc := bufio.NewScanner(bytes.NewReader(buf.Bytes())) + didAdd := false + for sc.Scan() { + ln := sc.Text() + out.WriteString(ln) + out.WriteByte('\n') + if !didAdd && strings.HasPrefix(ln, "package ") { + fmt.Fprintf(&out, "\nimport %s %q\n\n", name, ipath) + didAdd = true + } + } + if err := sc.Err(); err != nil { + return nil, err + } + return parser.ParseFile(fset, "", out.Bytes(), parser.ParseComments) +} diff --git a/astutil/imports_test.go b/astutil/imports_test.go index 279b147c30..55e2edcf4f 100644 --- a/astutil/imports_test.go +++ b/astutil/imports_test.go @@ -176,9 +176,8 @@ import ( `, }, { - broken: true, - name: "struct comment", - pkg: "time", + name: "struct comment", + pkg: "time", in: `package main // This is a comment before a struct. @@ -203,7 +202,7 @@ func TestAddImport(t *testing.T) { file := parse(t, test.name, test.in) var before bytes.Buffer ast.Fprint(&before, fset, file, nil) - AddNamedImport(file, test.renamedPkg, test.pkg) + AddNamedImport(fset, file, test.renamedPkg, test.pkg) if got := print(t, test.name, file); got != test.out { if test.broken { t.Logf("%s is known broken:\ngot: %s\nwant: %s", test.name, got, test.out) @@ -220,8 +219,8 @@ func TestAddImport(t *testing.T) { func TestDoubleAddImport(t *testing.T) { file := parse(t, "doubleimport", "package main\n") - AddImport(file, "os") - AddImport(file, "bytes") + AddImport(fset, file, "os") + AddImport(fset, file, "bytes") want := `package main import ( @@ -416,7 +415,7 @@ import ( func TestDeleteImport(t *testing.T) { for _, test := range deleteTests { file := parse(t, test.name, test.in) - DeleteImport(file, test.pkg) + DeleteImport(fset, file, test.pkg) if got := print(t, test.name, file); got != test.out { t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out) } @@ -545,7 +544,7 @@ var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18 func TestRewriteImport(t *testing.T) { for _, test := range rewriteTests { file := parse(t, test.name, test.in) - RewriteImport(file, test.srcPkg, test.dstPkg) + RewriteImport(fset, file, test.srcPkg, test.dstPkg) if got := print(t, test.name, file); got != test.out { t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out) }