| 1 | //! Dead-argument elimination for contained procedures. |
| 2 | //! |
| 3 | //! Removes unused parameters from internal-only functions and trims the |
| 4 | //! matching argument slots from all `Call(Internal, ...)` sites. |
| 5 | |
| 6 | use crate::ir::inst::*; |
| 7 | |
| 8 | use super::pass::Pass; |
| 9 | |
| 10 | pub struct DeadArgElim; |
| 11 | |
| 12 | impl Pass for DeadArgElim { |
| 13 | fn name(&self) -> &'static str { |
| 14 | "dead-arg-elim" |
| 15 | } |
| 16 | |
| 17 | fn run(&self, module: &mut Module) -> bool { |
| 18 | let mut live_masks: Vec<Vec<bool>> = module |
| 19 | .functions |
| 20 | .iter() |
| 21 | .map(|func| vec![false; func.params.len()]) |
| 22 | .collect(); |
| 23 | |
| 24 | let mut changed = true; |
| 25 | while changed { |
| 26 | changed = false; |
| 27 | for (func_idx, func) in module.functions.iter().enumerate() { |
| 28 | if !func.internal_only || func.params.is_empty() { |
| 29 | continue; |
| 30 | } |
| 31 | for (param_idx, param) in func.params.iter().enumerate() { |
| 32 | let live = param_has_observable_use(module, func_idx, param.id, &live_masks); |
| 33 | if live_masks[func_idx][param_idx] != live { |
| 34 | live_masks[func_idx][param_idx] = live; |
| 35 | changed = true; |
| 36 | } |
| 37 | } |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | let mut removed_any = false; |
| 42 | for (idx, func) in module.functions.iter_mut().enumerate() { |
| 43 | if !func.internal_only || func.params.is_empty() { |
| 44 | continue; |
| 45 | } |
| 46 | let keep = &live_masks[idx]; |
| 47 | if keep.iter().all(|&slot| slot) { |
| 48 | continue; |
| 49 | } |
| 50 | let old_params = std::mem::take(&mut func.params); |
| 51 | func.params = old_params |
| 52 | .into_iter() |
| 53 | .zip(keep.iter().copied()) |
| 54 | .filter_map(|(param, keep_slot)| keep_slot.then_some(param)) |
| 55 | .collect(); |
| 56 | removed_any = true; |
| 57 | } |
| 58 | |
| 59 | if !removed_any { |
| 60 | return false; |
| 61 | } |
| 62 | |
| 63 | for func in &mut module.functions { |
| 64 | for block in &mut func.blocks { |
| 65 | for inst in &mut block.insts { |
| 66 | let InstKind::Call(FuncRef::Internal(callee_idx), args) = &mut inst.kind else { |
| 67 | continue; |
| 68 | }; |
| 69 | let Some(keep) = live_masks.get(*callee_idx as usize) else { |
| 70 | continue; |
| 71 | }; |
| 72 | let old_args = std::mem::take(args); |
| 73 | *args = old_args |
| 74 | .into_iter() |
| 75 | .zip(keep.iter().copied()) |
| 76 | .filter_map(|(arg, keep_slot)| keep_slot.then_some(arg)) |
| 77 | .collect(); |
| 78 | } |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | true |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | fn param_has_observable_use( |
| 87 | module: &Module, |
| 88 | func_idx: usize, |
| 89 | param_id: ValueId, |
| 90 | live_masks: &[Vec<bool>], |
| 91 | ) -> bool { |
| 92 | let func = &module.functions[func_idx]; |
| 93 | for block in &func.blocks { |
| 94 | for inst in &block.insts { |
| 95 | match &inst.kind { |
| 96 | InstKind::Call(FuncRef::Internal(callee_idx), args) => { |
| 97 | for (arg_idx, arg) in args.iter().enumerate() { |
| 98 | if *arg != param_id { |
| 99 | continue; |
| 100 | } |
| 101 | let Some(callee) = module.functions.get(*callee_idx as usize) else { |
| 102 | return true; |
| 103 | }; |
| 104 | if !callee.internal_only { |
| 105 | return true; |
| 106 | } |
| 107 | let Some(callee_live) = live_masks |
| 108 | .get(*callee_idx as usize) |
| 109 | .and_then(|mask| mask.get(arg_idx)) |
| 110 | else { |
| 111 | return true; |
| 112 | }; |
| 113 | if *callee_live { |
| 114 | return true; |
| 115 | } |
| 116 | } |
| 117 | } |
| 118 | kind => { |
| 119 | if inst_uses_param(kind, param_id) { |
| 120 | return true; |
| 121 | } |
| 122 | } |
| 123 | } |
| 124 | } |
| 125 | if block_terminator_uses_param(&block.terminator, param_id) { |
| 126 | return true; |
| 127 | } |
| 128 | } |
| 129 | false |
| 130 | } |
| 131 | |
| 132 | fn inst_uses_param(kind: &InstKind, param_id: ValueId) -> bool { |
| 133 | match kind { |
| 134 | InstKind::ConstInt(..) |
| 135 | | InstKind::ConstFloat(..) |
| 136 | | InstKind::ConstBool(..) |
| 137 | | InstKind::ConstString(..) |
| 138 | | InstKind::Undef(..) |
| 139 | | InstKind::Alloca(..) |
| 140 | | InstKind::GlobalAddr(..) => false, |
| 141 | |
| 142 | InstKind::IAdd(a, b) |
| 143 | | InstKind::ISub(a, b) |
| 144 | | InstKind::IMul(a, b) |
| 145 | | InstKind::IDiv(a, b) |
| 146 | | InstKind::IMod(a, b) |
| 147 | | InstKind::FAdd(a, b) |
| 148 | | InstKind::FSub(a, b) |
| 149 | | InstKind::FMul(a, b) |
| 150 | | InstKind::FDiv(a, b) |
| 151 | | InstKind::FPow(a, b) |
| 152 | | InstKind::ICmp(_, a, b) |
| 153 | | InstKind::FCmp(_, a, b) |
| 154 | | InstKind::And(a, b) |
| 155 | | InstKind::Or(a, b) |
| 156 | | InstKind::BitAnd(a, b) |
| 157 | | InstKind::BitOr(a, b) |
| 158 | | InstKind::BitXor(a, b) |
| 159 | | InstKind::Shl(a, b) |
| 160 | | InstKind::LShr(a, b) |
| 161 | | InstKind::AShr(a, b) |
| 162 | | InstKind::Store(a, b) => *a == param_id || *b == param_id, |
| 163 | |
| 164 | InstKind::INeg(a) |
| 165 | | InstKind::FNeg(a) |
| 166 | | InstKind::FAbs(a) |
| 167 | | InstKind::FSqrt(a) |
| 168 | | InstKind::Not(a) |
| 169 | | InstKind::BitNot(a) |
| 170 | | InstKind::CountLeadingZeros(a) |
| 171 | | InstKind::CountTrailingZeros(a) |
| 172 | | InstKind::PopCount(a) |
| 173 | | InstKind::IntToFloat(a, _) |
| 174 | | InstKind::FloatToInt(a, _) |
| 175 | | InstKind::FloatExtend(a, _) |
| 176 | | InstKind::FloatTrunc(a, _) |
| 177 | | InstKind::IntExtend(a, _, _) |
| 178 | | InstKind::IntTrunc(a, _) |
| 179 | | InstKind::PtrToInt(a) |
| 180 | | InstKind::IntToPtr(a, _) |
| 181 | | InstKind::Load(a) |
| 182 | | InstKind::ExtractField(a, _) => *a == param_id, |
| 183 | |
| 184 | InstKind::Select(c, t, f) => *c == param_id || *t == param_id || *f == param_id, |
| 185 | InstKind::GetElementPtr(base, idxs) => *base == param_id || idxs.contains(¶m_id), |
| 186 | InstKind::RuntimeCall(_, args) | InstKind::Call(FuncRef::External(_), args) => { |
| 187 | args.contains(¶m_id) |
| 188 | } |
| 189 | InstKind::Call(FuncRef::Indirect(target), args) => { |
| 190 | *target == param_id || args.contains(¶m_id) |
| 191 | } |
| 192 | InstKind::InsertField(agg, _, val) => *agg == param_id || *val == param_id, |
| 193 | InstKind::Call(FuncRef::Internal(_), _) => false, |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | fn block_terminator_uses_param(term: &Option<Terminator>, param_id: ValueId) -> bool { |
| 198 | match term { |
| 199 | Some(Terminator::Return(Some(v))) => *v == param_id, |
| 200 | Some(Terminator::Branch(_, args)) => args.contains(¶m_id), |
| 201 | Some(Terminator::CondBranch { |
| 202 | cond, |
| 203 | true_args, |
| 204 | false_args, |
| 205 | .. |
| 206 | }) => *cond == param_id || true_args.contains(¶m_id) || false_args.contains(¶m_id), |
| 207 | Some(Terminator::Switch { selector, .. }) => *selector == param_id, |
| 208 | _ => false, |
| 209 | } |
| 210 | } |
| 211 | |
| 212 | #[cfg(test)] |
| 213 | mod tests { |
| 214 | use super::*; |
| 215 | use crate::ir::types::{IntWidth, IrType}; |
| 216 | use crate::lexer::{Position, Span}; |
| 217 | |
| 218 | fn span() -> Span { |
| 219 | let pos = Position { line: 0, col: 0 }; |
| 220 | Span { |
| 221 | file_id: 0, |
| 222 | start: pos, |
| 223 | end: pos, |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId { |
| 228 | let id = f.next_value_id(); |
| 229 | let entry = f.entry; |
| 230 | f.block_mut(entry).insts.push(Inst { |
| 231 | id, |
| 232 | kind, |
| 233 | ty: ty.clone(), |
| 234 | span: span(), |
| 235 | }); |
| 236 | f.register_type(id, ty); |
| 237 | id |
| 238 | } |
| 239 | |
| 240 | #[test] |
| 241 | fn trims_unused_param_and_call_arg_for_internal_only_function() { |
| 242 | let mut module = Module::new("t".into()); |
| 243 | |
| 244 | let params = vec![ |
| 245 | Param { |
| 246 | name: "x".into(), |
| 247 | ty: IrType::Int(IntWidth::I32), |
| 248 | id: ValueId(0), |
| 249 | fortran_noalias: false, |
| 250 | }, |
| 251 | Param { |
| 252 | name: "unused".into(), |
| 253 | ty: IrType::Int(IntWidth::I32), |
| 254 | id: ValueId(1), |
| 255 | fortran_noalias: false, |
| 256 | }, |
| 257 | ]; |
| 258 | let mut callee = Function::new("helper".into(), params, IrType::Int(IntWidth::I32)); |
| 259 | callee.internal_only = true; |
| 260 | let callee_entry = callee.entry; |
| 261 | callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(ValueId(0)))); |
| 262 | module.add_function(callee); |
| 263 | |
| 264 | let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32)); |
| 265 | let a = push( |
| 266 | &mut caller, |
| 267 | InstKind::ConstInt(7, IntWidth::I32), |
| 268 | IrType::Int(IntWidth::I32), |
| 269 | ); |
| 270 | let b = push( |
| 271 | &mut caller, |
| 272 | InstKind::ConstInt(9, IntWidth::I32), |
| 273 | IrType::Int(IntWidth::I32), |
| 274 | ); |
| 275 | let call = push( |
| 276 | &mut caller, |
| 277 | InstKind::Call(FuncRef::Internal(0), vec![a, b]), |
| 278 | IrType::Int(IntWidth::I32), |
| 279 | ); |
| 280 | let caller_entry = caller.entry; |
| 281 | caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(call))); |
| 282 | module.add_function(caller); |
| 283 | |
| 284 | assert!(DeadArgElim.run(&mut module)); |
| 285 | assert_eq!(module.functions[0].params.len(), 1); |
| 286 | |
| 287 | let call_args = match &module.functions[1].blocks[0].insts[2].kind { |
| 288 | InstKind::Call(FuncRef::Internal(0), args) => args, |
| 289 | other => panic!("expected internal call, got {:?}", other), |
| 290 | }; |
| 291 | assert_eq!(call_args, &vec![a]); |
| 292 | } |
| 293 | |
| 294 | #[test] |
| 295 | fn preserves_abi_for_non_internal_function() { |
| 296 | let mut module = Module::new("t".into()); |
| 297 | let params = vec![ |
| 298 | Param { |
| 299 | name: "x".into(), |
| 300 | ty: IrType::Int(IntWidth::I32), |
| 301 | id: ValueId(0), |
| 302 | fortran_noalias: false, |
| 303 | }, |
| 304 | Param { |
| 305 | name: "unused".into(), |
| 306 | ty: IrType::Int(IntWidth::I32), |
| 307 | id: ValueId(1), |
| 308 | fortran_noalias: false, |
| 309 | }, |
| 310 | ]; |
| 311 | let mut func = Function::new("exported".into(), params, IrType::Int(IntWidth::I32)); |
| 312 | let entry = func.entry; |
| 313 | func.block_mut(entry).terminator = Some(Terminator::Return(Some(ValueId(0)))); |
| 314 | module.add_function(func); |
| 315 | |
| 316 | assert!(!DeadArgElim.run(&mut module)); |
| 317 | assert_eq!(module.functions[0].params.len(), 2); |
| 318 | } |
| 319 | } |
| 320 |