Rust · 9022 bytes Raw Blame History
1 //! Call graph construction and analysis.
2 //!
3 //! Builds a call graph from a module's functions after call resolution.
4 //! Detects recursive functions via DFS cycle detection. Provides
5 //! reverse post-order iteration for bottom-up inlining (callees first).
6
7 use crate::ir::inst::*;
8 use crate::ir::walk::find_natural_loops;
9 use std::collections::HashSet;
10
11 /// A node in the call graph — one per function in the module.
12 #[derive(Debug)]
13 pub struct CallNode {
14 /// Index into Module::functions.
15 pub func_idx: u32,
16 /// Indices of functions this function calls (via Internal refs).
17 pub callees: Vec<u32>,
18 /// Indices of functions that call this one.
19 pub callers: Vec<u32>,
20 /// Total instruction count (cost model).
21 pub inst_count: usize,
22 /// True if this function is (directly or transitively) recursive.
23 pub is_recursive: bool,
24 }
25
26 /// The call graph for an entire module.
27 #[derive(Debug)]
28 pub struct CallGraph {
29 pub nodes: Vec<CallNode>,
30 }
31
32 impl CallGraph {
33 /// Build the call graph from a module.
34 pub fn build(module: &Module) -> Self {
35 let n = module.functions.len();
36 let mut nodes: Vec<CallNode> = (0..n)
37 .map(|i| CallNode {
38 func_idx: i as u32,
39 callees: Vec::new(),
40 callers: Vec::new(),
41 inst_count: 0,
42 is_recursive: false,
43 })
44 .collect();
45
46 // Scan each function for Call(Internal) instructions.
47 for (i, func) in module.functions.iter().enumerate() {
48 let mut callees_set: HashSet<u32> = HashSet::new();
49 let mut count = 0usize;
50 for block in &func.blocks {
51 count += block.insts.len();
52 for inst in &block.insts {
53 if let InstKind::Call(FuncRef::Internal(idx), _) = &inst.kind {
54 callees_set.insert(*idx);
55 }
56 }
57 }
58 let mut callees: Vec<u32> = callees_set.into_iter().collect();
59 callees.sort();
60 nodes[i].callees = callees.clone();
61
62 // Apply loop cost multiplier: functions with loops have higher
63 // dynamic cost than their static instruction count suggests.
64 // Multiply by 4 for each loop found (conservative estimate of
65 // average trip count impact on code expansion after inlining).
66 let loops = find_natural_loops(func);
67 let loop_multiplier = if loops.is_empty() { 1 } else { 4 * loops.len() };
68 nodes[i].inst_count = count * loop_multiplier;
69
70 // Register callers.
71 for &callee in &callees {
72 if (callee as usize) < n {
73 nodes[callee as usize].callers.push(i as u32);
74 }
75 }
76 }
77
78 // Detect recursion via DFS from each node.
79 for i in 0..n {
80 if detect_cycle(&nodes, i as u32) {
81 nodes[i].is_recursive = true;
82 }
83 }
84
85 CallGraph { nodes }
86 }
87
88 /// Return function indices in bottom-up order: callees before
89 /// callers. This is the correct order for inlining.
90 pub fn bottom_up_order(&self) -> Vec<u32> {
91 let n = self.nodes.len();
92 let mut visited = vec![false; n];
93 let mut order = Vec::new();
94
95 // Start DFS from root nodes (functions with no callers) to
96 // ensure post-order puts leaves (callees) first.
97 let roots: Vec<u32> = (0..n as u32)
98 .filter(|&i| self.nodes[i as usize].callers.is_empty())
99 .collect();
100 for root in &roots {
101 rpo_dfs(&self.nodes, *root, &mut visited, &mut order);
102 }
103 // Also visit any unvisited nodes (cycles, disconnected).
104 for i in 0..n {
105 if !visited[i] {
106 rpo_dfs(&self.nodes, i as u32, &mut visited, &mut order);
107 }
108 }
109
110 // Post-order puts callees AFTER callers. We want callees first,
111 // so DON'T reverse — just return the post-order directly.
112 order
113 }
114
115 /// Get the cost of inlining a function (instruction count).
116 pub fn inline_cost(&self, func_idx: u32) -> usize {
117 self.nodes[func_idx as usize].inst_count
118 }
119
120 /// Is the function recursive?
121 pub fn is_recursive(&self, func_idx: u32) -> bool {
122 self.nodes[func_idx as usize].is_recursive
123 }
124 }
125
126 fn rpo_dfs(nodes: &[CallNode], idx: u32, visited: &mut [bool], order: &mut Vec<u32>) {
127 if visited[idx as usize] {
128 return;
129 }
130 visited[idx as usize] = true;
131 for &callee in &nodes[idx as usize].callees {
132 if (callee as usize) < nodes.len() {
133 rpo_dfs(nodes, callee, visited, order);
134 }
135 }
136 order.push(idx);
137 }
138
139 /// Detect if function `start` is part of a cycle in the call graph.
140 fn detect_cycle(nodes: &[CallNode], start: u32) -> bool {
141 let mut visiting = HashSet::new();
142 dfs_cycle(nodes, start, &mut visiting)
143 }
144
145 fn dfs_cycle(nodes: &[CallNode], idx: u32, visiting: &mut HashSet<u32>) -> bool {
146 if !visiting.insert(idx) {
147 return true; // back-edge → cycle
148 }
149 for &callee in &nodes[idx as usize].callees {
150 if (callee as usize) < nodes.len() && dfs_cycle(nodes, callee, visiting) {
151 return true;
152 }
153 }
154 visiting.remove(&idx);
155 false
156 }
157
158 #[cfg(test)]
159 mod tests {
160 use super::*;
161 use crate::ir::types::{IntWidth, IrType};
162 use crate::lexer::{Position, Span};
163
164 fn span() -> Span {
165 let pos = Position { line: 0, col: 0 };
166 Span {
167 file_id: 0,
168 start: pos,
169 end: pos,
170 }
171 }
172
173 #[test]
174 fn detects_self_recursion() {
175 let mut m = Module::new("test".into());
176 let mut f = Function::new("factorial".into(), vec![], IrType::Int(IntWidth::I32));
177 // Self-call: factorial calls factorial.
178 let call_id = f.next_value_id();
179 f.register_type(call_id, IrType::Int(IntWidth::I32));
180 f.block_mut(f.entry).insts.push(Inst {
181 id: call_id,
182 ty: IrType::Int(IntWidth::I32),
183 span: span(),
184 kind: InstKind::Call(FuncRef::Internal(0), vec![]),
185 });
186 f.block_mut(f.entry).terminator = Some(Terminator::Return(Some(call_id)));
187 m.add_function(f);
188
189 let cg = CallGraph::build(&m);
190 assert!(
191 cg.is_recursive(0),
192 "self-calling function should be recursive"
193 );
194 }
195
196 #[test]
197 fn non_recursive_detected() {
198 let mut m = Module::new("test".into());
199 // Function 0: "callee" — no calls.
200 let mut callee = Function::new("callee".into(), vec![], IrType::Int(IntWidth::I32));
201 callee.block_mut(callee.entry).terminator = Some(Terminator::Return(None));
202 m.add_function(callee);
203
204 // Function 1: "caller" — calls callee (Internal(0)).
205 let mut caller = Function::new("caller".into(), vec![], IrType::Void);
206 let call_id = caller.next_value_id();
207 caller.register_type(call_id, IrType::Int(IntWidth::I32));
208 caller.block_mut(caller.entry).insts.push(Inst {
209 id: call_id,
210 ty: IrType::Int(IntWidth::I32),
211 span: span(),
212 kind: InstKind::Call(FuncRef::Internal(0), vec![]),
213 });
214 caller.block_mut(caller.entry).terminator = Some(Terminator::Return(None));
215 m.add_function(caller);
216
217 let cg = CallGraph::build(&m);
218 assert!(!cg.is_recursive(0), "callee should not be recursive");
219 assert!(!cg.is_recursive(1), "caller should not be recursive");
220 assert_eq!(cg.nodes[1].callees, vec![0]);
221 assert_eq!(cg.nodes[0].callers, vec![1]);
222 }
223
224 #[test]
225 fn bottom_up_order_callees_first() {
226 let mut m = Module::new("test".into());
227 // callee (idx 0) — no calls
228 let mut callee = Function::new("callee".into(), vec![], IrType::Void);
229 callee.block_mut(callee.entry).terminator = Some(Terminator::Return(None));
230 m.add_function(callee);
231 // caller (idx 1) — calls callee
232 let mut caller = Function::new("caller".into(), vec![], IrType::Void);
233 let cid = caller.next_value_id();
234 caller.register_type(cid, IrType::Void);
235 caller.block_mut(caller.entry).insts.push(Inst {
236 id: cid,
237 ty: IrType::Void,
238 span: span(),
239 kind: InstKind::Call(FuncRef::Internal(0), vec![]),
240 });
241 caller.block_mut(caller.entry).terminator = Some(Terminator::Return(None));
242 m.add_function(caller);
243
244 let cg = CallGraph::build(&m);
245 let rpo = cg.bottom_up_order();
246 // Callee should come before caller in RPO.
247 let callee_pos = rpo.iter().position(|&i| i == 0).unwrap();
248 let caller_pos = rpo.iter().position(|&i| i == 1).unwrap();
249 assert!(
250 callee_pos < caller_pos,
251 "callee should come before caller in RPO"
252 );
253 }
254 }
255