1
2
3
4
5 package asmgen
6
7 import (
8 "fmt"
9 "math/bits"
10 "slices"
11 )
12
13
14
15
16
17
18
19
20
21
22
23
24 type Pipe struct {
25 f *Func
26 label string
27 backward bool
28 started bool
29 loaded bool
30 inPtr []RegPtr
31 hints []Hint
32 outPtr []RegPtr
33 index Reg
34 useIndexCounter bool
35 indexCounter int
36 readOff int
37 writeOff int
38 factors []int
39 counts []Reg
40 needWrite bool
41 maxColumns int
42 unrollStart func()
43 unrollEnd func()
44 }
45
46
47 func (f *Func) Pipe() *Pipe {
48 a := f.Asm
49 p := &Pipe{
50 f: f,
51 label: "loop",
52 maxColumns: 10000000,
53 }
54 if m := a.Arch.maxColumns; m != 0 {
55 p.maxColumns = m
56 }
57 return p
58 }
59
60
61
62 func (p *Pipe) SetBackward() {
63 if p.loaded {
64 p.f.Asm.Fatalf("SetBackward after Start/LoadPtrs")
65 }
66 p.backward = true
67 }
68
69
70
71
72
73 func (p *Pipe) SetUseIndexCounter() {
74 if p.f.Asm.Arch.memIndex == nil {
75 return
76 }
77 p.useIndexCounter = true
78 }
79
80
81
82 func (p *Pipe) SetLabel(label string) {
83 p.label = label
84 }
85
86
87
88 func (p *Pipe) SetMaxColumns(m int) {
89 p.maxColumns = m
90 }
91
92
93
94
95
96
97
98 func (p *Pipe) SetHint(name string, hint Hint) {
99 if hint == HintMemOK && !p.f.Asm.Arch.memOK {
100 return
101 }
102 i := slices.Index(p.f.inputs, name)
103 if i < 0 {
104 p.f.Asm.Fatalf("unknown input name %s", name)
105 }
106 if p.f.Asm.hint(hint) != "" {
107 p.SetMaxColumns(1)
108 }
109 for len(p.hints) <= i {
110 p.hints = append(p.hints, HintNone)
111 }
112 p.hints[i] = hint
113 }
114
115
116
117
118
119
120
121
122
123 func (p *Pipe) LoadPtrs(n Reg) {
124 a := p.f.Asm
125 if p.loaded {
126 a.Fatalf("pointers already loaded")
127 }
128
129
130 p.loaded = true
131 for _, name := range p.f.inputs {
132 p.inPtr = append(p.inPtr, RegPtr(p.f.Arg(name+"_base")))
133 }
134 for _, name := range p.f.outputs {
135 p.outPtr = append(p.outPtr, RegPtr(p.f.Arg(name+"_base")))
136 }
137
138
139 switch {
140 case p.backward && p.useIndexCounter:
141
142
143
144
145 a.Comment("run loop backward, using counter as positive index")
146 p.indexCounter = +1
147 p.index = n
148
149 case !p.backward && p.useIndexCounter:
150
151
152
153
154
155
156
157 a.Comment("use counter as negative index")
158 p.indexCounter = -1
159 p.index = n
160 for _, ptr := range p.inPtr {
161 a.AddWords(n, ptr, ptr)
162 }
163 for _, ptr := range p.outPtr {
164 a.AddWords(n, ptr, ptr)
165 }
166 a.Neg(n, n)
167
168 case p.backward:
169
170
171
172
173
174
175
176
177
178
179
180
181
182 a.Comment("run loop backward")
183 for _, ptr := range p.inPtr {
184 a.AddWords(n, ptr, ptr)
185 }
186 for _, ptr := range p.outPtr {
187 a.AddWords(n, ptr, ptr)
188 }
189
190 case !p.backward:
191
192 }
193 }
194
195
196
197
198 func (p *Pipe) LoadN(n int) [][]Reg {
199 a := p.f.Asm
200 regs := make([][]Reg, len(p.inPtr))
201 for i, ptr := range p.inPtr {
202 regs[i] = make([]Reg, n)
203 switch {
204 case a.Arch.loadIncN != nil:
205
206 for j := range regs[i] {
207 regs[i][j] = p.f.Asm.Reg()
208 }
209 if p.backward {
210 a.Arch.loadDecN(a, ptr, regs[i])
211 } else {
212 a.Arch.loadIncN(a, ptr, regs[i])
213 }
214
215 default:
216
217
218 for j := range n {
219 off := p.readOff + j
220 if p.backward {
221 off = -(off + 1)
222 }
223 var mem Reg
224 if p.indexCounter != 0 {
225 mem = a.Arch.memIndex(a, off*a.Arch.WordBytes, p.index, ptr)
226 } else {
227 mem = ptr.mem(off * a.Arch.WordBytes)
228 }
229 h := HintNone
230 if i < len(p.hints) {
231 h = p.hints[i]
232 }
233 if h == HintMemOK {
234 regs[i][j] = mem
235 } else {
236 r := p.f.Asm.RegHint(h)
237 a.Mov(mem, r)
238 regs[i][j] = r
239 }
240 }
241 }
242 }
243 p.readOff += n
244 return regs
245 }
246
247
248 func (p *Pipe) StoreN(regs [][]Reg) {
249 p.needWrite = false
250 a := p.f.Asm
251 if len(regs) != len(p.outPtr) {
252 p.f.Asm.Fatalf("wrong number of output rows")
253 }
254 n := len(regs[0])
255 for i, ptr := range p.outPtr {
256 switch {
257 case a.Arch.storeIncN != nil:
258
259 if p.backward {
260 a.Arch.storeDecN(a, ptr, regs[i])
261 } else {
262 a.Arch.storeIncN(a, ptr, regs[i])
263 }
264
265 default:
266
267
268 for j, r := range regs[i] {
269 off := p.writeOff + j
270 if p.backward {
271 off = -(off + 1)
272 }
273 var mem Reg
274 if p.indexCounter != 0 {
275 mem = a.Arch.memIndex(a, off*a.Arch.WordBytes, p.index, ptr)
276 } else {
277 mem = ptr.mem(off * a.Arch.WordBytes)
278 }
279 a.Mov(r, mem)
280 }
281 }
282 }
283 p.writeOff += n
284 }
285
286
287
288
289 func (p *Pipe) advancePtrs(step int) {
290 a := p.f.Asm
291 switch {
292 case a.Arch.loadIncN != nil:
293
294
295 default:
296
297 p.readOff -= step
298 p.writeOff -= step
299
300 if p.indexCounter == 0 {
301
302 if p.backward {
303 step = -step
304 }
305 for _, ptr := range p.inPtr {
306 a.Add(a.Imm(step*a.Arch.WordBytes), Reg(ptr), Reg(ptr), KeepCarry)
307 }
308 for _, ptr := range p.outPtr {
309 a.Add(a.Imm(step*a.Arch.WordBytes), Reg(ptr), Reg(ptr), KeepCarry)
310 }
311 }
312 }
313 }
314
315
316
317
318 func (p *Pipe) DropInput(name string) {
319 i := slices.Index(p.f.inputs, name)
320 if i < 0 {
321 p.f.Asm.Fatalf("unknown input %s", name)
322 }
323 ptr := p.inPtr[i]
324 p.f.Asm.Free(Reg(ptr))
325 p.inPtr = slices.Delete(p.inPtr, i, i+1)
326 p.f.inputs = slices.Delete(p.f.inputs, i, i+1)
327 if len(p.hints) > i {
328 p.hints = slices.Delete(p.hints, i, i+1)
329 }
330 }
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352 func (p *Pipe) Start(n Reg, factors ...int) {
353 a := p.f.Asm
354 if p.started {
355 a.Fatalf("loop already started")
356 }
357 if p.useIndexCounter && len(factors) > 1 {
358 a.Fatalf("cannot call SetUseIndexCounter and then use Start with factors != [1]; have factors = %v", factors)
359 }
360 p.started = true
361 if !p.loaded {
362 if len(factors) == 1 {
363 p.SetUseIndexCounter()
364 }
365 p.LoadPtrs(n)
366 }
367
368
369
370
371
372 if off := p.readOff; off != 0 {
373 if p.indexCounter < 0 {
374
375 a.Add(a.Imm(off), n, n, SmashCarry)
376 } else {
377 a.Sub(a.Imm(off), n, n, SmashCarry)
378 }
379 if p.indexCounter != 0 {
380
381
382 p.readOff -= off
383 p.writeOff -= off
384 }
385 }
386
387 p.Restart(n, factors...)
388 }
389
390
391
392 func (p *Pipe) Restart(n Reg, factors ...int) {
393 a := p.f.Asm
394 if !p.started {
395 a.Fatalf("pipe not started")
396 }
397 p.factors = factors
398 p.counts = make([]Reg, len(factors))
399 if len(factors) == 0 {
400 factors = []int{1}
401 }
402
403
404
405
406 if len(factors) > 1 {
407 a.Comment("compute unrolled loop lengths")
408 }
409 switch {
410 default:
411 a.Fatalf("invalid factors %v", factors)
412
413 case factors[0] == 1:
414
415 div := 1
416 for i, f := range factors[1:] {
417 if f <= factors[i] {
418 a.Fatalf("non-increasing factors %v", factors)
419 }
420 if f&(f-1) != 0 {
421 a.Fatalf("non-power-of-two factors %v", factors)
422 }
423 t := p.f.Asm.Reg()
424 f /= div
425 a.And(a.Imm(f-1), n, t)
426 a.Rsh(a.Imm(bits.TrailingZeros(uint(f))), n, n)
427 div *= f
428 p.counts[i] = t
429 }
430 p.counts[len(p.counts)-1] = n
431
432 case factors[len(factors)-1] == 1:
433
434 for i, f := range factors[:len(factors)-1] {
435 if f <= factors[i+1] {
436 a.Fatalf("non-decreasing factors %v", factors)
437 }
438 if f&(f-1) != 0 {
439 a.Fatalf("non-power-of-two factors %v", factors)
440 }
441 t := p.f.Asm.Reg()
442 a.Rsh(a.Imm(bits.TrailingZeros(uint(f))), n, t)
443 a.And(a.Imm(f-1), n, n)
444 p.counts[i] = t
445 }
446 p.counts[len(p.counts)-1] = n
447 }
448 }
449
450
451 func (p *Pipe) Done() {
452 for _, ptr := range p.inPtr {
453 p.f.Asm.Free(Reg(ptr))
454 }
455 p.inPtr = nil
456 for _, ptr := range p.outPtr {
457 p.f.Asm.Free(Reg(ptr))
458 }
459 p.outPtr = nil
460 p.index = Reg{}
461 }
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483 func (p *Pipe) Loop(block func(in, out [][]Reg)) {
484 if p.factors == nil {
485 p.f.Asm.Fatalf("Pipe.Start not called")
486 }
487 for i, factor := range p.factors {
488 n := p.counts[i]
489 p.unroll(n, factor, block)
490 if i < len(p.factors)-1 {
491 p.f.Asm.Free(n)
492 }
493 }
494 p.factors = nil
495 }
496
497
498
499 func (p *Pipe) AtUnrollStart(start func()) {
500 p.unrollStart = start
501 }
502
503
504
505 func (p *Pipe) AtUnrollEnd(end func()) {
506 p.unrollEnd = end
507 }
508
509
510 func (p *Pipe) unroll(n Reg, factor int, block func(in, out [][]Reg)) {
511 a := p.f.Asm
512 label := fmt.Sprintf("%s%d", p.label, factor)
513
514
515 a.Label(label)
516 if a.Arch.loopTop != "" {
517 a.Printf("\t"+a.Arch.loopTop+"\n", n, label+"done")
518 } else {
519 a.JmpZero(n, label+"done")
520 }
521 a.Label(label + "cont")
522
523
524 if factor < p.maxColumns {
525 a.Comment("unroll %dX", factor)
526 } else {
527 a.Comment("unroll %dX in batches of %d", factor, p.maxColumns)
528 }
529 if p.unrollStart != nil {
530 p.unrollStart()
531 }
532 for done := 0; done < factor; {
533 batch := min(factor-done, p.maxColumns)
534 regs := a.RegsUsed()
535 out := make([][]Reg, len(p.outPtr))
536 for i := range out {
537 out[i] = make([]Reg, batch)
538 }
539 in := p.LoadN(batch)
540 p.needWrite = true
541 block(in, out)
542 if p.needWrite && len(p.outPtr) > 0 {
543 a.Fatalf("missing p.Write1 or p.StoreN")
544 }
545 a.SetRegsUsed(regs)
546 done += batch
547 }
548 if p.unrollEnd != nil {
549 p.unrollEnd()
550 }
551 p.advancePtrs(factor)
552
553
554 switch {
555 case p.indexCounter >= 0 && a.Arch.loopBottom != "":
556 a.Printf("\t"+a.Arch.loopBottom+"\n", n, label+"cont")
557
558 case p.indexCounter >= 0:
559 a.Sub(a.Imm(1), n, n, KeepCarry)
560 a.JmpNonZero(n, label+"cont")
561
562 case p.indexCounter < 0 && a.Arch.loopBottomNeg != "":
563 a.Printf("\t"+a.Arch.loopBottomNeg+"\n", n, label+"cont")
564
565 case p.indexCounter < 0:
566 a.Add(a.Imm(1), n, n, KeepCarry)
567 }
568 a.Label(label + "done")
569 }
570
View as plain text