@@ -0,0 +1,944 @@ |
| 1 | +//! Sparse Conditional Constant Propagation (SCCP) pass. |
| 2 | +//! |
| 3 | +//! Combines constant tracking with reachability analysis. The key |
| 4 | +//! insight over plain `const_fold` + `const_prop` is that constants |
| 5 | +//! flowing into a block parameter via a CFG edge whose source is |
| 6 | +//! statically unreachable do **not** force the parameter to Bottom. |
| 7 | +//! That lets SCCP fold cases the basic passes can't reach, e.g. |
| 8 | +//! ```text |
| 9 | +//! entry: br .true., a, b |
| 10 | +//! a: br merge(c1) |
| 11 | +//! b: br merge(c2) ; b unreachable from entry |
| 12 | +//! merge: param p; ... use p ... |
| 13 | +//! ``` |
| 14 | +//! After SCCP, the merge param `p` is constant `c1` because only the |
| 15 | +//! `a → merge` edge is reachable. |
| 16 | +//! |
| 17 | +//! Algorithm (Wegman-Zadeck, 1991): |
| 18 | +//! 1. Lattice: `Top` (not yet seen), `Const(c)` (proven constant), |
| 19 | +//! `Bottom` (overdefined / non-constant). |
| 20 | +//! 2. Reachability: a `HashSet<BlockId>` of reachable blocks plus a |
| 21 | +//! `HashSet<(BlockId, BlockId)>` of reachable CFG edges. |
| 22 | +//! 3. Two worklists: |
| 23 | +//! * SSA-edge worklist — re-evaluate uses when a value's lattice |
| 24 | +//! moves down (Top → Const, Top → Bottom, Const → Bottom). |
| 25 | +//! * CFG-edge worklist — when a new edge becomes reachable, |
| 26 | +//! re-meet the destination block's params. |
| 27 | +//! 4. Iterate until both worklists are empty. |
| 28 | +//! |
| 29 | +//! After fixpoint we materialize the result: |
| 30 | +//! * Each constant block-param is rewritten: a `Const*` instruction |
| 31 | +//! is inserted at the start of its block; uses are redirected; |
| 32 | +//! the param entry and matching predecessor args are dropped. |
| 33 | +//! * Each constant-condition branch is folded to an unconditional |
| 34 | +//! `Branch` to the live target. |
| 35 | +//! * `prune_unreachable` reaps the now-dead blocks. |
| 36 | +//! |
| 37 | +//! Per-instruction arithmetic transfer (e.g. `IAdd Const Const → |
| 38 | +//! Const`) is intentionally left to `const_fold`. The pass-manager |
| 39 | +//! fixpoint composes them: SCCP exposes new reachability → |
| 40 | +//! `const_fold` propagates → SCCP runs again. This keeps SCCP narrow |
| 41 | +//! and avoids duplicating the dozens of width-aware folds in |
| 42 | +//! `const_fold::try_fold`. |
| 43 | + |
| 44 | +use super::pass::Pass; |
| 45 | +use super::util::prune_unreachable; |
| 46 | +use crate::ir::inst::*; |
| 47 | +use crate::ir::types::{FloatWidth, IntWidth, IrType}; |
| 48 | +use std::collections::{HashMap, HashSet, VecDeque}; |
| 49 | + |
| 50 | +#[derive(Debug, Clone, Copy, PartialEq)] |
| 51 | +enum ConstVal { |
| 52 | + Int(i128, IntWidth), |
| 53 | + Float(u64, FloatWidth), // bit pattern — avoids NaN equality issues |
| 54 | + Bool(bool), |
| 55 | +} |
| 56 | + |
| 57 | +#[derive(Debug, Clone, Copy, PartialEq)] |
| 58 | +enum Lattice { |
| 59 | + Top, |
| 60 | + Const(ConstVal), |
| 61 | + Bottom, |
| 62 | +} |
| 63 | + |
| 64 | +impl ConstVal { |
| 65 | + fn from_inst(kind: &InstKind) -> Option<Self> { |
| 66 | + match kind { |
| 67 | + InstKind::ConstInt(v, w) => Some(ConstVal::Int(sext(*v, w.bits()), *w)), |
| 68 | + InstKind::ConstFloat(v, w) => Some(ConstVal::Float(v.to_bits(), *w)), |
| 69 | + InstKind::ConstBool(b) => Some(ConstVal::Bool(*b)), |
| 70 | + _ => None, |
| 71 | + } |
| 72 | + } |
| 73 | + |
| 74 | + fn to_inst_kind(self) -> InstKind { |
| 75 | + match self { |
| 76 | + ConstVal::Int(v, w) => InstKind::ConstInt(v, w), |
| 77 | + ConstVal::Float(bits, w) => InstKind::ConstFloat(f64::from_bits(bits), w), |
| 78 | + ConstVal::Bool(b) => InstKind::ConstBool(b), |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + fn ir_type(self) -> IrType { |
| 83 | + match self { |
| 84 | + ConstVal::Int(_, w) => IrType::Int(w), |
| 85 | + ConstVal::Float(_, w) => IrType::Float(w), |
| 86 | + ConstVal::Bool(_) => IrType::Bool, |
| 87 | + } |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +fn sext(v: i128, bits: u32) -> i128 { |
| 92 | + if bits >= 128 { |
| 93 | + v |
| 94 | + } else { |
| 95 | + let shift = 128 - bits; |
| 96 | + (v << shift) >> shift |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +fn meet(a: Lattice, b: Lattice) -> Lattice { |
| 101 | + match (a, b) { |
| 102 | + (Lattice::Top, x) | (x, Lattice::Top) => x, |
| 103 | + (Lattice::Bottom, _) | (_, Lattice::Bottom) => Lattice::Bottom, |
| 104 | + (Lattice::Const(x), Lattice::Const(y)) => { |
| 105 | + if x == y { |
| 106 | + Lattice::Const(x) |
| 107 | + } else { |
| 108 | + Lattice::Bottom |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +struct Sccp<'a> { |
| 115 | + func: &'a Function, |
| 116 | + /// Reverse map: ValueId → its defining block id (for SSA-edge |
| 117 | + /// worklist processing of block params, since we need to know |
| 118 | + /// which block's incoming edges to re-meet). |
| 119 | + block_param_owner: HashMap<ValueId, BlockId>, |
| 120 | + /// Lattice value per ValueId. Absent = Top. |
| 121 | + lattice: HashMap<ValueId, Lattice>, |
| 122 | + /// Reachable blocks. |
| 123 | + reachable_blocks: HashSet<BlockId>, |
| 124 | + /// Reachable CFG edges (pred, succ). |
| 125 | + reachable_edges: HashSet<(BlockId, BlockId)>, |
| 126 | + /// CFG worklist — newly reachable edges to process. |
| 127 | + cfg_worklist: VecDeque<(BlockId, BlockId)>, |
| 128 | + /// SSA worklist — values whose lattice changed; re-evaluate users |
| 129 | + /// (block params and terminators that consume them). |
| 130 | + ssa_worklist: VecDeque<ValueId>, |
| 131 | + /// Predecessor list (built once). |
| 132 | + preds: HashMap<BlockId, Vec<BlockId>>, |
| 133 | + /// (pred_block, args_passed) per (pred → succ) edge — used to |
| 134 | + /// look up what value flows into each block param along an edge. |
| 135 | + /// Indexed by (pred, succ); empty for Return/Unreachable terms. |
| 136 | + edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>>, |
| 137 | + /// For each block param ValueId: the (pred → succ) edges that |
| 138 | + /// feed it, plus its index within the block's param list. We |
| 139 | + /// rebuild the meet over reachable edges whenever any of the |
| 140 | + /// inputs changes. |
| 141 | + param_inputs: HashMap<ValueId, ParamInputInfo>, |
| 142 | + /// For each ValueId, the set of dependents we should requeue on |
| 143 | + /// SSA-worklist when it changes. For now: only block params and |
| 144 | + /// terminator-conditions are tracked (we don't model arithmetic). |
| 145 | + value_users: HashMap<ValueId, Vec<ValueUser>>, |
| 146 | +} |
| 147 | + |
| 148 | +#[derive(Debug, Clone)] |
| 149 | +struct ParamInputInfo { |
| 150 | + block: BlockId, |
| 151 | + param_idx: usize, |
| 152 | +} |
| 153 | + |
| 154 | +#[derive(Debug, Clone)] |
| 155 | +enum ValueUser { |
| 156 | + /// `cond` of a CondBranch in the given block; reanalyze the |
| 157 | + /// terminator to maybe expose newly-unreachable successors. |
| 158 | + Terminator(BlockId), |
| 159 | + /// A block param that consumes this value along some edge — |
| 160 | + /// re-meet the param when the value's lattice changes. |
| 161 | + BlockParam(ValueId), |
| 162 | +} |
| 163 | + |
| 164 | +impl<'a> Sccp<'a> { |
| 165 | + fn new(func: &'a Function) -> Self { |
| 166 | + let mut block_param_owner = HashMap::new(); |
| 167 | + let mut param_inputs = HashMap::new(); |
| 168 | + for block in &func.blocks { |
| 169 | + for (i, p) in block.params.iter().enumerate() { |
| 170 | + block_param_owner.insert(p.id, block.id); |
| 171 | + param_inputs.insert( |
| 172 | + p.id, |
| 173 | + ParamInputInfo { |
| 174 | + block: block.id, |
| 175 | + param_idx: i, |
| 176 | + }, |
| 177 | + ); |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + let mut preds: HashMap<BlockId, Vec<BlockId>> = HashMap::new(); |
| 182 | + let mut edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>> = HashMap::new(); |
| 183 | + for block in &func.blocks { |
| 184 | + preds.entry(block.id).or_default(); |
| 185 | + } |
| 186 | + for block in &func.blocks { |
| 187 | + let Some(term) = &block.terminator else { |
| 188 | + continue; |
| 189 | + }; |
| 190 | + match term { |
| 191 | + Terminator::Branch(succ, args) => { |
| 192 | + preds.entry(*succ).or_default().push(block.id); |
| 193 | + edge_args.insert((block.id, *succ), args.clone()); |
| 194 | + } |
| 195 | + Terminator::CondBranch { |
| 196 | + true_dest, |
| 197 | + true_args, |
| 198 | + false_dest, |
| 199 | + false_args, |
| 200 | + .. |
| 201 | + } => { |
| 202 | + preds.entry(*true_dest).or_default().push(block.id); |
| 203 | + preds.entry(*false_dest).or_default().push(block.id); |
| 204 | + edge_args.insert((block.id, *true_dest), true_args.clone()); |
| 205 | + edge_args.insert((block.id, *false_dest), false_args.clone()); |
| 206 | + } |
| 207 | + Terminator::Switch { cases, default, .. } => { |
| 208 | + preds.entry(*default).or_default().push(block.id); |
| 209 | + edge_args.insert((block.id, *default), vec![]); |
| 210 | + for (_, tgt) in cases { |
| 211 | + preds.entry(*tgt).or_default().push(block.id); |
| 212 | + edge_args.insert((block.id, *tgt), vec![]); |
| 213 | + } |
| 214 | + } |
| 215 | + Terminator::Return(_) | Terminator::Unreachable => {} |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + // Build dependency graph: for each value, the users that need |
| 220 | + // to be reanalyzed when it changes. |
| 221 | + let mut value_users: HashMap<ValueId, Vec<ValueUser>> = HashMap::new(); |
| 222 | + for block in &func.blocks { |
| 223 | + // Block-param dependents: each param's incoming arg |
| 224 | + // along each pred edge. |
| 225 | + for (pi, p) in block.params.iter().enumerate() { |
| 226 | + for pred in preds.get(&block.id).cloned().unwrap_or_default() { |
| 227 | + if let Some(args) = edge_args.get(&(pred, block.id)) { |
| 228 | + if let Some(arg) = args.get(pi) { |
| 229 | + value_users |
| 230 | + .entry(*arg) |
| 231 | + .or_default() |
| 232 | + .push(ValueUser::BlockParam(p.id)); |
| 233 | + } |
| 234 | + } |
| 235 | + } |
| 236 | + } |
| 237 | + // Terminator-condition dependents (only CondBranch/Switch |
| 238 | + // can fold on a constant condition). |
| 239 | + if let Some(term) = &block.terminator { |
| 240 | + match term { |
| 241 | + Terminator::CondBranch { cond, .. } => { |
| 242 | + value_users |
| 243 | + .entry(*cond) |
| 244 | + .or_default() |
| 245 | + .push(ValueUser::Terminator(block.id)); |
| 246 | + } |
| 247 | + Terminator::Switch { selector, .. } => { |
| 248 | + value_users |
| 249 | + .entry(*selector) |
| 250 | + .or_default() |
| 251 | + .push(ValueUser::Terminator(block.id)); |
| 252 | + } |
| 253 | + _ => {} |
| 254 | + } |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + Self { |
| 259 | + func, |
| 260 | + block_param_owner, |
| 261 | + lattice: HashMap::new(), |
| 262 | + reachable_blocks: HashSet::new(), |
| 263 | + reachable_edges: HashSet::new(), |
| 264 | + cfg_worklist: VecDeque::new(), |
| 265 | + ssa_worklist: VecDeque::new(), |
| 266 | + preds, |
| 267 | + edge_args, |
| 268 | + param_inputs, |
| 269 | + value_users, |
| 270 | + } |
| 271 | + } |
| 272 | + |
| 273 | + fn lat(&self, v: ValueId) -> Lattice { |
| 274 | + self.lattice.get(&v).copied().unwrap_or(Lattice::Top) |
| 275 | + } |
| 276 | + |
| 277 | + /// Set lattice value, queue SSA-worklist if changed and not |
| 278 | + /// already at the bottom. |
| 279 | + fn set_lat(&mut self, v: ValueId, new: Lattice) { |
| 280 | + let old = self.lat(v); |
| 281 | + let merged = meet(old, new); |
| 282 | + if merged != old { |
| 283 | + self.lattice.insert(v, merged); |
| 284 | + self.ssa_worklist.push_back(v); |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + /// Make a block reachable (idempotent). Seeds CFG worklist with |
| 289 | + /// outgoing edges and marks the block. |
| 290 | + fn mark_block_reachable(&mut self, b: BlockId) { |
| 291 | + if !self.reachable_blocks.insert(b) { |
| 292 | + return; |
| 293 | + } |
| 294 | + // First time reaching this block — also seed its outgoing |
| 295 | + // edges based on its (current) terminator analysis. |
| 296 | + self.process_terminator(b); |
| 297 | + } |
| 298 | + |
| 299 | + /// Mark a CFG edge reachable; if newly so, mark the destination |
| 300 | + /// block reachable and queue the edge for block-param re-meet. |
| 301 | + fn mark_edge(&mut self, from: BlockId, to: BlockId) { |
| 302 | + if self.reachable_edges.insert((from, to)) { |
| 303 | + self.cfg_worklist.push_back((from, to)); |
| 304 | + self.mark_block_reachable(to); |
| 305 | + } |
| 306 | + } |
| 307 | + |
| 308 | + /// Analyze a block's terminator and mark live successors. |
| 309 | + /// Called when the block becomes reachable or when its |
| 310 | + /// terminator-condition lattice changes. |
| 311 | + fn process_terminator(&mut self, b: BlockId) { |
| 312 | + let Some(block) = self.func.blocks.iter().find(|bb| bb.id == b) else { |
| 313 | + return; |
| 314 | + }; |
| 315 | + let Some(term) = &block.terminator else { |
| 316 | + return; |
| 317 | + }; |
| 318 | + match term { |
| 319 | + Terminator::Return(_) | Terminator::Unreachable => {} |
| 320 | + Terminator::Branch(dest, _) => { |
| 321 | + let dest = *dest; |
| 322 | + self.mark_edge(b, dest); |
| 323 | + } |
| 324 | + Terminator::CondBranch { |
| 325 | + cond, |
| 326 | + true_dest, |
| 327 | + false_dest, |
| 328 | + .. |
| 329 | + } => match self.lat(*cond) { |
| 330 | + Lattice::Const(ConstVal::Bool(true)) => self.mark_edge(b, *true_dest), |
| 331 | + Lattice::Const(ConstVal::Bool(false)) => self.mark_edge(b, *false_dest), |
| 332 | + Lattice::Bottom => { |
| 333 | + let t = *true_dest; |
| 334 | + let f = *false_dest; |
| 335 | + self.mark_edge(b, t); |
| 336 | + self.mark_edge(b, f); |
| 337 | + } |
| 338 | + // Const(non-bool) is a type error in well-typed IR; |
| 339 | + // be conservative and mark both. Top means we'll |
| 340 | + // revisit when `cond` resolves. |
| 341 | + Lattice::Const(_) => { |
| 342 | + let t = *true_dest; |
| 343 | + let f = *false_dest; |
| 344 | + self.mark_edge(b, t); |
| 345 | + self.mark_edge(b, f); |
| 346 | + } |
| 347 | + Lattice::Top => {} |
| 348 | + }, |
| 349 | + Terminator::Switch { |
| 350 | + selector, |
| 351 | + cases, |
| 352 | + default, |
| 353 | + } => match self.lat(*selector) { |
| 354 | + Lattice::Const(ConstVal::Int(sv, w)) => { |
| 355 | + let bits = w.bits(); |
| 356 | + let target = cases |
| 357 | + .iter() |
| 358 | + .find(|(k, _)| sext(*k as i128, bits) == sv) |
| 359 | + .map(|(_, blk)| *blk) |
| 360 | + .unwrap_or(*default); |
| 361 | + self.mark_edge(b, target); |
| 362 | + } |
| 363 | + Lattice::Bottom => { |
| 364 | + let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect(); |
| 365 | + let default = *default; |
| 366 | + for t in cases_clone { |
| 367 | + self.mark_edge(b, t); |
| 368 | + } |
| 369 | + self.mark_edge(b, default); |
| 370 | + } |
| 371 | + Lattice::Const(_) => { |
| 372 | + let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect(); |
| 373 | + let default = *default; |
| 374 | + for t in cases_clone { |
| 375 | + self.mark_edge(b, t); |
| 376 | + } |
| 377 | + self.mark_edge(b, default); |
| 378 | + } |
| 379 | + Lattice::Top => {} |
| 380 | + }, |
| 381 | + } |
| 382 | + } |
| 383 | + |
| 384 | + /// Re-meet a block param's lattice value over all reachable |
| 385 | + /// incoming edges. |
| 386 | + fn recompute_param(&mut self, param_id: ValueId) { |
| 387 | + let Some(info) = self.param_inputs.get(¶m_id).cloned() else { |
| 388 | + return; |
| 389 | + }; |
| 390 | + let block = info.block; |
| 391 | + let pi = info.param_idx; |
| 392 | + let preds = self.preds.get(&block).cloned().unwrap_or_default(); |
| 393 | + let mut new = Lattice::Top; |
| 394 | + for pred in preds { |
| 395 | + if !self.reachable_edges.contains(&(pred, block)) { |
| 396 | + continue; |
| 397 | + } |
| 398 | + let Some(args) = self.edge_args.get(&(pred, block)) else { |
| 399 | + continue; |
| 400 | + }; |
| 401 | + let Some(arg) = args.get(pi) else { |
| 402 | + continue; |
| 403 | + }; |
| 404 | + // If the arg IS the param itself (back-edge with same |
| 405 | + // value), don't pollute with our own current state. |
| 406 | + // SSA permits this, e.g. loop induction merges. |
| 407 | + if *arg == param_id { |
| 408 | + continue; |
| 409 | + } |
| 410 | + new = meet(new, self.lat(*arg)); |
| 411 | + if new == Lattice::Bottom { |
| 412 | + break; |
| 413 | + } |
| 414 | + } |
| 415 | + // Convert "Top with at least one reachable edge" → keep Top |
| 416 | + // until a concrete arg shows up (Top is sound: still optimistic). |
| 417 | + // monotone update via set_lat. |
| 418 | + let old = self.lat(param_id); |
| 419 | + let merged = meet(old, new); |
| 420 | + if merged != old { |
| 421 | + self.lattice.insert(param_id, merged); |
| 422 | + self.ssa_worklist.push_back(param_id); |
| 423 | + } |
| 424 | + } |
| 425 | + |
| 426 | + fn run(&mut self) { |
| 427 | + // Seed lattice for every constant-defining instruction. |
| 428 | + for block in &self.func.blocks { |
| 429 | + for inst in &block.insts { |
| 430 | + if let Some(c) = ConstVal::from_inst(&inst.kind) { |
| 431 | + self.lattice.insert(inst.id, Lattice::Const(c)); |
| 432 | + } else { |
| 433 | + // Non-const, non-block-param values default to |
| 434 | + // Bottom — SCCP doesn't model arithmetic |
| 435 | + // transfer. const_fold handles those after we |
| 436 | + // expose new reachability. |
| 437 | + self.lattice.insert(inst.id, Lattice::Bottom); |
| 438 | + } |
| 439 | + } |
| 440 | + } |
| 441 | + // Function params are unknown → Bottom. |
| 442 | + for p in &self.func.params { |
| 443 | + self.lattice.insert(p.id, Lattice::Bottom); |
| 444 | + } |
| 445 | + |
| 446 | + // Entry is always reachable. |
| 447 | + self.mark_block_reachable(self.func.entry); |
| 448 | + |
| 449 | + while !self.cfg_worklist.is_empty() || !self.ssa_worklist.is_empty() { |
| 450 | + while let Some((from, to)) = self.cfg_worklist.pop_front() { |
| 451 | + let _ = from; |
| 452 | + // Re-meet every block param of `to`, since a new |
| 453 | + // reachable edge may add a new input. |
| 454 | + let params: Vec<ValueId> = self |
| 455 | + .func |
| 456 | + .blocks |
| 457 | + .iter() |
| 458 | + .find(|b| b.id == to) |
| 459 | + .map(|b| b.params.iter().map(|p| p.id).collect()) |
| 460 | + .unwrap_or_default(); |
| 461 | + for pid in params { |
| 462 | + self.recompute_param(pid); |
| 463 | + } |
| 464 | + } |
| 465 | + while let Some(v) = self.ssa_worklist.pop_front() { |
| 466 | + let users = self.value_users.get(&v).cloned().unwrap_or_default(); |
| 467 | + for u in users { |
| 468 | + match u { |
| 469 | + ValueUser::Terminator(b) => { |
| 470 | + if self.reachable_blocks.contains(&b) { |
| 471 | + self.process_terminator(b); |
| 472 | + } |
| 473 | + } |
| 474 | + ValueUser::BlockParam(pid) => { |
| 475 | + self.recompute_param(pid); |
| 476 | + } |
| 477 | + } |
| 478 | + } |
| 479 | + // Block params themselves change → reanalyze |
| 480 | + // terminators that consume them. Already handled by |
| 481 | + // `value_users`. |
| 482 | + let _ = self.block_param_owner.get(&v); |
| 483 | + } |
| 484 | + } |
| 485 | + } |
| 486 | +} |
| 487 | + |
| 488 | +/// Result of SCCP analysis — the data needed by `apply` without |
| 489 | +/// holding any borrow on the function. |
| 490 | +struct SccpResult { |
| 491 | + lattice: HashMap<ValueId, Lattice>, |
| 492 | +} |
| 493 | + |
| 494 | +/// Apply the SCCP analysis result to `func`. Returns whether anything |
| 495 | +/// was rewritten. |
| 496 | +fn apply(func: &mut Function, analysis: &SccpResult) -> bool { |
| 497 | + let mut changed = false; |
| 498 | + |
| 499 | + // Step 1: rewrite constant block params. |
| 500 | + // |
| 501 | + // For every block param whose lattice resolved to `Const`: |
| 502 | + // - Insert a fresh `Const*` instruction at the front of its |
| 503 | + // block, allocate a new ValueId for it. |
| 504 | + // - Substitute uses of the block-param ValueId with the new |
| 505 | + // constant ValueId across the function. |
| 506 | + // - Remove the param from the block's `params`. |
| 507 | + // - Remove the matching argument from every predecessor's |
| 508 | + // branch terminator. |
| 509 | + // |
| 510 | + // We do this block-by-block, scanning back→front within the |
| 511 | + // params list so index-based predecessor arg removal stays |
| 512 | + // valid. |
| 513 | + let mut const_param_rewrites: Vec<(BlockId, usize, ValueId, ConstVal, IrType)> = Vec::new(); |
| 514 | + for block in &func.blocks { |
| 515 | + for (pi, p) in block.params.iter().enumerate() { |
| 516 | + if let Some(Lattice::Const(c)) = analysis.lattice.get(&p.id).copied() { |
| 517 | + // Sanity: lattice type should be compatible with |
| 518 | + // declared param type. If not (defensive), skip. |
| 519 | + if !ty_compatible(&p.ty, &c.ir_type()) { |
| 520 | + continue; |
| 521 | + } |
| 522 | + const_param_rewrites.push((block.id, pi, p.id, c, p.ty.clone())); |
| 523 | + } |
| 524 | + } |
| 525 | + } |
| 526 | + |
| 527 | + if !const_param_rewrites.is_empty() { |
| 528 | + // Sort by (block, descending param index) so predecessor |
| 529 | + // arg removal is index-stable as we mutate. |
| 530 | + const_param_rewrites.sort_by(|a, b| b.1.cmp(&a.1)); |
| 531 | + for (block_id, pi, param_id, cval, ty) in const_param_rewrites { |
| 532 | + // Allocate a new value for the constant. |
| 533 | + let new_id = func.next_value_id(); |
| 534 | + // Insert at the front of the target block. |
| 535 | + if let Some(block) = func.blocks.iter_mut().find(|b| b.id == block_id) { |
| 536 | + let span = block |
| 537 | + .insts |
| 538 | + .first() |
| 539 | + .map(|i| i.span) |
| 540 | + .or_else(|| { |
| 541 | + block.terminator.as_ref().map(|_| crate::lexer::Span { |
| 542 | + start: crate::lexer::Position { line: 1, col: 1 }, |
| 543 | + end: crate::lexer::Position { line: 1, col: 1 }, |
| 544 | + file_id: 0, |
| 545 | + }) |
| 546 | + }) |
| 547 | + .unwrap_or(crate::lexer::Span { |
| 548 | + start: crate::lexer::Position { line: 1, col: 1 }, |
| 549 | + end: crate::lexer::Position { line: 1, col: 1 }, |
| 550 | + file_id: 0, |
| 551 | + }); |
| 552 | + block.insts.insert( |
| 553 | + 0, |
| 554 | + Inst { |
| 555 | + id: new_id, |
| 556 | + kind: cval.to_inst_kind(), |
| 557 | + ty: ty.clone(), |
| 558 | + span, |
| 559 | + }, |
| 560 | + ); |
| 561 | + // Drop the param. |
| 562 | + block.params.remove(pi); |
| 563 | + } |
| 564 | + // Rewrite uses of param_id → new_id. |
| 565 | + crate::ir::walk::substitute_uses(func, param_id, new_id); |
| 566 | + // Drop the matching argument in every predecessor's |
| 567 | + // branch. |
| 568 | + for pred_block in &mut func.blocks { |
| 569 | + let Some(term) = pred_block.terminator.as_mut() else { |
| 570 | + continue; |
| 571 | + }; |
| 572 | + match term { |
| 573 | + Terminator::Branch(dest, args) if *dest == block_id => { |
| 574 | + if pi < args.len() { |
| 575 | + args.remove(pi); |
| 576 | + } |
| 577 | + } |
| 578 | + Terminator::CondBranch { |
| 579 | + true_dest, |
| 580 | + true_args, |
| 581 | + false_dest, |
| 582 | + false_args, |
| 583 | + .. |
| 584 | + } => { |
| 585 | + if *true_dest == block_id && pi < true_args.len() { |
| 586 | + true_args.remove(pi); |
| 587 | + } |
| 588 | + if *false_dest == block_id && pi < false_args.len() { |
| 589 | + false_args.remove(pi); |
| 590 | + } |
| 591 | + } |
| 592 | + _ => {} |
| 593 | + } |
| 594 | + } |
| 595 | + changed = true; |
| 596 | + } |
| 597 | + func.rebuild_type_cache(); |
| 598 | + } |
| 599 | + |
| 600 | + // Step 2: fold constant-condition terminators. |
| 601 | + // |
| 602 | + // Re-collect const map (block-param rewrites just added new Const |
| 603 | + // instructions). |
| 604 | + let mut consts: HashMap<ValueId, ConstVal> = HashMap::new(); |
| 605 | + for block in &func.blocks { |
| 606 | + for inst in &block.insts { |
| 607 | + if let Some(c) = ConstVal::from_inst(&inst.kind) { |
| 608 | + consts.insert(inst.id, c); |
| 609 | + } |
| 610 | + } |
| 611 | + } |
| 612 | + let mut folded_any = false; |
| 613 | + for block in &mut func.blocks { |
| 614 | + let Some(term) = block.terminator.take() else { |
| 615 | + continue; |
| 616 | + }; |
| 617 | + let new_term = match term { |
| 618 | + Terminator::CondBranch { |
| 619 | + cond, |
| 620 | + true_dest, |
| 621 | + true_args, |
| 622 | + false_dest, |
| 623 | + false_args, |
| 624 | + } => match consts.get(&cond) { |
| 625 | + Some(ConstVal::Bool(true)) => { |
| 626 | + folded_any = true; |
| 627 | + Terminator::Branch(true_dest, true_args) |
| 628 | + } |
| 629 | + Some(ConstVal::Bool(false)) => { |
| 630 | + folded_any = true; |
| 631 | + Terminator::Branch(false_dest, false_args) |
| 632 | + } |
| 633 | + _ => Terminator::CondBranch { |
| 634 | + cond, |
| 635 | + true_dest, |
| 636 | + true_args, |
| 637 | + false_dest, |
| 638 | + false_args, |
| 639 | + }, |
| 640 | + }, |
| 641 | + Terminator::Switch { |
| 642 | + selector, |
| 643 | + cases, |
| 644 | + default, |
| 645 | + } => match consts.get(&selector) { |
| 646 | + Some(ConstVal::Int(sv, w)) => { |
| 647 | + folded_any = true; |
| 648 | + let bits = w.bits(); |
| 649 | + let target = cases |
| 650 | + .iter() |
| 651 | + .find(|(k, _)| sext(*k as i128, bits) == *sv) |
| 652 | + .map(|(_, blk)| *blk) |
| 653 | + .unwrap_or(default); |
| 654 | + Terminator::Branch(target, vec![]) |
| 655 | + } |
| 656 | + _ => Terminator::Switch { |
| 657 | + selector, |
| 658 | + cases, |
| 659 | + default, |
| 660 | + }, |
| 661 | + }, |
| 662 | + other => other, |
| 663 | + }; |
| 664 | + block.terminator = Some(new_term); |
| 665 | + } |
| 666 | + if folded_any { |
| 667 | + changed = true; |
| 668 | + } |
| 669 | + |
| 670 | + // Step 3: prune unreachable blocks. |
| 671 | + // |
| 672 | + // The SCCP analysis itself produced a reachability set, but the |
| 673 | + // safe thing is to recompute it post-rewrite — block-param |
| 674 | + // rewrites and terminator folds can change the CFG. |
| 675 | + if prune_unreachable(func) { |
| 676 | + changed = true; |
| 677 | + } |
| 678 | + if changed { |
| 679 | + func.rebuild_type_cache(); |
| 680 | + } |
| 681 | + |
| 682 | + changed |
| 683 | +} |
| 684 | + |
| 685 | +fn ty_compatible(a: &IrType, b: &IrType) -> bool { |
| 686 | + // Cheap structural compare for the const types we materialize. |
| 687 | + match (a, b) { |
| 688 | + (IrType::Bool, IrType::Bool) => true, |
| 689 | + (IrType::Int(wa), IrType::Int(wb)) => wa == wb, |
| 690 | + (IrType::Float(wa), IrType::Float(wb)) => wa == wb, |
| 691 | + _ => false, |
| 692 | + } |
| 693 | +} |
| 694 | + |
| 695 | +/// The pass entry point. |
| 696 | +pub struct Sccp_; |
| 697 | + |
| 698 | +impl Pass for Sccp_ { |
| 699 | + fn name(&self) -> &'static str { |
| 700 | + "sccp" |
| 701 | + } |
| 702 | + |
| 703 | + fn run(&self, module: &mut Module) -> bool { |
| 704 | + let mut changed = false; |
| 705 | + for func in &mut module.functions { |
| 706 | + let analysis = { |
| 707 | + let mut s = Sccp::new(func); |
| 708 | + s.run(); |
| 709 | + SccpResult { lattice: s.lattice } |
| 710 | + }; |
| 711 | + if apply(func, &analysis) { |
| 712 | + changed = true; |
| 713 | + } |
| 714 | + } |
| 715 | + changed |
| 716 | + } |
| 717 | +} |
| 718 | + |
| 719 | +#[cfg(test)] |
| 720 | +mod tests { |
| 721 | + use super::*; |
| 722 | + use crate::ir::types::IrType; |
| 723 | + use crate::lexer::{Position, Span}; |
| 724 | + |
| 725 | + fn dummy_span() -> Span { |
| 726 | + let p = Position { line: 1, col: 1 }; |
| 727 | + Span { |
| 728 | + start: p, |
| 729 | + end: p, |
| 730 | + file_id: 0, |
| 731 | + } |
| 732 | + } |
| 733 | + |
| 734 | + #[test] |
| 735 | + fn folds_constant_true_condbranch() { |
| 736 | + // entry: cond = const(true); cond_branch cond, then, else |
| 737 | + // then: ret |
| 738 | + // else: ret |
| 739 | + let mut m = Module::new("t".into()); |
| 740 | + let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 741 | + let bb_t = f.create_block("then"); |
| 742 | + let bb_e = f.create_block("else"); |
| 743 | + let cond_id = f.next_value_id(); |
| 744 | + f.block_mut(f.entry).insts.push(Inst { |
| 745 | + id: cond_id, |
| 746 | + kind: InstKind::ConstBool(true), |
| 747 | + ty: IrType::Bool, |
| 748 | + span: dummy_span(), |
| 749 | + }); |
| 750 | + f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 751 | + cond: cond_id, |
| 752 | + true_dest: bb_t, |
| 753 | + true_args: vec![], |
| 754 | + false_dest: bb_e, |
| 755 | + false_args: vec![], |
| 756 | + }); |
| 757 | + f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 758 | + f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 759 | + m.add_function(f); |
| 760 | + |
| 761 | + assert!(Sccp_.run(&mut m)); |
| 762 | + let f = &m.functions[0]; |
| 763 | + match &f.blocks[0].terminator { |
| 764 | + Some(Terminator::Branch(d, _)) => assert_eq!(*d, bb_t), |
| 765 | + other => panic!("expected Branch, got {:?}", other), |
| 766 | + } |
| 767 | + // else block should be pruned. |
| 768 | + assert!(f.blocks.iter().all(|b| b.id != bb_e)); |
| 769 | + } |
| 770 | + |
| 771 | + #[test] |
| 772 | + fn merge_param_with_uniform_const_resolves() { |
| 773 | + // entry: cond=const(true); cond_branch cond, a, b |
| 774 | + // a: branch merge(const 7) |
| 775 | + // b: branch merge(const 99) ; statically unreachable |
| 776 | + // merge: param p; ret p |
| 777 | + // |
| 778 | + // Plain const_prop can't fold this — `p` has two distinct |
| 779 | + // incoming arg values, and const_prop doesn't know `b` is |
| 780 | + // unreachable. SCCP does. |
| 781 | + let mut m = Module::new("t".into()); |
| 782 | + let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32)); |
| 783 | + |
| 784 | + let bb_a = f.create_block("a"); |
| 785 | + let bb_b = f.create_block("b"); |
| 786 | + let bb_merge = f.create_block("merge"); |
| 787 | + |
| 788 | + let cond_id = f.next_value_id(); |
| 789 | + let const_a = f.next_value_id(); |
| 790 | + let const_b = f.next_value_id(); |
| 791 | + let merge_param = f.next_value_id(); |
| 792 | + |
| 793 | + f.block_mut(bb_merge).params.push(BlockParam { |
| 794 | + id: merge_param, |
| 795 | + ty: IrType::Int(IntWidth::I32), |
| 796 | + }); |
| 797 | + |
| 798 | + f.block_mut(f.entry).insts.push(Inst { |
| 799 | + id: cond_id, |
| 800 | + kind: InstKind::ConstBool(true), |
| 801 | + ty: IrType::Bool, |
| 802 | + span: dummy_span(), |
| 803 | + }); |
| 804 | + f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 805 | + cond: cond_id, |
| 806 | + true_dest: bb_a, |
| 807 | + true_args: vec![], |
| 808 | + false_dest: bb_b, |
| 809 | + false_args: vec![], |
| 810 | + }); |
| 811 | + |
| 812 | + f.block_mut(bb_a).insts.push(Inst { |
| 813 | + id: const_a, |
| 814 | + kind: InstKind::ConstInt(7, IntWidth::I32), |
| 815 | + ty: IrType::Int(IntWidth::I32), |
| 816 | + span: dummy_span(), |
| 817 | + }); |
| 818 | + f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a])); |
| 819 | + |
| 820 | + f.block_mut(bb_b).insts.push(Inst { |
| 821 | + id: const_b, |
| 822 | + kind: InstKind::ConstInt(99, IntWidth::I32), |
| 823 | + ty: IrType::Int(IntWidth::I32), |
| 824 | + span: dummy_span(), |
| 825 | + }); |
| 826 | + f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b])); |
| 827 | + |
| 828 | + f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param))); |
| 829 | + f.rebuild_type_cache(); |
| 830 | + m.add_function(f); |
| 831 | + |
| 832 | + assert!(Sccp_.run(&mut m), "SCCP should fold this"); |
| 833 | + |
| 834 | + let f = &m.functions[0]; |
| 835 | + // merge block: param removed, new ConstInt(7) inserted, ret |
| 836 | + // now references that const. |
| 837 | + let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap(); |
| 838 | + assert!(merge.params.is_empty(), "merge param should be gone"); |
| 839 | + // First inst should be a ConstInt(7). |
| 840 | + match merge.insts.first().map(|i| &i.kind) { |
| 841 | + Some(InstKind::ConstInt(7, IntWidth::I32)) => {} |
| 842 | + other => panic!("expected ConstInt(7,I32) at start of merge, got {:?}", other), |
| 843 | + } |
| 844 | + // bb_b is unreachable post-fold and should be pruned. |
| 845 | + assert!( |
| 846 | + !f.blocks.iter().any(|b| b.id == bb_b), |
| 847 | + "unreachable b should be pruned" |
| 848 | + ); |
| 849 | + } |
| 850 | + |
| 851 | + #[test] |
| 852 | + fn non_constant_param_is_not_rewritten() { |
| 853 | + // entry: cond_branch <param> a, b |
| 854 | + // a: branch merge(const 1) |
| 855 | + // b: branch merge(const 2) |
| 856 | + // merge: ret param |
| 857 | + // |
| 858 | + // Both arms reachable, two distinct constants → param stays. |
| 859 | + let mut m = Module::new("t".into()); |
| 860 | + let params = vec![Param { |
| 861 | + name: "c".into(), |
| 862 | + ty: IrType::Bool, |
| 863 | + id: ValueId(0), |
| 864 | + fortran_noalias: false, |
| 865 | + }]; |
| 866 | + let mut f = Function::new("f".into(), params, IrType::Int(IntWidth::I32)); |
| 867 | + let bb_a = f.create_block("a"); |
| 868 | + let bb_b = f.create_block("b"); |
| 869 | + let bb_merge = f.create_block("merge"); |
| 870 | + |
| 871 | + let const_a = f.next_value_id(); |
| 872 | + let const_b = f.next_value_id(); |
| 873 | + let merge_param = f.next_value_id(); |
| 874 | + f.block_mut(bb_merge).params.push(BlockParam { |
| 875 | + id: merge_param, |
| 876 | + ty: IrType::Int(IntWidth::I32), |
| 877 | + }); |
| 878 | + f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 879 | + cond: ValueId(0), |
| 880 | + true_dest: bb_a, |
| 881 | + true_args: vec![], |
| 882 | + false_dest: bb_b, |
| 883 | + false_args: vec![], |
| 884 | + }); |
| 885 | + f.block_mut(bb_a).insts.push(Inst { |
| 886 | + id: const_a, |
| 887 | + kind: InstKind::ConstInt(1, IntWidth::I32), |
| 888 | + ty: IrType::Int(IntWidth::I32), |
| 889 | + span: dummy_span(), |
| 890 | + }); |
| 891 | + f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a])); |
| 892 | + f.block_mut(bb_b).insts.push(Inst { |
| 893 | + id: const_b, |
| 894 | + kind: InstKind::ConstInt(2, IntWidth::I32), |
| 895 | + ty: IrType::Int(IntWidth::I32), |
| 896 | + span: dummy_span(), |
| 897 | + }); |
| 898 | + f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b])); |
| 899 | + f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param))); |
| 900 | + f.rebuild_type_cache(); |
| 901 | + m.add_function(f); |
| 902 | + |
| 903 | + let _changed = Sccp_.run(&mut m); |
| 904 | + let f = &m.functions[0]; |
| 905 | + // The merge block's param must remain — both arms are |
| 906 | + // genuinely reachable and disagree. |
| 907 | + let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap(); |
| 908 | + assert_eq!( |
| 909 | + merge.params.len(), |
| 910 | + 1, |
| 911 | + "merge param should survive when both arms are reachable and disagree" |
| 912 | + ); |
| 913 | + } |
| 914 | + |
| 915 | + #[test] |
| 916 | + fn unknown_cond_left_alone() { |
| 917 | + let mut m = Module::new("t".into()); |
| 918 | + let params = vec![Param { |
| 919 | + name: "p".into(), |
| 920 | + ty: IrType::Bool, |
| 921 | + id: ValueId(0), |
| 922 | + fortran_noalias: false, |
| 923 | + }]; |
| 924 | + let mut f = Function::new("f".into(), params, IrType::Void); |
| 925 | + let bb_t = f.create_block("then"); |
| 926 | + let bb_e = f.create_block("else"); |
| 927 | + f.block_mut(f.entry).terminator = Some(Terminator::CondBranch { |
| 928 | + cond: ValueId(0), |
| 929 | + true_dest: bb_t, |
| 930 | + true_args: vec![], |
| 931 | + false_dest: bb_e, |
| 932 | + false_args: vec![], |
| 933 | + }); |
| 934 | + f.block_mut(bb_t).terminator = Some(Terminator::Return(None)); |
| 935 | + f.block_mut(bb_e).terminator = Some(Terminator::Return(None)); |
| 936 | + m.add_function(f); |
| 937 | + |
| 938 | + assert!(!Sccp_.run(&mut m)); |
| 939 | + assert!(matches!( |
| 940 | + m.functions[0].blocks[0].terminator, |
| 941 | + Some(Terminator::CondBranch { .. }) |
| 942 | + )); |
| 943 | + } |
| 944 | +} |