Rust · 14142 bytes Raw Blame History
1 //! Constant propagation pass.
2 //!
3 //! In our SSA IR, propagating a constant value into the operand slot of
4 //! an instruction is a no-op rewrite — every consumer already finds the
5 //! defining `Const*` instruction by `ValueId`, and codegen happily turns
6 //! that into an immediate. The transformations that *do* matter are the
7 //! ones the constant-folding pass cannot reach because they live in
8 //! block terminators:
9 //!
10 //! * `CondBranch { cond, true_dest, false_dest }` with a constant
11 //! `cond` becomes an unconditional `Branch` to the live target.
12 //! * `Switch { selector, cases, default }` with a constant `selector`
13 //! becomes an unconditional `Branch` to the matching case.
14 //!
15 //! Both rewrites are critical: they expose dead blocks to later passes
16 //! (CSE, LICM, GVN) and shrink the CFG.
17 //!
18 //! After folding terminators we **must** prune any blocks that are no
19 //! longer reachable from the entry. Leaving an orphan reachable through
20 //! some other path could be fine, but our verifier rejects values used
21 //! across blocks that no longer dominate the use — and unreachable
22 //! blocks can break dominance for downstream merges. The cleanest fix
23 //! is to remove the dead blocks as part of the same transformation so
24 //! the IR handed to the verifier is always well-formed.
25 //!
26 //! The pass is conservative: if the picked target has block parameters
27 //! (which Fortran lowering only emits for loop headers and merges), we
28 //! still rewrite — we just keep the original branch arguments for the
29 //! taken side, since that side already had matching args in the
30 //! original `CondBranch`.
31
32 use super::pass::Pass;
33 use super::util::prune_unreachable;
34 use crate::ir::inst::*;
35 use crate::ir::types::IntWidth;
36 use std::collections::HashMap;
37
38 /// What we know about a value at compile time.
39 #[derive(Debug, Clone, Copy)]
40 enum Const {
41 Bool(bool),
42 Int(i64, IntWidth),
43 }
44
45 impl Const {
46 /// Build a `Const` from a constant-producing instruction,
47 /// sign-extending integer values at their declared width so the
48 /// stored value is canonical. Audit M4-4: matches the N-9
49 /// normalization in `const_fold::Const::from_inst`. The Switch
50 /// fold at line 119 already sexts at use site, but relying on
51 /// every use site to remember that is fragile — canonicalizing
52 /// here closes the consistency gap with `const_fold`.
53 fn from_inst(kind: &InstKind) -> Option<Self> {
54 match kind {
55 InstKind::ConstBool(b) => Some(Const::Bool(*b)),
56 InstKind::ConstInt(v, w) => {
57 let bits = w.bits();
58 let signed = if bits >= 64 {
59 i64::try_from(*v).ok()?
60 } else {
61 let shift = 64 - bits;
62 let narrow = i64::try_from(*v).ok()?;
63 (narrow << shift) >> shift
64 };
65 Some(Const::Int(signed, *w))
66 }
67 _ => None,
68 }
69 }
70 }
71
72 /// Build the const-value map for one function. Walks every block in
73 /// definition order; SSA dominance guarantees an instruction's operands
74 /// are already in the map when we reach it.
75 fn collect_consts(func: &Function) -> HashMap<ValueId, Const> {
76 let mut consts = HashMap::new();
77 for block in &func.blocks {
78 for inst in &block.insts {
79 if let Some(c) = Const::from_inst(&inst.kind) {
80 consts.insert(inst.id, c);
81 }
82 }
83 }
84 consts
85 }
86
87 /// The constant propagation pass.
88 pub struct ConstProp;
89
90 impl Pass for ConstProp {
91 fn name(&self) -> &'static str {
92 "const-prop"
93 }
94
95 fn run(&self, module: &mut Module) -> bool {
96 let mut changed = false;
97 for func in &mut module.functions {
98 let consts = collect_consts(func);
99 let mut local_changed = false;
100 for block in &mut func.blocks {
101 let Some(term) = block.terminator.take() else {
102 continue;
103 };
104 let new_term = simplify_terminator(term, &consts, &mut local_changed);
105 block.terminator = Some(new_term);
106 }
107 if local_changed {
108 // Folded a CondBranch/Switch — drop any block that is no
109 // longer reachable from the entry so the verifier doesn't
110 // see stale uses across a now-impossible path.
111 let _pruned = prune_unreachable(func);
112 changed = true;
113 }
114 }
115 changed
116 }
117 }
118
119 fn simplify_terminator(
120 term: Terminator,
121 consts: &HashMap<ValueId, Const>,
122 changed: &mut bool,
123 ) -> Terminator {
124 match term {
125 Terminator::CondBranch {
126 cond,
127 true_dest,
128 true_args,
129 false_dest,
130 false_args,
131 } => {
132 if let Some(Const::Bool(c)) = consts.get(&cond).copied() {
133 *changed = true;
134 if c {
135 Terminator::Branch(true_dest, true_args)
136 } else {
137 Terminator::Branch(false_dest, false_args)
138 }
139 } else {
140 Terminator::CondBranch {
141 cond,
142 true_dest,
143 true_args,
144 false_dest,
145 false_args,
146 }
147 }
148 }
149 Terminator::Switch {
150 selector,
151 cases,
152 default,
153 } => {
154 if let Some(Const::Int(sv, w)) = consts.get(&selector).copied() {
155 // `sv` is already canonical (sign-extended at width)
156 // thanks to the M4-4 normalization in `Const::from_inst`.
157 // Case keys come from the IR and aren't guaranteed
158 // canonical, so we still normalize them before
159 // comparing.
160 *changed = true;
161 let target = cases
162 .iter()
163 .find(|(k, _)| sext(*k, w.bits()) == sv)
164 .map(|(_, b)| *b)
165 .unwrap_or(default);
166 // Switch targets in our IR cannot have block params, so
167 // an empty arg vec is correct.
168 Terminator::Branch(target, vec![])
169 } else {
170 Terminator::Switch {
171 selector,
172 cases,
173 default,
174 }
175 }
176 }
177 other => other,
178 }
179 }
180
181 fn sext(v: i64, bits: u32) -> i64 {
182 if bits >= 64 {
183 v
184 } else {
185 let shift = 64 - bits;
186 (v << shift) >> shift
187 }
188 }
189
190 #[cfg(test)]
191 mod tests {
192 use super::*;
193 use crate::ir::types::IrType;
194 use crate::lexer::{Position, Span};
195
196 fn dummy_span() -> Span {
197 let p = Position { line: 1, col: 1 };
198 Span {
199 start: p,
200 end: p,
201 file_id: 0,
202 }
203 }
204
205 #[test]
206 fn condbranch_with_true_const_folds_to_branch_true() {
207 let mut m = Module::new("t".into());
208 let mut f = Function::new("f".into(), vec![], IrType::Void);
209
210 // Build: entry { cond=const(true); cond_branch cond, then, else }
211 // then { ret }
212 // else { ret }
213 let bb_t = f.create_block("then");
214 let bb_e = f.create_block("else");
215 let cond_id = f.next_value_id();
216 f.block_mut(f.entry).insts.push(Inst {
217 id: cond_id,
218 kind: InstKind::ConstBool(true),
219 ty: IrType::Bool,
220 span: dummy_span(),
221 });
222 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
223 cond: cond_id,
224 true_dest: bb_t,
225 true_args: vec![],
226 false_dest: bb_e,
227 false_args: vec![],
228 });
229 f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
230 f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
231 m.add_function(f);
232
233 assert!(ConstProp.run(&mut m));
234 match &m.functions[0].blocks[0].terminator {
235 Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, bb_t),
236 other => panic!("expected Branch, got {:?}", other),
237 }
238 }
239
240 #[test]
241 fn condbranch_with_false_const_folds_to_branch_false() {
242 let mut m = Module::new("t".into());
243 let mut f = Function::new("f".into(), vec![], IrType::Void);
244 let bb_t = f.create_block("then");
245 let bb_e = f.create_block("else");
246 let cond_id = f.next_value_id();
247 f.block_mut(f.entry).insts.push(Inst {
248 id: cond_id,
249 kind: InstKind::ConstBool(false),
250 ty: IrType::Bool,
251 span: dummy_span(),
252 });
253 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
254 cond: cond_id,
255 true_dest: bb_t,
256 true_args: vec![],
257 false_dest: bb_e,
258 false_args: vec![],
259 });
260 f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
261 f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
262 m.add_function(f);
263
264 assert!(ConstProp.run(&mut m));
265 match &m.functions[0].blocks[0].terminator {
266 Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, bb_e),
267 other => panic!("expected Branch, got {:?}", other),
268 }
269 }
270
271 #[test]
272 fn switch_with_const_selector_picks_case() {
273 let mut m = Module::new("t".into());
274 let mut f = Function::new("f".into(), vec![], IrType::Void);
275 let case_a = f.create_block("a");
276 let case_b = f.create_block("b");
277 let default = f.create_block("default");
278 let sel_id = f.next_value_id();
279 f.block_mut(f.entry).insts.push(Inst {
280 id: sel_id,
281 kind: InstKind::ConstInt(2, IntWidth::I32),
282 ty: IrType::Int(IntWidth::I32),
283 span: dummy_span(),
284 });
285 f.block_mut(f.entry).terminator = Some(Terminator::Switch {
286 selector: sel_id,
287 cases: vec![(1, case_a), (2, case_b)],
288 default,
289 });
290 f.block_mut(case_a).terminator = Some(Terminator::Return(None));
291 f.block_mut(case_b).terminator = Some(Terminator::Return(None));
292 f.block_mut(default).terminator = Some(Terminator::Return(None));
293 m.add_function(f);
294
295 assert!(ConstProp.run(&mut m));
296 match &m.functions[0].blocks[0].terminator {
297 Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, case_b),
298 other => panic!("expected Branch, got {:?}", other),
299 }
300 }
301
302 #[test]
303 fn switch_with_unmatched_const_selector_picks_default() {
304 let mut m = Module::new("t".into());
305 let mut f = Function::new("f".into(), vec![], IrType::Void);
306 let case_a = f.create_block("a");
307 let default = f.create_block("default");
308 let sel_id = f.next_value_id();
309 f.block_mut(f.entry).insts.push(Inst {
310 id: sel_id,
311 kind: InstKind::ConstInt(99, IntWidth::I32),
312 ty: IrType::Int(IntWidth::I32),
313 span: dummy_span(),
314 });
315 f.block_mut(f.entry).terminator = Some(Terminator::Switch {
316 selector: sel_id,
317 cases: vec![(1, case_a)],
318 default,
319 });
320 f.block_mut(case_a).terminator = Some(Terminator::Return(None));
321 f.block_mut(default).terminator = Some(Terminator::Return(None));
322 m.add_function(f);
323
324 assert!(ConstProp.run(&mut m));
325 match &m.functions[0].blocks[0].terminator {
326 Some(Terminator::Branch(dest, _)) => assert_eq!(*dest, default),
327 other => panic!("expected Branch, got {:?}", other),
328 }
329 }
330
331 #[test]
332 fn folded_condbranch_prunes_dead_block() {
333 // entry { cond=const(true); cond_branch -> then or dead }
334 // then { ret }
335 // dead { ret } // should be removed
336 let mut m = Module::new("t".into());
337 let mut f = Function::new("f".into(), vec![], IrType::Void);
338 let bb_t = f.create_block("then");
339 let bb_d = f.create_block("dead");
340 let cond_id = f.next_value_id();
341 f.block_mut(f.entry).insts.push(Inst {
342 id: cond_id,
343 kind: InstKind::ConstBool(true),
344 ty: IrType::Bool,
345 span: dummy_span(),
346 });
347 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
348 cond: cond_id,
349 true_dest: bb_t,
350 true_args: vec![],
351 false_dest: bb_d,
352 false_args: vec![],
353 });
354 f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
355 f.block_mut(bb_d).terminator = Some(Terminator::Return(None));
356 m.add_function(f);
357
358 assert!(ConstProp.run(&mut m));
359 let f = &m.functions[0];
360 assert_eq!(f.blocks.len(), 2, "dead block should be pruned");
361 assert!(f.blocks.iter().any(|b| b.id == bb_t));
362 assert!(!f.blocks.iter().any(|b| b.id == bb_d));
363 }
364
365 #[test]
366 fn condbranch_with_unknown_cond_left_alone() {
367 // Operand is not a constant — must not transform.
368 let mut m = Module::new("t".into());
369 let params = vec![Param {
370 name: "p".into(),
371 ty: IrType::Bool,
372 id: ValueId(0),
373 fortran_noalias: false,
374 }];
375 let mut f = Function::new("f".into(), params, IrType::Void);
376 let bb_t = f.create_block("then");
377 let bb_e = f.create_block("else");
378 f.block_mut(f.entry).terminator = Some(Terminator::CondBranch {
379 cond: ValueId(0),
380 true_dest: bb_t,
381 true_args: vec![],
382 false_dest: bb_e,
383 false_args: vec![],
384 });
385 f.block_mut(bb_t).terminator = Some(Terminator::Return(None));
386 f.block_mut(bb_e).terminator = Some(Terminator::Return(None));
387 m.add_function(f);
388
389 assert!(!ConstProp.run(&mut m));
390 assert!(matches!(
391 m.functions[0].blocks[0].terminator,
392 Some(Terminator::CondBranch { .. })
393 ));
394 }
395 }
396