imports: rename to internal/imports

For various reasons we need an internal-facing imports API. Move imports
to internal/imports, leaving behind a small wrapper package. The wrapper
package captures the globals at time of call into the options struct.

Also converts the last goimports tests to use the test helpers, and
fixes go/packages in module mode to work with empty modules, which was
necessary to get those last tests converted.

Change-Id: Ib1212c67908741a1800b992ef1935d563c6ade32
Reviewed-on: https://go-review.googlesource.com/c/tools/+/175437
Run-TryBot: Heschi Kreinick <heschi@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
This commit is contained in:
Heschi Kreinick 2019-05-06 15:37:46 -04:00
parent 7e7c6e5214
commit 757ca719ca
24 changed files with 220 additions and 130 deletions

View File

@ -56,7 +56,13 @@ func (modules) Finalize(exported *Exported) error {
// other weird stuff, and will be the working dir for the go command. // other weird stuff, and will be the working dir for the go command.
// It depends on all the other modules. // It depends on all the other modules.
primaryDir := primaryDir(exported) primaryDir := primaryDir(exported)
if err := os.MkdirAll(primaryDir, 0755); err != nil {
return err
}
exported.Config.Dir = primaryDir exported.Config.Dir = primaryDir
if exported.written[exported.primary] == nil {
exported.written[exported.primary] = make(map[string]string)
}
exported.written[exported.primary]["go.mod"] = filepath.Join(primaryDir, "go.mod") exported.written[exported.primary]["go.mod"] = filepath.Join(primaryDir, "go.mod")
primaryGomod := "module " + exported.primary + "\nrequire (\n" primaryGomod := "module " + exported.primary + "\nrequire (\n"
for other := range exported.written { for other := range exported.written {

59
imports/forward.go Normal file
View File

@ -0,0 +1,59 @@
// Package imports implements a Go pretty-printer (like package "go/format")
// that also adds or removes import statements as necessary.
package imports // import "golang.org/x/tools/imports"
import (
"go/build"
intimp "golang.org/x/tools/internal/imports"
)
// Options specifies options for processing files.
type Options struct {
Fragment bool // Accept fragment of a source file (no package statement)
AllErrors bool // Report all errors (not just the first 10 on different lines)
Comments bool // Print comments (true if nil *Options provided)
TabIndent bool // Use tabs for indent (true if nil *Options provided)
TabWidth int // Tab width (8 if nil *Options provided)
FormatOnly bool // Disable the insertion and deletion of imports
}
// Debug controls verbose logging.
var Debug = false
// LocalPrefix is a comma-separated string of import path prefixes, which, if
// set, instructs Process to sort the import paths with the given prefixes
// into another group after 3rd-party packages.
var LocalPrefix string
// Process formats and adjusts imports for the provided file.
// If opt is nil the defaults are used.
//
// Note that filename's directory influences which imports can be chosen,
// so it is important that filename be accurate.
// To process data ``as if'' it were in filename, pass the data as a non-nil src.
func Process(filename string, src []byte, opt *Options) ([]byte, error) {
intopt := &intimp.Options{
Env: &intimp.ProcessEnv{
GOPATH: build.Default.GOPATH,
GOROOT: build.Default.GOROOT,
Debug: Debug,
LocalPrefix: LocalPrefix,
},
AllErrors: opt.AllErrors,
Comments: opt.Comments,
FormatOnly: opt.FormatOnly,
Fragment: opt.Fragment,
TabIndent: opt.TabIndent,
TabWidth: opt.TabWidth,
}
return intimp.Process(filename, src, intopt)
}
// VendorlessPath returns the devendorized version of the import path ipath.
// For example, VendorlessPath("foo/bar/vendor/a/b") returns "a/b".
func VendorlessPath(ipath string) string {
return intimp.VendorlessPath(ipath)
}

View File

@ -31,39 +31,27 @@ import (
"golang.org/x/tools/internal/gopathwalk" "golang.org/x/tools/internal/gopathwalk"
) )
// Debug controls verbose logging.
var Debug = false
// LocalPrefix is a comma-separated string of import path prefixes, which, if
// set, instructs Process to sort the import paths with the given prefixes
// into another group after 3rd-party packages.
var LocalPrefix string
func localPrefixes() []string {
if LocalPrefix != "" {
return strings.Split(LocalPrefix, ",")
}
return nil
}
// importToGroup is a list of functions which map from an import path to // importToGroup is a list of functions which map from an import path to
// a group number. // a group number.
var importToGroup = []func(importPath string) (num int, ok bool){ var importToGroup = []func(env *ProcessEnv, importPath string) (num int, ok bool){
func(importPath string) (num int, ok bool) { func(env *ProcessEnv, importPath string) (num int, ok bool) {
for _, p := range localPrefixes() { if env.LocalPrefix == "" {
return
}
for _, p := range strings.Split(env.LocalPrefix, ",") {
if strings.HasPrefix(importPath, p) || strings.TrimSuffix(p, "/") == importPath { if strings.HasPrefix(importPath, p) || strings.TrimSuffix(p, "/") == importPath {
return 3, true return 3, true
} }
} }
return return
}, },
func(importPath string) (num int, ok bool) { func(_ *ProcessEnv, importPath string) (num int, ok bool) {
if strings.HasPrefix(importPath, "appengine") { if strings.HasPrefix(importPath, "appengine") {
return 2, true return 2, true
} }
return return
}, },
func(importPath string) (num int, ok bool) { func(_ *ProcessEnv, importPath string) (num int, ok bool) {
if strings.Contains(importPath, ".") { if strings.Contains(importPath, ".") {
return 1, true return 1, true
} }
@ -71,9 +59,9 @@ var importToGroup = []func(importPath string) (num int, ok bool){
}, },
} }
func importGroup(importPath string) int { func importGroup(env *ProcessEnv, importPath string) int {
for _, fn := range importToGroup { for _, fn := range importToGroup {
if n, ok := fn(importPath); ok { if n, ok := fn(env, importPath); ok {
return n return n
} }
} }
@ -241,7 +229,7 @@ type pass struct {
fset *token.FileSet // fset used to parse f and its siblings. fset *token.FileSet // fset used to parse f and its siblings.
f *ast.File // the file being fixed. f *ast.File // the file being fixed.
srcDir string // the directory containing f. srcDir string // the directory containing f.
fixEnv *fixEnv // the environment to use for go commands, etc. env *ProcessEnv // the environment to use for go commands, etc.
loadRealPackageNames bool // if true, load package names from disk rather than guessing them. loadRealPackageNames bool // if true, load package names from disk rather than guessing them.
otherFiles []*ast.File // sibling files. otherFiles []*ast.File // sibling files.
@ -266,7 +254,7 @@ func (p *pass) loadPackageNames(imports []*importInfo) error {
unknown = append(unknown, imp.importPath) unknown = append(unknown, imp.importPath)
} }
names, err := p.fixEnv.getResolver().loadPackageNames(unknown, p.srcDir) names, err := p.env.getResolver().loadPackageNames(unknown, p.srcDir)
if err != nil { if err != nil {
return err return err
} }
@ -324,7 +312,7 @@ func (p *pass) load() bool {
if p.loadRealPackageNames { if p.loadRealPackageNames {
err := p.loadPackageNames(append(imports, p.candidates...)) err := p.loadPackageNames(append(imports, p.candidates...))
if err != nil { if err != nil {
if Debug { if p.env.Debug {
log.Printf("loading package names: %v", err) log.Printf("loading package names: %v", err)
} }
return false return false
@ -448,13 +436,13 @@ func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) {
// easily be extended by adding a file with an init function. // easily be extended by adding a file with an init function.
var fixImports = fixImportsDefault var fixImports = fixImportsDefault
func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *fixEnv) error { func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error {
abs, err := filepath.Abs(filename) abs, err := filepath.Abs(filename)
if err != nil { if err != nil {
return err return err
} }
srcDir := filepath.Dir(abs) srcDir := filepath.Dir(abs)
if Debug { if env.Debug {
log.Printf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir) log.Printf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir)
} }
@ -486,7 +474,7 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *f
// Third pass: get real package names where we had previously used // Third pass: get real package names where we had previously used
// the naive algorithm. This is the first step that will use the // the naive algorithm. This is the first step that will use the
// environment, so we provide it here for the first time. // environment, so we provide it here for the first time.
p = &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env} p = &pass{fset: fset, f: f, srcDir: srcDir, env: env}
p.loadRealPackageNames = true p.loadRealPackageNames = true
p.otherFiles = otherFiles p.otherFiles = otherFiles
if p.load() { if p.load() {
@ -510,9 +498,12 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *f
return nil return nil
} }
// fixEnv contains environment variables and settings that affect the use of // ProcessEnv contains environment variables and settings that affect the use of
// the go command, the go/build package, etc. // the go command, the go/build package, etc.
type fixEnv struct { type ProcessEnv struct {
LocalPrefix string
Debug bool
// If non-empty, these will be used instead of the // If non-empty, these will be used instead of the
// process-wide values. // process-wide values.
GOPATH, GOROOT, GO111MODULE, GOPROXY, GOFLAGS string GOPATH, GOROOT, GO111MODULE, GOPROXY, GOFLAGS string
@ -524,7 +515,7 @@ type fixEnv struct {
resolver resolver resolver resolver
} }
func (e *fixEnv) env() []string { func (e *ProcessEnv) env() []string {
env := os.Environ() env := os.Environ()
add := func(k, v string) { add := func(k, v string) {
if v != "" { if v != "" {
@ -542,7 +533,7 @@ func (e *fixEnv) env() []string {
return env return env
} }
func (e *fixEnv) getResolver() resolver { func (e *ProcessEnv) getResolver() resolver {
if e.resolver != nil { if e.resolver != nil {
return e.resolver return e.resolver
} }
@ -557,7 +548,7 @@ func (e *fixEnv) getResolver() resolver {
return &moduleResolver{env: e} return &moduleResolver{env: e}
} }
func (e *fixEnv) newPackagesConfig(mode packages.LoadMode) *packages.Config { func (e *ProcessEnv) newPackagesConfig(mode packages.LoadMode) *packages.Config {
return &packages.Config{ return &packages.Config{
Mode: mode, Mode: mode,
Dir: e.WorkingDir, Dir: e.WorkingDir,
@ -565,14 +556,14 @@ func (e *fixEnv) newPackagesConfig(mode packages.LoadMode) *packages.Config {
} }
} }
func (e *fixEnv) buildContext() *build.Context { func (e *ProcessEnv) buildContext() *build.Context {
ctx := build.Default ctx := build.Default
ctx.GOROOT = e.GOROOT ctx.GOROOT = e.GOROOT
ctx.GOPATH = e.GOPATH ctx.GOPATH = e.GOPATH
return &ctx return &ctx
} }
func (e *fixEnv) invokeGo(args ...string) (*bytes.Buffer, error) { func (e *ProcessEnv) invokeGo(args ...string) (*bytes.Buffer, error) {
cmd := exec.Command("go", args...) cmd := exec.Command("go", args...)
stdout := &bytes.Buffer{} stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{} stderr := &bytes.Buffer{}
@ -581,7 +572,7 @@ func (e *fixEnv) invokeGo(args ...string) (*bytes.Buffer, error) {
cmd.Env = e.env() cmd.Env = e.env()
cmd.Dir = e.WorkingDir cmd.Dir = e.WorkingDir
if Debug { if e.Debug {
defer func(start time.Time) { log.Printf("%s for %v", time.Since(start), cmdDebugStr(cmd)) }(time.Now()) defer func(start time.Time) { log.Printf("%s for %v", time.Since(start), cmdDebugStr(cmd)) }(time.Now())
} }
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
@ -632,7 +623,7 @@ type resolver interface {
// gopathResolver implements resolver for GOPATH and module workspaces using go/packages. // gopathResolver implements resolver for GOPATH and module workspaces using go/packages.
type goPackagesResolver struct { type goPackagesResolver struct {
env *fixEnv env *ProcessEnv
} }
func (r *goPackagesResolver) loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) { func (r *goPackagesResolver) loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) {
@ -680,7 +671,7 @@ func (r *goPackagesResolver) scan(refs references) ([]*pkg, error) {
} }
func addExternalCandidates(pass *pass, refs references, filename string) error { func addExternalCandidates(pass *pass, refs references, filename string) error {
dirScan, err := pass.fixEnv.getResolver().scan(refs) dirScan, err := pass.env.getResolver().scan(refs)
if err != nil { if err != nil {
return err return err
} }
@ -707,7 +698,7 @@ func addExternalCandidates(pass *pass, refs references, filename string) error {
go func(pkgName string, symbols map[string]bool) { go func(pkgName string, symbols map[string]bool) {
defer wg.Done() defer wg.Done()
found, err := findImport(ctx, pass.fixEnv, dirScan, pkgName, symbols, filename) found, err := findImport(ctx, pass.env, dirScan, pkgName, symbols, filename)
if err != nil { if err != nil {
firstErrOnce.Do(func() { firstErrOnce.Do(func() {
@ -778,7 +769,7 @@ func importPathToAssumedName(importPath string) string {
// gopathResolver implements resolver for GOPATH workspaces. // gopathResolver implements resolver for GOPATH workspaces.
type gopathResolver struct { type gopathResolver struct {
env *fixEnv env *ProcessEnv
} }
func (r *gopathResolver) loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) { func (r *gopathResolver) loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) {
@ -791,7 +782,7 @@ func (r *gopathResolver) loadPackageNames(importPaths []string, srcDir string) (
// importPathToNameGoPath finds out the actual package name, as declared in its .go files. // importPathToNameGoPath finds out the actual package name, as declared in its .go files.
// If there's a problem, it returns "". // If there's a problem, it returns "".
func importPathToName(env *fixEnv, importPath, srcDir string) (packageName string) { func importPathToName(env *ProcessEnv, importPath, srcDir string) (packageName string) {
// Fast path for standard library without going to disk. // Fast path for standard library without going to disk.
if _, ok := stdlib[importPath]; ok { if _, ok := stdlib[importPath]; ok {
return path.Base(importPath) // stdlib packages always match their paths. return path.Base(importPath) // stdlib packages always match their paths.
@ -927,7 +918,7 @@ func (r *gopathResolver) scan(_ references) ([]*pkg, error) {
dir: dir, dir: dir,
}) })
} }
gopathwalk.Walk(gopathwalk.SrcDirsRoots(r.env.buildContext()), add, gopathwalk.Options{Debug: Debug, ModulesEnabled: false}) gopathwalk.Walk(gopathwalk.SrcDirsRoots(r.env.buildContext()), add, gopathwalk.Options{Debug: r.env.Debug, ModulesEnabled: false})
return result, nil return result, nil
} }
@ -946,8 +937,8 @@ func VendorlessPath(ipath string) string {
// loadExports returns the set of exported symbols in the package at dir. // loadExports returns the set of exported symbols in the package at dir.
// It returns nil on error or if the package name in dir does not match expectPackage. // It returns nil on error or if the package name in dir does not match expectPackage.
func loadExports(ctx context.Context, env *fixEnv, expectPackage string, pkg *pkg) (map[string]bool, error) { func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg *pkg) (map[string]bool, error) {
if Debug { if env.Debug {
log.Printf("loading exports in dir %s (seeking package %s)", pkg.dir, expectPackage) log.Printf("loading exports in dir %s (seeking package %s)", pkg.dir, expectPackage)
} }
if pkg.goPackage != nil { if pkg.goPackage != nil {
@ -1020,7 +1011,7 @@ func loadExports(ctx context.Context, env *fixEnv, expectPackage string, pkg *pk
} }
} }
if Debug { if env.Debug {
exportList := make([]string, 0, len(exports)) exportList := make([]string, 0, len(exports))
for k := range exports { for k := range exports {
exportList = append(exportList, k) exportList = append(exportList, k)
@ -1033,7 +1024,7 @@ func loadExports(ctx context.Context, env *fixEnv, expectPackage string, pkg *pk
// findImport searches for a package with the given symbols. // findImport searches for a package with the given symbols.
// If no package is found, findImport returns ("", false, nil) // If no package is found, findImport returns ("", false, nil)
func findImport(ctx context.Context, env *fixEnv, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) { func findImport(ctx context.Context, env *ProcessEnv, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) {
pkgDir, err := filepath.Abs(filename) pkgDir, err := filepath.Abs(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1056,7 +1047,7 @@ func findImport(ctx context.Context, env *fixEnv, dirScan []*pkg, pkgName string
// ones. Note that this sorts by the de-vendored name, so // ones. Note that this sorts by the de-vendored name, so
// there's no "penalty" for vendoring. // there's no "penalty" for vendoring.
sort.Sort(byDistanceOrImportPathShortLength(candidates)) sort.Sort(byDistanceOrImportPathShortLength(candidates))
if Debug { if env.Debug {
for i, c := range candidates { for i, c := range candidates {
log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir) log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
} }
@ -1097,7 +1088,7 @@ func findImport(ctx context.Context, env *fixEnv, dirScan []*pkg, pkgName string
exports, err := loadExports(ctx, env, pkgName, c.pkg) exports, err := loadExports(ctx, env, pkgName, c.pkg)
if err != nil { if err != nil {
if Debug { if env.Debug {
log.Printf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err) log.Printf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
} }
resc <- nil resc <- nil

View File

@ -5,6 +5,7 @@
package imports package imports
import ( import (
"flag"
"fmt" "fmt"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -14,6 +15,8 @@ import (
"golang.org/x/tools/go/packages/packagestest" "golang.org/x/tools/go/packages/packagestest"
) )
var testDebug = flag.Bool("debug", false, "enable debug output")
var tests = []struct { var tests = []struct {
name string name string
formatOnly bool formatOnly bool
@ -1116,8 +1119,7 @@ var _, _ = rand.Read, rand.NewZipf
} }
func TestSimpleCases(t *testing.T) { func TestSimpleCases(t *testing.T) {
defer func(lp string) { LocalPrefix = lp }(LocalPrefix) const localPrefix = "local.com,github.com/local"
LocalPrefix = "local.com,github.com/local"
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
options := &Options{ options := &Options{
@ -1163,7 +1165,11 @@ func TestSimpleCases(t *testing.T) {
Files: fm{"bar/x.go": "package bar\nfunc Bar(){}\n"}, Files: fm{"bar/x.go": "package bar\nfunc Bar(){}\n"},
}, },
}, },
}.processTest(t, "golang.org/fake", "x.go", nil, options, tt.out) }.test(t, func(t *goimportTest) {
t.env.LocalPrefix = localPrefix
t.assertProcessEquals("golang.org/fake", "x.go", nil, options, tt.out)
})
}) })
} }
} }
@ -1485,20 +1491,28 @@ func TestFindStdlib(t *testing.T) {
for _, sym := range tt.symbols { for _, sym := range tt.symbols {
input += fmt.Sprintf("var _ = %s.%s\n", tt.pkg, sym) input += fmt.Sprintf("var _ = %s.%s\n", tt.pkg, sym)
} }
buf, err := Process("x.go", []byte(input), &Options{}) testConfig{
if err != nil { module: packagestest.Module{
t.Fatal(err) Name: "foo.com",
} Files: fm{"x.go": input},
if got := string(buf); !strings.Contains(got, tt.want) { },
t.Errorf("Process(%q) = %q, wanted it to contain %q", input, buf, tt.want) }.test(t, func(t *goimportTest) {
} buf, err := t.process("foo.com", "x.go", nil, nil)
if err != nil {
t.Fatal(err)
}
if got := string(buf); !strings.Contains(got, tt.want) {
t.Errorf("Process(%q) = %q, wanted it to contain %q", input, buf, tt.want)
}
})
} }
} }
type testConfig struct { type testConfig struct {
gopathOnly bool gopathOnly bool
module packagestest.Module goPackagesIncompatible bool
modules []packagestest.Module module packagestest.Module
modules []packagestest.Module
} }
// fm is the type for a packagestest.Module's Files, abbreviated for shorter lines. // fm is the type for a packagestest.Module's Files, abbreviated for shorter lines.
@ -1522,6 +1536,12 @@ func (c testConfig) test(t *testing.T, fn func(*goimportTest)) {
forceGoPackages := false forceGoPackages := false
var exporter packagestest.Exporter var exporter packagestest.Exporter
if c.gopathOnly && strings.HasPrefix(kind, "Modules") {
t.Skip("test marked GOPATH-only")
}
if c.goPackagesIncompatible && strings.HasSuffix(kind, "_GoPackages") {
t.Skip("test marked go/packages-incompatible")
}
switch kind { switch kind {
case "GOPATH": case "GOPATH":
exporter = packagestest.GOPATH exporter = packagestest.GOPATH
@ -1529,14 +1549,8 @@ func (c testConfig) test(t *testing.T, fn func(*goimportTest)) {
exporter = packagestest.GOPATH exporter = packagestest.GOPATH
forceGoPackages = true forceGoPackages = true
case "Modules": case "Modules":
if c.gopathOnly {
t.Skip("test marked GOPATH-only")
}
exporter = packagestest.Modules exporter = packagestest.Modules
case "Modules_GoPackages": case "Modules_GoPackages":
if c.gopathOnly {
t.Skip("test marked GOPATH-only")
}
exporter = packagestest.Modules exporter = packagestest.Modules
forceGoPackages = true forceGoPackages = true
default: default:
@ -1554,12 +1568,13 @@ func (c testConfig) test(t *testing.T, fn func(*goimportTest)) {
it := &goimportTest{ it := &goimportTest{
T: t, T: t,
fixEnv: &fixEnv{ env: &ProcessEnv{
GOROOT: env["GOROOT"], GOROOT: env["GOROOT"],
GOPATH: env["GOPATH"], GOPATH: env["GOPATH"],
GO111MODULE: env["GO111MODULE"], GO111MODULE: env["GO111MODULE"],
WorkingDir: exported.Config.Dir, WorkingDir: exported.Config.Dir,
ForceGoPackages: forceGoPackages, ForceGoPackages: forceGoPackages,
Debug: *testDebug,
}, },
exported: exported, exported: exported,
} }
@ -1572,23 +1587,36 @@ func (c testConfig) processTest(t *testing.T, module, file string, contents []by
t.Helper() t.Helper()
c.test(t, func(t *goimportTest) { c.test(t, func(t *goimportTest) {
t.Helper() t.Helper()
t.process(module, file, contents, opts, want) t.assertProcessEquals(module, file, contents, opts, want)
}) })
} }
type goimportTest struct { type goimportTest struct {
*testing.T *testing.T
fixEnv *fixEnv env *ProcessEnv
exported *packagestest.Exported exported *packagestest.Exported
} }
func (t *goimportTest) process(module, file string, contents []byte, opts *Options, want string) { func (t *goimportTest) process(module, file string, contents []byte, opts *Options) ([]byte, error) {
t.Helper() t.Helper()
f := t.exported.File(module, file) f := t.exported.File(module, file)
if f == "" { if f == "" {
t.Fatalf("%v not found in exported files (typo in filename?)", file) t.Fatalf("%v not found in exported files (typo in filename?)", file)
} }
buf, err := process(f, contents, opts, t.fixEnv) return t.processNonModule(f, contents, opts)
}
func (t *goimportTest) processNonModule(file string, contents []byte, opts *Options) ([]byte, error) {
if opts == nil {
opts = &Options{Comments: true, TabIndent: true, TabWidth: 8}
}
opts.Env = t.env
opts.Env.Debug = *testDebug
return Process(file, contents, opts)
}
func (t *goimportTest) assertProcessEquals(module, file string, contents []byte, opts *Options, want string) {
buf, err := t.process(module, file, contents, opts)
if err != nil { if err != nil {
t.Fatalf("Process() = %v", err) t.Fatalf("Process() = %v", err)
} }
@ -1775,9 +1803,8 @@ const _ = runtime.GOOS
Files: fm{"t.go": tt.src}, Files: fm{"t.go": tt.src},
}}, tt.modules...), }}, tt.modules...),
}.test(t, func(t *goimportTest) { }.test(t, func(t *goimportTest) {
defer func(s string) { LocalPrefix = s }(LocalPrefix) t.env.LocalPrefix = tt.localPrefix
LocalPrefix = tt.localPrefix t.assertProcessEquals("test.com", "t.go", nil, nil, tt.want)
t.process("test.com", "t.go", nil, nil, tt.want)
}) })
}) })
} }
@ -1827,7 +1854,7 @@ func TestImportPathToNameGoPathParse(t *testing.T) {
if strings.Contains(t.Name(), "GoPackages") { if strings.Contains(t.Name(), "GoPackages") {
t.Skip("go/packages does not ignore package main") t.Skip("go/packages does not ignore package main")
} }
r := t.fixEnv.getResolver() r := t.env.getResolver()
srcDir := filepath.Dir(t.exported.File("example.net/pkg", "z.go")) srcDir := filepath.Dir(t.exported.File("example.net/pkg", "z.go"))
names, err := r.loadPackageNames([]string{"example.net/pkg"}, srcDir) names, err := r.loadPackageNames([]string{"example.net/pkg"}, srcDir)
if err != nil { if err != nil {
@ -2220,13 +2247,19 @@ func TestPkgIsCandidate(t *testing.T) {
// Issue 20941: this used to panic on Windows. // Issue 20941: this used to panic on Windows.
func TestProcessStdin(t *testing.T) { func TestProcessStdin(t *testing.T) {
got, err := Process("<standard input>", []byte("package main\nfunc main() {\n\tfmt.Println(123)\n}\n"), nil) testConfig{
if err != nil { module: packagestest.Module{
t.Fatal(err) Name: "foo.com",
} },
if !strings.Contains(string(got), `"fmt"`) { }.test(t, func(t *goimportTest) {
t.Errorf("expected fmt import; got: %s", got) got, err := t.processNonModule("<standard input>", []byte("package main\nfunc main() {\n\tfmt.Println(123)\n}\n"), nil)
} if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(got), `"fmt"`) {
t.Errorf("expected fmt import; got: %s", got)
}
})
} }
// Tests LocalPackagePromotion when there is a local package that matches, it // Tests LocalPackagePromotion when there is a local package that matches, it
@ -2324,14 +2357,21 @@ import "bytes"
var _ = &bytes.Buffer{} var _ = &bytes.Buffer{}
` `
testConfig{
goPackagesIncompatible: true,
module: packagestest.Module{
Name: "mycompany.net",
},
}.test(t, func(t *goimportTest) {
buf, err := t.processNonModule("mycompany.net/tool/main.go", []byte(input), nil)
if err != nil {
t.Fatalf("Process() = %v", err)
}
if string(buf) != want {
t.Errorf("Got:\n%s\nWant:\n%s", buf, want)
}
})
buf, err := Process("mycompany.net/tool/main.go", []byte(input), nil)
if err != nil {
t.Fatalf("Process() = %v", err)
}
if string(buf) != want {
t.Errorf("Got:\n%s\nWant:\n%s", buf, want)
}
} }
// Ensures a token as large as 500000 bytes can be handled // Ensures a token as large as 500000 bytes can be handled

View File

@ -6,14 +6,13 @@
// Package imports implements a Go pretty-printer (like package "go/format") // Package imports implements a Go pretty-printer (like package "go/format")
// that also adds or removes import statements as necessary. // that also adds or removes import statements as necessary.
package imports // import "golang.org/x/tools/imports" package imports
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build"
"go/format" "go/format"
"go/parser" "go/parser"
"go/printer" "go/printer"
@ -27,8 +26,10 @@ import (
"golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ast/astutil"
) )
// Options specifies options for processing files. // Options is golang.org/x/tools/imports.Options with extra internal-only options.
type Options struct { type Options struct {
Env *ProcessEnv // The environment to use. Note: this contains the cached module and filesystem state.
Fragment bool // Accept fragment of a source file (no package statement) Fragment bool // Accept fragment of a source file (no package statement)
AllErrors bool // Report all errors (not just the first 10 on different lines) AllErrors bool // Report all errors (not just the first 10 on different lines)
@ -39,18 +40,8 @@ type Options struct {
FormatOnly bool // Disable the insertion and deletion of imports FormatOnly bool // Disable the insertion and deletion of imports
} }
// Process formats and adjusts imports for the provided file. // Process implements golang.org/x/tools/imports.Process with explicit context in env.
// If opt is nil the defaults are used.
//
// Note that filename's directory influences which imports can be chosen,
// so it is important that filename be accurate.
// To process data ``as if'' it were in filename, pass the data as a non-nil src.
func Process(filename string, src []byte, opt *Options) ([]byte, error) { func Process(filename string, src []byte, opt *Options) ([]byte, error) {
env := &fixEnv{GOPATH: build.Default.GOPATH, GOROOT: build.Default.GOROOT}
return process(filename, src, opt, env)
}
func process(filename string, src []byte, opt *Options, env *fixEnv) ([]byte, error) {
if opt == nil { if opt == nil {
opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
} }
@ -69,12 +60,12 @@ func process(filename string, src []byte, opt *Options, env *fixEnv) ([]byte, er
} }
if !opt.FormatOnly { if !opt.FormatOnly {
if err := fixImports(fileSet, file, filename, env); err != nil { if err := fixImports(fileSet, file, filename, opt.Env); err != nil {
return nil, err return nil, err
} }
} }
sortImports(fileSet, file) sortImports(opt.Env, fileSet, file)
imps := astutil.Imports(fileSet, file) imps := astutil.Imports(fileSet, file)
var spacesBefore []string // import paths we need spaces before var spacesBefore []string // import paths we need spaces before
for _, impSection := range imps { for _, impSection := range imps {
@ -85,7 +76,7 @@ func process(filename string, src []byte, opt *Options, env *fixEnv) ([]byte, er
lastGroup := -1 lastGroup := -1
for _, importSpec := range impSection { for _, importSpec := range impSection {
importPath, _ := strconv.Unquote(importSpec.Path.Value) importPath, _ := strconv.Unquote(importSpec.Path.Value)
groupNum := importGroup(importPath) groupNum := importGroup(opt.Env, importPath)
if groupNum != lastGroup && lastGroup != -1 { if groupNum != lastGroup && lastGroup != -1 {
spacesBefore = append(spacesBefore, importPath) spacesBefore = append(spacesBefore, importPath)
} }

View File

@ -8,7 +8,7 @@
// standard library. The file is intended to be built as part of the imports // standard library. The file is intended to be built as part of the imports
// package, so that the package may be used in environments where a GOROOT is // package, so that the package may be used in environments where a GOROOT is
// not available (such as App Engine). // not available (such as App Engine).
package main package imports
import ( import (
"bytes" "bytes"

View File

@ -3,7 +3,7 @@
// mkstdlib generates the zstdlib.go file, containing the Go standard // mkstdlib generates the zstdlib.go file, containing the Go standard
// library API symbols. It's baked into the binary to avoid scanning // library API symbols. It's baked into the binary to avoid scanning
// GOPATH in the common case. // GOPATH in the common case.
package main package imports
import ( import (
"bufio" "bufio"

View File

@ -22,7 +22,7 @@ import (
// moduleResolver implements resolver for modules using the go command as little // moduleResolver implements resolver for modules using the go command as little
// as feasible. // as feasible.
type moduleResolver struct { type moduleResolver struct {
env *fixEnv env *ProcessEnv
initialized bool initialized bool
main *moduleJSON main *moduleJSON
@ -62,7 +62,7 @@ func (r *moduleResolver) init() error {
return err return err
} }
if mod.Dir == "" { if mod.Dir == "" {
if Debug { if r.env.Debug {
log.Printf("module %v has not been downloaded and will be ignored", mod.Path) log.Printf("module %v has not been downloaded and will be ignored", mod.Path)
} }
// Can't do anything with a module that's not downloaded. // Can't do anything with a module that's not downloaded.
@ -253,7 +253,7 @@ func (r *moduleResolver) scan(_ references) ([]*pkg, error) {
matches := modCacheRegexp.FindStringSubmatch(subdir) matches := modCacheRegexp.FindStringSubmatch(subdir)
modPath, err := module.DecodePath(filepath.ToSlash(matches[1])) modPath, err := module.DecodePath(filepath.ToSlash(matches[1]))
if err != nil { if err != nil {
if Debug { if r.env.Debug {
log.Printf("decoding module cache path %q: %v", subdir, err) log.Printf("decoding module cache path %q: %v", subdir, err)
} }
return return
@ -303,7 +303,7 @@ func (r *moduleResolver) scan(_ references) ([]*pkg, error) {
importPathShort: VendorlessPath(importPath), importPathShort: VendorlessPath(importPath),
dir: dir, dir: dir,
}) })
}, gopathwalk.Options{Debug: Debug, ModulesEnabled: true}) }, gopathwalk.Options{Debug: r.env.Debug, ModulesEnabled: true})
return result, nil return result, nil
} }

View File

@ -485,7 +485,7 @@ var proxyDir string
type modTest struct { type modTest struct {
*testing.T *testing.T
env *fixEnv env *ProcessEnv
resolver *moduleResolver resolver *moduleResolver
cleanup func() cleanup func()
} }
@ -515,7 +515,7 @@ func setup(t *testing.T, main, wd string) *modTest {
t.Fatal(err) t.Fatal(err)
} }
env := &fixEnv{ env := &ProcessEnv{
GOROOT: build.Default.GOROOT, GOROOT: build.Default.GOROOT,
GOPATH: filepath.Join(dir, "gopath"), GOPATH: filepath.Join(dir, "gopath"),
GO111MODULE: "on", GO111MODULE: "on",

View File

@ -15,7 +15,7 @@ import (
// sortImports sorts runs of consecutive import lines in import blocks in f. // sortImports sorts runs of consecutive import lines in import blocks in f.
// It also removes duplicate imports when it is possible to do so without data loss. // It also removes duplicate imports when it is possible to do so without data loss.
func sortImports(fset *token.FileSet, f *ast.File) { func sortImports(env *ProcessEnv, fset *token.FileSet, f *ast.File) {
for i, d := range f.Decls { for i, d := range f.Decls {
d, ok := d.(*ast.GenDecl) d, ok := d.(*ast.GenDecl)
if !ok || d.Tok != token.IMPORT { if !ok || d.Tok != token.IMPORT {
@ -40,11 +40,11 @@ func sortImports(fset *token.FileSet, f *ast.File) {
for j, s := range d.Specs { for j, s := range d.Specs {
if j > i && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[j-1].End()).Line { if j > i && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[j-1].End()).Line {
// j begins a new run. End this one. // j begins a new run. End this one.
specs = append(specs, sortSpecs(fset, f, d.Specs[i:j])...) specs = append(specs, sortSpecs(env, fset, f, d.Specs[i:j])...)
i = j i = j
} }
} }
specs = append(specs, sortSpecs(fset, f, d.Specs[i:])...) specs = append(specs, sortSpecs(env, fset, f, d.Specs[i:])...)
d.Specs = specs d.Specs = specs
// Deduping can leave a blank line before the rparen; clean that up. // Deduping can leave a blank line before the rparen; clean that up.
@ -95,7 +95,7 @@ type posSpan struct {
End token.Pos End token.Pos
} }
func sortSpecs(fset *token.FileSet, f *ast.File, specs []ast.Spec) []ast.Spec { func sortSpecs(env *ProcessEnv, fset *token.FileSet, f *ast.File, specs []ast.Spec) []ast.Spec {
// Can't short-circuit here even if specs are already sorted, // Can't short-circuit here even if specs are already sorted,
// since they might yet need deduplication. // since they might yet need deduplication.
// A lone import, however, may be safely ignored. // A lone import, however, may be safely ignored.
@ -144,7 +144,7 @@ func sortSpecs(fset *token.FileSet, f *ast.File, specs []ast.Spec) []ast.Spec {
// Reassign the import paths to have the same position sequence. // Reassign the import paths to have the same position sequence.
// Reassign each comment to abut the end of its spec. // Reassign each comment to abut the end of its spec.
// Sort the comments by new position. // Sort the comments by new position.
sort.Sort(byImportSpec(specs)) sort.Sort(byImportSpec{env, specs})
// Dedup. Thanks to our sorting, we can just consider // Dedup. Thanks to our sorting, we can just consider
// adjacent pairs of imports. // adjacent pairs of imports.
@ -197,16 +197,19 @@ func sortSpecs(fset *token.FileSet, f *ast.File, specs []ast.Spec) []ast.Spec {
return specs return specs
} }
type byImportSpec []ast.Spec // slice of *ast.ImportSpec type byImportSpec struct {
env *ProcessEnv
specs []ast.Spec // slice of *ast.ImportSpec
}
func (x byImportSpec) Len() int { return len(x) } func (x byImportSpec) Len() int { return len(x.specs) }
func (x byImportSpec) Swap(i, j int) { x[i], x[j] = x[j], x[i] } func (x byImportSpec) Swap(i, j int) { x.specs[i], x.specs[j] = x.specs[j], x.specs[i] }
func (x byImportSpec) Less(i, j int) bool { func (x byImportSpec) Less(i, j int) bool {
ipath := importPath(x[i]) ipath := importPath(x.specs[i])
jpath := importPath(x[j]) jpath := importPath(x.specs[j])
igroup := importGroup(ipath) igroup := importGroup(x.env, ipath)
jgroup := importGroup(jpath) jgroup := importGroup(x.env, jpath)
if igroup != jgroup { if igroup != jgroup {
return igroup < jgroup return igroup < jgroup
} }
@ -214,13 +217,13 @@ func (x byImportSpec) Less(i, j int) bool {
if ipath != jpath { if ipath != jpath {
return ipath < jpath return ipath < jpath
} }
iname := importName(x[i]) iname := importName(x.specs[i])
jname := importName(x[j]) jname := importName(x.specs[j])
if iname != jname { if iname != jname {
return iname < jname return iname < jname
} }
return importComment(x[i]) < importComment(x[j]) return importComment(x.specs[i]) < importComment(x.specs[j])
} }
type byCommentPos []*ast.CommentGroup type byCommentPos []*ast.CommentGroup