| 1 | //! Loop unrolling pass (O2+). |
| 2 | //! |
| 3 | //! Fully unrolls simple counted loops whose trip count is statically |
| 4 | //! known and small. Targets the common Fortran `do i = lo, hi` pattern |
| 5 | //! after mem2reg has converted the loop induction variable to an SSA |
| 6 | //! block parameter. |
| 7 | //! |
| 8 | //! ## What constitutes a "simple counted loop" |
| 9 | //! |
| 10 | //! After mem2reg runs, a canonicalized `do i = lo, hi` with stride 1 |
| 11 | //! produces this 2-block structure: |
| 12 | //! |
| 13 | //! ```text |
| 14 | //! preheader: |
| 15 | //! ... |
| 16 | //! br header(const(lo)) |
| 17 | //! |
| 18 | //! header(%i): |
| 19 | //! %cmp = icmp.sle %i, const(hi) ; or icmp.slt, hi+1 |
| 20 | //! condBr %cmp, latch([]), exit([]) |
| 21 | //! |
| 22 | //! latch: |
| 23 | //! ... body instructions using %i ... |
| 24 | //! %i_next = iadd %i, const(1) |
| 25 | //! br header(%i_next) |
| 26 | //! |
| 27 | //! exit: |
| 28 | //! ... |
| 29 | //! ``` |
| 30 | //! |
| 31 | //! We require: |
| 32 | //! - Exactly 1 latch (single back-edge). |
| 33 | //! - Body is exactly `{header, latch}` — no additional blocks (i.e., no IF |
| 34 | //! inside the loop, no nested loops). |
| 35 | //! - Header has exactly 1 block parameter (the IV). |
| 36 | //! - IV's initial value, the loop bound, and the stride are all compile-time |
| 37 | //! constants discoverable from instruction operands. |
| 38 | //! - Stride is exactly 1. |
| 39 | //! - Trip count ≤ `FULL_UNROLL_MAX` for ordinary `DO`, or |
| 40 | //! `DO_CONCURRENT_FULL_UNROLL_MAX` for lowered `DO CONCURRENT`. |
| 41 | //! - Body instruction count ≤ `BODY_SIZE_MAX`. |
| 42 | //! - No instruction defined in the latch is used outside the loop (i.e., no |
| 43 | //! values escape the unrolled body — they'll be dead after unrolling). This |
| 44 | //! covers array loops (`a(i) = expr`) but excludes reduction loops |
| 45 | //! (`s = s + a(i)`) where the accumulator IS a second block param and |
| 46 | //! would escape. |
| 47 | //! |
| 48 | //! ## Transformation |
| 49 | //! |
| 50 | //! For trip count TC, we: |
| 51 | //! 1. Clone the latch body TC times into the preheader (substituting the IV |
| 52 | //! with the constant value for that iteration). |
| 53 | //! 2. Rewrite the preheader's terminator to jump directly to `exit`. |
| 54 | //! 3. Call `prune_unreachable` to remove the now-dead header and latch. |
| 55 | //! |
| 56 | //! After this, the loop is entirely gone. |
| 57 | |
| 58 | use super::pass::Pass; |
| 59 | use super::util::{find_natural_loops, predecessors, NaturalLoop}; |
| 60 | use crate::ir::inst::*; |
| 61 | use crate::ir::types::{IntWidth, IrType}; |
| 62 | use crate::ir::walk::prune_unreachable; |
| 63 | use crate::lexer::{Position, Span}; |
| 64 | use std::collections::{HashMap, HashSet}; |
| 65 | |
| 66 | fn dummy_span() -> Span { |
| 67 | let pos = Position { line: 0, col: 0 }; |
| 68 | Span { |
| 69 | file_id: 0, |
| 70 | start: pos, |
| 71 | end: pos, |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | // --------------------------------------------------------------------------- |
| 76 | // Thresholds |
| 77 | // --------------------------------------------------------------------------- |
| 78 | |
| 79 | /// Maximum trip count eligible for full unrolling. |
| 80 | /// |
| 81 | /// Trip counts above this are left for sprint 29.6's partial unroller |
| 82 | /// which can handle dynamic trip counts with remainder loops. |
| 83 | const FULL_UNROLL_MAX: i64 = 8; |
| 84 | |
| 85 | /// `DO CONCURRENT` is a stronger signal that iterations are independent, |
| 86 | /// so we allow a slightly larger full-unroll budget. |
| 87 | const DO_CONCURRENT_FULL_UNROLL_MAX: i64 = 16; |
| 88 | |
| 89 | /// Maximum instruction count in the latch block. |
| 90 | /// |
| 91 | /// Prevents code bloat from unrolling large bodies. |
| 92 | const BODY_SIZE_MAX: usize = 30; |
| 93 | |
| 94 | // --------------------------------------------------------------------------- |
| 95 | // Public pass entry point |
| 96 | // --------------------------------------------------------------------------- |
| 97 | |
| 98 | pub struct LoopUnroll; |
| 99 | |
| 100 | impl Pass for LoopUnroll { |
| 101 | fn name(&self) -> &'static str { |
| 102 | "loop-unroll" |
| 103 | } |
| 104 | |
| 105 | fn run(&self, module: &mut Module) -> bool { |
| 106 | let mut changed = false; |
| 107 | for func in &mut module.functions { |
| 108 | changed |= unroll_in_function(func); |
| 109 | } |
| 110 | changed |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | // --------------------------------------------------------------------------- |
| 115 | // Per-function entry |
| 116 | // --------------------------------------------------------------------------- |
| 117 | |
| 118 | fn unroll_in_function(func: &mut Function) -> bool { |
| 119 | let loops = find_natural_loops(func); |
| 120 | let preds = predecessors(func); |
| 121 | let mut changed = false; |
| 122 | |
| 123 | for nl in &loops { |
| 124 | if let Some(shape) = detect_simple_loop(func, nl, &preds) { |
| 125 | do_unroll(func, shape); |
| 126 | prune_unreachable(func); |
| 127 | changed = true; |
| 128 | // After structural changes, re-running loop detection in the next |
| 129 | // pass manager iteration handles any inner loops that became |
| 130 | // exposed. Don't iterate here — return and let the pass manager |
| 131 | // fix-point loop drive re-runs. |
| 132 | break; |
| 133 | } |
| 134 | // Try partial unrolling for trip > FULL_UNROLL_MAX. Keeps the |
| 135 | // loop intact but clones body U times and bumps step by U. |
| 136 | if let Some(shape) = detect_partial_unroll_loop(func, nl, &preds) { |
| 137 | do_partial_unroll(func, shape); |
| 138 | changed = true; |
| 139 | break; |
| 140 | } |
| 141 | // Same shape but with a runtime (non-constant) bound — emits |
| 142 | // a head_bound computation in the preheader and a scalar |
| 143 | // remainder loop after the unrolled main loop. |
| 144 | if let Some(shape) = detect_partial_unroll_runtime_loop(func, nl, &preds) { |
| 145 | do_partial_unroll_runtime(func, shape); |
| 146 | changed = true; |
| 147 | break; |
| 148 | } |
| 149 | // Multi-block partial unroll: 3-block (or longer) linear-chain |
| 150 | // body. Allocates fresh BlockIds for each cloned iteration. |
| 151 | if let Some(shape) = detect_partial_unroll_multiblock_loop(func, nl, &preds) { |
| 152 | do_partial_unroll_multiblock(func, shape); |
| 153 | changed = true; |
| 154 | break; |
| 155 | } |
| 156 | } |
| 157 | changed |
| 158 | } |
| 159 | |
| 160 | // --------------------------------------------------------------------------- |
| 161 | // Shape of a fully unrollable loop |
| 162 | // --------------------------------------------------------------------------- |
| 163 | |
| 164 | #[derive(Debug)] |
| 165 | struct LoopShape { |
| 166 | preheader: BlockId, |
| 167 | header: BlockId, // block with IV block param; branches to cmp_block |
| 168 | cmp_block: BlockId, // block with icmp + condBr to body/exit |
| 169 | /// The computation blocks (everything that is not header, cmp_block, or |
| 170 | /// latch). Ordered in CFG traversal order from cmp_block's true-successor |
| 171 | /// to latch's predecessor. For most Fortran DO loops this is a single |
| 172 | /// block ("do_body"). Multi-block bodies are supported as long as the |
| 173 | /// sub-CFG is a linear chain with no branches to outside. |
| 174 | body_blocks: Vec<BlockId>, |
| 175 | latch: BlockId, // iadd IV, 1 + br header(IV_next) |
| 176 | exit: BlockId, // cond_br false target |
| 177 | iv_param: ValueId, |
| 178 | iv_ty: IrType, |
| 179 | iv_init: i64, |
| 180 | iv_bound: i64, // inclusive upper bound |
| 181 | trip_count: usize, |
| 182 | exit_args: Vec<ValueId>, // args from cmp_block's false branch to exit |
| 183 | /// Optional second header block-param representing a reduction |
| 184 | /// accumulator (sum, product, ...). When present, the unroller |
| 185 | /// threads its value across iterations: each iteration's body |
| 186 | /// reads the previous iteration's latch result instead of the |
| 187 | /// header's join. |
| 188 | reduction: Option<ReductionInfo>, |
| 189 | } |
| 190 | |
| 191 | /// Per-loop reduction lane. mem2reg promotes a sum/product accumulator |
| 192 | /// into a header block param; the latch's `br header(iv_next, new_acc)` |
| 193 | /// passes the per-iteration update on the second lane. To unroll such a |
| 194 | /// loop we need to know: |
| 195 | /// |
| 196 | /// * which header param carries the accumulator (`acc_param`), |
| 197 | /// * its preheader-supplied initial value (`acc_init`), |
| 198 | /// * what value the latch passes back as the new accumulator |
| 199 | /// (`latch_acc_value`) — that's the SSA value defined inside the |
| 200 | /// body that the unroller substitutes through across iterations. |
| 201 | #[derive(Debug, Clone, Copy)] |
| 202 | struct ReductionInfo { |
| 203 | acc_param: ValueId, |
| 204 | acc_init: ValueId, |
| 205 | latch_acc_value: ValueId, |
| 206 | } |
| 207 | |
| 208 | fn is_do_concurrent_loop( |
| 209 | func: &Function, |
| 210 | header: BlockId, |
| 211 | cmp_block: BlockId, |
| 212 | body_blocks: &[BlockId], |
| 213 | latch: BlockId, |
| 214 | ) -> bool { |
| 215 | std::iter::once(header) |
| 216 | .chain(std::iter::once(cmp_block)) |
| 217 | .chain(body_blocks.iter().copied()) |
| 218 | .chain(std::iter::once(latch)) |
| 219 | .any(|bb| func.block(bb).name.starts_with("doconc_")) |
| 220 | } |
| 221 | |
| 222 | // --------------------------------------------------------------------------- |
| 223 | // Detection helpers |
| 224 | // --------------------------------------------------------------------------- |
| 225 | |
| 226 | fn detect_simple_loop( |
| 227 | func: &Function, |
| 228 | nl: &NaturalLoop, |
| 229 | preds: &HashMap<BlockId, Vec<BlockId>>, |
| 230 | ) -> Option<LoopShape> { |
| 231 | // ---- structural requirements ---------------------------------------- |
| 232 | if nl.latches.len() != 1 { |
| 233 | return None; |
| 234 | } |
| 235 | let latch = nl.latches[0]; |
| 236 | if latch == nl.header { |
| 237 | return None; |
| 238 | } // no self-loop |
| 239 | // Body must have between 2 (header+latch) and 5 (header+cmp+body×N+latch) blocks. |
| 240 | if nl.body.len() < 2 || nl.body.len() > 5 { |
| 241 | return None; |
| 242 | } |
| 243 | |
| 244 | let header = nl.header; |
| 245 | |
| 246 | // ---- header must have 1 (pure IV) or 2 (IV + reduction acc) |
| 247 | // block params --------------------------------------------------- |
| 248 | let hdr = func.block(header); |
| 249 | if hdr.params.is_empty() || hdr.params.len() > 2 { |
| 250 | return None; |
| 251 | } |
| 252 | // mem2reg places the IV at index 0 and the reduction accumulator (if |
| 253 | // any) at index 1. We rely on that ordering — it matches the |
| 254 | // canonical lowering for `do i = 1, n; s = s + a(i); end do`. |
| 255 | let iv_param_idx = 0usize; |
| 256 | let acc_param_idx: Option<usize> = if hdr.params.len() == 2 { Some(1) } else { None }; |
| 257 | let iv_param = hdr.params[iv_param_idx].id; |
| 258 | let iv_ty = hdr.params[iv_param_idx].ty.clone(); |
| 259 | if !matches!(iv_ty, IrType::Int(_)) { |
| 260 | return None; |
| 261 | } |
| 262 | let acc_param: Option<ValueId> = acc_param_idx.map(|i| hdr.params[i].id); |
| 263 | |
| 264 | // ---- find preheader ------------------------------------------------- |
| 265 | let header_preds = preds.get(&header)?; |
| 266 | let mut outside: Vec<BlockId> = header_preds |
| 267 | .iter() |
| 268 | .copied() |
| 269 | .filter(|p| !nl.body.contains(p)) |
| 270 | .collect(); |
| 271 | outside.sort_by_key(|b| b.0); |
| 272 | outside.dedup(); |
| 273 | if outside.len() != 1 { |
| 274 | return None; |
| 275 | } |
| 276 | let preheader = outside[0]; |
| 277 | |
| 278 | // Preheader must branch to header with exactly N args matching the |
| 279 | // header's param count. The IV arg is a const; the accumulator init |
| 280 | // can be any SSA value (typically a ConstInt(0) for sum, but we |
| 281 | // don't enforce that — we only need to capture the value to feed |
| 282 | // into iteration 0's body substitution). |
| 283 | let ph_blk = func.block(preheader); |
| 284 | let expected_ph_args = hdr.params.len(); |
| 285 | let (iv_init, acc_init) = match &ph_blk.terminator { |
| 286 | Some(Terminator::Branch(dest, args)) |
| 287 | if *dest == header && args.len() == expected_ph_args => |
| 288 | { |
| 289 | let iv_init = resolve_const_int(func, args[iv_param_idx])?; |
| 290 | let acc_init = acc_param_idx.map(|i| args[i]); |
| 291 | (iv_init, acc_init) |
| 292 | } |
| 293 | _ => return None, |
| 294 | }; |
| 295 | |
| 296 | // ---- find the comparison block and exit ---------------------------- |
| 297 | // |
| 298 | // Two cases: |
| 299 | // (a) 2-block loop: header IS the comparison block. |
| 300 | // (b) 4-block loop (canonical Fortran DO): header is a transparent |
| 301 | // relay block (0 instructions, br → cmp_block). cmp_block has |
| 302 | // the ICmp + CondBranch. |
| 303 | // |
| 304 | // In both cases we need (cmp_block, iv_bound, exit, exit_args). |
| 305 | |
| 306 | let (cmp_block, iv_bound, exit, exit_args) = { |
| 307 | if let Some((bound, ex, args)) = detect_bound_and_exit(func, header, iv_param) { |
| 308 | // Case (a): header has the comparison directly. |
| 309 | (header, bound, ex, args) |
| 310 | } else { |
| 311 | // Case (b): header must be a transparent relay. |
| 312 | let hdr_blk = func.block(header); |
| 313 | if !hdr_blk.insts.is_empty() { |
| 314 | return None; |
| 315 | } // must have 0 instructions |
| 316 | let relay_target = match &hdr_blk.terminator { |
| 317 | Some(Terminator::Branch(t, args)) if args.is_empty() => *t, |
| 318 | _ => return None, |
| 319 | }; |
| 320 | if !nl.body.contains(&relay_target) { |
| 321 | return None; |
| 322 | } |
| 323 | let (bound, ex, args) = detect_bound_and_exit(func, relay_target, iv_param)?; |
| 324 | (relay_target, bound, ex, args) |
| 325 | } |
| 326 | }; |
| 327 | |
| 328 | if nl.body.contains(&exit) { |
| 329 | return None; |
| 330 | } |
| 331 | |
| 332 | // ---- latch: stride-1 increment, passes iv (and optionally acc) |
| 333 | // back to header ------------------------------------------------ |
| 334 | let latch_info = check_latch(func, latch, header, iv_param, iv_param_idx, acc_param_idx)?; |
| 335 | |
| 336 | // ---- compute body_blocks: body − {header, cmp_block, latch} -------- |
| 337 | // |
| 338 | // Walk from cmp_block's true-successor to latch (exclusive) in CFG |
| 339 | // order. The sub-CFG must be a linear chain — any branch or divergence |
| 340 | // means the loop has internal control flow we don't handle here. |
| 341 | let body_blocks = collect_body_chain(func, cmp_block, latch, &nl.body)?; |
| 342 | |
| 343 | let is_do_concurrent = is_do_concurrent_loop(func, header, cmp_block, &body_blocks, latch); |
| 344 | if !is_do_concurrent && body_blocks.iter().any(|&bb| block_contains_load(func, bb)) { |
| 345 | return None; |
| 346 | } |
| 347 | |
| 348 | // ---- no loop-defined values escape outside the loop ---------------- |
| 349 | // |
| 350 | // We must check ALL blocks that will be pruned after unrolling — not just |
| 351 | // body_blocks. In particular, `header` carries a block param (the IV) that |
| 352 | // mem2reg may have threaded into OUTER loop latches via dominance-frontier |
| 353 | // placement. If that param is used outside nl.body, the header's block param |
| 354 | // becomes undefined after we discard the header block. |
| 355 | for &bb in body_blocks |
| 356 | .iter() |
| 357 | .chain(std::iter::once(&header)) |
| 358 | .chain(std::iter::once(&cmp_block)) |
| 359 | .chain(std::iter::once(&latch)) |
| 360 | { |
| 361 | if has_escaping_values(func, bb, &nl.body, preds) { |
| 362 | return None; |
| 363 | } |
| 364 | } |
| 365 | |
| 366 | // ---- size check ------------------------------------------------------ |
| 367 | let total_insts: usize = body_blocks.iter().map(|&b| func.block(b).insts.len()).sum(); |
| 368 | if total_insts > BODY_SIZE_MAX { |
| 369 | return None; |
| 370 | } |
| 371 | |
| 372 | let full_unroll_max = if is_do_concurrent { |
| 373 | DO_CONCURRENT_FULL_UNROLL_MAX |
| 374 | } else { |
| 375 | FULL_UNROLL_MAX |
| 376 | }; |
| 377 | |
| 378 | // ---- trip count ------------------------------------------------------ |
| 379 | let trip_count_i64 = iv_bound - iv_init + 1; |
| 380 | if trip_count_i64 <= 0 { |
| 381 | return None; |
| 382 | } |
| 383 | if trip_count_i64 > full_unroll_max { |
| 384 | return None; |
| 385 | } |
| 386 | |
| 387 | let reduction = match (acc_param, acc_init, latch_info.latch_acc_value) { |
| 388 | (Some(p), Some(init), Some(latch_val)) => { |
| 389 | // The latch accumulator value must be defined inside one |
| 390 | // of the body_blocks — otherwise it lives in the latch |
| 391 | // (which the unroller doesn't clone) and we'd drop the |
| 392 | // reduction op. Bail in that case rather than miscompile. |
| 393 | let defined_in_body = body_blocks |
| 394 | .iter() |
| 395 | .any(|&b| func.block(b).insts.iter().any(|i| i.id == latch_val)); |
| 396 | if !defined_in_body { |
| 397 | return None; |
| 398 | } |
| 399 | Some(ReductionInfo { |
| 400 | acc_param: p, |
| 401 | acc_init: init, |
| 402 | latch_acc_value: latch_val, |
| 403 | }) |
| 404 | } |
| 405 | (None, None, None) => None, |
| 406 | // Inconsistent state — fall through and reject the loop. Should |
| 407 | // not happen because the lane indices line up at the top. |
| 408 | _ => return None, |
| 409 | }; |
| 410 | |
| 411 | Some(LoopShape { |
| 412 | preheader, |
| 413 | header, |
| 414 | cmp_block, |
| 415 | body_blocks, |
| 416 | latch, |
| 417 | exit, |
| 418 | iv_param, |
| 419 | iv_ty, |
| 420 | iv_init, |
| 421 | iv_bound, |
| 422 | trip_count: trip_count_i64 as usize, |
| 423 | exit_args, |
| 424 | reduction, |
| 425 | }) |
| 426 | } |
| 427 | |
| 428 | /// Walk from the comparison block's true-successor to the latch, collecting |
| 429 | /// the "body" blocks in order. Returns None if the sub-CFG is not a linear |
| 430 | /// chain (no branches to outside, no divergence). |
| 431 | /// |
| 432 | /// Two structural cases: |
| 433 | /// |
| 434 | /// (a) 2-block loop — the cmp_block's true-successor IS the latch. The latch |
| 435 | /// contains both the body instructions and the IV increment. We return |
| 436 | /// `[latch]` so the body is cloned from the latch's instructions. |
| 437 | /// The IV increment (`iadd iv, 1`) will be included in the clone but its |
| 438 | /// result is unused in the cloned block (the terminator is rewritten to |
| 439 | /// jump to the next iteration, discarding the iadd result). DCE cleans it. |
| 440 | /// |
| 441 | /// (b) 4-block loop — cmp_block's true-successor is a separate body block |
| 442 | /// distinct from the latch. We walk the chain through the latch and |
| 443 | /// include it in the clone set because real lowered loops may still |
| 444 | /// carry side effects there in addition to the IV increment. |
| 445 | fn collect_body_chain( |
| 446 | func: &Function, |
| 447 | cmp_block: BlockId, |
| 448 | latch: BlockId, |
| 449 | loop_body: &HashSet<BlockId>, |
| 450 | ) -> Option<Vec<BlockId>> { |
| 451 | // Find the true-successor of cmp_block (the first body block). |
| 452 | let first = match &func.block(cmp_block).terminator { |
| 453 | Some(Terminator::CondBranch { true_dest, .. }) => *true_dest, |
| 454 | _ => return None, |
| 455 | }; |
| 456 | if !loop_body.contains(&first) { |
| 457 | return None; |
| 458 | } |
| 459 | |
| 460 | if first == latch { |
| 461 | // Case (a): the cmp's true-successor is the latch itself. |
| 462 | // The latch IS the body — include it in the clone chain. |
| 463 | return Some(vec![latch]); |
| 464 | } |
| 465 | |
| 466 | // Case (b): walk the chain from first through latch. The latch often |
| 467 | // contains only the IV increment, but real fused/lowered loops can keep |
| 468 | // user-visible side effects there too, so excluding it is unsound. |
| 469 | let mut chain = Vec::new(); |
| 470 | let mut cur = first; |
| 471 | loop { |
| 472 | if !loop_body.contains(&cur) { |
| 473 | return None; |
| 474 | } |
| 475 | chain.push(cur); |
| 476 | if cur == latch { |
| 477 | break; |
| 478 | } |
| 479 | let blk = func.block(cur); |
| 480 | // Each block in the chain must have no params (no join point). |
| 481 | if !blk.params.is_empty() { |
| 482 | return None; |
| 483 | } |
| 484 | // Its terminator must be an unconditional branch to the next block. |
| 485 | match &blk.terminator { |
| 486 | Some(Terminator::Branch(next, args)) if args.is_empty() => { |
| 487 | cur = *next; |
| 488 | } |
| 489 | _ => return None, |
| 490 | } |
| 491 | if chain.len() > 4 { |
| 492 | return None; |
| 493 | } // safety limit |
| 494 | } |
| 495 | Some(chain) |
| 496 | } |
| 497 | |
| 498 | /// Delegate to shared loop utility. |
| 499 | fn resolve_const_int(func: &Function, vid: ValueId) -> Option<i64> { |
| 500 | super::loop_utils::resolve_const_int(func, vid) |
| 501 | } |
| 502 | |
| 503 | /// Inspect the header to find the loop bound and exit block. |
| 504 | /// |
| 505 | /// Accepts: |
| 506 | /// - `ICmp(Sle, iv, const(hi))` → bound = hi, true_dest = latch, false_dest = exit |
| 507 | /// - `ICmp(Slt, iv, const(hi))` → bound = hi-1 (exclusive → inclusive), same |
| 508 | /// - `ICmp(Sge, const(hi), iv)` → bound = hi (commuted) |
| 509 | /// - `ICmp(Sgt, const(hi), iv)` → bound = hi-1 (commuted exclusive) |
| 510 | /// |
| 511 | /// Returns `(inclusive_bound, exit_block_id, exit_args)`. |
| 512 | fn detect_bound_and_exit( |
| 513 | func: &Function, |
| 514 | header: BlockId, |
| 515 | iv: ValueId, |
| 516 | ) -> Option<(i64, BlockId, Vec<ValueId>)> { |
| 517 | let hdr = func.block(header); |
| 518 | |
| 519 | // Find the single ICmp that involves the IV. |
| 520 | let mut cmp_id: Option<ValueId> = None; |
| 521 | let mut bound: Option<i64> = None; |
| 522 | let mut is_lt = false; // true when we need bound-1 |
| 523 | |
| 524 | for inst in &hdr.insts { |
| 525 | if let InstKind::ICmp(op, a, b) = &inst.kind { |
| 526 | let other = if *a == iv { |
| 527 | *b |
| 528 | } else if *b == iv { |
| 529 | *a |
| 530 | } else { |
| 531 | continue; |
| 532 | }; |
| 533 | let c = resolve_const_int(func, other)?; |
| 534 | match op { |
| 535 | CmpOp::Le => { |
| 536 | bound = Some(c); |
| 537 | is_lt = false; |
| 538 | } |
| 539 | CmpOp::Lt => { |
| 540 | bound = Some(c); |
| 541 | is_lt = true; |
| 542 | } |
| 543 | CmpOp::Ge => { |
| 544 | bound = Some(c); |
| 545 | is_lt = false; |
| 546 | } // iv >= c means upper = c when commuted |
| 547 | CmpOp::Gt => { |
| 548 | bound = Some(c); |
| 549 | is_lt = true; |
| 550 | } // iv > c means upper = c-1 commuted |
| 551 | _ => return None, |
| 552 | } |
| 553 | cmp_id = Some(inst.id); |
| 554 | break; |
| 555 | } |
| 556 | } |
| 557 | let cmp_id = cmp_id?; |
| 558 | let raw_bound = bound?; |
| 559 | let inclusive_bound = if is_lt { raw_bound - 1 } else { raw_bound }; |
| 560 | |
| 561 | // Terminator must be CondBranch on that cmp. |
| 562 | match &hdr.terminator { |
| 563 | Some(Terminator::CondBranch { |
| 564 | cond, |
| 565 | true_dest, |
| 566 | true_args, |
| 567 | false_dest, |
| 568 | false_args, |
| 569 | }) if *cond == cmp_id => { |
| 570 | // true → body (latch), false → exit. |
| 571 | // Check which side is the latch. We don't have the latch id here; |
| 572 | // we'll accept either ordering and return the "false" side as exit. |
| 573 | // The caller verifies that the exit is not in the loop body. |
| 574 | let _ = true_args; |
| 575 | Some((inclusive_bound, *false_dest, false_args.clone())) |
| 576 | } |
| 577 | _ => None, |
| 578 | } |
| 579 | } |
| 580 | |
| 581 | /// Verify that the latch block increments the IV by 1 and passes only |
| 582 | /// the new IV value back to the header (stride-1 check). |
| 583 | /// |
| 584 | /// Accepted pattern: |
| 585 | /// ```text |
| 586 | /// %i_next = iadd %iv, const(1) |
| 587 | /// br header(%i_next) |
| 588 | /// ``` |
| 589 | /// Validate the latch terminator and (when applicable) capture the |
| 590 | /// reduction lane. |
| 591 | /// |
| 592 | /// Accepts either: |
| 593 | /// * `br header(%iv_next)` — pure-IV loop, returns |
| 594 | /// `Some(LatchInfo { latch_acc_value: None })`. |
| 595 | /// * `br header(%iv_next, %new_acc)` — reduction loop, where |
| 596 | /// `acc_lane_idx` (passed by the caller as 0 or 1) tells us which |
| 597 | /// branch arg index carries the accumulator. Returns |
| 598 | /// `Some(LatchInfo { latch_acc_value: Some(%new_acc) })`. |
| 599 | /// |
| 600 | /// In both cases the IV stride must be `iadd iv, const(1)` defined in |
| 601 | /// the latch. |
| 602 | fn check_latch( |
| 603 | func: &Function, |
| 604 | latch: BlockId, |
| 605 | header: BlockId, |
| 606 | iv: ValueId, |
| 607 | iv_lane_idx: usize, |
| 608 | acc_lane_idx: Option<usize>, |
| 609 | ) -> Option<LatchInfo> { |
| 610 | let blk = func.block(latch); |
| 611 | |
| 612 | let expected_args = if acc_lane_idx.is_some() { 2 } else { 1 }; |
| 613 | let args = match &blk.terminator { |
| 614 | Some(Terminator::Branch(dest, args)) |
| 615 | if *dest == header && args.len() == expected_args => |
| 616 | { |
| 617 | args.clone() |
| 618 | } |
| 619 | _ => return None, |
| 620 | }; |
| 621 | |
| 622 | let iv_next = args[iv_lane_idx]; |
| 623 | let acc_value = acc_lane_idx.map(|i| args[i]); |
| 624 | |
| 625 | // The iv lane's value must be defined as `iadd %iv, const(1)` |
| 626 | // inside the latch. |
| 627 | let mut iv_ok = false; |
| 628 | for inst in &blk.insts { |
| 629 | if inst.id == iv_next { |
| 630 | if let InstKind::IAdd(a, b) = &inst.kind { |
| 631 | let (is_iv, other) = if *a == iv { |
| 632 | (true, *b) |
| 633 | } else if *b == iv { |
| 634 | (true, *a) |
| 635 | } else { |
| 636 | (false, *a) |
| 637 | }; |
| 638 | if !is_iv { |
| 639 | return None; |
| 640 | } |
| 641 | if resolve_const_int(func, other) != Some(1) { |
| 642 | return None; |
| 643 | } |
| 644 | iv_ok = true; |
| 645 | } else { |
| 646 | return None; |
| 647 | } |
| 648 | break; |
| 649 | } |
| 650 | } |
| 651 | if !iv_ok { |
| 652 | return None; |
| 653 | } |
| 654 | |
| 655 | Some(LatchInfo { |
| 656 | latch_acc_value: acc_value, |
| 657 | }) |
| 658 | } |
| 659 | |
| 660 | #[derive(Debug, Clone, Copy)] |
| 661 | struct LatchInfo { |
| 662 | latch_acc_value: Option<ValueId>, |
| 663 | } |
| 664 | |
| 665 | /// Return true if any instruction result defined in the latch is used |
| 666 | /// in a block that is NOT part of the loop body. |
| 667 | /// |
| 668 | /// This ensures no latch-computed value leaks out into the exit or |
| 669 | /// subsequent blocks. (Such loops have loop-carried dependencies that |
| 670 | /// require different handling.) |
| 671 | fn has_escaping_values( |
| 672 | func: &Function, |
| 673 | latch: BlockId, |
| 674 | body: &HashSet<BlockId>, |
| 675 | preds: &HashMap<BlockId, Vec<BlockId>>, |
| 676 | ) -> bool { |
| 677 | // Collect all ValueIds defined in the block — both block params and |
| 678 | // instruction results. Block params matter because mem2reg places them at |
| 679 | // dominance-frontier blocks; if such a param escapes into an outer-loop |
| 680 | // latch it becomes undefined after unrolling removes the block. |
| 681 | let latch_blk = func.block(latch); |
| 682 | let mut latch_defs: HashSet<ValueId> = HashSet::new(); |
| 683 | for bp in &latch_blk.params { |
| 684 | latch_defs.insert(bp.id); |
| 685 | } |
| 686 | for inst in &latch_blk.insts { |
| 687 | latch_defs.insert(inst.id); |
| 688 | } |
| 689 | |
| 690 | if latch_defs.is_empty() { |
| 691 | return false; |
| 692 | } |
| 693 | |
| 694 | // Check every block outside the loop body for uses of those values. |
| 695 | for block in &func.blocks { |
| 696 | if body.contains(&block.id) { |
| 697 | continue; |
| 698 | } |
| 699 | // Check instruction operands. |
| 700 | for inst in &block.insts { |
| 701 | for use_id in inst_uses(&inst.kind) { |
| 702 | if latch_defs.contains(&use_id) { |
| 703 | return true; |
| 704 | } |
| 705 | } |
| 706 | } |
| 707 | // Check terminator operands. |
| 708 | if let Some(term) = &block.terminator { |
| 709 | for &use_id in terminator_uses_vec(term).iter() { |
| 710 | if latch_defs.contains(&use_id) { |
| 711 | return true; |
| 712 | } |
| 713 | } |
| 714 | } |
| 715 | // Also check block params of outside blocks that receive args from latch. |
| 716 | // (If the latch branches to an outside block with args, those args |
| 717 | // might carry latch values. But `check_latch` already restricts the |
| 718 | // latch terminator to `br header(%iv_next)`, so this can't happen for |
| 719 | // loops we accept. Kept as belt-and-suspenders.) |
| 720 | } |
| 721 | let _ = preds; |
| 722 | false |
| 723 | } |
| 724 | |
| 725 | fn block_contains_load(func: &Function, block: BlockId) -> bool { |
| 726 | func.block(block) |
| 727 | .insts |
| 728 | .iter() |
| 729 | .any(|inst| matches!(inst.kind, InstKind::Load(_))) |
| 730 | } |
| 731 | |
| 732 | /// Collect all ValueId operands in a terminator into a Vec. |
| 733 | /// (Mirrors `terminator_uses` from walk.rs but returns owned storage.) |
| 734 | fn terminator_uses_vec(term: &Terminator) -> Vec<ValueId> { |
| 735 | let mut out = Vec::new(); |
| 736 | match term { |
| 737 | Terminator::Return(Some(v)) => out.push(*v), |
| 738 | Terminator::Branch(_, args) => out.extend(args), |
| 739 | Terminator::CondBranch { |
| 740 | cond, |
| 741 | true_args, |
| 742 | false_args, |
| 743 | .. |
| 744 | } => { |
| 745 | out.push(*cond); |
| 746 | out.extend(true_args); |
| 747 | out.extend(false_args); |
| 748 | } |
| 749 | Terminator::Switch { selector, .. } => out.push(*selector), |
| 750 | _ => {} |
| 751 | } |
| 752 | out |
| 753 | } |
| 754 | |
| 755 | /// Import inst_uses from the walk module for convenience. |
| 756 | fn inst_uses(kind: &InstKind) -> Vec<ValueId> { |
| 757 | crate::ir::walk::inst_uses(kind) |
| 758 | } |
| 759 | |
| 760 | // --------------------------------------------------------------------------- |
| 761 | // Transformation |
| 762 | // --------------------------------------------------------------------------- |
| 763 | |
| 764 | fn do_unroll(func: &mut Function, shape: LoopShape) { |
| 765 | // Collect per-body-block instructions once (before we mutate func). |
| 766 | let body_snapshots: Vec<Vec<Inst>> = shape |
| 767 | .body_blocks |
| 768 | .iter() |
| 769 | .map(|&b| func.block(b).insts.clone()) |
| 770 | .collect(); |
| 771 | |
| 772 | // Determine the IV width. |
| 773 | let iv_width = match &shape.iv_ty { |
| 774 | IrType::Int(w) => *w, |
| 775 | _ => IntWidth::I32, |
| 776 | }; |
| 777 | |
| 778 | // For each iteration, we create one new block per body block, cloned |
| 779 | // with the IV substituted by a constant. The chain is: |
| 780 | // |
| 781 | // preheader → iter0_body0 → iter0_body1 → ... → iter0_bodyN-1 |
| 782 | // → iter1_body0 → ... |
| 783 | // → iter{TC-1}_body0 → ... |
| 784 | // → exit |
| 785 | // |
| 786 | // If body_blocks is empty (cmp_block directly targets latch), the |
| 787 | // iterations have no instructions and we just wire preheader → exit. |
| 788 | |
| 789 | // Accumulate the block ids for each iteration so we can wire terminators. |
| 790 | let tc = shape.trip_count; |
| 791 | let nb = body_snapshots.len(); |
| 792 | |
| 793 | let mut iter_blocks: Vec<Vec<BlockId>> = Vec::with_capacity(tc); |
| 794 | let mut iter_substs: Vec<HashMap<ValueId, ValueId>> = Vec::with_capacity(tc); |
| 795 | |
| 796 | // Reduction threading: prev_acc carries iteration k-1's latch |
| 797 | // accumulator value into iteration k's body substitution. For |
| 798 | // iteration 0 it's seeded from the preheader's branch arg; after |
| 799 | // each iteration completes, it's updated to that iter's body |
| 800 | // output (looked up through the iter's subst by the original |
| 801 | // latch_acc_value's ID). |
| 802 | let mut prev_acc: Option<ValueId> = shape.reduction.as_ref().map(|r| r.acc_init); |
| 803 | |
| 804 | for k in 0..tc { |
| 805 | let iv_val = shape.iv_init + k as i64; |
| 806 | |
| 807 | // Build substitution map seeded with iv_param → const(iv_val) |
| 808 | // and (when present) acc_param → previous iteration's latch |
| 809 | // accumulator value. |
| 810 | let mut subst: HashMap<ValueId, ValueId> = HashMap::new(); |
| 811 | |
| 812 | // Emit the IV constant (will be placed in each new block's first inst). |
| 813 | let iv_const_id = func.next_value_id(); |
| 814 | subst.insert(shape.iv_param, iv_const_id); |
| 815 | |
| 816 | if let (Some(r), Some(prev)) = (shape.reduction.as_ref(), prev_acc) { |
| 817 | subst.insert(r.acc_param, prev); |
| 818 | } |
| 819 | |
| 820 | let mut blk_ids = Vec::with_capacity(nb.max(1)); |
| 821 | if nb == 0 { |
| 822 | // No body blocks — create a single pass-through block. |
| 823 | let bid = func.create_block(&format!("unroll_k{}", k)); |
| 824 | blk_ids.push(bid); |
| 825 | } else { |
| 826 | for (bi, snap) in body_snapshots.iter().enumerate() { |
| 827 | let bid = func.create_block(&format!("unroll_k{}b{}", k, bi)); |
| 828 | // First block of this iter: emit IV constant. |
| 829 | if bi == 0 { |
| 830 | let iv_inst = Inst { |
| 831 | id: iv_const_id, |
| 832 | kind: InstKind::ConstInt(iv_val as i128, iv_width), |
| 833 | ty: shape.iv_ty.clone(), |
| 834 | span: dummy_span(), |
| 835 | }; |
| 836 | func.block_mut(bid).insts.push(iv_inst); |
| 837 | } |
| 838 | // Clone body instructions. |
| 839 | for inst in snap { |
| 840 | let new_id = func.next_value_id(); |
| 841 | subst.insert(inst.id, new_id); |
| 842 | let new_kind = remap_kind(&inst.kind, &subst); |
| 843 | func.block_mut(bid).insts.push(Inst { |
| 844 | id: new_id, |
| 845 | kind: new_kind, |
| 846 | ty: inst.ty.clone(), |
| 847 | span: inst.span, |
| 848 | }); |
| 849 | } |
| 850 | blk_ids.push(bid); |
| 851 | } |
| 852 | } |
| 853 | // After this iteration's body has been cloned, look up the |
| 854 | // post-substitution form of the latch's accumulator value — |
| 855 | // that becomes the seed for the next iteration's acc_param |
| 856 | // substitution. If the latch's acc value wasn't defined in |
| 857 | // this body (e.g. an invariant or pre-loop value), fall back |
| 858 | // to the original ID; the unroller's other safety checks |
| 859 | // ensure that's still well-defined post-pruning. |
| 860 | if let Some(r) = shape.reduction.as_ref() { |
| 861 | prev_acc = Some( |
| 862 | subst |
| 863 | .get(&r.latch_acc_value) |
| 864 | .copied() |
| 865 | .unwrap_or(r.latch_acc_value), |
| 866 | ); |
| 867 | } |
| 868 | iter_blocks.push(blk_ids); |
| 869 | iter_substs.push(subst); |
| 870 | } |
| 871 | |
| 872 | // Wire terminators within each iteration (intra-iter body chain) |
| 873 | // and across iterations (last body block of iter k → first of iter k+1). |
| 874 | let original_body_next: Vec<BlockId> = { |
| 875 | // For each body block b, what block does it branch to? |
| 876 | // (It's either the next body block or the latch.) |
| 877 | let mut nexts = Vec::new(); |
| 878 | for (bi, &b) in shape.body_blocks.iter().enumerate() { |
| 879 | let next = if bi + 1 < shape.body_blocks.len() { |
| 880 | shape.body_blocks[bi + 1] |
| 881 | } else { |
| 882 | shape.latch |
| 883 | }; |
| 884 | let _ = b; |
| 885 | nexts.push(next); |
| 886 | } |
| 887 | nexts |
| 888 | }; |
| 889 | |
| 890 | for k in 0..tc { |
| 891 | let nb_actual = iter_blocks[k].len(); |
| 892 | // Determine which block this iteration chains to after its last body block. |
| 893 | let after_iter = if k + 1 < tc { |
| 894 | // First block of next iteration. |
| 895 | iter_blocks[k + 1][0] |
| 896 | } else { |
| 897 | // Last iteration → exit. |
| 898 | shape.exit |
| 899 | }; |
| 900 | |
| 901 | // Args to pass when branching to the exit block. For |
| 902 | // reduction loops, the original `exit_args` references the |
| 903 | // header's acc_param — pruned along with the rest of the |
| 904 | // header. Replace the acc_param entry with the final |
| 905 | // iteration's accumulator value (`prev_acc`, set up after |
| 906 | // the per-iter cloning loop ran iteration N-1). Other |
| 907 | // entries are left alone, falling through to the verifier |
| 908 | // if they reference values that would survive (today's only |
| 909 | // supported shape has exactly the acc lane). |
| 910 | let exit_args = if after_iter == shape.exit { |
| 911 | if let (Some(r), Some(final_acc)) = (shape.reduction.as_ref(), prev_acc) { |
| 912 | shape |
| 913 | .exit_args |
| 914 | .iter() |
| 915 | .map(|&v| if v == r.acc_param { final_acc } else { v }) |
| 916 | .collect() |
| 917 | } else { |
| 918 | shape.exit_args.clone() |
| 919 | } |
| 920 | } else { |
| 921 | vec![] |
| 922 | }; |
| 923 | |
| 924 | if nb == 0 { |
| 925 | // No body — the single pass-through block jumps to after_iter. |
| 926 | func.block_mut(iter_blocks[k][0]).terminator = |
| 927 | Some(Terminator::Branch(after_iter, exit_args)); |
| 928 | } else { |
| 929 | for bi in 0..nb_actual { |
| 930 | let cur_blk = iter_blocks[k][bi]; |
| 931 | let (next_blk, args) = if bi + 1 < nb_actual { |
| 932 | (iter_blocks[k][bi + 1], vec![]) |
| 933 | } else { |
| 934 | // Last body block of iter k → after_iter. |
| 935 | (after_iter, exit_args.clone()) |
| 936 | }; |
| 937 | let _ = original_body_next[bi]; // consumed above |
| 938 | func.block_mut(cur_blk).terminator = Some(Terminator::Branch(next_blk, args)); |
| 939 | } |
| 940 | } |
| 941 | } |
| 942 | |
| 943 | // Rewrite preheader terminator: skip the loop header, jump to first iter. |
| 944 | let (first_block, preheader_args) = if tc > 0 { |
| 945 | (iter_blocks[0][0], vec![]) |
| 946 | } else { |
| 947 | // Zero-trip loop: jump directly to exit with its required args. |
| 948 | (shape.exit, shape.exit_args.clone()) |
| 949 | }; |
| 950 | func.block_mut(shape.preheader).terminator = |
| 951 | Some(Terminator::Branch(first_block, preheader_args)); |
| 952 | |
| 953 | // Mark loop control blocks as Unreachable → prune_unreachable removes them. |
| 954 | func.block_mut(shape.header).terminator = Some(Terminator::Unreachable); |
| 955 | func.block_mut(shape.cmp_block).terminator = Some(Terminator::Unreachable); |
| 956 | func.block_mut(shape.latch).terminator = Some(Terminator::Unreachable); |
| 957 | for &b in &shape.body_blocks { |
| 958 | func.block_mut(b).terminator = Some(Terminator::Unreachable); |
| 959 | } |
| 960 | } |
| 961 | |
| 962 | /// Deep-clone an InstKind, substituting all ValueId operands according to `subst`. |
| 963 | fn remap_kind(kind: &InstKind, subst: &HashMap<ValueId, ValueId>) -> InstKind { |
| 964 | let r = |v: &ValueId| *subst.get(v).unwrap_or(v); |
| 965 | match kind { |
| 966 | // Constants — no operands. |
| 967 | InstKind::ConstInt(v, w) => InstKind::ConstInt(*v, *w), |
| 968 | InstKind::ConstFloat(v, w) => InstKind::ConstFloat(*v, *w), |
| 969 | InstKind::ConstBool(v) => InstKind::ConstBool(*v), |
| 970 | InstKind::ConstString(v) => InstKind::ConstString(v.clone()), |
| 971 | InstKind::Undef(t) => InstKind::Undef(t.clone()), |
| 972 | InstKind::GlobalAddr(s) => InstKind::GlobalAddr(s.clone()), |
| 973 | |
| 974 | // Integer arithmetic. |
| 975 | InstKind::IAdd(a, b) => InstKind::IAdd(r(a), r(b)), |
| 976 | InstKind::ISub(a, b) => InstKind::ISub(r(a), r(b)), |
| 977 | InstKind::IMul(a, b) => InstKind::IMul(r(a), r(b)), |
| 978 | InstKind::IDiv(a, b) => InstKind::IDiv(r(a), r(b)), |
| 979 | InstKind::IMod(a, b) => InstKind::IMod(r(a), r(b)), |
| 980 | InstKind::INeg(a) => InstKind::INeg(r(a)), |
| 981 | |
| 982 | // Float arithmetic. |
| 983 | InstKind::FAdd(a, b) => InstKind::FAdd(r(a), r(b)), |
| 984 | InstKind::FSub(a, b) => InstKind::FSub(r(a), r(b)), |
| 985 | InstKind::FMul(a, b) => InstKind::FMul(r(a), r(b)), |
| 986 | InstKind::FDiv(a, b) => InstKind::FDiv(r(a), r(b)), |
| 987 | InstKind::FNeg(a) => InstKind::FNeg(r(a)), |
| 988 | InstKind::FAbs(a) => InstKind::FAbs(r(a)), |
| 989 | InstKind::FSqrt(a) => InstKind::FSqrt(r(a)), |
| 990 | InstKind::FPow(a, b) => InstKind::FPow(r(a), r(b)), |
| 991 | |
| 992 | // Comparison. |
| 993 | InstKind::ICmp(op, a, b) => InstKind::ICmp(*op, r(a), r(b)), |
| 994 | InstKind::FCmp(op, a, b) => InstKind::FCmp(*op, r(a), r(b)), |
| 995 | |
| 996 | // Logic. |
| 997 | InstKind::And(a, b) => InstKind::And(r(a), r(b)), |
| 998 | InstKind::Or(a, b) => InstKind::Or(r(a), r(b)), |
| 999 | InstKind::Not(a) => InstKind::Not(r(a)), |
| 1000 | |
| 1001 | // Select. |
| 1002 | InstKind::Select(c, t, f) => InstKind::Select(r(c), r(t), r(f)), |
| 1003 | |
| 1004 | // Bitwise. |
| 1005 | InstKind::BitAnd(a, b) => InstKind::BitAnd(r(a), r(b)), |
| 1006 | InstKind::BitOr(a, b) => InstKind::BitOr(r(a), r(b)), |
| 1007 | InstKind::BitXor(a, b) => InstKind::BitXor(r(a), r(b)), |
| 1008 | InstKind::BitNot(a) => InstKind::BitNot(r(a)), |
| 1009 | InstKind::Shl(a, b) => InstKind::Shl(r(a), r(b)), |
| 1010 | InstKind::LShr(a, b) => InstKind::LShr(r(a), r(b)), |
| 1011 | InstKind::AShr(a, b) => InstKind::AShr(r(a), r(b)), |
| 1012 | InstKind::CountLeadingZeros(a) => InstKind::CountLeadingZeros(r(a)), |
| 1013 | InstKind::CountTrailingZeros(a) => InstKind::CountTrailingZeros(r(a)), |
| 1014 | InstKind::PopCount(a) => InstKind::PopCount(r(a)), |
| 1015 | |
| 1016 | // Conversions. |
| 1017 | InstKind::IntToFloat(a, w) => InstKind::IntToFloat(r(a), *w), |
| 1018 | InstKind::FloatToInt(a, w) => InstKind::FloatToInt(r(a), *w), |
| 1019 | InstKind::FloatExtend(a, w) => InstKind::FloatExtend(r(a), *w), |
| 1020 | InstKind::FloatTrunc(a, w) => InstKind::FloatTrunc(r(a), *w), |
| 1021 | InstKind::IntExtend(a, w, s) => InstKind::IntExtend(r(a), *w, *s), |
| 1022 | InstKind::IntTrunc(a, w) => InstKind::IntTrunc(r(a), *w), |
| 1023 | InstKind::PtrToInt(a) => InstKind::PtrToInt(r(a)), |
| 1024 | InstKind::IntToPtr(a, ty) => InstKind::IntToPtr(r(a), ty.clone()), |
| 1025 | |
| 1026 | // Memory. |
| 1027 | InstKind::Alloca(t) => InstKind::Alloca(t.clone()), |
| 1028 | InstKind::Load(a) => InstKind::Load(r(a)), |
| 1029 | InstKind::Store(v, p) => InstKind::Store(r(v), r(p)), |
| 1030 | InstKind::GetElementPtr(base, idxs) => { |
| 1031 | InstKind::GetElementPtr(r(base), idxs.iter().map(&r).collect()) |
| 1032 | } |
| 1033 | |
| 1034 | // Calls. |
| 1035 | InstKind::Call(f, args) => InstKind::Call(f.clone(), args.iter().map(&r).collect()), |
| 1036 | InstKind::RuntimeCall(f, args) => { |
| 1037 | InstKind::RuntimeCall(f.clone(), args.iter().map(&r).collect()) |
| 1038 | } |
| 1039 | |
| 1040 | // Aggregates. |
| 1041 | InstKind::ExtractField(v, idx) => InstKind::ExtractField(r(v), *idx), |
| 1042 | InstKind::InsertField(v, idx, fld) => InstKind::InsertField(r(v), *idx, r(fld)), |
| 1043 | |
| 1044 | // ---- SIMD vector ops ---- |
| 1045 | InstKind::VAdd(a, b) => InstKind::VAdd(r(a), r(b)), |
| 1046 | InstKind::VSub(a, b) => InstKind::VSub(r(a), r(b)), |
| 1047 | InstKind::VMul(a, b) => InstKind::VMul(r(a), r(b)), |
| 1048 | InstKind::VDiv(a, b) => InstKind::VDiv(r(a), r(b)), |
| 1049 | InstKind::VNeg(a) => InstKind::VNeg(r(a)), |
| 1050 | InstKind::VAbs(a) => InstKind::VAbs(r(a)), |
| 1051 | InstKind::VSqrt(a) => InstKind::VSqrt(r(a)), |
| 1052 | InstKind::VFma(a, b, c) => InstKind::VFma(r(a), r(b), r(c)), |
| 1053 | InstKind::VSelect(m, t, f) => InstKind::VSelect(r(m), r(t), r(f)), |
| 1054 | InstKind::VMin(a, b) => InstKind::VMin(r(a), r(b)), |
| 1055 | InstKind::VMax(a, b) => InstKind::VMax(r(a), r(b)), |
| 1056 | InstKind::VICmp(op, a, b) => InstKind::VICmp(*op, r(a), r(b)), |
| 1057 | InstKind::VFCmp(op, a, b) => InstKind::VFCmp(*op, r(a), r(b)), |
| 1058 | InstKind::VLoad(p) => InstKind::VLoad(r(p)), |
| 1059 | InstKind::VStore(v, p) => InstKind::VStore(r(v), r(p)), |
| 1060 | InstKind::VBitcast(v, ty) => InstKind::VBitcast(r(v), ty.clone()), |
| 1061 | InstKind::VExtract(v, lane) => InstKind::VExtract(r(v), *lane), |
| 1062 | InstKind::VInsert(v, lane, s) => InstKind::VInsert(r(v), *lane, r(s)), |
| 1063 | InstKind::VBroadcast(s) => InstKind::VBroadcast(r(s)), |
| 1064 | InstKind::VReduceSum(v) => InstKind::VReduceSum(r(v)), |
| 1065 | InstKind::VReduceMin(v) => InstKind::VReduceMin(r(v)), |
| 1066 | InstKind::VReduceMax(v) => InstKind::VReduceMax(r(v)), |
| 1067 | } |
| 1068 | } |
| 1069 | |
| 1070 | // --------------------------------------------------------------------------- |
| 1071 | // Partial unrolling |
| 1072 | // --------------------------------------------------------------------------- |
| 1073 | // |
| 1074 | // Loops with statically known trip counts above FULL_UNROLL_MAX don't |
| 1075 | // benefit from full unrolling (code bloat), but a U-way partial unroll |
| 1076 | // (clone body U-1 additional times, step IV by U) can still expose ILP |
| 1077 | // and reduce loop overhead. v0 handles the simplest shape: |
| 1078 | // |
| 1079 | // * 2-block loop (header + latch). |
| 1080 | // * Single block param (no reduction). |
| 1081 | // * Static trip count > FULL_UNROLL_MAX, divisible by `U`. |
| 1082 | // * Body instruction count × U ≤ PARTIAL_UNROLL_BODY_BUDGET. |
| 1083 | // |
| 1084 | // The transform clones the latch's body insts U-1 more times with |
| 1085 | // `iv` substituted by a freshly-computed `iv + k` (1 ≤ k < U), then |
| 1086 | // rewrites the latch's `iadd iv, 1` to `iadd iv, U`. |
| 1087 | |
| 1088 | /// Largest unrolled-body size in instructions for partial unrolling. |
| 1089 | const PARTIAL_UNROLL_BODY_BUDGET: usize = 60; |
| 1090 | /// Largest unroll factor we'll apply. |
| 1091 | const PARTIAL_UNROLL_MAX_FACTOR: usize = 4; |
| 1092 | |
| 1093 | #[derive(Debug)] |
| 1094 | struct PartialShape { |
| 1095 | header: BlockId, |
| 1096 | latch: BlockId, |
| 1097 | iv_param: ValueId, |
| 1098 | iv_ty: IrType, |
| 1099 | iv_init: i64, |
| 1100 | iv_bound: i64, |
| 1101 | /// Unroll factor. |
| 1102 | u: usize, |
| 1103 | /// The latch's `iadd iv, 1` instruction id. |
| 1104 | iadd_id: ValueId, |
| 1105 | /// The const(1) value feeding the iadd. |
| 1106 | step_const_id: ValueId, |
| 1107 | /// If the loop is a reduction (header has 2 params), this carries |
| 1108 | /// the acc param id and the latch's terminator's new-acc value. |
| 1109 | reduction: Option<PartialReductionInfo>, |
| 1110 | } |
| 1111 | |
| 1112 | #[derive(Debug, Clone, Copy)] |
| 1113 | struct PartialReductionInfo { |
| 1114 | acc_param: ValueId, |
| 1115 | /// The value the latch passes back as the new accumulator |
| 1116 | /// (terminator's args[1]). |
| 1117 | new_acc: ValueId, |
| 1118 | } |
| 1119 | |
| 1120 | fn detect_partial_unroll_loop( |
| 1121 | func: &Function, |
| 1122 | nl: &NaturalLoop, |
| 1123 | preds: &HashMap<BlockId, Vec<BlockId>>, |
| 1124 | ) -> Option<PartialShape> { |
| 1125 | if nl.latches.len() != 1 || nl.body.len() != 2 { |
| 1126 | return None; |
| 1127 | } |
| 1128 | let header = nl.header; |
| 1129 | let latch = nl.latches[0]; |
| 1130 | if header == latch { |
| 1131 | return None; |
| 1132 | } |
| 1133 | let hdr = func.block(header); |
| 1134 | if hdr.params.is_empty() || hdr.params.len() > 2 { |
| 1135 | return None; |
| 1136 | } |
| 1137 | let iv_param = hdr.params[0].id; |
| 1138 | let iv_ty = hdr.params[0].ty.clone(); |
| 1139 | if !matches!(iv_ty, IrType::Int(_)) { |
| 1140 | return None; |
| 1141 | } |
| 1142 | let acc_param: Option<ValueId> = if hdr.params.len() == 2 { |
| 1143 | Some(hdr.params[1].id) |
| 1144 | } else { |
| 1145 | None |
| 1146 | }; |
| 1147 | // Header has the cmp + cond_br body/exit. |
| 1148 | let (iv_bound, exit, _exit_args) = detect_bound_and_exit(func, header, iv_param)?; |
| 1149 | if nl.body.contains(&exit) { |
| 1150 | return None; |
| 1151 | } |
| 1152 | // Preheader supplies iv_init. |
| 1153 | let header_preds = preds.get(&header)?; |
| 1154 | let mut outside: Vec<BlockId> = header_preds |
| 1155 | .iter() |
| 1156 | .copied() |
| 1157 | .filter(|p| !nl.body.contains(p)) |
| 1158 | .collect(); |
| 1159 | outside.sort_by_key(|b| b.0); |
| 1160 | outside.dedup(); |
| 1161 | if outside.len() != 1 { |
| 1162 | return None; |
| 1163 | } |
| 1164 | let preheader = outside[0]; |
| 1165 | let ph = func.block(preheader); |
| 1166 | let expected_args = if acc_param.is_some() { 2 } else { 1 }; |
| 1167 | let iv_init = match &ph.terminator { |
| 1168 | Some(Terminator::Branch(d, args)) |
| 1169 | if *d == header && args.len() == expected_args => |
| 1170 | { |
| 1171 | resolve_const_int(func, args[0])? |
| 1172 | } |
| 1173 | _ => return None, |
| 1174 | }; |
| 1175 | // Latch: `... insts ...; iadd iv, 1; br header(iv_next [, new_acc])`. |
| 1176 | let latch_blk = func.block(latch); |
| 1177 | let latch_term_args = match &latch_blk.terminator { |
| 1178 | Some(Terminator::Branch(d, args)) |
| 1179 | if *d == header && args.len() == expected_args => |
| 1180 | { |
| 1181 | args.clone() |
| 1182 | } |
| 1183 | _ => return None, |
| 1184 | }; |
| 1185 | let iv_next = latch_term_args[0]; |
| 1186 | let new_acc = if acc_param.is_some() { |
| 1187 | Some(latch_term_args[1]) |
| 1188 | } else { |
| 1189 | None |
| 1190 | }; |
| 1191 | let iadd_inst = latch_blk.insts.iter().find(|i| i.id == iv_next)?; |
| 1192 | let (lhs, rhs) = match iadd_inst.kind { |
| 1193 | InstKind::IAdd(l, r) => (l, r), |
| 1194 | _ => return None, |
| 1195 | }; |
| 1196 | let step_const_id = if lhs == iv_param { |
| 1197 | rhs |
| 1198 | } else if rhs == iv_param { |
| 1199 | lhs |
| 1200 | } else { |
| 1201 | return None; |
| 1202 | }; |
| 1203 | if resolve_const_int(func, step_const_id) != Some(1) { |
| 1204 | return None; |
| 1205 | } |
| 1206 | // No latch-defined value escapes the loop. |
| 1207 | let body_set: HashSet<BlockId> = nl.body.iter().copied().collect(); |
| 1208 | if has_escaping_values(func, latch, &body_set, preds) { |
| 1209 | return None; |
| 1210 | } |
| 1211 | // Trip count gate: > FULL_UNROLL_MAX (full unroller would have |
| 1212 | // taken it otherwise) and divisible by some U in [2, MAX_FACTOR]. |
| 1213 | let trip = iv_bound - iv_init + 1; |
| 1214 | if trip <= FULL_UNROLL_MAX { |
| 1215 | return None; |
| 1216 | } |
| 1217 | let body_inst_count = latch_blk.insts.len(); |
| 1218 | if body_inst_count == 0 { |
| 1219 | return None; |
| 1220 | } |
| 1221 | // Heuristic: only partial-unroll reductions OR non-reduction |
| 1222 | // loops with at most one store. Multi-store init-style loops |
| 1223 | // (e.g., `a(i) = ...; b(i) = ...; c(i) = ...;`) cloned U-way |
| 1224 | // create too much register pressure across the function and |
| 1225 | // expose latent regalloc fragility (Sprint 18 follow-up). |
| 1226 | let store_count = latch_blk |
| 1227 | .insts |
| 1228 | .iter() |
| 1229 | .filter(|i| matches!(i.kind, InstKind::Store(..))) |
| 1230 | .count(); |
| 1231 | if acc_param.is_none() && store_count > 1 { |
| 1232 | return None; |
| 1233 | } |
| 1234 | // Pick the largest U ≤ MAX_FACTOR that divides trip and respects |
| 1235 | // the body-size budget. |
| 1236 | let mut chosen_u: Option<usize> = None; |
| 1237 | for u in (2..=PARTIAL_UNROLL_MAX_FACTOR).rev() { |
| 1238 | if trip % (u as i64) != 0 { |
| 1239 | continue; |
| 1240 | } |
| 1241 | if body_inst_count * u > PARTIAL_UNROLL_BODY_BUDGET { |
| 1242 | continue; |
| 1243 | } |
| 1244 | chosen_u = Some(u); |
| 1245 | break; |
| 1246 | } |
| 1247 | let u = chosen_u?; |
| 1248 | let reduction = match (acc_param, new_acc) { |
| 1249 | (Some(p), Some(v)) => Some(PartialReductionInfo { |
| 1250 | acc_param: p, |
| 1251 | new_acc: v, |
| 1252 | }), |
| 1253 | (None, None) => None, |
| 1254 | _ => return None, |
| 1255 | }; |
| 1256 | Some(PartialShape { |
| 1257 | header, |
| 1258 | latch, |
| 1259 | iv_param, |
| 1260 | iv_ty, |
| 1261 | iv_init, |
| 1262 | iv_bound, |
| 1263 | u, |
| 1264 | iadd_id: iv_next, |
| 1265 | step_const_id, |
| 1266 | reduction, |
| 1267 | }) |
| 1268 | } |
| 1269 | |
| 1270 | fn do_partial_unroll(func: &mut Function, shape: PartialShape) { |
| 1271 | // Snapshot of latch body BEFORE the iadd. We'll clone these |
| 1272 | // instructions U-1 more times with iv substituted. |
| 1273 | let latch_snapshot: Vec<Inst> = func.block(shape.latch).insts.clone(); |
| 1274 | // Precompute body-without-iadd for cloning. |
| 1275 | let body_inst_clone: Vec<Inst> = latch_snapshot |
| 1276 | .iter() |
| 1277 | .filter(|i| i.id != shape.iadd_id) |
| 1278 | .cloned() |
| 1279 | .collect(); |
| 1280 | let iv_width = match shape.iv_ty { |
| 1281 | IrType::Int(w) => w, |
| 1282 | _ => return, |
| 1283 | }; |
| 1284 | let span = dummy_span(); |
| 1285 | // For each k in 1..U: emit `iv_k = iadd iv, k_const`, then clone |
| 1286 | // every body inst (except iadd) with iv → iv_k. Append all of |
| 1287 | // these BEFORE the existing iadd, so the sequence is: |
| 1288 | // <orig body insts> ; uses iv |
| 1289 | // iv_1 = iadd iv, 1 |
| 1290 | // <body cloned, iv → iv_1> |
| 1291 | // iv_2 = iadd iv, 2 |
| 1292 | // <body cloned, iv → iv_2> |
| 1293 | // ... |
| 1294 | // iv_next = iadd iv, U |
| 1295 | // br header(iv_next) |
| 1296 | // |
| 1297 | // We construct the new instruction list off-band, then swap |
| 1298 | // it in. |
| 1299 | let mut new_insts: Vec<Inst> = latch_snapshot |
| 1300 | .iter() |
| 1301 | .filter(|i| i.id != shape.iadd_id) |
| 1302 | .cloned() |
| 1303 | .collect(); |
| 1304 | // For reduction loops, track the running accumulator. After the |
| 1305 | // original (k=0) body, the running acc is the original new_acc. |
| 1306 | // Each subsequent clone substitutes acc_param → prev_acc and |
| 1307 | // produces a fresh new_acc (the cloned counterpart of orig new_acc). |
| 1308 | let mut prev_acc: Option<ValueId> = shape.reduction.map(|r| r.new_acc); |
| 1309 | for k in 1..shape.u { |
| 1310 | // Allocate the const(k) and the iv_k = iadd iv, k. |
| 1311 | let k_const_id = func.next_value_id(); |
| 1312 | func.register_type(k_const_id, shape.iv_ty.clone()); |
| 1313 | let iv_k_id = func.next_value_id(); |
| 1314 | func.register_type(iv_k_id, shape.iv_ty.clone()); |
| 1315 | new_insts.push(Inst { |
| 1316 | id: k_const_id, |
| 1317 | kind: InstKind::ConstInt(k as i128, iv_width), |
| 1318 | ty: shape.iv_ty.clone(), |
| 1319 | span, |
| 1320 | }); |
| 1321 | new_insts.push(Inst { |
| 1322 | id: iv_k_id, |
| 1323 | kind: InstKind::IAdd(shape.iv_param, k_const_id), |
| 1324 | ty: shape.iv_ty.clone(), |
| 1325 | span, |
| 1326 | }); |
| 1327 | // Clone the body (sans iadd) substituting iv → iv_k and |
| 1328 | // (for reductions) acc_param → prev_acc; each inst id → |
| 1329 | // fresh id. |
| 1330 | let mut subst: HashMap<ValueId, ValueId> = HashMap::new(); |
| 1331 | subst.insert(shape.iv_param, iv_k_id); |
| 1332 | if let (Some(r), Some(prev)) = (shape.reduction, prev_acc) { |
| 1333 | subst.insert(r.acc_param, prev); |
| 1334 | } |
| 1335 | for orig in &body_inst_clone { |
| 1336 | let new_id = func.next_value_id(); |
| 1337 | func.register_type(new_id, orig.ty.clone()); |
| 1338 | subst.insert(orig.id, new_id); |
| 1339 | let new_kind = remap_kind(&orig.kind, &subst); |
| 1340 | new_insts.push(Inst { |
| 1341 | id: new_id, |
| 1342 | kind: new_kind, |
| 1343 | ty: orig.ty.clone(), |
| 1344 | span: orig.span, |
| 1345 | }); |
| 1346 | } |
| 1347 | // Walk forward the accumulator: prev_acc becomes the cloned |
| 1348 | // counterpart of orig.new_acc. |
| 1349 | if let Some(r) = shape.reduction { |
| 1350 | prev_acc = subst.get(&r.new_acc).copied(); |
| 1351 | } |
| 1352 | } |
| 1353 | // Finally, the new iadd: iv_next = iadd iv, U. |
| 1354 | let new_step_const = func.next_value_id(); |
| 1355 | func.register_type(new_step_const, shape.iv_ty.clone()); |
| 1356 | new_insts.push(Inst { |
| 1357 | id: new_step_const, |
| 1358 | kind: InstKind::ConstInt(shape.u as i128, iv_width), |
| 1359 | ty: shape.iv_ty.clone(), |
| 1360 | span, |
| 1361 | }); |
| 1362 | // Replace iadd's step constant. Reuse the existing iadd id so |
| 1363 | // the latch terminator (which references it) stays valid. |
| 1364 | let new_iadd = Inst { |
| 1365 | id: shape.iadd_id, |
| 1366 | kind: InstKind::IAdd(shape.iv_param, new_step_const), |
| 1367 | ty: shape.iv_ty.clone(), |
| 1368 | span, |
| 1369 | }; |
| 1370 | new_insts.push(new_iadd); |
| 1371 | |
| 1372 | // Swap the latch's instruction list. |
| 1373 | func.block_mut(shape.latch).insts = new_insts; |
| 1374 | |
| 1375 | // For reduction loops, retarget the latch terminator's args[1] |
| 1376 | // to the FINAL accumulator (after U-1 clones). |
| 1377 | if let (Some(_r), Some(final_acc)) = (shape.reduction, prev_acc) { |
| 1378 | let latch_blk = func.block_mut(shape.latch); |
| 1379 | if let Some(Terminator::Branch(_, args)) = &mut latch_blk.terminator { |
| 1380 | if args.len() == 2 { |
| 1381 | args[1] = final_acc; |
| 1382 | } |
| 1383 | } |
| 1384 | } |
| 1385 | } |
| 1386 | |
| 1387 | // --------------------------------------------------------------------------- |
| 1388 | // Runtime-trip partial unroll |
| 1389 | // --------------------------------------------------------------------------- |
| 1390 | // |
| 1391 | // `detect_partial_unroll_loop` requires both `iv_init` and `iv_bound` |
| 1392 | // to be compile-time constants. Many real Fortran loops have runtime |
| 1393 | // trip counts (e.g., `do i = 1, n` where `n` is a subroutine arg). |
| 1394 | // For those we run a U-way unrolled main loop over a head range |
| 1395 | // `[init, init + head_count - 1]` and then a scalar remainder loop |
| 1396 | // over the leftover `[init + head_count, bound]`. The head_count is |
| 1397 | // computed at the top of the preheader as `((bound - init + 1) / U) |
| 1398 | // * U` using runtime arithmetic. |
| 1399 | // |
| 1400 | // v0 limitations: |
| 1401 | // * `iv_init` is still a compile-time constant. |
| 1402 | // * Non-reduction only (no acc threading across remainder). |
| 1403 | // * 2-block loop (header + latch). |
| 1404 | // * Body has at most one store (same regalloc-pressure heuristic |
| 1405 | // as the static partial unroller — see commit 1b94a26). |
| 1406 | |
| 1407 | #[derive(Debug)] |
| 1408 | struct PartialRuntimeShape { |
| 1409 | preheader: BlockId, |
| 1410 | header: BlockId, |
| 1411 | latch: BlockId, |
| 1412 | /// Original loop exit (cond_br false-target on the header). |
| 1413 | exit: BlockId, |
| 1414 | iv_param: ValueId, |
| 1415 | iv_ty: IrType, |
| 1416 | iv_init: i64, |
| 1417 | /// Runtime SSA value of the loop bound (the rhs of `icmp.le iv, ?` |
| 1418 | /// or equivalent). Must be loop-invariant. |
| 1419 | iv_bound_v: ValueId, |
| 1420 | iv_width: IntWidth, |
| 1421 | /// Header's icmp inst id. |
| 1422 | cond_id: ValueId, |
| 1423 | /// Latch's `iadd iv, 1` inst id. |
| 1424 | iadd_id: ValueId, |
| 1425 | /// Const(1) feeding the iadd. |
| 1426 | step_const_id: ValueId, |
| 1427 | u: usize, |
| 1428 | /// Optional reduction. When `Some`, the header has 2 params |
| 1429 | /// (iv + acc); the latch passes (iv_next, new_acc) to header. |
| 1430 | /// Acc must be threaded through the unrolled body and the |
| 1431 | /// remainder loop. |
| 1432 | reduction: Option<RuntimeReduction>, |
| 1433 | } |
| 1434 | |
| 1435 | #[derive(Debug, Clone)] |
| 1436 | struct RuntimeReduction { |
| 1437 | /// Header's accumulator block-param value id. |
| 1438 | acc_param: ValueId, |
| 1439 | /// Type of the accumulator. |
| 1440 | acc_ty: IrType, |
| 1441 | /// Initial accumulator value (preheader branch args[acc_idx]). |
| 1442 | acc_init: ValueId, |
| 1443 | /// New accumulator value computed in the latch (latch terminator |
| 1444 | /// args[acc_idx]) — typically a fadd / iadd / select chain. |
| 1445 | new_acc: ValueId, |
| 1446 | /// Position of the accumulator in header.params and the latch / |
| 1447 | /// preheader terminator args. mem2reg's ordering of phi nodes is |
| 1448 | /// not stable across loops — sometimes acc is at idx 0, iv at |
| 1449 | /// idx 1. |
| 1450 | acc_idx: usize, |
| 1451 | /// True when the original header's cond_br false_args carries |
| 1452 | /// the acc as a block-arg to exit. False when the exit references |
| 1453 | /// acc_param via dominance only (no block-arg passing). Affects |
| 1454 | /// how the apply path forwards the final acc to the original |
| 1455 | /// exit after rewiring the CFG through the remainder loop. |
| 1456 | exit_takes_acc: bool, |
| 1457 | } |
| 1458 | |
| 1459 | fn detect_partial_unroll_runtime_loop( |
| 1460 | func: &Function, |
| 1461 | nl: &NaturalLoop, |
| 1462 | preds: &HashMap<BlockId, Vec<BlockId>>, |
| 1463 | ) -> Option<PartialRuntimeShape> { |
| 1464 | if nl.latches.len() != 1 || nl.body.len() != 2 { |
| 1465 | return None; |
| 1466 | } |
| 1467 | let header = nl.header; |
| 1468 | let latch = nl.latches[0]; |
| 1469 | if header == latch { |
| 1470 | return None; |
| 1471 | } |
| 1472 | let hdr = func.block(header); |
| 1473 | // Header may be iv-only (no reduction) or iv + acc (reduction). |
| 1474 | if hdr.params.len() != 1 && hdr.params.len() != 2 { |
| 1475 | return None; |
| 1476 | } |
| 1477 | // Find the IV by inspecting the header's icmp — the param that |
| 1478 | // appears as the LHS of `icmp.le|.lt _, bound` is the IV. The |
| 1479 | // other param (when 2 params) is the accumulator. |
| 1480 | let icmp_lhs: Option<ValueId> = hdr.terminator.as_ref().and_then(|t| { |
| 1481 | if let Terminator::CondBranch { cond, .. } = t { |
| 1482 | hdr.insts.iter().find(|i| i.id == *cond).and_then(|c| match c.kind { |
| 1483 | InstKind::ICmp(CmpOp::Le, lhs, _) | InstKind::ICmp(CmpOp::Lt, lhs, _) => { |
| 1484 | Some(lhs) |
| 1485 | } |
| 1486 | _ => None, |
| 1487 | }) |
| 1488 | } else { |
| 1489 | None |
| 1490 | } |
| 1491 | }); |
| 1492 | let icmp_lhs = icmp_lhs?; |
| 1493 | let iv_idx = hdr.params.iter().position(|p| p.id == icmp_lhs)?; |
| 1494 | let iv_param = hdr.params[iv_idx].id; |
| 1495 | let iv_ty = hdr.params[iv_idx].ty.clone(); |
| 1496 | let iv_width = match iv_ty { |
| 1497 | IrType::Int(w) => w, |
| 1498 | _ => return None, |
| 1499 | }; |
| 1500 | let acc_param_info: Option<(ValueId, IrType)> = if hdr.params.len() == 2 { |
| 1501 | let other_idx = 1 - iv_idx; |
| 1502 | let p = &hdr.params[other_idx]; |
| 1503 | Some((p.id, p.ty.clone())) |
| 1504 | } else { |
| 1505 | None |
| 1506 | }; |
| 1507 | let acc_idx: Option<usize> = if hdr.params.len() == 2 { Some(1 - iv_idx) } else { None }; |
| 1508 | // Header's icmp pattern. Accept icmp.le or icmp.lt with the IV on |
| 1509 | // the LHS; capture the rhs ValueId regardless of whether it's a |
| 1510 | // const. |
| 1511 | let (cond_id, true_dest, false_dest, rhs_v) = { |
| 1512 | let term = hdr.terminator.as_ref()?; |
| 1513 | let (cond, td, fd) = match term { |
| 1514 | Terminator::CondBranch { |
| 1515 | cond, |
| 1516 | true_dest, |
| 1517 | true_args, |
| 1518 | false_dest, |
| 1519 | false_args, |
| 1520 | } if true_args.is_empty() && false_args.is_empty() => (*cond, *true_dest, *false_dest), |
| 1521 | _ => return None, |
| 1522 | }; |
| 1523 | let cmp = hdr.insts.iter().find(|i| i.id == cond)?; |
| 1524 | let rhs = match cmp.kind { |
| 1525 | InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => rhs, |
| 1526 | InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => rhs, |
| 1527 | _ => return None, |
| 1528 | }; |
| 1529 | (cond, td, fd, rhs) |
| 1530 | }; |
| 1531 | if !nl.body.contains(&true_dest) || nl.body.contains(&false_dest) { |
| 1532 | return None; |
| 1533 | } |
| 1534 | let exit = false_dest; |
| 1535 | // The exit's args from header: must be empty for the iv-only |
| 1536 | // form. For the reduction form, accept either [] (acc_param is |
| 1537 | // referenced by the exit via dominance — common Fortran lowering) |
| 1538 | // OR [acc_param] (canonical mem2reg shape). |
| 1539 | let (exit_args_ok, exit_takes_acc): (bool, bool) = hdr |
| 1540 | .terminator |
| 1541 | .as_ref() |
| 1542 | .map(|t| { |
| 1543 | if let Terminator::CondBranch { false_args, .. } = t { |
| 1544 | if let Some((acc_p, _)) = &acc_param_info { |
| 1545 | if false_args.is_empty() { |
| 1546 | (true, false) |
| 1547 | } else if false_args.len() == 1 && false_args[0] == *acc_p { |
| 1548 | (true, true) |
| 1549 | } else { |
| 1550 | (false, false) |
| 1551 | } |
| 1552 | } else { |
| 1553 | (false_args.is_empty(), false) |
| 1554 | } |
| 1555 | } else { |
| 1556 | (false, false) |
| 1557 | } |
| 1558 | }) |
| 1559 | .unwrap_or((false, false)); |
| 1560 | if !exit_args_ok { |
| 1561 | return None; |
| 1562 | } |
| 1563 | // Bound must NOT be a compile-time constant — that case is handled |
| 1564 | // by the static detector. We require the bound to be loop-invariant |
| 1565 | // (defined outside the loop body). |
| 1566 | if resolve_const_int(func, rhs_v).is_some() { |
| 1567 | return None; |
| 1568 | } |
| 1569 | let body_set: HashSet<BlockId> = nl.body.iter().copied().collect(); |
| 1570 | let bound_in_body = body_set.iter().any(|b| { |
| 1571 | let blk = func.block(*b); |
| 1572 | blk.params.iter().any(|p| p.id == rhs_v) || blk.insts.iter().any(|i| i.id == rhs_v) |
| 1573 | }); |
| 1574 | if bound_in_body { |
| 1575 | return None; |
| 1576 | } |
| 1577 | // iv_init from preheader (still requires constant init for v0). |
| 1578 | let preheader = { |
| 1579 | let header_preds = preds.get(&header)?; |
| 1580 | let mut outside: Vec<BlockId> = header_preds |
| 1581 | .iter() |
| 1582 | .copied() |
| 1583 | .filter(|p| !nl.body.contains(p)) |
| 1584 | .collect(); |
| 1585 | outside.sort_by_key(|b| b.0); |
| 1586 | outside.dedup(); |
| 1587 | if outside.len() != 1 { |
| 1588 | return None; |
| 1589 | } |
| 1590 | outside[0] |
| 1591 | }; |
| 1592 | let ph = func.block(preheader); |
| 1593 | let want_arity = if acc_param_info.is_some() { 2 } else { 1 }; |
| 1594 | let (iv_init, acc_init_v) = match &ph.terminator { |
| 1595 | Some(Terminator::Branch(d, args)) |
| 1596 | if *d == header && args.len() == want_arity => |
| 1597 | { |
| 1598 | let init = resolve_const_int(func, args[iv_idx])?; |
| 1599 | let acc_init = acc_idx.map(|i| args[i]); |
| 1600 | (init, acc_init) |
| 1601 | } |
| 1602 | _ => return None, |
| 1603 | }; |
| 1604 | // Latch must end with `br header(iv_next)` (iv-only) or |
| 1605 | // `br header(iv_next, new_acc)` (reduction) — args ordering |
| 1606 | // matches header.params ordering. |
| 1607 | let latch_blk = func.block(latch); |
| 1608 | let latch_term_args = match &latch_blk.terminator { |
| 1609 | Some(Terminator::Branch(d, args)) |
| 1610 | if *d == header && args.len() == want_arity => |
| 1611 | { |
| 1612 | args.clone() |
| 1613 | } |
| 1614 | _ => return None, |
| 1615 | }; |
| 1616 | let iv_next = latch_term_args[iv_idx]; |
| 1617 | let new_acc_v = acc_idx.map(|i| latch_term_args[i]); |
| 1618 | let iadd_inst = latch_blk.insts.iter().find(|i| i.id == iv_next)?; |
| 1619 | let (lhs, rhs) = match iadd_inst.kind { |
| 1620 | InstKind::IAdd(l, r) => (l, r), |
| 1621 | _ => return None, |
| 1622 | }; |
| 1623 | let step_const_id = if lhs == iv_param { |
| 1624 | rhs |
| 1625 | } else if rhs == iv_param { |
| 1626 | lhs |
| 1627 | } else { |
| 1628 | return None; |
| 1629 | }; |
| 1630 | if resolve_const_int(func, step_const_id) != Some(1) { |
| 1631 | return None; |
| 1632 | } |
| 1633 | if has_escaping_values(func, latch, &body_set, preds) { |
| 1634 | return None; |
| 1635 | } |
| 1636 | let body_inst_count = latch_blk.insts.len(); |
| 1637 | if body_inst_count == 0 { |
| 1638 | return None; |
| 1639 | } |
| 1640 | // Multi-store regalloc-pressure heuristic (mirrors the static path). |
| 1641 | let store_count = latch_blk |
| 1642 | .insts |
| 1643 | .iter() |
| 1644 | .filter(|i| matches!(i.kind, InstKind::Store(..))) |
| 1645 | .count(); |
| 1646 | if store_count > 1 { |
| 1647 | return None; |
| 1648 | } |
| 1649 | // Pick the largest U ≤ MAX_FACTOR within body budget. Trip not |
| 1650 | // required to be divisible — we'll generate a remainder loop. |
| 1651 | let mut chosen_u: Option<usize> = None; |
| 1652 | for u in (2..=PARTIAL_UNROLL_MAX_FACTOR).rev() { |
| 1653 | if body_inst_count * u > PARTIAL_UNROLL_BODY_BUDGET { |
| 1654 | continue; |
| 1655 | } |
| 1656 | chosen_u = Some(u); |
| 1657 | break; |
| 1658 | } |
| 1659 | let u = chosen_u?; |
| 1660 | let reduction = match (acc_param_info, acc_init_v, new_acc_v, acc_idx) { |
| 1661 | (Some((acc_param, acc_ty)), Some(acc_init), Some(new_acc), Some(idx)) => { |
| 1662 | Some(RuntimeReduction { |
| 1663 | acc_param, |
| 1664 | acc_ty, |
| 1665 | acc_init, |
| 1666 | new_acc, |
| 1667 | acc_idx: idx, |
| 1668 | exit_takes_acc, |
| 1669 | }) |
| 1670 | } |
| 1671 | _ => None, |
| 1672 | }; |
| 1673 | Some(PartialRuntimeShape { |
| 1674 | preheader, |
| 1675 | header, |
| 1676 | latch, |
| 1677 | exit, |
| 1678 | iv_param, |
| 1679 | iv_ty, |
| 1680 | iv_init, |
| 1681 | iv_bound_v: rhs_v, |
| 1682 | iv_width, |
| 1683 | cond_id, |
| 1684 | iadd_id: iv_next, |
| 1685 | step_const_id, |
| 1686 | u, |
| 1687 | reduction, |
| 1688 | }) |
| 1689 | } |
| 1690 | |
| 1691 | fn do_partial_unroll_runtime(func: &mut Function, shape: PartialRuntimeShape) { |
| 1692 | let span = dummy_span(); |
| 1693 | let int_ty = IrType::Int(shape.iv_width); |
| 1694 | let bool_ty = IrType::Bool; |
| 1695 | |
| 1696 | // ---- Snapshot the original body BEFORE we modify the latch ---- |
| 1697 | let body_snapshot: Vec<Inst> = func.block(shape.latch).insts.clone(); |
| 1698 | let body_no_iadd: Vec<Inst> = body_snapshot |
| 1699 | .iter() |
| 1700 | .filter(|i| i.id != shape.iadd_id) |
| 1701 | .cloned() |
| 1702 | .collect(); |
| 1703 | |
| 1704 | // ---- 1. Emit head_bound runtime arithmetic in the preheader ---- |
| 1705 | // |
| 1706 | // trip = (bound - init) + 1 |
| 1707 | // head_count = trip - (trip mod U) |
| 1708 | // head_bound = init + (head_count - 1) |
| 1709 | let init_const_id = func.next_value_id(); |
| 1710 | let one_id = func.next_value_id(); |
| 1711 | let u_const_id = func.next_value_id(); |
| 1712 | let trip_minus_init_id = func.next_value_id(); |
| 1713 | let trip_id = func.next_value_id(); |
| 1714 | let mod_id = func.next_value_id(); |
| 1715 | let head_count_id = func.next_value_id(); |
| 1716 | let head_count_minus_one_id = func.next_value_id(); |
| 1717 | let head_bound_id = func.next_value_id(); |
| 1718 | for v in [ |
| 1719 | init_const_id, |
| 1720 | one_id, |
| 1721 | u_const_id, |
| 1722 | trip_minus_init_id, |
| 1723 | trip_id, |
| 1724 | mod_id, |
| 1725 | head_count_id, |
| 1726 | head_count_minus_one_id, |
| 1727 | head_bound_id, |
| 1728 | ] { |
| 1729 | func.register_type(v, int_ty.clone()); |
| 1730 | } |
| 1731 | let preheader_new_insts = vec![ |
| 1732 | Inst { |
| 1733 | id: init_const_id, |
| 1734 | kind: InstKind::ConstInt(shape.iv_init as i128, shape.iv_width), |
| 1735 | ty: int_ty.clone(), |
| 1736 | span, |
| 1737 | }, |
| 1738 | Inst { |
| 1739 | id: one_id, |
| 1740 | kind: InstKind::ConstInt(1, shape.iv_width), |
| 1741 | ty: int_ty.clone(), |
| 1742 | span, |
| 1743 | }, |
| 1744 | Inst { |
| 1745 | id: u_const_id, |
| 1746 | kind: InstKind::ConstInt(shape.u as i128, shape.iv_width), |
| 1747 | ty: int_ty.clone(), |
| 1748 | span, |
| 1749 | }, |
| 1750 | Inst { |
| 1751 | id: trip_minus_init_id, |
| 1752 | kind: InstKind::ISub(shape.iv_bound_v, init_const_id), |
| 1753 | ty: int_ty.clone(), |
| 1754 | span, |
| 1755 | }, |
| 1756 | Inst { |
| 1757 | id: trip_id, |
| 1758 | kind: InstKind::IAdd(trip_minus_init_id, one_id), |
| 1759 | ty: int_ty.clone(), |
| 1760 | span, |
| 1761 | }, |
| 1762 | Inst { |
| 1763 | id: mod_id, |
| 1764 | kind: InstKind::IMod(trip_id, u_const_id), |
| 1765 | ty: int_ty.clone(), |
| 1766 | span, |
| 1767 | }, |
| 1768 | Inst { |
| 1769 | id: head_count_id, |
| 1770 | kind: InstKind::ISub(trip_id, mod_id), |
| 1771 | ty: int_ty.clone(), |
| 1772 | span, |
| 1773 | }, |
| 1774 | Inst { |
| 1775 | id: head_count_minus_one_id, |
| 1776 | kind: InstKind::ISub(head_count_id, one_id), |
| 1777 | ty: int_ty.clone(), |
| 1778 | span, |
| 1779 | }, |
| 1780 | Inst { |
| 1781 | id: head_bound_id, |
| 1782 | kind: InstKind::IAdd(init_const_id, head_count_minus_one_id), |
| 1783 | ty: int_ty.clone(), |
| 1784 | span, |
| 1785 | }, |
| 1786 | ]; |
| 1787 | { |
| 1788 | let pre = func.block_mut(shape.preheader); |
| 1789 | // Insert immediately before the terminator (after any existing insts). |
| 1790 | let pos = pre.insts.len(); |
| 1791 | for (i, inst) in preheader_new_insts.into_iter().enumerate() { |
| 1792 | pre.insts.insert(pos + i, inst); |
| 1793 | } |
| 1794 | } |
| 1795 | |
| 1796 | // ---- 2. Rewire header's icmp RHS to head_bound ---- |
| 1797 | { |
| 1798 | let hdr = func.block_mut(shape.header); |
| 1799 | if let Some(inst) = hdr.insts.iter_mut().find(|i| i.id == shape.cond_id) { |
| 1800 | if let InstKind::ICmp(_, _, ref mut rhs) = inst.kind { |
| 1801 | *rhs = head_bound_id; |
| 1802 | } |
| 1803 | } |
| 1804 | } |
| 1805 | |
| 1806 | // ---- 3. Allocate remainder loop blocks ---- |
| 1807 | let header_remain = func.create_block("partial_remain_header"); |
| 1808 | let latch_remain = func.create_block("partial_remain_latch"); |
| 1809 | let remain_iv_id = func.next_value_id(); |
| 1810 | func.register_type(remain_iv_id, int_ty.clone()); |
| 1811 | func.block_mut(header_remain).params.push(BlockParam { |
| 1812 | id: remain_iv_id, |
| 1813 | ty: int_ty.clone(), |
| 1814 | }); |
| 1815 | // Reduction: header_remain also takes an acc param. |
| 1816 | let remain_acc_id: Option<ValueId> = if let Some(r) = &shape.reduction { |
| 1817 | let id = func.next_value_id(); |
| 1818 | func.register_type(id, r.acc_ty.clone()); |
| 1819 | func.block_mut(header_remain).params.push(BlockParam { |
| 1820 | id, |
| 1821 | ty: r.acc_ty.clone(), |
| 1822 | }); |
| 1823 | Some(id) |
| 1824 | } else { |
| 1825 | None |
| 1826 | }; |
| 1827 | |
| 1828 | // ---- 4. Rewire header's cond_br false-target from exit to header_remain ---- |
| 1829 | // When reducing, also forward the acc value into header_remain. |
| 1830 | // header_remain's params are [iv, acc] in that order. |
| 1831 | let original_exit = shape.exit; |
| 1832 | { |
| 1833 | let hdr = func.block_mut(shape.header); |
| 1834 | if let Some(Terminator::CondBranch { |
| 1835 | false_dest, |
| 1836 | false_args, |
| 1837 | .. |
| 1838 | }) = hdr.terminator.as_mut() |
| 1839 | { |
| 1840 | *false_dest = header_remain; |
| 1841 | *false_args = if let Some(r) = &shape.reduction { |
| 1842 | vec![shape.iv_param, r.acc_param] |
| 1843 | } else { |
| 1844 | vec![shape.iv_param] |
| 1845 | }; |
| 1846 | } |
| 1847 | } |
| 1848 | |
| 1849 | // ---- 5. Build the U-way unrolled main-loop body in the latch ---- |
| 1850 | // When reducing, thread the acc through clones (each clone's |
| 1851 | // acc_param substitution is the previous clone's new_acc). |
| 1852 | let mut new_main_insts: Vec<Inst> = body_no_iadd.clone(); |
| 1853 | let mut prev_acc: Option<ValueId> = shape.reduction.as_ref().map(|r| r.new_acc); |
| 1854 | for k in 1..shape.u { |
| 1855 | let k_const_id = func.next_value_id(); |
| 1856 | func.register_type(k_const_id, int_ty.clone()); |
| 1857 | let iv_k_id = func.next_value_id(); |
| 1858 | func.register_type(iv_k_id, int_ty.clone()); |
| 1859 | new_main_insts.push(Inst { |
| 1860 | id: k_const_id, |
| 1861 | kind: InstKind::ConstInt(k as i128, shape.iv_width), |
| 1862 | ty: int_ty.clone(), |
| 1863 | span, |
| 1864 | }); |
| 1865 | new_main_insts.push(Inst { |
| 1866 | id: iv_k_id, |
| 1867 | kind: InstKind::IAdd(shape.iv_param, k_const_id), |
| 1868 | ty: int_ty.clone(), |
| 1869 | span, |
| 1870 | }); |
| 1871 | let mut subst: HashMap<ValueId, ValueId> = HashMap::new(); |
| 1872 | subst.insert(shape.iv_param, iv_k_id); |
| 1873 | // Reduction: substitute acc_param with the running prev_acc so |
| 1874 | // this clone consumes the previous iteration's accumulator. |
| 1875 | if let (Some(r), Some(prev)) = (shape.reduction.as_ref(), prev_acc) { |
| 1876 | subst.insert(r.acc_param, prev); |
| 1877 | } |
| 1878 | let mut clone_new_acc: Option<ValueId> = None; |
| 1879 | for orig in &body_no_iadd { |
| 1880 | let new_id = func.next_value_id(); |
| 1881 | func.register_type(new_id, orig.ty.clone()); |
| 1882 | subst.insert(orig.id, new_id); |
| 1883 | let new_kind = remap_kind(&orig.kind, &subst); |
| 1884 | new_main_insts.push(Inst { |
| 1885 | id: new_id, |
| 1886 | kind: new_kind, |
| 1887 | ty: orig.ty.clone(), |
| 1888 | span: orig.span, |
| 1889 | }); |
| 1890 | // Track this clone's new_acc id for the next iteration's |
| 1891 | // substitution. |
| 1892 | if let Some(r) = shape.reduction.as_ref() { |
| 1893 | if orig.id == r.new_acc { |
| 1894 | clone_new_acc = Some(new_id); |
| 1895 | } |
| 1896 | } |
| 1897 | } |
| 1898 | if shape.reduction.is_some() { |
| 1899 | prev_acc = clone_new_acc; |
| 1900 | } |
| 1901 | } |
| 1902 | // New iadd: iv_next = iadd iv, U. Reuse the original iadd id. |
| 1903 | let new_step_const = func.next_value_id(); |
| 1904 | func.register_type(new_step_const, int_ty.clone()); |
| 1905 | new_main_insts.push(Inst { |
| 1906 | id: new_step_const, |
| 1907 | kind: InstKind::ConstInt(shape.u as i128, shape.iv_width), |
| 1908 | ty: int_ty.clone(), |
| 1909 | span, |
| 1910 | }); |
| 1911 | new_main_insts.push(Inst { |
| 1912 | id: shape.iadd_id, |
| 1913 | kind: InstKind::IAdd(shape.iv_param, new_step_const), |
| 1914 | ty: int_ty.clone(), |
| 1915 | span, |
| 1916 | }); |
| 1917 | func.block_mut(shape.latch).insts = new_main_insts; |
| 1918 | // When reducing, rewrite the latch terminator's args[acc_idx] to |
| 1919 | // the FINAL clone's new_acc (so each main-loop iteration commits |
| 1920 | // the accumulated U-way result instead of the original first- |
| 1921 | // iteration new_acc). |
| 1922 | if let (Some(r), Some(final_acc)) = (shape.reduction.as_ref(), prev_acc) { |
| 1923 | let latch_blk = func.block_mut(shape.latch); |
| 1924 | if let Some(Terminator::Branch(_, args)) = latch_blk.terminator.as_mut() { |
| 1925 | if args.len() == 2 { |
| 1926 | args[r.acc_idx] = final_acc; |
| 1927 | } |
| 1928 | } |
| 1929 | } |
| 1930 | |
| 1931 | // ---- 6. Build header_remain ---- |
| 1932 | let remain_cmp_id = func.next_value_id(); |
| 1933 | func.register_type(remain_cmp_id, bool_ty.clone()); |
| 1934 | func.block_mut(header_remain).insts.push(Inst { |
| 1935 | id: remain_cmp_id, |
| 1936 | kind: InstKind::ICmp(CmpOp::Le, remain_iv_id, shape.iv_bound_v), |
| 1937 | ty: bool_ty.clone(), |
| 1938 | span, |
| 1939 | }); |
| 1940 | // When the original exit takes acc as a false_arg, header_remain |
| 1941 | // must forward the final acc that way. Otherwise (acc-via- |
| 1942 | // dominance lowering), the exit takes no args and the final acc |
| 1943 | // is forwarded by rewriting acc_param uses in step 8 below. |
| 1944 | let exit_args_remain: Vec<ValueId> = match ( |
| 1945 | remain_acc_id, |
| 1946 | shape.reduction.as_ref().map(|r| r.exit_takes_acc), |
| 1947 | ) { |
| 1948 | (Some(acc), Some(true)) => vec![acc], |
| 1949 | _ => vec![], |
| 1950 | }; |
| 1951 | func.block_mut(header_remain).terminator = Some(Terminator::CondBranch { |
| 1952 | cond: remain_cmp_id, |
| 1953 | true_dest: latch_remain, |
| 1954 | true_args: vec![], |
| 1955 | false_dest: original_exit, |
| 1956 | false_args: exit_args_remain, |
| 1957 | }); |
| 1958 | |
| 1959 | // ---- 7. Build latch_remain (scalar 1-iter copy of original body) ---- |
| 1960 | // Reduction: substitute acc_param → remain_acc_id so the |
| 1961 | // body computes new_acc on the running accumulator. |
| 1962 | let mut subst: HashMap<ValueId, ValueId> = HashMap::new(); |
| 1963 | subst.insert(shape.iv_param, remain_iv_id); |
| 1964 | if let (Some(r), Some(remain_acc)) = (shape.reduction.as_ref(), remain_acc_id) { |
| 1965 | subst.insert(r.acc_param, remain_acc); |
| 1966 | } |
| 1967 | let mut remain_insts: Vec<Inst> = Vec::new(); |
| 1968 | let mut remain_new_acc: Option<ValueId> = None; |
| 1969 | for orig in &body_no_iadd { |
| 1970 | let new_id = func.next_value_id(); |
| 1971 | func.register_type(new_id, orig.ty.clone()); |
| 1972 | subst.insert(orig.id, new_id); |
| 1973 | let new_kind = remap_kind(&orig.kind, &subst); |
| 1974 | remain_insts.push(Inst { |
| 1975 | id: new_id, |
| 1976 | kind: new_kind, |
| 1977 | ty: orig.ty.clone(), |
| 1978 | span: orig.span, |
| 1979 | }); |
| 1980 | if let Some(r) = shape.reduction.as_ref() { |
| 1981 | if orig.id == r.new_acc { |
| 1982 | remain_new_acc = Some(new_id); |
| 1983 | } |
| 1984 | } |
| 1985 | } |
| 1986 | let remain_step_const_id = func.next_value_id(); |
| 1987 | func.register_type(remain_step_const_id, int_ty.clone()); |
| 1988 | let remain_iv_next_id = func.next_value_id(); |
| 1989 | func.register_type(remain_iv_next_id, int_ty.clone()); |
| 1990 | remain_insts.push(Inst { |
| 1991 | id: remain_step_const_id, |
| 1992 | kind: InstKind::ConstInt(1, shape.iv_width), |
| 1993 | ty: int_ty.clone(), |
| 1994 | span, |
| 1995 | }); |
| 1996 | remain_insts.push(Inst { |
| 1997 | id: remain_iv_next_id, |
| 1998 | kind: InstKind::IAdd(remain_iv_id, remain_step_const_id), |
| 1999 | ty: int_ty.clone(), |
| 2000 | span, |
| 2001 | }); |
| 2002 | func.block_mut(latch_remain).insts = remain_insts; |
| 2003 | let latch_remain_args: Vec<ValueId> = if let Some(acc) = remain_new_acc { |
| 2004 | vec![remain_iv_next_id, acc] |
| 2005 | } else { |
| 2006 | vec![remain_iv_next_id] |
| 2007 | }; |
| 2008 | func.block_mut(latch_remain).terminator = |
| 2009 | Some(Terminator::Branch(header_remain, latch_remain_args)); |
| 2010 | |
| 2011 | // ---- 8. When the original exit references acc_param via dominance |
| 2012 | // (rather than via false_args), the head loop's acc value |
| 2013 | // that flows out of `header` is now the U-way unrolled head |
| 2014 | // result — but the remainder loop continues the |
| 2015 | // accumulation. Rewrite uses of acc_param outside the loop |
| 2016 | // body to use `remain_acc_id` (the final acc from |
| 2017 | // header_remain). header_remain dominates everything that |
| 2018 | // used to be dominated by header (along the false branch). |
| 2019 | if let (Some(r), Some(remain_acc)) = (shape.reduction.as_ref(), remain_acc_id) { |
| 2020 | if !r.exit_takes_acc { |
| 2021 | let body_set: HashSet<BlockId> = |
| 2022 | [shape.header, shape.latch].iter().copied().collect(); |
| 2023 | for block in func.blocks.iter_mut() { |
| 2024 | if body_set.contains(&block.id) || block.id == shape.preheader { |
| 2025 | continue; |
| 2026 | } |
| 2027 | if block.id == header_remain || block.id == latch_remain { |
| 2028 | continue; |
| 2029 | } |
| 2030 | for inst in block.insts.iter_mut() { |
| 2031 | inst.kind = remap_kind_single(&inst.kind, r.acc_param, remain_acc); |
| 2032 | } |
| 2033 | if let Some(term) = block.terminator.as_mut() { |
| 2034 | remap_term_single(term, r.acc_param, remain_acc); |
| 2035 | } |
| 2036 | } |
| 2037 | } |
| 2038 | } |
| 2039 | } |
| 2040 | |
| 2041 | /// Replace every occurrence of `from` in an `InstKind`'s value |
| 2042 | /// operands with `to`. Forwards through `remap_kind` by building a |
| 2043 | /// 1-entry HashMap. |
| 2044 | fn remap_kind_single(kind: &InstKind, from: ValueId, to: ValueId) -> InstKind { |
| 2045 | let mut subst: HashMap<ValueId, ValueId> = HashMap::new(); |
| 2046 | subst.insert(from, to); |
| 2047 | remap_kind(kind, &subst) |
| 2048 | } |
| 2049 | |
| 2050 | /// Like `remap_kind_single` but for a `Terminator`. Walks both |
| 2051 | /// args lists for `Branch` / `CondBranch`, plus the `cond` operand |
| 2052 | /// for `CondBranch` and the value operand for `Return`. |
| 2053 | fn remap_term_single(term: &mut Terminator, from: ValueId, to: ValueId) { |
| 2054 | let sub = |v: &mut ValueId| { |
| 2055 | if *v == from { |
| 2056 | *v = to; |
| 2057 | } |
| 2058 | }; |
| 2059 | match term { |
| 2060 | Terminator::Branch(_, args) => { |
| 2061 | for v in args.iter_mut() { |
| 2062 | sub(v); |
| 2063 | } |
| 2064 | } |
| 2065 | Terminator::CondBranch { |
| 2066 | cond, |
| 2067 | true_args, |
| 2068 | false_args, |
| 2069 | .. |
| 2070 | } => { |
| 2071 | sub(cond); |
| 2072 | for v in true_args.iter_mut() { |
| 2073 | sub(v); |
| 2074 | } |
| 2075 | for v in false_args.iter_mut() { |
| 2076 | sub(v); |
| 2077 | } |
| 2078 | } |
| 2079 | Terminator::Return(Some(v)) => sub(v), |
| 2080 | Terminator::Return(None) | Terminator::Unreachable => {} |
| 2081 | Terminator::Switch { selector, .. } => sub(selector), |
| 2082 | } |
| 2083 | } |
| 2084 | |
| 2085 | // --------------------------------------------------------------------------- |
| 2086 | // Multi-block partial unroll |
| 2087 | // --------------------------------------------------------------------------- |
| 2088 | // |
| 2089 | // `detect_partial_unroll_loop` rejects loops whose natural body has |
| 2090 | // more than 2 blocks (i.e., header + extra blocks + latch). At -O3 |
| 2091 | // such shapes are rare — SCCP, jump-threading, and simplify-cfg |
| 2092 | // merge most linear chains — but they do arise in less-aggressively- |
| 2093 | // optimized IR (debug builds, certain control-flow patterns the |
| 2094 | // optimizer can't simplify, or third-party IR consumers). For the |
| 2095 | // linear-chain case we can extend the partial-unroll transform by |
| 2096 | // allocating fresh BlockIds for each cloned body iteration and |
| 2097 | // rewiring branches; the latch's iadd still bumps iv by U. |
| 2098 | // |
| 2099 | // v0 limitations: |
| 2100 | // * 3-block loop only (header + body + latch). Generalizing to |
| 2101 | // longer chains is mechanical (clone N-1 blocks per copy) but |
| 2102 | // not yet wired up. |
| 2103 | // * Static trip count (constant header bound). |
| 2104 | // * Non-reduction (iv-only header). |
| 2105 | // * Body's intermediate block has at most 1 store (regalloc- |
| 2106 | // pressure heuristic, mirrors the 2-block path). |
| 2107 | |
| 2108 | #[derive(Debug)] |
| 2109 | struct PartialMultiBlockShape { |
| 2110 | header: BlockId, |
| 2111 | /// The body block (between header and latch). Has unconditional |
| 2112 | /// br to latch. |
| 2113 | body_block: BlockId, |
| 2114 | latch: BlockId, |
| 2115 | iv_param: ValueId, |
| 2116 | iv_ty: IrType, |
| 2117 | iv_init: i64, |
| 2118 | iv_bound: i64, |
| 2119 | iv_width: IntWidth, |
| 2120 | /// Latch's `iadd iv, 1` inst id. |
| 2121 | iadd_id: ValueId, |
| 2122 | /// Const(1) feeding the iadd. |
| 2123 | step_const_id: ValueId, |
| 2124 | u: usize, |
| 2125 | } |
| 2126 | |
| 2127 | fn detect_partial_unroll_multiblock_loop( |
| 2128 | func: &Function, |
| 2129 | nl: &NaturalLoop, |
| 2130 | preds: &HashMap<BlockId, Vec<BlockId>>, |
| 2131 | ) -> Option<PartialMultiBlockShape> { |
| 2132 | if nl.latches.len() != 1 || nl.body.len() != 3 { |
| 2133 | return None; |
| 2134 | } |
| 2135 | let header = nl.header; |
| 2136 | let latch = nl.latches[0]; |
| 2137 | if header == latch { |
| 2138 | return None; |
| 2139 | } |
| 2140 | let hdr = func.block(header); |
| 2141 | if hdr.params.len() != 1 { |
| 2142 | return None; |
| 2143 | } |
| 2144 | let iv_param = hdr.params[0].id; |
| 2145 | let iv_ty = hdr.params[0].ty.clone(); |
| 2146 | let iv_width = match iv_ty { |
| 2147 | IrType::Int(w) => w, |
| 2148 | _ => return None, |
| 2149 | }; |
| 2150 | // Header's cmp + cond_br true→body_block, false→exit. |
| 2151 | let (cond_id, body_block, exit_dest) = match &hdr.terminator { |
| 2152 | Some(Terminator::CondBranch { |
| 2153 | cond, |
| 2154 | true_dest, |
| 2155 | true_args, |
| 2156 | false_dest, |
| 2157 | false_args, |
| 2158 | }) if true_args.is_empty() && false_args.is_empty() => (*cond, *true_dest, *false_dest), |
| 2159 | _ => return None, |
| 2160 | }; |
| 2161 | if !nl.body.contains(&body_block) || nl.body.contains(&exit_dest) || body_block == latch { |
| 2162 | return None; |
| 2163 | } |
| 2164 | // body_block must have no params and unconditional br latch. |
| 2165 | let body_blk = func.block(body_block); |
| 2166 | if !body_blk.params.is_empty() { |
| 2167 | return None; |
| 2168 | } |
| 2169 | match &body_blk.terminator { |
| 2170 | Some(Terminator::Branch(d, args)) if *d == latch && args.is_empty() => {} |
| 2171 | _ => return None, |
| 2172 | } |
| 2173 | // Header's cmp must be `icmp.le iv, const(hi)` or `icmp.lt`. |
| 2174 | let cmp_inst = hdr.insts.iter().find(|i| i.id == cond_id)?; |
| 2175 | let iv_bound = match cmp_inst.kind { |
| 2176 | InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => resolve_const_int(func, rhs)?, |
| 2177 | InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => { |
| 2178 | resolve_const_int(func, rhs)?.checked_sub(1)? |
| 2179 | } |
| 2180 | _ => return None, |
| 2181 | }; |
| 2182 | // Preheader supplies iv_init. |
| 2183 | let preheader = { |
| 2184 | let header_preds = preds.get(&header)?; |
| 2185 | let mut outside: Vec<BlockId> = header_preds |
| 2186 | .iter() |
| 2187 | .copied() |
| 2188 | .filter(|p| !nl.body.contains(p)) |
| 2189 | .collect(); |
| 2190 | outside.sort_by_key(|b| b.0); |
| 2191 | outside.dedup(); |
| 2192 | if outside.len() != 1 { |
| 2193 | return None; |
| 2194 | } |
| 2195 | outside[0] |
| 2196 | }; |
| 2197 | let ph = func.block(preheader); |
| 2198 | let iv_init = match &ph.terminator { |
| 2199 | Some(Terminator::Branch(d, args)) if *d == header && args.len() == 1 => { |
| 2200 | resolve_const_int(func, args[0])? |
| 2201 | } |
| 2202 | _ => return None, |
| 2203 | }; |
| 2204 | // Latch's terminator: iadd iv, 1; br header(iv_next). |
| 2205 | let latch_blk = func.block(latch); |
| 2206 | let latch_term_args = match &latch_blk.terminator { |
| 2207 | Some(Terminator::Branch(d, args)) if *d == header && args.len() == 1 => args.clone(), |
| 2208 | _ => return None, |
| 2209 | }; |
| 2210 | let iv_next = latch_term_args[0]; |
| 2211 | let iadd_inst = latch_blk.insts.iter().find(|i| i.id == iv_next)?; |
| 2212 | let (lhs, rhs) = match iadd_inst.kind { |
| 2213 | InstKind::IAdd(l, r) => (l, r), |
| 2214 | _ => return None, |
| 2215 | }; |
| 2216 | let step_const_id = if lhs == iv_param { |
| 2217 | rhs |
| 2218 | } else if rhs == iv_param { |
| 2219 | lhs |
| 2220 | } else { |
| 2221 | return None; |
| 2222 | }; |
| 2223 | if resolve_const_int(func, step_const_id) != Some(1) { |
| 2224 | return None; |
| 2225 | } |
| 2226 | let body_set: HashSet<BlockId> = nl.body.iter().copied().collect(); |
| 2227 | if has_escaping_values(func, latch, &body_set, preds) { |
| 2228 | return None; |
| 2229 | } |
| 2230 | if has_escaping_values(func, body_block, &body_set, preds) { |
| 2231 | return None; |
| 2232 | } |
| 2233 | // Trip / size gates. Use total inst count across body+latch. |
| 2234 | let trip = iv_bound - iv_init + 1; |
| 2235 | if trip <= FULL_UNROLL_MAX { |
| 2236 | return None; |
| 2237 | } |
| 2238 | let body_count = func.block(body_block).insts.len(); |
| 2239 | let latch_count = latch_blk.insts.len(); |
| 2240 | let total_inst = body_count + latch_count; |
| 2241 | if total_inst == 0 { |
| 2242 | return None; |
| 2243 | } |
| 2244 | // Multi-store gate (across both blocks). |
| 2245 | let store_count = func |
| 2246 | .block(body_block) |
| 2247 | .insts |
| 2248 | .iter() |
| 2249 | .chain(latch_blk.insts.iter()) |
| 2250 | .filter(|i| matches!(i.kind, InstKind::Store(..))) |
| 2251 | .count(); |
| 2252 | if store_count > 1 { |
| 2253 | return None; |
| 2254 | } |
| 2255 | let mut chosen_u: Option<usize> = None; |
| 2256 | for u in (2..=PARTIAL_UNROLL_MAX_FACTOR).rev() { |
| 2257 | if trip % (u as i64) != 0 { |
| 2258 | continue; |
| 2259 | } |
| 2260 | if total_inst * u > PARTIAL_UNROLL_BODY_BUDGET { |
| 2261 | continue; |
| 2262 | } |
| 2263 | chosen_u = Some(u); |
| 2264 | break; |
| 2265 | } |
| 2266 | let u = chosen_u?; |
| 2267 | Some(PartialMultiBlockShape { |
| 2268 | header, |
| 2269 | body_block, |
| 2270 | latch, |
| 2271 | iv_param, |
| 2272 | iv_ty, |
| 2273 | iv_init, |
| 2274 | iv_bound, |
| 2275 | iv_width, |
| 2276 | iadd_id: iv_next, |
| 2277 | step_const_id, |
| 2278 | u, |
| 2279 | }) |
| 2280 | } |
| 2281 | |
| 2282 | fn do_partial_unroll_multiblock(func: &mut Function, shape: PartialMultiBlockShape) { |
| 2283 | let span = dummy_span(); |
| 2284 | let int_ty = IrType::Int(shape.iv_width); |
| 2285 | |
| 2286 | // ---- Snapshot body and latch (sans iadd) BEFORE we modify them ---- |
| 2287 | let body_snapshot: Vec<Inst> = func.block(shape.body_block).insts.clone(); |
| 2288 | let latch_snapshot_no_iadd: Vec<Inst> = func |
| 2289 | .block(shape.latch) |
| 2290 | .insts |
| 2291 | .iter() |
| 2292 | .filter(|i| i.id != shape.iadd_id) |
| 2293 | .cloned() |
| 2294 | .collect(); |
| 2295 | |
| 2296 | // For each clone k=1..U, allocate two fresh BlockIds (body clone + |
| 2297 | // latch-work clone). They form a chain wired between the original |
| 2298 | // latch's body work and the iadd. The original chain runs first; |
| 2299 | // each clone follows. |
| 2300 | let mut clone_blocks: Vec<(BlockId, BlockId)> = Vec::new(); |
| 2301 | for k in 1..shape.u { |
| 2302 | let body_c = func.create_block(&format!("partial_mb_body_c{}", k)); |
| 2303 | let latch_c = func.create_block(&format!("partial_mb_latch_c{}", k)); |
| 2304 | clone_blocks.push((body_c, latch_c)); |
| 2305 | // Pre-fill: each clone needs (k_const, iv_k, body insts subst, br latch_c) |
| 2306 | // (body insts subst, br to next clone or to original iadd block) |
| 2307 | let _ = k; |
| 2308 | let _ = body_c; |
| 2309 | let _ = latch_c; |
| 2310 | } |
| 2311 | |
| 2312 | // ---- 1. Original body_block: keep insts; rewire terminator from |
| 2313 | // `br latch` to `br latch (with body work)` — actually the |
| 2314 | // original body still flows into the original latch's body |
| 2315 | // work, then to the first clone. |
| 2316 | |
| 2317 | // Strategy: keep the original chain (body_block → latch) intact for |
| 2318 | // the work, but split the latch by removing the iadd. The latch's |
| 2319 | // body-work insts run, then we br to the first clone; the iadd |
| 2320 | // moves to the LAST clone's latch_c. |
| 2321 | // |
| 2322 | // After this: |
| 2323 | // header → body_block (orig) → latch (orig, no iadd) → clone1.body |
| 2324 | // → clone1.latch_c → clone2.body → clone2.latch_c → ... → |
| 2325 | // cloneU-1.latch_c (with new iadd*U) → header |
| 2326 | // |
| 2327 | // Each clone_k computes iv_k = iv + k at the top of body_c. |
| 2328 | |
| 2329 | // Remove the iadd from the original latch's insts. |
| 2330 | { |
| 2331 | let lat = func.block_mut(shape.latch); |
| 2332 | lat.insts.retain(|i| i.id != shape.iadd_id); |
| 2333 | } |
| 2334 | |
| 2335 | // ---- 2. Build each clone ---- |
| 2336 | for (k, &(body_c, latch_c)) in clone_blocks.iter().enumerate() { |
| 2337 | let k_val = (k + 1) as i64; // k=0 in Vec → iv+1 |
| 2338 | // body_c: iv_k_const + iv_k_iadd + cloned body insts + br latch_c |
| 2339 | let mut body_c_insts: Vec<Inst> = Vec::new(); |
| 2340 | let k_const_id = func.next_value_id(); |
| 2341 | func.register_type(k_const_id, int_ty.clone()); |
| 2342 | let iv_k_id = func.next_value_id(); |
| 2343 | func.register_type(iv_k_id, int_ty.clone()); |
| 2344 | body_c_insts.push(Inst { |
| 2345 | id: k_const_id, |
| 2346 | kind: InstKind::ConstInt(k_val as i128, shape.iv_width), |
| 2347 | ty: int_ty.clone(), |
| 2348 | span, |
| 2349 | }); |
| 2350 | body_c_insts.push(Inst { |
| 2351 | id: iv_k_id, |
| 2352 | kind: InstKind::IAdd(shape.iv_param, k_const_id), |
| 2353 | ty: int_ty.clone(), |
| 2354 | span, |
| 2355 | }); |
| 2356 | let mut subst: HashMap<ValueId, ValueId> = HashMap::new(); |
| 2357 | subst.insert(shape.iv_param, iv_k_id); |
| 2358 | for orig in &body_snapshot { |
| 2359 | let new_id = func.next_value_id(); |
| 2360 | func.register_type(new_id, orig.ty.clone()); |
| 2361 | subst.insert(orig.id, new_id); |
| 2362 | let new_kind = remap_kind(&orig.kind, &subst); |
| 2363 | body_c_insts.push(Inst { |
| 2364 | id: new_id, |
| 2365 | kind: new_kind, |
| 2366 | ty: orig.ty.clone(), |
| 2367 | span: orig.span, |
| 2368 | }); |
| 2369 | } |
| 2370 | func.block_mut(body_c).insts = body_c_insts; |
| 2371 | func.block_mut(body_c).terminator = Some(Terminator::Branch(latch_c, vec![])); |
| 2372 | |
| 2373 | // latch_c: cloned latch-body insts (sans iadd) with same subst. |
| 2374 | let mut latch_c_insts: Vec<Inst> = Vec::new(); |
| 2375 | for orig in &latch_snapshot_no_iadd { |
| 2376 | let new_id = func.next_value_id(); |
| 2377 | func.register_type(new_id, orig.ty.clone()); |
| 2378 | subst.insert(orig.id, new_id); |
| 2379 | let new_kind = remap_kind(&orig.kind, &subst); |
| 2380 | latch_c_insts.push(Inst { |
| 2381 | id: new_id, |
| 2382 | kind: new_kind, |
| 2383 | ty: orig.ty.clone(), |
| 2384 | span: orig.span, |
| 2385 | }); |
| 2386 | } |
| 2387 | func.block_mut(latch_c).insts = latch_c_insts; |
| 2388 | // Terminator: chain to next clone, or for the LAST clone, |
| 2389 | // emit the new iadd*U + br header. |
| 2390 | if k + 1 < clone_blocks.len() { |
| 2391 | let next_body = clone_blocks[k + 1].0; |
| 2392 | func.block_mut(latch_c).terminator = Some(Terminator::Branch(next_body, vec![])); |
| 2393 | } else { |
| 2394 | // Last clone: append the new iadd and branch to header. |
| 2395 | let new_step_const = func.next_value_id(); |
| 2396 | func.register_type(new_step_const, int_ty.clone()); |
| 2397 | func.block_mut(latch_c).insts.push(Inst { |
| 2398 | id: new_step_const, |
| 2399 | kind: InstKind::ConstInt(shape.u as i128, shape.iv_width), |
| 2400 | ty: int_ty.clone(), |
| 2401 | span, |
| 2402 | }); |
| 2403 | // Reuse the original iadd id so anything referencing it |
| 2404 | // (e.g., dominance maps) stays consistent. |
| 2405 | func.block_mut(latch_c).insts.push(Inst { |
| 2406 | id: shape.iadd_id, |
| 2407 | kind: InstKind::IAdd(shape.iv_param, new_step_const), |
| 2408 | ty: int_ty.clone(), |
| 2409 | span, |
| 2410 | }); |
| 2411 | func.block_mut(latch_c).terminator = |
| 2412 | Some(Terminator::Branch(shape.header, vec![shape.iadd_id])); |
| 2413 | } |
| 2414 | } |
| 2415 | |
| 2416 | // ---- 3. Rewire the original latch's terminator to flow into the |
| 2417 | // first clone (or directly to header if U == 2 and we |
| 2418 | // skipped clones — shouldn't happen since min U=2 means |
| 2419 | // exactly 1 clone). |
| 2420 | if let Some(&(first_clone_body, _)) = clone_blocks.first() { |
| 2421 | func.block_mut(shape.latch).terminator = |
| 2422 | Some(Terminator::Branch(first_clone_body, vec![])); |
| 2423 | } |
| 2424 | } |
| 2425 | |
| 2426 | // --------------------------------------------------------------------------- |
| 2427 | // Unit tests |
| 2428 | // --------------------------------------------------------------------------- |
| 2429 | |
| 2430 | #[cfg(test)] |
| 2431 | mod tests { |
| 2432 | use super::*; |
| 2433 | use crate::ir::types::{IntWidth, IrType}; |
| 2434 | use crate::lexer::Span; |
| 2435 | use crate::opt::pass::Pass; |
| 2436 | |
| 2437 | fn span() -> Span { |
| 2438 | super::dummy_span() |
| 2439 | } |
| 2440 | |
| 2441 | /// Build a minimal Module with one function containing a simple |
| 2442 | /// counted loop: do i = 1, N; a(i_offset) = const(42); end do |
| 2443 | /// |
| 2444 | /// The loop has 4 blocks: entry (preheader), header, latch, exit. |
| 2445 | /// The latch stores a constant to a fixed alloca (simulating a(i) = 42 |
| 2446 | /// where the GEP is not yet computed — we just store to a fixed addr). |
| 2447 | fn build_counted_loop_with_prefix(lo: i64, hi: i64, prefix: &str) -> Module { |
| 2448 | let mut m = Module::new("test".into()); |
| 2449 | let mut f = Function::new("loop_test".into(), vec![], IrType::Void); |
| 2450 | |
| 2451 | // Allocate block IDs. |
| 2452 | let header_id = f.create_block(&format!("{}_check", prefix)); |
| 2453 | let latch_id = f.create_block(&format!("{}_body", prefix)); |
| 2454 | let exit_id = f.create_block("exit"); |
| 2455 | let entry_id = f.entry; // preheader |
| 2456 | |
| 2457 | // ---- entry (preheader) ------------------------------------------- |
| 2458 | // const lo |
| 2459 | let lo_val = f.next_value_id(); |
| 2460 | f.block_mut(entry_id).insts.push(Inst { |
| 2461 | id: lo_val, |
| 2462 | ty: IrType::Int(IntWidth::I64), |
| 2463 | span: span(), |
| 2464 | kind: InstKind::ConstInt(lo as i128, IntWidth::I64), |
| 2465 | }); |
| 2466 | // alloca for the "array" (just a slot we store into) |
| 2467 | let alloca_val = f.next_value_id(); |
| 2468 | f.block_mut(entry_id).insts.push(Inst { |
| 2469 | id: alloca_val, |
| 2470 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))), |
| 2471 | span: span(), |
| 2472 | kind: InstKind::Alloca(IrType::Int(IntWidth::I64)), |
| 2473 | }); |
| 2474 | // br header(lo) |
| 2475 | f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo_val])); |
| 2476 | |
| 2477 | // ---- header (%i) ------------------------------------------------- |
| 2478 | let iv_param = f.next_value_id(); |
| 2479 | f.block_mut(header_id).params.push(BlockParam { |
| 2480 | id: iv_param, |
| 2481 | ty: IrType::Int(IntWidth::I64), |
| 2482 | }); |
| 2483 | // const hi |
| 2484 | let hi_val = f.next_value_id(); |
| 2485 | f.block_mut(header_id).insts.push(Inst { |
| 2486 | id: hi_val, |
| 2487 | ty: IrType::Int(IntWidth::I64), |
| 2488 | span: span(), |
| 2489 | kind: InstKind::ConstInt(hi as i128, IntWidth::I64), |
| 2490 | }); |
| 2491 | // %cmp = icmp.sle %i, hi |
| 2492 | let cmp_val = f.next_value_id(); |
| 2493 | f.block_mut(header_id).insts.push(Inst { |
| 2494 | id: cmp_val, |
| 2495 | ty: IrType::Bool, |
| 2496 | span: span(), |
| 2497 | kind: InstKind::ICmp(CmpOp::Le, iv_param, hi_val), |
| 2498 | }); |
| 2499 | // condBr %cmp, latch, exit |
| 2500 | f.block_mut(header_id).terminator = Some(Terminator::CondBranch { |
| 2501 | cond: cmp_val, |
| 2502 | true_dest: latch_id, |
| 2503 | true_args: vec![], |
| 2504 | false_dest: exit_id, |
| 2505 | false_args: vec![], |
| 2506 | }); |
| 2507 | |
| 2508 | // ---- latch ------------------------------------------------------- |
| 2509 | // Store 42 to the alloca (simulates a(i) = 42). |
| 2510 | let const42 = f.next_value_id(); |
| 2511 | f.block_mut(latch_id).insts.push(Inst { |
| 2512 | id: const42, |
| 2513 | ty: IrType::Int(IntWidth::I64), |
| 2514 | span: span(), |
| 2515 | kind: InstKind::ConstInt(42, IntWidth::I64), |
| 2516 | }); |
| 2517 | let store_id = f.next_value_id(); |
| 2518 | f.block_mut(latch_id).insts.push(Inst { |
| 2519 | id: store_id, |
| 2520 | ty: IrType::Void, |
| 2521 | span: span(), |
| 2522 | kind: InstKind::Store(const42, alloca_val), |
| 2523 | }); |
| 2524 | // %i_next = iadd %i, 1 |
| 2525 | let one = f.next_value_id(); |
| 2526 | f.block_mut(latch_id).insts.push(Inst { |
| 2527 | id: one, |
| 2528 | ty: IrType::Int(IntWidth::I64), |
| 2529 | span: span(), |
| 2530 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 2531 | }); |
| 2532 | let i_next = f.next_value_id(); |
| 2533 | f.block_mut(latch_id).insts.push(Inst { |
| 2534 | id: i_next, |
| 2535 | ty: IrType::Int(IntWidth::I64), |
| 2536 | span: span(), |
| 2537 | kind: InstKind::IAdd(iv_param, one), |
| 2538 | }); |
| 2539 | // br header(%i_next) |
| 2540 | f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![i_next])); |
| 2541 | |
| 2542 | // ---- exit -------------------------------------------------------- |
| 2543 | f.block_mut(exit_id).terminator = Some(Terminator::Return(None)); |
| 2544 | |
| 2545 | m.add_function(f); |
| 2546 | m |
| 2547 | } |
| 2548 | |
| 2549 | fn build_counted_loop(lo: i64, hi: i64) -> Module { |
| 2550 | build_counted_loop_with_prefix(lo, hi, "do") |
| 2551 | } |
| 2552 | |
| 2553 | /// Build a 2-block counted loop where the bound is a function |
| 2554 | /// parameter (runtime). Uses an alloca + store as the body work |
| 2555 | /// (1 store keeps the multi-store regalloc-pressure heuristic |
| 2556 | /// happy) and increments by 1 in the latch. |
| 2557 | fn build_runtime_counted_loop(lo: i64) -> Module { |
| 2558 | let mut m = Module::new("test".into()); |
| 2559 | // Reserve %0 for the n_param so the function params line up. |
| 2560 | let n_param = ValueId(0); |
| 2561 | let params = vec![Param { |
| 2562 | name: "n".into(), |
| 2563 | ty: IrType::Int(IntWidth::I64), |
| 2564 | id: n_param, |
| 2565 | fortran_noalias: false, |
| 2566 | }]; |
| 2567 | let mut f = Function::new("loop_runtime".into(), params, IrType::Void); |
| 2568 | |
| 2569 | let header_id = f.create_block("do_check"); |
| 2570 | let latch_id = f.create_block("do_body"); |
| 2571 | let exit_id = f.create_block("exit"); |
| 2572 | let entry_id = f.entry; |
| 2573 | |
| 2574 | // ---- preheader ---- |
| 2575 | let lo_val = f.next_value_id(); |
| 2576 | f.block_mut(entry_id).insts.push(Inst { |
| 2577 | id: lo_val, |
| 2578 | ty: IrType::Int(IntWidth::I64), |
| 2579 | span: span(), |
| 2580 | kind: InstKind::ConstInt(lo as i128, IntWidth::I64), |
| 2581 | }); |
| 2582 | let alloca_val = f.next_value_id(); |
| 2583 | f.block_mut(entry_id).insts.push(Inst { |
| 2584 | id: alloca_val, |
| 2585 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))), |
| 2586 | span: span(), |
| 2587 | kind: InstKind::Alloca(IrType::Int(IntWidth::I64)), |
| 2588 | }); |
| 2589 | f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo_val])); |
| 2590 | |
| 2591 | // ---- header ---- |
| 2592 | let iv_param = f.next_value_id(); |
| 2593 | f.block_mut(header_id).params.push(BlockParam { |
| 2594 | id: iv_param, |
| 2595 | ty: IrType::Int(IntWidth::I64), |
| 2596 | }); |
| 2597 | let cmp_val = f.next_value_id(); |
| 2598 | f.block_mut(header_id).insts.push(Inst { |
| 2599 | id: cmp_val, |
| 2600 | ty: IrType::Bool, |
| 2601 | span: span(), |
| 2602 | kind: InstKind::ICmp(CmpOp::Le, iv_param, n_param), |
| 2603 | }); |
| 2604 | f.block_mut(header_id).terminator = Some(Terminator::CondBranch { |
| 2605 | cond: cmp_val, |
| 2606 | true_dest: latch_id, |
| 2607 | true_args: vec![], |
| 2608 | false_dest: exit_id, |
| 2609 | false_args: vec![], |
| 2610 | }); |
| 2611 | |
| 2612 | // ---- latch (one store as the body work; iadd iv,1; br) ---- |
| 2613 | let const_v = f.next_value_id(); |
| 2614 | f.block_mut(latch_id).insts.push(Inst { |
| 2615 | id: const_v, |
| 2616 | ty: IrType::Int(IntWidth::I64), |
| 2617 | span: span(), |
| 2618 | kind: InstKind::ConstInt(42, IntWidth::I64), |
| 2619 | }); |
| 2620 | let store_id = f.next_value_id(); |
| 2621 | f.block_mut(latch_id).insts.push(Inst { |
| 2622 | id: store_id, |
| 2623 | ty: IrType::Void, |
| 2624 | span: span(), |
| 2625 | kind: InstKind::Store(const_v, alloca_val), |
| 2626 | }); |
| 2627 | let one_val = f.next_value_id(); |
| 2628 | f.block_mut(latch_id).insts.push(Inst { |
| 2629 | id: one_val, |
| 2630 | ty: IrType::Int(IntWidth::I64), |
| 2631 | span: span(), |
| 2632 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 2633 | }); |
| 2634 | let iv_next = f.next_value_id(); |
| 2635 | f.block_mut(latch_id).insts.push(Inst { |
| 2636 | id: iv_next, |
| 2637 | ty: IrType::Int(IntWidth::I64), |
| 2638 | span: span(), |
| 2639 | kind: InstKind::IAdd(iv_param, one_val), |
| 2640 | }); |
| 2641 | f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![iv_next])); |
| 2642 | |
| 2643 | // ---- exit ---- |
| 2644 | f.block_mut(exit_id).terminator = Some(Terminator::Return(None)); |
| 2645 | |
| 2646 | m.add_function(f); |
| 2647 | m |
| 2648 | } |
| 2649 | |
| 2650 | /// Build a 3-block counted loop: header + body_block (unconditional |
| 2651 | /// br to latch) + latch (with iadd + br header). Trip is statically |
| 2652 | /// known (`hi`); body_block stores `42` to a local alloca. |
| 2653 | fn build_3block_counted_loop(lo: i64, hi: i64) -> Module { |
| 2654 | let mut m = Module::new("test".into()); |
| 2655 | let mut f = Function::new("loop_3block".into(), vec![], IrType::Void); |
| 2656 | |
| 2657 | let header_id = f.create_block("do_check"); |
| 2658 | let body_id = f.create_block("do_body_inner"); |
| 2659 | let latch_id = f.create_block("do_body"); |
| 2660 | let exit_id = f.create_block("exit"); |
| 2661 | let entry_id = f.entry; |
| 2662 | |
| 2663 | // ---- preheader ---- |
| 2664 | let lo_val = f.next_value_id(); |
| 2665 | f.block_mut(entry_id).insts.push(Inst { |
| 2666 | id: lo_val, |
| 2667 | ty: IrType::Int(IntWidth::I64), |
| 2668 | span: span(), |
| 2669 | kind: InstKind::ConstInt(lo as i128, IntWidth::I64), |
| 2670 | }); |
| 2671 | let alloca_val = f.next_value_id(); |
| 2672 | f.block_mut(entry_id).insts.push(Inst { |
| 2673 | id: alloca_val, |
| 2674 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))), |
| 2675 | span: span(), |
| 2676 | kind: InstKind::Alloca(IrType::Int(IntWidth::I64)), |
| 2677 | }); |
| 2678 | f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo_val])); |
| 2679 | |
| 2680 | // ---- header(iv) ---- |
| 2681 | let iv_param = f.next_value_id(); |
| 2682 | f.block_mut(header_id).params.push(BlockParam { |
| 2683 | id: iv_param, |
| 2684 | ty: IrType::Int(IntWidth::I64), |
| 2685 | }); |
| 2686 | let hi_val = f.next_value_id(); |
| 2687 | f.block_mut(header_id).insts.push(Inst { |
| 2688 | id: hi_val, |
| 2689 | ty: IrType::Int(IntWidth::I64), |
| 2690 | span: span(), |
| 2691 | kind: InstKind::ConstInt(hi as i128, IntWidth::I64), |
| 2692 | }); |
| 2693 | let cmp_val = f.next_value_id(); |
| 2694 | f.block_mut(header_id).insts.push(Inst { |
| 2695 | id: cmp_val, |
| 2696 | ty: IrType::Bool, |
| 2697 | span: span(), |
| 2698 | kind: InstKind::ICmp(CmpOp::Le, iv_param, hi_val), |
| 2699 | }); |
| 2700 | f.block_mut(header_id).terminator = Some(Terminator::CondBranch { |
| 2701 | cond: cmp_val, |
| 2702 | true_dest: body_id, |
| 2703 | true_args: vec![], |
| 2704 | false_dest: exit_id, |
| 2705 | false_args: vec![], |
| 2706 | }); |
| 2707 | |
| 2708 | // ---- body block (one store, br latch) ---- |
| 2709 | let const_v = f.next_value_id(); |
| 2710 | f.block_mut(body_id).insts.push(Inst { |
| 2711 | id: const_v, |
| 2712 | ty: IrType::Int(IntWidth::I64), |
| 2713 | span: span(), |
| 2714 | kind: InstKind::ConstInt(42, IntWidth::I64), |
| 2715 | }); |
| 2716 | let store_id = f.next_value_id(); |
| 2717 | f.block_mut(body_id).insts.push(Inst { |
| 2718 | id: store_id, |
| 2719 | ty: IrType::Void, |
| 2720 | span: span(), |
| 2721 | kind: InstKind::Store(const_v, alloca_val), |
| 2722 | }); |
| 2723 | f.block_mut(body_id).terminator = Some(Terminator::Branch(latch_id, vec![])); |
| 2724 | |
| 2725 | // ---- latch (iadd iv,1; br header) ---- |
| 2726 | let one_val = f.next_value_id(); |
| 2727 | f.block_mut(latch_id).insts.push(Inst { |
| 2728 | id: one_val, |
| 2729 | ty: IrType::Int(IntWidth::I64), |
| 2730 | span: span(), |
| 2731 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 2732 | }); |
| 2733 | let iv_next = f.next_value_id(); |
| 2734 | f.block_mut(latch_id).insts.push(Inst { |
| 2735 | id: iv_next, |
| 2736 | ty: IrType::Int(IntWidth::I64), |
| 2737 | span: span(), |
| 2738 | kind: InstKind::IAdd(iv_param, one_val), |
| 2739 | }); |
| 2740 | f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![iv_next])); |
| 2741 | |
| 2742 | // ---- exit ---- |
| 2743 | f.block_mut(exit_id).terminator = Some(Terminator::Return(None)); |
| 2744 | |
| 2745 | m.add_function(f); |
| 2746 | m |
| 2747 | } |
| 2748 | |
| 2749 | #[test] |
| 2750 | fn partial_unrolls_3block_loop_trip16() { |
| 2751 | // Trip 16 > FULL_UNROLL_MAX=8. Multi-block detector should pick |
| 2752 | // the largest U that divides 16 and fits the body budget; with |
| 2753 | // a small body this is U=4. The transform allocates 3 fresh |
| 2754 | // BlockIds per clone (body_c + latch_c) and wires them between |
| 2755 | // the original latch and the header. |
| 2756 | let mut m = build_3block_counted_loop(1, 16); |
| 2757 | let blocks_before = m.functions[0].blocks.len(); |
| 2758 | let pass = LoopUnroll; |
| 2759 | let changed = pass.run(&mut m); |
| 2760 | assert!( |
| 2761 | changed, |
| 2762 | "3-block trip-16 loop should multi-block partial-unroll" |
| 2763 | ); |
| 2764 | let f = &m.functions[0]; |
| 2765 | // For U=4, we add (U-1) * 2 = 6 new blocks (body_c + latch_c per clone). |
| 2766 | assert_eq!( |
| 2767 | f.blocks.len(), |
| 2768 | blocks_before + 6, |
| 2769 | "expected 6 new clone blocks; got {} → {}", |
| 2770 | blocks_before, |
| 2771 | f.blocks.len() |
| 2772 | ); |
| 2773 | // The original latch lost its iadd; the new iadd lives in the |
| 2774 | // last clone's latch_c. Count IAdd instances across the function: |
| 2775 | // one in the last clone's latch_c. (The original iadd id was |
| 2776 | // reused there.) |
| 2777 | let total_iadd = f |
| 2778 | .blocks |
| 2779 | .iter() |
| 2780 | .flat_map(|b| b.insts.iter()) |
| 2781 | .filter(|i| matches!(i.kind, InstKind::IAdd(..))) |
| 2782 | .count(); |
| 2783 | // Each clone's body_c has one iadd (computing iv_k). The final |
| 2784 | // clone's latch_c also has one iadd (the loop step). With U=4: |
| 2785 | // 3 body_c iadds + 1 final-step iadd = 4 total. |
| 2786 | assert_eq!( |
| 2787 | total_iadd, 4, |
| 2788 | "expected 4 IAdds total (3 iv_k + 1 step), got {}", |
| 2789 | total_iadd |
| 2790 | ); |
| 2791 | } |
| 2792 | |
| 2793 | #[test] |
| 2794 | fn partial_unrolls_runtime_trip_2block_loop() { |
| 2795 | // The bound is a function parameter (runtime), so the static |
| 2796 | // partial-unroll detector won't fire — but the runtime detector |
| 2797 | // should pick up the same shape and emit a head_bound |
| 2798 | // computation in the preheader plus a remainder loop after |
| 2799 | // the unrolled main loop. |
| 2800 | let mut m = build_runtime_counted_loop(1); |
| 2801 | let blocks_before = m.functions[0].blocks.len(); |
| 2802 | let pass = LoopUnroll; |
| 2803 | let changed = pass.run(&mut m); |
| 2804 | assert!( |
| 2805 | changed, |
| 2806 | "runtime-trip 2-block loop should partial-unroll" |
| 2807 | ); |
| 2808 | let f = &m.functions[0]; |
| 2809 | // Two new blocks: header_remain + latch_remain. |
| 2810 | assert_eq!( |
| 2811 | f.blocks.len(), |
| 2812 | blocks_before + 2, |
| 2813 | "expected 2 new blocks (remainder header + latch); got {} → {}", |
| 2814 | blocks_before, |
| 2815 | f.blocks.len() |
| 2816 | ); |
| 2817 | // Preheader should contain the runtime arithmetic for head_bound: |
| 2818 | // ISub, IAdd, IMod feeding into the header's icmp via a fresh value. |
| 2819 | let preheader = &f.blocks[0]; |
| 2820 | let kinds: Vec<&InstKind> = preheader.insts.iter().map(|i| &i.kind).collect(); |
| 2821 | let has_imod = kinds |
| 2822 | .iter() |
| 2823 | .any(|k| matches!(k, InstKind::IMod(..))); |
| 2824 | assert!(has_imod, "preheader should compute head_bound via IMod"); |
| 2825 | } |
| 2826 | |
| 2827 | #[test] |
| 2828 | fn unrolls_trip4() { |
| 2829 | let mut m = build_counted_loop(1, 4); |
| 2830 | let pass = LoopUnroll; |
| 2831 | let changed = pass.run(&mut m); |
| 2832 | assert!(changed, "expected unroll to fire"); |
| 2833 | |
| 2834 | let f = &m.functions[0]; |
| 2835 | // After unrolling: preheader, 4 cloned iteration blocks, exit = 6 blocks. |
| 2836 | // The original header and latch are pruned (marked Unreachable then removed). |
| 2837 | assert_eq!(f.blocks.len(), 6, "expected entry + 4 iter blocks + exit"); |
| 2838 | |
| 2839 | // Each iteration block has a ConstInt for that iteration's IV value. |
| 2840 | let all_iv_consts: Vec<i128> = f |
| 2841 | .blocks |
| 2842 | .iter() |
| 2843 | .flat_map(|b| b.insts.iter()) |
| 2844 | .filter_map(|i| { |
| 2845 | if let InstKind::ConstInt(v, _) = i.kind { |
| 2846 | Some(v) |
| 2847 | } else { |
| 2848 | None |
| 2849 | } |
| 2850 | }) |
| 2851 | .collect(); |
| 2852 | assert!(all_iv_consts.contains(&1), "IV=1 should appear"); |
| 2853 | assert!(all_iv_consts.contains(&2), "IV=2 should appear"); |
| 2854 | assert!(all_iv_consts.contains(&3), "IV=3 should appear"); |
| 2855 | assert!(all_iv_consts.contains(&4), "IV=4 should appear"); |
| 2856 | } |
| 2857 | |
| 2858 | #[test] |
| 2859 | fn partial_unrolls_trip10() { |
| 2860 | // Trip 10 > FULL_UNROLL_MAX=8 — full unroll skips it, but |
| 2861 | // partial unroll picks U=2 (largest factor of 10 that fits |
| 2862 | // the body budget) and doubles the body in place. |
| 2863 | let mut m = build_counted_loop(1, 10); |
| 2864 | let pass = LoopUnroll; |
| 2865 | let changed = pass.run(&mut m); |
| 2866 | assert!(changed, "trip-10 loop should partial-unroll"); |
| 2867 | // Same number of blocks (loop is preserved); body inst count |
| 2868 | // grew because we cloned the body once with iv → iv+1. |
| 2869 | let f = &m.functions[0]; |
| 2870 | let block_count = f.blocks.len(); |
| 2871 | assert!( |
| 2872 | block_count >= 3, |
| 2873 | "expected loop structure preserved, got {} blocks", |
| 2874 | block_count |
| 2875 | ); |
| 2876 | } |
| 2877 | |
| 2878 | #[test] |
| 2879 | fn does_not_partial_unroll_trip_above_threshold_with_no_factor() { |
| 2880 | // Trip 11 is prime and > FULL_UNROLL_MAX — neither full nor |
| 2881 | // partial unroll fires. |
| 2882 | let mut m = build_counted_loop(1, 11); |
| 2883 | let pass = LoopUnroll; |
| 2884 | let changed = pass.run(&mut m); |
| 2885 | assert!(!changed, "trip-11 (prime) loop should not unroll"); |
| 2886 | } |
| 2887 | |
| 2888 | #[test] |
| 2889 | fn unrolls_trip10_do_concurrent() { |
| 2890 | let mut m = build_counted_loop_with_prefix(1, 10, "doconc"); |
| 2891 | let pass = LoopUnroll; |
| 2892 | let changed = pass.run(&mut m); |
| 2893 | assert!( |
| 2894 | changed, |
| 2895 | "DO CONCURRENT trip-count-10 loop should fully unroll" |
| 2896 | ); |
| 2897 | |
| 2898 | let f = &m.functions[0]; |
| 2899 | assert_eq!(f.blocks.len(), 12, "expected entry + 10 iter blocks + exit"); |
| 2900 | } |
| 2901 | |
| 2902 | #[test] |
| 2903 | fn does_not_unroll_empty_loop() { |
| 2904 | let mut m = build_counted_loop(5, 3); // lo > hi → trip count 0 |
| 2905 | let pass = LoopUnroll; |
| 2906 | let changed = pass.run(&mut m); |
| 2907 | assert!(!changed, "empty loop should be left for DCE"); |
| 2908 | } |
| 2909 | |
| 2910 | #[test] |
| 2911 | fn unrolls_reduction_loop_threading_accumulator() { |
| 2912 | // A reduction loop has 2 block params (iv + accumulator). |
| 2913 | // The unroller should clone the body once per iteration and |
| 2914 | // thread the accumulator value across iterations: iter 0 |
| 2915 | // reads `acc_init`, iter 1 reads iter 0's `new_acc`, etc. The |
| 2916 | // final branch to `exit` passes the last iteration's |
| 2917 | // accumulator value to the exit block param. |
| 2918 | let mut m = Module::new("test".into()); |
| 2919 | let mut f = Function::new("reduce".into(), vec![], IrType::Void); |
| 2920 | let header_id = f.create_block("header"); |
| 2921 | let latch_id = f.create_block("latch"); |
| 2922 | let exit_id = f.create_block("exit"); |
| 2923 | let entry_id = f.entry; |
| 2924 | |
| 2925 | // Entry: br header(1, 0) |
| 2926 | let lo = f.next_value_id(); |
| 2927 | f.block_mut(entry_id).insts.push(Inst { |
| 2928 | id: lo, |
| 2929 | ty: IrType::Int(IntWidth::I64), |
| 2930 | span: span(), |
| 2931 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 2932 | }); |
| 2933 | let z = f.next_value_id(); |
| 2934 | f.block_mut(entry_id).insts.push(Inst { |
| 2935 | id: z, |
| 2936 | ty: IrType::Int(IntWidth::I64), |
| 2937 | span: span(), |
| 2938 | kind: InstKind::ConstInt(0, IntWidth::I64), |
| 2939 | }); |
| 2940 | f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo, z])); |
| 2941 | |
| 2942 | // Header: 2 params (iv, acc) |
| 2943 | let iv = f.next_value_id(); |
| 2944 | let acc = f.next_value_id(); |
| 2945 | f.block_mut(header_id).params.push(BlockParam { |
| 2946 | id: iv, |
| 2947 | ty: IrType::Int(IntWidth::I64), |
| 2948 | }); |
| 2949 | f.block_mut(header_id).params.push(BlockParam { |
| 2950 | id: acc, |
| 2951 | ty: IrType::Int(IntWidth::I64), |
| 2952 | }); |
| 2953 | let hi = f.next_value_id(); |
| 2954 | f.block_mut(header_id).insts.push(Inst { |
| 2955 | id: hi, |
| 2956 | ty: IrType::Int(IntWidth::I64), |
| 2957 | span: span(), |
| 2958 | kind: InstKind::ConstInt(4, IntWidth::I64), |
| 2959 | }); |
| 2960 | let cmp = f.next_value_id(); |
| 2961 | f.block_mut(header_id).insts.push(Inst { |
| 2962 | id: cmp, |
| 2963 | ty: IrType::Bool, |
| 2964 | span: span(), |
| 2965 | kind: InstKind::ICmp(CmpOp::Le, iv, hi), |
| 2966 | }); |
| 2967 | f.block_mut(header_id).terminator = Some(Terminator::CondBranch { |
| 2968 | cond: cmp, |
| 2969 | true_dest: latch_id, |
| 2970 | true_args: vec![], |
| 2971 | false_dest: exit_id, |
| 2972 | false_args: vec![acc], |
| 2973 | }); |
| 2974 | |
| 2975 | // Latch: acc_next = acc + iv; i_next = iv + 1; br header(i_next, acc_next) |
| 2976 | let one = f.next_value_id(); |
| 2977 | f.block_mut(latch_id).insts.push(Inst { |
| 2978 | id: one, |
| 2979 | ty: IrType::Int(IntWidth::I64), |
| 2980 | span: span(), |
| 2981 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 2982 | }); |
| 2983 | let acc_next = f.next_value_id(); |
| 2984 | f.block_mut(latch_id).insts.push(Inst { |
| 2985 | id: acc_next, |
| 2986 | ty: IrType::Int(IntWidth::I64), |
| 2987 | span: span(), |
| 2988 | kind: InstKind::IAdd(acc, iv), |
| 2989 | }); |
| 2990 | let i_next = f.next_value_id(); |
| 2991 | f.block_mut(latch_id).insts.push(Inst { |
| 2992 | id: i_next, |
| 2993 | ty: IrType::Int(IntWidth::I64), |
| 2994 | span: span(), |
| 2995 | kind: InstKind::IAdd(iv, one), |
| 2996 | }); |
| 2997 | f.block_mut(latch_id).terminator = |
| 2998 | Some(Terminator::Branch(header_id, vec![i_next, acc_next])); |
| 2999 | |
| 3000 | // Exit: acc param, return |
| 3001 | let acc_out = f.next_value_id(); |
| 3002 | f.block_mut(exit_id).params.push(BlockParam { |
| 3003 | id: acc_out, |
| 3004 | ty: IrType::Int(IntWidth::I64), |
| 3005 | }); |
| 3006 | f.block_mut(exit_id).terminator = Some(Terminator::Return(None)); |
| 3007 | |
| 3008 | m.add_function(f); |
| 3009 | |
| 3010 | let pass = LoopUnroll; |
| 3011 | let changed = pass.run(&mut m); |
| 3012 | assert!(changed, "reduction loop should now unroll"); |
| 3013 | |
| 3014 | let f = &m.functions[0]; |
| 3015 | |
| 3016 | // The exit block must still exist and its (only) predecessor |
| 3017 | // path now feeds it a non-acc-param value — i.e. a value |
| 3018 | // defined in one of the cloned iteration blocks. |
| 3019 | let exit_blk = f |
| 3020 | .blocks |
| 3021 | .iter() |
| 3022 | .find(|b| b.id == exit_id) |
| 3023 | .expect("exit survives"); |
| 3024 | assert_eq!(exit_blk.params.len(), 1, "exit still has its acc_out param"); |
| 3025 | |
| 3026 | // Find the predecessor block whose terminator branches to the |
| 3027 | // exit block. Its single arg should be the unrolled chain's |
| 3028 | // final accumulator value, NOT the original `acc` header |
| 3029 | // param (which is now unreachable). |
| 3030 | let mut found_branch_to_exit = false; |
| 3031 | for blk in &f.blocks { |
| 3032 | if let Some(Terminator::Branch(dest, args)) = &blk.terminator { |
| 3033 | if *dest == exit_id { |
| 3034 | found_branch_to_exit = true; |
| 3035 | assert_eq!(args.len(), 1, "exit takes one arg (the accumulator)"); |
| 3036 | assert_ne!(args[0], acc, "must not reference the pruned header acc param"); |
| 3037 | assert_ne!(args[0], iv, "must not reference the pruned header iv param"); |
| 3038 | break; |
| 3039 | } |
| 3040 | } |
| 3041 | } |
| 3042 | assert!( |
| 3043 | found_branch_to_exit, |
| 3044 | "must find an unrolled iter block branching to the original exit" |
| 3045 | ); |
| 3046 | |
| 3047 | // Trip count = 4 → expect 4 cloned IAdd-of-acc instructions in |
| 3048 | // the function (one per iteration). The original latch's IAdd |
| 3049 | // is now under an Unreachable terminator and will be pruned by |
| 3050 | // a later DCE pass; for this test we just count the clones. |
| 3051 | let total_iadds: usize = f |
| 3052 | .blocks |
| 3053 | .iter() |
| 3054 | .flat_map(|b| b.insts.iter()) |
| 3055 | .filter(|inst| matches!(inst.kind, InstKind::IAdd(..))) |
| 3056 | .count(); |
| 3057 | assert!( |
| 3058 | total_iadds >= 4, |
| 3059 | "expected at least 4 IAdd insts (one acc-update per iteration), found {}", |
| 3060 | total_iadds |
| 3061 | ); |
| 3062 | } |
| 3063 | |
| 3064 | #[test] |
| 3065 | fn does_not_unroll_load_bearing_loop() { |
| 3066 | let mut m = Module::new("test".into()); |
| 3067 | let mut f = Function::new("load_loop".into(), vec![], IrType::Void); |
| 3068 | |
| 3069 | let header_id = f.create_block("header"); |
| 3070 | let latch_id = f.create_block("latch"); |
| 3071 | let exit_id = f.create_block("exit"); |
| 3072 | let entry_id = f.entry; |
| 3073 | |
| 3074 | let lo = f.next_value_id(); |
| 3075 | f.block_mut(entry_id).insts.push(Inst { |
| 3076 | id: lo, |
| 3077 | ty: IrType::Int(IntWidth::I64), |
| 3078 | span: span(), |
| 3079 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 3080 | }); |
| 3081 | let x_slot = f.next_value_id(); |
| 3082 | f.block_mut(entry_id).insts.push(Inst { |
| 3083 | id: x_slot, |
| 3084 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))), |
| 3085 | span: span(), |
| 3086 | kind: InstKind::Alloca(IrType::Int(IntWidth::I64)), |
| 3087 | }); |
| 3088 | let y_slot = f.next_value_id(); |
| 3089 | f.block_mut(entry_id).insts.push(Inst { |
| 3090 | id: y_slot, |
| 3091 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I64))), |
| 3092 | span: span(), |
| 3093 | kind: InstKind::Alloca(IrType::Int(IntWidth::I64)), |
| 3094 | }); |
| 3095 | let x_init = f.next_value_id(); |
| 3096 | f.block_mut(entry_id).insts.push(Inst { |
| 3097 | id: x_init, |
| 3098 | ty: IrType::Int(IntWidth::I64), |
| 3099 | span: span(), |
| 3100 | kind: InstKind::ConstInt(5, IntWidth::I64), |
| 3101 | }); |
| 3102 | let store_x = f.next_value_id(); |
| 3103 | f.block_mut(entry_id).insts.push(Inst { |
| 3104 | id: store_x, |
| 3105 | ty: IrType::Void, |
| 3106 | span: span(), |
| 3107 | kind: InstKind::Store(x_init, x_slot), |
| 3108 | }); |
| 3109 | let y_init = f.next_value_id(); |
| 3110 | f.block_mut(entry_id).insts.push(Inst { |
| 3111 | id: y_init, |
| 3112 | ty: IrType::Int(IntWidth::I64), |
| 3113 | span: span(), |
| 3114 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 3115 | }); |
| 3116 | let store_y = f.next_value_id(); |
| 3117 | f.block_mut(entry_id).insts.push(Inst { |
| 3118 | id: store_y, |
| 3119 | ty: IrType::Void, |
| 3120 | span: span(), |
| 3121 | kind: InstKind::Store(y_init, y_slot), |
| 3122 | }); |
| 3123 | f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo])); |
| 3124 | |
| 3125 | let iv = f.next_value_id(); |
| 3126 | f.block_mut(header_id).params.push(BlockParam { |
| 3127 | id: iv, |
| 3128 | ty: IrType::Int(IntWidth::I64), |
| 3129 | }); |
| 3130 | let hi = f.next_value_id(); |
| 3131 | f.block_mut(header_id).insts.push(Inst { |
| 3132 | id: hi, |
| 3133 | ty: IrType::Int(IntWidth::I64), |
| 3134 | span: span(), |
| 3135 | kind: InstKind::ConstInt(4, IntWidth::I64), |
| 3136 | }); |
| 3137 | let cmp = f.next_value_id(); |
| 3138 | f.block_mut(header_id).insts.push(Inst { |
| 3139 | id: cmp, |
| 3140 | ty: IrType::Bool, |
| 3141 | span: span(), |
| 3142 | kind: InstKind::ICmp(CmpOp::Le, iv, hi), |
| 3143 | }); |
| 3144 | f.block_mut(header_id).terminator = Some(Terminator::CondBranch { |
| 3145 | cond: cmp, |
| 3146 | true_dest: latch_id, |
| 3147 | true_args: vec![], |
| 3148 | false_dest: exit_id, |
| 3149 | false_args: vec![], |
| 3150 | }); |
| 3151 | |
| 3152 | let x_val = f.next_value_id(); |
| 3153 | f.block_mut(latch_id).insts.push(Inst { |
| 3154 | id: x_val, |
| 3155 | ty: IrType::Int(IntWidth::I64), |
| 3156 | span: span(), |
| 3157 | kind: InstKind::Load(x_slot), |
| 3158 | }); |
| 3159 | let y_val = f.next_value_id(); |
| 3160 | f.block_mut(latch_id).insts.push(Inst { |
| 3161 | id: y_val, |
| 3162 | ty: IrType::Int(IntWidth::I64), |
| 3163 | span: span(), |
| 3164 | kind: InstKind::Load(y_slot), |
| 3165 | }); |
| 3166 | let sum = f.next_value_id(); |
| 3167 | f.block_mut(latch_id).insts.push(Inst { |
| 3168 | id: sum, |
| 3169 | ty: IrType::Int(IntWidth::I64), |
| 3170 | span: span(), |
| 3171 | kind: InstKind::IAdd(x_val, y_val), |
| 3172 | }); |
| 3173 | let store_sum = f.next_value_id(); |
| 3174 | f.block_mut(latch_id).insts.push(Inst { |
| 3175 | id: store_sum, |
| 3176 | ty: IrType::Void, |
| 3177 | span: span(), |
| 3178 | kind: InstKind::Store(sum, y_slot), |
| 3179 | }); |
| 3180 | let one = f.next_value_id(); |
| 3181 | f.block_mut(latch_id).insts.push(Inst { |
| 3182 | id: one, |
| 3183 | ty: IrType::Int(IntWidth::I64), |
| 3184 | span: span(), |
| 3185 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 3186 | }); |
| 3187 | let next = f.next_value_id(); |
| 3188 | f.block_mut(latch_id).insts.push(Inst { |
| 3189 | id: next, |
| 3190 | ty: IrType::Int(IntWidth::I64), |
| 3191 | span: span(), |
| 3192 | kind: InstKind::IAdd(iv, one), |
| 3193 | }); |
| 3194 | f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![next])); |
| 3195 | |
| 3196 | f.block_mut(exit_id).terminator = Some(Terminator::Return(None)); |
| 3197 | m.add_function(f); |
| 3198 | |
| 3199 | let pass = LoopUnroll; |
| 3200 | let changed = pass.run(&mut m); |
| 3201 | assert!( |
| 3202 | !changed, |
| 3203 | "load-bearing loop should stay out of the full-unroll fast path" |
| 3204 | ); |
| 3205 | } |
| 3206 | |
| 3207 | #[test] |
| 3208 | fn unrolls_trip1() { |
| 3209 | let mut m = build_counted_loop(3, 3); // lo == hi → 1 iteration |
| 3210 | let pass = LoopUnroll; |
| 3211 | let changed = pass.run(&mut m); |
| 3212 | assert!(changed, "trip-1 loop should unroll"); |
| 3213 | let f = &m.functions[0]; |
| 3214 | // entry + 1 cloned iteration block + exit = 3 blocks |
| 3215 | assert_eq!(f.blocks.len(), 3, "expected entry + 1 iter block + exit"); |
| 3216 | } |
| 3217 | |
| 3218 | /// Regression: mem2reg places block params at dominance-frontier blocks, |
| 3219 | /// which can cause an outer-loop latch to reference the inner loop's header |
| 3220 | /// block param (%iv) in its branch terminator. After unrolling removes the |
| 3221 | /// inner header, %iv becomes undefined → IR verifier panic. |
| 3222 | /// |
| 3223 | /// `has_escaping_values` must check block params (not just instruction |
| 3224 | /// results) AND must be called for the header block (not just body_blocks). |
| 3225 | #[test] |
| 3226 | fn does_not_unroll_when_header_param_escapes_into_outer_block() { |
| 3227 | // Build: inner loop (entry → header(%iv) → latch → header, | → exit) |
| 3228 | // plus an outer_latch block that passes %iv to outer_dest. |
| 3229 | // outer_latch is not part of the inner nl.body but references %iv. |
| 3230 | let mut m = Module::new("test".into()); |
| 3231 | let mut f = Function::new("escape_test".into(), vec![], IrType::Void); |
| 3232 | |
| 3233 | let header_id = f.create_block("header"); |
| 3234 | let latch_id = f.create_block("latch"); |
| 3235 | let exit_id = f.create_block("exit"); |
| 3236 | let outer_latch = f.create_block("outer_latch"); |
| 3237 | let outer_dest = f.create_block("outer_dest"); |
| 3238 | let entry_id = f.entry; |
| 3239 | |
| 3240 | // Entry: const 1; br header(1) |
| 3241 | let lo = f.next_value_id(); |
| 3242 | f.block_mut(entry_id).insts.push(Inst { |
| 3243 | id: lo, |
| 3244 | ty: IrType::Int(IntWidth::I64), |
| 3245 | span: span(), |
| 3246 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 3247 | }); |
| 3248 | f.block_mut(entry_id).terminator = Some(Terminator::Branch(header_id, vec![lo])); |
| 3249 | |
| 3250 | // Header: block param %iv; icmp %iv ≤ 3; condBr latch, exit |
| 3251 | let iv = f.next_value_id(); |
| 3252 | f.block_mut(header_id).params.push(BlockParam { |
| 3253 | id: iv, |
| 3254 | ty: IrType::Int(IntWidth::I64), |
| 3255 | }); |
| 3256 | let hi_c = f.next_value_id(); |
| 3257 | f.block_mut(header_id).insts.push(Inst { |
| 3258 | id: hi_c, |
| 3259 | ty: IrType::Int(IntWidth::I64), |
| 3260 | span: span(), |
| 3261 | kind: InstKind::ConstInt(3, IntWidth::I64), |
| 3262 | }); |
| 3263 | let cmp = f.next_value_id(); |
| 3264 | f.block_mut(header_id).insts.push(Inst { |
| 3265 | id: cmp, |
| 3266 | ty: IrType::Bool, |
| 3267 | span: span(), |
| 3268 | kind: InstKind::ICmp(CmpOp::Le, iv, hi_c), |
| 3269 | }); |
| 3270 | f.block_mut(header_id).terminator = Some(Terminator::CondBranch { |
| 3271 | cond: cmp, |
| 3272 | true_dest: latch_id, |
| 3273 | true_args: vec![], |
| 3274 | false_dest: exit_id, |
| 3275 | false_args: vec![], |
| 3276 | }); |
| 3277 | |
| 3278 | // Latch: %next = iadd %iv, 1; br header(%next) |
| 3279 | let one = f.next_value_id(); |
| 3280 | f.block_mut(latch_id).insts.push(Inst { |
| 3281 | id: one, |
| 3282 | ty: IrType::Int(IntWidth::I64), |
| 3283 | span: span(), |
| 3284 | kind: InstKind::ConstInt(1, IntWidth::I64), |
| 3285 | }); |
| 3286 | let nxt = f.next_value_id(); |
| 3287 | f.block_mut(latch_id).insts.push(Inst { |
| 3288 | id: nxt, |
| 3289 | ty: IrType::Int(IntWidth::I64), |
| 3290 | span: span(), |
| 3291 | kind: InstKind::IAdd(iv, one), |
| 3292 | }); |
| 3293 | f.block_mut(latch_id).terminator = Some(Terminator::Branch(header_id, vec![nxt])); |
| 3294 | |
| 3295 | // Exit: ret void |
| 3296 | f.block_mut(exit_id).terminator = Some(Terminator::Return(None)); |
| 3297 | |
| 3298 | // outer_latch: br outer_dest(%iv) |
| 3299 | // Simulates an outer-loop latch that mem2reg threaded %iv through. |
| 3300 | // %iv is the header's block param — it escapes into this external block. |
| 3301 | f.block_mut(outer_latch).terminator = Some(Terminator::Branch(outer_dest, vec![iv])); |
| 3302 | |
| 3303 | // outer_dest(%x): ret void |
| 3304 | let x = f.next_value_id(); |
| 3305 | f.block_mut(outer_dest).params.push(BlockParam { |
| 3306 | id: x, |
| 3307 | ty: IrType::Int(IntWidth::I64), |
| 3308 | }); |
| 3309 | f.block_mut(outer_dest).terminator = Some(Terminator::Return(None)); |
| 3310 | |
| 3311 | m.add_function(f); |
| 3312 | |
| 3313 | let pass = LoopUnroll; |
| 3314 | let changed = pass.run(&mut m); |
| 3315 | assert!( |
| 3316 | !changed, |
| 3317 | "must not unroll when header block param escapes into an outer-loop block" |
| 3318 | ); |
| 3319 | } |
| 3320 | } |
| 3321 |