1
2
3
4
5 package main
6
7 import (
8 "bufio"
9 "bytes"
10 "fmt"
11 "go/format"
12 "log"
13 "os"
14 "path/filepath"
15 "reflect"
16 "slices"
17 "sort"
18 "strings"
19 "text/template"
20 "unicode"
21 )
22
23 func templateOf(temp, name string) *template.Template {
24 t, err := template.New(name).Parse(temp)
25 if err != nil {
26 panic(fmt.Errorf("failed to parse template %s: %w", name, err))
27 }
28 return t
29 }
30
31 func createPath(goroot string, file string) (*os.File, error) {
32 fp := filepath.Join(goroot, file)
33 dir := filepath.Dir(fp)
34 err := os.MkdirAll(dir, 0755)
35 if err != nil {
36 return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
37 }
38 f, err := os.Create(fp)
39 if err != nil {
40 return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
41 }
42 return f, nil
43 }
44
45 func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
46 b, err := format.Source(out.Bytes())
47 if err != nil {
48 fmt.Fprintf(os.Stderr, "%v\n", err)
49 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
50 fmt.Fprintf(os.Stderr, "%v\n", err)
51 panic(err)
52 } else {
53 writeAndClose(b, goroot, file)
54 }
55 }
56
57 func writeAndClose(b []byte, goroot string, file string) {
58 ofile, err := createPath(goroot, file)
59 if err != nil {
60 panic(err)
61 }
62 ofile.Write(b)
63 ofile.Close()
64 }
65
66
67
68 func numberLines(data []byte) string {
69 var buf bytes.Buffer
70 r := bytes.NewReader(data)
71 s := bufio.NewScanner(r)
72 for i := 1; s.Scan(); i++ {
73 fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
74 }
75 return buf.String()
76 }
77
78 type inShape uint8
79 type outShape uint8
80 type maskShape uint8
81 type immShape uint8
82 type memShape uint8
83
84 const (
85 InvalidIn inShape = iota
86 PureVregIn
87 OneKmaskIn
88 OneImmIn
89 OneKmaskImmIn
90 PureKmaskIn
91 )
92
93 const (
94 InvalidOut outShape = iota
95 NoOut
96 OneVregOut
97 OneGregOut
98 OneKmaskOut
99 OneVregOutAtIn
100 )
101
102 const (
103 InvalidMask maskShape = iota
104 NoMask
105 OneMask
106 AllMasks
107 )
108
109 const (
110 InvalidImm immShape = iota
111 NoImm
112 ConstImm
113 VarImm
114 ConstVarImm
115 )
116
117 const (
118 InvalidMem memShape = iota
119 NoMem
120 VregMemIn
121 )
122
123
124
125
126
127
128
129 func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
130 opNoImm Operation) {
131 if len(op.Out) > 1 {
132 panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
133 }
134 var outputReg int
135 if len(op.Out) == 1 {
136 outputReg = op.Out[0].AsmPos
137 if op.Out[0].Class == "vreg" {
138 shapeOut = OneVregOut
139 } else if op.Out[0].Class == "greg" {
140 shapeOut = OneGregOut
141 } else if op.Out[0].Class == "mask" {
142 shapeOut = OneKmaskOut
143 } else {
144 panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
145 }
146 } else {
147 shapeOut = NoOut
148
149
150 panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
151 }
152 hasImm := false
153 maskCount := 0
154 hasVreg := false
155 for _, in := range op.In {
156 if in.AsmPos == outputReg {
157 if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
158 shapeOut = OneVregOutAtIn
159 } else {
160 panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
161 }
162 }
163 if in.Class == "immediate" {
164
165
166 if *in.Bits != 8 {
167 panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
168 }
169 hasImm = true
170 } else if in.Class == "mask" {
171 maskCount++
172 } else {
173 hasVreg = true
174 }
175 }
176 opNoImm = *op
177
178 removeImm := func(o *Operation) {
179 o.In = o.In[1:]
180 }
181 if hasImm {
182 removeImm(&opNoImm)
183 if op.In[0].Const != nil {
184 if op.In[0].ImmOffset != nil {
185 immType = ConstVarImm
186 } else {
187 immType = ConstImm
188 }
189 } else if op.In[0].ImmOffset != nil {
190 immType = VarImm
191 } else {
192 panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
193 }
194 } else {
195 immType = NoImm
196 }
197 if maskCount == 0 {
198 maskType = NoMask
199 } else {
200 maskType = OneMask
201 }
202 checkPureMask := func() bool {
203 if hasImm {
204 panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
205 }
206 if hasVreg {
207 panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
208 }
209 return false
210 }
211 if !hasImm && maskCount == 0 {
212 shapeIn = PureVregIn
213 } else if !hasImm && maskCount > 0 {
214 if maskCount == 1 {
215 shapeIn = OneKmaskIn
216 } else {
217 if checkPureMask() {
218 return
219 }
220 shapeIn = PureKmaskIn
221 maskType = AllMasks
222 }
223 } else if hasImm && maskCount == 0 {
224 shapeIn = OneImmIn
225 } else {
226 if maskCount == 1 {
227 shapeIn = OneKmaskImmIn
228 } else {
229 checkPureMask()
230 return
231 }
232 }
233 return
234 }
235
236
237 func (op *Operation) regShape(mem memShape) (string, error) {
238 _, _, _, _, gOp := op.shape()
239 var regInfo, fixedName string
240 var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt, memInCnt, memOutCnt int
241 for i, in := range gOp.In {
242 switch in.Class {
243 case "vreg":
244 vRegInCnt++
245 case "greg":
246 gRegInCnt++
247 case "mask":
248 kMaskInCnt++
249 case "memory":
250 if mem != VregMemIn {
251 panic("simdgen only knows VregMemIn in regShape")
252 }
253 memInCnt++
254 vRegInCnt++
255 }
256 if in.FixedReg != nil {
257 fixedName = fmt.Sprintf("%sAtIn%d", *in.FixedReg, i)
258 }
259 }
260 for i, out := range gOp.Out {
261
262 if out.Class == "vreg" || out.OverwriteClass != nil {
263 vRegOutCnt++
264 } else if out.Class == "greg" {
265 gRegOutCnt++
266 } else if out.Class == "mask" {
267 kMaskOutCnt++
268 } else if out.Class == "memory" {
269 if mem != VregMemIn {
270 panic("simdgen only knows VregMemIn in regShape")
271 }
272 vRegOutCnt++
273 memOutCnt++
274 }
275 if out.FixedReg != nil {
276 fixedName = fmt.Sprintf("%sAtIn%d", *out.FixedReg, i)
277 }
278 }
279 var inRegs, inMasks, outRegs, outMasks string
280
281 rmAbbrev := func(s string, i int) string {
282 if i == 0 {
283 return ""
284 }
285 if i == 1 {
286 return s
287 }
288 return fmt.Sprintf("%s%d", s, i)
289
290 }
291
292 inRegs = rmAbbrev("v", vRegInCnt)
293 inRegs += rmAbbrev("gp", gRegInCnt)
294 inMasks = rmAbbrev("k", kMaskInCnt)
295
296 outRegs = rmAbbrev("v", vRegOutCnt)
297 outRegs += rmAbbrev("gp", gRegOutCnt)
298 outMasks = rmAbbrev("k", kMaskOutCnt)
299
300 if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
301
302 regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
303 } else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
304 regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
305 } else {
306 regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
307 }
308 if memInCnt > 0 {
309 if memInCnt == 1 {
310 regInfo += "load"
311 } else {
312 panic("simdgen does not understand more than 1 mem op as of now")
313 }
314 }
315 if memOutCnt > 0 {
316 panic("simdgen does not understand memory as output as of now")
317 }
318 regInfo += fixedName
319 return regInfo, nil
320 }
321
322
323
324
325
326 func (op *Operation) sortOperand() {
327 priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
328 sort.SliceStable(op.In, func(i, j int) bool {
329 pi := priority[op.In[i].Class]
330 pj := priority[op.In[j].Class]
331 if pi != pj {
332 return pi < pj
333 }
334 return op.In[i].AsmPos < op.In[j].AsmPos
335 })
336 }
337
338
339 func (op *Operation) adjustAsm() {
340 if op.Asm == "VCVTTPD2DQ" || op.Asm == "VCVTTPD2UDQ" ||
341 op.Asm == "VCVTQQ2PS" || op.Asm == "VCVTUQQ2PS" ||
342 op.Asm == "VCVTPD2PS" {
343 switch *op.In[0].Bits {
344 case 128:
345 op.Asm += "X"
346 case 256:
347 op.Asm += "Y"
348 }
349 }
350 }
351
352
353
354
355
356
357
358 func (op Operation) goNormalType() string {
359 if op.Go == "GetElem" {
360
361
362
363
364
365
366
367 at := 0
368 if op.In[at].Class == "immediate" {
369 at++
370 }
371 return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
372 }
373 panic(fmt.Errorf("Implement goNormalType for %v", op))
374 }
375
376
377
378 func (op Operation) SSAType() string {
379 if op.Out[0].Class == "greg" {
380 return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
381 }
382 return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
383 }
384
385
386
387 func (op Operation) GoType() string {
388 if op.Out[0].Class == "greg" {
389 return op.goNormalType()
390 }
391 return *op.Out[0].Go
392 }
393
394
395
396
397 func (op Operation) ImmName() string {
398 return op.Op0Name("constant")
399 }
400
401 func (o Operand) OpName(s string) string {
402 if n := o.Name; n != nil {
403 return *n
404 }
405 if o.Class == "mask" {
406 return "mask"
407 }
408 return s
409 }
410
411 func (o Operand) OpNameAndType(s string) string {
412 return o.OpName(s) + " " + *o.Go
413 }
414
415
416 func (op Operation) GoExported() string {
417 return capitalizeFirst(op.Go)
418 }
419
420
421 func (op Operation) DocumentationExported() string {
422 return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
423 }
424
425
426
427 func (op Operation) Op0Name(s string) string {
428 return op.In[0].OpName(s)
429 }
430
431
432
433 func (op Operation) Op1Name(s string) string {
434 return op.In[1].OpName(s)
435 }
436
437
438
439 func (op Operation) Op2Name(s string) string {
440 return op.In[2].OpName(s)
441 }
442
443
444
445 func (op Operation) Op3Name(s string) string {
446 return op.In[3].OpName(s)
447 }
448
449
450
451
452 func (op Operation) Op0NameAndType(s string) string {
453 return op.In[0].OpNameAndType(s)
454 }
455
456
457
458
459 func (op Operation) Op1NameAndType(s string) string {
460 return op.In[1].OpNameAndType(s)
461 }
462
463
464
465
466 func (op Operation) Op2NameAndType(s string) string {
467 return op.In[2].OpNameAndType(s)
468 }
469
470
471
472
473 func (op Operation) Op3NameAndType(s string) string {
474 return op.In[3].OpNameAndType(s)
475 }
476
477
478
479
480 func (op Operation) Op4NameAndType(s string) string {
481 return op.In[4].OpNameAndType(s)
482 }
483
484 var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
485 var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
486
487
488
489
490
491
492
493 func classifyOp(op Operation) (string, Operation, error) {
494 _, _, _, immType, gOp := op.shape()
495
496 var class string
497
498 if immType == VarImm || immType == ConstVarImm {
499 switch l := len(op.In); l {
500 case 1:
501 return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
502 case 2, 3, 4, 5:
503 class = immClasses[l]
504 default:
505 return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
506 }
507 if order := op.OperandOrder; order != nil {
508 class += "_" + *order
509 }
510 return class, op, nil
511 } else {
512 switch l := len(gOp.In); l {
513 case 1, 2, 3, 4:
514 class = classes[l]
515 default:
516 return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
517 }
518 if order := op.OperandOrder; order != nil {
519 class += "_" + *order
520 }
521 return class, gOp, nil
522 }
523 }
524
525 func checkVecAsScalar(op Operation) (idx int, err error) {
526 idx = -1
527 sSize := 0
528 for i, o := range op.In {
529 if o.TreatLikeAScalarOfSize != nil {
530 if idx == -1 {
531 idx = i
532 sSize = *o.TreatLikeAScalarOfSize
533 } else {
534 err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
535 return
536 }
537 }
538 }
539 if idx >= 0 {
540 if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
541 err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
542 return
543 }
544 }
545 return
546 }
547
548 func rewriteVecAsScalarRegInfo(op Operation, regInfo string) (string, error) {
549 idx, err := checkVecAsScalar(op)
550 if err != nil {
551 return "", err
552 }
553 if idx != -1 {
554 if regInfo == "v21" {
555 regInfo = "vfpv"
556 } else if regInfo == "v2kv" {
557 regInfo = "vfpkv"
558 } else if regInfo == "v31" {
559 regInfo = "v2fpv"
560 } else if regInfo == "v3kv" {
561 regInfo = "v2fpkv"
562 } else {
563 return "", fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op)
564 }
565 }
566 return regInfo, nil
567 }
568
569 func rewriteLastVregToMem(op Operation) Operation {
570 newIn := make([]Operand, len(op.In))
571 lastVregIdx := -1
572 for i := range len(op.In) {
573 newIn[i] = op.In[i]
574 if op.In[i].Class == "vreg" {
575 lastVregIdx = i
576 }
577 }
578
579 if lastVregIdx == -1 {
580 panic("simdgen cannot find one vreg in the mem op vreg original")
581 }
582 newIn[lastVregIdx].Class = "memory"
583 op.In = newIn
584
585 return op
586 }
587
588
589 func dedup(ops []Operation) (deduped []Operation) {
590 for _, op := range ops {
591 seen := false
592 for _, dop := range deduped {
593 if reflect.DeepEqual(op, dop) {
594 seen = true
595 break
596 }
597 }
598 if !seen {
599 deduped = append(deduped, op)
600 }
601 }
602 return
603 }
604
605 func (op Operation) GenericName() string {
606 if op.OperandOrder != nil {
607 switch *op.OperandOrder {
608 case "21Type1", "231Type1":
609
610 return op.Go + *op.In[1].Go
611 }
612 }
613 if op.In[0].Class == "immediate" {
614 return op.Go + *op.In[1].Go
615 }
616 return op.Go + *op.In[0].Go
617 }
618
619
620
621
622
623 func dedupGodef(ops []Operation) ([]Operation, error) {
624 seen := map[string][]Operation{}
625 for _, op := range ops {
626 _, _, _, _, gOp := op.shape()
627
628 gN := gOp.GenericName()
629 seen[gN] = append(seen[gN], op)
630 }
631 if *FlagReportDup {
632 for gName, dup := range seen {
633 if len(dup) > 1 {
634 log.Printf("Duplicate for %s:\n", gName)
635 for _, op := range dup {
636 log.Printf("%s\n", op)
637 }
638 }
639 }
640 return ops, nil
641 }
642 isAVX512 := func(op Operation) bool {
643 return strings.Contains(op.CPUFeature, "AVX512")
644 }
645 deduped := []Operation{}
646 for _, dup := range seen {
647 if len(dup) > 1 {
648 slices.SortFunc(dup, func(i, j Operation) int {
649
650 if !isAVX512(i) && isAVX512(j) {
651 return -1
652 }
653 if isAVX512(i) && !isAVX512(j) {
654 return 1
655 }
656 if i.CPUFeature != j.CPUFeature {
657 return strings.Compare(i.CPUFeature, j.CPUFeature)
658 }
659
660
661
662
663 if i.MemFeatures != nil && j.MemFeatures == nil {
664 return -1
665 }
666 if i.MemFeatures == nil && j.MemFeatures != nil {
667 return 1
668 }
669 if i.Commutative != j.Commutative {
670 if j.Commutative {
671 return -1
672 }
673 return 1
674 }
675
676 return 0
677 })
678 }
679 deduped = append(deduped, dup[0])
680 }
681 slices.SortFunc(deduped, compareOperations)
682 return deduped, nil
683 }
684
685
686
687 func copyConstImm(ops []Operation) error {
688 for _, op := range ops {
689 if op.ConstImm == nil {
690 continue
691 }
692 _, _, _, immType, _ := op.shape()
693
694 if immType == ConstImm || immType == ConstVarImm {
695 op.In[0].Const = op.ConstImm
696 }
697
698
699 }
700 return nil
701 }
702
703 func capitalizeFirst(s string) string {
704 if s == "" {
705 return ""
706 }
707
708 r := []rune(s)
709 r[0] = unicode.ToUpper(r[0])
710 return string(r)
711 }
712
713
714
715
716
717
718
719 func overwrite(ops []Operation) error {
720 hasClassOverwrite := false
721 overwrite := func(op []Operand, idx int, o Operation) error {
722 if op[idx].OverwriteElementBits != nil {
723 if op[idx].ElemBits == nil {
724 panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
725 }
726 *op[idx].ElemBits = *op[idx].OverwriteElementBits
727 *op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
728 *op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
729 }
730 if op[idx].OverwriteClass != nil {
731 if op[idx].OverwriteBase == nil {
732 panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
733 }
734 oBase := *op[idx].OverwriteBase
735 oClass := *op[idx].OverwriteClass
736 if oClass != "mask" {
737 panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
738 }
739 if oBase != "int" {
740 panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
741 }
742 if op[idx].Class != "vreg" {
743 panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
744 }
745 hasClassOverwrite = true
746 *op[idx].Base = oBase
747 op[idx].Class = oClass
748 *op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
749 } else if op[idx].OverwriteBase != nil {
750 oBase := *op[idx].OverwriteBase
751 *op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
752 if op[idx].Class == "greg" {
753 *op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
754 }
755 *op[idx].Base = oBase
756 }
757 return nil
758 }
759 for i, o := range ops {
760 hasClassOverwrite = false
761 for j := range ops[i].In {
762 if err := overwrite(ops[i].In, j, o); err != nil {
763 return err
764 }
765 if hasClassOverwrite {
766 return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
767 }
768 }
769 for j := range ops[i].Out {
770 if err := overwrite(ops[i].Out, j, o); err != nil {
771 return err
772 }
773 }
774 if hasClassOverwrite {
775 for _, in := range ops[i].In {
776 if in.Class == "mask" {
777 return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
778 }
779 }
780 }
781 }
782 return nil
783 }
784
785
786
787
788
789
790
791 func reportXEDInconsistency(ops []Operation) error {
792 for _, o := range ops {
793 if o.NameAndSizeCheck != nil {
794 suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
795 checkOperand := func(opr Operand) error {
796 if opr.ElemBits == nil {
797 return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
798 }
799 if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
800 return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
801 } else {
802 if v != *opr.ElemBits {
803 return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
804 }
805 }
806 return nil
807 }
808 for _, in := range o.In {
809 if in.Class != "vreg" && in.Class != "mask" {
810 continue
811 }
812 if in.TreatLikeAScalarOfSize != nil {
813
814 continue
815 }
816 if err := checkOperand(in); err != nil {
817 return err
818 }
819 }
820 for _, out := range o.Out {
821 if err := checkOperand(out); err != nil {
822 return err
823 }
824 }
825 }
826 }
827 return nil
828 }
829
830 func (o *Operation) hasMaskedMerging(maskType maskShape, outType outShape) bool {
831
832 return o.OperandOrder == nil && o.SpecialLower == nil && maskType == OneMask && outType == OneVregOut &&
833 len(o.InVariant) == 1 && !strings.Contains(o.Asm, "BLEND") && !strings.Contains(o.Asm, "VMOVDQU")
834 }
835
836 func getVbcstData(s string) (feat1Match, feat2Match string) {
837 _, err := fmt.Sscanf(s, "feat1=%[^;];feat2=%s", &feat1Match, &feat2Match)
838 if err != nil {
839 panic(err)
840 }
841 return
842 }
843
844 func (o Operation) String() string {
845 return pprints(o)
846 }
847
848 func (op Operand) String() string {
849 return pprints(op)
850 }
851
View as plain text