Rust · 61898 bytes Raw Blame History
1 //! Global Value Numbering (GVN).
2 //!
3 //! Extends local CSE to cross-block redundancy elimination. Processes
4 //! blocks in dominator-tree preorder, maintaining a scoped hash table
5 //! of value numbers. When a block computes an expression already
6 //! available from a dominating block, the redundant instruction is
7 //! replaced with the dominating definition.
8 //!
9 //! Uses the same canonical Key structure as local CSE (cse.rs).
10
11 use super::pass::Pass;
12 use crate::ir::inst::*;
13 use crate::ir::types::IrType;
14 use crate::ir::walk::{compute_immediate_dominators, dominator_tree_children};
15 use std::collections::HashMap;
16
17 pub struct Gvn;
18
19 impl Pass for Gvn {
20 fn name(&self) -> &'static str {
21 "gvn"
22 }
23
24 fn run(&self, module: &mut Module) -> bool {
25 let pure_calls: Vec<PureCallPolicy> = module
26 .functions
27 .iter()
28 .map(PureCallPolicy::for_function)
29 .collect();
30 let mut changed = false;
31 for func in &mut module.functions {
32 // Sprint 15 Stage 2: congruence closure via fixpoint.
33 // After the first pass merges some values, downstream
34 // instructions may now have equivalent operands; their
35 // keys must be recomputed with the new `replacements`
36 // table. Iterating gvn_function until it stops mutating
37 // achieves the fixpoint without rewriting the per-block
38 // dominator-tree walk. Cap at 8 rounds to bound the worst
39 // case (in practice 1-2 rounds covers it).
40 let mut local_changed = false;
41 for _ in 0..8 {
42 if gvn_function(func, &pure_calls) {
43 local_changed = true;
44 } else {
45 break;
46 }
47 }
48 if local_changed {
49 changed = true;
50 }
51 }
52 changed
53 }
54 }
55
56 /// Canonical key for value numbering (same as CSE).
57 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
58 struct Key {
59 tag: u32,
60 operands: Vec<ValueId>,
61 aux: i128,
62 name: Option<String>,
63 ty: IrType,
64 }
65
66 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
67 enum PureArgPolicy {
68 ByValue,
69 ReadOnlyWrapperPtr,
70 Unsupported,
71 }
72
73 #[derive(Debug, Clone)]
74 struct PureCallPolicy {
75 reusable: bool,
76 arg_policies: Vec<PureArgPolicy>,
77 }
78
79 impl PureCallPolicy {
80 fn for_function(func: &Function) -> Self {
81 let arg_policies = func
82 .params
83 .iter()
84 .map(|param| classify_pure_arg(func, param))
85 .collect::<Vec<_>>();
86 let reusable = func.is_pure
87 && !matches!(func.return_type, IrType::Void)
88 && !func.return_type.is_ptr()
89 && arg_policies
90 .iter()
91 .all(|policy| *policy != PureArgPolicy::Unsupported)
92 && !reads_non_argument_memory(func);
93 Self {
94 reusable,
95 arg_policies,
96 }
97 }
98 }
99
100 /// True if the function body touches any state outside of its
101 /// arguments — e.g. a module global via `GlobalAddr`, an external
102 /// call, or a runtime call that reads mutable state.
103 ///
104 /// Per F2018 15.7 a PURE function is free to *read* common blocks
105 /// and module variables; it just can't *write* them. The Fortran
106 /// `is_pure` flag therefore doesn't imply "result depends only on
107 /// arguments" — a pure recursive accumulator over a module
108 /// variable is fully legal, and GVN must not hash-cons such a call
109 /// across an intervening store to that variable.
110 ///
111 /// We conservatively reject any callee that contains a `GlobalAddr`
112 /// or a non-Internal `Call` / `RuntimeCall`. The existing
113 /// argument-policy machinery already handles the "reads through
114 /// argument pointer" case via `ReadOnlyWrapperPtr`.
115 /// External math intrinsics whose `Call` sites are pure functions of
116 /// their by-value arguments — emitting one is identical to emitting
117 /// any other equivalent one, so they don't disqualify their hosting
118 /// function from GVN's pure-call CSE policy. Names match the libm /
119 /// armfortas runtime symbols emitted by lower.
120 const PURE_EXTERNAL_INTRINSICS: &[&str] = &[
121 // libm scalar math (single + double).
122 "sinf", "cosf", "tanf", "asinf", "acosf", "atanf", "atan2f",
123 "sinhf", "coshf", "tanhf", "expf", "expm1f", "logf", "log2f",
124 "log10f", "log1pf", "sqrtf", "cbrtf", "fabsf", "ceilf", "floorf",
125 "roundf", "truncf", "powf", "fmodf", "hypotf", "copysignf",
126 "sin", "cos", "tan", "asin", "acos", "atan", "atan2",
127 "sinh", "cosh", "tanh", "exp", "expm1", "log", "log2",
128 "log10", "log1p", "sqrt", "cbrt", "fabs", "ceil", "floor",
129 "round", "trunc", "pow", "fmod", "hypot", "copysign",
130 ];
131
132 fn is_pure_external_intrinsic(name: &str) -> bool {
133 PURE_EXTERNAL_INTRINSICS.contains(&name)
134 }
135
136 fn reads_non_argument_memory(func: &Function) -> bool {
137 for block in &func.blocks {
138 for inst in &block.insts {
139 match &inst.kind {
140 InstKind::GlobalAddr(_) => return true,
141 InstKind::Call(FuncRef::External(name), _)
142 if !is_pure_external_intrinsic(name) =>
143 {
144 return true;
145 }
146 InstKind::RuntimeCall(_, _) => return true,
147 _ => {}
148 }
149 }
150 }
151 false
152 }
153
154 fn is_scalar_type(ty: &IrType) -> bool {
155 matches!(ty, IrType::Bool | IrType::Int(_) | IrType::Float(_))
156 }
157
158 fn inst_map(func: &Function) -> HashMap<ValueId, &Inst> {
159 func.blocks
160 .iter()
161 .flat_map(|block| block.insts.iter())
162 .map(|inst| (inst.id, inst))
163 .collect()
164 }
165
166 fn classify_pure_arg(func: &Function, param: &Param) -> PureArgPolicy {
167 if !param.ty.is_ptr() {
168 return PureArgPolicy::ByValue;
169 }
170 let IrType::Ptr(inner) = &param.ty else {
171 return PureArgPolicy::Unsupported;
172 };
173 if !is_scalar_type(inner.as_ref()) {
174 return PureArgPolicy::Unsupported;
175 }
176 if pointer_param_is_read_only_scalar(func, param.id) {
177 PureArgPolicy::ReadOnlyWrapperPtr
178 } else {
179 PureArgPolicy::Unsupported
180 }
181 }
182
183 fn pointer_param_is_read_only_scalar(func: &Function, param_id: ValueId) -> bool {
184 let defs = inst_map(func);
185 let mut shadow_slots = Vec::new();
186
187 for block in &func.blocks {
188 for inst in &block.insts {
189 let uses = super::util::inst_uses(&inst.kind);
190 if !uses.contains(&param_id) {
191 continue;
192 }
193 match &inst.kind {
194 InstKind::Store(value, addr) if *value == param_id => {
195 let Some(def) = defs.get(addr) else {
196 return false;
197 };
198 match &def.kind {
199 InstKind::Alloca(slot_ty) => {
200 let IrType::Ptr(inner) = slot_ty else {
201 return false;
202 };
203 if !is_scalar_type(inner.as_ref()) {
204 return false;
205 }
206 shadow_slots.push(*addr);
207 }
208 _ => return false,
209 }
210 }
211 // After mem2reg, the entry shadow slot often folds away and
212 // the PURE callee just reads the incoming by-ref scalar
213 // directly.
214 InstKind::Load(addr) if *addr == param_id => {}
215 _ => return false,
216 }
217 }
218 if let Some(term) = &block.terminator {
219 if super::util::terminator_uses(term).contains(&param_id) {
220 return false;
221 }
222 }
223 }
224
225 for slot in shadow_slots {
226 let mut alias_ptrs = Vec::new();
227 for block in &func.blocks {
228 for inst in &block.insts {
229 let uses = super::util::inst_uses(&inst.kind);
230 if !uses.contains(&slot) {
231 continue;
232 }
233 match &inst.kind {
234 InstKind::Store(value, addr) if *value == param_id && *addr == slot => {}
235 InstKind::Load(addr) if *addr == slot => alias_ptrs.push(inst.id),
236 _ => return false,
237 }
238 }
239 if let Some(term) = &block.terminator {
240 if super::util::terminator_uses(term).contains(&slot) {
241 return false;
242 }
243 }
244 }
245
246 for alias in alias_ptrs {
247 for block in &func.blocks {
248 for inst in &block.insts {
249 let uses = super::util::inst_uses(&inst.kind);
250 if !uses.contains(&alias) {
251 continue;
252 }
253 match &inst.kind {
254 InstKind::Load(addr) if *addr == alias => {}
255 _ => return false,
256 }
257 }
258 if let Some(term) = &block.terminator {
259 if super::util::terminator_uses(term).contains(&alias) {
260 return false;
261 }
262 }
263 }
264 }
265 }
266
267 true
268 }
269
270 fn wrapper_alloca_values(
271 func: &Function,
272 pure_calls: &[PureCallPolicy],
273 ) -> HashMap<ValueId, ValueId> {
274 let mut wrappers = HashMap::new();
275 let candidates: Vec<ValueId> = func
276 .blocks
277 .iter()
278 .flat_map(|block| block.insts.iter())
279 .filter_map(|inst| match &inst.kind {
280 InstKind::Alloca(inner_ty) if is_scalar_type(inner_ty) => Some(inst.id),
281 _ => None,
282 })
283 .collect();
284
285 'candidate: for alloca_id in candidates {
286 let mut stored_value = None;
287
288 for scan_block in &func.blocks {
289 for scan_inst in &scan_block.insts {
290 let uses = super::util::inst_uses(&scan_inst.kind);
291 if !uses.contains(&alloca_id) {
292 continue;
293 }
294 match &scan_inst.kind {
295 InstKind::Store(value, addr) if *addr == alloca_id => {
296 if stored_value.replace(*value).is_some() {
297 continue 'candidate;
298 }
299 }
300 InstKind::Call(FuncRef::Internal(idx), args) => {
301 let Some(policy) = pure_calls.get(*idx as usize) else {
302 continue 'candidate;
303 };
304 if !policy.reusable {
305 continue 'candidate;
306 }
307 let valid_arg = args.iter().enumerate().any(|(arg_idx, arg)| {
308 *arg == alloca_id
309 && policy.arg_policies.get(arg_idx).copied()
310 == Some(PureArgPolicy::ReadOnlyWrapperPtr)
311 });
312 if !valid_arg {
313 continue 'candidate;
314 }
315 }
316 _ => continue 'candidate,
317 }
318 }
319 if let Some(term) = &scan_block.terminator {
320 if super::util::terminator_uses(term).contains(&alloca_id) {
321 continue 'candidate;
322 }
323 }
324 }
325
326 if let Some(value) = stored_value {
327 wrappers.insert(alloca_id, value);
328 }
329 }
330
331 wrappers
332 }
333
334 fn resolve_value(map: &HashMap<ValueId, ValueId>, mut value: ValueId) -> ValueId {
335 while let Some(&next) = map.get(&value) {
336 if next == value {
337 break;
338 }
339 value = next;
340 }
341 value
342 }
343
344 /// Tag identifying an associative+commutative integer/bitwise op.
345 /// Used to detect when an operand of an outer op of the same kind can
346 /// be flattened into a multi-operand list. FP ops are intentionally
347 /// excluded — under default rounding mode `(a+b)+c` is not equivalent
348 /// to `a+(b+c)`, so reordering would change results.
349 fn assoc_tag(kind: &InstKind) -> Option<u32> {
350 match kind {
351 InstKind::IAdd(..) => Some(1),
352 InstKind::IMul(..) => Some(3),
353 InstKind::And(..) => Some(30),
354 InstKind::Or(..) => Some(31),
355 InstKind::BitAnd(..) => Some(40),
356 InstKind::BitOr(..) => Some(41),
357 InstKind::BitXor(..) => Some(42),
358 _ => None,
359 }
360 }
361
362 fn assoc_inputs(kind: &InstKind) -> Option<(ValueId, ValueId)> {
363 match kind {
364 InstKind::IAdd(a, b)
365 | InstKind::IMul(a, b)
366 | InstKind::And(a, b)
367 | InstKind::Or(a, b)
368 | InstKind::BitAnd(a, b)
369 | InstKind::BitOr(a, b)
370 | InstKind::BitXor(a, b) => Some((*a, *b)),
371 _ => None,
372 }
373 }
374
375 /// Walk `a` and `b`, expanding any operand whose defining instruction
376 /// is the same associative-commutative op (after replacements). The
377 /// returned vector is sorted by ValueId so reorderings of the same
378 /// chain hash identically. `(a+b)+c` and `a+(b+c)` both flatten to
379 /// the sorted multiset `{a, b, c}`.
380 fn flatten_assoc(
381 tag: u32,
382 a: ValueId,
383 b: ValueId,
384 defs: &HashMap<ValueId, &Inst>,
385 replacements: &HashMap<ValueId, ValueId>,
386 ) -> Vec<ValueId> {
387 let mut out = Vec::new();
388 let push = |v: ValueId, out: &mut Vec<ValueId>, defs: &HashMap<ValueId, &Inst>| {
389 let r = resolve_value(replacements, v);
390 if let Some(inst) = defs.get(&r) {
391 if assoc_tag(&inst.kind) == Some(tag) {
392 if let Some((aa, bb)) = assoc_inputs(&inst.kind) {
393 out.extend(flatten_assoc(tag, aa, bb, defs, replacements));
394 return;
395 }
396 }
397 }
398 out.push(r);
399 };
400 push(a, &mut out, defs);
401 push(b, &mut out, defs);
402 out.sort_by_key(|v| v.0);
403 out
404 }
405
406 fn key_of(
407 inst: &Inst,
408 replacements: &HashMap<ValueId, ValueId>,
409 pure_calls: &[PureCallPolicy],
410 wrapper_values: &HashMap<ValueId, ValueId>,
411 defs: &HashMap<ValueId, &Inst>,
412 ) -> Option<Key> {
413 let mk = |tag: u32, ops: Vec<ValueId>, aux: i128| -> Option<Key> {
414 Some(Key {
415 tag,
416 operands: ops,
417 aux,
418 name: None,
419 ty: inst.ty.clone(),
420 })
421 };
422 let mk_named = |tag: u32, name: String| -> Option<Key> {
423 Some(Key {
424 tag,
425 operands: vec![],
426 aux: 0,
427 name: Some(name),
428 ty: inst.ty.clone(),
429 })
430 };
431 let remap = |v: ValueId| resolve_value(replacements, v);
432 fn canon(a: ValueId, b: ValueId) -> Vec<ValueId> {
433 if a.0 <= b.0 {
434 vec![a, b]
435 } else {
436 vec![b, a]
437 }
438 }
439
440 match &inst.kind {
441 // Pure arithmetic — associative+commutative ops are flattened
442 // and sorted so `(a+b)+c` ≡ `a+(b+c)` ≡ `b+(a+c)` etc.
443 InstKind::IAdd(a, b) => mk(1, flatten_assoc(1, *a, *b, defs, replacements), 0),
444 InstKind::ISub(a, b) => mk(2, vec![remap(*a), remap(*b)], 0),
445 InstKind::IMul(a, b) => mk(3, flatten_assoc(3, *a, *b, defs, replacements), 0),
446 InstKind::IDiv(a, b) => mk(4, vec![remap(*a), remap(*b)], 0),
447 InstKind::IMod(a, b) => mk(5, vec![remap(*a), remap(*b)], 0),
448 InstKind::INeg(a) => mk(6, vec![remap(*a)], 0),
449 InstKind::FAdd(a, b) => mk(10, canon(remap(*a), remap(*b)), 0),
450 InstKind::FSub(a, b) => mk(11, vec![remap(*a), remap(*b)], 0),
451 InstKind::FMul(a, b) => mk(12, canon(remap(*a), remap(*b)), 0),
452 InstKind::FDiv(a, b) => mk(13, vec![remap(*a), remap(*b)], 0),
453 InstKind::FNeg(a) => mk(14, vec![remap(*a)], 0),
454 InstKind::FAbs(a) => mk(15, vec![remap(*a)], 0),
455 InstKind::FSqrt(a) => mk(16, vec![remap(*a)], 0),
456 InstKind::FPow(a, b) => mk(17, vec![remap(*a), remap(*b)], 0),
457 InstKind::ICmp(op, a, b) => {
458 let op_val = *op as i128;
459 match op {
460 CmpOp::Eq | CmpOp::Ne => mk(20, canon(remap(*a), remap(*b)), op_val),
461 _ => mk(20, vec![remap(*a), remap(*b)], op_val),
462 }
463 }
464 InstKind::FCmp(op, a, b) => mk(21, vec![remap(*a), remap(*b)], *op as i128),
465 InstKind::And(a, b) => mk(30, flatten_assoc(30, *a, *b, defs, replacements), 0),
466 InstKind::Or(a, b) => mk(31, flatten_assoc(31, *a, *b, defs, replacements), 0),
467 InstKind::Not(a) => mk(32, vec![remap(*a)], 0),
468 InstKind::Select(c, t, f) => mk(33, vec![remap(*c), remap(*t), remap(*f)], 0),
469 InstKind::BitAnd(a, b) => mk(40, flatten_assoc(40, *a, *b, defs, replacements), 0),
470 InstKind::BitOr(a, b) => mk(41, flatten_assoc(41, *a, *b, defs, replacements), 0),
471 InstKind::BitXor(a, b) => mk(42, flatten_assoc(42, *a, *b, defs, replacements), 0),
472 InstKind::BitNot(a) => mk(43, vec![remap(*a)], 0),
473 InstKind::Shl(a, b) => mk(44, vec![remap(*a), remap(*b)], 0),
474 InstKind::LShr(a, b) => mk(45, vec![remap(*a), remap(*b)], 0),
475 InstKind::AShr(a, b) => mk(46, vec![remap(*a), remap(*b)], 0),
476 InstKind::CountLeadingZeros(a) => mk(47, vec![remap(*a)], 0),
477 InstKind::CountTrailingZeros(a) => mk(48, vec![remap(*a)], 0),
478 InstKind::PopCount(a) => mk(49, vec![remap(*a)], 0),
479 // Conversions.
480 InstKind::IntToFloat(a, w) => mk(50, vec![remap(*a)], w.bits() as i128),
481 InstKind::FloatToInt(a, w) => mk(51, vec![remap(*a)], w.bits() as i128),
482 InstKind::FloatExtend(a, w) => mk(52, vec![remap(*a)], w.bits() as i128),
483 InstKind::FloatTrunc(a, w) => mk(53, vec![remap(*a)], w.bits() as i128),
484 InstKind::IntExtend(a, w, s) => mk(
485 54,
486 vec![remap(*a)],
487 (w.bits() as i128) * if *s { 1 } else { -1 },
488 ),
489 InstKind::IntTrunc(a, w) => mk(55, vec![remap(*a)], w.bits() as i128),
490 InstKind::PtrToInt(a) => mk(56, vec![remap(*a)], 0),
491 InstKind::IntToPtr(a, _) => mk(57, vec![remap(*a)], 0),
492 // Constants.
493 InstKind::ConstInt(v, w) => {
494 let bits = w.bits();
495 let signed = if bits >= 128 {
496 *v
497 } else {
498 let shift = 128 - bits;
499 (*v << shift) >> shift
500 };
501 mk(60, vec![], signed)
502 }
503 InstKind::ConstFloat(v, w) => mk(61, vec![], ((*v).to_bits() as i128) ^ (w.bits() as i128)),
504 InstKind::ConstBool(v) => mk(62, vec![], *v as i128),
505 // GlobalAddr.
506 InstKind::GlobalAddr(name) => mk_named(70, name.clone()),
507 // GEP.
508 InstKind::GetElementPtr(base, idxs) => {
509 let mut ops = vec![remap(*base)];
510 ops.extend(idxs.iter().copied().map(remap));
511 mk(80, ops, 0)
512 }
513 // Reusable PURE calls: by-value arguments only, real value return,
514 // and no pointer result identity to preserve. For lowered
515 // Fortran scalar-by-reference args, look through the wrapper
516 // alloca to the stored SSA value when the callee only reads the
517 // pointee as a scalar input.
518 InstKind::Call(FuncRef::Internal(idx), args) => {
519 let policy = pure_calls.get(*idx as usize)?;
520 if !policy.reusable || policy.arg_policies.len() != args.len() {
521 return None;
522 }
523 let mut ops = Vec::with_capacity(args.len());
524 for (arg, arg_policy) in args.iter().zip(policy.arg_policies.iter()) {
525 match arg_policy {
526 PureArgPolicy::ByValue => ops.push(remap(*arg)),
527 PureArgPolicy::ReadOnlyWrapperPtr => {
528 let wrapper = remap(*arg);
529 let stored = wrapper_values.get(&wrapper)?;
530 ops.push(remap(*stored));
531 }
532 PureArgPolicy::Unsupported => return None,
533 }
534 }
535 mk(90, ops, *idx as i128)
536 }
537 // Impure: loads, stores, runtime calls, external calls, alloca — not GVN candidates.
538 InstKind::Load(..)
539 | InstKind::Store(..)
540 | InstKind::Alloca(..)
541 | InstKind::Call(..)
542 | InstKind::RuntimeCall(..)
543 | InstKind::ConstString(..)
544 | InstKind::Undef(..)
545 | InstKind::ExtractField(..)
546 | InstKind::InsertField(..) => None,
547
548 // Vector ops opt out of GVN until the vectorizer lands and we
549 // can characterize aliasing / commutativity properly. Same
550 // policy as CSE keying.
551 InstKind::VAdd(..)
552 | InstKind::VSub(..)
553 | InstKind::VMul(..)
554 | InstKind::VDiv(..)
555 | InstKind::VNeg(..)
556 | InstKind::VAbs(..)
557 | InstKind::VSqrt(..)
558 | InstKind::VFma(..)
559 | InstKind::VSelect(..)
560 | InstKind::VMin(..)
561 | InstKind::VMax(..)
562 | InstKind::VICmp(..)
563 | InstKind::VFCmp(..)
564 | InstKind::VLoad(..)
565 | InstKind::VStore(..)
566 | InstKind::VBitcast(..)
567 | InstKind::VExtract(..)
568 | InstKind::VInsert(..)
569 | InstKind::VBroadcast(..)
570 | InstKind::VReduceSum(..)
571 | InstKind::VReduceMin(..)
572 | InstKind::VReduceMax(..) => None,
573 }
574 }
575
576 fn gvn_function(func: &mut Function, pure_calls: &[PureCallPolicy]) -> bool {
577 let idoms = compute_immediate_dominators(func);
578 let children = dominator_tree_children(&idoms);
579 let wrapper_values = wrapper_alloca_values(func, pure_calls);
580 let defs = inst_map(func);
581
582 // Scoped value number table: Key → dominating ValueId.
583 let mut vn_table: HashMap<Key, ValueId> = HashMap::new();
584 // Replacement map: redundant ValueId → dominating ValueId.
585 let mut replacements: HashMap<ValueId, ValueId> = HashMap::new();
586
587 // Process blocks in dominator-tree preorder (DFS from entry).
588 let mut stack = vec![(func.entry, 0usize)]; // (block, depth for scope management)
589 let mut scope_stack: Vec<Vec<Key>> = Vec::new(); // keys to remove when leaving scope
590
591 while let Some((block_id, depth)) = stack.pop() {
592 // Pop scope entries for blocks we've left.
593 while scope_stack.len() > depth {
594 if let Some(keys) = scope_stack.pop() {
595 for key in keys {
596 vn_table.remove(&key);
597 }
598 }
599 }
600
601 // Process this block's instructions.
602 let mut new_keys = Vec::new();
603 let block = func.block(block_id);
604 for inst in &block.insts {
605 if let Some(key) = key_of(inst, &replacements, pure_calls, &wrapper_values, &defs) {
606 if let Some(&existing) = vn_table.get(&key) {
607 // This expression is already available from a dominating block.
608 replacements.insert(inst.id, existing);
609 } else {
610 // First occurrence — register in the table.
611 vn_table.insert(key.clone(), inst.id);
612 new_keys.push(key);
613 }
614 }
615 }
616
617 scope_stack.push(new_keys);
618
619 // Push children (reverse order so leftmost child is processed first).
620 if let Some(kids) = children.get(&block_id) {
621 for &child in kids.iter().rev() {
622 stack.push((child, depth + 1));
623 }
624 }
625 }
626
627 if replacements.is_empty() {
628 return false;
629 }
630
631 // Apply replacements: substitute all uses of redundant values, then
632 // drop the now-dead duplicate instructions directly. This matters
633 // for PURE calls because DCE conservatively treats generic calls as
634 // side-effecting.
635 for block in &mut func.blocks {
636 for inst in &mut block.insts {
637 inst.kind = remap_operands(&inst.kind, &replacements);
638 }
639 block
640 .insts
641 .retain(|inst| !replacements.contains_key(&inst.id));
642 if let Some(ref mut term) = block.terminator {
643 remap_terminator_operands(term, &replacements);
644 }
645 }
646
647 true
648 }
649
650 /// Remap operands in an instruction using the replacement map.
651 fn remap_operands(kind: &InstKind, map: &HashMap<ValueId, ValueId>) -> InstKind {
652 super::loop_utils::remap_inst_kind(kind, map)
653 }
654
655 /// Remap operands in a terminator.
656 fn remap_terminator_operands(term: &mut Terminator, map: &HashMap<ValueId, ValueId>) {
657 let r = |v: &ValueId| *map.get(v).unwrap_or(v);
658 match term {
659 Terminator::Return(Some(v)) => *v = r(v),
660 Terminator::Branch(_, args) => {
661 for a in args.iter_mut() {
662 *a = r(a);
663 }
664 }
665 Terminator::CondBranch {
666 cond,
667 true_args,
668 false_args,
669 ..
670 } => {
671 *cond = r(cond);
672 for a in true_args.iter_mut() {
673 *a = r(a);
674 }
675 for a in false_args.iter_mut() {
676 *a = r(a);
677 }
678 }
679 Terminator::Switch { selector, .. } => {
680 *selector = r(selector);
681 }
682 _ => {}
683 }
684 }
685
686 #[cfg(test)]
687 mod tests {
688 use super::*;
689 use crate::ir::lower;
690 use crate::ir::types::{IntWidth, IrType};
691 use crate::lexer::{tokenize, SourceForm};
692 use crate::lexer::{Position, Span};
693 use crate::opt::pass::Pass;
694 use crate::opt::pass::PassManager;
695 use crate::opt::pipeline::OptLevel;
696 use crate::opt::{
697 bce::Bce, call_resolve::CallResolve, const_fold::ConstFold, const_prop::ConstProp,
698 dead_func::DeadFuncElim, dse::Dse, fission::LoopFission, fusion::LoopFusion,
699 global_lsf::GlobalLsf, inline::Inline, interchange::LoopInterchange, licm::Licm,
700 lsf::LocalLsf, mem2reg::Mem2Reg, peel::LoopPeel, preheader::PreheaderInsert,
701 simplify_cfg::SimplifyCfg, sroa::Sroa, strength_reduce::StrengthReduce, unroll::LoopUnroll,
702 unswitch::LoopUnswitch,
703 };
704 use crate::parser::Parser;
705 use crate::preprocess::PreprocConfig;
706 use crate::sema::{resolve, validate};
707 use std::fs;
708 use std::path::PathBuf;
709
710 fn span() -> Span {
711 let pos = Position { line: 0, col: 0 };
712 Span {
713 file_id: 0,
714 start: pos,
715 end: pos,
716 }
717 }
718
719 fn push_inst(func: &mut Function, block: BlockId, kind: InstKind, ty: IrType) -> ValueId {
720 let id = func.next_value_id();
721 func.register_type(id, ty.clone());
722 func.block_mut(block).insts.push(Inst {
723 id,
724 kind,
725 ty,
726 span: span(),
727 });
728 id
729 }
730
731 fn lower_fixture(name: &str) -> Module {
732 let path = PathBuf::from("test_programs").join(name);
733 let source = fs::read_to_string(&path).expect("fixture source");
734 let pp = crate::preprocess::preprocess(
735 &source,
736 &PreprocConfig {
737 filename: path.to_string_lossy().into_owned(),
738 fixed_form: false,
739 ..PreprocConfig::default()
740 },
741 )
742 .expect("preprocess fixture");
743 let tokens = tokenize(&pp.text, 0, SourceForm::FreeForm).expect("tokenize fixture");
744 let mut parser = Parser::new(&tokens);
745 let units = parser.parse_file().expect("parse fixture");
746 let (st, type_layouts) = {
747 let rr = resolve::resolve_file(&units, &[]).expect("resolve fixture");
748 (rr.st, rr.type_layouts)
749 };
750 let diags = validate::validate_file(&units, &st);
751 assert!(
752 !diags
753 .iter()
754 .any(|diag| diag.kind == validate::DiagKind::Error),
755 "fixture should lower cleanly: {:?}",
756 diags
757 );
758 lower::lower_file(
759 &units,
760 &st,
761 &type_layouts,
762 std::collections::HashMap::new(),
763 std::collections::HashMap::new(),
764 std::collections::HashMap::new(),
765 std::collections::HashMap::new(),
766 )
767 .0
768 }
769
770 fn build_pre_gvn_o2_pipeline() -> PassManager {
771 let mut pm = PassManager::new();
772 pm.add(Box::new(CallResolve));
773 pm.add(Box::new(Mem2Reg));
774 pm.add(Box::new(ConstFold));
775 pm.add(Box::new(Sroa));
776 pm.add(Box::new(Mem2Reg));
777 pm.add(Box::new(Inline::for_level(OptLevel::O2)));
778 pm.add(Box::new(SimplifyCfg));
779 pm.add(Box::new(DeadFuncElim));
780 pm.add(Box::new(Bce));
781 pm.add(Box::new(StrengthReduce));
782 pm.add(Box::new(LocalLsf));
783 pm.add(Box::new(GlobalLsf));
784 pm.add(Box::new(crate::opt::cse::LocalCse));
785 pm.add(Box::new(PreheaderInsert));
786 pm.add(Box::new(LoopPeel));
787 pm.add(Box::new(LoopUnswitch));
788 pm.add(Box::new(Licm));
789 pm.add(Box::new(ConstProp));
790 pm.add(Box::new(Dse));
791 pm.add(Box::new(LoopInterchange));
792 pm.add(Box::new(LoopFission));
793 pm.add(Box::new(LoopFusion));
794 pm.add(Box::new(LoopUnroll));
795 pm
796 }
797
798 #[test]
799 fn gvn_associativity_folds_left_and_right_regroups() {
800 // r1 = (a + b) + c
801 // r2 = a + (b + c)
802 // After GVN, r2 should be replaced by r1 (or vice versa) — the
803 // pass should detect at least one redundancy.
804 let mut m = Module::new("test".into());
805 let mut f = Function::new("test".into(), vec![], IrType::Int(IntWidth::I32));
806 let entry = f.entry;
807 let i32t = IrType::Int(IntWidth::I32);
808 let a = push_inst(&mut f, entry, InstKind::ConstInt(7, IntWidth::I32), i32t.clone());
809 let b = push_inst(&mut f, entry, InstKind::ConstInt(11, IntWidth::I32), i32t.clone());
810 let c = push_inst(&mut f, entry, InstKind::ConstInt(13, IntWidth::I32), i32t.clone());
811 // (a + b) + c
812 let ab = push_inst(&mut f, entry, InstKind::IAdd(a, b), i32t.clone());
813 let r1 = push_inst(&mut f, entry, InstKind::IAdd(ab, c), i32t.clone());
814 // a + (b + c)
815 let bc = push_inst(&mut f, entry, InstKind::IAdd(b, c), i32t.clone());
816 let r2 = push_inst(&mut f, entry, InstKind::IAdd(a, bc), i32t.clone());
817 // Use both so DCE doesn't strip them: sum = r1 + r2
818 let sum = push_inst(&mut f, entry, InstKind::IAdd(r1, r2), i32t.clone());
819 f.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
820 m.add_function(f);
821
822 let pass = Gvn;
823 let changed = pass.run(&mut m);
824 assert!(
825 changed,
826 "GVN should detect that (a+b)+c and a+(b+c) are the same value"
827 );
828 }
829
830 #[test]
831 fn gvn_no_op_on_empty() {
832 let mut m = Module::new("test".into());
833 let mut f = Function::new("test".into(), vec![], IrType::Void);
834 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
835 m.add_function(f);
836 let pass = Gvn;
837 assert!(!pass.run(&mut m));
838 }
839
840 #[test]
841 fn gvn_handles_max_i128_constant_keys() {
842 let mut m = Module::new("test".into());
843 let mut f = Function::new("test".into(), vec![], IrType::Int(IntWidth::I128));
844 let entry = f.entry;
845 let wide = push_inst(
846 &mut f,
847 entry,
848 InstKind::ConstInt(i128::MAX, IntWidth::I128),
849 IrType::Int(IntWidth::I128),
850 );
851 f.block_mut(entry).terminator = Some(Terminator::Return(Some(wide)));
852 m.add_function(f);
853 let pass = Gvn;
854 assert!(
855 !pass.run(&mut m),
856 "single max-i128 constant should not produce a spurious rewrite"
857 );
858 }
859
860 #[test]
861 fn gvn_reuses_pure_by_value_calls() {
862 let mut m = Module::new("test".into());
863
864 let param = Param {
865 name: "x".into(),
866 ty: IrType::Int(IntWidth::I32),
867 id: ValueId(0),
868 fortran_noalias: false,
869 };
870 let mut callee = Function::new("square".into(), vec![param], IrType::Int(IntWidth::I32));
871 callee.is_pure = true;
872 callee.block_mut(callee.entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
873 m.add_function(callee);
874
875 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
876 let entry = caller.entry;
877 let c7 = push_inst(
878 &mut caller,
879 entry,
880 InstKind::ConstInt(7, IntWidth::I32),
881 IrType::Int(IntWidth::I32),
882 );
883 let call1 = push_inst(
884 &mut caller,
885 entry,
886 InstKind::Call(FuncRef::Internal(0), vec![c7]),
887 IrType::Int(IntWidth::I32),
888 );
889 let call2 = push_inst(
890 &mut caller,
891 entry,
892 InstKind::Call(FuncRef::Internal(0), vec![c7]),
893 IrType::Int(IntWidth::I32),
894 );
895 let sum = push_inst(
896 &mut caller,
897 entry,
898 InstKind::IAdd(call1, call2),
899 IrType::Int(IntWidth::I32),
900 );
901 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
902 m.add_function(caller);
903
904 let pass = Gvn;
905 assert!(pass.run(&mut m));
906
907 let caller = &m.functions[1];
908 let calls: Vec<&Inst> = caller.blocks[0]
909 .insts
910 .iter()
911 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
912 .collect();
913 assert_eq!(calls.len(), 1, "redundant PURE call should be removed");
914 let kept_call = calls[0].id;
915
916 let add = caller.blocks[0]
917 .insts
918 .iter()
919 .find(|inst| matches!(inst.kind, InstKind::IAdd(..)))
920 .expect("caller should still contain the sum");
921 match add.kind {
922 InstKind::IAdd(lhs, rhs) => {
923 assert_eq!(lhs, kept_call);
924 assert_eq!(rhs, kept_call);
925 }
926 _ => unreachable!(),
927 }
928 }
929
930 #[test]
931 fn gvn_reuses_pure_calls_after_operand_canonicalization() {
932 let mut m = Module::new("test".into());
933
934 let param = Param {
935 name: "x".into(),
936 ty: IrType::Int(IntWidth::I32),
937 id: ValueId(0),
938 fortran_noalias: false,
939 };
940 let mut callee = Function::new("square".into(), vec![param], IrType::Int(IntWidth::I32));
941 callee.is_pure = true;
942 callee.block_mut(callee.entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
943 m.add_function(callee);
944
945 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
946 let entry = caller.entry;
947 let a = push_inst(
948 &mut caller,
949 entry,
950 InstKind::ConstInt(3, IntWidth::I32),
951 IrType::Int(IntWidth::I32),
952 );
953 let b = push_inst(
954 &mut caller,
955 entry,
956 InstKind::ConstInt(4, IntWidth::I32),
957 IrType::Int(IntWidth::I32),
958 );
959 let sum1 = push_inst(
960 &mut caller,
961 entry,
962 InstKind::IAdd(a, b),
963 IrType::Int(IntWidth::I32),
964 );
965 let sum2 = push_inst(
966 &mut caller,
967 entry,
968 InstKind::IAdd(a, b),
969 IrType::Int(IntWidth::I32),
970 );
971 let call1 = push_inst(
972 &mut caller,
973 entry,
974 InstKind::Call(FuncRef::Internal(0), vec![sum1]),
975 IrType::Int(IntWidth::I32),
976 );
977 let call2 = push_inst(
978 &mut caller,
979 entry,
980 InstKind::Call(FuncRef::Internal(0), vec![sum2]),
981 IrType::Int(IntWidth::I32),
982 );
983 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(call2)));
984 m.add_function(caller);
985
986 let pass = Gvn;
987 assert!(pass.run(&mut m));
988
989 let caller = &m.functions[1];
990 let call_count = caller.blocks[0]
991 .insts
992 .iter()
993 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
994 .count();
995 assert_eq!(
996 call_count, 1,
997 "PURE call should reuse equivalent dominating args"
998 );
999
1000 match caller.blocks[0]
1001 .terminator
1002 .as_ref()
1003 .expect("return terminator")
1004 {
1005 Terminator::Return(Some(v)) => assert_eq!(*v, call1),
1006 other => panic!("unexpected terminator: {:?}", other),
1007 }
1008 }
1009
1010 #[test]
1011 fn pure_pointer_param_policy_accepts_read_only_scalar_pattern() {
1012 let param = Param {
1013 name: "n".into(),
1014 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1015 id: ValueId(0),
1016 fortran_noalias: false,
1017 };
1018 let mut callee =
1019 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1020 callee.is_pure = true;
1021 let entry = callee.entry;
1022
1023 let slot = push_inst(
1024 &mut callee,
1025 entry,
1026 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1027 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1028 );
1029 push_inst(
1030 &mut callee,
1031 entry,
1032 InstKind::Store(ValueId(0), slot),
1033 IrType::Void,
1034 );
1035 let arg_ptr = push_inst(
1036 &mut callee,
1037 entry,
1038 InstKind::Load(slot),
1039 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1040 );
1041 let arg_val = push_inst(
1042 &mut callee,
1043 entry,
1044 InstKind::Load(arg_ptr),
1045 IrType::Int(IntWidth::I32),
1046 );
1047 callee.block_mut(entry).terminator = Some(Terminator::Return(Some(arg_val)));
1048
1049 let policy = PureCallPolicy::for_function(&callee);
1050 assert!(policy.reusable);
1051 assert_eq!(policy.arg_policies, vec![PureArgPolicy::ReadOnlyWrapperPtr]);
1052 }
1053
1054 #[test]
1055 fn wrapper_alloca_values_accept_scalar_by_ref_call_pattern() {
1056 let mut m = Module::new("test".into());
1057
1058 let param = Param {
1059 name: "n".into(),
1060 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1061 id: ValueId(0),
1062 fortran_noalias: false,
1063 };
1064 let mut callee =
1065 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1066 callee.is_pure = true;
1067 let callee_entry = callee.entry;
1068 let slot = push_inst(
1069 &mut callee,
1070 callee_entry,
1071 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1072 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1073 );
1074 push_inst(
1075 &mut callee,
1076 callee_entry,
1077 InstKind::Store(ValueId(0), slot),
1078 IrType::Void,
1079 );
1080 let arg_ptr = push_inst(
1081 &mut callee,
1082 callee_entry,
1083 InstKind::Load(slot),
1084 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1085 );
1086 let arg_val = push_inst(
1087 &mut callee,
1088 callee_entry,
1089 InstKind::Load(arg_ptr),
1090 IrType::Int(IntWidth::I32),
1091 );
1092 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(arg_val)));
1093 m.add_function(callee);
1094
1095 let pure_calls = vec![PureCallPolicy::for_function(&m.functions[0])];
1096
1097 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1098 let entry = caller.entry;
1099 let c = push_inst(
1100 &mut caller,
1101 entry,
1102 InstKind::ConstInt(6, IntWidth::I32),
1103 IrType::Int(IntWidth::I32),
1104 );
1105 let wrapper = push_inst(
1106 &mut caller,
1107 entry,
1108 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1109 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1110 );
1111 push_inst(
1112 &mut caller,
1113 entry,
1114 InstKind::Store(c, wrapper),
1115 IrType::Void,
1116 );
1117 let call = push_inst(
1118 &mut caller,
1119 entry,
1120 InstKind::Call(FuncRef::Internal(0), vec![wrapper]),
1121 IrType::Int(IntWidth::I32),
1122 );
1123 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(call)));
1124
1125 let wrappers = wrapper_alloca_values(&caller, &pure_calls);
1126 assert_eq!(wrappers.get(&wrapper), Some(&c));
1127 }
1128
1129 #[test]
1130 fn gvn_reuses_pure_calls_through_scalar_wrapper_allocas() {
1131 let mut m = Module::new("test".into());
1132
1133 let param = Param {
1134 name: "n".into(),
1135 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1136 id: ValueId(0),
1137 fortran_noalias: false,
1138 };
1139 let mut callee =
1140 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1141 callee.is_pure = true;
1142 let callee_entry = callee.entry;
1143 let shadow = push_inst(
1144 &mut callee,
1145 callee_entry,
1146 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1147 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1148 );
1149 push_inst(
1150 &mut callee,
1151 callee_entry,
1152 InstKind::Store(ValueId(0), shadow),
1153 IrType::Void,
1154 );
1155 let arg_ptr = push_inst(
1156 &mut callee,
1157 callee_entry,
1158 InstKind::Load(shadow),
1159 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1160 );
1161 let arg_val = push_inst(
1162 &mut callee,
1163 callee_entry,
1164 InstKind::Load(arg_ptr),
1165 IrType::Int(IntWidth::I32),
1166 );
1167 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(arg_val)));
1168 m.add_function(callee);
1169
1170 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1171 let entry = caller.entry;
1172 let first = push_inst(
1173 &mut caller,
1174 entry,
1175 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1176 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1177 );
1178 let second = push_inst(
1179 &mut caller,
1180 entry,
1181 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1182 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1183 );
1184 let c1 = push_inst(
1185 &mut caller,
1186 entry,
1187 InstKind::ConstInt(6, IntWidth::I32),
1188 IrType::Int(IntWidth::I32),
1189 );
1190 let c2 = push_inst(
1191 &mut caller,
1192 entry,
1193 InstKind::ConstInt(6, IntWidth::I32),
1194 IrType::Int(IntWidth::I32),
1195 );
1196 let wrap1 = push_inst(
1197 &mut caller,
1198 entry,
1199 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1200 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1201 );
1202 push_inst(&mut caller, entry, InstKind::Store(c2, wrap1), IrType::Void);
1203 let call1 = push_inst(
1204 &mut caller,
1205 entry,
1206 InstKind::Call(FuncRef::Internal(0), vec![wrap1]),
1207 IrType::Int(IntWidth::I32),
1208 );
1209 push_inst(
1210 &mut caller,
1211 entry,
1212 InstKind::Store(call1, first),
1213 IrType::Void,
1214 );
1215 let c3 = push_inst(
1216 &mut caller,
1217 entry,
1218 InstKind::ConstInt(6, IntWidth::I32),
1219 IrType::Int(IntWidth::I32),
1220 );
1221 let c4 = push_inst(
1222 &mut caller,
1223 entry,
1224 InstKind::ConstInt(6, IntWidth::I32),
1225 IrType::Int(IntWidth::I32),
1226 );
1227 let wrap2 = push_inst(
1228 &mut caller,
1229 entry,
1230 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1231 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1232 );
1233 push_inst(&mut caller, entry, InstKind::Store(c4, wrap2), IrType::Void);
1234 let call2 = push_inst(
1235 &mut caller,
1236 entry,
1237 InstKind::Call(FuncRef::Internal(0), vec![wrap2]),
1238 IrType::Int(IntWidth::I32),
1239 );
1240 push_inst(
1241 &mut caller,
1242 entry,
1243 InstKind::Store(call2, second),
1244 IrType::Void,
1245 );
1246 let left = push_inst(
1247 &mut caller,
1248 entry,
1249 InstKind::Load(first),
1250 IrType::Int(IntWidth::I32),
1251 );
1252 let right = push_inst(
1253 &mut caller,
1254 entry,
1255 InstKind::Load(second),
1256 IrType::Int(IntWidth::I32),
1257 );
1258 let sum = push_inst(
1259 &mut caller,
1260 entry,
1261 InstKind::IAdd(left, right),
1262 IrType::Int(IntWidth::I32),
1263 );
1264 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
1265 m.add_function(caller);
1266
1267 let pure_calls: Vec<PureCallPolicy> = m
1268 .functions
1269 .iter()
1270 .map(PureCallPolicy::for_function)
1271 .collect();
1272 assert!(pure_calls[0].reusable);
1273 assert_eq!(
1274 pure_calls[0].arg_policies,
1275 vec![PureArgPolicy::ReadOnlyWrapperPtr]
1276 );
1277 let wrappers = wrapper_alloca_values(&m.functions[1], &pure_calls);
1278 assert_eq!(wrappers.get(&wrap1), Some(&c2));
1279 assert_eq!(wrappers.get(&wrap2), Some(&c4));
1280
1281 let mut replacements = HashMap::new();
1282 replacements.insert(c2, c1);
1283 replacements.insert(c3, c1);
1284 replacements.insert(c4, c1);
1285 let caller_before = &m.functions[1];
1286 let call1_inst = caller_before.blocks[0]
1287 .insts
1288 .iter()
1289 .find(|inst| inst.id == call1)
1290 .expect("call1 inst");
1291 let call2_inst = caller_before.blocks[0]
1292 .insts
1293 .iter()
1294 .find(|inst| inst.id == call2)
1295 .expect("call2 inst");
1296 let test_defs = inst_map(caller_before);
1297 let key1 = key_of(call1_inst, &replacements, &pure_calls, &wrappers, &test_defs)
1298 .expect("call1 should key");
1299 let key2 = key_of(call2_inst, &replacements, &pure_calls, &wrappers, &test_defs)
1300 .expect("call2 should key");
1301 assert_eq!(
1302 key1, key2,
1303 "wrapper-based call keys should line up before the pass runs"
1304 );
1305
1306 let pass = Gvn;
1307 assert!(pass.run(&mut m));
1308
1309 let caller = &m.functions[1];
1310 let call_count = caller.blocks[0]
1311 .insts
1312 .iter()
1313 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1314 .count();
1315 assert_eq!(call_count, 1, "wrapper-alloca PURE calls should dedupe");
1316 }
1317
1318 #[test]
1319 fn gvn_reuses_pure_calls_for_recursive_factorial_shape() {
1320 let mut m = Module::new("test".into());
1321
1322 let param = Param {
1323 name: "n".into(),
1324 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1325 id: ValueId(0),
1326 fortran_noalias: false,
1327 };
1328 let mut callee =
1329 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1330 callee.is_pure = true;
1331 let entry = callee.entry;
1332 let if_end = callee.create_block("if_end");
1333 let if_then = callee.create_block("if_then");
1334 let if_else = callee.create_block("if_else");
1335
1336 let shadow = push_inst(
1337 &mut callee,
1338 entry,
1339 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1340 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1341 );
1342 push_inst(
1343 &mut callee,
1344 entry,
1345 InstKind::Store(ValueId(0), shadow),
1346 IrType::Void,
1347 );
1348 let result = push_inst(
1349 &mut callee,
1350 entry,
1351 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1352 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1353 );
1354 let p1 = push_inst(
1355 &mut callee,
1356 entry,
1357 InstKind::Load(shadow),
1358 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1359 );
1360 let n1 = push_inst(
1361 &mut callee,
1362 entry,
1363 InstKind::Load(p1),
1364 IrType::Int(IntWidth::I32),
1365 );
1366 let one = push_inst(
1367 &mut callee,
1368 entry,
1369 InstKind::ConstInt(1, IntWidth::I32),
1370 IrType::Int(IntWidth::I32),
1371 );
1372 let cond = push_inst(
1373 &mut callee,
1374 entry,
1375 InstKind::ICmp(CmpOp::Le, n1, one),
1376 IrType::Bool,
1377 );
1378 callee.block_mut(entry).terminator = Some(Terminator::CondBranch {
1379 cond,
1380 true_dest: if_then,
1381 true_args: vec![],
1382 false_dest: if_else,
1383 false_args: vec![],
1384 });
1385
1386 let then_one = push_inst(
1387 &mut callee,
1388 if_then,
1389 InstKind::ConstInt(1, IntWidth::I32),
1390 IrType::Int(IntWidth::I32),
1391 );
1392 push_inst(
1393 &mut callee,
1394 if_then,
1395 InstKind::Store(then_one, result),
1396 IrType::Void,
1397 );
1398 callee.block_mut(if_then).terminator = Some(Terminator::Branch(if_end, vec![]));
1399
1400 let p2 = push_inst(
1401 &mut callee,
1402 if_else,
1403 InstKind::Load(shadow),
1404 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1405 );
1406 let n2 = push_inst(
1407 &mut callee,
1408 if_else,
1409 InstKind::Load(p2),
1410 IrType::Int(IntWidth::I32),
1411 );
1412 let p3 = push_inst(
1413 &mut callee,
1414 if_else,
1415 InstKind::Load(shadow),
1416 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1417 );
1418 let n3 = push_inst(
1419 &mut callee,
1420 if_else,
1421 InstKind::Load(p3),
1422 IrType::Int(IntWidth::I32),
1423 );
1424 let one2 = push_inst(
1425 &mut callee,
1426 if_else,
1427 InstKind::ConstInt(1, IntWidth::I32),
1428 IrType::Int(IntWidth::I32),
1429 );
1430 let dec1 = push_inst(
1431 &mut callee,
1432 if_else,
1433 InstKind::ISub(n3, one2),
1434 IrType::Int(IntWidth::I32),
1435 );
1436 let p4 = push_inst(
1437 &mut callee,
1438 if_else,
1439 InstKind::Load(shadow),
1440 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1441 );
1442 let n4 = push_inst(
1443 &mut callee,
1444 if_else,
1445 InstKind::Load(p4),
1446 IrType::Int(IntWidth::I32),
1447 );
1448 let one3 = push_inst(
1449 &mut callee,
1450 if_else,
1451 InstKind::ConstInt(1, IntWidth::I32),
1452 IrType::Int(IntWidth::I32),
1453 );
1454 let dec2 = push_inst(
1455 &mut callee,
1456 if_else,
1457 InstKind::ISub(n4, one3),
1458 IrType::Int(IntWidth::I32),
1459 );
1460 let wrap = push_inst(
1461 &mut callee,
1462 if_else,
1463 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1464 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1465 );
1466 push_inst(
1467 &mut callee,
1468 if_else,
1469 InstKind::Store(dec2, wrap),
1470 IrType::Void,
1471 );
1472 let rec = push_inst(
1473 &mut callee,
1474 if_else,
1475 InstKind::Call(FuncRef::Internal(0), vec![wrap]),
1476 IrType::Int(IntWidth::I32),
1477 );
1478 let prod = push_inst(
1479 &mut callee,
1480 if_else,
1481 InstKind::IMul(n2, rec),
1482 IrType::Int(IntWidth::I32),
1483 );
1484 push_inst(
1485 &mut callee,
1486 if_else,
1487 InstKind::Store(prod, result),
1488 IrType::Void,
1489 );
1490 callee.block_mut(if_else).terminator = Some(Terminator::Branch(if_end, vec![]));
1491
1492 let final_val = push_inst(
1493 &mut callee,
1494 if_end,
1495 InstKind::Load(result),
1496 IrType::Int(IntWidth::I32),
1497 );
1498 callee.block_mut(if_end).terminator = Some(Terminator::Return(Some(final_val)));
1499 m.add_function(callee);
1500
1501 let mut caller = Function::new("main".into(), vec![], IrType::Void);
1502 let caller_entry = caller.entry;
1503 let unit = push_inst(
1504 &mut caller,
1505 caller_entry,
1506 InstKind::ConstInt(6, IntWidth::I32),
1507 IrType::Int(IntWidth::I32),
1508 );
1509 let unit2 = push_inst(
1510 &mut caller,
1511 caller_entry,
1512 InstKind::ConstInt(6, IntWidth::I32),
1513 IrType::Int(IntWidth::I32),
1514 );
1515 let wrap1 = push_inst(
1516 &mut caller,
1517 caller_entry,
1518 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1519 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1520 );
1521 push_inst(
1522 &mut caller,
1523 caller_entry,
1524 InstKind::Store(unit2, wrap1),
1525 IrType::Void,
1526 );
1527 let call1 = push_inst(
1528 &mut caller,
1529 caller_entry,
1530 InstKind::Call(FuncRef::Internal(0), vec![wrap1]),
1531 IrType::Int(IntWidth::I32),
1532 );
1533 let wrap2 = push_inst(
1534 &mut caller,
1535 caller_entry,
1536 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1537 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1538 );
1539 push_inst(
1540 &mut caller,
1541 caller_entry,
1542 InstKind::Store(unit, wrap2),
1543 IrType::Void,
1544 );
1545 let call2 = push_inst(
1546 &mut caller,
1547 caller_entry,
1548 InstKind::Call(FuncRef::Internal(0), vec![wrap2]),
1549 IrType::Int(IntWidth::I32),
1550 );
1551 let sum = push_inst(
1552 &mut caller,
1553 caller_entry,
1554 InstKind::IAdd(call1, call2),
1555 IrType::Int(IntWidth::I32),
1556 );
1557 caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(sum)));
1558 m.add_function(caller);
1559
1560 let pass = Gvn;
1561 assert!(pass.run(&mut m));
1562 let caller = &m.functions[1];
1563 let call_count = caller.blocks[0]
1564 .insts
1565 .iter()
1566 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1567 .count();
1568 assert_eq!(
1569 call_count, 1,
1570 "recursive PURE factorial calls should dedupe"
1571 );
1572 }
1573
1574 #[test]
1575 fn gvn_reuses_pure_calls_when_callee_follows_caller() {
1576 let mut m = Module::new("test".into());
1577
1578 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1579 let entry = caller.entry;
1580 let c1 = push_inst(
1581 &mut caller,
1582 entry,
1583 InstKind::ConstInt(6, IntWidth::I32),
1584 IrType::Int(IntWidth::I32),
1585 );
1586 let wrap1 = push_inst(
1587 &mut caller,
1588 entry,
1589 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1590 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1591 );
1592 push_inst(&mut caller, entry, InstKind::Store(c1, wrap1), IrType::Void);
1593 let call1 = push_inst(
1594 &mut caller,
1595 entry,
1596 InstKind::Call(FuncRef::Internal(1), vec![wrap1]),
1597 IrType::Int(IntWidth::I32),
1598 );
1599 let c2 = push_inst(
1600 &mut caller,
1601 entry,
1602 InstKind::ConstInt(6, IntWidth::I32),
1603 IrType::Int(IntWidth::I32),
1604 );
1605 let wrap2 = push_inst(
1606 &mut caller,
1607 entry,
1608 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1609 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1610 );
1611 push_inst(&mut caller, entry, InstKind::Store(c2, wrap2), IrType::Void);
1612 let call2 = push_inst(
1613 &mut caller,
1614 entry,
1615 InstKind::Call(FuncRef::Internal(1), vec![wrap2]),
1616 IrType::Int(IntWidth::I32),
1617 );
1618 let sum = push_inst(
1619 &mut caller,
1620 entry,
1621 InstKind::IAdd(call1, call2),
1622 IrType::Int(IntWidth::I32),
1623 );
1624 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
1625 m.add_function(caller);
1626
1627 let param = Param {
1628 name: "n".into(),
1629 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1630 id: ValueId(0),
1631 fortran_noalias: false,
1632 };
1633 let mut callee =
1634 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1635 callee.is_pure = true;
1636 let callee_entry = callee.entry;
1637 let shadow = push_inst(
1638 &mut callee,
1639 callee_entry,
1640 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1641 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1642 );
1643 push_inst(
1644 &mut callee,
1645 callee_entry,
1646 InstKind::Store(ValueId(0), shadow),
1647 IrType::Void,
1648 );
1649 let arg_ptr = push_inst(
1650 &mut callee,
1651 callee_entry,
1652 InstKind::Load(shadow),
1653 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1654 );
1655 let arg_val = push_inst(
1656 &mut callee,
1657 callee_entry,
1658 InstKind::Load(arg_ptr),
1659 IrType::Int(IntWidth::I32),
1660 );
1661 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(arg_val)));
1662 m.add_function(callee);
1663
1664 let pass = Gvn;
1665 assert!(pass.run(&mut m));
1666 let caller = &m.functions[0];
1667 let call_count = caller.blocks[0]
1668 .insts
1669 .iter()
1670 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1671 .count();
1672 assert_eq!(
1673 call_count, 1,
1674 "caller-before-callee ordering should still dedupe"
1675 );
1676 }
1677
1678 #[test]
1679 fn gvn_matches_real_pure_recursive_fixture_after_pre_o2_passes() {
1680 let mut module = lower_fixture("pure_recursive_reuse.f90");
1681 let pm = build_pre_gvn_o2_pipeline();
1682 pm.run(&mut module);
1683
1684 let pure_calls: Vec<PureCallPolicy> = module
1685 .functions
1686 .iter()
1687 .map(PureCallPolicy::for_function)
1688 .collect();
1689 assert!(
1690 pure_calls.iter().any(|policy| policy.reusable),
1691 "at least one function in the fixture should remain a reusable PURE callee:\n{}",
1692 crate::ir::printer::print_module(&module)
1693 );
1694
1695 let caller_idx = module
1696 .functions
1697 .iter()
1698 .position(|func| func.name == "__prog_pure_recursive_reuse")
1699 .expect("program entry in fixture");
1700 let wrappers = wrapper_alloca_values(&module.functions[caller_idx], &pure_calls);
1701 assert!(
1702 !wrappers.is_empty(),
1703 "caller should still expose scalar by-ref wrapper allocas before GVN"
1704 );
1705
1706 let pass = Gvn;
1707 assert!(
1708 pass.run(&mut module),
1709 "GVN should make progress on the real fixture"
1710 );
1711
1712 let caller = &module.functions[caller_idx];
1713 let call_count = caller.blocks[0]
1714 .insts
1715 .iter()
1716 .filter(|inst| matches!(inst.kind, InstKind::Call(FuncRef::Internal(_), _)))
1717 .count();
1718 assert_eq!(
1719 call_count, 1,
1720 "real fixture caller should end with one PURE recursive call"
1721 );
1722 }
1723
1724 #[test]
1725 fn gvn_does_not_reuse_pure_calls_with_pointer_args() {
1726 let mut m = Module::new("test".into());
1727
1728 let param = Param {
1729 name: "p".into(),
1730 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1731 id: ValueId(0),
1732 fortran_noalias: false,
1733 };
1734 let mut callee = Function::new("peek".into(), vec![param], IrType::Int(IntWidth::I32));
1735 callee.is_pure = true;
1736 let callee_entry = callee.entry;
1737 let zero = push_inst(
1738 &mut callee,
1739 callee_entry,
1740 InstKind::ConstInt(0, IntWidth::I32),
1741 IrType::Int(IntWidth::I32),
1742 );
1743 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(zero)));
1744 m.add_function(callee);
1745
1746 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1747 let entry = caller.entry;
1748 let slot = push_inst(
1749 &mut caller,
1750 entry,
1751 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1752 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1753 );
1754 let call1 = push_inst(
1755 &mut caller,
1756 entry,
1757 InstKind::Call(FuncRef::Internal(0), vec![slot]),
1758 IrType::Int(IntWidth::I32),
1759 );
1760 let call2 = push_inst(
1761 &mut caller,
1762 entry,
1763 InstKind::Call(FuncRef::Internal(0), vec![slot]),
1764 IrType::Int(IntWidth::I32),
1765 );
1766 let sum = push_inst(
1767 &mut caller,
1768 entry,
1769 InstKind::IAdd(call1, call2),
1770 IrType::Int(IntWidth::I32),
1771 );
1772 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
1773 m.add_function(caller);
1774
1775 let pass = Gvn;
1776 assert!(
1777 !pass.run(&mut m),
1778 "pointer-arg PURE calls must stay distinct"
1779 );
1780
1781 let caller = &m.functions[1];
1782 let call_count = caller.blocks[0]
1783 .insts
1784 .iter()
1785 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1786 .count();
1787 assert_eq!(call_count, 2);
1788 }
1789 }
1790