Rust · 17418 bytes Raw Blame History
1 //! Common-subexpression elimination (local).
2 //!
3 //! Within a basic block, two pure instructions that compute the same
4 //! expression on the same operands produce the same value. We can drop
5 //! the second one and rewrite its uses to point at the first. The
6 //! invariant we rely on is **SSA**: an instruction's `ValueId` never
7 //! changes once defined, so a downstream rewrite is just a textual
8 //! substitution.
9 //!
10 //! ## Why "local"
11 //!
12 //! Local CSE only matches duplicates inside a single basic block. It's
13 //! cheap, never crosses control-flow merges, and is provably correct
14 //! without alias analysis or dominance walks. A separate global-CSE
15 //! pass (or GVN, which subsumes CSE) will land later for matches
16 //! across dominating blocks.
17 //!
18 //! ## Side effects
19 //!
20 //! Only **pure** instructions are CSE candidates. We deliberately do
21 //! not deduplicate `Load` here — even though two loads of the same
22 //! address are usually equivalent, an intervening `Store` or `Call`
23 //! may have written the location. The future load-store-forwarding
24 //! pass will handle that with proper dependence analysis.
25 //!
26 //! ## Commutativity
27 //!
28 //! For commutative operators we canonicalize the operand pair so that
29 //! `iadd(a, b)` and `iadd(b, a)` collide on the same key. The
30 //! canonical form puts the smaller `ValueId` first. Comparisons
31 //! `Eq`/`Ne` are also commutative; `Lt`/`Le`/`Gt`/`Ge` are not.
32
33 use super::pass::Pass;
34 use super::util::{for_each_operand_mut, for_each_terminator_operand_mut};
35 use crate::ir::inst::*;
36 use crate::ir::types::IrType;
37 use std::collections::HashMap;
38
39 /// A canonical key for one pure instruction.
40 ///
41 /// Two instructions producing the same key are guaranteed to compute
42 /// the same value (modulo type, which is also encoded). The integer
43 /// "tag" disambiguates instruction kinds; the rest of the tuple
44 /// encodes the operands.
45 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
46 struct Key {
47 tag: u32,
48 operands: Vec<ValueId>,
49 /// Auxiliary integer (used for things like comparison op, bitwidth, etc.)
50 aux: i128,
51 /// Optional name for instructions whose value depends on a
52 /// symbol (currently only `GlobalAddr`). Audit Med-2: a hashed
53 /// aux risked theoretical SipHash13 collisions merging two
54 /// different globals into one ADRP+ADD; the explicit name is
55 /// soundness-bearing, not a performance hint.
56 name: Option<String>,
57 /// Result type — same expression on same operands but different
58 /// declared result type would be a bug, but we include it for safety.
59 ty: IrType,
60 }
61
62 /// Build a canonical key for an instruction. Returns `None` if the
63 /// instruction is impure or otherwise not a CSE candidate.
64 fn key_of(inst: &Inst) -> Option<Key> {
65 let mk = |tag: u32, ops: Vec<ValueId>, aux: i128| -> Option<Key> {
66 Some(Key { tag, operands: ops, aux, name: None, ty: inst.ty.clone() })
67 };
68 let mk_named = |tag: u32, name: String| -> Option<Key> {
69 Some(Key { tag, operands: vec![], aux: 0, name: Some(name), ty: inst.ty.clone() })
70 };
71 let canon = |a: ValueId, b: ValueId| -> Vec<ValueId> {
72 if a.0 <= b.0 { vec![a, b] } else { vec![b, a] }
73 };
74
75 match &inst.kind {
76 // Pure address-of-global — no operands, pure function of
77 // the symbol name. Two ADRP+ADD pairs to the same global
78 // inside the same block should fold. Audit Med-3 (CSE-eligible)
79 // and Med-2 (collision-free key).
80 InstKind::GlobalAddr(name) => mk_named(90, name.clone()),
81
82 // Constants ------------------------------------------------------
83 // Audit Min-4: width is carried by the `ty` field, so the
84 // aux can just be the literal value / bit pattern.
85 // Audit B-8: normalize the int value at its declared width
86 // before keying so that semantically-equal constants stored
87 // at different bit patterns dedupe. Example: `ConstInt(255, I8)`
88 // and `ConstInt(-1, I8)` both represent -1 in i8 — keying on
89 // the raw `*v` fails to dedupe them.
90 InstKind::ConstInt(v, w) => {
91 let bits = w.bits();
92 // Sign-extend at width: low `bits` bits → i64 sign-extended.
93 let signed = if bits >= 128 {
94 *v
95 } else {
96 let shift = 128 - bits;
97 (*v << shift) >> shift
98 };
99 mk(1, vec![], signed)
100 }
101 InstKind::ConstFloat(v, _) => mk(2, vec![], v.to_bits() as i128),
102 InstKind::ConstBool(b) => mk(3, vec![], if *b { 1 } else { 0 }),
103
104 // Integer arithmetic --------------------------------------------
105 InstKind::IAdd(a, b) => mk(10, canon(*a, *b), 0),
106 InstKind::ISub(a, b) => mk(11, vec![*a, *b], 0),
107 InstKind::IMul(a, b) => mk(12, canon(*a, *b), 0),
108 InstKind::IDiv(a, b) => mk(13, vec![*a, *b], 0),
109 InstKind::IMod(a, b) => mk(14, vec![*a, *b], 0),
110 InstKind::INeg(a) => mk(15, vec![*a], 0),
111
112 // Float arithmetic ----------------------------------------------
113 InstKind::FAdd(a, b) => mk(20, canon(*a, *b), 0),
114 InstKind::FSub(a, b) => mk(21, vec![*a, *b], 0),
115 InstKind::FMul(a, b) => mk(22, canon(*a, *b), 0),
116 InstKind::FDiv(a, b) => mk(23, vec![*a, *b], 0),
117 InstKind::FNeg(a) => mk(24, vec![*a], 0),
118 InstKind::FAbs(a) => mk(25, vec![*a], 0),
119 InstKind::FSqrt(a) => mk(26, vec![*a], 0),
120 InstKind::FPow(a, b) => mk(27, vec![*a, *b], 0),
121
122 // Comparisons ---------------------------------------------------
123 InstKind::ICmp(op, a, b) => {
124 let aux = *op as i128;
125 let ops = match op { CmpOp::Eq | CmpOp::Ne => canon(*a, *b), _ => vec![*a, *b] };
126 mk(30, ops, aux)
127 }
128 InstKind::FCmp(op, a, b) => {
129 let aux = *op as i128;
130 let ops = match op { CmpOp::Eq | CmpOp::Ne => canon(*a, *b), _ => vec![*a, *b] };
131 mk(31, ops, aux)
132 }
133
134 // Logic ---------------------------------------------------------
135 InstKind::And(a, b) => mk(40, canon(*a, *b), 0),
136 InstKind::Or(a, b) => mk(41, canon(*a, *b), 0),
137 InstKind::Not(a) => mk(42, vec![*a], 0),
138
139 InstKind::Select(c, t, f) => mk(43, vec![*c, *t, *f], 0),
140
141 // Bitwise -------------------------------------------------------
142 InstKind::BitAnd(a, b) => mk(50, canon(*a, *b), 0),
143 InstKind::BitOr(a, b) => mk(51, canon(*a, *b), 0),
144 InstKind::BitXor(a, b) => mk(52, canon(*a, *b), 0),
145 InstKind::BitNot(a) => mk(53, vec![*a], 0),
146 InstKind::Shl(a, b) => mk(54, vec![*a, *b], 0),
147 InstKind::LShr(a, b) => mk(55, vec![*a, *b], 0),
148 InstKind::AShr(a, b) => mk(56, vec![*a, *b], 0),
149 InstKind::CountLeadingZeros(a) => mk(57, vec![*a], 0),
150 InstKind::CountTrailingZeros(a) => mk(58, vec![*a], 0),
151 InstKind::PopCount(a) => mk(59, vec![*a], 0),
152
153 // Conversions ---------------------------------------------------
154 InstKind::IntToFloat(v, fw) => mk(60, vec![*v], fw.bits() as i128),
155 InstKind::FloatToInt(v, w) => mk(61, vec![*v], w.bits() as i128),
156 InstKind::FloatExtend(v, fw) => mk(62, vec![*v], fw.bits() as i128),
157 InstKind::FloatTrunc(v, fw) => mk(63, vec![*v], fw.bits() as i128),
158 InstKind::IntExtend(v, w, sgn) => mk(64, vec![*v], (w.bits() as i128) | ((*sgn as i128) << 32)),
159 InstKind::IntTrunc(v, w) => mk(65, vec![*v], w.bits() as i128),
160
161 // Address arithmetic --------------------------------------------
162 InstKind::GetElementPtr(base, idxs) => {
163 let mut ops = vec![*base];
164 ops.extend(idxs);
165 mk(70, ops, 0)
166 }
167
168 // Aggregates ----------------------------------------------------
169 InstKind::ExtractField(agg, i) => mk(80, vec![*agg], *i as i128),
170 // InsertField produces a new aggregate value — pure but rarely
171 // duplicate; include for completeness.
172 InstKind::InsertField(agg, i, v) => mk(81, vec![*agg, *v], *i as i128),
173
174 // Impure / not handled ------------------------------------------
175 InstKind::Load(..)
176 | InstKind::Store(..)
177 | InstKind::Alloca(..)
178 | InstKind::Call(..)
179 | InstKind::RuntimeCall(..)
180 | InstKind::ConstString(..)
181 | InstKind::Undef(..) => None,
182 }
183 }
184
185 /// The local CSE pass.
186 pub struct LocalCse;
187
188 impl Pass for LocalCse {
189 fn name(&self) -> &'static str { "local-cse" }
190
191 fn run(&self, module: &mut Module) -> bool {
192 let mut changed = false;
193 for func in &mut module.functions {
194 // Collect all (old, new) rewrites first, then apply them
195 // in a *single* function walk. Audit Min-1: the previous
196 // version called `substitute_uses` once per rewrite, so a
197 // function with N CSE candidates ran N full walks for an
198 // overall O(N · function_size). The batched form is one
199 // walk with HashMap-driven renaming.
200 let mut rewrite_map: HashMap<ValueId, ValueId> = HashMap::new();
201 for block in &func.blocks {
202 let mut seen: HashMap<Key, ValueId> = HashMap::new();
203 for inst in &block.insts {
204 let Some(k) = key_of(inst) else { continue };
205 if let Some(&first) = seen.get(&k) {
206 rewrite_map.insert(inst.id, first);
207 } else {
208 seen.insert(k, inst.id);
209 }
210 }
211 }
212 if rewrite_map.is_empty() { continue; }
213
214 // Audit B-7: in **local** CSE, every entry maps a later
215 // duplicate to its block's *first* occurrence — and that
216 // first occurrence is, by construction, never itself a
217 // key in the map. So pointer chains never form, and the
218 // chase loop the first version had would always exit
219 // after zero iterations. Removed. If a future global
220 // CSE / GVN pass reuses this map shape and CAN produce
221 // chains, the chase logic will need to come back —
222 // strict-decrease in ValueId guarantees termination.
223 substitute_uses_batch(func, &rewrite_map);
224 changed = true;
225 }
226 changed
227 }
228 }
229
230 /// Replace every operand `old` with `rewrite_map[old]` (if any) in a
231 /// single walk over the function. Pairs with Min-1: avoids the
232 /// O(N · size) cost of calling `substitute_uses` once per rename.
233 /// Audit B-6: delegates to `walk::for_each_*_operand_mut`.
234 ///
235 /// Audit B-9: the closure `r` captures `rewrites` by shared
236 /// reference and is therefore `Copy`. This is what lets us pass
237 /// `r` by value into multiple `for_each_operand_mut` calls inside
238 /// the per-block loop. If the closure ever needs mutable state
239 /// (e.g., to count rewrites), it stops being `Copy` and the loop
240 /// must shift to `&mut r` — at which point the helper signatures
241 /// in `walk.rs` would need to take `&mut impl FnMut(...)` instead
242 /// of `mut r: impl FnMut(...)`.
243 fn substitute_uses_batch(func: &mut Function, rewrites: &HashMap<ValueId, ValueId>) {
244 let r = |v: &mut ValueId| {
245 if let Some(&new) = rewrites.get(v) {
246 *v = new;
247 }
248 };
249 for block in &mut func.blocks {
250 for inst in &mut block.insts {
251 for_each_operand_mut(&mut inst.kind, r);
252 }
253 if let Some(term) = &mut block.terminator {
254 for_each_terminator_operand_mut(term, r);
255 }
256 }
257 }
258
259 #[cfg(test)]
260 mod tests {
261 use super::*;
262 use crate::ir::types::{IrType, IntWidth, FloatWidth};
263 use crate::lexer::{Span, Position};
264
265 fn dummy_span() -> Span {
266 let p = Position { line: 1, col: 1 };
267 Span { start: p, end: p, file_id: 0 }
268 }
269
270 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
271 let id = f.next_value_id();
272 let entry = f.entry;
273 f.block_mut(entry).insts.push(Inst { id, kind, ty, span: dummy_span() });
274 id
275 }
276
277 #[test]
278 fn dedupes_iadd_pair() {
279 // %0 = const 1
280 // %1 = const 2
281 // %2 = iadd %0, %1
282 // %3 = iadd %0, %1 ; same as %2
283 // ret %3 → after CSE → ret %2 (and %3 is dead)
284 let mut m = Module::new("t".into());
285 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
286 let a = push(&mut f, InstKind::ConstInt(1, IntWidth::I32), IrType::Int(IntWidth::I32));
287 let b = push(&mut f, InstKind::ConstInt(2, IntWidth::I32), IrType::Int(IntWidth::I32));
288 let c1 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
289 let c2 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
290 let entry = f.entry;
291 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
292 m.add_function(f);
293
294 assert!(LocalCse.run(&mut m));
295 // Terminator now references c1 instead of c2.
296 match &m.functions[0].blocks[0].terminator {
297 Some(Terminator::Return(Some(v))) => assert_eq!(*v, c1),
298 _ => panic!(),
299 }
300 }
301
302 #[test]
303 fn commutative_iadd_dedupes_swapped_operands() {
304 // %2 = iadd %0, %1
305 // %3 = iadd %1, %0
306 // Should canonicalize to the same key.
307 let mut m = Module::new("t".into());
308 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
309 let a = push(&mut f, InstKind::ConstInt(1, IntWidth::I32), IrType::Int(IntWidth::I32));
310 let b = push(&mut f, InstKind::ConstInt(2, IntWidth::I32), IrType::Int(IntWidth::I32));
311 let c1 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
312 let c2 = push(&mut f, InstKind::IAdd(b, a), IrType::Int(IntWidth::I32));
313 let _ = c2;
314 let entry = f.entry;
315 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
316 m.add_function(f);
317
318 assert!(LocalCse.run(&mut m));
319 match &m.functions[0].blocks[0].terminator {
320 Some(Terminator::Return(Some(v))) => assert_eq!(*v, c1),
321 _ => panic!(),
322 }
323 }
324
325 #[test]
326 fn non_commutative_isub_does_not_dedupe_swapped() {
327 let mut m = Module::new("t".into());
328 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
329 let a = push(&mut f, InstKind::ConstInt(5, IntWidth::I32), IrType::Int(IntWidth::I32));
330 let b = push(&mut f, InstKind::ConstInt(3, IntWidth::I32), IrType::Int(IntWidth::I32));
331 let _c1 = push(&mut f, InstKind::ISub(a, b), IrType::Int(IntWidth::I32));
332 let c2 = push(&mut f, InstKind::ISub(b, a), IrType::Int(IntWidth::I32));
333 let entry = f.entry;
334 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
335 m.add_function(f);
336
337 // Returns c2 unchanged — no rewrite was possible.
338 assert!(!LocalCse.run(&mut m));
339 }
340
341 #[test]
342 fn keeps_load_pair_intact() {
343 // Loads must NOT be deduplicated by local CSE.
344 let mut m = Module::new("t".into());
345 let mut f = Function::new("f".into(), vec![], IrType::Void);
346 let addr = push(&mut f,
347 InstKind::Alloca(IrType::Int(IntWidth::I32)),
348 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
349 );
350 let _l1 = push(&mut f, InstKind::Load(addr), IrType::Int(IntWidth::I32));
351 let _l2 = push(&mut f, InstKind::Load(addr), IrType::Int(IntWidth::I32));
352 let entry = f.entry;
353 f.block_mut(entry).terminator = Some(Terminator::Return(None));
354 m.add_function(f);
355
356 assert!(!LocalCse.run(&mut m));
357 }
358
359 #[test]
360 fn fmul_dedupes() {
361 let mut m = Module::new("t".into());
362 let mut f = Function::new("f".into(), vec![], IrType::Float(FloatWidth::F64));
363 let a = push(&mut f, InstKind::ConstFloat(1.5, FloatWidth::F64), IrType::Float(FloatWidth::F64));
364 let b = push(&mut f, InstKind::ConstFloat(2.5, FloatWidth::F64), IrType::Float(FloatWidth::F64));
365 let m1 = push(&mut f, InstKind::FMul(a, b), IrType::Float(FloatWidth::F64));
366 let m2 = push(&mut f, InstKind::FMul(b, a), IrType::Float(FloatWidth::F64));
367 let entry = f.entry;
368 f.block_mut(entry).terminator = Some(Terminator::Return(Some(m2)));
369 m.add_function(f);
370
371 assert!(LocalCse.run(&mut m));
372 match &m.functions[0].blocks[0].terminator {
373 Some(Terminator::Return(Some(v))) => assert_eq!(*v, m1),
374 _ => panic!(),
375 }
376 }
377
378 #[test]
379 fn icmp_lt_not_canonicalized() {
380 // Lt is not commutative — must not collapse.
381 let mut m = Module::new("t".into());
382 let mut f = Function::new("f".into(), vec![], IrType::Bool);
383 let a = push(&mut f, InstKind::ConstInt(1, IntWidth::I32), IrType::Int(IntWidth::I32));
384 let b = push(&mut f, InstKind::ConstInt(2, IntWidth::I32), IrType::Int(IntWidth::I32));
385 let _c1 = push(&mut f, InstKind::ICmp(CmpOp::Lt, a, b), IrType::Bool);
386 let c2 = push(&mut f, InstKind::ICmp(CmpOp::Lt, b, a), IrType::Bool);
387 let entry = f.entry;
388 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
389 m.add_function(f);
390
391 assert!(!LocalCse.run(&mut m));
392 }
393 }
394