Rust · 34852 bytes Raw Blame History
1 //! Sparse Conditional Constant Propagation (SCCP) pass.
2 //!
3 //! Combines constant tracking with reachability analysis. The key
4 //! insight over plain `const_fold` + `const_prop` is that constants
5 //! flowing into a block parameter via a CFG edge whose source is
6 //! statically unreachable do **not** force the parameter to Bottom.
7 //! That lets SCCP fold cases the basic passes can't reach, e.g.
8 //! ```text
9 //! entry: br .true., a, b
10 //! a: br merge(c1)
11 //! b: br merge(c2) ; b unreachable from entry
12 //! merge: param p; ... use p ...
13 //! ```
14 //! After SCCP, the merge param `p` is constant `c1` because only the
15 //! `a → merge` edge is reachable.
16 //!
17 //! Algorithm (Wegman-Zadeck, 1991):
18 //! 1. Lattice: `Top` (not yet seen), `Const(c)` (proven constant),
19 //! `Bottom` (overdefined / non-constant).
20 //! 2. Reachability: a `HashSet<BlockId>` of reachable blocks plus a
21 //! `HashSet<(BlockId, BlockId)>` of reachable CFG edges.
22 //! 3. Two worklists:
23 //! * SSA-edge worklist — re-evaluate uses when a value's lattice
24 //! moves down (Top → Const, Top → Bottom, Const → Bottom).
25 //! * CFG-edge worklist — when a new edge becomes reachable,
26 //! re-meet the destination block's params.
27 //! 4. Iterate until both worklists are empty.
28 //!
29 //! After fixpoint we materialize the result:
30 //! * Each constant block-param is rewritten: a `Const*` instruction
31 //! is inserted at the start of its block; uses are redirected;
32 //! the param entry and matching predecessor args are dropped.
33 //! * Each constant-condition branch is folded to an unconditional
34 //! `Branch` to the live target.
35 //! * `prune_unreachable` reaps the now-dead blocks.
36 //!
37 //! Per-instruction arithmetic transfer (e.g. `IAdd Const Const →
38 //! Const`) is intentionally left to `const_fold`. The pass-manager
39 //! fixpoint composes them: SCCP exposes new reachability →
40 //! `const_fold` propagates → SCCP runs again. This keeps SCCP narrow
41 //! and avoids duplicating the dozens of width-aware folds in
42 //! `const_fold::try_fold`.
43
44 use super::pass::Pass;
45 use super::util::prune_unreachable;
46 use crate::ir::inst::*;
47 use crate::ir::types::{FloatWidth, IntWidth, IrType};
48 use std::collections::{HashMap, HashSet, VecDeque};
49
50 #[derive(Debug, Clone, Copy, PartialEq)]
51 enum ConstVal {
52 Int(i128, IntWidth),
53 Float(u64, FloatWidth), // bit pattern — avoids NaN equality issues
54 Bool(bool),
55 }
56
57 #[derive(Debug, Clone, Copy, PartialEq)]
58 enum Lattice {
59 Top,
60 Const(ConstVal),
61 Bottom,
62 }
63
64 impl ConstVal {
65 fn from_inst(kind: &InstKind) -> Option<Self> {
66 match kind {
67 InstKind::ConstInt(v, w) => Some(ConstVal::Int(sext(*v, w.bits()), *w)),
68 InstKind::ConstFloat(v, w) => Some(ConstVal::Float(v.to_bits(), *w)),
69 InstKind::ConstBool(b) => Some(ConstVal::Bool(*b)),
70 _ => None,
71 }
72 }
73
74 fn to_inst_kind(self) -> InstKind {
75 match self {
76 ConstVal::Int(v, w) => InstKind::ConstInt(v, w),
77 ConstVal::Float(bits, w) => InstKind::ConstFloat(f64::from_bits(bits), w),
78 ConstVal::Bool(b) => InstKind::ConstBool(b),
79 }
80 }
81
82 fn ir_type(self) -> IrType {
83 match self {
84 ConstVal::Int(_, w) => IrType::Int(w),
85 ConstVal::Float(_, w) => IrType::Float(w),
86 ConstVal::Bool(_) => IrType::Bool,
87 }
88 }
89 }
90
91 fn sext(v: i128, bits: u32) -> i128 {
92 if bits >= 128 {
93 v
94 } else {
95 let shift = 128 - bits;
96 (v << shift) >> shift
97 }
98 }
99
100 fn meet(a: Lattice, b: Lattice) -> Lattice {
101 match (a, b) {
102 (Lattice::Top, x) | (x, Lattice::Top) => x,
103 (Lattice::Bottom, _) | (_, Lattice::Bottom) => Lattice::Bottom,
104 (Lattice::Const(x), Lattice::Const(y)) => {
105 if x == y {
106 Lattice::Const(x)
107 } else {
108 Lattice::Bottom
109 }
110 }
111 }
112 }
113
114 struct Sccp<'a> {
115 func: &'a Function,
116 /// Reverse map: ValueId → its defining block id (for SSA-edge
117 /// worklist processing of block params, since we need to know
118 /// which block's incoming edges to re-meet).
119 block_param_owner: HashMap<ValueId, BlockId>,
120 /// Lattice value per ValueId. Absent = Top.
121 lattice: HashMap<ValueId, Lattice>,
122 /// Reachable blocks.
123 reachable_blocks: HashSet<BlockId>,
124 /// Reachable CFG edges (pred, succ).
125 reachable_edges: HashSet<(BlockId, BlockId)>,
126 /// CFG worklist — newly reachable edges to process.
127 cfg_worklist: VecDeque<(BlockId, BlockId)>,
128 /// SSA worklist — values whose lattice changed; re-evaluate users
129 /// (block params and terminators that consume them).
130 ssa_worklist: VecDeque<ValueId>,
131 /// Predecessor list (built once).
132 preds: HashMap<BlockId, Vec<BlockId>>,
133 /// (pred_block, args_passed) per (pred → succ) edge — used to
134 /// look up what value flows into each block param along an edge.
135 /// Indexed by (pred, succ); empty for Return/Unreachable terms.
136 edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>>,
137 /// For each block param ValueId: the (pred → succ) edges that
138 /// feed it, plus its index within the block's param list. We
139 /// rebuild the meet over reachable edges whenever any of the
140 /// inputs changes.
141 param_inputs: HashMap<ValueId, ParamInputInfo>,
142 /// For each ValueId, the set of dependents we should requeue on
143 /// SSA-worklist when it changes. For now: only block params and
144 /// terminator-conditions are tracked (we don't model arithmetic).
145 value_users: HashMap<ValueId, Vec<ValueUser>>,
146 }
147
148 #[derive(Debug, Clone)]
149 struct ParamInputInfo {
150 block: BlockId,
151 param_idx: usize,
152 }
153
154 #[derive(Debug, Clone)]
155 enum ValueUser {
156 /// `cond` of a CondBranch in the given block; reanalyze the
157 /// terminator to maybe expose newly-unreachable successors.
158 Terminator(BlockId),
159 /// A block param that consumes this value along some edge —
160 /// re-meet the param when the value's lattice changes.
161 BlockParam(ValueId),
162 }
163
164 impl<'a> Sccp<'a> {
165 fn new(func: &'a Function) -> Self {
166 let mut block_param_owner = HashMap::new();
167 let mut param_inputs = HashMap::new();
168 for block in &func.blocks {
169 for (i, p) in block.params.iter().enumerate() {
170 block_param_owner.insert(p.id, block.id);
171 param_inputs.insert(
172 p.id,
173 ParamInputInfo {
174 block: block.id,
175 param_idx: i,
176 },
177 );
178 }
179 }
180
181 let mut preds: HashMap<BlockId, Vec<BlockId>> = HashMap::new();
182 let mut edge_args: HashMap<(BlockId, BlockId), Vec<ValueId>> = HashMap::new();
183 for block in &func.blocks {
184 preds.entry(block.id).or_default();
185 }
186 for block in &func.blocks {
187 let Some(term) = &block.terminator else {
188 continue;
189 };
190 match term {
191 Terminator::Branch(succ, args) => {
192 preds.entry(*succ).or_default().push(block.id);
193 edge_args.insert((block.id, *succ), args.clone());
194 }
195 Terminator::CondBranch {
196 true_dest,
197 true_args,
198 false_dest,
199 false_args,
200 ..
201 } => {
202 preds.entry(*true_dest).or_default().push(block.id);
203 preds.entry(*false_dest).or_default().push(block.id);
204 edge_args.insert((block.id, *true_dest), true_args.clone());
205 edge_args.insert((block.id, *false_dest), false_args.clone());
206 }
207 Terminator::Switch { cases, default, .. } => {
208 preds.entry(*default).or_default().push(block.id);
209 edge_args.insert((block.id, *default), vec![]);
210 for (_, tgt) in cases {
211 preds.entry(*tgt).or_default().push(block.id);
212 edge_args.insert((block.id, *tgt), vec![]);
213 }
214 }
215 Terminator::Return(_) | Terminator::Unreachable => {}
216 }
217 }
218
219 // Build dependency graph: for each value, the users that need
220 // to be reanalyzed when it changes.
221 let mut value_users: HashMap<ValueId, Vec<ValueUser>> = HashMap::new();
222 for block in &func.blocks {
223 // Block-param dependents: each param's incoming arg
224 // along each pred edge.
225 for (pi, p) in block.params.iter().enumerate() {
226 for pred in preds.get(&block.id).cloned().unwrap_or_default() {
227 if let Some(args) = edge_args.get(&(pred, block.id)) {
228 if let Some(arg) = args.get(pi) {
229 value_users
230 .entry(*arg)
231 .or_default()
232 .push(ValueUser::BlockParam(p.id));
233 }
234 }
235 }
236 }
237 // Terminator-condition dependents (only CondBranch/Switch
238 // can fold on a constant condition).
239 if let Some(term) = &block.terminator {
240 match term {
241 Terminator::CondBranch { cond, .. } => {
242 value_users
243 .entry(*cond)
244 .or_default()
245 .push(ValueUser::Terminator(block.id));
246 }
247 Terminator::Switch { selector, .. } => {
248 value_users
249 .entry(*selector)
250 .or_default()
251 .push(ValueUser::Terminator(block.id));
252 }
253 _ => {}
254 }
255 }
256 }
257
258 Self {
259 func,
260 block_param_owner,
261 lattice: HashMap::new(),
262 reachable_blocks: HashSet::new(),
263 reachable_edges: HashSet::new(),
264 cfg_worklist: VecDeque::new(),
265 ssa_worklist: VecDeque::new(),
266 preds,
267 edge_args,
268 param_inputs,
269 value_users,
270 }
271 }
272
273 fn lat(&self, v: ValueId) -> Lattice {
274 self.lattice.get(&v).copied().unwrap_or(Lattice::Top)
275 }
276
277 /// Set lattice value, queue SSA-worklist if changed and not
278 /// already at the bottom.
279 fn set_lat(&mut self, v: ValueId, new: Lattice) {
280 let old = self.lat(v);
281 let merged = meet(old, new);
282 if merged != old {
283 self.lattice.insert(v, merged);
284 self.ssa_worklist.push_back(v);
285 }
286 }
287
288 /// Make a block reachable (idempotent). Seeds CFG worklist with
289 /// outgoing edges and marks the block.
290 fn mark_block_reachable(&mut self, b: BlockId) {
291 if !self.reachable_blocks.insert(b) {
292 return;
293 }
294 // First time reaching this block — also seed its outgoing
295 // edges based on its (current) terminator analysis.
296 self.process_terminator(b);
297 }
298
299 /// Mark a CFG edge reachable; if newly so, mark the destination
300 /// block reachable and queue the edge for block-param re-meet.
301 fn mark_edge(&mut self, from: BlockId, to: BlockId) {
302 if self.reachable_edges.insert((from, to)) {
303 self.cfg_worklist.push_back((from, to));
304 self.mark_block_reachable(to);
305 }
306 }
307
308 /// Analyze a block's terminator and mark live successors.
309 /// Called when the block becomes reachable or when its
310 /// terminator-condition lattice changes.
311 fn process_terminator(&mut self, b: BlockId) {
312 let Some(block) = self.func.blocks.iter().find(|bb| bb.id == b) else {
313 return;
314 };
315 let Some(term) = &block.terminator else {
316 return;
317 };
318 match term {
319 Terminator::Return(_) | Terminator::Unreachable => {}
320 Terminator::Branch(dest, _) => {
321 let dest = *dest;
322 self.mark_edge(b, dest);
323 }
324 Terminator::CondBranch {
325 cond,
326 true_dest,
327 false_dest,
328 ..
329 } => match self.lat(*cond) {
330 Lattice::Const(ConstVal::Bool(true)) => self.mark_edge(b, *true_dest),
331 Lattice::Const(ConstVal::Bool(false)) => self.mark_edge(b, *false_dest),
332 Lattice::Bottom => {
333 let t = *true_dest;
334 let f = *false_dest;
335 self.mark_edge(b, t);
336 self.mark_edge(b, f);
337 }
338 // Const(non-bool) is a type error in well-typed IR;
339 // be conservative and mark both. Top means we'll
340 // revisit when `cond` resolves.
341 Lattice::Const(_) => {
342 let t = *true_dest;
343 let f = *false_dest;
344 self.mark_edge(b, t);
345 self.mark_edge(b, f);
346 }
347 Lattice::Top => {}
348 },
349 Terminator::Switch {
350 selector,
351 cases,
352 default,
353 } => match self.lat(*selector) {
354 Lattice::Const(ConstVal::Int(sv, w)) => {
355 let bits = w.bits();
356 let target = cases
357 .iter()
358 .find(|(k, _)| sext(*k as i128, bits) == sv)
359 .map(|(_, blk)| *blk)
360 .unwrap_or(*default);
361 self.mark_edge(b, target);
362 }
363 Lattice::Bottom => {
364 let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect();
365 let default = *default;
366 for t in cases_clone {
367 self.mark_edge(b, t);
368 }
369 self.mark_edge(b, default);
370 }
371 Lattice::Const(_) => {
372 let cases_clone: Vec<BlockId> = cases.iter().map(|(_, t)| *t).collect();
373 let default = *default;
374 for t in cases_clone {
375 self.mark_edge(b, t);
376 }
377 self.mark_edge(b, default);
378 }
379 Lattice::Top => {}
380 },
381 }
382 }
383
384 /// Re-meet a block param's lattice value over all reachable
385 /// incoming edges.
386 fn recompute_param(&mut self, param_id: ValueId) {
387 let Some(info) = self.param_inputs.get(&param_id).cloned() else {
388 return;
389 };
390 let block = info.block;
391 let pi = info.param_idx;
392 let preds = self.preds.get(&block).cloned().unwrap_or_default();
393 let mut new = Lattice::Top;
394 for pred in preds {
395 if !self.reachable_edges.contains(&(pred, block)) {
396 continue;
397 }
398 let Some(args) = self.edge_args.get(&(pred, block)) else {
399 continue;
400 };
401 let Some(arg) = args.get(pi) else {
402 continue;
403 };
404 // If the arg IS the param itself (back-edge with same
405 // value), don't pollute with our own current state.
406 // SSA permits this, e.g. loop induction merges.
407 if *arg == param_id {
408 continue;
409 }
410 new = meet(new, self.lat(*arg));
411 if new == Lattice::Bottom {
412 break;
413 }
414 }
415 // Convert "Top with at least one reachable edge" → keep Top
416 // until a concrete arg shows up (Top is sound: still optimistic).
417 // monotone update via set_lat.
418 let old = self.lat(param_id);
419 let merged = meet(old, new);
420 if merged != old {
421 self.lattice.insert(param_id, merged);
422 self.ssa_worklist.push_back(param_id);
423 }
424 }
425
426 fn run(&mut self) {
427 // Seed lattice for every constant-defining instruction.
428 for block in &self.func.blocks {
429 for inst in &block.insts {
430 if let Some(c) = ConstVal::from_inst(&inst.kind) {
431 self.lattice.insert(inst.id, Lattice::Const(c));
432 } else {
433 // Non-const, non-block-param values default to
434 // Bottom — SCCP doesn't model arithmetic
435 // transfer. const_fold handles those after we
436 // expose new reachability.
437 self.lattice.insert(inst.id, Lattice::Bottom);
438 }
439 }
440 }
441 // Function params are unknown → Bottom.
442 for p in &self.func.params {
443 self.lattice.insert(p.id, Lattice::Bottom);
444 }
445
446 // Entry is always reachable.
447 self.mark_block_reachable(self.func.entry);
448
449 while !self.cfg_worklist.is_empty() || !self.ssa_worklist.is_empty() {
450 while let Some((from, to)) = self.cfg_worklist.pop_front() {
451 let _ = from;
452 // Re-meet every block param of `to`, since a new
453 // reachable edge may add a new input.
454 let params: Vec<ValueId> = self
455 .func
456 .blocks
457 .iter()
458 .find(|b| b.id == to)
459 .map(|b| b.params.iter().map(|p| p.id).collect())
460 .unwrap_or_default();
461 for pid in params {
462 self.recompute_param(pid);
463 }
464 }
465 while let Some(v) = self.ssa_worklist.pop_front() {
466 let users = self.value_users.get(&v).cloned().unwrap_or_default();
467 for u in users {
468 match u {
469 ValueUser::Terminator(b) => {
470 if self.reachable_blocks.contains(&b) {
471 self.process_terminator(b);
472 }
473 }
474 ValueUser::BlockParam(pid) => {
475 self.recompute_param(pid);
476 }
477 }
478 }
479 // Block params themselves change → reanalyze
480 // terminators that consume them. Already handled by
481 // `value_users`.
482 let _ = self.block_param_owner.get(&v);
483 }
484 }
485 }
486 }
487
488 /// Result of SCCP analysis — the data needed by `apply` without
489 /// holding any borrow on the function.
490 struct SccpResult {
491 lattice: HashMap<ValueId, Lattice>,
492 }
493
494 /// Apply the SCCP analysis result to `func`. Returns whether anything
495 /// was rewritten.
496 fn apply(func: &mut Function, analysis: &SccpResult) -> bool {
497 let mut changed = false;
498
499 // Step 1: rewrite constant block params.
500 //
501 // For every block param whose lattice resolved to `Const`:
502 // - Insert a fresh `Const*` instruction at the front of its
503 // block, allocate a new ValueId for it.
504 // - Substitute uses of the block-param ValueId with the new
505 // constant ValueId across the function.
506 // - Remove the param from the block's `params`.
507 // - Remove the matching argument from every predecessor's
508 // branch terminator.
509 //
510 // We do this block-by-block, scanning back→front within the
511 // params list so index-based predecessor arg removal stays
512 // valid.
513 let mut const_param_rewrites: Vec<(BlockId, usize, ValueId, ConstVal, IrType)> = Vec::new();
514 for block in &func.blocks {
515 for (pi, p) in block.params.iter().enumerate() {
516 if let Some(Lattice::Const(c)) = analysis.lattice.get(&p.id).copied() {
517 // Sanity: lattice type should be compatible with
518 // declared param type. If not (defensive), skip.
519 if !ty_compatible(&p.ty, &c.ir_type()) {
520 continue;
521 }
522 const_param_rewrites.push((block.id, pi, p.id, c, p.ty.clone()));
523 }
524 }
525 }
526
527 if !const_param_rewrites.is_empty() {
528 // Sort by (block, descending param index) so predecessor
529 // arg removal is index-stable as we mutate.
530 const_param_rewrites.sort_by_key(|x| std::cmp::Reverse(x.1));
531 for (block_id, pi, param_id, cval, ty) in const_param_rewrites {
532 // Allocate a new value for the constant.
533 let new_id = func.next_value_id();
534 // Insert at the front of the target block.
535 if let Some(block) = func.blocks.iter_mut().find(|b| b.id == block_id) {
536 let span = block
537 .insts
538 .first()
539 .map(|i| i.span)
540 .or_else(|| {
541 block.terminator.as_ref().map(|_| crate::lexer::Span {
542 start: crate::lexer::Position { line: 1, col: 1 },
543 end: crate::lexer::Position { line: 1, col: 1 },
544 file_id: 0,
545 })
546 })
547 .unwrap_or(crate::lexer::Span {
548 start: crate::lexer::Position { line: 1, col: 1 },
549 end: crate::lexer::Position { line: 1, col: 1 },
550 file_id: 0,
551 });
552 block.insts.insert(
553 0,
554 Inst {
555 id: new_id,
556 kind: cval.to_inst_kind(),
557 ty: ty.clone(),
558 span,
559 },
560 );
561 // Drop the param.
562 block.params.remove(pi);
563 }
564 // Rewrite uses of param_id → new_id.
565 crate::ir::walk::substitute_uses(func, param_id, new_id);
566 // Drop the matching argument in every predecessor's
567 // branch.
568 for pred_block in &mut func.blocks {
569 let Some(term) = pred_block.terminator.as_mut() else {
570 continue;
571 };
572 match term {
573 Terminator::Branch(dest, args) if *dest == block_id && pi < args.len() => {
574 args.remove(pi);
575 }
576 Terminator::CondBranch {
577 true_dest,
578 true_args,
579 false_dest,
580 false_args,
581 ..
582 } => {
583 if *true_dest == block_id && pi < true_args.len() {
584 true_args.remove(pi);
585 }
586 if *false_dest == block_id && pi < false_args.len() {
587 false_args.remove(pi);
588 }
589 }
590 _ => {}
591 }
592 }
593 changed = true;
594 }
595 func.rebuild_type_cache();
596 }
597
598 // Step 2: fold constant-condition terminators.
599 //
600 // Re-collect const map (block-param rewrites just added new Const
601 // instructions).
602 let mut consts: HashMap<ValueId, ConstVal> = HashMap::new();
603 for block in &func.blocks {
604 for inst in &block.insts {
605 if let Some(c) = ConstVal::from_inst(&inst.kind) {
606 consts.insert(inst.id, c);
607 }
608 }
609 }
610 let mut folded_any = false;
611 for block in &mut func.blocks {
612 let Some(term) = block.terminator.take() else {
613 continue;
614 };
615 let new_term = match term {
616 Terminator::CondBranch {
617 cond,
618 true_dest,
619 true_args,
620 false_dest,
621 false_args,
622 } => match consts.get(&cond) {
623 Some(ConstVal::Bool(true)) => {
624 folded_any = true;
625 Terminator::Branch(true_dest, true_args)
626 }
627 Some(ConstVal::Bool(false)) => {
628 folded_any = true;
629 Terminator::Branch(false_dest, false_args)
630 }
631 _ => Terminator::CondBranch {
632 cond,
633 true_dest,
634 true_args,
635 false_dest,
636 false_args,
637 },
638 },
639 Terminator::Switch {
640 selector,
641 cases,
642 default,
643 } => match consts.get(&selector) {
644 Some(ConstVal::Int(sv, w)) => {
645 folded_any = true;
646 let bits = w.bits();
647 let target = cases
648 .iter()
649 .find(|(k, _)| sext(*k as i128, bits) == *sv)
650 .map(|(_, blk)| *blk)
651 .unwrap_or(default);
652 Terminator::Branch(target, vec![])
653 }
654 _ => Terminator::Switch {
655 selector,
656 cases,
657 default,
658 },
659 },
660 other => other,
661 };
662 block.terminator = Some(new_term);
663 }
664 if folded_any {
665 changed = true;
666 }
667
668 // Step 3: prune unreachable blocks.
669 //
670 // The SCCP analysis itself produced a reachability set, but the
671 // safe thing is to recompute it post-rewrite — block-param
672 // rewrites and terminator folds can change the CFG.
673 if prune_unreachable(func) {
674 changed = true;
675 }
676 if changed {
677 func.rebuild_type_cache();
678 }
679
680 changed
681 }
682
683 fn ty_compatible(a: &IrType, b: &IrType) -> bool {
684 // Cheap structural compare for the const types we materialize.
685 match (a, b) {
686 (IrType::Bool, IrType::Bool) => true,
687 (IrType::Int(wa), IrType::Int(wb)) => wa == wb,
688 (IrType::Float(wa), IrType::Float(wb)) => wa == wb,
689 _ => false,
690 }
691 }
692
693 /// The pass entry point.
694 pub struct Sccp_;
695
696 impl Pass for Sccp_ {
697 fn name(&self) -> &'static str {
698 "sccp"
699 }
700
701 fn run(&self, module: &mut Module) -> bool {
702 let mut changed = false;
703 for func in &mut module.functions {
704 let analysis = {
705 let mut s = Sccp::new(func);
706 s.run();
707 SccpResult { lattice: s.lattice }
708 };
709 if apply(func, &analysis) {
710 changed = true;
711 }
712 }
713 changed
714 }
715 }
716
717 #[cfg(test)]
718 mod tests {
719 use super::*;
720 use crate::ir::types::IrType;
721 use crate::lexer::{Position, Span};
722
723 fn dummy_span() -> Span {
724 let p = Position { line: 1, col: 1 };
725 Span {
726 start: p,
727 end: p,
728 file_id: 0,
729 }
730 }
731
732 #[test]
733 fn folds_constant_true_condbranch() {
734 // entry: cond = const(true); cond_branch cond, then, else
735 // then: ret
736 // else: ret
737 let mut m = Module::new("t".into());
738 let mut f = Function::new("f".into(), vec![], IrType::Void);
739 let bb_t = f.create_block("then");
740 let bb_e = f.create_block("else");
741 let cond_id = f.next_value_id();
742 f.block_mut(f.entry).insts.push(Inst {
743 id: cond_id,
744 kind: InstKind::ConstBool(true),
745 ty: IrType::Bool,
746 span: dummy_span(),
747 });
748 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
749 cond: cond_id,
750 true_dest: bb_t,
751 true_args: vec![],
752 false_dest: bb_e,
753 false_args: vec![],
754 });
755 f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
756 f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
757 m.add_function(f);
758
759 assert!(Sccp_.run(&mut m));
760 let f = &m.functions[0];
761 match &f.blocks[0].terminator {
762 Some(Terminator::Branch(d, _)) => assert_eq!(*d, bb_t),
763 other => panic!("expected Branch, got {:?}", other),
764 }
765 // else block should be pruned.
766 assert!(f.blocks.iter().all(|b| b.id != bb_e));
767 }
768
769 #[test]
770 fn merge_param_with_uniform_const_resolves() {
771 // entry: cond=const(true); cond_branch cond, a, b
772 // a: branch merge(const 7)
773 // b: branch merge(const 99) ; statically unreachable
774 // merge: param p; ret p
775 //
776 // Plain const_prop can't fold this — `p` has two distinct
777 // incoming arg values, and const_prop doesn't know `b` is
778 // unreachable. SCCP does.
779 let mut m = Module::new("t".into());
780 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
781
782 let bb_a = f.create_block("a");
783 let bb_b = f.create_block("b");
784 let bb_merge = f.create_block("merge");
785
786 let cond_id = f.next_value_id();
787 let const_a = f.next_value_id();
788 let const_b = f.next_value_id();
789 let merge_param = f.next_value_id();
790
791 f.block_mut(bb_merge).params.push(BlockParam {
792 id: merge_param,
793 ty: IrType::Int(IntWidth::I32),
794 });
795
796 f.block_mut(f.entry).insts.push(Inst {
797 id: cond_id,
798 kind: InstKind::ConstBool(true),
799 ty: IrType::Bool,
800 span: dummy_span(),
801 });
802 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
803 cond: cond_id,
804 true_dest: bb_a,
805 true_args: vec![],
806 false_dest: bb_b,
807 false_args: vec![],
808 });
809
810 f.block_mut(bb_a).insts.push(Inst {
811 id: const_a,
812 kind: InstKind::ConstInt(7, IntWidth::I32),
813 ty: IrType::Int(IntWidth::I32),
814 span: dummy_span(),
815 });
816 f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a]));
817
818 f.block_mut(bb_b).insts.push(Inst {
819 id: const_b,
820 kind: InstKind::ConstInt(99, IntWidth::I32),
821 ty: IrType::Int(IntWidth::I32),
822 span: dummy_span(),
823 });
824 f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b]));
825
826 f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param)));
827 f.rebuild_type_cache();
828 m.add_function(f);
829
830 assert!(Sccp_.run(&mut m), "SCCP should fold this");
831
832 let f = &m.functions[0];
833 // merge block: param removed, new ConstInt(7) inserted, ret
834 // now references that const.
835 let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap();
836 assert!(merge.params.is_empty(), "merge param should be gone");
837 // First inst should be a ConstInt(7).
838 match merge.insts.first().map(|i| &i.kind) {
839 Some(InstKind::ConstInt(7, IntWidth::I32)) => {}
840 other => panic!("expected ConstInt(7,I32) at start of merge, got {:?}", other),
841 }
842 // bb_b is unreachable post-fold and should be pruned.
843 assert!(
844 !f.blocks.iter().any(|b| b.id == bb_b),
845 "unreachable b should be pruned"
846 );
847 }
848
849 #[test]
850 fn non_constant_param_is_not_rewritten() {
851 // entry: cond_branch <param> a, b
852 // a: branch merge(const 1)
853 // b: branch merge(const 2)
854 // merge: ret param
855 //
856 // Both arms reachable, two distinct constants → param stays.
857 let mut m = Module::new("t".into());
858 let params = vec![Param {
859 name: "c".into(),
860 ty: IrType::Bool,
861 id: ValueId(0),
862 fortran_noalias: false,
863 }];
864 let mut f = Function::new("f".into(), params, IrType::Int(IntWidth::I32));
865 let bb_a = f.create_block("a");
866 let bb_b = f.create_block("b");
867 let bb_merge = f.create_block("merge");
868
869 let const_a = f.next_value_id();
870 let const_b = f.next_value_id();
871 let merge_param = f.next_value_id();
872 f.block_mut(bb_merge).params.push(BlockParam {
873 id: merge_param,
874 ty: IrType::Int(IntWidth::I32),
875 });
876 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
877 cond: ValueId(0),
878 true_dest: bb_a,
879 true_args: vec![],
880 false_dest: bb_b,
881 false_args: vec![],
882 });
883 f.block_mut(bb_a).insts.push(Inst {
884 id: const_a,
885 kind: InstKind::ConstInt(1, IntWidth::I32),
886 ty: IrType::Int(IntWidth::I32),
887 span: dummy_span(),
888 });
889 f.block_mut(bb_a).terminator = Some(Terminator::Branch(bb_merge, vec![const_a]));
890 f.block_mut(bb_b).insts.push(Inst {
891 id: const_b,
892 kind: InstKind::ConstInt(2, IntWidth::I32),
893 ty: IrType::Int(IntWidth::I32),
894 span: dummy_span(),
895 });
896 f.block_mut(bb_b).terminator = Some(Terminator::Branch(bb_merge, vec![const_b]));
897 f.block_mut(bb_merge).terminator = Some(Terminator::Return(Some(merge_param)));
898 f.rebuild_type_cache();
899 m.add_function(f);
900
901 let _changed = Sccp_.run(&mut m);
902 let f = &m.functions[0];
903 // The merge block's param must remain — both arms are
904 // genuinely reachable and disagree.
905 let merge = f.blocks.iter().find(|b| b.id == bb_merge).unwrap();
906 assert_eq!(
907 merge.params.len(),
908 1,
909 "merge param should survive when both arms are reachable and disagree"
910 );
911 }
912
913 #[test]
914 fn unknown_cond_left_alone() {
915 let mut m = Module::new("t".into());
916 let params = vec![Param {
917 name: "p".into(),
918 ty: IrType::Bool,
919 id: ValueId(0),
920 fortran_noalias: false,
921 }];
922 let mut f = Function::new("f".into(), params, IrType::Void);
923 let bb_t = f.create_block("then");
924 let bb_e = f.create_block("else");
925 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
926 cond: ValueId(0),
927 true_dest: bb_t,
928 true_args: vec![],
929 false_dest: bb_e,
930 false_args: vec![],
931 });
932 f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
933 f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
934 m.add_function(f);
935
936 assert!(!Sccp_.run(&mut m));
937 assert!(matches!(
938 m.functions[0].blocks[0].terminator,
939 Some(Terminator::CondBranch { .. })
940 ));
941 }
942 }
943