Source file src/cmd/compile/internal/inline/interleaved/interleaved.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package interleaved implements the interleaved devirtualization and
     6  // inlining pass.
     7  package interleaved
     8  
     9  import (
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/devirtualize"
    12  	"cmd/compile/internal/inline"
    13  	"cmd/compile/internal/inline/inlheur"
    14  	"cmd/compile/internal/ir"
    15  	"cmd/compile/internal/pgoir"
    16  	"cmd/compile/internal/typecheck"
    17  	"fmt"
    18  )
    19  
    20  // DevirtualizeAndInlinePackage interleaves devirtualization and inlining on
    21  // all functions within pkg.
    22  func DevirtualizeAndInlinePackage(pkg *ir.Package, profile *pgoir.Profile) {
    23  	if profile != nil && base.Debug.PGODevirtualize > 0 {
    24  		// TODO(mdempsky): Integrate into DevirtualizeAndInlineFunc below.
    25  		ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    26  			for _, fn := range list {
    27  				devirtualize.ProfileGuided(fn, profile)
    28  			}
    29  		})
    30  		ir.CurFunc = nil
    31  	}
    32  
    33  	if base.Flag.LowerL != 0 {
    34  		inlheur.SetupScoreAdjustments()
    35  	}
    36  
    37  	var inlProfile *pgoir.Profile // copy of profile for inlining
    38  	if base.Debug.PGOInline != 0 {
    39  		inlProfile = profile
    40  	}
    41  
    42  	// First compute inlinability of all functions in the package.
    43  	inline.CanInlineFuncs(pkg.Funcs, inlProfile)
    44  
    45  	inlState := make(map[*ir.Func]*inlClosureState)
    46  	calleeUseCounts := make(map[*ir.Func]int)
    47  
    48  	// Pre-process all the functions, adding parentheses around call sites and starting their "inl state".
    49  	for _, fn := range typecheck.Target.Funcs {
    50  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
    51  		if bigCaller && base.Flag.LowerM > 1 {
    52  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
    53  		}
    54  
    55  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: calleeUseCounts}
    56  		s.parenthesize()
    57  		inlState[fn] = s
    58  
    59  		// Do a first pass at counting call sites.
    60  		for i := range s.parens {
    61  			s.resolve(i)
    62  		}
    63  	}
    64  
    65  	ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    66  
    67  		anyInlineHeuristics := false
    68  
    69  		// inline heuristics, placed here because they have static state and that's what seems to work.
    70  		for _, fn := range list {
    71  			if base.Flag.LowerL != 0 {
    72  				if inlheur.Enabled() && !fn.Wrapper() {
    73  					inlheur.ScoreCalls(fn)
    74  					anyInlineHeuristics = true
    75  				}
    76  				if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
    77  					inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
    78  				}
    79  			}
    80  		}
    81  
    82  		if anyInlineHeuristics {
    83  			defer inlheur.ScoreCallsCleanup()
    84  		}
    85  
    86  		// Iterate to a fixed point over all the functions.
    87  		done := false
    88  		for !done {
    89  			done = true
    90  			for _, fn := range list {
    91  				s := inlState[fn]
    92  
    93  				ir.WithFunc(fn, func() {
    94  					l1 := len(s.parens)
    95  					l0 := 0
    96  
    97  					// Batch iterations so that newly discovered call sites are
    98  					// resolved in a batch before inlining attempts.
    99  					// Do this to avoid discovering new closure calls 1 at a time
   100  					// which might cause first call to be seen as a single (high-budget)
   101  					// call before the second is observed.
   102  					for {
   103  						for i := l0; i < l1; i++ { // can't use "range parens" here
   104  							paren := s.parens[i]
   105  							if new := s.edit(i); new != nil {
   106  								// Update AST and recursively mark nodes.
   107  								paren.X = new
   108  								ir.EditChildren(new, s.mark) // mark may append to parens
   109  								done = false
   110  							}
   111  						}
   112  						l0, l1 = l1, len(s.parens)
   113  						if l0 == l1 {
   114  							break
   115  						}
   116  						for i := l0; i < l1; i++ {
   117  							s.resolve(i)
   118  						}
   119  
   120  					}
   121  
   122  				}) // WithFunc
   123  
   124  			}
   125  		}
   126  	})
   127  
   128  	ir.CurFunc = nil
   129  
   130  	if base.Flag.LowerL != 0 {
   131  		if base.Debug.DumpInlFuncProps != "" {
   132  			inlheur.DumpFuncProps(nil, base.Debug.DumpInlFuncProps)
   133  		}
   134  		if inlheur.Enabled() {
   135  			inline.PostProcessCallSites(inlProfile)
   136  			inlheur.TearDown()
   137  		}
   138  	}
   139  
   140  	// remove parentheses
   141  	for _, fn := range typecheck.Target.Funcs {
   142  		inlState[fn].unparenthesize()
   143  	}
   144  
   145  }
   146  
   147  // DevirtualizeAndInlineFunc interleaves devirtualization and inlining
   148  // on a single function.
   149  func DevirtualizeAndInlineFunc(fn *ir.Func, profile *pgoir.Profile) {
   150  	ir.WithFunc(fn, func() {
   151  		if base.Flag.LowerL != 0 {
   152  			if inlheur.Enabled() && !fn.Wrapper() {
   153  				inlheur.ScoreCalls(fn)
   154  				defer inlheur.ScoreCallsCleanup()
   155  			}
   156  			if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
   157  				inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
   158  			}
   159  		}
   160  
   161  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
   162  		if bigCaller && base.Flag.LowerM > 1 {
   163  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
   164  		}
   165  
   166  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: make(map[*ir.Func]int)}
   167  		s.parenthesize()
   168  		s.fixpoint()
   169  		s.unparenthesize()
   170  	})
   171  }
   172  
   173  type callSite struct {
   174  	fn         *ir.Func
   175  	whichParen int
   176  }
   177  
   178  type inlClosureState struct {
   179  	fn        *ir.Func
   180  	profile   *pgoir.Profile
   181  	callSites map[*ir.ParenExpr]bool // callSites[p] == "p appears in parens" (do not append again)
   182  	resolved  []*ir.Func             // for each call in parens, the resolved target of the call
   183  	useCounts map[*ir.Func]int       // shared among all InlClosureStates
   184  	parens    []*ir.ParenExpr
   185  	bigCaller bool
   186  }
   187  
   188  // resolve attempts to resolve a call to a potentially inlineable callee
   189  // and updates use counts on the callees.  Returns the call site count
   190  // for that callee.
   191  func (s *inlClosureState) resolve(i int) (*ir.Func, int) {
   192  	p := s.parens[i]
   193  	if i < len(s.resolved) {
   194  		if callee := s.resolved[i]; callee != nil {
   195  			return callee, s.useCounts[callee]
   196  		}
   197  	}
   198  	n := p.X
   199  	call, ok := n.(*ir.CallExpr)
   200  	if !ok { // previously inlined
   201  		return nil, -1
   202  	}
   203  	devirtualize.StaticCall(call)
   204  	if callee := inline.InlineCallTarget(s.fn, call, s.profile); callee != nil {
   205  		for len(s.resolved) <= i {
   206  			s.resolved = append(s.resolved, nil)
   207  		}
   208  		s.resolved[i] = callee
   209  		c := s.useCounts[callee] + 1
   210  		s.useCounts[callee] = c
   211  		return callee, c
   212  	}
   213  	return nil, 0
   214  }
   215  
   216  func (s *inlClosureState) edit(i int) ir.Node {
   217  	n := s.parens[i].X
   218  	call, ok := n.(*ir.CallExpr)
   219  	if !ok {
   220  		return nil
   221  	}
   222  	// This is redundant with earlier calls to
   223  	// resolve, but because things can change it
   224  	// must be re-checked.
   225  	callee, count := s.resolve(i)
   226  	if count <= 0 {
   227  		return nil
   228  	}
   229  	if inlCall := inline.TryInlineCall(s.fn, call, s.bigCaller, s.profile, count == 1 && callee.ClosureParent != nil); inlCall != nil {
   230  		return inlCall
   231  	}
   232  	return nil
   233  }
   234  
   235  // Mark inserts parentheses, and is called repeatedly.
   236  // These inserted parentheses mark the call sites where
   237  // inlining will be attempted.
   238  func (s *inlClosureState) mark(n ir.Node) ir.Node {
   239  	// Consider the expression "f(g())". We want to be able to replace
   240  	// "g()" in-place with its inlined representation. But if we first
   241  	// replace "f(...)" with its inlined representation, then "g()" will
   242  	// instead appear somewhere within this new AST.
   243  	//
   244  	// To mitigate this, each matched node n is wrapped in a ParenExpr,
   245  	// so we can reliably replace n in-place by assigning ParenExpr.X.
   246  	// It's safe to use ParenExpr here, because typecheck already
   247  	// removed them all.
   248  
   249  	p, _ := n.(*ir.ParenExpr)
   250  	if p != nil && s.callSites[p] {
   251  		return n // already visited n.X before wrapping
   252  	}
   253  
   254  	if isTestingBLoop(n) {
   255  		// No inlining nor devirtualization performed on b.Loop body
   256  		if base.Flag.LowerM > 0 {
   257  			fmt.Printf("%v: skip inlining within testing.B.loop for %v\n", ir.Line(n), n)
   258  		}
   259  		// We still want to explore inlining opportunities in other parts of ForStmt.
   260  		nFor, _ := n.(*ir.ForStmt)
   261  		nForInit := nFor.Init()
   262  		for i, x := range nForInit {
   263  			if x != nil {
   264  				nForInit[i] = s.mark(x)
   265  			}
   266  		}
   267  		if nFor.Cond != nil {
   268  			nFor.Cond = s.mark(nFor.Cond)
   269  		}
   270  		if nFor.Post != nil {
   271  			nFor.Post = s.mark(nFor.Post)
   272  		}
   273  		return n
   274  	}
   275  
   276  	if p != nil {
   277  		n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node.
   278  	}
   279  
   280  	ok := match(n)
   281  
   282  	ir.EditChildren(n, s.mark)
   283  
   284  	if ok {
   285  		if p == nil {
   286  			p = ir.NewParenExpr(n.Pos(), n)
   287  			p.SetType(n.Type())
   288  			p.SetTypecheck(n.Typecheck())
   289  			s.callSites[p] = true
   290  		}
   291  
   292  		s.parens = append(s.parens, p)
   293  		n = p
   294  	} else if p != nil {
   295  		n = p // didn't change anything, restore n
   296  	}
   297  	return n
   298  }
   299  
   300  // parenthesize applies s.mark to all the nodes within
   301  // s.fn to mark calls and simplify rewriting them in place.
   302  func (s *inlClosureState) parenthesize() {
   303  	ir.EditChildren(s.fn, s.mark)
   304  }
   305  
   306  func (s *inlClosureState) unparenthesize() {
   307  	if s == nil {
   308  		return
   309  	}
   310  	if len(s.parens) == 0 {
   311  		return // short circuit
   312  	}
   313  
   314  	var unparen func(ir.Node) ir.Node
   315  	unparen = func(n ir.Node) ir.Node {
   316  		if paren, ok := n.(*ir.ParenExpr); ok {
   317  			n = paren.X
   318  		}
   319  		ir.EditChildren(n, unparen)
   320  		// special case for tail calls: if the tail call was inlined, transform
   321  		// the tail call to a return stmt if the inlined function was not void,
   322  		// otherwise replace it with the inlined expression followed by a return.
   323  		if tail, ok := n.(*ir.TailCallStmt); ok {
   324  			if inl, done := tail.Call.(*ir.InlinedCallExpr); done {
   325  				if len(inl.ReturnVars) != 0 {
   326  					ret := ir.NewReturnStmt(tail.Pos(), []ir.Node{inl})
   327  					if len(inl.ReturnVars) > 1 {
   328  						typecheck.RewriteMultiValueCall(ret, inl)
   329  					}
   330  					n = ret
   331  				} else {
   332  					ret := ir.NewReturnStmt(tail.Pos(), nil)
   333  					n = ir.NewBlockStmt(tail.Pos(), []ir.Node{inl, ret})
   334  				}
   335  			}
   336  		}
   337  		return n
   338  	}
   339  	ir.EditChildren(s.fn, unparen)
   340  }
   341  
   342  // fixpoint repeatedly edits a function until it stabilizes, returning
   343  // whether anything changed in any of the fixpoint iterations.
   344  //
   345  // It applies s.edit(n) to each node n within the parentheses in s.parens.
   346  // If s.edit(n) returns nil, no change is made. Otherwise, the result
   347  // replaces n in fn's body, and fixpoint iterates at least once more.
   348  //
   349  // After an iteration where all edit calls return nil, fixpoint
   350  // returns.
   351  func (s *inlClosureState) fixpoint() bool {
   352  	changed := false
   353  	ir.WithFunc(s.fn, func() {
   354  		done := false
   355  		for !done {
   356  			done = true
   357  			for i := 0; i < len(s.parens); i++ { // can't use "range parens" here
   358  				paren := s.parens[i]
   359  				if new := s.edit(i); new != nil {
   360  					// Update AST and recursively mark nodes.
   361  					paren.X = new
   362  					ir.EditChildren(new, s.mark) // mark may append to parens
   363  					done = false
   364  					changed = true
   365  				}
   366  			}
   367  		}
   368  	})
   369  	return changed
   370  }
   371  
   372  func match(n ir.Node) bool {
   373  	switch n.(type) {
   374  	case *ir.CallExpr:
   375  		return true
   376  	}
   377  	return false
   378  }
   379  
   380  // isTestingBLoop returns true if it matches the node as a
   381  // testing.(*B).Loop. See issue #61515.
   382  func isTestingBLoop(t ir.Node) bool {
   383  	if t.Op() != ir.OFOR {
   384  		return false
   385  	}
   386  	nFor, ok := t.(*ir.ForStmt)
   387  	if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
   388  		return false
   389  	}
   390  	n, ok := nFor.Cond.(*ir.CallExpr)
   391  	if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
   392  		return false
   393  	}
   394  	name := ir.MethodExprName(n.Fun)
   395  	if name == nil {
   396  		return false
   397  	}
   398  	if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
   399  		fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
   400  		// Attempting to match a function call to testing.(*B).Loop
   401  		return true
   402  	}
   403  	return false
   404  }
   405  

View as plain text