| 1 | //! Sparse Conditional Constant Propagation (SCCP) pass. |
| 2 | //! |
| 3 | //! Combines constant tracking with reachability analysis. The key |
| 4 | //! insight over plain `const_fold` + `const_prop` is that constants |
| 5 | //! flowing into a block parameter via a CFG edge whose source is |
| 6 | //! statically unreachable do **not** force the parameter to Bottom. |
| 7 | //! That lets SCCP fold cases the basic passes can't reach, e.g. |
| 8 | //! ```text |
| 9 | //! entry: br .true., a, b |
| 10 | //! a: br merge(c1) |
| 11 | //! b: br merge(c2) ; b unreachable from entry |
| 12 | //! merge: param p; ... use p ... |
| 13 | //! ``` |
| 14 | //! After SCCP, the merge param `p` is constant `c1` because only the |
| 15 | //! `a → merge` edge is reachable. |
| 16 | //! |
| 17 | //! Algorithm (Wegman-Zadeck, 1991): |
| 18 | //! 1. Lattice: `Top` (not yet seen), `Const(c)` (proven constant), |
| 19 | //! `Bottom` (overdefined / non-constant). |
| 20 | //! 2. Reachability: a `HashSet<BlockId>` of reachable blocks plus a |
| 21 | //! `HashSet<(BlockId, BlockId)>` of reachable CFG edges. |
| 22 | //! 3. Two worklists: |
| 23 | //! * SSA-edge worklist — re-evaluate uses when a value's lattice |
| 24 | //! moves down (Top → Const, Top → Bottom, Const → Bottom). |
| 25 | //! * CFG-edge worklist — when a new edge becomes reachable, |
| 26 | //! re-meet the destination block's params. |
| 27 | //! 4. Iterate until both worklists are empty. |
| 28 | //! |
| 29 | //! After fixpoint we materialize the result: |
| 30 | //! * Each constant block-param is rewritten: a `Const*` instruction |
| 31 | //! is inserted at the start of its block; uses are redirected; |
| 32 | //! the param entry and matching predecessor args are dropped. |
| 33 | //! * Each constant-condition branch is folded to an unconditional |
| 34 | //! `Branch` to the live target. |
| 35 | //! * `prune_unreachable` reaps the now-dead blocks. |
| 36 | //! |
| 37 | //! Per-instruction arithmetic transfer (e.g. `IAdd Const Const → |
| 38 | //! Const`) is intentionally left to `const_fold`. The pass-manager |
| 39 | //! fixpoint composes them: SCCP exposes new reachability → |
| 40 | //! `const_fold` propagates → SCCP runs again. This keeps SCCP narrow |
| 41 | //! and avoids duplicating the dozens of width-aware folds in |
| 42 | //! `const_fold::try_fold`. |
| 43 | |
| 44 | use super::pass::Pass; |
| 45 | use super::util::prune_unreachable; |
| 46 | use crate::ir::inst::*; |
| 47 | use crate::ir::types::{FloatWidth, IntWidth, IrType}; |
| 48 | use std::collections::{HashMap, HashSet, VecDeque}; |
| 49 | |
| 50 | #[derive(Debug, Clone, Copy, PartialEq)] |
| 51 | enum ConstVal { |
| 52 | Int(i128, IntWidth), |
| 53 | Float(u64, FloatWidth), // bit pattern — avoids NaN equality issues |
| 54 | Bool(bool), |
| 55 | } |
| 56 | |
| 57 | #[derive(Debug, Clone, Copy, PartialEq)] |
| 58 | enum Lattice { |
| 59 | Top, |
| 60 | Const(ConstVal), |
| 61 | Bottom, |
| 62 | } |
| 63 | |
| 64 | impl ConstVal { |
| 65 | fn from_inst(kind: &InstKind) -> Option<Self> { |
| 66 | match kind { |
| 67 | InstKind::ConstInt(v, w) => Some(ConstVal::Int(sext(*v, w.bits()), *w)), |
| 68 | InstKind::ConstFloat(v, w) => Some(ConstVal::Float(v.to_bits(), *w)), |
| 69 | InstKind::ConstBool(b) => Some(ConstVal::Bool(*b)), |
| 70 | _ => None, |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | fn to_inst_kind(self) -> InstKind { |
| 75 | match self { |
| 76 | ConstVal::Int(v, w) => InstKind::ConstInt(v, w), |
| 77 | ConstVal::Float(bits, w) => InstKind::ConstFloat(f64::from_bits(bits), w), |
| 78 | ConstVal::Bool(b) => InstKind::ConstBool(b), |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | fn ir_type(self) -> IrType { |
| 83 | match self { |
| 84 | ConstVal::Int(_, w) => IrType::Int(w), |
| 85 | ConstVal::Float(_, w) => IrType::Float(w), |
| 86 | ConstVal::Bool(_) => IrType::Bool, |
| 87 | } |
| 88 | } |
| 89 | } |
| 90 | |
| 91 | fn sext(v: i128, bits: u32) -> i128 { |
| 92 | if bits >= 128 { |
| 93 | v |
| 94 | } else { |
| 95 | let shift = 128 - bits; |
| 96 | (v << shift) >> shift |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | fn meet(a: Lattice, b: Lattice) -> Lattice { |
| 101 | match (a, b) { |
| 102 | (Lattice::Top, x) | (x, Lattice::Top) => x, |
| 103 | (Lattice::Bottom, _) | (_, Lattice::Bottom) => Lattice::Bottom, |
| 104 | (Lattice::Const(x), Lattice::Const(y)) => { |
| 105 | if x == y { |
| 106 | Lattice::Const(x) |
| 107 | } else { |
| 108 | Lattice::Bottom |
| 109 | } |
| 110 | } |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | struct Sccp<'a> { |
| 115 | func: &'a Function, |
| 116 | /// Reverse map: ValueId → its defining block id (for SSA-edge |
| 117 | /// worklist processing of block params, since we need to know |
| 118 | /// which block's incoming edges to re-meet). |
| 119 | block_param_owner: HashMap<ValueId, BlockId>, |
| 120 | /// Lattice value per ValueId. Absent = Top. |
| 121 | lattice: HashMap<ValueId, Lattice>, |
| 122 | /// Reachable blocks. |
| 123 | reachable_blocks: HashSet<BlockId>, |
| 124 | /// Reachable CFG edges (pred, succ). |
| 125 | reachable_edges: HashSet<(BlockId, BlockId)>, |
| 126 | /// CFG worklist — newly reachable edges to process. |
| 127 | cfg_worklist: VecDeque<(BlockId, BlockId)>, |
| 128 | /// SSA worklist — values whose lattice changed; re-evaluate users |
| 129 | /// (block params and terminators that consume them). |
| 130 | ssa_worklist: VecDeque<ValueId>, |
| 131 | /// Predecessor list (built once). |
| 132 | preds: HashMap<BlockId, Vec<BlockId>>, |
| 133 | /// (pred_block, args_passed) per (pred → succ) edge — used to |
| 134 | /// look up what value flows into each block param along an edge. |
| 135 | /// Indexed by (pred, succ); empty for Return/Unreachable terms. |
| 136 | edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>>, |
| 137 | /// For each block param ValueId: the (pred → succ) edges that |
| 138 | /// feed it, plus its index within the block's param list. We |
| 139 | /// rebuild the meet over reachable edges whenever any of the |
| 140 | /// inputs changes. |
| 141 | param_inputs: HashMap<ValueId, ParamInputInfo>, |
| 142 | /// For each ValueId, the set of dependents we should requeue on |
| 143 | /// SSA-worklist when it changes. For now: only block params and |
| 144 | /// terminator-conditions are tracked (we don't model arithmetic). |
| 145 | value_users: HashMap<ValueId, Vec<ValueUser>>, |
| 146 | } |
| 147 | |
| 148 | #[derive(Debug, Clone)] |
| 149 | struct ParamInputInfo { |
| 150 | block: BlockId, |
| 151 | param_idx: usize, |
| 152 | } |
| 153 | |
| 154 | #[derive(Debug, Clone)] |
| 155 | enum ValueUser { |
| 156 | /// `cond` of a CondBranch in the given block; reanalyze the |
| 157 | /// terminator to maybe expose newly-unreachable successors. |
| 158 | Terminator(BlockId), |
| 159 | /// A block param that consumes this value along some edge — |
| 160 | /// re-meet the param when the value's lattice changes. |
| 161 | BlockParam(ValueId), |
| 162 | } |
| 163 | |
| 164 | impl<'a> Sccp<'a> { |
| 165 | fn new(func: &'a Function) -> Self { |
| 166 | let mut block_param_owner = HashMap::new(); |
| 167 | let mut param_inputs = HashMap::new(); |
| 168 | for block in &func.blocks { |
| 169 | for (i, p) in block.params.iter().enumerate() { |
| 170 | block_param_owner.insert(p.id, block.id); |
| 171 | param_inputs.insert( |
| 172 | p.id, |
| 173 | ParamInputInfo { |
| 174 | block: block.id, |
| 175 | param_idx: i, |
| 176 | }, |
| 177 | ); |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | let mut preds: HashMap<BlockId, Vec<BlockId>> = HashMap::new(); |
| 182 | let mut edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>> = HashMap::new(); |
| 183 | for block in &func.blocks { |
| 184 | preds.entry(block.id).or_default(); |
| 185 | } |
| 186 | for block in &func.blocks { |
| 187 | let Some(term) = &block.terminator else { |
| 188 | continue; |
| 189 | }; |
| 190 | match term { |
| 191 | Terminator::Branch(succ, args) => { |
| 192 | preds.entry(*succ).or_default().push(block.id); |
| 193 | edge_args.insert((block.id, *succ), args.clone()); |
| 194 | } |
| 195 | Terminator::CondBranch { |
| 196 | true_dest, |
| 197 | true_args, |
| 198 | false_dest, |
| 199 | false_args, |
| 200 | .. |
| 201 | } => { |
| 202 | preds.entry(*true_dest).or_default().push(block.id); |
| 203 | preds.entry(*false_dest).or_default().push(block.id); |
| 204 | edge_args.insert((block.id, *true_dest), true_args.clone()); |
| 205 | edge_args.insert((block.id, *false_dest), false_args.clone()); |
| 206 | } |
| 207 | Terminator::Switch { cases, default, .. } => { |
| 208 | preds.entry(*default).or_default().push(block.id); |
| 209 | edge_args.insert((block.id, *default), vec![]); |
| 210 | for (_, tgt) in cases { |
| 211 | preds.entry(*tgt).or_default().push(block.id); |
| 212 | edge_args.insert((block.id, *tgt), vec![]); |
| 213 | } |
| 214 | } |
| 215 | Terminator::Return(_) | Terminator::Unreachable => {} |
| 216 | } |
| 217 | } |
| 218 | |
| 219 | // Build dependency graph: for each value, the users that need |
| 220 | // to be reanalyzed when it changes. |
| 221 | let mut value_users: HashMap<ValueId, Vec<ValueUser>> = HashMap::new(); |
| 222 | for block in &func.blocks { |
| 223 | // Block-param dependents: each param's incoming arg |
| 224 | // along each pred edge. |
| 225 | for (pi, p) in block.params.iter().enumerate() { |
| 226 | for pred in preds.get(&block.id).cloned().unwrap_or_default() { |
| 227 | if let Some(args) = edge_args.get(&(pred, block.id)) { |
| 228 | if let Some(arg) = args.get(pi) { |
| 229 | value_users |
| 230 | .entry(*arg) |
| 231 | .or_default() |
| 232 | .push(ValueUser::BlockParam(p.id)); |
| 233 | } |
| 234 | } |
| 235 | } |
| 236 | } |
| 237 | // Terminator-condition dependents (only CondBranch/Switch |
| 238 | // can fold on a constant condition). |
| 239 | if let Some(term) = &block.terminator { |
| 240 | match term { |
| 241 | Terminator::CondBranch { cond, .. } => { |
| 242 | value_users |
| 243 | .entry(*cond) |
| 244 | .or_default() |
| 245 | .push(ValueUser::Terminator(block.id)); |
| 246 | } |
| 247 | Terminator::Switch { selector, .. } => { |
| 248 | value_users |
| 249 | .entry(*selector) |
| 250 | .or_default() |
| 251 | .push(ValueUser::Terminator(block.id)); |
| 252 | } |
| 253 | _ => {} |
| 254 | } |
| 255 | } |
| 256 | } |
| 257 | |
| 258 | Self { |
| 259 | func, |
| 260 | block_param_owner, |
| 261 | lattice: HashMap::new(), |
| 262 | reachable_blocks: HashSet::new(), |
| 263 | reachable_edges: HashSet::new(), |
| 264 | cfg_worklist: VecDeque::new(), |
| 265 | ssa_worklist: VecDeque::new(), |
| 266 | preds, |
| 267 | edge_args, |
| 268 | param_inputs, |
| 269 | value_users, |
| 270 | } |
| 271 | } |
| 272 | |
| 273 | fn lat(&self, v: ValueId) -> Lattice { |
| 274 | self.lattice.get(&v).copied().unwrap_or(Lattice::Top) |
| 275 | } |
| 276 | |
| 277 | /// Set lattice value, queue SSA-worklist if changed and not |
| 278 | /// already at the bottom. |
| 279 | fn set_lat(&mut self, v: ValueId, new: Lattice) { |
| 280 | let old = self.lat(v); |
| 281 | let merged = meet(old, new); |
| 282 | if merged != old { |
| 283 | self.lattice.insert(v, merged); |
| 284 | self.ssa_worklist.push_back(v); |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | /// Make a block reachable (idempotent). Seeds CFG worklist with |
| 289 | /// outgoing edges and marks the block. |
| 290 | fn mark_block_reachable(&mut self, b: BlockId) { |
| 291 | if !self.reachable_blocks.insert(b) { |
| 292 | return; |
| 293 | } |
| 294 | // First time reaching this block — also seed its outgoing |
| 295 | // edges based on its (current) terminator analysis. |
| 296 | self.process_terminator(b); |
| 297 | } |
| 298 | |
| 299 | /// Mark a CFG edge reachable; if newly so, mark the destination |
| 300 | /// block reachable and queue the edge for block-param re-meet. |
| 301 | fn mark_edge(&mut self, from: BlockId, to: BlockId) { |
| 302 | if self.reachable_edges.insert((from, to)) { |
| 303 | self.cfg_worklist.push_back((from, to)); |
| 304 | self.mark_block_reachable(to); |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | /// Analyze a block's terminator and mark live successors. |
| 309 | /// Called when the block becomes reachable or when its |
| 310 | /// terminator-condition lattice changes. |
| 311 | fn process_terminator(&mut self, b: BlockId) { |
| 312 | let Some(block) = self.func.blocks.iter().find(|bb| bb.id == b) else { |
| 313 | return; |
| 314 | }; |
| 315 | let Some(term) = &block.terminator else { |
| 316 | return; |
| 317 | }; |
| 318 | match term { |
| 319 | Terminator::Return(_) | Terminator::Unreachable => {} |
| 320 | Terminator::Branch(dest, _) => { |
| 321 | let dest = *dest; |
| 322 | self.mark_edge(b, dest); |
| 323 | } |
| 324 | Terminator::CondBranch { |
| 325 | cond, |
| 326 | true_dest, |
| 327 | false_dest, |
| 328 | .. |
| 329 | } => match self.lat(*cond) { |
| 330 | Lattice::Const(ConstVal::Bool(true)) => self.mark_edge(b, *true_dest), |
| 331 | Lattice::Const(ConstVal::Bool(false)) => self.mark_edge(b, *false_dest), |
| 332 | Lattice::Bottom => { |
| 333 | let t = *true_dest; |
| 334 | let f = *false_dest; |
| 335 | self.mark_edge(b, t); |
| 336 | self.mark_edge(b, f); |
| 337 | } |
| 338 | // Const(non-bool) is a type error in well-typed IR; |
| 339 | // be conservative and mark both. Top means we'll |
| 340 | // revisit when `cond` resolves. |
| 341 | Lattice::Const(_) => { |
| 342 | let t = *true_dest; |
| 343 | let f = *false_dest; |
| 344 | self.mark_edge(b, t); |
| 345 | self.mark_edge(b, f); |
| 346 | } |
| 347 | Lattice::Top => {} |
| 348 | }, |
| 349 | Terminator::Switch { |
| 350 | selector, |
| 351 | cases, |
| 352 | default, |
| 353 | } => match self.lat(*selector) { |
| 354 | Lattice::Const(ConstVal::Int(sv, w)) => { |
| 355 | let bits = w.bits(); |
| 356 | let target = cases |
| 357 | .iter() |
| 358 | .find(|(k, _)| sext(*k as i128, bits) == sv) |
| 359 | .map(|(_, blk)| *blk) |
| 360 | .unwrap_or(*default); |
| 361 | self.mark_edge(b, target); |
| 362 | } |
| 363 | Lattice::Bottom => { |
| 364 | let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect(); |
| 365 | let default = *default; |
| 366 | for t in cases_clone { |
| 367 | self.mark_edge(b, t); |
| 368 | } |
| 369 | self.mark_edge(b, default); |
| 370 | } |
| 371 | Lattice::Const(_) => { |
| 372 | let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect(); |
| 373 | let default = *default; |
| 374 | for t in cases_clone { |
| 375 | self.mark_edge(b, t); |
| 376 | } |
| 377 | self.mark_edge(b, default); |
| 378 | } |
| 379 | Lattice::Top => {} |
| 380 | }, |
| 381 | } |
| 382 | } |
| 383 | |
| 384 | /// Re-meet a block param's lattice value over all reachable |
| 385 | /// incoming edges. |
| 386 | fn recompute_param(&mut self, param_id: ValueId) { |
| 387 | let Some(info) = self.param_inputs.get(¶m_id).cloned() else { |
| 388 | return; |
| 389 | }; |
| 390 | let block = info.block; |
| 391 | let pi = info.param_idx; |
| 392 | let preds = self.preds.get(&block).cloned().unwrap_or_default(); |
| 393 | let mut new = Lattice::Top; |
| 394 | for pred in preds { |
| 395 | if !self.reachable_edges.contains(&(pred, block)) { |
| 396 | continue; |
| 397 | } |
| 398 | let Some(args) = self.edge_args.get(&(pred, block)) else { |
| 399 | continue; |
| 400 | }; |
| 401 | let Some(arg) = args.get(pi) else { |
| 402 | continue; |
| 403 | }; |
| 404 | // If the arg IS the param itself (back-edge with same |
| 405 | // value), don't pollute with our own current state. |
| 406 | // SSA permits this, e.g. loop induction merges. |
| 407 | if *arg == param_id { |
| 408 | continue; |
| 409 | } |
| 410 | new = meet(new, self.lat(*arg)); |
| 411 | if new == Lattice::Bottom { |
| 412 | break; |
| 413 | } |
| 414 | } |
| 415 | // Convert "Top with at least one reachable edge" → keep Top |
| 416 | // until a concrete arg shows up (Top is sound: still optimistic). |
| 417 | // monotone update via set_lat. |
| 418 | let old = self.lat(param_id); |
| 419 | let merged = meet(old, new); |
| 420 | if merged != old { |
| 421 | self.lattice.insert(param_id, merged); |
| 422 | self.ssa_worklist.push_back(param_id); |
| 423 | } |
| 424 | } |
| 425 | |
| 426 | fn run(&mut self) { |
| 427 | // Seed lattice for every constant-defining instruction. |
| 428 | for block in &self.func.blocks { |
| 429 | for inst in &block.insts { |
| 430 | if let Some(c) = ConstVal::from_inst(&inst.kind) { |
| 431 | self.lattice.insert(inst.id, Lattice::Const(c)); |
| 432 | } else { |
| 433 | // Non-const, non-block-param values default to |
| 434 | // Bottom — SCCP doesn't model arithmetic |
| 435 | // transfer. const_fold handles those after we |
| 436 | // expose new reachability. |
| 437 | self.lattice.insert(inst.id, Lattice::Bottom); |
| 438 | } |
| 439 | } |
| 440 | } |
| 441 | // Function params are unknown → Bottom. |
| 442 | for p in &self.func.params { |
| 443 | self.lattice.insert(p.id, Lattice::Bottom); |
| 444 | } |
| 445 | |
| 446 | // Entry is always reachable. |
| 447 | self.mark_block_reachable(self.func.entry); |
| 448 | |
| 449 | while !self.cfg_worklist.is_empty() || !self.ssa_worklist.is_empty() { |
| 450 | while let Some((from, to)) = self.cfg_worklist.pop_front() { |
| 451 | let _ = from; |
| 452 | // Re-meet every block param of `to`, since a new |
| 453 | // reachable edge may add a new input. |
| 454 | let params: Vec<ValueId> = self |
| 455 | .func |
| 456 | .blocks |
| 457 | .iter() |
| 458 | .find(|b| b.id == to) |
| 459 | .map(|b| b.params.iter().map(|p| p.id).collect()) |
| 460 | .unwrap_or_default(); |
| 461 | for pid in params { |
| 462 | self.recompute_param(pid); |
| 463 | } |
| 464 | } |
| 465 | while let Some(v) = self.ssa_worklist.pop_front() { |
| 466 | let users = self.value_users.get(&v).cloned().unwrap_or_default(); |
| 467 | for u in users { |
| 468 | match u { |
| 469 | ValueUser::Terminator(b) => { |
| 470 | if self.reachable_blocks.contains(&b) { |
| 471 | self.process_terminator(b); |
| 472 | } |
| 473 | } |
| 474 | ValueUser::BlockParam(pid) => { |
| 475 | self.recompute_param(pid); |
| 476 | } |
| 477 | } |
| 478 | } |
| 479 | // Block params themselves change → reanalyze |
| 480 | // terminators that consume them. Already handled by |
| 481 | // `value_users`. |
| 482 | let _ = self.block_param_owner.get(&v); |
| 483 | } |
| 484 | } |
| 485 | } |
| 486 | } |
| 487 | |
| 488 | /// Result of SCCP analysis — the data needed by `apply` without |
| 489 | /// holding any borrow on the function. |
| 490 | struct SccpResult { |
| 491 | lattice: HashMap<ValueId, Lattice>, |
| 492 | } |
| 493 | |
| 494 | /// Apply the SCCP analysis result to `func`. Returns whether anything |
| 495 | /// was rewritten. |
| 496 | fn apply(func: &mut Function, analysis: &SccpResult) -> bool { |
| 497 | let mut changed = false; |
| 498 | |
| 499 | // Step 1: rewrite constant block params. |
| 500 | // |
| 501 | // For every block param whose lattice resolved to `Const`: |
| 502 | // - Insert a fresh `Const*` instruction at the front of its |
| 503 | // block, allocate a new ValueId for it. |
| 504 | // - Substitute uses of the block-param ValueId with the new |
| 505 | // constant ValueId across the function. |
| 506 | // - Remove the param from the block's `params`. |
| 507 | // - Remove the matching argument from every predecessor's |
| 508 | // branch terminator. |
| 509 | // |
| 510 | // We do this block-by-block, scanning back→front within the |
| 511 | // params list so index-based predecessor arg removal stays |
| 512 | // valid. |
| 513 | let mut const_param_rewrites: Vec<(BlockId, usize, ValueId, ConstVal, IrType)> = Vec::new(); |
| 514 | for block in &func.blocks { |
| 515 | for (pi, p) in block.params.iter().enumerate() { |
| 516 | if let Some(Lattice::Const(c)) = analysis.lattice.get(&p.id).copied() { |
| 517 | // Sanity: lattice type should be compatible with |
| 518 | // declared param type. If not (defensive), skip. |
| 519 | if !ty_compatible(&p.ty, &c.ir_type()) { |
| 520 | continue; |
| 521 | } |
| 522 | const_param_rewrites.push((block.id, pi, p.id, c, p.ty.clone())); |
| 523 | } |
| 524 | } |
| 525 | } |
| 526 | |
| 527 | if !const_param_rewrites.is_empty() { |
| 528 | // Sort by (block, descending param index) so predecessor |
| 529 | // arg removal is index-stable as we mutate. |
| 530 | const_param_rewrites.sort_by_key(|x| std::cmp::Reverse(x.1)); |
| 531 | for (block_id, pi, param_id, cval, ty) in const_param_rewrites { |
| 532 | // Allocate a new value for the constant. |
| 533 | let new_id = func.next_value_id(); |
| 534 | // Insert at the front of the target block. |
| 535 | if let Some(block) = func.blocks.iter_mut().find(|b| b.id == block_id) { |
| 536 | let span = block |
| 537 | .insts |
| 538 | .first() |
| 539 | .map(|i| i.span) |
| 540 | .or_else(|| { |
| 541 | block.terminator.as_ref().map(|_| crate::lexer::Span { |
| 542 | start: crate::lexer::Position { line: 1, col: 1 }, |
| 543 | end: crate::lexer::Position { line: 1, col: 1 }, |
| 544 | file_id: 0, |
| 545 | }) |
| 546 | }) |
| 547 | .unwrap_or(crate::lexer::Span { |
| 548 | start: crate::lexer::Position { line: 1, col: 1 }, |
| 549 | end: crate::lexer::Position { line: 1, col: 1 }, |
| 550 | file_id: 0, |
| 551 | }); |
| 552 | block.insts.insert( |
| 553 | 0, |
| 554 | Inst { |
| 555 | id: new_id, |
| 556 | kind: cval.to_inst_kind(), |
| 557 | ty: ty.clone(), |
| 558 | span, |
| 559 | }, |
| 560 | ); |
| 561 | // Drop the param. |
| 562 | block.params.remove(pi); |
| 563 | } |
| 564 | // Rewrite uses of param_id → new_id. |
| 565 | crate::ir::walk::substitute_uses(func, param_id, new_id); |
| 566 | // Drop the matching argument in every predecessor's |
| 567 | // branch. |
| 568 | for pred_block in &mut func.blocks { |
| 569 | let Some(term) = pred_block.terminator.as_mut() else { |
| 570 | continue; |
| 571 | }; |
| 572 | match term { |
| 573 | Terminator::Branch(dest, args) if *dest == block_id && pi < args.len() => { |
| 574 | args.remove(pi); |
| 575 | } |
| 576 | Terminator::CondBranch { |
| 577 | true_dest, |
| 578 | true_args, |
| 579 | false_dest, |
| 580 | false_args, |
| 581 | .. |
| 582 | } => { |
| 583 | if *true_dest == block_id && pi < true_args.len() { |
| 584 | true_args.remove(pi); |
| 585 | } |
| 586 | if *false_dest == block_id && pi < false_args.len() { |
| 587 | false_args.remove(pi); |
| 588 | } |
| 589 | } |
| 590 | _ => {} |
| 591 | } |
| 592 | } |
| 593 | changed = true; |
| 594 | } |
| 595 | func.rebuild_type_cache(); |
| 596 | } |
| 597 | |
| 598 | // Step 2: fold constant-condition terminators. |
| 599 | // |
| 600 | // Re-collect const map (block-param rewrites just added new Const |
| 601 | // instructions). |
| 602 | let mut consts: HashMap<ValueId, ConstVal> = HashMap::new(); |
| 603 | for block in &func.blocks { |
| 604 | for inst in &block.insts { |
| 605 | if let Some(c) = ConstVal::from_inst(&inst.kind) { |
| 606 | consts.insert(inst.id, c); |
| 607 | } |
| 608 | } |
| 609 | } |
| 610 | let mut folded_any = false; |
| 611 | for block in &mut func.blocks { |
| 612 | let Some(term) = block.terminator.take() else { |
| 613 | continue; |
| 614 | }; |
| 615 | let new_term = match term { |
| 616 | Terminator::CondBranch { |
| 617 | cond, |
| 618 | true_dest, |
| 619 | true_args, |
| 620 | false_dest, |
| 621 | false_args, |
| 622 | } => match consts.get(&cond) { |
| 623 | Some(ConstVal::Bool(true)) => { |
| 624 | folded_any = true; |
| 625 | Terminator::Branch(true_dest, true_args) |
| 626 | } |
| 627 | Some(ConstVal::Bool(false)) => { |
| 628 | folded_any = true; |
| 629 | Terminator::Branch(false_dest, false_args) |
| 630 | } |
| 631 | _ => Terminator::CondBranch { |
| 632 | cond, |
| 633 | true_dest, |
| 634 | true_args, |
| 635 | false_dest, |
| 636 | false_args, |
| 637 | }, |
| 638 | }, |
| 639 | Terminator::Switch { |
| 640 | selector, |
| 641 | cases, |
| 642 | default, |
| 643 | } => match consts.get(&selector) { |
| 644 | Some(ConstVal::Int(sv, w)) => { |
| 645 | folded_any = true; |
| 646 | let bits = w.bits(); |
| 647 | let target = cases |
| 648 | .iter() |
| 649 | .find(|(k, _)| sext(*k as i128, bits) == *sv) |
| 650 | .map(|(_, blk)| *blk) |
| 651 | .unwrap_or(default); |
| 652 | Terminator::Branch(target, vec![]) |
| 653 | } |
| 654 | _ => Terminator::Switch { |
| 655 | selector, |
| 656 | cases, |
| 657 | default, |
| 658 | }, |
| 659 | }, |
| 660 | other => other, |
| 661 | }; |
| 662 | block.terminator = Some(new_term); |
| 663 | } |
| 664 | if folded_any { |
| 665 | changed = true; |
| 666 | } |
| 667 | |
| 668 | // Step 3: prune unreachable blocks. |
| 669 | // |
| 670 | // The SCCP analysis itself produced a reachability set, but the |
| 671 | // safe thing is to recompute it post-rewrite — block-param |
| 672 | // rewrites and terminator folds can change the CFG. |
| 673 | if prune_unreachable(func) { |
| 674 | changed = true; |
| 675 | } |
| 676 | if changed { |
| 677 | func.rebuild_type_cache(); |
| 678 | } |
| 679 | |
| 680 | changed |
| 681 | } |
| 682 | |
| 683 | fn ty_compatible(a: &IrType, b: &IrType) -> bool { |
| 684 | // Cheap structural compare for the const types we materialize. |
| 685 | match (a, b) { |
| 686 | (IrType::Bool, IrType::Bool) => true, |
| 687 | (IrType::Int(wa), IrType::Int(wb)) => wa == wb, |
| 688 | (IrType::Float(wa), IrType::Float(wb)) => wa == wb, |
| 689 | _ => false, |
| 690 | } |
| 691 | } |
| 692 | |
| 693 | /// The pass entry point. |
| 694 | pub struct Sccp_; |
| 695 | |
| 696 | impl Pass for Sccp_ { |
| 697 | fn name(&self) -> &'static str { |
| 698 | "sccp" |
| 699 | } |
| 700 | |
| 701 | fn run(&self, module: &mut Module) -> bool { |
| 702 | let mut changed = false; |
| 703 | for func in &mut module.functions { |
| 704 | let analysis = { |
| 705 | let mut s = Sccp::new(func); |
| 706 | s.run(); |
| 707 | SccpResult { lattice: s.lattice } |
| 708 | }; |
| 709 | if apply(func, &analysis) { |
| 710 | changed = true; |
| 711 | } |
| 712 | } |
| 713 | changed |
| 714 | } |
| 715 | } |
| 716 | |
| 717 | #[cfg(test)] |
| 718 | mod tests { |
| 719 | use super::*; |
| 720 | use crate::ir::types::IrType; |
| 721 | use crate::lexer::{Position, Span}; |
| 722 | |
| 723 | fn dummy_span() -> Span { |
| 724 | let p = Position { line: 1, col: 1 }; |
| 725 | Span { |
| 726 | start: p, |
| 727 | end: p, |
| 728 | file_id: 0, |
| 729 | } |
| 730 | } |
| 731 | |
| 732 | #[test] |
| 733 | fn folds_constant_true_condbranch() { |
| 734 | // entry: cond = const(true); cond_branch cond, then, else |
| 735 | // then: ret |
| 736 | // else: ret |
| 737 | let mut m = Module::new("t".into()); |
| 738 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 739 | let bb_t = f.create_block("then"); |
| 740 | let bb_e = f.create_block("else"); |
| 741 | let cond_id = f.next_value_id(); |
| 742 | f.block_mut(f.entry).insts.push(Inst { |
| 743 | id: cond_id, |
| 744 | kind: InstKind::ConstBool(true), |
| 745 | ty: IrType::Bool, |
| 746 | span: dummy_span(), |
| 747 | }); |
| 748 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 749 | cond: cond_id, |
| 750 | true_dest: bb_t, |
| 751 | true_args: vec![], |
| 752 | false_dest: bb_e, |
| 753 | false_args: vec![], |
| 754 | }); |
| 755 | f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 756 | f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 757 | m.add_function(f); |
| 758 | |
| 759 | assert!(Sccp_.run(&mut m)); |
| 760 | let f = &m.functions[0]; |
| 761 | match &f.blocks[0].terminator { |
| 762 | Some(Terminator::Branch(d, _)) => assert_eq!(*d, bb_t), |
| 763 | other => panic!("expected Branch, got {:?}", other), |
| 764 | } |
| 765 | // else block should be pruned. |
| 766 | assert!(f.blocks.iter().all(|b| b.id != bb_e)); |
| 767 | } |
| 768 | |
| 769 | #[test] |
| 770 | fn merge_param_with_uniform_const_resolves() { |
| 771 | // entry: cond=const(true); cond_branch cond, a, b |
| 772 | // a: branch merge(const 7) |
| 773 | // b: branch merge(const 99) ; statically unreachable |
| 774 | // merge: param p; ret p |
| 775 | // |
| 776 | // Plain const_prop can't fold this — `p` has two distinct |
| 777 | // incoming arg values, and const_prop doesn't know `b` is |
| 778 | // unreachable. SCCP does. |
| 779 | let mut m = Module::new("t".into()); |
| 780 | let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32)); |
| 781 | |
| 782 | let bb_a = f.create_block("a"); |
| 783 | let bb_b = f.create_block("b"); |
| 784 | let bb_merge = f.create_block("merge"); |
| 785 | |
| 786 | let cond_id = f.next_value_id(); |
| 787 | let const_a = f.next_value_id(); |
| 788 | let const_b = f.next_value_id(); |
| 789 | let merge_param = f.next_value_id(); |
| 790 | |
| 791 | f.block_mut(bb_merge).params.push(BlockParam { |
| 792 | id: merge_param, |
| 793 | ty: IrType::Int(IntWidth::I32), |
| 794 | }); |
| 795 | |
| 796 | f.block_mut(f.entry).insts.push(Inst { |
| 797 | id: cond_id, |
| 798 | kind: InstKind::ConstBool(true), |
| 799 | ty: IrType::Bool, |
| 800 | span: dummy_span(), |
| 801 | }); |
| 802 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 803 | cond: cond_id, |
| 804 | true_dest: bb_a, |
| 805 | true_args: vec![], |
| 806 | false_dest: bb_b, |
| 807 | false_args: vec![], |
| 808 | }); |
| 809 | |
| 810 | f.block_mut(bb_a).insts.push(Inst { |
| 811 | id: const_a, |
| 812 | kind: InstKind::ConstInt(7, IntWidth::I32), |
| 813 | ty: IrType::Int(IntWidth::I32), |
| 814 | span: dummy_span(), |
| 815 | }); |
| 816 | f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a])); |
| 817 | |
| 818 | f.block_mut(bb_b).insts.push(Inst { |
| 819 | id: const_b, |
| 820 | kind: InstKind::ConstInt(99, IntWidth::I32), |
| 821 | ty: IrType::Int(IntWidth::I32), |
| 822 | span: dummy_span(), |
| 823 | }); |
| 824 | f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b])); |
| 825 | |
| 826 | f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param))); |
| 827 | f.rebuild_type_cache(); |
| 828 | m.add_function(f); |
| 829 | |
| 830 | assert!(Sccp_.run(&mut m), "SCCP should fold this"); |
| 831 | |
| 832 | let f = &m.functions[0]; |
| 833 | // merge block: param removed, new ConstInt(7) inserted, ret |
| 834 | // now references that const. |
| 835 | let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap(); |
| 836 | assert!(merge.params.is_empty(), "merge param should be gone"); |
| 837 | // First inst should be a ConstInt(7). |
| 838 | match merge.insts.first().map(|i| &i.kind) { |
| 839 | Some(InstKind::ConstInt(7, IntWidth::I32)) => {} |
| 840 | other => panic!("expected ConstInt(7,I32) at start of merge, got {:?}", other), |
| 841 | } |
| 842 | // bb_b is unreachable post-fold and should be pruned. |
| 843 | assert!( |
| 844 | !f.blocks.iter().any(|b| b.id == bb_b), |
| 845 | "unreachable b should be pruned" |
| 846 | ); |
| 847 | } |
| 848 | |
| 849 | #[test] |
| 850 | fn non_constant_param_is_not_rewritten() { |
| 851 | // entry: cond_branch <param> a, b |
| 852 | // a: branch merge(const 1) |
| 853 | // b: branch merge(const 2) |
| 854 | // merge: ret param |
| 855 | // |
| 856 | // Both arms reachable, two distinct constants → param stays. |
| 857 | let mut m = Module::new("t".into()); |
| 858 | let params = vec![Param { |
| 859 | name: "c".into(), |
| 860 | ty: IrType::Bool, |
| 861 | id: ValueId(0), |
| 862 | fortran_noalias: false, |
| 863 | }]; |
| 864 | let mut f = Function::new("f".into(), params, IrType::Int(IntWidth::I32)); |
| 865 | let bb_a = f.create_block("a"); |
| 866 | let bb_b = f.create_block("b"); |
| 867 | let bb_merge = f.create_block("merge"); |
| 868 | |
| 869 | let const_a = f.next_value_id(); |
| 870 | let const_b = f.next_value_id(); |
| 871 | let merge_param = f.next_value_id(); |
| 872 | f.block_mut(bb_merge).params.push(BlockParam { |
| 873 | id: merge_param, |
| 874 | ty: IrType::Int(IntWidth::I32), |
| 875 | }); |
| 876 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 877 | cond: ValueId(0), |
| 878 | true_dest: bb_a, |
| 879 | true_args: vec![], |
| 880 | false_dest: bb_b, |
| 881 | false_args: vec![], |
| 882 | }); |
| 883 | f.block_mut(bb_a).insts.push(Inst { |
| 884 | id: const_a, |
| 885 | kind: InstKind::ConstInt(1, IntWidth::I32), |
| 886 | ty: IrType::Int(IntWidth::I32), |
| 887 | span: dummy_span(), |
| 888 | }); |
| 889 | f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a])); |
| 890 | f.block_mut(bb_b).insts.push(Inst { |
| 891 | id: const_b, |
| 892 | kind: InstKind::ConstInt(2, IntWidth::I32), |
| 893 | ty: IrType::Int(IntWidth::I32), |
| 894 | span: dummy_span(), |
| 895 | }); |
| 896 | f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b])); |
| 897 | f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param))); |
| 898 | f.rebuild_type_cache(); |
| 899 | m.add_function(f); |
| 900 | |
| 901 | let _changed = Sccp_.run(&mut m); |
| 902 | let f = &m.functions[0]; |
| 903 | // The merge block's param must remain — both arms are |
| 904 | // genuinely reachable and disagree. |
| 905 | let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap(); |
| 906 | assert_eq!( |
| 907 | merge.params.len(), |
| 908 | 1, |
| 909 | "merge param should survive when both arms are reachable and disagree" |
| 910 | ); |
| 911 | } |
| 912 | |
| 913 | #[test] |
| 914 | fn unknown_cond_left_alone() { |
| 915 | let mut m = Module::new("t".into()); |
| 916 | let params = vec![Param { |
| 917 | name: "p".into(), |
| 918 | ty: IrType::Bool, |
| 919 | id: ValueId(0), |
| 920 | fortran_noalias: false, |
| 921 | }]; |
| 922 | let mut f = Function::new("f".into(), params, IrType::Void); |
| 923 | let bb_t = f.create_block("then"); |
| 924 | let bb_e = f.create_block("else"); |
| 925 | f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 926 | cond: ValueId(0), |
| 927 | true_dest: bb_t, |
| 928 | true_args: vec![], |
| 929 | false_dest: bb_e, |
| 930 | false_args: vec![], |
| 931 | }); |
| 932 | f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 933 | f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 934 | m.add_function(f); |
| 935 | |
| 936 | assert!(!Sccp_.run(&mut m)); |
| 937 | assert!(matches!( |
| 938 | m.functions[0].blocks[0].terminator, |
| 939 | Some(Terminator::CondBranch { .. }) |
| 940 | )); |
| 941 | } |
| 942 | } |
| 943 |