Rust · 19378 bytes Raw Blame History
1 //! Loop nesting tree.
2 //!
3 //! Builds a parent/child hierarchy from the flat `Vec<NaturalLoop>`
4 //! returned by `find_natural_loops`. Each node knows its depth,
5 //! parent, and children, so passes like interchange can find
6 //! perfectly-nested pairs and unswitching can target innermost loops.
7
8 use crate::ir::inst::{BlockId, Function};
9 use crate::ir::walk::find_natural_loops;
10 use std::collections::{HashMap, HashSet};
11
12 /// Unique identifier for a loop in the tree.
13 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14 pub struct LoopId(pub u32);
15
16 /// A node in the loop nesting tree.
17 #[derive(Debug, Clone)]
18 pub struct LoopTreeNode {
19 pub id: LoopId,
20 pub header: BlockId,
21 pub body: HashSet<BlockId>,
22 pub latches: Vec<BlockId>,
23 pub parent: Option<LoopId>,
24 pub children: Vec<LoopId>,
25 /// Nesting depth: 1 for outermost loops, 2 for their children, etc.
26 pub depth: u32,
27 }
28
29 /// The complete loop nesting forest for a function.
30 #[derive(Debug)]
31 pub struct LoopTree {
32 pub nodes: Vec<LoopTreeNode>,
33 /// Maps each block to the innermost loop that contains it.
34 pub block_to_loop: HashMap<BlockId, LoopId>,
35 }
36
37 impl LoopTree {
38 /// Return the IDs of all innermost (leaf) loops — loops with no children.
39 pub fn innermost_loops(&self) -> Vec<LoopId> {
40 self.nodes
41 .iter()
42 .filter(|n| n.children.is_empty())
43 .map(|n| n.id)
44 .collect()
45 }
46
47 /// Get a node by ID.
48 pub fn node(&self, id: LoopId) -> &LoopTreeNode {
49 &self.nodes[id.0 as usize]
50 }
51
52 /// Nesting depth of a block (0 if not in any loop).
53 pub fn loop_depth(&self, block: BlockId) -> u32 {
54 self.block_to_loop
55 .get(&block)
56 .map(|lid| self.node(*lid).depth)
57 .unwrap_or(0)
58 }
59
60 /// Return the parent loop of a given loop, if any.
61 pub fn parent(&self, id: LoopId) -> Option<LoopId> {
62 self.node(id).parent
63 }
64
65 /// Return (outer, inner) pairs of perfectly nested loops.
66 /// "Perfectly nested" means outer's body blocks (excluding inner's
67 /// body) contain no instructions other than control flow to/from
68 /// the inner loop.
69 pub fn perfectly_nested_pairs(&self, func: &Function) -> Vec<(LoopId, LoopId)> {
70 let mut pairs = Vec::new();
71 for node in &self.nodes {
72 if node.children.len() != 1 {
73 continue;
74 }
75 let child_id = node.children[0];
76 let child = self.node(child_id);
77
78 // The outer loop's "own" blocks (not in the inner loop)
79 // must contain no non-trivial instructions. We allow:
80 // - the outer header (relay block, 0 insts ok)
81 // - the outer cmp block (comparison + condBr)
82 // - the outer latch (increment + branch)
83 // - the outer "body" block that just branches to the inner preheader
84 // Everything else must be part of the inner loop.
85 let outer_only: Vec<BlockId> = node
86 .body
87 .iter()
88 .filter(|b| !child.body.contains(b))
89 .copied()
90 .collect();
91
92 // Conservative check: outer-only blocks should have very few
93 // instructions total (header relay + cmp + latch + body-entry).
94 let total_outer_insts: usize =
95 outer_only.iter().map(|&b| func.block(b).insts.len()).sum();
96
97 // A typical Fortran DO nest has ~4-6 instructions in the
98 // outer shell (const bound, icmp, iadd). Allow up to 10
99 // to account for LICM-hoisted invariants.
100 if total_outer_insts <= 10 {
101 pairs.push((node.id, child_id));
102 }
103 }
104 pairs
105 }
106 }
107
108 /// Build the loop nesting tree for a function.
109 ///
110 /// Algorithm: discover natural loops, sort by body size descending,
111 /// then for each loop find the smallest enclosing loop (its parent).
112 pub fn build_loop_tree(func: &Function) -> LoopTree {
113 let natural = find_natural_loops(func);
114
115 if natural.is_empty() {
116 return LoopTree {
117 nodes: Vec::new(),
118 block_to_loop: HashMap::new(),
119 };
120 }
121
122 // Create nodes sorted by body size descending (largest = outermost first).
123 let mut indexed: Vec<(usize, &crate::ir::walk::NaturalLoop)> =
124 natural.iter().enumerate().collect();
125 indexed.sort_by_key(|entry| std::cmp::Reverse(entry.1.body.len()));
126
127 // Build nodes with stable IDs (original discovery order).
128 let mut nodes: Vec<LoopTreeNode> = natural
129 .iter()
130 .enumerate()
131 .map(|(i, nl)| LoopTreeNode {
132 id: LoopId(i as u32),
133 header: nl.header,
134 body: nl.body.clone(),
135 latches: nl.latches.clone(),
136 parent: None,
137 children: Vec::new(),
138 depth: 0,
139 })
140 .collect();
141
142 // For each loop, find its parent = the smallest loop that strictly
143 // contains it. We iterate in body-size order so we can check
144 // containment efficiently.
145 let n = nodes.len();
146 for i in 0..n {
147 let mut best_parent: Option<LoopId> = None;
148 let mut best_size = usize::MAX;
149 for j in 0..n {
150 if i == j {
151 continue;
152 }
153 // j is a candidate parent if j's body strictly contains i's body.
154 if nodes[j].body.len() > nodes[i].body.len()
155 && nodes[i].body.is_subset(&nodes[j].body)
156 && nodes[j].body.len() < best_size
157 {
158 best_parent = Some(nodes[j].id);
159 best_size = nodes[j].body.len();
160 }
161 }
162 nodes[i].parent = best_parent;
163 }
164
165 // Wire children from parent links.
166 let parents: Vec<Option<LoopId>> = nodes.iter().map(|n| n.parent).collect();
167 for (i, parent) in parents.into_iter().enumerate() {
168 if let Some(pid) = parent {
169 nodes[pid.0 as usize].children.push(LoopId(i as u32));
170 }
171 }
172
173 // Sort children by header for determinism.
174 // (Two-step to satisfy the borrow checker: collect sort keys, then sort.)
175 for i in 0..n {
176 let headers: Vec<(LoopId, u32)> = nodes[i]
177 .children
178 .iter()
179 .map(|c| (*c, nodes[c.0 as usize].header.0))
180 .collect();
181 nodes[i].children.sort_by_key(|c| {
182 headers
183 .iter()
184 .find(|(id, _)| id == c)
185 .map(|(_, h)| *h)
186 .unwrap_or(0)
187 });
188 }
189
190 // Compute depths.
191 fn set_depth(nodes: &mut [LoopTreeNode], id: LoopId, depth: u32) {
192 nodes[id.0 as usize].depth = depth;
193 let children: Vec<LoopId> = nodes[id.0 as usize].children.clone();
194 for child in children {
195 set_depth(nodes, child, depth + 1);
196 }
197 }
198 for i in 0..n {
199 if nodes[i].parent.is_none() {
200 set_depth(&mut nodes, LoopId(i as u32), 1);
201 }
202 }
203
204 // Build block → innermost loop mapping.
205 // Process innermost (deepest) loops last so they overwrite parents.
206 let mut block_to_loop: HashMap<BlockId, LoopId> = HashMap::new();
207 let mut by_depth: Vec<(u32, LoopId)> = nodes.iter().map(|n| (n.depth, n.id)).collect();
208 by_depth.sort_by_key(|(d, _)| *d);
209 for (_, lid) in by_depth {
210 for &block in &nodes[lid.0 as usize].body {
211 block_to_loop.insert(block, lid);
212 }
213 }
214
215 LoopTree {
216 nodes,
217 block_to_loop,
218 }
219 }
220
221 // ---------------------------------------------------------------------------
222 // Tests
223 // ---------------------------------------------------------------------------
224
225 #[cfg(test)]
226 mod tests {
227 use super::*;
228 use crate::ir::inst::*;
229 use crate::ir::types::{IntWidth, IrType};
230 use crate::lexer::{Position, Span};
231
232 fn span() -> Span {
233 let pos = Position { line: 0, col: 0 };
234 Span {
235 file_id: 0,
236 start: pos,
237 end: pos,
238 }
239 }
240
241 /// Build a function with 3-level nested DO loops:
242 /// entry → outer_header → outer_cmp → inner_header → inner_cmp →
243 /// innermost_header → innermost_cmp → body → innermost_latch →
244 /// inner_latch → outer_latch → outer_header
245 ///
246 /// Each loop has: header (block param) → cmp (icmp + condBr) → ... → latch (iadd + br header)
247 fn build_triple_nested() -> Function {
248 let mut f = Function::new("triple".into(), vec![], IrType::Void);
249
250 // Create all blocks upfront.
251 let outer_hdr = f.create_block("outer_hdr");
252 let outer_cmp = f.create_block("outer_cmp");
253 let inner_hdr = f.create_block("inner_hdr");
254 let inner_cmp = f.create_block("inner_cmp");
255 let deep_hdr = f.create_block("deep_hdr");
256 let deep_cmp = f.create_block("deep_cmp");
257 let body = f.create_block("body");
258 let deep_latch = f.create_block("deep_latch");
259 let inner_latch = f.create_block("inner_latch");
260 let outer_latch = f.create_block("outer_latch");
261 let exit = f.create_block("exit");
262
263 let entry = f.entry;
264
265 // Entry: const 1, br outer_hdr(1)
266 let one = f.next_value_id();
267 f.register_type(one, IrType::Int(IntWidth::I32));
268 f.block_mut(entry).insts.push(Inst {
269 id: one,
270 ty: IrType::Int(IntWidth::I32),
271 span: span(),
272 kind: InstKind::ConstInt(1, IntWidth::I32),
273 });
274 f.block_mut(entry).terminator = Some(Terminator::Branch(outer_hdr, vec![one]));
275
276 // Helper: add a simple loop level (header with param → cmp → condBr)
277 fn add_loop_level(
278 f: &mut Function,
279 hdr: BlockId,
280 cmp: BlockId,
281 body_target: BlockId,
282 exit_target: BlockId,
283 latch: BlockId,
284 init_src: ValueId,
285 ) -> (ValueId, ValueId) {
286 // Header: block param
287 let iv = f.next_value_id();
288 f.register_type(iv, IrType::Int(IntWidth::I32));
289 f.block_mut(hdr).params.push(BlockParam {
290 id: iv,
291 ty: IrType::Int(IntWidth::I32),
292 });
293 f.block_mut(hdr).terminator = Some(Terminator::Branch(cmp, vec![]));
294
295 // Cmp: icmp le iv, 10; condBr body, exit
296 let bound = f.next_value_id();
297 f.register_type(bound, IrType::Int(IntWidth::I32));
298 f.block_mut(cmp).insts.push(Inst {
299 id: bound,
300 ty: IrType::Int(IntWidth::I32),
301 span: span(),
302 kind: InstKind::ConstInt(10, IntWidth::I32),
303 });
304 let cmp_val = f.next_value_id();
305 f.register_type(cmp_val, IrType::Bool);
306 f.block_mut(cmp).insts.push(Inst {
307 id: cmp_val,
308 ty: IrType::Bool,
309 span: span(),
310 kind: InstKind::ICmp(CmpOp::Le, iv, bound),
311 });
312 f.block_mut(cmp).terminator = Some(Terminator::CondBranch {
313 cond: cmp_val,
314 true_dest: body_target,
315 true_args: vec![],
316 false_dest: exit_target,
317 false_args: vec![],
318 });
319
320 // Latch: iadd iv, 1; br hdr(next)
321 let one_l = f.next_value_id();
322 f.register_type(one_l, IrType::Int(IntWidth::I32));
323 f.block_mut(latch).insts.push(Inst {
324 id: one_l,
325 ty: IrType::Int(IntWidth::I32),
326 span: span(),
327 kind: InstKind::ConstInt(1, IntWidth::I32),
328 });
329 let next = f.next_value_id();
330 f.register_type(next, IrType::Int(IntWidth::I32));
331 f.block_mut(latch).insts.push(Inst {
332 id: next,
333 ty: IrType::Int(IntWidth::I32),
334 span: span(),
335 kind: InstKind::IAdd(iv, one_l),
336 });
337 f.block_mut(latch).terminator = Some(Terminator::Branch(hdr, vec![next]));
338
339 let _ = init_src; // init_src is passed as the branch arg from the outer scope
340 (iv, next)
341 }
342
343 // Outer loop: entry→outer_hdr(1)→outer_cmp→inner_hdr/exit
344 let (_, _) = add_loop_level(
345 &mut f,
346 outer_hdr,
347 outer_cmp,
348 inner_hdr,
349 exit,
350 outer_latch,
351 one,
352 );
353 // Wire outer_cmp's true branch to pass `one` to inner_hdr
354 // (inner loop starts at 1 each outer iteration)
355 if let Some(Terminator::CondBranch { true_args, .. }) =
356 &mut f.block_mut(outer_cmp).terminator
357 {
358 true_args.push(one);
359 }
360
361 // Inner loop
362 let (_, _) = add_loop_level(
363 &mut f,
364 inner_hdr,
365 inner_cmp,
366 deep_hdr,
367 inner_latch,
368 inner_latch,
369 one,
370 );
371 // Inner latch needs to go to outer_latch after the inner loop exits...
372 // Actually inner_latch IS the inner latch (iadd + br inner_hdr). The exit
373 // of the inner loop goes to outer_latch. Let me fix: inner_cmp's false goes to outer_latch.
374 // But add_loop_level set false_dest = inner_latch for the exit. Let me restructure.
375
376 // This is getting complex. Let me simplify: just set terminators directly.
377 // Clear what add_loop_level did and redo properly.
378
379 // Actually let me just build it manually for clarity.
380 // I'll clear all blocks and rebuild.
381 for b in &mut f.blocks {
382 b.insts.clear();
383 b.params.clear();
384 b.terminator = None;
385 }
386
387 // Re-emit entry
388 let c1 = f.next_value_id();
389 f.register_type(c1, IrType::Int(IntWidth::I32));
390 f.block_mut(entry).insts.push(Inst {
391 id: c1,
392 ty: IrType::Int(IntWidth::I32),
393 span: span(),
394 kind: InstKind::ConstInt(1, IntWidth::I32),
395 });
396 let c10 = f.next_value_id();
397 f.register_type(c10, IrType::Int(IntWidth::I32));
398 f.block_mut(entry).insts.push(Inst {
399 id: c10,
400 ty: IrType::Int(IntWidth::I32),
401 span: span(),
402 kind: InstKind::ConstInt(10, IntWidth::I32),
403 });
404 f.block_mut(entry).terminator = Some(Terminator::Branch(outer_hdr, vec![c1]));
405
406 // Macro for loop headers
407 macro_rules! loop_header {
408 ($f:expr, $hdr:expr, $cmp:expr) => {{
409 let iv = $f.next_value_id();
410 $f.register_type(iv, IrType::Int(IntWidth::I32));
411 $f.block_mut($hdr).params.push(BlockParam {
412 id: iv,
413 ty: IrType::Int(IntWidth::I32),
414 });
415 $f.block_mut($hdr).terminator = Some(Terminator::Branch($cmp, vec![]));
416 iv
417 }};
418 }
419 macro_rules! loop_cmp {
420 ($f:expr, $cmp:expr, $iv:expr, $bound:expr, $t:expr, $t_args:expr, $fal:expr) => {{
421 let cv = $f.next_value_id();
422 $f.register_type(cv, IrType::Bool);
423 $f.block_mut($cmp).insts.push(Inst {
424 id: cv,
425 ty: IrType::Bool,
426 span: span(),
427 kind: InstKind::ICmp(CmpOp::Le, $iv, $bound),
428 });
429 $f.block_mut($cmp).terminator = Some(Terminator::CondBranch {
430 cond: cv,
431 true_dest: $t,
432 true_args: $t_args,
433 false_dest: $fal,
434 false_args: vec![],
435 });
436 }};
437 }
438 macro_rules! loop_latch {
439 ($f:expr, $latch:expr, $iv:expr, $one:expr, $hdr:expr) => {{
440 let nxt = $f.next_value_id();
441 $f.register_type(nxt, IrType::Int(IntWidth::I32));
442 $f.block_mut($latch).insts.push(Inst {
443 id: nxt,
444 ty: IrType::Int(IntWidth::I32),
445 span: span(),
446 kind: InstKind::IAdd($iv, $one),
447 });
448 $f.block_mut($latch).terminator = Some(Terminator::Branch($hdr, vec![nxt]));
449 }};
450 }
451
452 // Outer: hdr(%oi) → cmp → {inner_hdr(1), exit}
453 let oi = loop_header!(f, outer_hdr, outer_cmp);
454 loop_cmp!(f, outer_cmp, oi, c10, inner_hdr, vec![c1], exit);
455
456 // Inner: hdr(%ii) → cmp → {deep_hdr(1), outer_latch}
457 let ii = loop_header!(f, inner_hdr, inner_cmp);
458 loop_cmp!(f, inner_cmp, ii, c10, deep_hdr, vec![c1], outer_latch);
459
460 // Deep: hdr(%di) → cmp → {body, inner_latch}
461 let di = loop_header!(f, deep_hdr, deep_cmp);
462 loop_cmp!(f, deep_cmp, di, c10, body, vec![], inner_latch);
463
464 // Body: just branch to deep_latch
465 f.block_mut(body).terminator = Some(Terminator::Branch(deep_latch, vec![]));
466
467 // Latches
468 loop_latch!(f, deep_latch, di, c1, deep_hdr);
469 loop_latch!(f, inner_latch, ii, c1, inner_hdr);
470 loop_latch!(f, outer_latch, oi, c1, outer_hdr);
471
472 // Exit
473 f.block_mut(exit).terminator = Some(Terminator::Return(None));
474
475 f
476 }
477
478 #[test]
479 fn triple_nested_tree_structure() {
480 let f = build_triple_nested();
481 let tree = build_loop_tree(&f);
482
483 assert_eq!(tree.nodes.len(), 3, "should have 3 loops");
484
485 // Find by depth
486 let roots: Vec<_> = tree.nodes.iter().filter(|n| n.parent.is_none()).collect();
487 assert_eq!(roots.len(), 1, "one outermost loop");
488 assert_eq!(roots[0].depth, 1);
489 assert_eq!(roots[0].children.len(), 1, "outer has one child");
490
491 let mid = tree.node(roots[0].children[0]);
492 assert_eq!(mid.depth, 2);
493 assert_eq!(mid.children.len(), 1, "middle has one child");
494
495 let inner = tree.node(mid.children[0]);
496 assert_eq!(inner.depth, 3);
497 assert!(inner.children.is_empty(), "innermost has no children");
498 }
499
500 #[test]
501 fn innermost_returns_leaf_only() {
502 let f = build_triple_nested();
503 let tree = build_loop_tree(&f);
504
505 let leaves = tree.innermost_loops();
506 assert_eq!(leaves.len(), 1, "only the deepest loop is innermost");
507 assert_eq!(tree.node(leaves[0]).depth, 3);
508 }
509
510 #[test]
511 fn block_to_loop_maps_to_innermost() {
512 let f = build_triple_nested();
513 let tree = build_loop_tree(&f);
514
515 // The body block should map to the innermost loop.
516 let body_block = f
517 .blocks
518 .iter()
519 .find(|b| b.name.starts_with("body"))
520 .unwrap();
521 let mapped = tree.block_to_loop.get(&body_block.id);
522 assert!(mapped.is_some(), "body block should be in a loop");
523 let mapped_node = tree.node(*mapped.unwrap());
524 assert_eq!(mapped_node.depth, 3, "body should map to innermost loop");
525 }
526
527 #[test]
528 fn empty_function_has_no_loops() {
529 let mut f = Function::new("empty".into(), vec![], IrType::Void);
530 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
531 let tree = build_loop_tree(&f);
532 assert!(tree.nodes.is_empty());
533 assert!(tree.innermost_loops().is_empty());
534 }
535 }
536