internal/lsp: improve completion for *ast.ArrayTypes

*ast.ArrayTypes are type expressions like "[]foo" or "[2]int". They
show up as standalone types (e.g. "var foo []int") and as part of
composite literals (e.g. "[]int{}"). I made the following
improvements:

- Always expect a type name for array types.
- Add a "type modifier" for array types so completions can be smart
  when we know the expected type. For example:

var foo []int
foo = []i<>

  we know we want a type name, but we also know the expected type is
  "[]int". When evaluating type names such as "int" we turn the type
  into a slice type "[]int" to match against the expected type.
- Tweak the AST fixing to add a phantom selector "_" after a naked
  "[]" so you can complete directly after the right bracket.

I split out the type name related type inference bits into a separate
typeNameInference struct. It had become confusing and complicated,
especially now that you can have an expected type and expect a type
name at the same time.

Change-Id: I00878532187ee5366ab8d681346532e36fa58e5f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/197438
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Muir Manders 2019-09-13 12:15:53 -07:00 committed by Rebecca Stambler
parent ff611c50cd
commit 98e333b8b3
6 changed files with 153 additions and 66 deletions

View File

@ -255,9 +255,16 @@ func fixArrayType(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte
return errors.Errorf("invalid BadExpr from/to: %d/%d", from, to) return errors.Errorf("invalid BadExpr from/to: %d/%d", from, to)
} }
exprBytes := make([]byte, 0, int(to-from)+2) exprBytes := make([]byte, 0, int(to-from)+3)
// Avoid doing tok.Offset(to) since that panics if badExpr ends at EOF. // Avoid doing tok.Offset(to) since that panics if badExpr ends at EOF.
exprBytes = append(exprBytes, src[tok.Offset(from):tok.Offset(to-1)+1]...) exprBytes = append(exprBytes, src[tok.Offset(from):tok.Offset(to-1)+1]...)
exprBytes = bytes.TrimSpace(exprBytes)
// If our expression ends in "]" (e.g. "[]"), add a phantom selector
// so we can complete directly after the "[]".
if len(exprBytes) > 0 && exprBytes[len(exprBytes)-1] == ']' {
exprBytes = append(exprBytes, '_')
}
// Add "{}" to turn our ArrayType into a CompositeLit. This is to // Add "{}" to turn our ArrayType into a CompositeLit. This is to
// handle the case of "[...]int" where we must make it a composite // handle the case of "[...]int" where we must make it a composite

View File

@ -7,6 +7,7 @@ package source
import ( import (
"context" "context"
"go/ast" "go/ast"
"go/constant"
"go/token" "go/token"
"go/types" "go/types"
"strings" "strings"
@ -561,7 +562,7 @@ func (c *completer) wantStructFieldCompletions() bool {
} }
func (c *completer) wantTypeName() bool { func (c *completer) wantTypeName() bool {
return c.expectedType.wantTypeName return c.expectedType.typeName.wantTypeName
} }
// selector finds completions for the specified selector expression. // selector finds completions for the specified selector expression.
@ -788,8 +789,10 @@ func enclosingCompositeLiteral(path []ast.Node, pos token.Pos, info *types.Info)
// //
// The position is not part of the composite literal unless it falls within the // The position is not part of the composite literal unless it falls within the
// curly braces (e.g. "foo.Foo<>Struct{}"). // curly braces (e.g. "foo.Foo<>Struct{}").
if !(n.Lbrace <= pos && pos <= n.Rbrace) { if !(n.Lbrace < pos && pos <= n.Rbrace) {
return nil // Keep searching since we may yet be inside a composite literal.
// For example "Foo{B: Ba<>{}}".
break
} }
tv, ok := info.Types[n] tv, ok := info.Types[n]
@ -937,12 +940,19 @@ func (c *completer) expectedCompositeLiteralType() types.Type {
} }
// typeModifier represents an operator that changes the expected type. // typeModifier represents an operator that changes the expected type.
type typeModifier int type typeModifier struct {
mod typeMod
arrayLen int64
}
type typeMod int
const ( const (
star typeModifier = iota // dereference operator for expressions, pointer indicator for types star typeMod = iota // dereference operator for expressions, pointer indicator for types
reference // reference ("&") operator reference // reference ("&") operator
chanRead // channel read ("<-") operator chanRead // channel read ("<-") operator
slice // make a slice type ("[]" in "[]int")
array // make an array type ("[2]" in "[2]int")
) )
// typeInference holds information we have inferred about a type that can be // typeInference holds information we have inferred about a type that can be
@ -955,6 +965,21 @@ type typeInference struct {
// variadic param. // variadic param.
variadic bool variadic bool
// modifiers are prefixes such as "*", "&" or "<-" that influence how
// a candidate type relates to the expected type.
modifiers []typeModifier
// convertibleTo is a type our candidate type must be convertible to.
convertibleTo types.Type
// typeName holds information about the expected type name at
// position, if any.
typeName typeNameInference
}
// typeNameInference holds information about the expected type name at
// position.
type typeNameInference struct {
// wantTypeName is true if we expect the name of a type. // wantTypeName is true if we expect the name of a type.
wantTypeName bool wantTypeName bool
@ -964,29 +989,20 @@ type typeInference struct {
// assertableFrom is a type that must be assertable to our candidate type. // assertableFrom is a type that must be assertable to our candidate type.
assertableFrom types.Type assertableFrom types.Type
// convertibleTo is a type our candidate type must be convertible to.
convertibleTo types.Type
} }
// expectedType returns information about the expected type for an expression at // expectedType returns information about the expected type for an expression at
// the query position. // the query position.
func expectedType(c *completer) typeInference { func expectedType(c *completer) typeInference {
if ti := expectTypeName(c); ti.wantTypeName { inf := typeInference{
return ti typeName: expectTypeName(c),
} }
if c.enclosingCompositeLiteral != nil { if c.enclosingCompositeLiteral != nil {
return typeInference{objType: c.expectedCompositeLiteralType()} inf.objType = c.expectedCompositeLiteralType()
return inf
} }
var (
modifiers []typeModifier
variadic bool
typ types.Type
convertibleTo types.Type
)
Nodes: Nodes:
for i, node := range c.path { for i, node := range c.path {
switch node := node.(type) { switch node := node.(type) {
@ -997,7 +1013,7 @@ Nodes:
e = node.Y e = node.Y
} }
if tv, ok := c.pkg.GetTypesInfo().Types[e]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[e]; ok {
typ = tv.Type inf.objType = tv.Type
break Nodes break Nodes
} }
case *ast.AssignStmt: case *ast.AssignStmt:
@ -1008,120 +1024,115 @@ Nodes:
i = len(node.Lhs) - 1 i = len(node.Lhs) - 1
} }
if tv, ok := c.pkg.GetTypesInfo().Types[node.Lhs[i]]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.Lhs[i]]; ok {
typ = tv.Type inf.objType = tv.Type
break Nodes break Nodes
} }
} }
return typeInference{} return inf
case *ast.CallExpr: case *ast.CallExpr:
// Only consider CallExpr args if position falls between parens. // Only consider CallExpr args if position falls between parens.
if node.Lparen <= c.pos && c.pos <= node.Rparen { if node.Lparen <= c.pos && c.pos <= node.Rparen {
// For type conversions like "int64(foo)" we can only infer our // For type conversions like "int64(foo)" we can only infer our
// desired type is convertible to int64. // desired type is convertible to int64.
if typ := typeConversion(node, c.pkg.GetTypesInfo()); typ != nil { if typ := typeConversion(node, c.pkg.GetTypesInfo()); typ != nil {
convertibleTo = typ inf.convertibleTo = typ
break Nodes break Nodes
} }
if tv, ok := c.pkg.GetTypesInfo().Types[node.Fun]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.Fun]; ok {
if sig, ok := tv.Type.(*types.Signature); ok { if sig, ok := tv.Type.(*types.Signature); ok {
if sig.Params().Len() == 0 { if sig.Params().Len() == 0 {
return typeInference{} return inf
} }
i := indexExprAtPos(c.pos, node.Args) i := indexExprAtPos(c.pos, node.Args)
// Make sure not to run past the end of expected parameters. // Make sure not to run past the end of expected parameters.
if i >= sig.Params().Len() { if i >= sig.Params().Len() {
i = sig.Params().Len() - 1 i = sig.Params().Len() - 1
} }
typ = sig.Params().At(i).Type() inf.objType = sig.Params().At(i).Type()
variadic = sig.Variadic() && i == sig.Params().Len()-1 inf.variadic = sig.Variadic() && i == sig.Params().Len()-1
break Nodes break Nodes
} }
} }
} }
return typeInference{} return inf
case *ast.ReturnStmt: case *ast.ReturnStmt:
if c.enclosingFunc != nil { if c.enclosingFunc != nil {
sig := c.enclosingFunc.sig sig := c.enclosingFunc.sig
// Find signature result that corresponds to our return statement. // Find signature result that corresponds to our return statement.
if resultIdx := indexExprAtPos(c.pos, node.Results); resultIdx < len(node.Results) { if resultIdx := indexExprAtPos(c.pos, node.Results); resultIdx < len(node.Results) {
if resultIdx < sig.Results().Len() { if resultIdx < sig.Results().Len() {
typ = sig.Results().At(resultIdx).Type() inf.objType = sig.Results().At(resultIdx).Type()
break Nodes break Nodes
} }
} }
} }
return typeInference{} return inf
case *ast.CaseClause: case *ast.CaseClause:
if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok { if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok {
if tv, ok := c.pkg.GetTypesInfo().Types[swtch.Tag]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[swtch.Tag]; ok {
typ = tv.Type inf.objType = tv.Type
break Nodes break Nodes
} }
} }
return typeInference{} return inf
case *ast.SliceExpr: case *ast.SliceExpr:
// Make sure position falls within the brackets (e.g. "foo[a:<>]"). // Make sure position falls within the brackets (e.g. "foo[a:<>]").
if node.Lbrack < c.pos && c.pos <= node.Rbrack { if node.Lbrack < c.pos && c.pos <= node.Rbrack {
typ = types.Typ[types.Int] inf.objType = types.Typ[types.Int]
break Nodes break Nodes
} }
return typeInference{} return inf
case *ast.IndexExpr: case *ast.IndexExpr:
// Make sure position falls within the brackets (e.g. "foo[<>]"). // Make sure position falls within the brackets (e.g. "foo[<>]").
if node.Lbrack < c.pos && c.pos <= node.Rbrack { if node.Lbrack < c.pos && c.pos <= node.Rbrack {
if tv, ok := c.pkg.GetTypesInfo().Types[node.X]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.X]; ok {
switch t := tv.Type.Underlying().(type) { switch t := tv.Type.Underlying().(type) {
case *types.Map: case *types.Map:
typ = t.Key() inf.objType = t.Key()
case *types.Slice, *types.Array: case *types.Slice, *types.Array:
typ = types.Typ[types.Int] inf.objType = types.Typ[types.Int]
default: default:
return typeInference{} return inf
} }
break Nodes break Nodes
} }
} }
return typeInference{} return inf
case *ast.SendStmt: case *ast.SendStmt:
// Make sure we are on right side of arrow (e.g. "foo <- <>"). // Make sure we are on right side of arrow (e.g. "foo <- <>").
if c.pos > node.Arrow+1 { if c.pos > node.Arrow+1 {
if tv, ok := c.pkg.GetTypesInfo().Types[node.Chan]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.Chan]; ok {
if ch, ok := tv.Type.Underlying().(*types.Chan); ok { if ch, ok := tv.Type.Underlying().(*types.Chan); ok {
typ = ch.Elem() inf.objType = ch.Elem()
break Nodes break Nodes
} }
} }
} }
return typeInference{} return inf
case *ast.StarExpr: case *ast.StarExpr:
modifiers = append(modifiers, star) inf.modifiers = append(inf.modifiers, typeModifier{mod: star})
case *ast.UnaryExpr: case *ast.UnaryExpr:
switch node.Op { switch node.Op {
case token.AND: case token.AND:
modifiers = append(modifiers, reference) inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
case token.ARROW: case token.ARROW:
modifiers = append(modifiers, chanRead) inf.modifiers = append(inf.modifiers, typeModifier{mod: chanRead})
} }
default: default:
if breaksExpectedTypeInference(node) { if breaksExpectedTypeInference(node) {
return typeInference{} return inf
} }
} }
} }
return typeInference{ return inf
variadic: variadic,
objType: typ,
modifiers: modifiers,
convertibleTo: convertibleTo,
}
} }
// applyTypeModifiers applies the list of type modifiers to a type. // applyTypeModifiers applies the list of type modifiers to a type.
func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type { func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type {
for _, mod := range ti.modifiers { for _, mod := range ti.modifiers {
switch mod { switch mod.mod {
case star: case star:
// For every "*" deref operator, remove a pointer layer from candidate type. // For every "*" deref operator, remove a pointer layer from candidate type.
typ = deref(typ) typ = deref(typ)
@ -1140,11 +1151,15 @@ func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type {
// applyTypeNameModifiers applies the list of type modifiers to a type name. // applyTypeNameModifiers applies the list of type modifiers to a type name.
func (ti typeInference) applyTypeNameModifiers(typ types.Type) types.Type { func (ti typeInference) applyTypeNameModifiers(typ types.Type) types.Type {
for _, mod := range ti.modifiers { for _, mod := range ti.typeName.modifiers {
switch mod { switch mod.mod {
case star: case star:
// For every "*" indicator, add a pointer layer to type name. // For every "*" indicator, add a pointer layer to type name.
typ = types.NewPointer(typ) typ = types.NewPointer(typ)
case array:
typ = types.NewArray(typ, mod.arrayLen)
case slice:
typ = types.NewSlice(typ)
} }
} }
return typ return typ
@ -1187,7 +1202,7 @@ func breaksExpectedTypeInference(n ast.Node) bool {
} }
// expectTypeName returns information about the expected type name at position. // expectTypeName returns information about the expected type name at position.
func expectTypeName(c *completer) typeInference { func expectTypeName(c *completer) typeNameInference {
var ( var (
wantTypeName bool wantTypeName bool
modifiers []typeModifier modifiers []typeModifier
@ -1219,7 +1234,7 @@ Nodes:
wantTypeName = true wantTypeName = true
break Nodes break Nodes
} }
return typeInference{} return typeNameInference{}
case *ast.TypeAssertExpr: case *ast.TypeAssertExpr:
// Expect type names in type assert expressions. // Expect type names in type assert expressions.
if n.Lparen < c.pos && c.pos <= n.Rparen { if n.Lparen < c.pos && c.pos <= n.Rparen {
@ -1228,17 +1243,51 @@ Nodes:
wantTypeName = true wantTypeName = true
break Nodes break Nodes
} }
return typeInference{} return typeNameInference{}
case *ast.StarExpr: case *ast.StarExpr:
modifiers = append(modifiers, star) modifiers = append(modifiers, typeModifier{mod: star})
case *ast.CompositeLit:
// We want a type name if position is in the "Type" part of a
// composite literal (e.g. "Foo<>{}").
if n.Type != nil && n.Type.Pos() <= c.pos && c.pos <= n.Type.End() {
wantTypeName = true
}
break Nodes
case *ast.ArrayType:
// If we are inside the "Elt" part of an array type, we want a type name.
if n.Elt.Pos() <= c.pos && c.pos <= n.Elt.End() {
wantTypeName = true
if n.Len == nil {
// No "Len" expression means a slice type.
modifiers = append(modifiers, typeModifier{mod: slice})
} else {
// Try to get the array type using the constant value of "Len".
tv, ok := c.pkg.GetTypesInfo().Types[n.Len]
if ok && tv.Value != nil && tv.Value.Kind() == constant.Int {
if arrayLen, ok := constant.Int64Val(tv.Value); ok {
modifiers = append(modifiers, typeModifier{mod: array, arrayLen: arrayLen})
}
}
}
// ArrayTypes can be nested, so keep going if our parent is an
// ArrayType.
if i < len(c.path)-1 {
if _, ok := c.path[i+1].(*ast.ArrayType); ok {
continue Nodes
}
}
break Nodes
}
default: default:
if breaksExpectedTypeInference(p) { if breaksExpectedTypeInference(p) {
return typeInference{} return typeNameInference{}
} }
} }
} }
return typeInference{ return typeNameInference{
wantTypeName: wantTypeName, wantTypeName: wantTypeName,
modifiers: modifiers, modifiers: modifiers,
assertableFrom: assertableFrom, assertableFrom: assertableFrom,
@ -1256,6 +1305,9 @@ func (c *completer) matchingType(T types.Type) bool {
func (c *completer) matchingCandidate(cand *candidate) bool { func (c *completer) matchingCandidate(cand *candidate) bool {
if isTypeName(cand.obj) { if isTypeName(cand.obj) {
return c.matchingTypeName(cand) return c.matchingTypeName(cand)
} else if c.wantTypeName() {
// If we want a type, a non-type object never matches.
return false
} }
objType := cand.obj.Type() objType := cand.obj.Type()
@ -1323,20 +1375,30 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
// Take into account any type name modifier prefixes. // Take into account any type name modifier prefixes.
actual := c.expectedType.applyTypeNameModifiers(cand.obj.Type()) actual := c.expectedType.applyTypeNameModifiers(cand.obj.Type())
if c.expectedType.assertableFrom != nil { if c.expectedType.typeName.assertableFrom != nil {
// Don't suggest the starting type in type assertions. For example, // Don't suggest the starting type in type assertions. For example,
// if "foo" is an io.Writer, don't suggest "foo.(io.Writer)". // if "foo" is an io.Writer, don't suggest "foo.(io.Writer)".
if types.Identical(c.expectedType.assertableFrom, actual) { if types.Identical(c.expectedType.typeName.assertableFrom, actual) {
return false return false
} }
if intf, ok := c.expectedType.assertableFrom.Underlying().(*types.Interface); ok { if intf, ok := c.expectedType.typeName.assertableFrom.Underlying().(*types.Interface); ok {
if !types.AssertableTo(intf, actual) { if !types.AssertableTo(intf, actual) {
return false return false
} }
} }
} }
// We can expect a type name and have an expected type in cases like:
//
// var foo []int
// foo = []i<>
//
// Where our expected type is "[]int", and we expect a type name.
if c.expectedType.objType != nil {
return types.AssignableTo(actual, c.expectedType.objType)
}
// Default to saying any type name is a match. // Default to saying any type name is a match.
return true return true
} }

View File

@ -15,6 +15,8 @@ func _() {
[]foo.StructFoo //@complete(" //", StructFoo) []foo.StructFoo //@complete(" //", StructFoo)
[]foo.StructFoo(nil) //@complete("(", StructFoo)
[]*foo.StructFoo //@complete(" //", StructFoo) []*foo.StructFoo //@complete(" //", StructFoo)
[...]foo.StructFoo //@complete(" //", StructFoo) [...]foo.StructFoo //@complete(" //", StructFoo)
@ -23,3 +25,19 @@ func _() {
[]struct { f []foo.StructFoo } //@complete(" }", StructFoo) []struct { f []foo.StructFoo } //@complete(" }", StructFoo)
} }
func _() {
type myInt int //@item(atMyInt, "myInt", "int", "type")
var mark []myInt //@item(atMark, "mark", "[]myInt", "var")
var s []myInt //@item(atS, "s", "[]myInt", "var")
s = []m //@complete(" //", atMyInt, atMark)
s = [] //@complete(" //", atMyInt, atMark, atS, PackageFoo)
var a [1]myInt
a = [1]m //@complete(" //", atMyInt, atMark)
var ds [][]myInt
ds = [][]m //@complete(" //", atMyInt, atMark)
}

View File

@ -10,7 +10,7 @@ func helper(i foo.IntFoo) {} //@item(helper, "helper", "func(i foo.IntFoo)", "fu
func _() { func _() {
help //@complete("l", helper) help //@complete("l", helper)
_ = foo.StructFoo{} //@complete("S", Foo, IntFoo, StructFoo) _ = foo.StructFoo{} //@complete("S", IntFoo, StructFoo, Foo)
} }
// Bar is a function. // Bar is a function.

View File

@ -9,6 +9,6 @@ type ncBar struct { //@item(structNCBar, "ncBar", "struct{...}", "struct")
func _() { func _() {
[]ncFoo{} //@item(litNCFoo, "[]ncFoo{}", "", "var") []ncFoo{} //@item(litNCFoo, "[]ncFoo{}", "", "var")
_ := ncBar{ _ := ncBar{
baz: [] //@complete(" //", litNCFoo, structNCBar, structNCFoo) baz: [] //@complete(" //", structNCFoo, structNCBar)
} }
} }

View File

@ -1,5 +1,5 @@
-- summary -- -- summary --
CompletionsCount = 211 CompletionsCount = 216
CompletionSnippetCount = 39 CompletionSnippetCount = 39
UnimportedCompletionsCount = 1 UnimportedCompletionsCount = 1
DeepCompletionsCount = 5 DeepCompletionsCount = 5