Source file src/math/big/internal/asmgen/pipe.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package asmgen
     6  
     7  import (
     8  	"fmt"
     9  	"math/bits"
    10  	"slices"
    11  )
    12  
    13  // Note: Exported fields and methods are expected to be used
    14  // by function generators (like the ones in add.go and so on).
    15  // Unexported fields and methods should not be.
    16  
    17  // A Pipe manages the input and output data pipelines for a function's
    18  // memory operations.
    19  //
    20  // The input is one or more equal-length slices of words, so collectively
    21  // it can be viewed as a matrix, in which each slice is a row and each column
    22  // is a set of corresponding words from the different slices.
    23  // The output can be viewed the same way, although it is often just one row.
    24  type Pipe struct {
    25  	f               *Func    // function being generated
    26  	label           string   // prefix for loop labels (default "loop")
    27  	backward        bool     // processing columns in reverse
    28  	started         bool     // Start has been called
    29  	loaded          bool     // LoadPtrs has been called
    30  	inPtr           []RegPtr // input slice pointers
    31  	hints           []Hint   // for each inPtr, a register hint to use for its data
    32  	outPtr          []RegPtr // output slice pointers
    33  	index           Reg      // index register, if in use
    34  	useIndexCounter bool     // index counter requested
    35  	indexCounter    int      // index is also counter (386); 0 no, -1 negative counter, +1 positive counter
    36  	readOff         int      // read offset not yet added to index
    37  	writeOff        int      // write offset not yet added to index
    38  	factors         []int    // unrolling factors
    39  	counts          []Reg    // iterations for each factor
    40  	needWrite       bool     // need a write call during Loop1/LoopN
    41  	maxColumns      int      // maximum columns during unrolled loop
    42  	unrollStart     func()   // emit code at start of unrolled body
    43  	unrollEnd       func()   // emit code end of unrolled body
    44  }
    45  
    46  // Pipe creates and returns a new pipe for use in the function f.
    47  func (f *Func) Pipe() *Pipe {
    48  	a := f.Asm
    49  	p := &Pipe{
    50  		f:          f,
    51  		label:      "loop",
    52  		maxColumns: 10000000,
    53  	}
    54  	if m := a.Arch.maxColumns; m != 0 {
    55  		p.maxColumns = m
    56  	}
    57  	return p
    58  }
    59  
    60  // SetBackward sets the pipe to process the input and output columns in reverse order.
    61  // This is needed for left shifts, which might otherwise overwrite data they will read later.
    62  func (p *Pipe) SetBackward() {
    63  	if p.loaded {
    64  		p.f.Asm.Fatalf("SetBackward after Start/LoadPtrs")
    65  	}
    66  	p.backward = true
    67  }
    68  
    69  // SetUseIndexCounter sets the pipe to use an index counter if possible,
    70  // meaning the loop counter is also used as an index for accessing the slice data.
    71  // This clever trick is slower on modern processors, but it is still necessary on 386.
    72  // On non-386 systems, SetUseIndexCounter is a no-op.
    73  func (p *Pipe) SetUseIndexCounter() {
    74  	if p.f.Asm.Arch.memIndex == nil { // need memIndex (only 386 provides it)
    75  		return
    76  	}
    77  	p.useIndexCounter = true
    78  }
    79  
    80  // SetLabel sets the label prefix for the loops emitted by the pipe.
    81  // The default prefix is "loop".
    82  func (p *Pipe) SetLabel(label string) {
    83  	p.label = label
    84  }
    85  
    86  // SetMaxColumns sets the maximum number of
    87  // columns processed in a single loop body call.
    88  func (p *Pipe) SetMaxColumns(m int) {
    89  	p.maxColumns = m
    90  }
    91  
    92  // SetHint records that the inputs from the named vector
    93  // should be allocated with the given register hint.
    94  //
    95  // If the hint indicates a single register on the target architecture,
    96  // then SetHint calls SetMaxColumns(1), since the hinted register
    97  // can only be used for one value at a time.
    98  func (p *Pipe) SetHint(name string, hint Hint) {
    99  	if hint == HintMemOK && !p.f.Asm.Arch.memOK {
   100  		return
   101  	}
   102  	i := slices.Index(p.f.inputs, name)
   103  	if i < 0 {
   104  		p.f.Asm.Fatalf("unknown input name %s", name)
   105  	}
   106  	if p.f.Asm.hint(hint) != "" {
   107  		p.SetMaxColumns(1)
   108  	}
   109  	for len(p.hints) <= i {
   110  		p.hints = append(p.hints, HintNone)
   111  	}
   112  	p.hints[i] = hint
   113  }
   114  
   115  // LoadPtrs loads the slice pointer arguments into registers,
   116  // assuming that the slice length n has already been loaded
   117  // into the register n.
   118  //
   119  // Start will call LoadPtrs if it has not been called already.
   120  // LoadPtrs only needs to be called explicitly when code needs
   121  // to use LoadN before Start, like when the shift.go generators
   122  // read an initial word before the loop.
   123  func (p *Pipe) LoadPtrs(n Reg) {
   124  	a := p.f.Asm
   125  	if p.loaded {
   126  		a.Fatalf("pointers already loaded")
   127  	}
   128  
   129  	// Load the actual pointers.
   130  	p.loaded = true
   131  	for _, name := range p.f.inputs {
   132  		p.inPtr = append(p.inPtr, RegPtr(p.f.Arg(name+"_base")))
   133  	}
   134  	for _, name := range p.f.outputs {
   135  		p.outPtr = append(p.outPtr, RegPtr(p.f.Arg(name+"_base")))
   136  	}
   137  
   138  	// Decide the memory access strategy for LoadN and StoreN.
   139  	switch {
   140  	case p.backward && p.useIndexCounter:
   141  		// Generator wants an index counter, meaning when the iteration counter
   142  		// is AX, we will access the slice with pointer BX using (BX)(AX*WordBytes).
   143  		// The loop is moving backward through the slice, but the counter
   144  		// is also moving backward, so not much to do.
   145  		a.Comment("run loop backward, using counter as positive index")
   146  		p.indexCounter = +1
   147  		p.index = n
   148  
   149  	case !p.backward && p.useIndexCounter:
   150  		// Generator wants an index counter, but the loop is moving forward.
   151  		// To make the counter move in the direction of data access,
   152  		// we negate the counter, counting up from -len(z) to -1.
   153  		// To make the index access the right words, we add len(z)*WordBytes
   154  		// to each of the pointers.
   155  		// See comment below about the garbage collector (non-)implications
   156  		// of pointing beyond the slice bounds.
   157  		a.Comment("use counter as negative index")
   158  		p.indexCounter = -1
   159  		p.index = n
   160  		for _, ptr := range p.inPtr {
   161  			a.AddWords(n, ptr, ptr)
   162  		}
   163  		for _, ptr := range p.outPtr {
   164  			a.AddWords(n, ptr, ptr)
   165  		}
   166  		a.Neg(n, n)
   167  
   168  	case p.backward:
   169  		// Generator wants to run the loop backward.
   170  		// We'll decrement the pointers before using them,
   171  		// so position them at the very end of the slices.
   172  		// If we had precise pointer information for assembly,
   173  		// these pointers would cause problems with the garbage collector,
   174  		// since they no longer point into the allocated slice,
   175  		// but the garbage collector ignores unexpected values in assembly stacks,
   176  		// and the actual slice pointers are still in the argument stack slots,
   177  		// so the slices won't be collected early.
   178  		// If we switched to the register ABI, we might have to rethink this.
   179  		// (The same thing happens by the end of forward loops,
   180  		// but it's less important since once the pointers go off the slice
   181  		// in a forward loop, the loop is over and the slice won't be accessed anymore.)
   182  		a.Comment("run loop backward")
   183  		for _, ptr := range p.inPtr {
   184  			a.AddWords(n, ptr, ptr)
   185  		}
   186  		for _, ptr := range p.outPtr {
   187  			a.AddWords(n, ptr, ptr)
   188  		}
   189  
   190  	case !p.backward:
   191  		// Nothing to do!
   192  	}
   193  }
   194  
   195  // LoadN returns the next n columns of input words as a slice of rows.
   196  // Regs for inputs that have been marked using p.SetMemOK will be direct memory references.
   197  // Regs for other inputs will be newly allocated registers and must be freed.
   198  func (p *Pipe) LoadN(n int) [][]Reg {
   199  	a := p.f.Asm
   200  	regs := make([][]Reg, len(p.inPtr))
   201  	for i, ptr := range p.inPtr {
   202  		regs[i] = make([]Reg, n)
   203  		switch {
   204  		case a.Arch.loadIncN != nil:
   205  			// Load from memory and advance pointers at the same time.
   206  			for j := range regs[i] {
   207  				regs[i][j] = p.f.Asm.Reg()
   208  			}
   209  			if p.backward {
   210  				a.Arch.loadDecN(a, ptr, regs[i])
   211  			} else {
   212  				a.Arch.loadIncN(a, ptr, regs[i])
   213  			}
   214  
   215  		default:
   216  			// Load from memory using offsets.
   217  			// We'll advance the pointers or the index counter later.
   218  			for j := range n {
   219  				off := p.readOff + j
   220  				if p.backward {
   221  					off = -(off + 1)
   222  				}
   223  				var mem Reg
   224  				if p.indexCounter != 0 {
   225  					mem = a.Arch.memIndex(a, off*a.Arch.WordBytes, p.index, ptr)
   226  				} else {
   227  					mem = ptr.mem(off * a.Arch.WordBytes)
   228  				}
   229  				h := HintNone
   230  				if i < len(p.hints) {
   231  					h = p.hints[i]
   232  				}
   233  				if h == HintMemOK {
   234  					regs[i][j] = mem
   235  				} else {
   236  					r := p.f.Asm.RegHint(h)
   237  					a.Mov(mem, r)
   238  					regs[i][j] = r
   239  				}
   240  			}
   241  		}
   242  	}
   243  	p.readOff += n
   244  	return regs
   245  }
   246  
   247  // StoreN writes regs (a slice of rows) to the next n columns of output, where n = len(regs[0]).
   248  func (p *Pipe) StoreN(regs [][]Reg) {
   249  	p.needWrite = false
   250  	a := p.f.Asm
   251  	if len(regs) != len(p.outPtr) {
   252  		p.f.Asm.Fatalf("wrong number of output rows")
   253  	}
   254  	n := len(regs[0])
   255  	for i, ptr := range p.outPtr {
   256  		switch {
   257  		case a.Arch.storeIncN != nil:
   258  			// Store to memory and advance pointers at the same time.
   259  			if p.backward {
   260  				a.Arch.storeDecN(a, ptr, regs[i])
   261  			} else {
   262  				a.Arch.storeIncN(a, ptr, regs[i])
   263  			}
   264  
   265  		default:
   266  			// Store to memory using offsets.
   267  			// We'll advance the pointers or the index counter later.
   268  			for j, r := range regs[i] {
   269  				off := p.writeOff + j
   270  				if p.backward {
   271  					off = -(off + 1)
   272  				}
   273  				var mem Reg
   274  				if p.indexCounter != 0 {
   275  					mem = a.Arch.memIndex(a, off*a.Arch.WordBytes, p.index, ptr)
   276  				} else {
   277  					mem = ptr.mem(off * a.Arch.WordBytes)
   278  				}
   279  				a.Mov(r, mem)
   280  			}
   281  		}
   282  	}
   283  	p.writeOff += n
   284  }
   285  
   286  // advancePtrs advances the pointers by step
   287  // or handles bookkeeping for an imminent index advance by step
   288  // that the caller will do.
   289  func (p *Pipe) advancePtrs(step int) {
   290  	a := p.f.Asm
   291  	switch {
   292  	case a.Arch.loadIncN != nil:
   293  		// nothing to do
   294  
   295  	default:
   296  		// Adjust read/write offsets for pointer advance (or imminent index advance).
   297  		p.readOff -= step
   298  		p.writeOff -= step
   299  
   300  		if p.indexCounter == 0 {
   301  			// Advance pointers.
   302  			if p.backward {
   303  				step = -step
   304  			}
   305  			for _, ptr := range p.inPtr {
   306  				a.Add(a.Imm(step*a.Arch.WordBytes), Reg(ptr), Reg(ptr), KeepCarry)
   307  			}
   308  			for _, ptr := range p.outPtr {
   309  				a.Add(a.Imm(step*a.Arch.WordBytes), Reg(ptr), Reg(ptr), KeepCarry)
   310  			}
   311  		}
   312  	}
   313  }
   314  
   315  // DropInput deletes the named input from the pipe,
   316  // usually because it has been exhausted.
   317  // (This is not used yet but will be used in a future generator.)
   318  func (p *Pipe) DropInput(name string) {
   319  	i := slices.Index(p.f.inputs, name)
   320  	if i < 0 {
   321  		p.f.Asm.Fatalf("unknown input %s", name)
   322  	}
   323  	ptr := p.inPtr[i]
   324  	p.f.Asm.Free(Reg(ptr))
   325  	p.inPtr = slices.Delete(p.inPtr, i, i+1)
   326  	p.f.inputs = slices.Delete(p.f.inputs, i, i+1)
   327  	if len(p.hints) > i {
   328  		p.hints = slices.Delete(p.hints, i, i+1)
   329  	}
   330  }
   331  
   332  // Start prepares to loop over n columns.
   333  // The factors give a sequence of unrolling factors to use,
   334  // which must be either strictly increasing or strictly decreasing
   335  // and must include 1.
   336  // For example, 4, 1 means to process 4 elements at a time
   337  // and then 1 at a time for the final 0-3; specifying 1,4 instead
   338  // handles 0-3 elements first and then 4 at a time.
   339  // Similarly, 32, 4, 1 means to process 32 at a time,
   340  // then 4 at a time, then 1 at a time.
   341  //
   342  // One benefit of using 1, 4 instead of 4, 1 is that the body
   343  // processing 4 at a time needs more registers, and if it is
   344  // the final body, the register holding the fragment count (0-3)
   345  // has been freed and is available for use.
   346  //
   347  // Start may modify the carry flag.
   348  //
   349  // Start must be followed by a call to Loop1 or LoopN,
   350  // but it is permitted to emit other instructions first,
   351  // for example to set an initial carry flag.
   352  func (p *Pipe) Start(n Reg, factors ...int) {
   353  	a := p.f.Asm
   354  	if p.started {
   355  		a.Fatalf("loop already started")
   356  	}
   357  	if p.useIndexCounter && len(factors) > 1 {
   358  		a.Fatalf("cannot call SetUseIndexCounter and then use Start with factors != [1]; have factors = %v", factors)
   359  	}
   360  	p.started = true
   361  	if !p.loaded {
   362  		if len(factors) == 1 {
   363  			p.SetUseIndexCounter()
   364  		}
   365  		p.LoadPtrs(n)
   366  	}
   367  
   368  	// If there were calls to LoadN between LoadPtrs and Start,
   369  	// adjust the loop not to scan those columns, assuming that
   370  	// either the code already called an equivalent StoreN or else
   371  	// that it will do so after the loop.
   372  	if off := p.readOff; off != 0 {
   373  		if p.indexCounter < 0 {
   374  			// Index is negated, so add off instead of subtracting.
   375  			a.Add(a.Imm(off), n, n, SmashCarry)
   376  		} else {
   377  			a.Sub(a.Imm(off), n, n, SmashCarry)
   378  		}
   379  		if p.indexCounter != 0 {
   380  			// n is also the index we are using, so adjust readOff and writeOff
   381  			// to continue to point at the same positions as before we changed n.
   382  			p.readOff -= off
   383  			p.writeOff -= off
   384  		}
   385  	}
   386  
   387  	p.Restart(n, factors...)
   388  }
   389  
   390  // Restart prepares to loop over an additional n columns,
   391  // beyond a previous loop run by p.Start/p.Loop.
   392  func (p *Pipe) Restart(n Reg, factors ...int) {
   393  	a := p.f.Asm
   394  	if !p.started {
   395  		a.Fatalf("pipe not started")
   396  	}
   397  	p.factors = factors
   398  	p.counts = make([]Reg, len(factors))
   399  	if len(factors) == 0 {
   400  		factors = []int{1}
   401  	}
   402  
   403  	// Compute the loop lengths for each unrolled section into separate registers.
   404  	// We compute them all ahead of time in case the computation would smash
   405  	// a carry flag that the loop bodies need preserved.
   406  	if len(factors) > 1 {
   407  		a.Comment("compute unrolled loop lengths")
   408  	}
   409  	switch {
   410  	default:
   411  		a.Fatalf("invalid factors %v", factors)
   412  
   413  	case factors[0] == 1:
   414  		// increasing loop factors
   415  		div := 1
   416  		for i, f := range factors[1:] {
   417  			if f <= factors[i] {
   418  				a.Fatalf("non-increasing factors %v", factors)
   419  			}
   420  			if f&(f-1) != 0 {
   421  				a.Fatalf("non-power-of-two factors %v", factors)
   422  			}
   423  			t := p.f.Asm.Reg()
   424  			f /= div
   425  			a.And(a.Imm(f-1), n, t)
   426  			a.Rsh(a.Imm(bits.TrailingZeros(uint(f))), n, n)
   427  			div *= f
   428  			p.counts[i] = t
   429  		}
   430  		p.counts[len(p.counts)-1] = n
   431  
   432  	case factors[len(factors)-1] == 1:
   433  		// decreasing loop factors
   434  		for i, f := range factors[:len(factors)-1] {
   435  			if f <= factors[i+1] {
   436  				a.Fatalf("non-decreasing factors %v", factors)
   437  			}
   438  			if f&(f-1) != 0 {
   439  				a.Fatalf("non-power-of-two factors %v", factors)
   440  			}
   441  			t := p.f.Asm.Reg()
   442  			a.Rsh(a.Imm(bits.TrailingZeros(uint(f))), n, t)
   443  			a.And(a.Imm(f-1), n, n)
   444  			p.counts[i] = t
   445  		}
   446  		p.counts[len(p.counts)-1] = n
   447  	}
   448  }
   449  
   450  // Done frees all the registers allocated by the pipe.
   451  func (p *Pipe) Done() {
   452  	for _, ptr := range p.inPtr {
   453  		p.f.Asm.Free(Reg(ptr))
   454  	}
   455  	p.inPtr = nil
   456  	for _, ptr := range p.outPtr {
   457  		p.f.Asm.Free(Reg(ptr))
   458  	}
   459  	p.outPtr = nil
   460  	p.index = Reg{}
   461  }
   462  
   463  // Loop emits code for the loop, calling block repeatedly to emit code that
   464  // handles a block of N input columns (for arbitrary N = len(in[0]) chosen by p).
   465  // block must call p.StoreN(out) to write N output columns.
   466  // The out slice is a pre-allocated matrix of uninitialized Reg values.
   467  // block is expected to set each entry to the Reg that should be written
   468  // before calling p.StoreN(out).
   469  //
   470  // For example, if the loop is to be unrolled 4x in blocks of 2 columns each,
   471  // the sequence of calls to emit the unrolled loop body is:
   472  //
   473  //	start()  // set by pAtUnrollStart
   474  //	... reads for 2 columns ...
   475  //	block()
   476  //	... writes for 2 columns ...
   477  //	... reads for 2 columns ...
   478  //	block()
   479  //	... writes for 2 columns ...
   480  //	end()  // set by p.AtUnrollEnd
   481  //
   482  // Any registers allocated during block are freed automatically when block returns.
   483  func (p *Pipe) Loop(block func(in, out [][]Reg)) {
   484  	if p.factors == nil {
   485  		p.f.Asm.Fatalf("Pipe.Start not called")
   486  	}
   487  	for i, factor := range p.factors {
   488  		n := p.counts[i]
   489  		p.unroll(n, factor, block)
   490  		if i < len(p.factors)-1 {
   491  			p.f.Asm.Free(n)
   492  		}
   493  	}
   494  	p.factors = nil
   495  }
   496  
   497  // AtUnrollStart sets a function to call at the start of an unrolled sequence.
   498  // See [Pipe.Loop] for details.
   499  func (p *Pipe) AtUnrollStart(start func()) {
   500  	p.unrollStart = start
   501  }
   502  
   503  // AtUnrollEnd sets a function to call at the end of an unrolled sequence.
   504  // See [Pipe.Loop] for details.
   505  func (p *Pipe) AtUnrollEnd(end func()) {
   506  	p.unrollEnd = end
   507  }
   508  
   509  // unroll emits a single unrolled loop for the given factor, iterating n times.
   510  func (p *Pipe) unroll(n Reg, factor int, block func(in, out [][]Reg)) {
   511  	a := p.f.Asm
   512  	label := fmt.Sprintf("%s%d", p.label, factor)
   513  
   514  	// Top of loop control flow.
   515  	a.Label(label)
   516  	if a.Arch.loopTop != "" {
   517  		a.Printf("\t"+a.Arch.loopTop+"\n", n, label+"done")
   518  	} else {
   519  		a.JmpZero(n, label+"done")
   520  	}
   521  	a.Label(label + "cont")
   522  
   523  	// Unrolled loop body.
   524  	if factor < p.maxColumns {
   525  		a.Comment("unroll %dX", factor)
   526  	} else {
   527  		a.Comment("unroll %dX in batches of %d", factor, p.maxColumns)
   528  	}
   529  	if p.unrollStart != nil {
   530  		p.unrollStart()
   531  	}
   532  	for done := 0; done < factor; {
   533  		batch := min(factor-done, p.maxColumns)
   534  		regs := a.RegsUsed()
   535  		out := make([][]Reg, len(p.outPtr))
   536  		for i := range out {
   537  			out[i] = make([]Reg, batch)
   538  		}
   539  		in := p.LoadN(batch)
   540  		p.needWrite = true
   541  		block(in, out)
   542  		if p.needWrite && len(p.outPtr) > 0 {
   543  			a.Fatalf("missing p.Write1 or p.StoreN")
   544  		}
   545  		a.SetRegsUsed(regs) // free anything block allocated
   546  		done += batch
   547  	}
   548  	if p.unrollEnd != nil {
   549  		p.unrollEnd()
   550  	}
   551  	p.advancePtrs(factor)
   552  
   553  	// Bottom of loop control flow.
   554  	switch {
   555  	case p.indexCounter >= 0 && a.Arch.loopBottom != "":
   556  		a.Printf("\t"+a.Arch.loopBottom+"\n", n, label+"cont")
   557  
   558  	case p.indexCounter >= 0:
   559  		a.Sub(a.Imm(1), n, n, KeepCarry)
   560  		a.JmpNonZero(n, label+"cont")
   561  
   562  	case p.indexCounter < 0 && a.Arch.loopBottomNeg != "":
   563  		a.Printf("\t"+a.Arch.loopBottomNeg+"\n", n, label+"cont")
   564  
   565  	case p.indexCounter < 0:
   566  		a.Add(a.Imm(1), n, n, KeepCarry)
   567  	}
   568  	a.Label(label + "done")
   569  }
   570  

View as plain text