| 1 | //! Preheader insertion pass. |
| 2 | //! |
| 3 | //! For each natural loop that lacks a unique preheader, creates one by |
| 4 | //! inserting a new block between the out-of-loop predecessors and the |
| 5 | //! header. Block arguments are threaded through so SSA form is preserved. |
| 6 | //! |
| 7 | //! Run early at O2+ so downstream passes (LICM, unswitching, interchange) |
| 8 | //! can assume every loop has a preheader. |
| 9 | |
| 10 | use super::loop_utils::find_preheader; |
| 11 | use super::pass::Pass; |
| 12 | use crate::ir::inst::*; |
| 13 | use crate::ir::walk::{find_natural_loops, predecessors}; |
| 14 | |
| 15 | /// Preheader insertion pass. |
| 16 | pub struct PreheaderInsert; |
| 17 | |
| 18 | impl Pass for PreheaderInsert { |
| 19 | fn name(&self) -> &'static str { |
| 20 | "preheader-insert" |
| 21 | } |
| 22 | |
| 23 | fn run(&self, module: &mut Module) -> bool { |
| 24 | let mut changed = false; |
| 25 | for func in &mut module.functions { |
| 26 | if insert_preheaders(func) { |
| 27 | changed = true; |
| 28 | } |
| 29 | } |
| 30 | changed |
| 31 | } |
| 32 | } |
| 33 | |
| 34 | /// Insert preheaders for all loops in one function. |
| 35 | fn insert_preheaders(func: &mut Function) -> bool { |
| 36 | let mut changed = false; |
| 37 | |
| 38 | // Re-discover loops and predecessors each iteration since we mutate |
| 39 | // the CFG. The pass runs at most once per loop (idempotent). |
| 40 | let loops = find_natural_loops(func); |
| 41 | let preds = predecessors(func); |
| 42 | |
| 43 | for lp in &loops { |
| 44 | // Skip if a preheader already exists. |
| 45 | if find_preheader(func, lp, &preds).is_some() { |
| 46 | continue; |
| 47 | } |
| 48 | |
| 49 | // Collect out-of-loop predecessors of the header. |
| 50 | let header_preds = match preds.get(&lp.header) { |
| 51 | Some(p) => p, |
| 52 | None => continue, |
| 53 | }; |
| 54 | let mut outside: Vec<BlockId> = header_preds |
| 55 | .iter() |
| 56 | .copied() |
| 57 | .filter(|p| !lp.body.contains(p)) |
| 58 | .collect(); |
| 59 | outside.sort_by_key(|b| b.0); |
| 60 | outside.dedup(); |
| 61 | |
| 62 | if outside.is_empty() { |
| 63 | continue; |
| 64 | } |
| 65 | |
| 66 | // Create the preheader block. It receives the same block params |
| 67 | // as the header and unconditionally branches to the header, |
| 68 | // passing them through. |
| 69 | let hdr = func.block(lp.header); |
| 70 | let hdr_params: Vec<BlockParam> = hdr.params.clone(); |
| 71 | let ph_id = func.create_block("preheader"); |
| 72 | |
| 73 | // Add matching block params to the preheader. |
| 74 | let mut ph_param_ids = Vec::new(); |
| 75 | for bp in &hdr_params { |
| 76 | let new_id = func.next_value_id(); |
| 77 | func.register_type(new_id, bp.ty.clone()); |
| 78 | func.block_mut(ph_id).params.push(BlockParam { |
| 79 | id: new_id, |
| 80 | ty: bp.ty.clone(), |
| 81 | }); |
| 82 | ph_param_ids.push(new_id); |
| 83 | } |
| 84 | |
| 85 | // Preheader terminates with unconditional branch to header, |
| 86 | // passing its params through. |
| 87 | func.block_mut(ph_id).terminator = Some(Terminator::Branch(lp.header, ph_param_ids)); |
| 88 | |
| 89 | // Redirect each out-of-loop predecessor to branch to the |
| 90 | // preheader instead of the header. |
| 91 | let header = lp.header; |
| 92 | for &pred_id in &outside { |
| 93 | redirect_terminator(func, pred_id, header, ph_id); |
| 94 | } |
| 95 | |
| 96 | changed = true; |
| 97 | } |
| 98 | |
| 99 | changed |
| 100 | } |
| 101 | |
| 102 | /// Rewrite all edges from `block` that target `old_dest` to target |
| 103 | /// `new_dest` instead. Branch arguments are preserved. |
| 104 | fn redirect_terminator(func: &mut Function, block: BlockId, old_dest: BlockId, new_dest: BlockId) { |
| 105 | let blk = func.block_mut(block); |
| 106 | let term = match &mut blk.terminator { |
| 107 | Some(t) => t, |
| 108 | None => return, |
| 109 | }; |
| 110 | match term { |
| 111 | Terminator::Branch(dest, _) => { |
| 112 | if *dest == old_dest { |
| 113 | *dest = new_dest; |
| 114 | } |
| 115 | } |
| 116 | Terminator::CondBranch { |
| 117 | true_dest, |
| 118 | false_dest, |
| 119 | .. |
| 120 | } => { |
| 121 | if *true_dest == old_dest { |
| 122 | *true_dest = new_dest; |
| 123 | } |
| 124 | if *false_dest == old_dest { |
| 125 | *false_dest = new_dest; |
| 126 | } |
| 127 | } |
| 128 | Terminator::Switch { default, cases, .. } => { |
| 129 | if *default == old_dest { |
| 130 | *default = new_dest; |
| 131 | } |
| 132 | for (_, dest) in cases.iter_mut() { |
| 133 | if *dest == old_dest { |
| 134 | *dest = new_dest; |
| 135 | } |
| 136 | } |
| 137 | } |
| 138 | _ => {} |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | // --------------------------------------------------------------------------- |
| 143 | // Tests |
| 144 | // --------------------------------------------------------------------------- |
| 145 | |
| 146 | #[cfg(test)] |
| 147 | mod tests { |
| 148 | use super::*; |
| 149 | use crate::ir::types::{IntWidth, IrType}; |
| 150 | use crate::ir::walk::predecessors; |
| 151 | use crate::lexer::{Position, Span}; |
| 152 | |
| 153 | fn span() -> Span { |
| 154 | let pos = Position { line: 0, col: 0 }; |
| 155 | Span { |
| 156 | file_id: 0, |
| 157 | start: pos, |
| 158 | end: pos, |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | /// Build a function where the loop header has TWO out-of-loop |
| 163 | /// predecessors (entry and an if_else block), so no preheader exists. |
| 164 | fn build_two_entry_loop() -> Module { |
| 165 | let mut m = Module::new("test".into()); |
| 166 | let mut f = Function::new("test".into(), vec![], IrType::Void); |
| 167 | |
| 168 | let header = f.create_block("header"); |
| 169 | let body = f.create_block("body"); |
| 170 | let latch = f.create_block("latch"); |
| 171 | let exit = f.create_block("exit"); |
| 172 | let alt = f.create_block("alt_entry"); |
| 173 | let entry = f.entry; |
| 174 | |
| 175 | // Entry: const 1, condBr → header(1) or alt_entry |
| 176 | let c1 = f.next_value_id(); |
| 177 | f.register_type(c1, IrType::Int(IntWidth::I32)); |
| 178 | f.block_mut(entry).insts.push(Inst { |
| 179 | id: c1, |
| 180 | ty: IrType::Int(IntWidth::I32), |
| 181 | span: span(), |
| 182 | kind: InstKind::ConstInt(1, IntWidth::I32), |
| 183 | }); |
| 184 | let cond = f.next_value_id(); |
| 185 | f.register_type(cond, IrType::Bool); |
| 186 | f.block_mut(entry).insts.push(Inst { |
| 187 | id: cond, |
| 188 | ty: IrType::Bool, |
| 189 | span: span(), |
| 190 | kind: InstKind::ConstBool(true), |
| 191 | }); |
| 192 | f.block_mut(entry).terminator = Some(Terminator::CondBranch { |
| 193 | cond, |
| 194 | true_dest: header, |
| 195 | true_args: vec![c1], |
| 196 | false_dest: alt, |
| 197 | false_args: vec![], |
| 198 | }); |
| 199 | |
| 200 | // Alt entry: const 5, br header(5) |
| 201 | let c5 = f.next_value_id(); |
| 202 | f.register_type(c5, IrType::Int(IntWidth::I32)); |
| 203 | f.block_mut(alt).insts.push(Inst { |
| 204 | id: c5, |
| 205 | ty: IrType::Int(IntWidth::I32), |
| 206 | span: span(), |
| 207 | kind: InstKind::ConstInt(5, IntWidth::I32), |
| 208 | }); |
| 209 | f.block_mut(alt).terminator = Some(Terminator::Branch(header, vec![c5])); |
| 210 | |
| 211 | // Header: block param, icmp, condBr body/exit |
| 212 | let iv = f.next_value_id(); |
| 213 | f.register_type(iv, IrType::Int(IntWidth::I32)); |
| 214 | f.block_mut(header).params.push(BlockParam { |
| 215 | id: iv, |
| 216 | ty: IrType::Int(IntWidth::I32), |
| 217 | }); |
| 218 | let c10 = f.next_value_id(); |
| 219 | f.register_type(c10, IrType::Int(IntWidth::I32)); |
| 220 | f.block_mut(header).insts.push(Inst { |
| 221 | id: c10, |
| 222 | ty: IrType::Int(IntWidth::I32), |
| 223 | span: span(), |
| 224 | kind: InstKind::ConstInt(10, IntWidth::I32), |
| 225 | }); |
| 226 | let cmp = f.next_value_id(); |
| 227 | f.register_type(cmp, IrType::Bool); |
| 228 | f.block_mut(header).insts.push(Inst { |
| 229 | id: cmp, |
| 230 | ty: IrType::Bool, |
| 231 | span: span(), |
| 232 | kind: InstKind::ICmp(CmpOp::Le, iv, c10), |
| 233 | }); |
| 234 | f.block_mut(header).terminator = Some(Terminator::CondBranch { |
| 235 | cond: cmp, |
| 236 | true_dest: body, |
| 237 | true_args: vec![], |
| 238 | false_dest: exit, |
| 239 | false_args: vec![], |
| 240 | }); |
| 241 | |
| 242 | // Body → latch |
| 243 | f.block_mut(body).terminator = Some(Terminator::Branch(latch, vec![])); |
| 244 | |
| 245 | // Latch: iadd + br header |
| 246 | let one_l = f.next_value_id(); |
| 247 | f.register_type(one_l, IrType::Int(IntWidth::I32)); |
| 248 | f.block_mut(latch).insts.push(Inst { |
| 249 | id: one_l, |
| 250 | ty: IrType::Int(IntWidth::I32), |
| 251 | span: span(), |
| 252 | kind: InstKind::ConstInt(1, IntWidth::I32), |
| 253 | }); |
| 254 | let nxt = f.next_value_id(); |
| 255 | f.register_type(nxt, IrType::Int(IntWidth::I32)); |
| 256 | f.block_mut(latch).insts.push(Inst { |
| 257 | id: nxt, |
| 258 | ty: IrType::Int(IntWidth::I32), |
| 259 | span: span(), |
| 260 | kind: InstKind::IAdd(iv, one_l), |
| 261 | }); |
| 262 | f.block_mut(latch).terminator = Some(Terminator::Branch(header, vec![nxt])); |
| 263 | |
| 264 | // Exit |
| 265 | f.block_mut(exit).terminator = Some(Terminator::Return(None)); |
| 266 | |
| 267 | m.add_function(f); |
| 268 | m |
| 269 | } |
| 270 | |
| 271 | #[test] |
| 272 | fn inserts_preheader_for_multi_entry_loop() { |
| 273 | let mut m = build_two_entry_loop(); |
| 274 | let f = &m.functions[0]; |
| 275 | |
| 276 | // Before: header has 2 out-of-loop preds (entry, alt_entry). |
| 277 | let loops = find_natural_loops(f); |
| 278 | let preds = predecessors(f); |
| 279 | assert_eq!(loops.len(), 1); |
| 280 | assert!( |
| 281 | find_preheader(f, &loops[0], &preds).is_none(), |
| 282 | "should have no preheader before insertion" |
| 283 | ); |
| 284 | |
| 285 | let pass = PreheaderInsert; |
| 286 | let changed = pass.run(&mut m); |
| 287 | assert!(changed, "should have inserted a preheader"); |
| 288 | |
| 289 | // After: header should have exactly one out-of-loop pred. |
| 290 | let f = &m.functions[0]; |
| 291 | let loops = find_natural_loops(f); |
| 292 | let preds = predecessors(f); |
| 293 | let ph = find_preheader(f, &loops[0], &preds); |
| 294 | assert!(ph.is_some(), "preheader should now exist"); |
| 295 | |
| 296 | // The preheader should have a block param matching the header's. |
| 297 | let ph_block = f.block(ph.unwrap()); |
| 298 | assert_eq!(ph_block.params.len(), 1, "preheader should have 1 param"); |
| 299 | } |
| 300 | |
| 301 | #[test] |
| 302 | fn idempotent_when_preheader_exists() { |
| 303 | let mut m = build_two_entry_loop(); |
| 304 | let pass = PreheaderInsert; |
| 305 | pass.run(&mut m); |
| 306 | let block_count_after_first = m.functions[0].blocks.len(); |
| 307 | |
| 308 | // Running again should not change anything. |
| 309 | let changed = pass.run(&mut m); |
| 310 | assert!(!changed, "second run should be idempotent"); |
| 311 | assert_eq!(m.functions[0].blocks.len(), block_count_after_first); |
| 312 | } |
| 313 | } |
| 314 |