Rust · 17787 bytes Raw Blame History
1 //! Function inlining pass.
2 //!
3 //! Replaces call sites with the callee's body, enabling downstream
4 //! optimizations (const prop, DSE, LICM) to fire across the former
5 //! call boundary.
6 //!
7 //! Algorithm:
8 //! 1. Build call graph, process in bottom-up order (callees first)
9 //! 2. For each call site passing the cost model, clone the callee's
10 //! blocks into the caller with fresh ValueIds/BlockIds
11 //! 3. Map callee params → caller args
12 //! 4. Replace Return(val) → Branch(post_call_block, [val])
13 //! 5. Split the call-containing block: pre-call instructions +
14 //! branch to cloned entry, post-call block receives return value
15
16 use super::callgraph::CallGraph;
17 use super::loop_utils::{remap_inst_kind, remap_terminator};
18 use super::pass::Pass;
19 use super::pipeline::OptLevel;
20 use crate::ir::inst::*;
21 use crate::ir::types::IrType;
22 use crate::ir::walk::{prune_unreachable, substitute_uses};
23 use std::collections::HashMap;
24
25 /// Maximum callee instruction count for inlining.
26 const INLINE_THRESHOLD_O1: usize = 20;
27 const INLINE_THRESHOLD_O2: usize = 100;
28
29 pub struct Inline {
30 threshold: usize,
31 }
32
33 impl Inline {
34 pub fn for_level(level: OptLevel) -> Self {
35 let threshold = match level {
36 OptLevel::O1 => INLINE_THRESHOLD_O1,
37 OptLevel::O2 | OptLevel::Os => INLINE_THRESHOLD_O2,
38 OptLevel::O3 | OptLevel::Ofast => 200,
39 OptLevel::O0 => 0,
40 };
41 Self { threshold }
42 }
43 }
44
45 impl Pass for Inline {
46 fn name(&self) -> &'static str {
47 "inline"
48 }
49
50 fn run(&self, module: &mut Module) -> bool {
51 if self.threshold == 0 {
52 return false;
53 }
54
55 let cg = CallGraph::build(module);
56 let order = cg.bottom_up_order();
57 let mut changed = false;
58
59 for &caller_idx in &order {
60 if inline_calls_in_function(module, caller_idx, &cg, self.threshold) {
61 changed = true;
62 }
63 }
64
65 changed
66 }
67 }
68
69 fn inline_calls_in_function(
70 module: &mut Module,
71 caller_idx: u32,
72 cg: &CallGraph,
73 threshold: usize,
74 ) -> bool {
75 // Find call sites eligible for inlining.
76 let call_sites: Vec<(BlockId, usize, u32, Vec<ValueId>)> = {
77 let caller = &module.functions[caller_idx as usize];
78 let mut sites = Vec::new();
79 for block in &caller.blocks {
80 for (inst_idx, inst) in block.insts.iter().enumerate() {
81 if let InstKind::Call(FuncRef::Internal(callee_idx), args) = &inst.kind {
82 let ci = *callee_idx;
83 // Don't inline recursive functions.
84 if cg.is_recursive(ci) {
85 continue;
86 }
87 // Don't self-inline.
88 if ci == caller_idx {
89 continue;
90 }
91 // Cost check.
92 if cg.inline_cost(ci) > threshold {
93 continue;
94 }
95 // Argument/parameter type agreement. When a
96 // Fortran OPTIONAL parameter is absent at a call
97 // site, the caller passes `const_i64 0` as a null
98 // placeholder — that's sound at the call boundary
99 // because the callee wraps every read in
100 // PRESENT(), but inlining would map a `load %b`
101 // from the callee onto a `load <i64 const>` in the
102 // caller, which fails the IR verifier. Refuse to
103 // inline any call whose arg type doesn't match the
104 // callee param type exactly. The same check also
105 // guards any future call-boundary coercion shims.
106 let callee = &module.functions[ci as usize];
107 if callee.params.len() != args.len() {
108 continue;
109 }
110 let mismatched =
111 callee.params.iter().zip(args.iter()).any(|(p, a)| {
112 match caller.value_type(*a) {
113 Some(ty) => ty != p.ty,
114 None => true,
115 }
116 });
117 if mismatched {
118 continue;
119 }
120 sites.push((block.id, inst_idx, ci, args.clone()));
121 }
122 }
123 }
124 sites
125 };
126
127 if call_sites.is_empty() {
128 return false;
129 }
130
131 // Inline one call site, then return true to let the pass manager
132 // re-run us. Processing multiple sites in one invocation is unsafe:
133 // splitting a block invalidates indices for other call sites in the
134 // same block. The pass manager's fixpoint loop handles re-invocation.
135 let (call_block_id, call_inst_idx, callee_idx, caller_args) = call_sites[0].clone();
136 {
137 // Clone the callee's body into the caller.
138 let callee = &module.functions[callee_idx as usize];
139 let callee_entry = callee.entry;
140 let callee_blocks: Vec<BasicBlock> = callee.blocks.clone();
141 let callee_params: Vec<Param> = callee.params.clone();
142 let callee_return_ty = callee.return_type.clone();
143
144 let caller = &mut module.functions[caller_idx as usize];
145
146 // Build value map: callee params → caller args.
147 let mut val_map: HashMap<ValueId, ValueId> = HashMap::new();
148 for (param, &arg) in callee_params.iter().zip(caller_args.iter()) {
149 val_map.insert(param.id, arg);
150 }
151
152 // Allocate fresh IDs for all callee values.
153 let mut block_map: HashMap<BlockId, BlockId> = HashMap::new();
154 for cb in &callee_blocks {
155 let new_bid = caller.create_block(&format!("inline_{}", cb.name));
156 block_map.insert(cb.id, new_bid);
157 }
158
159 // Create post-call block to receive the return value.
160 let post_call = caller.create_block("inline_post");
161 let has_return_val = !matches!(callee_return_ty, IrType::Void);
162
163 let result_param_id = if has_return_val {
164 let pid = caller.next_value_id();
165 caller.register_type(pid, callee_return_ty.clone());
166 caller.block_mut(post_call).params.push(BlockParam {
167 id: pid,
168 ty: callee_return_ty.clone(),
169 });
170 Some(pid)
171 } else {
172 None
173 };
174
175 // Clone block params and pre-allocate fresh IDs for every
176 // callee instruction before remapping any instruction bodies.
177 // Valid SSA can use a value defined in a dominating block that
178 // appears later in the block vector; remapping instruction
179 // operands against a partial map leaves raw callee IDs behind
180 // and can alias unrelated caller values.
181 for cb in &callee_blocks {
182 let new_bid = block_map[&cb.id];
183 // Clone block params.
184 for bp in &cb.params {
185 let new_id = caller.next_value_id();
186 caller.register_type(new_id, bp.ty.clone());
187 val_map.insert(bp.id, new_id);
188 caller.block_mut(new_bid).params.push(BlockParam {
189 id: new_id,
190 ty: bp.ty.clone(),
191 });
192 }
193 }
194 for cb in &callee_blocks {
195 for inst in &cb.insts {
196 let new_id = caller.next_value_id();
197 caller.register_type(new_id, inst.ty.clone());
198 val_map.insert(inst.id, new_id);
199 }
200 }
201 for cb in &callee_blocks {
202 let new_bid = block_map[&cb.id];
203 for inst in &cb.insts {
204 let new_id = *val_map
205 .get(&inst.id)
206 .expect("cloned callee instruction id should be preallocated");
207 let new_kind = remap_inst_kind(&inst.kind, &val_map);
208 caller.block_mut(new_bid).insts.push(Inst {
209 id: new_id,
210 kind: new_kind,
211 ty: inst.ty.clone(),
212 span: inst.span,
213 });
214 }
215 }
216
217 // Clone terminators, replacing Return with Branch to post_call.
218 for cb in &callee_blocks {
219 let new_bid = block_map[&cb.id];
220 let new_term = match &cb.terminator {
221 Some(Terminator::Return(Some(val))) => {
222 let remapped = *val_map.get(val).unwrap_or(val);
223 Terminator::Branch(post_call, vec![remapped])
224 }
225 Some(Terminator::Return(None)) => Terminator::Branch(post_call, vec![]),
226 Some(other) => remap_terminator(other, &block_map, &val_map),
227 None => Terminator::Unreachable,
228 };
229 caller.block_mut(new_bid).terminator = Some(new_term);
230 }
231
232 // Split the call-containing block: move instructions after the call
233 // into the post-call block.
234 let call_block = caller.block_mut(call_block_id);
235 let call_result_id = call_block.insts[call_inst_idx].id;
236
237 // Move post-call instructions to the new block.
238 let post_insts: Vec<Inst> = call_block.insts.split_off(call_inst_idx + 1);
239 let old_term = call_block.terminator.take();
240
241 // Remove the call instruction itself.
242 call_block.insts.pop(); // removes the call at call_inst_idx
243
244 // Add branch from call block to inlined entry.
245 let inlined_entry = block_map[&callee_entry];
246 caller.block_mut(call_block_id).terminator =
247 Some(Terminator::Branch(inlined_entry, vec![]));
248
249 // Populate post-call block with remaining instructions and terminator.
250 // Remap uses of the call result to the post-call block param.
251 let mut post_remap: HashMap<ValueId, ValueId> = HashMap::new();
252 if let Some(param_id) = result_param_id {
253 post_remap.insert(call_result_id, param_id);
254 }
255
256 for inst in post_insts {
257 let new_kind = if post_remap.is_empty() {
258 inst.kind.clone()
259 } else {
260 remap_inst_kind(&inst.kind, &post_remap)
261 };
262 caller.block_mut(post_call).insts.push(Inst {
263 id: inst.id,
264 kind: new_kind,
265 ty: inst.ty,
266 span: inst.span,
267 });
268 }
269
270 if let Some(term) = old_term {
271 let new_term = if post_remap.is_empty() {
272 term
273 } else {
274 remap_terminator(&term, &HashMap::new(), &post_remap)
275 };
276 caller.block_mut(post_call).terminator = Some(new_term);
277 }
278
279 if let Some(param_id) = result_param_id {
280 substitute_uses(caller, call_result_id, param_id);
281 }
282 } // end single inline
283
284 let caller = &mut module.functions[caller_idx as usize];
285 prune_unreachable(caller);
286 true
287 }
288
289 #[cfg(test)]
290 mod tests {
291 use super::*;
292 use crate::ir::types::{IntWidth, IrType};
293 use crate::ir::verify::verify_module;
294 use crate::lexer::{Position, Span};
295 use crate::opt::pass::Pass;
296
297 fn dummy_span() -> Span {
298 let p = Position { line: 1, col: 1 };
299 Span {
300 start: p,
301 end: p,
302 file_id: 0,
303 }
304 }
305
306 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
307 let id = f.next_value_id();
308 let entry = f.entry;
309 f.block_mut(entry).insts.push(Inst {
310 id,
311 kind,
312 ty,
313 span: dummy_span(),
314 });
315 id
316 }
317
318 #[test]
319 fn inline_no_op_at_o0() {
320 let mut m = Module::new("test".into());
321 let mut f = Function::new("main".into(), vec![], IrType::Void);
322 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
323 m.add_function(f);
324 let pass = Inline::for_level(OptLevel::O0);
325 assert!(!pass.run(&mut m));
326 }
327
328 #[test]
329 fn inline_no_op_without_internal_calls() {
330 let mut m = Module::new("test".into());
331 let mut f = Function::new("main".into(), vec![], IrType::Void);
332 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
333 m.add_function(f);
334 let pass = Inline::for_level(OptLevel::O2);
335 assert!(!pass.run(&mut m));
336 }
337
338 #[test]
339 fn inline_rewrites_result_uses_in_successor_blocks() {
340 let mut m = Module::new("test".into());
341
342 let mut callee = Function::new("callee".into(), vec![], IrType::Int(IntWidth::I32));
343 let then_b = callee.create_block("then");
344 let else_b = callee.create_block("else");
345 let cond = push(&mut callee, InstKind::ConstBool(true), IrType::Bool);
346 let entry = callee.entry;
347 callee.block_mut(entry).terminator = Some(Terminator::CondBranch {
348 cond,
349 true_dest: then_b,
350 true_args: vec![],
351 false_dest: else_b,
352 false_args: vec![],
353 });
354 let one = callee.next_value_id();
355 callee.register_type(one, IrType::Int(IntWidth::I32));
356 callee.block_mut(then_b).insts.push(Inst {
357 id: one,
358 kind: InstKind::ConstInt(1, IntWidth::I32),
359 ty: IrType::Int(IntWidth::I32),
360 span: dummy_span(),
361 });
362 callee.block_mut(then_b).terminator = Some(Terminator::Return(Some(one)));
363 let two = callee.next_value_id();
364 callee.register_type(two, IrType::Int(IntWidth::I32));
365 callee.block_mut(else_b).insts.push(Inst {
366 id: two,
367 kind: InstKind::ConstInt(2, IntWidth::I32),
368 ty: IrType::Int(IntWidth::I32),
369 span: dummy_span(),
370 });
371 callee.block_mut(else_b).terminator = Some(Terminator::Return(Some(two)));
372 m.add_function(callee);
373
374 let mut caller = Function::new("caller".into(), vec![], IrType::Int(IntWidth::I32));
375 let then_b = caller.create_block("then");
376 let else_b = caller.create_block("else");
377 let call_id = push(
378 &mut caller,
379 InstKind::Call(FuncRef::Internal(0), vec![]),
380 IrType::Int(IntWidth::I32),
381 );
382 let zero = push(
383 &mut caller,
384 InstKind::ConstInt(0, IntWidth::I32),
385 IrType::Int(IntWidth::I32),
386 );
387 let cond = push(
388 &mut caller,
389 InstKind::ICmp(CmpOp::Gt, call_id, zero),
390 IrType::Bool,
391 );
392 let entry = caller.entry;
393 caller.block_mut(entry).terminator = Some(Terminator::CondBranch {
394 cond,
395 true_dest: then_b,
396 true_args: vec![],
397 false_dest: else_b,
398 false_args: vec![],
399 });
400 let add = caller.next_value_id();
401 caller.register_type(add, IrType::Int(IntWidth::I32));
402 caller.block_mut(then_b).insts.push(Inst {
403 id: add,
404 kind: InstKind::IAdd(call_id, zero),
405 ty: IrType::Int(IntWidth::I32),
406 span: dummy_span(),
407 });
408 caller.block_mut(then_b).terminator = Some(Terminator::Return(Some(add)));
409 caller.block_mut(else_b).terminator = Some(Terminator::Return(Some(call_id)));
410 m.add_function(caller);
411
412 let pass = Inline::for_level(OptLevel::O2);
413 assert!(pass.run(&mut m), "expected the internal call to inline");
414
415 let post = verify_module(&m);
416 assert!(
417 post.is_empty(),
418 "inliner left invalid SSA when the call result escaped into successor blocks: {:?}",
419 post
420 );
421 }
422
423 #[test]
424 fn inline_preallocates_defs_for_later_vector_blocks() {
425 let mut m = Module::new("test".into());
426
427 let mut callee = Function::new("callee".into(), vec![], IrType::Int(IntWidth::I32));
428 let use_b = callee.create_block("use");
429 let def_b = callee.create_block("def");
430 let entry = callee.entry;
431 callee.block_mut(entry).terminator = Some(Terminator::Branch(def_b, vec![]));
432
433 let shared = callee.next_value_id();
434 callee.register_type(shared, IrType::Int(IntWidth::I32));
435 callee.block_mut(def_b).insts.push(Inst {
436 id: shared,
437 kind: InstKind::ConstInt(7, IntWidth::I32),
438 ty: IrType::Int(IntWidth::I32),
439 span: dummy_span(),
440 });
441 callee.block_mut(def_b).terminator = Some(Terminator::Branch(use_b, vec![]));
442
443 let doubled = callee.next_value_id();
444 callee.register_type(doubled, IrType::Int(IntWidth::I32));
445 callee.block_mut(use_b).insts.push(Inst {
446 id: doubled,
447 kind: InstKind::IAdd(shared, shared),
448 ty: IrType::Int(IntWidth::I32),
449 span: dummy_span(),
450 });
451 callee.block_mut(use_b).terminator = Some(Terminator::Return(Some(doubled)));
452 m.add_function(callee);
453
454 let mut caller = Function::new("caller".into(), vec![], IrType::Int(IntWidth::I32));
455 let call_id = push(
456 &mut caller,
457 InstKind::Call(FuncRef::Internal(0), vec![]),
458 IrType::Int(IntWidth::I32),
459 );
460 let entry = caller.entry;
461 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(call_id)));
462 m.add_function(caller);
463
464 let pass = Inline::for_level(OptLevel::O2);
465 assert!(pass.run(&mut m), "expected the internal call to inline");
466
467 let post = verify_module(&m);
468 assert!(
469 post.is_empty(),
470 "inliner left invalid IDs when the callee block vector was not in dominance order: {:?}",
471 post
472 );
473 }
474 }
475