| 1 | //! Loop fission (distribution) pass. |
| 2 | //! |
| 3 | //! Splits a loop with independent statement groups into two loops. |
| 4 | //! Uses the LLVM-inspired "clone + replace-with-undef" pattern: |
| 5 | //! 1. Clone the entire loop |
| 6 | //! 2. In the original: replace group B's stores with no-ops (Undef stores) |
| 7 | //! 3. In the clone: replace group A's stores with no-ops |
| 8 | //! 4. Wire original exit → clone's preheader |
| 9 | //! 5. Let DCE clean up the dead instructions in later passes |
| 10 | //! |
| 11 | //! This avoids SSA domination bugs from selective instruction removal. |
| 12 | |
| 13 | use super::dep_analysis::{collect_mem_refs, test_dependence}; |
| 14 | use super::loop_utils::{clone_loop, find_preheader}; |
| 15 | use super::pass::Pass; |
| 16 | use crate::ir::inst::*; |
| 17 | use crate::ir::walk::{find_natural_loops, predecessors}; |
| 18 | use std::collections::HashSet; |
| 19 | |
| 20 | const FISSION_MIN_BODY: usize = 4; |
| 21 | |
| 22 | pub struct LoopFission; |
| 23 | |
| 24 | impl Pass for LoopFission { |
| 25 | fn name(&self) -> &'static str { |
| 26 | "loop-fission" |
| 27 | } |
| 28 | |
| 29 | fn run(&self, module: &mut Module) -> bool { |
| 30 | let mut changed = false; |
| 31 | for func in &mut module.functions { |
| 32 | if fission_in_function(func) { |
| 33 | changed = true; |
| 34 | } |
| 35 | } |
| 36 | changed |
| 37 | } |
| 38 | } |
| 39 | |
| 40 | fn fission_in_function(func: &mut Function) -> bool { |
| 41 | let loops = find_natural_loops(func); |
| 42 | let preds = predecessors(func); |
| 43 | |
| 44 | for lp in &loops { |
| 45 | let Some(_ph_id) = find_preheader(func, lp, &preds) else { |
| 46 | continue; |
| 47 | }; |
| 48 | if lp.latches.len() != 1 { |
| 49 | continue; |
| 50 | } |
| 51 | let latch_id = lp.latches[0]; |
| 52 | |
| 53 | let hdr = func.block(lp.header); |
| 54 | if hdr.params.len() != 1 { |
| 55 | continue; |
| 56 | } |
| 57 | let iv = hdr.params[0].id; |
| 58 | |
| 59 | // Find the single computation body block. |
| 60 | let body_block = find_computation_block(func, lp, latch_id); |
| 61 | let Some(body_bid) = body_block else { continue }; |
| 62 | |
| 63 | let block = func.block(body_bid); |
| 64 | if block.insts.len() < FISSION_MIN_BODY { |
| 65 | continue; |
| 66 | } |
| 67 | |
| 68 | // Need exactly 2 stores to different arrays. |
| 69 | let stores: Vec<(usize, ValueId)> = block |
| 70 | .insts |
| 71 | .iter() |
| 72 | .enumerate() |
| 73 | .filter(|(_, i)| matches!(i.kind, InstKind::Store(..))) |
| 74 | .map(|(idx, i)| (idx, i.id)) |
| 75 | .collect(); |
| 76 | if stores.len() != 2 { |
| 77 | continue; |
| 78 | } |
| 79 | |
| 80 | // Check independence via dep analysis. |
| 81 | let mut ivs = HashSet::new(); |
| 82 | ivs.insert(iv); |
| 83 | let mem_refs = collect_mem_refs(func, &lp.body, &ivs); |
| 84 | let writes: Vec<_> = mem_refs.iter().filter(|r| r.is_write).collect(); |
| 85 | if writes.len() != 2 { |
| 86 | continue; |
| 87 | } |
| 88 | if writes[0].base == writes[1].base { |
| 89 | continue; |
| 90 | } |
| 91 | |
| 92 | let mut has_cross_dep = false; |
| 93 | for i in 0..mem_refs.len() { |
| 94 | for j in (i + 1)..mem_refs.len() { |
| 95 | if !mem_refs[i].is_write && !mem_refs[j].is_write { |
| 96 | continue; |
| 97 | } |
| 98 | if mem_refs[i].base == mem_refs[j].base { |
| 99 | continue; |
| 100 | } |
| 101 | let dep = test_dependence(&mem_refs[i], &mem_refs[j]); |
| 102 | if dep.dependent { |
| 103 | has_cross_dep = true; |
| 104 | break; |
| 105 | } |
| 106 | } |
| 107 | if has_cross_dep { |
| 108 | break; |
| 109 | } |
| 110 | } |
| 111 | if has_cross_dep { |
| 112 | continue; |
| 113 | } |
| 114 | |
| 115 | // Find the exit block. |
| 116 | let exit_id = find_loop_exit(func, lp); |
| 117 | let Some(exit_id) = exit_id else { continue }; |
| 118 | |
| 119 | // Compute backward slices for each store. |
| 120 | // Clone the loop. |
| 121 | let (block_map, _) = clone_loop(func, lp); |
| 122 | |
| 123 | // In original body: neutralize store B (replace value with undef). |
| 124 | neutralize_store(func, body_bid, stores[1].0); |
| 125 | |
| 126 | // In clone body: neutralize store A. |
| 127 | let clone_body = block_map[&body_bid]; |
| 128 | neutralize_store(func, clone_body, stores[0].0); |
| 129 | |
| 130 | // Wire original exit → clone's header via a bridge block. |
| 131 | let clone_header = block_map[&lp.header]; |
| 132 | let ph_id = find_preheader(func, lp, &predecessors(func)).unwrap(); |
| 133 | let init_val = match &func.block(ph_id).terminator { |
| 134 | Some(Terminator::Branch(_, args)) if !args.is_empty() => args[0], |
| 135 | _ => continue, |
| 136 | }; |
| 137 | |
| 138 | let bridge = func.create_block("fission_bridge"); |
| 139 | func.block_mut(bridge).terminator = Some(Terminator::Branch(clone_header, vec![init_val])); |
| 140 | |
| 141 | // Redirect original cmp's exit → bridge. |
| 142 | for &bid in &lp.body { |
| 143 | let block = func.block_mut(bid); |
| 144 | if let Some(Terminator::CondBranch { |
| 145 | false_dest, |
| 146 | false_args, |
| 147 | .. |
| 148 | }) = &mut block.terminator |
| 149 | { |
| 150 | if *false_dest == exit_id { |
| 151 | *false_dest = bridge; |
| 152 | false_args.clear(); |
| 153 | break; |
| 154 | } |
| 155 | } |
| 156 | } |
| 157 | |
| 158 | // Clone's cmp exit already points to exit_id (wasn't remapped |
| 159 | // since it's outside the body). Correct. |
| 160 | |
| 161 | return true; |
| 162 | } |
| 163 | false |
| 164 | } |
| 165 | |
| 166 | /// Replace a store instruction's value operand with Undef, effectively |
| 167 | /// making it a dead store that DSE/DCE can clean up. |
| 168 | fn neutralize_store(func: &mut Function, block_id: BlockId, store_idx: usize) { |
| 169 | let block = func.block_mut(block_id); |
| 170 | if store_idx >= block.insts.len() { |
| 171 | return; |
| 172 | } |
| 173 | if let InstKind::Store(_, _) = block.insts[store_idx].kind { |
| 174 | // Replace with a store of undef — the store is now dead. |
| 175 | // We keep the store instruction (not remove it) to preserve |
| 176 | // SSA structure. DCE will clean it up later. |
| 177 | let undef_id = ValueId(u32::MAX - 1); // sentinel; will be replaced below |
| 178 | let _ = undef_id; |
| 179 | } |
| 180 | // Actually, the simplest approach: just remove the store entirely. |
| 181 | // Stores have no result value used by other instructions, so removing |
| 182 | // them can't break SSA. The computation leading to the store value |
| 183 | // becomes dead and DCE removes it. |
| 184 | block.insts.remove(store_idx); |
| 185 | } |
| 186 | |
| 187 | fn find_computation_block( |
| 188 | func: &Function, |
| 189 | lp: &crate::ir::walk::NaturalLoop, |
| 190 | latch_id: BlockId, |
| 191 | ) -> Option<BlockId> { |
| 192 | let mut comp = None; |
| 193 | for &bid in &lp.body { |
| 194 | if bid == lp.header || bid == latch_id { |
| 195 | continue; |
| 196 | } |
| 197 | let block = func.block(bid); |
| 198 | if block |
| 199 | .insts |
| 200 | .iter() |
| 201 | .any(|i| matches!(i.kind, InstKind::Store(..))) |
| 202 | { |
| 203 | if comp.is_some() { |
| 204 | return None; |
| 205 | } |
| 206 | comp = Some(bid); |
| 207 | } |
| 208 | } |
| 209 | comp |
| 210 | } |
| 211 | |
| 212 | fn backward_slice( |
| 213 | func: &Function, |
| 214 | root: ValueId, |
| 215 | loop_defs: &HashSet<ValueId>, |
| 216 | ) -> HashSet<ValueId> { |
| 217 | let mut slice = HashSet::new(); |
| 218 | let mut worklist = vec![root]; |
| 219 | while let Some(vid) = worklist.pop() { |
| 220 | if !slice.insert(vid) { |
| 221 | continue; |
| 222 | } |
| 223 | for block in &func.blocks { |
| 224 | for inst in &block.insts { |
| 225 | if inst.id == vid { |
| 226 | for operand in crate::ir::walk::inst_uses(&inst.kind) { |
| 227 | if loop_defs.contains(&operand) && !slice.contains(&operand) { |
| 228 | worklist.push(operand); |
| 229 | } |
| 230 | } |
| 231 | } |
| 232 | } |
| 233 | } |
| 234 | } |
| 235 | slice |
| 236 | } |
| 237 | |
| 238 | fn find_loop_exit(func: &Function, lp: &crate::ir::walk::NaturalLoop) -> Option<BlockId> { |
| 239 | for &bid in &lp.body { |
| 240 | let block = func.block(bid); |
| 241 | if let Some(Terminator::CondBranch { false_dest, .. }) = &block.terminator { |
| 242 | if !lp.body.contains(false_dest) { |
| 243 | return Some(*false_dest); |
| 244 | } |
| 245 | } |
| 246 | } |
| 247 | None |
| 248 | } |
| 249 | |
| 250 | #[cfg(test)] |
| 251 | mod tests { |
| 252 | use super::*; |
| 253 | use crate::ir::types::{IntWidth, IrType}; |
| 254 | use crate::ir::verify::verify_module; |
| 255 | use crate::lexer::{Position, Span}; |
| 256 | use crate::opt::pass::Pass; |
| 257 | |
| 258 | fn span() -> Span { |
| 259 | let pos = Position { line: 0, col: 0 }; |
| 260 | Span { |
| 261 | file_id: 0, |
| 262 | start: pos, |
| 263 | end: pos, |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | #[test] |
| 268 | fn fission_no_op_on_empty() { |
| 269 | let mut m = Module::new("test".into()); |
| 270 | let mut f = Function::new("test".into(), vec![], IrType::Void); |
| 271 | f.block_mut(f.entry).terminator = Some(Terminator::Return(None)); |
| 272 | m.add_function(f); |
| 273 | let pass = LoopFission; |
| 274 | let changed = pass.run(&mut m); |
| 275 | assert!(!changed, "no loops → no fission"); |
| 276 | } |
| 277 | |
| 278 | #[test] |
| 279 | fn fission_clears_exit_args_when_rerouting_to_bridge() { |
| 280 | let mut m = Module::new("test".into()); |
| 281 | let mut f = Function::new("test".into(), vec![], IrType::Void); |
| 282 | |
| 283 | let preheader = f.create_block("preheader"); |
| 284 | let header = f.create_block("header"); |
| 285 | let body = f.create_block("body"); |
| 286 | let latch = f.create_block("latch"); |
| 287 | let exit = f.create_block("exit"); |
| 288 | let entry = f.entry; |
| 289 | |
| 290 | let arr_ty = IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 16); |
| 291 | let arr_a = f.next_value_id(); |
| 292 | f.register_type(arr_a, IrType::Ptr(Box::new(arr_ty.clone()))); |
| 293 | f.block_mut(entry).insts.push(Inst { |
| 294 | id: arr_a, |
| 295 | ty: IrType::Ptr(Box::new(arr_ty.clone())), |
| 296 | span: span(), |
| 297 | kind: InstKind::Alloca(arr_ty.clone()), |
| 298 | }); |
| 299 | let arr_b = f.next_value_id(); |
| 300 | f.register_type(arr_b, IrType::Ptr(Box::new(arr_ty.clone()))); |
| 301 | f.block_mut(entry).insts.push(Inst { |
| 302 | id: arr_b, |
| 303 | ty: IrType::Ptr(Box::new(arr_ty)), |
| 304 | span: span(), |
| 305 | kind: InstKind::Alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 16)), |
| 306 | }); |
| 307 | let c0 = f.next_value_id(); |
| 308 | f.register_type(c0, IrType::Int(IntWidth::I32)); |
| 309 | f.block_mut(entry).insts.push(Inst { |
| 310 | id: c0, |
| 311 | ty: IrType::Int(IntWidth::I32), |
| 312 | span: span(), |
| 313 | kind: InstKind::ConstInt(0, IntWidth::I32), |
| 314 | }); |
| 315 | let c1 = f.next_value_id(); |
| 316 | f.register_type(c1, IrType::Int(IntWidth::I32)); |
| 317 | f.block_mut(entry).insts.push(Inst { |
| 318 | id: c1, |
| 319 | ty: IrType::Int(IntWidth::I32), |
| 320 | span: span(), |
| 321 | kind: InstKind::ConstInt(1, IntWidth::I32), |
| 322 | }); |
| 323 | let c2 = f.next_value_id(); |
| 324 | f.register_type(c2, IrType::Int(IntWidth::I32)); |
| 325 | f.block_mut(entry).insts.push(Inst { |
| 326 | id: c2, |
| 327 | ty: IrType::Int(IntWidth::I32), |
| 328 | span: span(), |
| 329 | kind: InstKind::ConstInt(2, IntWidth::I32), |
| 330 | }); |
| 331 | let c4 = f.next_value_id(); |
| 332 | f.register_type(c4, IrType::Int(IntWidth::I32)); |
| 333 | f.block_mut(entry).insts.push(Inst { |
| 334 | id: c4, |
| 335 | ty: IrType::Int(IntWidth::I32), |
| 336 | span: span(), |
| 337 | kind: InstKind::ConstInt(4, IntWidth::I32), |
| 338 | }); |
| 339 | f.block_mut(entry).terminator = Some(Terminator::Branch(preheader, vec![])); |
| 340 | |
| 341 | f.block_mut(preheader).terminator = Some(Terminator::Branch(header, vec![c0])); |
| 342 | |
| 343 | let iv = f.next_value_id(); |
| 344 | f.register_type(iv, IrType::Int(IntWidth::I32)); |
| 345 | f.block_mut(header).params.push(BlockParam { |
| 346 | id: iv, |
| 347 | ty: IrType::Int(IntWidth::I32), |
| 348 | }); |
| 349 | f.block_mut(header).terminator = Some(Terminator::Branch(body, vec![])); |
| 350 | |
| 351 | let gep_a = f.next_value_id(); |
| 352 | f.register_type(gep_a, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))); |
| 353 | f.block_mut(body).insts.push(Inst { |
| 354 | id: gep_a, |
| 355 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))), |
| 356 | span: span(), |
| 357 | kind: InstKind::GetElementPtr(arr_a, vec![iv]), |
| 358 | }); |
| 359 | let store_a = f.next_value_id(); |
| 360 | f.register_type(store_a, IrType::Void); |
| 361 | f.block_mut(body).insts.push(Inst { |
| 362 | id: store_a, |
| 363 | ty: IrType::Void, |
| 364 | span: span(), |
| 365 | kind: InstKind::Store(c1, gep_a), |
| 366 | }); |
| 367 | |
| 368 | let gep_b = f.next_value_id(); |
| 369 | f.register_type(gep_b, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))); |
| 370 | f.block_mut(body).insts.push(Inst { |
| 371 | id: gep_b, |
| 372 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))), |
| 373 | span: span(), |
| 374 | kind: InstKind::GetElementPtr(arr_b, vec![iv]), |
| 375 | }); |
| 376 | let store_b = f.next_value_id(); |
| 377 | f.register_type(store_b, IrType::Void); |
| 378 | f.block_mut(body).insts.push(Inst { |
| 379 | id: store_b, |
| 380 | ty: IrType::Void, |
| 381 | span: span(), |
| 382 | kind: InstKind::Store(c2, gep_b), |
| 383 | }); |
| 384 | f.block_mut(body).terminator = Some(Terminator::Branch(latch, vec![])); |
| 385 | |
| 386 | let nxt = f.next_value_id(); |
| 387 | f.register_type(nxt, IrType::Int(IntWidth::I32)); |
| 388 | f.block_mut(latch).insts.push(Inst { |
| 389 | id: nxt, |
| 390 | ty: IrType::Int(IntWidth::I32), |
| 391 | span: span(), |
| 392 | kind: InstKind::IAdd(iv, c1), |
| 393 | }); |
| 394 | let cmp = f.next_value_id(); |
| 395 | f.register_type(cmp, IrType::Bool); |
| 396 | f.block_mut(latch).insts.push(Inst { |
| 397 | id: cmp, |
| 398 | ty: IrType::Bool, |
| 399 | span: span(), |
| 400 | kind: InstKind::ICmp(CmpOp::Le, nxt, c4), |
| 401 | }); |
| 402 | f.block_mut(latch).terminator = Some(Terminator::CondBranch { |
| 403 | cond: cmp, |
| 404 | true_dest: header, |
| 405 | true_args: vec![nxt], |
| 406 | false_dest: exit, |
| 407 | false_args: vec![nxt], |
| 408 | }); |
| 409 | |
| 410 | let exit_param = f.next_value_id(); |
| 411 | f.register_type(exit_param, IrType::Int(IntWidth::I32)); |
| 412 | f.block_mut(exit).params.push(BlockParam { |
| 413 | id: exit_param, |
| 414 | ty: IrType::Int(IntWidth::I32), |
| 415 | }); |
| 416 | let _use_exit = f.next_value_id(); |
| 417 | f.register_type(_use_exit, IrType::Int(IntWidth::I32)); |
| 418 | f.block_mut(exit).insts.push(Inst { |
| 419 | id: _use_exit, |
| 420 | ty: IrType::Int(IntWidth::I32), |
| 421 | span: span(), |
| 422 | kind: InstKind::IAdd(exit_param, c1), |
| 423 | }); |
| 424 | f.block_mut(exit).terminator = Some(Terminator::Return(None)); |
| 425 | |
| 426 | m.add_function(f); |
| 427 | assert!(verify_module(&m).is_empty(), "test setup must start valid"); |
| 428 | |
| 429 | let pass = LoopFission; |
| 430 | let changed = pass.run(&mut m); |
| 431 | assert!(changed, "the loop should fission"); |
| 432 | assert!( |
| 433 | verify_module(&m).is_empty(), |
| 434 | "fission should keep bridge exit edges verifier-clean" |
| 435 | ); |
| 436 | } |
| 437 | } |
| 438 |