Source file src/simd/archsimd/_gen/simdgen/gen_utility.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"fmt"
    11  	"go/format"
    12  	"log"
    13  	"os"
    14  	"path/filepath"
    15  	"reflect"
    16  	"slices"
    17  	"sort"
    18  	"strings"
    19  	"text/template"
    20  	"unicode"
    21  )
    22  
    23  func templateOf(temp, name string) *template.Template {
    24  	t, err := template.New(name).Parse(temp)
    25  	if err != nil {
    26  		panic(fmt.Errorf("failed to parse template %s: %w", name, err))
    27  	}
    28  	return t
    29  }
    30  
    31  func createPath(goroot string, file string) (*os.File, error) {
    32  	fp := filepath.Join(goroot, file)
    33  	dir := filepath.Dir(fp)
    34  	err := os.MkdirAll(dir, 0755)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
    37  	}
    38  	f, err := os.Create(fp)
    39  	if err != nil {
    40  		return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
    41  	}
    42  	return f, nil
    43  }
    44  
    45  func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
    46  	b, err := format.Source(out.Bytes())
    47  	if err != nil {
    48  		fmt.Fprintf(os.Stderr, "%v\n", err)
    49  		fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
    50  		fmt.Fprintf(os.Stderr, "%v\n", err)
    51  		panic(err)
    52  	} else {
    53  		writeAndClose(b, goroot, file)
    54  	}
    55  }
    56  
    57  func writeAndClose(b []byte, goroot string, file string) {
    58  	ofile, err := createPath(goroot, file)
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	ofile.Write(b)
    63  	ofile.Close()
    64  }
    65  
    66  // numberLines takes a slice of bytes, and returns a string where each line
    67  // is numbered, starting from 1.
    68  func numberLines(data []byte) string {
    69  	var buf bytes.Buffer
    70  	r := bytes.NewReader(data)
    71  	s := bufio.NewScanner(r)
    72  	for i := 1; s.Scan(); i++ {
    73  		fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
    74  	}
    75  	return buf.String()
    76  }
    77  
    78  type inShape uint8
    79  type outShape uint8
    80  type maskShape uint8
    81  type immShape uint8
    82  type memShape uint8
    83  
    84  const (
    85  	InvalidIn     inShape = iota
    86  	PureVregIn            // vector register input only
    87  	OneKmaskIn            // vector and kmask input
    88  	OneImmIn              // vector and immediate input
    89  	OneKmaskImmIn         // vector, kmask, and immediate inputs
    90  	PureKmaskIn           // only mask inputs.
    91  )
    92  
    93  const (
    94  	InvalidOut     outShape = iota
    95  	NoOut                   // no output
    96  	OneVregOut              // (one) vector register output
    97  	OneGregOut              // (one) general register output
    98  	OneKmaskOut             // mask output
    99  	OneVregOutAtIn          // the first input is also the output
   100  )
   101  
   102  const (
   103  	InvalidMask maskShape = iota
   104  	NoMask                // no mask
   105  	OneMask               // with mask (K1 to K7)
   106  	AllMasks              // a K mask instruction (K0-K7)
   107  )
   108  
   109  const (
   110  	InvalidImm  immShape = iota
   111  	NoImm                // no immediate
   112  	ConstImm             // const only immediate
   113  	VarImm               // pure imm argument provided by the users
   114  	ConstVarImm          // a combination of user arg and const
   115  )
   116  
   117  const (
   118  	InvalidMem memShape = iota
   119  	NoMem
   120  	VregMemIn // The instruction contains a mem input which is loading a vreg.
   121  )
   122  
   123  // opShape returns the several integers describing the shape of the operation,
   124  // and modified versions of the op:
   125  //
   126  // opNoImm is op with its inputs excluding the const imm.
   127  //
   128  // This function does not modify op.
   129  func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
   130  	opNoImm Operation) {
   131  	if len(op.Out) > 1 {
   132  		panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
   133  	}
   134  	var outputReg int
   135  	if len(op.Out) == 1 {
   136  		outputReg = op.Out[0].AsmPos
   137  		if op.Out[0].Class == "vreg" {
   138  			shapeOut = OneVregOut
   139  		} else if op.Out[0].Class == "greg" {
   140  			shapeOut = OneGregOut
   141  		} else if op.Out[0].Class == "mask" {
   142  			shapeOut = OneKmaskOut
   143  		} else {
   144  			panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
   145  		}
   146  	} else {
   147  		shapeOut = NoOut
   148  		// TODO: are these only Load/Stores?
   149  		// We manually supported two Load and Store, are those enough?
   150  		panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
   151  	}
   152  	hasImm := false
   153  	maskCount := 0
   154  	hasVreg := false
   155  	for _, in := range op.In {
   156  		if in.AsmPos == outputReg {
   157  			if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
   158  				shapeOut = OneVregOutAtIn
   159  			} else {
   160  				panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
   161  			}
   162  		}
   163  		if in.Class == "immediate" {
   164  			// A manual check on XED data found that AMD64 SIMD instructions at most
   165  			// have 1 immediates. So we don't need to check this here.
   166  			if *in.Bits != 8 {
   167  				panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
   168  			}
   169  			hasImm = true
   170  		} else if in.Class == "mask" {
   171  			maskCount++
   172  		} else {
   173  			hasVreg = true
   174  		}
   175  	}
   176  	opNoImm = *op
   177  
   178  	removeImm := func(o *Operation) {
   179  		o.In = o.In[1:]
   180  	}
   181  	if hasImm {
   182  		removeImm(&opNoImm)
   183  		if op.In[0].Const != nil {
   184  			if op.In[0].ImmOffset != nil {
   185  				immType = ConstVarImm
   186  			} else {
   187  				immType = ConstImm
   188  			}
   189  		} else if op.In[0].ImmOffset != nil {
   190  			immType = VarImm
   191  		} else {
   192  			panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
   193  		}
   194  	} else {
   195  		immType = NoImm
   196  	}
   197  	if maskCount == 0 {
   198  		maskType = NoMask
   199  	} else {
   200  		maskType = OneMask
   201  	}
   202  	checkPureMask := func() bool {
   203  		if hasImm {
   204  			panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
   205  		}
   206  		if hasVreg {
   207  			panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
   208  		}
   209  		return false
   210  	}
   211  	if !hasImm && maskCount == 0 {
   212  		shapeIn = PureVregIn
   213  	} else if !hasImm && maskCount > 0 {
   214  		if maskCount == 1 {
   215  			shapeIn = OneKmaskIn
   216  		} else {
   217  			if checkPureMask() {
   218  				return
   219  			}
   220  			shapeIn = PureKmaskIn
   221  			maskType = AllMasks
   222  		}
   223  	} else if hasImm && maskCount == 0 {
   224  		shapeIn = OneImmIn
   225  	} else {
   226  		if maskCount == 1 {
   227  			shapeIn = OneKmaskImmIn
   228  		} else {
   229  			checkPureMask()
   230  			return
   231  		}
   232  	}
   233  	return
   234  }
   235  
   236  // regShape returns a string representation of the register shape.
   237  func (op *Operation) regShape(mem memShape) (string, error) {
   238  	_, _, _, _, gOp := op.shape()
   239  	var regInfo, fixedName string
   240  	var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt, memInCnt, memOutCnt int
   241  	for i, in := range gOp.In {
   242  		switch in.Class {
   243  		case "vreg":
   244  			vRegInCnt++
   245  		case "greg":
   246  			gRegInCnt++
   247  		case "mask":
   248  			kMaskInCnt++
   249  		case "memory":
   250  			if mem != VregMemIn {
   251  				panic("simdgen only knows VregMemIn in regShape")
   252  			}
   253  			memInCnt++
   254  			vRegInCnt++
   255  		}
   256  		if in.FixedReg != nil {
   257  			fixedName = fmt.Sprintf("%sAtIn%d", *in.FixedReg, i)
   258  		}
   259  	}
   260  	for i, out := range gOp.Out {
   261  		// If class overwrite is happening, that's not really a mask but a vreg.
   262  		if out.Class == "vreg" || out.OverwriteClass != nil {
   263  			vRegOutCnt++
   264  		} else if out.Class == "greg" {
   265  			gRegOutCnt++
   266  		} else if out.Class == "mask" {
   267  			kMaskOutCnt++
   268  		} else if out.Class == "memory" {
   269  			if mem != VregMemIn {
   270  				panic("simdgen only knows VregMemIn in regShape")
   271  			}
   272  			vRegOutCnt++
   273  			memOutCnt++
   274  		}
   275  		if out.FixedReg != nil {
   276  			fixedName = fmt.Sprintf("%sAtIn%d", *out.FixedReg, i)
   277  		}
   278  	}
   279  	var inRegs, inMasks, outRegs, outMasks string
   280  
   281  	rmAbbrev := func(s string, i int) string {
   282  		if i == 0 {
   283  			return ""
   284  		}
   285  		if i == 1 {
   286  			return s
   287  		}
   288  		return fmt.Sprintf("%s%d", s, i)
   289  
   290  	}
   291  
   292  	inRegs = rmAbbrev("v", vRegInCnt)
   293  	inRegs += rmAbbrev("gp", gRegInCnt)
   294  	inMasks = rmAbbrev("k", kMaskInCnt)
   295  
   296  	outRegs = rmAbbrev("v", vRegOutCnt)
   297  	outRegs += rmAbbrev("gp", gRegOutCnt)
   298  	outMasks = rmAbbrev("k", kMaskOutCnt)
   299  
   300  	if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
   301  		// For pure v we can abbreviate it as v%d%d.
   302  		regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
   303  	} else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
   304  		regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
   305  	} else {
   306  		regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
   307  	}
   308  	if memInCnt > 0 {
   309  		if memInCnt == 1 {
   310  			regInfo += "load"
   311  		} else {
   312  			panic("simdgen does not understand more than 1 mem op as of now")
   313  		}
   314  	}
   315  	if memOutCnt > 0 {
   316  		panic("simdgen does not understand memory as output as of now")
   317  	}
   318  	regInfo += fixedName
   319  	return regInfo, nil
   320  }
   321  
   322  // sortOperand sorts op.In by putting immediates first, then vreg, and mask the last.
   323  // TODO: verify that this is a safe assumption of the prog structure.
   324  // from my observation looks like in asm, imms are always the first,
   325  // masks are always the last, with vreg in between.
   326  func (op *Operation) sortOperand() {
   327  	priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
   328  	sort.SliceStable(op.In, func(i, j int) bool {
   329  		pi := priority[op.In[i].Class]
   330  		pj := priority[op.In[j].Class]
   331  		if pi != pj {
   332  			return pi < pj
   333  		}
   334  		return op.In[i].AsmPos < op.In[j].AsmPos
   335  	})
   336  }
   337  
   338  // adjustAsm adjusts the asm to make it align with Go's assembler.
   339  func (op *Operation) adjustAsm() {
   340  	if op.Asm == "VCVTTPD2DQ" || op.Asm == "VCVTTPD2UDQ" ||
   341  		op.Asm == "VCVTQQ2PS" || op.Asm == "VCVTUQQ2PS" ||
   342  		op.Asm == "VCVTPD2PS" {
   343  		switch *op.In[0].Bits {
   344  		case 128:
   345  			op.Asm += "X"
   346  		case 256:
   347  			op.Asm += "Y"
   348  		}
   349  	}
   350  }
   351  
   352  // goNormalType returns the Go type name for the result of an Op that
   353  // does not return a vector, i.e., that returns a result in a general
   354  // register.  Currently there's only one family of Ops in Go's simd library
   355  // that does this (GetElem), and so this is specialized to work for that,
   356  // but the problem (mismatch betwen hardware register width and Go type
   357  // width) seems likely to recur if there are any other cases.
   358  func (op Operation) goNormalType() string {
   359  	if op.Go == "GetElem" {
   360  		// GetElem returns an element of the vector into a general register
   361  		// but as far as the hardware is concerned, that result is either 32
   362  		// or 64 bits wide, no matter what the vector element width is.
   363  		// This is not "wrong" but it is not the right answer for Go source code.
   364  		// To get the Go type right, combine the base type ("int", "uint", "float"),
   365  		// with the input vector element width in bits (8,16,32,64).
   366  
   367  		at := 0 // proper value of at depends on whether immediate was stripped or not
   368  		if op.In[at].Class == "immediate" {
   369  			at++
   370  		}
   371  		return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
   372  	}
   373  	panic(fmt.Errorf("Implement goNormalType for %v", op))
   374  }
   375  
   376  // SSAType returns the string for the type reference in SSA generation,
   377  // for example in the intrinsics generating template.
   378  func (op Operation) SSAType() string {
   379  	if op.Out[0].Class == "greg" {
   380  		return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
   381  	}
   382  	return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
   383  }
   384  
   385  // GoType returns the Go type returned by this operation (relative to the simd package),
   386  // for example "int32" or "Int8x16".  This is used in a template.
   387  func (op Operation) GoType() string {
   388  	if op.Out[0].Class == "greg" {
   389  		return op.goNormalType()
   390  	}
   391  	return *op.Out[0].Go
   392  }
   393  
   394  // ImmName returns the name to use for an operation's immediate operand.
   395  // This can be overriden in the yaml with "name" on an operand,
   396  // otherwise, for now, "constant"
   397  func (op Operation) ImmName() string {
   398  	return op.Op0Name("constant")
   399  }
   400  
   401  func (o Operand) OpName(s string) string {
   402  	if n := o.Name; n != nil {
   403  		return *n
   404  	}
   405  	if o.Class == "mask" {
   406  		return "mask"
   407  	}
   408  	return s
   409  }
   410  
   411  func (o Operand) OpNameAndType(s string) string {
   412  	return o.OpName(s) + " " + *o.Go
   413  }
   414  
   415  // GoExported returns [Go] with first character capitalized.
   416  func (op Operation) GoExported() string {
   417  	return capitalizeFirst(op.Go)
   418  }
   419  
   420  // DocumentationExported returns [Documentation] with method name capitalized.
   421  func (op Operation) DocumentationExported() string {
   422  	return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
   423  }
   424  
   425  // Op0Name returns the name to use for the 0 operand,
   426  // if any is present, otherwise the parameter is used.
   427  func (op Operation) Op0Name(s string) string {
   428  	return op.In[0].OpName(s)
   429  }
   430  
   431  // Op1Name returns the name to use for the 1 operand,
   432  // if any is present, otherwise the parameter is used.
   433  func (op Operation) Op1Name(s string) string {
   434  	return op.In[1].OpName(s)
   435  }
   436  
   437  // Op2Name returns the name to use for the 2 operand,
   438  // if any is present, otherwise the parameter is used.
   439  func (op Operation) Op2Name(s string) string {
   440  	return op.In[2].OpName(s)
   441  }
   442  
   443  // Op3Name returns the name to use for the 3 operand,
   444  // if any is present, otherwise the parameter is used.
   445  func (op Operation) Op3Name(s string) string {
   446  	return op.In[3].OpName(s)
   447  }
   448  
   449  // Op0NameAndType returns the name and type to use for
   450  // the 0 operand, if a name is provided, otherwise
   451  // the parameter value is used as the default.
   452  func (op Operation) Op0NameAndType(s string) string {
   453  	return op.In[0].OpNameAndType(s)
   454  }
   455  
   456  // Op1NameAndType returns the name and type to use for
   457  // the 1 operand, if a name is provided, otherwise
   458  // the parameter value is used as the default.
   459  func (op Operation) Op1NameAndType(s string) string {
   460  	return op.In[1].OpNameAndType(s)
   461  }
   462  
   463  // Op2NameAndType returns the name and type to use for
   464  // the 2 operand, if a name is provided, otherwise
   465  // the parameter value is used as the default.
   466  func (op Operation) Op2NameAndType(s string) string {
   467  	return op.In[2].OpNameAndType(s)
   468  }
   469  
   470  // Op3NameAndType returns the name and type to use for
   471  // the 3 operand, if a name is provided, otherwise
   472  // the parameter value is used as the default.
   473  func (op Operation) Op3NameAndType(s string) string {
   474  	return op.In[3].OpNameAndType(s)
   475  }
   476  
   477  // Op4NameAndType returns the name and type to use for
   478  // the 4 operand, if a name is provided, otherwise
   479  // the parameter value is used as the default.
   480  func (op Operation) Op4NameAndType(s string) string {
   481  	return op.In[4].OpNameAndType(s)
   482  }
   483  
   484  var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
   485  var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
   486  
   487  // classifyOp returns a classification string, modified operation, and perhaps error based
   488  // on the stub and intrinsic shape for the operation.
   489  // The classification string is in the regular expression set "op[1234](Imm8)?(_<order>)?"
   490  // where the "<order>" suffix is optionally attached to the Operation in its input yaml.
   491  // The classification string is used to select a template or a clause of a template
   492  // for intrinsics declaration and the ssagen intrinisics glue code in the compiler.
   493  func classifyOp(op Operation) (string, Operation, error) {
   494  	_, _, _, immType, gOp := op.shape()
   495  
   496  	var class string
   497  
   498  	if immType == VarImm || immType == ConstVarImm {
   499  		switch l := len(op.In); l {
   500  		case 1:
   501  			return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
   502  		case 2, 3, 4, 5:
   503  			class = immClasses[l]
   504  		default:
   505  			return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
   506  		}
   507  		if order := op.OperandOrder; order != nil {
   508  			class += "_" + *order
   509  		}
   510  		return class, op, nil
   511  	} else {
   512  		switch l := len(gOp.In); l {
   513  		case 1, 2, 3, 4:
   514  			class = classes[l]
   515  		default:
   516  			return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
   517  		}
   518  		if order := op.OperandOrder; order != nil {
   519  			class += "_" + *order
   520  		}
   521  		return class, gOp, nil
   522  	}
   523  }
   524  
   525  func checkVecAsScalar(op Operation) (idx int, err error) {
   526  	idx = -1
   527  	sSize := 0
   528  	for i, o := range op.In {
   529  		if o.TreatLikeAScalarOfSize != nil {
   530  			if idx == -1 {
   531  				idx = i
   532  				sSize = *o.TreatLikeAScalarOfSize
   533  			} else {
   534  				err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
   535  				return
   536  			}
   537  		}
   538  	}
   539  	if idx >= 0 {
   540  		if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
   541  			err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
   542  			return
   543  		}
   544  	}
   545  	return
   546  }
   547  
   548  func rewriteVecAsScalarRegInfo(op Operation, regInfo string) (string, error) {
   549  	idx, err := checkVecAsScalar(op)
   550  	if err != nil {
   551  		return "", err
   552  	}
   553  	if idx != -1 {
   554  		if regInfo == "v21" {
   555  			regInfo = "vfpv"
   556  		} else if regInfo == "v2kv" {
   557  			regInfo = "vfpkv"
   558  		} else if regInfo == "v31" {
   559  			regInfo = "v2fpv"
   560  		} else if regInfo == "v3kv" {
   561  			regInfo = "v2fpkv"
   562  		} else {
   563  			return "", fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op)
   564  		}
   565  	}
   566  	return regInfo, nil
   567  }
   568  
   569  func rewriteLastVregToMem(op Operation) Operation {
   570  	newIn := make([]Operand, len(op.In))
   571  	lastVregIdx := -1
   572  	for i := range len(op.In) {
   573  		newIn[i] = op.In[i]
   574  		if op.In[i].Class == "vreg" {
   575  			lastVregIdx = i
   576  		}
   577  	}
   578  	// vbcst operations put their mem op always as the last vreg.
   579  	if lastVregIdx == -1 {
   580  		panic("simdgen cannot find one vreg in the mem op vreg original")
   581  	}
   582  	newIn[lastVregIdx].Class = "memory"
   583  	op.In = newIn
   584  
   585  	return op
   586  }
   587  
   588  // dedup is deduping operations in the full structure level.
   589  func dedup(ops []Operation) (deduped []Operation) {
   590  	for _, op := range ops {
   591  		seen := false
   592  		for _, dop := range deduped {
   593  			if reflect.DeepEqual(op, dop) {
   594  				seen = true
   595  				break
   596  			}
   597  		}
   598  		if !seen {
   599  			deduped = append(deduped, op)
   600  		}
   601  	}
   602  	return
   603  }
   604  
   605  func (op Operation) GenericName() string {
   606  	if op.OperandOrder != nil {
   607  		switch *op.OperandOrder {
   608  		case "21Type1", "231Type1":
   609  			// Permute uses operand[1] for method receiver.
   610  			return op.Go + *op.In[1].Go
   611  		}
   612  	}
   613  	if op.In[0].Class == "immediate" {
   614  		return op.Go + *op.In[1].Go
   615  	}
   616  	return op.Go + *op.In[0].Go
   617  }
   618  
   619  // dedupGodef is deduping operations in [Op.Go]+[*Op.In[0].Go] level.
   620  // By deduping, it means picking the least advanced architecture that satisfy the requirement:
   621  // AVX512 will be least preferred.
   622  // If FlagNoDedup is set, it will report the duplicates to the console.
   623  func dedupGodef(ops []Operation) ([]Operation, error) {
   624  	seen := map[string][]Operation{}
   625  	for _, op := range ops {
   626  		_, _, _, _, gOp := op.shape()
   627  
   628  		gN := gOp.GenericName()
   629  		seen[gN] = append(seen[gN], op)
   630  	}
   631  	if *FlagReportDup {
   632  		for gName, dup := range seen {
   633  			if len(dup) > 1 {
   634  				log.Printf("Duplicate for %s:\n", gName)
   635  				for _, op := range dup {
   636  					log.Printf("%s\n", op)
   637  				}
   638  			}
   639  		}
   640  		return ops, nil
   641  	}
   642  	isAVX512 := func(op Operation) bool {
   643  		return strings.Contains(op.CPUFeature, "AVX512")
   644  	}
   645  	deduped := []Operation{}
   646  	for _, dup := range seen {
   647  		if len(dup) > 1 {
   648  			slices.SortFunc(dup, func(i, j Operation) int {
   649  				// Put non-AVX512 candidates at the beginning
   650  				if !isAVX512(i) && isAVX512(j) {
   651  					return -1
   652  				}
   653  				if isAVX512(i) && !isAVX512(j) {
   654  					return 1
   655  				}
   656  				if i.CPUFeature != j.CPUFeature {
   657  					return strings.Compare(i.CPUFeature, j.CPUFeature)
   658  				}
   659  				// Weirdly Intel sometimes has duplicated definitions for the same instruction,
   660  				// this confuses the XED mem-op merge logic: [MemFeature] will only be attached to an instruction
   661  				// for only once, which means that for essentially duplicated instructions only one will have the
   662  				// proper [MemFeature] set. We have to make this sort deterministic for [MemFeature].
   663  				if i.MemFeatures != nil && j.MemFeatures == nil {
   664  					return -1
   665  				}
   666  				if i.MemFeatures == nil && j.MemFeatures != nil {
   667  					return 1
   668  				}
   669  				if i.Commutative != j.Commutative {
   670  					if j.Commutative {
   671  						return -1
   672  					}
   673  					return 1
   674  				}
   675  				// Their order does not matter anymore, at least for now.
   676  				return 0
   677  			})
   678  		}
   679  		deduped = append(deduped, dup[0])
   680  	}
   681  	slices.SortFunc(deduped, compareOperations)
   682  	return deduped, nil
   683  }
   684  
   685  // Copy op.ConstImm to op.In[0].Const
   686  // This is a hack to reduce the size of defs we need for const imm operations.
   687  func copyConstImm(ops []Operation) error {
   688  	for _, op := range ops {
   689  		if op.ConstImm == nil {
   690  			continue
   691  		}
   692  		_, _, _, immType, _ := op.shape()
   693  
   694  		if immType == ConstImm || immType == ConstVarImm {
   695  			op.In[0].Const = op.ConstImm
   696  		}
   697  		// Otherwise, just not port it - e.g. {VPCMP[BWDQ] imm=0} and {VPCMPEQ[BWDQ]} are
   698  		// the same operations "Equal", [dedupgodef] should be able to distinguish them.
   699  	}
   700  	return nil
   701  }
   702  
   703  func capitalizeFirst(s string) string {
   704  	if s == "" {
   705  		return ""
   706  	}
   707  	// Convert the string to a slice of runes to handle multi-byte characters correctly.
   708  	r := []rune(s)
   709  	r[0] = unicode.ToUpper(r[0])
   710  	return string(r)
   711  }
   712  
   713  // overwrite corrects some errors due to:
   714  //   - The XED data is wrong
   715  //   - Go's SIMD API requirement, for example AVX2 compares should also produce masks.
   716  //     This rewrite has strict constraints, please see the error message.
   717  //     These constraints are also explointed in [writeSIMDRules], [writeSIMDMachineOps]
   718  //     and [writeSIMDSSA], please be careful when updating these constraints.
   719  func overwrite(ops []Operation) error {
   720  	hasClassOverwrite := false
   721  	overwrite := func(op []Operand, idx int, o Operation) error {
   722  		if op[idx].OverwriteElementBits != nil {
   723  			if op[idx].ElemBits == nil {
   724  				panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
   725  			}
   726  			*op[idx].ElemBits = *op[idx].OverwriteElementBits
   727  			*op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
   728  			*op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
   729  		}
   730  		if op[idx].OverwriteClass != nil {
   731  			if op[idx].OverwriteBase == nil {
   732  				panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
   733  			}
   734  			oBase := *op[idx].OverwriteBase
   735  			oClass := *op[idx].OverwriteClass
   736  			if oClass != "mask" {
   737  				panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
   738  			}
   739  			if oBase != "int" {
   740  				panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
   741  			}
   742  			if op[idx].Class != "vreg" {
   743  				panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
   744  			}
   745  			hasClassOverwrite = true
   746  			*op[idx].Base = oBase
   747  			op[idx].Class = oClass
   748  			*op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
   749  		} else if op[idx].OverwriteBase != nil {
   750  			oBase := *op[idx].OverwriteBase
   751  			*op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
   752  			if op[idx].Class == "greg" {
   753  				*op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
   754  			}
   755  			*op[idx].Base = oBase
   756  		}
   757  		return nil
   758  	}
   759  	for i, o := range ops {
   760  		hasClassOverwrite = false
   761  		for j := range ops[i].In {
   762  			if err := overwrite(ops[i].In, j, o); err != nil {
   763  				return err
   764  			}
   765  			if hasClassOverwrite {
   766  				return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
   767  			}
   768  		}
   769  		for j := range ops[i].Out {
   770  			if err := overwrite(ops[i].Out, j, o); err != nil {
   771  				return err
   772  			}
   773  		}
   774  		if hasClassOverwrite {
   775  			for _, in := range ops[i].In {
   776  				if in.Class == "mask" {
   777  					return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
   778  				}
   779  			}
   780  		}
   781  	}
   782  	return nil
   783  }
   784  
   785  // reportXEDInconsistency reports potential XED inconsistencies.
   786  // We can add more fields to [Operation] to enable more checks and implement it here.
   787  // Supported checks:
   788  // [NameAndSizeCheck]: NAME[BWDQ] should set the elemBits accordingly.
   789  // This check is useful to find inconsistencies, then we can add overwrite fields to
   790  // those defs to correct them manually.
   791  func reportXEDInconsistency(ops []Operation) error {
   792  	for _, o := range ops {
   793  		if o.NameAndSizeCheck != nil {
   794  			suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
   795  			checkOperand := func(opr Operand) error {
   796  				if opr.ElemBits == nil {
   797  					return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
   798  				}
   799  				if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
   800  					return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
   801  				} else {
   802  					if v != *opr.ElemBits {
   803  						return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
   804  					}
   805  				}
   806  				return nil
   807  			}
   808  			for _, in := range o.In {
   809  				if in.Class != "vreg" && in.Class != "mask" {
   810  					continue
   811  				}
   812  				if in.TreatLikeAScalarOfSize != nil {
   813  					// This is an irregular operand, don't check it.
   814  					continue
   815  				}
   816  				if err := checkOperand(in); err != nil {
   817  					return err
   818  				}
   819  			}
   820  			for _, out := range o.Out {
   821  				if err := checkOperand(out); err != nil {
   822  					return err
   823  				}
   824  			}
   825  		}
   826  	}
   827  	return nil
   828  }
   829  
   830  func (o *Operation) hasMaskedMerging(maskType maskShape, outType outShape) bool {
   831  	// BLEND and VMOVDQU are not user-facing ops so we should filter them out.
   832  	return o.OperandOrder == nil && o.SpecialLower == nil && maskType == OneMask && outType == OneVregOut &&
   833  		len(o.InVariant) == 1 && !strings.Contains(o.Asm, "BLEND") && !strings.Contains(o.Asm, "VMOVDQU")
   834  }
   835  
   836  func getVbcstData(s string) (feat1Match, feat2Match string) {
   837  	_, err := fmt.Sscanf(s, "feat1=%[^;];feat2=%s", &feat1Match, &feat2Match)
   838  	if err != nil {
   839  		panic(err)
   840  	}
   841  	return
   842  }
   843  
   844  func (o Operation) String() string {
   845  	return pprints(o)
   846  }
   847  
   848  func (op Operand) String() string {
   849  	return pprints(op)
   850  }
   851  

View as plain text