Rust · 10268 bytes Raw Blame History
1 //! Preheader insertion pass.
2 //!
3 //! For each natural loop that lacks a unique preheader, creates one by
4 //! inserting a new block between the out-of-loop predecessors and the
5 //! header. Block arguments are threaded through so SSA form is preserved.
6 //!
7 //! Run early at O2+ so downstream passes (LICM, unswitching, interchange)
8 //! can assume every loop has a preheader.
9
10 use super::loop_utils::find_preheader;
11 use super::pass::Pass;
12 use crate::ir::inst::*;
13 use crate::ir::walk::{find_natural_loops, predecessors};
14
15 /// Preheader insertion pass.
16 pub struct PreheaderInsert;
17
18 impl Pass for PreheaderInsert {
19 fn name(&self) -> &'static str {
20 "preheader-insert"
21 }
22
23 fn run(&self, module: &mut Module) -> bool {
24 let mut changed = false;
25 for func in &mut module.functions {
26 if insert_preheaders(func) {
27 changed = true;
28 }
29 }
30 changed
31 }
32 }
33
34 /// Insert preheaders for all loops in one function.
35 fn insert_preheaders(func: &mut Function) -> bool {
36 let mut changed = false;
37
38 // Re-discover loops and predecessors each iteration since we mutate
39 // the CFG. The pass runs at most once per loop (idempotent).
40 let loops = find_natural_loops(func);
41 let preds = predecessors(func);
42
43 for lp in &loops {
44 // Skip if a preheader already exists.
45 if find_preheader(func, lp, &preds).is_some() {
46 continue;
47 }
48
49 // Collect out-of-loop predecessors of the header.
50 let header_preds = match preds.get(&lp.header) {
51 Some(p) => p,
52 None => continue,
53 };
54 let mut outside: Vec<BlockId> = header_preds
55 .iter()
56 .copied()
57 .filter(|p| !lp.body.contains(p))
58 .collect();
59 outside.sort_by_key(|b| b.0);
60 outside.dedup();
61
62 if outside.is_empty() {
63 continue;
64 }
65
66 // Create the preheader block. It receives the same block params
67 // as the header and unconditionally branches to the header,
68 // passing them through.
69 let hdr = func.block(lp.header);
70 let hdr_params: Vec<BlockParam> = hdr.params.clone();
71 let ph_id = func.create_block("preheader");
72
73 // Add matching block params to the preheader.
74 let mut ph_param_ids = Vec::new();
75 for bp in &hdr_params {
76 let new_id = func.next_value_id();
77 func.register_type(new_id, bp.ty.clone());
78 func.block_mut(ph_id).params.push(BlockParam {
79 id: new_id,
80 ty: bp.ty.clone(),
81 });
82 ph_param_ids.push(new_id);
83 }
84
85 // Preheader terminates with unconditional branch to header,
86 // passing its params through.
87 func.block_mut(ph_id).terminator = Some(Terminator::Branch(lp.header, ph_param_ids));
88
89 // Redirect each out-of-loop predecessor to branch to the
90 // preheader instead of the header.
91 let header = lp.header;
92 for &pred_id in &outside {
93 redirect_terminator(func, pred_id, header, ph_id);
94 }
95
96 changed = true;
97 }
98
99 changed
100 }
101
102 /// Rewrite all edges from `block` that target `old_dest` to target
103 /// `new_dest` instead. Branch arguments are preserved.
104 fn redirect_terminator(func: &mut Function, block: BlockId, old_dest: BlockId, new_dest: BlockId) {
105 let blk = func.block_mut(block);
106 let term = match &mut blk.terminator {
107 Some(t) => t,
108 None => return,
109 };
110 match term {
111 Terminator::Branch(dest, _) => {
112 if *dest == old_dest {
113 *dest = new_dest;
114 }
115 }
116 Terminator::CondBranch {
117 true_dest,
118 false_dest,
119 ..
120 } => {
121 if *true_dest == old_dest {
122 *true_dest = new_dest;
123 }
124 if *false_dest == old_dest {
125 *false_dest = new_dest;
126 }
127 }
128 Terminator::Switch { default, cases, .. } => {
129 if *default == old_dest {
130 *default = new_dest;
131 }
132 for (_, dest) in cases.iter_mut() {
133 if *dest == old_dest {
134 *dest = new_dest;
135 }
136 }
137 }
138 _ => {}
139 }
140 }
141
142 // ---------------------------------------------------------------------------
143 // Tests
144 // ---------------------------------------------------------------------------
145
146 #[cfg(test)]
147 mod tests {
148 use super::*;
149 use crate::ir::types::{IntWidth, IrType};
150 use crate::ir::walk::predecessors;
151 use crate::lexer::{Position, Span};
152
153 fn span() -> Span {
154 let pos = Position { line: 0, col: 0 };
155 Span {
156 file_id: 0,
157 start: pos,
158 end: pos,
159 }
160 }
161
162 /// Build a function where the loop header has TWO out-of-loop
163 /// predecessors (entry and an if_else block), so no preheader exists.
164 fn build_two_entry_loop() -> Module {
165 let mut m = Module::new("test".into());
166 let mut f = Function::new("test".into(), vec![], IrType::Void);
167
168 let header = f.create_block("header");
169 let body = f.create_block("body");
170 let latch = f.create_block("latch");
171 let exit = f.create_block("exit");
172 let alt = f.create_block("alt_entry");
173 let entry = f.entry;
174
175 // Entry: const 1, condBr → header(1) or alt_entry
176 let c1 = f.next_value_id();
177 f.register_type(c1, IrType::Int(IntWidth::I32));
178 f.block_mut(entry).insts.push(Inst {
179 id: c1,
180 ty: IrType::Int(IntWidth::I32),
181 span: span(),
182 kind: InstKind::ConstInt(1, IntWidth::I32),
183 });
184 let cond = f.next_value_id();
185 f.register_type(cond, IrType::Bool);
186 f.block_mut(entry).insts.push(Inst {
187 id: cond,
188 ty: IrType::Bool,
189 span: span(),
190 kind: InstKind::ConstBool(true),
191 });
192 f.block_mut(entry).terminator = Some(Terminator::CondBranch {
193 cond,
194 true_dest: header,
195 true_args: vec![c1],
196 false_dest: alt,
197 false_args: vec![],
198 });
199
200 // Alt entry: const 5, br header(5)
201 let c5 = f.next_value_id();
202 f.register_type(c5, IrType::Int(IntWidth::I32));
203 f.block_mut(alt).insts.push(Inst {
204 id: c5,
205 ty: IrType::Int(IntWidth::I32),
206 span: span(),
207 kind: InstKind::ConstInt(5, IntWidth::I32),
208 });
209 f.block_mut(alt).terminator = Some(Terminator::Branch(header, vec![c5]));
210
211 // Header: block param, icmp, condBr body/exit
212 let iv = f.next_value_id();
213 f.register_type(iv, IrType::Int(IntWidth::I32));
214 f.block_mut(header).params.push(BlockParam {
215 id: iv,
216 ty: IrType::Int(IntWidth::I32),
217 });
218 let c10 = f.next_value_id();
219 f.register_type(c10, IrType::Int(IntWidth::I32));
220 f.block_mut(header).insts.push(Inst {
221 id: c10,
222 ty: IrType::Int(IntWidth::I32),
223 span: span(),
224 kind: InstKind::ConstInt(10, IntWidth::I32),
225 });
226 let cmp = f.next_value_id();
227 f.register_type(cmp, IrType::Bool);
228 f.block_mut(header).insts.push(Inst {
229 id: cmp,
230 ty: IrType::Bool,
231 span: span(),
232 kind: InstKind::ICmp(CmpOp::Le, iv, c10),
233 });
234 f.block_mut(header).terminator = Some(Terminator::CondBranch {
235 cond: cmp,
236 true_dest: body,
237 true_args: vec![],
238 false_dest: exit,
239 false_args: vec![],
240 });
241
242 // Body → latch
243 f.block_mut(body).terminator = Some(Terminator::Branch(latch, vec![]));
244
245 // Latch: iadd + br header
246 let one_l = f.next_value_id();
247 f.register_type(one_l, IrType::Int(IntWidth::I32));
248 f.block_mut(latch).insts.push(Inst {
249 id: one_l,
250 ty: IrType::Int(IntWidth::I32),
251 span: span(),
252 kind: InstKind::ConstInt(1, IntWidth::I32),
253 });
254 let nxt = f.next_value_id();
255 f.register_type(nxt, IrType::Int(IntWidth::I32));
256 f.block_mut(latch).insts.push(Inst {
257 id: nxt,
258 ty: IrType::Int(IntWidth::I32),
259 span: span(),
260 kind: InstKind::IAdd(iv, one_l),
261 });
262 f.block_mut(latch).terminator = Some(Terminator::Branch(header, vec![nxt]));
263
264 // Exit
265 f.block_mut(exit).terminator = Some(Terminator::Return(None));
266
267 m.add_function(f);
268 m
269 }
270
271 #[test]
272 fn inserts_preheader_for_multi_entry_loop() {
273 let mut m = build_two_entry_loop();
274 let f = &m.functions[0];
275
276 // Before: header has 2 out-of-loop preds (entry, alt_entry).
277 let loops = find_natural_loops(f);
278 let preds = predecessors(f);
279 assert_eq!(loops.len(), 1);
280 assert!(
281 find_preheader(f, &loops[0], &preds).is_none(),
282 "should have no preheader before insertion"
283 );
284
285 let pass = PreheaderInsert;
286 let changed = pass.run(&mut m);
287 assert!(changed, "should have inserted a preheader");
288
289 // After: header should have exactly one out-of-loop pred.
290 let f = &m.functions[0];
291 let loops = find_natural_loops(f);
292 let preds = predecessors(f);
293 let ph = find_preheader(f, &loops[0], &preds);
294 assert!(ph.is_some(), "preheader should now exist");
295
296 // The preheader should have a block param matching the header's.
297 let ph_block = f.block(ph.unwrap());
298 assert_eq!(ph_block.params.len(), 1, "preheader should have 1 param");
299 }
300
301 #[test]
302 fn idempotent_when_preheader_exists() {
303 let mut m = build_two_entry_loop();
304 let pass = PreheaderInsert;
305 pass.run(&mut m);
306 let block_count_after_first = m.functions[0].blocks.len();
307
308 // Running again should not change anything.
309 let changed = pass.run(&mut m);
310 assert!(!changed, "second run should be idempotent");
311 assert_eq!(m.functions[0].blocks.len(), block_count_after_first);
312 }
313 }
314