Rust · 20946 bytes Raw Blame History
1 //! Loop peeling pass.
2 //!
3 //! Peels the first iteration of a counted loop when the body contains
4 //! an equality check against the initial IV value (`if (i == 1)`).
5 //!
6 //! The approach: clone the ENTIRE loop as a subgraph using `clone_loop`,
7 //! then redirect two edges so the clone executes exactly one iteration:
8 //! 1. Clone's latch back-edge → original loop header (passing init+stride)
9 //! 2. Clone's cmp exit → original loop's exit block
10 //!
11 //! This preserves all internal control flow (if/else branches) in the
12 //! peeled iteration. Later const-prop folds `iv=init_const` through the
13 //! clone, turning `if (i==1)` into `if (true)` and eliminating dead code.
14
15 use super::loop_utils::{clone_loop, find_preheader, loop_defined_values, resolve_const_int};
16 use super::pass::Pass;
17 use crate::ir::inst::*;
18 use crate::ir::types::IrType;
19 use crate::ir::walk::{find_natural_loops, inst_uses, predecessors, terminator_uses};
20 use std::collections::HashSet;
21
22 pub struct LoopPeel;
23
24 impl Pass for LoopPeel {
25 fn name(&self) -> &'static str {
26 "loop-peel"
27 }
28
29 fn run(&self, module: &mut Module) -> bool {
30 let mut changed = false;
31 for func in &mut module.functions {
32 if peel_in_function(func) {
33 changed = true;
34 }
35 }
36 changed
37 }
38 }
39
40 fn peel_in_function(func: &mut Function) -> bool {
41 let loops = find_natural_loops(func);
42 let preds = predecessors(func);
43
44 for lp in &loops {
45 let Some(ph_id) = find_preheader(func, lp, &preds) else {
46 continue;
47 };
48
49 // Header must have exactly 1 block param (the IV).
50 let hdr = func.block(lp.header);
51 if hdr.params.len() != 1 {
52 continue;
53 }
54 let iv = hdr.params[0].id;
55
56 // Get the init value from the preheader's branch to header.
57 let init_val = match &func.block(ph_id).terminator {
58 Some(Terminator::Branch(dest, args)) if *dest == lp.header && args.len() == 1 => {
59 args[0]
60 }
61 _ => continue,
62 };
63
64 // Init must be a compile-time constant.
65 let Some(init_const) = resolve_const_int(func, init_val) else {
66 continue;
67 };
68
69 // Must have a single latch.
70 if lp.latches.len() != 1 {
71 continue;
72 }
73 let latch_id = lp.latches[0];
74
75 // Latch must branch back to header with one arg.
76 let next_iv = match &func.block(latch_id).terminator {
77 Some(Terminator::Branch(dest, args)) if *dest == lp.header && args.len() == 1 => {
78 args[0]
79 }
80 _ => continue,
81 };
82
83 // Find stride: next_iv = iadd(iv, stride_const).
84 let stride_const = {
85 let mut found = None;
86 for inst in &func.block(latch_id).insts {
87 if inst.id == next_iv {
88 if let InstKind::IAdd(a, b) = &inst.kind {
89 if *a == iv {
90 found = resolve_const_int(func, *b);
91 } else if *b == iv {
92 found = resolve_const_int(func, *a);
93 }
94 }
95 break;
96 }
97 }
98 found
99 };
100 let Some(stride) = stride_const else { continue };
101
102 // Check if body has a FIRST-ITERATION conditional:
103 // ICmp(Eq, iv, init_val) feeding a CondBranch.
104 let loop_defs = loop_defined_values(func, lp);
105 if has_external_ssa_uses(func, lp, &loop_defs) {
106 continue;
107 }
108 if !has_first_iter_conditional(func, lp, iv, init_const, &loop_defs) {
109 continue;
110 }
111
112 // Find the loop's exit block (cmp's false-branch target outside the body).
113 let exit_block = find_loop_exit(func, lp);
114 let Some(exit_id) = exit_block else { continue };
115
116 // Also find the exit args (values passed to exit on the false branch).
117 // ---- Perform the peel ----
118 do_peel(func, lp, ph_id, init_val, latch_id, exit_id, stride);
119 return true; // one at a time
120 }
121 false
122 }
123
124 /// Only match `ICmp(Eq, iv, val)` where `val` resolves to the same
125 /// constant as the loop's init value. This is the "first iteration"
126 /// check. Do NOT match bound checks like `ICmp(Le, iv, 10)`.
127 fn has_first_iter_conditional(
128 func: &Function,
129 lp: &crate::ir::walk::NaturalLoop,
130 iv: ValueId,
131 init_const: i64,
132 loop_defs: &HashSet<ValueId>,
133 ) -> bool {
134 for &bid in &lp.body {
135 let block = func.block(bid);
136 for inst in &block.insts {
137 if let InstKind::ICmp(CmpOp::Eq, a, b) = &inst.kind {
138 // One operand must be the IV.
139 let (is_iv, other) = if *a == iv {
140 (true, *b)
141 } else if *b == iv {
142 (true, *a)
143 } else {
144 (false, ValueId(0))
145 };
146 if !is_iv {
147 continue;
148 }
149
150 // The other operand must resolve to the init constant
151 // and be loop-invariant.
152 if loop_defs.contains(&other) {
153 continue;
154 }
155 if let Some(c) = resolve_const_int(func, other) {
156 if c == init_const {
157 // Must feed a CondBranch in this block.
158 if let Some(Terminator::CondBranch { cond, .. }) = &block.terminator {
159 if *cond == inst.id {
160 return true;
161 }
162 }
163 }
164 }
165 }
166 }
167 }
168 false
169 }
170
171 /// Find the loop's exit block (first false-branch target outside the body).
172 fn find_loop_exit(func: &Function, lp: &crate::ir::walk::NaturalLoop) -> Option<BlockId> {
173 for &bid in &lp.body {
174 let block = func.block(bid);
175 if let Some(Terminator::CondBranch { false_dest, .. }) = &block.terminator {
176 if !lp.body.contains(false_dest) {
177 return Some(*false_dest);
178 }
179 }
180 }
181 None
182 }
183
184 /// Peeling adds a new peeled-exit path into the original exit block.
185 /// If downstream blocks still read loop-defined SSA values directly, the
186 /// transform would need extra SSA repair across the new predecessor. Until
187 /// that exists, skip those loops.
188 fn has_external_ssa_uses(
189 func: &Function,
190 lp: &crate::ir::walk::NaturalLoop,
191 loop_defs: &HashSet<ValueId>,
192 ) -> bool {
193 for block in &func.blocks {
194 if lp.body.contains(&block.id) {
195 continue;
196 }
197 for inst in &block.insts {
198 if inst_uses(&inst.kind)
199 .into_iter()
200 .any(|value| loop_defs.contains(&value))
201 {
202 return true;
203 }
204 }
205 if let Some(term) = &block.terminator {
206 if terminator_uses(term)
207 .into_iter()
208 .any(|value| loop_defs.contains(&value))
209 {
210 return true;
211 }
212 }
213 }
214 false
215 }
216
217 /// Get the args passed to the exit block from the cmp's false-branch.
218 fn find_exit_args(
219 func: &Function,
220 lp: &crate::ir::walk::NaturalLoop,
221 exit_id: BlockId,
222 ) -> Vec<ValueId> {
223 for &bid in &lp.body {
224 let block = func.block(bid);
225 if let Some(Terminator::CondBranch {
226 false_dest,
227 false_args,
228 ..
229 }) = &block.terminator
230 {
231 if *false_dest == exit_id {
232 return false_args.clone();
233 }
234 }
235 }
236 Vec::new()
237 }
238
239 /// Perform the peel using clone_loop.
240 ///
241 /// Strategy:
242 /// 1. Clone the entire loop as a subgraph (preserves if/else, etc.)
243 /// 2. Redirect clone's latch → original header (with init+stride)
244 /// 3. Redirect clone's cmp exit → original exit block
245 /// 4. Preheader → clone's header (with init)
246 fn do_peel(
247 func: &mut Function,
248 lp: &crate::ir::walk::NaturalLoop,
249 ph_id: BlockId,
250 init_val: ValueId,
251 latch_id: BlockId,
252 exit_id: BlockId,
253 stride: i64,
254 ) {
255 // Clone the entire loop.
256 let (block_map, _new_blocks) = clone_loop(func, lp);
257
258 // The IV type (needed to emit the init+stride constant).
259 let hdr = func.block(lp.header);
260 let iv_ty = hdr.params[0].ty.clone();
261 let iv_width = match &iv_ty {
262 IrType::Int(w) => *w,
263 _ => return,
264 };
265 let init_const = resolve_const_int(func, init_val).unwrap();
266
267 // Emit init+stride constant in a helper block so it's available.
268 let next_init_id = func.next_value_id();
269 func.register_type(next_init_id, iv_ty.clone());
270 // We'll place this constant in the clone's latch before rewriting it.
271
272 // --- Redirect clone's latch ---
273 // Clone's latch currently branches to clone's header. Redirect it
274 // to the ORIGINAL header, passing init+stride.
275 let clone_latch = block_map[&latch_id];
276 // Add the init+stride constant to the clone's latch.
277 let dummy_span = crate::lexer::Span {
278 file_id: 0,
279 start: crate::lexer::Position { line: 0, col: 0 },
280 end: crate::lexer::Position { line: 0, col: 0 },
281 };
282 func.block_mut(clone_latch).insts.push(Inst {
283 id: next_init_id,
284 kind: InstKind::ConstInt((init_const + stride) as i128, iv_width),
285 ty: iv_ty.clone(),
286 span: dummy_span,
287 });
288 func.block_mut(clone_latch).terminator =
289 Some(Terminator::Branch(lp.header, vec![next_init_id]));
290
291 // --- Redirect clone's cmp exit ---
292 // Find the cmp block in the clone (the one with a CondBranch whose
293 // false-dest was the clone's exit). Redirect its false-dest to the
294 // ORIGINAL exit block with the original exit args.
295 for &orig_bid in &lp.body {
296 let block = func.block(orig_bid);
297 if let Some(Terminator::CondBranch { false_dest, .. }) = &block.terminator {
298 if *false_dest == exit_id {
299 // This is the cmp block in the original. Find its clone.
300 let clone_cmp = block_map[&orig_bid];
301 // The clone's false-dest currently points to the cloned exit
302 // (which doesn't exist as a separate block since exit_id is
303 // outside the loop body and wasn't cloned). Actually, clone_loop
304 // only clones blocks IN the body, so the exit_id wasn't cloned.
305 // The clone's terminator already points to exit_id — but with
306 // remapped args. We need to ensure the exit_args are correct.
307 //
308 // clone_loop's remap_terminator remaps values but not block
309 // targets that are OUTSIDE the body. So the clone's cmp false-dest
310 // already points to the original exit_id. The false_args were
311 // remapped through val_map. This should be correct as-is.
312 //
313 // However, we need to double-check: the exit args in the clone
314 // might reference cloned values that should reference originals.
315 // Actually no — for the peeled iteration, the cloned values ARE
316 // the correct ones (they compute the exit condition for iteration 1).
317 let _ = clone_cmp; // Already correct — exit_id is outside body.
318 break;
319 }
320 }
321 }
322
323 // --- Wire preheader → clone's header ---
324 let clone_header = block_map[&lp.header];
325 func.block_mut(ph_id).terminator = Some(Terminator::Branch(clone_header, vec![init_val]));
326 }
327
328 // ---------------------------------------------------------------------------
329 // Tests
330 // ---------------------------------------------------------------------------
331
332 #[cfg(test)]
333 mod tests {
334 use super::*;
335 use crate::ir::inst::*;
336 use crate::ir::types::{IntWidth, IrType};
337 use crate::ir::verify::verify_module;
338 use crate::lexer::{Position, Span};
339 use crate::opt::pass::Pass;
340
341 fn span() -> Span {
342 let pos = Position { line: 0, col: 0 };
343 Span {
344 file_id: 0,
345 start: pos,
346 end: pos,
347 }
348 }
349
350 #[test]
351 fn peel_no_op_on_empty() {
352 let mut m = Module::new("test".into());
353 let mut f = Function::new("test".into(), vec![], IrType::Void);
354 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
355 m.add_function(f);
356 let pass = LoopPeel;
357 let changed = pass.run(&mut m);
358 assert!(!changed, "no loops → no peeling");
359 }
360
361 #[test]
362 fn peel_no_op_without_eq_check() {
363 // Loop with no `if (i == init)` → should not peel.
364 let mut m = Module::new("test".into());
365 let mut f = Function::new("test".into(), vec![], IrType::Void);
366
367 let header = f.create_block("header");
368 let cmp = f.create_block("cmp");
369 let body = f.create_block("body");
370 let latch = f.create_block("latch");
371 let exit = f.create_block("exit");
372 let entry = f.entry;
373
374 let c1 = f.next_value_id();
375 f.register_type(c1, IrType::Int(IntWidth::I32));
376 f.block_mut(entry).insts.push(Inst {
377 id: c1,
378 ty: IrType::Int(IntWidth::I32),
379 span: span(),
380 kind: InstKind::ConstInt(1, IntWidth::I32),
381 });
382 let c10 = f.next_value_id();
383 f.register_type(c10, IrType::Int(IntWidth::I32));
384 f.block_mut(entry).insts.push(Inst {
385 id: c10,
386 ty: IrType::Int(IntWidth::I32),
387 span: span(),
388 kind: InstKind::ConstInt(10, IntWidth::I32),
389 });
390 f.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![c1]));
391
392 let iv = f.next_value_id();
393 f.register_type(iv, IrType::Int(IntWidth::I32));
394 f.block_mut(header).params.push(BlockParam {
395 id: iv,
396 ty: IrType::Int(IntWidth::I32),
397 });
398 f.block_mut(header).terminator = Some(Terminator::Branch(cmp, vec![]));
399
400 let cmp_v = f.next_value_id();
401 f.register_type(cmp_v, IrType::Bool);
402 f.block_mut(cmp).insts.push(Inst {
403 id: cmp_v,
404 ty: IrType::Bool,
405 span: span(),
406 kind: InstKind::ICmp(CmpOp::Le, iv, c10),
407 });
408 let exit_seed = f.next_value_id();
409 f.register_type(exit_seed, IrType::Int(IntWidth::I32));
410 f.block_mut(cmp).insts.push(Inst {
411 id: exit_seed,
412 ty: IrType::Int(IntWidth::I32),
413 span: span(),
414 kind: InstKind::IAdd(iv, c1),
415 });
416 f.block_mut(cmp).terminator = Some(Terminator::CondBranch {
417 cond: cmp_v,
418 true_dest: body,
419 true_args: vec![],
420 false_dest: exit,
421 false_args: vec![],
422 });
423
424 // Body has NO equality check — just a store.
425 let alloca = f.next_value_id();
426 f.register_type(alloca, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))));
427 f.block_mut(body).insts.push(Inst {
428 id: alloca,
429 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
430 span: span(),
431 kind: InstKind::Alloca(IrType::Int(IntWidth::I32)),
432 });
433 let store_id = f.next_value_id();
434 f.register_type(store_id, IrType::Void);
435 f.block_mut(body).insts.push(Inst {
436 id: store_id,
437 ty: IrType::Void,
438 span: span(),
439 kind: InstKind::Store(iv, alloca),
440 });
441 f.block_mut(body).terminator = Some(Terminator::Branch(latch, vec![]));
442
443 let nxt = f.next_value_id();
444 f.register_type(nxt, IrType::Int(IntWidth::I32));
445 f.block_mut(latch).insts.push(Inst {
446 id: nxt,
447 ty: IrType::Int(IntWidth::I32),
448 span: span(),
449 kind: InstKind::IAdd(iv, c1),
450 });
451 f.block_mut(latch).terminator = Some(Terminator::Branch(header, vec![nxt]));
452
453 f.block_mut(exit).terminator = Some(Terminator::Return(None));
454 m.add_function(f);
455
456 let pass = LoopPeel;
457 let changed = pass.run(&mut m);
458 assert!(!changed, "loop without i==init check should not be peeled");
459 }
460
461 #[test]
462 fn peel_skips_loops_with_direct_exit_uses_of_loop_values() {
463 let mut m = Module::new("test".into());
464 let mut f = Function::new("test".into(), vec![], IrType::Void);
465
466 let preheader = f.create_block("preheader");
467 let header = f.create_block("header");
468 let cmp = f.create_block("cmp");
469 let body = f.create_block("body");
470 let latch = f.create_block("latch");
471 let exit = f.create_block("exit");
472 let entry = f.entry;
473
474 let c1 = f.next_value_id();
475 f.register_type(c1, IrType::Int(IntWidth::I32));
476 f.block_mut(entry).insts.push(Inst {
477 id: c1,
478 ty: IrType::Int(IntWidth::I32),
479 span: span(),
480 kind: InstKind::ConstInt(1, IntWidth::I32),
481 });
482 let c10 = f.next_value_id();
483 f.register_type(c10, IrType::Int(IntWidth::I32));
484 f.block_mut(entry).insts.push(Inst {
485 id: c10,
486 ty: IrType::Int(IntWidth::I32),
487 span: span(),
488 kind: InstKind::ConstInt(10, IntWidth::I32),
489 });
490 f.block_mut(entry).terminator = Some(Terminator::Branch(preheader, vec![]));
491 f.block_mut(preheader).terminator = Some(Terminator::Branch(header, vec![c1]));
492
493 let iv = f.next_value_id();
494 f.register_type(iv, IrType::Int(IntWidth::I32));
495 f.block_mut(header).params.push(BlockParam {
496 id: iv,
497 ty: IrType::Int(IntWidth::I32),
498 });
499 f.block_mut(header).terminator = Some(Terminator::Branch(cmp, vec![]));
500
501 let cmp_v = f.next_value_id();
502 f.register_type(cmp_v, IrType::Bool);
503 f.block_mut(cmp).insts.push(Inst {
504 id: cmp_v,
505 ty: IrType::Bool,
506 span: span(),
507 kind: InstKind::ICmp(CmpOp::Le, iv, c10),
508 });
509 let exit_seed = f.next_value_id();
510 f.register_type(exit_seed, IrType::Int(IntWidth::I32));
511 f.block_mut(cmp).insts.push(Inst {
512 id: exit_seed,
513 ty: IrType::Int(IntWidth::I32),
514 span: span(),
515 kind: InstKind::IAdd(iv, c1),
516 });
517 f.block_mut(cmp).terminator = Some(Terminator::CondBranch {
518 cond: cmp_v,
519 true_dest: body,
520 true_args: vec![],
521 false_dest: exit,
522 false_args: vec![],
523 });
524
525 let first_iter = f.next_value_id();
526 f.register_type(first_iter, IrType::Bool);
527 f.block_mut(body).insts.push(Inst {
528 id: first_iter,
529 ty: IrType::Bool,
530 span: span(),
531 kind: InstKind::ICmp(CmpOp::Eq, iv, c1),
532 });
533 let alloca = f.next_value_id();
534 f.register_type(alloca, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))));
535 f.block_mut(body).insts.push(Inst {
536 id: alloca,
537 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
538 span: span(),
539 kind: InstKind::Alloca(IrType::Int(IntWidth::I32)),
540 });
541 let store_id = f.next_value_id();
542 f.register_type(store_id, IrType::Void);
543 f.block_mut(body).insts.push(Inst {
544 id: store_id,
545 ty: IrType::Void,
546 span: span(),
547 kind: InstKind::Store(iv, alloca),
548 });
549 f.block_mut(body).terminator = Some(Terminator::Branch(latch, vec![]));
550
551 let nxt = f.next_value_id();
552 f.register_type(nxt, IrType::Int(IntWidth::I32));
553 f.block_mut(latch).insts.push(Inst {
554 id: nxt,
555 ty: IrType::Int(IntWidth::I32),
556 span: span(),
557 kind: InstKind::IAdd(iv, c1),
558 });
559 f.block_mut(latch).terminator = Some(Terminator::Branch(header, vec![nxt]));
560
561 let escaped = f.next_value_id();
562 f.register_type(escaped, IrType::Int(IntWidth::I32));
563 f.block_mut(exit).insts.push(Inst {
564 id: escaped,
565 ty: IrType::Int(IntWidth::I32),
566 span: span(),
567 kind: InstKind::IAdd(exit_seed, c1),
568 });
569 f.block_mut(exit).terminator = Some(Terminator::Return(None));
570
571 m.add_function(f);
572 assert!(
573 verify_module(&m).is_empty(),
574 "test setup must start valid before peeling"
575 );
576
577 let pass = LoopPeel;
578 let changed = pass.run(&mut m);
579 assert!(
580 !changed,
581 "peeling should skip loops whose exit reads loop-defined SSA values directly"
582 );
583 assert!(
584 verify_module(&m).is_empty(),
585 "skipping the peel should keep the IR valid"
586 );
587 }
588 }
589