diff --git a/cmd/goimports/goimports.go b/cmd/goimports/goimports.go index d7857d22f4..ed779f613e 100644 --- a/cmd/goimports/goimports.go +++ b/cmd/goimports/goimports.go @@ -11,6 +11,7 @@ import ( "go/scanner" "io" "io/ioutil" + "log" "os" "os/exec" "path/filepath" @@ -22,10 +23,11 @@ import ( var ( // main operation modes - list = flag.Bool("l", false, "list files whose formatting differs from goimport's") - write = flag.Bool("w", false, "write result to (source) file instead of stdout") - doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") - srcdir = flag.String("srcdir", "", "choose imports as if source code is from `dir`") + list = flag.Bool("l", false, "list files whose formatting differs from goimport's") + write = flag.Bool("w", false, "write result to (source) file instead of stdout") + doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") + srcdir = flag.String("srcdir", "", "choose imports as if source code is from `dir`") + verbose = flag.Bool("v", false, "verbose logging") options = &imports.Options{ TabWidth: 8, @@ -154,6 +156,10 @@ func gofmtMain() { flag.Usage = usage paths := parseFlags() + if *verbose { + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + imports.Debug = true + } if options.TabWidth < 0 { fmt.Fprintf(os.Stderr, "negative tabwidth %d\n", options.TabWidth) exitCode = 2 diff --git a/imports/fix.go b/imports/fix.go index 5e260ccc8f..05bc8dc193 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -10,15 +10,21 @@ import ( "go/build" "go/parser" "go/token" + "io/ioutil" + "log" "os" "path" "path/filepath" + "sort" "strings" "sync" "golang.org/x/tools/go/ast/astutil" ) +// Debug controls verbose logging. +var Debug = false + // importToGroup is a list of functions which map from an import path to // a group number. var importToGroup = []func(importPath string) (num int, ok bool){ @@ -58,7 +64,10 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri if err != nil { return nil, err } - srcDir := path.Dir(abs) + srcDir := filepath.Dir(abs) + if Debug { + log.Printf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir) + } // collect potential uses of packages. var visitor visitFn @@ -70,10 +79,14 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri case *ast.ImportSpec: if v.Name != nil { decls[v.Name.Name] = v - } else { - local := importPathToName(strings.Trim(v.Path.Value, `\"`), srcDir) - decls[local] = v + break } + ipath := strings.Trim(v.Path.Value, `\"`) + if ipath == "C" { + break + } + local := importPathToName(ipath, srcDir) + decls[local] = v case *ast.SelectorExpr: xident, ok := v.X.(*ast.Ident) if !ok { @@ -114,18 +127,22 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri astutil.DeleteNamedImport(fset, f, name, ipath) } + for pkgName, symbols := range refs { + if len(symbols) == 0 { + // skip over packages already imported + delete(refs, pkgName) + } + } + // Search for imports matching potential package references. searches := 0 type result struct { - ipath string - name string + ipath string // import path (if err == nil) + name string // optional name to rename import as err error } results := make(chan result) for pkgName, symbols := range refs { - if len(symbols) == 0 { - continue // skip over packages already imported - } go func(pkgName string, symbols map[string]bool) { ipath, rename, err := findImport(pkgName, symbols, filename) r := result{ipath: ipath, err: err} @@ -155,7 +172,7 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri } // importPathToName returns the package name for the given import path. -var importPathToName = importPathToNameGoPath +var importPathToName func(importPath, srcDir string) (packageName string) = importPathToNameGoPath // importPathToNameBasic assumes the package name is the base of import path. func importPathToNameBasic(importPath, srcDir string) (packageName string) { @@ -165,24 +182,76 @@ func importPathToNameBasic(importPath, srcDir string) (packageName string) { // importPathToNameGoPath finds out the actual package name, as declared in its .go files. // If there's a problem, it falls back to using importPathToNameBasic. func importPathToNameGoPath(importPath, srcDir string) (packageName string) { + // Fast path for standard library without going to disk: + if pkg, ok := stdImportPackage[importPath]; ok { + return pkg + } + + // TODO(bradfitz): build.Import does too much work, and + // doesn't cache either in-process or long-lived (anything + // found in the first pass from explicit imports aren't used + // again when scanning all directories). Also, it opens+reads + // *_test.go files too. As a baby step, use a cheaper + // mechanism to start (build.FindOnly), and then just read a + // single file (like parser.ParseFile in loadExportsGoPath) and + // skip only "documentation" but otherwise trust the first matching + // file's package name. + if Debug { + log.Printf("build.Import(ip=%q, srcDir=%q) ...", importPath, srcDir) + } if buildPkg, err := build.Import(importPath, srcDir, 0); err == nil { + if Debug { + log.Printf("build.Import(%q, srcDir=%q) = %q", importPath, srcDir, buildPkg.Name) + } return buildPkg.Name } else { + if Debug { + log.Printf("build.Import(%q, srcDir=%q) error: %v", importPath, srcDir, err) + } return importPathToNameBasic(importPath, srcDir) } } +var stdImportPackage = map[string]string{} // "net/http" => "http" + +func init() { + // Nothing in the standard library has a package name not + // matching its import base name. + for _, pkg := range stdlib { + if _, ok := stdImportPackage[pkg]; !ok { + stdImportPackage[pkg] = path.Base(pkg) + } + } +} + +// Directory-scanning state. +var ( + // scanGoRootOnce guards calling scanGoRoot (for $GOROOT) + scanGoRootOnce = &sync.Once{} + // scanGoPathOnce guards calling scanGoPath (for $GOPATH) + scanGoPathOnce = &sync.Once{} + + dirScanMu sync.RWMutex + dirScan map[string]*pkg // abs dir path => *pkg +) + type pkg struct { - importpath string // full pkg import path, e.g. "net/http" - dir string // absolute file path to pkg directory e.g. "/usr/lib/go/src/fmt" + dir string // absolute file path to pkg directory ("/usr/lib/go/src/net/http") + importPath string // full pkg import path ("net/http", "foo/bar/vendor/a/b") + importPathShort string // vendorless import path ("net/http", "a/b") } -var pkgIndexOnce = &sync.Once{} +// byImportPathShortLength sorts by the short import path length, breaking ties on the +// import string itself. +type byImportPathShortLength []*pkg + +func (s byImportPathShortLength) Len() int { return len(s) } +func (s byImportPathShortLength) Less(i, j int) bool { + vi, vj := s[i].importPathShort, s[j].importPathShort + return len(vi) < len(vj) || (len(vi) == len(vj) && vi < vj) -var pkgIndex struct { - sync.Mutex - m map[string][]pkg // shortname => []pkg, e.g "http" => "net/http" } +func (s byImportPathShortLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // gate is a semaphore for limiting concurrency. type gate chan struct{} @@ -212,12 +281,14 @@ func shouldTraverse(dir string, fi os.FileInfo) bool { path := filepath.Join(dir, fi.Name()) target, err := filepath.EvalSymlinks(path) if err != nil { - fmt.Fprint(os.Stderr, err) + if !os.IsNotExist(err) { + fmt.Fprintln(os.Stderr, err) + } return false } ts, err := os.Stat(target) if err != nil { - fmt.Fprint(os.Stderr, err) + fmt.Fprintln(os.Stderr, err) return false } if !ts.IsDir() { @@ -242,46 +313,73 @@ func shouldTraverse(dir string, fi os.FileInfo) bool { return true } -func loadPkgIndex() { - pkgIndex.Lock() - pkgIndex.m = make(map[string][]pkg) - pkgIndex.Unlock() +var testHookScanDir = func(dir string) {} + +func scanGoRoot() { scanGoDirs(true) } +func scanGoPath() { scanGoDirs(false) } + +func scanGoDirs(goRoot bool) { + if Debug { + which := "$GOROOT" + if !goRoot { + which = "$GOPATH" + } + log.Printf("scanning " + which) + defer log.Printf("scanned " + which) + } + dirScanMu.Lock() + if dirScan == nil { + dirScan = make(map[string]*pkg) + } + dirScanMu.Unlock() var wg sync.WaitGroup for _, path := range build.Default.SrcDirs() { + isGoroot := path == filepath.Join(build.Default.GOROOT, "src") + if isGoroot != goRoot { + continue + } fsgate.enter() + testHookScanDir(path) + if Debug { + log.Printf("scanGoDir, open dir: %v\n", path) + } f, err := os.Open(path) if err != nil { fsgate.leave() - fmt.Fprint(os.Stderr, err) + fmt.Fprintf(os.Stderr, "goimports: scanning directories: %v\n", err) continue } children, err := f.Readdir(-1) f.Close() fsgate.leave() if err != nil { - fmt.Fprint(os.Stderr, err) + fmt.Fprintf(os.Stderr, "goimports: scanning directory entries: %v\n", err) continue } for _, child := range children { - if shouldTraverse(path, child) { - wg.Add(1) - go func(path, name string) { - defer wg.Done() - loadPkg(&wg, path, name) - }(path, child.Name()) + if !shouldTraverse(path, child) { + continue } + wg.Add(1) + go func(path, name string) { + defer wg.Done() + scanDir(&wg, path, name) + }(path, child.Name()) } } wg.Wait() } -func loadPkg(wg *sync.WaitGroup, root, pkgrelpath string) { +func scanDir(wg *sync.WaitGroup, root, pkgrelpath string) { importpath := filepath.ToSlash(pkgrelpath) dir := filepath.Join(root, importpath) fsgate.enter() defer fsgate.leave() + if Debug { + log.Printf("scanning dir %s", dir) + } pkgDir, err := os.Open(dir) if err != nil { return @@ -309,60 +407,132 @@ func loadPkg(wg *sync.WaitGroup, root, pkgrelpath string) { wg.Add(1) go func(root, name string) { defer wg.Done() - loadPkg(wg, root, name) + scanDir(wg, root, name) }(root, filepath.Join(importpath, name)) } } if hasGo { - shortName := importPathToName(importpath, "") - pkgIndex.Lock() - pkgIndex.m[shortName] = append(pkgIndex.m[shortName], pkg{ - importpath: importpath, - dir: dir, - }) - pkgIndex.Unlock() + dirScanMu.Lock() + dirScan[dir] = &pkg{ + importPath: importpath, + importPathShort: vendorlessImportPath(importpath), + dir: dir, + } + dirScanMu.Unlock() } - } -// loadExports returns a list exports for a package. -var loadExports = loadExportsGoPath +// vendorlessImportPath returns the devendorized version of the provided import path. +// e.g. "foo/bar/vendor/a/b" => "a/b" +func vendorlessImportPath(ipath string) string { + // Devendorize for use in import statement. + if i := strings.LastIndex(ipath, "/vendor/"); i >= 0 { + return ipath[i+len("/vendor/"):] + } + if strings.HasPrefix(ipath, "vendor/") { + return ipath[len("vendor/"):] + } + return ipath +} -func loadExportsGoPath(dir string) map[string]bool { +// loadExports returns the set of exported symbols in the package at dir. +// It returns nil on error or if the package name in dir does not match expectPackage. +var loadExports func(expectPackage, dir string) map[string]bool = loadExportsGoPath + +func loadExportsGoPath(expectPackage, dir string) map[string]bool { + if Debug { + log.Printf("loading exports in dir %s (seeking package %s)", dir, expectPackage) + } exports := make(map[string]bool) - buildPkg, err := build.ImportDir(dir, 0) - if err != nil { - if strings.Contains(err.Error(), "no buildable Go source files in") { - return nil + + ctx := build.Default + + // ReadDir is like ioutil.ReadDir, but only returns *.go files + // and filters out _test.go files since they're not relevant + // and only slow things down. + ctx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) { + all, err := ioutil.ReadDir(dir) + if err != nil { + return nil, err } - fmt.Fprintf(os.Stderr, "could not import %q: %v\n", dir, err) + notTests = all[:0] + for _, fi := range all { + name := fi.Name() + if strings.HasSuffix(name, ".go") && !strings.HasSuffix(name, "_test.go") { + notTests = append(notTests, fi) + } + } + return notTests, nil + } + + files, err := ctx.ReadDir(dir) + if err != nil { + log.Print(err) return nil } + fset := token.NewFileSet() - for _, files := range [...][]string{buildPkg.GoFiles, buildPkg.CgoFiles} { - for _, file := range files { - f, err := parser.ParseFile(fset, filepath.Join(dir, file), nil, 0) - if err != nil { - fmt.Fprintf(os.Stderr, "could not parse %q: %v\n", file, err) - continue + + for _, fi := range files { + match, err := ctx.MatchFile(dir, fi.Name()) + if err != nil || !match { + continue + } + fullFile := filepath.Join(dir, fi.Name()) + f, err := parser.ParseFile(fset, fullFile, nil, 0) + if err != nil { + if Debug { + log.Printf("Parsing %s: %v", fullFile, err) } - for name := range f.Scope.Objects { - if ast.IsExported(name) { - exports[name] = true - } + return nil + } + pkgName := f.Name.Name + if pkgName == "documentation" { + // Special case from go/build.ImportDir, not + // handled by ctx.MatchFile. + continue + } + if pkgName != expectPackage { + if Debug { + log.Printf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) + } + return nil + } + for name := range f.Scope.Objects { + if ast.IsExported(name) { + exports[name] = true } } } + + if Debug { + exportList := make([]string, 0, len(exports)) + for k := range exports { + exportList = append(exportList, k) + } + sort.Strings(exportList) + log.Printf("scanned dir %v (package %v): exports = %v", dir, expectPackage, strings.Join(exportList, ", ")) + } return exports } // findImport searches for a package with the given symbols. -// If no package is found, findImport returns "". -// Declared as a variable rather than a function so goimports can be easily -// extended by adding a file with an init function. -var findImport = findImportGoPath +// If no package is found, findImport returns ("", false, nil) +// +// This is declared as a variable rather than a function so goimports +// can be easily extended by adding a file with an init function. +// +// The rename value tells goimports whether to use the package name as +// a local qualifier in an import. For example, if findImports("pkg", +// "X") returns ("foo/bar", rename=true), then goimports adds the +// import line: +// import pkg "foo/bar" +// to satisfy uses of pkg.X in the file. +var findImport func(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath -func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (string, bool, error) { +// findImportGoPath is the normal implementation of findImport. (Some +// companies have their own internally) +func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) { // Fast path for the standard library. // In the common case we hopefully never have to scan the GOPATH, which can // be slow with moving disks. @@ -385,57 +555,103 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string) // in the current Go file. Return rename=true when the other Go files // use a renamed package that's also used in the current file. - pkgIndexOnce.Do(loadPkgIndex) + scanGoRootOnce.Do(scanGoRoot) + if !strings.HasPrefix(filename, build.Default.GOROOT) { + scanGoPathOnce.Do(scanGoPath) + } - // Collect exports for packages with matching names. - var ( - wg sync.WaitGroup - mu sync.Mutex - shortest string - ) - pkgIndex.Lock() - for _, pkg := range pkgIndex.m[pkgName] { + var candidates []*pkg + + for _, pkg := range dirScan { + if !strings.Contains(lastTwoComponents(pkg.importPathShort), pkgName) { + // Speed optimization to minimize disk I/O: + // the last two components on disk must contain the + // package name somewhere. + // + // This permits mismatch naming like directory + // "go-foo" being package "foo", or "pkg.v3" being "pkg", + // or directory "google.golang.org/api/cloudbilling/v1" + // being package "cloudbilling", but doesn't + // permit a directory "foo" to be package + // "bar", which is strongly discouraged + // anyway. There's no reason goimports needs + // to be slow just to accomodate that. + continue + } if !canUse(filename, pkg.dir) { continue } - wg.Add(1) - go func(importpath, dir string) { - defer wg.Done() - exports := loadExports(dir) - if exports == nil { + candidates = append(candidates, pkg) + } + + sort.Sort(byImportPathShortLength(candidates)) + if Debug { + for i, pkg := range candidates { + log.Printf("%s candidate %d/%d: %v", pkgName, i+1, len(candidates), pkg.importPathShort) + } + } + + // Collect exports for packages with matching names. + + done := make(chan struct{}) // closed when we find the answer + defer close(done) + + rescv := make([]chan *pkg, len(candidates)) + for i := range candidates { + rescv[i] = make(chan *pkg, 1) + } + const maxConcurrentPackageImport = 4 + loadExportsSem := make(chan struct{}, maxConcurrentPackageImport) + + go func() { + for i, pkg := range candidates { + select { + case loadExportsSem <- struct{}{}: + case <-done: return } - // If it doesn't have the right symbols, stop. - for symbol := range symbols { - if !exports[symbol] { - return + pkg := pkg + resc := rescv[i] + go func() { + defer func() { <-loadExportsSem }() + exports := loadExports(pkgName, pkg.dir) + + // If it doesn't have the right + // symbols, send nil to mean no match. + for symbol := range symbols { + if !exports[symbol] { + pkg = nil + break + } } - } - - // Devendorize for use in import statement. - if i := strings.LastIndex(importpath, "/vendor/"); i >= 0 { - importpath = importpath[i+len("/vendor/"):] - } else if strings.HasPrefix(importpath, "vendor/") { - importpath = importpath[len("vendor/"):] - } - - // Save as the answer. - // If there are multiple candidates, the shortest wins, - // to prefer "bytes" over "github.com/foo/bytes". - mu.Lock() - if shortest == "" || len(importpath) < len(shortest) || len(importpath) == len(shortest) && importpath < shortest { - shortest = importpath - } - mu.Unlock() - }(pkg.importpath, pkg.dir) + resc <- pkg + }() + } + }() + for _, resc := range rescv { + pkg := <-resc + if pkg == nil { + continue + } + // If the package name in the source doesn't match the import path's base, + // return true so the rewriter adds a name (import foo "github.com/bar/go-foo") + needsRename := path.Base(pkg.importPath) != pkgName + return pkg.importPathShort, needsRename, nil } - pkgIndex.Unlock() - wg.Wait() - - return shortest, false, nil + return "", false, nil } +// canUse reports whether the package in dir is usable from filename, +// respecting the Go "internal" and "vendor" visibility rules. func canUse(filename, dir string) bool { + // Fast path check, before any allocations. If it doesn't contain vendor + // or internal, it's not tricky: + // Note that this can false-negative on directories like "notinternal", + // but we check it correctly below. This is just a fast path. + if !strings.Contains(dir, "vendor") && !strings.Contains(dir, "internal") { + return true + } + dirSlash := filepath.ToSlash(dir) if !strings.Contains(dirSlash, "/vendor/") && !strings.Contains(dirSlash, "/internal/") && !strings.HasSuffix(dirSlash, "/internal") { return true @@ -461,6 +677,21 @@ func canUse(filename, dir string) bool { return !strings.Contains(relSlash, "/vendor/") && !strings.Contains(relSlash, "/internal/") && !strings.HasSuffix(relSlash, "/internal") } +// lastTwoComponents returns at most the last two path components +// of v, using either / or \ as the path separator. +func lastTwoComponents(v string) string { + nslash := 0 + for i := len(v) - 1; i >= 0; i-- { + if v[i] == '/' || v[i] == '\\' { + nslash++ + if nslash == 2 { + return v[i:] + } + } + } + return v +} + type visitFn func(node ast.Node) ast.Visitor func (fn visitFn) Visit(node ast.Node) ast.Visitor { diff --git a/imports/fix_test.go b/imports/fix_test.go index 3454e22879..3efdf66ffc 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -857,22 +857,17 @@ func TestImportSymlinks(t *testing.T) { t.Fatal(err) } - pkgIndexOnce = &sync.Once{} - oldGOPATH := build.Default.GOPATH - build.Default.GOPATH = newGoPath - defer func() { - build.Default.GOPATH = oldGOPATH - visitedSymlinks.m = nil - }() + withEmptyGoPath(func() { + build.Default.GOPATH = newGoPath - input := `package p + input := `package p var ( _ = fmt.Print _ = mypkg.Foo ) ` - output := `package p + output := `package p import ( "fmt" @@ -884,13 +879,14 @@ var ( _ = mypkg.Foo ) ` - buf, err := Process(newGoPath+"/src/myotherpkg/toformat.go", []byte(input), &Options{}) - if err != nil { - t.Fatal(err) - } - if got := string(buf); got != output { - t.Fatalf("results differ\nGOT:\n%s\nWANT:\n%s\n", got, output) - } + buf, err := Process(newGoPath+"/src/myotherpkg/toformat.go", []byte(input), &Options{}) + if err != nil { + t.Fatal(err) + } + if got := string(buf); got != output { + t.Fatalf("results differ\nGOT:\n%s\nWANT:\n%s\n", got, output) + } + }) } // Test for correctly identifying the name of a vendored package when it @@ -902,30 +898,12 @@ func TestFixImportsVendorPackage(t *testing.T) { if _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")); err != nil { t.Skip(err) } - - newGoPath, err := ioutil.TempDir("", "vendortest") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(newGoPath) - - vendoredPath := newGoPath + "/src/mypkg.com/outpkg/vendor/mypkg.com/mypkg.v1" - if err := os.MkdirAll(vendoredPath, 0755); err != nil { - t.Fatal(err) - } - - pkgIndexOnce = &sync.Once{} - oldGOPATH := build.Default.GOPATH - build.Default.GOPATH = newGoPath - defer func() { - build.Default.GOPATH = oldGOPATH - }() - - if err := ioutil.WriteFile(vendoredPath+"/f.go", []byte("package mypkg\nvar Foo = 123\n"), 0666); err != nil { - t.Fatal(err) - } - - input := `package p + testConfig{ + gopathFiles: map[string]string{ + "mypkg.com/outpkg/vendor/mypkg.com/mypkg.v1/f.go": "package mypkg\nvar Foo = 123\n", + }, + }.test(t, func(t *goimportTest) { + input := `package p import ( "fmt" @@ -938,13 +916,14 @@ var ( _ = mypkg.Foo ) ` - buf, err := Process(newGoPath+"/src/mypkg.com/outpkg/toformat.go", []byte(input), &Options{}) - if err != nil { - t.Fatal(err) - } - if got := string(buf); got != input { - t.Fatalf("results differ\nGOT:\n%s\nWANT:\n%s\n", got, input) - } + buf, err := Process(filepath.Join(t.gopath, "src/mypkg.com/outpkg/toformat.go"), []byte(input), &Options{}) + if err != nil { + t.Fatal(err) + } + if got := string(buf); got != input { + t.Fatalf("results differ\nGOT:\n%s\nWANT:\n%s\n", got, input) + } + }) } func TestFindImportGoPath(t *testing.T) { @@ -954,66 +933,69 @@ func TestFindImportGoPath(t *testing.T) { } defer os.RemoveAll(goroot) - pkgIndexOnce = &sync.Once{} - origStdlib := stdlib defer func() { stdlib = origStdlib }() stdlib = nil - // Test against imaginary bits/bytes package in std lib - bytesDir := filepath.Join(goroot, "src", "pkg", "bits", "bytes") - for _, tag := range build.Default.ReleaseTags { - // Go 1.4 rearranged the GOROOT tree to remove the "pkg" path component. - if tag == "go1.4" { - bytesDir = filepath.Join(goroot, "src", "bits", "bytes") + withEmptyGoPath(func() { + // Test against imaginary bits/bytes package in std lib + bytesDir := filepath.Join(goroot, "src", "pkg", "bits", "bytes") + for _, tag := range build.Default.ReleaseTags { + // Go 1.4 rearranged the GOROOT tree to remove the "pkg" path component. + if tag == "go1.4" { + bytesDir = filepath.Join(goroot, "src", "bits", "bytes") + } } - } - if err := os.MkdirAll(bytesDir, 0755); err != nil { - t.Fatal(err) - } - bytesSrcPath := filepath.Join(bytesDir, "bytes.go") - bytesPkgPath := "bits/bytes" - bytesSrc := []byte(`package bytes + if err := os.MkdirAll(bytesDir, 0755); err != nil { + t.Fatal(err) + } + bytesSrcPath := filepath.Join(bytesDir, "bytes.go") + bytesPkgPath := "bits/bytes" + bytesSrc := []byte(`package bytes type Buffer2 struct {} `) - if err := ioutil.WriteFile(bytesSrcPath, bytesSrc, 0775); err != nil { - t.Fatal(err) - } - oldGOROOT := build.Default.GOROOT - oldGOPATH := build.Default.GOPATH - build.Default.GOROOT = goroot - build.Default.GOPATH = "" - defer func() { - build.Default.GOROOT = oldGOROOT - build.Default.GOPATH = oldGOPATH - }() + if err := ioutil.WriteFile(bytesSrcPath, bytesSrc, 0775); err != nil { + t.Fatal(err) + } + build.Default.GOROOT = goroot - got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go") - if err != nil { - t.Fatal(err) - } - if got != bytesPkgPath || rename { - t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) - } + got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go") + if err != nil { + t.Fatal(err) + } + if got != bytesPkgPath || rename { + t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) + } - got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go") - if err != nil { - t.Fatal(err) - } - if got != "" || rename { - t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename) - } + got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go") + if err != nil { + t.Fatal(err) + } + if got != "" || rename { + t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename) + } + }) } func withEmptyGoPath(fn func()) { - pkgIndexOnce = &sync.Once{} + dirScanMu.Lock() + scanGoRootOnce = &sync.Once{} + scanGoPathOnce = &sync.Once{} + dirScan = nil + dirScanMu.Unlock() + oldGOPATH := build.Default.GOPATH + oldGOROOT := build.Default.GOROOT build.Default.GOPATH = "" + visitedSymlinks.m = nil + testHookScanDir = func(string) {} defer func() { + testHookScanDir = func(string) {} build.Default.GOPATH = oldGOPATH + build.Default.GOROOT = oldGOROOT }() fn() } @@ -1033,7 +1015,7 @@ func TestFindImportInternal(t *testing.T) { t.Fatal(err) } if got != "internal/race" || rename { - t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "internal/race", false`, got, rename) + t.Errorf(`findImportGoPath("race", Acquire ...) = %q, %t; want "internal/race", false`, got, rename) } // should not be able to use internal from outside that tree @@ -1090,65 +1072,45 @@ func TestFindImportRandRead(t *testing.T) { } func TestFindImportVendor(t *testing.T) { - pkgIndexOnce = &sync.Once{} - oldGOPATH := build.Default.GOPATH - build.Default.GOPATH = "" - defer func() { - build.Default.GOPATH = oldGOPATH - }() - - _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")) - if err != nil { - t.Skip(err) - } - - got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) - if err != nil { - t.Fatal(err) - } - want := "golang.org/x/net/http2/hpack" - // Pre-1.7, we temporarily had this package under "internal" - adjust want accordingly. - _, err = os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor", want)) - if err != nil { - want = filepath.Join("internal", want) - } - if got != want || rename { - t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want %q, false`, got, rename, want) - } - - // should not be able to use vendor from outside that tree - got, rename, err = findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "x.go")) - if err != nil { - t.Fatal(err) - } - if got != "" || rename { - t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "", false`, got, rename) - } + testConfig{ + gorootFiles: map[string]string{ + "vendor/golang.org/x/net/http2/hpack/huffman.go": "package hpack\nfunc HuffmanDecode() { }\n", + }, + }.test(t, func(t *goimportTest) { + got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go")) + if err != nil { + t.Fatal(err) + } + want := "golang.org/x/net/http2/hpack" + if got != want || rename { + t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...) = %q, %t; want %q, false`, got, rename, want) + } + }) } func TestProcessVendor(t *testing.T) { - pkgIndexOnce = &sync.Once{} - oldGOPATH := build.Default.GOPATH - build.Default.GOPATH = "" - defer func() { - build.Default.GOPATH = oldGOPATH - }() + withEmptyGoPath(func() { + _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")) + if err != nil { + t.Skip(err) + } - _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")) - if err != nil { - t.Skip(err) - } + target := filepath.Join(runtime.GOROOT(), "src/math/x.go") + out, err := Process(target, []byte("package http\nimport \"bytes\"\nfunc f() { strings.NewReader(); hpack.HuffmanDecode() }\n"), nil) - target := filepath.Join(runtime.GOROOT(), "src/math/x.go") - out, err := Process(target, []byte("package http\nimport \"bytes\"\nfunc f() { strings.NewReader(); hpack.HuffmanDecode() }\n"), nil) + if err != nil { + t.Fatal(err) + } - if err != nil { - t.Fatal(err) - } - want := "golang.org/x/net/http2/hpack" - if !bytes.Contains(out, []byte(want)) { - t.Fatalf("Process(%q) did not add expected hpack import:\n%s", target, out) - } + want := "golang_org/x/net/http2/hpack" + if _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor", want)); os.IsNotExist(err) { + want = "golang.org/x/net/http2/hpack" + } + + if !bytes.Contains(out, []byte(want)) { + t.Fatalf("Process(%q) did not add expected hpack import %q; got:\n%s", target, want, out) + } + }) } func TestFindImportStdlib(t *testing.T) { @@ -1174,6 +1136,150 @@ func TestFindImportStdlib(t *testing.T) { } } +type testConfig struct { + // gorootFiles optionally specifies the complete contents of GOROOT to use, + // If nil, the normal current $GOROOT is used. + gorootFiles map[string]string // paths relative to $GOROOT/src to contents + + // gopathFiles is like gorootFiles, but for $GOPATH. + // If nil, there is no GOPATH, though. + gopathFiles map[string]string // paths relative to $GOPATH/src to contents +} + +func mustTempDir(t *testing.T, prefix string) string { + dir, err := ioutil.TempDir("", prefix) + if err != nil { + t.Fatal(err) + } + return dir +} + +func mapToDir(destDir string, files map[string]string) error { + for path, contents := range files { + file := filepath.Join(destDir, "src", path) + if err := os.MkdirAll(filepath.Dir(file), 0755); err != nil { + return err + } + if err := ioutil.WriteFile(file, []byte(contents), 0644); err != nil { + return err + } + } + return nil +} + +func (c testConfig) test(t *testing.T, fn func(*goimportTest)) { + var goroot string + var gopath string + + if c.gorootFiles != nil { + goroot = mustTempDir(t, "goroot-") + defer os.RemoveAll(goroot) + if err := mapToDir(goroot, c.gorootFiles); err != nil { + t.Fatal(err) + } + } + if c.gopathFiles != nil { + gopath = mustTempDir(t, "gopath-") + defer os.RemoveAll(gopath) + if err := mapToDir(gopath, c.gopathFiles); err != nil { + t.Fatal(err) + } + } + + withEmptyGoPath(func() { + if goroot != "" { + build.Default.GOROOT = goroot + } + build.Default.GOPATH = gopath + + it := &goimportTest{ + T: t, + goroot: build.Default.GOROOT, + gopath: gopath, + ctx: &build.Default, + } + fn(it) + }) +} + +type goimportTest struct { + *testing.T + ctx *build.Context + goroot string + gopath string +} + +// Tests that added imports are renamed when the import path's base doesn't +// match its package name. For example, we want to generate: +// +// import cloudbilling "google.golang.org/api/cloudbilling/v1" +func TestRenameWhenPackageNameMismatch(t *testing.T) { + testConfig{ + gopathFiles: map[string]string{ + "foo/bar/v1/x.go": "package bar \n const X = 1", + }, + }.test(t, func(t *goimportTest) { + buf, err := Process(t.gopath+"/src/test/t.go", []byte("package main \n const Y = bar.X"), &Options{}) + if err != nil { + t.Fatal(err) + } + const want = `package main + +import bar "foo/bar/v1" + +const Y = bar.X +` + if string(buf) != want { + t.Errorf("Got:\n%s\nWant:\n%s", buf, want) + } + }) +} + +// Tests that running goimport on files in GOROOT (for people hacking +// on Go itself) don't cause the GOPATH to be scanned (which might be +// much bigger). +func TestOptimizationWhenInGoroot(t *testing.T) { + testConfig{ + gopathFiles: map[string]string{ + "foo/foo.go": "package foo\nconst X = 1\n", + }, + }.test(t, func(t *goimportTest) { + testHookScanDir = func(dir string) { + if dir != filepath.Join(build.Default.GOROOT, "src") { + t.Errorf("unexpected dir scan of %s", dir) + } + } + const in = "package foo\n\nconst Y = bar.X\n" + buf, err := Process(t.goroot+"/src/foo/foo.go", []byte(in), nil) + if err != nil { + t.Fatal(err) + } + if string(buf) != in { + t.Errorf("got:\n%q\nwant unchanged:\n%q\n", in, buf) + } + }) +} + +// Tests that "package documentation" files are ignored. +func TestIgnoreDocumentationPackage(t *testing.T) { + testConfig{ + gopathFiles: map[string]string{ + "foo/foo.go": "package foo\nconst X = 1\n", + "foo/doc.go": "package documentation \n // just to confuse things\n", + }, + }.test(t, func(t *goimportTest) { + const in = "package x\n\nconst Y = foo.X\n" + const want = "package x\n\nimport \"foo\"\n\nconst Y = foo.X\n" + buf, err := Process(t.gopath+"/src/x/x.go", []byte(in), nil) + if err != nil { + t.Fatal(err) + } + if string(buf) != want { + t.Errorf("wrong output.\ngot:\n%q\nwant:\n%q\n", in, want) + } + }) +} + func strSet(ss []string) map[string]bool { m := make(map[string]bool) for _, s := range ss { diff --git a/refactor/rename/mvpkg_test.go b/refactor/rename/mvpkg_test.go index 9476c17ae3..1800f6b1de 100644 --- a/refactor/rename/mvpkg_test.go +++ b/refactor/rename/mvpkg_test.go @@ -12,6 +12,7 @@ import ( "path/filepath" "reflect" "regexp" + "runtime" "strings" "testing" @@ -113,6 +114,9 @@ var _ foo.T } func TestMoves(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("broken on Windows; see golang.org/issue/16384") + } tests := []struct { ctxt *build.Context from, to string