cmd/compile: derive bounds on signed %N for N a power of 2

-N+1 <= x % N <= N-1

This is useful for cases like:

func setBit(b []byte, i int) {
    b[i/8] |= 1<<(i%8)
}

The shift does not need protection against larger-than-7 cases.
(It does still need protection against <0 cases.)

Change-Id: Idf83101386af538548bfeb6e2928cea855610ce2
Reviewed-on: https://go-review.googlesource.com/c/go/+/672995
Reviewed-by: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
Keith Randall 2025-04-24 18:37:49 -07:00
parent 498899e205
commit 3baf53aec6
2 changed files with 128 additions and 1 deletions

View File

@ -1757,7 +1757,9 @@ func (ft *factsTable) flowLimit(v *Value) bool {
case OpSub64, OpSub32, OpSub16, OpSub8:
a := ft.limits[v.Args[0].ID]
b := ft.limits[v.Args[1].ID]
return ft.newLimit(v, a.sub(b, uint(v.Type.Size())*8))
sub := ft.newLimit(v, a.sub(b, uint(v.Type.Size())*8))
mod := ft.detectSignedMod(v)
return sub || mod
case OpNeg64, OpNeg32, OpNeg16, OpNeg8:
a := ft.limits[v.Args[0].ID]
bitsize := uint(v.Type.Size()) * 8
@ -1913,6 +1915,122 @@ func (ft *factsTable) flowLimit(v *Value) bool {
return false
}
// See if we can get any facts because v is the result of signed mod by a constant.
// The mod operation has already been rewritten, so we have to try and reconstruct it.
// x % d
// is rewritten as
// x - (x / d) * d
// furthermore, the divide itself gets rewritten. If d is a power of 2 (d == 1<<k), we do
// (x / d) * d = ((x + adj) >> k) << k
// = (x + adj) & (-1<<k)
// with adj being an adjustment in case x is negative (see below).
// if d is not a power of 2, we do
// x / d = ... TODO ...
func (ft *factsTable) detectSignedMod(v *Value) bool {
if ft.detectSignedModByPowerOfTwo(v) {
return true
}
// TODO: non-powers-of-2
return false
}
func (ft *factsTable) detectSignedModByPowerOfTwo(v *Value) bool {
// We're looking for:
//
// x % d ==
// x - (x / d) * d
//
// which for d a power of 2, d == 1<<k, is done as
//
// x - ((x + (x>>(w-1))>>>(w-k)) & (-1<<k))
//
// w = bit width of x.
// (>> = signed shift, >>> = unsigned shift).
// See ./_gen/generic.rules, search for "Signed divide by power of 2".
var w int64
var addOp, andOp, constOp, sshiftOp, ushiftOp Op
switch v.Op {
case OpSub64:
w = 64
addOp = OpAdd64
andOp = OpAnd64
constOp = OpConst64
sshiftOp = OpRsh64x64
ushiftOp = OpRsh64Ux64
case OpSub32:
w = 32
addOp = OpAdd32
andOp = OpAnd32
constOp = OpConst32
sshiftOp = OpRsh32x64
ushiftOp = OpRsh32Ux64
case OpSub16:
w = 16
addOp = OpAdd16
andOp = OpAnd16
constOp = OpConst16
sshiftOp = OpRsh16x64
ushiftOp = OpRsh16Ux64
case OpSub8:
w = 8
addOp = OpAdd8
andOp = OpAnd8
constOp = OpConst8
sshiftOp = OpRsh8x64
ushiftOp = OpRsh8Ux64
default:
return false
}
x := v.Args[0]
and := v.Args[1]
if and.Op != andOp {
return false
}
var add, mask *Value
if and.Args[0].Op == addOp && and.Args[1].Op == constOp {
add = and.Args[0]
mask = and.Args[1]
} else if and.Args[1].Op == addOp && and.Args[0].Op == constOp {
add = and.Args[1]
mask = and.Args[0]
} else {
return false
}
var ushift *Value
if add.Args[0] == x {
ushift = add.Args[1]
} else if add.Args[1] == x {
ushift = add.Args[0]
} else {
return false
}
if ushift.Op != ushiftOp {
return false
}
if ushift.Args[1].Op != OpConst64 {
return false
}
k := w - ushift.Args[1].AuxInt // Now we know k!
d := int64(1) << k // divisor
sshift := ushift.Args[0]
if sshift.Op != sshiftOp {
return false
}
if sshift.Args[0] != x {
return false
}
if sshift.Args[1].Op != OpConst64 || sshift.Args[1].AuxInt != w-1 {
return false
}
if mask.AuxInt != -d {
return false
}
// All looks ok. x % d is at most +/- d-1.
return ft.signedMinMax(v, -d+1, d-1)
}
// getBranch returns the range restrictions added by p
// when reaching b. p is the immediate dominator of b.
func getBranch(sdom SparseTree, p *Block, b *Block) branch {

View File

@ -656,3 +656,12 @@ func rsh64to8(v int64) int8 {
}
return x
}
// We don't need to worry about shifting
// more than the type size.
// (There is still a negative shift test, but
// no shift-too-big test.)
func signedModShift(i int) int64 {
// arm64:-"CMP",-"CSEL"
return 1 << (i % 64)
}