Rust · 12721 bytes Raw Blame History
1 //! Scalar Replacement of Aggregates (SROA).
2 //!
3 //! Decomposes small aggregate allocas (`alloca [T x N]` where N ≤ 8)
4 //! into N individual scalar allocas. After SROA, mem2reg can promote
5 //! the scalars to SSA values in registers.
6 //!
7 //! Eligibility:
8 //! - Alloca type is `Array(elem, count)` with count ≤ SROA_MAX_FIELDS
9 //! - ALL uses of the alloca are GEP with a single constant index
10 //! - The alloca address is never passed to a Call (no escape)
11 //!
12 //! After SROA, a second Mem2Reg pass promotes the new scalar allocas.
13
14 use super::loop_utils::resolve_const_int;
15 use super::pass::Pass;
16 use crate::ir::inst::*;
17 use crate::ir::types::IrType;
18 use crate::ir::walk::inst_uses;
19 use std::collections::HashMap;
20
21 const SROA_MAX_FIELDS: u64 = 8;
22
23 pub struct Sroa;
24
25 impl Pass for Sroa {
26 fn name(&self) -> &'static str {
27 "sroa"
28 }
29
30 fn run(&self, module: &mut Module) -> bool {
31 let mut changed = false;
32 for func in &mut module.functions {
33 if sroa_function(func) {
34 changed = true;
35 }
36 }
37 changed
38 }
39 }
40
41 fn sroa_function(func: &mut Function) -> bool {
42 // Collect candidate allocas: Array type, small, all constant-index GEPs.
43 let candidates = find_candidates(func);
44 if candidates.is_empty() {
45 return false;
46 }
47
48 let mut changed = false;
49 for cand in candidates {
50 if decompose_alloca(func, &cand) {
51 changed = true;
52 }
53 }
54 changed
55 }
56
57 struct SroaCandidate {
58 alloca_id: ValueId,
59 alloca_block: BlockId,
60 alloca_inst_idx: usize,
61 elem_ty: IrType,
62 count: u64,
63 }
64
65 fn find_candidates(func: &Function) -> Vec<SroaCandidate> {
66 let mut candidates = Vec::new();
67
68 for block in &func.blocks {
69 for (inst_idx, inst) in block.insts.iter().enumerate() {
70 if let InstKind::Alloca(IrType::Array(ref elem, count)) = inst.kind {
71 if count <= SROA_MAX_FIELDS && count > 0 {
72 // Check eligibility: all uses are constant-index GEPs, no escape.
73 if is_eligible(func, inst.id) {
74 candidates.push(SroaCandidate {
75 alloca_id: inst.id,
76 alloca_block: block.id,
77 alloca_inst_idx: inst_idx,
78 elem_ty: (**elem).clone(),
79 count,
80 });
81 }
82 }
83 }
84 }
85 }
86 candidates
87 }
88
89 /// Check if all uses of an alloca are constant-index GEPs with no escape.
90 fn is_eligible(func: &Function, alloca_id: ValueId) -> bool {
91 for block in &func.blocks {
92 for inst in &block.insts {
93 let uses = inst_uses(&inst.kind);
94 if !uses.contains(&alloca_id) {
95 continue;
96 }
97
98 // Classify this use of the alloca.
99 match &inst.kind {
100 // GEP with the alloca as base: OK if single constant index
101 // AND the result type matches the element type (not byte-level).
102 InstKind::GetElementPtr(base, indices) if *base == alloca_id => {
103 if indices.len() != 1 {
104 return false;
105 }
106 if resolve_const_int(func, indices[0]).is_none() {
107 return false;
108 }
109 // Reject byte-level GEPs (ptr<i8>) — SROA only handles
110 // element-typed accesses. Byte-level GEPs from array
111 // constructors use raw byte offsets that don't map to
112 // element indices.
113 if let Some(IrType::Ptr(inner)) = func.value_type(inst.id) {
114 if matches!(*inner, IrType::Int(crate::ir::types::IntWidth::I8)) {
115 // Result is ptr<i8> — byte-level access, not element.
116 return false;
117 }
118 }
119 if gep_result_escapes(func, inst.id) {
120 return false;
121 }
122 // This use is fine — constant-index element access.
123 }
124 // Store where the alloca is the VALUE being stored = pointer escape.
125 InstKind::Store(val, _) if *val == alloca_id => {
126 return false;
127 }
128 // Call/RuntimeCall with the alloca as an argument = escape.
129 InstKind::Call(_, args) if args.contains(&alloca_id) => {
130 return false;
131 }
132 InstKind::RuntimeCall(_, args) if args.contains(&alloca_id) => {
133 return false;
134 }
135 // Any other instruction that uses the alloca = ineligible.
136 // (Includes: direct store TO the aggregate base, arithmetic on the pointer, etc.)
137 _ => {
138 return false;
139 }
140 }
141 }
142 }
143 true
144 }
145
146 fn gep_result_escapes(func: &Function, gep_id: ValueId) -> bool {
147 for block in &func.blocks {
148 for inst in &block.insts {
149 let uses = inst_uses(&inst.kind);
150 if !uses.contains(&gep_id) {
151 continue;
152 }
153 match &inst.kind {
154 InstKind::Load(ptr) if *ptr == gep_id => {}
155 InstKind::Store(_, ptr) if *ptr == gep_id => {}
156 _ => return true,
157 }
158 }
159 if let Some(term) = &block.terminator {
160 match term {
161 Terminator::Return(Some(v)) if *v == gep_id => return true,
162 Terminator::Branch(_, args) if args.contains(&gep_id) => return true,
163 Terminator::CondBranch {
164 cond,
165 true_args,
166 false_args,
167 ..
168 } if *cond == gep_id
169 || true_args.contains(&gep_id)
170 || false_args.contains(&gep_id) =>
171 {
172 return true;
173 }
174 Terminator::Switch { selector, .. } if *selector == gep_id => {
175 return true;
176 }
177 _ => {}
178 }
179 }
180 }
181 false
182 }
183
184 /// Decompose one alloca into individual scalar allocas.
185 fn decompose_alloca(func: &mut Function, cand: &SroaCandidate) -> bool {
186 // Create N individual allocas, one per field.
187 let mut field_allocas: Vec<ValueId> = Vec::new();
188 let span = func.block(cand.alloca_block).insts[cand.alloca_inst_idx].span;
189
190 // Create field allocas and insert them right after the original alloca
191 // so they dominate all subsequent uses.
192 let insert_pos = cand.alloca_inst_idx + 1;
193 for i in 0..cand.count {
194 let new_id = func.next_value_id();
195 let ptr_ty = IrType::Ptr(Box::new(cand.elem_ty.clone()));
196 func.register_type(new_id, ptr_ty.clone());
197 func.block_mut(cand.alloca_block).insts.insert(
198 insert_pos + i as usize,
199 Inst {
200 id: new_id,
201 kind: InstKind::Alloca(cand.elem_ty.clone()),
202 ty: ptr_ty,
203 span,
204 },
205 );
206 field_allocas.push(new_id);
207 }
208
209 // Build a GEP→field_alloca mapping: for each GEP(alloca, [const_idx]),
210 // replace the GEP result with the corresponding field alloca.
211 let mut gep_to_field: HashMap<ValueId, ValueId> = HashMap::new();
212
213 for block in &func.blocks {
214 for inst in &block.insts {
215 if let InstKind::GetElementPtr(base, indices) = &inst.kind {
216 if *base == cand.alloca_id && indices.len() == 1 {
217 if let Some(idx) = resolve_const_int(func, indices[0]) {
218 if idx >= 0 && (idx as u64) < cand.count {
219 gep_to_field.insert(inst.id, field_allocas[idx as usize]);
220 }
221 }
222 }
223 }
224 }
225 }
226
227 if gep_to_field.is_empty() {
228 return false;
229 }
230
231 // Rewrite all uses of GEP results to use the field allocas.
232 for block in &mut func.blocks {
233 for inst in &mut block.insts {
234 // Replace operands that reference GEP results.
235 let mut new_kind = inst.kind.clone();
236 let mut replaced = false;
237 match &mut new_kind {
238 InstKind::Load(ptr) => {
239 if let Some(&field) = gep_to_field.get(ptr) {
240 *ptr = field;
241 replaced = true;
242 }
243 }
244 InstKind::Store(_, ptr) => {
245 if let Some(&field) = gep_to_field.get(ptr) {
246 *ptr = field;
247 replaced = true;
248 }
249 }
250 _ => {}
251 }
252 if replaced {
253 inst.kind = new_kind;
254 }
255 }
256 }
257
258 true
259 }
260
261 #[cfg(test)]
262 mod tests {
263 use super::*;
264 use crate::ir::types::IntWidth;
265 use crate::lexer::{Position, Span};
266 use crate::opt::pass::Pass;
267
268 fn span() -> Span {
269 let pos = Position { line: 0, col: 0 };
270 Span {
271 file_id: 0,
272 start: pos,
273 end: pos,
274 }
275 }
276
277 #[test]
278 fn sroa_no_op_on_scalars() {
279 let mut m = Module::new("test".into());
280 let mut f = Function::new("test".into(), vec![], IrType::Void);
281 let a = f.next_value_id();
282 f.register_type(a, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))));
283 f.block_mut(f.entry).insts.push(Inst {
284 id: a,
285 ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
286 span: span(),
287 kind: InstKind::Alloca(IrType::Int(IntWidth::I32)),
288 });
289 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
290 m.add_function(f);
291 let pass = Sroa;
292 assert!(!pass.run(&mut m), "scalar alloca should not be decomposed");
293 }
294
295 #[test]
296 fn sroa_rejects_gep_address_escape() {
297 let mut m = Module::new("test".into());
298 let mut f = Function::new("test".into(), vec![], IrType::Void);
299 let arr_ty = IrType::Array(
300 Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))),
301 2,
302 );
303 let arr = f.next_value_id();
304 f.register_type(arr, IrType::Ptr(Box::new(arr_ty.clone())));
305 f.block_mut(f.entry).insts.push(Inst {
306 id: arr,
307 ty: IrType::Ptr(Box::new(arr_ty)),
308 span: span(),
309 kind: InstKind::Alloca(IrType::Array(
310 Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))),
311 2,
312 )),
313 });
314
315 let zero = f.next_value_id();
316 f.register_type(zero, IrType::Int(IntWidth::I64));
317 f.block_mut(f.entry).insts.push(Inst {
318 id: zero,
319 ty: IrType::Int(IntWidth::I64),
320 span: span(),
321 kind: InstKind::ConstInt(0, IntWidth::I64),
322 });
323
324 let gep = f.next_value_id();
325 f.register_type(
326 gep,
327 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))),
328 );
329 f.block_mut(f.entry).insts.push(Inst {
330 id: gep,
331 ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))),
332 span: span(),
333 kind: InstKind::GetElementPtr(arr, vec![zero]),
334 });
335
336 let sink = f.next_value_id();
337 f.register_type(
338 sink,
339 IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new(
340 IrType::Int(IntWidth::I8),
341 )))))),
342 );
343 f.block_mut(f.entry).insts.push(Inst {
344 id: sink,
345 ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new(
346 IrType::Int(IntWidth::I8),
347 )))))),
348 span: span(),
349 kind: InstKind::Alloca(IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(
350 IntWidth::I8,
351 )))))),
352 });
353
354 let escape_store = f.next_value_id();
355 f.register_type(escape_store, IrType::Void);
356 f.block_mut(f.entry).insts.push(Inst {
357 id: escape_store,
358 ty: IrType::Void,
359 span: span(),
360 kind: InstKind::Store(gep, sink),
361 });
362
363 f.block_mut(f.entry).terminator = Some(Terminator::Return(None));
364 m.add_function(f);
365
366 assert!(
367 !Sroa.run(&mut m),
368 "GEP addresses that escape should block SROA"
369 );
370 }
371 }
372