Source file src/cmd/compile/internal/ssa/known_bits.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import "slices"
     8  
     9  func (kb *knownBitsState) fold(v *Value) (value, known int64) {
    10  	if kb.seenValues.Test(uint32(v.ID)) {
    11  		return kb.entries[v.ID].value, kb.entries[v.ID].known
    12  	}
    13  	defer func() {
    14  		// maintain the invariants:
    15  		// 3. booleans are stored as 1 byte values who are either 0 or 1.
    16  		if v.Type.IsBoolean() {
    17  			value &= 1
    18  			known |= ^1
    19  		}
    20  
    21  		// 2. all values are sign-extended to int64 (inspired by RISC-V's xlen=64)
    22  		switch v.Type.Size() {
    23  		case 1:
    24  			value = int64(int8(value))
    25  			known = int64(int8(known))
    26  		case 2:
    27  			value = int64(int16(value))
    28  			known = int64(int16(known))
    29  		case 4:
    30  			value = int64(int32(value))
    31  			known = int64(int32(known))
    32  		case 8:
    33  		default:
    34  			panic("unreachable; unknown integer size")
    35  		}
    36  
    37  		// 1. unknown bits are always set to 0 inside value
    38  		value &= known
    39  
    40  		if v.Block.Func.pass.debug > 1 {
    41  			v.Block.Func.Warnl(v.Pos, "known bits state %v: k:%d v:%d", v, known, value)
    42  		}
    43  		kb.entries[v.ID].known = known
    44  		kb.entries[v.ID].value = value
    45  	}()
    46  	kb.seenValues.Set(uint32(v.ID)) // set seen early to give up on loops
    47  
    48  	switch v.Op {
    49  	// TODO: rotates, ...
    50  	case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool:
    51  		return v.AuxInt, -1
    52  	case OpAnd64, OpAnd32, OpAnd16, OpAnd8, OpAndB:
    53  		x, xk := kb.fold(v.Args[0])
    54  		y, yk := kb.fold(v.Args[1])
    55  		onesInBoth := x & y
    56  		zerosInX := ^x & xk
    57  		zerosInY := ^y & yk
    58  		return x & y, onesInBoth | zerosInX | zerosInY
    59  	case OpOr64, OpOr32, OpOr16, OpOr8, OpOrB:
    60  		x, xk := kb.fold(v.Args[0])
    61  		y, yk := kb.fold(v.Args[1])
    62  		zerosInBoth := ^x & ^y & (xk & yk)
    63  		onesInX := x
    64  		onesInY := y
    65  		return x | y, onesInX | onesInY | zerosInBoth
    66  	case OpXor64, OpXor32, OpXor16, OpXor8:
    67  		x, xk := kb.fold(v.Args[0])
    68  		y, yk := kb.fold(v.Args[1])
    69  		return x ^ y, xk & yk
    70  	case OpCom64, OpCom32, OpCom16, OpCom8, OpNot:
    71  		x, xk := kb.fold(v.Args[0])
    72  		return ^x, xk
    73  	case OpPhi:
    74  		set := false
    75  		for i, arg := range v.Args {
    76  			if !kb.isLiveInEdge(v.Block, uint(i)) {
    77  				continue
    78  			}
    79  			a, k := kb.fold(arg)
    80  			if !set {
    81  				value, known = a, k
    82  				set = true
    83  			} else {
    84  				known &^= value ^ a
    85  				known &= k
    86  			}
    87  			if known == 0 {
    88  				break
    89  			}
    90  		}
    91  		return value, known
    92  	case OpCopy, OpCvtBoolToUint8,
    93  		OpSignExt8to16, OpSignExt8to32, OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
    94  		// The defer block handles maintaining the sign-extension invariant using v.Type.Size()
    95  		// thus we can just pass Truncs as-is.
    96  		OpTrunc64to32, OpTrunc64to16, OpTrunc64to8, OpTrunc32to16, OpTrunc32to8, OpTrunc16to8:
    97  		return kb.fold(v.Args[0])
    98  	case OpEq64, OpEq32, OpEq16, OpEq8, OpEqB:
    99  		x, xk := kb.fold(v.Args[0])
   100  		y, yk := kb.fold(v.Args[1])
   101  		differentBits := x ^ y
   102  		if differentBits&xk&yk != 0 {
   103  			return 0, -1
   104  		}
   105  		if xk == -1 && yk == -1 {
   106  			return boolToAuxInt(x == y), -1
   107  		}
   108  		return 0, -1 << 1
   109  	case OpNeq64, OpNeq32, OpNeq16, OpNeq8, OpNeqB:
   110  		x, xk := kb.fold(v.Args[0])
   111  		y, yk := kb.fold(v.Args[1])
   112  		differentBits := x ^ y
   113  		if differentBits&xk&yk != 0 {
   114  			return 1, -1
   115  		}
   116  		if xk == -1 && yk == -1 {
   117  			return boolToAuxInt(x != y), -1
   118  		}
   119  		return 0, -1 << 1
   120  	case OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64:
   121  		x, k := kb.fold(v.Args[0])
   122  		srcSize := v.Args[0].Type.Size() * 8
   123  		mask := int64(1<<srcSize - 1)
   124  		return x & mask, k | ^mask
   125  	case OpLsh8x8, OpLsh16x8, OpLsh32x8, OpLsh64x8,
   126  		OpLsh8x16, OpLsh16x16, OpLsh32x16, OpLsh64x16,
   127  		OpLsh8x32, OpLsh16x32, OpLsh32x32, OpLsh64x32,
   128  		OpLsh8x64, OpLsh16x64, OpLsh32x64, OpLsh64x64:
   129  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   130  			return x << shift, xk<<shift | (1<<shift - 1)
   131  		})
   132  	case OpRsh8Ux8, OpRsh16Ux8, OpRsh32Ux8, OpRsh64Ux8,
   133  		OpRsh8Ux16, OpRsh16Ux16, OpRsh32Ux16, OpRsh64Ux16,
   134  		OpRsh8Ux32, OpRsh16Ux32, OpRsh32Ux32, OpRsh64Ux32,
   135  		OpRsh8Ux64, OpRsh16Ux64, OpRsh32Ux64, OpRsh64Ux64:
   136  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   137  			x &= (1<<xSize - 1)
   138  			xk |= -1 << xSize
   139  			return int64(uint64(x) >> shift), int64(uint64(xk)>>shift | (^uint64(0) << (64 - shift)))
   140  		})
   141  	case OpRsh8x8, OpRsh16x8, OpRsh32x8, OpRsh64x8,
   142  		OpRsh8x16, OpRsh16x16, OpRsh32x16, OpRsh64x16,
   143  		OpRsh8x32, OpRsh16x32, OpRsh32x32, OpRsh64x32,
   144  		OpRsh8x64, OpRsh16x64, OpRsh32x64, OpRsh64x64:
   145  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   146  			return x >> shift, xk >> shift
   147  		})
   148  	default:
   149  		return 0, 0
   150  	}
   151  }
   152  
   153  // knownBits does constant folding across bitfields
   154  func knownBits(f *Func) {
   155  	kb := &knownBitsState{
   156  		entries:         f.Cache.allocKnownBitsEntriesSlice(f.NumValues()),
   157  		seenValues:      f.Cache.allocBitset(f.NumValues()),
   158  		reachableBlocks: f.Cache.allocBitset(f.NumBlocks()),
   159  	}
   160  	defer f.Cache.freeKnownBitsEntriesSlice(kb.entries)
   161  	defer f.Cache.freeBitset(kb.seenValues)
   162  	defer f.Cache.freeBitset(kb.reachableBlocks)
   163  	clear(kb.seenValues)
   164  	clear(kb.entries)
   165  	clear(kb.reachableBlocks)
   166  
   167  	blocks := f.postorder()
   168  	for _, b := range blocks {
   169  		kb.reachableBlocks.Set(uint32(b.ID))
   170  	}
   171  
   172  	for _, b := range slices.Backward(blocks) {
   173  		for _, v := range b.Values {
   174  			if v.Uses == 0 || !(v.Type.IsInteger() || v.Type.IsBoolean()) {
   175  				continue
   176  			}
   177  			switch v.Op {
   178  			case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool:
   179  				continue
   180  			}
   181  			val, k := kb.fold(v)
   182  			if k != -1 {
   183  				continue
   184  			}
   185  			if f.pass.debug > 0 {
   186  				var pval any = val
   187  				if v.Type.IsBoolean() {
   188  					pval = val != 0
   189  				}
   190  				f.Warnl(v.Pos, "known value of %v (%v): %v", v, v.Op, pval)
   191  			}
   192  			var c *Value
   193  			switch v.Type.Size() {
   194  			case 1:
   195  				if v.Type.IsBoolean() {
   196  					c = f.ConstBool(v.Type, val != 0)
   197  					break
   198  				}
   199  				c = f.ConstInt8(v.Type, int8(val))
   200  			case 2:
   201  				c = f.ConstInt16(v.Type, int16(val))
   202  			case 4:
   203  				c = f.ConstInt32(v.Type, int32(val))
   204  			case 8:
   205  				c = f.ConstInt64(v.Type, val)
   206  			default:
   207  				panic("unreachable; unknown integer size")
   208  			}
   209  			v.copyOf(c)
   210  		}
   211  	}
   212  }
   213  
   214  type knownBitsState struct {
   215  	entries         []knownBitsEntry // indexed by Value.ID
   216  	seenValues      bitset           // indexed by Value.ID (at the bit level)
   217  	reachableBlocks bitset           // indexed by Block.ID (at the bit level)
   218  }
   219  
   220  type knownBitsEntry struct {
   221  	// Two invariants:
   222  	// 1. unknown bits are always set to 0 inside value
   223  	// 2. all values are sign-extended to int64 (inspired by RISC-V's xlen=64)
   224  	//    This means let's say you know an 8 bits value is 0b10??????,
   225  	//    known = int64(int8(0b11000000))
   226  	//    value = int64(int8(0b10000000))
   227  	// 3. booleans are stored as 1 byte values who are either 0 or 1.
   228  	known, value int64
   229  }
   230  
   231  func (kb *knownBitsState) isLiveInEdge(b *Block, index uint) bool {
   232  	inEdge := b.Preds[index]
   233  	return kb.isLiveOutEdge(inEdge.b, uint(inEdge.i))
   234  }
   235  
   236  func (kb *knownBitsState) isLiveOutEdge(b *Block, index uint) bool {
   237  	if !kb.reachableBlocks.Test(uint32(b.ID)) {
   238  		return false
   239  	}
   240  
   241  	switch b.Kind {
   242  	case BlockFirst:
   243  		return index == 0
   244  	case BlockPlain, BlockIf, BlockDefer, BlockRet, BlockRetJmp, BlockExit, BlockJumpTable:
   245  		return true
   246  	default:
   247  		panic("unreachable; unknown block kind")
   248  	}
   249  }
   250  
   251  // computeKnownBitsForShift computes the known bits for a shift operation.
   252  // Considering the following piece of code x = x << uint8(i)
   253  // The algorithm is based on two observations:
   254  //
   255  //  1. computing a shift of a lattice by a constant (i) is easy:
   256  //     value, known = x<<i, xk<<i|(1<<i-1)
   257  //     each point in the lattice is shifted by the constant, all new shifted in bits are known zeros.
   258  //
   259  //  2. x = uint8(x) << i is equivalent to
   260  //
   261  //     switch i {
   262  //     case 0:  x0 = x << 0
   263  //     case 1:  x1 = x << 1
   264  //     case 2:  x2 = x << 2
   265  //     case 3:  x3 = x << 3
   266  //     case 4:  x4 = x << 4
   267  //     case 5:  x5 = x << 5
   268  //     case 6:  x6 = x << 6
   269  //     case 7:  x7 = x << 7
   270  //     default: xd = x << 8
   271  //     }
   272  //     x = phi(x0, x1, x2, x3, x4, x5, x6, x7, xd)
   273  //
   274  // The algorithm below then models the phi in the equivalence above using same intersection algorithm phi uses.
   275  // We also leverage known bits of the shift amount to remove "branches" in the switch that are proved to be impossible.
   276  func (kb *knownBitsState) computeKnownBitsForShift(v *Value, doShiftByAConst func(x, xk, xSize, shift int64) (value, known int64)) (value, known int64) {
   277  	xSize := v.Args[0].Type.Size() * 8
   278  	x, xk := kb.fold(v.Args[0])
   279  	y, yk := kb.fold(v.Args[1])
   280  	if uint64(y) >= uint64(xSize) {
   281  		return doShiftByAConst(x, xk, xSize, 64)
   282  	}
   283  
   284  	set := false
   285  	if v.AuxInt == 0 && uint64(^yk) >= uint64(xSize) {
   286  		// this implement the default case of the equivalent switch above.
   287  		// if the shift isn't bounded and there are unknown bits above the shift size we might completely stomp all bits.
   288  
   289  		value, known = doShiftByAConst(x, xk, xSize, 64)
   290  		set = true
   291  	}
   292  	yk &= xSize - 1
   293  
   294  	for i := range xSize {
   295  		if i&yk != y {
   296  			continue
   297  		}
   298  		a, k := doShiftByAConst(x, xk, xSize, int64(i))
   299  		if !set {
   300  			value, known = a, k
   301  			set = true
   302  		} else {
   303  			known &^= value ^ a
   304  			known &= k
   305  		}
   306  		if known == 0 {
   307  			break
   308  		}
   309  	}
   310  
   311  	return value & known, known
   312  }
   313  

View as plain text