diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index bccf80d6b5..6b2a9d1eb1 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -664,6 +664,10 @@ type typeInference struct { // wantTypeName is true if we expect the name of a type. wantTypeName bool + + // modifiers are prefixes such as "*", "&" or "<-" that influence how + // a candidate type relates to the expected type. + modifiers []typeModifier } // expectedType returns information about the expected type for an expression at @@ -796,25 +800,30 @@ Nodes: } } - if typ != nil { - for _, mod := range modifiers { - switch mod { - case dereference: - // For every "*" deref operator, add another pointer layer to expected type. - typ = types.NewPointer(typ) - case reference: - // For every "&" ref operator, remove a pointer layer from expected type. - typ = deref(typ) - case chanRead: - // For every "<-" operator, add another layer of channelness. - typ = types.NewChan(types.SendRecv, typ) + return typeInference{ + objType: typ, + modifiers: modifiers, + } +} + +// applyTypeModifiers applies the list of type modifiers to a type. +func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type { + for _, mod := range ti.modifiers { + switch mod { + case dereference: + // For every "*" deref operator, remove a pointer layer from candidate type. + typ = deref(typ) + case reference: + // For every "&" ref operator, add another pointer layer to candidate type. + typ = types.NewPointer(typ) + case chanRead: + // For every "<-" operator, remove a layer of channelness. + if ch, ok := typ.(*types.Chan); ok { + typ = ch.Elem() } } } - - return typeInference{ - objType: typ, - } + return typ } // findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or @@ -916,6 +925,9 @@ func (c *completer) matchingType(obj types.Object) bool { } } + // Take into account any type modifiers on the expected type. + actual = c.expectedType.applyTypeModifiers(actual) + if c.expectedType.objType != nil { // AssignableTo covers the case where the types are equal, but also handles // cases like assigning a concrete type to an interface type. diff --git a/internal/lsp/testdata/channel/channel.go b/internal/lsp/testdata/channel/channel.go index a83b895325..dc559513bf 100644 --- a/internal/lsp/testdata/channel/channel.go +++ b/internal/lsp/testdata/channel/channel.go @@ -20,6 +20,6 @@ func _() { { var foo chan int //@item(channelFoo, "foo", "chan int", "var") wantsInt := func(int) {} //@item(channelWantsInt, "wantsInt", "func(int)", "var") - wantsInt(<-) //@complete(")", channelFoo, channelWantsInt, channelAA, channelAB) + wantsInt(<-) //@complete(")", channelFoo, channelAB, channelWantsInt, channelAA) } } diff --git a/internal/lsp/testdata/interfacerank/interface_rank.go b/internal/lsp/testdata/interfacerank/interface_rank.go index 968c1a6a0d..acb5a42e0a 100644 --- a/internal/lsp/testdata/interfacerank/interface_rank.go +++ b/internal/lsp/testdata/interfacerank/interface_rank.go @@ -17,4 +17,7 @@ func _() { ) wantsFoo(a) //@complete(")", irAB, irAA) + + var ac fooImpl //@item(irAC, "ac", "fooImpl", "var") + wantsFoo(&a) //@complete(")", irAC, irAA, irAB) } diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index efdd7382b7..cfbb0615c0 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -25,7 +25,7 @@ import ( // We hardcode the expected number of test cases to ensure that all tests // are being executed. If a test is added, this number must be changed. const ( - ExpectedCompletionsCount = 124 + ExpectedCompletionsCount = 125 ExpectedCompletionSnippetCount = 14 ExpectedDiagnosticsCount = 17 ExpectedFormatCount = 5