1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "fmt"
10 "log"
11 "sort"
12 "strings"
13 )
14
15 const simdMachineOpsTmpl = `
16 package main
17
18 func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw,
19 wkwload, v21load, v31load, v11load, w21load, w31load, w2kload, w2kwload, w11load, w3kwload, w2kkload, v31x0AtIn2 regInfo) []opData {
20 return []opData{
21 {{- range .OpsData }}
22 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
23 {{- end }}
24 {{- range .OpsDataImm }}
25 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
26 {{- end }}
27 {{- range .OpsDataLoad}}
28 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
29 {{- end}}
30 {{- range .OpsDataImmLoad}}
31 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
32 {{- end}}
33 {{- range .OpsDataMerging }}
34 {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true},
35 {{- end }}
36 {{- range .OpsDataImmMerging }}
37 {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true},
38 {{- end }}
39 }
40 }
41 `
42
43
44
45 func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
46 t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
47 buffer := new(bytes.Buffer)
48 buffer.WriteString(generatedHeader)
49
50 type opData struct {
51 OpName string
52 Asm string
53 OpInLen int
54 RegInfo string
55 Comm bool
56 Type string
57 ResultInArg0 bool
58 }
59 type machineOpsData struct {
60 OpsData []opData
61 OpsDataImm []opData
62 OpsDataLoad []opData
63 OpsDataImmLoad []opData
64 OpsDataMerging []opData
65 OpsDataImmMerging []opData
66 }
67
68 regInfoSet := map[string]bool{
69 "v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
70 "w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true,
71 "wkwload": true, "v21load": true, "v31load": true, "v11load": true, "w21load": true, "w31load": true, "w2kload": true, "w2kwload": true, "w11load": true,
72 "w3kwload": true, "w2kkload": true, "v31x0AtIn2": true}
73 opsData := make([]opData, 0)
74 opsDataImm := make([]opData, 0)
75 opsDataLoad := make([]opData, 0)
76 opsDataImmLoad := make([]opData, 0)
77 opsDataMerging := make([]opData, 0)
78 opsDataImmMerging := make([]opData, 0)
79
80
81 best := make(map[string]Operation)
82 var mOpOrder []string
83 countOverrides := func(s []Operand) int {
84 a := 0
85 for _, o := range s {
86 if o.OverwriteBase != nil {
87 a++
88 }
89 }
90 return a
91 }
92 for _, op := range ops {
93 _, _, maskType, _, gOp := op.shape()
94 asm := machineOpName(maskType, gOp)
95 other, ok := best[asm]
96 if !ok {
97 best[asm] = op
98 mOpOrder = append(mOpOrder, asm)
99 continue
100 }
101 if !op.Commutative && other.Commutative {
102 best[asm] = op
103 continue
104 }
105
106 if countOverrides(op.In)+countOverrides(op.Out) < countOverrides(other.In)+countOverrides(other.Out) {
107 best[asm] = op
108 }
109 }
110
111 regInfoErrs := make([]error, 0)
112 regInfoMissing := make(map[string]bool, 0)
113 for _, asm := range mOpOrder {
114 op := best[asm]
115 shapeIn, shapeOut, maskType, _, gOp := op.shape()
116
117
118
119 makeRegInfo := func(op Operation, mem memShape) (string, error) {
120 regInfo, err := op.regShape(mem)
121 if err != nil {
122 panic(err)
123 }
124 regInfo, err = rewriteVecAsScalarRegInfo(op, regInfo)
125 if err != nil {
126 if mem == NoMem || mem == InvalidMem {
127 panic(err)
128 }
129 return "", err
130 }
131 if regInfo == "v01load" {
132 regInfo = "vload"
133 }
134
135 if strings.Contains(op.CPUFeature, "AVX512") {
136 regInfo = strings.ReplaceAll(regInfo, "v", "w")
137 }
138 if _, ok := regInfoSet[regInfo]; !ok {
139 regInfoErrs = append(regInfoErrs, fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s. Op is %s", regInfo, op))
140 regInfoMissing[regInfo] = true
141 }
142 return regInfo, nil
143 }
144 regInfo, err := makeRegInfo(op, NoMem)
145 if err != nil {
146 panic(err)
147 }
148 var outType string
149 if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
150
151 outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
152 } else if shapeOut == OneGregOut {
153 outType = gOp.GoType()
154 } else if shapeOut == OneKmaskOut {
155 outType = "Mask"
156 } else {
157 panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
158 }
159 resultInArg0 := false
160 if shapeOut == OneVregOutAtIn {
161 resultInArg0 = true
162 }
163 var memOpData *opData
164 regInfoMerging := regInfo
165 hasMerging := false
166 if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
167
168
169 opMem := rewriteLastVregToMem(op)
170 regInfo, err := makeRegInfo(opMem, VregMemIn)
171 if err != nil {
172
173
174
175 if *Verbose {
176 log.Printf("Seen error: %e", err)
177 }
178 } else {
179 memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0}
180 }
181 }
182 hasMerging = gOp.hasMaskedMerging(maskType, shapeOut)
183 if hasMerging && !resultInArg0 {
184
185
186 newIn := make([]Operand, len(op.In), len(op.In)+1)
187 copy(newIn, op.In)
188 op.In = newIn
189 op.In = append(op.In, op.Out[0])
190 op.sortOperand()
191 regInfoMerging, err = makeRegInfo(op, NoMem)
192 if err != nil {
193 panic(err)
194 }
195 }
196
197 if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
198 opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
199 if memOpData != nil {
200 if *op.MemFeatures != "vbcst" {
201 panic("simdgen only knows vbcst for mem ops for now")
202 }
203 opsDataImmLoad = append(opsDataImmLoad, *memOpData)
204 }
205 if hasMerging {
206 mergingLen := len(gOp.In)
207 if !resultInArg0 {
208 mergingLen++
209 }
210 opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
211 }
212 } else {
213 opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
214 if memOpData != nil {
215 if *op.MemFeatures != "vbcst" {
216 panic("simdgen only knows vbcst for mem ops for now")
217 }
218 opsDataLoad = append(opsDataLoad, *memOpData)
219 }
220 if hasMerging {
221 mergingLen := len(gOp.In)
222 if !resultInArg0 {
223 mergingLen++
224 }
225 opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
226 }
227 }
228 }
229 if len(regInfoErrs) != 0 {
230 for _, e := range regInfoErrs {
231 log.Printf("Errors: %e\n", e)
232 }
233 panic(fmt.Errorf("these regInfo unseen: %v", regInfoMissing))
234 }
235 sort.Slice(opsData, func(i, j int) bool {
236 return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
237 })
238 sort.Slice(opsDataImm, func(i, j int) bool {
239 return compareNatural(opsDataImm[i].OpName, opsDataImm[j].OpName) < 0
240 })
241 sort.Slice(opsDataLoad, func(i, j int) bool {
242 return compareNatural(opsDataLoad[i].OpName, opsDataLoad[j].OpName) < 0
243 })
244 sort.Slice(opsDataImmLoad, func(i, j int) bool {
245 return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0
246 })
247 sort.Slice(opsDataMerging, func(i, j int) bool {
248 return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0
249 })
250 sort.Slice(opsDataImmMerging, func(i, j int) bool {
251 return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0
252 })
253 err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad,
254 opsDataMerging, opsDataImmMerging})
255 if err != nil {
256 panic(fmt.Errorf("failed to execute template: %w", err))
257 }
258
259 return buffer
260 }
261
View as plain text