Rust · 15309 bytes Raw Blame History
1 //! Loop fusion pass.
2 //!
3 //! Merges two adjacent loops with identical iteration spaces using
4 //! the LLVM-inspired "latch redirect" pattern:
5 //! 1. Redirect A's latch to B's body (instead of A's header)
6 //! 2. Redirect B's latch back to A's header (instead of B's header)
7 //! 3. Remove B's header/cmp blocks (now unreachable)
8 //! 4. Remap B's IV to A's IV
9 //!
10 //! This avoids splicing instructions between blocks entirely.
11
12 use super::dep_analysis;
13 use super::loop_utils::remap_inst_kind;
14 use super::loop_utils::{find_preheader, resolve_const_int};
15 use super::pass::Pass;
16 use crate::ir::inst::*;
17 use crate::ir::walk::{find_natural_loops, predecessors};
18 use std::collections::HashSet;
19
20 pub struct LoopFusion;
21
22 impl Pass for LoopFusion {
23 fn name(&self) -> &'static str {
24 "loop-fusion"
25 }
26
27 fn run(&self, module: &mut Module) -> bool {
28 let mut changed = false;
29 for func in &mut module.functions {
30 if fusion_in_function(func) {
31 changed = true;
32 }
33 }
34 changed
35 }
36 }
37
38 fn fusion_in_function(func: &mut Function) -> bool {
39 let loops = find_natural_loops(func);
40 if loops.len() < 2 {
41 return false;
42 }
43 let preds = predecessors(func);
44
45 for i in 0..loops.len() {
46 for j in (i + 1)..loops.len() {
47 let lp_a = &loops[i];
48 let lp_b = &loops[j];
49
50 // Both need preheaders and single latches.
51 let Some(_ph_a) = find_preheader(func, lp_a, &preds) else {
52 continue;
53 };
54 let Some(ph_b) = find_preheader(func, lp_b, &preds) else {
55 continue;
56 };
57 if lp_a.latches.len() != 1 || lp_b.latches.len() != 1 {
58 continue;
59 }
60 let latch_a = lp_a.latches[0];
61 let latch_b = lp_b.latches[0];
62
63 // Both headers must have exactly 1 block param (IV).
64 let hdr_a = func.block(lp_a.header);
65 let hdr_b = func.block(lp_b.header);
66 if hdr_a.params.len() != 1 || hdr_b.params.len() != 1 {
67 continue;
68 }
69 let iv_a = hdr_a.params[0].id;
70 let iv_b = hdr_b.params[0].id;
71
72 // Neither loop should be nested inside the other.
73 if lp_a.body.iter().any(|b| lp_b.body.contains(b)) {
74 continue;
75 }
76 if lp_b.body.iter().any(|b| lp_a.body.contains(b)) {
77 continue;
78 }
79
80 // Neither loop should contain inner loops. Fusing loops
81 // with nested inner loops requires more complex handling
82 // (inner loop blocks need reparenting). V1: simple only.
83 let has_inner_a = loops
84 .iter()
85 .any(|other| other.header != lp_a.header && lp_a.body.is_superset(&other.body));
86 let has_inner_b = loops
87 .iter()
88 .any(|other| other.header != lp_b.header && lp_b.body.is_superset(&other.body));
89 if has_inner_a || has_inner_b {
90 continue;
91 }
92
93 // Skip loops produced by fission in the same pipeline
94 // iteration — they have clone/bridge blocks that fusion
95 // can't handle safely.
96 let has_clone_a = lp_a.body.iter().any(|&b| {
97 func.block(b).name.contains("clone") || func.block(b).name.contains("fission")
98 });
99 let has_clone_b = lp_b.body.iter().any(|&b| {
100 func.block(b).name.contains("clone") || func.block(b).name.contains("fission")
101 });
102 if has_clone_a || has_clone_b {
103 continue;
104 }
105
106 // Loop A's exit must be (or flow to) loop B's preheader.
107 let exit_a = find_loop_exit(func, lp_a);
108 let Some(exit_a) = exit_a else { continue };
109 if exit_a != ph_b && !flows_to(func, exit_a, ph_b) {
110 continue;
111 }
112
113 // Matching iteration spaces.
114 let Some(init_a) = get_init_const(func, lp_a, &preds) else {
115 continue;
116 };
117 let Some(init_b) = get_init_const(func, lp_b, &preds) else {
118 continue;
119 };
120 if init_a != init_b {
121 continue;
122 }
123
124 let Some(bound_a) = find_bound_const(func, lp_a, iv_a) else {
125 continue;
126 };
127 let Some(bound_b) = find_bound_const(func, lp_b, iv_b) else {
128 continue;
129 };
130 if bound_a != bound_b {
131 continue;
132 }
133
134 // Dep analysis: fusion legal?
135 if !dep_analysis::fusion_legal(func, &lp_a.body, &lp_b.body, iv_a, iv_b) {
136 continue;
137 }
138
139 let body_a = find_body_block(func, lp_a, latch_a);
140 let Some(body_a_id) = body_a else { continue };
141
142 // Find B's body block (the one with stores/computation).
143 let body_b = find_body_block(func, lp_b, latch_b);
144 let Some(body_b_id) = body_b else { continue };
145
146 let cmp_a = find_cmp_block(func, lp_a);
147 let Some(cmp_a_id) = cmp_a else { continue };
148
149 // Find B's cmp block (the one with ICmp).
150 let cmp_b = find_cmp_block(func, lp_b);
151 let Some(cmp_b_id) = cmp_b else { continue };
152
153 if !has_simple_fusion_shape(func, lp_a, latch_a, body_a_id, cmp_a_id) {
154 continue;
155 }
156 if !has_simple_fusion_shape(func, lp_b, latch_b, body_b_id, cmp_b_id) {
157 continue;
158 }
159
160 // Find B's exit block.
161 let exit_b = find_loop_exit(func, lp_b);
162 let Some(exit_b) = exit_b else { continue };
163
164 // ---- Perform fusion via latch redirect ----
165 do_fusion_latch_redirect(
166 func, lp_a, lp_b, latch_a, latch_b, iv_a, iv_b, body_b_id, cmp_b_id, exit_a, exit_b,
167 );
168 return true;
169 }
170 }
171 false
172 }
173
174 /// Fuse by redirecting latches:
175 /// A's latch → B's body (skipping B's header/cmp)
176 /// B's latch → A's header
177 /// Remap iv_b → iv_a throughout B's body
178 fn do_fusion_latch_redirect(
179 func: &mut Function,
180 lp_a: &crate::ir::walk::NaturalLoop,
181 lp_b: &crate::ir::walk::NaturalLoop,
182 latch_a: BlockId,
183 latch_b: BlockId,
184 iv_a: ValueId,
185 iv_b: ValueId,
186 body_b_id: BlockId,
187 _cmp_b_id: BlockId,
188 exit_a: BlockId,
189 exit_b: BlockId,
190 ) {
191 // Step 1: Remap iv_b → iv_a throughout ALL of B's body blocks.
192 // We rebuild each instruction with the IV substitution.
193 let mut sub_map = std::collections::HashMap::new();
194 sub_map.insert(iv_b, iv_a);
195 for &bid in &lp_b.body {
196 let old_insts: Vec<Inst> = func.block(bid).insts.clone();
197 let new_insts: Vec<Inst> = old_insts
198 .into_iter()
199 .map(|inst| {
200 let new_kind = remap_inst_kind(&inst.kind, &sub_map);
201 Inst {
202 kind: new_kind,
203 ..inst
204 }
205 })
206 .collect();
207 func.block_mut(bid).insts = new_insts;
208 // Also remap terminator operands.
209 let old_term = func.block(bid).terminator.clone();
210 if let Some(term) = old_term {
211 let new_term = super::loop_utils::remap_terminator(
212 &term,
213 &std::collections::HashMap::new(), // no block remapping
214 &sub_map,
215 );
216 func.block_mut(bid).terminator = Some(new_term);
217 }
218 }
219
220 // Step 1.5: Clone the gap block's (exit_a's) instructions into B's
221 // body block. These are constants (like `%24 = const 1`) that B's
222 // body needs but that will be destroyed when we replace exit_a's
223 // content with B's exit block in step 3. Prepend them to B's body
224 // so they dominate all uses within B's body.
225 let gap_insts: Vec<Inst> = func.block(exit_a).insts.clone();
226 let existing_b_insts: Vec<Inst> = func.block(body_b_id).insts.clone();
227 func.block_mut(body_b_id).insts = gap_insts;
228 func.block_mut(body_b_id).insts.extend(existing_b_insts);
229
230 // Step 2: A's latch currently branches to A's header. Redirect it
231 // to B's body block instead. This makes the fused loop run A's body
232 // then B's body on each iteration.
233 //
234 // But we need to preserve the IV increment: A's latch computes
235 // iv_next and passes it to A's header. After fusion, B's body
236 // should see iv_a (current iteration), and B's latch should pass
237 // iv_next back to A's header.
238 //
239 // The cleanest approach: redirect A's body exit (the block that
240 // branches to A's latch) to branch to B's body instead. Then B's
241 // body branches to B's latch, and B's latch branches to A's latch,
242 // and A's latch branches to A's header with iv_next.
243 //
244 // Actually simpler: redirect A's latch to branch through B's body
245 // before going back to A's header.
246 //
247 // The issue is that A's latch has the IV increment instruction.
248 // We want: A's body → B's body → A's latch (increment) → A's header.
249 //
250 // So redirect the block that branches to A's latch to go to B's body
251 // instead. Then B's latch branches to A's latch (not B's header).
252
253 // Find the block that branches to latch_a (A's body exit).
254 let a_body_exit = find_branch_to(func, lp_a, latch_a);
255 let Some(a_body_exit_id) = a_body_exit else {
256 return;
257 };
258
259 // Redirect A's body exit → B's body.
260 redirect_branch(func, a_body_exit_id, latch_a, body_b_id);
261
262 // Redirect B's latch → A's latch (not B's header).
263 // B's latch currently: br B's_header(iv_next_b). Change to: br A's_latch.
264 // But B's latch's iv_next_b should be discarded — A's latch will
265 // compute iv_next_a. So B's latch should just branch to A's latch
266 // with no args.
267 func.block_mut(latch_b).terminator = Some(Terminator::Branch(latch_a, vec![]));
268
269 // Step 3: A's cmp exit (exit_a) should now go to B's exit (exit_b)
270 // since B's header/cmp are bypassed. Copy B's exit block content
271 // into A's exit block.
272 let exit_b_insts: Vec<Inst> = func.block(exit_b).insts.clone();
273 let exit_b_term = func.block(exit_b).terminator.clone();
274 func.block_mut(exit_a).insts = exit_b_insts;
275 func.block_mut(exit_a).terminator = exit_b_term;
276
277 // Step 4: Mark B's header, cmp, and preheader blocks as unreachable.
278 func.block_mut(lp_b.header).terminator = Some(Terminator::Unreachable);
279 // Mark B's cmp blocks too.
280 for &bid in &lp_b.body {
281 if bid == body_b_id || bid == latch_b {
282 continue;
283 }
284 func.block_mut(bid).terminator = Some(Terminator::Unreachable);
285 }
286
287 crate::ir::walk::prune_unreachable(func);
288 }
289
290 fn find_branch_to(
291 func: &Function,
292 lp: &crate::ir::walk::NaturalLoop,
293 target: BlockId,
294 ) -> Option<BlockId> {
295 for &bid in &lp.body {
296 if bid == target {
297 continue;
298 }
299 let block = func.block(bid);
300 match &block.terminator {
301 Some(Terminator::Branch(dest, _)) if *dest == target => return Some(bid),
302 _ => {}
303 }
304 }
305 None
306 }
307
308 fn redirect_branch(func: &mut Function, from: BlockId, old_target: BlockId, new_target: BlockId) {
309 let block = func.block_mut(from);
310 if let Some(Terminator::Branch(dest, args)) = &mut block.terminator {
311 if *dest == old_target {
312 *dest = new_target;
313 args.clear(); // B's body doesn't take args from A's body
314 }
315 }
316 }
317
318 fn find_loop_exit(func: &Function, lp: &crate::ir::walk::NaturalLoop) -> Option<BlockId> {
319 for &bid in &lp.body {
320 let block = func.block(bid);
321 if let Some(Terminator::CondBranch { false_dest, .. }) = &block.terminator {
322 if !lp.body.contains(false_dest) {
323 return Some(*false_dest);
324 }
325 }
326 }
327 None
328 }
329
330 fn flows_to(func: &Function, from: BlockId, to: BlockId) -> bool {
331 let block = func.block(from);
332 match &block.terminator {
333 Some(Terminator::Branch(dest, _)) => {
334 if *dest == to {
335 return true;
336 }
337 let mid = func.block(*dest);
338 if let Some(Terminator::Branch(dest2, _)) = &mid.terminator {
339 return *dest2 == to;
340 }
341 false
342 }
343 _ => false,
344 }
345 }
346
347 fn get_init_const(
348 func: &Function,
349 lp: &crate::ir::walk::NaturalLoop,
350 preds: &std::collections::HashMap<BlockId, Vec<BlockId>>,
351 ) -> Option<i64> {
352 let ph = find_preheader(func, lp, preds)?;
353 let init_val = match &func.block(ph).terminator {
354 Some(Terminator::Branch(_, args)) if !args.is_empty() => args[0],
355 _ => return None,
356 };
357 resolve_const_int(func, init_val)
358 }
359
360 fn find_bound_const(
361 func: &Function,
362 lp: &crate::ir::walk::NaturalLoop,
363 iv: ValueId,
364 ) -> Option<i64> {
365 for &bid in &lp.body {
366 let block = func.block(bid);
367 for inst in &block.insts {
368 if let InstKind::ICmp(_, a, b) = &inst.kind {
369 let bound_val = if *a == iv {
370 *b
371 } else if *b == iv {
372 *a
373 } else {
374 continue;
375 };
376 return resolve_const_int(func, bound_val);
377 }
378 }
379 }
380 None
381 }
382
383 fn find_body_block(
384 func: &Function,
385 lp: &crate::ir::walk::NaturalLoop,
386 latch_id: BlockId,
387 ) -> Option<BlockId> {
388 for &bid in &lp.body {
389 if bid == lp.header || bid == latch_id {
390 continue;
391 }
392 let block = func.block(bid);
393 if block
394 .insts
395 .iter()
396 .any(|i| matches!(i.kind, InstKind::Store(..)))
397 {
398 return Some(bid);
399 }
400 }
401 None
402 }
403
404 fn find_cmp_block(func: &Function, lp: &crate::ir::walk::NaturalLoop) -> Option<BlockId> {
405 for &bid in &lp.body {
406 let block = func.block(bid);
407 if block
408 .insts
409 .iter()
410 .any(|i| matches!(i.kind, InstKind::ICmp(..)))
411 {
412 return Some(bid);
413 }
414 }
415 None
416 }
417
418 fn has_simple_fusion_shape(
419 func: &Function,
420 lp: &crate::ir::walk::NaturalLoop,
421 latch_id: BlockId,
422 body_id: BlockId,
423 cmp_id: BlockId,
424 ) -> bool {
425 let allowed: HashSet<BlockId> = [lp.header, cmp_id, body_id, latch_id].into_iter().collect();
426 if lp.body.iter().any(|bid| !allowed.contains(bid)) {
427 return false;
428 }
429
430 let body = func.block(body_id);
431 if !body.params.is_empty() {
432 return false;
433 }
434 matches!(&body.terminator, Some(Terminator::Branch(dest, args)) if *dest == latch_id && args.is_empty())
435 }
436
437 #[cfg(test)]
438 mod tests {
439 use super::*;
440 use crate::ir::types::IrType;
441 use crate::opt::pass::Pass;
442
443 #[test]
444 fn fusion_no_op_on_empty() {
445 let mut m = Module::new("test".into());
446 let mut f = Function::new("test".into(), vec![], IrType::Void);
447 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
448 m.add_function(f);
449 let pass = LoopFusion;
450 let changed = pass.run(&mut m);
451 assert!(!changed);
452 }
453 }
454