Rust · 55060 bytes Raw Blame History
1 //! Global Value Numbering (GVN).
2 //!
3 //! Extends local CSE to cross-block redundancy elimination. Processes
4 //! blocks in dominator-tree preorder, maintaining a scoped hash table
5 //! of value numbers. When a block computes an expression already
6 //! available from a dominating block, the redundant instruction is
7 //! replaced with the dominating definition.
8 //!
9 //! Uses the same canonical Key structure as local CSE (cse.rs).
10
11 use super::pass::Pass;
12 use crate::ir::inst::*;
13 use crate::ir::types::IrType;
14 use crate::ir::walk::{compute_immediate_dominators, dominator_tree_children};
15 use std::collections::HashMap;
16
17 pub struct Gvn;
18
19 impl Pass for Gvn {
20 fn name(&self) -> &'static str {
21 "gvn"
22 }
23
24 fn run(&self, module: &mut Module) -> bool {
25 let pure_calls: Vec<PureCallPolicy> = module
26 .functions
27 .iter()
28 .map(PureCallPolicy::for_function)
29 .collect();
30 let mut changed = false;
31 for func in &mut module.functions {
32 if gvn_function(func, &pure_calls) {
33 changed = true;
34 }
35 }
36 changed
37 }
38 }
39
40 /// Canonical key for value numbering (same as CSE).
41 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
42 struct Key {
43 tag: u32,
44 operands: Vec<ValueId>,
45 aux: i128,
46 name: Option<String>,
47 ty: IrType,
48 }
49
50 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
51 enum PureArgPolicy {
52 ByValue,
53 ReadOnlyWrapperPtr,
54 Unsupported,
55 }
56
57 #[derive(Debug, Clone)]
58 struct PureCallPolicy {
59 reusable: bool,
60 arg_policies: Vec<PureArgPolicy>,
61 }
62
63 impl PureCallPolicy {
64 fn for_function(func: &Function) -> Self {
65 let arg_policies = func
66 .params
67 .iter()
68 .map(|param| classify_pure_arg(func, param))
69 .collect::<Vec<_>>();
70 let reusable = func.is_pure
71 && !matches!(func.return_type, IrType::Void)
72 && !func.return_type.is_ptr()
73 && arg_policies
74 .iter()
75 .all(|policy| *policy != PureArgPolicy::Unsupported)
76 && !reads_non_argument_memory(func);
77 Self {
78 reusable,
79 arg_policies,
80 }
81 }
82 }
83
84 /// True if the function body touches any state outside of its
85 /// arguments — e.g. a module global via `GlobalAddr`, an external
86 /// call, or a runtime call that reads mutable state.
87 ///
88 /// Per F2018 15.7 a PURE function is free to *read* common blocks
89 /// and module variables; it just can't *write* them. The Fortran
90 /// `is_pure` flag therefore doesn't imply "result depends only on
91 /// arguments" — a pure recursive accumulator over a module
92 /// variable is fully legal, and GVN must not hash-cons such a call
93 /// across an intervening store to that variable.
94 ///
95 /// We conservatively reject any callee that contains a `GlobalAddr`
96 /// or a non-Internal `Call` / `RuntimeCall`. The existing
97 /// argument-policy machinery already handles the "reads through
98 /// argument pointer" case via `ReadOnlyWrapperPtr`.
99 fn reads_non_argument_memory(func: &Function) -> bool {
100 for block in &func.blocks {
101 for inst in &block.insts {
102 match &inst.kind {
103 InstKind::GlobalAddr(_) => return true,
104 InstKind::Call(FuncRef::External(_), _) => return true,
105 InstKind::RuntimeCall(_, _) => return true,
106 _ => {}
107 }
108 }
109 }
110 false
111 }
112
113 fn is_scalar_type(ty: &IrType) -> bool {
114 matches!(ty, IrType::Bool | IrType::Int(_) | IrType::Float(_))
115 }
116
117 fn inst_map(func: &Function) -> HashMap<ValueId, &Inst> {
118 func.blocks
119 .iter()
120 .flat_map(|block| block.insts.iter())
121 .map(|inst| (inst.id, inst))
122 .collect()
123 }
124
125 fn classify_pure_arg(func: &Function, param: &Param) -> PureArgPolicy {
126 if !param.ty.is_ptr() {
127 return PureArgPolicy::ByValue;
128 }
129 let IrType::Ptr(inner) = &param.ty else {
130 return PureArgPolicy::Unsupported;
131 };
132 if !is_scalar_type(inner.as_ref()) {
133 return PureArgPolicy::Unsupported;
134 }
135 if pointer_param_is_read_only_scalar(func, param.id) {
136 PureArgPolicy::ReadOnlyWrapperPtr
137 } else {
138 PureArgPolicy::Unsupported
139 }
140 }
141
142 fn pointer_param_is_read_only_scalar(func: &Function, param_id: ValueId) -> bool {
143 let defs = inst_map(func);
144 let mut shadow_slots = Vec::new();
145
146 for block in &func.blocks {
147 for inst in &block.insts {
148 let uses = super::util::inst_uses(&inst.kind);
149 if !uses.contains(&param_id) {
150 continue;
151 }
152 match &inst.kind {
153 InstKind::Store(value, addr) if *value == param_id => {
154 let Some(def) = defs.get(addr) else {
155 return false;
156 };
157 match &def.kind {
158 InstKind::Alloca(slot_ty) => {
159 let IrType::Ptr(inner) = slot_ty else {
160 return false;
161 };
162 if !is_scalar_type(inner.as_ref()) {
163 return false;
164 }
165 shadow_slots.push(*addr);
166 }
167 _ => return false,
168 }
169 }
170 // After mem2reg, the entry shadow slot often folds away and
171 // the PURE callee just reads the incoming by-ref scalar
172 // directly.
173 InstKind::Load(addr) if *addr == param_id => {}
174 _ => return false,
175 }
176 }
177 if let Some(term) = &block.terminator {
178 if super::util::terminator_uses(term).contains(&param_id) {
179 return false;
180 }
181 }
182 }
183
184 for slot in shadow_slots {
185 let mut alias_ptrs = Vec::new();
186 for block in &func.blocks {
187 for inst in &block.insts {
188 let uses = super::util::inst_uses(&inst.kind);
189 if !uses.contains(&slot) {
190 continue;
191 }
192 match &inst.kind {
193 InstKind::Store(value, addr) if *value == param_id && *addr == slot => {}
194 InstKind::Load(addr) if *addr == slot => alias_ptrs.push(inst.id),
195 _ => return false,
196 }
197 }
198 if let Some(term) = &block.terminator {
199 if super::util::terminator_uses(term).contains(&slot) {
200 return false;
201 }
202 }
203 }
204
205 for alias in alias_ptrs {
206 for block in &func.blocks {
207 for inst in &block.insts {
208 let uses = super::util::inst_uses(&inst.kind);
209 if !uses.contains(&alias) {
210 continue;
211 }
212 match &inst.kind {
213 InstKind::Load(addr) if *addr == alias => {}
214 _ => return false,
215 }
216 }
217 if let Some(term) = &block.terminator {
218 if super::util::terminator_uses(term).contains(&alias) {
219 return false;
220 }
221 }
222 }
223 }
224 }
225
226 true
227 }
228
229 fn wrapper_alloca_values(
230 func: &Function,
231 pure_calls: &[PureCallPolicy],
232 ) -> HashMap<ValueId, ValueId> {
233 let mut wrappers = HashMap::new();
234 let candidates: Vec<ValueId> = func
235 .blocks
236 .iter()
237 .flat_map(|block| block.insts.iter())
238 .filter_map(|inst| match &inst.kind {
239 InstKind::Alloca(inner_ty) if is_scalar_type(inner_ty) => Some(inst.id),
240 _ => None,
241 })
242 .collect();
243
244 'candidate: for alloca_id in candidates {
245 let mut stored_value = None;
246
247 for scan_block in &func.blocks {
248 for scan_inst in &scan_block.insts {
249 let uses = super::util::inst_uses(&scan_inst.kind);
250 if !uses.contains(&alloca_id) {
251 continue;
252 }
253 match &scan_inst.kind {
254 InstKind::Store(value, addr) if *addr == alloca_id => {
255 if stored_value.replace(*value).is_some() {
256 continue 'candidate;
257 }
258 }
259 InstKind::Call(FuncRef::Internal(idx), args) => {
260 let Some(policy) = pure_calls.get(*idx as usize) else {
261 continue 'candidate;
262 };
263 if !policy.reusable {
264 continue 'candidate;
265 }
266 let valid_arg = args.iter().enumerate().any(|(arg_idx, arg)| {
267 *arg == alloca_id
268 && policy.arg_policies.get(arg_idx).copied()
269 == Some(PureArgPolicy::ReadOnlyWrapperPtr)
270 });
271 if !valid_arg {
272 continue 'candidate;
273 }
274 }
275 _ => continue 'candidate,
276 }
277 }
278 if let Some(term) = &scan_block.terminator {
279 if super::util::terminator_uses(term).contains(&alloca_id) {
280 continue 'candidate;
281 }
282 }
283 }
284
285 if let Some(value) = stored_value {
286 wrappers.insert(alloca_id, value);
287 }
288 }
289
290 wrappers
291 }
292
293 fn resolve_value(map: &HashMap<ValueId, ValueId>, mut value: ValueId) -> ValueId {
294 while let Some(&next) = map.get(&value) {
295 if next == value {
296 break;
297 }
298 value = next;
299 }
300 value
301 }
302
303 fn key_of(
304 inst: &Inst,
305 replacements: &HashMap<ValueId, ValueId>,
306 pure_calls: &[PureCallPolicy],
307 wrapper_values: &HashMap<ValueId, ValueId>,
308 ) -> Option<Key> {
309 let mk = |tag: u32, ops: Vec<ValueId>, aux: i128| -> Option<Key> {
310 Some(Key {
311 tag,
312 operands: ops,
313 aux,
314 name: None,
315 ty: inst.ty.clone(),
316 })
317 };
318 let mk_named = |tag: u32, name: String| -> Option<Key> {
319 Some(Key {
320 tag,
321 operands: vec![],
322 aux: 0,
323 name: Some(name),
324 ty: inst.ty.clone(),
325 })
326 };
327 let remap = |v: ValueId| resolve_value(replacements, v);
328 fn canon(a: ValueId, b: ValueId) -> Vec<ValueId> {
329 if a.0 <= b.0 {
330 vec![a, b]
331 } else {
332 vec![b, a]
333 }
334 }
335
336 match &inst.kind {
337 // Pure arithmetic — commutative ops get canonicalized operand order.
338 InstKind::IAdd(a, b) => mk(1, canon(remap(*a), remap(*b)), 0),
339 InstKind::ISub(a, b) => mk(2, vec![remap(*a), remap(*b)], 0),
340 InstKind::IMul(a, b) => mk(3, canon(remap(*a), remap(*b)), 0),
341 InstKind::IDiv(a, b) => mk(4, vec![remap(*a), remap(*b)], 0),
342 InstKind::IMod(a, b) => mk(5, vec![remap(*a), remap(*b)], 0),
343 InstKind::INeg(a) => mk(6, vec![remap(*a)], 0),
344 InstKind::FAdd(a, b) => mk(10, canon(remap(*a), remap(*b)), 0),
345 InstKind::FSub(a, b) => mk(11, vec![remap(*a), remap(*b)], 0),
346 InstKind::FMul(a, b) => mk(12, canon(remap(*a), remap(*b)), 0),
347 InstKind::FDiv(a, b) => mk(13, vec![remap(*a), remap(*b)], 0),
348 InstKind::FNeg(a) => mk(14, vec![remap(*a)], 0),
349 InstKind::FAbs(a) => mk(15, vec![remap(*a)], 0),
350 InstKind::FSqrt(a) => mk(16, vec![remap(*a)], 0),
351 InstKind::FPow(a, b) => mk(17, vec![remap(*a), remap(*b)], 0),
352 InstKind::ICmp(op, a, b) => {
353 let op_val = *op as i128;
354 match op {
355 CmpOp::Eq | CmpOp::Ne => mk(20, canon(remap(*a), remap(*b)), op_val),
356 _ => mk(20, vec![remap(*a), remap(*b)], op_val),
357 }
358 }
359 InstKind::FCmp(op, a, b) => mk(21, vec![remap(*a), remap(*b)], *op as i128),
360 InstKind::And(a, b) => mk(30, canon(remap(*a), remap(*b)), 0),
361 InstKind::Or(a, b) => mk(31, canon(remap(*a), remap(*b)), 0),
362 InstKind::Not(a) => mk(32, vec![remap(*a)], 0),
363 InstKind::Select(c, t, f) => mk(33, vec![remap(*c), remap(*t), remap(*f)], 0),
364 InstKind::BitAnd(a, b) => mk(40, canon(remap(*a), remap(*b)), 0),
365 InstKind::BitOr(a, b) => mk(41, canon(remap(*a), remap(*b)), 0),
366 InstKind::BitXor(a, b) => mk(42, canon(remap(*a), remap(*b)), 0),
367 InstKind::BitNot(a) => mk(43, vec![remap(*a)], 0),
368 InstKind::Shl(a, b) => mk(44, vec![remap(*a), remap(*b)], 0),
369 InstKind::LShr(a, b) => mk(45, vec![remap(*a), remap(*b)], 0),
370 InstKind::AShr(a, b) => mk(46, vec![remap(*a), remap(*b)], 0),
371 InstKind::CountLeadingZeros(a) => mk(47, vec![remap(*a)], 0),
372 InstKind::CountTrailingZeros(a) => mk(48, vec![remap(*a)], 0),
373 InstKind::PopCount(a) => mk(49, vec![remap(*a)], 0),
374 // Conversions.
375 InstKind::IntToFloat(a, w) => mk(50, vec![remap(*a)], w.bits() as i128),
376 InstKind::FloatToInt(a, w) => mk(51, vec![remap(*a)], w.bits() as i128),
377 InstKind::FloatExtend(a, w) => mk(52, vec![remap(*a)], w.bits() as i128),
378 InstKind::FloatTrunc(a, w) => mk(53, vec![remap(*a)], w.bits() as i128),
379 InstKind::IntExtend(a, w, s) => mk(
380 54,
381 vec![remap(*a)],
382 (w.bits() as i128) * if *s { 1 } else { -1 },
383 ),
384 InstKind::IntTrunc(a, w) => mk(55, vec![remap(*a)], w.bits() as i128),
385 InstKind::PtrToInt(a) => mk(56, vec![remap(*a)], 0),
386 InstKind::IntToPtr(a, _) => mk(57, vec![remap(*a)], 0),
387 // Constants.
388 InstKind::ConstInt(v, w) => {
389 let bits = w.bits();
390 let signed = if bits >= 128 {
391 *v
392 } else {
393 let shift = 128 - bits;
394 (*v << shift) >> shift
395 };
396 mk(60, vec![], signed)
397 }
398 InstKind::ConstFloat(v, w) => mk(61, vec![], ((*v).to_bits() as i128) ^ (w.bits() as i128)),
399 InstKind::ConstBool(v) => mk(62, vec![], *v as i128),
400 // GlobalAddr.
401 InstKind::GlobalAddr(name) => mk_named(70, name.clone()),
402 // GEP.
403 InstKind::GetElementPtr(base, idxs) => {
404 let mut ops = vec![remap(*base)];
405 ops.extend(idxs.iter().copied().map(remap));
406 mk(80, ops, 0)
407 }
408 // Reusable PURE calls: by-value arguments only, real value return,
409 // and no pointer result identity to preserve. For lowered
410 // Fortran scalar-by-reference args, look through the wrapper
411 // alloca to the stored SSA value when the callee only reads the
412 // pointee as a scalar input.
413 InstKind::Call(FuncRef::Internal(idx), args) => {
414 let policy = pure_calls.get(*idx as usize)?;
415 if !policy.reusable || policy.arg_policies.len() != args.len() {
416 return None;
417 }
418 let mut ops = Vec::with_capacity(args.len());
419 for (arg, arg_policy) in args.iter().zip(policy.arg_policies.iter()) {
420 match arg_policy {
421 PureArgPolicy::ByValue => ops.push(remap(*arg)),
422 PureArgPolicy::ReadOnlyWrapperPtr => {
423 let wrapper = remap(*arg);
424 let stored = wrapper_values.get(&wrapper)?;
425 ops.push(remap(*stored));
426 }
427 PureArgPolicy::Unsupported => return None,
428 }
429 }
430 mk(90, ops, *idx as i128)
431 }
432 // Impure: loads, stores, runtime calls, external calls, alloca — not GVN candidates.
433 InstKind::Load(..)
434 | InstKind::Store(..)
435 | InstKind::Alloca(..)
436 | InstKind::Call(..)
437 | InstKind::RuntimeCall(..)
438 | InstKind::ConstString(..)
439 | InstKind::Undef(..)
440 | InstKind::ExtractField(..)
441 | InstKind::InsertField(..) => None,
442 }
443 }
444
445 fn gvn_function(func: &mut Function, pure_calls: &[PureCallPolicy]) -> bool {
446 let idoms = compute_immediate_dominators(func);
447 let children = dominator_tree_children(&idoms);
448 let wrapper_values = wrapper_alloca_values(func, pure_calls);
449
450 // Scoped value number table: Key → dominating ValueId.
451 let mut vn_table: HashMap<Key, ValueId> = HashMap::new();
452 // Replacement map: redundant ValueId → dominating ValueId.
453 let mut replacements: HashMap<ValueId, ValueId> = HashMap::new();
454
455 // Process blocks in dominator-tree preorder (DFS from entry).
456 let mut stack = vec![(func.entry, 0usize)]; // (block, depth for scope management)
457 let mut scope_stack: Vec<Vec<Key>> = Vec::new(); // keys to remove when leaving scope
458
459 while let Some((block_id, depth)) = stack.pop() {
460 // Pop scope entries for blocks we've left.
461 while scope_stack.len() > depth {
462 if let Some(keys) = scope_stack.pop() {
463 for key in keys {
464 vn_table.remove(&key);
465 }
466 }
467 }
468
469 // Process this block's instructions.
470 let mut new_keys = Vec::new();
471 let block = func.block(block_id);
472 for inst in &block.insts {
473 if let Some(key) = key_of(inst, &replacements, pure_calls, &wrapper_values) {
474 if let Some(&existing) = vn_table.get(&key) {
475 // This expression is already available from a dominating block.
476 replacements.insert(inst.id, existing);
477 } else {
478 // First occurrence — register in the table.
479 vn_table.insert(key.clone(), inst.id);
480 new_keys.push(key);
481 }
482 }
483 }
484
485 scope_stack.push(new_keys);
486
487 // Push children (reverse order so leftmost child is processed first).
488 if let Some(kids) = children.get(&block_id) {
489 for &child in kids.iter().rev() {
490 stack.push((child, depth + 1));
491 }
492 }
493 }
494
495 if replacements.is_empty() {
496 return false;
497 }
498
499 // Apply replacements: substitute all uses of redundant values, then
500 // drop the now-dead duplicate instructions directly. This matters
501 // for PURE calls because DCE conservatively treats generic calls as
502 // side-effecting.
503 for block in &mut func.blocks {
504 for inst in &mut block.insts {
505 inst.kind = remap_operands(&inst.kind, &replacements);
506 }
507 block
508 .insts
509 .retain(|inst| !replacements.contains_key(&inst.id));
510 if let Some(ref mut term) = block.terminator {
511 remap_terminator_operands(term, &replacements);
512 }
513 }
514
515 true
516 }
517
518 /// Remap operands in an instruction using the replacement map.
519 fn remap_operands(kind: &InstKind, map: &HashMap<ValueId, ValueId>) -> InstKind {
520 super::loop_utils::remap_inst_kind(kind, map)
521 }
522
523 /// Remap operands in a terminator.
524 fn remap_terminator_operands(term: &mut Terminator, map: &HashMap<ValueId, ValueId>) {
525 let r = |v: &ValueId| *map.get(v).unwrap_or(v);
526 match term {
527 Terminator::Return(Some(v)) => *v = r(v),
528 Terminator::Branch(_, args) => {
529 for a in args.iter_mut() {
530 *a = r(a);
531 }
532 }
533 Terminator::CondBranch {
534 cond,
535 true_args,
536 false_args,
537 ..
538 } => {
539 *cond = r(cond);
540 for a in true_args.iter_mut() {
541 *a = r(a);
542 }
543 for a in false_args.iter_mut() {
544 *a = r(a);
545 }
546 }
547 Terminator::Switch { selector, .. } => {
548 *selector = r(selector);
549 }
550 _ => {}
551 }
552 }
553
554 #[cfg(test)]
555 mod tests {
556 use super::*;
557 use crate::ir::lower;
558 use crate::ir::types::{IntWidth, IrType};
559 use crate::lexer::{tokenize, SourceForm};
560 use crate::lexer::{Position, Span};
561 use crate::opt::pass::Pass;
562 use crate::opt::pass::PassManager;
563 use crate::opt::pipeline::OptLevel;
564 use crate::opt::{
565 bce::Bce, call_resolve::CallResolve, const_fold::ConstFold, const_prop::ConstProp,
566 dead_func::DeadFuncElim, dse::Dse, fission::LoopFission, fusion::LoopFusion,
567 global_lsf::GlobalLsf, inline::Inline, interchange::LoopInterchange, licm::Licm,
568 lsf::LocalLsf, mem2reg::Mem2Reg, peel::LoopPeel, preheader::PreheaderInsert,
569 simplify_cfg::SimplifyCfg, sroa::Sroa, strength_reduce::StrengthReduce, unroll::LoopUnroll,
570 unswitch::LoopUnswitch,
571 };
572 use crate::parser::Parser;
573 use crate::preprocess::PreprocConfig;
574 use crate::sema::{resolve, validate};
575 use std::fs;
576 use std::path::PathBuf;
577
578 fn span() -> Span {
579 let pos = Position { line: 0, col: 0 };
580 Span {
581 file_id: 0,
582 start: pos,
583 end: pos,
584 }
585 }
586
587 fn push_inst(func: &mut Function, block: BlockId, kind: InstKind, ty: IrType) -> ValueId {
588 let id = func.next_value_id();
589 func.register_type(id, ty.clone());
590 func.block_mut(block).insts.push(Inst {
591 id,
592 kind,
593 ty,
594 span: span(),
595 });
596 id
597 }
598
599 fn lower_fixture(name: &str) -> Module {
600 let path = PathBuf::from("test_programs").join(name);
601 let source = fs::read_to_string(&path).expect("fixture source");
602 let pp = crate::preprocess::preprocess(
603 &source,
604 &PreprocConfig {
605 filename: path.to_string_lossy().into_owned(),
606 fixed_form: false,
607 ..PreprocConfig::default()
608 },
609 )
610 .expect("preprocess fixture");
611 let tokens = tokenize(&pp.text, 0, SourceForm::FreeForm).expect("tokenize fixture");
612 let mut parser = Parser::new(&tokens);
613 let units = parser.parse_file().expect("parse fixture");
614 let (st, type_layouts) = {
615 let rr = resolve::resolve_file(&units, &[]).expect("resolve fixture");
616 (rr.st, rr.type_layouts)
617 };
618 let diags = validate::validate_file(&units, &st);
619 assert!(
620 !diags
621 .iter()
622 .any(|diag| diag.kind == validate::DiagKind::Error),
623 "fixture should lower cleanly: {:?}",
624 diags
625 );
626 lower::lower_file(
627 &units,
628 &st,
629 &type_layouts,
630 std::collections::HashMap::new(),
631 std::collections::HashMap::new(),
632 std::collections::HashMap::new(),
633 )
634 .0
635 }
636
637 fn build_pre_gvn_o2_pipeline() -> PassManager {
638 let mut pm = PassManager::new();
639 pm.add(Box::new(CallResolve));
640 pm.add(Box::new(Mem2Reg));
641 pm.add(Box::new(ConstFold));
642 pm.add(Box::new(Sroa));
643 pm.add(Box::new(Mem2Reg));
644 pm.add(Box::new(Inline::for_level(OptLevel::O2)));
645 pm.add(Box::new(SimplifyCfg));
646 pm.add(Box::new(DeadFuncElim));
647 pm.add(Box::new(Bce));
648 pm.add(Box::new(StrengthReduce));
649 pm.add(Box::new(LocalLsf));
650 pm.add(Box::new(GlobalLsf));
651 pm.add(Box::new(crate::opt::cse::LocalCse));
652 pm.add(Box::new(PreheaderInsert));
653 pm.add(Box::new(LoopPeel));
654 pm.add(Box::new(LoopUnswitch));
655 pm.add(Box::new(Licm));
656 pm.add(Box::new(ConstProp));
657 pm.add(Box::new(Dse));
658 pm.add(Box::new(LoopInterchange));
659 pm.add(Box::new(LoopFission));
660 pm.add(Box::new(LoopFusion));
661 pm.add(Box::new(LoopUnroll));
662 pm
663 }
664
665 #[test]
666 fn gvn_no_op_on_empty() {
667 let mut m = Module::new("test".into());
668 let mut f = Function::new("test".into(), vec![], IrType::Void);
669 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
670 m.add_function(f);
671 let pass = Gvn;
672 assert!(!pass.run(&mut m));
673 }
674
675 #[test]
676 fn gvn_handles_max_i128_constant_keys() {
677 let mut m = Module::new("test".into());
678 let mut f = Function::new("test".into(), vec![], IrType::Int(IntWidth::I128));
679 let entry = f.entry;
680 let wide = push_inst(
681 &mut f,
682 entry,
683 InstKind::ConstInt(i128::MAX, IntWidth::I128),
684 IrType::Int(IntWidth::I128),
685 );
686 f.block_mut(entry).terminator = Some(Terminator::Return(Some(wide)));
687 m.add_function(f);
688 let pass = Gvn;
689 assert!(
690 !pass.run(&mut m),
691 "single max-i128 constant should not produce a spurious rewrite"
692 );
693 }
694
695 #[test]
696 fn gvn_reuses_pure_by_value_calls() {
697 let mut m = Module::new("test".into());
698
699 let param = Param {
700 name: "x".into(),
701 ty: IrType::Int(IntWidth::I32),
702 id: ValueId(0),
703 fortran_noalias: false,
704 };
705 let mut callee = Function::new("square".into(), vec![param], IrType::Int(IntWidth::I32));
706 callee.is_pure = true;
707 callee.block_mut(callee.entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
708 m.add_function(callee);
709
710 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
711 let entry = caller.entry;
712 let c7 = push_inst(
713 &mut caller,
714 entry,
715 InstKind::ConstInt(7, IntWidth::I32),
716 IrType::Int(IntWidth::I32),
717 );
718 let call1 = push_inst(
719 &mut caller,
720 entry,
721 InstKind::Call(FuncRef::Internal(0), vec![c7]),
722 IrType::Int(IntWidth::I32),
723 );
724 let call2 = push_inst(
725 &mut caller,
726 entry,
727 InstKind::Call(FuncRef::Internal(0), vec![c7]),
728 IrType::Int(IntWidth::I32),
729 );
730 let sum = push_inst(
731 &mut caller,
732 entry,
733 InstKind::IAdd(call1, call2),
734 IrType::Int(IntWidth::I32),
735 );
736 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
737 m.add_function(caller);
738
739 let pass = Gvn;
740 assert!(pass.run(&mut m));
741
742 let caller = &m.functions[1];
743 let calls: Vec<&Inst> = caller.blocks[0]
744 .insts
745 .iter()
746 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
747 .collect();
748 assert_eq!(calls.len(), 1, "redundant PURE call should be removed");
749 let kept_call = calls[0].id;
750
751 let add = caller.blocks[0]
752 .insts
753 .iter()
754 .find(|inst| matches!(inst.kind, InstKind::IAdd(..)))
755 .expect("caller should still contain the sum");
756 match add.kind {
757 InstKind::IAdd(lhs, rhs) => {
758 assert_eq!(lhs, kept_call);
759 assert_eq!(rhs, kept_call);
760 }
761 _ => unreachable!(),
762 }
763 }
764
765 #[test]
766 fn gvn_reuses_pure_calls_after_operand_canonicalization() {
767 let mut m = Module::new("test".into());
768
769 let param = Param {
770 name: "x".into(),
771 ty: IrType::Int(IntWidth::I32),
772 id: ValueId(0),
773 fortran_noalias: false,
774 };
775 let mut callee = Function::new("square".into(), vec![param], IrType::Int(IntWidth::I32));
776 callee.is_pure = true;
777 callee.block_mut(callee.entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
778 m.add_function(callee);
779
780 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
781 let entry = caller.entry;
782 let a = push_inst(
783 &mut caller,
784 entry,
785 InstKind::ConstInt(3, IntWidth::I32),
786 IrType::Int(IntWidth::I32),
787 );
788 let b = push_inst(
789 &mut caller,
790 entry,
791 InstKind::ConstInt(4, IntWidth::I32),
792 IrType::Int(IntWidth::I32),
793 );
794 let sum1 = push_inst(
795 &mut caller,
796 entry,
797 InstKind::IAdd(a, b),
798 IrType::Int(IntWidth::I32),
799 );
800 let sum2 = push_inst(
801 &mut caller,
802 entry,
803 InstKind::IAdd(a, b),
804 IrType::Int(IntWidth::I32),
805 );
806 let call1 = push_inst(
807 &mut caller,
808 entry,
809 InstKind::Call(FuncRef::Internal(0), vec![sum1]),
810 IrType::Int(IntWidth::I32),
811 );
812 let call2 = push_inst(
813 &mut caller,
814 entry,
815 InstKind::Call(FuncRef::Internal(0), vec![sum2]),
816 IrType::Int(IntWidth::I32),
817 );
818 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(call2)));
819 m.add_function(caller);
820
821 let pass = Gvn;
822 assert!(pass.run(&mut m));
823
824 let caller = &m.functions[1];
825 let call_count = caller.blocks[0]
826 .insts
827 .iter()
828 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
829 .count();
830 assert_eq!(
831 call_count, 1,
832 "PURE call should reuse equivalent dominating args"
833 );
834
835 match caller.blocks[0]
836 .terminator
837 .as_ref()
838 .expect("return terminator")
839 {
840 Terminator::Return(Some(v)) => assert_eq!(*v, call1),
841 other => panic!("unexpected terminator: {:?}", other),
842 }
843 }
844
845 #[test]
846 fn pure_pointer_param_policy_accepts_read_only_scalar_pattern() {
847 let param = Param {
848 name: "n".into(),
849 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
850 id: ValueId(0),
851 fortran_noalias: false,
852 };
853 let mut callee =
854 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
855 callee.is_pure = true;
856 let entry = callee.entry;
857
858 let slot = push_inst(
859 &mut callee,
860 entry,
861 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
862 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
863 );
864 push_inst(
865 &mut callee,
866 entry,
867 InstKind::Store(ValueId(0), slot),
868 IrType::Void,
869 );
870 let arg_ptr = push_inst(
871 &mut callee,
872 entry,
873 InstKind::Load(slot),
874 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
875 );
876 let arg_val = push_inst(
877 &mut callee,
878 entry,
879 InstKind::Load(arg_ptr),
880 IrType::Int(IntWidth::I32),
881 );
882 callee.block_mut(entry).terminator = Some(Terminator::Return(Some(arg_val)));
883
884 let policy = PureCallPolicy::for_function(&callee);
885 assert!(policy.reusable);
886 assert_eq!(policy.arg_policies, vec![PureArgPolicy::ReadOnlyWrapperPtr]);
887 }
888
889 #[test]
890 fn wrapper_alloca_values_accept_scalar_by_ref_call_pattern() {
891 let mut m = Module::new("test".into());
892
893 let param = Param {
894 name: "n".into(),
895 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
896 id: ValueId(0),
897 fortran_noalias: false,
898 };
899 let mut callee =
900 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
901 callee.is_pure = true;
902 let callee_entry = callee.entry;
903 let slot = push_inst(
904 &mut callee,
905 callee_entry,
906 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
907 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
908 );
909 push_inst(
910 &mut callee,
911 callee_entry,
912 InstKind::Store(ValueId(0), slot),
913 IrType::Void,
914 );
915 let arg_ptr = push_inst(
916 &mut callee,
917 callee_entry,
918 InstKind::Load(slot),
919 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
920 );
921 let arg_val = push_inst(
922 &mut callee,
923 callee_entry,
924 InstKind::Load(arg_ptr),
925 IrType::Int(IntWidth::I32),
926 );
927 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(arg_val)));
928 m.add_function(callee);
929
930 let pure_calls = vec![PureCallPolicy::for_function(&m.functions[0])];
931
932 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
933 let entry = caller.entry;
934 let c = push_inst(
935 &mut caller,
936 entry,
937 InstKind::ConstInt(6, IntWidth::I32),
938 IrType::Int(IntWidth::I32),
939 );
940 let wrapper = push_inst(
941 &mut caller,
942 entry,
943 InstKind::Alloca(IrType::Int(IntWidth::I32)),
944 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
945 );
946 push_inst(
947 &mut caller,
948 entry,
949 InstKind::Store(c, wrapper),
950 IrType::Void,
951 );
952 let call = push_inst(
953 &mut caller,
954 entry,
955 InstKind::Call(FuncRef::Internal(0), vec![wrapper]),
956 IrType::Int(IntWidth::I32),
957 );
958 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(call)));
959
960 let wrappers = wrapper_alloca_values(&caller, &pure_calls);
961 assert_eq!(wrappers.get(&wrapper), Some(&c));
962 }
963
964 #[test]
965 fn gvn_reuses_pure_calls_through_scalar_wrapper_allocas() {
966 let mut m = Module::new("test".into());
967
968 let param = Param {
969 name: "n".into(),
970 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
971 id: ValueId(0),
972 fortran_noalias: false,
973 };
974 let mut callee =
975 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
976 callee.is_pure = true;
977 let callee_entry = callee.entry;
978 let shadow = push_inst(
979 &mut callee,
980 callee_entry,
981 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
982 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
983 );
984 push_inst(
985 &mut callee,
986 callee_entry,
987 InstKind::Store(ValueId(0), shadow),
988 IrType::Void,
989 );
990 let arg_ptr = push_inst(
991 &mut callee,
992 callee_entry,
993 InstKind::Load(shadow),
994 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
995 );
996 let arg_val = push_inst(
997 &mut callee,
998 callee_entry,
999 InstKind::Load(arg_ptr),
1000 IrType::Int(IntWidth::I32),
1001 );
1002 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(arg_val)));
1003 m.add_function(callee);
1004
1005 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1006 let entry = caller.entry;
1007 let first = push_inst(
1008 &mut caller,
1009 entry,
1010 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1011 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1012 );
1013 let second = push_inst(
1014 &mut caller,
1015 entry,
1016 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1017 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1018 );
1019 let c1 = push_inst(
1020 &mut caller,
1021 entry,
1022 InstKind::ConstInt(6, IntWidth::I32),
1023 IrType::Int(IntWidth::I32),
1024 );
1025 let c2 = push_inst(
1026 &mut caller,
1027 entry,
1028 InstKind::ConstInt(6, IntWidth::I32),
1029 IrType::Int(IntWidth::I32),
1030 );
1031 let wrap1 = push_inst(
1032 &mut caller,
1033 entry,
1034 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1035 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1036 );
1037 push_inst(&mut caller, entry, InstKind::Store(c2, wrap1), IrType::Void);
1038 let call1 = push_inst(
1039 &mut caller,
1040 entry,
1041 InstKind::Call(FuncRef::Internal(0), vec![wrap1]),
1042 IrType::Int(IntWidth::I32),
1043 );
1044 push_inst(
1045 &mut caller,
1046 entry,
1047 InstKind::Store(call1, first),
1048 IrType::Void,
1049 );
1050 let c3 = push_inst(
1051 &mut caller,
1052 entry,
1053 InstKind::ConstInt(6, IntWidth::I32),
1054 IrType::Int(IntWidth::I32),
1055 );
1056 let c4 = push_inst(
1057 &mut caller,
1058 entry,
1059 InstKind::ConstInt(6, IntWidth::I32),
1060 IrType::Int(IntWidth::I32),
1061 );
1062 let wrap2 = push_inst(
1063 &mut caller,
1064 entry,
1065 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1066 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1067 );
1068 push_inst(&mut caller, entry, InstKind::Store(c4, wrap2), IrType::Void);
1069 let call2 = push_inst(
1070 &mut caller,
1071 entry,
1072 InstKind::Call(FuncRef::Internal(0), vec![wrap2]),
1073 IrType::Int(IntWidth::I32),
1074 );
1075 push_inst(
1076 &mut caller,
1077 entry,
1078 InstKind::Store(call2, second),
1079 IrType::Void,
1080 );
1081 let left = push_inst(
1082 &mut caller,
1083 entry,
1084 InstKind::Load(first),
1085 IrType::Int(IntWidth::I32),
1086 );
1087 let right = push_inst(
1088 &mut caller,
1089 entry,
1090 InstKind::Load(second),
1091 IrType::Int(IntWidth::I32),
1092 );
1093 let sum = push_inst(
1094 &mut caller,
1095 entry,
1096 InstKind::IAdd(left, right),
1097 IrType::Int(IntWidth::I32),
1098 );
1099 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
1100 m.add_function(caller);
1101
1102 let pure_calls: Vec<PureCallPolicy> = m
1103 .functions
1104 .iter()
1105 .map(PureCallPolicy::for_function)
1106 .collect();
1107 assert!(pure_calls[0].reusable);
1108 assert_eq!(
1109 pure_calls[0].arg_policies,
1110 vec![PureArgPolicy::ReadOnlyWrapperPtr]
1111 );
1112 let wrappers = wrapper_alloca_values(&m.functions[1], &pure_calls);
1113 assert_eq!(wrappers.get(&wrap1), Some(&c2));
1114 assert_eq!(wrappers.get(&wrap2), Some(&c4));
1115
1116 let mut replacements = HashMap::new();
1117 replacements.insert(c2, c1);
1118 replacements.insert(c3, c1);
1119 replacements.insert(c4, c1);
1120 let caller_before = &m.functions[1];
1121 let call1_inst = caller_before.blocks[0]
1122 .insts
1123 .iter()
1124 .find(|inst| inst.id == call1)
1125 .expect("call1 inst");
1126 let call2_inst = caller_before.blocks[0]
1127 .insts
1128 .iter()
1129 .find(|inst| inst.id == call2)
1130 .expect("call2 inst");
1131 let key1 =
1132 key_of(call1_inst, &replacements, &pure_calls, &wrappers).expect("call1 should key");
1133 let key2 =
1134 key_of(call2_inst, &replacements, &pure_calls, &wrappers).expect("call2 should key");
1135 assert_eq!(
1136 key1, key2,
1137 "wrapper-based call keys should line up before the pass runs"
1138 );
1139
1140 let pass = Gvn;
1141 assert!(pass.run(&mut m));
1142
1143 let caller = &m.functions[1];
1144 let call_count = caller.blocks[0]
1145 .insts
1146 .iter()
1147 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1148 .count();
1149 assert_eq!(call_count, 1, "wrapper-alloca PURE calls should dedupe");
1150 }
1151
1152 #[test]
1153 fn gvn_reuses_pure_calls_for_recursive_factorial_shape() {
1154 let mut m = Module::new("test".into());
1155
1156 let param = Param {
1157 name: "n".into(),
1158 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1159 id: ValueId(0),
1160 fortran_noalias: false,
1161 };
1162 let mut callee =
1163 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1164 callee.is_pure = true;
1165 let entry = callee.entry;
1166 let if_end = callee.create_block("if_end");
1167 let if_then = callee.create_block("if_then");
1168 let if_else = callee.create_block("if_else");
1169
1170 let shadow = push_inst(
1171 &mut callee,
1172 entry,
1173 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1174 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1175 );
1176 push_inst(
1177 &mut callee,
1178 entry,
1179 InstKind::Store(ValueId(0), shadow),
1180 IrType::Void,
1181 );
1182 let result = push_inst(
1183 &mut callee,
1184 entry,
1185 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1186 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1187 );
1188 let p1 = push_inst(
1189 &mut callee,
1190 entry,
1191 InstKind::Load(shadow),
1192 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1193 );
1194 let n1 = push_inst(
1195 &mut callee,
1196 entry,
1197 InstKind::Load(p1),
1198 IrType::Int(IntWidth::I32),
1199 );
1200 let one = push_inst(
1201 &mut callee,
1202 entry,
1203 InstKind::ConstInt(1, IntWidth::I32),
1204 IrType::Int(IntWidth::I32),
1205 );
1206 let cond = push_inst(
1207 &mut callee,
1208 entry,
1209 InstKind::ICmp(CmpOp::Le, n1, one),
1210 IrType::Bool,
1211 );
1212 callee.block_mut(entry).terminator = Some(Terminator::CondBranch {
1213 cond,
1214 true_dest: if_then,
1215 true_args: vec![],
1216 false_dest: if_else,
1217 false_args: vec![],
1218 });
1219
1220 let then_one = push_inst(
1221 &mut callee,
1222 if_then,
1223 InstKind::ConstInt(1, IntWidth::I32),
1224 IrType::Int(IntWidth::I32),
1225 );
1226 push_inst(
1227 &mut callee,
1228 if_then,
1229 InstKind::Store(then_one, result),
1230 IrType::Void,
1231 );
1232 callee.block_mut(if_then).terminator = Some(Terminator::Branch(if_end, vec![]));
1233
1234 let p2 = push_inst(
1235 &mut callee,
1236 if_else,
1237 InstKind::Load(shadow),
1238 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1239 );
1240 let n2 = push_inst(
1241 &mut callee,
1242 if_else,
1243 InstKind::Load(p2),
1244 IrType::Int(IntWidth::I32),
1245 );
1246 let p3 = push_inst(
1247 &mut callee,
1248 if_else,
1249 InstKind::Load(shadow),
1250 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1251 );
1252 let n3 = push_inst(
1253 &mut callee,
1254 if_else,
1255 InstKind::Load(p3),
1256 IrType::Int(IntWidth::I32),
1257 );
1258 let one2 = push_inst(
1259 &mut callee,
1260 if_else,
1261 InstKind::ConstInt(1, IntWidth::I32),
1262 IrType::Int(IntWidth::I32),
1263 );
1264 let dec1 = push_inst(
1265 &mut callee,
1266 if_else,
1267 InstKind::ISub(n3, one2),
1268 IrType::Int(IntWidth::I32),
1269 );
1270 let p4 = push_inst(
1271 &mut callee,
1272 if_else,
1273 InstKind::Load(shadow),
1274 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1275 );
1276 let n4 = push_inst(
1277 &mut callee,
1278 if_else,
1279 InstKind::Load(p4),
1280 IrType::Int(IntWidth::I32),
1281 );
1282 let one3 = push_inst(
1283 &mut callee,
1284 if_else,
1285 InstKind::ConstInt(1, IntWidth::I32),
1286 IrType::Int(IntWidth::I32),
1287 );
1288 let dec2 = push_inst(
1289 &mut callee,
1290 if_else,
1291 InstKind::ISub(n4, one3),
1292 IrType::Int(IntWidth::I32),
1293 );
1294 let wrap = push_inst(
1295 &mut callee,
1296 if_else,
1297 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1298 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1299 );
1300 push_inst(
1301 &mut callee,
1302 if_else,
1303 InstKind::Store(dec2, wrap),
1304 IrType::Void,
1305 );
1306 let rec = push_inst(
1307 &mut callee,
1308 if_else,
1309 InstKind::Call(FuncRef::Internal(0), vec![wrap]),
1310 IrType::Int(IntWidth::I32),
1311 );
1312 let prod = push_inst(
1313 &mut callee,
1314 if_else,
1315 InstKind::IMul(n2, rec),
1316 IrType::Int(IntWidth::I32),
1317 );
1318 push_inst(
1319 &mut callee,
1320 if_else,
1321 InstKind::Store(prod, result),
1322 IrType::Void,
1323 );
1324 callee.block_mut(if_else).terminator = Some(Terminator::Branch(if_end, vec![]));
1325
1326 let final_val = push_inst(
1327 &mut callee,
1328 if_end,
1329 InstKind::Load(result),
1330 IrType::Int(IntWidth::I32),
1331 );
1332 callee.block_mut(if_end).terminator = Some(Terminator::Return(Some(final_val)));
1333 m.add_function(callee);
1334
1335 let mut caller = Function::new("main".into(), vec![], IrType::Void);
1336 let caller_entry = caller.entry;
1337 let unit = push_inst(
1338 &mut caller,
1339 caller_entry,
1340 InstKind::ConstInt(6, IntWidth::I32),
1341 IrType::Int(IntWidth::I32),
1342 );
1343 let unit2 = push_inst(
1344 &mut caller,
1345 caller_entry,
1346 InstKind::ConstInt(6, IntWidth::I32),
1347 IrType::Int(IntWidth::I32),
1348 );
1349 let wrap1 = push_inst(
1350 &mut caller,
1351 caller_entry,
1352 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1353 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1354 );
1355 push_inst(
1356 &mut caller,
1357 caller_entry,
1358 InstKind::Store(unit2, wrap1),
1359 IrType::Void,
1360 );
1361 let call1 = push_inst(
1362 &mut caller,
1363 caller_entry,
1364 InstKind::Call(FuncRef::Internal(0), vec![wrap1]),
1365 IrType::Int(IntWidth::I32),
1366 );
1367 let wrap2 = push_inst(
1368 &mut caller,
1369 caller_entry,
1370 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1371 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1372 );
1373 push_inst(
1374 &mut caller,
1375 caller_entry,
1376 InstKind::Store(unit, wrap2),
1377 IrType::Void,
1378 );
1379 let call2 = push_inst(
1380 &mut caller,
1381 caller_entry,
1382 InstKind::Call(FuncRef::Internal(0), vec![wrap2]),
1383 IrType::Int(IntWidth::I32),
1384 );
1385 let sum = push_inst(
1386 &mut caller,
1387 caller_entry,
1388 InstKind::IAdd(call1, call2),
1389 IrType::Int(IntWidth::I32),
1390 );
1391 caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(sum)));
1392 m.add_function(caller);
1393
1394 let pass = Gvn;
1395 assert!(pass.run(&mut m));
1396 let caller = &m.functions[1];
1397 let call_count = caller.blocks[0]
1398 .insts
1399 .iter()
1400 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1401 .count();
1402 assert_eq!(
1403 call_count, 1,
1404 "recursive PURE factorial calls should dedupe"
1405 );
1406 }
1407
1408 #[test]
1409 fn gvn_reuses_pure_calls_when_callee_follows_caller() {
1410 let mut m = Module::new("test".into());
1411
1412 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1413 let entry = caller.entry;
1414 let c1 = push_inst(
1415 &mut caller,
1416 entry,
1417 InstKind::ConstInt(6, IntWidth::I32),
1418 IrType::Int(IntWidth::I32),
1419 );
1420 let wrap1 = push_inst(
1421 &mut caller,
1422 entry,
1423 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1424 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1425 );
1426 push_inst(&mut caller, entry, InstKind::Store(c1, wrap1), IrType::Void);
1427 let call1 = push_inst(
1428 &mut caller,
1429 entry,
1430 InstKind::Call(FuncRef::Internal(1), vec![wrap1]),
1431 IrType::Int(IntWidth::I32),
1432 );
1433 let c2 = push_inst(
1434 &mut caller,
1435 entry,
1436 InstKind::ConstInt(6, IntWidth::I32),
1437 IrType::Int(IntWidth::I32),
1438 );
1439 let wrap2 = push_inst(
1440 &mut caller,
1441 entry,
1442 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1443 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1444 );
1445 push_inst(&mut caller, entry, InstKind::Store(c2, wrap2), IrType::Void);
1446 let call2 = push_inst(
1447 &mut caller,
1448 entry,
1449 InstKind::Call(FuncRef::Internal(1), vec![wrap2]),
1450 IrType::Int(IntWidth::I32),
1451 );
1452 let sum = push_inst(
1453 &mut caller,
1454 entry,
1455 InstKind::IAdd(call1, call2),
1456 IrType::Int(IntWidth::I32),
1457 );
1458 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
1459 m.add_function(caller);
1460
1461 let param = Param {
1462 name: "n".into(),
1463 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1464 id: ValueId(0),
1465 fortran_noalias: false,
1466 };
1467 let mut callee =
1468 Function::new("heavy_fact".into(), vec![param], IrType::Int(IntWidth::I32));
1469 callee.is_pure = true;
1470 let callee_entry = callee.entry;
1471 let shadow = push_inst(
1472 &mut callee,
1473 callee_entry,
1474 InstKind::Alloca(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))),
1475 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))))),
1476 );
1477 push_inst(
1478 &mut callee,
1479 callee_entry,
1480 InstKind::Store(ValueId(0), shadow),
1481 IrType::Void,
1482 );
1483 let arg_ptr = push_inst(
1484 &mut callee,
1485 callee_entry,
1486 InstKind::Load(shadow),
1487 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1488 );
1489 let arg_val = push_inst(
1490 &mut callee,
1491 callee_entry,
1492 InstKind::Load(arg_ptr),
1493 IrType::Int(IntWidth::I32),
1494 );
1495 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(arg_val)));
1496 m.add_function(callee);
1497
1498 let pass = Gvn;
1499 assert!(pass.run(&mut m));
1500 let caller = &m.functions[0];
1501 let call_count = caller.blocks[0]
1502 .insts
1503 .iter()
1504 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1505 .count();
1506 assert_eq!(
1507 call_count, 1,
1508 "caller-before-callee ordering should still dedupe"
1509 );
1510 }
1511
1512 #[test]
1513 fn gvn_matches_real_pure_recursive_fixture_after_pre_o2_passes() {
1514 let mut module = lower_fixture("pure_recursive_reuse.f90");
1515 let pm = build_pre_gvn_o2_pipeline();
1516 pm.run(&mut module);
1517
1518 let pure_calls: Vec<PureCallPolicy> = module
1519 .functions
1520 .iter()
1521 .map(PureCallPolicy::for_function)
1522 .collect();
1523 assert!(
1524 pure_calls.iter().any(|policy| policy.reusable),
1525 "at least one function in the fixture should remain a reusable PURE callee:\n{}",
1526 crate::ir::printer::print_module(&module)
1527 );
1528
1529 let caller_idx = module
1530 .functions
1531 .iter()
1532 .position(|func| func.name == "__prog_pure_recursive_reuse")
1533 .expect("program entry in fixture");
1534 let wrappers = wrapper_alloca_values(&module.functions[caller_idx], &pure_calls);
1535 assert!(
1536 !wrappers.is_empty(),
1537 "caller should still expose scalar by-ref wrapper allocas before GVN"
1538 );
1539
1540 let pass = Gvn;
1541 assert!(
1542 pass.run(&mut module),
1543 "GVN should make progress on the real fixture"
1544 );
1545
1546 let caller = &module.functions[caller_idx];
1547 let call_count = caller.blocks[0]
1548 .insts
1549 .iter()
1550 .filter(|inst| matches!(inst.kind, InstKind::Call(FuncRef::Internal(_), _)))
1551 .count();
1552 assert_eq!(
1553 call_count, 1,
1554 "real fixture caller should end with one PURE recursive call"
1555 );
1556 }
1557
1558 #[test]
1559 fn gvn_does_not_reuse_pure_calls_with_pointer_args() {
1560 let mut m = Module::new("test".into());
1561
1562 let param = Param {
1563 name: "p".into(),
1564 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1565 id: ValueId(0),
1566 fortran_noalias: false,
1567 };
1568 let mut callee = Function::new("peek".into(), vec![param], IrType::Int(IntWidth::I32));
1569 callee.is_pure = true;
1570 let callee_entry = callee.entry;
1571 let zero = push_inst(
1572 &mut callee,
1573 callee_entry,
1574 InstKind::ConstInt(0, IntWidth::I32),
1575 IrType::Int(IntWidth::I32),
1576 );
1577 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(zero)));
1578 m.add_function(callee);
1579
1580 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
1581 let entry = caller.entry;
1582 let slot = push_inst(
1583 &mut caller,
1584 entry,
1585 InstKind::Alloca(IrType::Int(IntWidth::I32)),
1586 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
1587 );
1588 let call1 = push_inst(
1589 &mut caller,
1590 entry,
1591 InstKind::Call(FuncRef::Internal(0), vec![slot]),
1592 IrType::Int(IntWidth::I32),
1593 );
1594 let call2 = push_inst(
1595 &mut caller,
1596 entry,
1597 InstKind::Call(FuncRef::Internal(0), vec![slot]),
1598 IrType::Int(IntWidth::I32),
1599 );
1600 let sum = push_inst(
1601 &mut caller,
1602 entry,
1603 InstKind::IAdd(call1, call2),
1604 IrType::Int(IntWidth::I32),
1605 );
1606 caller.block_mut(entry).terminator = Some(Terminator::Return(Some(sum)));
1607 m.add_function(caller);
1608
1609 let pass = Gvn;
1610 assert!(
1611 !pass.run(&mut m),
1612 "pointer-arg PURE calls must stay distinct"
1613 );
1614
1615 let caller = &m.functions[1];
1616 let call_count = caller.blocks[0]
1617 .insts
1618 .iter()
1619 .filter(|inst| matches!(inst.kind, InstKind::Call(..)))
1620 .count();
1621 assert_eq!(call_count, 2);
1622 }
1623 }
1624