diff --git a/internal/lsp/cmd/check_test.go b/internal/lsp/cmd/check_test.go index d918f0da35..5752b873d8 100644 --- a/internal/lsp/cmd/check_test.go +++ b/internal/lsp/cmd/check_test.go @@ -25,7 +25,7 @@ func (r *runner) Diagnostics(t *testing.T, data tests.Diagnostics) { args := []string{"-remote=internal", "check", fname} app := cmd.New("gopls-test", r.data.Config.Dir, r.data.Exported.Config.Env) out := captureStdOut(t, func() { - tool.Main(r.ctx, app, args) + _ = tool.Run(r.ctx, app, args) }) // parse got into a collection of reports got := map[string]struct{}{} diff --git a/internal/lsp/cmd/cmd.go b/internal/lsp/cmd/cmd.go index a2e2af8c70..3f95add2d2 100644 --- a/internal/lsp/cmd/cmd.go +++ b/internal/lsp/cmd/cmd.go @@ -119,14 +119,12 @@ func (app *Application) Run(ctx context.Context, args ...string) error { export.AddExporters(ocagent.Connect(ocConfig)) app.Serve.app = app if len(args) == 0 { - tool.Main(ctx, &app.Serve, args) - return nil + return tool.Run(ctx, &app.Serve, args) } command, args := args[0], args[1:] for _, c := range app.commands() { if c.Name() == command { - tool.Main(ctx, c, args) - return nil + return tool.Run(ctx, c, args) } } return tool.CommandLineErrorf("Unknown command %v", command) diff --git a/internal/lsp/cmd/definition_test.go b/internal/lsp/cmd/definition_test.go index 480a8cc110..5737fc0776 100644 --- a/internal/lsp/cmd/definition_test.go +++ b/internal/lsp/cmd/definition_test.go @@ -55,7 +55,7 @@ func TestDefinitionHelpExample(t *testing.T) { fmt.Sprintf("%v:#%v", thisFile, cmd.ExampleOffset)} { args := append(baseArgs, query) got := captureStdOut(t, func() { - tool.Main(tests.Context(t), cmd.New("gopls-test", "", nil), args) + _ = tool.Run(tests.Context(t), cmd.New("gopls-test", "", nil), args) }) if !expect.MatchString(got) { t.Errorf("test with %v\nexpected:\n%s\ngot:\n%s", args, expect, got) @@ -84,7 +84,7 @@ func (r *runner) Definition(t *testing.T, data tests.Definitions) { args = append(args, fmt.Sprint(d.Src)) got := captureStdOut(t, func() { app := cmd.New("gopls-test", r.data.Config.Dir, r.data.Exported.Config.Env) - tool.Main(r.ctx, app, args) + _ = tool.Run(r.ctx, app, args) }) got = normalizePaths(r.data, got) if mode&jsonGoDef != 0 && runtime.GOOS == "windows" { diff --git a/internal/lsp/cmd/format_test.go b/internal/lsp/cmd/format_test.go index b7cdabdcfe..878246189d 100644 --- a/internal/lsp/cmd/format_test.go +++ b/internal/lsp/cmd/format_test.go @@ -39,7 +39,7 @@ func (r *runner) Format(t *testing.T, data tests.Formats) { } app := cmd.New("gopls-test", r.data.Config.Dir, r.data.Config.Env) got := captureStdOut(t, func() { - tool.Main(r.ctx, app, append([]string{"-remote=internal", "format"}, args...)) + _ = tool.Run(r.ctx, app, append([]string{"-remote=internal", "format"}, args...)) }) got = normalizePaths(r.data, got) // check the first two lines are the expected file header diff --git a/internal/lsp/cmd/query.go b/internal/lsp/cmd/query.go index 4fb42266b6..037b4d0076 100644 --- a/internal/lsp/cmd/query.go +++ b/internal/lsp/cmd/query.go @@ -56,8 +56,7 @@ func (q *query) Run(ctx context.Context, args ...string) error { mode, args := args[0], args[1:] for _, m := range q.modes() { if m.Name() == mode { - tool.Main(ctx, m, args) - return nil + return tool.Run(ctx, m, args) // pass errors up the chain } } return tool.CommandLineErrorf("unknown command %v", mode) diff --git a/internal/tool/tool.go b/internal/tool/tool.go index d21a099803..b50569aced 100644 --- a/internal/tool/tool.go +++ b/internal/tool/tool.go @@ -78,7 +78,9 @@ func CommandLineErrorf(message string, args ...interface{}) error { } // Main should be invoked directly by main function. -// It will only return if there was no error. +// It will only return if there was no error. If an error +// was encountered it is printed to standard error and the +// application exits with an exit code of 2. func Main(ctx context.Context, app Application, args []string) { s := flag.NewFlagSet(app.Name(), flag.ExitOnError) s.Usage = func() { @@ -86,50 +88,7 @@ func Main(ctx context.Context, app Application, args []string) { fmt.Fprintf(s.Output(), "\n\nUsage: %v [flags] %v\n", app.Name(), app.Usage()) app.DetailedHelp(s) } - p := addFlags(s, reflect.StructField{}, reflect.ValueOf(app)) - s.Parse(args) - err := func() error { - if p != nil && p.CPU != "" { - f, err := os.Create(p.CPU) - if err != nil { - return err - } - if err := pprof.StartCPUProfile(f); err != nil { - return err - } - defer pprof.StopCPUProfile() - } - - if p != nil && p.Trace != "" { - f, err := os.Create(p.Trace) - if err != nil { - return err - } - if err := trace.Start(f); err != nil { - return err - } - defer func() { - trace.Stop() - log.Printf("To view the trace, run:\n$ go tool trace view %s", p.Trace) - }() - } - - if p != nil && p.Memory != "" { - f, err := os.Create(p.Memory) - if err != nil { - return err - } - defer func() { - runtime.GC() // get up-to-date statistics - if err := pprof.WriteHeapProfile(f); err != nil { - log.Printf("Writing memory profile: %v", err) - } - f.Close() - }() - } - return app.Run(ctx, s.Args()...) - }() - if err != nil { + if err := Run(ctx, app, args); err != nil { fmt.Fprintf(s.Output(), "%s: %v\n", app.Name(), err) if _, printHelp := err.(commandLineError); printHelp { s.Usage() @@ -138,6 +97,56 @@ func Main(ctx context.Context, app Application, args []string) { } } +// Run is the inner loop for Main; invoked by Main, recursively by +// Run, and by various tests. It runs the application and returns an +// error. +func Run(ctx context.Context, app Application, args []string) error { + s := flag.NewFlagSet(app.Name(), flag.ExitOnError) + p := addFlags(s, reflect.StructField{}, reflect.ValueOf(app)) + s.Parse(args) + + if p != nil && p.CPU != "" { + f, err := os.Create(p.CPU) + if err != nil { + return err + } + if err := pprof.StartCPUProfile(f); err != nil { + return err + } + defer pprof.StopCPUProfile() + } + + if p != nil && p.Trace != "" { + f, err := os.Create(p.Trace) + if err != nil { + return err + } + if err := trace.Start(f); err != nil { + return err + } + defer func() { + trace.Stop() + log.Printf("To view the trace, run:\n$ go tool trace view %s", p.Trace) + }() + } + + if p != nil && p.Memory != "" { + f, err := os.Create(p.Memory) + if err != nil { + return err + } + defer func() { + runtime.GC() // get up-to-date statistics + if err := pprof.WriteHeapProfile(f); err != nil { + log.Printf("Writing memory profile: %v", err) + } + f.Close() + }() + } + + return app.Run(ctx, s.Args()...) +} + // addFlags scans fields of structs recursively to find things with flag tags // and add them to the flag set. func addFlags(f *flag.FlagSet, field reflect.StructField, value reflect.Value) *Profile {