diff --git a/src/cmd/compile/internal/noder/transform.go b/src/cmd/compile/internal/noder/transform.go index 6b17ab283a..ddbccf4ff4 100644 --- a/src/cmd/compile/internal/noder/transform.go +++ b/src/cmd/compile/internal/noder/transform.go @@ -162,6 +162,7 @@ func transformCall(n *ir.CallExpr) { ir.SetPos(n) // n.Type() can be nil for calls with no return value assert(n.Typecheck() == 1) + typecheck.RewriteNonNameCall(n) transformArgs(n) l := n.X t := l.Type() diff --git a/src/cmd/compile/internal/typecheck/const.go b/src/cmd/compile/internal/typecheck/const.go index a626c000be..22fa9e7d95 100644 --- a/src/cmd/compile/internal/typecheck/const.go +++ b/src/cmd/compile/internal/typecheck/const.go @@ -734,35 +734,40 @@ func IndexConst(n ir.Node) int64 { return ir.IntVal(types.Types[types.TINT], v) } +// callOrChan reports whether n is a call or channel operation. +func callOrChan(n ir.Node) bool { + switch n.Op() { + case ir.OAPPEND, + ir.OCALL, + ir.OCALLFUNC, + ir.OCALLINTER, + ir.OCALLMETH, + ir.OCAP, + ir.OCLOSE, + ir.OCOMPLEX, + ir.OCOPY, + ir.ODELETE, + ir.OIMAG, + ir.OLEN, + ir.OMAKE, + ir.ONEW, + ir.OPANIC, + ir.OPRINT, + ir.OPRINTN, + ir.OREAL, + ir.ORECOVER, + ir.ORECV, + ir.OUNSAFEADD, + ir.OUNSAFESLICE: + return true + } + return false +} + // anyCallOrChan reports whether n contains any calls or channel operations. func anyCallOrChan(n ir.Node) bool { return ir.Any(n, func(n ir.Node) bool { - switch n.Op() { - case ir.OAPPEND, - ir.OCALL, - ir.OCALLFUNC, - ir.OCALLINTER, - ir.OCALLMETH, - ir.OCAP, - ir.OCLOSE, - ir.OCOMPLEX, - ir.OCOPY, - ir.ODELETE, - ir.OIMAG, - ir.OLEN, - ir.OMAKE, - ir.ONEW, - ir.OPANIC, - ir.OPRINT, - ir.OPRINTN, - ir.OREAL, - ir.ORECOVER, - ir.ORECV, - ir.OUNSAFEADD, - ir.OUNSAFESLICE: - return true - } - return false + return callOrChan(n) }) } diff --git a/src/cmd/compile/internal/typecheck/func.go b/src/cmd/compile/internal/typecheck/func.go index 9d55d73592..0988ce8dc7 100644 --- a/src/cmd/compile/internal/typecheck/func.go +++ b/src/cmd/compile/internal/typecheck/func.go @@ -343,6 +343,7 @@ func tcCall(n *ir.CallExpr, top int) ir.Node { return tcConv(n) } + RewriteNonNameCall(n) typecheckargs(n) t := l.Type() if t == nil { diff --git a/src/cmd/compile/internal/typecheck/typecheck.go b/src/cmd/compile/internal/typecheck/typecheck.go index 06d7f5dc82..3b0c1f734e 100644 --- a/src/cmd/compile/internal/typecheck/typecheck.go +++ b/src/cmd/compile/internal/typecheck/typecheck.go @@ -867,6 +867,42 @@ func typecheckargs(n ir.InitNode) { RewriteMultiValueCall(n, list[0]) } +// RewriteNonNameCall replaces non-Name call expressions with temps, +// rewriting f()(...) to t0 := f(); t0(...). +func RewriteNonNameCall(n *ir.CallExpr) { + np := &n.X + if inst, ok := (*np).(*ir.InstExpr); ok && inst.Op() == ir.OFUNCINST { + np = &inst.X + } + if dot, ok := (*np).(*ir.SelectorExpr); ok && (dot.Op() == ir.ODOTMETH || dot.Op() == ir.ODOTINTER || dot.Op() == ir.OMETHVALUE) { + np = &dot.X // peel away method selector + } + + // Check for side effects in the callee expression. + // We explicitly special case new(T) though, because it doesn't have + // observable side effects, and keeping it in place allows better escape analysis. + if !ir.Any(*np, func(n ir.Node) bool { return n.Op() != ir.ONEW && callOrChan(n) }) { + return + } + + // See comment (1) in RewriteMultiValueCall. + static := ir.CurFunc == nil + if static { + ir.CurFunc = InitTodoFunc + } + + tmp := Temp((*np).Type()) + as := ir.NewAssignStmt(base.Pos, tmp, *np) + as.Def = true + *np = tmp + + if static { + ir.CurFunc = nil + } + + n.PtrInit().Append(Stmt(as)) +} + // RewriteMultiValueCall rewrites multi-valued f() to use temporaries, // so the backend wouldn't need to worry about tuple-valued expressions. func RewriteMultiValueCall(n ir.InitNode, call ir.Node) { @@ -874,7 +910,7 @@ func RewriteMultiValueCall(n ir.InitNode, call ir.Node) { // be executed during the generated init function. However, // init.go hasn't yet created it. Instead, associate the // temporary variables with InitTodoFunc for now, and init.go - // will reassociate them later when it's appropriate. + // will reassociate them later when it's appropriate. (1) static := ir.CurFunc == nil if static { ir.CurFunc = InitTodoFunc diff --git a/test/fixedbugs/issue50672.go b/test/fixedbugs/issue50672.go new file mode 100644 index 0000000000..178786a104 --- /dev/null +++ b/test/fixedbugs/issue50672.go @@ -0,0 +1,105 @@ +// run + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +var ok = false + +func f() func(int, int) int { + ok = true + return func(int, int) int { return 0 } +} + +func g() (int, int) { + if !ok { + panic("FAIL") + } + return 0, 0 +} + +var _ = f()(g()) + +func main() { + f1() + f2() + f3() + f4() +} + +func f1() { + ok := false + + f := func() func(int, int) { + ok = true + return func(int, int) {} + } + g := func() (int, int) { + if !ok { + panic("FAIL") + } + return 0, 0 + } + + f()(g()) +} + +type S struct{} + +func (S) f(int, int) {} + +func f2() { + ok := false + + f := func() S { + ok = true + return S{} + } + g := func() (int, int) { + if !ok { + panic("FAIL") + } + return 0, 0 + } + + f().f(g()) +} + +func f3() { + ok := false + + f := func() []func(int, int) { + ok = true + return []func(int, int){func(int, int) {}} + } + g := func() (int, int) { + if !ok { + panic("FAIL") + } + return 0, 0 + } + f()[0](g()) +} + +type G[T any] struct{} + +func (G[T]) f(int, int) {} + +func f4() { + ok := false + + f := func() G[int] { + ok = true + return G[int]{} + } + g := func() (int, int) { + if !ok { + panic("FAIL") + } + return 0, 0 + } + + f().f(g()) +}