fortrangoingonforty/armfortas / 38dd813

Browse files

Add SCCP pass with reachability + block-param const folding

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
38dd813df6c25a2006dd9703b260320833c3e658
Parents
96ef033
Tree
2b32cf4

1 changed file

StatusFile+-
A src/opt/sccp.rs 944 0
src/opt/sccp.rsadded
@@ -0,0 +1,944 @@
1
+//! Sparse Conditional Constant Propagation (SCCP) pass.
2
+//!
3
+//! Combines constant tracking with reachability analysis. The key
4
+//! insight over plain `const_fold` + `const_prop` is that constants
5
+//! flowing into a block parameter via a CFG edge whose source is
6
+//! statically unreachable do **not** force the parameter to Bottom.
7
+//! That lets SCCP fold cases the basic passes can't reach, e.g.
8
+//! ```text
9
+//! entry: br .true., a, b
10
+//! a:     br merge(c1)
11
+//! b:     br merge(c2)             ; b unreachable from entry
12
+//! merge: param p; ... use p ...
13
+//! ```
14
+//! After SCCP, the merge param `p` is constant `c1` because only the
15
+//! `a → merge` edge is reachable.
16
+//!
17
+//! Algorithm (Wegman-Zadeck, 1991):
18
+//!  1. Lattice: `Top` (not yet seen), `Const(c)` (proven constant),
19
+//!     `Bottom` (overdefined / non-constant).
20
+//!  2. Reachability: a `HashSet<BlockId>` of reachable blocks plus a
21
+//!     `HashSet<(BlockId, BlockId)>` of reachable CFG edges.
22
+//!  3. Two worklists:
23
+//!      * SSA-edge worklist — re-evaluate uses when a value's lattice
24
+//!        moves down (Top → Const, Top → Bottom, Const → Bottom).
25
+//!      * CFG-edge worklist — when a new edge becomes reachable,
26
+//!        re-meet the destination block's params.
27
+//!  4. Iterate until both worklists are empty.
28
+//!
29
+//! After fixpoint we materialize the result:
30
+//!  * Each constant block-param is rewritten: a `Const*` instruction
31
+//!    is inserted at the start of its block; uses are redirected;
32
+//!    the param entry and matching predecessor args are dropped.
33
+//!  * Each constant-condition branch is folded to an unconditional
34
+//!    `Branch` to the live target.
35
+//!  * `prune_unreachable` reaps the now-dead blocks.
36
+//!
37
+//! Per-instruction arithmetic transfer (e.g. `IAdd Const Const →
38
+//! Const`) is intentionally left to `const_fold`. The pass-manager
39
+//! fixpoint composes them: SCCP exposes new reachability →
40
+//! `const_fold` propagates → SCCP runs again. This keeps SCCP narrow
41
+//! and avoids duplicating the dozens of width-aware folds in
42
+//! `const_fold::try_fold`.
43
+
44
+use super::pass::Pass;
45
+use super::util::prune_unreachable;
46
+use crate::ir::inst::*;
47
+use crate::ir::types::{FloatWidth, IntWidth, IrType};
48
+use std::collections::{HashMap, HashSet, VecDeque};
49
+
50
+#[derive(Debug, Clone, Copy, PartialEq)]
51
+enum ConstVal {
52
+    Int(i128, IntWidth),
53
+    Float(u64, FloatWidth), // bit pattern — avoids NaN equality issues
54
+    Bool(bool),
55
+}
56
+
57
+#[derive(Debug, Clone, Copy, PartialEq)]
58
+enum Lattice {
59
+    Top,
60
+    Const(ConstVal),
61
+    Bottom,
62
+}
63
+
64
+impl ConstVal {
65
+    fn from_inst(kind: &InstKind) -> Option<Self> {
66
+        match kind {
67
+            InstKind::ConstInt(v, w) => Some(ConstVal::Int(sext(*v, w.bits()), *w)),
68
+            InstKind::ConstFloat(v, w) => Some(ConstVal::Float(v.to_bits(), *w)),
69
+            InstKind::ConstBool(b) => Some(ConstVal::Bool(*b)),
70
+            _ => None,
71
+        }
72
+    }
73
+
74
+    fn to_inst_kind(self) -> InstKind {
75
+        match self {
76
+            ConstVal::Int(v, w) => InstKind::ConstInt(v, w),
77
+            ConstVal::Float(bits, w) => InstKind::ConstFloat(f64::from_bits(bits), w),
78
+            ConstVal::Bool(b) => InstKind::ConstBool(b),
79
+        }
80
+    }
81
+
82
+    fn ir_type(self) -> IrType {
83
+        match self {
84
+            ConstVal::Int(_, w) => IrType::Int(w),
85
+            ConstVal::Float(_, w) => IrType::Float(w),
86
+            ConstVal::Bool(_) => IrType::Bool,
87
+        }
88
+    }
89
+}
90
+
91
+fn sext(v: i128, bits: u32) -> i128 {
92
+    if bits >= 128 {
93
+        v
94
+    } else {
95
+        let shift = 128 - bits;
96
+        (v << shift) >> shift
97
+    }
98
+}
99
+
100
+fn meet(a: Lattice, b: Lattice) -> Lattice {
101
+    match (a, b) {
102
+        (Lattice::Top, x) | (x, Lattice::Top) => x,
103
+        (Lattice::Bottom, _) | (_, Lattice::Bottom) => Lattice::Bottom,
104
+        (Lattice::Const(x), Lattice::Const(y)) => {
105
+            if x == y {
106
+                Lattice::Const(x)
107
+            } else {
108
+                Lattice::Bottom
109
+            }
110
+        }
111
+    }
112
+}
113
+
114
+struct Sccp<'a> {
115
+    func: &'a Function,
116
+    /// Reverse map: ValueId → its defining block id (for SSA-edge
117
+    /// worklist processing of block params, since we need to know
118
+    /// which block's incoming edges to re-meet).
119
+    block_param_owner: HashMap<ValueId, BlockId>,
120
+    /// Lattice value per ValueId. Absent = Top.
121
+    lattice: HashMap<ValueId, Lattice>,
122
+    /// Reachable blocks.
123
+    reachable_blocks: HashSet<BlockId>,
124
+    /// Reachable CFG edges (pred, succ).
125
+    reachable_edges: HashSet<(BlockId, BlockId)>,
126
+    /// CFG worklist — newly reachable edges to process.
127
+    cfg_worklist: VecDeque<(BlockId, BlockId)>,
128
+    /// SSA worklist — values whose lattice changed; re-evaluate users
129
+    /// (block params and terminators that consume them).
130
+    ssa_worklist: VecDeque<ValueId>,
131
+    /// Predecessor list (built once).
132
+    preds: HashMap<BlockId, Vec<BlockId>>,
133
+    /// (pred_block, args_passed) per (pred → succ) edge — used to
134
+    /// look up what value flows into each block param along an edge.
135
+    /// Indexed by (pred, succ); empty for Return/Unreachable terms.
136
+    edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>>,
137
+    /// For each block param ValueId: the (pred → succ) edges that
138
+    /// feed it, plus its index within the block's param list. We
139
+    /// rebuild the meet over reachable edges whenever any of the
140
+    /// inputs changes.
141
+    param_inputs: HashMap<ValueId, ParamInputInfo>,
142
+    /// For each ValueId, the set of dependents we should requeue on
143
+    /// SSA-worklist when it changes. For now: only block params and
144
+    /// terminator-conditions are tracked (we don't model arithmetic).
145
+    value_users: HashMap<ValueId, Vec<ValueUser>>,
146
+}
147
+
148
+#[derive(Debug, Clone)]
149
+struct ParamInputInfo {
150
+    block: BlockId,
151
+    param_idx: usize,
152
+}
153
+
154
+#[derive(Debug, Clone)]
155
+enum ValueUser {
156
+    /// `cond` of a CondBranch in the given block; reanalyze the
157
+    /// terminator to maybe expose newly-unreachable successors.
158
+    Terminator(BlockId),
159
+    /// A block param that consumes this value along some edge —
160
+    /// re-meet the param when the value's lattice changes.
161
+    BlockParam(ValueId),
162
+}
163
+
164
+impl<'a> Sccp<'a> {
165
+    fn new(func: &'a Function) -> Self {
166
+        let mut block_param_owner = HashMap::new();
167
+        let mut param_inputs = HashMap::new();
168
+        for block in &func.blocks {
169
+            for (i, p) in block.params.iter().enumerate() {
170
+                block_param_owner.insert(p.id, block.id);
171
+                param_inputs.insert(
172
+                    p.id,
173
+                    ParamInputInfo {
174
+                        block: block.id,
175
+                        param_idx: i,
176
+                    },
177
+                );
178
+            }
179
+        }
180
+
181
+        let mut preds: HashMap<BlockId, Vec<BlockId>> = HashMap::new();
182
+        let mut edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>> = HashMap::new();
183
+        for block in &func.blocks {
184
+            preds.entry(block.id).or_default();
185
+        }
186
+        for block in &func.blocks {
187
+            let Some(term) = &block.terminator else {
188
+                continue;
189
+            };
190
+            match term {
191
+                Terminator::Branch(succ, args) => {
192
+                    preds.entry(*succ).or_default().push(block.id);
193
+                    edge_args.insert((block.id, *succ), args.clone());
194
+                }
195
+                Terminator::CondBranch {
196
+                    true_dest,
197
+                    true_args,
198
+                    false_dest,
199
+                    false_args,
200
+                    ..
201
+                } => {
202
+                    preds.entry(*true_dest).or_default().push(block.id);
203
+                    preds.entry(*false_dest).or_default().push(block.id);
204
+                    edge_args.insert((block.id, *true_dest), true_args.clone());
205
+                    edge_args.insert((block.id, *false_dest), false_args.clone());
206
+                }
207
+                Terminator::Switch { cases, default, .. } => {
208
+                    preds.entry(*default).or_default().push(block.id);
209
+                    edge_args.insert((block.id, *default), vec![]);
210
+                    for (_, tgt) in cases {
211
+                        preds.entry(*tgt).or_default().push(block.id);
212
+                        edge_args.insert((block.id, *tgt), vec![]);
213
+                    }
214
+                }
215
+                Terminator::Return(_) | Terminator::Unreachable => {}
216
+            }
217
+        }
218
+
219
+        // Build dependency graph: for each value, the users that need
220
+        // to be reanalyzed when it changes.
221
+        let mut value_users: HashMap<ValueId, Vec<ValueUser>> = HashMap::new();
222
+        for block in &func.blocks {
223
+            // Block-param dependents: each param's incoming arg
224
+            // along each pred edge.
225
+            for (pi, p) in block.params.iter().enumerate() {
226
+                for pred in preds.get(&block.id).cloned().unwrap_or_default() {
227
+                    if let Some(args) = edge_args.get(&(pred, block.id)) {
228
+                        if let Some(arg) = args.get(pi) {
229
+                            value_users
230
+                                .entry(*arg)
231
+                                .or_default()
232
+                                .push(ValueUser::BlockParam(p.id));
233
+                        }
234
+                    }
235
+                }
236
+            }
237
+            // Terminator-condition dependents (only CondBranch/Switch
238
+            // can fold on a constant condition).
239
+            if let Some(term) = &block.terminator {
240
+                match term {
241
+                    Terminator::CondBranch { cond, .. } => {
242
+                        value_users
243
+                            .entry(*cond)
244
+                            .or_default()
245
+                            .push(ValueUser::Terminator(block.id));
246
+                    }
247
+                    Terminator::Switch { selector, .. } => {
248
+                        value_users
249
+                            .entry(*selector)
250
+                            .or_default()
251
+                            .push(ValueUser::Terminator(block.id));
252
+                    }
253
+                    _ => {}
254
+                }
255
+            }
256
+        }
257
+
258
+        Self {
259
+            func,
260
+            block_param_owner,
261
+            lattice: HashMap::new(),
262
+            reachable_blocks: HashSet::new(),
263
+            reachable_edges: HashSet::new(),
264
+            cfg_worklist: VecDeque::new(),
265
+            ssa_worklist: VecDeque::new(),
266
+            preds,
267
+            edge_args,
268
+            param_inputs,
269
+            value_users,
270
+        }
271
+    }
272
+
273
+    fn lat(&self, v: ValueId) -> Lattice {
274
+        self.lattice.get(&v).copied().unwrap_or(Lattice::Top)
275
+    }
276
+
277
+    /// Set lattice value, queue SSA-worklist if changed and not
278
+    /// already at the bottom.
279
+    fn set_lat(&mut self, v: ValueId, new: Lattice) {
280
+        let old = self.lat(v);
281
+        let merged = meet(old, new);
282
+        if merged != old {
283
+            self.lattice.insert(v, merged);
284
+            self.ssa_worklist.push_back(v);
285
+        }
286
+    }
287
+
288
+    /// Make a block reachable (idempotent). Seeds CFG worklist with
289
+    /// outgoing edges and marks the block.
290
+    fn mark_block_reachable(&mut self, b: BlockId) {
291
+        if !self.reachable_blocks.insert(b) {
292
+            return;
293
+        }
294
+        // First time reaching this block — also seed its outgoing
295
+        // edges based on its (current) terminator analysis.
296
+        self.process_terminator(b);
297
+    }
298
+
299
+    /// Mark a CFG edge reachable; if newly so, mark the destination
300
+    /// block reachable and queue the edge for block-param re-meet.
301
+    fn mark_edge(&mut self, from: BlockId, to: BlockId) {
302
+        if self.reachable_edges.insert((from, to)) {
303
+            self.cfg_worklist.push_back((from, to));
304
+            self.mark_block_reachable(to);
305
+        }
306
+    }
307
+
308
+    /// Analyze a block's terminator and mark live successors.
309
+    /// Called when the block becomes reachable or when its
310
+    /// terminator-condition lattice changes.
311
+    fn process_terminator(&mut self, b: BlockId) {
312
+        let Some(block) = self.func.blocks.iter().find(|bb| bb.id == b) else {
313
+            return;
314
+        };
315
+        let Some(term) = &block.terminator else {
316
+            return;
317
+        };
318
+        match term {
319
+            Terminator::Return(_) | Terminator::Unreachable => {}
320
+            Terminator::Branch(dest, _) => {
321
+                let dest = *dest;
322
+                self.mark_edge(b, dest);
323
+            }
324
+            Terminator::CondBranch {
325
+                cond,
326
+                true_dest,
327
+                false_dest,
328
+                ..
329
+            } => match self.lat(*cond) {
330
+                Lattice::Const(ConstVal::Bool(true)) => self.mark_edge(b, *true_dest),
331
+                Lattice::Const(ConstVal::Bool(false)) => self.mark_edge(b, *false_dest),
332
+                Lattice::Bottom => {
333
+                    let t = *true_dest;
334
+                    let f = *false_dest;
335
+                    self.mark_edge(b, t);
336
+                    self.mark_edge(b, f);
337
+                }
338
+                // Const(non-bool) is a type error in well-typed IR;
339
+                // be conservative and mark both. Top means we'll
340
+                // revisit when `cond` resolves.
341
+                Lattice::Const(_) => {
342
+                    let t = *true_dest;
343
+                    let f = *false_dest;
344
+                    self.mark_edge(b, t);
345
+                    self.mark_edge(b, f);
346
+                }
347
+                Lattice::Top => {}
348
+            },
349
+            Terminator::Switch {
350
+                selector,
351
+                cases,
352
+                default,
353
+            } => match self.lat(*selector) {
354
+                Lattice::Const(ConstVal::Int(sv, w)) => {
355
+                    let bits = w.bits();
356
+                    let target = cases
357
+                        .iter()
358
+                        .find(|(k, _)| sext(*k as i128, bits) == sv)
359
+                        .map(|(_, blk)| *blk)
360
+                        .unwrap_or(*default);
361
+                    self.mark_edge(b, target);
362
+                }
363
+                Lattice::Bottom => {
364
+                    let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect();
365
+                    let default = *default;
366
+                    for t in cases_clone {
367
+                        self.mark_edge(b, t);
368
+                    }
369
+                    self.mark_edge(b, default);
370
+                }
371
+                Lattice::Const(_) => {
372
+                    let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect();
373
+                    let default = *default;
374
+                    for t in cases_clone {
375
+                        self.mark_edge(b, t);
376
+                    }
377
+                    self.mark_edge(b, default);
378
+                }
379
+                Lattice::Top => {}
380
+            },
381
+        }
382
+    }
383
+
384
+    /// Re-meet a block param's lattice value over all reachable
385
+    /// incoming edges.
386
+    fn recompute_param(&mut self, param_id: ValueId) {
387
+        let Some(info) = self.param_inputs.get(&param_id).cloned() else {
388
+            return;
389
+        };
390
+        let block = info.block;
391
+        let pi = info.param_idx;
392
+        let preds = self.preds.get(&block).cloned().unwrap_or_default();
393
+        let mut new = Lattice::Top;
394
+        for pred in preds {
395
+            if !self.reachable_edges.contains(&(pred, block)) {
396
+                continue;
397
+            }
398
+            let Some(args) = self.edge_args.get(&(pred, block)) else {
399
+                continue;
400
+            };
401
+            let Some(arg) = args.get(pi) else {
402
+                continue;
403
+            };
404
+            // If the arg IS the param itself (back-edge with same
405
+            // value), don't pollute with our own current state.
406
+            // SSA permits this, e.g. loop induction merges.
407
+            if *arg == param_id {
408
+                continue;
409
+            }
410
+            new = meet(new, self.lat(*arg));
411
+            if new == Lattice::Bottom {
412
+                break;
413
+            }
414
+        }
415
+        // Convert "Top with at least one reachable edge" → keep Top
416
+        // until a concrete arg shows up (Top is sound: still optimistic).
417
+        // monotone update via set_lat.
418
+        let old = self.lat(param_id);
419
+        let merged = meet(old, new);
420
+        if merged != old {
421
+            self.lattice.insert(param_id, merged);
422
+            self.ssa_worklist.push_back(param_id);
423
+        }
424
+    }
425
+
426
+    fn run(&mut self) {
427
+        // Seed lattice for every constant-defining instruction.
428
+        for block in &self.func.blocks {
429
+            for inst in &block.insts {
430
+                if let Some(c) = ConstVal::from_inst(&inst.kind) {
431
+                    self.lattice.insert(inst.id, Lattice::Const(c));
432
+                } else {
433
+                    // Non-const, non-block-param values default to
434
+                    // Bottom — SCCP doesn't model arithmetic
435
+                    // transfer. const_fold handles those after we
436
+                    // expose new reachability.
437
+                    self.lattice.insert(inst.id, Lattice::Bottom);
438
+                }
439
+            }
440
+        }
441
+        // Function params are unknown → Bottom.
442
+        for p in &self.func.params {
443
+            self.lattice.insert(p.id, Lattice::Bottom);
444
+        }
445
+
446
+        // Entry is always reachable.
447
+        self.mark_block_reachable(self.func.entry);
448
+
449
+        while !self.cfg_worklist.is_empty() || !self.ssa_worklist.is_empty() {
450
+            while let Some((from, to)) = self.cfg_worklist.pop_front() {
451
+                let _ = from;
452
+                // Re-meet every block param of `to`, since a new
453
+                // reachable edge may add a new input.
454
+                let params: Vec<ValueId> = self
455
+                    .func
456
+                    .blocks
457
+                    .iter()
458
+                    .find(|b| b.id == to)
459
+                    .map(|b| b.params.iter().map(|p| p.id).collect())
460
+                    .unwrap_or_default();
461
+                for pid in params {
462
+                    self.recompute_param(pid);
463
+                }
464
+            }
465
+            while let Some(v) = self.ssa_worklist.pop_front() {
466
+                let users = self.value_users.get(&v).cloned().unwrap_or_default();
467
+                for u in users {
468
+                    match u {
469
+                        ValueUser::Terminator(b) => {
470
+                            if self.reachable_blocks.contains(&b) {
471
+                                self.process_terminator(b);
472
+                            }
473
+                        }
474
+                        ValueUser::BlockParam(pid) => {
475
+                            self.recompute_param(pid);
476
+                        }
477
+                    }
478
+                }
479
+                // Block params themselves change → reanalyze
480
+                // terminators that consume them. Already handled by
481
+                // `value_users`.
482
+                let _ = self.block_param_owner.get(&v);
483
+            }
484
+        }
485
+    }
486
+}
487
+
488
+/// Result of SCCP analysis — the data needed by `apply` without
489
+/// holding any borrow on the function.
490
+struct SccpResult {
491
+    lattice: HashMap<ValueId, Lattice>,
492
+}
493
+
494
+/// Apply the SCCP analysis result to `func`. Returns whether anything
495
+/// was rewritten.
496
+fn apply(func: &mut Function, analysis: &SccpResult) -> bool {
497
+    let mut changed = false;
498
+
499
+    // Step 1: rewrite constant block params.
500
+    //
501
+    // For every block param whose lattice resolved to `Const`:
502
+    //   - Insert a fresh `Const*` instruction at the front of its
503
+    //     block, allocate a new ValueId for it.
504
+    //   - Substitute uses of the block-param ValueId with the new
505
+    //     constant ValueId across the function.
506
+    //   - Remove the param from the block's `params`.
507
+    //   - Remove the matching argument from every predecessor's
508
+    //     branch terminator.
509
+    //
510
+    // We do this block-by-block, scanning back→front within the
511
+    // params list so index-based predecessor arg removal stays
512
+    // valid.
513
+    let mut const_param_rewrites: Vec<(BlockId, usize, ValueId, ConstVal, IrType)> = Vec::new();
514
+    for block in &func.blocks {
515
+        for (pi, p) in block.params.iter().enumerate() {
516
+            if let Some(Lattice::Const(c)) = analysis.lattice.get(&p.id).copied() {
517
+                // Sanity: lattice type should be compatible with
518
+                // declared param type. If not (defensive), skip.
519
+                if !ty_compatible(&p.ty, &c.ir_type()) {
520
+                    continue;
521
+                }
522
+                const_param_rewrites.push((block.id, pi, p.id, c, p.ty.clone()));
523
+            }
524
+        }
525
+    }
526
+
527
+    if !const_param_rewrites.is_empty() {
528
+        // Sort by (block, descending param index) so predecessor
529
+        // arg removal is index-stable as we mutate.
530
+        const_param_rewrites.sort_by(|a, b| b.1.cmp(&a.1));
531
+        for (block_id, pi, param_id, cval, ty) in const_param_rewrites {
532
+            // Allocate a new value for the constant.
533
+            let new_id = func.next_value_id();
534
+            // Insert at the front of the target block.
535
+            if let Some(block) = func.blocks.iter_mut().find(|b| b.id == block_id) {
536
+                let span = block
537
+                    .insts
538
+                    .first()
539
+                    .map(|i| i.span)
540
+                    .or_else(|| {
541
+                        block.terminator.as_ref().map(|_| crate::lexer::Span {
542
+                            start: crate::lexer::Position { line: 1, col: 1 },
543
+                            end: crate::lexer::Position { line: 1, col: 1 },
544
+                            file_id: 0,
545
+                        })
546
+                    })
547
+                    .unwrap_or(crate::lexer::Span {
548
+                        start: crate::lexer::Position { line: 1, col: 1 },
549
+                        end: crate::lexer::Position { line: 1, col: 1 },
550
+                        file_id: 0,
551
+                    });
552
+                block.insts.insert(
553
+                    0,
554
+                    Inst {
555
+                        id: new_id,
556
+                        kind: cval.to_inst_kind(),
557
+                        ty: ty.clone(),
558
+                        span,
559
+                    },
560
+                );
561
+                // Drop the param.
562
+                block.params.remove(pi);
563
+            }
564
+            // Rewrite uses of param_id → new_id.
565
+            crate::ir::walk::substitute_uses(func, param_id, new_id);
566
+            // Drop the matching argument in every predecessor's
567
+            // branch.
568
+            for pred_block in &mut func.blocks {
569
+                let Some(term) = pred_block.terminator.as_mut() else {
570
+                    continue;
571
+                };
572
+                match term {
573
+                    Terminator::Branch(dest, args) if *dest == block_id => {
574
+                        if pi < args.len() {
575
+                            args.remove(pi);
576
+                        }
577
+                    }
578
+                    Terminator::CondBranch {
579
+                        true_dest,
580
+                        true_args,
581
+                        false_dest,
582
+                        false_args,
583
+                        ..
584
+                    } => {
585
+                        if *true_dest == block_id && pi < true_args.len() {
586
+                            true_args.remove(pi);
587
+                        }
588
+                        if *false_dest == block_id && pi < false_args.len() {
589
+                            false_args.remove(pi);
590
+                        }
591
+                    }
592
+                    _ => {}
593
+                }
594
+            }
595
+            changed = true;
596
+        }
597
+        func.rebuild_type_cache();
598
+    }
599
+
600
+    // Step 2: fold constant-condition terminators.
601
+    //
602
+    // Re-collect const map (block-param rewrites just added new Const
603
+    // instructions).
604
+    let mut consts: HashMap<ValueId, ConstVal> = HashMap::new();
605
+    for block in &func.blocks {
606
+        for inst in &block.insts {
607
+            if let Some(c) = ConstVal::from_inst(&inst.kind) {
608
+                consts.insert(inst.id, c);
609
+            }
610
+        }
611
+    }
612
+    let mut folded_any = false;
613
+    for block in &mut func.blocks {
614
+        let Some(term) = block.terminator.take() else {
615
+            continue;
616
+        };
617
+        let new_term = match term {
618
+            Terminator::CondBranch {
619
+                cond,
620
+                true_dest,
621
+                true_args,
622
+                false_dest,
623
+                false_args,
624
+            } => match consts.get(&cond) {
625
+                Some(ConstVal::Bool(true)) => {
626
+                    folded_any = true;
627
+                    Terminator::Branch(true_dest, true_args)
628
+                }
629
+                Some(ConstVal::Bool(false)) => {
630
+                    folded_any = true;
631
+                    Terminator::Branch(false_dest, false_args)
632
+                }
633
+                _ => Terminator::CondBranch {
634
+                    cond,
635
+                    true_dest,
636
+                    true_args,
637
+                    false_dest,
638
+                    false_args,
639
+                },
640
+            },
641
+            Terminator::Switch {
642
+                selector,
643
+                cases,
644
+                default,
645
+            } => match consts.get(&selector) {
646
+                Some(ConstVal::Int(sv, w)) => {
647
+                    folded_any = true;
648
+                    let bits = w.bits();
649
+                    let target = cases
650
+                        .iter()
651
+                        .find(|(k, _)| sext(*k as i128, bits) == *sv)
652
+                        .map(|(_, blk)| *blk)
653
+                        .unwrap_or(default);
654
+                    Terminator::Branch(target, vec![])
655
+                }
656
+                _ => Terminator::Switch {
657
+                    selector,
658
+                    cases,
659
+                    default,
660
+                },
661
+            },
662
+            other => other,
663
+        };
664
+        block.terminator = Some(new_term);
665
+    }
666
+    if folded_any {
667
+        changed = true;
668
+    }
669
+
670
+    // Step 3: prune unreachable blocks.
671
+    //
672
+    // The SCCP analysis itself produced a reachability set, but the
673
+    // safe thing is to recompute it post-rewrite — block-param
674
+    // rewrites and terminator folds can change the CFG.
675
+    if prune_unreachable(func) {
676
+        changed = true;
677
+    }
678
+    if changed {
679
+        func.rebuild_type_cache();
680
+    }
681
+
682
+    changed
683
+}
684
+
685
+fn ty_compatible(a: &IrType, b: &IrType) -> bool {
686
+    // Cheap structural compare for the const types we materialize.
687
+    match (a, b) {
688
+        (IrType::Bool, IrType::Bool) => true,
689
+        (IrType::Int(wa), IrType::Int(wb)) => wa == wb,
690
+        (IrType::Float(wa), IrType::Float(wb)) => wa == wb,
691
+        _ => false,
692
+    }
693
+}
694
+
695
+/// The pass entry point.
696
+pub struct Sccp_;
697
+
698
+impl Pass for Sccp_ {
699
+    fn name(&self) -> &'static str {
700
+        "sccp"
701
+    }
702
+
703
+    fn run(&self, module: &mut Module) -> bool {
704
+        let mut changed = false;
705
+        for func in &mut module.functions {
706
+            let analysis = {
707
+                let mut s = Sccp::new(func);
708
+                s.run();
709
+                SccpResult { lattice: s.lattice }
710
+            };
711
+            if apply(func, &analysis) {
712
+                changed = true;
713
+            }
714
+        }
715
+        changed
716
+    }
717
+}
718
+
719
+#[cfg(test)]
720
+mod tests {
721
+    use super::*;
722
+    use crate::ir::types::IrType;
723
+    use crate::lexer::{Position, Span};
724
+
725
+    fn dummy_span() -> Span {
726
+        let p = Position { line: 1, col: 1 };
727
+        Span {
728
+            start: p,
729
+            end: p,
730
+            file_id: 0,
731
+        }
732
+    }
733
+
734
+    #[test]
735
+    fn folds_constant_true_condbranch() {
736
+        // entry: cond = const(true); cond_branch cond, then, else
737
+        // then:  ret
738
+        // else:  ret
739
+        let mut m = Module::new("t".into());
740
+        let mut f = Function::new("f".into(), vec![], IrType::Void);
741
+        let bb_t = f.create_block("then");
742
+        let bb_e = f.create_block("else");
743
+        let cond_id = f.next_value_id();
744
+        f.block_mut(f.entry).insts.push(Inst {
745
+            id: cond_id,
746
+            kind: InstKind::ConstBool(true),
747
+            ty: IrType::Bool,
748
+            span: dummy_span(),
749
+        });
750
+        f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
751
+            cond: cond_id,
752
+            true_dest: bb_t,
753
+            true_args: vec![],
754
+            false_dest: bb_e,
755
+            false_args: vec![],
756
+        });
757
+        f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
758
+        f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
759
+        m.add_function(f);
760
+
761
+        assert!(Sccp_.run(&mut m));
762
+        let f = &m.functions[0];
763
+        match &f.blocks[0].terminator {
764
+            Some(Terminator::Branch(d, _)) => assert_eq!(*d, bb_t),
765
+            other => panic!("expected Branch, got {:?}", other),
766
+        }
767
+        // else block should be pruned.
768
+        assert!(f.blocks.iter().all(|b| b.id != bb_e));
769
+    }
770
+
771
+    #[test]
772
+    fn merge_param_with_uniform_const_resolves() {
773
+        // entry:  cond=const(true); cond_branch cond, a, b
774
+        // a:      branch merge(const 7)
775
+        // b:      branch merge(const 99)   ; statically unreachable
776
+        // merge:  param p; ret p
777
+        //
778
+        // Plain const_prop can't fold this — `p` has two distinct
779
+        // incoming arg values, and const_prop doesn't know `b` is
780
+        // unreachable. SCCP does.
781
+        let mut m = Module::new("t".into());
782
+        let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
783
+
784
+        let bb_a = f.create_block("a");
785
+        let bb_b = f.create_block("b");
786
+        let bb_merge = f.create_block("merge");
787
+
788
+        let cond_id = f.next_value_id();
789
+        let const_a = f.next_value_id();
790
+        let const_b = f.next_value_id();
791
+        let merge_param = f.next_value_id();
792
+
793
+        f.block_mut(bb_merge).params.push(BlockParam {
794
+            id: merge_param,
795
+            ty: IrType::Int(IntWidth::I32),
796
+        });
797
+
798
+        f.block_mut(f.entry).insts.push(Inst {
799
+            id: cond_id,
800
+            kind: InstKind::ConstBool(true),
801
+            ty: IrType::Bool,
802
+            span: dummy_span(),
803
+        });
804
+        f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
805
+            cond: cond_id,
806
+            true_dest: bb_a,
807
+            true_args: vec![],
808
+            false_dest: bb_b,
809
+            false_args: vec![],
810
+        });
811
+
812
+        f.block_mut(bb_a).insts.push(Inst {
813
+            id: const_a,
814
+            kind: InstKind::ConstInt(7, IntWidth::I32),
815
+            ty: IrType::Int(IntWidth::I32),
816
+            span: dummy_span(),
817
+        });
818
+        f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a]));
819
+
820
+        f.block_mut(bb_b).insts.push(Inst {
821
+            id: const_b,
822
+            kind: InstKind::ConstInt(99, IntWidth::I32),
823
+            ty: IrType::Int(IntWidth::I32),
824
+            span: dummy_span(),
825
+        });
826
+        f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b]));
827
+
828
+        f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param)));
829
+        f.rebuild_type_cache();
830
+        m.add_function(f);
831
+
832
+        assert!(Sccp_.run(&mut m), "SCCP should fold this");
833
+
834
+        let f = &m.functions[0];
835
+        // merge block: param removed, new ConstInt(7) inserted, ret
836
+        // now references that const.
837
+        let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap();
838
+        assert!(merge.params.is_empty(), "merge param should be gone");
839
+        // First inst should be a ConstInt(7).
840
+        match merge.insts.first().map(|i| &i.kind) {
841
+            Some(InstKind::ConstInt(7, IntWidth::I32)) => {}
842
+            other => panic!("expected ConstInt(7,I32) at start of merge, got {:?}", other),
843
+        }
844
+        // bb_b is unreachable post-fold and should be pruned.
845
+        assert!(
846
+            !f.blocks.iter().any(|b| b.id == bb_b),
847
+            "unreachable b should be pruned"
848
+        );
849
+    }
850
+
851
+    #[test]
852
+    fn non_constant_param_is_not_rewritten() {
853
+        // entry: cond_branch <param> a, b
854
+        // a:     branch merge(const 1)
855
+        // b:     branch merge(const 2)
856
+        // merge: ret param
857
+        //
858
+        // Both arms reachable, two distinct constants → param stays.
859
+        let mut m = Module::new("t".into());
860
+        let params = vec![Param {
861
+            name: "c".into(),
862
+            ty: IrType::Bool,
863
+            id: ValueId(0),
864
+            fortran_noalias: false,
865
+        }];
866
+        let mut f = Function::new("f".into(), params, IrType::Int(IntWidth::I32));
867
+        let bb_a = f.create_block("a");
868
+        let bb_b = f.create_block("b");
869
+        let bb_merge = f.create_block("merge");
870
+
871
+        let const_a = f.next_value_id();
872
+        let const_b = f.next_value_id();
873
+        let merge_param = f.next_value_id();
874
+        f.block_mut(bb_merge).params.push(BlockParam {
875
+            id: merge_param,
876
+            ty: IrType::Int(IntWidth::I32),
877
+        });
878
+        f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
879
+            cond: ValueId(0),
880
+            true_dest: bb_a,
881
+            true_args: vec![],
882
+            false_dest: bb_b,
883
+            false_args: vec![],
884
+        });
885
+        f.block_mut(bb_a).insts.push(Inst {
886
+            id: const_a,
887
+            kind: InstKind::ConstInt(1, IntWidth::I32),
888
+            ty: IrType::Int(IntWidth::I32),
889
+            span: dummy_span(),
890
+        });
891
+        f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a]));
892
+        f.block_mut(bb_b).insts.push(Inst {
893
+            id: const_b,
894
+            kind: InstKind::ConstInt(2, IntWidth::I32),
895
+            ty: IrType::Int(IntWidth::I32),
896
+            span: dummy_span(),
897
+        });
898
+        f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b]));
899
+        f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param)));
900
+        f.rebuild_type_cache();
901
+        m.add_function(f);
902
+
903
+        let _changed = Sccp_.run(&mut m);
904
+        let f = &m.functions[0];
905
+        // The merge block's param must remain — both arms are
906
+        // genuinely reachable and disagree.
907
+        let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap();
908
+        assert_eq!(
909
+            merge.params.len(),
910
+            1,
911
+            "merge param should survive when both arms are reachable and disagree"
912
+        );
913
+    }
914
+
915
+    #[test]
916
+    fn unknown_cond_left_alone() {
917
+        let mut m = Module::new("t".into());
918
+        let params = vec![Param {
919
+            name: "p".into(),
920
+            ty: IrType::Bool,
921
+            id: ValueId(0),
922
+            fortran_noalias: false,
923
+        }];
924
+        let mut f = Function::new("f".into(), params, IrType::Void);
925
+        let bb_t = f.create_block("then");
926
+        let bb_e = f.create_block("else");
927
+        f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
928
+            cond: ValueId(0),
929
+            true_dest: bb_t,
930
+            true_args: vec![],
931
+            false_dest: bb_e,
932
+            false_args: vec![],
933
+        });
934
+        f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
935
+        f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
936
+        m.add_function(f);
937
+
938
+        assert!(!Sccp_.run(&mut m));
939
+        assert!(matches!(
940
+            m.functions[0].blocks[0].terminator,
941
+            Some(Terminator::CondBranch { .. })
942
+        ));
943
+    }
944
+}