Rust · 122759 bytes Raw Blame History
1 //! Loop unrolling pass (O2+).
2 //!
3 //! Fully unrolls simple counted loops whose trip count is statically
4 //! known and small. Targets the common Fortran `do i = lo, hi` pattern
5 //! after mem2reg has converted the loop induction variable to an SSA
6 //! block parameter.
7 //!
8 //! ## What constitutes a "simple counted loop"
9 //!
10 //! After mem2reg runs, a canonicalized `do i = lo, hi` with stride 1
11 //! produces this 2-block structure:
12 //!
13 //! ```text
14 //! preheader:
15 //! ...
16 //! br header(const(lo))
17 //!
18 //! header(%i):
19 //! %cmp = icmp.sle %i, const(hi) ; or icmp.slt, hi+1
20 //! condBr %cmp, latch([]), exit([])
21 //!
22 //! latch:
23 //! ... body instructions using %i ...
24 //! %i_next = iadd %i, const(1)
25 //! br header(%i_next)
26 //!
27 //! exit:
28 //! ...
29 //! ```
30 //!
31 //! We require:
32 //! - Exactly 1 latch (single back-edge).
33 //! - Body is exactly `{header, latch}` — no additional blocks (i.e., no IF
34 //! inside the loop, no nested loops).
35 //! - Header has exactly 1 block parameter (the IV).
36 //! - IV's initial value, the loop bound, and the stride are all compile-time
37 //! constants discoverable from instruction operands.
38 //! - Stride is exactly 1.
39 //! - Trip count ≤ `FULL_UNROLL_MAX` for ordinary `DO`, or
40 //! `DO_CONCURRENT_FULL_UNROLL_MAX` for lowered `DO CONCURRENT`.
41 //! - Body instruction count ≤ `BODY_SIZE_MAX`.
42 //! - No instruction defined in the latch is used outside the loop (i.e., no
43 //! values escape the unrolled body — they'll be dead after unrolling). This
44 //! covers array loops (`a(i) = expr`) but excludes reduction loops
45 //! (`s = s + a(i)`) where the accumulator IS a second block param and
46 //! would escape.
47 //!
48 //! ## Transformation
49 //!
50 //! For trip count TC, we:
51 //! 1. Clone the latch body TC times into the preheader (substituting the IV
52 //! with the constant value for that iteration).
53 //! 2. Rewrite the preheader's terminator to jump directly to `exit`.
54 //! 3. Call `prune_unreachable` to remove the now-dead header and latch.
55 //!
56 //! After this, the loop is entirely gone.
57
58 use super::pass::Pass;
59 use super::util::{find_natural_loops, predecessors, NaturalLoop};
60 use crate::ir::inst::*;
61 use crate::ir::types::{IntWidth, IrType};
62 use crate::ir::walk::prune_unreachable;
63 use crate::lexer::{Position, Span};
64 use std::collections::{HashMap, HashSet};
65
66 fn dummy_span() -> Span {
67 let pos = Position { line: 0, col: 0 };
68 Span {
69 file_id: 0,
70 start: pos,
71 end: pos,
72 }
73 }
74
75 // ---------------------------------------------------------------------------
76 // Thresholds
77 // ---------------------------------------------------------------------------
78
79 /// Maximum trip count eligible for full unrolling.
80 ///
81 /// Trip counts above this are left for sprint 29.6's partial unroller
82 /// which can handle dynamic trip counts with remainder loops.
83 const FULL_UNROLL_MAX: i64 = 8;
84
85 /// `DO CONCURRENT` is a stronger signal that iterations are independent,
86 /// so we allow a slightly larger full-unroll budget.
87 const DO_CONCURRENT_FULL_UNROLL_MAX: i64 = 16;
88
89 /// Maximum instruction count in the latch block.
90 ///
91 /// Prevents code bloat from unrolling large bodies.
92 const BODY_SIZE_MAX: usize = 30;
93
94 // ---------------------------------------------------------------------------
95 // Public pass entry point
96 // ---------------------------------------------------------------------------
97
98 pub struct LoopUnroll;
99
100 impl Pass for LoopUnroll {
101 fn name(&self) -> &'static str {
102 "loop-unroll"
103 }
104
105 fn run(&self, module: &mut Module) -> bool {
106 let mut changed = false;
107 for func in &mut module.functions {
108 changed |= unroll_in_function(func);
109 }
110 changed
111 }
112 }
113
114 // ---------------------------------------------------------------------------
115 // Per-function entry
116 // ---------------------------------------------------------------------------
117
118 fn unroll_in_function(func: &mut Function) -> bool {
119 let loops = find_natural_loops(func);
120 let preds = predecessors(func);
121 let mut changed = false;
122
123 for nl in &loops {
124 if let Some(shape) = detect_simple_loop(func, nl, &preds) {
125 do_unroll(func, shape);
126 prune_unreachable(func);
127 changed = true;
128 // After structural changes, re-running loop detection in the next
129 // pass manager iteration handles any inner loops that became
130 // exposed. Don't iterate here — return and let the pass manager
131 // fix-point loop drive re-runs.
132 break;
133 }
134 // Try partial unrolling for trip > FULL_UNROLL_MAX. Keeps the
135 // loop intact but clones body U times and bumps step by U.
136 if let Some(shape) = detect_partial_unroll_loop(func, nl, &preds) {
137 do_partial_unroll(func, shape);
138 changed = true;
139 break;
140 }
141 // Same shape but with a runtime (non-constant) bound — emits
142 // a head_bound computation in the preheader and a scalar
143 // remainder loop after the unrolled main loop.
144 if let Some(shape) = detect_partial_unroll_runtime_loop(func, nl, &preds) {
145 do_partial_unroll_runtime(func, shape);
146 changed = true;
147 break;
148 }
149 // Multi-block partial unroll: 3-block (or longer) linear-chain
150 // body. Allocates fresh BlockIds for each cloned iteration.
151 if let Some(shape) = detect_partial_unroll_multiblock_loop(func, nl, &preds) {
152 do_partial_unroll_multiblock(func, shape);
153 changed = true;
154 break;
155 }
156 }
157 changed
158 }
159
160 // ---------------------------------------------------------------------------
161 // Shape of a fully unrollable loop
162 // ---------------------------------------------------------------------------
163
164 #[derive(Debug)]
165 struct LoopShape {
166 preheader: BlockId,
167 header: BlockId, // block with IV block param; branches to cmp_block
168 cmp_block: BlockId, // block with icmp + condBr to body/exit
169 /// The computation blocks (everything that is not header, cmp_block, or
170 /// latch). Ordered in CFG traversal order from cmp_block's true-successor
171 /// to latch's predecessor. For most Fortran DO loops this is a single
172 /// block ("do_body"). Multi-block bodies are supported as long as the
173 /// sub-CFG is a linear chain with no branches to outside.
174 body_blocks: Vec<BlockId>,
175 latch: BlockId, // iadd IV, 1 + br header(IV_next)
176 exit: BlockId, // cond_br false target
177 iv_param: ValueId,
178 iv_ty: IrType,
179 iv_init: i64,
180 iv_bound: i64, // inclusive upper bound
181 trip_count: usize,
182 exit_args: Vec<ValueId>, // args from cmp_block's false branch to exit
183 /// Optional second header block-param representing a reduction
184 /// accumulator (sum, product, ...). When present, the unroller
185 /// threads its value across iterations: each iteration's body
186 /// reads the previous iteration's latch result instead of the
187 /// header's join.
188 reduction: Option<ReductionInfo>,
189 }
190
191 /// Per-loop reduction lane. mem2reg promotes a sum/product accumulator
192 /// into a header block param; the latch's `br header(iv_next, new_acc)`
193 /// passes the per-iteration update on the second lane. To unroll such a
194 /// loop we need to know:
195 ///
196 /// * which header param carries the accumulator (`acc_param`),
197 /// * its preheader-supplied initial value (`acc_init`),
198 /// * what value the latch passes back as the new accumulator
199 /// (`latch_acc_value`) — that's the SSA value defined inside the
200 /// body that the unroller substitutes through across iterations.
201 #[derive(Debug, Clone, Copy)]
202 struct ReductionInfo {
203 acc_param: ValueId,
204 acc_init: ValueId,
205 latch_acc_value: ValueId,
206 }
207
208 fn is_do_concurrent_loop(
209 func: &Function,
210 header: BlockId,
211 cmp_block: BlockId,
212 body_blocks: &[BlockId],
213 latch: BlockId,
214 ) -> bool {
215 std::iter::once(header)
216 .chain(std::iter::once(cmp_block))
217 .chain(body_blocks.iter().copied())
218 .chain(std::iter::once(latch))
219 .any(|bb| func.block(bb).name.starts_with("doconc_"))
220 }
221
222 // ---------------------------------------------------------------------------
223 // Detection helpers
224 // ---------------------------------------------------------------------------
225
226 fn detect_simple_loop(
227 func: &Function,
228 nl: &NaturalLoop,
229 preds: &HashMap<BlockId, Vec<BlockId>>,
230 ) -> Option<LoopShape> {
231 // ---- structural requirements ----------------------------------------
232 if nl.latches.len() != 1 {
233 return None;
234 }
235 let latch = nl.latches[0];
236 if latch == nl.header {
237 return None;
238 } // no self-loop
239 // Body must have between 2 (header+latch) and 5 (header+cmp+body×N+latch) blocks.
240 if nl.body.len() < 2 || nl.body.len() > 5 {
241 return None;
242 }
243
244 let header = nl.header;
245
246 // ---- header must have 1 (pure IV) or 2 (IV + reduction acc)
247 // block params ---------------------------------------------------
248 let hdr = func.block(header);
249 if hdr.params.is_empty() || hdr.params.len() > 2 {
250 return None;
251 }
252 // mem2reg places the IV at index 0 and the reduction accumulator (if
253 // any) at index 1. We rely on that ordering — it matches the
254 // canonical lowering for `do i = 1, n; s = s + a(i); end do`.
255 let iv_param_idx = 0usize;
256 let acc_param_idx: Option<usize> = if hdr.params.len() == 2 { Some(1) } else { None };
257 let iv_param = hdr.params[iv_param_idx].id;
258 let iv_ty = hdr.params[iv_param_idx].ty.clone();
259 if !matches!(iv_ty, IrType::Int(_)) {
260 return None;
261 }
262 let acc_param: Option<ValueId> = acc_param_idx.map(|i| hdr.params[i].id);
263
264 // ---- find preheader -------------------------------------------------
265 let header_preds = preds.get(&header)?;
266 let mut outside: Vec<BlockId> = header_preds
267 .iter()
268 .copied()
269 .filter(|p| !nl.body.contains(p))
270 .collect();
271 outside.sort_by_key(|b| b.0);
272 outside.dedup();
273 if outside.len() != 1 {
274 return None;
275 }
276 let preheader = outside[0];
277
278 // Preheader must branch to header with exactly N args matching the
279 // header's param count. The IV arg is a const; the accumulator init
280 // can be any SSA value (typically a ConstInt(0) for sum, but we
281 // don't enforce that — we only need to capture the value to feed
282 // into iteration 0's body substitution).
283 let ph_blk = func.block(preheader);
284 let expected_ph_args = hdr.params.len();
285 let (iv_init, acc_init) = match &ph_blk.terminator {
286 Some(Terminator::Branch(dest, args))
287 if *dest == header && args.len() == expected_ph_args =>
288 {
289 let iv_init = resolve_const_int(func, args[iv_param_idx])?;
290 let acc_init = acc_param_idx.map(|i| args[i]);
291 (iv_init, acc_init)
292 }
293 _ => return None,
294 };
295
296 // ---- find the comparison block and exit ----------------------------
297 //
298 // Two cases:
299 // (a) 2-block loop: header IS the comparison block.
300 // (b) 4-block loop (canonical Fortran DO): header is a transparent
301 // relay block (0 instructions, br → cmp_block). cmp_block has
302 // the ICmp + CondBranch.
303 //
304 // In both cases we need (cmp_block, iv_bound, exit, exit_args).
305
306 let (cmp_block, iv_bound, exit, exit_args) = {
307 if let Some((bound, ex, args)) = detect_bound_and_exit(func, header, iv_param) {
308 // Case (a): header has the comparison directly.
309 (header, bound, ex, args)
310 } else {
311 // Case (b): header must be a transparent relay.
312 let hdr_blk = func.block(header);
313 if !hdr_blk.insts.is_empty() {
314 return None;
315 } // must have 0 instructions
316 let relay_target = match &hdr_blk.terminator {
317 Some(Terminator::Branch(t, args)) if args.is_empty() => *t,
318 _ => return None,
319 };
320 if !nl.body.contains(&relay_target) {
321 return None;
322 }
323 let (bound, ex, args) = detect_bound_and_exit(func, relay_target, iv_param)?;
324 (relay_target, bound, ex, args)
325 }
326 };
327
328 if nl.body.contains(&exit) {
329 return None;
330 }
331
332 // ---- latch: stride-1 increment, passes iv (and optionally acc)
333 // back to header ------------------------------------------------
334 let latch_info = check_latch(func, latch, header, iv_param, iv_param_idx, acc_param_idx)?;
335
336 // ---- compute body_blocks: body − {header, cmp_block, latch} --------
337 //
338 // Walk from cmp_block's true-successor to latch (exclusive) in CFG
339 // order. The sub-CFG must be a linear chain — any branch or divergence
340 // means the loop has internal control flow we don't handle here.
341 let body_blocks = collect_body_chain(func, cmp_block, latch, &nl.body)?;
342
343 let is_do_concurrent = is_do_concurrent_loop(func, header, cmp_block, &body_blocks, latch);
344 if !is_do_concurrent && body_blocks.iter().any(|&bb| block_contains_load(func, bb)) {
345 return None;
346 }
347
348 // ---- no loop-defined values escape outside the loop ----------------
349 //
350 // We must check ALL blocks that will be pruned after unrolling — not just
351 // body_blocks. In particular, `header` carries a block param (the IV) that
352 // mem2reg may have threaded into OUTER loop latches via dominance-frontier
353 // placement. If that param is used outside nl.body, the header's block param
354 // becomes undefined after we discard the header block.
355 for &bb in body_blocks
356 .iter()
357 .chain(std::iter::once(&header))
358 .chain(std::iter::once(&cmp_block))
359 .chain(std::iter::once(&latch))
360 {
361 if has_escaping_values(func, bb, &nl.body, preds) {
362 return None;
363 }
364 }
365
366 // ---- size check ------------------------------------------------------
367 let total_insts: usize = body_blocks.iter().map(|&b| func.block(b).insts.len()).sum();
368 if total_insts > BODY_SIZE_MAX {
369 return None;
370 }
371
372 let full_unroll_max = if is_do_concurrent {
373 DO_CONCURRENT_FULL_UNROLL_MAX
374 } else {
375 FULL_UNROLL_MAX
376 };
377
378 // ---- trip count ------------------------------------------------------
379 let trip_count_i64 = iv_bound - iv_init + 1;
380 if trip_count_i64 <= 0 {
381 return None;
382 }
383 if trip_count_i64 > full_unroll_max {
384 return None;
385 }
386
387 let reduction = match (acc_param, acc_init, latch_info.latch_acc_value) {
388 (Some(p), Some(init), Some(latch_val)) => {
389 // The latch accumulator value must be defined inside one
390 // of the body_blocks — otherwise it lives in the latch
391 // (which the unroller doesn't clone) and we'd drop the
392 // reduction op. Bail in that case rather than miscompile.
393 let defined_in_body = body_blocks
394 .iter()
395 .any(|&b| func.block(b).insts.iter().any(|i| i.id == latch_val));
396 if !defined_in_body {
397 return None;
398 }
399 Some(ReductionInfo {
400 acc_param: p,
401 acc_init: init,
402 latch_acc_value: latch_val,
403 })
404 }
405 (None, None, None) => None,
406 // Inconsistent state — fall through and reject the loop. Should
407 // not happen because the lane indices line up at the top.
408 _ => return None,
409 };
410
411 Some(LoopShape {
412 preheader,
413 header,
414 cmp_block,
415 body_blocks,
416 latch,
417 exit,
418 iv_param,
419 iv_ty,
420 iv_init,
421 iv_bound,
422 trip_count: trip_count_i64 as usize,
423 exit_args,
424 reduction,
425 })
426 }
427
428 /// Walk from the comparison block's true-successor to the latch, collecting
429 /// the "body" blocks in order. Returns None if the sub-CFG is not a linear
430 /// chain (no branches to outside, no divergence).
431 ///
432 /// Two structural cases:
433 ///
434 /// (a) 2-block loop — the cmp_block's true-successor IS the latch. The latch
435 /// contains both the body instructions and the IV increment. We return
436 /// `[latch]` so the body is cloned from the latch's instructions.
437 /// The IV increment (`iadd iv, 1`) will be included in the clone but its
438 /// result is unused in the cloned block (the terminator is rewritten to
439 /// jump to the next iteration, discarding the iadd result). DCE cleans it.
440 ///
441 /// (b) 4-block loop — cmp_block's true-successor is a separate body block
442 /// distinct from the latch. We walk the chain through the latch and
443 /// include it in the clone set because real lowered loops may still
444 /// carry side effects there in addition to the IV increment.
445 fn collect_body_chain(
446 func: &Function,
447 cmp_block: BlockId,
448 latch: BlockId,
449 loop_body: &HashSet<BlockId>,
450 ) -> Option<Vec<BlockId>> {
451 // Find the true-successor of cmp_block (the first body block).
452 let first = match &func.block(cmp_block).terminator {
453 Some(Terminator::CondBranch { true_dest, .. }) => *true_dest,
454 _ => return None,
455 };
456 if !loop_body.contains(&first) {
457 return None;
458 }
459
460 if first == latch {
461 // Case (a): the cmp's true-successor is the latch itself.
462 // The latch IS the body — include it in the clone chain.
463 return Some(vec![latch]);
464 }
465
466 // Case (b): walk the chain from first through latch. The latch often
467 // contains only the IV increment, but real fused/lowered loops can keep
468 // user-visible side effects there too, so excluding it is unsound.
469 let mut chain = Vec::new();
470 let mut cur = first;
471 loop {
472 if !loop_body.contains(&cur) {
473 return None;
474 }
475 chain.push(cur);
476 if cur == latch {
477 break;
478 }
479 let blk = func.block(cur);
480 // Each block in the chain must have no params (no join point).
481 if !blk.params.is_empty() {
482 return None;
483 }
484 // Its terminator must be an unconditional branch to the next block.
485 match &blk.terminator {
486 Some(Terminator::Branch(next, args)) if args.is_empty() => {
487 cur = *next;
488 }
489 _ => return None,
490 }
491 if chain.len() > 4 {
492 return None;
493 } // safety limit
494 }
495 Some(chain)
496 }
497
498 /// Delegate to shared loop utility.
499 fn resolve_const_int(func: &Function, vid: ValueId) -> Option<i64> {
500 super::loop_utils::resolve_const_int(func, vid)
501 }
502
503 /// Inspect the header to find the loop bound and exit block.
504 ///
505 /// Accepts:
506 /// - `ICmp(Sle, iv, const(hi))` → bound = hi, true_dest = latch, false_dest = exit
507 /// - `ICmp(Slt, iv, const(hi))` → bound = hi-1 (exclusive → inclusive), same
508 /// - `ICmp(Sge, const(hi), iv)` → bound = hi (commuted)
509 /// - `ICmp(Sgt, const(hi), iv)` → bound = hi-1 (commuted exclusive)
510 ///
511 /// Returns `(inclusive_bound, exit_block_id, exit_args)`.
512 fn detect_bound_and_exit(
513 func: &Function,
514 header: BlockId,
515 iv: ValueId,
516 ) -> Option<(i64, BlockId, Vec<ValueId>)> {
517 let hdr = func.block(header);
518
519 // Find the single ICmp that involves the IV.
520 let mut cmp_id: Option<ValueId> = None;
521 let mut bound: Option<i64> = None;
522 let mut is_lt = false; // true when we need bound-1
523
524 for inst in &hdr.insts {
525 if let InstKind::ICmp(op, a, b) = &inst.kind {
526 let other = if *a == iv {
527 *b
528 } else if *b == iv {
529 *a
530 } else {
531 continue;
532 };
533 let c = resolve_const_int(func, other)?;
534 match op {
535 CmpOp::Le => {
536 bound = Some(c);
537 is_lt = false;
538 }
539 CmpOp::Lt => {
540 bound = Some(c);
541 is_lt = true;
542 }
543 CmpOp::Ge => {
544 bound = Some(c);
545 is_lt = false;
546 } // iv >= c means upper = c when commuted
547 CmpOp::Gt => {
548 bound = Some(c);
549 is_lt = true;
550 } // iv > c means upper = c-1 commuted
551 _ => return None,
552 }
553 cmp_id = Some(inst.id);
554 break;
555 }
556 }
557 let cmp_id = cmp_id?;
558 let raw_bound = bound?;
559 let inclusive_bound = if is_lt { raw_bound - 1 } else { raw_bound };
560
561 // Terminator must be CondBranch on that cmp.
562 match &hdr.terminator {
563 Some(Terminator::CondBranch {
564 cond,
565 true_dest,
566 true_args,
567 false_dest,
568 false_args,
569 }) if *cond == cmp_id => {
570 // true → body (latch), false → exit.
571 // Check which side is the latch. We don't have the latch id here;
572 // we'll accept either ordering and return the "false" side as exit.
573 // The caller verifies that the exit is not in the loop body.
574 let _ = true_args;
575 Some((inclusive_bound, *false_dest, false_args.clone()))
576 }
577 _ => None,
578 }
579 }
580
581 /// Verify that the latch block increments the IV by 1 and passes only
582 /// the new IV value back to the header (stride-1 check).
583 ///
584 /// Accepted pattern:
585 /// ```text
586 /// %i_next = iadd %iv, const(1)
587 /// br header(%i_next)
588 /// ```
589 /// Validate the latch terminator and (when applicable) capture the
590 /// reduction lane.
591 ///
592 /// Accepts either:
593 /// * `br header(%iv_next)` — pure-IV loop, returns
594 /// `Some(LatchInfo { latch_acc_value: None })`.
595 /// * `br header(%iv_next, %new_acc)` — reduction loop, where
596 /// `acc_lane_idx` (passed by the caller as 0 or 1) tells us which
597 /// branch arg index carries the accumulator. Returns
598 /// `Some(LatchInfo { latch_acc_value: Some(%new_acc) })`.
599 ///
600 /// In both cases the IV stride must be `iadd iv, const(1)` defined in
601 /// the latch.
602 fn check_latch(
603 func: &Function,
604 latch: BlockId,
605 header: BlockId,
606 iv: ValueId,
607 iv_lane_idx: usize,
608 acc_lane_idx: Option<usize>,
609 ) -> Option<LatchInfo> {
610 let blk = func.block(latch);
611
612 let expected_args = if acc_lane_idx.is_some() { 2 } else { 1 };
613 let args = match &blk.terminator {
614 Some(Terminator::Branch(dest, args))
615 if *dest == header && args.len() == expected_args =>
616 {
617 args.clone()
618 }
619 _ => return None,
620 };
621
622 let iv_next = args[iv_lane_idx];
623 let acc_value = acc_lane_idx.map(|i| args[i]);
624
625 // The iv lane's value must be defined as `iadd %iv, const(1)`
626 // inside the latch.
627 let mut iv_ok = false;
628 for inst in &blk.insts {
629 if inst.id == iv_next {
630 if let InstKind::IAdd(a, b) = &inst.kind {
631 let (is_iv, other) = if *a == iv {
632 (true, *b)
633 } else if *b == iv {
634 (true, *a)
635 } else {
636 (false, *a)
637 };
638 if !is_iv {
639 return None;
640 }
641 if resolve_const_int(func, other) != Some(1) {
642 return None;
643 }
644 iv_ok = true;
645 } else {
646 return None;
647 }
648 break;
649 }
650 }
651 if !iv_ok {
652 return None;
653 }
654
655 Some(LatchInfo {
656 latch_acc_value: acc_value,
657 })
658 }
659
660 #[derive(Debug, Clone, Copy)]
661 struct LatchInfo {
662 latch_acc_value: Option<ValueId>,
663 }
664
665 /// Return true if any instruction result defined in the latch is used
666 /// in a block that is NOT part of the loop body.
667 ///
668 /// This ensures no latch-computed value leaks out into the exit or
669 /// subsequent blocks. (Such loops have loop-carried dependencies that
670 /// require different handling.)
671 fn has_escaping_values(
672 func: &Function,
673 latch: BlockId,
674 body: &HashSet<BlockId>,
675 preds: &HashMap<BlockId, Vec<BlockId>>,
676 ) -> bool {
677 // Collect all ValueIds defined in the block — both block params and
678 // instruction results. Block params matter because mem2reg places them at
679 // dominance-frontier blocks; if such a param escapes into an outer-loop
680 // latch it becomes undefined after unrolling removes the block.
681 let latch_blk = func.block(latch);
682 let mut latch_defs: HashSet<ValueId> = HashSet::new();
683 for bp in &latch_blk.params {
684 latch_defs.insert(bp.id);
685 }
686 for inst in &latch_blk.insts {
687 latch_defs.insert(inst.id);
688 }
689
690 if latch_defs.is_empty() {
691 return false;
692 }
693
694 // Check every block outside the loop body for uses of those values.
695 for block in &func.blocks {
696 if body.contains(&block.id) {
697 continue;
698 }
699 // Check instruction operands.
700 for inst in &block.insts {
701 for use_id in inst_uses(&inst.kind) {
702 if latch_defs.contains(&use_id) {
703 return true;
704 }
705 }
706 }
707 // Check terminator operands.
708 if let Some(term) = &block.terminator {
709 for &use_id in terminator_uses_vec(term).iter() {
710 if latch_defs.contains(&use_id) {
711 return true;
712 }
713 }
714 }
715 // Also check block params of outside blocks that receive args from latch.
716 // (If the latch branches to an outside block with args, those args
717 // might carry latch values. But `check_latch` already restricts the
718 // latch terminator to `br header(%iv_next)`, so this can't happen for
719 // loops we accept. Kept as belt-and-suspenders.)
720 }
721 let _ = preds;
722 false
723 }
724
725 fn block_contains_load(func: &Function, block: BlockId) -> bool {
726 func.block(block)
727 .insts
728 .iter()
729 .any(|inst| matches!(inst.kind, InstKind::Load(_)))
730 }
731
732 /// Collect all ValueId operands in a terminator into a Vec.
733 /// (Mirrors `terminator_uses` from walk.rs but returns owned storage.)
734 fn terminator_uses_vec(term: &Terminator) -> Vec<ValueId> {
735 let mut out = Vec::new();
736 match term {
737 Terminator::Return(Some(v)) => out.push(*v),
738 Terminator::Branch(_, args) => out.extend(args),
739 Terminator::CondBranch {
740 cond,
741 true_args,
742 false_args,
743 ..
744 } => {
745 out.push(*cond);
746 out.extend(true_args);
747 out.extend(false_args);
748 }
749 Terminator::Switch { selector, .. } => out.push(*selector),
750 _ => {}
751 }
752 out
753 }
754
755 /// Import inst_uses from the walk module for convenience.
756 fn inst_uses(kind: &InstKind) -> Vec<ValueId> {
757 crate::ir::walk::inst_uses(kind)
758 }
759
760 // ---------------------------------------------------------------------------
761 // Transformation
762 // ---------------------------------------------------------------------------
763
764 fn do_unroll(func: &mut Function, shape: LoopShape) {
765 // Collect per-body-block instructions once (before we mutate func).
766 let body_snapshots: Vec<Vec<Inst>> = shape
767 .body_blocks
768 .iter()
769 .map(|&b| func.block(b).insts.clone())
770 .collect();
771
772 // Determine the IV width.
773 let iv_width = match &shape.iv_ty {
774 IrType::Int(w) => *w,
775 _ => IntWidth::I32,
776 };
777
778 // For each iteration, we create one new block per body block, cloned
779 // with the IV substituted by a constant. The chain is:
780 //
781 // preheader → iter0_body0 → iter0_body1 → ... → iter0_bodyN-1
782 // → iter1_body0 → ...
783 // → iter{TC-1}_body0 → ...
784 // → exit
785 //
786 // If body_blocks is empty (cmp_block directly targets latch), the
787 // iterations have no instructions and we just wire preheader → exit.
788
789 // Accumulate the block ids for each iteration so we can wire terminators.
790 let tc = shape.trip_count;
791 let nb = body_snapshots.len();
792
793 let mut iter_blocks: Vec<Vec<BlockId>> = Vec::with_capacity(tc);
794 let mut iter_substs: Vec<HashMap<ValueId, ValueId>> = Vec::with_capacity(tc);
795
796 // Reduction threading: prev_acc carries iteration k-1's latch
797 // accumulator value into iteration k's body substitution. For
798 // iteration 0 it's seeded from the preheader's branch arg; after
799 // each iteration completes, it's updated to that iter's body
800 // output (looked up through the iter's subst by the original
801 // latch_acc_value's ID).
802 let mut prev_acc: Option<ValueId> = shape.reduction.as_ref().map(|r| r.acc_init);
803
804 for k in 0..tc {
805 let iv_val = shape.iv_init + k as i64;
806
807 // Build substitution map seeded with iv_param → const(iv_val)
808 // and (when present) acc_param → previous iteration's latch
809 // accumulator value.
810 let mut subst: HashMap<ValueId, ValueId> = HashMap::new();
811
812 // Emit the IV constant (will be placed in each new block's first inst).
813 let iv_const_id = func.next_value_id();
814 subst.insert(shape.iv_param, iv_const_id);
815
816 if let (Some(r), Some(prev)) = (shape.reduction.as_ref(), prev_acc) {
817 subst.insert(r.acc_param, prev);
818 }
819
820 let mut blk_ids = Vec::with_capacity(nb.max(1));
821 if nb == 0 {
822 // No body blocks — create a single pass-through block.
823 let bid = func.create_block(&format!("unroll_k{}", k));
824 blk_ids.push(bid);
825 } else {
826 for (bi, snap) in body_snapshots.iter().enumerate() {
827 let bid = func.create_block(&format!("unroll_k{}b{}", k, bi));
828 // First block of this iter: emit IV constant.
829 if bi == 0 {
830 let iv_inst = Inst {
831 id: iv_const_id,
832 kind: InstKind::ConstInt(iv_val as i128, iv_width),
833 ty: shape.iv_ty.clone(),
834 span: dummy_span(),
835 };
836 func.block_mut(bid).insts.push(iv_inst);
837 }
838 // Clone body instructions.
839 for inst in snap {
840 let new_id = func.next_value_id();
841 subst.insert(inst.id, new_id);
842 let new_kind = remap_kind(&inst.kind, &subst);
843 func.block_mut(bid).insts.push(Inst {
844 id: new_id,
845 kind: new_kind,
846 ty: inst.ty.clone(),
847 span: inst.span,
848 });
849 }
850 blk_ids.push(bid);
851 }
852 }
853 // After this iteration's body has been cloned, look up the
854 // post-substitution form of the latch's accumulator value —
855 // that becomes the seed for the next iteration's acc_param
856 // substitution. If the latch's acc value wasn't defined in
857 // this body (e.g. an invariant or pre-loop value), fall back
858 // to the original ID; the unroller's other safety checks
859 // ensure that's still well-defined post-pruning.
860 if let Some(r) = shape.reduction.as_ref() {
861 prev_acc = Some(
862 subst
863 .get(&r.latch_acc_value)
864 .copied()
865 .unwrap_or(r.latch_acc_value),
866 );
867 }
868 iter_blocks.push(blk_ids);
869 iter_substs.push(subst);
870 }
871
872 // Wire terminators within each iteration (intra-iter body chain)
873 // and across iterations (last body block of iter k → first of iter k+1).
874 let original_body_next: Vec<BlockId> = {
875 // For each body block b, what block does it branch to?
876 // (It's either the next body block or the latch.)
877 let mut nexts = Vec::new();
878 for (bi, &b) in shape.body_blocks.iter().enumerate() {
879 let next = if bi + 1 < shape.body_blocks.len() {
880 shape.body_blocks[bi + 1]
881 } else {
882 shape.latch
883 };
884 let _ = b;
885 nexts.push(next);
886 }
887 nexts
888 };
889
890 for k in 0..tc {
891 let nb_actual = iter_blocks[k].len();
892 // Determine which block this iteration chains to after its last body block.
893 let after_iter = if k + 1 < tc {
894 // First block of next iteration.
895 iter_blocks[k + 1][0]
896 } else {
897 // Last iteration → exit.
898 shape.exit
899 };
900
901 // Args to pass when branching to the exit block. For
902 // reduction loops, the original `exit_args` references the
903 // header's acc_param — pruned along with the rest of the
904 // header. Replace the acc_param entry with the final
905 // iteration's accumulator value (`prev_acc`, set up after
906 // the per-iter cloning loop ran iteration N-1). Other
907 // entries are left alone, falling through to the verifier
908 // if they reference values that would survive (today's only
909 // supported shape has exactly the acc lane).
910 let exit_args = if after_iter == shape.exit {
911 if let (Some(r), Some(final_acc)) = (shape.reduction.as_ref(), prev_acc) {
912 shape
913 .exit_args
914 .iter()
915 .map(|&v| if v == r.acc_param { final_acc } else { v })
916 .collect()
917 } else {
918 shape.exit_args.clone()
919 }
920 } else {
921 vec![]
922 };
923
924 if nb == 0 {
925 // No body — the single pass-through block jumps to after_iter.
926 func.block_mut(iter_blocks[k][0]).terminator =
927 Some(Terminator::Branch(after_iter, exit_args));
928 } else {
929 for bi in 0..nb_actual {
930 let cur_blk = iter_blocks[k][bi];
931 let (next_blk, args) = if bi + 1 < nb_actual {
932 (iter_blocks[k][bi + 1], vec![])
933 } else {
934 // Last body block of iter k → after_iter.
935 (after_iter, exit_args.clone())
936 };
937 let _ = original_body_next[bi]; // consumed above
938 func.block_mut(cur_blk).terminator = Some(Terminator::Branch(next_blk, args));
939 }
940 }
941 }
942
943 // Rewrite preheader terminator: skip the loop header, jump to first iter.
944 let (first_block, preheader_args) = if tc > 0 {
945 (iter_blocks[0][0], vec![])
946 } else {
947 // Zero-trip loop: jump directly to exit with its required args.
948 (shape.exit, shape.exit_args.clone())
949 };
950 func.block_mut(shape.preheader).terminator =
951 Some(Terminator::Branch(first_block, preheader_args));
952
953 // Mark loop control blocks as Unreachable → prune_unreachable removes them.
954 func.block_mut(shape.header).terminator = Some(Terminator::Unreachable);
955 func.block_mut(shape.cmp_block).terminator = Some(Terminator::Unreachable);
956 func.block_mut(shape.latch).terminator = Some(Terminator::Unreachable);
957 for &b in &shape.body_blocks {
958 func.block_mut(b).terminator = Some(Terminator::Unreachable);
959 }
960 }
961
962 /// Deep-clone an InstKind, substituting all ValueId operands according to `subst`.
963 fn remap_kind(kind: &InstKind, subst: &HashMap<ValueId, ValueId>) -> InstKind {
964 let r = |v: &ValueId| *subst.get(v).unwrap_or(v);
965 match kind {
966 // Constants — no operands.
967 InstKind::ConstInt(v, w) => InstKind::ConstInt(*v, *w),
968 InstKind::ConstFloat(v, w) => InstKind::ConstFloat(*v, *w),
969 InstKind::ConstBool(v) => InstKind::ConstBool(*v),
970 InstKind::ConstString(v) => InstKind::ConstString(v.clone()),
971 InstKind::Undef(t) => InstKind::Undef(t.clone()),
972 InstKind::GlobalAddr(s) => InstKind::GlobalAddr(s.clone()),
973
974 // Integer arithmetic.
975 InstKind::IAdd(a, b) => InstKind::IAdd(r(a), r(b)),
976 InstKind::ISub(a, b) => InstKind::ISub(r(a), r(b)),
977 InstKind::IMul(a, b) => InstKind::IMul(r(a), r(b)),
978 InstKind::IDiv(a, b) => InstKind::IDiv(r(a), r(b)),
979 InstKind::IMod(a, b) => InstKind::IMod(r(a), r(b)),
980 InstKind::INeg(a) => InstKind::INeg(r(a)),
981
982 // Float arithmetic.
983 InstKind::FAdd(a, b) => InstKind::FAdd(r(a), r(b)),
984 InstKind::FSub(a, b) => InstKind::FSub(r(a), r(b)),
985 InstKind::FMul(a, b) => InstKind::FMul(r(a), r(b)),
986 InstKind::FDiv(a, b) => InstKind::FDiv(r(a), r(b)),
987 InstKind::FNeg(a) => InstKind::FNeg(r(a)),
988 InstKind::FAbs(a) => InstKind::FAbs(r(a)),
989 InstKind::FSqrt(a) => InstKind::FSqrt(r(a)),
990 InstKind::FPow(a, b) => InstKind::FPow(r(a), r(b)),
991
992 // Comparison.
993 InstKind::ICmp(op, a, b) => InstKind::ICmp(*op, r(a), r(b)),
994 InstKind::FCmp(op, a, b) => InstKind::FCmp(*op, r(a), r(b)),
995
996 // Logic.
997 InstKind::And(a, b) => InstKind::And(r(a), r(b)),
998 InstKind::Or(a, b) => InstKind::Or(r(a), r(b)),
999 InstKind::Not(a) => InstKind::Not(r(a)),
1000
1001 // Select.
1002 InstKind::Select(c, t, f) => InstKind::Select(r(c), r(t), r(f)),
1003
1004 // Bitwise.
1005 InstKind::BitAnd(a, b) => InstKind::BitAnd(r(a), r(b)),
1006 InstKind::BitOr(a, b) => InstKind::BitOr(r(a), r(b)),
1007 InstKind::BitXor(a, b) => InstKind::BitXor(r(a), r(b)),
1008 InstKind::BitNot(a) => InstKind::BitNot(r(a)),
1009 InstKind::Shl(a, b) => InstKind::Shl(r(a), r(b)),
1010 InstKind::LShr(a, b) => InstKind::LShr(r(a), r(b)),
1011 InstKind::AShr(a, b) => InstKind::AShr(r(a), r(b)),
1012 InstKind::CountLeadingZeros(a) => InstKind::CountLeadingZeros(r(a)),
1013 InstKind::CountTrailingZeros(a) => InstKind::CountTrailingZeros(r(a)),
1014 InstKind::PopCount(a) => InstKind::PopCount(r(a)),
1015
1016 // Conversions.
1017 InstKind::IntToFloat(a, w) => InstKind::IntToFloat(r(a), *w),
1018 InstKind::FloatToInt(a, w) => InstKind::FloatToInt(r(a), *w),
1019 InstKind::FloatExtend(a, w) => InstKind::FloatExtend(r(a), *w),
1020 InstKind::FloatTrunc(a, w) => InstKind::FloatTrunc(r(a), *w),
1021 InstKind::IntExtend(a, w, s) => InstKind::IntExtend(r(a), *w, *s),
1022 InstKind::IntTrunc(a, w) => InstKind::IntTrunc(r(a), *w),
1023 InstKind::PtrToInt(a) => InstKind::PtrToInt(r(a)),
1024 InstKind::IntToPtr(a, ty) => InstKind::IntToPtr(r(a), ty.clone()),
1025
1026 // Memory.
1027 InstKind::Alloca(t) => InstKind::Alloca(t.clone()),
1028 InstKind::Load(a) => InstKind::Load(r(a)),
1029 InstKind::Store(v, p) => InstKind::Store(r(v), r(p)),
1030 InstKind::GetElementPtr(base, idxs) => {
1031 InstKind::GetElementPtr(r(base), idxs.iter().map(&r).collect())
1032 }
1033
1034 // Calls.
1035 InstKind::Call(f, args) => InstKind::Call(f.clone(), args.iter().map(&r).collect()),
1036 InstKind::RuntimeCall(f, args) => {
1037 InstKind::RuntimeCall(f.clone(), args.iter().map(&r).collect())
1038 }
1039
1040 // Aggregates.
1041 InstKind::ExtractField(v, idx) => InstKind::ExtractField(r(v), *idx),
1042 InstKind::InsertField(v, idx, fld) => InstKind::InsertField(r(v), *idx, r(fld)),
1043
1044 // ---- SIMD vector ops ----
1045 InstKind::VAdd(a, b) => InstKind::VAdd(r(a), r(b)),
1046 InstKind::VSub(a, b) => InstKind::VSub(r(a), r(b)),
1047 InstKind::VMul(a, b) => InstKind::VMul(r(a), r(b)),
1048 InstKind::VDiv(a, b) => InstKind::VDiv(r(a), r(b)),
1049 InstKind::VNeg(a) => InstKind::VNeg(r(a)),
1050 InstKind::VAbs(a) => InstKind::VAbs(r(a)),
1051 InstKind::VSqrt(a) => InstKind::VSqrt(r(a)),
1052 InstKind::VFma(a, b, c) => InstKind::VFma(r(a), r(b), r(c)),
1053 InstKind::VSelect(m, t, f) => InstKind::VSelect(r(m), r(t), r(f)),
1054 InstKind::VMin(a, b) => InstKind::VMin(r(a), r(b)),
1055 InstKind::VMax(a, b) => InstKind::VMax(r(a), r(b)),
1056 InstKind::VICmp(op, a, b) => InstKind::VICmp(*op, r(a), r(b)),
1057 InstKind::VFCmp(op, a, b) => InstKind::VFCmp(*op, r(a), r(b)),
1058 InstKind::VLoad(p) => InstKind::VLoad(r(p)),
1059 InstKind::VStore(v, p) => InstKind::VStore(r(v), r(p)),
1060 InstKind::VBitcast(v, ty) => InstKind::VBitcast(r(v), ty.clone()),
1061 InstKind::VExtract(v, lane) => InstKind::VExtract(r(v), *lane),
1062 InstKind::VInsert(v, lane, s) => InstKind::VInsert(r(v), *lane, r(s)),
1063 InstKind::VBroadcast(s) => InstKind::VBroadcast(r(s)),
1064 InstKind::VReduceSum(v) => InstKind::VReduceSum(r(v)),
1065 InstKind::VReduceMin(v) => InstKind::VReduceMin(r(v)),
1066 InstKind::VReduceMax(v) => InstKind::VReduceMax(r(v)),
1067 }
1068 }
1069
1070 // ---------------------------------------------------------------------------
1071 // Partial unrolling
1072 // ---------------------------------------------------------------------------
1073 //
1074 // Loops with statically known trip counts above FULL_UNROLL_MAX don't
1075 // benefit from full unrolling (code bloat), but a U-way partial unroll
1076 // (clone body U-1 additional times, step IV by U) can still expose ILP
1077 // and reduce loop overhead. v0 handles the simplest shape:
1078 //
1079 // * 2-block loop (header + latch).
1080 // * Single block param (no reduction).
1081 // * Static trip count > FULL_UNROLL_MAX, divisible by `U`.
1082 // * Body instruction count × U ≤ PARTIAL_UNROLL_BODY_BUDGET.
1083 //
1084 // The transform clones the latch's body insts U-1 more times with
1085 // `iv` substituted by a freshly-computed `iv + k` (1 ≤ k < U), then
1086 // rewrites the latch's `iadd iv, 1` to `iadd iv, U`.
1087
1088 /// Largest unrolled-body size in instructions for partial unrolling.
1089 const PARTIAL_UNROLL_BODY_BUDGET: usize = 60;
1090 /// Largest unroll factor we'll apply.
1091 const PARTIAL_UNROLL_MAX_FACTOR: usize = 4;
1092
1093 #[derive(Debug)]
1094 struct PartialShape {
1095 header: BlockId,
1096 latch: BlockId,
1097 iv_param: ValueId,
1098 iv_ty: IrType,
1099 iv_init: i64,
1100 iv_bound: i64,
1101 /// Unroll factor.
1102 u: usize,
1103 /// The latch's `iadd iv, 1` instruction id.
1104 iadd_id: ValueId,
1105 /// The const(1) value feeding the iadd.
1106 step_const_id: ValueId,
1107 /// If the loop is a reduction (header has 2 params), this carries
1108 /// the acc param id and the latch's terminator's new-acc value.
1109 reduction: Option<PartialReductionInfo>,
1110 }
1111
1112 #[derive(Debug, Clone, Copy)]
1113 struct PartialReductionInfo {
1114 acc_param: ValueId,
1115 /// The value the latch passes back as the new accumulator
1116 /// (terminator's args[1]).
1117 new_acc: ValueId,
1118 }
1119
1120 fn detect_partial_unroll_loop(
1121 func: &Function,
1122 nl: &NaturalLoop,
1123 preds: &HashMap<BlockId, Vec<BlockId>>,
1124 ) -> Option<PartialShape> {
1125 if nl.latches.len() != 1 || nl.body.len() != 2 {
1126 return None;
1127 }
1128 let header = nl.header;
1129 let latch = nl.latches[0];
1130 if header == latch {
1131 return None;
1132 }
1133 let hdr = func.block(header);
1134 if hdr.params.is_empty() || hdr.params.len() > 2 {
1135 return None;
1136 }
1137 let iv_param = hdr.params[0].id;
1138 let iv_ty = hdr.params[0].ty.clone();
1139 if !matches!(iv_ty, IrType::Int(_)) {
1140 return None;
1141 }
1142 let acc_param: Option<ValueId> = if hdr.params.len() == 2 {
1143 Some(hdr.params[1].id)
1144 } else {
1145 None
1146 };
1147 // Header has the cmp + cond_br body/exit.
1148 let (iv_bound, exit, _exit_args) = detect_bound_and_exit(func, header, iv_param)?;
1149 if nl.body.contains(&exit) {
1150 return None;
1151 }
1152 // Preheader supplies iv_init.
1153 let header_preds = preds.get(&header)?;
1154 let mut outside: Vec<BlockId> = header_preds
1155 .iter()
1156 .copied()
1157 .filter(|p| !nl.body.contains(p))
1158 .collect();
1159 outside.sort_by_key(|b| b.0);
1160 outside.dedup();
1161 if outside.len() != 1 {
1162 return None;
1163 }
1164 let preheader = outside[0];
1165 let ph = func.block(preheader);
1166 let expected_args = if acc_param.is_some() { 2 } else { 1 };
1167 let iv_init = match &ph.terminator {
1168 Some(Terminator::Branch(d, args))
1169 if *d == header && args.len() == expected_args =>
1170 {
1171 resolve_const_int(func, args[0])?
1172 }
1173 _ => return None,
1174 };
1175 // Latch: `... insts ...; iadd iv, 1; br header(iv_next [, new_acc])`.
1176 let latch_blk = func.block(latch);
1177 let latch_term_args = match &latch_blk.terminator {
1178 Some(Terminator::Branch(d, args))
1179 if *d == header && args.len() == expected_args =>
1180 {
1181 args.clone()
1182 }
1183 _ => return None,
1184 };
1185 let iv_next = latch_term_args[0];
1186 let new_acc = if acc_param.is_some() {
1187 Some(latch_term_args[1])
1188 } else {
1189 None
1190 };
1191 let iadd_inst = latch_blk.insts.iter().find(|i| i.id == iv_next)?;
1192 let (lhs, rhs) = match iadd_inst.kind {
1193 InstKind::IAdd(l, r) => (l, r),
1194 _ => return None,
1195 };
1196 let step_const_id = if lhs == iv_param {
1197 rhs
1198 } else if rhs == iv_param {
1199 lhs
1200 } else {
1201 return None;
1202 };
1203 if resolve_const_int(func, step_const_id) != Some(1) {
1204 return None;
1205 }
1206 // No latch-defined value escapes the loop.
1207 let body_set: HashSet<BlockId> = nl.body.iter().copied().collect();
1208 if has_escaping_values(func, latch, &body_set, preds) {
1209 return None;
1210 }
1211 // Trip count gate: > FULL_UNROLL_MAX (full unroller would have
1212 // taken it otherwise) and divisible by some U in [2, MAX_FACTOR].
1213 let trip = iv_bound - iv_init + 1;
1214 if trip <= FULL_UNROLL_MAX {
1215 return None;
1216 }
1217 let body_inst_count = latch_blk.insts.len();
1218 if body_inst_count == 0 {
1219 return None;
1220 }
1221 // Heuristic: only partial-unroll reductions OR non-reduction
1222 // loops with at most one store. Multi-store init-style loops
1223 // (e.g., `a(i) = ...; b(i) = ...; c(i) = ...;`) cloned U-way
1224 // create too much register pressure across the function and
1225 // expose latent regalloc fragility (Sprint 18 follow-up).
1226 let store_count = latch_blk
1227 .insts
1228 .iter()
1229 .filter(|i| matches!(i.kind, InstKind::Store(..)))
1230 .count();
1231 if acc_param.is_none() && store_count > 1 {
1232 return None;
1233 }
1234 // Pick the largest U ≤ MAX_FACTOR that divides trip and respects
1235 // the body-size budget.
1236 let mut chosen_u: Option<usize> = None;
1237 for u in (2..=PARTIAL_UNROLL_MAX_FACTOR).rev() {
1238 if trip % (u as i64) != 0 {
1239 continue;
1240 }
1241 if body_inst_count * u > PARTIAL_UNROLL_BODY_BUDGET {
1242 continue;
1243 }
1244 chosen_u = Some(u);
1245 break;
1246 }
1247 let u = chosen_u?;
1248 let reduction = match (acc_param, new_acc) {
1249 (Some(p), Some(v)) => Some(PartialReductionInfo {
1250 acc_param: p,
1251 new_acc: v,
1252 }),
1253 (None, None) => None,
1254 _ => return None,
1255 };
1256 Some(PartialShape {
1257 header,
1258 latch,
1259 iv_param,
1260 iv_ty,
1261 iv_init,
1262 iv_bound,
1263 u,
1264 iadd_id: iv_next,
1265 step_const_id,
1266 reduction,
1267 })
1268 }
1269
1270 fn do_partial_unroll(func: &mut Function, shape: PartialShape) {
1271 // Snapshot of latch body BEFORE the iadd. We'll clone these
1272 // instructions U-1 more times with iv substituted.
1273 let latch_snapshot: Vec<Inst> = func.block(shape.latch).insts.clone();
1274 // Precompute body-without-iadd for cloning.
1275 let body_inst_clone: Vec<Inst> = latch_snapshot
1276 .iter()
1277 .filter(|i| i.id != shape.iadd_id)
1278 .cloned()
1279 .collect();
1280 let iv_width = match shape.iv_ty {
1281 IrType::Int(w) => w,
1282 _ => return,
1283 };
1284 let span = dummy_span();
1285 // For each k in 1..U: emit `iv_k = iadd iv, k_const`, then clone
1286 // every body inst (except iadd) with iv → iv_k. Append all of
1287 // these BEFORE the existing iadd, so the sequence is:
1288 // <orig body insts> ; uses iv
1289 // iv_1 = iadd iv, 1
1290 // <body cloned, iv → iv_1>
1291 // iv_2 = iadd iv, 2
1292 // <body cloned, iv → iv_2>
1293 // ...
1294 // iv_next = iadd iv, U
1295 // br header(iv_next)
1296 //
1297 // We construct the new instruction list off-band, then swap
1298 // it in.
1299 let mut new_insts: Vec<Inst> = latch_snapshot
1300 .iter()
1301 .filter(|i| i.id != shape.iadd_id)
1302 .cloned()
1303 .collect();
1304 // For reduction loops, track the running accumulator. After the
1305 // original (k=0) body, the running acc is the original new_acc.
1306 // Each subsequent clone substitutes acc_param → prev_acc and
1307 // produces a fresh new_acc (the cloned counterpart of orig new_acc).
1308 let mut prev_acc: Option<ValueId> = shape.reduction.map(|r| r.new_acc);
1309 for k in 1..shape.u {
1310 // Allocate the const(k) and the iv_k = iadd iv, k.
1311 let k_const_id = func.next_value_id();
1312 func.register_type(k_const_id, shape.iv_ty.clone());
1313 let iv_k_id = func.next_value_id();
1314 func.register_type(iv_k_id, shape.iv_ty.clone());
1315 new_insts.push(Inst {
1316 id: k_const_id,
1317 kind: InstKind::ConstInt(k as i128, iv_width),
1318 ty: shape.iv_ty.clone(),
1319 span,
1320 });
1321 new_insts.push(Inst {
1322 id: iv_k_id,
1323 kind: InstKind::IAdd(shape.iv_param, k_const_id),
1324 ty: shape.iv_ty.clone(),
1325 span,
1326 });
1327 // Clone the body (sans iadd) substituting iv → iv_k and
1328 // (for reductions) acc_param → prev_acc; each inst id →
1329 // fresh id.
1330 let mut subst: HashMap<ValueId, ValueId> = HashMap::new();
1331 subst.insert(shape.iv_param, iv_k_id);
1332 if let (Some(r), Some(prev)) = (shape.reduction, prev_acc) {
1333 subst.insert(r.acc_param, prev);
1334 }
1335 for orig in &body_inst_clone {
1336 let new_id = func.next_value_id();
1337 func.register_type(new_id, orig.ty.clone());
1338 subst.insert(orig.id, new_id);
1339 let new_kind = remap_kind(&orig.kind, &subst);
1340 new_insts.push(Inst {
1341 id: new_id,
1342 kind: new_kind,
1343 ty: orig.ty.clone(),
1344 span: orig.span,
1345 });
1346 }
1347 // Walk forward the accumulator: prev_acc becomes the cloned
1348 // counterpart of orig.new_acc.
1349 if let Some(r) = shape.reduction {
1350 prev_acc = subst.get(&r.new_acc).copied();
1351 }
1352 }
1353 // Finally, the new iadd: iv_next = iadd iv, U.
1354 let new_step_const = func.next_value_id();
1355 func.register_type(new_step_const, shape.iv_ty.clone());
1356 new_insts.push(Inst {
1357 id: new_step_const,
1358 kind: InstKind::ConstInt(shape.u as i128, iv_width),
1359 ty: shape.iv_ty.clone(),
1360 span,
1361 });
1362 // Replace iadd's step constant. Reuse the existing iadd id so
1363 // the latch terminator (which references it) stays valid.
1364 let new_iadd = Inst {
1365 id: shape.iadd_id,
1366 kind: InstKind::IAdd(shape.iv_param, new_step_const),
1367 ty: shape.iv_ty.clone(),
1368 span,
1369 };
1370 new_insts.push(new_iadd);
1371
1372 // Swap the latch's instruction list.
1373 func.block_mut(shape.latch).insts = new_insts;
1374
1375 // For reduction loops, retarget the latch terminator's args[1]
1376 // to the FINAL accumulator (after U-1 clones).
1377 if let (Some(_r), Some(final_acc)) = (shape.reduction, prev_acc) {
1378 let latch_blk = func.block_mut(shape.latch);
1379 if let Some(Terminator::Branch(_, args)) = &mut latch_blk.terminator {
1380 if args.len() == 2 {
1381 args[1] = final_acc;
1382 }
1383 }
1384 }
1385 }
1386
1387 // ---------------------------------------------------------------------------
1388 // Runtime-trip partial unroll
1389 // ---------------------------------------------------------------------------
1390 //
1391 // `detect_partial_unroll_loop` requires both `iv_init` and `iv_bound`
1392 // to be compile-time constants. Many real Fortran loops have runtime
1393 // trip counts (e.g., `do i = 1, n` where `n` is a subroutine arg).
1394 // For those we run a U-way unrolled main loop over a head range
1395 // `[init, init + head_count - 1]` and then a scalar remainder loop
1396 // over the leftover `[init + head_count, bound]`. The head_count is
1397 // computed at the top of the preheader as `((bound - init + 1) / U)
1398 // * U` using runtime arithmetic.
1399 //
1400 // v0 limitations:
1401 // * `iv_init` is still a compile-time constant.
1402 // * Non-reduction only (no acc threading across remainder).
1403 // * 2-block loop (header + latch).
1404 // * Body has at most one store (same regalloc-pressure heuristic
1405 // as the static partial unroller — see commit 1b94a26).
1406
1407 #[derive(Debug)]
1408 struct PartialRuntimeShape {
1409 preheader: BlockId,
1410 header: BlockId,
1411 latch: BlockId,
1412 /// Original loop exit (cond_br false-target on the header).
1413 exit: BlockId,
1414 iv_param: ValueId,
1415 iv_ty: IrType,
1416 iv_init: i64,
1417 /// Runtime SSA value of the loop bound (the rhs of `icmp.le iv, ?`
1418 /// or equivalent). Must be loop-invariant.
1419 iv_bound_v: ValueId,
1420 iv_width: IntWidth,
1421 /// Header's icmp inst id.
1422 cond_id: ValueId,
1423 /// Latch's `iadd iv, 1` inst id.
1424 iadd_id: ValueId,
1425 /// Const(1) feeding the iadd.
1426 step_const_id: ValueId,
1427 u: usize,
1428 /// Optional reduction. When `Some`, the header has 2 params
1429 /// (iv + acc); the latch passes (iv_next, new_acc) to header.
1430 /// Acc must be threaded through the unrolled body and the
1431 /// remainder loop.
1432 reduction: Option<RuntimeReduction>,
1433 }
1434
1435 #[derive(Debug, Clone)]
1436 struct RuntimeReduction {
1437 /// Header's accumulator block-param value id.
1438 acc_param: ValueId,
1439 /// Type of the accumulator.
1440 acc_ty: IrType,
1441 /// Initial accumulator value (preheader branch args[acc_idx]).
1442 acc_init: ValueId,
1443 /// New accumulator value computed in the latch (latch terminator
1444 /// args[acc_idx]) — typically a fadd / iadd / select chain.
1445 new_acc: ValueId,
1446 /// Position of the accumulator in header.params and the latch /
1447 /// preheader terminator args. mem2reg's ordering of phi nodes is
1448 /// not stable across loops — sometimes acc is at idx 0, iv at
1449 /// idx 1.
1450 acc_idx: usize,
1451 /// True when the original header's cond_br false_args carries
1452 /// the acc as a block-arg to exit. False when the exit references
1453 /// acc_param via dominance only (no block-arg passing). Affects
1454 /// how the apply path forwards the final acc to the original
1455 /// exit after rewiring the CFG through the remainder loop.
1456 exit_takes_acc: bool,
1457 }
1458
1459 fn detect_partial_unroll_runtime_loop(
1460 func: &Function,
1461 nl: &NaturalLoop,
1462 preds: &HashMap<BlockId, Vec<BlockId>>,
1463 ) -> Option<PartialRuntimeShape> {
1464 if nl.latches.len() != 1 || nl.body.len() != 2 {
1465 return None;
1466 }
1467 let header = nl.header;
1468 let latch = nl.latches[0];
1469 if header == latch {
1470 return None;
1471 }
1472 let hdr = func.block(header);
1473 // Header may be iv-only (no reduction) or iv + acc (reduction).
1474 if hdr.params.len() != 1 && hdr.params.len() != 2 {
1475 return None;
1476 }
1477 // Find the IV by inspecting the header's icmp — the param that
1478 // appears as the LHS of `icmp.le|.lt _, bound` is the IV. The
1479 // other param (when 2 params) is the accumulator.
1480 let icmp_lhs: Option<ValueId> = hdr.terminator.as_ref().and_then(|t| {
1481 if let Terminator::CondBranch { cond, .. } = t {
1482 hdr.insts.iter().find(|i| i.id == *cond).and_then(|c| match c.kind {
1483 InstKind::ICmp(CmpOp::Le, lhs, _) | InstKind::ICmp(CmpOp::Lt, lhs, _) => {
1484 Some(lhs)
1485 }
1486 _ => None,
1487 })
1488 } else {
1489 None
1490 }
1491 });
1492 let icmp_lhs = icmp_lhs?;
1493 let iv_idx = hdr.params.iter().position(|p| p.id == icmp_lhs)?;
1494 let iv_param = hdr.params[iv_idx].id;
1495 let iv_ty = hdr.params[iv_idx].ty.clone();
1496 let iv_width = match iv_ty {
1497 IrType::Int(w) => w,
1498 _ => return None,
1499 };
1500 let acc_param_info: Option<(ValueId, IrType)> = if hdr.params.len() == 2 {
1501 let other_idx = 1 - iv_idx;
1502 let p = &hdr.params[other_idx];
1503 Some((p.id, p.ty.clone()))
1504 } else {
1505 None
1506 };
1507 let acc_idx: Option<usize> = if hdr.params.len() == 2 { Some(1 - iv_idx) } else { None };
1508 // Header's icmp pattern. Accept icmp.le or icmp.lt with the IV on
1509 // the LHS; capture the rhs ValueId regardless of whether it's a
1510 // const.
1511 let (cond_id, true_dest, false_dest, rhs_v) = {
1512 let term = hdr.terminator.as_ref()?;
1513 let (cond, td, fd) = match term {
1514 Terminator::CondBranch {
1515 cond,
1516 true_dest,
1517 true_args,
1518 false_dest,
1519 false_args,
1520 } if true_args.is_empty() && false_args.is_empty() => (*cond, *true_dest, *false_dest),
1521 _ => return None,
1522 };
1523 let cmp = hdr.insts.iter().find(|i| i.id == cond)?;
1524 let rhs = match cmp.kind {
1525 InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => rhs,
1526 InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => rhs,
1527 _ => return None,
1528 };
1529 (cond, td, fd, rhs)
1530 };
1531 if !nl.body.contains(&true_dest) || nl.body.contains(&false_dest) {
1532 return None;
1533 }
1534 let exit = false_dest;
1535 // The exit's args from header: must be empty for the iv-only
1536 // form. For the reduction form, accept either [] (acc_param is
1537 // referenced by the exit via dominance — common Fortran lowering)
1538 // OR [acc_param] (canonical mem2reg shape).
1539 let (exit_args_ok, exit_takes_acc): (bool, bool) = hdr
1540 .terminator
1541 .as_ref()
1542 .map(|t| {
1543 if let Terminator::CondBranch { false_args, .. } = t {
1544 if let Some((acc_p, _)) = &acc_param_info {
1545 if false_args.is_empty() {
1546 (true, false)
1547 } else if false_args.len() == 1 && false_args[0] == *acc_p {
1548 (true, true)
1549 } else {
1550 (false, false)
1551 }
1552 } else {
1553 (false_args.is_empty(), false)
1554 }
1555 } else {
1556 (false, false)
1557 }
1558 })
1559 .unwrap_or((false, false));
1560 if !exit_args_ok {
1561 return None;
1562 }
1563 // Bound must NOT be a compile-time constant — that case is handled
1564 // by the static detector. We require the bound to be loop-invariant
1565 // (defined outside the loop body).
1566 if resolve_const_int(func, rhs_v).is_some() {
1567 return None;
1568 }
1569 let body_set: HashSet<BlockId> = nl.body.iter().copied().collect();
1570 let bound_in_body = body_set.iter().any(|b| {
1571 let blk = func.block(*b);
1572 blk.params.iter().any(|p| p.id == rhs_v) || blk.insts.iter().any(|i| i.id == rhs_v)
1573 });
1574 if bound_in_body {
1575 return None;
1576 }
1577 // iv_init from preheader (still requires constant init for v0).
1578 let preheader = {
1579 let header_preds = preds.get(&header)?;
1580 let mut outside: Vec<BlockId> = header_preds
1581 .iter()
1582 .copied()
1583 .filter(|p| !nl.body.contains(p))
1584 .collect();
1585 outside.sort_by_key(|b| b.0);
1586 outside.dedup();
1587 if outside.len() != 1 {
1588 return None;
1589 }
1590 outside[0]
1591 };
1592 let ph = func.block(preheader);
1593 let want_arity = if acc_param_info.is_some() { 2 } else { 1 };
1594 let (iv_init, acc_init_v) = match &ph.terminator {
1595 Some(Terminator::Branch(d, args))
1596 if *d == header && args.len() == want_arity =>
1597 {
1598 let init = resolve_const_int(func, args[iv_idx])?;
1599 let acc_init = acc_idx.map(|i| args[i]);
1600 (init, acc_init)
1601 }
1602 _ => return None,
1603 };
1604 // Latch must end with `br header(iv_next)` (iv-only) or
1605 // `br header(iv_next, new_acc)` (reduction) — args ordering
1606 // matches header.params ordering.
1607 let latch_blk = func.block(latch);
1608 let latch_term_args = match &latch_blk.terminator {
1609 Some(Terminator::Branch(d, args))
1610 if *d == header && args.len() == want_arity =>
1611 {
1612 args.clone()
1613 }
1614 _ => return None,
1615 };
1616 let iv_next = latch_term_args[iv_idx];
1617 let new_acc_v = acc_idx.map(|i| latch_term_args[i]);
1618 let iadd_inst = latch_blk.insts.iter().find(|i| i.id == iv_next)?;
1619 let (lhs, rhs) = match iadd_inst.kind {
1620 InstKind::IAdd(l, r) => (l, r),
1621 _ => return None,
1622 };
1623 let step_const_id = if lhs == iv_param {
1624 rhs
1625 } else if rhs == iv_param {
1626 lhs
1627 } else {
1628 return None;
1629 };
1630 if resolve_const_int(func, step_const_id) != Some(1) {
1631 return None;
1632 }
1633 if has_escaping_values(func, latch, &body_set, preds) {
1634 return None;
1635 }
1636 let body_inst_count = latch_blk.insts.len();
1637 if body_inst_count == 0 {
1638 return None;
1639 }
1640 // Multi-store regalloc-pressure heuristic (mirrors the static path).
1641 let store_count = latch_blk
1642 .insts
1643 .iter()
1644 .filter(|i| matches!(i.kind, InstKind::Store(..)))
1645 .count();
1646 if store_count > 1 {
1647 return None;
1648 }
1649 // Pick the largest U ≤ MAX_FACTOR within body budget. Trip not
1650 // required to be divisible — we'll generate a remainder loop.
1651 let mut chosen_u: Option<usize> = None;
1652 for u in (2..=PARTIAL_UNROLL_MAX_FACTOR).rev() {
1653 if body_inst_count * u > PARTIAL_UNROLL_BODY_BUDGET {
1654 continue;
1655 }
1656 chosen_u = Some(u);
1657 break;
1658 }
1659 let u = chosen_u?;
1660 let reduction = match (acc_param_info, acc_init_v, new_acc_v, acc_idx) {
1661 (Some((acc_param, acc_ty)), Some(acc_init), Some(new_acc), Some(idx)) => {
1662 Some(RuntimeReduction {
1663 acc_param,
1664 acc_ty,
1665 acc_init,
1666 new_acc,
1667 acc_idx: idx,
1668 exit_takes_acc,
1669 })
1670 }
1671 _ => None,
1672 };
1673 Some(PartialRuntimeShape {
1674 preheader,
1675 header,
1676 latch,
1677 exit,
1678 iv_param,
1679 iv_ty,
1680 iv_init,
1681 iv_bound_v: rhs_v,
1682 iv_width,
1683 cond_id,
1684 iadd_id: iv_next,
1685 step_const_id,
1686 u,
1687 reduction,
1688 })
1689 }
1690
1691 fn do_partial_unroll_runtime(func: &mut Function, shape: PartialRuntimeShape) {
1692 let span = dummy_span();
1693 let int_ty = IrType::Int(shape.iv_width);
1694 let bool_ty = IrType::Bool;
1695
1696 // ---- Snapshot the original body BEFORE we modify the latch ----
1697 let body_snapshot: Vec<Inst> = func.block(shape.latch).insts.clone();
1698 let body_no_iadd: Vec<Inst> = body_snapshot
1699 .iter()
1700 .filter(|i| i.id != shape.iadd_id)
1701 .cloned()
1702 .collect();
1703
1704 // ---- 1. Emit head_bound runtime arithmetic in the preheader ----
1705 //
1706 // trip = (bound - init) + 1
1707 // head_count = trip - (trip mod U)
1708 // head_bound = init + (head_count - 1)
1709 let init_const_id = func.next_value_id();
1710 let one_id = func.next_value_id();
1711 let u_const_id = func.next_value_id();
1712 let trip_minus_init_id = func.next_value_id();
1713 let trip_id = func.next_value_id();
1714 let mod_id = func.next_value_id();
1715 let head_count_id = func.next_value_id();
1716 let head_count_minus_one_id = func.next_value_id();
1717 let head_bound_id = func.next_value_id();
1718 for v in [
1719 init_const_id,
1720 one_id,
1721 u_const_id,
1722 trip_minus_init_id,
1723 trip_id,
1724 mod_id,
1725 head_count_id,
1726 head_count_minus_one_id,
1727 head_bound_id,
1728 ] {
1729 func.register_type(v, int_ty.clone());
1730 }
1731 let preheader_new_insts = vec![
1732 Inst {
1733 id: init_const_id,
1734 kind: InstKind::ConstInt(shape.iv_init as i128, shape.iv_width),
1735 ty: int_ty.clone(),
1736 span,
1737 },
1738 Inst {
1739 id: one_id,
1740 kind: InstKind::ConstInt(1, shape.iv_width),
1741 ty: int_ty.clone(),
1742 span,
1743 },
1744 Inst {
1745 id: u_const_id,
1746 kind: InstKind::ConstInt(shape.u as i128, shape.iv_width),
1747 ty: int_ty.clone(),
1748 span,
1749 },
1750 Inst {
1751 id: trip_minus_init_id,
1752 kind: InstKind::ISub(shape.iv_bound_v, init_const_id),
1753 ty: int_ty.clone(),
1754 span,
1755 },
1756 Inst {
1757 id: trip_id,
1758 kind: InstKind::IAdd(trip_minus_init_id, one_id),
1759 ty: int_ty.clone(),
1760 span,
1761 },
1762 Inst {
1763 id: mod_id,
1764 kind: InstKind::IMod(trip_id, u_const_id),
1765 ty: int_ty.clone(),
1766 span,
1767 },
1768 Inst {
1769 id: head_count_id,
1770 kind: InstKind::ISub(trip_id, mod_id),
1771 ty: int_ty.clone(),
1772 span,
1773 },
1774 Inst {
1775 id: head_count_minus_one_id,
1776 kind: InstKind::ISub(head_count_id, one_id),
1777 ty: int_ty.clone(),
1778 span,
1779 },
1780 Inst {
1781 id: head_bound_id,
1782 kind: InstKind::IAdd(init_const_id, head_count_minus_one_id),
1783 ty: int_ty.clone(),
1784 span,
1785 },
1786 ];
1787 {
1788 let pre = func.block_mut(shape.preheader);
1789 // Insert immediately before the terminator (after any existing insts).
1790 let pos = pre.insts.len();
1791 for (i, inst) in preheader_new_insts.into_iter().enumerate() {
1792 pre.insts.insert(pos + i, inst);
1793 }
1794 }
1795
1796 // ---- 2. Rewire header's icmp RHS to head_bound ----
1797 {
1798 let hdr = func.block_mut(shape.header);
1799 if let Some(inst) = hdr.insts.iter_mut().find(|i| i.id == shape.cond_id) {
1800 if let InstKind::ICmp(_, _, ref mut rhs) = inst.kind {
1801 *rhs = head_bound_id;
1802 }
1803 }
1804 }
1805
1806 // ---- 3. Allocate remainder loop blocks ----
1807 let header_remain = func.create_block("partial_remain_header");
1808 let latch_remain = func.create_block("partial_remain_latch");
1809 let remain_iv_id = func.next_value_id();
1810 func.register_type(remain_iv_id, int_ty.clone());
1811 func.block_mut(header_remain).params.push(BlockParam {
1812 id: remain_iv_id,
1813 ty: int_ty.clone(),
1814 });
1815 // Reduction: header_remain also takes an acc param.
1816 let remain_acc_id: Option<ValueId> = if let Some(r) = &shape.reduction {
1817 let id = func.next_value_id();
1818 func.register_type(id, r.acc_ty.clone());
1819 func.block_mut(header_remain).params.push(BlockParam {
1820 id,
1821 ty: r.acc_ty.clone(),
1822 });
1823 Some(id)
1824 } else {
1825 None
1826 };
1827
1828 // ---- 4. Rewire header's cond_br false-target from exit to header_remain ----
1829 // When reducing, also forward the acc value into header_remain.
1830 // header_remain's params are [iv, acc] in that order.
1831 let original_exit = shape.exit;
1832 {
1833 let hdr = func.block_mut(shape.header);
1834 if let Some(Terminator::CondBranch {
1835 false_dest,
1836 false_args,
1837 ..
1838 }) = hdr.terminator.as_mut()
1839 {
1840 *false_dest = header_remain;
1841 *false_args = if let Some(r) = &shape.reduction {
1842 vec![shape.iv_param, r.acc_param]
1843 } else {
1844 vec![shape.iv_param]
1845 };
1846 }
1847 }
1848
1849 // ---- 5. Build the U-way unrolled main-loop body in the latch ----
1850 // When reducing, thread the acc through clones (each clone's
1851 // acc_param substitution is the previous clone's new_acc).
1852 let mut new_main_insts: Vec<Inst> = body_no_iadd.clone();
1853 let mut prev_acc: Option<ValueId> = shape.reduction.as_ref().map(|r| r.new_acc);
1854 for k in 1..shape.u {
1855 let k_const_id = func.next_value_id();
1856 func.register_type(k_const_id, int_ty.clone());
1857 let iv_k_id = func.next_value_id();
1858 func.register_type(iv_k_id, int_ty.clone());
1859 new_main_insts.push(Inst {
1860 id: k_const_id,
1861 kind: InstKind::ConstInt(k as i128, shape.iv_width),
1862 ty: int_ty.clone(),
1863 span,
1864 });
1865 new_main_insts.push(Inst {
1866 id: iv_k_id,
1867 kind: InstKind::IAdd(shape.iv_param, k_const_id),
1868 ty: int_ty.clone(),
1869 span,
1870 });
1871 let mut subst: HashMap<ValueId, ValueId> = HashMap::new();
1872 subst.insert(shape.iv_param, iv_k_id);
1873 // Reduction: substitute acc_param with the running prev_acc so
1874 // this clone consumes the previous iteration's accumulator.
1875 if let (Some(r), Some(prev)) = (shape.reduction.as_ref(), prev_acc) {
1876 subst.insert(r.acc_param, prev);
1877 }
1878 let mut clone_new_acc: Option<ValueId> = None;
1879 for orig in &body_no_iadd {
1880 let new_id = func.next_value_id();
1881 func.register_type(new_id, orig.ty.clone());
1882 subst.insert(orig.id, new_id);
1883 let new_kind = remap_kind(&orig.kind, &subst);
1884 new_main_insts.push(Inst {
1885 id: new_id,
1886 kind: new_kind,
1887 ty: orig.ty.clone(),
1888 span: orig.span,
1889 });
1890 // Track this clone's new_acc id for the next iteration's
1891 // substitution.
1892 if let Some(r) = shape.reduction.as_ref() {
1893 if orig.id == r.new_acc {
1894 clone_new_acc = Some(new_id);
1895 }
1896 }
1897 }
1898 if shape.reduction.is_some() {
1899 prev_acc = clone_new_acc;
1900 }
1901 }
1902 // New iadd: iv_next = iadd iv, U. Reuse the original iadd id.
1903 let new_step_const = func.next_value_id();
1904 func.register_type(new_step_const, int_ty.clone());
1905 new_main_insts.push(Inst {
1906 id: new_step_const,
1907 kind: InstKind::ConstInt(shape.u as i128, shape.iv_width),
1908 ty: int_ty.clone(),
1909 span,
1910 });
1911 new_main_insts.push(Inst {
1912 id: shape.iadd_id,
1913 kind: InstKind::IAdd(shape.iv_param, new_step_const),
1914 ty: int_ty.clone(),
1915 span,
1916 });
1917 func.block_mut(shape.latch).insts = new_main_insts;
1918 // When reducing, rewrite the latch terminator's args[acc_idx] to
1919 // the FINAL clone's new_acc (so each main-loop iteration commits
1920 // the accumulated U-way result instead of the original first-
1921 // iteration new_acc).
1922 if let (Some(r), Some(final_acc)) = (shape.reduction.as_ref(), prev_acc) {
1923 let latch_blk = func.block_mut(shape.latch);
1924 if let Some(Terminator::Branch(_, args)) = latch_blk.terminator.as_mut() {
1925 if args.len() == 2 {
1926 args[r.acc_idx] = final_acc;
1927 }
1928 }
1929 }
1930
1931 // ---- 6. Build header_remain ----
1932 let remain_cmp_id = func.next_value_id();
1933 func.register_type(remain_cmp_id, bool_ty.clone());
1934 func.block_mut(header_remain).insts.push(Inst {
1935 id: remain_cmp_id,
1936 kind: InstKind::ICmp(CmpOp::Le, remain_iv_id, shape.iv_bound_v),
1937 ty: bool_ty.clone(),
1938 span,
1939 });
1940 // When the original exit takes acc as a false_arg, header_remain
1941 // must forward the final acc that way. Otherwise (acc-via-
1942 // dominance lowering), the exit takes no args and the final acc
1943 // is forwarded by rewriting acc_param uses in step 8 below.
1944 let exit_args_remain: Vec<ValueId> = match (
1945 remain_acc_id,
1946 shape.reduction.as_ref().map(|r| r.exit_takes_acc),
1947 ) {
1948 (Some(acc), Some(true)) => vec![acc],
1949 _ => vec![],
1950 };
1951 func.block_mut(header_remain).terminator = Some(Terminator::CondBranch {
1952 cond: remain_cmp_id,
1953 true_dest: latch_remain,
1954 true_args: vec![],
1955 false_dest: original_exit,
1956 false_args: exit_args_remain,
1957 });
1958
1959 // ---- 7. Build latch_remain (scalar 1-iter copy of original body) ----
1960 // Reduction: substitute acc_param → remain_acc_id so the
1961 // body computes new_acc on the running accumulator.
1962 let mut subst: HashMap<ValueId, ValueId> = HashMap::new();
1963 subst.insert(shape.iv_param, remain_iv_id);
1964 if let (Some(r), Some(remain_acc)) = (shape.reduction.as_ref(), remain_acc_id) {
1965 subst.insert(r.acc_param, remain_acc);
1966 }
1967 let mut remain_insts: Vec<Inst> = Vec::new();
1968 let mut remain_new_acc: Option<ValueId> = None;
1969 for orig in &body_no_iadd {
1970 let new_id = func.next_value_id();
1971 func.register_type(new_id, orig.ty.clone());
1972 subst.insert(orig.id, new_id);
1973 let new_kind = remap_kind(&orig.kind, &subst);
1974 remain_insts.push(Inst {
1975 id: new_id,
1976 kind: new_kind,
1977 ty: orig.ty.clone(),
1978 span: orig.span,
1979 });
1980 if let Some(r) = shape.reduction.as_ref() {
1981 if orig.id == r.new_acc {
1982 remain_new_acc = Some(new_id);
1983 }
1984 }
1985 }
1986 let remain_step_const_id = func.next_value_id();
1987 func.register_type(remain_step_const_id, int_ty.clone());
1988 let remain_iv_next_id = func.next_value_id();
1989 func.register_type(remain_iv_next_id, int_ty.clone());
1990 remain_insts.push(Inst {
1991 id: remain_step_const_id,
1992 kind: InstKind::ConstInt(1, shape.iv_width),
1993 ty: int_ty.clone(),
1994 span,
1995 });
1996 remain_insts.push(Inst {
1997 id: remain_iv_next_id,
1998 kind: InstKind::IAdd(remain_iv_id, remain_step_const_id),
1999 ty: int_ty.clone(),
2000 span,
2001 });
2002 func.block_mut(latch_remain).insts = remain_insts;
2003 let latch_remain_args: Vec<ValueId> = if let Some(acc) = remain_new_acc {
2004 vec![remain_iv_next_id, acc]
2005 } else {
2006 vec![remain_iv_next_id]
2007 };
2008 func.block_mut(latch_remain).terminator =
2009 Some(Terminator::Branch(header_remain, latch_remain_args));
2010
2011 // ---- 8. When the original exit references acc_param via dominance
2012 // (rather than via false_args), the head loop's acc value
2013 // that flows out of `header` is now the U-way unrolled head
2014 // result — but the remainder loop continues the
2015 // accumulation. Rewrite uses of acc_param outside the loop
2016 // body to use `remain_acc_id` (the final acc from
2017 // header_remain). header_remain dominates everything that
2018 // used to be dominated by header (along the false branch).
2019 if let (Some(r), Some(remain_acc)) = (shape.reduction.as_ref(), remain_acc_id) {
2020 if !r.exit_takes_acc {
2021 let body_set: HashSet<BlockId> =
2022 [shape.header, shape.latch].iter().copied().collect();
2023 for block in func.blocks.iter_mut() {
2024 if body_set.contains(&block.id) || block.id == shape.preheader {
2025 continue;
2026 }
2027 if block.id == header_remain || block.id == latch_remain {
2028 continue;
2029 }
2030 for inst in block.insts.iter_mut() {
2031 inst.kind = remap_kind_single(&inst.kind, r.acc_param, remain_acc);
2032 }
2033 if let Some(term) = block.terminator.as_mut() {
2034 remap_term_single(term, r.acc_param, remain_acc);
2035 }
2036 }
2037 }
2038 }
2039 }
2040
2041 /// Replace every occurrence of `from` in an `InstKind`'s value
2042 /// operands with `to`. Forwards through `remap_kind` by building a
2043 /// 1-entry HashMap.
2044 fn remap_kind_single(kind: &InstKind, from: ValueId, to: ValueId) -> InstKind {
2045 let mut subst: HashMap<ValueId, ValueId> = HashMap::new();
2046 subst.insert(from, to);
2047 remap_kind(kind, &subst)
2048 }
2049
2050 /// Like `remap_kind_single` but for a `Terminator`. Walks both
2051 /// args lists for `Branch` / `CondBranch`, plus the `cond` operand
2052 /// for `CondBranch` and the value operand for `Return`.
2053 fn remap_term_single(term: &mut Terminator, from: ValueId, to: ValueId) {
2054 let sub = |v: &mut ValueId| {
2055 if *v == from {
2056 *v = to;
2057 }
2058 };
2059 match term {
2060 Terminator::Branch(_, args) => {
2061 for v in args.iter_mut() {
2062 sub(v);
2063 }
2064 }
2065 Terminator::CondBranch {
2066 cond,
2067 true_args,
2068 false_args,
2069 ..
2070 } => {
2071 sub(cond);
2072 for v in true_args.iter_mut() {
2073 sub(v);
2074 }
2075 for v in false_args.iter_mut() {
2076 sub(v);
2077 }
2078 }
2079 Terminator::Return(Some(v)) => sub(v),
2080 Terminator::Return(None) | Terminator::Unreachable => {}
2081 Terminator::Switch { selector, .. } => sub(selector),
2082 }
2083 }
2084
2085 // ---------------------------------------------------------------------------
2086 // Multi-block partial unroll
2087 // ---------------------------------------------------------------------------
2088 //
2089 // `detect_partial_unroll_loop` rejects loops whose natural body has
2090 // more than 2 blocks (i.e., header + extra blocks + latch). At -O3
2091 // such shapes are rare — SCCP, jump-threading, and simplify-cfg
2092 // merge most linear chains — but they do arise in less-aggressively-
2093 // optimized IR (debug builds, certain control-flow patterns the
2094 // optimizer can't simplify, or third-party IR consumers). For the
2095 // linear-chain case we can extend the partial-unroll transform by
2096 // allocating fresh BlockIds for each cloned body iteration and
2097 // rewiring branches; the latch's iadd still bumps iv by U.
2098 //
2099 // v0 limitations:
2100 // * 3-block loop only (header + body + latch). Generalizing to
2101 // longer chains is mechanical (clone N-1 blocks per copy) but
2102 // not yet wired up.
2103 // * Static trip count (constant header bound).
2104 // * Non-reduction (iv-only header).
2105 // * Body's intermediate block has at most 1 store (regalloc-
2106 // pressure heuristic, mirrors the 2-block path).
2107
2108 #[derive(Debug)]
2109 struct PartialMultiBlockShape {
2110 header: BlockId,
2111 /// The body block (between header and latch). Has unconditional
2112 /// br to latch.
2113 body_block: BlockId,
2114 latch: BlockId,
2115 iv_param: ValueId,
2116 iv_ty: IrType,
2117 iv_init: i64,
2118 iv_bound: i64,
2119 iv_width: IntWidth,
2120 /// Latch's `iadd iv, 1` inst id.
2121 iadd_id: ValueId,
2122 /// Const(1) feeding the iadd.
2123 step_const_id: ValueId,
2124 u: usize,
2125 }
2126
2127 fn detect_partial_unroll_multiblock_loop(
2128 func: &Function,
2129 nl: &NaturalLoop,
2130 preds: &HashMap<BlockId, Vec<BlockId>>,
2131 ) -> Option<PartialMultiBlockShape> {
2132 if nl.latches.len() != 1 || nl.body.len() != 3 {
2133 return None;
2134 }
2135 let header = nl.header;
2136 let latch = nl.latches[0];
2137 if header == latch {
2138 return None;
2139 }
2140 let hdr = func.block(header);
2141 if hdr.params.len() != 1 {
2142 return None;
2143 }
2144 let iv_param = hdr.params[0].id;
2145 let iv_ty = hdr.params[0].ty.clone();
2146 let iv_width = match iv_ty {
2147 IrType::Int(w) => w,
2148 _ => return None,
2149 };
2150 // Header's cmp + cond_br true→body_block, false→exit.
2151 let (cond_id, body_block, exit_dest) = match &hdr.terminator {
2152 Some(Terminator::CondBranch {
2153 cond,
2154 true_dest,
2155 true_args,
2156 false_dest,
2157 false_args,
2158 }) if true_args.is_empty() && false_args.is_empty() => (*cond, *true_dest, *false_dest),
2159 _ => return None,
2160 };
2161 if !nl.body.contains(&body_block) || nl.body.contains(&exit_dest) || body_block == latch {
2162 return None;
2163 }
2164 // body_block must have no params and unconditional br latch.
2165 let body_blk = func.block(body_block);
2166 if !body_blk.params.is_empty() {
2167 return None;
2168 }
2169 match &body_blk.terminator {
2170 Some(Terminator::Branch(d, args)) if *d == latch && args.is_empty() => {}
2171 _ => return None,
2172 }
2173 // Header's cmp must be `icmp.le iv, const(hi)` or `icmp.lt`.
2174 let cmp_inst = hdr.insts.iter().find(|i| i.id == cond_id)?;
2175 let iv_bound = match cmp_inst.kind {
2176 InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => resolve_const_int(func, rhs)?,
2177 InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => {
2178 resolve_const_int(func, rhs)?.checked_sub(1)?
2179 }
2180 _ => return None,
2181 };
2182 // Preheader supplies iv_init.
2183 let preheader = {
2184 let header_preds = preds.get(&header)?;
2185 let mut outside: Vec<BlockId> = header_preds
2186 .iter()
2187 .copied()
2188 .filter(|p| !nl.body.contains(p))
2189 .collect();
2190 outside.sort_by_key(|b| b.0);
2191 outside.dedup();
2192 if outside.len() != 1 {
2193 return None;
2194 }
2195 outside[0]
2196 };
2197 let ph = func.block(preheader);
2198 let iv_init = match &ph.terminator {
2199 Some(Terminator::Branch(d, args)) if *d == header && args.len() == 1 => {
2200 resolve_const_int(func, args[0])?
2201 }
2202 _ => return None,
2203 };
2204 // Latch's terminator: iadd iv, 1; br header(iv_next).
2205 let latch_blk = func.block(latch);
2206 let latch_term_args = match &latch_blk.terminator {
2207 Some(Terminator::Branch(d, args)) if *d == header && args.len() == 1 => args.clone(),
2208 _ => return None,
2209 };
2210 let iv_next = latch_term_args[0];
2211 let iadd_inst = latch_blk.insts.iter().find(|i| i.id == iv_next)?;
2212 let (lhs, rhs) = match iadd_inst.kind {
2213 InstKind::IAdd(l, r) => (l, r),
2214 _ => return None,
2215 };
2216 let step_const_id = if lhs == iv_param {
2217 rhs
2218 } else if rhs == iv_param {
2219 lhs
2220 } else {
2221 return None;
2222 };
2223 if resolve_const_int(func, step_const_id) != Some(1) {
2224 return None;
2225 }
2226 let body_set: HashSet<BlockId> = nl.body.iter().copied().collect();
2227 if has_escaping_values(func, latch, &body_set, preds) {
2228 return None;
2229 }
2230 if has_escaping_values(func, body_block, &body_set, preds) {
2231 return None;
2232 }
2233 // Trip / size gates. Use total inst count across body+latch.
2234 let trip = iv_bound - iv_init + 1;
2235 if trip <= FULL_UNROLL_MAX {
2236 return None;
2237 }
2238 let body_count = func.block(body_block).insts.len();
2239 let latch_count = latch_blk.insts.len();
2240 let total_inst = body_count + latch_count;
2241 if total_inst == 0 {
2242 return None;
2243 }
2244 // Multi-store gate (across both blocks).
2245 let store_count = func
2246 .block(body_block)
2247 .insts
2248 .iter()
2249 .chain(latch_blk.insts.iter())
2250 .filter(|i| matches!(i.kind, InstKind::Store(..)))
2251 .count();
2252 if store_count > 1 {
2253 return None;
2254 }
2255 let mut chosen_u: Option<usize> = None;
2256 for u in (2..=PARTIAL_UNROLL_MAX_FACTOR).rev() {
2257 if trip % (u as i64) != 0 {
2258 continue;
2259 }
2260 if total_inst * u > PARTIAL_UNROLL_BODY_BUDGET {
2261 continue;
2262 }
2263 chosen_u = Some(u);
2264 break;
2265 }
2266 let u = chosen_u?;
2267 Some(PartialMultiBlockShape {
2268 header,
2269 body_block,
2270 latch,
2271 iv_param,
2272 iv_ty,
2273 iv_init,
2274 iv_bound,
2275 iv_width,
2276 iadd_id: iv_next,
2277 step_const_id,
2278 u,
2279 })
2280 }
2281
2282 fn do_partial_unroll_multiblock(func: &mut Function, shape: PartialMultiBlockShape) {
2283 let span = dummy_span();
2284 let int_ty = IrType::Int(shape.iv_width);
2285
2286 // ---- Snapshot body and latch (sans iadd) BEFORE we modify them ----
2287 let body_snapshot: Vec<Inst> = func.block(shape.body_block).insts.clone();
2288 let latch_snapshot_no_iadd: Vec<Inst> = func
2289 .block(shape.latch)
2290 .insts
2291 .iter()
2292 .filter(|i| i.id != shape.iadd_id)
2293 .cloned()
2294 .collect();
2295
2296 // For each clone k=1..U, allocate two fresh BlockIds (body clone +
2297 // latch-work clone). They form a chain wired between the original
2298 // latch's body work and the iadd. The original chain runs first;
2299 // each clone follows.
2300 let mut clone_blocks: Vec<(BlockId, BlockId)> = Vec::new();
2301 for k in 1..shape.u {
2302 let body_c = func.create_block(&format!("partial_mb_body_c{}", k));
2303 let latch_c = func.create_block(&format!("partial_mb_latch_c{}", k));
2304 clone_blocks.push((body_c, latch_c));
2305 // Pre-fill: each clone needs (k_const, iv_k, body insts subst, br latch_c)
2306 // (body insts subst, br to next clone or to original iadd block)
2307 let _ = k;
2308 let _ = body_c;
2309 let _ = latch_c;
2310 }
2311
2312 // ---- 1. Original body_block: keep insts; rewire terminator from
2313 // `br latch` to `br latch (with body work)` — actually the
2314 // original body still flows into the original latch's body
2315 // work, then to the first clone.
2316
2317 // Strategy: keep the original chain (body_block → latch) intact for
2318 // the work, but split the latch by removing the iadd. The latch's
2319 // body-work insts run, then we br to the first clone; the iadd
2320 // moves to the LAST clone's latch_c.
2321 //
2322 // After this:
2323 // header → body_block (orig) → latch (orig, no iadd) → clone1.body
2324 // → clone1.latch_c → clone2.body → clone2.latch_c → ... →
2325 // cloneU-1.latch_c (with new iadd*U) → header
2326 //
2327 // Each clone_k computes iv_k = iv + k at the top of body_c.
2328
2329 // Remove the iadd from the original latch's insts.
2330 {
2331 let lat = func.block_mut(shape.latch);
2332 lat.insts.retain(|i| i.id != shape.iadd_id);
2333 }
2334
2335 // ---- 2. Build each clone ----
2336 for (k, &(body_c, latch_c)) in clone_blocks.iter().enumerate() {
2337 let k_val = (k + 1) as i64; // k=0 in Vec → iv+1
2338 // body_c: iv_k_const + iv_k_iadd + cloned body insts + br latch_c
2339 let mut body_c_insts: Vec<Inst> = Vec::new();
2340 let k_const_id = func.next_value_id();
2341 func.register_type(k_const_id, int_ty.clone());
2342 let iv_k_id = func.next_value_id();
2343 func.register_type(iv_k_id, int_ty.clone());
2344 body_c_insts.push(Inst {
2345 id: k_const_id,
2346 kind: InstKind::ConstInt(k_val as i128, shape.iv_width),
2347 ty: int_ty.clone(),
2348 span,
2349 });
2350 body_c_insts.push(Inst {
2351 id: iv_k_id,
2352 kind: InstKind::IAdd(shape.iv_param, k_const_id),
2353 ty: int_ty.clone(),
2354 span,
2355 });
2356 let mut subst: HashMap<ValueId, ValueId> = HashMap::new();
2357 subst.insert(shape.iv_param, iv_k_id);
2358 for orig in &body_snapshot {
2359 let new_id = func.next_value_id();
2360 func.register_type(new_id, orig.ty.clone());
2361 subst.insert(orig.id, new_id);
2362 let new_kind = remap_kind(&orig.kind, &subst);
2363 body_c_insts.push(Inst {
2364 id: new_id,
2365 kind: new_kind,
2366 ty: orig.ty.clone(),
2367 span: orig.span,
2368 });
2369 }
2370 func.block_mut(body_c).insts = body_c_insts;
2371 func.block_mut(body_c).terminator = Some(Terminator::Branch(latch_c, vec![]));
2372
2373 // latch_c: cloned latch-body insts (sans iadd) with same subst.
2374 let mut latch_c_insts: Vec<Inst> = Vec::new();
2375 for orig in &latch_snapshot_no_iadd {
2376 let new_id = func.next_value_id();
2377 func.register_type(new_id, orig.ty.clone());
2378 subst.insert(orig.id, new_id);
2379 let new_kind = remap_kind(&orig.kind, &subst);
2380 latch_c_insts.push(Inst {
2381 id: new_id,
2382 kind: new_kind,
2383 ty: orig.ty.clone(),
2384 span: orig.span,
2385 });
2386 }
2387 func.block_mut(latch_c).insts = latch_c_insts;
2388 // Terminator: chain to next clone, or for the LAST clone,
2389 // emit the new iadd*U + br header.
2390 if k + 1 < clone_blocks.len() {
2391 let next_body = clone_blocks[k + 1].0;
2392 func.block_mut(latch_c).terminator = Some(Terminator::Branch(next_body, vec![]));
2393 } else {
2394 // Last clone: append the new iadd and branch to header.
2395 let new_step_const = func.next_value_id();
2396 func.register_type(new_step_const, int_ty.clone());
2397 func.block_mut(latch_c).insts.push(Inst {
2398 id: new_step_const,
2399 kind: InstKind::ConstInt(shape.u as i128, shape.iv_width),
2400 ty: int_ty.clone(),
2401 span,
2402 });
2403 // Reuse the original iadd id so anything referencing it
2404 // (e.g., dominance maps) stays consistent.
2405 func.block_mut(latch_c).insts.push(Inst {
2406 id: shape.iadd_id,
2407 kind: InstKind::IAdd(shape.iv_param, new_step_const),
2408 ty: int_ty.clone(),
2409 span,
2410 });
2411 func.block_mut(latch_c).terminator =
2412 Some(Terminator::Branch(shape.header, vec![shape.iadd_id]));
2413 }
2414 }
2415
2416 // ---- 3. Rewire the original latch's terminator to flow into the
2417 // first clone (or directly to header if U == 2 and we
2418 // skipped clones — shouldn't happen since min U=2 means
2419 // exactly 1 clone).
2420 if let Some(&(first_clone_body, _)) = clone_blocks.first() {
2421 func.block_mut(shape.latch).terminator =
2422 Some(Terminator::Branch(first_clone_body, vec![]));
2423 }
2424 }
2425
2426 // ---------------------------------------------------------------------------
2427 // Unit tests
2428 // ---------------------------------------------------------------------------
2429
2430 #[cfg(test)]
2431 mod tests {
2432 use super::*;
2433 use crate::ir::types::{IntWidth, IrType};
2434 use crate::lexer::Span;
2435 use crate::opt::pass::Pass;
2436
2437 fn span() -> Span {
2438 super::dummy_span()
2439 }
2440
2441 /// Build a minimal Module with one function containing a simple
2442 /// counted loop: do i = 1, N; a(i_offset) = const(42); end do
2443 ///
2444 /// The loop has 4 blocks: entry (preheader), header, latch, exit.
2445 /// The latch stores a constant to a fixed alloca (simulating a(i) = 42
2446 /// where the GEP is not yet computed — we just store to a fixed addr).
2447 fn build_counted_loop_with_prefix(lo: i64, hi: i64, prefix: &str) -> Module {
2448 let mut m = Module::new("test".into());
2449 let mut f = Function::new("loop_test".into(), vec![], IrType::Void);
2450
2451 // Allocate block IDs.
2452 let header_id = f.create_block(&format!("{}_check", prefix));
2453 let latch_id = f.create_block(&format!("{}_body", prefix));
2454 let exit_id = f.create_block("exit");
2455 let entry_id = f.entry; // preheader
2456
2457 // ---- entry (preheader) -------------------------------------------
2458 // const lo
2459 let lo_val = f.next_value_id();
2460 f.block_mut(entry_id).insts.push(Inst {
2461 id: lo_val,
2462 ty: IrType::Int(IntWidth::I64),
2463 span: span(),
2464 kind: InstKind::ConstInt(lo as i128, IntWidth::I64),
2465 });
2466 // alloca for the "array" (just a slot we store into)
2467 let alloca_val = f.next_value_id();
2468 f.block_mut(entry_id).insts.push(Inst {
2469 id: alloca_val,
2470 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))),
2471 span: span(),
2472 kind: InstKind::Alloca(IrType::Int(IntWidth::I64)),
2473 });
2474 // br header(lo)
2475 f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo_val]));
2476
2477 // ---- header (%i) -------------------------------------------------
2478 let iv_param = f.next_value_id();
2479 f.block_mut(header_id).params.push(BlockParam {
2480 id: iv_param,
2481 ty: IrType::Int(IntWidth::I64),
2482 });
2483 // const hi
2484 let hi_val = f.next_value_id();
2485 f.block_mut(header_id).insts.push(Inst {
2486 id: hi_val,
2487 ty: IrType::Int(IntWidth::I64),
2488 span: span(),
2489 kind: InstKind::ConstInt(hi as i128, IntWidth::I64),
2490 });
2491 // %cmp = icmp.sle %i, hi
2492 let cmp_val = f.next_value_id();
2493 f.block_mut(header_id).insts.push(Inst {
2494 id: cmp_val,
2495 ty: IrType::Bool,
2496 span: span(),
2497 kind: InstKind::ICmp(CmpOp::Le, iv_param, hi_val),
2498 });
2499 // condBr %cmp, latch, exit
2500 f.block_mut(header_id).terminator = Some(Terminator::CondBranch {
2501 cond: cmp_val,
2502 true_dest: latch_id,
2503 true_args: vec![],
2504 false_dest: exit_id,
2505 false_args: vec![],
2506 });
2507
2508 // ---- latch -------------------------------------------------------
2509 // Store 42 to the alloca (simulates a(i) = 42).
2510 let const42 = f.next_value_id();
2511 f.block_mut(latch_id).insts.push(Inst {
2512 id: const42,
2513 ty: IrType::Int(IntWidth::I64),
2514 span: span(),
2515 kind: InstKind::ConstInt(42, IntWidth::I64),
2516 });
2517 let store_id = f.next_value_id();
2518 f.block_mut(latch_id).insts.push(Inst {
2519 id: store_id,
2520 ty: IrType::Void,
2521 span: span(),
2522 kind: InstKind::Store(const42, alloca_val),
2523 });
2524 // %i_next = iadd %i, 1
2525 let one = f.next_value_id();
2526 f.block_mut(latch_id).insts.push(Inst {
2527 id: one,
2528 ty: IrType::Int(IntWidth::I64),
2529 span: span(),
2530 kind: InstKind::ConstInt(1, IntWidth::I64),
2531 });
2532 let i_next = f.next_value_id();
2533 f.block_mut(latch_id).insts.push(Inst {
2534 id: i_next,
2535 ty: IrType::Int(IntWidth::I64),
2536 span: span(),
2537 kind: InstKind::IAdd(iv_param, one),
2538 });
2539 // br header(%i_next)
2540 f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![i_next]));
2541
2542 // ---- exit --------------------------------------------------------
2543 f.block_mut(exit_id).terminator = Some(Terminator::Return(None));
2544
2545 m.add_function(f);
2546 m
2547 }
2548
2549 fn build_counted_loop(lo: i64, hi: i64) -> Module {
2550 build_counted_loop_with_prefix(lo, hi, "do")
2551 }
2552
2553 /// Build a 2-block counted loop where the bound is a function
2554 /// parameter (runtime). Uses an alloca + store as the body work
2555 /// (1 store keeps the multi-store regalloc-pressure heuristic
2556 /// happy) and increments by 1 in the latch.
2557 fn build_runtime_counted_loop(lo: i64) -> Module {
2558 let mut m = Module::new("test".into());
2559 // Reserve %0 for the n_param so the function params line up.
2560 let n_param = ValueId(0);
2561 let params = vec![Param {
2562 name: "n".into(),
2563 ty: IrType::Int(IntWidth::I64),
2564 id: n_param,
2565 fortran_noalias: false,
2566 }];
2567 let mut f = Function::new("loop_runtime".into(), params, IrType::Void);
2568
2569 let header_id = f.create_block("do_check");
2570 let latch_id = f.create_block("do_body");
2571 let exit_id = f.create_block("exit");
2572 let entry_id = f.entry;
2573
2574 // ---- preheader ----
2575 let lo_val = f.next_value_id();
2576 f.block_mut(entry_id).insts.push(Inst {
2577 id: lo_val,
2578 ty: IrType::Int(IntWidth::I64),
2579 span: span(),
2580 kind: InstKind::ConstInt(lo as i128, IntWidth::I64),
2581 });
2582 let alloca_val = f.next_value_id();
2583 f.block_mut(entry_id).insts.push(Inst {
2584 id: alloca_val,
2585 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))),
2586 span: span(),
2587 kind: InstKind::Alloca(IrType::Int(IntWidth::I64)),
2588 });
2589 f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo_val]));
2590
2591 // ---- header ----
2592 let iv_param = f.next_value_id();
2593 f.block_mut(header_id).params.push(BlockParam {
2594 id: iv_param,
2595 ty: IrType::Int(IntWidth::I64),
2596 });
2597 let cmp_val = f.next_value_id();
2598 f.block_mut(header_id).insts.push(Inst {
2599 id: cmp_val,
2600 ty: IrType::Bool,
2601 span: span(),
2602 kind: InstKind::ICmp(CmpOp::Le, iv_param, n_param),
2603 });
2604 f.block_mut(header_id).terminator = Some(Terminator::CondBranch {
2605 cond: cmp_val,
2606 true_dest: latch_id,
2607 true_args: vec![],
2608 false_dest: exit_id,
2609 false_args: vec![],
2610 });
2611
2612 // ---- latch (one store as the body work; iadd iv,1; br) ----
2613 let const_v = f.next_value_id();
2614 f.block_mut(latch_id).insts.push(Inst {
2615 id: const_v,
2616 ty: IrType::Int(IntWidth::I64),
2617 span: span(),
2618 kind: InstKind::ConstInt(42, IntWidth::I64),
2619 });
2620 let store_id = f.next_value_id();
2621 f.block_mut(latch_id).insts.push(Inst {
2622 id: store_id,
2623 ty: IrType::Void,
2624 span: span(),
2625 kind: InstKind::Store(const_v, alloca_val),
2626 });
2627 let one_val = f.next_value_id();
2628 f.block_mut(latch_id).insts.push(Inst {
2629 id: one_val,
2630 ty: IrType::Int(IntWidth::I64),
2631 span: span(),
2632 kind: InstKind::ConstInt(1, IntWidth::I64),
2633 });
2634 let iv_next = f.next_value_id();
2635 f.block_mut(latch_id).insts.push(Inst {
2636 id: iv_next,
2637 ty: IrType::Int(IntWidth::I64),
2638 span: span(),
2639 kind: InstKind::IAdd(iv_param, one_val),
2640 });
2641 f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![iv_next]));
2642
2643 // ---- exit ----
2644 f.block_mut(exit_id).terminator = Some(Terminator::Return(None));
2645
2646 m.add_function(f);
2647 m
2648 }
2649
2650 /// Build a 3-block counted loop: header + body_block (unconditional
2651 /// br to latch) + latch (with iadd + br header). Trip is statically
2652 /// known (`hi`); body_block stores `42` to a local alloca.
2653 fn build_3block_counted_loop(lo: i64, hi: i64) -> Module {
2654 let mut m = Module::new("test".into());
2655 let mut f = Function::new("loop_3block".into(), vec![], IrType::Void);
2656
2657 let header_id = f.create_block("do_check");
2658 let body_id = f.create_block("do_body_inner");
2659 let latch_id = f.create_block("do_body");
2660 let exit_id = f.create_block("exit");
2661 let entry_id = f.entry;
2662
2663 // ---- preheader ----
2664 let lo_val = f.next_value_id();
2665 f.block_mut(entry_id).insts.push(Inst {
2666 id: lo_val,
2667 ty: IrType::Int(IntWidth::I64),
2668 span: span(),
2669 kind: InstKind::ConstInt(lo as i128, IntWidth::I64),
2670 });
2671 let alloca_val = f.next_value_id();
2672 f.block_mut(entry_id).insts.push(Inst {
2673 id: alloca_val,
2674 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))),
2675 span: span(),
2676 kind: InstKind::Alloca(IrType::Int(IntWidth::I64)),
2677 });
2678 f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo_val]));
2679
2680 // ---- header(iv) ----
2681 let iv_param = f.next_value_id();
2682 f.block_mut(header_id).params.push(BlockParam {
2683 id: iv_param,
2684 ty: IrType::Int(IntWidth::I64),
2685 });
2686 let hi_val = f.next_value_id();
2687 f.block_mut(header_id).insts.push(Inst {
2688 id: hi_val,
2689 ty: IrType::Int(IntWidth::I64),
2690 span: span(),
2691 kind: InstKind::ConstInt(hi as i128, IntWidth::I64),
2692 });
2693 let cmp_val = f.next_value_id();
2694 f.block_mut(header_id).insts.push(Inst {
2695 id: cmp_val,
2696 ty: IrType::Bool,
2697 span: span(),
2698 kind: InstKind::ICmp(CmpOp::Le, iv_param, hi_val),
2699 });
2700 f.block_mut(header_id).terminator = Some(Terminator::CondBranch {
2701 cond: cmp_val,
2702 true_dest: body_id,
2703 true_args: vec![],
2704 false_dest: exit_id,
2705 false_args: vec![],
2706 });
2707
2708 // ---- body block (one store, br latch) ----
2709 let const_v = f.next_value_id();
2710 f.block_mut(body_id).insts.push(Inst {
2711 id: const_v,
2712 ty: IrType::Int(IntWidth::I64),
2713 span: span(),
2714 kind: InstKind::ConstInt(42, IntWidth::I64),
2715 });
2716 let store_id = f.next_value_id();
2717 f.block_mut(body_id).insts.push(Inst {
2718 id: store_id,
2719 ty: IrType::Void,
2720 span: span(),
2721 kind: InstKind::Store(const_v, alloca_val),
2722 });
2723 f.block_mut(body_id).terminator = Some(Terminator::Branch(latch_id, vec![]));
2724
2725 // ---- latch (iadd iv,1; br header) ----
2726 let one_val = f.next_value_id();
2727 f.block_mut(latch_id).insts.push(Inst {
2728 id: one_val,
2729 ty: IrType::Int(IntWidth::I64),
2730 span: span(),
2731 kind: InstKind::ConstInt(1, IntWidth::I64),
2732 });
2733 let iv_next = f.next_value_id();
2734 f.block_mut(latch_id).insts.push(Inst {
2735 id: iv_next,
2736 ty: IrType::Int(IntWidth::I64),
2737 span: span(),
2738 kind: InstKind::IAdd(iv_param, one_val),
2739 });
2740 f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![iv_next]));
2741
2742 // ---- exit ----
2743 f.block_mut(exit_id).terminator = Some(Terminator::Return(None));
2744
2745 m.add_function(f);
2746 m
2747 }
2748
2749 #[test]
2750 fn partial_unrolls_3block_loop_trip16() {
2751 // Trip 16 > FULL_UNROLL_MAX=8. Multi-block detector should pick
2752 // the largest U that divides 16 and fits the body budget; with
2753 // a small body this is U=4. The transform allocates 3 fresh
2754 // BlockIds per clone (body_c + latch_c) and wires them between
2755 // the original latch and the header.
2756 let mut m = build_3block_counted_loop(1, 16);
2757 let blocks_before = m.functions[0].blocks.len();
2758 let pass = LoopUnroll;
2759 let changed = pass.run(&mut m);
2760 assert!(
2761 changed,
2762 "3-block trip-16 loop should multi-block partial-unroll"
2763 );
2764 let f = &m.functions[0];
2765 // For U=4, we add (U-1) * 2 = 6 new blocks (body_c + latch_c per clone).
2766 assert_eq!(
2767 f.blocks.len(),
2768 blocks_before + 6,
2769 "expected 6 new clone blocks; got {} → {}",
2770 blocks_before,
2771 f.blocks.len()
2772 );
2773 // The original latch lost its iadd; the new iadd lives in the
2774 // last clone's latch_c. Count IAdd instances across the function:
2775 // one in the last clone's latch_c. (The original iadd id was
2776 // reused there.)
2777 let total_iadd = f
2778 .blocks
2779 .iter()
2780 .flat_map(|b| b.insts.iter())
2781 .filter(|i| matches!(i.kind, InstKind::IAdd(..)))
2782 .count();
2783 // Each clone's body_c has one iadd (computing iv_k). The final
2784 // clone's latch_c also has one iadd (the loop step). With U=4:
2785 // 3 body_c iadds + 1 final-step iadd = 4 total.
2786 assert_eq!(
2787 total_iadd, 4,
2788 "expected 4 IAdds total (3 iv_k + 1 step), got {}",
2789 total_iadd
2790 );
2791 }
2792
2793 #[test]
2794 fn partial_unrolls_runtime_trip_2block_loop() {
2795 // The bound is a function parameter (runtime), so the static
2796 // partial-unroll detector won't fire — but the runtime detector
2797 // should pick up the same shape and emit a head_bound
2798 // computation in the preheader plus a remainder loop after
2799 // the unrolled main loop.
2800 let mut m = build_runtime_counted_loop(1);
2801 let blocks_before = m.functions[0].blocks.len();
2802 let pass = LoopUnroll;
2803 let changed = pass.run(&mut m);
2804 assert!(
2805 changed,
2806 "runtime-trip 2-block loop should partial-unroll"
2807 );
2808 let f = &m.functions[0];
2809 // Two new blocks: header_remain + latch_remain.
2810 assert_eq!(
2811 f.blocks.len(),
2812 blocks_before + 2,
2813 "expected 2 new blocks (remainder header + latch); got {} → {}",
2814 blocks_before,
2815 f.blocks.len()
2816 );
2817 // Preheader should contain the runtime arithmetic for head_bound:
2818 // ISub, IAdd, IMod feeding into the header's icmp via a fresh value.
2819 let preheader = &f.blocks[0];
2820 let kinds: Vec<&InstKind> = preheader.insts.iter().map(|i| &i.kind).collect();
2821 let has_imod = kinds
2822 .iter()
2823 .any(|k| matches!(k, InstKind::IMod(..)));
2824 assert!(has_imod, "preheader should compute head_bound via IMod");
2825 }
2826
2827 #[test]
2828 fn unrolls_trip4() {
2829 let mut m = build_counted_loop(1, 4);
2830 let pass = LoopUnroll;
2831 let changed = pass.run(&mut m);
2832 assert!(changed, "expected unroll to fire");
2833
2834 let f = &m.functions[0];
2835 // After unrolling: preheader, 4 cloned iteration blocks, exit = 6 blocks.
2836 // The original header and latch are pruned (marked Unreachable then removed).
2837 assert_eq!(f.blocks.len(), 6, "expected entry + 4 iter blocks + exit");
2838
2839 // Each iteration block has a ConstInt for that iteration's IV value.
2840 let all_iv_consts: Vec<i128> = f
2841 .blocks
2842 .iter()
2843 .flat_map(|b| b.insts.iter())
2844 .filter_map(|i| {
2845 if let InstKind::ConstInt(v, _) = i.kind {
2846 Some(v)
2847 } else {
2848 None
2849 }
2850 })
2851 .collect();
2852 assert!(all_iv_consts.contains(&1), "IV=1 should appear");
2853 assert!(all_iv_consts.contains(&2), "IV=2 should appear");
2854 assert!(all_iv_consts.contains(&3), "IV=3 should appear");
2855 assert!(all_iv_consts.contains(&4), "IV=4 should appear");
2856 }
2857
2858 #[test]
2859 fn partial_unrolls_trip10() {
2860 // Trip 10 > FULL_UNROLL_MAX=8 — full unroll skips it, but
2861 // partial unroll picks U=2 (largest factor of 10 that fits
2862 // the body budget) and doubles the body in place.
2863 let mut m = build_counted_loop(1, 10);
2864 let pass = LoopUnroll;
2865 let changed = pass.run(&mut m);
2866 assert!(changed, "trip-10 loop should partial-unroll");
2867 // Same number of blocks (loop is preserved); body inst count
2868 // grew because we cloned the body once with iv → iv+1.
2869 let f = &m.functions[0];
2870 let block_count = f.blocks.len();
2871 assert!(
2872 block_count >= 3,
2873 "expected loop structure preserved, got {} blocks",
2874 block_count
2875 );
2876 }
2877
2878 #[test]
2879 fn does_not_partial_unroll_trip_above_threshold_with_no_factor() {
2880 // Trip 11 is prime and > FULL_UNROLL_MAX — neither full nor
2881 // partial unroll fires.
2882 let mut m = build_counted_loop(1, 11);
2883 let pass = LoopUnroll;
2884 let changed = pass.run(&mut m);
2885 assert!(!changed, "trip-11 (prime) loop should not unroll");
2886 }
2887
2888 #[test]
2889 fn unrolls_trip10_do_concurrent() {
2890 let mut m = build_counted_loop_with_prefix(1, 10, "doconc");
2891 let pass = LoopUnroll;
2892 let changed = pass.run(&mut m);
2893 assert!(
2894 changed,
2895 "DO CONCURRENT trip-count-10 loop should fully unroll"
2896 );
2897
2898 let f = &m.functions[0];
2899 assert_eq!(f.blocks.len(), 12, "expected entry + 10 iter blocks + exit");
2900 }
2901
2902 #[test]
2903 fn does_not_unroll_empty_loop() {
2904 let mut m = build_counted_loop(5, 3); // lo > hi → trip count 0
2905 let pass = LoopUnroll;
2906 let changed = pass.run(&mut m);
2907 assert!(!changed, "empty loop should be left for DCE");
2908 }
2909
2910 #[test]
2911 fn unrolls_reduction_loop_threading_accumulator() {
2912 // A reduction loop has 2 block params (iv + accumulator).
2913 // The unroller should clone the body once per iteration and
2914 // thread the accumulator value across iterations: iter 0
2915 // reads `acc_init`, iter 1 reads iter 0's `new_acc`, etc. The
2916 // final branch to `exit` passes the last iteration's
2917 // accumulator value to the exit block param.
2918 let mut m = Module::new("test".into());
2919 let mut f = Function::new("reduce".into(), vec![], IrType::Void);
2920 let header_id = f.create_block("header");
2921 let latch_id = f.create_block("latch");
2922 let exit_id = f.create_block("exit");
2923 let entry_id = f.entry;
2924
2925 // Entry: br header(1, 0)
2926 let lo = f.next_value_id();
2927 f.block_mut(entry_id).insts.push(Inst {
2928 id: lo,
2929 ty: IrType::Int(IntWidth::I64),
2930 span: span(),
2931 kind: InstKind::ConstInt(1, IntWidth::I64),
2932 });
2933 let z = f.next_value_id();
2934 f.block_mut(entry_id).insts.push(Inst {
2935 id: z,
2936 ty: IrType::Int(IntWidth::I64),
2937 span: span(),
2938 kind: InstKind::ConstInt(0, IntWidth::I64),
2939 });
2940 f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo, z]));
2941
2942 // Header: 2 params (iv, acc)
2943 let iv = f.next_value_id();
2944 let acc = f.next_value_id();
2945 f.block_mut(header_id).params.push(BlockParam {
2946 id: iv,
2947 ty: IrType::Int(IntWidth::I64),
2948 });
2949 f.block_mut(header_id).params.push(BlockParam {
2950 id: acc,
2951 ty: IrType::Int(IntWidth::I64),
2952 });
2953 let hi = f.next_value_id();
2954 f.block_mut(header_id).insts.push(Inst {
2955 id: hi,
2956 ty: IrType::Int(IntWidth::I64),
2957 span: span(),
2958 kind: InstKind::ConstInt(4, IntWidth::I64),
2959 });
2960 let cmp = f.next_value_id();
2961 f.block_mut(header_id).insts.push(Inst {
2962 id: cmp,
2963 ty: IrType::Bool,
2964 span: span(),
2965 kind: InstKind::ICmp(CmpOp::Le, iv, hi),
2966 });
2967 f.block_mut(header_id).terminator = Some(Terminator::CondBranch {
2968 cond: cmp,
2969 true_dest: latch_id,
2970 true_args: vec![],
2971 false_dest: exit_id,
2972 false_args: vec![acc],
2973 });
2974
2975 // Latch: acc_next = acc + iv; i_next = iv + 1; br header(i_next, acc_next)
2976 let one = f.next_value_id();
2977 f.block_mut(latch_id).insts.push(Inst {
2978 id: one,
2979 ty: IrType::Int(IntWidth::I64),
2980 span: span(),
2981 kind: InstKind::ConstInt(1, IntWidth::I64),
2982 });
2983 let acc_next = f.next_value_id();
2984 f.block_mut(latch_id).insts.push(Inst {
2985 id: acc_next,
2986 ty: IrType::Int(IntWidth::I64),
2987 span: span(),
2988 kind: InstKind::IAdd(acc, iv),
2989 });
2990 let i_next = f.next_value_id();
2991 f.block_mut(latch_id).insts.push(Inst {
2992 id: i_next,
2993 ty: IrType::Int(IntWidth::I64),
2994 span: span(),
2995 kind: InstKind::IAdd(iv, one),
2996 });
2997 f.block_mut(latch_id).terminator =
2998 Some(Terminator::Branch(header_id, vec![i_next, acc_next]));
2999
3000 // Exit: acc param, return
3001 let acc_out = f.next_value_id();
3002 f.block_mut(exit_id).params.push(BlockParam {
3003 id: acc_out,
3004 ty: IrType::Int(IntWidth::I64),
3005 });
3006 f.block_mut(exit_id).terminator = Some(Terminator::Return(None));
3007
3008 m.add_function(f);
3009
3010 let pass = LoopUnroll;
3011 let changed = pass.run(&mut m);
3012 assert!(changed, "reduction loop should now unroll");
3013
3014 let f = &m.functions[0];
3015
3016 // The exit block must still exist and its (only) predecessor
3017 // path now feeds it a non-acc-param value — i.e. a value
3018 // defined in one of the cloned iteration blocks.
3019 let exit_blk = f
3020 .blocks
3021 .iter()
3022 .find(|b| b.id == exit_id)
3023 .expect("exit survives");
3024 assert_eq!(exit_blk.params.len(), 1, "exit still has its acc_out param");
3025
3026 // Find the predecessor block whose terminator branches to the
3027 // exit block. Its single arg should be the unrolled chain's
3028 // final accumulator value, NOT the original `acc` header
3029 // param (which is now unreachable).
3030 let mut found_branch_to_exit = false;
3031 for blk in &f.blocks {
3032 if let Some(Terminator::Branch(dest, args)) = &blk.terminator {
3033 if *dest == exit_id {
3034 found_branch_to_exit = true;
3035 assert_eq!(args.len(), 1, "exit takes one arg (the accumulator)");
3036 assert_ne!(args[0], acc, "must not reference the pruned header acc param");
3037 assert_ne!(args[0], iv, "must not reference the pruned header iv param");
3038 break;
3039 }
3040 }
3041 }
3042 assert!(
3043 found_branch_to_exit,
3044 "must find an unrolled iter block branching to the original exit"
3045 );
3046
3047 // Trip count = 4 → expect 4 cloned IAdd-of-acc instructions in
3048 // the function (one per iteration). The original latch's IAdd
3049 // is now under an Unreachable terminator and will be pruned by
3050 // a later DCE pass; for this test we just count the clones.
3051 let total_iadds: usize = f
3052 .blocks
3053 .iter()
3054 .flat_map(|b| b.insts.iter())
3055 .filter(|inst| matches!(inst.kind, InstKind::IAdd(..)))
3056 .count();
3057 assert!(
3058 total_iadds >= 4,
3059 "expected at least 4 IAdd insts (one acc-update per iteration), found {}",
3060 total_iadds
3061 );
3062 }
3063
3064 #[test]
3065 fn does_not_unroll_load_bearing_loop() {
3066 let mut m = Module::new("test".into());
3067 let mut f = Function::new("load_loop".into(), vec![], IrType::Void);
3068
3069 let header_id = f.create_block("header");
3070 let latch_id = f.create_block("latch");
3071 let exit_id = f.create_block("exit");
3072 let entry_id = f.entry;
3073
3074 let lo = f.next_value_id();
3075 f.block_mut(entry_id).insts.push(Inst {
3076 id: lo,
3077 ty: IrType::Int(IntWidth::I64),
3078 span: span(),
3079 kind: InstKind::ConstInt(1, IntWidth::I64),
3080 });
3081 let x_slot = f.next_value_id();
3082 f.block_mut(entry_id).insts.push(Inst {
3083 id: x_slot,
3084 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))),
3085 span: span(),
3086 kind: InstKind::Alloca(IrType::Int(IntWidth::I64)),
3087 });
3088 let y_slot = f.next_value_id();
3089 f.block_mut(entry_id).insts.push(Inst {
3090 id: y_slot,
3091 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))),
3092 span: span(),
3093 kind: InstKind::Alloca(IrType::Int(IntWidth::I64)),
3094 });
3095 let x_init = f.next_value_id();
3096 f.block_mut(entry_id).insts.push(Inst {
3097 id: x_init,
3098 ty: IrType::Int(IntWidth::I64),
3099 span: span(),
3100 kind: InstKind::ConstInt(5, IntWidth::I64),
3101 });
3102 let store_x = f.next_value_id();
3103 f.block_mut(entry_id).insts.push(Inst {
3104 id: store_x,
3105 ty: IrType::Void,
3106 span: span(),
3107 kind: InstKind::Store(x_init, x_slot),
3108 });
3109 let y_init = f.next_value_id();
3110 f.block_mut(entry_id).insts.push(Inst {
3111 id: y_init,
3112 ty: IrType::Int(IntWidth::I64),
3113 span: span(),
3114 kind: InstKind::ConstInt(1, IntWidth::I64),
3115 });
3116 let store_y = f.next_value_id();
3117 f.block_mut(entry_id).insts.push(Inst {
3118 id: store_y,
3119 ty: IrType::Void,
3120 span: span(),
3121 kind: InstKind::Store(y_init, y_slot),
3122 });
3123 f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo]));
3124
3125 let iv = f.next_value_id();
3126 f.block_mut(header_id).params.push(BlockParam {
3127 id: iv,
3128 ty: IrType::Int(IntWidth::I64),
3129 });
3130 let hi = f.next_value_id();
3131 f.block_mut(header_id).insts.push(Inst {
3132 id: hi,
3133 ty: IrType::Int(IntWidth::I64),
3134 span: span(),
3135 kind: InstKind::ConstInt(4, IntWidth::I64),
3136 });
3137 let cmp = f.next_value_id();
3138 f.block_mut(header_id).insts.push(Inst {
3139 id: cmp,
3140 ty: IrType::Bool,
3141 span: span(),
3142 kind: InstKind::ICmp(CmpOp::Le, iv, hi),
3143 });
3144 f.block_mut(header_id).terminator = Some(Terminator::CondBranch {
3145 cond: cmp,
3146 true_dest: latch_id,
3147 true_args: vec![],
3148 false_dest: exit_id,
3149 false_args: vec![],
3150 });
3151
3152 let x_val = f.next_value_id();
3153 f.block_mut(latch_id).insts.push(Inst {
3154 id: x_val,
3155 ty: IrType::Int(IntWidth::I64),
3156 span: span(),
3157 kind: InstKind::Load(x_slot),
3158 });
3159 let y_val = f.next_value_id();
3160 f.block_mut(latch_id).insts.push(Inst {
3161 id: y_val,
3162 ty: IrType::Int(IntWidth::I64),
3163 span: span(),
3164 kind: InstKind::Load(y_slot),
3165 });
3166 let sum = f.next_value_id();
3167 f.block_mut(latch_id).insts.push(Inst {
3168 id: sum,
3169 ty: IrType::Int(IntWidth::I64),
3170 span: span(),
3171 kind: InstKind::IAdd(x_val, y_val),
3172 });
3173 let store_sum = f.next_value_id();
3174 f.block_mut(latch_id).insts.push(Inst {
3175 id: store_sum,
3176 ty: IrType::Void,
3177 span: span(),
3178 kind: InstKind::Store(sum, y_slot),
3179 });
3180 let one = f.next_value_id();
3181 f.block_mut(latch_id).insts.push(Inst {
3182 id: one,
3183 ty: IrType::Int(IntWidth::I64),
3184 span: span(),
3185 kind: InstKind::ConstInt(1, IntWidth::I64),
3186 });
3187 let next = f.next_value_id();
3188 f.block_mut(latch_id).insts.push(Inst {
3189 id: next,
3190 ty: IrType::Int(IntWidth::I64),
3191 span: span(),
3192 kind: InstKind::IAdd(iv, one),
3193 });
3194 f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![next]));
3195
3196 f.block_mut(exit_id).terminator = Some(Terminator::Return(None));
3197 m.add_function(f);
3198
3199 let pass = LoopUnroll;
3200 let changed = pass.run(&mut m);
3201 assert!(
3202 !changed,
3203 "load-bearing loop should stay out of the full-unroll fast path"
3204 );
3205 }
3206
3207 #[test]
3208 fn unrolls_trip1() {
3209 let mut m = build_counted_loop(3, 3); // lo == hi → 1 iteration
3210 let pass = LoopUnroll;
3211 let changed = pass.run(&mut m);
3212 assert!(changed, "trip-1 loop should unroll");
3213 let f = &m.functions[0];
3214 // entry + 1 cloned iteration block + exit = 3 blocks
3215 assert_eq!(f.blocks.len(), 3, "expected entry + 1 iter block + exit");
3216 }
3217
3218 /// Regression: mem2reg places block params at dominance-frontier blocks,
3219 /// which can cause an outer-loop latch to reference the inner loop's header
3220 /// block param (%iv) in its branch terminator. After unrolling removes the
3221 /// inner header, %iv becomes undefined → IR verifier panic.
3222 ///
3223 /// `has_escaping_values` must check block params (not just instruction
3224 /// results) AND must be called for the header block (not just body_blocks).
3225 #[test]
3226 fn does_not_unroll_when_header_param_escapes_into_outer_block() {
3227 // Build: inner loop (entry → header(%iv) → latch → header, | → exit)
3228 // plus an outer_latch block that passes %iv to outer_dest.
3229 // outer_latch is not part of the inner nl.body but references %iv.
3230 let mut m = Module::new("test".into());
3231 let mut f = Function::new("escape_test".into(), vec![], IrType::Void);
3232
3233 let header_id = f.create_block("header");
3234 let latch_id = f.create_block("latch");
3235 let exit_id = f.create_block("exit");
3236 let outer_latch = f.create_block("outer_latch");
3237 let outer_dest = f.create_block("outer_dest");
3238 let entry_id = f.entry;
3239
3240 // Entry: const 1; br header(1)
3241 let lo = f.next_value_id();
3242 f.block_mut(entry_id).insts.push(Inst {
3243 id: lo,
3244 ty: IrType::Int(IntWidth::I64),
3245 span: span(),
3246 kind: InstKind::ConstInt(1, IntWidth::I64),
3247 });
3248 f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo]));
3249
3250 // Header: block param %iv; icmp %iv ≤ 3; condBr latch, exit
3251 let iv = f.next_value_id();
3252 f.block_mut(header_id).params.push(BlockParam {
3253 id: iv,
3254 ty: IrType::Int(IntWidth::I64),
3255 });
3256 let hi_c = f.next_value_id();
3257 f.block_mut(header_id).insts.push(Inst {
3258 id: hi_c,
3259 ty: IrType::Int(IntWidth::I64),
3260 span: span(),
3261 kind: InstKind::ConstInt(3, IntWidth::I64),
3262 });
3263 let cmp = f.next_value_id();
3264 f.block_mut(header_id).insts.push(Inst {
3265 id: cmp,
3266 ty: IrType::Bool,
3267 span: span(),
3268 kind: InstKind::ICmp(CmpOp::Le, iv, hi_c),
3269 });
3270 f.block_mut(header_id).terminator = Some(Terminator::CondBranch {
3271 cond: cmp,
3272 true_dest: latch_id,
3273 true_args: vec![],
3274 false_dest: exit_id,
3275 false_args: vec![],
3276 });
3277
3278 // Latch: %next = iadd %iv, 1; br header(%next)
3279 let one = f.next_value_id();
3280 f.block_mut(latch_id).insts.push(Inst {
3281 id: one,
3282 ty: IrType::Int(IntWidth::I64),
3283 span: span(),
3284 kind: InstKind::ConstInt(1, IntWidth::I64),
3285 });
3286 let nxt = f.next_value_id();
3287 f.block_mut(latch_id).insts.push(Inst {
3288 id: nxt,
3289 ty: IrType::Int(IntWidth::I64),
3290 span: span(),
3291 kind: InstKind::IAdd(iv, one),
3292 });
3293 f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![nxt]));
3294
3295 // Exit: ret void
3296 f.block_mut(exit_id).terminator = Some(Terminator::Return(None));
3297
3298 // outer_latch: br outer_dest(%iv)
3299 // Simulates an outer-loop latch that mem2reg threaded %iv through.
3300 // %iv is the header's block param — it escapes into this external block.
3301 f.block_mut(outer_latch).terminator = Some(Terminator::Branch(outer_dest, vec![iv]));
3302
3303 // outer_dest(%x): ret void
3304 let x = f.next_value_id();
3305 f.block_mut(outer_dest).params.push(BlockParam {
3306 id: x,
3307 ty: IrType::Int(IntWidth::I64),
3308 });
3309 f.block_mut(outer_dest).terminator = Some(Terminator::Return(None));
3310
3311 m.add_function(f);
3312
3313 let pass = LoopUnroll;
3314 let changed = pass.run(&mut m);
3315 assert!(
3316 !changed,
3317 "must not unroll when header block param escapes into an outer-loop block"
3318 );
3319 }
3320 }
3321