Rust · 18378 bytes Raw Blame History
1 //! Common-subexpression elimination (local).
2 //!
3 //! Within a basic block, two pure instructions that compute the same
4 //! expression on the same operands produce the same value. We can drop
5 //! the second one and rewrite its uses to point at the first. The
6 //! invariant we rely on is **SSA**: an instruction's `ValueId` never
7 //! changes once defined, so a downstream rewrite is just a textual
8 //! substitution.
9 //!
10 //! ## Why "local"
11 //!
12 //! Local CSE only matches duplicates inside a single basic block. It's
13 //! cheap, never crosses control-flow merges, and is provably correct
14 //! without alias analysis or dominance walks. A separate global-CSE
15 //! pass (or GVN, which subsumes CSE) will land later for matches
16 //! across dominating blocks.
17 //!
18 //! ## Side effects
19 //!
20 //! Only **pure** instructions are CSE candidates. We deliberately do
21 //! not deduplicate `Load` here — even though two loads of the same
22 //! address are usually equivalent, an intervening `Store` or `Call`
23 //! may have written the location. The future load-store-forwarding
24 //! pass will handle that with proper dependence analysis.
25 //!
26 //! ## Commutativity
27 //!
28 //! For commutative operators we canonicalize the operand pair so that
29 //! `iadd(a, b)` and `iadd(b, a)` collide on the same key. The
30 //! canonical form puts the smaller `ValueId` first. Comparisons
31 //! `Eq`/`Ne` are also commutative; `Lt`/`Le`/`Gt`/`Ge` are not.
32
33 use super::pass::Pass;
34 use super::util::{for_each_operand_mut, for_each_terminator_operand_mut};
35 use crate::ir::inst::*;
36 use crate::ir::types::IrType;
37 use std::collections::HashMap;
38
39 /// A canonical key for one pure instruction.
40 ///
41 /// Two instructions producing the same key are guaranteed to compute
42 /// the same value (modulo type, which is also encoded). The integer
43 /// "tag" disambiguates instruction kinds; the rest of the tuple
44 /// encodes the operands.
45 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
46 struct Key {
47 tag: u32,
48 operands: Vec<ValueId>,
49 /// Auxiliary integer (used for things like comparison op, bitwidth, etc.)
50 aux: i128,
51 /// Optional name for instructions whose value depends on a
52 /// symbol (currently only `GlobalAddr`). Audit Med-2: a hashed
53 /// aux risked theoretical SipHash13 collisions merging two
54 /// different globals into one ADRP+ADD; the explicit name is
55 /// soundness-bearing, not a performance hint.
56 name: Option<String>,
57 /// Result type — same expression on same operands but different
58 /// declared result type would be a bug, but we include it for safety.
59 ty: IrType,
60 }
61
62 /// Build a canonical key for an instruction. Returns `None` if the
63 /// instruction is impure or otherwise not a CSE candidate.
64 fn key_of(inst: &Inst) -> Option<Key> {
65 let mk = |tag: u32, ops: Vec<ValueId>, aux: i128| -> Option<Key> {
66 Some(Key {
67 tag,
68 operands: ops,
69 aux,
70 name: None,
71 ty: inst.ty.clone(),
72 })
73 };
74 let mk_named = |tag: u32, name: String| -> Option<Key> {
75 Some(Key {
76 tag,
77 operands: vec![],
78 aux: 0,
79 name: Some(name),
80 ty: inst.ty.clone(),
81 })
82 };
83 let canon = |a: ValueId, b: ValueId| -> Vec<ValueId> {
84 if a.0 <= b.0 {
85 vec![a, b]
86 } else {
87 vec![b, a]
88 }
89 };
90
91 match &inst.kind {
92 // Pure address-of-global — no operands, pure function of
93 // the symbol name. Two ADRP+ADD pairs to the same global
94 // inside the same block should fold. Audit Med-3 (CSE-eligible)
95 // and Med-2 (collision-free key).
96 InstKind::GlobalAddr(name) => mk_named(90, name.clone()),
97
98 // Constants ------------------------------------------------------
99 // Audit Min-4: width is carried by the `ty` field, so the
100 // aux can just be the literal value / bit pattern.
101 // Audit B-8: normalize the int value at its declared width
102 // before keying so that semantically-equal constants stored
103 // at different bit patterns dedupe. Example: `ConstInt(255, I8)`
104 // and `ConstInt(-1, I8)` both represent -1 in i8 — keying on
105 // the raw `*v` fails to dedupe them.
106 InstKind::ConstInt(v, w) => {
107 let bits = w.bits();
108 // Sign-extend at width: low `bits` bits → i64 sign-extended.
109 let signed = if bits >= 128 {
110 *v
111 } else {
112 let shift = 128 - bits;
113 (*v << shift) >> shift
114 };
115 mk(1, vec![], signed)
116 }
117 InstKind::ConstFloat(v, _) => mk(2, vec![], v.to_bits() as i128),
118 InstKind::ConstBool(b) => mk(3, vec![], if *b { 1 } else { 0 }),
119
120 // Integer arithmetic --------------------------------------------
121 InstKind::IAdd(a, b) => mk(10, canon(*a, *b), 0),
122 InstKind::ISub(a, b) => mk(11, vec![*a, *b], 0),
123 InstKind::IMul(a, b) => mk(12, canon(*a, *b), 0),
124 InstKind::IDiv(a, b) => mk(13, vec![*a, *b], 0),
125 InstKind::IMod(a, b) => mk(14, vec![*a, *b], 0),
126 InstKind::INeg(a) => mk(15, vec![*a], 0),
127
128 // Float arithmetic ----------------------------------------------
129 InstKind::FAdd(a, b) => mk(20, canon(*a, *b), 0),
130 InstKind::FSub(a, b) => mk(21, vec![*a, *b], 0),
131 InstKind::FMul(a, b) => mk(22, canon(*a, *b), 0),
132 InstKind::FDiv(a, b) => mk(23, vec![*a, *b], 0),
133 InstKind::FNeg(a) => mk(24, vec![*a], 0),
134 InstKind::FAbs(a) => mk(25, vec![*a], 0),
135 InstKind::FSqrt(a) => mk(26, vec![*a], 0),
136 InstKind::FPow(a, b) => mk(27, vec![*a, *b], 0),
137
138 // Comparisons ---------------------------------------------------
139 InstKind::ICmp(op, a, b) => {
140 let aux = *op as i128;
141 let ops = match op {
142 CmpOp::Eq | CmpOp::Ne => canon(*a, *b),
143 _ => vec![*a, *b],
144 };
145 mk(30, ops, aux)
146 }
147 InstKind::FCmp(op, a, b) => {
148 let aux = *op as i128;
149 let ops = match op {
150 CmpOp::Eq | CmpOp::Ne => canon(*a, *b),
151 _ => vec![*a, *b],
152 };
153 mk(31, ops, aux)
154 }
155
156 // Logic ---------------------------------------------------------
157 InstKind::And(a, b) => mk(40, canon(*a, *b), 0),
158 InstKind::Or(a, b) => mk(41, canon(*a, *b), 0),
159 InstKind::Not(a) => mk(42, vec![*a], 0),
160
161 InstKind::Select(c, t, f) => mk(43, vec![*c, *t, *f], 0),
162
163 // Bitwise -------------------------------------------------------
164 InstKind::BitAnd(a, b) => mk(50, canon(*a, *b), 0),
165 InstKind::BitOr(a, b) => mk(51, canon(*a, *b), 0),
166 InstKind::BitXor(a, b) => mk(52, canon(*a, *b), 0),
167 InstKind::BitNot(a) => mk(53, vec![*a], 0),
168 InstKind::Shl(a, b) => mk(54, vec![*a, *b], 0),
169 InstKind::LShr(a, b) => mk(55, vec![*a, *b], 0),
170 InstKind::AShr(a, b) => mk(56, vec![*a, *b], 0),
171 InstKind::CountLeadingZeros(a) => mk(57, vec![*a], 0),
172 InstKind::CountTrailingZeros(a) => mk(58, vec![*a], 0),
173 InstKind::PopCount(a) => mk(59, vec![*a], 0),
174
175 // Conversions ---------------------------------------------------
176 InstKind::IntToFloat(v, fw) => mk(60, vec![*v], fw.bits() as i128),
177 InstKind::FloatToInt(v, w) => mk(61, vec![*v], w.bits() as i128),
178 InstKind::FloatExtend(v, fw) => mk(62, vec![*v], fw.bits() as i128),
179 InstKind::FloatTrunc(v, fw) => mk(63, vec![*v], fw.bits() as i128),
180 InstKind::IntExtend(v, w, sgn) => {
181 mk(64, vec![*v], (w.bits() as i128) | ((*sgn as i128) << 32))
182 }
183 InstKind::IntTrunc(v, w) => mk(65, vec![*v], w.bits() as i128),
184 InstKind::PtrToInt(v) => mk(66, vec![*v], 0),
185 InstKind::IntToPtr(v, _) => mk(67, vec![*v], 0),
186
187 // Address arithmetic --------------------------------------------
188 InstKind::GetElementPtr(base, idxs) => {
189 let mut ops = vec![*base];
190 ops.extend(idxs);
191 mk(70, ops, 0)
192 }
193
194 // Aggregates ----------------------------------------------------
195 InstKind::ExtractField(agg, i) => mk(80, vec![*agg], *i as i128),
196 // InsertField produces a new aggregate value — pure but rarely
197 // duplicate; include for completeness.
198 InstKind::InsertField(agg, i, v) => mk(81, vec![*agg, *v], *i as i128),
199
200 // Impure / not handled ------------------------------------------
201 InstKind::Load(..)
202 | InstKind::Store(..)
203 | InstKind::Alloca(..)
204 | InstKind::Call(..)
205 | InstKind::RuntimeCall(..)
206 | InstKind::ConstString(..)
207 | InstKind::Undef(..) => None,
208 }
209 }
210
211 /// The local CSE pass.
212 pub struct LocalCse;
213
214 impl Pass for LocalCse {
215 fn name(&self) -> &'static str {
216 "local-cse"
217 }
218
219 fn run(&self, module: &mut Module) -> bool {
220 let mut changed = false;
221 for func in &mut module.functions {
222 // Collect all (old, new) rewrites first, then apply them
223 // in a *single* function walk. Audit Min-1: the previous
224 // version called `substitute_uses` once per rewrite, so a
225 // function with N CSE candidates ran N full walks for an
226 // overall O(N · function_size). The batched form is one
227 // walk with HashMap-driven renaming.
228 let mut rewrite_map: HashMap<ValueId, ValueId> = HashMap::new();
229 for block in &func.blocks {
230 let mut seen: HashMap<Key, ValueId> = HashMap::new();
231 for inst in &block.insts {
232 let Some(k) = key_of(inst) else { continue };
233 if let Some(&first) = seen.get(&k) {
234 rewrite_map.insert(inst.id, first);
235 } else {
236 seen.insert(k, inst.id);
237 }
238 }
239 }
240 if rewrite_map.is_empty() {
241 continue;
242 }
243
244 // Audit B-7: in **local** CSE, every entry maps a later
245 // duplicate to its block's *first* occurrence — and that
246 // first occurrence is, by construction, never itself a
247 // key in the map. So pointer chains never form, and the
248 // chase loop the first version had would always exit
249 // after zero iterations. Removed. If a future global
250 // CSE / GVN pass reuses this map shape and CAN produce
251 // chains, the chase logic will need to come back —
252 // strict-decrease in ValueId guarantees termination.
253 substitute_uses_batch(func, &rewrite_map);
254 changed = true;
255 }
256 changed
257 }
258 }
259
260 /// Replace every operand `old` with `rewrite_map[old]` (if any) in a
261 /// single walk over the function. Pairs with Min-1: avoids the
262 /// O(N · size) cost of calling `substitute_uses` once per rename.
263 /// Audit B-6: delegates to `walk::for_each_*_operand_mut`.
264 ///
265 /// Audit B-9: the closure `r` captures `rewrites` by shared
266 /// reference and is therefore `Copy`. This is what lets us pass
267 /// `r` by value into multiple `for_each_operand_mut` calls inside
268 /// the per-block loop. If the closure ever needs mutable state
269 /// (e.g., to count rewrites), it stops being `Copy` and the loop
270 /// must shift to `&mut r` — at which point the helper signatures
271 /// in `walk.rs` would need to take `&mut impl FnMut(...)` instead
272 /// of `mut r: impl FnMut(...)`.
273 fn substitute_uses_batch(func: &mut Function, rewrites: &HashMap<ValueId, ValueId>) {
274 let r = |v: &mut ValueId| {
275 if let Some(&new) = rewrites.get(v) {
276 *v = new;
277 }
278 };
279 for block in &mut func.blocks {
280 for inst in &mut block.insts {
281 for_each_operand_mut(&mut inst.kind, r);
282 }
283 if let Some(term) = &mut block.terminator {
284 for_each_terminator_operand_mut(term, r);
285 }
286 }
287 }
288
289 #[cfg(test)]
290 mod tests {
291 use super::*;
292 use crate::ir::types::{FloatWidth, IntWidth, IrType};
293 use crate::lexer::{Position, Span};
294
295 fn dummy_span() -> Span {
296 let p = Position { line: 1, col: 1 };
297 Span {
298 start: p,
299 end: p,
300 file_id: 0,
301 }
302 }
303
304 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
305 let id = f.next_value_id();
306 let entry = f.entry;
307 f.block_mut(entry).insts.push(Inst {
308 id,
309 kind,
310 ty,
311 span: dummy_span(),
312 });
313 id
314 }
315
316 #[test]
317 fn dedupes_iadd_pair() {
318 // %0 = const 1
319 // %1 = const 2
320 // %2 = iadd %0, %1
321 // %3 = iadd %0, %1 ; same as %2
322 // ret %3 → after CSE → ret %2 (and %3 is dead)
323 let mut m = Module::new("t".into());
324 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
325 let a = push(
326 &mut f,
327 InstKind::ConstInt(1, IntWidth::I32),
328 IrType::Int(IntWidth::I32),
329 );
330 let b = push(
331 &mut f,
332 InstKind::ConstInt(2, IntWidth::I32),
333 IrType::Int(IntWidth::I32),
334 );
335 let c1 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
336 let c2 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
337 let entry = f.entry;
338 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
339 m.add_function(f);
340
341 assert!(LocalCse.run(&mut m));
342 // Terminator now references c1 instead of c2.
343 match &m.functions[0].blocks[0].terminator {
344 Some(Terminator::Return(Some(v))) => assert_eq!(*v, c1),
345 _ => panic!(),
346 }
347 }
348
349 #[test]
350 fn commutative_iadd_dedupes_swapped_operands() {
351 // %2 = iadd %0, %1
352 // %3 = iadd %1, %0
353 // Should canonicalize to the same key.
354 let mut m = Module::new("t".into());
355 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
356 let a = push(
357 &mut f,
358 InstKind::ConstInt(1, IntWidth::I32),
359 IrType::Int(IntWidth::I32),
360 );
361 let b = push(
362 &mut f,
363 InstKind::ConstInt(2, IntWidth::I32),
364 IrType::Int(IntWidth::I32),
365 );
366 let c1 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
367 let c2 = push(&mut f, InstKind::IAdd(b, a), IrType::Int(IntWidth::I32));
368 let _ = c2;
369 let entry = f.entry;
370 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
371 m.add_function(f);
372
373 assert!(LocalCse.run(&mut m));
374 match &m.functions[0].blocks[0].terminator {
375 Some(Terminator::Return(Some(v))) => assert_eq!(*v, c1),
376 _ => panic!(),
377 }
378 }
379
380 #[test]
381 fn non_commutative_isub_does_not_dedupe_swapped() {
382 let mut m = Module::new("t".into());
383 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
384 let a = push(
385 &mut f,
386 InstKind::ConstInt(5, IntWidth::I32),
387 IrType::Int(IntWidth::I32),
388 );
389 let b = push(
390 &mut f,
391 InstKind::ConstInt(3, IntWidth::I32),
392 IrType::Int(IntWidth::I32),
393 );
394 let _c1 = push(&mut f, InstKind::ISub(a, b), IrType::Int(IntWidth::I32));
395 let c2 = push(&mut f, InstKind::ISub(b, a), IrType::Int(IntWidth::I32));
396 let entry = f.entry;
397 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
398 m.add_function(f);
399
400 // Returns c2 unchanged — no rewrite was possible.
401 assert!(!LocalCse.run(&mut m));
402 }
403
404 #[test]
405 fn keeps_load_pair_intact() {
406 // Loads must NOT be deduplicated by local CSE.
407 let mut m = Module::new("t".into());
408 let mut f = Function::new("f".into(), vec![], IrType::Void);
409 let addr = push(
410 &mut f,
411 InstKind::Alloca(IrType::Int(IntWidth::I32)),
412 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
413 );
414 let _l1 = push(&mut f, InstKind::Load(addr), IrType::Int(IntWidth::I32));
415 let _l2 = push(&mut f, InstKind::Load(addr), IrType::Int(IntWidth::I32));
416 let entry = f.entry;
417 f.block_mut(entry).terminator = Some(Terminator::Return(None));
418 m.add_function(f);
419
420 assert!(!LocalCse.run(&mut m));
421 }
422
423 #[test]
424 fn fmul_dedupes() {
425 let mut m = Module::new("t".into());
426 let mut f = Function::new("f".into(), vec![], IrType::Float(FloatWidth::F64));
427 let a = push(
428 &mut f,
429 InstKind::ConstFloat(1.5, FloatWidth::F64),
430 IrType::Float(FloatWidth::F64),
431 );
432 let b = push(
433 &mut f,
434 InstKind::ConstFloat(2.5, FloatWidth::F64),
435 IrType::Float(FloatWidth::F64),
436 );
437 let m1 = push(&mut f, InstKind::FMul(a, b), IrType::Float(FloatWidth::F64));
438 let m2 = push(&mut f, InstKind::FMul(b, a), IrType::Float(FloatWidth::F64));
439 let entry = f.entry;
440 f.block_mut(entry).terminator = Some(Terminator::Return(Some(m2)));
441 m.add_function(f);
442
443 assert!(LocalCse.run(&mut m));
444 match &m.functions[0].blocks[0].terminator {
445 Some(Terminator::Return(Some(v))) => assert_eq!(*v, m1),
446 _ => panic!(),
447 }
448 }
449
450 #[test]
451 fn icmp_lt_not_canonicalized() {
452 // Lt is not commutative — must not collapse.
453 let mut m = Module::new("t".into());
454 let mut f = Function::new("f".into(), vec![], IrType::Bool);
455 let a = push(
456 &mut f,
457 InstKind::ConstInt(1, IntWidth::I32),
458 IrType::Int(IntWidth::I32),
459 );
460 let b = push(
461 &mut f,
462 InstKind::ConstInt(2, IntWidth::I32),
463 IrType::Int(IntWidth::I32),
464 );
465 let _c1 = push(&mut f, InstKind::ICmp(CmpOp::Lt, a, b), IrType::Bool);
466 let c2 = push(&mut f, InstKind::ICmp(CmpOp::Lt, b, a), IrType::Bool);
467 let entry = f.entry;
468 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
469 m.add_function(f);
470
471 assert!(!LocalCse.run(&mut m));
472 }
473 }
474