diff --git a/src/math/big/internal/asmgen/arch.go b/src/math/big/internal/asmgen/arch.go index bcba3992a9..adfcff9384 100644 --- a/src/math/big/internal/asmgen/arch.go +++ b/src/math/big/internal/asmgen/arch.go @@ -24,10 +24,15 @@ type Arch struct { // Registers. regs []string // usable general registers, in allocation order reg0 string // dedicated zero register - regCarry string // dedicated carry register - regAltCarry string // dedicated secondary carry register + regCarry string // dedicated carry register, for systems with no hardware carry bits + regAltCarry string // dedicated secondary carry register, for systems with no hardware carry bits regTmp string // dedicated temporary register + // regShift indicates that the architecture supports + // using REG1>>REG2 and REG1<>%s, %s\n", src, strings.TrimPrefix(shift.String(), "$"), dst) -} - func armMulWide(a *Asm, src1, src2, dstlo, dsthi Reg) { a.Printf("\tMULLU %s, %s, (%s, %s)\n", src1, src2, dsthi, dstlo) } diff --git a/src/math/big/internal/asmgen/asm.go b/src/math/big/internal/asmgen/asm.go index cc2cfc32d1..d1d8309c8f 100644 --- a/src/math/big/internal/asmgen/asm.go +++ b/src/math/big/internal/asmgen/asm.go @@ -328,14 +328,28 @@ func (a *Asm) Neg(src, dst Reg) { } } +// HasRegShift reports whether the architecture can use shift expressions as operands. +func (a *Asm) HasRegShift() bool { + return a.Arch.regShift +} + +// LshReg returns a shift-expression operand src<>shift. +// If a.HasRegShift() == false, RshReg panics. +func (a *Asm) RshReg(shift, src Reg) Reg { + if !a.HasRegShift() { + a.Fatalf("no reg shift") + } + return Reg{fmt.Sprintf("%s>>%s", src, strings.TrimPrefix(shift.name, "$"))} +} + // Rsh emits dst = src >> shift. // It may modify the carry flag. func (a *Asm) Rsh(shift, src, dst Reg) { if need := a.hint(HintShiftCount); need != "" && shift.name != need && !shift.IsImm() { a.Fatalf("shift count not in %s", need) } - if a.Arch.rshF != nil { - a.Arch.rshF(a, shift, src, dst) + if a.HasRegShift() { + a.Mov(a.RshReg(shift, src), dst) return } a.op3(a.Arch.rsh, shift, src, dst) diff --git a/src/math/big/internal/asmgen/cheat.go b/src/math/big/internal/asmgen/cheat.go index 0149d9ac56..9faf6d0483 100644 --- a/src/math/big/internal/asmgen/cheat.go +++ b/src/math/big/internal/asmgen/cheat.go @@ -36,16 +36,19 @@ func loop(x int) int { s := 0 for i := 1; i < x; i++ { s += i - if s == 98 { + if s == 98 { // useful for jmpEqual return 99 } if s == 99 { return 100 } - if s == 0 { + if s == 0 { // useful for jmpZero return 101 } - s += 2 + if s != 0 { // useful for jmpNonZero + s *= 3 + } + s += 2 // keep last condition from being inverted } return s } diff --git a/src/math/big/internal/asmgen/main.go b/src/math/big/internal/asmgen/main.go index 7f7f36c89f..9ca231f24b 100644 --- a/src/math/big/internal/asmgen/main.go +++ b/src/math/big/internal/asmgen/main.go @@ -33,5 +33,9 @@ func generate(arch *Arch) (file string, data []byte) { a := NewAsm(arch) addOrSubVV(a, "addVV") addOrSubVV(a, "subVV") + shiftVU(a, "lshVU") + shiftVU(a, "rshVU") + mulAddVWW(a) + addMulVVWW(a) return file, a.out.Bytes() } diff --git a/src/math/big/internal/asmgen/mul.go b/src/math/big/internal/asmgen/mul.go new file mode 100644 index 0000000000..007bfc9060 --- /dev/null +++ b/src/math/big/internal/asmgen/mul.go @@ -0,0 +1,320 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package asmgen + +// mulAddVWW generates mulAddVWW, which does z, c = x*m + a. +func mulAddVWW(a *Asm) { + f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)") + + if a.AltCarry().Valid() { + addMulVirtualCarry(f, 0) + return + } + addMul(f, "", "x", 0) +} + +// addMulVVWW generates addMulVVWW which does z, c = x + y*m + a. +// (A more pedantic name would be addMulAddVVWW.) +func addMulVVWW(a *Asm) { + f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)") + + // If the architecture has virtual carries, emit that version unconditionally. + if a.AltCarry().Valid() { + addMulVirtualCarry(f, 1) + return + } + + // If the architecture optionally has two carries, test and emit both versions. + if a.JmpEnable(OptionAltCarry, "altcarry") { + regs := a.RegsUsed() + addMul(f, "x", "y", 1) + a.Label("altcarry") + a.SetOption(OptionAltCarry, true) + a.SetRegsUsed(regs) + addMulAlt(f) + a.SetOption(OptionAltCarry, false) + return + } + + // Otherwise emit the one-carry form. + addMul(f, "x", "y", 1) +} + +// Computing z = addsrc + m*mulsrc + a, we need: +// +// for i := range z { +// lo, hi := m * mulsrc[i] +// lo, carry = bits.Add(lo, a, 0) +// lo, carryAlt = bits.Add(lo, addsrc[i], 0) +// z[i] = lo +// a = hi + carry + carryAlt // cannot overflow +// } +// +// The final addition cannot overflow because after processing N words, +// the maximum possible value is (for a 64-bit system): +// +// (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1) +// = (2**64)*(2**64N - 1) + (2**64 - 1) +// = 2**64(N+1) - 1, +// +// which fits in N+1 words (the high order one being the new value of a). +// +// (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.) +// +// If we unroll the loop a bit, then we can chain the carries in two passes. +// Consider: +// +// lo0, hi0 := m * mulsrc[i] +// lo0, carry = bits.Add(lo0, a, 0) +// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0) +// z[i] = lo0 +// a = hi + carry + carryAlt // cannot overflow +// +// lo1, hi1 := m * mulsrc[i] +// lo1, carry = bits.Add(lo1, a, 0) +// lo1, carryAlt = bits.Add(lo1, addsrc[i], 0) +// z[i] = lo1 +// a = hi + carry + carryAlt // cannot overflow +// +// lo2, hi2 := m * mulsrc[i] +// lo2, carry = bits.Add(lo2, a, 0) +// lo2, carryAlt = bits.Add(lo2, addsrc[i], 0) +// z[i] = lo2 +// a = hi + carry + carryAlt // cannot overflow +// +// lo3, hi3 := m * mulsrc[i] +// lo3, carry = bits.Add(lo3, a, 0) +// lo3, carryAlt = bits.Add(lo3, addsrc[i], 0) +// z[i] = lo3 +// a = hi + carry + carryAlt // cannot overflow +// +// There are three ways we can optimize this sequence. +// +// (1) Reordering, we can chain carries so that we can use one hardware carry flag +// but amortize the cost of saving and restoring it across multiple instructions: +// +// // multiply +// lo0, hi0 := m * mulsrc[i] +// lo1, hi1 := m * mulsrc[i+1] +// lo2, hi2 := m * mulsrc[i+2] +// lo3, hi3 := m * mulsrc[i+3] +// +// lo0, carry = bits.Add(lo0, a, 0) +// lo1, carry = bits.Add(lo1, hi0, carry) +// lo2, carry = bits.Add(lo2, hi1, carry) +// lo3, carry = bits.Add(lo3, hi2, carry) +// a = hi3 + carry // cannot overflow +// +// // add +// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0) +// lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt) +// lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt) +// lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt) +// a = a + carryAlt // cannot overflow +// +// z[i] = lo0 +// z[i+1] = lo1 +// z[i+2] = lo2 +// z[i+3] = lo3 +// +// addMul takes this approach, using the hardware carry flag +// first for carry and then for carryAlt. +// +// (2) addMulAlt assumes there are two hardware carry flags available. +// It dedicates one each to carry and carryAlt, so that a multi-block +// unrolling can keep the flags in hardware across all the blocks. +// So even if the block size is 1, the code can do: +// +// // multiply and add +// lo0, hi0 := m * mulsrc[i] +// lo0, carry = bits.Add(lo0, a, 0) +// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0) +// z[i] = lo0 +// +// lo1, hi1 := m * mulsrc[i+1] +// lo1, carry = bits.Add(lo1, hi0, carry) +// lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt) +// z[i+1] = lo1 +// +// lo2, hi2 := m * mulsrc[i+2] +// lo2, carry = bits.Add(lo2, hi1, carry) +// lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt) +// z[i+2] = lo2 +// +// lo3, hi3 := m * mulsrc[i+3] +// lo3, carry = bits.Add(lo3, hi2, carry) +// lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt) +// z[i+3] = lo2 +// +// a = hi3 + carry + carryAlt // cannot overflow +// +// (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits +// (loong64, mips, riscv64), cutting the number of actual instructions almost by half. +// Look again at the original word-at-a-time version: +// +// lo1, hi1 := m * mulsrc[i] +// lo1, carry = bits.Add(lo1, a, 0) +// lo1, carryAlt = bits.Add(lo1, addsrc[i], 0) +// z[i] = lo1 +// a = hi + carry + carryAlt // cannot overflow +// +// Although it uses four adds per word, those are cheap adds: the two bits.Add adds +// use two instructions each (ADD+SLTU) and the final + adds only use one ADD each, +// for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use +// only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to +// five instruction each, for a total of 10 instructions per word. So the word-at-a-time +// loop is actually better. And we can reorder things slightly to use only a single carry bit: +// +// lo1, hi1 := m * mulsrc[i] +// lo1, carry = bits.Add(lo1, a, 0) +// a = hi + carry +// lo1, carry = bits.Add(lo1, addsrc[i], 0) +// a = a + carry +// z[i] = lo1 +func addMul(f *Func, addsrc, mulsrc string, mulIndex int) { + a := f.Asm + mh := HintNone + if a.Arch == Arch386 && addsrc != "" { + mh = HintMemOK // too few registers otherwise + } + m := f.ArgHint("m", mh) + c := f.Arg("a") + n := f.Arg("z_len") + + p := f.Pipe() + if addsrc != "" { + p.SetHint(addsrc, HintMemOK) + } + p.SetHint(mulsrc, HintMulSrc) + unroll := []int{1, 4} + switch a.Arch { + case Arch386: + unroll = []int{1} // too few registers + case ArchARM: + p.SetMaxColumns(2) // too few registers (but more than 386) + case ArchARM64: + unroll = []int{1, 8} // 5% speedup on c4as16 + } + + // See the large comment above for an explanation of the code being generated. + // This is optimization strategy 1. + p.Start(n, unroll...) + p.Loop(func(in, out [][]Reg) { + a.Comment("multiply") + prev := c + flag := SetCarry + for i, x := range in[mulIndex] { + hi := a.RegHint(HintMulHi) + a.MulWide(m, x, x, hi) + a.Add(prev, x, x, flag) + flag = UseCarry | SetCarry + if prev != c { + a.Free(prev) + } + out[0][i] = x + prev = hi + } + a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry) + if addsrc != "" { + a.Comment("add") + flag := SetCarry + for i, x := range in[0] { + a.Add(x, out[0][i], out[0][i], flag) + flag = UseCarry | SetCarry + } + a.Add(a.Imm(0), c, c, UseCarry|SmashCarry) + } + p.StoreN(out) + }) + + f.StoreArg(c, "c") + a.Ret() +} + +func addMulAlt(f *Func) { + a := f.Asm + m := f.ArgHint("m", HintMulSrc) + c := f.Arg("a") + n := f.Arg("z_len") + + // On amd64, we need a non-immediate for the AtUnrollEnd adds. + r0 := a.ZR() + if !r0.Valid() { + r0 = a.Reg() + a.Mov(a.Imm(0), r0) + } + + p := f.Pipe() + p.SetLabel("alt") + p.SetHint("x", HintMemOK) + p.SetHint("y", HintMemOK) + if a.Arch == ArchAMD64 { + p.SetMaxColumns(2) + } + + // See the large comment above for an explanation of the code being generated. + // This is optimization strategy (2). + var hi Reg + prev := c + p.Start(n, 1, 8) + p.AtUnrollStart(func() { + a.Comment("multiply and add") + a.ClearCarry(AddCarry | AltCarry) + a.ClearCarry(AddCarry) + hi = a.Reg() + }) + p.AtUnrollEnd(func() { + a.Add(r0, prev, c, UseCarry|SmashCarry) + a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry) + prev = c + }) + p.Loop(func(in, out [][]Reg) { + for i, y := range in[1] { + x := in[0][i] + lo := y + if lo.IsMem() { + lo = a.Reg() + } + a.MulWide(m, y, lo, hi) + a.Add(prev, lo, lo, UseCarry|SetCarry) + a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry) + out[0][i] = lo + prev, hi = hi, prev + } + p.StoreN(out) + }) + + f.StoreArg(c, "c") + a.Ret() +} + +func addMulVirtualCarry(f *Func, mulIndex int) { + a := f.Asm + m := f.Arg("m") + c := f.Arg("a") + n := f.Arg("z_len") + + // See the large comment above for an explanation of the code being generated. + // This is optimization strategy (3). + p := f.Pipe() + p.Start(n, 1, 4) + p.Loop(func(in, out [][]Reg) { + a.Comment("synthetic carry, one column at a time") + lo, hi := a.Reg(), a.Reg() + for i, x := range in[mulIndex] { + a.MulWide(m, x, lo, hi) + if mulIndex == 1 { + a.Add(in[0][i], lo, lo, SetCarry) + a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry) + } + a.Add(c, lo, x, SetCarry) + a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry) + out[0][i] = x + } + p.StoreN(out) + }) + f.StoreArg(c, "c") + a.Ret() +} diff --git a/src/math/big/internal/asmgen/shift.go b/src/math/big/internal/asmgen/shift.go new file mode 100644 index 0000000000..6ece599a4b --- /dev/null +++ b/src/math/big/internal/asmgen/shift.go @@ -0,0 +1,135 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package asmgen + +// shiftVU generates lshVU and rshVU, which do +// z, c = x << s and z, c = x >> s, for 0 < s < _W. +func shiftVU(a *Asm, name string) { + // Because these routines can be called for z.Lsh(z, N) and z.Rsh(z, N), + // the input and output slices may be aliased at different offsets. + // For example (on 64-bit systems), during z.Lsh(z, 65), &z[0] == &x[1], + // and during z.Rsh(z, 65), &z[1] == &x[0]. + // For left shift, we must process the slices from len(z)-1 down to 0, + // so that we don't overwrite a word before we need to read it. + // For right shift, we must process the slices from 0 up to len(z)-1. + // The different traversals at least make the two cases more consistent, + // since we're always delaying the output by one word compared + // to the input. + + f := a.Func("func " + name + "(z, x []Word, s uint) (c Word)") + + // Check for no input early, since we need to start by reading 1 word. + n := f.Arg("z_len") + a.JmpZero(n, "ret0") + + // Start loop by reading first input word. + s := f.ArgHint("s", HintShiftCount) + p := f.Pipe() + if name == "lshVU" { + p.SetBackward() + } + unroll := []int{1, 4} + if a.Arch == Arch386 { + unroll = []int{1} // too few registers for more + p.SetUseIndexCounter() + } + p.LoadPtrs(n) + a.Comment("shift first word into carry") + prev := p.LoadN(1)[0][0] + + // Decide how to shift. On systems with a wide shift (x86), use that. + // Otherwise, we need shift by s and negative (reverse) shift by 64-s or 32-s. + shift := a.Lsh + shiftWide := a.LshWide + negShift := a.Rsh + negShiftReg := a.RshReg + if name == "rshVU" { + shift = a.Rsh + shiftWide = a.RshWide + negShift = a.Lsh + negShiftReg = a.LshReg + } + if a.Arch.HasShiftWide() { + // Use wide shift to avoid needing negative shifts. + // The invariant is that prev holds the previous word (not shifted at all), + // to be used as input into the wide shift. + // After the loop finishes, prev holds the final output word to be written. + c := a.Reg() + shiftWide(s, prev, a.Imm(0), c) + f.StoreArg(c, "c") + a.Free(c) + a.Comment("shift remaining words") + p.Start(n, unroll...) + p.Loop(func(in [][]Reg, out [][]Reg) { + // We reuse the input registers as output, delayed one cycle; prev is the first output. + // After writing the outputs to memory, we can copy the final x value into prev + // for the next iteration. + old := prev + for i, x := range in[0] { + shiftWide(s, x, old, old) + out[0][i] = old + old = x + } + p.StoreN(out) + a.Mov(old, prev) + }) + a.Comment("store final shifted bits") + shift(s, prev, prev) + } else { + // Construct values from x << s and x >> (64-s). + // After the first word has been processed, the invariant is that + // prev holds x << s, to be used as the high bits of the next output word, + // once we find the low bits after reading the next input word. + // After the loop finishes, prev holds the final output word to be written. + sNeg := a.Reg() + a.Mov(a.Imm(a.Arch.WordBits), sNeg) + a.Sub(s, sNeg, sNeg, SmashCarry) + c := a.Reg() + negShift(sNeg, prev, c) + shift(s, prev, prev) + f.StoreArg(c, "c") + a.Free(c) + a.Comment("shift remaining words") + p.Start(n, unroll...) + p.Loop(func(in, out [][]Reg) { + if a.HasRegShift() { + // ARM (32-bit) allows shifts in most arithmetic expressions, + // including OR, letting us combine the negShift and a.Or. + // The simplest way to manage the registers is to do StoreN for + // one output at a time, and since we don't use multi-register + // stores on ARM, that doesn't hurt us. + out[0] = out[0][:1] + for _, x := range in[0] { + a.Or(negShiftReg(sNeg, x), prev, prev) + out[0][0] = prev + p.StoreN(out) + shift(s, x, prev) + } + return + } + // We reuse the input registers as output, delayed one cycle; z0 is the first output. + z0 := a.Reg() + z := z0 + for i, x := range in[0] { + negShift(sNeg, x, z) + a.Or(prev, z, z) + shift(s, x, prev) + out[0][i] = z + z = x + } + p.StoreN(out) + }) + a.Comment("store final shifted bits") + } + p.StoreN([][]Reg{{prev}}) + p.Done() + a.Free(s) + a.Ret() + + // Return 0, used from above. + a.Label("ret0") + f.StoreArg(a.Imm(0), "c") + a.Ret() +}