Rust · 19249 bytes Raw Blame History
1 //! Common-subexpression elimination (local).
2 //!
3 //! Within a basic block, two pure instructions that compute the same
4 //! expression on the same operands produce the same value. We can drop
5 //! the second one and rewrite its uses to point at the first. The
6 //! invariant we rely on is **SSA**: an instruction's `ValueId` never
7 //! changes once defined, so a downstream rewrite is just a textual
8 //! substitution.
9 //!
10 //! ## Why "local"
11 //!
12 //! Local CSE only matches duplicates inside a single basic block. It's
13 //! cheap, never crosses control-flow merges, and is provably correct
14 //! without alias analysis or dominance walks. A separate global-CSE
15 //! pass (or GVN, which subsumes CSE) will land later for matches
16 //! across dominating blocks.
17 //!
18 //! ## Side effects
19 //!
20 //! Only **pure** instructions are CSE candidates. We deliberately do
21 //! not deduplicate `Load` here — even though two loads of the same
22 //! address are usually equivalent, an intervening `Store` or `Call`
23 //! may have written the location. The future load-store-forwarding
24 //! pass will handle that with proper dependence analysis.
25 //!
26 //! ## Commutativity
27 //!
28 //! For commutative operators we canonicalize the operand pair so that
29 //! `iadd(a, b)` and `iadd(b, a)` collide on the same key. The
30 //! canonical form puts the smaller `ValueId` first. Comparisons
31 //! `Eq`/`Ne` are also commutative; `Lt`/`Le`/`Gt`/`Ge` are not.
32
33 use super::pass::Pass;
34 use super::util::{for_each_operand_mut, for_each_terminator_operand_mut};
35 use crate::ir::inst::*;
36 use crate::ir::types::IrType;
37 use std::collections::HashMap;
38
39 /// A canonical key for one pure instruction.
40 ///
41 /// Two instructions producing the same key are guaranteed to compute
42 /// the same value (modulo type, which is also encoded). The integer
43 /// "tag" disambiguates instruction kinds; the rest of the tuple
44 /// encodes the operands.
45 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
46 struct Key {
47 tag: u32,
48 operands: Vec<ValueId>,
49 /// Auxiliary integer (used for things like comparison op, bitwidth, etc.)
50 aux: i128,
51 /// Optional name for instructions whose value depends on a
52 /// symbol (currently only `GlobalAddr`). Audit Med-2: a hashed
53 /// aux risked theoretical SipHash13 collisions merging two
54 /// different globals into one ADRP+ADD; the explicit name is
55 /// soundness-bearing, not a performance hint.
56 name: Option<String>,
57 /// Result type — same expression on same operands but different
58 /// declared result type would be a bug, but we include it for safety.
59 ty: IrType,
60 }
61
62 /// Build a canonical key for an instruction. Returns `None` if the
63 /// instruction is impure or otherwise not a CSE candidate.
64 fn key_of(inst: &Inst) -> Option<Key> {
65 let mk = |tag: u32, ops: Vec<ValueId>, aux: i128| -> Option<Key> {
66 Some(Key {
67 tag,
68 operands: ops,
69 aux,
70 name: None,
71 ty: inst.ty.clone(),
72 })
73 };
74 let mk_named = |tag: u32, name: String| -> Option<Key> {
75 Some(Key {
76 tag,
77 operands: vec![],
78 aux: 0,
79 name: Some(name),
80 ty: inst.ty.clone(),
81 })
82 };
83 let canon = |a: ValueId, b: ValueId| -> Vec<ValueId> {
84 if a.0 <= b.0 {
85 vec![a, b]
86 } else {
87 vec![b, a]
88 }
89 };
90
91 match &inst.kind {
92 // Pure address-of-global — no operands, pure function of
93 // the symbol name. Two ADRP+ADD pairs to the same global
94 // inside the same block should fold. Audit Med-3 (CSE-eligible)
95 // and Med-2 (collision-free key).
96 InstKind::GlobalAddr(name) => mk_named(90, name.clone()),
97
98 // Constants ------------------------------------------------------
99 // Audit Min-4: width is carried by the `ty` field, so the
100 // aux can just be the literal value / bit pattern.
101 // Audit B-8: normalize the int value at its declared width
102 // before keying so that semantically-equal constants stored
103 // at different bit patterns dedupe. Example: `ConstInt(255, I8)`
104 // and `ConstInt(-1, I8)` both represent -1 in i8 — keying on
105 // the raw `*v` fails to dedupe them.
106 InstKind::ConstInt(v, w) => {
107 let bits = w.bits();
108 // Sign-extend at width: low `bits` bits → i64 sign-extended.
109 let signed = if bits >= 128 {
110 *v
111 } else {
112 let shift = 128 - bits;
113 (*v << shift) >> shift
114 };
115 mk(1, vec![], signed)
116 }
117 InstKind::ConstFloat(v, _) => mk(2, vec![], v.to_bits() as i128),
118 InstKind::ConstBool(b) => mk(3, vec![], if *b { 1 } else { 0 }),
119
120 // Integer arithmetic --------------------------------------------
121 InstKind::IAdd(a, b) => mk(10, canon(*a, *b), 0),
122 InstKind::ISub(a, b) => mk(11, vec![*a, *b], 0),
123 InstKind::IMul(a, b) => mk(12, canon(*a, *b), 0),
124 InstKind::IDiv(a, b) => mk(13, vec![*a, *b], 0),
125 InstKind::IMod(a, b) => mk(14, vec![*a, *b], 0),
126 InstKind::INeg(a) => mk(15, vec![*a], 0),
127
128 // Float arithmetic ----------------------------------------------
129 InstKind::FAdd(a, b) => mk(20, canon(*a, *b), 0),
130 InstKind::FSub(a, b) => mk(21, vec![*a, *b], 0),
131 InstKind::FMul(a, b) => mk(22, canon(*a, *b), 0),
132 InstKind::FDiv(a, b) => mk(23, vec![*a, *b], 0),
133 InstKind::FNeg(a) => mk(24, vec![*a], 0),
134 InstKind::FAbs(a) => mk(25, vec![*a], 0),
135 InstKind::FSqrt(a) => mk(26, vec![*a], 0),
136 InstKind::FPow(a, b) => mk(27, vec![*a, *b], 0),
137
138 // Comparisons ---------------------------------------------------
139 InstKind::ICmp(op, a, b) => {
140 let aux = *op as i128;
141 let ops = match op {
142 CmpOp::Eq | CmpOp::Ne => canon(*a, *b),
143 _ => vec![*a, *b],
144 };
145 mk(30, ops, aux)
146 }
147 InstKind::FCmp(op, a, b) => {
148 let aux = *op as i128;
149 let ops = match op {
150 CmpOp::Eq | CmpOp::Ne => canon(*a, *b),
151 _ => vec![*a, *b],
152 };
153 mk(31, ops, aux)
154 }
155
156 // Logic ---------------------------------------------------------
157 InstKind::And(a, b) => mk(40, canon(*a, *b), 0),
158 InstKind::Or(a, b) => mk(41, canon(*a, *b), 0),
159 InstKind::Not(a) => mk(42, vec![*a], 0),
160
161 InstKind::Select(c, t, f) => mk(43, vec![*c, *t, *f], 0),
162
163 // Bitwise -------------------------------------------------------
164 InstKind::BitAnd(a, b) => mk(50, canon(*a, *b), 0),
165 InstKind::BitOr(a, b) => mk(51, canon(*a, *b), 0),
166 InstKind::BitXor(a, b) => mk(52, canon(*a, *b), 0),
167 InstKind::BitNot(a) => mk(53, vec![*a], 0),
168 InstKind::Shl(a, b) => mk(54, vec![*a, *b], 0),
169 InstKind::LShr(a, b) => mk(55, vec![*a, *b], 0),
170 InstKind::AShr(a, b) => mk(56, vec![*a, *b], 0),
171 InstKind::CountLeadingZeros(a) => mk(57, vec![*a], 0),
172 InstKind::CountTrailingZeros(a) => mk(58, vec![*a], 0),
173 InstKind::PopCount(a) => mk(59, vec![*a], 0),
174
175 // Conversions ---------------------------------------------------
176 InstKind::IntToFloat(v, fw) => mk(60, vec![*v], fw.bits() as i128),
177 InstKind::FloatToInt(v, w) => mk(61, vec![*v], w.bits() as i128),
178 InstKind::FloatExtend(v, fw) => mk(62, vec![*v], fw.bits() as i128),
179 InstKind::FloatTrunc(v, fw) => mk(63, vec![*v], fw.bits() as i128),
180 InstKind::IntExtend(v, w, sgn) => {
181 mk(64, vec![*v], (w.bits() as i128) | ((*sgn as i128) << 32))
182 }
183 InstKind::IntTrunc(v, w) => mk(65, vec![*v], w.bits() as i128),
184 InstKind::PtrToInt(v) => mk(66, vec![*v], 0),
185 InstKind::IntToPtr(v, _) => mk(67, vec![*v], 0),
186
187 // Address arithmetic --------------------------------------------
188 InstKind::GetElementPtr(base, idxs) => {
189 let mut ops = vec![*base];
190 ops.extend(idxs);
191 mk(70, ops, 0)
192 }
193
194 // Aggregates ----------------------------------------------------
195 InstKind::ExtractField(agg, i) => mk(80, vec![*agg], *i as i128),
196 // InsertField produces a new aggregate value — pure but rarely
197 // duplicate; include for completeness.
198 InstKind::InsertField(agg, i, v) => mk(81, vec![*agg, *v], *i as i128),
199
200 // Impure / not handled ------------------------------------------
201 InstKind::Load(..)
202 | InstKind::Store(..)
203 | InstKind::Alloca(..)
204 | InstKind::Call(..)
205 | InstKind::RuntimeCall(..)
206 | InstKind::ConstString(..)
207 | InstKind::Undef(..) => None,
208
209 // Vector ops not yet CSE-eligible — Stage 1 lands the
210 // type/instruction system; CSE keying for SIMD lands when the
211 // vectorizer starts producing them.
212 InstKind::VAdd(..)
213 | InstKind::VSub(..)
214 | InstKind::VMul(..)
215 | InstKind::VDiv(..)
216 | InstKind::VNeg(..)
217 | InstKind::VAbs(..)
218 | InstKind::VSqrt(..)
219 | InstKind::VFma(..)
220 | InstKind::VSelect(..)
221 | InstKind::VMin(..)
222 | InstKind::VMax(..)
223 | InstKind::VICmp(..)
224 | InstKind::VFCmp(..)
225 | InstKind::VLoad(..)
226 | InstKind::VStore(..)
227 | InstKind::VBitcast(..)
228 | InstKind::VExtract(..)
229 | InstKind::VInsert(..)
230 | InstKind::VBroadcast(..)
231 | InstKind::VReduceSum(..)
232 | InstKind::VReduceMin(..)
233 | InstKind::VReduceMax(..) => None,
234 }
235 }
236
237 /// The local CSE pass.
238 pub struct LocalCse;
239
240 impl Pass for LocalCse {
241 fn name(&self) -> &'static str {
242 "local-cse"
243 }
244
245 fn run(&self, module: &mut Module) -> bool {
246 let mut changed = false;
247 for func in &mut module.functions {
248 // Collect all (old, new) rewrites first, then apply them
249 // in a *single* function walk. Audit Min-1: the previous
250 // version called `substitute_uses` once per rewrite, so a
251 // function with N CSE candidates ran N full walks for an
252 // overall O(N · function_size). The batched form is one
253 // walk with HashMap-driven renaming.
254 let mut rewrite_map: HashMap<ValueId, ValueId> = HashMap::new();
255 for block in &func.blocks {
256 let mut seen: HashMap<Key, ValueId> = HashMap::new();
257 for inst in &block.insts {
258 let Some(k) = key_of(inst) else { continue };
259 if let Some(&first) = seen.get(&k) {
260 rewrite_map.insert(inst.id, first);
261 } else {
262 seen.insert(k, inst.id);
263 }
264 }
265 }
266 if rewrite_map.is_empty() {
267 continue;
268 }
269
270 // Audit B-7: in **local** CSE, every entry maps a later
271 // duplicate to its block's *first* occurrence — and that
272 // first occurrence is, by construction, never itself a
273 // key in the map. So pointer chains never form, and the
274 // chase loop the first version had would always exit
275 // after zero iterations. Removed. If a future global
276 // CSE / GVN pass reuses this map shape and CAN produce
277 // chains, the chase logic will need to come back —
278 // strict-decrease in ValueId guarantees termination.
279 substitute_uses_batch(func, &rewrite_map);
280 changed = true;
281 }
282 changed
283 }
284 }
285
286 /// Replace every operand `old` with `rewrite_map[old]` (if any) in a
287 /// single walk over the function. Pairs with Min-1: avoids the
288 /// O(N · size) cost of calling `substitute_uses` once per rename.
289 /// Audit B-6: delegates to `walk::for_each_*_operand_mut`.
290 ///
291 /// Audit B-9: the closure `r` captures `rewrites` by shared
292 /// reference and is therefore `Copy`. This is what lets us pass
293 /// `r` by value into multiple `for_each_operand_mut` calls inside
294 /// the per-block loop. If the closure ever needs mutable state
295 /// (e.g., to count rewrites), it stops being `Copy` and the loop
296 /// must shift to `&mut r` — at which point the helper signatures
297 /// in `walk.rs` would need to take `&mut impl FnMut(...)` instead
298 /// of `mut r: impl FnMut(...)`.
299 fn substitute_uses_batch(func: &mut Function, rewrites: &HashMap<ValueId, ValueId>) {
300 let r = |v: &mut ValueId| {
301 if let Some(&new) = rewrites.get(v) {
302 *v = new;
303 }
304 };
305 for block in &mut func.blocks {
306 for inst in &mut block.insts {
307 for_each_operand_mut(&mut inst.kind, r);
308 }
309 if let Some(term) = &mut block.terminator {
310 for_each_terminator_operand_mut(term, r);
311 }
312 }
313 }
314
315 #[cfg(test)]
316 mod tests {
317 use super::*;
318 use crate::ir::types::{FloatWidth, IntWidth, IrType};
319 use crate::lexer::{Position, Span};
320
321 fn dummy_span() -> Span {
322 let p = Position { line: 1, col: 1 };
323 Span {
324 start: p,
325 end: p,
326 file_id: 0,
327 }
328 }
329
330 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
331 let id = f.next_value_id();
332 let entry = f.entry;
333 f.block_mut(entry).insts.push(Inst {
334 id,
335 kind,
336 ty,
337 span: dummy_span(),
338 });
339 id
340 }
341
342 #[test]
343 fn dedupes_iadd_pair() {
344 // %0 = const 1
345 // %1 = const 2
346 // %2 = iadd %0, %1
347 // %3 = iadd %0, %1 ; same as %2
348 // ret %3 → after CSE → ret %2 (and %3 is dead)
349 let mut m = Module::new("t".into());
350 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
351 let a = push(
352 &mut f,
353 InstKind::ConstInt(1, IntWidth::I32),
354 IrType::Int(IntWidth::I32),
355 );
356 let b = push(
357 &mut f,
358 InstKind::ConstInt(2, IntWidth::I32),
359 IrType::Int(IntWidth::I32),
360 );
361 let c1 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
362 let c2 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
363 let entry = f.entry;
364 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
365 m.add_function(f);
366
367 assert!(LocalCse.run(&mut m));
368 // Terminator now references c1 instead of c2.
369 match &m.functions[0].blocks[0].terminator {
370 Some(Terminator::Return(Some(v))) => assert_eq!(*v, c1),
371 _ => panic!(),
372 }
373 }
374
375 #[test]
376 fn commutative_iadd_dedupes_swapped_operands() {
377 // %2 = iadd %0, %1
378 // %3 = iadd %1, %0
379 // Should canonicalize to the same key.
380 let mut m = Module::new("t".into());
381 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
382 let a = push(
383 &mut f,
384 InstKind::ConstInt(1, IntWidth::I32),
385 IrType::Int(IntWidth::I32),
386 );
387 let b = push(
388 &mut f,
389 InstKind::ConstInt(2, IntWidth::I32),
390 IrType::Int(IntWidth::I32),
391 );
392 let c1 = push(&mut f, InstKind::IAdd(a, b), IrType::Int(IntWidth::I32));
393 let c2 = push(&mut f, InstKind::IAdd(b, a), IrType::Int(IntWidth::I32));
394 let _ = c2;
395 let entry = f.entry;
396 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
397 m.add_function(f);
398
399 assert!(LocalCse.run(&mut m));
400 match &m.functions[0].blocks[0].terminator {
401 Some(Terminator::Return(Some(v))) => assert_eq!(*v, c1),
402 _ => panic!(),
403 }
404 }
405
406 #[test]
407 fn non_commutative_isub_does_not_dedupe_swapped() {
408 let mut m = Module::new("t".into());
409 let mut f = Function::new("f".into(), vec![], IrType::Int(IntWidth::I32));
410 let a = push(
411 &mut f,
412 InstKind::ConstInt(5, IntWidth::I32),
413 IrType::Int(IntWidth::I32),
414 );
415 let b = push(
416 &mut f,
417 InstKind::ConstInt(3, IntWidth::I32),
418 IrType::Int(IntWidth::I32),
419 );
420 let _c1 = push(&mut f, InstKind::ISub(a, b), IrType::Int(IntWidth::I32));
421 let c2 = push(&mut f, InstKind::ISub(b, a), IrType::Int(IntWidth::I32));
422 let entry = f.entry;
423 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
424 m.add_function(f);
425
426 // Returns c2 unchanged — no rewrite was possible.
427 assert!(!LocalCse.run(&mut m));
428 }
429
430 #[test]
431 fn keeps_load_pair_intact() {
432 // Loads must NOT be deduplicated by local CSE.
433 let mut m = Module::new("t".into());
434 let mut f = Function::new("f".into(), vec![], IrType::Void);
435 let addr = push(
436 &mut f,
437 InstKind::Alloca(IrType::Int(IntWidth::I32)),
438 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
439 );
440 let _l1 = push(&mut f, InstKind::Load(addr), IrType::Int(IntWidth::I32));
441 let _l2 = push(&mut f, InstKind::Load(addr), IrType::Int(IntWidth::I32));
442 let entry = f.entry;
443 f.block_mut(entry).terminator = Some(Terminator::Return(None));
444 m.add_function(f);
445
446 assert!(!LocalCse.run(&mut m));
447 }
448
449 #[test]
450 fn fmul_dedupes() {
451 let mut m = Module::new("t".into());
452 let mut f = Function::new("f".into(), vec![], IrType::Float(FloatWidth::F64));
453 let a = push(
454 &mut f,
455 InstKind::ConstFloat(1.5, FloatWidth::F64),
456 IrType::Float(FloatWidth::F64),
457 );
458 let b = push(
459 &mut f,
460 InstKind::ConstFloat(2.5, FloatWidth::F64),
461 IrType::Float(FloatWidth::F64),
462 );
463 let m1 = push(&mut f, InstKind::FMul(a, b), IrType::Float(FloatWidth::F64));
464 let m2 = push(&mut f, InstKind::FMul(b, a), IrType::Float(FloatWidth::F64));
465 let entry = f.entry;
466 f.block_mut(entry).terminator = Some(Terminator::Return(Some(m2)));
467 m.add_function(f);
468
469 assert!(LocalCse.run(&mut m));
470 match &m.functions[0].blocks[0].terminator {
471 Some(Terminator::Return(Some(v))) => assert_eq!(*v, m1),
472 _ => panic!(),
473 }
474 }
475
476 #[test]
477 fn icmp_lt_not_canonicalized() {
478 // Lt is not commutative — must not collapse.
479 let mut m = Module::new("t".into());
480 let mut f = Function::new("f".into(), vec![], IrType::Bool);
481 let a = push(
482 &mut f,
483 InstKind::ConstInt(1, IntWidth::I32),
484 IrType::Int(IntWidth::I32),
485 );
486 let b = push(
487 &mut f,
488 InstKind::ConstInt(2, IntWidth::I32),
489 IrType::Int(IntWidth::I32),
490 );
491 let _c1 = push(&mut f, InstKind::ICmp(CmpOp::Lt, a, b), IrType::Bool);
492 let c2 = push(&mut f, InstKind::ICmp(CmpOp::Lt, b, a), IrType::Bool);
493 let entry = f.entry;
494 f.block_mut(entry).terminator = Some(Terminator::Return(Some(c2)));
495 m.add_function(f);
496
497 assert!(!LocalCse.run(&mut m));
498 }
499 }
500