Rust · 17236 bytes Raw Blame History
1 //! Constant-argument specialization for contained procedures.
2 //!
3 //! When every internal call site passes the same compile-time constant
4 //! to a contained helper, rewrite the helper body to materialize that
5 //! constant directly. `DeadArgElim` then trims the now-unused dummy.
6
7 use std::collections::HashMap;
8
9 use crate::ir::inst::*;
10 use crate::ir::types::{FloatWidth, IntWidth, IrType};
11 use crate::lexer::{Position, Span};
12
13 use super::pass::Pass;
14 use super::util::substitute_uses;
15
16 pub struct ConstArgSpecialize;
17
18 impl Pass for ConstArgSpecialize {
19 fn name(&self) -> &'static str {
20 "const-arg-specialize"
21 }
22
23 fn run(&self, module: &mut Module) -> bool {
24 let param_modes: Vec<Vec<Option<ParamMode>>> = module
25 .functions
26 .iter()
27 .map(|func| {
28 func.params
29 .iter()
30 .map(|param| classify_param(func, param))
31 .collect()
32 })
33 .collect();
34
35 let plans = specialization_plans(module, &param_modes);
36 if plans.is_empty() {
37 return false;
38 }
39
40 for plan in plans {
41 apply_plan(module, plan);
42 }
43
44 true
45 }
46 }
47
48 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
49 enum ParamMode {
50 ByValueScalar,
51 ReadOnlyScalarPtr,
52 }
53
54 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
55 enum ScalarConst {
56 Int(i128, IntWidth),
57 Float(u64, FloatWidth),
58 Bool(bool),
59 }
60
61 impl ScalarConst {
62 fn ty(self) -> IrType {
63 match self {
64 Self::Int(_, width) => IrType::Int(width),
65 Self::Float(_, width) => IrType::Float(width),
66 Self::Bool(_) => IrType::Bool,
67 }
68 }
69
70 fn kind(self) -> InstKind {
71 match self {
72 Self::Int(value, width) => InstKind::ConstInt(value, width),
73 Self::Float(bits, width) => match width {
74 FloatWidth::F32 => InstKind::ConstFloat(f32::from_bits(bits as u32) as f64, width),
75 FloatWidth::F64 => InstKind::ConstFloat(f64::from_bits(bits), width),
76 },
77 Self::Bool(value) => InstKind::ConstBool(value),
78 }
79 }
80 }
81
82 #[derive(Debug, Clone, Copy)]
83 struct SpecializedParam {
84 mode: ParamMode,
85 value: ScalarConst,
86 }
87
88 #[derive(Debug, Clone)]
89 struct SpecializationPlan {
90 func_idx: usize,
91 params: Vec<Option<SpecializedParam>>,
92 }
93
94 fn is_scalar_type(ty: &IrType) -> bool {
95 matches!(ty, IrType::Bool | IrType::Int(_) | IrType::Float(_))
96 }
97
98 fn classify_param(func: &Function, param: &Param) -> Option<ParamMode> {
99 if !func.internal_only {
100 return None;
101 }
102
103 if !param.ty.is_ptr() {
104 return is_scalar_type(&param.ty).then_some(ParamMode::ByValueScalar);
105 }
106
107 let IrType::Ptr(inner) = &param.ty else {
108 return None;
109 };
110 if !is_scalar_type(inner.as_ref()) {
111 return None;
112 }
113 pointer_param_is_direct_read_only_scalar(func, param.id).then_some(ParamMode::ReadOnlyScalarPtr)
114 }
115
116 fn pointer_param_is_direct_read_only_scalar(func: &Function, param_id: ValueId) -> bool {
117 let mut saw_load = false;
118 for block in &func.blocks {
119 for inst in &block.insts {
120 let uses = super::util::inst_uses(&inst.kind);
121 if !uses.contains(&param_id) {
122 continue;
123 }
124 match inst.kind {
125 InstKind::Load(addr) if addr == param_id => {
126 saw_load = true;
127 }
128 _ => return false,
129 }
130 }
131 if let Some(term) = &block.terminator {
132 if super::util::terminator_uses(term).contains(&param_id) {
133 return false;
134 }
135 }
136 }
137 saw_load
138 }
139
140 fn inst_map(func: &Function) -> HashMap<ValueId, &Inst> {
141 func.blocks
142 .iter()
143 .flat_map(|block| block.insts.iter())
144 .map(|inst| (inst.id, inst))
145 .collect()
146 }
147
148 fn const_value_of(func: &Function, value: ValueId) -> Option<ScalarConst> {
149 let defs = inst_map(func);
150 let inst = defs.get(&value)?;
151 match inst.kind {
152 InstKind::ConstInt(v, w) => Some(ScalarConst::Int(v, w)),
153 InstKind::ConstFloat(v, w) => Some(ScalarConst::Float(
154 match w {
155 FloatWidth::F32 => (v as f32).to_bits() as u64,
156 FloatWidth::F64 => v.to_bits(),
157 },
158 w,
159 )),
160 InstKind::ConstBool(v) => Some(ScalarConst::Bool(v)),
161 _ => None,
162 }
163 }
164
165 fn wrapper_const_value(
166 caller: &Function,
167 ptr_value: ValueId,
168 param_modes: &[Vec<Option<ParamMode>>],
169 ) -> Option<ScalarConst> {
170 let defs = inst_map(caller);
171 let def = defs.get(&ptr_value)?;
172 match &def.kind {
173 InstKind::Alloca(slot_ty) if is_scalar_type(slot_ty) => {}
174 _ => return None,
175 }
176
177 let mut stored = None;
178 let mut saw_call_use = false;
179
180 for block in &caller.blocks {
181 for inst in &block.insts {
182 let uses = super::util::inst_uses(&inst.kind);
183 if !uses.contains(&ptr_value) {
184 continue;
185 }
186 match &inst.kind {
187 InstKind::Store(value, addr) if *addr == ptr_value => {
188 let const_value = const_value_of(caller, *value)?;
189 if let Some(existing) = stored {
190 if existing != const_value {
191 return None;
192 }
193 } else {
194 stored = Some(const_value);
195 }
196 }
197 InstKind::Call(FuncRef::Internal(idx), args) => {
198 let valid = args.iter().enumerate().any(|(arg_idx, arg)| {
199 *arg == ptr_value
200 && param_modes
201 .get(*idx as usize)
202 .and_then(|modes| modes.get(arg_idx))
203 .copied()
204 .flatten()
205 == Some(ParamMode::ReadOnlyScalarPtr)
206 });
207 if !valid {
208 return None;
209 }
210 saw_call_use = true;
211 }
212 _ => return None,
213 }
214 }
215 if let Some(term) = &block.terminator {
216 if super::util::terminator_uses(term).contains(&ptr_value) {
217 return None;
218 }
219 }
220 }
221
222 if saw_call_use {
223 stored
224 } else {
225 None
226 }
227 }
228
229 fn const_arg_value(
230 caller: &Function,
231 arg: ValueId,
232 mode: ParamMode,
233 param_modes: &[Vec<Option<ParamMode>>],
234 ) -> Option<ScalarConst> {
235 match mode {
236 ParamMode::ByValueScalar => const_value_of(caller, arg),
237 ParamMode::ReadOnlyScalarPtr => wrapper_const_value(caller, arg, param_modes),
238 }
239 }
240
241 fn specialization_plans(
242 module: &Module,
243 param_modes: &[Vec<Option<ParamMode>>],
244 ) -> Vec<SpecializationPlan> {
245 let mut plans = Vec::new();
246
247 for (func_idx, func) in module.functions.iter().enumerate() {
248 if !func.internal_only || func.params.is_empty() {
249 continue;
250 }
251
252 let mut params = vec![None; func.params.len()];
253
254 for (param_idx, mode) in param_modes[func_idx].iter().copied().enumerate() {
255 let Some(mode) = mode else {
256 continue;
257 };
258
259 let mut agreed = None;
260 let mut saw_call = false;
261 let mut conflicted = false;
262
263 'callers: for caller in &module.functions {
264 for block in &caller.blocks {
265 for inst in &block.insts {
266 let InstKind::Call(FuncRef::Internal(idx), args) = &inst.kind else {
267 continue;
268 };
269 if *idx as usize != func_idx {
270 continue;
271 }
272 saw_call = true;
273 let Some(arg) = args.get(param_idx) else {
274 conflicted = true;
275 break 'callers;
276 };
277 let Some(value) = const_arg_value(caller, *arg, mode, param_modes) else {
278 conflicted = true;
279 break 'callers;
280 };
281 if let Some(existing) = agreed {
282 if existing != value {
283 conflicted = true;
284 break 'callers;
285 }
286 } else {
287 agreed = Some(value);
288 }
289 }
290 }
291 }
292
293 if saw_call && !conflicted {
294 if let Some(value) = agreed {
295 params[param_idx] = Some(SpecializedParam { mode, value });
296 }
297 }
298 }
299
300 if params.iter().any(|param| param.is_some()) {
301 plans.push(SpecializationPlan { func_idx, params });
302 }
303 }
304
305 plans
306 }
307
308 fn default_span(func: &Function) -> Span {
309 if let Some(span) = func
310 .blocks
311 .iter()
312 .flat_map(|block| block.insts.iter())
313 .map(|inst| inst.span)
314 .next()
315 {
316 span
317 } else {
318 let pos = Position { line: 0, col: 0 };
319 Span {
320 file_id: 0,
321 start: pos,
322 end: pos,
323 }
324 }
325 }
326
327 fn apply_plan(module: &mut Module, plan: SpecializationPlan) {
328 let func = &mut module.functions[plan.func_idx];
329 let span = default_span(func);
330 let mut new_consts = Vec::new();
331 let mut substitutions = Vec::new();
332 let mut remove_loads = Vec::new();
333
334 for (param_idx, specialized) in plan.params.iter().enumerate() {
335 let Some(specialized) = specialized else {
336 continue;
337 };
338 let const_id = func.next_value_id();
339 let const_ty = specialized.value.ty();
340 func.register_type(const_id, const_ty.clone());
341 new_consts.push(Inst {
342 id: const_id,
343 kind: specialized.value.kind(),
344 ty: const_ty,
345 span,
346 });
347
348 let param_id = func.params[param_idx].id;
349 match specialized.mode {
350 ParamMode::ByValueScalar => {
351 substitutions.push((param_id, const_id));
352 }
353 ParamMode::ReadOnlyScalarPtr => {
354 for block in &func.blocks {
355 for inst in &block.insts {
356 if matches!(inst.kind, InstKind::Load(addr) if addr == param_id) {
357 substitutions.push((inst.id, const_id));
358 remove_loads.push(inst.id);
359 }
360 }
361 }
362 }
363 }
364 }
365
366 if new_consts.is_empty() {
367 return;
368 }
369
370 let entry = func.entry;
371 func.block_mut(entry).insts.splice(0..0, new_consts);
372 for (old, new) in substitutions {
373 substitute_uses(func, old, new);
374 }
375 if !remove_loads.is_empty() {
376 for block in &mut func.blocks {
377 block.insts.retain(|inst| !remove_loads.contains(&inst.id));
378 }
379 }
380 }
381
382 #[cfg(test)]
383 mod tests {
384 use super::*;
385 use crate::ir::types::IrType;
386 use crate::lexer::{Position, Span};
387 use crate::opt::dead_arg::DeadArgElim;
388
389 fn span() -> Span {
390 let pos = Position { line: 0, col: 0 };
391 Span {
392 file_id: 0,
393 start: pos,
394 end: pos,
395 }
396 }
397
398 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
399 let id = f.next_value_id();
400 let entry = f.entry;
401 f.block_mut(entry).insts.push(Inst {
402 id,
403 kind,
404 ty: ty.clone(),
405 span: span(),
406 });
407 f.register_type(id, ty);
408 id
409 }
410
411 #[test]
412 fn specializes_by_value_constant_dummy_and_enables_arg_elision() {
413 let mut module = Module::new("t".into());
414
415 let params = vec![
416 Param {
417 name: "x".into(),
418 ty: IrType::Int(IntWidth::I32),
419 id: ValueId(0),
420 fortran_noalias: false,
421 },
422 Param {
423 name: "step".into(),
424 ty: IrType::Int(IntWidth::I32),
425 id: ValueId(1),
426 fortran_noalias: false,
427 },
428 ];
429 let mut callee = Function::new("helper".into(), params, IrType::Int(IntWidth::I32));
430 callee.internal_only = true;
431 let sum = push(
432 &mut callee,
433 InstKind::IAdd(ValueId(0), ValueId(1)),
434 IrType::Int(IntWidth::I32),
435 );
436 let callee_entry = callee.entry;
437 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(sum)));
438 module.add_function(callee);
439
440 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
441 let x1 = push(
442 &mut caller,
443 InstKind::ConstInt(7, IntWidth::I32),
444 IrType::Int(IntWidth::I32),
445 );
446 let step1 = push(
447 &mut caller,
448 InstKind::ConstInt(3, IntWidth::I32),
449 IrType::Int(IntWidth::I32),
450 );
451 let c1 = push(
452 &mut caller,
453 InstKind::Call(FuncRef::Internal(0), vec![x1, step1]),
454 IrType::Int(IntWidth::I32),
455 );
456 let x2 = push(
457 &mut caller,
458 InstKind::ConstInt(9, IntWidth::I32),
459 IrType::Int(IntWidth::I32),
460 );
461 let step2 = push(
462 &mut caller,
463 InstKind::ConstInt(3, IntWidth::I32),
464 IrType::Int(IntWidth::I32),
465 );
466 let c2 = push(
467 &mut caller,
468 InstKind::Call(FuncRef::Internal(0), vec![x2, step2]),
469 IrType::Int(IntWidth::I32),
470 );
471 let total = push(
472 &mut caller,
473 InstKind::IAdd(c1, c2),
474 IrType::Int(IntWidth::I32),
475 );
476 let caller_entry = caller.entry;
477 caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(total)));
478 module.add_function(caller);
479
480 assert!(ConstArgSpecialize.run(&mut module));
481 assert!(DeadArgElim.run(&mut module));
482
483 assert_eq!(
484 module.functions[0].params.len(),
485 1,
486 "specialized helper should drop constant dummy"
487 );
488 let call_args_1 = match &module.functions[1].blocks[0].insts[2].kind {
489 InstKind::Call(FuncRef::Internal(0), args) => args,
490 other => panic!("expected call, got {:?}", other),
491 };
492 assert_eq!(
493 call_args_1.len(),
494 1,
495 "call site should drop specialized constant arg"
496 );
497 }
498
499 #[test]
500 fn does_not_specialize_when_callsite_constants_disagree() {
501 let mut module = Module::new("t".into());
502
503 let params = vec![
504 Param {
505 name: "x".into(),
506 ty: IrType::Int(IntWidth::I32),
507 id: ValueId(0),
508 fortran_noalias: false,
509 },
510 Param {
511 name: "step".into(),
512 ty: IrType::Int(IntWidth::I32),
513 id: ValueId(1),
514 fortran_noalias: false,
515 },
516 ];
517 let mut callee = Function::new("helper".into(), params, IrType::Int(IntWidth::I32));
518 callee.internal_only = true;
519 let sum = push(
520 &mut callee,
521 InstKind::IAdd(ValueId(0), ValueId(1)),
522 IrType::Int(IntWidth::I32),
523 );
524 let callee_entry = callee.entry;
525 callee.block_mut(callee_entry).terminator = Some(Terminator::Return(Some(sum)));
526 module.add_function(callee);
527
528 let mut caller = Function::new("main".into(), vec![], IrType::Int(IntWidth::I32));
529 let x1 = push(
530 &mut caller,
531 InstKind::ConstInt(7, IntWidth::I32),
532 IrType::Int(IntWidth::I32),
533 );
534 let step1 = push(
535 &mut caller,
536 InstKind::ConstInt(3, IntWidth::I32),
537 IrType::Int(IntWidth::I32),
538 );
539 let _ = push(
540 &mut caller,
541 InstKind::Call(FuncRef::Internal(0), vec![x1, step1]),
542 IrType::Int(IntWidth::I32),
543 );
544 let x2 = push(
545 &mut caller,
546 InstKind::ConstInt(9, IntWidth::I32),
547 IrType::Int(IntWidth::I32),
548 );
549 let step2 = push(
550 &mut caller,
551 InstKind::ConstInt(4, IntWidth::I32),
552 IrType::Int(IntWidth::I32),
553 );
554 let call = push(
555 &mut caller,
556 InstKind::Call(FuncRef::Internal(0), vec![x2, step2]),
557 IrType::Int(IntWidth::I32),
558 );
559 let caller_entry = caller.entry;
560 caller.block_mut(caller_entry).terminator = Some(Terminator::Return(Some(call)));
561 module.add_function(caller);
562
563 assert!(!ConstArgSpecialize.run(&mut module));
564 assert_eq!(module.functions[0].params.len(), 2);
565 }
566 }
567