cmd/compile: constant fold 128-bit multiplies

The full 64x64->128 multiply comes up when using bits.Mul64.
The 64x64->64+overflow multiply comes up in unsafe.Slice when using
a constant length.

Change-Id: I298515162ca07d804b2d699d03bc957ca30a4ebc
Reviewed-on: https://go-review.googlesource.com/c/go/+/667175
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Keith Randall <khr@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Keith Randall 2025-04-21 12:44:24 -07:00 committed by Keith Randall
parent 7a177114df
commit 7d0cb2a2ad
4 changed files with 138 additions and 1 deletions

View File

@ -139,6 +139,10 @@
(Mul64 (Const64 [c]) (Const64 [d])) => (Const64 [c*d])
(Mul32F (Const32F [c]) (Const32F [d])) && c*d == c*d => (Const32F [c*d])
(Mul64F (Const64F [c]) (Const64F [d])) && c*d == c*d => (Const64F [c*d])
(Mul32uhilo (Const32 [c]) (Const32 [d])) => (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).hi]) (Const32 <typ.UInt32> [bitsMulU32(c,d).lo]))
(Mul64uhilo (Const64 [c]) (Const64 [d])) => (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).hi]) (Const64 <typ.UInt64> [bitsMulU64(c,d).lo]))
(Mul32uover (Const32 [c]) (Const32 [d])) => (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU32(c,d).hi != 0]))
(Mul64uover (Const64 [c]) (Const64 [d])) => (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU64(c,d).hi != 0]))
(And8 (Const8 [c]) (Const8 [d])) => (Const8 [c&d])
(And16 (Const16 [c]) (Const16 [d])) => (Const16 [c&d])

View File

@ -2584,3 +2584,14 @@ func bitsAdd64(x, y, carry int64) (r struct{ sum, carry int64 }) {
r.sum, r.carry = int64(s), int64(c)
return
}
func bitsMulU64(x, y int64) (r struct{ hi, lo int64 }) {
hi, lo := bits.Mul64(uint64(x), uint64(y))
r.hi, r.lo = int64(hi), int64(lo)
return
}
func bitsMulU32(x, y int32) (r struct{ hi, lo int32 }) {
hi, lo := bits.Mul32(uint32(x), uint32(y))
r.hi, r.lo = int32(hi), int32(lo)
return
}

View File

@ -246,12 +246,16 @@ func rewriteValuegeneric(v *Value) bool {
return rewriteValuegeneric_OpMul32(v)
case OpMul32F:
return rewriteValuegeneric_OpMul32F(v)
case OpMul32uhilo:
return rewriteValuegeneric_OpMul32uhilo(v)
case OpMul32uover:
return rewriteValuegeneric_OpMul32uover(v)
case OpMul64:
return rewriteValuegeneric_OpMul64(v)
case OpMul64F:
return rewriteValuegeneric_OpMul64F(v)
case OpMul64uhilo:
return rewriteValuegeneric_OpMul64uhilo(v)
case OpMul64uover:
return rewriteValuegeneric_OpMul64uover(v)
case OpMul8:
@ -18766,10 +18770,62 @@ func rewriteValuegeneric_OpMul32F(v *Value) bool {
}
return false
}
func rewriteValuegeneric_OpMul32uhilo(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]
b := v.Block
typ := &b.Func.Config.Types
// match: (Mul32uhilo (Const32 [c]) (Const32 [d]))
// result: (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).hi]) (Const32 <typ.UInt32> [bitsMulU32(c,d).lo]))
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpConst32 {
continue
}
c := auxIntToInt32(v_0.AuxInt)
if v_1.Op != OpConst32 {
continue
}
d := auxIntToInt32(v_1.AuxInt)
v.reset(OpMakeTuple)
v0 := b.NewValue0(v.Pos, OpConst32, typ.UInt32)
v0.AuxInt = int32ToAuxInt(bitsMulU32(c, d).hi)
v1 := b.NewValue0(v.Pos, OpConst32, typ.UInt32)
v1.AuxInt = int32ToAuxInt(bitsMulU32(c, d).lo)
v.AddArg2(v0, v1)
return true
}
break
}
return false
}
func rewriteValuegeneric_OpMul32uover(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]
b := v.Block
typ := &b.Func.Config.Types
// match: (Mul32uover (Const32 [c]) (Const32 [d]))
// result: (MakeTuple (Const32 <typ.UInt32> [bitsMulU32(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU32(c,d).hi != 0]))
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpConst32 {
continue
}
c := auxIntToInt32(v_0.AuxInt)
if v_1.Op != OpConst32 {
continue
}
d := auxIntToInt32(v_1.AuxInt)
v.reset(OpMakeTuple)
v0 := b.NewValue0(v.Pos, OpConst32, typ.UInt32)
v0.AuxInt = int32ToAuxInt(bitsMulU32(c, d).lo)
v1 := b.NewValue0(v.Pos, OpConstBool, typ.Bool)
v1.AuxInt = boolToAuxInt(bitsMulU32(c, d).hi != 0)
v.AddArg2(v0, v1)
return true
}
break
}
// match: (Mul32uover <t> (Const32 [1]) x)
// result: (MakeTuple x (ConstBool <t.FieldType(1)> [false]))
for {
@ -19082,10 +19138,62 @@ func rewriteValuegeneric_OpMul64F(v *Value) bool {
}
return false
}
func rewriteValuegeneric_OpMul64uhilo(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]
b := v.Block
typ := &b.Func.Config.Types
// match: (Mul64uhilo (Const64 [c]) (Const64 [d]))
// result: (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).hi]) (Const64 <typ.UInt64> [bitsMulU64(c,d).lo]))
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpConst64 {
continue
}
c := auxIntToInt64(v_0.AuxInt)
if v_1.Op != OpConst64 {
continue
}
d := auxIntToInt64(v_1.AuxInt)
v.reset(OpMakeTuple)
v0 := b.NewValue0(v.Pos, OpConst64, typ.UInt64)
v0.AuxInt = int64ToAuxInt(bitsMulU64(c, d).hi)
v1 := b.NewValue0(v.Pos, OpConst64, typ.UInt64)
v1.AuxInt = int64ToAuxInt(bitsMulU64(c, d).lo)
v.AddArg2(v0, v1)
return true
}
break
}
return false
}
func rewriteValuegeneric_OpMul64uover(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]
b := v.Block
typ := &b.Func.Config.Types
// match: (Mul64uover (Const64 [c]) (Const64 [d]))
// result: (MakeTuple (Const64 <typ.UInt64> [bitsMulU64(c, d).lo]) (ConstBool <typ.Bool> [bitsMulU64(c,d).hi != 0]))
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpConst64 {
continue
}
c := auxIntToInt64(v_0.AuxInt)
if v_1.Op != OpConst64 {
continue
}
d := auxIntToInt64(v_1.AuxInt)
v.reset(OpMakeTuple)
v0 := b.NewValue0(v.Pos, OpConst64, typ.UInt64)
v0.AuxInt = int64ToAuxInt(bitsMulU64(c, d).lo)
v1 := b.NewValue0(v.Pos, OpConstBool, typ.Bool)
v1.AuxInt = boolToAuxInt(bitsMulU64(c, d).hi != 0)
v.AddArg2(v0, v1)
return true
}
break
}
// match: (Mul64uover <t> (Const64 [1]) x)
// result: (MakeTuple x (ConstBool <t.FieldType(1)> [false]))
for {

View File

@ -6,7 +6,10 @@
package codegen
import "math/bits"
import (
"math/bits"
"unsafe"
)
// ----------------------- //
// bits.LeadingZeros //
@ -957,6 +960,17 @@ func Mul64LoOnly(x, y uint64) uint64 {
return lo
}
func Mul64Const() (uint64, uint64) {
// 7133701809754865664 == 99<<56
// arm64:"MOVD\t[$]7133701809754865664, R1", "MOVD\t[$]88, R0"
return bits.Mul64(99+88<<8, 1<<56)
}
func MulUintOverflow(p *uint64) []uint64 {
// arm64:"CMP\t[$]72"
return unsafe.Slice(p, 9)
}
// --------------- //
// bits.Div* //
// --------------- //