Rust · 20299 bytes Raw Blame History
1 //! Loop unswitching pass.
2 //!
3 //! Hoists loop-invariant conditionals out of loops by cloning the loop
4 //! into two versions — one for the true branch, one for the false
5 //! branch. Each clone has the internal conditional replaced with an
6 //! unconditional branch, eliminating the branch from the hot path.
7 //!
8 //! ```text
9 //! Before:
10 //! do i = 1, n
11 //! if (flag) then ← flag is loop-invariant
12 //! a(i) = b(i)
13 //! else
14 //! a(i) = c(i)
15 //! end if
16 //! end do
17 //!
18 //! After:
19 //! if (flag) then
20 //! do i = 1, n; a(i) = b(i); end do
21 //! else
22 //! do i = 1, n; a(i) = c(i); end do
23 //! end if
24 //! ```
25 //!
26 //! Gated on body size (≤ UNSWITCH_MAX_BODY instructions) to prevent
27 //! code bloat. Fires at O2+.
28
29 use super::loop_utils::{find_preheader, loop_defined_values};
30 use super::pass::Pass;
31 use crate::ir::inst::*;
32 use crate::ir::walk::{
33 find_natural_loops, inst_uses, predecessors, prune_unreachable, terminator_uses,
34 };
35 use std::collections::{HashMap, HashSet};
36
37 /// Maximum number of instructions in the loop body to consider for
38 /// unswitching. Unswitching doubles the code, so keep this tight.
39 const UNSWITCH_MAX_BODY: usize = 50;
40
41 pub struct LoopUnswitch;
42
43 type UnswitchCandidate = (
44 BlockId,
45 ValueId,
46 BlockId,
47 Vec<ValueId>,
48 BlockId,
49 Vec<ValueId>,
50 );
51
52 impl Pass for LoopUnswitch {
53 fn name(&self) -> &'static str {
54 "loop-unswitch"
55 }
56
57 fn run(&self, module: &mut Module) -> bool {
58 let mut changed = false;
59 for func in &mut module.functions {
60 if unswitch_in_function(func) {
61 changed = true;
62 }
63 }
64 changed
65 }
66 }
67
68 /// Attempt to unswitch one loop in the function. Returns true if a
69 /// transformation was applied. We process one loop per call and let
70 /// the pass manager's fixpoint loop handle cascading opportunities.
71 fn unswitch_in_function(func: &mut Function) -> bool {
72 let loops = find_natural_loops(func);
73 let preds = predecessors(func);
74
75 for lp in &loops {
76 let Some(ph_id) = find_preheader(func, lp, &preds) else {
77 continue;
78 };
79
80 // Size guard.
81 let total_insts: usize = lp.body.iter().map(|&b| func.block(b).insts.len()).sum();
82 if total_insts > UNSWITCH_MAX_BODY {
83 continue;
84 }
85
86 // Find a CondBranch inside the loop whose condition is invariant.
87 let loop_defs = loop_defined_values(func, lp);
88 if has_external_ssa_uses(func, lp, &loop_defs) {
89 continue;
90 }
91
92 let candidate = find_unswitch_candidate(func, lp, &loop_defs);
93 let Some((cond_block, cond_val, true_dest, true_args, false_dest, false_args)) = candidate
94 else {
95 continue;
96 };
97
98 // Both successors must be inside the loop (otherwise it's the
99 // loop exit condition, not an unswitchable interior branch).
100 if !lp.body.contains(&true_dest) || !lp.body.contains(&false_dest) {
101 continue;
102 }
103
104 // Clone the entire loop body into two copies.
105 let (true_map, true_blocks) = clone_loop(func, lp);
106 let (false_map, false_blocks) = clone_loop(func, lp);
107
108 // In the true clone: replace the CondBranch with an unconditional
109 // branch to the true successor.
110 let true_cond_block = true_map[&cond_block];
111 let true_true_dest = true_map[&true_dest];
112 let true_val_map = build_value_map(func, lp, &true_blocks, &true_map);
113 let remapped_true_args: Vec<ValueId> = true_args
114 .iter()
115 .map(|v| *true_val_map.get(v).unwrap_or(v))
116 .collect();
117 func.block_mut(true_cond_block).terminator =
118 Some(Terminator::Branch(true_true_dest, remapped_true_args));
119
120 // In the false clone: replace the CondBranch with an unconditional
121 // branch to the false successor.
122 let false_cond_block = false_map[&cond_block];
123 let false_false_dest = false_map[&false_dest];
124 let false_val_map = build_value_map(func, lp, &false_blocks, &false_map);
125 let remapped_false_args: Vec<ValueId> = false_args
126 .iter()
127 .map(|v| *false_val_map.get(v).unwrap_or(v))
128 .collect();
129 func.block_mut(false_cond_block).terminator =
130 Some(Terminator::Branch(false_false_dest, remapped_false_args));
131
132 // Rewrite the preheader to test the condition and branch to
133 // the appropriate clone's header.
134 let true_header = true_map[&lp.header];
135 let false_header = false_map[&lp.header];
136
137 // Get the preheader's original branch args to the header.
138 let ph_args = match &func.block(ph_id).terminator {
139 Some(Terminator::Branch(_, args)) => args.clone(),
140 _ => vec![],
141 };
142
143 func.block_mut(ph_id).terminator = Some(Terminator::CondBranch {
144 cond: cond_val,
145 true_dest: true_header,
146 true_args: ph_args.clone(),
147 false_dest: false_header,
148 false_args: ph_args,
149 });
150
151 // Mark original loop blocks as unreachable so prune removes them.
152 for &bid in &lp.body {
153 func.block_mut(bid).terminator = Some(Terminator::Unreachable);
154 }
155 prune_unreachable(func);
156
157 return true; // one at a time
158 }
159 false
160 }
161
162 /// Find a CondBranch in the loop body whose condition is loop-invariant.
163 fn find_unswitch_candidate(
164 func: &Function,
165 lp: &crate::ir::walk::NaturalLoop,
166 loop_defs: &HashSet<ValueId>,
167 ) -> Option<UnswitchCandidate> {
168 for &bid in &lp.body {
169 let block = func.block(bid);
170 if let Some(Terminator::CondBranch {
171 cond,
172 true_dest,
173 true_args,
174 false_dest,
175 false_args,
176 }) = &block.terminator
177 {
178 // Condition must be loop-invariant (not defined in the loop).
179 if !loop_defs.contains(cond) {
180 // Both targets must be in the loop (not the loop exit).
181 if lp.body.contains(true_dest) && lp.body.contains(false_dest) {
182 return Some((
183 bid,
184 *cond,
185 *true_dest,
186 true_args.clone(),
187 *false_dest,
188 false_args.clone(),
189 ));
190 }
191 }
192 }
193 }
194 None
195 }
196
197 /// Unswitching clones the loop body and removes the original loop blocks.
198 /// If later blocks still read a loop-defined SSA value directly, the transform
199 /// would need extra SSA repair on the outside uses. Until that exists, bail out.
200 fn has_external_ssa_uses(
201 func: &Function,
202 lp: &crate::ir::walk::NaturalLoop,
203 loop_defs: &HashSet<ValueId>,
204 ) -> bool {
205 for block in &func.blocks {
206 if lp.body.contains(&block.id) {
207 continue;
208 }
209 for inst in &block.insts {
210 if inst_uses(&inst.kind)
211 .into_iter()
212 .any(|value| loop_defs.contains(&value))
213 {
214 return true;
215 }
216 }
217 if let Some(term) = &block.terminator {
218 if terminator_uses(term)
219 .into_iter()
220 .any(|value| loop_defs.contains(&value))
221 {
222 return true;
223 }
224 }
225 }
226 false
227 }
228
229 /// Delegate to shared loop utility.
230 fn clone_loop(
231 func: &mut Function,
232 lp: &crate::ir::walk::NaturalLoop,
233 ) -> (HashMap<BlockId, BlockId>, Vec<BlockId>) {
234 super::loop_utils::clone_loop(func, lp)
235 }
236
237 /// Delegate to shared loop utility.
238 fn build_value_map(
239 func: &Function,
240 lp: &crate::ir::walk::NaturalLoop,
241 _new_blocks: &[BlockId],
242 block_map: &HashMap<BlockId, BlockId>,
243 ) -> HashMap<ValueId, ValueId> {
244 super::loop_utils::build_value_map(func, lp, block_map)
245 }
246
247 // remap_inst_kind and remap_terminator have been moved to loop_utils.rs.
248 // The unswitching pass's clone_loop delegate above calls them transitively.
249
250 // ---------------------------------------------------------------------------
251 // Tests
252 // ---------------------------------------------------------------------------
253
254 #[cfg(test)]
255 mod tests {
256 use super::*;
257 use crate::ir::types::{IntWidth, IrType};
258 use crate::ir::verify::verify_module;
259 use crate::lexer::{Position, Span};
260 use crate::opt::pass::Pass;
261
262 fn span() -> Span {
263 let pos = Position { line: 0, col: 0 };
264 Span {
265 file_id: 0,
266 start: pos,
267 end: pos,
268 }
269 }
270
271 /// Build: entry → preheader → header(%i) → cmp → body → cond_branch(flag, t_body, f_body) →
272 /// t_body → latch; f_body → latch; latch → header
273 ///
274 /// `flag` is defined in entry (loop-invariant).
275 fn build_unswitchable_loop() -> Module {
276 let mut m = Module::new("test".into());
277 let mut f = Function::new("test".into(), vec![], IrType::Void);
278
279 let preheader = f.create_block("preheader");
280 let header = f.create_block("header");
281 let cmp_blk = f.create_block("cmp");
282 let body = f.create_block("body");
283 let t_body = f.create_block("t_body");
284 let f_body = f.create_block("f_body");
285 let latch = f.create_block("latch");
286 let exit = f.create_block("exit");
287 let entry = f.entry;
288
289 // Entry: flag = const_bool(true), c1 = const 1, c10 = const 10
290 let flag = f.next_value_id();
291 f.register_type(flag, IrType::Bool);
292 f.block_mut(entry).insts.push(Inst {
293 id: flag,
294 ty: IrType::Bool,
295 span: span(),
296 kind: InstKind::ConstBool(true),
297 });
298 let c1 = f.next_value_id();
299 f.register_type(c1, IrType::Int(IntWidth::I32));
300 f.block_mut(entry).insts.push(Inst {
301 id: c1,
302 ty: IrType::Int(IntWidth::I32),
303 span: span(),
304 kind: InstKind::ConstInt(1, IntWidth::I32),
305 });
306 let c10 = f.next_value_id();
307 f.register_type(c10, IrType::Int(IntWidth::I32));
308 f.block_mut(entry).insts.push(Inst {
309 id: c10,
310 ty: IrType::Int(IntWidth::I32),
311 span: span(),
312 kind: InstKind::ConstInt(10, IntWidth::I32),
313 });
314 f.block_mut(entry).terminator = Some(Terminator::Branch(preheader, vec![]));
315
316 // Preheader → header(c1)
317 f.block_mut(preheader).terminator = Some(Terminator::Branch(header, vec![c1]));
318
319 // Header(%i) → cmp
320 let iv = f.next_value_id();
321 f.register_type(iv, IrType::Int(IntWidth::I32));
322 f.block_mut(header).params.push(BlockParam {
323 id: iv,
324 ty: IrType::Int(IntWidth::I32),
325 });
326 f.block_mut(header).terminator = Some(Terminator::Branch(cmp_blk, vec![]));
327
328 // Cmp: icmp le %i, 10; condBr body, exit
329 let cmp_v = f.next_value_id();
330 f.register_type(cmp_v, IrType::Bool);
331 f.block_mut(cmp_blk).insts.push(Inst {
332 id: cmp_v,
333 ty: IrType::Bool,
334 span: span(),
335 kind: InstKind::ICmp(CmpOp::Le, iv, c10),
336 });
337 f.block_mut(cmp_blk).terminator = Some(Terminator::CondBranch {
338 cond: cmp_v,
339 true_dest: body,
340 true_args: vec![],
341 false_dest: exit,
342 false_args: vec![],
343 });
344
345 // Body: condBr flag, t_body, f_body ← the unswitchable conditional
346 f.block_mut(body).terminator = Some(Terminator::CondBranch {
347 cond: flag,
348 true_dest: t_body,
349 true_args: vec![],
350 false_dest: f_body,
351 false_args: vec![],
352 });
353
354 // t_body → latch
355 f.block_mut(t_body).terminator = Some(Terminator::Branch(latch, vec![]));
356
357 // f_body → latch
358 f.block_mut(f_body).terminator = Some(Terminator::Branch(latch, vec![]));
359
360 // Latch: iadd + br header
361 let nxt = f.next_value_id();
362 f.register_type(nxt, IrType::Int(IntWidth::I32));
363 f.block_mut(latch).insts.push(Inst {
364 id: nxt,
365 ty: IrType::Int(IntWidth::I32),
366 span: span(),
367 kind: InstKind::IAdd(iv, c1),
368 });
369 f.block_mut(latch).terminator = Some(Terminator::Branch(header, vec![nxt]));
370
371 // Exit
372 f.block_mut(exit).terminator = Some(Terminator::Return(None));
373
374 m.add_function(f);
375 m
376 }
377
378 #[test]
379 fn unswitches_invariant_conditional() {
380 let mut m = build_unswitchable_loop();
381 let pass = LoopUnswitch;
382 let changed = pass.run(&mut m);
383 assert!(changed, "should unswitch the invariant conditional");
384
385 // After unswitching, the preheader should have a CondBranch (not a Branch).
386 let f = &m.functions[0];
387 let preheader = f
388 .blocks
389 .iter()
390 .find(|b| b.name.contains("preheader"))
391 .unwrap();
392 assert!(
393 matches!(&preheader.terminator, Some(Terminator::CondBranch { .. })),
394 "preheader should now have a CondBranch: {:?}",
395 preheader.terminator
396 );
397 }
398
399 #[test]
400 fn does_not_unswitch_variant_conditional() {
401 // Build a loop where the condition IS loop-defined (the IV).
402 let mut m = Module::new("test".into());
403 let mut f = Function::new("test".into(), vec![], IrType::Void);
404
405 let header = f.create_block("header");
406 let cmp_blk = f.create_block("cmp");
407 let body = f.create_block("body");
408 let latch = f.create_block("latch");
409 let exit = f.create_block("exit");
410 let entry = f.entry;
411
412 let c1 = f.next_value_id();
413 f.register_type(c1, IrType::Int(IntWidth::I32));
414 f.block_mut(entry).insts.push(Inst {
415 id: c1,
416 ty: IrType::Int(IntWidth::I32),
417 span: span(),
418 kind: InstKind::ConstInt(1, IntWidth::I32),
419 });
420 let c10 = f.next_value_id();
421 f.register_type(c10, IrType::Int(IntWidth::I32));
422 f.block_mut(entry).insts.push(Inst {
423 id: c10,
424 ty: IrType::Int(IntWidth::I32),
425 span: span(),
426 kind: InstKind::ConstInt(10, IntWidth::I32),
427 });
428 f.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![c1]));
429
430 let iv = f.next_value_id();
431 f.register_type(iv, IrType::Int(IntWidth::I32));
432 f.block_mut(header).params.push(BlockParam {
433 id: iv,
434 ty: IrType::Int(IntWidth::I32),
435 });
436 let cmp_v = f.next_value_id();
437 f.register_type(cmp_v, IrType::Bool);
438 f.block_mut(header).insts.push(Inst {
439 id: cmp_v,
440 ty: IrType::Bool,
441 span: span(),
442 kind: InstKind::ICmp(CmpOp::Le, iv, c10),
443 });
444 f.block_mut(header).terminator = Some(Terminator::CondBranch {
445 cond: cmp_v,
446 true_dest: body,
447 true_args: vec![],
448 false_dest: exit,
449 false_args: vec![],
450 });
451
452 // Body: the "conditional" uses the IV (loop-variant).
453 let iv_cmp = f.next_value_id();
454 f.register_type(iv_cmp, IrType::Bool);
455 f.block_mut(body).insts.push(Inst {
456 id: iv_cmp,
457 ty: IrType::Bool,
458 span: span(),
459 kind: InstKind::ICmp(CmpOp::Le, iv, c1),
460 });
461 f.block_mut(body).terminator = Some(Terminator::CondBranch {
462 cond: iv_cmp,
463 true_dest: latch,
464 true_args: vec![],
465 false_dest: latch,
466 false_args: vec![],
467 });
468
469 let nxt = f.next_value_id();
470 f.register_type(nxt, IrType::Int(IntWidth::I32));
471 f.block_mut(latch).insts.push(Inst {
472 id: nxt,
473 ty: IrType::Int(IntWidth::I32),
474 span: span(),
475 kind: InstKind::IAdd(iv, c1),
476 });
477 f.block_mut(latch).terminator = Some(Terminator::Branch(header, vec![nxt]));
478 f.block_mut(exit).terminator = Some(Terminator::Return(None));
479
480 m.add_function(f);
481
482 let pass = LoopUnswitch;
483 let changed = pass.run(&mut m);
484 assert!(!changed, "should not unswitch a loop-variant conditional");
485 }
486
487 #[test]
488 fn does_not_unswitch_when_loop_values_escape_directly() {
489 let mut m = Module::new("test".into());
490 let mut f = Function::new("test".into(), vec![], IrType::Void);
491
492 let preheader = f.create_block("preheader");
493 let header = f.create_block("header");
494 let body = f.create_block("body");
495 let t_body = f.create_block("t_body");
496 let f_body = f.create_block("f_body");
497 let latch = f.create_block("latch");
498 let exit = f.create_block("exit");
499 let entry = f.entry;
500
501 let flag = f.next_value_id();
502 f.register_type(flag, IrType::Bool);
503 f.block_mut(entry).insts.push(Inst {
504 id: flag,
505 ty: IrType::Bool,
506 span: span(),
507 kind: InstKind::ConstBool(true),
508 });
509 let c1 = f.next_value_id();
510 f.register_type(c1, IrType::Int(IntWidth::I32));
511 f.block_mut(entry).insts.push(Inst {
512 id: c1,
513 ty: IrType::Int(IntWidth::I32),
514 span: span(),
515 kind: InstKind::ConstInt(1, IntWidth::I32),
516 });
517 let c10 = f.next_value_id();
518 f.register_type(c10, IrType::Int(IntWidth::I32));
519 f.block_mut(entry).insts.push(Inst {
520 id: c10,
521 ty: IrType::Int(IntWidth::I32),
522 span: span(),
523 kind: InstKind::ConstInt(10, IntWidth::I32),
524 });
525 f.block_mut(entry).terminator = Some(Terminator::Branch(preheader, vec![]));
526
527 f.block_mut(preheader).terminator = Some(Terminator::Branch(header, vec![c1]));
528
529 let iv = f.next_value_id();
530 f.register_type(iv, IrType::Int(IntWidth::I32));
531 f.block_mut(header).params.push(BlockParam {
532 id: iv,
533 ty: IrType::Int(IntWidth::I32),
534 });
535 f.block_mut(header).terminator = Some(Terminator::Branch(body, vec![]));
536
537 f.block_mut(body).terminator = Some(Terminator::CondBranch {
538 cond: flag,
539 true_dest: t_body,
540 true_args: vec![],
541 false_dest: f_body,
542 false_args: vec![],
543 });
544 f.block_mut(t_body).terminator = Some(Terminator::Branch(latch, vec![]));
545 f.block_mut(f_body).terminator = Some(Terminator::Branch(latch, vec![]));
546
547 let nxt = f.next_value_id();
548 f.register_type(nxt, IrType::Int(IntWidth::I32));
549 f.block_mut(latch).insts.push(Inst {
550 id: nxt,
551 ty: IrType::Int(IntWidth::I32),
552 span: span(),
553 kind: InstKind::IAdd(iv, c1),
554 });
555 let cmp_v = f.next_value_id();
556 f.register_type(cmp_v, IrType::Bool);
557 f.block_mut(latch).insts.push(Inst {
558 id: cmp_v,
559 ty: IrType::Bool,
560 span: span(),
561 kind: InstKind::ICmp(CmpOp::Le, nxt, c10),
562 });
563 f.block_mut(latch).terminator = Some(Terminator::CondBranch {
564 cond: cmp_v,
565 true_dest: header,
566 true_args: vec![nxt],
567 false_dest: exit,
568 false_args: vec![],
569 });
570
571 let escaped = f.next_value_id();
572 f.register_type(escaped, IrType::Int(IntWidth::I32));
573 f.block_mut(exit).insts.push(Inst {
574 id: escaped,
575 ty: IrType::Int(IntWidth::I32),
576 span: span(),
577 kind: InstKind::IAdd(nxt, c1),
578 });
579 f.block_mut(exit).terminator = Some(Terminator::Return(None));
580
581 m.add_function(f);
582 assert!(
583 verify_module(&m).is_empty(),
584 "test setup must start valid before unswitch"
585 );
586
587 let pass = LoopUnswitch;
588 let changed = pass.run(&mut m);
589 assert!(
590 !changed,
591 "unswitch should bail when loop-defined values escape directly"
592 );
593 assert!(
594 verify_module(&m).is_empty(),
595 "bailing out should keep the IR valid"
596 );
597 }
598 }
599