From 3baf53aec6c2209562495d4ac1dc035c2881f6eb Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Thu, 24 Apr 2025 18:37:49 -0700 Subject: [PATCH] cmd/compile: derive bounds on signed %N for N a power of 2 -N+1 <= x % N <= N-1 This is useful for cases like: func setBit(b []byte, i int) { b[i/8] |= 1<<(i%8) } The shift does not need protection against larger-than-7 cases. (It does still need protection against <0 cases.) Change-Id: Idf83101386af538548bfeb6e2928cea855610ce2 Reviewed-on: https://go-review.googlesource.com/c/go/+/672995 Reviewed-by: Jorropo Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Knyszek --- src/cmd/compile/internal/ssa/prove.go | 120 +++++++++++++++++++++++++- test/codegen/shift.go | 9 ++ 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 94f23a84aa..5617edb21f 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -1757,7 +1757,9 @@ func (ft *factsTable) flowLimit(v *Value) bool { case OpSub64, OpSub32, OpSub16, OpSub8: a := ft.limits[v.Args[0].ID] b := ft.limits[v.Args[1].ID] - return ft.newLimit(v, a.sub(b, uint(v.Type.Size())*8)) + sub := ft.newLimit(v, a.sub(b, uint(v.Type.Size())*8)) + mod := ft.detectSignedMod(v) + return sub || mod case OpNeg64, OpNeg32, OpNeg16, OpNeg8: a := ft.limits[v.Args[0].ID] bitsize := uint(v.Type.Size()) * 8 @@ -1913,6 +1915,122 @@ func (ft *factsTable) flowLimit(v *Value) bool { return false } +// See if we can get any facts because v is the result of signed mod by a constant. +// The mod operation has already been rewritten, so we have to try and reconstruct it. +// x % d +// is rewritten as +// x - (x / d) * d +// furthermore, the divide itself gets rewritten. If d is a power of 2 (d == 1<> k) << k +// = (x + adj) & (-1<>(w-1))>>>(w-k)) & (-1<> = signed shift, >>> = unsigned shift). + // See ./_gen/generic.rules, search for "Signed divide by power of 2". + + var w int64 + var addOp, andOp, constOp, sshiftOp, ushiftOp Op + switch v.Op { + case OpSub64: + w = 64 + addOp = OpAdd64 + andOp = OpAnd64 + constOp = OpConst64 + sshiftOp = OpRsh64x64 + ushiftOp = OpRsh64Ux64 + case OpSub32: + w = 32 + addOp = OpAdd32 + andOp = OpAnd32 + constOp = OpConst32 + sshiftOp = OpRsh32x64 + ushiftOp = OpRsh32Ux64 + case OpSub16: + w = 16 + addOp = OpAdd16 + andOp = OpAnd16 + constOp = OpConst16 + sshiftOp = OpRsh16x64 + ushiftOp = OpRsh16Ux64 + case OpSub8: + w = 8 + addOp = OpAdd8 + andOp = OpAnd8 + constOp = OpConst8 + sshiftOp = OpRsh8x64 + ushiftOp = OpRsh8Ux64 + default: + return false + } + + x := v.Args[0] + and := v.Args[1] + if and.Op != andOp { + return false + } + var add, mask *Value + if and.Args[0].Op == addOp && and.Args[1].Op == constOp { + add = and.Args[0] + mask = and.Args[1] + } else if and.Args[1].Op == addOp && and.Args[0].Op == constOp { + add = and.Args[1] + mask = and.Args[0] + } else { + return false + } + var ushift *Value + if add.Args[0] == x { + ushift = add.Args[1] + } else if add.Args[1] == x { + ushift = add.Args[0] + } else { + return false + } + if ushift.Op != ushiftOp { + return false + } + if ushift.Args[1].Op != OpConst64 { + return false + } + k := w - ushift.Args[1].AuxInt // Now we know k! + d := int64(1) << k // divisor + sshift := ushift.Args[0] + if sshift.Op != sshiftOp { + return false + } + if sshift.Args[0] != x { + return false + } + if sshift.Args[1].Op != OpConst64 || sshift.Args[1].AuxInt != w-1 { + return false + } + if mask.AuxInt != -d { + return false + } + + // All looks ok. x % d is at most +/- d-1. + return ft.signedMinMax(v, -d+1, d-1) +} + // getBranch returns the range restrictions added by p // when reaching b. p is the immediate dominator of b. func getBranch(sdom SparseTree, p *Block, b *Block) branch { diff --git a/test/codegen/shift.go b/test/codegen/shift.go index 98d621d352..56e8d354e6 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -656,3 +656,12 @@ func rsh64to8(v int64) int8 { } return x } + +// We don't need to worry about shifting +// more than the type size. +// (There is still a negative shift test, but +// no shift-too-big test.) +func signedModShift(i int) int64 { + // arm64:-"CMP",-"CSEL" + return 1 << (i % 64) +}