Rust · 14779 bytes Raw Blame History
1 //! Shared loop analysis utilities.
2 //!
3 //! Promotes preheader finding (from licm.rs) and constant resolution
4 //! (from unroll.rs) into public shared functions so all loop passes
5 //! use the same logic.
6
7 use crate::ir::inst::*;
8 use crate::ir::walk::NaturalLoop;
9 use std::collections::{HashMap, HashSet};
10
11 /// Find the unique preheader for a natural loop, if one exists.
12 ///
13 /// Returns `Some(preheader_id)` only when:
14 /// * the header has exactly one predecessor outside the loop body,
15 /// * that predecessor's terminator is an unconditional `Branch` to
16 /// the header, and
17 /// * the preheader is not itself the header.
18 pub fn find_preheader(
19 func: &Function,
20 lp: &NaturalLoop,
21 preds: &HashMap<BlockId, Vec<BlockId>>,
22 ) -> Option<BlockId> {
23 let header_preds = preds.get(&lp.header)?;
24 let mut outside: Vec<BlockId> = header_preds
25 .iter()
26 .copied()
27 .filter(|p| !lp.body.contains(p))
28 .collect();
29 outside.sort_by_key(|b| b.0);
30 outside.dedup();
31 if outside.len() != 1 {
32 return None;
33 }
34 let ph = outside[0];
35 if ph == lp.header {
36 return None;
37 }
38 let ph_block = func.block(ph);
39 match &ph_block.terminator {
40 Some(Terminator::Branch(dest, _)) if *dest == lp.header => Some(ph),
41 _ => None,
42 }
43 }
44
45 /// Resolve a ValueId to a compile-time constant integer, if it was
46 /// produced by a `ConstInt` instruction.
47 pub fn resolve_const_int(func: &Function, vid: ValueId) -> Option<i64> {
48 for block in &func.blocks {
49 for inst in &block.insts {
50 if inst.id == vid {
51 if let InstKind::ConstInt(v, _) = inst.kind {
52 return i64::try_from(v).ok();
53 }
54 return None;
55 }
56 }
57 }
58 None
59 }
60
61 /// Collect all ValueIds defined inside a loop body (block params +
62 /// instruction results).
63 pub fn loop_defined_values(func: &Function, lp: &NaturalLoop) -> HashSet<ValueId> {
64 let mut defs = HashSet::new();
65 for &bid in &lp.body {
66 let block = func.block(bid);
67 for bp in &block.params {
68 defs.insert(bp.id);
69 }
70 for inst in &block.insts {
71 defs.insert(inst.id);
72 }
73 }
74 defs
75 }
76
77 // ---------------------------------------------------------------------------
78 // Loop cloning and value remapping (shared by unswitching, fission, fusion)
79 // ---------------------------------------------------------------------------
80
81 /// Clone all blocks in a loop body, returning a block-ID mapping
82 /// (old → new) and the list of new block IDs.
83 pub fn clone_loop(
84 func: &mut Function,
85 lp: &NaturalLoop,
86 ) -> (HashMap<BlockId, BlockId>, Vec<BlockId>) {
87 let mut block_map: HashMap<BlockId, BlockId> = HashMap::new();
88 let mut new_blocks = Vec::new();
89
90 let body_sorted: Vec<BlockId> = {
91 let mut v: Vec<BlockId> = lp.body.iter().copied().collect();
92 v.sort_by_key(|b| b.0);
93 v
94 };
95 for &old_id in &body_sorted {
96 let old_name = &func.block(old_id).name;
97 let new_id = func.create_block(&format!("{}_clone", old_name));
98 block_map.insert(old_id, new_id);
99 new_blocks.push(new_id);
100 }
101
102 let mut val_map: HashMap<ValueId, ValueId> = HashMap::new();
103
104 for &old_id in &body_sorted {
105 let new_id = block_map[&old_id];
106 let old_params: Vec<BlockParam> = func.block(old_id).params.clone();
107 for bp in &old_params {
108 let new_vid = func.next_value_id();
109 func.register_type(new_vid, bp.ty.clone());
110 val_map.insert(bp.id, new_vid);
111 func.block_mut(new_id).params.push(BlockParam {
112 id: new_vid,
113 ty: bp.ty.clone(),
114 });
115 }
116 }
117
118 for &old_id in &body_sorted {
119 let old_insts: Vec<Inst> = func.block(old_id).insts.clone();
120 for inst in &old_insts {
121 let new_vid = func.next_value_id();
122 func.register_type(new_vid, inst.ty.clone());
123 val_map.insert(inst.id, new_vid);
124 }
125 }
126
127 for &old_id in &body_sorted {
128 let new_id = block_map[&old_id];
129 let old_insts: Vec<Inst> = func.block(old_id).insts.clone();
130 for inst in &old_insts {
131 let new_vid = val_map[&inst.id];
132 let new_kind = remap_inst_kind(&inst.kind, &val_map);
133 func.block_mut(new_id).insts.push(Inst {
134 id: new_vid,
135 kind: new_kind,
136 ty: inst.ty.clone(),
137 span: inst.span,
138 });
139 }
140 }
141
142 for &old_id in &body_sorted {
143 let new_id = block_map[&old_id];
144 let old_term = func.block(old_id).terminator.clone();
145 if let Some(term) = old_term {
146 let new_term = remap_terminator(&term, &block_map, &val_map);
147 func.block_mut(new_id).terminator = Some(new_term);
148 }
149 }
150
151 (block_map, new_blocks)
152 }
153
154 #[cfg(test)]
155 mod tests {
156 use super::*;
157 use crate::ir::builder::FuncBuilder;
158 use crate::ir::types::{IntWidth, IrType};
159
160 #[test]
161 fn clone_loop_remaps_values_defined_in_later_blocks() {
162 let mut func = Function::new("test".into(), vec![], IrType::Void);
163 let mut b = FuncBuilder::new(&mut func);
164
165 let preheader = b.create_block("preheader");
166 let header = b.create_block("header");
167 let use_block = b.create_block("use");
168 let def_block = b.create_block("def");
169 let latch = b.create_block("latch");
170 let exit = b.create_block("exit");
171
172 let one = b.const_i32(1);
173 let ten = b.const_i32(10);
174 b.branch(preheader, vec![]);
175
176 b.set_block(preheader);
177 b.branch(header, vec![one]);
178
179 let iv = b.add_block_param(header, IrType::Int(IntWidth::I32));
180 b.set_block(header);
181 let keep_looping = b.icmp(CmpOp::Le, iv, ten);
182 b.cond_branch(keep_looping, def_block, vec![], exit, vec![]);
183
184 b.set_block(def_block);
185 let carried = b.iadd(iv, one);
186 b.branch(use_block, vec![]);
187
188 b.set_block(use_block);
189 let observed = b.iadd(carried, one);
190 let spill = b.alloca(IrType::Int(IntWidth::I32));
191 b.store(observed, spill);
192 b.branch(latch, vec![]);
193
194 b.set_block(latch);
195 let nxt = b.iadd(iv, one);
196 b.branch(header, vec![nxt]);
197
198 b.set_block(exit);
199 b.ret_void();
200
201 let loop_body = [header, use_block, def_block, latch]
202 .into_iter()
203 .collect::<HashSet<_>>();
204 let lp = NaturalLoop {
205 header,
206 body: loop_body,
207 latches: vec![latch],
208 };
209
210 let (block_map, _) = clone_loop(&mut func, &lp);
211 let val_map = build_value_map(&func, &lp, &block_map);
212 let cloned_use = func.block(block_map[&use_block]);
213
214 let cloned_carried = val_map[&carried];
215 let cloned_observed = val_map[&observed];
216 let use_inst = cloned_use
217 .insts
218 .iter()
219 .find(|inst| inst.id == cloned_observed)
220 .expect("cloned use-block instruction");
221
222 assert!(
223 matches!(use_inst.kind, InstKind::IAdd(lhs, _) if lhs == cloned_carried),
224 "cloned use block should reference the cloned later-block value, got {:?}",
225 use_inst.kind,
226 );
227 }
228 }
229
230 /// Build a value map from original loop blocks → cloned loop blocks.
231 pub fn build_value_map(
232 func: &Function,
233 lp: &NaturalLoop,
234 block_map: &HashMap<BlockId, BlockId>,
235 ) -> HashMap<ValueId, ValueId> {
236 let mut val_map: HashMap<ValueId, ValueId> = HashMap::new();
237 let body_sorted: Vec<BlockId> = {
238 let mut v: Vec<BlockId> = lp.body.iter().copied().collect();
239 v.sort_by_key(|b| b.0);
240 v
241 };
242 for &old_id in &body_sorted {
243 let new_id = block_map[&old_id];
244 let old_block = func.block(old_id);
245 let new_block = func.block(new_id);
246 for (old_bp, new_bp) in old_block.params.iter().zip(new_block.params.iter()) {
247 val_map.insert(old_bp.id, new_bp.id);
248 }
249 for (old_inst, new_inst) in old_block.insts.iter().zip(new_block.insts.iter()) {
250 val_map.insert(old_inst.id, new_inst.id);
251 }
252 }
253 val_map
254 }
255
256 /// Remap all ValueId operands in an InstKind. Values not in the map
257 /// are left unchanged (they're defined outside the cloned region).
258 pub fn remap_inst_kind(kind: &InstKind, map: &HashMap<ValueId, ValueId>) -> InstKind {
259 let r = |v: &ValueId| *map.get(v).unwrap_or(v);
260 match kind {
261 InstKind::ConstInt(v, w) => InstKind::ConstInt(*v, *w),
262 InstKind::ConstFloat(v, w) => InstKind::ConstFloat(*v, *w),
263 InstKind::ConstBool(v) => InstKind::ConstBool(*v),
264 InstKind::ConstString(v) => InstKind::ConstString(v.clone()),
265 InstKind::Undef(t) => InstKind::Undef(t.clone()),
266 InstKind::GlobalAddr(s) => InstKind::GlobalAddr(s.clone()),
267 InstKind::IAdd(a, b) => InstKind::IAdd(r(a), r(b)),
268 InstKind::ISub(a, b) => InstKind::ISub(r(a), r(b)),
269 InstKind::IMul(a, b) => InstKind::IMul(r(a), r(b)),
270 InstKind::IDiv(a, b) => InstKind::IDiv(r(a), r(b)),
271 InstKind::IMod(a, b) => InstKind::IMod(r(a), r(b)),
272 InstKind::INeg(a) => InstKind::INeg(r(a)),
273 InstKind::FAdd(a, b) => InstKind::FAdd(r(a), r(b)),
274 InstKind::FSub(a, b) => InstKind::FSub(r(a), r(b)),
275 InstKind::FMul(a, b) => InstKind::FMul(r(a), r(b)),
276 InstKind::FDiv(a, b) => InstKind::FDiv(r(a), r(b)),
277 InstKind::FNeg(a) => InstKind::FNeg(r(a)),
278 InstKind::FAbs(a) => InstKind::FAbs(r(a)),
279 InstKind::FSqrt(a) => InstKind::FSqrt(r(a)),
280 InstKind::FPow(a, b) => InstKind::FPow(r(a), r(b)),
281 InstKind::ICmp(op, a, b) => InstKind::ICmp(*op, r(a), r(b)),
282 InstKind::FCmp(op, a, b) => InstKind::FCmp(*op, r(a), r(b)),
283 InstKind::And(a, b) => InstKind::And(r(a), r(b)),
284 InstKind::Or(a, b) => InstKind::Or(r(a), r(b)),
285 InstKind::Not(a) => InstKind::Not(r(a)),
286 InstKind::Select(c, t, f) => InstKind::Select(r(c), r(t), r(f)),
287 InstKind::BitAnd(a, b) => InstKind::BitAnd(r(a), r(b)),
288 InstKind::BitOr(a, b) => InstKind::BitOr(r(a), r(b)),
289 InstKind::BitXor(a, b) => InstKind::BitXor(r(a), r(b)),
290 InstKind::BitNot(a) => InstKind::BitNot(r(a)),
291 InstKind::Shl(a, b) => InstKind::Shl(r(a), r(b)),
292 InstKind::LShr(a, b) => InstKind::LShr(r(a), r(b)),
293 InstKind::AShr(a, b) => InstKind::AShr(r(a), r(b)),
294 InstKind::CountLeadingZeros(a) => InstKind::CountLeadingZeros(r(a)),
295 InstKind::CountTrailingZeros(a) => InstKind::CountTrailingZeros(r(a)),
296 InstKind::PopCount(a) => InstKind::PopCount(r(a)),
297 InstKind::IntToFloat(a, w) => InstKind::IntToFloat(r(a), *w),
298 InstKind::FloatToInt(a, w) => InstKind::FloatToInt(r(a), *w),
299 InstKind::FloatExtend(a, w) => InstKind::FloatExtend(r(a), *w),
300 InstKind::FloatTrunc(a, w) => InstKind::FloatTrunc(r(a), *w),
301 InstKind::IntExtend(a, w, s) => InstKind::IntExtend(r(a), *w, *s),
302 InstKind::IntTrunc(a, w) => InstKind::IntTrunc(r(a), *w),
303 InstKind::PtrToInt(a) => InstKind::PtrToInt(r(a)),
304 InstKind::IntToPtr(a, ty) => InstKind::IntToPtr(r(a), ty.clone()),
305 InstKind::Alloca(t) => InstKind::Alloca(t.clone()),
306 InstKind::Load(a) => InstKind::Load(r(a)),
307 InstKind::Store(v, p) => InstKind::Store(r(v), r(p)),
308 InstKind::GetElementPtr(base, idxs) => {
309 InstKind::GetElementPtr(r(base), idxs.iter().map(&r).collect())
310 }
311 InstKind::Call(f, args) => InstKind::Call(f.clone(), args.iter().map(&r).collect()),
312 InstKind::RuntimeCall(f, args) => {
313 InstKind::RuntimeCall(f.clone(), args.iter().map(&r).collect())
314 }
315 InstKind::ExtractField(v, idx) => InstKind::ExtractField(r(v), *idx),
316 InstKind::InsertField(v, idx, fld) => InstKind::InsertField(r(v), *idx, r(fld)),
317
318 // ---- SIMD vector ops ----
319 InstKind::VAdd(a, b) => InstKind::VAdd(r(a), r(b)),
320 InstKind::VSub(a, b) => InstKind::VSub(r(a), r(b)),
321 InstKind::VMul(a, b) => InstKind::VMul(r(a), r(b)),
322 InstKind::VDiv(a, b) => InstKind::VDiv(r(a), r(b)),
323 InstKind::VNeg(a) => InstKind::VNeg(r(a)),
324 InstKind::VAbs(a) => InstKind::VAbs(r(a)),
325 InstKind::VSqrt(a) => InstKind::VSqrt(r(a)),
326 InstKind::VFma(a, b, c) => InstKind::VFma(r(a), r(b), r(c)),
327 InstKind::VSelect(m, t, f) => InstKind::VSelect(r(m), r(t), r(f)),
328 InstKind::VMin(a, b) => InstKind::VMin(r(a), r(b)),
329 InstKind::VMax(a, b) => InstKind::VMax(r(a), r(b)),
330 InstKind::VICmp(op, a, b) => InstKind::VICmp(*op, r(a), r(b)),
331 InstKind::VFCmp(op, a, b) => InstKind::VFCmp(*op, r(a), r(b)),
332 InstKind::VLoad(p) => InstKind::VLoad(r(p)),
333 InstKind::VStore(v, p) => InstKind::VStore(r(v), r(p)),
334 InstKind::VBitcast(v, ty) => InstKind::VBitcast(r(v), ty.clone()),
335 InstKind::VExtract(v, lane) => InstKind::VExtract(r(v), *lane),
336 InstKind::VInsert(v, lane, s) => InstKind::VInsert(r(v), *lane, r(s)),
337 InstKind::VBroadcast(s) => InstKind::VBroadcast(r(s)),
338 InstKind::VReduceSum(v) => InstKind::VReduceSum(r(v)),
339 InstKind::VReduceMin(v) => InstKind::VReduceMin(r(v)),
340 InstKind::VReduceMax(v) => InstKind::VReduceMax(r(v)),
341 }
342 }
343
344 /// Remap block targets and value operands in a terminator.
345 pub fn remap_terminator(
346 term: &Terminator,
347 block_map: &HashMap<BlockId, BlockId>,
348 val_map: &HashMap<ValueId, ValueId>,
349 ) -> Terminator {
350 let rb = |b: &BlockId| *block_map.get(b).unwrap_or(b);
351 let rv = |v: &ValueId| *val_map.get(v).unwrap_or(v);
352 let rvs = |vs: &[ValueId]| -> Vec<ValueId> { vs.iter().map(&rv).collect() };
353 match term {
354 Terminator::Return(v) => Terminator::Return(v.map(|x| rv(&x))),
355 Terminator::Branch(dest, args) => Terminator::Branch(rb(dest), rvs(args)),
356 Terminator::CondBranch {
357 cond,
358 true_dest,
359 true_args,
360 false_dest,
361 false_args,
362 } => Terminator::CondBranch {
363 cond: rv(cond),
364 true_dest: rb(true_dest),
365 true_args: rvs(true_args),
366 false_dest: rb(false_dest),
367 false_args: rvs(false_args),
368 },
369 Terminator::Switch {
370 selector,
371 default,
372 cases,
373 } => Terminator::Switch {
374 selector: rv(selector),
375 default: rb(default),
376 cases: cases.iter().map(|(v, d)| (*v, rb(d))).collect(),
377 },
378 Terminator::Unreachable => Terminator::Unreachable,
379 }
380 }
381