Rust · 11630 bytes Raw Blame History
1 //! Dead-argument elimination for contained procedures.
2 //!
3 //! Removes unused parameters from internal-only functions and trims the
4 //! matching argument slots from all `Call(Internal, ...)` sites.
5
6 use crate::ir::inst::*;
7
8 use super::pass::Pass;
9
10 pub struct DeadArgElim;
11
12 impl Pass for DeadArgElim {
13 fn name(&self) -> &'static str {
14 "dead-arg-elim"
15 }
16
17 fn run(&self, module: &mut Module) -> bool {
18 let mut live_masks: Vec<Vec<bool>> = module
19 .functions
20 .iter()
21 .map(|func| vec![false; func.params.len()])
22 .collect();
23
24 let mut changed = true;
25 while changed {
26 changed = false;
27 for (func_idx, func) in module.functions.iter().enumerate() {
28 if !func.internal_only || func.params.is_empty() {
29 continue;
30 }
31 for (param_idx, param) in func.params.iter().enumerate() {
32 let live = param_has_observable_use(module, func_idx, param.id, &live_masks);
33 if live_masks[func_idx][param_idx] != live {
34 live_masks[func_idx][param_idx] = live;
35 changed = true;
36 }
37 }
38 }
39 }
40
41 let mut removed_any = false;
42 for (idx, func) in module.functions.iter_mut().enumerate() {
43 if !func.internal_only || func.params.is_empty() {
44 continue;
45 }
46 let keep = &live_masks[idx];
47 if keep.iter().all(|&slot| slot) {
48 continue;
49 }
50 let old_params = std::mem::take(&mut func.params);
51 func.params = old_params
52 .into_iter()
53 .zip(keep.iter().copied())
54 .filter_map(|(param, keep_slot)| keep_slot.then_some(param))
55 .collect();
56 removed_any = true;
57 }
58
59 if !removed_any {
60 return false;
61 }
62
63 for func in &mut module.functions {
64 for block in &mut func.blocks {
65 for inst in &mut block.insts {
66 let InstKind::Call(FuncRef::Internal(callee_idx), args) = &mut inst.kind else {
67 continue;
68 };
69 let Some(keep) = live_masks.get(*callee_idx as usize) else {
70 continue;
71 };
72 let old_args = std::mem::take(args);
73 *args = old_args
74 .into_iter()
75 .zip(keep.iter().copied())
76 .filter_map(|(arg, keep_slot)| keep_slot.then_some(arg))
77 .collect();
78 }
79 }
80 }
81
82 true
83 }
84 }
85
86 fn param_has_observable_use(
87 module: &Module,
88 func_idx: usize,
89 param_id: ValueId,
90 live_masks: &[Vec<bool>],
91 ) -> bool {
92 let func = &module.functions[func_idx];
93 for block in &func.blocks {
94 for inst in &block.insts {
95 match &inst.kind {
96 InstKind::Call(FuncRef::Internal(callee_idx), args) => {
97 for (arg_idx, arg) in args.iter().enumerate() {
98 if *arg != param_id {
99 continue;
100 }
101 let Some(callee) = module.functions.get(*callee_idx as usize) else {
102 return true;
103 };
104 if !callee.internal_only {
105 return true;
106 }
107 let Some(callee_live) = live_masks
108 .get(*callee_idx as usize)
109 .and_then(|mask| mask.get(arg_idx))
110 else {
111 return true;
112 };
113 if *callee_live {
114 return true;
115 }
116 }
117 }
118 kind => {
119 if inst_uses_param(kind, param_id) {
120 return true;
121 }
122 }
123 }
124 }
125 if block_terminator_uses_param(&block.terminator, param_id) {
126 return true;
127 }
128 }
129 false
130 }
131
132 fn inst_uses_param(kind: &InstKind, param_id: ValueId) -> bool {
133 match kind {
134 InstKind::ConstInt(..)
135 | InstKind::ConstFloat(..)
136 | InstKind::ConstBool(..)
137 | InstKind::ConstString(..)
138 | InstKind::Undef(..)
139 | InstKind::Alloca(..)
140 | InstKind::GlobalAddr(..) => false,
141
142 InstKind::IAdd(a, b)
143 | InstKind::ISub(a, b)
144 | InstKind::IMul(a, b)
145 | InstKind::IDiv(a, b)
146 | InstKind::IMod(a, b)
147 | InstKind::FAdd(a, b)
148 | InstKind::FSub(a, b)
149 | InstKind::FMul(a, b)
150 | InstKind::FDiv(a, b)
151 | InstKind::FPow(a, b)
152 | InstKind::ICmp(_, a, b)
153 | InstKind::FCmp(_, a, b)
154 | InstKind::And(a, b)
155 | InstKind::Or(a, b)
156 | InstKind::BitAnd(a, b)
157 | InstKind::BitOr(a, b)
158 | InstKind::BitXor(a, b)
159 | InstKind::Shl(a, b)
160 | InstKind::LShr(a, b)
161 | InstKind::AShr(a, b)
162 | InstKind::Store(a, b) => *a == param_id || *b == param_id,
163
164 InstKind::INeg(a)
165 | InstKind::FNeg(a)
166 | InstKind::FAbs(a)
167 | InstKind::FSqrt(a)
168 | InstKind::Not(a)
169 | InstKind::BitNot(a)
170 | InstKind::CountLeadingZeros(a)
171 | InstKind::CountTrailingZeros(a)
172 | InstKind::PopCount(a)
173 | InstKind::IntToFloat(a, _)
174 | InstKind::FloatToInt(a, _)
175 | InstKind::FloatExtend(a, _)
176 | InstKind::FloatTrunc(a, _)
177 | InstKind::IntExtend(a, _, _)
178 | InstKind::IntTrunc(a, _)
179 | InstKind::PtrToInt(a)
180 | InstKind::IntToPtr(a, _)
181 | InstKind::Load(a)
182 | InstKind::ExtractField(a, _) => *a == param_id,
183
184 InstKind::Select(c, t, f) => *c == param_id || *t == param_id || *f == param_id,
185 InstKind::GetElementPtr(base, idxs) => *base == param_id || idxs.contains(&param_id),
186 InstKind::RuntimeCall(_, args) | InstKind::Call(FuncRef::External(_), args) => {
187 args.contains(&param_id)
188 }
189 InstKind::Call(FuncRef::Indirect(target), args) => {
190 *target == param_id || args.contains(&param_id)
191 }
192 InstKind::InsertField(agg, _, val) => *agg == param_id || *val == param_id,
193 InstKind::Call(FuncRef::Internal(_), _) => false,
194
195 // Vector ops — fall through to the generic operand walk so we
196 // don't accidentally treat a vector inst as not using the
197 // param if a future vectorizer makes it consume one.
198 kind @ (InstKind::VAdd(..)
199 | InstKind::VSub(..)
200 | InstKind::VMul(..)
201 | InstKind::VDiv(..)
202 | InstKind::VNeg(..)
203 | InstKind::VAbs(..)
204 | InstKind::VSqrt(..)
205 | InstKind::VFma(..)
206 | InstKind::VSelect(..)
207 | InstKind::VMin(..)
208 | InstKind::VMax(..)
209 | InstKind::VICmp(..)
210 | InstKind::VFCmp(..)
211 | InstKind::VLoad(..)
212 | InstKind::VStore(..)
213 | InstKind::VBitcast(..)
214 | InstKind::VExtract(..)
215 | InstKind::VInsert(..)
216 | InstKind::VBroadcast(..)
217 | InstKind::VReduceSum(..)
218 | InstKind::VReduceMin(..)
219 | InstKind::VReduceMax(..)) => {
220 crate::ir::walk::inst_uses(kind).contains(&param_id)
221 }
222 }
223 }
224
225 fn block_terminator_uses_param(term: &Option<Terminator>, param_id: ValueId) -> bool {
226 match term {
227 Some(Terminator::Return(Some(v))) => *v == param_id,
228 Some(Terminator::Branch(_, args)) => args.contains(&param_id),
229 Some(Terminator::CondBranch {
230 cond,
231 true_args,
232 false_args,
233 ..
234 }) => *cond == param_id || true_args.contains(&param_id) || false_args.contains(&param_id),
235 Some(Terminator::Switch { selector, .. }) => *selector == param_id,
236 _ => false,
237 }
238 }
239
240 #[cfg(test)]
241 mod tests {
242 use super::*;
243 use crate::ir::types::{IntWidth, IrType};
244 use crate::lexer::{Position, Span};
245
246 fn span() -> Span {
247 let pos = Position { line: 0, col: 0 };
248 Span {
249 file_id: 0,
250 start: pos,
251 end: pos,
252 }
253 }
254
255 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
256 let id = f.next_value_id();
257 let entry = f.entry;
258 f.block_mut(entry).insts.push(Inst {
259 id,
260 kind,
261 ty: ty.clone(),
262 span: span(),
263 });
264 f.register_type(id, ty);
265 id
266 }
267
268 #[test]
269 fn trims_unused_param_and_call_arg_for_internal_only_function() {
270 let mut module = Module::new("t".into());
271
272 let params = vec![
273 Param {
274 name: "x".into(),
275 ty: IrType::Int(IntWidth::I32),
276 id: ValueId(0),
277 fortran_noalias: false,
278 },
279 Param {
280 name: "unused".into(),
281 ty: IrType::Int(IntWidth::I32),
282 id: ValueId(1),
283 fortran_noalias: false,
284 },
285 ];
286 let mut callee = Function::new("helper".into(), params, IrType::Int(IntWidth::I32));
287 callee.internal_only = true;
288 let callee_entry = callee.entry;
289 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
290 module.add_function(callee);
291
292 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
293 let a = push(
294 &mut caller,
295 InstKind::ConstInt(7, IntWidth::I32),
296 IrType::Int(IntWidth::I32),
297 );
298 let b = push(
299 &mut caller,
300 InstKind::ConstInt(9, IntWidth::I32),
301 IrType::Int(IntWidth::I32),
302 );
303 let call = push(
304 &mut caller,
305 InstKind::Call(FuncRef::Internal(0), vec![a, b]),
306 IrType::Int(IntWidth::I32),
307 );
308 let caller_entry = caller.entry;
309 caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(call)));
310 module.add_function(caller);
311
312 assert!(DeadArgElim.run(&mut module));
313 assert_eq!(module.functions[0].params.len(), 1);
314
315 let call_args = match &module.functions[1].blocks[0].insts[2].kind {
316 InstKind::Call(FuncRef::Internal(0), args) => args,
317 other => panic!("expected internal call, got {:?}", other),
318 };
319 assert_eq!(call_args, &vec![a]);
320 }
321
322 #[test]
323 fn preserves_abi_for_non_internal_function() {
324 let mut module = Module::new("t".into());
325 let params = vec![
326 Param {
327 name: "x".into(),
328 ty: IrType::Int(IntWidth::I32),
329 id: ValueId(0),
330 fortran_noalias: false,
331 },
332 Param {
333 name: "unused".into(),
334 ty: IrType::Int(IntWidth::I32),
335 id: ValueId(1),
336 fortran_noalias: false,
337 },
338 ];
339 let mut func = Function::new("exported".into(), params, IrType::Int(IntWidth::I32));
340 let entry = func.entry;
341 func.block_mut(entry).terminator = Some(Terminator::Return(Some(ValueId(0))));
342 module.add_function(func);
343
344 assert!(!DeadArgElim.run(&mut module));
345 assert_eq!(module.functions[0].params.len(), 2);
346 }
347 }
348