diff --git a/src/cmd/compile/internal/gc/ssa.go b/src/cmd/compile/internal/gc/ssa.go index e2fbd6f096..e01ebd6e89 100644 --- a/src/cmd/compile/internal/gc/ssa.go +++ b/src/cmd/compile/internal/gc/ssa.go @@ -2556,7 +2556,7 @@ func (s *state) expr(n *Node) *ssa.Value { return s.addr(n.Left) case ORESULT: - if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall { + if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall && s.prevCall.Op != ssa.OpInterLECall && s.prevCall.Op != ssa.OpClosureLECall { // Do the old thing addr := s.constOffPtrSP(types.NewPtr(n.Type), n.Xoffset) return s.rawLoad(n.Type, addr) @@ -4409,6 +4409,9 @@ func (s *state) call(n *Node, k callKind, returnResultAddr bool) *ssa.Value { iclosure, rcvr = s.getClosureAndRcvr(fn) if k == callNormal { codeptr = s.load(types.Types[TUINTPTR], iclosure) + if ssa.LateCallExpansionEnabledWithin(s.f) { + testLateExpansion = true + } } else { closure = iclosure } @@ -4555,16 +4558,17 @@ func (s *state) call(n *Node, k callKind, returnResultAddr bool) *ssa.Value { codeptr = s.rawLoad(types.Types[TUINTPTR], closure) call = s.newValue3A(ssa.OpClosureCall, types.TypeMem, ssa.ClosureAuxCall(ACArgs, ACResults), codeptr, closure, s.mem()) case codeptr != nil: - call = s.newValue2A(ssa.OpInterCall, types.TypeMem, ssa.InterfaceAuxCall(ACArgs, ACResults), codeptr, s.mem()) + if testLateExpansion { + aux := ssa.InterfaceAuxCall(ACArgs, ACResults) + call = s.newValue1A(ssa.OpInterLECall, aux.LateExpansionResultType(), aux, codeptr) + call.AddArgs(callArgs...) + } else { + call = s.newValue2A(ssa.OpInterCall, types.TypeMem, ssa.InterfaceAuxCall(ACArgs, ACResults), codeptr, s.mem()) + } case sym != nil: if testLateExpansion { - var tys []*types.Type aux := ssa.StaticAuxCall(sym.Linksym(), ACArgs, ACResults) - for i := int64(0); i < aux.NResults(); i++ { - tys = append(tys, aux.TypeOfResult(i)) - } - tys = append(tys, types.TypeMem) - call = s.newValue0A(ssa.OpStaticLECall, types.NewResults(tys), aux) + call = s.newValue0A(ssa.OpStaticLECall, aux.LateExpansionResultType(), aux) call.AddArgs(callArgs...) } else { call = s.newValue1A(ssa.OpStaticCall, types.TypeMem, ssa.StaticAuxCall(sym.Linksym(), ACArgs, ACResults), s.mem()) @@ -4713,7 +4717,7 @@ func (s *state) addr(n *Node) *ssa.Value { } case ORESULT: // load return from callee - if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall { + if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall && s.prevCall.Op != ssa.OpInterLECall && s.prevCall.Op != ssa.OpClosureLECall { return s.constOffPtrSP(t, n.Xoffset) } which := s.prevCall.Aux.(*ssa.AuxCall).ResultForOffset(n.Xoffset) diff --git a/src/cmd/compile/internal/ssa/expand_calls.go b/src/cmd/compile/internal/ssa/expand_calls.go index 8456dbab8d..7b1d656b64 100644 --- a/src/cmd/compile/internal/ssa/expand_calls.go +++ b/src/cmd/compile/internal/ssa/expand_calls.go @@ -38,6 +38,7 @@ func expandCalls(f *Func) { } else { hiOffset = 4 } + pairTypes := func(et types.EType) (tHi, tLo *types.Type) { tHi = tUint32 if et == types.TINT64 { @@ -231,46 +232,64 @@ func expandCalls(f *Func) { return x } + rewriteArgs := func(v *Value, firstArg int) *Value { + // Thread the stores on the memory arg + aux := v.Aux.(*AuxCall) + pos := v.Pos.WithNotStmt() + m0 := v.Args[len(v.Args)-1] + mem := m0 + for i, a := range v.Args { + if i < firstArg { + continue + } + if a == m0 { // mem is last. + break + } + auxI := int64(i - firstArg) + if a.Op == OpDereference { + if a.MemoryArg() != m0 { + f.Fatalf("Op...LECall and OpDereference have mismatched mem, %s and %s", v.LongString(), a.LongString()) + } + // "Dereference" of addressed (probably not-SSA-eligible) value becomes Move + // TODO this will be more complicated with registers in the picture. + src := a.Args[0] + dst := f.ConstOffPtrSP(src.Type, aux.OffsetOfArg(auxI), sp) + if a.Uses == 1 { + a.reset(OpMove) + a.Pos = pos + a.Type = types.TypeMem + a.Aux = aux.TypeOfArg(auxI) + a.AuxInt = aux.SizeOfArg(auxI) + a.SetArgs3(dst, src, mem) + mem = a + } else { + mem = a.Block.NewValue3A(pos, OpMove, types.TypeMem, aux.TypeOfArg(auxI), dst, src, mem) + mem.AuxInt = aux.SizeOfArg(auxI) + } + } else { + mem = storeArg(pos, v.Block, a, aux.TypeOfArg(auxI), aux.OffsetOfArg(auxI), mem) + } + } + v.resetArgs() + return mem + } + // Step 0: rewrite the calls to convert incoming args to stores. for _, b := range f.Blocks { for _, v := range b.Values { switch v.Op { case OpStaticLECall: - // Thread the stores on the memory arg - m0 := v.MemoryArg() - mem := m0 - pos := v.Pos.WithNotStmt() - aux := v.Aux.(*AuxCall) - for i, a := range v.Args { - if a == m0 { // mem is last. - break - } - if a.Op == OpDereference { - // "Dereference" of addressed (probably not-SSA-eligible) value becomes Move - // TODO this will be more complicated with registers in the picture. - if a.MemoryArg() != m0 { - f.Fatalf("Op...LECall and OpDereference have mismatched mem, %s and %s", v.LongString(), a.LongString()) - } - src := a.Args[0] - dst := f.ConstOffPtrSP(src.Type, aux.OffsetOfArg(int64(i)), sp) - if a.Uses == 1 { - a.reset(OpMove) - a.Pos = pos - a.Type = types.TypeMem - a.Aux = aux.TypeOfArg(int64(i)) - a.AuxInt = aux.SizeOfArg(int64(i)) - a.SetArgs3(dst, src, mem) - mem = a - } else { - mem = a.Block.NewValue3A(pos, OpMove, types.TypeMem, aux.TypeOfArg(int64(i)), dst, src, mem) - mem.AuxInt = aux.SizeOfArg(int64(i)) - } - } else { - mem = storeArg(pos, b, a, aux.TypeOfArg(int64(i)), aux.OffsetOfArg(int64(i)), mem) - } - } - v.resetArgs() + mem := rewriteArgs(v, 0) v.SetArgs1(mem) + case OpClosureLECall: + code := v.Args[0] + context := v.Args[1] + mem := rewriteArgs(v, 2) + v.SetArgs3(code, context, mem) + case OpInterLECall: + code := v.Args[0] + mem := rewriteArgs(v, 1) + v.SetArgs2(code, mem) } } } @@ -370,6 +389,12 @@ func expandCalls(f *Func) { case OpStaticLECall: v.Op = OpStaticCall v.Type = types.TypeMem + case OpClosureLECall: + v.Op = OpClosureCall + v.Type = types.TypeMem + case OpInterLECall: + v.Op = OpInterCall + v.Type = types.TypeMem } } } diff --git a/src/cmd/compile/internal/ssa/gen/generic.rules b/src/cmd/compile/internal/ssa/gen/generic.rules index 39f8cc8889..588077422c 100644 --- a/src/cmd/compile/internal/ssa/gen/generic.rules +++ b/src/cmd/compile/internal/ssa/gen/generic.rules @@ -2024,6 +2024,13 @@ (InterCall [argsize] {auxCall} (Load (OffPtr [off] (ITab (IMake (Addr {itab} (SB)) _))) _) mem) && devirt(v, auxCall, itab, off) != nil => (StaticCall [int32(argsize)] {devirt(v, auxCall, itab, off)} mem) +// De-virtualize late-expanded interface calls into late-expanded static calls. +// Note that (ITab (IMake)) doesn't get rewritten until after the first opt pass, +// so this rule should trigger reliably. +// devirtLECall removes the first argument, adds the devirtualized symbol to the AuxCall, and changes the opcode +(InterLECall [argsize] {auxCall} (Load (OffPtr [off] (ITab (IMake (Addr {itab} (SB)) _))) _) ___) && devirtLESym(v, auxCall, itab, off) != + nil => devirtLECall(v, devirtLESym(v, auxCall, itab, off)) + // Move and Zero optimizations. // Move source and destination may overlap. diff --git a/src/cmd/compile/internal/ssa/gen/genericOps.go b/src/cmd/compile/internal/ssa/gen/genericOps.go index 95edff4c8c..3518dd1e3c 100644 --- a/src/cmd/compile/internal/ssa/gen/genericOps.go +++ b/src/cmd/compile/internal/ssa/gen/genericOps.go @@ -389,10 +389,12 @@ var genericOps = []opData{ // TODO(josharian): ClosureCall and InterCall should have Int32 aux // to match StaticCall's 32 bit arg size limit. // TODO(drchase,josharian): could the arg size limit be bundled into the rules for CallOff? - {name: "ClosureCall", argLength: 3, aux: "CallOff", call: true}, // arg0=code pointer, arg1=context ptr, arg2=memory. auxint=arg size. Returns memory. - {name: "StaticCall", argLength: 1, aux: "CallOff", call: true}, // call function aux.(*obj.LSym), arg0=memory. auxint=arg size. Returns memory. - {name: "InterCall", argLength: 2, aux: "CallOff", call: true}, // interface call. arg0=code pointer, arg1=memory, auxint=arg size. Returns memory. - {name: "StaticLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded static call function aux.(*ssa.AuxCall.Fn). arg0..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem. + {name: "ClosureCall", argLength: 3, aux: "CallOff", call: true}, // arg0=code pointer, arg1=context ptr, arg2=memory. auxint=arg size. Returns memory. + {name: "StaticCall", argLength: 1, aux: "CallOff", call: true}, // call function aux.(*obj.LSym), arg0=memory. auxint=arg size. Returns memory. + {name: "InterCall", argLength: 2, aux: "CallOff", call: true}, // interface call. arg0=code pointer, arg1=memory, auxint=arg size. Returns memory. + {name: "ClosureLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded closure call. arg0=code pointer, arg1=context ptr, arg2..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem. + {name: "StaticLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded static call function aux.(*ssa.AuxCall.Fn). arg0..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem. + {name: "InterLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded interface call. arg0=code pointer, arg1..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem. // Conversions: signed extensions, zero (unsigned) extensions, truncations {name: "SignExt8to16", argLength: 1, typ: "Int16"}, diff --git a/src/cmd/compile/internal/ssa/gen/rulegen.go b/src/cmd/compile/internal/ssa/gen/rulegen.go index be51a7c5f8..504ee2bd0a 100644 --- a/src/cmd/compile/internal/ssa/gen/rulegen.go +++ b/src/cmd/compile/internal/ssa/gen/rulegen.go @@ -50,8 +50,12 @@ import ( // variable ::= some token // opcode ::= one of the opcodes from the *Ops.go files +// special rules: trailing ellipsis "..." (in the outermost sexpr?) must match on both sides of a rule. +// trailing three underscore "___" in the outermost match sexpr indicate the presence of +// extra ignored args that need not appear in the replacement + // extra conditions is just a chunk of Go that evaluates to a boolean. It may use -// variables declared in the matching sexpr. The variable "v" is predefined to be +// variables declared in the matching tsexpr. The variable "v" is predefined to be // the value matched by the entire rule. // If multiple rules match, the first one in file order is selected. @@ -1019,6 +1023,19 @@ func genMatch0(rr *RuleRewrite, arch arch, match, v string, cnt map[string]int, pos = v + ".Pos" } + // If the last argument is ___, it means "don't care about trailing arguments, really" + // The likely/intended use is for rewrites that are too tricky to express in the existing pattern language + // Do a length check early because long patterns fed short (ultimately not-matching) inputs will + // do an indexing error in pattern-matching. + if op.argLength == -1 { + l := len(args) + if l == 0 || args[l-1] != "___" { + rr.add(breakf("len(%s.Args) != %d", v, l)) + } else if l > 1 && args[l-1] == "___" { + rr.add(breakf("len(%s.Args) < %d", v, l-1)) + } + } + for _, e := range []struct { name, field, dclType string }{ @@ -1159,9 +1176,6 @@ func genMatch0(rr *RuleRewrite, arch arch, match, v string, cnt map[string]int, } } - if op.argLength == -1 { - rr.add(breakf("len(%s.Args) != %d", v, len(args))) - } return pos, checkOp } diff --git a/src/cmd/compile/internal/ssa/op.go b/src/cmd/compile/internal/ssa/op.go index 9b45dd53c7..62f5cddcfc 100644 --- a/src/cmd/compile/internal/ssa/op.go +++ b/src/cmd/compile/internal/ssa/op.go @@ -127,6 +127,17 @@ func (a *AuxCall) NResults() int64 { return int64(len(a.results)) } +// LateExpansionResultType returns the result type (including trailing mem) +// for a call that will be expanded later in the SSA phase. +func (a *AuxCall) LateExpansionResultType() *types.Type { + var tys []*types.Type + for i := int64(0); i < a.NResults(); i++ { + tys = append(tys, a.TypeOfResult(i)) + } + tys = append(tys, types.TypeMem) + return types.NewResults(tys) +} + // NArgs returns the number of arguments func (a *AuxCall) NArgs() int64 { return int64(len(a.args)) diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 1fe00c7026..9fe943c2e0 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2732,7 +2732,9 @@ const ( OpClosureCall OpStaticCall OpInterCall + OpClosureLECall OpStaticLECall + OpInterLECall OpSignExt8to16 OpSignExt8to32 OpSignExt8to64 @@ -34851,6 +34853,13 @@ var opcodeTable = [...]opInfo{ call: true, generic: true, }, + { + name: "ClosureLECall", + auxType: auxCallOff, + argLen: -1, + call: true, + generic: true, + }, { name: "StaticLECall", auxType: auxCallOff, @@ -34858,6 +34867,13 @@ var opcodeTable = [...]opInfo{ call: true, generic: true, }, + { + name: "InterLECall", + auxType: auxCallOff, + argLen: -1, + call: true, + generic: true, + }, { name: "SignExt8to16", argLen: 1, diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go index d9c3e455a0..9f4de83a77 100644 --- a/src/cmd/compile/internal/ssa/rewrite.go +++ b/src/cmd/compile/internal/ssa/rewrite.go @@ -764,6 +764,36 @@ func devirt(v *Value, aux interface{}, sym Sym, offset int64) *AuxCall { return StaticAuxCall(lsym, va.args, va.results) } +// de-virtualize an InterLECall +// 'sym' is the symbol for the itab +func devirtLESym(v *Value, aux interface{}, sym Sym, offset int64) *obj.LSym { + n, ok := sym.(*obj.LSym) + if !ok { + return nil + } + + f := v.Block.Func + lsym := f.fe.DerefItab(n, offset) + if f.pass.debug > 0 { + if lsym != nil { + f.Warnl(v.Pos, "de-virtualizing call") + } else { + f.Warnl(v.Pos, "couldn't de-virtualize call") + } + } + if lsym == nil { + return nil + } + return lsym +} + +func devirtLECall(v *Value, sym *obj.LSym) *Value { + v.Op = OpStaticLECall + v.Aux.(*AuxCall).Fn = sym + v.RemoveArg(0) + return v +} + // isSamePtr reports whether p1 and p2 point to the same address. func isSamePtr(p1, p2 *Value) bool { if p1 == p2 { diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index 925ff53fd1..ade0a69a10 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -124,6 +124,8 @@ func rewriteValuegeneric(v *Value) bool { return rewriteValuegeneric_OpIMake(v) case OpInterCall: return rewriteValuegeneric_OpInterCall(v) + case OpInterLECall: + return rewriteValuegeneric_OpInterLECall(v) case OpIsInBounds: return rewriteValuegeneric_OpIsInBounds(v) case OpIsNonNil: @@ -8522,6 +8524,46 @@ func rewriteValuegeneric_OpInterCall(v *Value) bool { } return false } +func rewriteValuegeneric_OpInterLECall(v *Value) bool { + // match: (InterLECall [argsize] {auxCall} (Load (OffPtr [off] (ITab (IMake (Addr {itab} (SB)) _))) _) ___) + // cond: devirtLESym(v, auxCall, itab, off) != nil + // result: devirtLECall(v, devirtLESym(v, auxCall, itab, off)) + for { + if len(v.Args) < 1 { + break + } + auxCall := auxToCall(v.Aux) + v_0 := v.Args[0] + if v_0.Op != OpLoad { + break + } + v_0_0 := v_0.Args[0] + if v_0_0.Op != OpOffPtr { + break + } + off := auxIntToInt64(v_0_0.AuxInt) + v_0_0_0 := v_0_0.Args[0] + if v_0_0_0.Op != OpITab { + break + } + v_0_0_0_0 := v_0_0_0.Args[0] + if v_0_0_0_0.Op != OpIMake { + break + } + v_0_0_0_0_0 := v_0_0_0_0.Args[0] + if v_0_0_0_0_0.Op != OpAddr { + break + } + itab := auxToSym(v_0_0_0_0_0.Aux) + v_0_0_0_0_0_0 := v_0_0_0_0_0.Args[0] + if v_0_0_0_0_0_0.Op != OpSB || !(devirtLESym(v, auxCall, itab, off) != nil) { + break + } + v.copyOf(devirtLECall(v, devirtLESym(v, auxCall, itab, off))) + return true + } + return false +} func rewriteValuegeneric_OpIsInBounds(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -18549,6 +18591,9 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { // match: (Phi (Const8 [c]) (Const8 [c])) // result: (Const8 [c]) for { + if len(v.Args) != 2 { + break + } _ = v.Args[1] v_0 := v.Args[0] if v_0.Op != OpConst8 { @@ -18556,7 +18601,7 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { } c := auxIntToInt8(v_0.AuxInt) v_1 := v.Args[1] - if v_1.Op != OpConst8 || auxIntToInt8(v_1.AuxInt) != c || len(v.Args) != 2 { + if v_1.Op != OpConst8 || auxIntToInt8(v_1.AuxInt) != c { break } v.reset(OpConst8) @@ -18566,6 +18611,9 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { // match: (Phi (Const16 [c]) (Const16 [c])) // result: (Const16 [c]) for { + if len(v.Args) != 2 { + break + } _ = v.Args[1] v_0 := v.Args[0] if v_0.Op != OpConst16 { @@ -18573,7 +18621,7 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { } c := auxIntToInt16(v_0.AuxInt) v_1 := v.Args[1] - if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != c || len(v.Args) != 2 { + if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != c { break } v.reset(OpConst16) @@ -18583,6 +18631,9 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { // match: (Phi (Const32 [c]) (Const32 [c])) // result: (Const32 [c]) for { + if len(v.Args) != 2 { + break + } _ = v.Args[1] v_0 := v.Args[0] if v_0.Op != OpConst32 { @@ -18590,7 +18641,7 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { } c := auxIntToInt32(v_0.AuxInt) v_1 := v.Args[1] - if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != c || len(v.Args) != 2 { + if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != c { break } v.reset(OpConst32) @@ -18600,6 +18651,9 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { // match: (Phi (Const64 [c]) (Const64 [c])) // result: (Const64 [c]) for { + if len(v.Args) != 2 { + break + } _ = v.Args[1] v_0 := v.Args[0] if v_0.Op != OpConst64 { @@ -18607,7 +18661,7 @@ func rewriteValuegeneric_OpPhi(v *Value) bool { } c := auxIntToInt64(v_0.AuxInt) v_1 := v.Args[1] - if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != c || len(v.Args) != 2 { + if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != c { break } v.reset(OpConst64) diff --git a/src/cmd/compile/internal/ssa/value.go b/src/cmd/compile/internal/ssa/value.go index 94b8763d5d..edc43aaae7 100644 --- a/src/cmd/compile/internal/ssa/value.go +++ b/src/cmd/compile/internal/ssa/value.go @@ -348,6 +348,9 @@ func (v *Value) reset(op Op) { // It modifies v to be (Copy a). //go:noinline func (v *Value) copyOf(a *Value) { + if v == a { + return + } if v.InCache { v.Block.Func.unCache(v) }