diff --git a/cmd/callgraph/main.go b/cmd/callgraph/main.go index 8ef4597c1f..204b1160d7 100644 --- a/cmd/callgraph/main.go +++ b/cmd/callgraph/main.go @@ -37,7 +37,7 @@ import ( "golang.org/x/tools/go/callgraph/cha" "golang.org/x/tools/go/callgraph/rta" "golang.org/x/tools/go/callgraph/static" - "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/packages" "golang.org/x/tools/go/pointer" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" @@ -67,7 +67,7 @@ const Usage = `callgraph: display the the call graph of a Go program. Usage: - callgraph [-algo=static|cha|rta|pta] [-test] [-format=...] ... + callgraph [-algo=static|cha|rta|pta] [-test] [-format=...] package... Flags: @@ -118,8 +118,6 @@ Flags: import path of the enclosing package. Consult the go/ssa API documentation for details. -` + loader.FromArgsUsage + ` - Examples: Show the call graph of the trivial web server application: @@ -158,7 +156,7 @@ func init() { func main() { flag.Parse() - if err := doCallgraph(&build.Default, *algoFlag, *formatFlag, *testFlag, flag.Args()); err != nil { + if err := doCallgraph("", "", *algoFlag, *formatFlag, *testFlag, flag.Args()); err != nil { fmt.Fprintf(os.Stderr, "callgraph: %s\n", err) os.Exit(1) } @@ -166,28 +164,27 @@ func main() { var stdout io.Writer = os.Stdout -func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []string) error { - conf := loader.Config{Build: ctxt} - +func doCallgraph(dir, gopath, algo, format string, tests bool, args []string) error { if len(args) == 0 { fmt.Fprintln(os.Stderr, Usage) return nil } - // Use the initial packages from the command line. - _, err := conf.FromArgs(args, tests) - if err != nil { - return err + cfg := &packages.Config{ + Mode: packages.LoadAllSyntax, + Tests: tests, + Dir: dir, } - - // Load, parse and type-check the whole program. - iprog, err := conf.Load() + if gopath != "" { + cfg.Env = append(os.Environ(), "GOPATH="+gopath) // to enable testing + } + initial, err := packages.Load(cfg, args...) if err != nil { return err } // Create and build SSA-form program representation. - prog := ssautil.CreateProgram(iprog, 0) + prog, pkgs := ssautil.Packages(initial, 0) prog.Build() // -- call graph construction ------------------------------------------ @@ -221,7 +218,7 @@ func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []st } } - mains, err := mainPackages(prog, tests) + mains, err := mainPackages(pkgs) if err != nil { return err } @@ -237,7 +234,7 @@ func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []st cg = ptares.CallGraph case "rta": - mains, err := mainPackages(prog, tests) + mains, err := mainPackages(pkgs) if err != nil { return err } @@ -305,25 +302,13 @@ func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []st // mainPackages returns the main packages to analyze. // Each resulting package is named "main" and has a main function. -func mainPackages(prog *ssa.Program, tests bool) ([]*ssa.Package, error) { - pkgs := prog.AllPackages() // TODO(adonovan): use only initial packages - - // If tests, create a "testmain" package for each test. +func mainPackages(pkgs []*ssa.Package) ([]*ssa.Package, error) { var mains []*ssa.Package - if tests { - for _, pkg := range pkgs { - if main := prog.CreateTestMainPackage(pkg); main != nil { - mains = append(mains, main) - } + for _, p := range pkgs { + if p != nil && p.Pkg.Name() == "main" && p.Func("main") != nil { + mains = append(mains, p) } - if mains == nil { - return nil, fmt.Errorf("no tests") - } - return mains, nil } - - // Otherwise, use the main packages. - mains = append(mains, ssautil.MainPackages(pkgs)...) if len(mains) == 0 { return nil, fmt.Errorf("no main packages") } diff --git a/cmd/callgraph/main_test.go b/cmd/callgraph/main_test.go index c42f56dafe..d5aa3230a5 100644 --- a/cmd/callgraph/main_test.go +++ b/cmd/callgraph/main_test.go @@ -11,25 +11,23 @@ package main import ( "bytes" "fmt" - "go/build" - "reflect" - "sort" + "path/filepath" "strings" "testing" ) func TestCallgraph(t *testing.T) { - ctxt := build.Default // copy - ctxt.GOPATH = "testdata" - - const format = "{{.Caller}} --> {{.Callee}}" + gopath, err := filepath.Abs("testdata") + if err != nil { + t.Fatal(err) + } for _, test := range []struct { - algo, format string - tests bool - want []string + algo string + tests bool + want []string }{ - {"rta", format, false, []string{ + {"rta", false, []string{ // rta imprecisely shows cross product of {main,main2} x {C,D} `pkg.main --> (pkg.C).f`, `pkg.main --> (pkg.D).f`, @@ -37,7 +35,7 @@ func TestCallgraph(t *testing.T) { `pkg.main2 --> (pkg.C).f`, `pkg.main2 --> (pkg.D).f`, }}, - {"pta", format, false, []string{ + {"pta", false, []string{ // pta distinguishes main->C, main2->D. Also has a root node. ` --> pkg.init`, ` --> pkg.main`, @@ -45,37 +43,42 @@ func TestCallgraph(t *testing.T) { `pkg.main --> pkg.main2`, `pkg.main2 --> (pkg.D).f`, }}, - // tests: main is not called. - {"rta", format, true, []string{ - `pkg$testmain.init --> pkg.init`, + // tests: both the package's main and the test's main are called. + // The callgraph includes all the guts of the "testing" package. + {"rta", true, []string{ + `pkg.test.main --> testing.MainStart`, + `testing.runExample --> pkg.Example`, `pkg.Example --> (pkg.C).f`, + `pkg.main --> (pkg.C).f`, }}, - {"pta", format, true, []string{ - ` --> pkg$testmain.init`, - ` --> pkg.Example`, - `pkg$testmain.init --> pkg.init`, + {"pta", true, []string{ + ` --> pkg.test.main`, + ` --> pkg.main`, + `pkg.test.main --> testing.MainStart`, + `testing.runExample --> pkg.Example`, `pkg.Example --> (pkg.C).f`, + `pkg.main --> (pkg.C).f`, }}, } { + const format = "{{.Caller}} --> {{.Callee}}" stdout = new(bytes.Buffer) - if err := doCallgraph(&ctxt, test.algo, test.format, test.tests, []string{"pkg"}); err != nil { + if err := doCallgraph("testdata/src", gopath, test.algo, format, test.tests, []string{"pkg"}); err != nil { t.Error(err) continue } - got := sortedLines(fmt.Sprint(stdout)) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("callgraph(%q, %q, %t):\ngot:\n%s\nwant:\n%s", - test.algo, test.format, test.tests, - strings.Join(got, "\n"), - strings.Join(test.want, "\n")) + edges := make(map[string]bool) + for _, line := range strings.Split(fmt.Sprint(stdout), "\n") { + edges[line] = true + } + for _, edge := range test.want { + if !edges[edge] { + t.Errorf("callgraph(%q, %t): missing edge: %s", + test.algo, test.tests, edge) + } + } + if t.Failed() { + t.Log("got:\n", stdout) } } } - -func sortedLines(s string) []string { - s = strings.TrimSpace(s) - lines := strings.Split(s, "\n") - sort.Strings(lines) - return lines -} diff --git a/cmd/callgraph/testdata/src/pkg/pkg_test.go b/cmd/callgraph/testdata/src/pkg/pkg_test.go index d6247577b0..0dae2c3105 100644 --- a/cmd/callgraph/testdata/src/pkg/pkg_test.go +++ b/cmd/callgraph/testdata/src/pkg/pkg_test.go @@ -1,7 +1,10 @@ package main -// Don't import "testing", it adds a lot of callgraph edges. +// An Example function must have an "Output:" comment for the go build +// system to generate a call to it from the test main package. func Example() { C(0).f() + + // Output: }