Source file src/math/big/internal/asmgen/mul.go
1 // Copyright 2025 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package asmgen 6 7 // mulAddVWW generates mulAddVWW, which does z, c = x*m + a. 8 func mulAddVWW(a *Asm) { 9 f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)") 10 11 if a.AltCarry().Valid() { 12 addMulVirtualCarry(f, 0) 13 return 14 } 15 addMul(f, "", "x", 0) 16 } 17 18 // addMulVVWW generates addMulVVWW which does z, c = x + y*m + a. 19 // (A more pedantic name would be addMulAddVVWW.) 20 func addMulVVWW(a *Asm) { 21 f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)") 22 23 // If the architecture has virtual carries, emit that version unconditionally. 24 if a.AltCarry().Valid() { 25 addMulVirtualCarry(f, 1) 26 return 27 } 28 29 // If the architecture optionally has two carries, test and emit both versions. 30 if a.JmpEnable(OptionAltCarry, "altcarry") { 31 regs := a.RegsUsed() 32 addMul(f, "x", "y", 1) 33 a.Label("altcarry") 34 a.SetOption(OptionAltCarry, true) 35 a.SetRegsUsed(regs) 36 addMulAlt(f) 37 a.SetOption(OptionAltCarry, false) 38 return 39 } 40 41 // Otherwise emit the one-carry form. 42 addMul(f, "x", "y", 1) 43 } 44 45 // Computing z = addsrc + m*mulsrc + a, we need: 46 // 47 // for i := range z { 48 // lo, hi := m * mulsrc[i] 49 // lo, carry = bits.Add(lo, a, 0) 50 // lo, carryAlt = bits.Add(lo, addsrc[i], 0) 51 // z[i] = lo 52 // a = hi + carry + carryAlt // cannot overflow 53 // } 54 // 55 // The final addition cannot overflow because after processing N words, 56 // the maximum possible value is (for a 64-bit system): 57 // 58 // (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1) 59 // = (2**64)*(2**64N - 1) + (2**64 - 1) 60 // = 2**64(N+1) - 1, 61 // 62 // which fits in N+1 words (the high order one being the new value of a). 63 // 64 // (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.) 65 // 66 // If we unroll the loop a bit, then we can chain the carries in two passes. 67 // Consider: 68 // 69 // lo0, hi0 := m * mulsrc[i] 70 // lo0, carry = bits.Add(lo0, a, 0) 71 // lo0, carryAlt = bits.Add(lo0, addsrc[i], 0) 72 // z[i] = lo0 73 // a = hi + carry + carryAlt // cannot overflow 74 // 75 // lo1, hi1 := m * mulsrc[i] 76 // lo1, carry = bits.Add(lo1, a, 0) 77 // lo1, carryAlt = bits.Add(lo1, addsrc[i], 0) 78 // z[i] = lo1 79 // a = hi + carry + carryAlt // cannot overflow 80 // 81 // lo2, hi2 := m * mulsrc[i] 82 // lo2, carry = bits.Add(lo2, a, 0) 83 // lo2, carryAlt = bits.Add(lo2, addsrc[i], 0) 84 // z[i] = lo2 85 // a = hi + carry + carryAlt // cannot overflow 86 // 87 // lo3, hi3 := m * mulsrc[i] 88 // lo3, carry = bits.Add(lo3, a, 0) 89 // lo3, carryAlt = bits.Add(lo3, addsrc[i], 0) 90 // z[i] = lo3 91 // a = hi + carry + carryAlt // cannot overflow 92 // 93 // There are three ways we can optimize this sequence. 94 // 95 // (1) Reordering, we can chain carries so that we can use one hardware carry flag 96 // but amortize the cost of saving and restoring it across multiple instructions: 97 // 98 // // multiply 99 // lo0, hi0 := m * mulsrc[i] 100 // lo1, hi1 := m * mulsrc[i+1] 101 // lo2, hi2 := m * mulsrc[i+2] 102 // lo3, hi3 := m * mulsrc[i+3] 103 // 104 // lo0, carry = bits.Add(lo0, a, 0) 105 // lo1, carry = bits.Add(lo1, hi0, carry) 106 // lo2, carry = bits.Add(lo2, hi1, carry) 107 // lo3, carry = bits.Add(lo3, hi2, carry) 108 // a = hi3 + carry // cannot overflow 109 // 110 // // add 111 // lo0, carryAlt = bits.Add(lo0, addsrc[i], 0) 112 // lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt) 113 // lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt) 114 // lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt) 115 // a = a + carryAlt // cannot overflow 116 // 117 // z[i] = lo0 118 // z[i+1] = lo1 119 // z[i+2] = lo2 120 // z[i+3] = lo3 121 // 122 // addMul takes this approach, using the hardware carry flag 123 // first for carry and then for carryAlt. 124 // 125 // (2) addMulAlt assumes there are two hardware carry flags available. 126 // It dedicates one each to carry and carryAlt, so that a multi-block 127 // unrolling can keep the flags in hardware across all the blocks. 128 // So even if the block size is 1, the code can do: 129 // 130 // // multiply and add 131 // lo0, hi0 := m * mulsrc[i] 132 // lo0, carry = bits.Add(lo0, a, 0) 133 // lo0, carryAlt = bits.Add(lo0, addsrc[i], 0) 134 // z[i] = lo0 135 // 136 // lo1, hi1 := m * mulsrc[i+1] 137 // lo1, carry = bits.Add(lo1, hi0, carry) 138 // lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt) 139 // z[i+1] = lo1 140 // 141 // lo2, hi2 := m * mulsrc[i+2] 142 // lo2, carry = bits.Add(lo2, hi1, carry) 143 // lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt) 144 // z[i+2] = lo2 145 // 146 // lo3, hi3 := m * mulsrc[i+3] 147 // lo3, carry = bits.Add(lo3, hi2, carry) 148 // lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt) 149 // z[i+3] = lo2 150 // 151 // a = hi3 + carry + carryAlt // cannot overflow 152 // 153 // (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits 154 // (loong64, mips, riscv64), cutting the number of actual instructions almost by half. 155 // Look again at the original word-at-a-time version: 156 // 157 // lo1, hi1 := m * mulsrc[i] 158 // lo1, carry = bits.Add(lo1, a, 0) 159 // lo1, carryAlt = bits.Add(lo1, addsrc[i], 0) 160 // z[i] = lo1 161 // a = hi + carry + carryAlt // cannot overflow 162 // 163 // Although it uses four adds per word, those are cheap adds: the two bits.Add adds 164 // use two instructions each (ADD+SLTU) and the final + adds only use one ADD each, 165 // for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use 166 // only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to 167 // five instruction each, for a total of 10 instructions per word. So the word-at-a-time 168 // loop is actually better. And we can reorder things slightly to use only a single carry bit: 169 // 170 // lo1, hi1 := m * mulsrc[i] 171 // lo1, carry = bits.Add(lo1, a, 0) 172 // a = hi + carry 173 // lo1, carry = bits.Add(lo1, addsrc[i], 0) 174 // a = a + carry 175 // z[i] = lo1 176 func addMul(f *Func, addsrc, mulsrc string, mulIndex int) { 177 a := f.Asm 178 mh := HintNone 179 if a.Arch == Arch386 && addsrc != "" { 180 mh = HintMemOK // too few registers otherwise 181 } 182 m := f.ArgHint("m", mh) 183 c := f.Arg("a") 184 n := f.Arg("z_len") 185 186 p := f.Pipe() 187 if addsrc != "" { 188 p.SetHint(addsrc, HintMemOK) 189 } 190 p.SetHint(mulsrc, HintMulSrc) 191 unroll := []int{1, 4} 192 switch a.Arch { 193 case Arch386: 194 unroll = []int{1} // too few registers 195 case ArchARM: 196 p.SetMaxColumns(2) // too few registers (but more than 386) 197 case ArchARM64: 198 unroll = []int{1, 8} // 5% speedup on c4as16 199 } 200 201 // See the large comment above for an explanation of the code being generated. 202 // This is optimization strategy 1. 203 p.Start(n, unroll...) 204 p.Loop(func(in, out [][]Reg) { 205 a.Comment("multiply") 206 prev := c 207 flag := SetCarry 208 for i, x := range in[mulIndex] { 209 hi := a.RegHint(HintMulHi) 210 a.MulWide(m, x, x, hi) 211 a.Add(prev, x, x, flag) 212 flag = UseCarry | SetCarry 213 if prev != c { 214 a.Free(prev) 215 } 216 out[0][i] = x 217 prev = hi 218 } 219 a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry) 220 if addsrc != "" { 221 a.Comment("add") 222 flag := SetCarry 223 for i, x := range in[0] { 224 a.Add(x, out[0][i], out[0][i], flag) 225 flag = UseCarry | SetCarry 226 } 227 a.Add(a.Imm(0), c, c, UseCarry|SmashCarry) 228 } 229 p.StoreN(out) 230 }) 231 232 f.StoreArg(c, "c") 233 a.Ret() 234 } 235 236 func addMulAlt(f *Func) { 237 a := f.Asm 238 m := f.ArgHint("m", HintMulSrc) 239 c := f.Arg("a") 240 n := f.Arg("z_len") 241 242 // On amd64, we need a non-immediate for the AtUnrollEnd adds. 243 r0 := a.ZR() 244 if !r0.Valid() { 245 r0 = a.Reg() 246 a.Mov(a.Imm(0), r0) 247 } 248 249 p := f.Pipe() 250 p.SetLabel("alt") 251 p.SetHint("x", HintMemOK) 252 p.SetHint("y", HintMemOK) 253 if a.Arch == ArchAMD64 { 254 p.SetMaxColumns(2) 255 } 256 257 // See the large comment above for an explanation of the code being generated. 258 // This is optimization strategy (2). 259 var hi Reg 260 prev := c 261 p.Start(n, 1, 8) 262 p.AtUnrollStart(func() { 263 a.Comment("multiply and add") 264 a.ClearCarry(AddCarry | AltCarry) 265 a.ClearCarry(AddCarry) 266 hi = a.Reg() 267 }) 268 p.AtUnrollEnd(func() { 269 a.Add(r0, prev, c, UseCarry|SmashCarry) 270 a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry) 271 prev = c 272 }) 273 p.Loop(func(in, out [][]Reg) { 274 for i, y := range in[1] { 275 x := in[0][i] 276 lo := y 277 if lo.IsMem() { 278 lo = a.Reg() 279 } 280 a.MulWide(m, y, lo, hi) 281 a.Add(prev, lo, lo, UseCarry|SetCarry) 282 a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry) 283 out[0][i] = lo 284 prev, hi = hi, prev 285 } 286 p.StoreN(out) 287 }) 288 289 f.StoreArg(c, "c") 290 a.Ret() 291 } 292 293 func addMulVirtualCarry(f *Func, mulIndex int) { 294 a := f.Asm 295 m := f.Arg("m") 296 c := f.Arg("a") 297 n := f.Arg("z_len") 298 299 // See the large comment above for an explanation of the code being generated. 300 // This is optimization strategy (3). 301 p := f.Pipe() 302 p.Start(n, 1, 4) 303 p.Loop(func(in, out [][]Reg) { 304 a.Comment("synthetic carry, one column at a time") 305 lo, hi := a.Reg(), a.Reg() 306 for i, x := range in[mulIndex] { 307 a.MulWide(m, x, lo, hi) 308 if mulIndex == 1 { 309 a.Add(in[0][i], lo, lo, SetCarry) 310 a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry) 311 } 312 a.Add(c, lo, x, SetCarry) 313 a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry) 314 out[0][i] = x 315 } 316 p.StoreN(out) 317 }) 318 f.StoreArg(c, "c") 319 a.Ret() 320 } 321