diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go index 537460aa0b..724746ae42 100644 --- a/internal/lsp/cache/parse.go +++ b/internal/lsp/cache/parse.go @@ -402,15 +402,11 @@ FindTo: case *ast.GoStmt: stmt.Call = call } - switch parent := parent.(type) { - case *ast.BlockStmt: - for i, s := range parent.List { - if s == bad { - parent.List[i] = stmt - break - } - } + + if !replaceNode(parent, bad, stmt) { + return errors.Errorf("couldn't replace %T in %T", stmt, parent) } + return nil } @@ -444,3 +440,56 @@ func offsetPositions(expr ast.Expr, offset token.Pos) { return true }) } + +// replaceNode updates parent's child oldChild to be newChild. It +// retuns whether it replaced successfully. +func replaceNode(parent, oldChild, newChild ast.Node) bool { + if parent == nil || oldChild == nil || newChild == nil { + return false + } + + parentVal := reflect.ValueOf(parent).Elem() + if parentVal.Kind() != reflect.Struct { + return false + } + + newChildVal := reflect.ValueOf(newChild) + + tryReplace := func(v reflect.Value) bool { + if !v.CanSet() || !v.CanInterface() { + return false + } + + // If the existing value is oldChild, we found our child. Make + // sure our newChild is assignable and then make the swap. + if v.Interface() == oldChild && newChildVal.Type().AssignableTo(v.Type()) { + v.Set(newChildVal) + return true + } + + return false + } + + // Loop over parent's struct fields. + for i := 0; i < parentVal.NumField(); i++ { + f := parentVal.Field(i) + + switch f.Kind() { + // Check interface and pointer fields. + case reflect.Interface, reflect.Ptr: + if tryReplace(f) { + return true + } + + // Search through any slice fields. + case reflect.Slice: + for i := 0; i < f.Len(); i++ { + if tryReplace(f.Index(i)) { + return true + } + } + } + } + + return false +} diff --git a/internal/lsp/testdata/badstmt/badstmt.go.in b/internal/lsp/testdata/badstmt/badstmt.go.in index 3aae7db679..05b2c9a38d 100644 --- a/internal/lsp/testdata/badstmt/badstmt.go.in +++ b/internal/lsp/testdata/badstmt/badstmt.go.in @@ -10,6 +10,13 @@ func _() { defer foo.F //@complete(" //", Foo) } +func _() { + switch true { + case true: + go foo.F //@complete(" //", Foo) + } +} + func _() { defer func() { foo.F //@complete(" //", Foo),snippet(" //", Foo, "Foo()", "Foo()") diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 97e57095c3..85e54fe4ba 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -30,7 +30,7 @@ import ( // We hardcode the expected number of test cases to ensure that all tests // are being executed. If a test is added, this number must be changed. const ( - ExpectedCompletionsCount = 164 + ExpectedCompletionsCount = 165 ExpectedCompletionSnippetCount = 35 ExpectedDiagnosticsCount = 21 ExpectedFormatCount = 6