| 1 | //! Constant propagation pass. |
| 2 | //! |
| 3 | //! In our SSA IR, propagating a constant value into the operand slot of |
| 4 | //! an instruction is a no-op rewrite — every consumer already finds the |
| 5 | //! defining `Const*` instruction by `ValueId`, and codegen happily turns |
| 6 | //! that into an immediate. The transformations that *do* matter are the |
| 7 | //! ones the constant-folding pass cannot reach because they live in |
| 8 | //! block terminators: |
| 9 | //! |
| 10 | //! * `CondBranch { cond, true_dest, false_dest }` with a constant |
| 11 | //! `cond` becomes an unconditional `Branch` to the live target. |
| 12 | //! * `Switch { selector, cases, default }` with a constant `selector` |
| 13 | //! becomes an unconditional `Branch` to the matching case. |
| 14 | //! |
| 15 | //! Both rewrites are critical: they expose dead blocks to later passes |
| 16 | //! (CSE, LICM, GVN) and shrink the CFG. |
| 17 | //! |
| 18 | //! After folding terminators we **must** prune any blocks that are no |
| 19 | //! longer reachable from the entry. Leaving an orphan reachable through |
| 20 | //! some other path could be fine, but our verifier rejects values used |
| 21 | //! across blocks that no longer dominate the use — and unreachable |
| 22 | //! blocks can break dominance for downstream merges. The cleanest fix |
| 23 | //! is to remove the dead blocks as part of the same transformation so |
| 24 | //! the IR handed to the verifier is always well-formed. |
| 25 | //! |
| 26 | //! The pass is conservative: if the picked target has block parameters |
| 27 | //! (which Fortran lowering only emits for loop headers and merges), we |
| 28 | //! still rewrite — we just keep the original branch arguments for the |
| 29 | //! taken side, since that side already had matching args in the |
| 30 | //! original `CondBranch`. |
| 31 | |
| 32 | use super::pass::Pass; |
| 33 | use super::util::prune_unreachable; |
| 34 | use crate::ir::inst::*; |
| 35 | use crate::ir::types::IntWidth; |
| 36 | use std::collections::HashMap; |
| 37 | |
| 38 | /// What we know about a value at compile time. |
| 39 | #[derive(Debug, Clone, Copy)] |
| 40 | enum Const { |
| 41 | Bool(bool), |
| 42 | Int(i64, IntWidth), |
| 43 | } |
| 44 | |
| 45 | impl Const { |
| 46 | /// Build a `Const` from a constant-producing instruction, |
| 47 | /// sign-extending integer values at their declared width so the |
| 48 | /// stored value is canonical. Audit M4-4: matches the N-9 |
| 49 | /// normalization in `const_fold::Const::from_inst`. The Switch |
| 50 | /// fold at line 119 already sexts at use site, but relying on |
| 51 | /// every use site to remember that is fragile — canonicalizing |
| 52 | /// here closes the consistency gap with `const_fold`. |
| 53 | fn from_inst(kind: &InstKind) -> Option<Self> { |
| 54 | match kind { |
| 55 | InstKind::ConstBool(b) => Some(Const::Bool(*b)), |
| 56 | InstKind::ConstInt(v, w) => { |
| 57 | let bits = w.bits(); |
| 58 | let signed = if bits >= 64 { |
| 59 | i64::try_from(*v).ok()? |
| 60 | } else { |
| 61 | let shift = 64 - bits; |
| 62 | let narrow = i64::try_from(*v).ok()?; |
| 63 | (narrow << shift) >> shift |
| 64 | }; |
| 65 | Some(Const::Int(signed, *w)) |
| 66 | } |
| 67 | _ => None, |
| 68 | } |
| 69 | } |
| 70 | } |
| 71 | |
| 72 | /// Build the const-value map for one function. Walks every block in |
| 73 | /// definition order; SSA dominance guarantees an instruction's operands |
| 74 | /// are already in the map when we reach it. |
| 75 | fn collect_consts(func: &Function) -> HashMap<ValueId, Const> { |
| 76 | let mut consts = HashMap::new(); |
| 77 | for block in &func.blocks { |
| 78 | for inst in &block.insts { |
| 79 | if let Some(c) = Const::from_inst(&inst.kind) { |
| 80 | consts.insert(inst.id, c); |
| 81 | } |
| 82 | } |
| 83 | } |
| 84 | consts |
| 85 | } |
| 86 | |
| 87 | /// The constant propagation pass. |
| 88 | pub struct ConstProp; |
| 89 | |
| 90 | impl Pass for ConstProp { |
| 91 | fn name(&self) -> &'static str { |
| 92 | "const-prop" |
| 93 | } |
| 94 | |
| 95 | fn run(&self, module: &mut Module) -> bool { |
| 96 | let mut changed = false; |
| 97 | for func in &mut module.functions { |
| 98 | let consts = collect_consts(func); |
| 99 | let mut local_changed = false; |
| 100 | for block in &mut func.blocks { |
| 101 | let Some(term) = block.terminator.take() else { |
| 102 | continue; |
| 103 | }; |
| 104 | let new_term = simplify_terminator(term, &consts, &mut local_changed); |
| 105 | block.terminator = Some(new_term); |
| 106 | } |
| 107 | if local_changed { |
| 108 | // Folded a CondBranch/Switch — drop any block that is no |
| 109 | // longer reachable from the entry so the verifier doesn't |
| 110 | // see stale uses across a now-impossible path. |
| 111 | let _pruned = prune_unreachable(func); |
| 112 | changed = true; |
| 113 | } |
| 114 | } |
| 115 | changed |
| 116 | } |
| 117 | } |
| 118 | |
| 119 | fn simplify_terminator( |
| 120 | term: Terminator, |
| 121 | consts: &HashMap<ValueId, Const>, |
| 122 | changed: &mut bool, |
| 123 | ) -> Terminator { |
| 124 | match term { |
| 125 | Terminator::CondBranch { |
| 126 | cond, |
| 127 | true_dest, |
| 128 | true_args, |
| 129 | false_dest, |
| 130 | false_args, |
| 131 | } => { |
| 132 | if let Some(Const::Bool(c)) = consts.get(&cond).copied() { |
| 133 | *changed = true; |
| 134 | if c { |
| 135 | Terminator::Branch(true_dest, true_args) |
| 136 | } else { |
| 137 | Terminator::Branch(false_dest, false_args) |
| 138 | } |
| 139 | } else { |
| 140 | Terminator::CondBranch { |
| 141 | cond, |
| 142 | true_dest, |
| 143 | true_args, |
| 144 | false_dest, |
| 145 | false_args, |
| 146 | } |
| 147 | } |
| 148 | } |
| 149 | Terminator::Switch { |
| 150 | selector, |
| 151 | cases, |
| 152 | default, |
| 153 | } => { |
| 154 | if let Some(Const::Int(sv, w)) = consts.get(&selector).copied() { |
| 155 | // `sv` is already canonical (sign-extended at width) |
| 156 | // thanks to the M4-4 normalization in `Const::from_inst`. |
| 157 | // Case keys come from the IR and aren't guaranteed |
| 158 | // canonical, so we still normalize them before |
| 159 | // comparing. |
| 160 | *changed = true; |
| 161 | let target = cases |
| 162 | .iter() |
| 163 | .find(|(k, _)| sext(*k, w.bits()) == sv) |
| 164 | .map(|(_, b)| *b) |
| 165 | .unwrap_or(default); |
| 166 | // Switch targets in our IR cannot have block params, so |
| 167 | // an empty arg vec is correct. |
| 168 | Terminator::Branch(target, vec![]) |
| 169 | } else { |
| 170 | Terminator::Switch { |
| 171 | selector, |
| 172 | cases, |
| 173 | default, |
| 174 | } |
| 175 | } |
| 176 | } |
| 177 | other => other, |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | fn sext(v: i64, bits: u32) -> i64 { |
| 182 | if bits >= 64 { |
| 183 | v |
| 184 | } else { |
| 185 | let shift = 64 - bits; |
| 186 | (v << shift) >> shift |
| 187 | } |
| 188 | } |
| 189 | |
| 190 | #[cfg(test)] |
| 191 | mod tests { |
| 192 | use super::*; |
| 193 | use crate::ir::types::IrType; |
| 194 | use crate::lexer::{Position, Span}; |
| 195 | |
| 196 | fn dummy_span() -> Span { |
| 197 | let p = Position { line: 1, col: 1 }; |
| 198 | Span { |
| 199 | start: p, |
| 200 | end: p, |
| 201 | file_id: 0, |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | #[test] |
| 206 | fn condbranch_with_true_const_folds_to_branch_true() { |
| 207 | let mut m = Module::new("t".into()); |
| 208 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 209 | |
| 210 | // Build: entry { cond=const(true); cond_branch cond, then, else } |
| 211 | // then { ret } |
| 212 | // else { ret } |
| 213 | let bb_t = f.create_block("then"); |
| 214 | let bb_e = f.create_block("else"); |
| 215 | let cond_id = f.next_value_id(); |
| 216 | f.block_mut(f.entry).insts.push(Inst { |
| 217 | id: cond_id, |
| 218 | kind: InstKind::ConstBool(true), |
| 219 | ty: IrType::Bool, |
| 220 | span: dummy_span(), |
| 221 | }); |
| 222 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 223 | cond: cond_id, |
| 224 | true_dest: bb_t, |
| 225 | true_args: vec![], |
| 226 | false_dest: bb_e, |
| 227 | false_args: vec![], |
| 228 | }); |
| 229 | f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 230 | f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 231 | m.add_function(f); |
| 232 | |
| 233 | assert!(ConstProp.run(&mut m)); |
| 234 | match &m.functions[0].blocks[0].terminator { |
| 235 | Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, bb_t), |
| 236 | other => panic!("expected Branch, got {:?}", other), |
| 237 | } |
| 238 | } |
| 239 | |
| 240 | #[test] |
| 241 | fn condbranch_with_false_const_folds_to_branch_false() { |
| 242 | let mut m = Module::new("t".into()); |
| 243 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 244 | let bb_t = f.create_block("then"); |
| 245 | let bb_e = f.create_block("else"); |
| 246 | let cond_id = f.next_value_id(); |
| 247 | f.block_mut(f.entry).insts.push(Inst { |
| 248 | id: cond_id, |
| 249 | kind: InstKind::ConstBool(false), |
| 250 | ty: IrType::Bool, |
| 251 | span: dummy_span(), |
| 252 | }); |
| 253 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 254 | cond: cond_id, |
| 255 | true_dest: bb_t, |
| 256 | true_args: vec![], |
| 257 | false_dest: bb_e, |
| 258 | false_args: vec![], |
| 259 | }); |
| 260 | f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 261 | f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 262 | m.add_function(f); |
| 263 | |
| 264 | assert!(ConstProp.run(&mut m)); |
| 265 | match &m.functions[0].blocks[0].terminator { |
| 266 | Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, bb_e), |
| 267 | other => panic!("expected Branch, got {:?}", other), |
| 268 | } |
| 269 | } |
| 270 | |
| 271 | #[test] |
| 272 | fn switch_with_const_selector_picks_case() { |
| 273 | let mut m = Module::new("t".into()); |
| 274 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 275 | let case_a = f.create_block("a"); |
| 276 | let case_b = f.create_block("b"); |
| 277 | let default = f.create_block("default"); |
| 278 | let sel_id = f.next_value_id(); |
| 279 | f.block_mut(f.entry).insts.push(Inst { |
| 280 | id: sel_id, |
| 281 | kind: InstKind::ConstInt(2, IntWidth::I32), |
| 282 | ty: IrType::Int(IntWidth::I32), |
| 283 | span: dummy_span(), |
| 284 | }); |
| 285 | f.block_mut(f.entry).terminator = Some(Terminator::Switch { |
| 286 | selector: sel_id, |
| 287 | cases: vec![(1, case_a), (2, case_b)], |
| 288 | default, |
| 289 | }); |
| 290 | f.block_mut(case_a).terminator = Some(Terminator::Return(None)); |
| 291 | f.block_mut(case_b).terminator = Some(Terminator::Return(None)); |
| 292 | f.block_mut(default).terminator = Some(Terminator::Return(None)); |
| 293 | m.add_function(f); |
| 294 | |
| 295 | assert!(ConstProp.run(&mut m)); |
| 296 | match &m.functions[0].blocks[0].terminator { |
| 297 | Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, case_b), |
| 298 | other => panic!("expected Branch, got {:?}", other), |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | #[test] |
| 303 | fn switch_with_unmatched_const_selector_picks_default() { |
| 304 | let mut m = Module::new("t".into()); |
| 305 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 306 | let case_a = f.create_block("a"); |
| 307 | let default = f.create_block("default"); |
| 308 | let sel_id = f.next_value_id(); |
| 309 | f.block_mut(f.entry).insts.push(Inst { |
| 310 | id: sel_id, |
| 311 | kind: InstKind::ConstInt(99, IntWidth::I32), |
| 312 | ty: IrType::Int(IntWidth::I32), |
| 313 | span: dummy_span(), |
| 314 | }); |
| 315 | f.block_mut(f.entry).terminator = Some(Terminator::Switch { |
| 316 | selector: sel_id, |
| 317 | cases: vec![(1, case_a)], |
| 318 | default, |
| 319 | }); |
| 320 | f.block_mut(case_a).terminator = Some(Terminator::Return(None)); |
| 321 | f.block_mut(default).terminator = Some(Terminator::Return(None)); |
| 322 | m.add_function(f); |
| 323 | |
| 324 | assert!(ConstProp.run(&mut m)); |
| 325 | match &m.functions[0].blocks[0].terminator { |
| 326 | Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, default), |
| 327 | other => panic!("expected Branch, got {:?}", other), |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | #[test] |
| 332 | fn folded_condbranch_prunes_dead_block() { |
| 333 | // entry { cond=const(true); cond_branch -> then or dead } |
| 334 | // then { ret } |
| 335 | // dead { ret } // should be removed |
| 336 | let mut m = Module::new("t".into()); |
| 337 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 338 | let bb_t = f.create_block("then"); |
| 339 | let bb_d = f.create_block("dead"); |
| 340 | let cond_id = f.next_value_id(); |
| 341 | f.block_mut(f.entry).insts.push(Inst { |
| 342 | id: cond_id, |
| 343 | kind: InstKind::ConstBool(true), |
| 344 | ty: IrType::Bool, |
| 345 | span: dummy_span(), |
| 346 | }); |
| 347 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 348 | cond: cond_id, |
| 349 | true_dest: bb_t, |
| 350 | true_args: vec![], |
| 351 | false_dest: bb_d, |
| 352 | false_args: vec![], |
| 353 | }); |
| 354 | f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 355 | f.block_mut(bb_d).terminator = Some(Terminator::Return(None)); |
| 356 | m.add_function(f); |
| 357 | |
| 358 | assert!(ConstProp.run(&mut m)); |
| 359 | let f = &m.functions[0]; |
| 360 | assert_eq!(f.blocks.len(), 2, "dead block should be pruned"); |
| 361 | assert!(f.blocks.iter().any(|b| b.id == bb_t)); |
| 362 | assert!(!f.blocks.iter().any(|b| b.id == bb_d)); |
| 363 | } |
| 364 | |
| 365 | #[test] |
| 366 | fn condbranch_with_unknown_cond_left_alone() { |
| 367 | // Operand is not a constant — must not transform. |
| 368 | let mut m = Module::new("t".into()); |
| 369 | let params = vec![Param { |
| 370 | name: "p".into(), |
| 371 | ty: IrType::Bool, |
| 372 | id: ValueId(0), |
| 373 | fortran_noalias: false, |
| 374 | }]; |
| 375 | let mut f = Function::new("f".into(), params, IrType::Void); |
| 376 | let bb_t = f.create_block("then"); |
| 377 | let bb_e = f.create_block("else"); |
| 378 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 379 | cond: ValueId(0), |
| 380 | true_dest: bb_t, |
| 381 | true_args: vec![], |
| 382 | false_dest: bb_e, |
| 383 | false_args: vec![], |
| 384 | }); |
| 385 | f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 386 | f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 387 | m.add_function(f); |
| 388 | |
| 389 | assert!(!ConstProp.run(&mut m)); |
| 390 | assert!(matches!( |
| 391 | m.functions[0].blocks[0].terminator, |
| 392 | Some(Terminator::CondBranch { .. }) |
| 393 | )); |
| 394 | } |
| 395 | } |
| 396 |