diff --git a/cmd/callgraph/main.go b/cmd/callgraph/main.go index 204b1160d7..7284c4b386 100644 --- a/cmd/callgraph/main.go +++ b/cmd/callgraph/main.go @@ -182,6 +182,9 @@ func doCallgraph(dir, gopath, algo, format string, tests bool, args []string) er if err != nil { return err } + if packages.PrintErrors(initial) > 0 { + return fmt.Errorf("packages contain errors") + } // Create and build SSA-form program representation. prog, pkgs := ssautil.Packages(initial, 0) diff --git a/cmd/ssadump/main.go b/cmd/ssadump/main.go index 5dcf92b57b..b978249503 100644 --- a/cmd/ssadump/main.go +++ b/cmd/ssadump/main.go @@ -126,6 +126,9 @@ func doMain() error { if len(initial) == 0 { return fmt.Errorf("no packages") } + if packages.PrintErrors(initial) > 0 { + return fmt.Errorf("packages contain errors") + } // Create SSA-form program representation. prog, pkgs := ssautil.Packages(initial, mode) diff --git a/go/packages/doc.go b/go/packages/doc.go index e5d23b50a6..1240974a60 100644 --- a/go/packages/doc.go +++ b/go/packages/doc.go @@ -58,29 +58,8 @@ for details. Most tools should pass their command-line arguments (after any flags) uninterpreted to the loader, so that the loader can interpret them according to the conventions of the underlying build system. -For example, this program prints the names of the source files -for each package listed on the command line: +See the Example function for typical usage. - package main - - import ( - "flag" - "fmt" - "log" - - "golang.org/x/tools/go/packages" - ) - - func main() { - flag.Parse() - pkgs, err := packages.Load(nil, flag.Args()...) - if err != nil { - log.Fatal(err) - } - for _, pkg := range pkgs { - fmt.Print(pkg.ID, pkg.GoFiles) - } - } */ package packages // import "golang.org/x/tools/go/packages" diff --git a/go/packages/example_test.go b/go/packages/example_test.go new file mode 100644 index 0000000000..ad1340f308 --- /dev/null +++ b/go/packages/example_test.go @@ -0,0 +1,34 @@ +package packages_test + +import ( + "flag" + "fmt" + "os" + + "golang.org/x/tools/go/packages" +) + +// Example demonstrates how to load the packages specified on the +// command line from source syntax. +func Example() { + flag.Parse() + + // Many tools pass their command-line arguments (after any flags) + // uninterpreted to packages.Load so that it can interpret them + // according to the conventions of the underlying build system. + cfg := &packages.Config{Mode: packages.LoadSyntax} + pkgs, err := packages.Load(cfg, flag.Args()...) + if err != nil { + fmt.Fprintf(os.Stderr, "load: %v\n", err) + os.Exit(1) + } + if packages.PrintErrors(pkgs) > 0 { + os.Exit(1) + } + + // Print the names of the source files + // for each package listed on the command line. + for _, pkg := range pkgs { + fmt.Println(pkg.ID, pkg.GoFiles) + } +} diff --git a/go/packages/gopackages/main.go b/go/packages/gopackages/main.go index 5b22de5a94..bff298f789 100644 --- a/go/packages/gopackages/main.go +++ b/go/packages/gopackages/main.go @@ -113,7 +113,6 @@ func main() { // Load, parse, and type-check the packages named on the command line. cfg := &packages.Config{ Mode: packages.LoadSyntax, - Error: func(error) {}, // we'll take responsibility for printing errors Tests: *testFlag, BuildFlags: buildFlags, } diff --git a/go/packages/packages.go b/go/packages/packages.go index 4827224b02..62a8085177 100644 --- a/go/packages/packages.go +++ b/go/packages/packages.go @@ -87,14 +87,6 @@ type Config struct { // the build system's query tool. BuildFlags []string - // Error is called for each error encountered during parsing and type-checking. - // It must be safe to call Error simultaneously from multiple goroutines. - // In addition to calling Error, the loader records each error - // in the corresponding Package's Errors list. - // If Error is nil, the loader prints errors to os.Stderr. - // To disable printing of errors, set opt.Error = func(error) {}. - Error func(error) - // Fset provides source position information for syntax trees and types. // If Fset is nil, the loader will create a new FileSet. Fset *token.FileSet @@ -155,6 +147,11 @@ type driverResponse struct { // as defined by the underlying build system. // It may return an empty list of packages without an error, // for instance for an empty expansion of a valid wildcard. +// Errors associated with a particular package are recorded in the +// corresponding Package's Errors list, and do not cause Load to +// return an error. Clients may need to handle such errors before +// proceeding with further analysis. The PrintErrors function is +// provided for convenient display of all errors. func Load(cfg *Config, patterns ...string) ([]*Package, error) { l := newLoader(cfg) response, err := defaultDriver(&l.Config, patterns...) @@ -367,15 +364,8 @@ func newLoader(cfg *Config) *loader { ld.Fset = token.NewFileSet() } - // Error and ParseFile are required even in LoadTypes mode + // ParseFile is required even in LoadTypes mode // because we load source if export data is missing. - - if ld.Error == nil { - ld.Error = func(e error) { - fmt.Fprintln(os.Stderr, e) - } - } - if ld.ParseFile == nil { ld.ParseFile = func(fset *token.FileSet, filename string) (*ast.File, error) { const mode = parser.AllErrors | parser.ParseComments @@ -616,20 +606,6 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) { log.Printf("internal error: error %q (%T) without position", err, err) } - // Allow application to print errors. - // - // TODO(adonovan): the real purpose of this hook is to - // allow (most) applications to _disable_ printing, - // while printing by default. - // Should we remove it, and make clients responsible for - // walking the import graph and printing errors? - // Though convenient for the common case, - // it seems like an unsafe default, and is decidedly less - // convenient for a tool that wants to print the errors. - for _, err := range errs { - ld.Error(err) - } - lpkg.Errors = append(lpkg.Errors, errs...) } diff --git a/go/packages/packages_test.go b/go/packages/packages_test.go index 238c290e64..6e71dc709f 100644 --- a/go/packages/packages_test.go +++ b/go/packages/packages_test.go @@ -19,7 +19,6 @@ import ( "runtime" "sort" "strings" - "sync" "testing" "golang.org/x/tools/go/packages" @@ -546,17 +545,6 @@ package b`, } } -type errCollector struct { - mu sync.Mutex - errors []packages.Error -} - -func (ec *errCollector) add(err error) { - ec.mu.Lock() - ec.errors = append(ec.errors, err.(packages.Error)) - ec.mu.Unlock() -} - func TestLoadTypes(t *testing.T) { // In LoadTypes and LoadSyntax modes, the compiler will // fail to generate an export data file for c, because it has @@ -571,9 +559,8 @@ func TestLoadTypes(t *testing.T) { defer cleanup() cfg := &packages.Config{ - Mode: packages.LoadTypes, - Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), - Error: func(error) {}, + Mode: packages.LoadTypes, + Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), } initial, err := packages.Load(cfg, "a") if err != nil { @@ -638,9 +625,8 @@ func TestLoadSyntaxOK(t *testing.T) { defer cleanup() cfg := &packages.Config{ - Mode: packages.LoadSyntax, - Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), - Error: func(error) {}, + Mode: packages.LoadSyntax, + Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), } initial, err := packages.Load(cfg, "a", "c") if err != nil { @@ -732,32 +718,16 @@ func TestLoadDiamondTypes(t *testing.T) { cfg := &packages.Config{ Mode: packages.LoadSyntax, Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), - Error: func(err error) { - t.Errorf("Error during load: %v", err) - }, } initial, err := packages.Load(cfg, "a") if err != nil { t.Fatal(err) } - - var visit func(pkg *packages.Package) - seen := make(map[string]bool) - visit = func(pkg *packages.Package) { - if seen[pkg.ID] { - return - } - seen[pkg.ID] = true + packages.Visit(initial, nil, func(pkg *packages.Package) { for _, err := range pkg.Errors { - t.Errorf("Error on package %v: %v", pkg.ID, err) + t.Errorf("package %s: %v", pkg.ID, err) } - for _, imp := range pkg.Imports { - visit(imp) - } - } - for _, pkg := range initial { - visit(pkg) - } + }) graph, _ := importGraph(initial) wantGraph := ` @@ -791,9 +761,8 @@ func TestLoadSyntaxError(t *testing.T) { defer cleanup() cfg := &packages.Config{ - Mode: packages.LoadSyntax, - Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), - Error: func(error) {}, + Mode: packages.LoadSyntax, + Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), } initial, err := packages.Load(cfg, "a", "c") if err != nil { @@ -801,18 +770,9 @@ func TestLoadSyntaxError(t *testing.T) { } all := make(map[string]*packages.Package) - var visit func(p *packages.Package) - visit = func(p *packages.Package) { - if all[p.ID] == nil { - all[p.ID] = p - for _, imp := range p.Imports { - visit(imp) - } - } - } - for _, p := range initial { - visit(p) - } + packages.Visit(initial, nil, func(p *packages.Package) { + all[p.ID] = p + }) for _, test := range []struct { id string @@ -896,11 +856,9 @@ func TestLoadAllSyntaxOverlay(t *testing.T) { return parser.ParseFile(fset, filename, src, mode) } } - var errs errCollector cfg := &packages.Config{ Mode: packages.LoadAllSyntax, Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), - Error: errs.add, ParseFile: parseFile, } initial, err := packages.Load(cfg, "a") @@ -916,7 +874,12 @@ func TestLoadAllSyntaxOverlay(t *testing.T) { t.Errorf("%d. a.A: got %s, want %s", i, got, test.want) } - if errs := errorMessages(errs.errors); !reflect.DeepEqual(errs, test.wantErrs) { + // Check errors. + var errors []packages.Error + packages.Visit(initial, nil, func(pkg *packages.Package) { + errors = append(errors, pkg.Errors...) + }) + if errs := errorMessages(errors); !reflect.DeepEqual(errs, test.wantErrs) { t.Errorf("%d. got errors %s, want %s", i, errs, test.wantErrs) } } @@ -948,11 +911,9 @@ import ( os.Mkdir(filepath.Join(tmp, "src/empty"), 0777) // create an existing but empty package - var errs2 errCollector cfg := &packages.Config{ - Mode: packages.LoadAllSyntax, - Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), - Error: errs2.add, + Mode: packages.LoadAllSyntax, + Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), } initial, err := packages.Load(cfg, "root") if err != nil { @@ -1169,39 +1130,21 @@ func TestJSON(t *testing.T) { if err != nil { t.Fatal(err) } + + // Visit and print all packages. buf := &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetIndent("", "\t") - seen := make(map[string]bool) - var visit func(*packages.Package) - visit = func(pkg *packages.Package) { - if seen[pkg.ID] { - return - } - seen[pkg.ID] = true - - // Trim the source lists for stable results. + packages.Visit(initial, nil, func(pkg *packages.Package) { + // trim the source lists for stable results pkg.GoFiles = cleanPaths(pkg.GoFiles) pkg.CompiledGoFiles = cleanPaths(pkg.CompiledGoFiles) pkg.OtherFiles = cleanPaths(pkg.OtherFiles) - - // Visit imports. - var importPaths []string - for path := range pkg.Imports { - importPaths = append(importPaths, path) - } - sort.Strings(importPaths) // for determinism - for _, path := range importPaths { - visit(pkg.Imports[path]) - } - if err := enc.Encode(pkg); err != nil { t.Fatal(err) } - } - for _, pkg := range initial { - visit(pkg) - } + }) + wantJSON := ` { "ID": "a", @@ -1406,7 +1349,6 @@ func errorMessages(errors []packages.Error) []string { for _, err := range errors { msgs = append(msgs, err.Msg) } - sort.Strings(msgs) return msgs } @@ -1439,6 +1381,8 @@ func importGraph(initial []*packages.Package) (string, map[string]*packages.Pack initialSet[p] = true } + // We can't use Visit because we need to prune + // the traversal of specific edges, not just nodes. var nodes, edges []string res := make(map[string]*packages.Package) seen := make(map[*packages.Package]bool) diff --git a/go/packages/stdlib_test.go b/go/packages/stdlib_test.go index ecaa89d62a..ddecff1c75 100644 --- a/go/packages/stdlib_test.go +++ b/go/packages/stdlib_test.go @@ -30,10 +30,7 @@ func TestStdlibMetadata(t *testing.T) { alloc := memstats.Alloc // Load, parse and type-check the program. - cfg := &packages.Config{ - Mode: packages.LoadAllSyntax, - Error: func(error) {}, - } + cfg := &packages.Config{Mode: packages.LoadAllSyntax} pkgs, err := packages.Load(cfg, "std") if err != nil { t.Fatalf("failed to load metadata: %v", err) @@ -96,10 +93,7 @@ func TestCgoOption(t *testing.T) { {"net", "cgoLookupHost", "cgo_stub.go"}, {"os/user", "current", "lookup_stubs.go"}, } { - cfg := &packages.Config{ - Mode: packages.LoadSyntax, - Error: func(error) {}, - } + cfg := &packages.Config{Mode: packages.LoadSyntax} pkgs, err := packages.Load(cfg, test.pkg) if err != nil { t.Errorf("Load failed: %v", err) diff --git a/go/packages/visit.go b/go/packages/visit.go new file mode 100644 index 0000000000..c1a4b28ca0 --- /dev/null +++ b/go/packages/visit.go @@ -0,0 +1,55 @@ +package packages + +import ( + "fmt" + "os" + "sort" +) + +// Visit visits all the packages in the import graph whose roots are +// pkgs, calling the optional pre function the first time each package +// is encountered (preorder), and the optional post function after a +// package's dependencies have been visited (postorder). +// The boolean result of pre(pkg) determines whether +// the imports of package pkg are visited. +func Visit(pkgs []*Package, pre func(*Package) bool, post func(*Package)) { + seen := make(map[*Package]bool) + var visit func(*Package) + visit = func(pkg *Package) { + if !seen[pkg] { + seen[pkg] = true + + if pre == nil || pre(pkg) { + paths := make([]string, 0, len(pkg.Imports)) + for path := range pkg.Imports { + paths = append(paths, path) + } + sort.Strings(paths) // for determinism + for _, path := range paths { + visit(pkg.Imports[path]) + } + } + + if post != nil { + post(pkg) + } + } + } + for _, pkg := range pkgs { + visit(pkg) + } +} + +// PrintErrors prints to os.Stderr the accumulated errors of all +// packages in the import graph rooted at pkgs, dependencies first. +// PrintErrors returns the number of errors printed. +func PrintErrors(pkgs []*Package) int { + var n int + Visit(pkgs, nil, func(pkg *Package) { + for _, err := range pkg.Errors { + fmt.Fprintln(os.Stderr, err) + n++ + } + }) + return n +} diff --git a/go/ssa/example_test.go b/go/ssa/example_test.go index 8bbe2d634e..6ca777686b 100644 --- a/go/ssa/example_test.go +++ b/go/ssa/example_test.go @@ -125,6 +125,13 @@ func ExampleLoadPackages() { log.Fatal(err) } + // Stop if any package had errors. + // This step is optional; without it, the next step + // will create SSA for only a subset of packages. + if packages.PrintErrors(initial) > 0 { + log.Fatalf("packages contain errors") + } + // Create SSA packages for all well-typed packages. prog, pkgs := ssautil.Packages(initial, ssa.PrintPackages) _ = prog diff --git a/go/ssa/ssautil/load.go b/go/ssa/ssautil/load.go index 9a69034b72..03068d50b2 100644 --- a/go/ssa/ssautil/load.go +++ b/go/ssa/ssautil/load.go @@ -35,30 +35,17 @@ func Packages(initial []*packages.Package, mode ssa.BuilderMode) (*ssa.Program, } prog := ssa.NewProgram(fset, mode) - seen := make(map[*packages.Package]*ssa.Package) - var create func(p *packages.Package) *ssa.Package - create = func(p *packages.Package) *ssa.Package { - ssapkg, ok := seen[p] - if !ok { - if p.Types == nil || p.IllTyped { - // not well typed - seen[p] = nil - return nil - } - ssapkg = prog.CreatePackage(p.Types, p.Syntax, p.TypesInfo, true) - seen[p] = ssapkg - - for _, imp := range p.Imports { - create(imp) - } + ssamap := make(map[*packages.Package]*ssa.Package) + packages.Visit(initial, nil, func(p *packages.Package) { + if p.Types != nil && !p.IllTyped { + ssamap[p] = prog.CreatePackage(p.Types, p.Syntax, p.TypesInfo, true) } - return ssapkg - } + }) var ssapkgs []*ssa.Package for _, p := range initial { - ssapkgs = append(ssapkgs, create(p)) + ssapkgs = append(ssapkgs, ssamap[p]) // may be nil } return prog, ssapkgs } diff --git a/go/ssa/ssautil/load_test.go b/go/ssa/ssautil/load_test.go index 5a1efe623b..2885ed3965 100644 --- a/go/ssa/ssautil/load_test.go +++ b/go/ssa/ssautil/load_test.go @@ -58,6 +58,9 @@ func TestPackages(t *testing.T) { if err != nil { t.Fatal(err) } + if packages.PrintErrors(initial) > 0 { + t.Fatal("there were errors") + } prog, pkgs := ssautil.Packages(initial, 0) bytesNewBuffer := pkgs[0].Func("NewBuffer")