| 1 | //! Shared loop analysis utilities. |
| 2 | //! |
| 3 | //! Promotes preheader finding (from licm.rs) and constant resolution |
| 4 | //! (from unroll.rs) into public shared functions so all loop passes |
| 5 | //! use the same logic. |
| 6 | |
| 7 | use crate::ir::inst::*; |
| 8 | use crate::ir::walk::NaturalLoop; |
| 9 | use std::collections::{HashMap, HashSet}; |
| 10 | |
| 11 | /// Find the unique preheader for a natural loop, if one exists. |
| 12 | /// |
| 13 | /// Returns `Some(preheader_id)` only when: |
| 14 | /// * the header has exactly one predecessor outside the loop body, |
| 15 | /// * that predecessor's terminator is an unconditional `Branch` to |
| 16 | /// the header, and |
| 17 | /// * the preheader is not itself the header. |
| 18 | pub fn find_preheader( |
| 19 | func: &Function, |
| 20 | lp: &NaturalLoop, |
| 21 | preds: &HashMap<BlockId, Vec<BlockId>>, |
| 22 | ) -> Option<BlockId> { |
| 23 | let header_preds = preds.get(&lp.header)?; |
| 24 | let mut outside: Vec<BlockId> = header_preds |
| 25 | .iter() |
| 26 | .copied() |
| 27 | .filter(|p| !lp.body.contains(p)) |
| 28 | .collect(); |
| 29 | outside.sort_by_key(|b| b.0); |
| 30 | outside.dedup(); |
| 31 | if outside.len() != 1 { |
| 32 | return None; |
| 33 | } |
| 34 | let ph = outside[0]; |
| 35 | if ph == lp.header { |
| 36 | return None; |
| 37 | } |
| 38 | let ph_block = func.block(ph); |
| 39 | match &ph_block.terminator { |
| 40 | Some(Terminator::Branch(dest, _)) if *dest == lp.header => Some(ph), |
| 41 | _ => None, |
| 42 | } |
| 43 | } |
| 44 | |
| 45 | /// Resolve a ValueId to a compile-time constant integer, if it was |
| 46 | /// produced by a `ConstInt` instruction. |
| 47 | pub fn resolve_const_int(func: &Function, vid: ValueId) -> Option<i64> { |
| 48 | for block in &func.blocks { |
| 49 | for inst in &block.insts { |
| 50 | if inst.id == vid { |
| 51 | if let InstKind::ConstInt(v, _) = inst.kind { |
| 52 | return i64::try_from(v).ok(); |
| 53 | } |
| 54 | return None; |
| 55 | } |
| 56 | } |
| 57 | } |
| 58 | None |
| 59 | } |
| 60 | |
| 61 | /// Collect all ValueIds defined inside a loop body (block params + |
| 62 | /// instruction results). |
| 63 | pub fn loop_defined_values(func: &Function, lp: &NaturalLoop) -> HashSet<ValueId> { |
| 64 | let mut defs = HashSet::new(); |
| 65 | for &bid in &lp.body { |
| 66 | let block = func.block(bid); |
| 67 | for bp in &block.params { |
| 68 | defs.insert(bp.id); |
| 69 | } |
| 70 | for inst in &block.insts { |
| 71 | defs.insert(inst.id); |
| 72 | } |
| 73 | } |
| 74 | defs |
| 75 | } |
| 76 | |
| 77 | // --------------------------------------------------------------------------- |
| 78 | // Loop cloning and value remapping (shared by unswitching, fission, fusion) |
| 79 | // --------------------------------------------------------------------------- |
| 80 | |
| 81 | /// Clone all blocks in a loop body, returning a block-ID mapping |
| 82 | /// (old → new) and the list of new block IDs. |
| 83 | pub fn clone_loop( |
| 84 | func: &mut Function, |
| 85 | lp: &NaturalLoop, |
| 86 | ) -> (HashMap<BlockId, BlockId>, Vec<BlockId>) { |
| 87 | let mut block_map: HashMap<BlockId, BlockId> = HashMap::new(); |
| 88 | let mut new_blocks = Vec::new(); |
| 89 | |
| 90 | let body_sorted: Vec<BlockId> = { |
| 91 | let mut v: Vec<BlockId> = lp.body.iter().copied().collect(); |
| 92 | v.sort_by_key(|b| b.0); |
| 93 | v |
| 94 | }; |
| 95 | for &old_id in &body_sorted { |
| 96 | let old_name = &func.block(old_id).name; |
| 97 | let new_id = func.create_block(&format!("{}_clone", old_name)); |
| 98 | block_map.insert(old_id, new_id); |
| 99 | new_blocks.push(new_id); |
| 100 | } |
| 101 | |
| 102 | let mut val_map: HashMap<ValueId, ValueId> = HashMap::new(); |
| 103 | |
| 104 | for &old_id in &body_sorted { |
| 105 | let new_id = block_map[&old_id]; |
| 106 | let old_params: Vec<BlockParam> = func.block(old_id).params.clone(); |
| 107 | for bp in &old_params { |
| 108 | let new_vid = func.next_value_id(); |
| 109 | func.register_type(new_vid, bp.ty.clone()); |
| 110 | val_map.insert(bp.id, new_vid); |
| 111 | func.block_mut(new_id).params.push(BlockParam { |
| 112 | id: new_vid, |
| 113 | ty: bp.ty.clone(), |
| 114 | }); |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | for &old_id in &body_sorted { |
| 119 | let old_insts: Vec<Inst> = func.block(old_id).insts.clone(); |
| 120 | for inst in &old_insts { |
| 121 | let new_vid = func.next_value_id(); |
| 122 | func.register_type(new_vid, inst.ty.clone()); |
| 123 | val_map.insert(inst.id, new_vid); |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | for &old_id in &body_sorted { |
| 128 | let new_id = block_map[&old_id]; |
| 129 | let old_insts: Vec<Inst> = func.block(old_id).insts.clone(); |
| 130 | for inst in &old_insts { |
| 131 | let new_vid = val_map[&inst.id]; |
| 132 | let new_kind = remap_inst_kind(&inst.kind, &val_map); |
| 133 | func.block_mut(new_id).insts.push(Inst { |
| 134 | id: new_vid, |
| 135 | kind: new_kind, |
| 136 | ty: inst.ty.clone(), |
| 137 | span: inst.span, |
| 138 | }); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | for &old_id in &body_sorted { |
| 143 | let new_id = block_map[&old_id]; |
| 144 | let old_term = func.block(old_id).terminator.clone(); |
| 145 | if let Some(term) = old_term { |
| 146 | let new_term = remap_terminator(&term, &block_map, &val_map); |
| 147 | func.block_mut(new_id).terminator = Some(new_term); |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | (block_map, new_blocks) |
| 152 | } |
| 153 | |
| 154 | #[cfg(test)] |
| 155 | mod tests { |
| 156 | use super::*; |
| 157 | use crate::ir::builder::FuncBuilder; |
| 158 | use crate::ir::types::{IntWidth, IrType}; |
| 159 | |
| 160 | #[test] |
| 161 | fn clone_loop_remaps_values_defined_in_later_blocks() { |
| 162 | let mut func = Function::new("test".into(), vec![], IrType::Void); |
| 163 | let mut b = FuncBuilder::new(&mut func); |
| 164 | |
| 165 | let preheader = b.create_block("preheader"); |
| 166 | let header = b.create_block("header"); |
| 167 | let use_block = b.create_block("use"); |
| 168 | let def_block = b.create_block("def"); |
| 169 | let latch = b.create_block("latch"); |
| 170 | let exit = b.create_block("exit"); |
| 171 | |
| 172 | let one = b.const_i32(1); |
| 173 | let ten = b.const_i32(10); |
| 174 | b.branch(preheader, vec![]); |
| 175 | |
| 176 | b.set_block(preheader); |
| 177 | b.branch(header, vec![one]); |
| 178 | |
| 179 | let iv = b.add_block_param(header, IrType::Int(IntWidth::I32)); |
| 180 | b.set_block(header); |
| 181 | let keep_looping = b.icmp(CmpOp::Le, iv, ten); |
| 182 | b.cond_branch(keep_looping, def_block, vec![], exit, vec![]); |
| 183 | |
| 184 | b.set_block(def_block); |
| 185 | let carried = b.iadd(iv, one); |
| 186 | b.branch(use_block, vec![]); |
| 187 | |
| 188 | b.set_block(use_block); |
| 189 | let observed = b.iadd(carried, one); |
| 190 | let spill = b.alloca(IrType::Int(IntWidth::I32)); |
| 191 | b.store(observed, spill); |
| 192 | b.branch(latch, vec![]); |
| 193 | |
| 194 | b.set_block(latch); |
| 195 | let nxt = b.iadd(iv, one); |
| 196 | b.branch(header, vec![nxt]); |
| 197 | |
| 198 | b.set_block(exit); |
| 199 | b.ret_void(); |
| 200 | |
| 201 | let loop_body = [header, use_block, def_block, latch] |
| 202 | .into_iter() |
| 203 | .collect::<HashSet<_>>(); |
| 204 | let lp = NaturalLoop { |
| 205 | header, |
| 206 | body: loop_body, |
| 207 | latches: vec![latch], |
| 208 | }; |
| 209 | |
| 210 | let (block_map, _) = clone_loop(&mut func, &lp); |
| 211 | let val_map = build_value_map(&func, &lp, &block_map); |
| 212 | let cloned_use = func.block(block_map[&use_block]); |
| 213 | |
| 214 | let cloned_carried = val_map[&carried]; |
| 215 | let cloned_observed = val_map[&observed]; |
| 216 | let use_inst = cloned_use |
| 217 | .insts |
| 218 | .iter() |
| 219 | .find(|inst| inst.id == cloned_observed) |
| 220 | .expect("cloned use-block instruction"); |
| 221 | |
| 222 | assert!( |
| 223 | matches!(use_inst.kind, InstKind::IAdd(lhs, _) if lhs == cloned_carried), |
| 224 | "cloned use block should reference the cloned later-block value, got {:?}", |
| 225 | use_inst.kind, |
| 226 | ); |
| 227 | } |
| 228 | } |
| 229 | |
| 230 | /// Build a value map from original loop blocks → cloned loop blocks. |
| 231 | pub fn build_value_map( |
| 232 | func: &Function, |
| 233 | lp: &NaturalLoop, |
| 234 | block_map: &HashMap<BlockId, BlockId>, |
| 235 | ) -> HashMap<ValueId, ValueId> { |
| 236 | let mut val_map: HashMap<ValueId, ValueId> = HashMap::new(); |
| 237 | let body_sorted: Vec<BlockId> = { |
| 238 | let mut v: Vec<BlockId> = lp.body.iter().copied().collect(); |
| 239 | v.sort_by_key(|b| b.0); |
| 240 | v |
| 241 | }; |
| 242 | for &old_id in &body_sorted { |
| 243 | let new_id = block_map[&old_id]; |
| 244 | let old_block = func.block(old_id); |
| 245 | let new_block = func.block(new_id); |
| 246 | for (old_bp, new_bp) in old_block.params.iter().zip(new_block.params.iter()) { |
| 247 | val_map.insert(old_bp.id, new_bp.id); |
| 248 | } |
| 249 | for (old_inst, new_inst) in old_block.insts.iter().zip(new_block.insts.iter()) { |
| 250 | val_map.insert(old_inst.id, new_inst.id); |
| 251 | } |
| 252 | } |
| 253 | val_map |
| 254 | } |
| 255 | |
| 256 | /// Remap all ValueId operands in an InstKind. Values not in the map |
| 257 | /// are left unchanged (they're defined outside the cloned region). |
| 258 | pub fn remap_inst_kind(kind: &InstKind, map: &HashMap<ValueId, ValueId>) -> InstKind { |
| 259 | let r = |v: &ValueId| *map.get(v).unwrap_or(v); |
| 260 | match kind { |
| 261 | InstKind::ConstInt(v, w) => InstKind::ConstInt(*v, *w), |
| 262 | InstKind::ConstFloat(v, w) => InstKind::ConstFloat(*v, *w), |
| 263 | InstKind::ConstBool(v) => InstKind::ConstBool(*v), |
| 264 | InstKind::ConstString(v) => InstKind::ConstString(v.clone()), |
| 265 | InstKind::Undef(t) => InstKind::Undef(t.clone()), |
| 266 | InstKind::GlobalAddr(s) => InstKind::GlobalAddr(s.clone()), |
| 267 | InstKind::IAdd(a, b) => InstKind::IAdd(r(a), r(b)), |
| 268 | InstKind::ISub(a, b) => InstKind::ISub(r(a), r(b)), |
| 269 | InstKind::IMul(a, b) => InstKind::IMul(r(a), r(b)), |
| 270 | InstKind::IDiv(a, b) => InstKind::IDiv(r(a), r(b)), |
| 271 | InstKind::IMod(a, b) => InstKind::IMod(r(a), r(b)), |
| 272 | InstKind::INeg(a) => InstKind::INeg(r(a)), |
| 273 | InstKind::FAdd(a, b) => InstKind::FAdd(r(a), r(b)), |
| 274 | InstKind::FSub(a, b) => InstKind::FSub(r(a), r(b)), |
| 275 | InstKind::FMul(a, b) => InstKind::FMul(r(a), r(b)), |
| 276 | InstKind::FDiv(a, b) => InstKind::FDiv(r(a), r(b)), |
| 277 | InstKind::FNeg(a) => InstKind::FNeg(r(a)), |
| 278 | InstKind::FAbs(a) => InstKind::FAbs(r(a)), |
| 279 | InstKind::FSqrt(a) => InstKind::FSqrt(r(a)), |
| 280 | InstKind::FPow(a, b) => InstKind::FPow(r(a), r(b)), |
| 281 | InstKind::ICmp(op, a, b) => InstKind::ICmp(*op, r(a), r(b)), |
| 282 | InstKind::FCmp(op, a, b) => InstKind::FCmp(*op, r(a), r(b)), |
| 283 | InstKind::And(a, b) => InstKind::And(r(a), r(b)), |
| 284 | InstKind::Or(a, b) => InstKind::Or(r(a), r(b)), |
| 285 | InstKind::Not(a) => InstKind::Not(r(a)), |
| 286 | InstKind::Select(c, t, f) => InstKind::Select(r(c), r(t), r(f)), |
| 287 | InstKind::BitAnd(a, b) => InstKind::BitAnd(r(a), r(b)), |
| 288 | InstKind::BitOr(a, b) => InstKind::BitOr(r(a), r(b)), |
| 289 | InstKind::BitXor(a, b) => InstKind::BitXor(r(a), r(b)), |
| 290 | InstKind::BitNot(a) => InstKind::BitNot(r(a)), |
| 291 | InstKind::Shl(a, b) => InstKind::Shl(r(a), r(b)), |
| 292 | InstKind::LShr(a, b) => InstKind::LShr(r(a), r(b)), |
| 293 | InstKind::AShr(a, b) => InstKind::AShr(r(a), r(b)), |
| 294 | InstKind::CountLeadingZeros(a) => InstKind::CountLeadingZeros(r(a)), |
| 295 | InstKind::CountTrailingZeros(a) => InstKind::CountTrailingZeros(r(a)), |
| 296 | InstKind::PopCount(a) => InstKind::PopCount(r(a)), |
| 297 | InstKind::IntToFloat(a, w) => InstKind::IntToFloat(r(a), *w), |
| 298 | InstKind::FloatToInt(a, w) => InstKind::FloatToInt(r(a), *w), |
| 299 | InstKind::FloatExtend(a, w) => InstKind::FloatExtend(r(a), *w), |
| 300 | InstKind::FloatTrunc(a, w) => InstKind::FloatTrunc(r(a), *w), |
| 301 | InstKind::IntExtend(a, w, s) => InstKind::IntExtend(r(a), *w, *s), |
| 302 | InstKind::IntTrunc(a, w) => InstKind::IntTrunc(r(a), *w), |
| 303 | InstKind::PtrToInt(a) => InstKind::PtrToInt(r(a)), |
| 304 | InstKind::IntToPtr(a, ty) => InstKind::IntToPtr(r(a), ty.clone()), |
| 305 | InstKind::Alloca(t) => InstKind::Alloca(t.clone()), |
| 306 | InstKind::Load(a) => InstKind::Load(r(a)), |
| 307 | InstKind::Store(v, p) => InstKind::Store(r(v), r(p)), |
| 308 | InstKind::GetElementPtr(base, idxs) => { |
| 309 | InstKind::GetElementPtr(r(base), idxs.iter().map(&r).collect()) |
| 310 | } |
| 311 | InstKind::Call(f, args) => InstKind::Call(f.clone(), args.iter().map(&r).collect()), |
| 312 | InstKind::RuntimeCall(f, args) => { |
| 313 | InstKind::RuntimeCall(f.clone(), args.iter().map(&r).collect()) |
| 314 | } |
| 315 | InstKind::ExtractField(v, idx) => InstKind::ExtractField(r(v), *idx), |
| 316 | InstKind::InsertField(v, idx, fld) => InstKind::InsertField(r(v), *idx, r(fld)), |
| 317 | |
| 318 | // ---- SIMD vector ops ---- |
| 319 | InstKind::VAdd(a, b) => InstKind::VAdd(r(a), r(b)), |
| 320 | InstKind::VSub(a, b) => InstKind::VSub(r(a), r(b)), |
| 321 | InstKind::VMul(a, b) => InstKind::VMul(r(a), r(b)), |
| 322 | InstKind::VDiv(a, b) => InstKind::VDiv(r(a), r(b)), |
| 323 | InstKind::VNeg(a) => InstKind::VNeg(r(a)), |
| 324 | InstKind::VAbs(a) => InstKind::VAbs(r(a)), |
| 325 | InstKind::VSqrt(a) => InstKind::VSqrt(r(a)), |
| 326 | InstKind::VFma(a, b, c) => InstKind::VFma(r(a), r(b), r(c)), |
| 327 | InstKind::VSelect(m, t, f) => InstKind::VSelect(r(m), r(t), r(f)), |
| 328 | InstKind::VMin(a, b) => InstKind::VMin(r(a), r(b)), |
| 329 | InstKind::VMax(a, b) => InstKind::VMax(r(a), r(b)), |
| 330 | InstKind::VICmp(op, a, b) => InstKind::VICmp(*op, r(a), r(b)), |
| 331 | InstKind::VFCmp(op, a, b) => InstKind::VFCmp(*op, r(a), r(b)), |
| 332 | InstKind::VLoad(p) => InstKind::VLoad(r(p)), |
| 333 | InstKind::VStore(v, p) => InstKind::VStore(r(v), r(p)), |
| 334 | InstKind::VBitcast(v, ty) => InstKind::VBitcast(r(v), ty.clone()), |
| 335 | InstKind::VExtract(v, lane) => InstKind::VExtract(r(v), *lane), |
| 336 | InstKind::VInsert(v, lane, s) => InstKind::VInsert(r(v), *lane, r(s)), |
| 337 | InstKind::VBroadcast(s) => InstKind::VBroadcast(r(s)), |
| 338 | InstKind::VReduceSum(v) => InstKind::VReduceSum(r(v)), |
| 339 | InstKind::VReduceMin(v) => InstKind::VReduceMin(r(v)), |
| 340 | InstKind::VReduceMax(v) => InstKind::VReduceMax(r(v)), |
| 341 | } |
| 342 | } |
| 343 | |
| 344 | /// Remap block targets and value operands in a terminator. |
| 345 | pub fn remap_terminator( |
| 346 | term: &Terminator, |
| 347 | block_map: &HashMap<BlockId, BlockId>, |
| 348 | val_map: &HashMap<ValueId, ValueId>, |
| 349 | ) -> Terminator { |
| 350 | let rb = |b: &BlockId| *block_map.get(b).unwrap_or(b); |
| 351 | let rv = |v: &ValueId| *val_map.get(v).unwrap_or(v); |
| 352 | let rvs = |vs: &[ValueId]| -> Vec<ValueId> { vs.iter().map(&rv).collect() }; |
| 353 | match term { |
| 354 | Terminator::Return(v) => Terminator::Return(v.map(|x| rv(&x))), |
| 355 | Terminator::Branch(dest, args) => Terminator::Branch(rb(dest), rvs(args)), |
| 356 | Terminator::CondBranch { |
| 357 | cond, |
| 358 | true_dest, |
| 359 | true_args, |
| 360 | false_dest, |
| 361 | false_args, |
| 362 | } => Terminator::CondBranch { |
| 363 | cond: rv(cond), |
| 364 | true_dest: rb(true_dest), |
| 365 | true_args: rvs(true_args), |
| 366 | false_dest: rb(false_dest), |
| 367 | false_args: rvs(false_args), |
| 368 | }, |
| 369 | Terminator::Switch { |
| 370 | selector, |
| 371 | default, |
| 372 | cases, |
| 373 | } => Terminator::Switch { |
| 374 | selector: rv(selector), |
| 375 | default: rb(default), |
| 376 | cases: cases.iter().map(|(v, d)| (*v, rb(d))).collect(), |
| 377 | }, |
| 378 | Terminator::Unreachable => Terminator::Unreachable, |
| 379 | } |
| 380 | } |
| 381 |