Rust · 10658 bytes Raw Blame History
1 //! Dead-argument elimination for contained procedures.
2 //!
3 //! Removes unused parameters from internal-only functions and trims the
4 //! matching argument slots from all `Call(Internal, ...)` sites.
5
6 use crate::ir::inst::*;
7
8 use super::pass::Pass;
9
10 pub struct DeadArgElim;
11
12 impl Pass for DeadArgElim {
13 fn name(&self) -> &'static str {
14 "dead-arg-elim"
15 }
16
17 fn run(&self, module: &mut Module) -> bool {
18 let mut live_masks: Vec<Vec<bool>> = module
19 .functions
20 .iter()
21 .map(|func| vec![false; func.params.len()])
22 .collect();
23
24 let mut changed = true;
25 while changed {
26 changed = false;
27 for (func_idx, func) in module.functions.iter().enumerate() {
28 if !func.internal_only || func.params.is_empty() {
29 continue;
30 }
31 for (param_idx, param) in func.params.iter().enumerate() {
32 let live = param_has_observable_use(module, func_idx, param.id, &live_masks);
33 if live_masks[func_idx][param_idx] != live {
34 live_masks[func_idx][param_idx] = live;
35 changed = true;
36 }
37 }
38 }
39 }
40
41 let mut removed_any = false;
42 for (idx, func) in module.functions.iter_mut().enumerate() {
43 if !func.internal_only || func.params.is_empty() {
44 continue;
45 }
46 let keep = &live_masks[idx];
47 if keep.iter().all(|&slot| slot) {
48 continue;
49 }
50 let old_params = std::mem::take(&mut func.params);
51 func.params = old_params
52 .into_iter()
53 .zip(keep.iter().copied())
54 .filter_map(|(param, keep_slot)| keep_slot.then_some(param))
55 .collect();
56 removed_any = true;
57 }
58
59 if !removed_any {
60 return false;
61 }
62
63 for func in &mut module.functions {
64 for block in &mut func.blocks {
65 for inst in &mut block.insts {
66 let InstKind::Call(FuncRef::Internal(callee_idx), args) = &mut inst.kind else {
67 continue;
68 };
69 let Some(keep) = live_masks.get(*callee_idx as usize) else {
70 continue;
71 };
72 let old_args = std::mem::take(args);
73 *args = old_args
74 .into_iter()
75 .zip(keep.iter().copied())
76 .filter_map(|(arg, keep_slot)| keep_slot.then_some(arg))
77 .collect();
78 }
79 }
80 }
81
82 true
83 }
84 }
85
86 fn param_has_observable_use(
87 module: &Module,
88 func_idx: usize,
89 param_id: ValueId,
90 live_masks: &[Vec<bool>],
91 ) -> bool {
92 let func = &module.functions[func_idx];
93 for block in &func.blocks {
94 for inst in &block.insts {
95 match &inst.kind {
96 InstKind::Call(FuncRef::Internal(callee_idx), args) => {
97 for (arg_idx, arg) in args.iter().enumerate() {
98 if *arg != param_id {
99 continue;
100 }
101 let Some(callee) = module.functions.get(*callee_idx as usize) else {
102 return true;
103 };
104 if !callee.internal_only {
105 return true;
106 }
107 let Some(callee_live) = live_masks
108 .get(*callee_idx as usize)
109 .and_then(|mask| mask.get(arg_idx))
110 else {
111 return true;
112 };
113 if *callee_live {
114 return true;
115 }
116 }
117 }
118 kind => {
119 if inst_uses_param(kind, param_id) {
120 return true;
121 }
122 }
123 }
124 }
125 if block_terminator_uses_param(&block.terminator, param_id) {
126 return true;
127 }
128 }
129 false
130 }
131
132 fn inst_uses_param(kind: &InstKind, param_id: ValueId) -> bool {
133 match kind {
134 InstKind::ConstInt(..)
135 | InstKind::ConstFloat(..)
136 | InstKind::ConstBool(..)
137 | InstKind::ConstString(..)
138 | InstKind::Undef(..)
139 | InstKind::Alloca(..)
140 | InstKind::GlobalAddr(..) => false,
141
142 InstKind::IAdd(a, b)
143 | InstKind::ISub(a, b)
144 | InstKind::IMul(a, b)
145 | InstKind::IDiv(a, b)
146 | InstKind::IMod(a, b)
147 | InstKind::FAdd(a, b)
148 | InstKind::FSub(a, b)
149 | InstKind::FMul(a, b)
150 | InstKind::FDiv(a, b)
151 | InstKind::FPow(a, b)
152 | InstKind::ICmp(_, a, b)
153 | InstKind::FCmp(_, a, b)
154 | InstKind::And(a, b)
155 | InstKind::Or(a, b)
156 | InstKind::BitAnd(a, b)
157 | InstKind::BitOr(a, b)
158 | InstKind::BitXor(a, b)
159 | InstKind::Shl(a, b)
160 | InstKind::LShr(a, b)
161 | InstKind::AShr(a, b)
162 | InstKind::Store(a, b) => *a == param_id || *b == param_id,
163
164 InstKind::INeg(a)
165 | InstKind::FNeg(a)
166 | InstKind::FAbs(a)
167 | InstKind::FSqrt(a)
168 | InstKind::Not(a)
169 | InstKind::BitNot(a)
170 | InstKind::CountLeadingZeros(a)
171 | InstKind::CountTrailingZeros(a)
172 | InstKind::PopCount(a)
173 | InstKind::IntToFloat(a, _)
174 | InstKind::FloatToInt(a, _)
175 | InstKind::FloatExtend(a, _)
176 | InstKind::FloatTrunc(a, _)
177 | InstKind::IntExtend(a, _, _)
178 | InstKind::IntTrunc(a, _)
179 | InstKind::PtrToInt(a)
180 | InstKind::IntToPtr(a, _)
181 | InstKind::Load(a)
182 | InstKind::ExtractField(a, _) => *a == param_id,
183
184 InstKind::Select(c, t, f) => *c == param_id || *t == param_id || *f == param_id,
185 InstKind::GetElementPtr(base, idxs) => *base == param_id || idxs.contains(&param_id),
186 InstKind::RuntimeCall(_, args) | InstKind::Call(FuncRef::External(_), args) => {
187 args.contains(&param_id)
188 }
189 InstKind::Call(FuncRef::Indirect(target), args) => {
190 *target == param_id || args.contains(&param_id)
191 }
192 InstKind::InsertField(agg, _, val) => *agg == param_id || *val == param_id,
193 InstKind::Call(FuncRef::Internal(_), _) => false,
194 }
195 }
196
197 fn block_terminator_uses_param(term: &Option<Terminator>, param_id: ValueId) -> bool {
198 match term {
199 Some(Terminator::Return(Some(v))) => *v == param_id,
200 Some(Terminator::Branch(_, args)) => args.contains(&param_id),
201 Some(Terminator::CondBranch {
202 cond,
203 true_args,
204 false_args,
205 ..
206 }) => *cond == param_id || true_args.contains(&param_id) || false_args.contains(&param_id),
207 Some(Terminator::Switch { selector, .. }) => *selector == param_id,
208 _ => false,
209 }
210 }
211
212 #[cfg(test)]
213 mod tests {
214 use super::*;
215 use crate::ir::types::{IntWidth, IrType};
216 use crate::lexer::{Position, Span};
217
218 fn span() -> Span {
219 let pos = Position { line: 0, col: 0 };
220 Span {
221 file_id: 0,
222 start: pos,
223 end: pos,
224 }
225 }
226
227 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
228 let id = f.next_value_id();
229 let entry = f.entry;
230 f.block_mut(entry).insts.push(Inst {
231 id,
232 kind,
233 ty: ty.clone(),
234 span: span(),
235 });
236 f.register_type(id, ty);
237 id
238 }
239
240 #[test]
241 fn trims_unused_param_and_call_arg_for_internal_only_function() {
242 let mut module = Module::new("t".into());
243
244 let params = vec![
245 Param {
246 name: "x".into(),
247 ty: IrType::Int(IntWidth::I32),
248 id: ValueId(0),
249 fortran_noalias: false,
250 },
251 Param {
252 name: "unused".into(),
253 ty: IrType::Int(IntWidth::I32),
254 id: ValueId(1),
255 fortran_noalias: false,
256 },
257 ];
258 let mut callee = Function::new("helper".into(), params, IrType::Int(IntWidth::I32));
259 callee.internal_only = true;
260 let callee_entry = callee.entry;
261 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
262 module.add_function(callee);
263
264 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
265 let a = push(
266 &mut caller,
267 InstKind::ConstInt(7, IntWidth::I32),
268 IrType::Int(IntWidth::I32),
269 );
270 let b = push(
271 &mut caller,
272 InstKind::ConstInt(9, IntWidth::I32),
273 IrType::Int(IntWidth::I32),
274 );
275 let call = push(
276 &mut caller,
277 InstKind::Call(FuncRef::Internal(0), vec![a, b]),
278 IrType::Int(IntWidth::I32),
279 );
280 let caller_entry = caller.entry;
281 caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(call)));
282 module.add_function(caller);
283
284 assert!(DeadArgElim.run(&mut module));
285 assert_eq!(module.functions[0].params.len(), 1);
286
287 let call_args = match &module.functions[1].blocks[0].insts[2].kind {
288 InstKind::Call(FuncRef::Internal(0), args) => args,
289 other => panic!("expected internal call, got {:?}", other),
290 };
291 assert_eq!(call_args, &vec![a]);
292 }
293
294 #[test]
295 fn preserves_abi_for_non_internal_function() {
296 let mut module = Module::new("t".into());
297 let params = vec![
298 Param {
299 name: "x".into(),
300 ty: IrType::Int(IntWidth::I32),
301 id: ValueId(0),
302 fortran_noalias: false,
303 },
304 Param {
305 name: "unused".into(),
306 ty: IrType::Int(IntWidth::I32),
307 id: ValueId(1),
308 fortran_noalias: false,
309 },
310 ];
311 let mut func = Function::new("exported".into(), params, IrType::Int(IntWidth::I32));
312 let entry = func.entry;
313 func.block_mut(entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
314 module.add_function(func);
315
316 assert!(!DeadArgElim.run(&mut module));
317 assert_eq!(module.functions[0].params.len(), 2);
318 }
319 }
320