Rust · 14615 bytes Raw Blame History
1 //! Loop fission (distribution) pass.
2 //!
3 //! Splits a loop with independent statement groups into two loops.
4 //! Uses the LLVM-inspired "clone + replace-with-undef" pattern:
5 //! 1. Clone the entire loop
6 //! 2. In the original: replace group B's stores with no-ops (Undef stores)
7 //! 3. In the clone: replace group A's stores with no-ops
8 //! 4. Wire original exit → clone's preheader
9 //! 5. Let DCE clean up the dead instructions in later passes
10 //!
11 //! This avoids SSA domination bugs from selective instruction removal.
12
13 use super::dep_analysis::{collect_mem_refs, test_dependence};
14 use super::loop_utils::{clone_loop, find_preheader};
15 use super::pass::Pass;
16 use crate::ir::inst::*;
17 use crate::ir::walk::{find_natural_loops, predecessors};
18 use std::collections::HashSet;
19
20 const FISSION_MIN_BODY: usize = 4;
21
22 pub struct LoopFission;
23
24 impl Pass for LoopFission {
25 fn name(&self) -> &'static str {
26 "loop-fission"
27 }
28
29 fn run(&self, module: &mut Module) -> bool {
30 let mut changed = false;
31 for func in &mut module.functions {
32 if fission_in_function(func) {
33 changed = true;
34 }
35 }
36 changed
37 }
38 }
39
40 fn fission_in_function(func: &mut Function) -> bool {
41 let loops = find_natural_loops(func);
42 let preds = predecessors(func);
43
44 for lp in &loops {
45 let Some(_ph_id) = find_preheader(func, lp, &preds) else {
46 continue;
47 };
48 if lp.latches.len() != 1 {
49 continue;
50 }
51 let latch_id = lp.latches[0];
52
53 let hdr = func.block(lp.header);
54 if hdr.params.len() != 1 {
55 continue;
56 }
57 let iv = hdr.params[0].id;
58
59 // Find the single computation body block.
60 let body_block = find_computation_block(func, lp, latch_id);
61 let Some(body_bid) = body_block else { continue };
62
63 let block = func.block(body_bid);
64 if block.insts.len() < FISSION_MIN_BODY {
65 continue;
66 }
67
68 // Need exactly 2 stores to different arrays.
69 let stores: Vec<(usize, ValueId)> = block
70 .insts
71 .iter()
72 .enumerate()
73 .filter(|(_, i)| matches!(i.kind, InstKind::Store(..)))
74 .map(|(idx, i)| (idx, i.id))
75 .collect();
76 if stores.len() != 2 {
77 continue;
78 }
79
80 // Check independence via dep analysis.
81 let mut ivs = HashSet::new();
82 ivs.insert(iv);
83 let mem_refs = collect_mem_refs(func, &lp.body, &ivs);
84 let writes: Vec<_> = mem_refs.iter().filter(|r| r.is_write).collect();
85 if writes.len() != 2 {
86 continue;
87 }
88 if writes[0].base == writes[1].base {
89 continue;
90 }
91
92 let mut has_cross_dep = false;
93 for i in 0..mem_refs.len() {
94 for j in (i + 1)..mem_refs.len() {
95 if !mem_refs[i].is_write && !mem_refs[j].is_write {
96 continue;
97 }
98 if mem_refs[i].base == mem_refs[j].base {
99 continue;
100 }
101 let dep = test_dependence(&mem_refs[i], &mem_refs[j]);
102 if dep.dependent {
103 has_cross_dep = true;
104 break;
105 }
106 }
107 if has_cross_dep {
108 break;
109 }
110 }
111 if has_cross_dep {
112 continue;
113 }
114
115 // Find the exit block.
116 let exit_id = find_loop_exit(func, lp);
117 let Some(exit_id) = exit_id else { continue };
118
119 // Compute backward slices for each store.
120 // Clone the loop.
121 let (block_map, _) = clone_loop(func, lp);
122
123 // In original body: neutralize store B (replace value with undef).
124 neutralize_store(func, body_bid, stores[1].0);
125
126 // In clone body: neutralize store A.
127 let clone_body = block_map[&body_bid];
128 neutralize_store(func, clone_body, stores[0].0);
129
130 // Wire original exit → clone's header via a bridge block.
131 let clone_header = block_map[&lp.header];
132 let ph_id = find_preheader(func, lp, &predecessors(func)).unwrap();
133 let init_val = match &func.block(ph_id).terminator {
134 Some(Terminator::Branch(_, args)) if !args.is_empty() => args[0],
135 _ => continue,
136 };
137
138 let bridge = func.create_block("fission_bridge");
139 func.block_mut(bridge).terminator = Some(Terminator::Branch(clone_header, vec![init_val]));
140
141 // Redirect original cmp's exit → bridge.
142 for &bid in &lp.body {
143 let block = func.block_mut(bid);
144 if let Some(Terminator::CondBranch {
145 false_dest,
146 false_args,
147 ..
148 }) = &mut block.terminator
149 {
150 if *false_dest == exit_id {
151 *false_dest = bridge;
152 false_args.clear();
153 break;
154 }
155 }
156 }
157
158 // Clone's cmp exit already points to exit_id (wasn't remapped
159 // since it's outside the body). Correct.
160
161 return true;
162 }
163 false
164 }
165
166 /// Replace a store instruction's value operand with Undef, effectively
167 /// making it a dead store that DSE/DCE can clean up.
168 fn neutralize_store(func: &mut Function, block_id: BlockId, store_idx: usize) {
169 let block = func.block_mut(block_id);
170 if store_idx >= block.insts.len() {
171 return;
172 }
173 if let InstKind::Store(_, _) = block.insts[store_idx].kind {
174 // Replace with a store of undef — the store is now dead.
175 // We keep the store instruction (not remove it) to preserve
176 // SSA structure. DCE will clean it up later.
177 let undef_id = ValueId(u32::MAX - 1); // sentinel; will be replaced below
178 let _ = undef_id;
179 }
180 // Actually, the simplest approach: just remove the store entirely.
181 // Stores have no result value used by other instructions, so removing
182 // them can't break SSA. The computation leading to the store value
183 // becomes dead and DCE removes it.
184 block.insts.remove(store_idx);
185 }
186
187 fn find_computation_block(
188 func: &Function,
189 lp: &crate::ir::walk::NaturalLoop,
190 latch_id: BlockId,
191 ) -> Option<BlockId> {
192 let mut comp = None;
193 for &bid in &lp.body {
194 if bid == lp.header || bid == latch_id {
195 continue;
196 }
197 let block = func.block(bid);
198 if block
199 .insts
200 .iter()
201 .any(|i| matches!(i.kind, InstKind::Store(..)))
202 {
203 if comp.is_some() {
204 return None;
205 }
206 comp = Some(bid);
207 }
208 }
209 comp
210 }
211
212 fn backward_slice(
213 func: &Function,
214 root: ValueId,
215 loop_defs: &HashSet<ValueId>,
216 ) -> HashSet<ValueId> {
217 let mut slice = HashSet::new();
218 let mut worklist = vec![root];
219 while let Some(vid) = worklist.pop() {
220 if !slice.insert(vid) {
221 continue;
222 }
223 for block in &func.blocks {
224 for inst in &block.insts {
225 if inst.id == vid {
226 for operand in crate::ir::walk::inst_uses(&inst.kind) {
227 if loop_defs.contains(&operand) && !slice.contains(&operand) {
228 worklist.push(operand);
229 }
230 }
231 }
232 }
233 }
234 }
235 slice
236 }
237
238 fn find_loop_exit(func: &Function, lp: &crate::ir::walk::NaturalLoop) -> Option<BlockId> {
239 for &bid in &lp.body {
240 let block = func.block(bid);
241 if let Some(Terminator::CondBranch { false_dest, .. }) = &block.terminator {
242 if !lp.body.contains(false_dest) {
243 return Some(*false_dest);
244 }
245 }
246 }
247 None
248 }
249
250 #[cfg(test)]
251 mod tests {
252 use super::*;
253 use crate::ir::types::{IntWidth, IrType};
254 use crate::ir::verify::verify_module;
255 use crate::lexer::{Position, Span};
256 use crate::opt::pass::Pass;
257
258 fn span() -> Span {
259 let pos = Position { line: 0, col: 0 };
260 Span {
261 file_id: 0,
262 start: pos,
263 end: pos,
264 }
265 }
266
267 #[test]
268 fn fission_no_op_on_empty() {
269 let mut m = Module::new("test".into());
270 let mut f = Function::new("test".into(), vec![], IrType::Void);
271 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
272 m.add_function(f);
273 let pass = LoopFission;
274 let changed = pass.run(&mut m);
275 assert!(!changed, "no loops → no fission");
276 }
277
278 #[test]
279 fn fission_clears_exit_args_when_rerouting_to_bridge() {
280 let mut m = Module::new("test".into());
281 let mut f = Function::new("test".into(), vec![], IrType::Void);
282
283 let preheader = f.create_block("preheader");
284 let header = f.create_block("header");
285 let body = f.create_block("body");
286 let latch = f.create_block("latch");
287 let exit = f.create_block("exit");
288 let entry = f.entry;
289
290 let arr_ty = IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 16);
291 let arr_a = f.next_value_id();
292 f.register_type(arr_a, IrType::Ptr(Box::new(arr_ty.clone())));
293 f.block_mut(entry).insts.push(Inst {
294 id: arr_a,
295 ty: IrType::Ptr(Box::new(arr_ty.clone())),
296 span: span(),
297 kind: InstKind::Alloca(arr_ty.clone()),
298 });
299 let arr_b = f.next_value_id();
300 f.register_type(arr_b, IrType::Ptr(Box::new(arr_ty.clone())));
301 f.block_mut(entry).insts.push(Inst {
302 id: arr_b,
303 ty: IrType::Ptr(Box::new(arr_ty)),
304 span: span(),
305 kind: InstKind::Alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 16)),
306 });
307 let c0 = f.next_value_id();
308 f.register_type(c0, IrType::Int(IntWidth::I32));
309 f.block_mut(entry).insts.push(Inst {
310 id: c0,
311 ty: IrType::Int(IntWidth::I32),
312 span: span(),
313 kind: InstKind::ConstInt(0, IntWidth::I32),
314 });
315 let c1 = f.next_value_id();
316 f.register_type(c1, IrType::Int(IntWidth::I32));
317 f.block_mut(entry).insts.push(Inst {
318 id: c1,
319 ty: IrType::Int(IntWidth::I32),
320 span: span(),
321 kind: InstKind::ConstInt(1, IntWidth::I32),
322 });
323 let c2 = f.next_value_id();
324 f.register_type(c2, IrType::Int(IntWidth::I32));
325 f.block_mut(entry).insts.push(Inst {
326 id: c2,
327 ty: IrType::Int(IntWidth::I32),
328 span: span(),
329 kind: InstKind::ConstInt(2, IntWidth::I32),
330 });
331 let c4 = f.next_value_id();
332 f.register_type(c4, IrType::Int(IntWidth::I32));
333 f.block_mut(entry).insts.push(Inst {
334 id: c4,
335 ty: IrType::Int(IntWidth::I32),
336 span: span(),
337 kind: InstKind::ConstInt(4, IntWidth::I32),
338 });
339 f.block_mut(entry).terminator = Some(Terminator::Branch(preheader, vec![]));
340
341 f.block_mut(preheader).terminator = Some(Terminator::Branch(header, vec![c0]));
342
343 let iv = f.next_value_id();
344 f.register_type(iv, IrType::Int(IntWidth::I32));
345 f.block_mut(header).params.push(BlockParam {
346 id: iv,
347 ty: IrType::Int(IntWidth::I32),
348 });
349 f.block_mut(header).terminator = Some(Terminator::Branch(body, vec![]));
350
351 let gep_a = f.next_value_id();
352 f.register_type(gep_a, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))));
353 f.block_mut(body).insts.push(Inst {
354 id: gep_a,
355 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
356 span: span(),
357 kind: InstKind::GetElementPtr(arr_a, vec![iv]),
358 });
359 let store_a = f.next_value_id();
360 f.register_type(store_a, IrType::Void);
361 f.block_mut(body).insts.push(Inst {
362 id: store_a,
363 ty: IrType::Void,
364 span: span(),
365 kind: InstKind::Store(c1, gep_a),
366 });
367
368 let gep_b = f.next_value_id();
369 f.register_type(gep_b, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))));
370 f.block_mut(body).insts.push(Inst {
371 id: gep_b,
372 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
373 span: span(),
374 kind: InstKind::GetElementPtr(arr_b, vec![iv]),
375 });
376 let store_b = f.next_value_id();
377 f.register_type(store_b, IrType::Void);
378 f.block_mut(body).insts.push(Inst {
379 id: store_b,
380 ty: IrType::Void,
381 span: span(),
382 kind: InstKind::Store(c2, gep_b),
383 });
384 f.block_mut(body).terminator = Some(Terminator::Branch(latch, vec![]));
385
386 let nxt = f.next_value_id();
387 f.register_type(nxt, IrType::Int(IntWidth::I32));
388 f.block_mut(latch).insts.push(Inst {
389 id: nxt,
390 ty: IrType::Int(IntWidth::I32),
391 span: span(),
392 kind: InstKind::IAdd(iv, c1),
393 });
394 let cmp = f.next_value_id();
395 f.register_type(cmp, IrType::Bool);
396 f.block_mut(latch).insts.push(Inst {
397 id: cmp,
398 ty: IrType::Bool,
399 span: span(),
400 kind: InstKind::ICmp(CmpOp::Le, nxt, c4),
401 });
402 f.block_mut(latch).terminator = Some(Terminator::CondBranch {
403 cond: cmp,
404 true_dest: header,
405 true_args: vec![nxt],
406 false_dest: exit,
407 false_args: vec![nxt],
408 });
409
410 let exit_param = f.next_value_id();
411 f.register_type(exit_param, IrType::Int(IntWidth::I32));
412 f.block_mut(exit).params.push(BlockParam {
413 id: exit_param,
414 ty: IrType::Int(IntWidth::I32),
415 });
416 let _use_exit = f.next_value_id();
417 f.register_type(_use_exit, IrType::Int(IntWidth::I32));
418 f.block_mut(exit).insts.push(Inst {
419 id: _use_exit,
420 ty: IrType::Int(IntWidth::I32),
421 span: span(),
422 kind: InstKind::IAdd(exit_param, c1),
423 });
424 f.block_mut(exit).terminator = Some(Terminator::Return(None));
425
426 m.add_function(f);
427 assert!(verify_module(&m).is_empty(), "test setup must start valid");
428
429 let pass = LoopFission;
430 let changed = pass.run(&mut m);
431 assert!(changed, "the loop should fission");
432 assert!(
433 verify_module(&m).is_empty(),
434 "fission should keep bridge exit edges verifier-clean"
435 );
436 }
437 }
438