internal/lsp: derive ASTs from type information

In the case of documentation items for completion items, we should make
sure to use the ASTs and type information for the originating package.
To do this while avoiding race conditions, we have to do this by
breadth-first searching the top-level package and its dependencies.

Change-Id: Id657be969ca3e400bb2bbd769a82d88e91865764
Reviewed-on: https://go-review.googlesource.com/c/tools/+/194477
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
This commit is contained in:
Rebecca Stambler 2019-09-09 20:22:42 -04:00
parent dd2b5c81c5
commit 238129aa63
9 changed files with 158 additions and 109 deletions

View File

@ -7,6 +7,7 @@ package cache
import ( import (
"context" "context"
"go/ast" "go/ast"
"go/token"
"go/types" "go/types"
"sort" "sort"
"sync" "sync"
@ -14,6 +15,8 @@ import (
"golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/packages" "golang.org/x/tools/go/packages"
"golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/lsp/source"
"golang.org/x/tools/internal/span"
errors "golang.org/x/xerrors"
) )
// pkg contains the type information needed by the source package. // pkg contains the type information needed by the source package.
@ -199,3 +202,32 @@ func (pkg *pkg) GetDiagnostics() []source.Diagnostic {
} }
return diags return diags
} }
func (p *pkg) FindFile(ctx context.Context, uri span.URI, pos token.Pos) (source.ParseGoHandle, *ast.File, source.Package, error) {
queue := []*pkg{p}
seen := make(map[string]bool)
for len(queue) > 0 {
pkg := queue[0]
queue = queue[1:]
seen[pkg.ID()] = true
for _, ph := range pkg.files {
if ph.File().Identity().URI == uri {
file, err := ph.Cached(ctx)
if file == nil {
return nil, nil, nil, err
}
if file.Pos() <= pos && pos <= file.End() {
return ph, file, pkg, nil
}
}
}
for _, dep := range pkg.imports {
if !seen[dep.ID()] {
queue = append(queue, dep)
}
}
}
return nil, nil, nil, errors.Errorf("no file for %s", uri)
}

View File

@ -138,11 +138,10 @@ func (pm prefixMatcher) Score(candidateLabel string) float32 {
// completer contains the necessary information for a single completion request. // completer contains the necessary information for a single completion request.
type completer struct { type completer struct {
// Package-specific fields. pkg Package
types *types.Package
info *types.Info qf types.Qualifier
qf types.Qualifier opts CompletionOptions
opts CompletionOptions
// view is the View associated with this completion request. // view is the View associated with this completion request.
view View view View
@ -278,7 +277,7 @@ func (c *completer) getSurrounding() *Selection {
// found adds a candidate completion. We will also search through the object's // found adds a candidate completion. We will also search through the object's
// members for more candidates. // members for more candidates.
func (c *completer) found(obj types.Object, score float64, imp *imports.ImportInfo) { func (c *completer) found(obj types.Object, score float64, imp *imports.ImportInfo) {
if obj.Pkg() != nil && obj.Pkg() != c.types && !obj.Exported() { if obj.Pkg() != nil && obj.Pkg() != c.pkg.GetTypes() && !obj.Exported() {
// obj is not accessible because it lives in another package and is not // obj is not accessible because it lives in another package and is not
// exported. Don't treat it as a completion candidate. // exported. Don't treat it as a completion candidate.
return return
@ -430,8 +429,7 @@ func Completion(ctx context.Context, view View, f GoFile, pos protocol.Position,
clInfo := enclosingCompositeLiteral(path, rng.Start, pkg.GetTypesInfo()) clInfo := enclosingCompositeLiteral(path, rng.Start, pkg.GetTypesInfo())
c := &completer{ c := &completer{
types: pkg.GetTypes(), pkg: pkg,
info: pkg.GetTypesInfo(),
qf: qualifier(file, pkg.GetTypes(), pkg.GetTypesInfo()), qf: qualifier(file, pkg.GetTypes(), pkg.GetTypesInfo()),
view: view, view: view,
ctx: ctx, ctx: ctx,
@ -545,14 +543,14 @@ func (c *completer) wantTypeName() bool {
func (c *completer) selector(sel *ast.SelectorExpr) error { func (c *completer) selector(sel *ast.SelectorExpr) error {
// Is sel a qualified identifier? // Is sel a qualified identifier?
if id, ok := sel.X.(*ast.Ident); ok { if id, ok := sel.X.(*ast.Ident); ok {
if pkgname, ok := c.info.Uses[id].(*types.PkgName); ok { if pkgname, ok := c.pkg.GetTypesInfo().Uses[id].(*types.PkgName); ok {
c.packageMembers(pkgname) c.packageMembers(pkgname)
return nil return nil
} }
} }
// Invariant: sel is a true selector. // Invariant: sel is a true selector.
tv, ok := c.info.Types[sel.X] tv, ok := c.pkg.GetTypesInfo().Types[sel.X]
if !ok { if !ok {
return errors.Errorf("cannot resolve %s", sel.X) return errors.Errorf("cannot resolve %s", sel.X)
} }
@ -601,9 +599,9 @@ func (c *completer) lexical() error {
case *ast.FuncLit: case *ast.FuncLit:
n = node.Type n = node.Type
} }
scopes = append(scopes, c.info.Scopes[n]) scopes = append(scopes, c.pkg.GetTypesInfo().Scopes[n])
} }
scopes = append(scopes, c.types.Scope(), types.Universe) scopes = append(scopes, c.pkg.GetTypes().Scope(), types.Universe)
// Track seen variables to avoid showing completions for shadowed variables. // Track seen variables to avoid showing completions for shadowed variables.
// This works since we look at scopes from innermost to outermost. // This works since we look at scopes from innermost to outermost.
@ -631,7 +629,7 @@ func (c *completer) lexical() error {
node = c.path[i-1] node = c.path[i-1]
} }
if node != nil { if node != nil {
if resolved := resolveInvalid(obj, node, c.info); resolved != nil { if resolved := resolveInvalid(obj, node, c.pkg.GetTypesInfo()); resolved != nil {
obj = resolved obj = resolved
} }
} }
@ -681,7 +679,7 @@ func (c *completer) structLiteralFieldName() error {
} }
if key, ok := kvExpr.Key.(*ast.Ident); ok { if key, ok := kvExpr.Key.(*ast.Ident); ok {
if used, ok := c.info.Uses[key]; ok { if used, ok := c.pkg.GetTypesInfo().Uses[key]; ok {
if usedVar, ok := used.(*types.Var); ok { if usedVar, ok := used.(*types.Var); ok {
addedFields[usedVar] = true addedFields[usedVar] = true
} }
@ -924,7 +922,7 @@ Nodes:
if c.pos < node.OpPos { if c.pos < node.OpPos {
e = node.Y e = node.Y
} }
if tv, ok := c.info.Types[e]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[e]; ok {
typ = tv.Type typ = tv.Type
break Nodes break Nodes
} }
@ -935,7 +933,7 @@ Nodes:
if i >= len(node.Lhs) { if i >= len(node.Lhs) {
i = len(node.Lhs) - 1 i = len(node.Lhs) - 1
} }
if tv, ok := c.info.Types[node.Lhs[i]]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.Lhs[i]]; ok {
typ = tv.Type typ = tv.Type
break Nodes break Nodes
} }
@ -946,12 +944,12 @@ Nodes:
if node.Lparen <= c.pos && c.pos <= node.Rparen { if node.Lparen <= c.pos && c.pos <= node.Rparen {
// For type conversions like "int64(foo)" we can only infer our // For type conversions like "int64(foo)" we can only infer our
// desired type is convertible to int64. // desired type is convertible to int64.
if typ := typeConversion(node, c.info); typ != nil { if typ := typeConversion(node, c.pkg.GetTypesInfo()); typ != nil {
convertibleTo = typ convertibleTo = typ
break Nodes break Nodes
} }
if tv, ok := c.info.Types[node.Fun]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.Fun]; ok {
if sig, ok := tv.Type.(*types.Signature); ok { if sig, ok := tv.Type.(*types.Signature); ok {
if sig.Params().Len() == 0 { if sig.Params().Len() == 0 {
return typeInference{} return typeInference{}
@ -980,7 +978,7 @@ Nodes:
return typeInference{} return typeInference{}
case *ast.CaseClause: case *ast.CaseClause:
if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok { if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok {
if tv, ok := c.info.Types[swtch.Tag]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[swtch.Tag]; ok {
typ = tv.Type typ = tv.Type
break Nodes break Nodes
} }
@ -996,7 +994,7 @@ Nodes:
case *ast.IndexExpr: case *ast.IndexExpr:
// Make sure position falls within the brackets (e.g. "foo[<>]"). // Make sure position falls within the brackets (e.g. "foo[<>]").
if node.Lbrack < c.pos && c.pos <= node.Rbrack { if node.Lbrack < c.pos && c.pos <= node.Rbrack {
if tv, ok := c.info.Types[node.X]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.X]; ok {
switch t := tv.Type.Underlying().(type) { switch t := tv.Type.Underlying().(type) {
case *types.Map: case *types.Map:
typ = t.Key() typ = t.Key()
@ -1012,7 +1010,7 @@ Nodes:
case *ast.SendStmt: case *ast.SendStmt:
// Make sure we are on right side of arrow (e.g. "foo <- <>"). // Make sure we are on right side of arrow (e.g. "foo <- <>").
if c.pos > node.Arrow+1 { if c.pos > node.Arrow+1 {
if tv, ok := c.info.Types[node.Chan]; ok { if tv, ok := c.pkg.GetTypesInfo().Types[node.Chan]; ok {
if ch, ok := tv.Type.Underlying().(*types.Chan); ok { if ch, ok := tv.Type.Underlying().(*types.Chan); ok {
typ = ch.Elem() typ = ch.Elem()
break Nodes break Nodes
@ -1146,7 +1144,7 @@ Nodes:
// The case clause types must be assertable from the type switch parameter. // The case clause types must be assertable from the type switch parameter.
ast.Inspect(swtch.Assign, func(n ast.Node) bool { ast.Inspect(swtch.Assign, func(n ast.Node) bool {
if ta, ok := n.(*ast.TypeAssertExpr); ok { if ta, ok := n.(*ast.TypeAssertExpr); ok {
assertableFrom = c.info.TypeOf(ta.X) assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X)
return false return false
} }
return true return true
@ -1159,7 +1157,7 @@ Nodes:
// Expect type names in type assert expressions. // Expect type names in type assert expressions.
if n.Lparen < c.pos && c.pos <= n.Rparen { if n.Lparen < c.pos && c.pos <= n.Rparen {
// The type in parens must be assertable from the expression type. // The type in parens must be assertable from the expression type.
assertableFrom = c.info.TypeOf(n.X) assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X)
wantTypeName = true wantTypeName = true
break Nodes break Nodes
} }

View File

@ -124,32 +124,11 @@ func (c *completer) item(cand candidate) (CompletionItem, error) {
return item, nil return item, nil
} }
uri := span.FileURI(pos.Filename) uri := span.FileURI(pos.Filename)
f, err := c.view.GetFile(c.ctx, uri) _, file, pkg, err := c.pkg.FindFile(c.ctx, uri, obj.Pos())
if err != nil { if file == nil || pkg == nil {
return item, nil return item, nil
} }
gof, ok := f.(GoFile) ident, err := findIdentifier(c.ctx, c.view, []Package{pkg}, file, obj.Pos())
if !ok {
return item, nil
}
pkg, err := gof.GetCachedPackage(c.ctx)
if err != nil {
return item, nil
}
var ph ParseGoHandle
for _, h := range pkg.GetHandles() {
if h.File().Identity().URI == gof.URI() {
ph = h
}
}
if ph == nil {
return item, nil
}
file, _ := ph.Cached(c.ctx)
if file == nil {
return item, nil
}
ident, err := findIdentifier(c.ctx, c.view, gof, pkg, file, declRange.spanRange.Start)
if err != nil { if err != nil {
return item, nil return item, nil
} }

View File

@ -22,7 +22,7 @@ import (
type IdentifierInfo struct { type IdentifierInfo struct {
Name string Name string
View View View View
File GoFile File ParseGoHandle
mappedRange mappedRange
Type struct { Type struct {
@ -32,7 +32,7 @@ type IdentifierInfo struct {
Declaration Declaration Declaration Declaration
pkg Package pkgs []Package
ident *ast.Ident ident *ast.Ident
wasEmbeddedField bool wasEmbeddedField bool
qf types.Qualifier qf types.Qualifier
@ -48,7 +48,7 @@ type Declaration struct {
// Identifier returns identifier information for a position // Identifier returns identifier information for a position
// in a file, accounting for a potentially incomplete selector. // in a file, accounting for a potentially incomplete selector.
func Identifier(ctx context.Context, view View, f GoFile, pos protocol.Position) (*IdentifierInfo, error) { func Identifier(ctx context.Context, view View, f GoFile, pos protocol.Position) (*IdentifierInfo, error) {
file, pkg, m, err := fileToMapper(ctx, view, f.URI()) file, pkgs, m, err := fileToMapper(ctx, view, f.URI())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -60,17 +60,17 @@ func Identifier(ctx context.Context, view View, f GoFile, pos protocol.Position)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return findIdentifier(ctx, view, f, pkg, file, rng.Start) return findIdentifier(ctx, view, pkgs, file, rng.Start)
} }
func findIdentifier(ctx context.Context, view View, f GoFile, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { func findIdentifier(ctx context.Context, view View, pkgs []Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) {
if result, err := identifier(ctx, view, f, pkg, file, pos); err != nil || result != nil { if result, err := identifier(ctx, view, pkgs, file, pos); err != nil || result != nil {
return result, err return result, err
} }
// If the position is not an identifier but immediately follows // If the position is not an identifier but immediately follows
// an identifier or selector period (as is common when // an identifier or selector period (as is common when
// requesting a completion), use the path to the preceding node. // requesting a completion), use the path to the preceding node.
ident, err := identifier(ctx, view, f, pkg, file, pos-1) ident, err := identifier(ctx, view, pkgs, file, pos-1)
if ident == nil && err == nil { if ident == nil && err == nil {
err = errors.New("no identifier found") err = errors.New("no identifier found")
} }
@ -78,25 +78,36 @@ func findIdentifier(ctx context.Context, view View, f GoFile, pkg Package, file
} }
// identifier checks a single position for a potential identifier. // identifier checks a single position for a potential identifier.
func identifier(ctx context.Context, view View, f GoFile, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { func identifier(ctx context.Context, view View, pkgs []Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) {
ctx, done := trace.StartSpan(ctx, "source.identifier") ctx, done := trace.StartSpan(ctx, "source.identifier")
defer done() defer done()
var err error var err error
// Handle import specs separately, as there is no formal position for a package declaration. // Handle import specs separately, as there is no formal position for a package declaration.
if result, err := importSpec(ctx, view, f, file, pkg, pos); result != nil || err != nil { if result, err := importSpec(ctx, view, file, pkgs, pos); result != nil || err != nil {
return result, err return result, err
} }
path, _ := astutil.PathEnclosingInterval(file, pos, pos) path, _ := astutil.PathEnclosingInterval(file, pos, pos)
if path == nil { if path == nil {
return nil, errors.Errorf("can't find node enclosing position") return nil, errors.Errorf("can't find node enclosing position")
} }
uri := span.FileURI(view.Session().Cache().FileSet().Position(pos).Filename)
pkg, err := bestPackage(uri, pkgs)
if err != nil {
return nil, err
}
var ph ParseGoHandle
for _, h := range pkg.GetHandles() {
if h.File().Identity().URI == uri {
ph = h
}
}
result := &IdentifierInfo{ result := &IdentifierInfo{
View: view, View: view,
File: f, File: ph,
qf: qualifier(file, pkg.GetTypes(), pkg.GetTypesInfo()), qf: qualifier(file, pkg.GetTypes(), pkg.GetTypesInfo()),
pkg: pkg, pkgs: pkgs,
} }
switch node := path[0].(type) { switch node := path[0].(type) {
@ -137,7 +148,7 @@ func identifier(ctx context.Context, view View, f GoFile, pkg Package, file *ast
// Handle builtins separately. // Handle builtins separately.
if result.Declaration.obj.Parent() == types.Universe { if result.Declaration.obj.Parent() == types.Universe {
decl, ok := lookupBuiltinDecl(f.View(), result.Name).(ast.Node) decl, ok := lookupBuiltinDecl(view, result.Name).(ast.Node)
if !ok { if !ok {
return nil, errors.Errorf("no declaration for %s", result.Name) return nil, errors.Errorf("no declaration for %s", result.Name)
} }
@ -170,7 +181,7 @@ func identifier(ctx context.Context, view View, f GoFile, pkg Package, file *ast
if result.Declaration.mappedRange, err = objToMappedRange(ctx, view, result.Declaration.obj); err != nil { if result.Declaration.mappedRange, err = objToMappedRange(ctx, view, result.Declaration.obj); err != nil {
return nil, err return nil, err
} }
if result.Declaration.node, err = objToNode(ctx, view, pkg.GetTypes(), result.Declaration.obj, result.Declaration.mappedRange.spanRange); err != nil { if result.Declaration.node, err = objToNode(ctx, view, pkg, result.Declaration.obj); err != nil {
return nil, err return nil, err
} }
typ := pkg.GetTypesInfo().TypeOf(result.ident) typ := pkg.GetTypesInfo().TypeOf(result.ident)
@ -206,35 +217,15 @@ func hasErrorType(obj types.Object) bool {
return types.IsInterface(obj.Type()) && obj.Pkg() == nil && obj.Name() == "error" return types.IsInterface(obj.Type()) && obj.Pkg() == nil && obj.Name() == "error"
} }
func objToNode(ctx context.Context, view View, originPkg *types.Package, obj types.Object, rng span.Range) (ast.Decl, error) { func objToNode(ctx context.Context, view View, pkg Package, obj types.Object) (ast.Decl, error) {
s, err := rng.Span() uri := span.FileURI(view.Session().Cache().FileSet().Position(obj.Pos()).Filename)
if err != nil { _, declAST, _, err := pkg.FindFile(ctx, uri, obj.Pos())
return nil, err
}
f, err := view.GetFile(ctx, s.URI())
if err != nil {
return nil, err
}
declFile, ok := f.(GoFile)
if !ok {
return nil, errors.Errorf("%s is not a Go file", s.URI())
}
declPkg, err := declFile.GetCachedPackage(ctx)
if err != nil {
return nil, err
}
var declAST *ast.File
for _, ph := range declPkg.GetHandles() {
if ph.File().Identity().URI == f.URI() {
declAST, err = ph.Cached(ctx)
}
}
if declAST == nil { if declAST == nil {
return nil, err return nil, err
} }
path, _ := astutil.PathEnclosingInterval(declAST, rng.Start, rng.End) path, _ := astutil.PathEnclosingInterval(declAST, obj.Pos(), obj.Pos())
if path == nil { if path == nil {
return nil, errors.Errorf("no path for range %v", rng) return nil, errors.Errorf("no path for object %v", obj.Name())
} }
for _, node := range path { for _, node := range path {
switch node := node.(type) { switch node := node.(type) {
@ -255,7 +246,7 @@ func objToNode(ctx context.Context, view View, originPkg *types.Package, obj typ
} }
// importSpec handles positions inside of an *ast.ImportSpec. // importSpec handles positions inside of an *ast.ImportSpec.
func importSpec(ctx context.Context, view View, f GoFile, fAST *ast.File, pkg Package, pos token.Pos) (*IdentifierInfo, error) { func importSpec(ctx context.Context, view View, fAST *ast.File, pkgs []Package, pos token.Pos) (*IdentifierInfo, error) {
var imp *ast.ImportSpec var imp *ast.ImportSpec
for _, spec := range fAST.Imports { for _, spec := range fAST.Imports {
if spec.Path.Pos() <= pos && pos < spec.Path.End() { if spec.Path.Pos() <= pos && pos < spec.Path.End() {
@ -269,11 +260,22 @@ func importSpec(ctx context.Context, view View, f GoFile, fAST *ast.File, pkg Pa
if err != nil { if err != nil {
return nil, errors.Errorf("import path not quoted: %s (%v)", imp.Path.Value, err) return nil, errors.Errorf("import path not quoted: %s (%v)", imp.Path.Value, err)
} }
uri := span.FileURI(view.Session().Cache().FileSet().Position(pos).Filename)
pkg, err := bestPackage(uri, pkgs)
if err != nil {
return nil, err
}
var ph ParseGoHandle
for _, h := range pkg.GetHandles() {
if h.File().Identity().URI == uri {
ph = h
}
}
result := &IdentifierInfo{ result := &IdentifierInfo{
View: view, View: view,
File: f, File: ph,
Name: importPath, Name: importPath,
pkg: pkg, pkgs: pkgs,
} }
if result.mappedRange, err = posToRange(ctx, view, imp.Path.Pos(), imp.Path.End()); err != nil { if result.mappedRange, err = posToRange(ctx, view, imp.Path.Pos(), imp.Path.End()); err != nil {
return nil, err return nil, err

View File

@ -34,11 +34,7 @@ func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, erro
if i.Declaration.obj == nil { if i.Declaration.obj == nil {
return nil, errors.Errorf("no references for an import spec") return nil, errors.Errorf("no references for an import spec")
} }
pkgs, err := i.File.GetCachedPackages(ctx) for _, pkg := range i.pkgs {
if err != nil {
return nil, err
}
for _, pkg := range pkgs {
info := pkg.GetTypesInfo() info := pkg.GetTypesInfo()
if info == nil { if info == nil {
return nil, errors.Errorf("package %s has no types info", pkg.PkgPath()) return nil, errors.Errorf("package %s has no types info", pkg.PkgPath())

View File

@ -102,11 +102,15 @@ func (i *IdentifierInfo) Rename(ctx context.Context, view View, newName string)
if i.Declaration.obj.Parent() == types.Universe { if i.Declaration.obj.Parent() == types.Universe {
return nil, errors.Errorf("cannot rename builtin %q", i.Name) return nil, errors.Errorf("cannot rename builtin %q", i.Name)
} }
if i.pkg == nil || i.pkg.IsIllTyped() { pkg, err := bestPackage(i.File.File().Identity().URI, i.pkgs)
return nil, errors.Errorf("package for %s is ill typed", i.File.URI()) if err != nil {
return nil, err
}
if pkg == nil || pkg.IsIllTyped() {
return nil, errors.Errorf("package for %s is ill typed", i.File.File().Identity().URI)
} }
// Do not rename identifiers declared in another package. // Do not rename identifiers declared in another package.
if i.pkg.GetTypes() != i.Declaration.obj.Pkg() { if pkg.GetTypes() != i.Declaration.obj.Pkg() {
return nil, errors.Errorf("failed to rename because %q is declared in package %q", i.Name, i.Declaration.obj.Pkg().Name()) return nil, errors.Errorf("failed to rename because %q is declared in package %q", i.Name, i.Declaration.obj.Pkg().Name())
} }
@ -168,8 +172,12 @@ func (i *IdentifierInfo) getPkgName(ctx context.Context) (*IdentifierInfo, error
file *ast.File file *ast.File
err error err error
) )
for _, ph := range i.pkg.GetHandles() { pkg, err := bestPackage(i.File.File().Identity().URI, i.pkgs)
if ph.File().Identity().URI == i.File.URI() { if err != nil {
return nil, err
}
for _, ph := range pkg.GetHandles() {
if ph.File().Identity().URI == i.File.File().Identity().URI {
file, err = ph.Cached(ctx) file, err = ph.Cached(ctx)
} }
} }
@ -188,13 +196,13 @@ func (i *IdentifierInfo) getPkgName(ctx context.Context) (*IdentifierInfo, error
} }
// Look for the object defined at NamePos. // Look for the object defined at NamePos.
for _, obj := range i.pkg.GetTypesInfo().Defs { for _, obj := range pkg.GetTypesInfo().Defs {
pkgName, ok := obj.(*types.PkgName) pkgName, ok := obj.(*types.PkgName)
if ok && pkgName.Pos() == namePos { if ok && pkgName.Pos() == namePos {
return getPkgNameIdentifier(ctx, i, pkgName) return getPkgNameIdentifier(ctx, i, pkgName)
} }
} }
for _, obj := range i.pkg.GetTypesInfo().Implicits { for _, obj := range pkg.GetTypesInfo().Implicits {
pkgName, ok := obj.(*types.PkgName) pkgName, ok := obj.(*types.PkgName)
if ok && pkgName.Pos() == namePos { if ok && pkgName.Pos() == namePos {
return getPkgNameIdentifier(ctx, i, pkgName) return getPkgNameIdentifier(ctx, i, pkgName)
@ -211,10 +219,14 @@ func getPkgNameIdentifier(ctx context.Context, ident *IdentifierInfo, pkgName *t
wasImplicit: true, wasImplicit: true,
} }
var err error var err error
if decl.mappedRange, err = objToMappedRange(ctx, ident.File.View(), decl.obj); err != nil { if decl.mappedRange, err = objToMappedRange(ctx, ident.View, decl.obj); err != nil {
return nil, err return nil, err
} }
if decl.node, err = objToNode(ctx, ident.File.View(), ident.pkg.GetTypes(), decl.obj, decl.mappedRange.spanRange); err != nil { pkg, err := bestPackage(ident.File.File().Identity().URI, ident.pkgs)
if err != nil {
return nil, err
}
if decl.node, err = objToNode(ctx, ident.View, pkg, decl.obj); err != nil {
return nil, err return nil, err
} }
return &IdentifierInfo{ return &IdentifierInfo{
@ -223,7 +235,7 @@ func getPkgNameIdentifier(ctx context.Context, ident *IdentifierInfo, pkgName *t
mappedRange: decl.mappedRange, mappedRange: decl.mappedRange,
File: ident.File, File: ident.File,
Declaration: decl, Declaration: decl,
pkg: ident.pkg, pkgs: ident.pkgs,
wasEmbeddedField: false, wasEmbeddedField: false,
qf: ident.qf, qf: ident.qf,
}, nil }, nil

View File

@ -31,7 +31,11 @@ func SignatureHelp(ctx context.Context, view View, f GoFile, pos protocol.Positi
ctx, done := trace.StartSpan(ctx, "source.SignatureHelp") ctx, done := trace.StartSpan(ctx, "source.SignatureHelp")
defer done() defer done()
file, pkg, m, err := fileToMapper(ctx, view, f.URI()) file, pkgs, m, err := fileToMapper(ctx, view, f.URI())
if err != nil {
return nil, err
}
pkg, err := bestPackage(f.URI(), pkgs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,11 +109,11 @@ FindCall:
comment *ast.CommentGroup comment *ast.CommentGroup
) )
if obj != nil { if obj != nil {
rng, err := objToMappedRange(ctx, view, obj) node, err := objToNode(ctx, f.View(), pkg, obj)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, err := objToNode(ctx, f.View(), pkg.GetTypes(), obj, rng.spanRange) rng, err := objToMappedRange(ctx, view, obj)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -51,7 +51,25 @@ func (s mappedRange) URI() span.URI {
return s.m.URI return s.m.URI
} }
func fileToMapper(ctx context.Context, view View, uri span.URI) (*ast.File, Package, *protocol.ColumnMapper, error) { // bestCheckPackageHandle picks the "narrowest" package for a given file.
//
// By "narrowest" package, we mean the package with the fewest number of files
// that includes the given file. This solves the problem of test variants,
// as the test will have more files than the non-test package.
func bestPackage(uri span.URI, pkgs []Package) (Package, error) {
var result Package
for _, pkg := range pkgs {
if result == nil || len(pkg.GetHandles()) < len(result.GetHandles()) {
result = pkg
}
}
if result == nil {
return nil, errors.Errorf("no CheckPackageHandle for %s", uri)
}
return result, nil
}
func fileToMapper(ctx context.Context, view View, uri span.URI) (*ast.File, []Package, *protocol.ColumnMapper, error) {
f, err := view.GetFile(ctx, uri) f, err := view.GetFile(ctx, uri)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -60,7 +78,11 @@ func fileToMapper(ctx context.Context, view View, uri span.URI) (*ast.File, Pack
if !ok { if !ok {
return nil, nil, nil, errors.Errorf("%s is not a Go file", f.URI()) return nil, nil, nil, errors.Errorf("%s is not a Go file", f.URI())
} }
pkg, err := gof.GetPackage(ctx) pkgs, err := gof.GetPackages(ctx)
if err != nil {
return nil, nil, nil, err
}
pkg, err := bestPackage(f.URI(), pkgs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -68,7 +90,7 @@ func fileToMapper(ctx context.Context, view View, uri span.URI) (*ast.File, Pack
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
return file, pkg, m, nil return file, pkgs, m, nil
} }
func cachedFileToMapper(ctx context.Context, view View, uri span.URI) (*ast.File, *protocol.ColumnMapper, error) { func cachedFileToMapper(ctx context.Context, view View, uri span.URI) (*ast.File, *protocol.ColumnMapper, error) {

View File

@ -318,4 +318,8 @@ type Package interface {
// GetActionGraph returns the action graph for the given package. // GetActionGraph returns the action graph for the given package.
GetActionGraph(ctx context.Context, a *analysis.Analyzer) (*Action, error) GetActionGraph(ctx context.Context, a *analysis.Analyzer) (*Action, error)
// FindFile returns the AST and type information for a file that may
// belong to or be part of a dependency of the given package.
FindFile(ctx context.Context, uri span.URI, pos token.Pos) (ParseGoHandle, *ast.File, Package, error)
} }