diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index a49060db10..b42d1d8efe 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -93,7 +93,10 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { expectedDefinitions := make(definitions) expectedTypeDefinitions := make(definitions) expectedHighlights := make(highlights) - expectedSymbols := make(symbols) + expectedSymbols := &symbols{ + m: make(map[span.URI][]protocol.DocumentSymbol), + children: make(map[string][]protocol.DocumentSymbol), + } // Collect any data that needs to be used by subsequent tests. if err := exported.Expect(map[string]interface{}{ @@ -180,8 +183,8 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { t.Run("Symbols", func(t *testing.T) { t.Helper() if goVersion111 { // TODO(rstambler): Remove this when we no longer support Go 1.10. - if len(expectedSymbols) != expectedSymbolsCount { - t.Errorf("got %v symbols expected %v", len(expectedSymbols), expectedSymbolsCount) + if len(expectedSymbols.m) != expectedSymbolsCount { + t.Errorf("got %v symbols expected %v", len(expectedSymbols.m), expectedSymbolsCount) } } expectedSymbols.test(t, s) @@ -194,7 +197,10 @@ type completions map[token.Position][]token.Pos type formats map[string]string type definitions map[protocol.Location]protocol.Location type highlights map[string][]protocol.Location -type symbols map[span.URI][]protocol.DocumentSymbol +type symbols struct { + m map[span.URI][]protocol.DocumentSymbol + children map[string][]protocol.DocumentSymbol +} func (d diagnostics) test(t *testing.T, v source.View) int { count := 0 @@ -522,7 +528,7 @@ func (h highlights) test(t *testing.T, s *Server) { } } -func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name string, rng span.Range, kind int64) { +func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name string, rng span.Range, kind int64, parentName string) { f := fset.File(rng.Start) if f == nil { return @@ -544,15 +550,20 @@ func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name str return } - s[spn.URI()] = append(s[spn.URI()], protocol.DocumentSymbol{ + sym := protocol.DocumentSymbol{ Name: name, Kind: protocol.SymbolKind(kind), SelectionRange: prng, - }) + } + if parentName == "" { + s.m[spn.URI()] = append(s.m[spn.URI()], sym) + } else { + s.children[parentName] = append(s.children[parentName], sym) + } } func (s symbols) test(t *testing.T, server *Server) { - for uri, expectedSymbols := range s { + for uri, expectedSymbols := range s.m { params := &protocol.DocumentSymbolParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: string(uri), @@ -564,28 +575,58 @@ func (s symbols) test(t *testing.T, server *Server) { } if len(symbols) != len(expectedSymbols) { - t.Errorf("want %d symbols in %v, got %d", len(expectedSymbols), uri, len(symbols)) + t.Errorf("want %d top-level symbols in %v, got %d", len(expectedSymbols), uri, len(symbols)) continue } sort.Slice(symbols, func(i, j int) bool { return symbols[i].Name < symbols[j].Name }) sort.Slice(expectedSymbols, func(i, j int) bool { return expectedSymbols[i].Name < expectedSymbols[j].Name }) - for i, w := range expectedSymbols { - g := symbols[i] - if w.Name != g.Name { - t.Errorf("%s: want symbol %q, got %q", uri, w.Name, g.Name) - continue - } - if w.Kind != g.Kind { - t.Errorf("%s: want kind %v for %s, got %v", uri, w.Kind, w.Name, g.Kind) - } - if w.SelectionRange != g.SelectionRange { - t.Errorf("%s: want selection range %v for %s, got %v", uri, w.SelectionRange, w.Name, g.SelectionRange) - } + for i := range expectedSymbols { + children := s.children[expectedSymbols[i].Name] + sort.Slice(children, func(i, j int) bool { return children[i].Name < children[j].Name }) + expectedSymbols[i].Children = children + } + if diff := diffSymbols(uri, expectedSymbols, symbols); diff != "" { + t.Error(diff) } } } +func diffSymbols(uri span.URI, want, got []protocol.DocumentSymbol) string { + if len(got) != len(want) { + goto Failed + } + for i, w := range want { + g := got[i] + if w.Name != g.Name { + goto Failed + } + if w.Kind != g.Kind { + goto Failed + } + if w.SelectionRange != g.SelectionRange { + goto Failed + } + sort.Slice(g.Children, func(i, j int) bool { return g.Children[i].Name < g.Children[j].Name }) + if msg := diffSymbols(uri, w.Children, g.Children); msg != "" { + return fmt.Sprintf("children of %s: %s", w.Name, msg) + } + } + return "" + +Failed: + msg := &bytes.Buffer{} + fmt.Fprintf(msg, "document symbols failed for %s:\nexpected:\n", uri) + for _, s := range want { + fmt.Fprintf(msg, " %v %v %v\n", s.Name, s.Kind, s.SelectionRange) + } + fmt.Fprintf(msg, "got:\n") + for _, s := range got { + fmt.Fprintf(msg, " %v %v %v\n", s.Name, s.Kind, s.SelectionRange) + } + return msg.String() +} + func testLocation(e *packagestest.Exported, fset *token.FileSet, rng packagestest.Range) (span.Span, *protocol.ColumnMapper) { spn, err := span.NewRange(fset, rng.Start, rng.End).Span() if err != nil { diff --git a/internal/lsp/source/symbols.go b/internal/lsp/source/symbols.go index f95d3f384e..4b4d2b8b1d 100644 --- a/internal/lsp/source/symbols.go +++ b/internal/lsp/source/symbols.go @@ -6,6 +6,7 @@ package source import ( "context" + "errors" "fmt" "go/ast" "go/token" @@ -24,6 +25,10 @@ const ( FunctionSymbol MethodSymbol InterfaceSymbol + NumberSymbol + StringSymbol + BooleanSymbol + FieldSymbol ) type Symbol struct { @@ -42,19 +47,30 @@ func DocumentSymbols(ctx context.Context, f File) []Symbol { info := pkg.GetTypesInfo() q := qualifier(file, pkg.GetTypes(), info) + methodsToReceiver := make(map[types.Type][]Symbol) + symbolsToReceiver := make(map[types.Type]int) var symbols []Symbol for _, decl := range file.Decls { switch decl := decl.(type) { case *ast.FuncDecl: if obj := info.ObjectOf(decl.Name); obj != nil { - symbols = append(symbols, funcSymbol(decl, obj, fset, q)) + if fs := funcSymbol(decl, obj, fset, q); fs.Kind == MethodSymbol { + // Store methods separately, as we want them to appear as children + // of the corresponding type (which we may not have seen yet). + rtype := obj.Type().(*types.Signature).Recv().Type() + methodsToReceiver[rtype] = append(methodsToReceiver[rtype], fs) + } else { + symbols = append(symbols, fs) + } } case *ast.GenDecl: for _, spec := range decl.Specs { switch spec := spec.(type) { case *ast.TypeSpec: if obj := info.ObjectOf(spec.Name); obj != nil { - symbols = append(symbols, typeSymbol(spec, obj, fset, q)) + ts := typeSymbol(spec, obj, fset, q) + symbols = append(symbols, ts) + symbolsToReceiver[obj.Type()] = len(symbols) - 1 } case *ast.ValueSpec: for _, name := range spec.Names { @@ -66,6 +82,21 @@ func DocumentSymbols(ctx context.Context, f File) []Symbol { } } } + + // Attempt to associate methods to the corresponding type symbol. + for typ, methods := range methodsToReceiver { + if ptr, ok := typ.(*types.Pointer); ok { + typ = ptr.Elem() + } + + if i, ok := symbolsToReceiver[typ]; ok { + symbols[i].Children = append(symbols[i].Children, methods...) + } else { + // The type definition for the receiver of these methods was not in the document. + symbols = append(symbols, methods...) + } + } + return symbols } @@ -102,24 +133,88 @@ func funcSymbol(decl *ast.FuncDecl, obj types.Object, fset *token.FileSet, q typ return s } -func typeSymbol(spec *ast.TypeSpec, obj types.Object, fset *token.FileSet, q types.Qualifier) Symbol { - s := Symbol{ - Name: obj.Name(), - Kind: StructSymbol, - } - if types.IsInterface(obj.Type()) { +func setKind(s *Symbol, typ types.Type, q types.Qualifier) { + switch typ := typ.Underlying().(type) { + case *types.Interface: s.Kind = InterfaceSymbol + case *types.Struct: + s.Kind = StructSymbol + case *types.Signature: + s.Kind = FunctionSymbol + if typ.Recv() != nil { + s.Kind = MethodSymbol + } + case *types.Named: + setKind(s, typ.Underlying(), q) + case *types.Basic: + i := typ.Info() + switch { + case i&types.IsNumeric != 0: + s.Kind = NumberSymbol + case i&types.IsBoolean != 0: + s.Kind = BooleanSymbol + case i&types.IsString != 0: + s.Kind = StringSymbol + } + default: + s.Kind = VariableSymbol } +} + +func typeSymbol(spec *ast.TypeSpec, obj types.Object, fset *token.FileSet, q types.Qualifier) Symbol { + s := Symbol{Name: obj.Name()} + s.Detail, _ = formatType(obj.Type(), q) + setKind(&s, obj.Type(), q) + if span, err := nodeSpan(spec, fset); err == nil { s.Span = span } if span, err := nodeSpan(spec.Name, fset); err == nil { s.SelectionSpan = span } - s.Detail, _ = formatType(obj.Type(), q) + + if t, ok := obj.Type().Underlying().(*types.Struct); ok { + st := spec.Type.(*ast.StructType) + for i := 0; i < t.NumFields(); i++ { + f := t.Field(i) + child := Symbol{Name: f.Name(), Kind: FieldSymbol} + child.Detail, _ = formatType(f.Type(), q) + + spanNode, selectionNode := nodesForStructField(i, st) + if span, err := nodeSpan(spanNode, fset); err == nil { + child.Span = span + } + if span, err := nodeSpan(selectionNode, fset); err == nil { + child.SelectionSpan = span + } + + s.Children = append(s.Children, child) + } + } + return s } +func nodesForStructField(i int, st *ast.StructType) (span, selection ast.Node) { + j := 0 + for _, field := range st.Fields.List { + if len(field.Names) == 0 { + if i == j { + return field, field.Type + } + j++ + continue + } + for _, name := range field.Names { + if i == j { + return field, name + } + j++ + } + } + return nil, nil +} + func varSymbol(decl ast.Node, name *ast.Ident, obj types.Object, fset *token.FileSet, q types.Qualifier) Symbol { s := Symbol{ Name: obj.Name(), @@ -139,6 +234,9 @@ func varSymbol(decl ast.Node, name *ast.Ident, obj types.Object, fset *token.Fil } func nodeSpan(n ast.Node, fset *token.FileSet) (span.Span, error) { + if n == nil { + return span.Span{}, errors.New("no span for nil node") + } r := span.NewRange(fset, n.Pos(), n.End()) return r.Span() } diff --git a/internal/lsp/symbols.go b/internal/lsp/symbols.go index ae15e08c65..6ac09e0cf8 100644 --- a/internal/lsp/symbols.go +++ b/internal/lsp/symbols.go @@ -45,6 +45,14 @@ func toProtocolSymbolKind(kind source.SymbolKind) protocol.SymbolKind { return protocol.Method case source.InterfaceSymbol: return protocol.Interface + case source.NumberSymbol: + return protocol.Number + case source.StringSymbol: + return protocol.String + case source.BooleanSymbol: + return protocol.Boolean + case source.FieldSymbol: + return protocol.Field default: return 0 } diff --git a/internal/lsp/testdata/symbols/main.go b/internal/lsp/testdata/symbols/main.go index df11cb3602..93ace00cdf 100644 --- a/internal/lsp/testdata/symbols/main.go +++ b/internal/lsp/testdata/symbols/main.go @@ -1,27 +1,43 @@ package main -var x = 42 //@symbol("x", "x", 13) +import "io" -const y = 43 //@symbol("y", "y", 14) +var x = 42 //@symbol("x", "x", 13, "") -type Foo struct { //@symbol("Foo", "Foo", 23) - Quux - Bar int - baz string +const y = 43 //@symbol("y", "y", 14, "") + +type Number int //@symbol("Number", "Number", 16, "") + +type Alias = string //@symbol("Alias", "Alias", 15, "") + +type NumberAlias = Number //@symbol("NumberAlias", "NumberAlias", 16, "") + +type ( + Boolean bool //@symbol("Boolean", "Boolean", 17, "") + BoolAlias = bool //@symbol("BoolAlias", "BoolAlias", 17, "") +) + +type Foo struct { //@symbol("Foo", "Foo", 23, "") + Quux //@symbol("Quux", "Quux", 8, "Foo") + W io.Writer //@symbol("W" , "W", 8, "Foo") + Bar int //@symbol("Bar", "Bar", 8, "Foo") + baz string //@symbol("baz", "baz", 8, "Foo") } -type Quux struct { //@symbol("Quux", "Quux", 23) - X float64 +type Quux struct { //@symbol("Quux", "Quux", 23, "") + X, Y float64 //@symbol("X", "X", 8, "Quux"), symbol("Y", "Y", 8, "Quux") } -func (f Foo) Baz() string { //@symbol("Baz", "Baz", 6) +func (f Foo) Baz() string { //@symbol("Baz", "Baz", 6, "Foo") return f.baz } -func main() { //@symbol("main", "main", 12) +func (q *Quux) Do() {} //@symbol("Do", "Do", 6, "Quux") + +func main() { //@symbol("main", "main", 12, "") } -type Stringer interface { //@symbol("Stringer", "Stringer", 11) +type Stringer interface { //@symbol("Stringer", "Stringer", 11, "") String() string }