Rust · 15885 bytes Raw Blame History
1 //! Loop interchange pass.
2 //!
3 //! Swaps the iteration order of perfectly-nested loop pairs to improve
4 //! memory access patterns. Critical for Fortran because arrays are
5 //! column-major: `a(i, j)` is stored with `i` varying fastest. If the
6 //! inner loop iterates over `j` while `i` is outer, array accesses
7 //! stride by the column extent on each iteration — cache-hostile.
8 //! Interchanging makes `i` the inner loop, giving stride-1 access.
9 //!
10 //! ## Algorithm
11 //!
12 //! 1. Build loop tree, find perfectly-nested pairs.
13 //! 2. For each pair, detect counted-loop structure (header, IV, bounds).
14 //! 3. Analyze the body's GEP instructions to determine which IV is used
15 //! as the "fast" (first/leftmost) subscript.
16 //! 4. If the OUTER IV appears as the fast subscript, interchange is
17 //! profitable (and almost always legal for simple array assignments).
18 //! 5. Transform by swapping the branch arguments that carry init values
19 //! to each header, and swapping the bounds used in each comparison.
20 //!
21 //! ## Legality
22 //!
23 //! Conservative: only interchange when all array accesses in the body
24 //! use both IVs as simple direct subscripts with no cross-iteration
25 //! read-before-write. This avoids needing full dependence analysis.
26
27 use super::loop_tree::build_loop_tree;
28 use super::pass::Pass;
29 use crate::ir::inst::*;
30 use crate::ir::walk::predecessors;
31
32 pub struct LoopInterchange;
33
34 impl Pass for LoopInterchange {
35 fn name(&self) -> &'static str {
36 "loop-interchange"
37 }
38
39 fn run(&self, module: &mut Module) -> bool {
40 let mut changed = false;
41 for func in &mut module.functions {
42 if interchange_in_function(func) {
43 changed = true;
44 }
45 }
46 changed
47 }
48 }
49
50 fn interchange_in_function(func: &mut Function) -> bool {
51 let tree = build_loop_tree(func);
52 let preds = predecessors(func);
53 let pairs = tree.perfectly_nested_pairs(func);
54
55 for (outer_id, inner_id) in &pairs {
56 let outer = tree.node(*outer_id);
57 let inner = tree.node(*inner_id);
58
59 // Both loops must have a recognized counted-loop structure:
60 // header(%iv) → cmp_block(icmp, condBr) → body → latch(iadd, br header)
61 let Some(outer_shape) = detect_loop_shape(func, outer, &preds) else {
62 continue;
63 };
64 let Some(inner_shape) = detect_loop_shape(func, inner, &preds) else {
65 continue;
66 };
67
68 // Check profitability: is the outer IV used as the first (fast)
69 // subscript of a multi-dimensional array GEP?
70 if !should_interchange(func, inner, outer_shape.iv, inner_shape.iv) {
71 continue;
72 }
73
74 // Check legality: conservative — only if the body has no
75 // loop-carried dependencies that would change semantics.
76 if !is_interchange_legal(func, inner, outer_shape.iv, inner_shape.iv) {
77 continue;
78 }
79
80 // Perform the interchange by swapping the loop bounds and inits.
81 do_interchange(func, &outer_shape, &inner_shape);
82 return true; // one at a time
83 }
84 false
85 }
86
87 /// Minimal loop shape for interchange.
88 struct LoopShape {
89 header: BlockId,
90 cmp_block: BlockId,
91 iv: ValueId, // block param on header
92 bound: ValueId, // the upper-bound value in the comparison
93 latch: BlockId,
94 /// The value passed to the header from the preheader (initial IV).
95 init_arg_idx: usize, // index in preheader's branch args
96 }
97
98 fn detect_loop_shape(
99 func: &Function,
100 node: &super::loop_tree::LoopTreeNode,
101 _preds: &std::collections::HashMap<BlockId, Vec<BlockId>>,
102 ) -> Option<LoopShape> {
103 let header = node.header;
104 let hdr = func.block(header);
105
106 // Header must have exactly 1 block param (the IV).
107 if hdr.params.len() != 1 {
108 return None;
109 }
110 let iv = hdr.params[0].id;
111
112 // Header must be a relay (0 instructions, branch to cmp_block).
113 if !hdr.insts.is_empty() {
114 return None;
115 }
116 let cmp_block = match &hdr.terminator {
117 Some(Terminator::Branch(t, args)) if args.is_empty() => *t,
118 _ => return None,
119 };
120 if !node.body.contains(&cmp_block) {
121 return None;
122 }
123
124 // Cmp block must have icmp + condBr.
125 let cmp_blk = func.block(cmp_block);
126 let bound = {
127 let mut found_bound = None;
128 for inst in &cmp_blk.insts {
129 if let InstKind::ICmp(_, a, b) = &inst.kind {
130 // One operand should be the IV, the other is the bound.
131 if *a == iv {
132 found_bound = Some(*b);
133 } else if *b == iv {
134 found_bound = Some(*a);
135 }
136 }
137 }
138 found_bound?
139 };
140
141 // Find the single latch.
142 if node.latches.len() != 1 {
143 return None;
144 }
145 let latch = node.latches[0];
146
147 Some(LoopShape {
148 header,
149 cmp_block,
150 iv,
151 bound,
152 latch,
153 init_arg_idx: 0, // always first param
154 })
155 }
156
157 /// Check if interchanging would improve memory access patterns.
158 ///
159 /// Returns true if the OUTER IV appears as the "fast-varying" (first)
160 /// subscript in a multi-dimensional array GEP. In column-major Fortran,
161 /// the first subscript should be the inner loop's IV for stride-1 access.
162 fn should_interchange(
163 func: &Function,
164 inner_loop: &super::loop_tree::LoopTreeNode,
165 outer_iv: ValueId,
166 inner_iv: ValueId,
167 ) -> bool {
168 // Scan the inner loop body for GEP instructions that use both IVs.
169 for &bid in &inner_loop.body {
170 let block = func.block(bid);
171 for inst in &block.insts {
172 if let InstKind::GetElementPtr(_, indices) = &inst.kind {
173 // We're looking for a flat-offset GEP where the offset
174 // is computed as: (outer_iv - lo) + (inner_iv - lo) * stride
175 // or equivalently: fast_part + slow_part * stride
176 //
177 // In column-major, the first addend (non-multiplied part)
178 // is the "fast" subscript. If the outer IV contributes to
179 // the non-multiplied addend, interchange is profitable.
180 if let Some(offset_val) = indices.first() {
181 if uses_iv_in_fast_position(func, *offset_val, outer_iv, inner_iv) {
182 return true;
183 }
184 }
185 }
186 }
187 }
188 false
189 }
190
191 /// Check if the flat offset computation has the outer IV in the
192 /// non-multiplied (fast) position. The lowered pattern is:
193 /// %fast = isub %outer_iv, %lo
194 /// %slow_raw = isub %inner_iv, %lo
195 /// %slow = imul %slow_raw, %stride
196 /// %offset = iadd %fast, %slow
197 ///
198 /// We trace back from the GEP index to find if outer_iv feeds the
199 /// non-multiplied side of the final iadd.
200 fn uses_iv_in_fast_position(
201 func: &Function,
202 offset: ValueId,
203 outer_iv: ValueId,
204 _inner_iv: ValueId,
205 ) -> bool {
206 // Find the instruction that produces `offset`.
207 let Some(inst) = find_inst(func, offset) else {
208 return false;
209 };
210
211 // The offset should be an iadd of two parts.
212 let (a, b) = match &inst.kind {
213 InstKind::IAdd(a, b) => (*a, *b),
214 _ => return false,
215 };
216
217 // One side should be the fast part (derived from outer_iv without
218 // multiplication), the other should be the slow part (involves imul).
219 let a_uses_mul = trace_involves_mul(func, a);
220 let b_uses_mul = trace_involves_mul(func, b);
221
222 // The fast part is the one WITHOUT multiplication.
223 let fast_part = if !a_uses_mul && b_uses_mul {
224 a
225 } else if a_uses_mul && !b_uses_mul {
226 b
227 } else {
228 return false;
229 };
230
231 // Does the fast part trace back to the outer IV?
232 traces_to_iv(func, fast_part, outer_iv)
233 }
234
235 /// Check if a value's computation involves an IMul somewhere.
236 fn trace_involves_mul(func: &Function, val: ValueId) -> bool {
237 let Some(inst) = find_inst(func, val) else {
238 return false;
239 };
240 match &inst.kind {
241 InstKind::IMul(..) => true,
242 InstKind::IAdd(a, b) | InstKind::ISub(a, b) => {
243 trace_involves_mul(func, *a) || trace_involves_mul(func, *b)
244 }
245 InstKind::IntExtend(a, _, _) => trace_involves_mul(func, *a),
246 _ => false,
247 }
248 }
249
250 /// Check if a value traces back to a specific IV (through isub, int_extend).
251 fn traces_to_iv(func: &Function, val: ValueId, iv: ValueId) -> bool {
252 if val == iv {
253 return true;
254 }
255 let Some(inst) = find_inst(func, val) else {
256 return false;
257 };
258 match &inst.kind {
259 InstKind::ISub(a, _) => traces_to_iv(func, *a, iv),
260 InstKind::IntExtend(a, _, _) => traces_to_iv(func, *a, iv),
261 _ => false,
262 }
263 }
264
265 /// Find the instruction that defines a value.
266 fn find_inst(func: &Function, vid: ValueId) -> Option<&Inst> {
267 for block in &func.blocks {
268 for inst in &block.insts {
269 if inst.id == vid {
270 return Some(inst);
271 }
272 }
273 }
274 None
275 }
276
277 /// Legality check using dependence analysis. Verifies that swapping
278 /// the loop order does not reverse any dependence direction.
279 fn is_interchange_legal(
280 func: &Function,
281 inner_loop: &super::loop_tree::LoopTreeNode,
282 outer_iv: ValueId,
283 inner_iv: ValueId,
284 ) -> bool {
285 super::dep_analysis::interchange_legal(func, &inner_loop.body, outer_iv, inner_iv)
286 }
287
288 /// Trace through a GEP chain to find the base array pointer.
289 fn trace_gep_base(func: &Function, ptr: ValueId) -> Option<ValueId> {
290 let Some(inst) = find_inst(func, ptr) else {
291 return Some(ptr);
292 };
293 match &inst.kind {
294 InstKind::GetElementPtr(base, _) => Some(*base),
295 _ => Some(ptr),
296 }
297 }
298
299 /// Perform the actual interchange transformation.
300 ///
301 /// Strategy: swap the init and bound values between the two loops.
302 /// After swapping, what was the outer loop iterates over the inner
303 /// range and vice versa.
304 fn do_interchange(func: &mut Function, outer: &LoopShape, inner: &LoopShape) {
305 // Find the preheader of the outer loop (it branches to outer.header
306 // with the outer IV init value).
307 let outer_preheader = {
308 let mut ph = None;
309 for block in &func.blocks {
310 if let Some(Terminator::Branch(dest, _)) = &block.terminator {
311 if *dest == outer.header && block.id != outer.latch {
312 ph = Some(block.id);
313 break;
314 }
315 }
316 if let Some(Terminator::CondBranch {
317 true_dest,
318 false_dest,
319 ..
320 }) = &block.terminator
321 {
322 if *true_dest == outer.header || *false_dest == outer.header {
323 // Could be a condBr preheader from preheader insertion
324 // but we need the unconditional one.
325 }
326 }
327 }
328 ph
329 };
330 let Some(outer_ph) = outer_preheader else {
331 return;
332 };
333
334 // Find the block that branches to the inner header with the inner
335 // IV init value (this is the outer loop's "body entry" block).
336 let inner_entry = {
337 let mut ie = None;
338 for block in &func.blocks {
339 if let Some(Terminator::Branch(dest, args)) = &block.terminator {
340 if *dest == inner.header && !args.is_empty() && block.id != inner.latch {
341 ie = Some(block.id);
342 break;
343 }
344 }
345 if let Some(Terminator::CondBranch {
346 true_dest,
347 true_args,
348 ..
349 }) = &block.terminator
350 {
351 if *true_dest == inner.header && !true_args.is_empty() {
352 ie = Some(block.id);
353 break;
354 }
355 }
356 }
357 ie
358 };
359 let Some(inner_entry_block) = inner_entry else {
360 return;
361 };
362
363 // Get current init values.
364 let outer_init = get_branch_arg_to(func, outer_ph, outer.header, 0);
365 let inner_init = get_branch_arg_to(func, inner_entry_block, inner.header, 0);
366 let Some(outer_init_val) = outer_init else {
367 return;
368 };
369 let Some(inner_init_val) = inner_init else {
370 return;
371 };
372
373 // Swap init values: outer preheader now passes inner's init to
374 // outer's header, and inner entry now passes outer's init to
375 // inner's header.
376 set_branch_arg_to(func, outer_ph, outer.header, 0, inner_init_val);
377 set_branch_arg_to(func, inner_entry_block, inner.header, 0, outer_init_val);
378
379 // Swap the bounds in the comparison blocks.
380 swap_bound(func, outer.cmp_block, outer.iv, outer.bound, inner.bound);
381 swap_bound(func, inner.cmp_block, inner.iv, inner.bound, outer.bound);
382 }
383
384 /// Get the Nth branch argument passed to a target block.
385 fn get_branch_arg_to(func: &Function, from: BlockId, to: BlockId, idx: usize) -> Option<ValueId> {
386 let block = func.block(from);
387 match &block.terminator {
388 Some(Terminator::Branch(dest, args)) if *dest == to => args.get(idx).copied(),
389 Some(Terminator::CondBranch {
390 true_dest,
391 true_args,
392 false_dest,
393 false_args,
394 ..
395 }) => {
396 if *true_dest == to {
397 true_args.get(idx).copied()
398 } else if *false_dest == to {
399 false_args.get(idx).copied()
400 } else {
401 None
402 }
403 }
404 _ => None,
405 }
406 }
407
408 /// Set the Nth branch argument passed to a target block.
409 fn set_branch_arg_to(func: &mut Function, from: BlockId, to: BlockId, idx: usize, val: ValueId) {
410 let block = func.block_mut(from);
411 match &mut block.terminator {
412 Some(Terminator::Branch(dest, args)) if *dest == to => {
413 if idx < args.len() {
414 args[idx] = val;
415 }
416 }
417 Some(Terminator::CondBranch {
418 true_dest,
419 true_args,
420 false_dest,
421 false_args,
422 ..
423 }) => {
424 if *true_dest == to && idx < true_args.len() {
425 true_args[idx] = val;
426 } else if *false_dest == to && idx < false_args.len() {
427 false_args[idx] = val;
428 }
429 }
430 _ => {}
431 }
432 }
433
434 /// Swap the bound value in a comparison block's ICmp instruction.
435 fn swap_bound(
436 func: &mut Function,
437 cmp_block: BlockId,
438 iv: ValueId,
439 old_bound: ValueId,
440 new_bound: ValueId,
441 ) {
442 let block = func.block_mut(cmp_block);
443 for inst in &mut block.insts {
444 if let InstKind::ICmp(_op, a, b) = &mut inst.kind {
445 if *a == iv && *b == old_bound {
446 *b = new_bound;
447 return;
448 } else if *b == iv && *a == old_bound {
449 *a = new_bound;
450 return;
451 }
452 }
453 }
454 }
455
456 // ---------------------------------------------------------------------------
457 // Tests
458 // ---------------------------------------------------------------------------
459
460 #[cfg(test)]
461 mod tests {
462 use super::*;
463 use crate::ir::types::IrType;
464 use crate::lexer::{Position, Span};
465 use crate::opt::pass::Pass;
466
467 fn span() -> Span {
468 let pos = Position { line: 0, col: 0 };
469 Span {
470 file_id: 0,
471 start: pos,
472 end: pos,
473 }
474 }
475
476 #[test]
477 fn interchange_pass_builds() {
478 // Smoke test: the pass can be constructed and run on an empty module.
479 let mut m = Module::new("test".into());
480 let mut f = Function::new("test".into(), vec![], IrType::Void);
481 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
482 m.add_function(f);
483 let pass = LoopInterchange;
484 let changed = pass.run(&mut m);
485 assert!(!changed, "no loops → no interchange");
486 }
487 }
488