Source file src/simd/archsimd/_gen/simdgen/gen_simdMachineOps.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"log"
    11  	"sort"
    12  	"strings"
    13  )
    14  
    15  const simdMachineOpsTmpl = `
    16  package main
    17  
    18  func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw,
    19  	wkwload, v21load, v31load, v11load, w21load, w31load, w2kload, w2kwload, w11load, w3kwload, w2kkload, v31x0AtIn2 regInfo) []opData {
    20  	return []opData{
    21  {{- range .OpsData }}
    22  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
    23  {{- end }}
    24  {{- range .OpsDataImm }}
    25  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
    26  {{- end }}
    27  {{- range .OpsDataLoad}}
    28  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
    29  {{- end}}
    30  {{- range .OpsDataImmLoad}}
    31  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
    32  {{- end}}
    33  {{- range .OpsDataMerging }}
    34  		{name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true},
    35  {{- end }}
    36  {{- range .OpsDataImmMerging }}
    37  		{name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true},
    38  {{- end }}
    39  	}
    40  }
    41  `
    42  
    43  // writeSIMDMachineOps generates the machine ops and writes it to simdAMD64ops.go
    44  // within the specified directory.
    45  func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
    46  	t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
    47  	buffer := new(bytes.Buffer)
    48  	buffer.WriteString(generatedHeader)
    49  
    50  	type opData struct {
    51  		OpName       string
    52  		Asm          string
    53  		OpInLen      int
    54  		RegInfo      string
    55  		Comm         bool
    56  		Type         string
    57  		ResultInArg0 bool
    58  	}
    59  	type machineOpsData struct {
    60  		OpsData           []opData
    61  		OpsDataImm        []opData
    62  		OpsDataLoad       []opData
    63  		OpsDataImmLoad    []opData
    64  		OpsDataMerging    []opData
    65  		OpsDataImmMerging []opData
    66  	}
    67  
    68  	regInfoSet := map[string]bool{
    69  		"v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
    70  		"w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true,
    71  		"wkwload": true, "v21load": true, "v31load": true, "v11load": true, "w21load": true, "w31load": true, "w2kload": true, "w2kwload": true, "w11load": true,
    72  		"w3kwload": true, "w2kkload": true, "v31x0AtIn2": true}
    73  	opsData := make([]opData, 0)
    74  	opsDataImm := make([]opData, 0)
    75  	opsDataLoad := make([]opData, 0)
    76  	opsDataImmLoad := make([]opData, 0)
    77  	opsDataMerging := make([]opData, 0)
    78  	opsDataImmMerging := make([]opData, 0)
    79  
    80  	// Determine the "best" version of an instruction to use
    81  	best := make(map[string]Operation)
    82  	var mOpOrder []string
    83  	countOverrides := func(s []Operand) int {
    84  		a := 0
    85  		for _, o := range s {
    86  			if o.OverwriteBase != nil {
    87  				a++
    88  			}
    89  		}
    90  		return a
    91  	}
    92  	for _, op := range ops {
    93  		_, _, maskType, _, gOp := op.shape()
    94  		asm := machineOpName(maskType, gOp)
    95  		other, ok := best[asm]
    96  		if !ok {
    97  			best[asm] = op
    98  			mOpOrder = append(mOpOrder, asm)
    99  			continue
   100  		}
   101  		if !op.Commutative && other.Commutative { // if there's a non-commutative version of the op, it wins.
   102  			best[asm] = op
   103  			continue
   104  		}
   105  		// see if "op" is better than "other"
   106  		if countOverrides(op.In)+countOverrides(op.Out) < countOverrides(other.In)+countOverrides(other.Out) {
   107  			best[asm] = op
   108  		}
   109  	}
   110  
   111  	regInfoErrs := make([]error, 0)
   112  	regInfoMissing := make(map[string]bool, 0)
   113  	for _, asm := range mOpOrder {
   114  		op := best[asm]
   115  		shapeIn, shapeOut, maskType, _, gOp := op.shape()
   116  
   117  		// TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
   118  		// one here with a name suffix "Merging". The rewrite rules will need them.
   119  		makeRegInfo := func(op Operation, mem memShape) (string, error) {
   120  			regInfo, err := op.regShape(mem)
   121  			if err != nil {
   122  				panic(err)
   123  			}
   124  			regInfo, err = rewriteVecAsScalarRegInfo(op, regInfo)
   125  			if err != nil {
   126  				if mem == NoMem || mem == InvalidMem {
   127  					panic(err)
   128  				}
   129  				return "", err
   130  			}
   131  			if regInfo == "v01load" {
   132  				regInfo = "vload"
   133  			}
   134  			// Makes AVX512 operations use upper registers
   135  			if strings.Contains(op.CPUFeature, "AVX512") {
   136  				regInfo = strings.ReplaceAll(regInfo, "v", "w")
   137  			}
   138  			if _, ok := regInfoSet[regInfo]; !ok {
   139  				regInfoErrs = append(regInfoErrs, fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s.  Op is %s", regInfo, op))
   140  				regInfoMissing[regInfo] = true
   141  			}
   142  			return regInfo, nil
   143  		}
   144  		regInfo, err := makeRegInfo(op, NoMem)
   145  		if err != nil {
   146  			panic(err)
   147  		}
   148  		var outType string
   149  		if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
   150  			// If class overwrite is happening, that's not really a mask but a vreg.
   151  			outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
   152  		} else if shapeOut == OneGregOut {
   153  			outType = gOp.GoType() // this is a straight Go type, not a VecNNN type
   154  		} else if shapeOut == OneKmaskOut {
   155  			outType = "Mask"
   156  		} else {
   157  			panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
   158  		}
   159  		resultInArg0 := false
   160  		if shapeOut == OneVregOutAtIn {
   161  			resultInArg0 = true
   162  		}
   163  		var memOpData *opData
   164  		regInfoMerging := regInfo
   165  		hasMerging := false
   166  		if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
   167  			// Right now we only have vbcst case
   168  			// Make a full vec memory variant.
   169  			opMem := rewriteLastVregToMem(op)
   170  			regInfo, err := makeRegInfo(opMem, VregMemIn)
   171  			if err != nil {
   172  				// Just skip it if it's non nill.
   173  				// an error could be triggered by [checkVecAsScalar].
   174  				// TODO: make [checkVecAsScalar] aware of mem ops.
   175  				if *Verbose {
   176  					log.Printf("Seen error: %e", err)
   177  				}
   178  			} else {
   179  				memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0}
   180  			}
   181  		}
   182  		hasMerging = gOp.hasMaskedMerging(maskType, shapeOut)
   183  		if hasMerging && !resultInArg0 {
   184  			// We have to copy the slice here becasue the sort will be visible from other
   185  			// aliases when no reslicing is happening.
   186  			newIn := make([]Operand, len(op.In), len(op.In)+1)
   187  			copy(newIn, op.In)
   188  			op.In = newIn
   189  			op.In = append(op.In, op.Out[0])
   190  			op.sortOperand()
   191  			regInfoMerging, err = makeRegInfo(op, NoMem)
   192  			if err != nil {
   193  				panic(err)
   194  			}
   195  		}
   196  
   197  		if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
   198  			opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
   199  			if memOpData != nil {
   200  				if *op.MemFeatures != "vbcst" {
   201  					panic("simdgen only knows vbcst for mem ops for now")
   202  				}
   203  				opsDataImmLoad = append(opsDataImmLoad, *memOpData)
   204  			}
   205  			if hasMerging {
   206  				mergingLen := len(gOp.In)
   207  				if !resultInArg0 {
   208  					mergingLen++
   209  				}
   210  				opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
   211  			}
   212  		} else {
   213  			opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
   214  			if memOpData != nil {
   215  				if *op.MemFeatures != "vbcst" {
   216  					panic("simdgen only knows vbcst for mem ops for now")
   217  				}
   218  				opsDataLoad = append(opsDataLoad, *memOpData)
   219  			}
   220  			if hasMerging {
   221  				mergingLen := len(gOp.In)
   222  				if !resultInArg0 {
   223  					mergingLen++
   224  				}
   225  				opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
   226  			}
   227  		}
   228  	}
   229  	if len(regInfoErrs) != 0 {
   230  		for _, e := range regInfoErrs {
   231  			log.Printf("Errors: %e\n", e)
   232  		}
   233  		panic(fmt.Errorf("these regInfo unseen: %v", regInfoMissing))
   234  	}
   235  	sort.Slice(opsData, func(i, j int) bool {
   236  		return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
   237  	})
   238  	sort.Slice(opsDataImm, func(i, j int) bool {
   239  		return compareNatural(opsDataImm[i].OpName, opsDataImm[j].OpName) < 0
   240  	})
   241  	sort.Slice(opsDataLoad, func(i, j int) bool {
   242  		return compareNatural(opsDataLoad[i].OpName, opsDataLoad[j].OpName) < 0
   243  	})
   244  	sort.Slice(opsDataImmLoad, func(i, j int) bool {
   245  		return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0
   246  	})
   247  	sort.Slice(opsDataMerging, func(i, j int) bool {
   248  		return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0
   249  	})
   250  	sort.Slice(opsDataImmMerging, func(i, j int) bool {
   251  		return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0
   252  	})
   253  	err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad,
   254  		opsDataMerging, opsDataImmMerging})
   255  	if err != nil {
   256  		panic(fmt.Errorf("failed to execute template: %w", err))
   257  	}
   258  
   259  	return buffer
   260  }
   261  

View as plain text