diff --git a/internal/imports/fix.go b/internal/imports/fix.go index 72b43bd884..fb70790fc8 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -585,49 +585,39 @@ func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv return fixes, nil } -// getAllCandidates gets all of the candidates to be imported, regardless of if they are needed. -func getAllCandidates(filename string, env *ProcessEnv) ([]ImportFix, error) { +// getCandidatePkgs returns the list of pkgs that are accessible from filename, +// optionall filtered to only packages named pkgName. +func getCandidatePkgs(pkgName, filename string, env *ProcessEnv) ([]*pkg, error) { // TODO(heschi): filter out current package. (Don't forget x_test can import x.) + var result []*pkg // Start off with the standard library. - var imports []ImportFix for importPath := range stdlib { - imports = append(imports, ImportFix{ - StmtInfo: ImportInfo{ - ImportPath: importPath, - }, - IdentName: path.Base(importPath), - FixType: AddImport, + if pkgName != "" && path.Base(importPath) != pkgName { + continue + } + result = append(result, &pkg{ + dir: filepath.Join(env.GOROOT, "src", importPath), + importPathShort: importPath, + packageName: path.Base(importPath), + relevance: 0, }) } - // Sort the stdlib bits solely by name. - sort.Slice(imports, func(i int, j int) bool { - return imports[i].StmtInfo.ImportPath < imports[j].StmtInfo.ImportPath - }) // Exclude goroot results -- getting them is relatively expensive, not cached, // and generally redundant with the in-memory version. exclude := []gopathwalk.RootType{gopathwalk.RootGOROOT} // Only the go/packages resolver uses the first argument, and nobody uses that resolver. - pkgs, err := env.GetResolver().scan(nil, true, exclude) + scannedPkgs, err := env.GetResolver().scan(nil, true, exclude) if err != nil { return nil, err } - // Sort first by relevance, then by name, so that when we add them they're - // still in order. - sort.Slice(pkgs, func(i, j int) bool { - pi, pj := pkgs[i], pkgs[j] - if pi.relevance < pj.relevance { - return true - } - if pi.relevance > pj.relevance { - return false - } - return pi.packageName < pj.packageName - }) dupCheck := map[string]struct{}{} - for _, pkg := range pkgs { + for _, pkg := range scannedPkgs { + if pkgName != "" && pkg.packageName != pkgName { + continue + } if !canUse(filename, pkg.dir) { continue } @@ -635,7 +625,33 @@ func getAllCandidates(filename string, env *ProcessEnv) ([]ImportFix, error) { continue } dupCheck[pkg.importPathShort] = struct{}{} - imports = append(imports, ImportFix{ + result = append(result, pkg) + } + + // Sort first by relevance, then by package name, with import path as a tiebreaker. + sort.Slice(result, func(i, j int) bool { + pi, pj := result[i], result[j] + if pi.relevance != pj.relevance { + return pi.relevance < pj.relevance + } + if pi.packageName != pj.packageName { + return pi.packageName < pj.packageName + } + return pi.importPathShort < pj.importPathShort + }) + + return result, nil +} + +// getAllCandidates gets all of the candidates to be imported, regardless of if they are needed. +func getAllCandidates(filename string, env *ProcessEnv) ([]ImportFix, error) { + pkgs, err := getCandidatePkgs("", filename, env) + if err != nil { + return nil, err + } + result := make([]ImportFix, 0, len(pkgs)) + for _, pkg := range pkgs { + result = append(result, ImportFix{ StmtInfo: ImportInfo{ ImportPath: pkg.importPathShort, }, @@ -643,7 +659,54 @@ func getAllCandidates(filename string, env *ProcessEnv) ([]ImportFix, error) { FixType: AddImport, }) } - return imports, nil + return result, nil +} + +// A PackageExport is a package and its exports. +type PackageExport struct { + Fix *ImportFix + Exports []string +} + +func getPackageExports(completePackage, filename string, env *ProcessEnv) ([]PackageExport, error) { + pkgs, err := getCandidatePkgs(completePackage, filename, env) + if err != nil { + return nil, err + } + + results := make([]PackageExport, 0, len(pkgs)) + for _, pkg := range pkgs { + fix := &ImportFix{ + StmtInfo: ImportInfo{ + ImportPath: pkg.importPathShort, + }, + IdentName: pkg.packageName, + FixType: AddImport, + } + var exportsMap map[string]bool + if e, ok := stdlib[pkg.importPathShort]; ok { + exportsMap = e + } else { + exportsMap, err = env.GetResolver().loadExports(context.TODO(), completePackage, pkg) + if err != nil { + if env.Debug { + env.Logf("while completing %q, error loading exports from %q: %v", completePackage, pkg.importPathShort, err) + } + continue + } + } + var exports []string + for export := range exportsMap { + exports = append(exports, export) + } + sort.Strings(exports) + results = append(results, PackageExport{ + Fix: fix, + Exports: exports, + }) + } + + return results, nil } // ProcessEnv contains environment variables and settings that affect the use of diff --git a/internal/imports/fix_test.go b/internal/imports/fix_test.go index a29fc6e591..f58cc3a0fe 100644 --- a/internal/imports/fix_test.go +++ b/internal/imports/fix_test.go @@ -2522,9 +2522,9 @@ func TestGetCandidates(t *testing.T) { } want := []res{ {"bytes", "bytes"}, + {"http", "net/http"}, {"rand", "crypto/rand"}, {"rand", "math/rand"}, - {"http", "net/http"}, {"bar", "bar.com/bar"}, {"foo", "foo.com/foo"}, } @@ -2560,6 +2560,45 @@ func TestGetCandidates(t *testing.T) { }) } +func TestGetPackageCompletions(t *testing.T) { + type res struct { + name, path, symbol string + } + want := []res{ + {"rand", "crypto/rand", "Prime"}, + {"rand", "math/rand", "Seed"}, + {"rand", "bar.com/rand", "Bar"}, + } + + testConfig{ + modules: []packagestest.Module{ + { + Name: "bar.com", + Files: fm{"rand/bar.go": "package rand\nvar Bar int\n"}, + }, + }, + goPackagesIncompatible: true, // getPackageCompletions doesn't support the go/packages resolver. + }.test(t, func(t *goimportTest) { + candidates, err := getPackageExports("rand", "x.go", t.env) + if err != nil { + t.Fatalf("getPackageCompletions() = %v", err) + } + var got []res + for _, c := range candidates { + for _, csym := range c.Exports { + for _, w := range want { + if c.Fix.StmtInfo.ImportPath == w.path && csym == w.symbol { + got = append(got, res{c.Fix.IdentName, c.Fix.StmtInfo.ImportPath, csym}) + } + } + } + } + if !reflect.DeepEqual(want, got) { + t.Errorf("wanted stdlib results in order %v, got %v", want, got) + } + }) +} + // Tests #34895: process should not panic on concurrent calls. func TestConcurrentProcess(t *testing.T) { testConfig{ diff --git a/internal/imports/imports.go b/internal/imports/imports.go index 2c074cb2db..ed3867bb59 100644 --- a/internal/imports/imports.go +++ b/internal/imports/imports.go @@ -105,13 +105,22 @@ func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options) ( // GetAllCandidates gets all of the standard library candidate packages to import in // sorted order on import path. func GetAllCandidates(filename string, opt *Options) (pkgs []ImportFix, err error) { - _, opt, err = initialize(filename, []byte{}, opt) + _, opt, err = initialize(filename, nil, opt) if err != nil { return nil, err } return getAllCandidates(filename, opt.Env) } +// GetPackageExports returns all known packages with name pkg and their exports. +func GetPackageExports(pkg, filename string, opt *Options) (exports []PackageExport, err error) { + _, opt, err = initialize(filename, nil, opt) + if err != nil { + return nil, err + } + return getPackageExports(pkg, filename, opt.Env) +} + // initialize sets the values for opt and src. // If they are provided, they are not changed. Otherwise opt is set to the // default values and src is read from the file system.