Source file src/math/big/internal/asmgen/mul.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package asmgen
     6  
     7  // mulAddVWW generates mulAddVWW, which does z, c = x*m + a.
     8  func mulAddVWW(a *Asm) {
     9  	f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)")
    10  
    11  	if a.AltCarry().Valid() {
    12  		addMulVirtualCarry(f, 0)
    13  		return
    14  	}
    15  	addMul(f, "", "x", 0)
    16  }
    17  
    18  // addMulVVWW generates addMulVVWW which does z, c = x + y*m + a.
    19  // (A more pedantic name would be addMulAddVVWW.)
    20  func addMulVVWW(a *Asm) {
    21  	f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)")
    22  
    23  	// If the architecture has virtual carries, emit that version unconditionally.
    24  	if a.AltCarry().Valid() {
    25  		addMulVirtualCarry(f, 1)
    26  		return
    27  	}
    28  
    29  	// If the architecture optionally has two carries, test and emit both versions.
    30  	if a.JmpEnable(OptionAltCarry, "altcarry") {
    31  		regs := a.RegsUsed()
    32  		addMul(f, "x", "y", 1)
    33  		a.Label("altcarry")
    34  		a.SetOption(OptionAltCarry, true)
    35  		a.SetRegsUsed(regs)
    36  		addMulAlt(f)
    37  		a.SetOption(OptionAltCarry, false)
    38  		return
    39  	}
    40  
    41  	// Otherwise emit the one-carry form.
    42  	addMul(f, "x", "y", 1)
    43  }
    44  
    45  // Computing z = addsrc + m*mulsrc + a, we need:
    46  //
    47  //	for i := range z {
    48  //		lo, hi := m * mulsrc[i]
    49  //		lo, carry = bits.Add(lo, a, 0)
    50  //		lo, carryAlt = bits.Add(lo, addsrc[i], 0)
    51  //		z[i] = lo
    52  //		a = hi + carry + carryAlt  // cannot overflow
    53  //	}
    54  //
    55  // The final addition cannot overflow because after processing N words,
    56  // the maximum possible value is (for a 64-bit system):
    57  //
    58  //	  (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1)
    59  //	= (2**64)*(2**64N - 1) + (2**64 - 1)
    60  //	= 2**64(N+1) - 1,
    61  //
    62  // which fits in N+1 words (the high order one being the new value of a).
    63  //
    64  // (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.)
    65  //
    66  // If we unroll the loop a bit, then we can chain the carries in two passes.
    67  // Consider:
    68  //
    69  //	lo0, hi0 := m * mulsrc[i]
    70  //	lo0, carry = bits.Add(lo0, a, 0)
    71  //	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
    72  //	z[i] = lo0
    73  //	a = hi + carry + carryAlt // cannot overflow
    74  //
    75  //	lo1, hi1 := m * mulsrc[i]
    76  //	lo1, carry = bits.Add(lo1, a, 0)
    77  //	lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
    78  //	z[i] = lo1
    79  //	a = hi + carry + carryAlt // cannot overflow
    80  //
    81  //	lo2, hi2 := m * mulsrc[i]
    82  //	lo2, carry = bits.Add(lo2, a, 0)
    83  //	lo2, carryAlt = bits.Add(lo2, addsrc[i], 0)
    84  //	z[i] = lo2
    85  //	a = hi + carry + carryAlt // cannot overflow
    86  //
    87  //	lo3, hi3 := m * mulsrc[i]
    88  //	lo3, carry = bits.Add(lo3, a, 0)
    89  //	lo3, carryAlt = bits.Add(lo3, addsrc[i], 0)
    90  //	z[i] = lo3
    91  //	a = hi + carry + carryAlt // cannot overflow
    92  //
    93  // There are three ways we can optimize this sequence.
    94  //
    95  // (1) Reordering, we can chain carries so that we can use one hardware carry flag
    96  // but amortize the cost of saving and restoring it across multiple instructions:
    97  //
    98  //	// multiply
    99  //	lo0, hi0 := m * mulsrc[i]
   100  //	lo1, hi1 := m * mulsrc[i+1]
   101  //	lo2, hi2 := m * mulsrc[i+2]
   102  //	lo3, hi3 := m * mulsrc[i+3]
   103  //
   104  //	lo0, carry = bits.Add(lo0, a, 0)
   105  //	lo1, carry = bits.Add(lo1, hi0, carry)
   106  //	lo2, carry = bits.Add(lo2, hi1, carry)
   107  //	lo3, carry = bits.Add(lo3, hi2, carry)
   108  //	a = hi3 + carry // cannot overflow
   109  //
   110  //	// add
   111  //	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
   112  //	lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
   113  //	lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
   114  //	lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
   115  //	a = a + carryAlt // cannot overflow
   116  //
   117  //	z[i] = lo0
   118  //	z[i+1] = lo1
   119  //	z[i+2] = lo2
   120  //	z[i+3] = lo3
   121  //
   122  // addMul takes this approach, using the hardware carry flag
   123  // first for carry and then for carryAlt.
   124  //
   125  // (2) addMulAlt assumes there are two hardware carry flags available.
   126  // It dedicates one each to carry and carryAlt, so that a multi-block
   127  // unrolling can keep the flags in hardware across all the blocks.
   128  // So even if the block size is 1, the code can do:
   129  //
   130  //	// multiply and add
   131  //	lo0, hi0 := m * mulsrc[i]
   132  //	lo0, carry = bits.Add(lo0, a, 0)
   133  //	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
   134  //	z[i] = lo0
   135  //
   136  //	lo1, hi1 := m * mulsrc[i+1]
   137  //	lo1, carry = bits.Add(lo1, hi0, carry)
   138  //	lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
   139  //	z[i+1] = lo1
   140  //
   141  //	lo2, hi2 := m * mulsrc[i+2]
   142  //	lo2, carry = bits.Add(lo2, hi1, carry)
   143  //	lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
   144  //	z[i+2] = lo2
   145  //
   146  //	lo3, hi3 := m * mulsrc[i+3]
   147  //	lo3, carry = bits.Add(lo3, hi2, carry)
   148  //	lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
   149  //	z[i+3] = lo2
   150  //
   151  //	a = hi3 + carry + carryAlt // cannot overflow
   152  //
   153  // (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits
   154  // (loong64, mips, riscv64), cutting the number of actual instructions almost by half.
   155  // Look again at the original word-at-a-time version:
   156  //
   157  //	lo1, hi1 := m * mulsrc[i]
   158  //	lo1, carry = bits.Add(lo1, a, 0)
   159  //	lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
   160  //	z[i] = lo1
   161  //	a = hi + carry + carryAlt // cannot overflow
   162  //
   163  // Although it uses four adds per word, those are cheap adds: the two bits.Add adds
   164  // use two instructions each (ADD+SLTU) and the final + adds only use one ADD each,
   165  // for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use
   166  // only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to
   167  // five instruction each, for a total of 10 instructions per word. So the word-at-a-time
   168  // loop is actually better. And we can reorder things slightly to use only a single carry bit:
   169  //
   170  //	lo1, hi1 := m * mulsrc[i]
   171  //	lo1, carry = bits.Add(lo1, a, 0)
   172  //	a = hi + carry
   173  //	lo1, carry = bits.Add(lo1, addsrc[i], 0)
   174  //	a = a + carry
   175  //	z[i] = lo1
   176  func addMul(f *Func, addsrc, mulsrc string, mulIndex int) {
   177  	a := f.Asm
   178  	mh := HintNone
   179  	if a.Arch == Arch386 && addsrc != "" {
   180  		mh = HintMemOK // too few registers otherwise
   181  	}
   182  	m := f.ArgHint("m", mh)
   183  	c := f.Arg("a")
   184  	n := f.Arg("z_len")
   185  
   186  	p := f.Pipe()
   187  	if addsrc != "" {
   188  		p.SetHint(addsrc, HintMemOK)
   189  	}
   190  	p.SetHint(mulsrc, HintMulSrc)
   191  	unroll := []int{1, 4}
   192  	switch a.Arch {
   193  	case Arch386:
   194  		unroll = []int{1} // too few registers
   195  	case ArchARM:
   196  		p.SetMaxColumns(2) // too few registers (but more than 386)
   197  	case ArchARM64:
   198  		unroll = []int{1, 8} // 5% speedup on c4as16
   199  	}
   200  
   201  	// See the large comment above for an explanation of the code being generated.
   202  	// This is optimization strategy 1.
   203  	p.Start(n, unroll...)
   204  	p.Loop(func(in, out [][]Reg) {
   205  		a.Comment("multiply")
   206  		prev := c
   207  		flag := SetCarry
   208  		for i, x := range in[mulIndex] {
   209  			hi := a.RegHint(HintMulHi)
   210  			a.MulWide(m, x, x, hi)
   211  			a.Add(prev, x, x, flag)
   212  			flag = UseCarry | SetCarry
   213  			if prev != c {
   214  				a.Free(prev)
   215  			}
   216  			out[0][i] = x
   217  			prev = hi
   218  		}
   219  		a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry)
   220  		if addsrc != "" {
   221  			a.Comment("add")
   222  			flag := SetCarry
   223  			for i, x := range in[0] {
   224  				a.Add(x, out[0][i], out[0][i], flag)
   225  				flag = UseCarry | SetCarry
   226  			}
   227  			a.Add(a.Imm(0), c, c, UseCarry|SmashCarry)
   228  		}
   229  		p.StoreN(out)
   230  	})
   231  
   232  	f.StoreArg(c, "c")
   233  	a.Ret()
   234  }
   235  
   236  func addMulAlt(f *Func) {
   237  	a := f.Asm
   238  	m := f.ArgHint("m", HintMulSrc)
   239  	c := f.Arg("a")
   240  	n := f.Arg("z_len")
   241  
   242  	// On amd64, we need a non-immediate for the AtUnrollEnd adds.
   243  	r0 := a.ZR()
   244  	if !r0.Valid() {
   245  		r0 = a.Reg()
   246  		a.Mov(a.Imm(0), r0)
   247  	}
   248  
   249  	p := f.Pipe()
   250  	p.SetLabel("alt")
   251  	p.SetHint("x", HintMemOK)
   252  	p.SetHint("y", HintMemOK)
   253  	if a.Arch == ArchAMD64 {
   254  		p.SetMaxColumns(2)
   255  	}
   256  
   257  	// See the large comment above for an explanation of the code being generated.
   258  	// This is optimization strategy (2).
   259  	var hi Reg
   260  	prev := c
   261  	p.Start(n, 1, 8)
   262  	p.AtUnrollStart(func() {
   263  		a.Comment("multiply and add")
   264  		a.ClearCarry(AddCarry | AltCarry)
   265  		a.ClearCarry(AddCarry)
   266  		hi = a.Reg()
   267  	})
   268  	p.AtUnrollEnd(func() {
   269  		a.Add(r0, prev, c, UseCarry|SmashCarry)
   270  		a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry)
   271  		prev = c
   272  	})
   273  	p.Loop(func(in, out [][]Reg) {
   274  		for i, y := range in[1] {
   275  			x := in[0][i]
   276  			lo := y
   277  			if lo.IsMem() {
   278  				lo = a.Reg()
   279  			}
   280  			a.MulWide(m, y, lo, hi)
   281  			a.Add(prev, lo, lo, UseCarry|SetCarry)
   282  			a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry)
   283  			out[0][i] = lo
   284  			prev, hi = hi, prev
   285  		}
   286  		p.StoreN(out)
   287  	})
   288  
   289  	f.StoreArg(c, "c")
   290  	a.Ret()
   291  }
   292  
   293  func addMulVirtualCarry(f *Func, mulIndex int) {
   294  	a := f.Asm
   295  	m := f.Arg("m")
   296  	c := f.Arg("a")
   297  	n := f.Arg("z_len")
   298  
   299  	// See the large comment above for an explanation of the code being generated.
   300  	// This is optimization strategy (3).
   301  	p := f.Pipe()
   302  	p.Start(n, 1, 4)
   303  	p.Loop(func(in, out [][]Reg) {
   304  		a.Comment("synthetic carry, one column at a time")
   305  		lo, hi := a.Reg(), a.Reg()
   306  		for i, x := range in[mulIndex] {
   307  			a.MulWide(m, x, lo, hi)
   308  			if mulIndex == 1 {
   309  				a.Add(in[0][i], lo, lo, SetCarry)
   310  				a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry)
   311  			}
   312  			a.Add(c, lo, x, SetCarry)
   313  			a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry)
   314  			out[0][i] = x
   315  		}
   316  		p.StoreN(out)
   317  	})
   318  	f.StoreArg(c, "c")
   319  	a.Ret()
   320  }
   321  

View as plain text