@@ -106,6 +106,9 @@ fn is_eligible(func: &Function, alloca_id: ValueId) -> bool { |
| 106 | 106 | return false; |
| 107 | 107 | } |
| 108 | 108 | } |
| 109 | + if gep_result_escapes(func, inst.id) { |
| 110 | + return false; |
| 111 | + } |
| 109 | 112 | // This use is fine — constant-index element access. |
| 110 | 113 | } |
| 111 | 114 | // Store where the alloca is the VALUE being stored = pointer escape. |
@@ -130,6 +133,40 @@ fn is_eligible(func: &Function, alloca_id: ValueId) -> bool { |
| 130 | 133 | true |
| 131 | 134 | } |
| 132 | 135 | |
| 136 | +fn gep_result_escapes(func: &Function, gep_id: ValueId) -> bool { |
| 137 | + for block in &func.blocks { |
| 138 | + for inst in &block.insts { |
| 139 | + let uses = inst_uses(&inst.kind); |
| 140 | + if !uses.contains(&gep_id) { |
| 141 | + continue; |
| 142 | + } |
| 143 | + match &inst.kind { |
| 144 | + InstKind::Load(ptr) if *ptr == gep_id => {} |
| 145 | + InstKind::Store(_, ptr) if *ptr == gep_id => {} |
| 146 | + _ => return true, |
| 147 | + } |
| 148 | + } |
| 149 | + if let Some(term) = &block.terminator { |
| 150 | + match term { |
| 151 | + Terminator::Return(Some(v)) if *v == gep_id => return true, |
| 152 | + Terminator::Branch(_, args) if args.contains(&gep_id) => return true, |
| 153 | + Terminator::CondBranch { cond, true_args, false_args, .. } => { |
| 154 | + if *cond == gep_id || true_args.contains(&gep_id) || false_args.contains(&gep_id) { |
| 155 | + return true; |
| 156 | + } |
| 157 | + } |
| 158 | + Terminator::Switch { selector, .. } => { |
| 159 | + if *selector == gep_id { |
| 160 | + return true; |
| 161 | + } |
| 162 | + } |
| 163 | + _ => {} |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | + false |
| 168 | +} |
| 169 | + |
| 133 | 170 | /// Decompose one alloca into individual scalar allocas. |
| 134 | 171 | fn decompose_alloca(func: &mut Function, cand: &SroaCandidate) -> bool { |
| 135 | 172 | // Create N individual allocas, one per field. |
@@ -229,4 +266,66 @@ mod tests { |
| 229 | 266 | let pass = Sroa; |
| 230 | 267 | assert!(!pass.run(&mut m), "scalar alloca should not be decomposed"); |
| 231 | 268 | } |
| 269 | + |
| 270 | + #[test] |
| 271 | + fn sroa_rejects_gep_address_escape() { |
| 272 | + let mut m = Module::new("test".into()); |
| 273 | + let mut f = Function::new("test".into(), vec![], IrType::Void); |
| 274 | + let arr_ty = IrType::Array( |
| 275 | + Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))), |
| 276 | + 2, |
| 277 | + ); |
| 278 | + let arr = f.next_value_id(); |
| 279 | + f.register_type(arr, IrType::Ptr(Box::new(arr_ty.clone()))); |
| 280 | + f.block_mut(f.entry).insts.push(Inst { |
| 281 | + id: arr, |
| 282 | + ty: IrType::Ptr(Box::new(arr_ty)), |
| 283 | + span: span(), |
| 284 | + kind: InstKind::Alloca(IrType::Array( |
| 285 | + Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))), |
| 286 | + 2, |
| 287 | + )), |
| 288 | + }); |
| 289 | + |
| 290 | + let zero = f.next_value_id(); |
| 291 | + f.register_type(zero, IrType::Int(IntWidth::I64)); |
| 292 | + f.block_mut(f.entry).insts.push(Inst { |
| 293 | + id: zero, |
| 294 | + ty: IrType::Int(IntWidth::I64), |
| 295 | + span: span(), |
| 296 | + kind: InstKind::ConstInt(0, IntWidth::I64), |
| 297 | + }); |
| 298 | + |
| 299 | + let gep = f.next_value_id(); |
| 300 | + f.register_type(gep, IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))))); |
| 301 | + f.block_mut(f.entry).insts.push(Inst { |
| 302 | + id: gep, |
| 303 | + ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))), |
| 304 | + span: span(), |
| 305 | + kind: InstKind::GetElementPtr(arr, vec![zero]), |
| 306 | + }); |
| 307 | + |
| 308 | + let sink = f.next_value_id(); |
| 309 | + f.register_type(sink, IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))))))); |
| 310 | + f.block_mut(f.entry).insts.push(Inst { |
| 311 | + id: sink, |
| 312 | + ty: IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))))))), |
| 313 | + span: span(), |
| 314 | + kind: InstKind::Alloca(IrType::Ptr(Box::new(IrType::Ptr(Box::new(IrType::Int(IntWidth::I8)))))), |
| 315 | + }); |
| 316 | + |
| 317 | + let escape_store = f.next_value_id(); |
| 318 | + f.register_type(escape_store, IrType::Void); |
| 319 | + f.block_mut(f.entry).insts.push(Inst { |
| 320 | + id: escape_store, |
| 321 | + ty: IrType::Void, |
| 322 | + span: span(), |
| 323 | + kind: InstKind::Store(gep, sink), |
| 324 | + }); |
| 325 | + |
| 326 | + f.block_mut(f.entry).terminator = Some(Terminator::Return(None)); |
| 327 | + m.add_function(f); |
| 328 | + |
| 329 | + assert!(!Sroa.run(&mut m), "GEP addresses that escape should block SROA"); |
| 330 | + } |
| 232 | 331 | } |