diff --git a/go/analysis/analysis.go b/go/analysis/analysis.go index e375484fa6..bc58c31c9f 100644 --- a/go/analysis/analysis.go +++ b/go/analysis/analysis.go @@ -128,11 +128,13 @@ type Pass struct { // See comments for ExportObjectFact. ExportPackageFact func(fact Fact) - // AllPackageFacts returns a new slice containing all package facts in unspecified order. + // AllPackageFacts returns a new slice containing all package facts of the analysis's FactTypes + // in unspecified order. // WARNING: This is an experimental API and may change in the future. AllPackageFacts func() []PackageFact - // AllObjectFacts returns a new slice containing all object facts in unspecified order. + // AllObjectFacts returns a new slice containing all object facts of the analysis's FactTypes + // in unspecified order. // WARNING: This is an experimental API and may change in the future. AllObjectFacts func() []ObjectFact diff --git a/go/analysis/internal/facts/facts.go b/go/analysis/internal/facts/facts.go index 86f1ce84a7..dcd4f4da81 100644 --- a/go/analysis/internal/facts/facts.go +++ b/go/analysis/internal/facts/facts.go @@ -99,10 +99,10 @@ func (s *Set) ExportObjectFact(obj types.Object, fact analysis.Fact) { s.mu.Unlock() } -func (s *Set) AllObjectFacts() []analysis.ObjectFact { +func (s *Set) AllObjectFacts(filter map[reflect.Type]bool) []analysis.ObjectFact { var facts []analysis.ObjectFact for k, v := range s.m { - if k.obj != nil { + if k.obj != nil && filter[k.t] { facts = append(facts, analysis.ObjectFact{k.obj, v}) } } @@ -132,10 +132,10 @@ func (s *Set) ExportPackageFact(fact analysis.Fact) { s.mu.Unlock() } -func (s *Set) AllPackageFacts() []analysis.PackageFact { +func (s *Set) AllPackageFacts(filter map[reflect.Type]bool) []analysis.PackageFact { var facts []analysis.PackageFact for k, v := range s.m { - if k.obj == nil { + if k.obj == nil && filter[k.t] { facts = append(facts, analysis.PackageFact{k.pkg, v}) } } diff --git a/go/analysis/internal/facts/facts_test.go b/go/analysis/internal/facts/facts_test.go index e21a4982ba..c345a12c04 100644 --- a/go/analysis/internal/facts/facts_test.go +++ b/go/analysis/internal/facts/facts_test.go @@ -10,6 +10,7 @@ import ( "go/token" "go/types" "os" + "reflect" "testing" "golang.org/x/tools/go/analysis/analysistest" @@ -172,3 +173,52 @@ func load(dir string, path string) (*types.Package, error) { } return pkgs[0].Types, nil } + +type otherFact struct { + S string +} + +func (f *otherFact) String() string { return fmt.Sprintf("otherFact(%s)", f.S) } +func (f *otherFact) AFact() {} + +func TestFactFilter(t *testing.T) { + files := map[string]string{ + "a/a.go": `package a; type A int`, + } + dir, cleanup, err := analysistest.WriteFiles(files) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + pkg, err := load(dir, "a") + if err != nil { + t.Fatal(err) + } + + obj := pkg.Scope().Lookup("A") + s, err := facts.Decode(pkg, func(string) ([]byte, error) { return nil, nil }) + if err != nil { + t.Fatal(err) + } + s.ExportObjectFact(obj, &myFact{"good object fact"}) + s.ExportPackageFact(&myFact{"good package fact"}) + s.ExportObjectFact(obj, &otherFact{"bad object fact"}) + s.ExportPackageFact(&otherFact{"bad package fact"}) + + filter := map[reflect.Type]bool{ + reflect.TypeOf(&myFact{}): true, + } + + pkgFacts := s.AllPackageFacts(filter) + wantPkgFacts := `[{package a ("a") myFact(good package fact)}]` + if got := fmt.Sprintf("%v", pkgFacts); got != wantPkgFacts { + t.Errorf("AllPackageFacts: got %v, want %v", got, wantPkgFacts) + } + + objFacts := s.AllObjectFacts(filter) + wantObjFacts := "[{type a.A int myFact(good object fact)}]" + if got := fmt.Sprintf("%v", objFacts); got != wantObjFacts { + t.Errorf("AllObjectFacts: got %v, want %v", got, wantObjFacts) + } +} diff --git a/go/analysis/unitchecker/unitchecker.go b/go/analysis/unitchecker/unitchecker.go index 87c3160847..2ed274949b 100644 --- a/go/analysis/unitchecker/unitchecker.go +++ b/go/analysis/unitchecker/unitchecker.go @@ -42,6 +42,7 @@ import ( "log" "os" "path/filepath" + "reflect" "sort" "strings" "sync" @@ -322,6 +323,11 @@ func run(fset *token.FileSet, cfg *Config, analyzers []*analysis.Analyzer) ([]re return } + factFilter := make(map[reflect.Type]bool) + for _, f := range a.FactTypes { + factFilter[reflect.TypeOf(f)] = true + } + pass := &analysis.Pass{ Analyzer: a, Fset: fset, @@ -334,10 +340,10 @@ func run(fset *token.FileSet, cfg *Config, analyzers []*analysis.Analyzer) ([]re Report: func(d analysis.Diagnostic) { act.diagnostics = append(act.diagnostics, d) }, ImportObjectFact: facts.ImportObjectFact, ExportObjectFact: facts.ExportObjectFact, - AllObjectFacts: facts.AllObjectFacts, + AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) }, ImportPackageFact: facts.ImportPackageFact, ExportPackageFact: facts.ExportPackageFact, - AllPackageFacts: facts.AllPackageFacts, + AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) }, } t0 := time.Now()