Rust · 12829 bytes Raw Blame History
1 //! Scalar Replacement of Aggregates (SROA).
2 //!
3 //! Decomposes small aggregate allocas (`alloca [T x N]` where N ≤ 8)
4 //! into N individual scalar allocas. After SROA, mem2reg can promote
5 //! the scalars to SSA values in registers.
6 //!
7 //! Eligibility:
8 //! - Alloca type is `Array(elem, count)` with count ≤ SROA_MAX_FIELDS
9 //! - ALL uses of the alloca are GEP with a single constant index
10 //! - The alloca address is never passed to a Call (no escape)
11 //!
12 //! After SROA, a second Mem2Reg pass promotes the new scalar allocas.
13
14 use super::loop_utils::resolve_const_int;
15 use super::pass::Pass;
16 use crate::ir::inst::*;
17 use crate::ir::types::IrType;
18 use crate::ir::walk::inst_uses;
19 use std::collections::HashMap;
20
21 const SROA_MAX_FIELDS: u64 = 8;
22
23 pub struct Sroa;
24
25 impl Pass for Sroa {
26 fn name(&self) -> &'static str {
27 "sroa"
28 }
29
30 fn run(&self, module: &mut Module) -> bool {
31 let mut changed = false;
32 for func in &mut module.functions {
33 if sroa_function(func) {
34 changed = true;
35 }
36 }
37 changed
38 }
39 }
40
41 fn sroa_function(func: &mut Function) -> bool {
42 // Collect candidate allocas: Array type, small, all constant-index GEPs.
43 let candidates = find_candidates(func);
44 if candidates.is_empty() {
45 return false;
46 }
47
48 let mut changed = false;
49 for cand in candidates {
50 if decompose_alloca(func, &cand) {
51 changed = true;
52 }
53 }
54 changed
55 }
56
57 struct SroaCandidate {
58 alloca_id: ValueId,
59 alloca_block: BlockId,
60 alloca_inst_idx: usize,
61 elem_ty: IrType,
62 count: u64,
63 }
64
65 fn find_candidates(func: &Function) -> Vec<SroaCandidate> {
66 let mut candidates = Vec::new();
67
68 for block in &func.blocks {
69 for (inst_idx, inst) in block.insts.iter().enumerate() {
70 if let InstKind::Alloca(IrType::Array(ref elem, count)) = inst.kind {
71 if count <= SROA_MAX_FIELDS && count > 0 {
72 // Check eligibility: all uses are constant-index GEPs, no escape.
73 if is_eligible(func, inst.id) {
74 candidates.push(SroaCandidate {
75 alloca_id: inst.id,
76 alloca_block: block.id,
77 alloca_inst_idx: inst_idx,
78 elem_ty: (**elem).clone(),
79 count,
80 });
81 }
82 }
83 }
84 }
85 }
86 candidates
87 }
88
89 /// Check if all uses of an alloca are constant-index GEPs with no escape.
90 fn is_eligible(func: &Function, alloca_id: ValueId) -> bool {
91 for block in &func.blocks {
92 for inst in &block.insts {
93 let uses = inst_uses(&inst.kind);
94 if !uses.contains(&alloca_id) {
95 continue;
96 }
97
98 // Classify this use of the alloca.
99 match &inst.kind {
100 // GEP with the alloca as base: OK if single constant index
101 // AND the result type matches the element type (not byte-level).
102 InstKind::GetElementPtr(base, indices) if *base == alloca_id => {
103 if indices.len() != 1 {
104 return false;
105 }
106 if resolve_const_int(func, indices[0]).is_none() {
107 return false;
108 }
109 // Reject byte-level GEPs (ptr<i8>) — SROA only handles
110 // element-typed accesses. Byte-level GEPs from array
111 // constructors use raw byte offsets that don't map to
112 // element indices.
113 if let Some(IrType::Ptr(inner)) = func.value_type(inst.id) {
114 if matches!(*inner, IrType::Int(crate::ir::types::IntWidth::I8)) {
115 // Result is ptr<i8> — byte-level access, not element.
116 return false;
117 }
118 }
119 if gep_result_escapes(func, inst.id) {
120 return false;
121 }
122 // This use is fine — constant-index element access.
123 }
124 // Store where the alloca is the VALUE being stored = pointer escape.
125 InstKind::Store(val, _) if *val == alloca_id => {
126 return false;
127 }
128 // Call/RuntimeCall with the alloca as an argument = escape.
129 InstKind::Call(_, args) if args.contains(&alloca_id) => {
130 return false;
131 }
132 InstKind::RuntimeCall(_, args) if args.contains(&alloca_id) => {
133 return false;
134 }
135 // Any other instruction that uses the alloca = ineligible.
136 // (Includes: direct store TO the aggregate base, arithmetic on the pointer, etc.)
137 _ => {
138 return false;
139 }
140 }
141 }
142 }
143 true
144 }
145
146 fn gep_result_escapes(func: &Function, gep_id: ValueId) -> bool {
147 for block in &func.blocks {
148 for inst in &block.insts {
149 let uses = inst_uses(&inst.kind);
150 if !uses.contains(&gep_id) {
151 continue;
152 }
153 match &inst.kind {
154 InstKind::Load(ptr) if *ptr == gep_id => {}
155 InstKind::Store(_, ptr) if *ptr == gep_id => {}
156 _ => return true,
157 }
158 }
159 if let Some(term) = &block.terminator {
160 match term {
161 Terminator::Return(Some(v)) if *v == gep_id => return true,
162 Terminator::Branch(_, args) if args.contains(&gep_id) => return true,
163 Terminator::CondBranch {
164 cond,
165 true_args,
166 false_args,
167 ..
168 } => {
169 if *cond == gep_id
170 || true_args.contains(&gep_id)
171 || false_args.contains(&gep_id)
172 {
173 return true;
174 }
175 }
176 Terminator::Switch { selector, .. } => {
177 if *selector == gep_id {
178 return true;
179 }
180 }
181 _ => {}
182 }
183 }
184 }
185 false
186 }
187
188 /// Decompose one alloca into individual scalar allocas.
189 fn decompose_alloca(func: &mut Function, cand: &SroaCandidate) -> bool {
190 // Create N individual allocas, one per field.
191 let mut field_allocas: Vec<ValueId> = Vec::new();
192 let span = func.block(cand.alloca_block).insts[cand.alloca_inst_idx].span;
193
194 // Create field allocas and insert them right after the original alloca
195 // so they dominate all subsequent uses.
196 let insert_pos = cand.alloca_inst_idx + 1;
197 for i in 0..cand.count {
198 let new_id = func.next_value_id();
199 let ptr_ty = IrType::Ptr(Box::new(cand.elem_ty.clone()));
200 func.register_type(new_id, ptr_ty.clone());
201 func.block_mut(cand.alloca_block).insts.insert(
202 insert_pos + i as usize,
203 Inst {
204 id: new_id,
205 kind: InstKind::Alloca(cand.elem_ty.clone()),
206 ty: ptr_ty,
207 span,
208 },
209 );
210 field_allocas.push(new_id);
211 }
212
213 // Build a GEP→field_alloca mapping: for each GEP(alloca, [const_idx]),
214 // replace the GEP result with the corresponding field alloca.
215 let mut gep_to_field: HashMap<ValueId, ValueId> = HashMap::new();
216
217 for block in &func.blocks {
218 for inst in &block.insts {
219 if let InstKind::GetElementPtr(base, indices) = &inst.kind {
220 if *base == cand.alloca_id && indices.len() == 1 {
221 if let Some(idx) = resolve_const_int(func, indices[0]) {
222 if idx >= 0 && (idx as u64) < cand.count {
223 gep_to_field.insert(inst.id, field_allocas[idx as usize]);
224 }
225 }
226 }
227 }
228 }
229 }
230
231 if gep_to_field.is_empty() {
232 return false;
233 }
234
235 // Rewrite all uses of GEP results to use the field allocas.
236 for block in &mut func.blocks {
237 for inst in &mut block.insts {
238 // Replace operands that reference GEP results.
239 let mut new_kind = inst.kind.clone();
240 let mut replaced = false;
241 match &mut new_kind {
242 InstKind::Load(ptr) => {
243 if let Some(&field) = gep_to_field.get(ptr) {
244 *ptr = field;
245 replaced = true;
246 }
247 }
248 InstKind::Store(_, ptr) => {
249 if let Some(&field) = gep_to_field.get(ptr) {
250 *ptr = field;
251 replaced = true;
252 }
253 }
254 _ => {}
255 }
256 if replaced {
257 inst.kind = new_kind;
258 }
259 }
260 }
261
262 true
263 }
264
265 #[cfg(test)]
266 mod tests {
267 use super::*;
268 use crate::ir::types::IntWidth;
269 use crate::lexer::{Position, Span};
270 use crate::opt::pass::Pass;
271
272 fn span() -> Span {
273 let pos = Position { line: 0, col: 0 };
274 Span {
275 file_id: 0,
276 start: pos,
277 end: pos,
278 }
279 }
280
281 #[test]
282 fn sroa_no_op_on_scalars() {
283 let mut m = Module::new("test".into());
284 let mut f = Function::new("test".into(), vec![], IrType::Void);
285 let a = f.next_value_id();
286 f.register_type(a, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))));
287 f.block_mut(f.entry).insts.push(Inst {
288 id: a,
289 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
290 span: span(),
291 kind: InstKind::Alloca(IrType::Int(IntWidth::I32)),
292 });
293 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
294 m.add_function(f);
295 let pass = Sroa;
296 assert!(!pass.run(&mut m), "scalar alloca should not be decomposed");
297 }
298
299 #[test]
300 fn sroa_rejects_gep_address_escape() {
301 let mut m = Module::new("test".into());
302 let mut f = Function::new("test".into(), vec![], IrType::Void);
303 let arr_ty = IrType::Array(
304 Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))),
305 2,
306 );
307 let arr = f.next_value_id();
308 f.register_type(arr, IrType::Ptr(Box::new(arr_ty.clone())));
309 f.block_mut(f.entry).insts.push(Inst {
310 id: arr,
311 ty: IrType::Ptr(Box::new(arr_ty)),
312 span: span(),
313 kind: InstKind::Alloca(IrType::Array(
314 Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))),
315 2,
316 )),
317 });
318
319 let zero = f.next_value_id();
320 f.register_type(zero, IrType::Int(IntWidth::I64));
321 f.block_mut(f.entry).insts.push(Inst {
322 id: zero,
323 ty: IrType::Int(IntWidth::I64),
324 span: span(),
325 kind: InstKind::ConstInt(0, IntWidth::I64),
326 });
327
328 let gep = f.next_value_id();
329 f.register_type(
330 gep,
331 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))),
332 );
333 f.block_mut(f.entry).insts.push(Inst {
334 id: gep,
335 ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))),
336 span: span(),
337 kind: InstKind::GetElementPtr(arr, vec![zero]),
338 });
339
340 let sink = f.next_value_id();
341 f.register_type(
342 sink,
343 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new(
344 IrType::Int(IntWidth::I8),
345 )))))),
346 );
347 f.block_mut(f.entry).insts.push(Inst {
348 id: sink,
349 ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new(
350 IrType::Int(IntWidth::I8),
351 )))))),
352 span: span(),
353 kind: InstKind::Alloca(IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(
354 IntWidth::I8,
355 )))))),
356 });
357
358 let escape_store = f.next_value_id();
359 f.register_type(escape_store, IrType::Void);
360 f.block_mut(f.entry).insts.push(Inst {
361 id: escape_store,
362 ty: IrType::Void,
363 span: span(),
364 kind: InstKind::Store(gep, sink),
365 });
366
367 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
368 m.add_function(f);
369
370 assert!(
371 !Sroa.run(&mut m),
372 "GEP addresses that escape should block SROA"
373 );
374 }
375 }
376