| 1 | //! Scalar Replacement of Aggregates (SROA). |
| 2 | //! |
| 3 | //! Decomposes small aggregate allocas (`alloca [T x N]` where N ≤ 8) |
| 4 | //! into N individual scalar allocas. After SROA, mem2reg can promote |
| 5 | //! the scalars to SSA values in registers. |
| 6 | //! |
| 7 | //! Eligibility: |
| 8 | //! - Alloca type is `Array(elem, count)` with count ≤ SROA_MAX_FIELDS |
| 9 | //! - ALL uses of the alloca are GEP with a single constant index |
| 10 | //! - The alloca address is never passed to a Call (no escape) |
| 11 | //! |
| 12 | //! After SROA, a second Mem2Reg pass promotes the new scalar allocas. |
| 13 | |
| 14 | use super::loop_utils::resolve_const_int; |
| 15 | use super::pass::Pass; |
| 16 | use crate::ir::inst::*; |
| 17 | use crate::ir::types::IrType; |
| 18 | use crate::ir::walk::inst_uses; |
| 19 | use std::collections::HashMap; |
| 20 | |
| 21 | const SROA_MAX_FIELDS: u64 = 8; |
| 22 | |
| 23 | pub struct Sroa; |
| 24 | |
| 25 | impl Pass for Sroa { |
| 26 | fn name(&self) -> &'static str { |
| 27 | "sroa" |
| 28 | } |
| 29 | |
| 30 | fn run(&self, module: &mut Module) -> bool { |
| 31 | let mut changed = false; |
| 32 | for func in &mut module.functions { |
| 33 | if sroa_function(func) { |
| 34 | changed = true; |
| 35 | } |
| 36 | } |
| 37 | changed |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | fn sroa_function(func: &mut Function) -> bool { |
| 42 | // Collect candidate allocas: Array type, small, all constant-index GEPs. |
| 43 | let candidates = find_candidates(func); |
| 44 | if candidates.is_empty() { |
| 45 | return false; |
| 46 | } |
| 47 | |
| 48 | let mut changed = false; |
| 49 | for cand in candidates { |
| 50 | if decompose_alloca(func, &cand) { |
| 51 | changed = true; |
| 52 | } |
| 53 | } |
| 54 | changed |
| 55 | } |
| 56 | |
| 57 | struct SroaCandidate { |
| 58 | alloca_id: ValueId, |
| 59 | alloca_block: BlockId, |
| 60 | alloca_inst_idx: usize, |
| 61 | elem_ty: IrType, |
| 62 | count: u64, |
| 63 | } |
| 64 | |
| 65 | fn find_candidates(func: &Function) -> Vec<SroaCandidate> { |
| 66 | let mut candidates = Vec::new(); |
| 67 | |
| 68 | for block in &func.blocks { |
| 69 | for (inst_idx, inst) in block.insts.iter().enumerate() { |
| 70 | if let InstKind::Alloca(IrType::Array(ref elem, count)) = inst.kind { |
| 71 | if count <= SROA_MAX_FIELDS && count > 0 { |
| 72 | // Check eligibility: all uses are constant-index GEPs, no escape. |
| 73 | if is_eligible(func, inst.id) { |
| 74 | candidates.push(SroaCandidate { |
| 75 | alloca_id: inst.id, |
| 76 | alloca_block: block.id, |
| 77 | alloca_inst_idx: inst_idx, |
| 78 | elem_ty: (**elem).clone(), |
| 79 | count, |
| 80 | }); |
| 81 | } |
| 82 | } |
| 83 | } |
| 84 | } |
| 85 | } |
| 86 | candidates |
| 87 | } |
| 88 | |
| 89 | /// Check if all uses of an alloca are constant-index GEPs with no escape. |
| 90 | fn is_eligible(func: &Function, alloca_id: ValueId) -> bool { |
| 91 | for block in &func.blocks { |
| 92 | for inst in &block.insts { |
| 93 | let uses = inst_uses(&inst.kind); |
| 94 | if !uses.contains(&alloca_id) { |
| 95 | continue; |
| 96 | } |
| 97 | |
| 98 | // Classify this use of the alloca. |
| 99 | match &inst.kind { |
| 100 | // GEP with the alloca as base: OK if single constant index |
| 101 | // AND the result type matches the element type (not byte-level). |
| 102 | InstKind::GetElementPtr(base, indices) if *base == alloca_id => { |
| 103 | if indices.len() != 1 { |
| 104 | return false; |
| 105 | } |
| 106 | if resolve_const_int(func, indices[0]).is_none() { |
| 107 | return false; |
| 108 | } |
| 109 | // Reject byte-level GEPs (ptr<i8>) — SROA only handles |
| 110 | // element-typed accesses. Byte-level GEPs from array |
| 111 | // constructors use raw byte offsets that don't map to |
| 112 | // element indices. |
| 113 | if let Some(IrType::Ptr(inner)) = func.value_type(inst.id) { |
| 114 | if matches!(*inner, IrType::Int(crate::ir::types::IntWidth::I8)) { |
| 115 | // Result is ptr<i8> — byte-level access, not element. |
| 116 | return false; |
| 117 | } |
| 118 | } |
| 119 | if gep_result_escapes(func, inst.id) { |
| 120 | return false; |
| 121 | } |
| 122 | // This use is fine — constant-index element access. |
| 123 | } |
| 124 | // Store where the alloca is the VALUE being stored = pointer escape. |
| 125 | InstKind::Store(val, _) if *val == alloca_id => { |
| 126 | return false; |
| 127 | } |
| 128 | // Call/RuntimeCall with the alloca as an argument = escape. |
| 129 | InstKind::Call(_, args) if args.contains(&alloca_id) => { |
| 130 | return false; |
| 131 | } |
| 132 | InstKind::RuntimeCall(_, args) if args.contains(&alloca_id) => { |
| 133 | return false; |
| 134 | } |
| 135 | // Any other instruction that uses the alloca = ineligible. |
| 136 | // (Includes: direct store TO the aggregate base, arithmetic on the pointer, etc.) |
| 137 | _ => { |
| 138 | return false; |
| 139 | } |
| 140 | } |
| 141 | } |
| 142 | } |
| 143 | true |
| 144 | } |
| 145 | |
| 146 | fn gep_result_escapes(func: &Function, gep_id: ValueId) -> bool { |
| 147 | for block in &func.blocks { |
| 148 | for inst in &block.insts { |
| 149 | let uses = inst_uses(&inst.kind); |
| 150 | if !uses.contains(&gep_id) { |
| 151 | continue; |
| 152 | } |
| 153 | match &inst.kind { |
| 154 | InstKind::Load(ptr) if *ptr == gep_id => {} |
| 155 | InstKind::Store(_, ptr) if *ptr == gep_id => {} |
| 156 | _ => return true, |
| 157 | } |
| 158 | } |
| 159 | if let Some(term) = &block.terminator { |
| 160 | match term { |
| 161 | Terminator::Return(Some(v)) if *v == gep_id => return true, |
| 162 | Terminator::Branch(_, args) if args.contains(&gep_id) => return true, |
| 163 | Terminator::CondBranch { |
| 164 | cond, |
| 165 | true_args, |
| 166 | false_args, |
| 167 | .. |
| 168 | } if *cond == gep_id |
| 169 | || true_args.contains(&gep_id) |
| 170 | || false_args.contains(&gep_id) => |
| 171 | { |
| 172 | return true; |
| 173 | } |
| 174 | Terminator::Switch { selector, .. } if *selector == gep_id => { |
| 175 | return true; |
| 176 | } |
| 177 | _ => {} |
| 178 | } |
| 179 | } |
| 180 | } |
| 181 | false |
| 182 | } |
| 183 | |
| 184 | /// Decompose one alloca into individual scalar allocas. |
| 185 | fn decompose_alloca(func: &mut Function, cand: &SroaCandidate) -> bool { |
| 186 | // Create N individual allocas, one per field. |
| 187 | let mut field_allocas: Vec<ValueId> = Vec::new(); |
| 188 | let span = func.block(cand.alloca_block).insts[cand.alloca_inst_idx].span; |
| 189 | |
| 190 | // Create field allocas and insert them right after the original alloca |
| 191 | // so they dominate all subsequent uses. |
| 192 | let insert_pos = cand.alloca_inst_idx + 1; |
| 193 | for i in 0..cand.count { |
| 194 | let new_id = func.next_value_id(); |
| 195 | let ptr_ty = IrType::Ptr(Box::new(cand.elem_ty.clone())); |
| 196 | func.register_type(new_id, ptr_ty.clone()); |
| 197 | func.block_mut(cand.alloca_block).insts.insert( |
| 198 | insert_pos + i as usize, |
| 199 | Inst { |
| 200 | id: new_id, |
| 201 | kind: InstKind::Alloca(cand.elem_ty.clone()), |
| 202 | ty: ptr_ty, |
| 203 | span, |
| 204 | }, |
| 205 | ); |
| 206 | field_allocas.push(new_id); |
| 207 | } |
| 208 | |
| 209 | // Build a GEP→field_alloca mapping: for each GEP(alloca, [const_idx]), |
| 210 | // replace the GEP result with the corresponding field alloca. |
| 211 | let mut gep_to_field: HashMap<ValueId, ValueId> = HashMap::new(); |
| 212 | |
| 213 | for block in &func.blocks { |
| 214 | for inst in &block.insts { |
| 215 | if let InstKind::GetElementPtr(base, indices) = &inst.kind { |
| 216 | if *base == cand.alloca_id && indices.len() == 1 { |
| 217 | if let Some(idx) = resolve_const_int(func, indices[0]) { |
| 218 | if idx >= 0 && (idx as u64) < cand.count { |
| 219 | gep_to_field.insert(inst.id, field_allocas[idx as usize]); |
| 220 | } |
| 221 | } |
| 222 | } |
| 223 | } |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | if gep_to_field.is_empty() { |
| 228 | return false; |
| 229 | } |
| 230 | |
| 231 | // Rewrite all uses of GEP results to use the field allocas. |
| 232 | for block in &mut func.blocks { |
| 233 | for inst in &mut block.insts { |
| 234 | // Replace operands that reference GEP results. |
| 235 | let mut new_kind = inst.kind.clone(); |
| 236 | let mut replaced = false; |
| 237 | match &mut new_kind { |
| 238 | InstKind::Load(ptr) => { |
| 239 | if let Some(&field) = gep_to_field.get(ptr) { |
| 240 | *ptr = field; |
| 241 | replaced = true; |
| 242 | } |
| 243 | } |
| 244 | InstKind::Store(_, ptr) => { |
| 245 | if let Some(&field) = gep_to_field.get(ptr) { |
| 246 | *ptr = field; |
| 247 | replaced = true; |
| 248 | } |
| 249 | } |
| 250 | _ => {} |
| 251 | } |
| 252 | if replaced { |
| 253 | inst.kind = new_kind; |
| 254 | } |
| 255 | } |
| 256 | } |
| 257 | |
| 258 | true |
| 259 | } |
| 260 | |
| 261 | #[cfg(test)] |
| 262 | mod tests { |
| 263 | use super::*; |
| 264 | use crate::ir::types::IntWidth; |
| 265 | use crate::lexer::{Position, Span}; |
| 266 | use crate::opt::pass::Pass; |
| 267 | |
| 268 | fn span() -> Span { |
| 269 | let pos = Position { line: 0, col: 0 }; |
| 270 | Span { |
| 271 | file_id: 0, |
| 272 | start: pos, |
| 273 | end: pos, |
| 274 | } |
| 275 | } |
| 276 | |
| 277 | #[test] |
| 278 | fn sroa_no_op_on_scalars() { |
| 279 | let mut m = Module::new("test".into()); |
| 280 | let mut f = Function::new("test".into(), vec![], IrType::Void); |
| 281 | let a = f.next_value_id(); |
| 282 | f.register_type(a, IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)))); |
| 283 | f.block_mut(f.entry).insts.push(Inst { |
| 284 | id: a, |
| 285 | ty: IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))), |
| 286 | span: span(), |
| 287 | kind: InstKind::Alloca(IrType::Int(IntWidth::I32)), |
| 288 | }); |
| 289 | f.block_mut(f.entry).terminator = Some(Terminator::Return(None)); |
| 290 | m.add_function(f); |
| 291 | let pass = Sroa; |
| 292 | assert!(!pass.run(&mut m), "scalar alloca should not be decomposed"); |
| 293 | } |
| 294 | |
| 295 | #[test] |
| 296 | fn sroa_rejects_gep_address_escape() { |
| 297 | let mut m = Module::new("test".into()); |
| 298 | let mut f = Function::new("test".into(), vec![], IrType::Void); |
| 299 | let arr_ty = IrType::Array( |
| 300 | Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))), |
| 301 | 2, |
| 302 | ); |
| 303 | let arr = f.next_value_id(); |
| 304 | f.register_type(arr, IrType::Ptr(Box::new(arr_ty.clone()))); |
| 305 | f.block_mut(f.entry).insts.push(Inst { |
| 306 | id: arr, |
| 307 | ty: IrType::Ptr(Box::new(arr_ty)), |
| 308 | span: span(), |
| 309 | kind: InstKind::Alloca(IrType::Array( |
| 310 | Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))), |
| 311 | 2, |
| 312 | )), |
| 313 | }); |
| 314 | |
| 315 | let zero = f.next_value_id(); |
| 316 | f.register_type(zero, IrType::Int(IntWidth::I64)); |
| 317 | f.block_mut(f.entry).insts.push(Inst { |
| 318 | id: zero, |
| 319 | ty: IrType::Int(IntWidth::I64), |
| 320 | span: span(), |
| 321 | kind: InstKind::ConstInt(0, IntWidth::I64), |
| 322 | }); |
| 323 | |
| 324 | let gep = f.next_value_id(); |
| 325 | f.register_type( |
| 326 | gep, |
| 327 | IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))), |
| 328 | ); |
| 329 | f.block_mut(f.entry).insts.push(Inst { |
| 330 | id: gep, |
| 331 | ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))), |
| 332 | span: span(), |
| 333 | kind: InstKind::GetElementPtr(arr, vec![zero]), |
| 334 | }); |
| 335 | |
| 336 | let sink = f.next_value_id(); |
| 337 | f.register_type( |
| 338 | sink, |
| 339 | IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new( |
| 340 | IrType::Int(IntWidth::I8), |
| 341 | )))))), |
| 342 | ); |
| 343 | f.block_mut(f.entry).insts.push(Inst { |
| 344 | id: sink, |
| 345 | ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new( |
| 346 | IrType::Int(IntWidth::I8), |
| 347 | )))))), |
| 348 | span: span(), |
| 349 | kind: InstKind::Alloca(IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int( |
| 350 | IntWidth::I8, |
| 351 | )))))), |
| 352 | }); |
| 353 | |
| 354 | let escape_store = f.next_value_id(); |
| 355 | f.register_type(escape_store, IrType::Void); |
| 356 | f.block_mut(f.entry).insts.push(Inst { |
| 357 | id: escape_store, |
| 358 | ty: IrType::Void, |
| 359 | span: span(), |
| 360 | kind: InstKind::Store(gep, sink), |
| 361 | }); |
| 362 | |
| 363 | f.block_mut(f.entry).terminator = Some(Terminator::Return(None)); |
| 364 | m.add_function(f); |
| 365 | |
| 366 | assert!( |
| 367 | !Sroa.run(&mut m), |
| 368 | "GEP addresses that escape should block SROA" |
| 369 | ); |
| 370 | } |
| 371 | } |
| 372 |