diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go index 766877f661..bece4f1d09 100644 --- a/internal/lsp/cache/snapshot.go +++ b/internal/lsp/cache/snapshot.go @@ -2,6 +2,7 @@ package cache import ( "context" + "fmt" "os" "sync" @@ -9,6 +10,7 @@ import ( "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/span" + "golang.org/x/tools/internal/telemetry/log" ) type snapshot struct { @@ -98,6 +100,29 @@ func (s *snapshot) getPackages(uri source.FileURI, m source.ParseMode) (cphs []s return cphs } +func (s *snapshot) KnownPackages(ctx context.Context) []source.Package { + // TODO(matloob): This function exists because KnownImportPaths can't + // determine the import paths of all packages. Remove this function + // if KnownImportPaths gains that ability. That could happen if + // go list or go packages provide that information. + s.mu.Lock() + defer s.mu.Unlock() + + var results []source.Package + for _, cph := range s.packages { + // Check the package now if it's not checked yet. + // TODO(matloob): is this too slow? + pkg, err := cph.check(ctx) + if err != nil { + log.Error(ctx, fmt.Sprintf("cph.Check of %v", cph.m.pkgPath), err) + continue + } + results = append(results, pkg) + } + + return results +} + func (s *snapshot) KnownImportPaths() map[string]source.Package { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index 20dd672eb9..2793333a6a 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -32,10 +32,21 @@ func Implementation(ctx context.Context, view View, f File, position protocol.Po } var objs []types.Object + pkgs := map[types.Object]Package{} if res.toMethod != nil { // If we looked up a method, results are in toMethod. for _, s := range res.toMethod { + // Determine package of receiver. + recv := s.Recv() + if p, ok := recv.(*types.Pointer); ok { + recv = p.Elem() + } + if n, ok := recv.(*types.Named); ok { + pkg := res.pkgs[n] + pkgs[s.Obj()] = pkg + } + // Add object to objs. objs = append(objs, s.Obj()) } } else { @@ -46,26 +57,49 @@ func Implementation(ctx context.Context, view View, f File, position protocol.Po t = p.Elem() } if n, ok := t.(*types.Named); ok { + pkg := res.pkgs[n] objs = append(objs, n.Obj()) + pkgs[n.Obj()] = pkg } } } var locations []protocol.Location - ph, pkg, err := view.FindFileInPackage(ctx, f.URI(), ident.pkg) - if err != nil { - return nil, err - } - af, _, _, err := ph.Cached() - if err != nil { - return nil, err - } for _, obj := range objs { - ident, err := findIdentifier(ctx, view.Snapshot(), pkg, af, obj.Pos()) + pkg := pkgs[obj] + if pkgs[obj] == nil || len(pkg.Files()) == 0 { + continue + } + // Search for the identifier in each of the package's files. + var ident *IdentifierInfo + + fset := view.Session().Cache().FileSet() + file := fset.File(obj.Pos()) + var containingFile FileHandle + for _, f := range pkg.Files() { + if f.File().Identity().URI.Filename() == file.Name() { + containingFile = f.File() + } + } + if containingFile == nil { + return nil, fmt.Errorf("Failed to find file %q in package %v", file.Name(), pkg.PkgPath()) + } + + uri := containingFile.Identity().URI + ph, _, err := view.FindFileInPackage(ctx, uri, pkgs[obj]) if err != nil { return nil, err } + astFile, _, _, err := ph.Cached() + if err != nil { + return nil, err + } + ident, err = findIdentifier(ctx, view.Snapshot(), pkg, astFile, obj.Pos()) + if err != nil { + return nil, err + } + decRange, err := ident.Declaration.Range() if err != nil { return nil, err @@ -102,11 +136,15 @@ func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, // We ignore aliases 'type M = N' to avoid duplicate // reporting of the Named type N. var allNamed []*types.Named - info := i.pkg.GetTypesInfo() - for _, obj := range info.Defs { - if obj, ok := obj.(*types.TypeName); ok && !obj.IsAlias() { - if named, ok := obj.Type().(*types.Named); ok { - allNamed = append(allNamed, named) + pkgs := map[*types.Named]Package{} + for _, pkg := range i.Snapshot.KnownPackages(ctx) { + info := pkg.GetTypesInfo() + for _, obj := range info.Defs { + if obj, ok := obj.(*types.TypeName); ok && !obj.IsAlias() { + if named, ok := obj.Type().(*types.Named); ok { + allNamed = append(allNamed, named) + pkgs[named] = pkg + } } } } @@ -173,11 +211,12 @@ func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, } } - return implementsResult{to, from, fromPtr, toMethod}, nil + return implementsResult{pkgs, to, from, fromPtr, toMethod}, nil } // implementsResult contains the results of an implements query. type implementsResult struct { + pkgs map[*types.Named]Package to []types.Type // named or ptr-to-named types assignable to interface T from []types.Type // named interfaces assignable from T fromPtr []types.Type // named interfaces assignable only from *T diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go index 1f3f18d721..842e5e2446 100644 --- a/internal/lsp/source/view.go +++ b/internal/lsp/source/view.go @@ -277,9 +277,12 @@ type Snapshot interface { // that this file belongs to. CheckPackageHandles(ctx context.Context, f File) ([]CheckPackageHandle, error) - // KnownImportPaths returns all the packages loaded in this snapshot, + // KnownImportPaths returns all the imported packages loaded in this snapshot, // indexed by their import path. KnownImportPaths() map[string]Package + + // KnownPackages returns all the packages loaded in this snapshot. + KnownPackages(ctx context.Context) []Package } // File represents a source file of any type. diff --git a/internal/lsp/testdata/implementation/implementation.go b/internal/lsp/testdata/implementation/implementation.go index 6bcb29b948..1dba704fc5 100644 --- a/internal/lsp/testdata/implementation/implementation.go +++ b/internal/lsp/testdata/implementation/implementation.go @@ -11,9 +11,9 @@ func (ImpS) Laugh() { //@mark(LaughS, "Laugh") } type ImpI interface { //@ImpI - Laugh() //@mark(LaughI, "Laugh"),implementations("augh", LaughP),implementations("augh", LaughS),implementations("augh", LaughL) + Laugh() //@mark(LaughI, "Laugh"),implementations("augh", LaughP),implementations("augh", OtherLaughP),implementations("augh", LaughS),implementations("augh", LaughL),implementations("augh", OtherLaughI),implementations("augh", OtherLaughS) } -type Laugher interface { //@Laugher,implementations("augher", ImpP),implementations("augher", ImpI),implementations("augher", ImpS), - Laugh() //@mark(LaughL, "Laugh"),implementations("augh", LaughP),implementations("augh", LaughI),implementations("augh", LaughS) +type Laugher interface { //@Laugher,implementations("augher", ImpP),implementations("augher", OtherImpP),implementations("augher", ImpI),implementations("augher", ImpS),implementations("augher", OtherImpI),implementations("augher", OtherImpS), + Laugh() //@mark(LaughL, "Laugh"),implementations("augh", LaughP),implementations("augh", OtherLaughP),implementations("augh", LaughI),implementations("augh", LaughS),implementations("augh", OtherLaughI),implementations("augh", OtherLaughS) } diff --git a/internal/lsp/testdata/implementation/other/other.go b/internal/lsp/testdata/implementation/other/other.go new file mode 100644 index 0000000000..ae7adf10bf --- /dev/null +++ b/internal/lsp/testdata/implementation/other/other.go @@ -0,0 +1,15 @@ +package other + +type ImpP struct{} //@mark(OtherImpP, "ImpP") + +func (*ImpP) Laugh() { //@mark(OtherLaughP, "Laugh") +} + +type ImpS struct{} //@mark(OtherImpS, "ImpS") + +func (ImpS) Laugh() { //@mark(OtherLaughS, "Laugh") +} + +type ImpI interface { //@mark(OtherImpI, "ImpI") + Laugh() //@mark(OtherLaughI, "Laugh") +}