Rust · 141441 bytes Raw Blame History
1 //! True NEON loop vectorizer (Sprint 12 Stage 4 MVP).
2 //!
3 //! Detects counted DO loops with statically-known trip count divisible
4 //! by the NEON lane count V, and rewrites the inner body to consume and
5 //! produce vector IR (`VLoad`/`VAdd`/`VStore`/...). The downstream
6 //! `isel` pass then emits real NEON intrinsics.
7 //!
8 //! This is the *real* vectorizer. The older `vectorize.rs` (which
9 //! batches scalar dispatch through runtime kernel calls) remains as a
10 //! fallback for shapes this MVP does not yet handle: mismatched trip
11 //! counts (no scalar tail yet), multi-statement bodies, reductions,
12 //! WHERE masks. Stages 5–6 of the sprint plan extend this pass to those
13 //! cases.
14
15 use std::collections::{HashMap, HashSet};
16
17 use crate::ir::inst::*;
18 use crate::ir::types::{FloatWidth, IntWidth, IrType};
19
20 use super::loop_utils::{find_preheader, loop_defined_values, remap_inst_kind, resolve_const_int};
21 use super::pass::Pass;
22 use super::util::{find_natural_loops, inst_uses, predecessors, terminator_uses, NaturalLoop};
23
24 pub struct NeonVectorize;
25
26 impl Pass for NeonVectorize {
27 fn name(&self) -> &'static str {
28 "neon_vectorize"
29 }
30
31 fn run(&self, module: &mut Module) -> bool {
32 let mut changed = false;
33 for func in &mut module.functions {
34 while vectorize_one_loop(func) {
35 changed = true;
36 }
37 }
38 if changed {
39 for func in &mut module.functions {
40 func.rebuild_type_cache();
41 }
42 }
43 changed
44 }
45 }
46
47 #[derive(Debug, Clone, Copy)]
48 struct CountedLoop {
49 preheader: BlockId,
50 header: BlockId,
51 body: BlockId,
52 iv_param: ValueId,
53 iv_init: i64,
54 iv_bound: i64,
55 /// The header's `icmp le|lt iv, hi_const` instruction id. Needed
56 /// when scalar-tail peeling has to retarget the loop's bound to
57 /// `iv_init + head_count - 1`.
58 cond_id: ValueId,
59 /// The ConstInt feeding the icmp's RHS. `apply_vector_plan` does
60 /// not mutate it in place (could be aliased) but inserts a fresh
61 /// const and rewires the icmp.
62 bound_const_id: ValueId,
63 }
64
65 #[derive(Debug, Clone)]
66 struct ArrayAccess {
67 base: ValueId,
68 elem_ty: IrType,
69 len: u64,
70 lower: i64,
71 }
72
73 /// A counted WHERE-block loop. The natural-loop body is a 4-block
74 /// diamond: header (cmp + cond_br exit/body) → body (load + cmp +
75 /// cond_br then/incr) → then (conditional store + br incr) → incr
76 /// (iv += 1 + br header). The vectorizer rewrites this into:
77 ///
78 /// body': vload a; vload b_old; v(f|i)cmp predicate; vselect mask, va, vb_old; vstore;
79 /// drop the `then` block, branch body' → incr unconditionally.
80 #[derive(Debug, Clone, Copy)]
81 struct WhereLoop {
82 preheader: BlockId,
83 header: BlockId,
84 /// The body block holding the per-iteration cmp and cond_br
85 /// to `then` / `incr`.
86 body: BlockId,
87 /// The "then" arm with the conditional store(s).
88 then_block: BlockId,
89 /// The "else" arm, when WHERE/ELSEWHERE is used. The block
90 /// branches unconditionally to `incr_block`. When `None`, the
91 /// body's false branch goes directly to `incr_block` (single-arm
92 /// WHERE).
93 else_block: Option<BlockId>,
94 /// The latch / incr block (iv + 1, br header).
95 incr_block: BlockId,
96 iv_param: ValueId,
97 iv_init: i64,
98 iv_bound: i64,
99 /// Header `icmp ge|gt iv, hi` (body on FALSE branch).
100 cond_id: ValueId,
101 bound_const_id: ValueId,
102 }
103
104 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
105 enum BinaryKind {
106 Add,
107 Sub,
108 Mul,
109 /// Float-only — NEON has no integer vector divide.
110 Div,
111 /// Element-wise `max(lhs, rhs)`. IR shape is
112 /// `select(cmp ge|gt lhs, rhs, lhs, rhs)`.
113 Max,
114 /// Element-wise `min(lhs, rhs)`. IR shape is
115 /// `select(cmp le|lt lhs, rhs, lhs, rhs)`.
116 Min,
117 }
118
119 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
120 enum UnaryKind {
121 Neg,
122 Abs,
123 Sqrt,
124 }
125
126 /// One operand of the body's binop, classified as either an array
127 /// `Load` that becomes a `VLoad` or a loop-invariant scalar that
128 /// becomes a `VBroadcast` hoisted into the preheader.
129 #[derive(Debug, Clone)]
130 enum BinopOperand {
131 /// A scalar `Load` whose pointer is `gep base, [iv-derived]`.
132 /// `load_id` is the original IR Load instruction we'll rewrite
133 /// to a VLoad.
134 ArrayLoad(ValueId),
135 /// A loop-invariant scalar value defined outside the loop. We
136 /// will emit a `VBroadcast` in the preheader to splat it across
137 /// every lane and rewrite the binop to consume that vector.
138 InvariantScalar(ValueId),
139 }
140
141 /// What kind of element-wise op the loop body computes.
142 #[derive(Debug, Clone)]
143 enum BodyOp {
144 /// `dest(i) = source` — a pure copy of one array load (no
145 /// arithmetic). The Load inst gets rewritten to VLoad and its
146 /// result is stored directly. `InvariantScalar` is rejected
147 /// here: a constant fill goes through the older bulk path.
148 Copy { source: BinopOperand },
149 /// `dest(i) = -src` or `dest(i) = abs(src)` — single-operand
150 /// element-wise op. `src` must be an `ArrayLoad` (negating an
151 /// invariant scalar would be a constant fill).
152 Unary {
153 source: BinopOperand,
154 unary_id: ValueId,
155 kind: UnaryKind,
156 },
157 /// `dest(i) = lhs op rhs` — a single element-wise binop with at
158 /// least one array load.
159 Binop {
160 lhs: BinopOperand,
161 rhs: BinopOperand,
162 binop_id: ValueId,
163 kind: BinaryKind,
164 },
165 /// `dest(i) = a*b + c` — element-wise FMA. Float-only (NEON has
166 /// `fmla.4s` / `fmla.2d` for floats; integer `mla.4s` exists but
167 /// VFma in our IR is float). At least one of {a,b,c} must be
168 /// an array load; the others can be invariant scalars (broadcast).
169 Fma {
170 a: BinopOperand,
171 b: BinopOperand,
172 c: BinopOperand,
173 fmul_id: ValueId,
174 fadd_id: ValueId,
175 },
176 }
177
178 /// One element-wise statement (one store) inside a multi-statement
179 /// vectorizable body. All statements in a `VectorPlan` share the
180 /// same lane count and element type.
181 #[derive(Debug, Clone)]
182 struct Statement {
183 /// What expression feeds the store.
184 op: BodyOp,
185 /// Original Store instruction ID to be rewritten to VStore.
186 store: ValueId,
187 }
188
189 /// Concrete plan: one or more element-wise statements (each up to
190 /// two array loads or one load + one invariant scalar plus one
191 /// array store) sharing the same iteration space and element type.
192 #[derive(Debug, Clone)]
193 struct VectorPlan {
194 lanes: u8,
195 elem_ty: IrType,
196 /// Every statement that feeds a store in the body.
197 statements: Vec<Statement>,
198 /// Original `iadd iv, 1` step instruction in the body.
199 step_iadd: ValueId,
200 /// The `1` ConstInt used as the step (for replacement with V).
201 step_const: ValueId,
202 /// Width of the IV ConstInt (i32 for typical 1..N loops).
203 iv_int_width: IntWidth,
204 /// Number of vector iterations × `lanes` = head iteration count.
205 /// When `tail_count == 0` the loop fully vectorizes; otherwise we
206 /// peel `tail_count` scalar iterations into the exit block.
207 head_count: i64,
208 /// Remaining iterations after the head (always `< lanes`).
209 tail_count: i64,
210 /// Span to use for synthesised instructions.
211 span: crate::lexer::Span,
212 }
213
214 fn vectorize_one_loop(func: &mut Function) -> bool {
215 let loops = find_natural_loops(func);
216 if loops.is_empty() {
217 return false;
218 }
219 let preds = predecessors(func);
220
221 for lp in &loops {
222 // Try element-wise vectorization first (no escaping values).
223 if let Some(shape) = detect_counted_loop(func, lp, &preds) {
224 let loop_defs = loop_defined_values(func, lp);
225 if !loop_values_escape(func, lp, &loop_defs) {
226 if let Some(plan) = build_vector_plan(func, &shape, &loop_defs) {
227 apply_vector_plan(func, &shape, plan);
228 return true;
229 }
230 }
231 }
232 // WHERE-block diamond (4-block: header / body / then / incr).
233 if let Some(shape) = detect_where_loop(func, lp, &preds) {
234 if let Some(plan) = build_where_plan(func, &shape) {
235 apply_where_plan(func, &shape, plan);
236 return true;
237 }
238 }
239 // Fall back: reduction loop (one escaping accumulator).
240 if let Some(plan) = detect_reduction_plan(func, lp, &preds) {
241 apply_reduction_plan(func, lp, plan);
242 return true;
243 }
244 }
245 false
246 }
247
248 fn detect_counted_loop(
249 func: &Function,
250 lp: &NaturalLoop,
251 preds: &HashMap<BlockId, Vec<BlockId>>,
252 ) -> Option<CountedLoop> {
253 if lp.latches.len() != 1 || lp.body.len() != 2 {
254 return None;
255 }
256 let header = lp.header;
257 let body = lp.latches[0];
258 if body == header {
259 return None;
260 }
261 let header_block = func.block(header);
262 if header_block.params.len() != 1 {
263 return None;
264 }
265 let iv_param = header_block.params[0].id;
266 if !matches!(header_block.params[0].ty, IrType::Int(_)) {
267 return None;
268 }
269 let preheader = find_preheader(func, lp, preds)?;
270 let iv_init = match &func.block(preheader).terminator {
271 Some(Terminator::Branch(dest, args)) if *dest == header && args.len() == 1 => {
272 resolve_const_int(func, args[0])?
273 }
274 _ => return None,
275 };
276 let (cond_id, true_dest, false_dest, true_args, false_args) = match &header_block.terminator {
277 Some(Terminator::CondBranch {
278 cond,
279 true_dest,
280 true_args,
281 false_dest,
282 false_args,
283 }) => (*cond, *true_dest, *false_dest, true_args, false_args),
284 _ => return None,
285 };
286 if !true_args.is_empty()
287 || !false_args.is_empty()
288 || true_dest != body
289 || lp.body.contains(&false_dest)
290 {
291 return None;
292 }
293 let cond_inst = header_block.insts.iter().find(|inst| inst.id == cond_id)?;
294 let (iv_bound, bound_const_id) = match cond_inst.kind {
295 InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => {
296 (resolve_const_int(func, rhs)?, rhs)
297 }
298 InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => {
299 (resolve_const_int(func, rhs)?.checked_sub(1)?, rhs)
300 }
301 _ => return None,
302 };
303 Some(CountedLoop {
304 preheader,
305 header,
306 body,
307 iv_param,
308 iv_init,
309 iv_bound,
310 cond_id,
311 bound_const_id,
312 })
313 }
314
315 /// Detect a counted WHERE-block diamond:
316 /// header(iv): icmp ge iv, hi; cond_br c, exit, body
317 /// body: load + cmp + cond_br mask, then, incr
318 /// then: store(s) + br incr
319 /// incr: iv+1 + br header(iv+1)
320 fn detect_where_loop(
321 func: &Function,
322 lp: &NaturalLoop,
323 preds: &HashMap<BlockId, Vec<BlockId>>,
324 ) -> Option<WhereLoop> {
325 // 4 blocks: header / body / then / incr (single-arm WHERE).
326 // 5 blocks: header / body / then / else / incr (WHERE/ELSEWHERE).
327 if lp.latches.len() != 1 || (lp.body.len() != 4 && lp.body.len() != 5) {
328 return None;
329 }
330 let header = lp.header;
331 let incr_block = lp.latches[0];
332 if incr_block == header {
333 return None;
334 }
335 let header_block = func.block(header);
336 if header_block.params.len() != 1 {
337 return None;
338 }
339 let iv_param = header_block.params[0].id;
340 if !matches!(header_block.params[0].ty, IrType::Int(_)) {
341 return None;
342 }
343 let preheader = find_preheader(func, lp, preds)?;
344 let iv_init = match &func.block(preheader).terminator {
345 Some(Terminator::Branch(dest, args)) if *dest == header && args.len() == 1 => {
346 resolve_const_int(func, args[0])?
347 }
348 _ => return None,
349 };
350 // Header terminator: cond_br with body on FALSE (exit on TRUE).
351 let (cond_id, true_dest, false_dest) = match &header_block.terminator {
352 Some(Terminator::CondBranch {
353 cond,
354 true_dest,
355 true_args,
356 false_dest,
357 false_args,
358 }) if true_args.is_empty() && false_args.is_empty() => {
359 (*cond, *true_dest, *false_dest)
360 }
361 _ => return None,
362 };
363 if lp.body.contains(&true_dest) || !lp.body.contains(&false_dest) {
364 return None;
365 }
366 let body = false_dest;
367 // Header cmp: `icmp ge iv, hi` (or `gt`, in which case bound is hi-1).
368 let cond_inst = header_block.insts.iter().find(|inst| inst.id == cond_id)?;
369 let (iv_bound, bound_const_id) = match cond_inst.kind {
370 InstKind::ICmp(CmpOp::Ge, lhs, rhs) if lhs == iv_param => {
371 (resolve_const_int(func, rhs)?.checked_sub(1)?, rhs)
372 }
373 InstKind::ICmp(CmpOp::Gt, lhs, rhs) if lhs == iv_param => {
374 (resolve_const_int(func, rhs)?, rhs)
375 }
376 _ => return None,
377 };
378 // Body terminator: cond_br to {then, incr}, with incr being the
379 // latch and `then` being a 4th body block.
380 let body_block = func.block(body);
381 let (then_block, body_else) = match &body_block.terminator {
382 Some(Terminator::CondBranch {
383 cond: _,
384 true_dest,
385 true_args,
386 false_dest,
387 false_args,
388 }) if true_args.is_empty() && false_args.is_empty() => (*true_dest, *false_dest),
389 _ => return None,
390 };
391 if !lp.body.contains(&then_block) || !lp.body.contains(&body_else) {
392 return None;
393 }
394 if then_block == header || then_block == incr_block || then_block == body {
395 return None;
396 }
397 // body_else may be either incr_block (single-arm WHERE) or a
398 // distinct else_block that itself branches to incr_block
399 // (WHERE/ELSEWHERE two-arm).
400 let else_block = if body_else == incr_block {
401 None
402 } else {
403 if body_else == header || body_else == body || body_else == then_block {
404 return None;
405 }
406 let else_blk = func.block(body_else);
407 match &else_blk.terminator {
408 Some(Terminator::Branch(d, args)) if *d == incr_block && args.is_empty() => {}
409 _ => return None,
410 }
411 Some(body_else)
412 };
413 // The then block must br unconditionally to incr.
414 let then_blk = func.block(then_block);
415 match &then_blk.terminator {
416 Some(Terminator::Branch(d, args)) if *d == incr_block && args.is_empty() => {}
417 _ => return None,
418 }
419 // The incr block must be `iadd iv, 1; br header(iv+1)`.
420 let incr_blk = func.block(incr_block);
421 match &incr_blk.terminator {
422 Some(Terminator::Branch(d, args)) if *d == header && args.len() == 1 => {}
423 _ => return None,
424 }
425 Some(WhereLoop {
426 preheader,
427 header,
428 body,
429 then_block,
430 else_block,
431 incr_block,
432 iv_param,
433 iv_init,
434 iv_bound,
435 cond_id,
436 bound_const_id,
437 })
438 }
439
440 fn build_vector_plan(
441 func: &Function,
442 shape: &CountedLoop,
443 loop_defs: &HashSet<ValueId>,
444 ) -> Option<VectorPlan> {
445 let body = func.block(shape.body);
446
447 // Reject loops with calls in the body — too risky to vectorize.
448 if body
449 .insts
450 .iter()
451 .any(|inst| matches!(inst.kind, InstKind::Call(..) | InstKind::RuntimeCall(..)))
452 {
453 return None;
454 }
455
456 // Walk every store. Each one must have a destination access that
457 // covers the full iteration space and an element type identical
458 // to the first store's. Each statement is classified independently
459 // (Copy or Binop) and contributes one entry to `statements`.
460 let stores: Vec<(ValueId, crate::lexer::Span, ValueId, ValueId)> = body
461 .insts
462 .iter()
463 .filter_map(|inst| match inst.kind {
464 InstKind::Store(value, ptr) => Some((inst.id, inst.span, value, ptr)),
465 _ => None,
466 })
467 .collect();
468 if stores.is_empty() {
469 return None;
470 }
471
472 // Pin the lane shape on the first destination; reject any later
473 // store whose dest disagrees.
474 let first_dest = classify_array_access(func, stores[0].3, shape.iv_param)?;
475 if !covers_full_array(shape, &first_dest) {
476 return None;
477 }
478 let lanes = lane_count_for(&first_dest.elem_ty)?;
479 let trip = shape
480 .iv_bound
481 .checked_sub(shape.iv_init)
482 .and_then(|d| d.checked_add(1))?;
483 if trip <= 0 {
484 return None;
485 }
486 // Head count is the largest multiple of `lanes` that fits within
487 // the trip; the remainder runs as scalar tail peeled into the
488 // exit block.
489 let head_count = trip - (trip % lanes as i64);
490 if head_count == 0 {
491 // Not a single full vector iteration would run — bail.
492 return None;
493 }
494 let tail_count = trip - head_count;
495 let elem_ty = first_dest.elem_ty.clone();
496 let span = stores[0].1;
497
498 let defs = inst_map(func);
499 let mut statements = Vec::with_capacity(stores.len());
500 for (store_id, _, stored_value, dest_ptr) in &stores {
501 let dest = classify_array_access(func, *dest_ptr, shape.iv_param)?;
502 if !covers_full_array(shape, &dest) {
503 return None;
504 }
505 if dest.elem_ty != elem_ty {
506 return None;
507 }
508 let stored_inst = defs.get(stored_value)?;
509 let op = classify_body_op(*stored_value, &stored_inst.kind, func, shape, &dest, loop_defs)?;
510 statements.push(Statement {
511 op,
512 store: *store_id,
513 });
514 }
515
516 // Find the iv-increment in the body.
517 let body_term = match &body.terminator {
518 Some(Terminator::Branch(dest, args)) if *dest == shape.header && args.len() == 1 => {
519 args[0]
520 }
521 _ => return None,
522 };
523 let step_inst = defs.get(&body_term)?;
524 let (step_lhs, step_rhs) = match step_inst.kind {
525 InstKind::IAdd(l, r) => (l, r),
526 _ => return None,
527 };
528 let (step_const, _step_value) = if step_lhs == shape.iv_param {
529 (step_rhs, resolve_const_int(func, step_rhs)?)
530 } else if step_rhs == shape.iv_param {
531 (step_lhs, resolve_const_int(func, step_lhs)?)
532 } else {
533 return None;
534 };
535
536 // Pull the IV's int width from the ConstInt that defines the step.
537 let iv_int_width = match defs.get(&step_const)?.kind {
538 InstKind::ConstInt(_, w) => w,
539 _ => return None,
540 };
541
542 Some(VectorPlan {
543 lanes,
544 elem_ty,
545 head_count,
546 tail_count,
547 statements,
548 step_iadd: step_inst.id,
549 step_const,
550 iv_int_width,
551 span,
552 })
553 }
554
555 /// Classify the expression `stored_value = ...` feeding one store as
556 /// either a `Copy` (pure load), a `Unary` (neg/abs), or a `Binop`.
557 /// Returns `None` for any shape we don't yet vectorize.
558 fn classify_body_op(
559 stored_value: ValueId,
560 kind: &InstKind,
561 func: &Function,
562 shape: &CountedLoop,
563 dest: &ArrayAccess,
564 loop_defs: &HashSet<ValueId>,
565 ) -> Option<BodyOp> {
566 match kind {
567 InstKind::Load(_) => {
568 let source =
569 classify_binop_operand(func, stored_value, shape.iv_param, dest, loop_defs)?;
570 match source {
571 BinopOperand::ArrayLoad(_) => Some(BodyOp::Copy { source }),
572 BinopOperand::InvariantScalar(_) => None,
573 }
574 }
575 InstKind::INeg(src) | InstKind::FNeg(src) => {
576 unary_body(stored_value, UnaryKind::Neg, *src, func, shape, dest, loop_defs)
577 }
578 InstKind::FAbs(src) => {
579 unary_body(stored_value, UnaryKind::Abs, *src, func, shape, dest, loop_defs)
580 }
581 InstKind::FSqrt(src) => {
582 // sqrt is float-only.
583 if !matches!(dest.elem_ty, IrType::Float(_)) {
584 return None;
585 }
586 unary_body(stored_value, UnaryKind::Sqrt, *src, func, shape, dest, loop_defs)
587 }
588 InstKind::IAdd(l, r) => {
589 binop_body(stored_value, BinaryKind::Add, *l, *r, func, shape, dest, loop_defs)
590 }
591 InstKind::ISub(l, r) => {
592 binop_body(stored_value, BinaryKind::Sub, *l, *r, func, shape, dest, loop_defs)
593 }
594 InstKind::IMul(l, r) => {
595 binop_body(stored_value, BinaryKind::Mul, *l, *r, func, shape, dest, loop_defs)
596 }
597 InstKind::FAdd(l, r) => {
598 // Detect element-wise FMA: `c(i) = a(i)*b(i) + d(i)`.
599 // The store value is FAdd whose one operand is an FMul of
600 // two operands (each load or invariant scalar). NEON
601 // FMLA is float-only, so gate on a Float dest.
602 if matches!(dest.elem_ty, IrType::Float(_)) {
603 if let Some(fma) =
604 fma_body(stored_value, *l, *r, func, shape, dest, loop_defs)
605 {
606 return Some(fma);
607 }
608 }
609 binop_body(stored_value, BinaryKind::Add, *l, *r, func, shape, dest, loop_defs)
610 }
611 InstKind::FSub(l, r) => {
612 binop_body(stored_value, BinaryKind::Sub, *l, *r, func, shape, dest, loop_defs)
613 }
614 InstKind::FMul(l, r) => {
615 binop_body(stored_value, BinaryKind::Mul, *l, *r, func, shape, dest, loop_defs)
616 }
617 InstKind::FDiv(l, r) => {
618 // Integer divide has no NEON form; only floats. The
619 // binop_body classifier doesn't itself check element
620 // type, but `lane_count_for` already required the dest
621 // to be Float for this code path to reach here.
622 if !matches!(dest.elem_ty, IrType::Float(_)) {
623 return None;
624 }
625 binop_body(stored_value, BinaryKind::Div, *l, *r, func, shape, dest, loop_defs)
626 }
627 // Element-wise `c(i) = max(a(i), b(i))` and `min(...)`. The
628 // IR shape is `select(cmp(la, lb), t, f)` where {t, f} is
629 // {la, lb}. Only fires when the cmp's operands match the
630 // select's true/false arms (in some order); all four slots
631 // must be classifiable as load or invariant scalar.
632 InstKind::Select(cmp_v, t, f) => {
633 let defs = inst_map(func);
634 let cmp_inst = defs.get(cmp_v)?;
635 let (cmp_op, cmp_a, cmp_b) = match cmp_inst.kind {
636 InstKind::ICmp(op, a, b) | InstKind::FCmp(op, a, b) => (op, a, b),
637 _ => return None,
638 };
639 // The select's arms must be exactly the cmp's operands in
640 // some order so the result is `max` or `min` of them.
641 let bk = if cmp_a == *t && cmp_b == *f {
642 match cmp_op {
643 CmpOp::Ge | CmpOp::Gt => BinaryKind::Max,
644 CmpOp::Le | CmpOp::Lt => BinaryKind::Min,
645 _ => return None,
646 }
647 } else if cmp_a == *f && cmp_b == *t {
648 match cmp_op {
649 CmpOp::Ge | CmpOp::Gt => BinaryKind::Min,
650 CmpOp::Le | CmpOp::Lt => BinaryKind::Max,
651 _ => return None,
652 }
653 } else {
654 return None;
655 };
656 binop_body(stored_value, bk, *t, *f, func, shape, dest, loop_defs)
657 }
658 _ => None,
659 }
660 }
661
662 /// Classify a body that is `dest(i) = -src` or `dest(i) = abs(src)`.
663 /// `src` must be an array load (a unary on an invariant scalar would
664 /// just be a constant fill).
665 fn unary_body(
666 unary_id: ValueId,
667 kind: UnaryKind,
668 src_v: ValueId,
669 func: &Function,
670 shape: &CountedLoop,
671 dest: &ArrayAccess,
672 loop_defs: &HashSet<ValueId>,
673 ) -> Option<BodyOp> {
674 let source = classify_binop_operand(func, src_v, shape.iv_param, dest, loop_defs)?;
675 match source {
676 BinopOperand::ArrayLoad(_) => Some(BodyOp::Unary {
677 source,
678 unary_id,
679 kind,
680 }),
681 BinopOperand::InvariantScalar(_) => None,
682 }
683 }
684
685 /// Classify a body that is `dest(i) = lhs op rhs`. At least one
686 /// side must be an array load — the all-scalar form has no business
687 /// being a vectorizable counted loop.
688 fn binop_body(
689 binop_id: ValueId,
690 kind: BinaryKind,
691 lhs_v: ValueId,
692 rhs_v: ValueId,
693 func: &Function,
694 shape: &CountedLoop,
695 dest: &ArrayAccess,
696 loop_defs: &HashSet<ValueId>,
697 ) -> Option<BodyOp> {
698 let lhs_op = classify_binop_operand(func, lhs_v, shape.iv_param, dest, loop_defs)?;
699 let rhs_op = classify_binop_operand(func, rhs_v, shape.iv_param, dest, loop_defs)?;
700 if matches!(lhs_op, BinopOperand::InvariantScalar(_))
701 && matches!(rhs_op, BinopOperand::InvariantScalar(_))
702 {
703 return None;
704 }
705 Some(BodyOp::Binop {
706 lhs: lhs_op,
707 rhs: rhs_op,
708 binop_id,
709 kind,
710 })
711 }
712
713 /// Classify a body that's `dest(i) = (a*b) + c` (or `c + (a*b)`).
714 /// `fadd_id` is the FAdd's value id; `lhs_v` and `rhs_v` are its
715 /// operands. One of them must itself be an `FMul` whose two operands
716 /// are each load-or-invariant-scalar; the other operand is `c`.
717 fn fma_body(
718 fadd_id: ValueId,
719 lhs_v: ValueId,
720 rhs_v: ValueId,
721 func: &Function,
722 shape: &CountedLoop,
723 dest: &ArrayAccess,
724 loop_defs: &HashSet<ValueId>,
725 ) -> Option<BodyOp> {
726 let defs = inst_map(func);
727 let try_fmul = |fmul_v: ValueId, other_v: ValueId| -> Option<BodyOp> {
728 let fmul_inst = defs.get(&fmul_v)?;
729 let (a_v, b_v) = match fmul_inst.kind {
730 InstKind::FMul(a, b) => (a, b),
731 _ => return None,
732 };
733 let a = classify_binop_operand(func, a_v, shape.iv_param, dest, loop_defs)?;
734 let b = classify_binop_operand(func, b_v, shape.iv_param, dest, loop_defs)?;
735 let c = classify_binop_operand(func, other_v, shape.iv_param, dest, loop_defs)?;
736 // At least one operand must be an array load — otherwise
737 // there's no per-iteration data to vectorize.
738 if matches!(a, BinopOperand::InvariantScalar(_))
739 && matches!(b, BinopOperand::InvariantScalar(_))
740 && matches!(c, BinopOperand::InvariantScalar(_))
741 {
742 return None;
743 }
744 Some(BodyOp::Fma {
745 a,
746 b,
747 c,
748 fmul_id: fmul_v,
749 fadd_id,
750 })
751 };
752 if let Some(op) = try_fmul(lhs_v, rhs_v) {
753 return Some(op);
754 }
755 try_fmul(rhs_v, lhs_v)
756 }
757
758 /// Classify one operand of the body's binop as either a load from
759 /// the destination array's iteration space (which becomes a `VLoad`)
760 /// or a value defined entirely outside the loop (which becomes a
761 /// preheader `VBroadcast`).
762 fn classify_binop_operand(
763 func: &Function,
764 value: ValueId,
765 iv_param: ValueId,
766 dest: &ArrayAccess,
767 loop_defs: &HashSet<ValueId>,
768 ) -> Option<BinopOperand> {
769 if let Some(load) = classify_loaded_array(func, value, iv_param) {
770 if !arrays_compatible(dest, &load.access) {
771 return None;
772 }
773 return Some(BinopOperand::ArrayLoad(load.load_id));
774 }
775 // Not an array load: only valid if it is loop-invariant.
776 if loop_defs.contains(&value) {
777 return None;
778 }
779 // Type must match the destination element type so the broadcast
780 // produces a vector compatible with the rewritten binop.
781 let ty = func.value_type(value)?;
782 if ty != dest.elem_ty {
783 return None;
784 }
785 Some(BinopOperand::InvariantScalar(value))
786 }
787
788 #[derive(Debug, Clone)]
789 struct LoadedArray {
790 load_id: ValueId,
791 access: ArrayAccess,
792 }
793
794 fn classify_loaded_array(
795 func: &Function,
796 value: ValueId,
797 iv_param: ValueId,
798 ) -> Option<LoadedArray> {
799 let defs = inst_map(func);
800 let inst = defs.get(&value)?;
801 let InstKind::Load(ptr) = inst.kind else {
802 return None;
803 };
804 let access = classify_array_access(func, ptr, iv_param)?;
805 Some(LoadedArray {
806 load_id: inst.id,
807 access,
808 })
809 }
810
811 fn classify_array_access(func: &Function, ptr: ValueId, iv_param: ValueId) -> Option<ArrayAccess> {
812 let defs = inst_map(func);
813 let inst = defs.get(&ptr)?;
814 let InstKind::GetElementPtr(base, ref indices) = inst.kind else {
815 return None;
816 };
817 if indices.len() != 1 {
818 return None;
819 }
820 let IrType::Ptr(inner) = func.value_type(base)? else {
821 return None;
822 };
823 let IrType::Array(elem, len) = inner.as_ref() else {
824 return None;
825 };
826 let lower = normalized_index_lower(func, indices[0], iv_param)
827 .or_else(|| byte_stride_lower(func, indices[0], iv_param, elem.as_ref()))?;
828 Some(ArrayAccess {
829 base,
830 elem_ty: elem.as_ref().clone(),
831 len: *len,
832 lower,
833 })
834 }
835
836 /// Recognize the byte-stride form `shl(iv, log2(elem_bytes))` (with
837 /// an optional `IntExtend` between iv and shl). Returns the lower
838 /// bound (currently only `0`, since the matcher requires the access
839 /// to start at iv_init = the array's lower bound).
840 fn byte_stride_lower(
841 func: &Function,
842 value: ValueId,
843 iv_param: ValueId,
844 elem_ty: &IrType,
845 ) -> Option<i64> {
846 let defs = inst_map(func);
847 let inst = defs.get(&value)?;
848 let (lhs, rhs) = match inst.kind {
849 InstKind::Shl(l, r) => (l, r),
850 _ => return None,
851 };
852 let shift = resolve_const_int(func, rhs)?;
853 let bytes = elem_size_bytes(elem_ty)?;
854 if bytes <= 0 || (1i64 << shift) != bytes {
855 return None;
856 }
857 if lhs == iv_param {
858 return Some(0);
859 }
860 let inner = defs.get(&lhs)?;
861 if let InstKind::IntExtend(src, _, _) = inner.kind {
862 if src == iv_param {
863 return Some(0);
864 }
865 }
866 None
867 }
868
869 fn normalized_index_lower(func: &Function, value: ValueId, iv_param: ValueId) -> Option<i64> {
870 if value == iv_param {
871 return Some(0);
872 }
873 let defs = inst_map(func);
874 let inst = defs.get(&value)?;
875 match inst.kind {
876 InstKind::IntExtend(src, IntWidth::I64, _) if src == iv_param => Some(0),
877 InstKind::ISub(lhs, rhs) => {
878 let lhs_lower = normalized_index_lower(func, lhs, iv_param)?;
879 let rhs_const = resolve_const_int(func, rhs)?;
880 lhs_lower.checked_add(rhs_const)
881 }
882 _ => None,
883 }
884 }
885
886 fn arrays_compatible(dest: &ArrayAccess, other: &ArrayAccess) -> bool {
887 dest.elem_ty == other.elem_ty && dest.len == other.len && dest.lower == other.lower
888 }
889
890 fn covers_full_array(shape: &CountedLoop, access: &ArrayAccess) -> bool {
891 if access.len == 0 {
892 return false;
893 }
894 let Some(upper) = access
895 .lower
896 .checked_add(access.len as i64)
897 .and_then(|value| value.checked_sub(1))
898 else {
899 return false;
900 };
901 shape.iv_init == access.lower && shape.iv_bound == upper
902 }
903
904 fn loop_values_escape(func: &Function, lp: &NaturalLoop, loop_defs: &HashSet<ValueId>) -> bool {
905 for block in &func.blocks {
906 if lp.body.contains(&block.id) {
907 continue;
908 }
909 if block.insts.iter().any(|inst| {
910 inst_uses(&inst.kind)
911 .into_iter()
912 .any(|value| loop_defs.contains(&value))
913 }) {
914 return true;
915 }
916 if block.terminator.as_ref().is_some_and(|term| {
917 terminator_uses(term)
918 .into_iter()
919 .any(|value| loop_defs.contains(&value))
920 }) {
921 return true;
922 }
923 }
924 false
925 }
926
927 fn lane_count_for(elem: &IrType) -> Option<u8> {
928 match elem {
929 IrType::Int(IntWidth::I32) => Some(4),
930 IrType::Int(IntWidth::I64) => Some(2),
931 IrType::Float(FloatWidth::F32) => Some(4),
932 IrType::Float(FloatWidth::F64) => Some(2),
933 _ => None,
934 }
935 }
936
937 /// Size of a scalar IR type in bytes. Only used to recognize
938 /// byte-stride GEP indexing in WHERE-block lowering, where a gep
939 /// index of `shl(iv, log2(elem_size))` denotes the i-th element.
940 fn elem_size_bytes(elem: &IrType) -> Option<i64> {
941 match elem {
942 IrType::Int(w) => Some(w.bytes() as i64),
943 IrType::Float(w) => Some((w.bits() / 8) as i64),
944 _ => None,
945 }
946 }
947
948 fn vector_ty(elem: &IrType, lanes: u8) -> IrType {
949 IrType::Vector {
950 lanes,
951 elem: Box::new(elem.clone()),
952 }
953 }
954
955 fn apply_vector_plan(func: &mut Function, shape: &CountedLoop, plan: VectorPlan) {
956 let v_ty = vector_ty(&plan.elem_ty, plan.lanes);
957
958 // 0. If we'll be peeling scalar tail iterations, snapshot the body's
959 // instruction list BEFORE we mutate it in place. The snapshot
960 // holds the original (scalar) Load/Store/Binop shape that the
961 // peel walks per remainder iteration. Take a clone of the Vec
962 // so subsequent in-place mutation doesn't disturb the snapshot.
963 let body_snapshot: Option<Vec<Inst>> = if plan.tail_count > 0 {
964 Some(func.block(shape.body).insts.clone())
965 } else {
966 None
967 };
968 // Identify the exit block (the false-dest of the header's
969 // cond_br) so we know where to peel into.
970 let exit_block_id: Option<BlockId> = if plan.tail_count > 0 {
971 match &func.block(shape.header).terminator {
972 Some(Terminator::CondBranch { false_dest, .. }) => Some(*false_dest),
973 _ => None,
974 }
975 } else {
976 None
977 };
978
979 // 1. Replace the step `iadd iv, 1` constant operand with V (using
980 // a fresh ConstInt to avoid clobbering shared `1` constants).
981 let new_step_const = func.next_value_id();
982 let step_const_ty = IrType::Int(plan.iv_int_width);
983 func.register_type(new_step_const, step_const_ty.clone());
984 let body_block = func.block_mut(shape.body);
985 // Insert the new const at the top of the body block.
986 body_block.insts.insert(
987 0,
988 Inst {
989 id: new_step_const,
990 kind: InstKind::ConstInt(plan.lanes as i128, plan.iv_int_width),
991 ty: step_const_ty,
992 span: plan.span,
993 },
994 );
995 // Update the iadd to reference the new const.
996 if let Some(step_inst) = body_block
997 .insts
998 .iter_mut()
999 .find(|inst| inst.id == plan.step_iadd)
1000 {
1001 if let InstKind::IAdd(ref mut l, ref mut r) = step_inst.kind {
1002 if *l == plan.step_const {
1003 *l = new_step_const;
1004 }
1005 if *r == plan.step_const {
1006 *r = new_step_const;
1007 }
1008 }
1009 }
1010
1011 // 2. For each statement, vectorize: rewrite array loads to VLoads,
1012 // emit any required preheader VBroadcasts, rewrite the binop
1013 // into a v-op, and finally rewrite the store into a VStore.
1014 for stmt in plan.statements.clone() {
1015 for op in op_operands(&stmt.op) {
1016 rewrite_array_load(func, shape.body, op, &v_ty);
1017 }
1018 let (lhs_subst, rhs_subst) = match &stmt.op {
1019 BodyOp::Copy { .. } | BodyOp::Unary { .. } => (None, None),
1020 BodyOp::Binop { lhs, rhs, .. } => (
1021 broadcast_if_invariant(func, shape.preheader, lhs, &v_ty, plan.span),
1022 broadcast_if_invariant(func, shape.preheader, rhs, &v_ty, plan.span),
1023 ),
1024 BodyOp::Fma { .. } => (None, None),
1025 };
1026 let fma_subst = if let BodyOp::Fma { a, b, c, .. } = &stmt.op {
1027 Some((
1028 broadcast_if_invariant(func, shape.preheader, a, &v_ty, plan.span),
1029 broadcast_if_invariant(func, shape.preheader, b, &v_ty, plan.span),
1030 broadcast_if_invariant(func, shape.preheader, c, &v_ty, plan.span),
1031 ))
1032 } else {
1033 None
1034 };
1035
1036 if let BodyOp::Unary {
1037 unary_id,
1038 kind: unary_kind,
1039 ..
1040 } = &stmt.op
1041 {
1042 let body_block = func.block_mut(shape.body);
1043 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == *unary_id) {
1044 let new_kind = match (inst.kind.clone(), unary_kind) {
1045 (InstKind::INeg(s), UnaryKind::Neg)
1046 | (InstKind::FNeg(s), UnaryKind::Neg) => InstKind::VNeg(s),
1047 (InstKind::FAbs(s), UnaryKind::Abs) => InstKind::VAbs(s),
1048 (InstKind::FSqrt(s), UnaryKind::Sqrt) => InstKind::VSqrt(s),
1049 _ => inst.kind.clone(),
1050 };
1051 inst.kind = new_kind;
1052 inst.ty = v_ty.clone();
1053 }
1054 func.register_type(*unary_id, v_ty.clone());
1055 }
1056
1057 if let BodyOp::Binop {
1058 binop_id,
1059 kind: binop_kind,
1060 ..
1061 } = &stmt.op
1062 {
1063 let body_block = func.block_mut(shape.body);
1064 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == *binop_id) {
1065 let new_kind = match (inst.kind.clone(), binop_kind) {
1066 (InstKind::IAdd(l, r), BinaryKind::Add)
1067 | (InstKind::FAdd(l, r), BinaryKind::Add) => {
1068 InstKind::VAdd(lhs_subst.unwrap_or(l), rhs_subst.unwrap_or(r))
1069 }
1070 (InstKind::ISub(l, r), BinaryKind::Sub)
1071 | (InstKind::FSub(l, r), BinaryKind::Sub) => {
1072 InstKind::VSub(lhs_subst.unwrap_or(l), rhs_subst.unwrap_or(r))
1073 }
1074 (InstKind::IMul(l, r), BinaryKind::Mul)
1075 | (InstKind::FMul(l, r), BinaryKind::Mul) => {
1076 InstKind::VMul(lhs_subst.unwrap_or(l), rhs_subst.unwrap_or(r))
1077 }
1078 (InstKind::FDiv(l, r), BinaryKind::Div) => {
1079 InstKind::VDiv(lhs_subst.unwrap_or(l), rhs_subst.unwrap_or(r))
1080 }
1081 (InstKind::Select(_, t, f), BinaryKind::Max) => {
1082 InstKind::VMax(lhs_subst.unwrap_or(t), rhs_subst.unwrap_or(f))
1083 }
1084 (InstKind::Select(_, t, f), BinaryKind::Min) => {
1085 InstKind::VMin(lhs_subst.unwrap_or(t), rhs_subst.unwrap_or(f))
1086 }
1087 _ => inst.kind.clone(),
1088 };
1089 inst.kind = new_kind;
1090 inst.ty = v_ty.clone();
1091 }
1092 func.register_type(*binop_id, v_ty.clone());
1093 }
1094
1095 if let BodyOp::Fma {
1096 fmul_id, fadd_id, ..
1097 } = &stmt.op
1098 {
1099 let (a_subst, b_subst, c_subst) = fma_subst.unwrap();
1100 let body_block = func.block_mut(shape.body);
1101 // Rewrite fmul to VMul (becomes dead — DCE will clean up;
1102 // we still rewrite to avoid leaving a scalar fmul whose
1103 // operands have been retyped to vectors).
1104 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == *fmul_id) {
1105 if let InstKind::FMul(l, r) = inst.kind {
1106 inst.kind = InstKind::VMul(a_subst.unwrap_or(l), b_subst.unwrap_or(r));
1107 inst.ty = v_ty.clone();
1108 }
1109 }
1110 func.register_type(*fmul_id, v_ty.clone());
1111 // Rewrite fadd to VFma(a, b, c). Lookup fmul to recover
1112 // its (possibly subst'd) operands so VFma reads the
1113 // original / broadcast values rather than the dead VMul.
1114 let (a_v, b_v) = {
1115 let body_ro = func.block(shape.body);
1116 let fmul_inst = body_ro.insts.iter().find(|i| i.id == *fmul_id).unwrap();
1117 if let InstKind::VMul(l, r) = fmul_inst.kind {
1118 (l, r)
1119 } else {
1120 unreachable!()
1121 }
1122 };
1123 let body_block = func.block_mut(shape.body);
1124 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == *fadd_id) {
1125 if let InstKind::FAdd(l, r) = inst.kind {
1126 let c = if l == *fmul_id { r } else { l };
1127 let c_final = c_subst.unwrap_or(c);
1128 inst.kind = InstKind::VFma(a_v, b_v, c_final);
1129 inst.ty = v_ty.clone();
1130 }
1131 }
1132 func.register_type(*fadd_id, v_ty.clone());
1133 }
1134
1135 let body_block = func.block_mut(shape.body);
1136 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == stmt.store) {
1137 if let InstKind::Store(val, ptr) = inst.kind {
1138 inst.kind = InstKind::VStore(val, ptr);
1139 }
1140 }
1141 }
1142
1143 // 3. Scalar tail. If `tail_count` remainder iterations live at the
1144 // end of the loop, retarget the original icmp's bound to
1145 // `iv_init + head_count - 1` and peel the remaining scalar
1146 // iterations into the head of the exit block.
1147 if plan.tail_count > 0 {
1148 if let (Some(snapshot), Some(exit_block)) = (body_snapshot, exit_block_id) {
1149 apply_scalar_tail_peel(func, shape, &plan, &snapshot, exit_block);
1150 }
1151 }
1152 }
1153
1154 #[derive(Debug, Clone, Copy)]
1155 struct ThenBinop {
1156 inst_id: ValueId,
1157 kind: BinaryKind,
1158 /// When `binop_on_load_b == false`: the non-load_a operand. With
1159 /// `other_is_load_b == false`, this is a loop-invariant scalar
1160 /// to broadcast; with `other_is_load_b == true`, this is the
1161 /// second array's `then_load_b` value id (the b-array load
1162 /// defined in the then_block).
1163 /// When `binop_on_load_b == true`: this is the loop-invariant
1164 /// scalar paired with the b-array load (load_a is unused; e.g.
1165 /// `c = K + d` where d is the second array).
1166 scalar_v: ValueId,
1167 /// Whether the "main" load is on the LHS of the binop. The main
1168 /// load is load_a unless `binop_on_load_b == true`, in which
1169 /// case it is load_b.
1170 load_on_lhs: bool,
1171 /// True iff the non-load_a operand is a second array load (`c = a + b`).
1172 /// Always false when `binop_on_load_b == true`.
1173 other_is_load_b: bool,
1174 /// True iff the binop's main load is load_b (`c = K + d` where d
1175 /// is the second array). When false, the main load is load_a
1176 /// (current default — `c = a + K` or `c = a + b`).
1177 binop_on_load_b: bool,
1178 }
1179
1180 /// Vectorizable WHERE-block plan: one conditional store guarded by
1181 /// a scalar fcmp/icmp predicate. Only the simplest shape is handled
1182 /// for now: the store value is a load of the same pointer used in
1183 /// the predicate (`b(i) = a(i)` under `where (a(i) op K)`).
1184 #[derive(Debug, Clone)]
1185 struct WherePlan {
1186 lanes: u8,
1187 elem_ty: IrType,
1188 /// The load in the body block that feeds the predicate.
1189 load_a_id: ValueId,
1190 /// The cmp inst id (in body block).
1191 cmp_id: ValueId,
1192 /// Whether the cmp is fcmp (true) or icmp (false).
1193 cmp_is_float: bool,
1194 /// The cmp's CmpOp.
1195 cmp_op: CmpOp,
1196 /// The threshold operand of the cmp (the other side, not load_a).
1197 /// Must be loop-invariant.
1198 threshold_v: ValueId,
1199 /// Whether `load_a` is on the LHS of the cmp.
1200 load_on_lhs: bool,
1201 /// The conditional Store in the then_block.
1202 store_id: ValueId,
1203 /// The redundant Load in the then_block (same ptr as load_a) —
1204 /// will be dropped during rewrite.
1205 then_load_id: Option<ValueId>,
1206 /// Optional unary applied to the loaded value before storing
1207 /// (`b = -a`, `b = abs(a)`, `b = sqrt(a)`). When `Some(uid, kind,
1208 /// on_load_b)`, `uid` is the inst id in then_block, `kind` is
1209 /// the unary kind, and `on_load_b == true` means the unary is
1210 /// applied to the second-array load (`c = -d`) rather than
1211 /// load_a.
1212 then_unary: Option<(ValueId, UnaryKind, bool)>,
1213 /// Optional binop applied to the loaded value with a loop-invariant
1214 /// scalar (`b = a + K`, `b = a * scale`, etc.). The `scalar_v` is
1215 /// the invariant operand (will be broadcast in preheader);
1216 /// `load_on_lhs` indicates whether the load is the LHS of the
1217 /// binop (to preserve operand order for non-commutative ops like
1218 /// Sub/Div).
1219 then_binop: Option<ThenBinop>,
1220 /// Optional loop-invariant scalar that is stored directly
1221 /// (`where (cond) b = K`). When set, the true arm of the vselect
1222 /// is `VBroadcast(K)`; no load_a is consumed by the store.
1223 then_const: Option<ValueId>,
1224 /// When `then_binop.other_is_load_b == true`, this holds the
1225 /// b-array's GEP ptr (so the apply path can emit a VLoad on it).
1226 b_ptr_id: Option<ValueId>,
1227 /// When the WHERE has an ELSEWHERE arm, the value to store in
1228 /// the false-mask lanes. Currently supports a loop-invariant
1229 /// scalar (`elsewhere; b = K; end where`) — broadcast in the
1230 /// preheader and used as vselect's false arm in lieu of the
1231 /// dest's prior value.
1232 else_const: Option<ValueId>,
1233 /// When ELSEWHERE loads from a different array
1234 /// (`elsewhere; c = d; end where`), this is the body-defined
1235 /// GEP ptr to load via VLoad for the false-mask lanes.
1236 else_load_ptr: Option<ValueId>,
1237 /// Unary lifted from `elsewhere; b = -d` / `abs(d)` / `sqrt(d)`.
1238 /// When `Some`, the apply path applies a V-unary to the
1239 /// else_load_ptr's vload before feeding vselect's false arm.
1240 else_unary: Option<UnaryKind>,
1241 /// Binop lifted from `elsewhere; b = d + K` / `d * scale`.
1242 /// `(BinaryKind, scalar_v, load_on_lhs)`.
1243 else_binop: Option<(BinaryKind, ValueId, bool)>,
1244 /// The dest pointer GEP (computed in body block).
1245 dest_ptr_id: ValueId,
1246 /// Source array's access shape (where load_a reads from).
1247 src_access: ArrayAccess,
1248 /// Destination array's access shape (where the store writes to).
1249 dest_access: ArrayAccess,
1250 span: crate::lexer::Span,
1251 }
1252
1253 fn build_where_plan(func: &Function, shape: &WhereLoop) -> Option<WherePlan> {
1254 let body_block = func.block(shape.body);
1255 let then_block = func.block(shape.then_block);
1256 // Reject calls in body or then.
1257 if body_block
1258 .insts
1259 .iter()
1260 .chain(then_block.insts.iter())
1261 .any(|inst| matches!(inst.kind, InstKind::Call(..) | InstKind::RuntimeCall(..)))
1262 {
1263 return None;
1264 }
1265 // Find the body's cmp (terminator's cond) and its associated load_a.
1266 let cond_id = match &body_block.terminator {
1267 Some(Terminator::CondBranch { cond, .. }) => *cond,
1268 _ => return None,
1269 };
1270 let cmp_inst = body_block.insts.iter().find(|i| i.id == cond_id)?;
1271 let (cmp_op, lhs_v, rhs_v, cmp_is_float) = match cmp_inst.kind {
1272 InstKind::FCmp(op, l, r) => (op, l, r, true),
1273 InstKind::ICmp(op, l, r) => (op, l, r, false),
1274 _ => return None,
1275 };
1276 // One of {lhs_v, rhs_v} must be a Load in body block; the other
1277 // must be loop-invariant (typically a ConstFloat / ConstInt).
1278 let body_loads: Vec<&Inst> = body_block
1279 .insts
1280 .iter()
1281 .filter(|i| matches!(i.kind, InstKind::Load(_)))
1282 .collect();
1283 if body_loads.len() != 1 {
1284 return None;
1285 }
1286 let load_a = body_loads[0];
1287 let load_a_id = load_a.id;
1288 let load_a_ptr = match load_a.kind {
1289 InstKind::Load(p) => p,
1290 _ => return None,
1291 };
1292 let (threshold_v, load_on_lhs) = if lhs_v == load_a_id {
1293 (rhs_v, true)
1294 } else if rhs_v == load_a_id {
1295 (lhs_v, false)
1296 } else {
1297 return None;
1298 };
1299 // Threshold must be defined OUTSIDE the loop body (loop-invariant).
1300 // Conservative: require it to be a Const* in the function (any block),
1301 // not defined in body, then, or incr.
1302 let body_ids: HashSet<ValueId> = body_block.insts.iter().map(|i| i.id).collect();
1303 let then_ids: HashSet<ValueId> = then_block.insts.iter().map(|i| i.id).collect();
1304 let incr_ids: HashSet<ValueId> = func
1305 .block(shape.incr_block)
1306 .insts
1307 .iter()
1308 .map(|i| i.id)
1309 .collect();
1310 if body_ids.contains(&threshold_v)
1311 || then_ids.contains(&threshold_v)
1312 || incr_ids.contains(&threshold_v)
1313 {
1314 return None;
1315 }
1316 // Source array access.
1317 let src_access = classify_array_access(func, load_a_ptr, shape.iv_param)?;
1318 // Find the dest pointer GEP (computed in body block) and the store.
1319 // The then block has the conditional store; dest_ptr must be a
1320 // GEP defined in body.
1321 let body_geps: Vec<&Inst> = body_block
1322 .insts
1323 .iter()
1324 .filter(|i| matches!(i.kind, InstKind::GetElementPtr(..)))
1325 .collect();
1326 if body_geps.is_empty() {
1327 return None;
1328 }
1329 // Walk then-block for the store. Expect optionally a redundant
1330 // Load (same ptr as load_a), optionally a second array Load (a
1331 // different body-defined ptr — array+array body), optionally a
1332 // unary (FNeg/FAbs/FSqrt/INeg) OR a binop with an invariant
1333 // scalar OR a binop with the second array load, and exactly one
1334 // Store.
1335 let body_gep_ids: HashSet<ValueId> =
1336 body_geps.iter().map(|i| i.id).collect();
1337 let mut store_id = None;
1338 let mut then_load_id = None;
1339 let mut then_load_b: Option<(ValueId, ValueId)> = None;
1340 // Unary tracking: (inst_id, kind, on_load_b). on_load_b == true
1341 // means the unary is applied to the second-array load (`c = -d`)
1342 // rather than load_a.
1343 let mut then_unary: Option<(ValueId, UnaryKind, bool)> = None;
1344 let mut then_binop: Option<ThenBinop> = None;
1345 let mut store_value = None;
1346 let mut store_ptr = None;
1347 for inst in &then_block.insts {
1348 let is_load_alias = |v: ValueId| v == load_a_id || Some(v) == then_load_id;
1349 let is_load_b = |v: ValueId| then_load_b.map(|(id, _)| id) == Some(v);
1350 match inst.kind {
1351 InstKind::Load(p) if p == load_a_ptr => {
1352 if then_load_id.is_some() {
1353 return None;
1354 }
1355 then_load_id = Some(inst.id);
1356 }
1357 InstKind::Load(p) if body_gep_ids.contains(&p) => {
1358 // Second array load — array+array body (`c = a + b`).
1359 // Must be a different ptr than load_a's. We accept at
1360 // most one such load.
1361 if then_load_b.is_some() {
1362 return None;
1363 }
1364 then_load_b = Some((inst.id, p));
1365 }
1366 InstKind::FNeg(src) | InstKind::INeg(src) => {
1367 if then_unary.is_some() || then_binop.is_some() {
1368 return None;
1369 }
1370 let on_load_b = if is_load_alias(src) {
1371 false
1372 } else if is_load_b(src) {
1373 true
1374 } else {
1375 return None;
1376 };
1377 then_unary = Some((inst.id, UnaryKind::Neg, on_load_b));
1378 }
1379 InstKind::FAbs(src) => {
1380 if then_unary.is_some() || then_binop.is_some() {
1381 return None;
1382 }
1383 let on_load_b = if is_load_alias(src) {
1384 false
1385 } else if is_load_b(src) {
1386 true
1387 } else {
1388 return None;
1389 };
1390 then_unary = Some((inst.id, UnaryKind::Abs, on_load_b));
1391 }
1392 InstKind::FSqrt(src) => {
1393 if then_unary.is_some() || then_binop.is_some() {
1394 return None;
1395 }
1396 let on_load_b = if is_load_alias(src) {
1397 false
1398 } else if is_load_b(src) {
1399 true
1400 } else {
1401 return None;
1402 };
1403 then_unary = Some((inst.id, UnaryKind::Sqrt, on_load_b));
1404 }
1405 InstKind::IAdd(l, r)
1406 | InstKind::ISub(l, r)
1407 | InstKind::IMul(l, r)
1408 | InstKind::FAdd(l, r)
1409 | InstKind::FSub(l, r)
1410 | InstKind::FMul(l, r)
1411 | InstKind::FDiv(l, r) => {
1412 if then_unary.is_some() || then_binop.is_some() {
1413 return None;
1414 }
1415 let kind = match inst.kind {
1416 InstKind::IAdd(..) | InstKind::FAdd(..) => BinaryKind::Add,
1417 InstKind::ISub(..) | InstKind::FSub(..) => BinaryKind::Sub,
1418 InstKind::IMul(..) | InstKind::FMul(..) => BinaryKind::Mul,
1419 InstKind::FDiv(..) => BinaryKind::Div,
1420 _ => unreachable!(),
1421 };
1422 // Three accepted shapes:
1423 // (i) binop(load_a, scalar) — current default
1424 // (ii) binop(load_a, load_b) — array+array body
1425 // (iii) binop(load_b, scalar) — `c = K + d` where d
1426 // is a second array (load_a only feeds cmp).
1427 let (load_on_lhs, scalar_v, binop_on_load_b) = if is_load_alias(l) {
1428 (true, r, false)
1429 } else if is_load_alias(r) {
1430 (false, l, false)
1431 } else if is_load_b(l) {
1432 (true, r, true)
1433 } else if is_load_b(r) {
1434 (false, l, true)
1435 } else {
1436 return None;
1437 };
1438 let other_is_load_b = if binop_on_load_b {
1439 false
1440 } else {
1441 is_load_b(scalar_v)
1442 };
1443 then_binop = Some(ThenBinop {
1444 inst_id: inst.id,
1445 kind,
1446 scalar_v,
1447 load_on_lhs,
1448 other_is_load_b,
1449 binop_on_load_b,
1450 });
1451 }
1452 InstKind::Store(v, p) => {
1453 if store_id.is_some() {
1454 return None;
1455 }
1456 store_id = Some(inst.id);
1457 store_value = Some(v);
1458 store_ptr = Some(p);
1459 }
1460 _ => return None,
1461 }
1462 }
1463 let store_id = store_id?;
1464 let store_value = store_value?;
1465 let store_ptr = store_ptr?;
1466 // The store value must be either: load_a, the redundant then-load,
1467 // the then-block unary, the then-block binop, or a loop-invariant
1468 // scalar (typically a literal constant — `where (cond) b = K`).
1469 let unary_id = then_unary.map(|(id, _, _)| id);
1470 let binop_id = then_binop.map(|b| b.inst_id);
1471 let mut then_const: Option<ValueId> = None;
1472 if store_value != load_a_id
1473 && Some(store_value) != then_load_id
1474 && Some(store_value) != unary_id
1475 && Some(store_value) != binop_id
1476 {
1477 // Accept iff the store value is loop-invariant (defined
1478 // outside body / then / incr). The scalar will be broadcast
1479 // in the preheader and routed through vselect's true arm.
1480 if body_ids.contains(&store_value)
1481 || then_ids.contains(&store_value)
1482 || incr_ids.contains(&store_value)
1483 {
1484 return None;
1485 }
1486 // Element type of the store value must match the dest array
1487 // element type — defer the check until dest_access is known
1488 // (validated below alongside src/dest type match).
1489 then_const = Some(store_value);
1490 }
1491 // For binop: scalar operand must be loop-invariant (not defined
1492 // in body, then, or incr). The `other_is_load_b == true` case
1493 // (a + b) doesn't have a scalar; everything else does. FDiv is
1494 // float-only.
1495 if let Some(b) = then_binop {
1496 if !b.other_is_load_b
1497 && (body_ids.contains(&b.scalar_v)
1498 || then_ids.contains(&b.scalar_v)
1499 || incr_ids.contains(&b.scalar_v))
1500 {
1501 return None;
1502 }
1503 if matches!(b.kind, BinaryKind::Div) && !matches!(src_access.elem_ty, IrType::Float(_)) {
1504 return None;
1505 }
1506 }
1507 // FSqrt is float-only; INeg is int-only. The dest elem type
1508 // must match.
1509 if let Some((_, k, _)) = then_unary {
1510 match (&src_access.elem_ty, k) {
1511 (IrType::Float(_), UnaryKind::Neg)
1512 | (IrType::Float(_), UnaryKind::Abs)
1513 | (IrType::Float(_), UnaryKind::Sqrt)
1514 | (IrType::Int(_), UnaryKind::Neg) => {}
1515 _ => return None,
1516 }
1517 }
1518 let dest_access = classify_array_access(func, store_ptr, shape.iv_param)?;
1519 // Both src and dest must cover the full array.
1520 let trip = shape.iv_bound.checked_sub(shape.iv_init)?.checked_add(1)?;
1521 let src_upper = src_access
1522 .lower
1523 .checked_add(src_access.len as i64)
1524 .and_then(|v| v.checked_sub(1))?;
1525 let dest_upper = dest_access
1526 .lower
1527 .checked_add(dest_access.len as i64)
1528 .and_then(|v| v.checked_sub(1))?;
1529 if shape.iv_init != src_access.lower
1530 || shape.iv_init != dest_access.lower
1531 || shape.iv_bound != src_upper
1532 || shape.iv_bound != dest_upper
1533 {
1534 return None;
1535 }
1536 if src_access.elem_ty != dest_access.elem_ty {
1537 return None;
1538 }
1539 // ELSEWHERE arm: walk the else_block (when present). Shapes:
1540 // (a) `Store(invariant_const, dest_ptr)` — broadcast-in-preheader.
1541 // (b) `Load(body_gep_ptr); Store(load_val, dest_ptr)` — VLoad
1542 // on the body-defined ptr (e.g., `elsewhere; c = d`).
1543 // (c) `Load(p); FNeg/FAbs/FSqrt/INeg(load); Store(unary, dest)`.
1544 // (d) `Load(p); binop(load, K); Store(binop, dest)` where K is
1545 // loop-invariant.
1546 type ElseArmInfo = (
1547 Option<ValueId>,
1548 Option<ValueId>,
1549 Option<UnaryKind>,
1550 Option<(BinaryKind, ValueId, bool)>,
1551 );
1552 let (else_const, else_load_ptr, else_unary, else_binop): ElseArmInfo
1553 = if let Some(else_blk_id) = shape.else_block {
1554 let else_blk = func.block(else_blk_id);
1555 if else_blk
1556 .insts
1557 .iter()
1558 .any(|inst| matches!(inst.kind, InstKind::Call(..) | InstKind::RuntimeCall(..)))
1559 {
1560 return None;
1561 }
1562 let mut else_load: Option<(ValueId, ValueId)> = None;
1563 let mut e_unary: Option<(ValueId, UnaryKind)> = None;
1564 let mut e_binop: Option<(ValueId, BinaryKind, ValueId, bool)> = None;
1565 let mut else_store: Option<(ValueId, ValueId)> = None;
1566 for inst in &else_blk.insts {
1567 let is_else_load = |v: ValueId| else_load.map(|(id, _)| id) == Some(v);
1568 match inst.kind {
1569 InstKind::Load(p) if body_gep_ids.contains(&p) => {
1570 if else_load.is_some() {
1571 return None;
1572 }
1573 else_load = Some((inst.id, p));
1574 }
1575 InstKind::FNeg(src) | InstKind::INeg(src) => {
1576 if e_unary.is_some() || e_binop.is_some() || !is_else_load(src) {
1577 return None;
1578 }
1579 e_unary = Some((inst.id, UnaryKind::Neg));
1580 }
1581 InstKind::FAbs(src) => {
1582 if e_unary.is_some() || e_binop.is_some() || !is_else_load(src) {
1583 return None;
1584 }
1585 e_unary = Some((inst.id, UnaryKind::Abs));
1586 }
1587 InstKind::FSqrt(src) => {
1588 if e_unary.is_some() || e_binop.is_some() || !is_else_load(src) {
1589 return None;
1590 }
1591 e_unary = Some((inst.id, UnaryKind::Sqrt));
1592 }
1593 InstKind::IAdd(l, r)
1594 | InstKind::ISub(l, r)
1595 | InstKind::IMul(l, r)
1596 | InstKind::FAdd(l, r)
1597 | InstKind::FSub(l, r)
1598 | InstKind::FMul(l, r)
1599 | InstKind::FDiv(l, r) => {
1600 if e_unary.is_some() || e_binop.is_some() {
1601 return None;
1602 }
1603 let kind = match inst.kind {
1604 InstKind::IAdd(..) | InstKind::FAdd(..) => BinaryKind::Add,
1605 InstKind::ISub(..) | InstKind::FSub(..) => BinaryKind::Sub,
1606 InstKind::IMul(..) | InstKind::FMul(..) => BinaryKind::Mul,
1607 InstKind::FDiv(..) => BinaryKind::Div,
1608 _ => unreachable!(),
1609 };
1610 let (load_on_lhs, scalar_v) = if is_else_load(l) {
1611 (true, r)
1612 } else if is_else_load(r) {
1613 (false, l)
1614 } else {
1615 return None;
1616 };
1617 e_binop = Some((inst.id, kind, scalar_v, load_on_lhs));
1618 }
1619 InstKind::Store(v, p) => {
1620 if else_store.is_some() {
1621 return None;
1622 }
1623 else_store = Some((v, p));
1624 }
1625 _ => return None,
1626 }
1627 }
1628 let (else_v, else_p) = else_store?;
1629 if else_p != store_ptr {
1630 return None;
1631 }
1632 let else_ids: HashSet<ValueId> = else_blk.insts.iter().map(|i| i.id).collect();
1633 let unary_id = e_unary.map(|(id, _)| id);
1634 let binop_id = e_binop.map(|(id, _, _, _)| id);
1635 // Determine the case via the store_value.
1636 if let Some((load_id, load_ptr)) = else_load {
1637 // Validate the load's access shape covers the full span.
1638 let acc = classify_array_access(func, load_ptr, shape.iv_param)?;
1639 let upper = acc
1640 .lower
1641 .checked_add(acc.len as i64)
1642 .and_then(|v| v.checked_sub(1))?;
1643 if shape.iv_init != acc.lower
1644 || shape.iv_bound != upper
1645 || acc.elem_ty != src_access.elem_ty
1646 {
1647 return None;
1648 }
1649 if Some(else_v) == unary_id {
1650 // Case (c): unary on load.
1651 let (_, kind) = e_unary.unwrap();
1652 match (&src_access.elem_ty, kind) {
1653 (IrType::Float(_), UnaryKind::Neg)
1654 | (IrType::Float(_), UnaryKind::Abs)
1655 | (IrType::Float(_), UnaryKind::Sqrt)
1656 | (IrType::Int(_), UnaryKind::Neg) => {}
1657 _ => return None,
1658 }
1659 (None, Some(load_ptr), Some(kind), None)
1660 } else if Some(else_v) == binop_id {
1661 // Case (d): binop on (load, invariant_scalar).
1662 let (_, kind, scalar_v, load_on_lhs) = e_binop.unwrap();
1663 if body_ids.contains(&scalar_v)
1664 || then_ids.contains(&scalar_v)
1665 || else_ids.contains(&scalar_v)
1666 || incr_ids.contains(&scalar_v)
1667 {
1668 return None;
1669 }
1670 if matches!(kind, BinaryKind::Div)
1671 && !matches!(src_access.elem_ty, IrType::Float(_))
1672 {
1673 return None;
1674 }
1675 (None, Some(load_ptr), None, Some((kind, scalar_v, load_on_lhs)))
1676 } else if else_v == load_id {
1677 // Case (b): identity load.
1678 (None, Some(load_ptr), None, None)
1679 } else {
1680 return None;
1681 }
1682 } else {
1683 // No load in else_block — Case (a): invariant constant.
1684 if body_ids.contains(&else_v)
1685 || then_ids.contains(&else_v)
1686 || else_ids.contains(&else_v)
1687 || incr_ids.contains(&else_v)
1688 {
1689 return None;
1690 }
1691 (Some(else_v), None, None, None)
1692 }
1693 } else {
1694 (None, None, None, None)
1695 };
1696 // If the binop's other operand is a b-array load, OR the binop's
1697 // main load is on b, OR the unary is applied to a b-array load,
1698 // validate that b's access shape covers the same span and elem
1699 // type.
1700 let unary_on_b = then_unary.map(|(_, _, on_b)| on_b).unwrap_or(false);
1701 let binop_on_b_pair = then_binop
1702 .map(|b| b.other_is_load_b || b.binop_on_load_b)
1703 .unwrap_or(false);
1704 let b_ptr_id = if unary_on_b || binop_on_b_pair {
1705 let (_b_load_id, b_ptr) = then_load_b?;
1706 let b_access = classify_array_access(func, b_ptr, shape.iv_param)?;
1707 let b_upper = b_access
1708 .lower
1709 .checked_add(b_access.len as i64)
1710 .and_then(|v| v.checked_sub(1))?;
1711 if shape.iv_init != b_access.lower
1712 || shape.iv_bound != b_upper
1713 || b_access.elem_ty != src_access.elem_ty
1714 {
1715 return None;
1716 }
1717 Some(b_ptr)
1718 } else {
1719 None
1720 };
1721 let elem_ty = src_access.elem_ty.clone();
1722 let lanes = lane_count_for(&elem_ty)?;
1723 // Skip tail for v0: require trip divisible by lanes.
1724 if trip % (lanes as i64) != 0 {
1725 return None;
1726 }
1727 // FCmp requires float dest; ICmp requires int dest.
1728 match (&elem_ty, cmp_is_float) {
1729 (IrType::Float(_), true) | (IrType::Int(_), false) => {}
1730 _ => return None,
1731 }
1732 Some(WherePlan {
1733 lanes,
1734 elem_ty,
1735 load_a_id,
1736 cmp_id: cond_id,
1737 cmp_is_float,
1738 cmp_op,
1739 threshold_v,
1740 load_on_lhs,
1741 store_id,
1742 then_load_id,
1743 then_unary,
1744 then_binop,
1745 then_const,
1746 b_ptr_id,
1747 else_const,
1748 else_load_ptr,
1749 else_unary,
1750 else_binop,
1751 dest_ptr_id: store_ptr,
1752 src_access,
1753 dest_access,
1754 span: cmp_inst.span,
1755 })
1756 }
1757
1758 /// Rewrite a WHERE diamond into a vectorized straight-line body:
1759 /// body: vload a; vload b_old; v(f|i)cmp pred; vselect mask, va, vb_old;
1760 /// vstore result, b_ptr; br incr_block
1761 /// The original `then` block becomes unreachable.
1762 fn apply_where_plan(func: &mut Function, shape: &WhereLoop, plan: WherePlan) {
1763 let v_ty = IrType::Vector {
1764 elem: Box::new(plan.elem_ty.clone()),
1765 lanes: plan.lanes,
1766 };
1767 let span = plan.span;
1768
1769 // 1. Broadcast the threshold into the preheader (it's loop-invariant,
1770 // typically a const). Use VBroadcast so vfcmp/vicmp gets a
1771 // full vector lane.
1772 let bcast_id = {
1773 let preheader = func.block_mut(shape.preheader);
1774 let id = preheader.params.first().map(|_| ()).map_or_else(|| 0, |_| 0);
1775 let _ = id;
1776 let new_id = func.next_value_id();
1777 func.register_type(new_id, v_ty.clone());
1778 let preheader = func.block_mut(shape.preheader);
1779 // Insert just before the terminator branch.
1780 let pos = preheader.insts.len();
1781 preheader.insts.insert(
1782 pos,
1783 Inst {
1784 id: new_id,
1785 kind: InstKind::VBroadcast(plan.threshold_v),
1786 ty: v_ty.clone(),
1787 span,
1788 },
1789 );
1790 new_id
1791 };
1792
1793 // 2. Rewrite load_a (in body) to VLoad. Type changes from elem to vector.
1794 let load_a_ptr = {
1795 let body = func.block_mut(shape.body);
1796 let inst = body.insts.iter_mut().find(|i| i.id == plan.load_a_id).unwrap();
1797 let p = match inst.kind {
1798 InstKind::Load(p) => p,
1799 _ => unreachable!(),
1800 };
1801 inst.kind = InstKind::VLoad(p);
1802 inst.ty = v_ty.clone();
1803 p
1804 };
1805 func.register_type(plan.load_a_id, v_ty.clone());
1806 let _ = load_a_ptr;
1807
1808 // 3. In body block, after load_a, emit:
1809 // vload_b_old (only when no ELSEWHERE — we need the dest's
1810 // prior value for the masked-off lanes), v(f|i)cmp, optional
1811 // v-unary, vselect, vstore. The cmp+cond_br are dropped (we
1812 // replace the terminator below).
1813 let vcmp_id = func.next_value_id();
1814 func.register_type(vcmp_id, v_ty.clone());
1815 let vsel_id = func.next_value_id();
1816 func.register_type(vsel_id, v_ty.clone());
1817 let vstore_id = func.next_value_id();
1818 func.register_type(vstore_id, IrType::Void);
1819
1820 let mut new_insts: Vec<Inst> = Vec::new();
1821 // The false-mask arm: prefer ELSEWHERE-supplied values when
1822 // present, else reload the dest's prior value so masked-off lanes
1823 // are preserved.
1824 // else_const → VBroadcast(K) in preheader.
1825 // else_load_ptr → VLoad on a body-defined GEP (e.g. `c = d`).
1826 // neither → VLoad on dest_ptr_id (preserve old lanes).
1827 let false_arm_id = if let Some(else_v) = plan.else_const {
1828 let vk_id = func.next_value_id();
1829 func.register_type(vk_id, v_ty.clone());
1830 let preheader = func.block_mut(shape.preheader);
1831 let pos = preheader.insts.len();
1832 preheader.insts.insert(
1833 pos,
1834 Inst {
1835 id: vk_id,
1836 kind: InstKind::VBroadcast(else_v),
1837 ty: v_ty.clone(),
1838 span,
1839 },
1840 );
1841 vk_id
1842 } else if let Some(load_ptr) = plan.else_load_ptr {
1843 let vload_else_id = func.next_value_id();
1844 func.register_type(vload_else_id, v_ty.clone());
1845 new_insts.push(Inst {
1846 id: vload_else_id,
1847 kind: InstKind::VLoad(load_ptr),
1848 ty: v_ty.clone(),
1849 span,
1850 });
1851 // Apply unary or binop on the else load when present.
1852 if let Some(kind) = plan.else_unary {
1853 let vu_id = func.next_value_id();
1854 func.register_type(vu_id, v_ty.clone());
1855 let vu_kind = match kind {
1856 UnaryKind::Neg => InstKind::VNeg(vload_else_id),
1857 UnaryKind::Abs => InstKind::VAbs(vload_else_id),
1858 UnaryKind::Sqrt => InstKind::VSqrt(vload_else_id),
1859 };
1860 new_insts.push(Inst {
1861 id: vu_id,
1862 kind: vu_kind,
1863 ty: v_ty.clone(),
1864 span,
1865 });
1866 vu_id
1867 } else if let Some((kind, scalar_v, load_on_lhs)) = plan.else_binop {
1868 let vk_id = func.next_value_id();
1869 func.register_type(vk_id, v_ty.clone());
1870 let preheader = func.block_mut(shape.preheader);
1871 let pos = preheader.insts.len();
1872 preheader.insts.insert(
1873 pos,
1874 Inst {
1875 id: vk_id,
1876 kind: InstKind::VBroadcast(scalar_v),
1877 ty: v_ty.clone(),
1878 span,
1879 },
1880 );
1881 let (l_id, r_id) = if load_on_lhs {
1882 (vload_else_id, vk_id)
1883 } else {
1884 (vk_id, vload_else_id)
1885 };
1886 let vbin_id = func.next_value_id();
1887 func.register_type(vbin_id, v_ty.clone());
1888 let vbin_kind = match kind {
1889 BinaryKind::Add => InstKind::VAdd(l_id, r_id),
1890 BinaryKind::Sub => InstKind::VSub(l_id, r_id),
1891 BinaryKind::Mul => InstKind::VMul(l_id, r_id),
1892 BinaryKind::Div => InstKind::VDiv(l_id, r_id),
1893 BinaryKind::Min | BinaryKind::Max => InstKind::VAdd(l_id, r_id),
1894 };
1895 new_insts.push(Inst {
1896 id: vbin_id,
1897 kind: vbin_kind,
1898 ty: v_ty.clone(),
1899 span,
1900 });
1901 vbin_id
1902 } else {
1903 vload_else_id
1904 }
1905 } else {
1906 let vload_b_id = func.next_value_id();
1907 func.register_type(vload_b_id, v_ty.clone());
1908 new_insts.push(Inst {
1909 id: vload_b_id,
1910 kind: InstKind::VLoad(plan.dest_ptr_id),
1911 ty: v_ty.clone(),
1912 span,
1913 });
1914 vload_b_id
1915 };
1916 new_insts.push(Inst {
1917 id: vcmp_id,
1918 kind: if plan.cmp_is_float {
1919 if plan.load_on_lhs {
1920 InstKind::VFCmp(plan.cmp_op, plan.load_a_id, bcast_id)
1921 } else {
1922 InstKind::VFCmp(plan.cmp_op, bcast_id, plan.load_a_id)
1923 }
1924 } else if plan.load_on_lhs {
1925 InstKind::VICmp(plan.cmp_op, plan.load_a_id, bcast_id)
1926 } else {
1927 InstKind::VICmp(plan.cmp_op, bcast_id, plan.load_a_id)
1928 },
1929 ty: v_ty.clone(),
1930 span,
1931 });
1932 // If the WHERE body computes `b = unary(a)` or `b = a op K`,
1933 // emit the vector op on the vload_a value (broadcasting the
1934 // scalar K for the binop case) and use that as the vselect's
1935 // "true" arm.
1936 let true_arm_id = if let Some((_then_uid, kind, on_load_b)) = plan.then_unary {
1937 // The source vector for the unary: load_a (default) or a
1938 // VLoad on the b-array's ptr (`c = -d`).
1939 let src_vec_id = if on_load_b {
1940 let b_ptr = plan.b_ptr_id.expect("unary_on_load_b must have b_ptr_id");
1941 let vload_b_id = func.next_value_id();
1942 func.register_type(vload_b_id, v_ty.clone());
1943 new_insts.push(Inst {
1944 id: vload_b_id,
1945 kind: InstKind::VLoad(b_ptr),
1946 ty: v_ty.clone(),
1947 span,
1948 });
1949 vload_b_id
1950 } else {
1951 plan.load_a_id
1952 };
1953 let vu_id = func.next_value_id();
1954 func.register_type(vu_id, v_ty.clone());
1955 let vu_kind = match kind {
1956 UnaryKind::Neg => InstKind::VNeg(src_vec_id),
1957 UnaryKind::Abs => InstKind::VAbs(src_vec_id),
1958 UnaryKind::Sqrt => InstKind::VSqrt(src_vec_id),
1959 };
1960 new_insts.push(Inst {
1961 id: vu_id,
1962 kind: vu_kind,
1963 ty: v_ty.clone(),
1964 span,
1965 });
1966 vu_id
1967 } else if let Some(b) = plan.then_binop {
1968 // The "main load" is load_a unless `binop_on_load_b == true`,
1969 // in which case it's a fresh VLoad on b_ptr.
1970 let main_load_id = if b.binop_on_load_b {
1971 let b_ptr = plan.b_ptr_id.expect("binop_on_load_b must have b_ptr_id");
1972 let vload_b_id = func.next_value_id();
1973 func.register_type(vload_b_id, v_ty.clone());
1974 new_insts.push(Inst {
1975 id: vload_b_id,
1976 kind: InstKind::VLoad(b_ptr),
1977 ty: v_ty.clone(),
1978 span,
1979 });
1980 vload_b_id
1981 } else {
1982 plan.load_a_id
1983 };
1984 // Other operand: either a vload on the b-array's ptr (array+
1985 // array body) or a vbroadcast of a loop-invariant scalar.
1986 let other_v = if b.other_is_load_b {
1987 let b_ptr = plan.b_ptr_id.expect("load_b binop must have b_ptr_id");
1988 let vload_b_id = func.next_value_id();
1989 func.register_type(vload_b_id, v_ty.clone());
1990 new_insts.push(Inst {
1991 id: vload_b_id,
1992 kind: InstKind::VLoad(b_ptr),
1993 ty: v_ty.clone(),
1994 span,
1995 });
1996 vload_b_id
1997 } else {
1998 let vk_id = func.next_value_id();
1999 func.register_type(vk_id, v_ty.clone());
2000 let preheader = func.block_mut(shape.preheader);
2001 let pos = preheader.insts.len();
2002 preheader.insts.insert(
2003 pos,
2004 Inst {
2005 id: vk_id,
2006 kind: InstKind::VBroadcast(b.scalar_v),
2007 ty: v_ty.clone(),
2008 span,
2009 },
2010 );
2011 vk_id
2012 };
2013 // Compute the binop in body block, in original operand order.
2014 let (l_id, r_id) = if b.load_on_lhs {
2015 (main_load_id, other_v)
2016 } else {
2017 (other_v, main_load_id)
2018 };
2019 let vbin_id = func.next_value_id();
2020 func.register_type(vbin_id, v_ty.clone());
2021 let vbin_kind = match b.kind {
2022 BinaryKind::Add => InstKind::VAdd(l_id, r_id),
2023 BinaryKind::Sub => InstKind::VSub(l_id, r_id),
2024 BinaryKind::Mul => InstKind::VMul(l_id, r_id),
2025 BinaryKind::Div => InstKind::VDiv(l_id, r_id),
2026 // Min/Max not produced by then-binop walker (those are
2027 // recognized via Select, not directly).
2028 BinaryKind::Min | BinaryKind::Max => InstKind::VAdd(l_id, r_id),
2029 };
2030 new_insts.push(Inst {
2031 id: vbin_id,
2032 kind: vbin_kind,
2033 ty: v_ty.clone(),
2034 span,
2035 });
2036 vbin_id
2037 } else if let Some(k_scalar) = plan.then_const {
2038 // Broadcast the loop-invariant scalar in the preheader so the
2039 // vselect sees a full lane-vector of K's in its true arm.
2040 let vk_id = func.next_value_id();
2041 func.register_type(vk_id, v_ty.clone());
2042 let preheader = func.block_mut(shape.preheader);
2043 let pos = preheader.insts.len();
2044 preheader.insts.insert(
2045 pos,
2046 Inst {
2047 id: vk_id,
2048 kind: InstKind::VBroadcast(k_scalar),
2049 ty: v_ty.clone(),
2050 span,
2051 },
2052 );
2053 vk_id
2054 } else {
2055 plan.load_a_id
2056 };
2057 new_insts.push(Inst {
2058 id: vsel_id,
2059 kind: InstKind::VSelect(vcmp_id, true_arm_id, false_arm_id),
2060 ty: v_ty.clone(),
2061 span,
2062 });
2063 new_insts.push(Inst {
2064 id: vstore_id,
2065 kind: InstKind::VStore(vsel_id, plan.dest_ptr_id),
2066 ty: IrType::Void,
2067 span,
2068 });
2069
2070 // Drop the original cmp inst from the body (it's the cond_id) —
2071 // it'll be dead. Drop everything *after* load_a that we don't
2072 // need (the original cmp). For simplicity, walk the body, keep
2073 // load_a + its dependency chain (gep ptrs), drop the cmp.
2074 {
2075 let body = func.block_mut(shape.body);
2076 body.insts.retain(|i| i.id != plan.cmp_id);
2077 // Append the new vector ops at the end of the body.
2078 body.insts.extend(new_insts);
2079 // Replace cond_br terminator with unconditional br to incr.
2080 body.terminator = Some(Terminator::Branch(shape.incr_block, vec![]));
2081 }
2082
2083 // 4. Drop then-block: clear its insts and make it unreachable.
2084 // prune_unreachable will remove the block after the pass.
2085 {
2086 let then = func.block_mut(shape.then_block);
2087 then.insts.clear();
2088 then.terminator = Some(Terminator::Branch(shape.incr_block, vec![]));
2089 }
2090 // Same for the else_block (when ELSEWHERE was present).
2091 if let Some(else_id) = shape.else_block {
2092 let else_blk = func.block_mut(else_id);
2093 else_blk.insts.clear();
2094 else_blk.terminator = Some(Terminator::Branch(shape.incr_block, vec![]));
2095 }
2096
2097 // 5. Update the incr block's iadd to step by `lanes` instead of 1.
2098 let incr = func.block_mut(shape.incr_block);
2099 let step_id = match &incr.terminator {
2100 Some(Terminator::Branch(_, args)) if args.len() == 1 => args[0],
2101 _ => return,
2102 };
2103 let iadd_inst = incr.insts.iter().find(|i| i.id == step_id).cloned();
2104 let (iv_param, old_step_const, iv_int_width) = match iadd_inst {
2105 Some(inst) => match inst.kind {
2106 InstKind::IAdd(l, r) => {
2107 let (iv, k) = if l == shape.iv_param {
2108 (l, r)
2109 } else if r == shape.iv_param {
2110 (r, l)
2111 } else {
2112 return;
2113 };
2114 let width = match inst.ty {
2115 IrType::Int(w) => w,
2116 _ => return,
2117 };
2118 (iv, k, width)
2119 }
2120 _ => return,
2121 },
2122 _ => return,
2123 };
2124 let _ = iv_param;
2125 // Allocate a fresh ConstInt for the new step.
2126 let new_step = func.next_value_id();
2127 func.register_type(new_step, IrType::Int(iv_int_width));
2128 let incr = func.block_mut(shape.incr_block);
2129 incr.insts.insert(
2130 0,
2131 Inst {
2132 id: new_step,
2133 kind: InstKind::ConstInt(plan.lanes as i128, iv_int_width),
2134 ty: IrType::Int(iv_int_width),
2135 span,
2136 },
2137 );
2138 if let Some(inst) = incr.insts.iter_mut().find(|i| i.id == step_id) {
2139 if let InstKind::IAdd(l, r) = inst.kind {
2140 if l == old_step_const {
2141 inst.kind = InstKind::IAdd(new_step, r);
2142 } else if r == old_step_const {
2143 inst.kind = InstKind::IAdd(l, new_step);
2144 }
2145 }
2146 }
2147 }
2148
2149 /// Insert a fresh ConstInt for the head bound
2150 /// (`iv_init + head_count - 1`) into the preheader and rewire the
2151 /// original icmp's RHS to reference it; then peel `tail_count` scalar
2152 /// copies of the body into the top of the exit block, with the IV
2153 /// substituted by a constant per iteration.
2154 fn apply_scalar_tail_peel(
2155 func: &mut Function,
2156 shape: &CountedLoop,
2157 plan: &VectorPlan,
2158 body_snapshot: &[Inst],
2159 exit_block: BlockId,
2160 ) {
2161 let int_ty = IrType::Int(plan.iv_int_width);
2162
2163 // Insert the new bound const (iv_init + head_count - 1) at the
2164 // top of the preheader. It dominates the header's icmp.
2165 let new_bound = shape.iv_init + plan.head_count - 1;
2166 let new_bound_id = func.next_value_id();
2167 func.register_type(new_bound_id, int_ty.clone());
2168 let pre_block = func.block_mut(shape.preheader);
2169 pre_block.insts.insert(
2170 0,
2171 Inst {
2172 id: new_bound_id,
2173 kind: InstKind::ConstInt(new_bound as i128, plan.iv_int_width),
2174 ty: int_ty.clone(),
2175 span: plan.span,
2176 },
2177 );
2178
2179 // Rewrite the icmp's RHS to point at the new bound const.
2180 let header_block = func.block_mut(shape.header);
2181 if let Some(inst) = header_block
2182 .insts
2183 .iter_mut()
2184 .find(|i| i.id == shape.cond_id)
2185 {
2186 if let InstKind::ICmp(_, _, rhs) = &mut inst.kind {
2187 if *rhs == shape.bound_const_id {
2188 *rhs = new_bound_id;
2189 }
2190 }
2191 }
2192
2193 // Skip the step iadd in the snapshot: the peel doesn't need to
2194 // bump the IV.
2195 let step_inst_id = plan.step_iadd;
2196
2197 // Build a vector of `(new_inst_id, new_kind, ty, span)` per peel
2198 // iteration, then prepend them to the exit block's insts.
2199 let mut peeled: Vec<Inst> = Vec::new();
2200 for t in 0..plan.tail_count {
2201 let tail_iv = shape.iv_init + plan.head_count + t;
2202 let tail_iv_const_id = func.next_value_id();
2203 func.register_type(tail_iv_const_id, int_ty.clone());
2204 peeled.push(Inst {
2205 id: tail_iv_const_id,
2206 kind: InstKind::ConstInt(tail_iv as i128, plan.iv_int_width),
2207 ty: int_ty.clone(),
2208 span: plan.span,
2209 });
2210
2211 let mut val_map: HashMap<ValueId, ValueId> = HashMap::new();
2212 val_map.insert(shape.iv_param, tail_iv_const_id);
2213
2214 for inst in body_snapshot {
2215 // Skip the step iadd — peel iterations don't bump the IV.
2216 if inst.id == step_inst_id {
2217 continue;
2218 }
2219 let new_id = func.next_value_id();
2220 func.register_type(new_id, inst.ty.clone());
2221 let new_kind = remap_inst_kind(&inst.kind, &val_map);
2222 val_map.insert(inst.id, new_id);
2223 peeled.push(Inst {
2224 id: new_id,
2225 kind: new_kind,
2226 ty: inst.ty.clone(),
2227 span: inst.span,
2228 });
2229 }
2230 }
2231
2232 // Prepend peeled insts at the top of the exit block.
2233 let exit = func.block_mut(exit_block);
2234 let existing = std::mem::take(&mut exit.insts);
2235 let mut new_insts = peeled;
2236 new_insts.extend(existing);
2237 exit.insts = new_insts;
2238 }
2239
2240 /// Iterate the operands of a body op (one for `Copy`/`Unary`, two
2241 /// for `Binop`).
2242 fn op_operands(op: &BodyOp) -> Vec<&BinopOperand> {
2243 match op {
2244 BodyOp::Copy { source } | BodyOp::Unary { source, .. } => vec![source],
2245 BodyOp::Binop { lhs, rhs, .. } => vec![lhs, rhs],
2246 BodyOp::Fma { a, b, c, .. } => vec![a, b, c],
2247 }
2248 }
2249
2250 /// If `op` is an `ArrayLoad`, rewrite its scalar Load to a VLoad and
2251 /// register the load's type as the vector type.
2252 fn rewrite_array_load(
2253 func: &mut Function,
2254 body: BlockId,
2255 op: &BinopOperand,
2256 v_ty: &IrType,
2257 ) {
2258 let load_id = match op {
2259 BinopOperand::ArrayLoad(id) => *id,
2260 BinopOperand::InvariantScalar(_) => return,
2261 };
2262 let body_block = func.block_mut(body);
2263 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == load_id) {
2264 if let InstKind::Load(ptr) = inst.kind {
2265 inst.kind = InstKind::VLoad(ptr);
2266 inst.ty = v_ty.clone();
2267 }
2268 }
2269 func.register_type(load_id, v_ty.clone());
2270 }
2271
2272 /// If `op` is an `InvariantScalar`, append a `VBroadcast` to the
2273 /// loop's preheader (just before its terminator) and return the
2274 /// resulting vector value. Returns `None` for `ArrayLoad` operands —
2275 /// those are rewritten in place by the load loop.
2276 fn broadcast_if_invariant(
2277 func: &mut Function,
2278 preheader: BlockId,
2279 op: &BinopOperand,
2280 v_ty: &IrType,
2281 span: crate::lexer::Span,
2282 ) -> Option<ValueId> {
2283 let scalar = match op {
2284 BinopOperand::InvariantScalar(v) => *v,
2285 BinopOperand::ArrayLoad(_) => return None,
2286 };
2287 let new_id = func.next_value_id();
2288 func.register_type(new_id, v_ty.clone());
2289 let pre_block = func.block_mut(preheader);
2290 // Insert the broadcast just before the preheader's terminator
2291 // (which is the unconditional branch into the header).
2292 let pos = pre_block.insts.len();
2293 pre_block.insts.insert(
2294 pos,
2295 Inst {
2296 id: new_id,
2297 kind: InstKind::VBroadcast(scalar),
2298 ty: v_ty.clone(),
2299 span,
2300 },
2301 );
2302 Some(new_id)
2303 }
2304
2305 fn inst_map(func: &Function) -> HashMap<ValueId, &Inst> {
2306 func.blocks
2307 .iter()
2308 .flat_map(|block| block.insts.iter())
2309 .map(|inst| (inst.id, inst))
2310 .collect()
2311 }
2312
2313 /// What feeds the accumulator on each iteration. `Sum` is a single
2314 /// load; `Dot` multiplies two loads, then adds the product to the
2315 /// accumulator (i.e. dot-product fold).
2316 #[derive(Debug, Clone)]
2317 enum AccumulateSource {
2318 Sum {
2319 load_id: ValueId,
2320 },
2321 /// `acc' = acc + neg(load)` or `acc' = acc + abs(load)`. The
2322 /// pre-existing `Sum` rewriter rewrites the load → vload; we
2323 /// also rewrite the unary `INeg`/`FNeg` → `VNeg` and
2324 /// `FAbs` → `VAbs` so the pre-fold value flows through the
2325 /// vector lanes.
2326 SumWithUnary {
2327 load_id: ValueId,
2328 unary_id: ValueId,
2329 kind: UnaryKind,
2330 },
2331 Dot {
2332 imul_id: ValueId,
2333 load_a: ValueId,
2334 load_b: ValueId,
2335 },
2336 /// `acc' = acc + (a(i) - b(i))` — sum of differences (variance,
2337 /// MSE, L1-distance numerator). The body has two loads and one
2338 /// `ISub`/`FSub` feeding `IAdd`/`FAdd` into the accumulator.
2339 SumOfDiff {
2340 sub_id: ValueId,
2341 load_a: ValueId,
2342 load_b: ValueId,
2343 },
2344 }
2345
2346 /// What kind of accumulator combine the body performs.
2347 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
2348 enum ReductionKind {
2349 /// `acc' = acc + value` (or `acc - value`, treated as Sum after
2350 /// negation — not yet supported).
2351 Sum,
2352 /// `acc' = max(acc, load)` lowered as `select(icmp ge, acc, load)`.
2353 Max,
2354 /// `acc' = min(acc, load)` lowered as `select(icmp le, acc, load)`.
2355 Min,
2356 }
2357
2358 /// A sum-reduction loop:
2359 /// `do i = lo, hi; s = s + a(i); end do` (or fadd for floats).
2360 /// or a dot-product fold:
2361 /// `do i = lo, hi; s = s + a(i)*b(i); end do`.
2362 ///
2363 /// The loop header carries the IV and a scalar accumulator as block
2364 /// params. The accumulator escapes the loop and is reduced to a
2365 /// scalar `vreduce_sum` after the vectorized body.
2366 #[derive(Debug, Clone)]
2367 struct ReductionPlan {
2368 preheader: BlockId,
2369 header: BlockId,
2370 body: BlockId,
2371 /// Block reachable when the loop exits (false-dest of header
2372 /// cond_br).
2373 exit: BlockId,
2374 /// Block param indices in the header: `[iv_idx, acc_idx]`.
2375 iv_param: ValueId,
2376 acc_param: ValueId,
2377 acc_param_idx: usize,
2378 /// Scalar accumulator init value passed in the preheader's branch.
2379 acc_init: ValueId,
2380 /// What computes the per-iteration value to fold into `acc`.
2381 source: AccumulateSource,
2382 /// What combine op the body performs (sum / min / max).
2383 reduce: ReductionKind,
2384 /// Original `acc' = ...` instruction (the IAdd, FAdd, or
2385 /// `select(icmp, ...)` whose result feeds back into the header).
2386 /// For min/max the icmp is part of the rewrite too — we hoist
2387 /// the `select+icmp` pair into a single `vmin`/`vmax`.
2388 accumulate_id: ValueId,
2389 /// For Min/Max, the icmp instruction we'll discard during the
2390 /// rewrite (its result is dead once the select becomes vmin/vmax).
2391 cmp_id: Option<ValueId>,
2392 /// Original `iv' = iv + 1` instruction.
2393 step_iadd: ValueId,
2394 /// The `1` ConstInt operand of the iv step.
2395 step_const: ValueId,
2396 /// IV ConstInt width.
2397 iv_int_width: IntWidth,
2398 /// Element type (i32 / i64 / f32 / f64).
2399 elem_ty: IrType,
2400 lanes: u8,
2401 /// IV's lower bound (preheader passes this as the initial iv).
2402 iv_init: i64,
2403 /// Number of vector iterations × `lanes`. When `tail_count > 0`,
2404 /// the head runs vectorized for `head_count` iterations and the
2405 /// remaining `tail_count` iterations are peeled as scalar code
2406 /// after the post-loop `vreduce_*`.
2407 head_count: i64,
2408 tail_count: i64,
2409 /// Header's `icmp le|lt iv, hi_const` instruction id, plus the
2410 /// const id feeding its RHS. Needed when `tail_count > 0` to
2411 /// retarget the bound to `iv_init + head_count - 1`.
2412 cond_id: ValueId,
2413 bound_const_id: ValueId,
2414 span: crate::lexer::Span,
2415 }
2416
2417 fn detect_reduction_plan(
2418 func: &Function,
2419 lp: &NaturalLoop,
2420 preds: &HashMap<BlockId, Vec<BlockId>>,
2421 ) -> Option<ReductionPlan> {
2422 if lp.latches.len() != 1 || lp.body.len() != 2 {
2423 return None;
2424 }
2425 let header = lp.header;
2426 let body = lp.latches[0];
2427 if body == header {
2428 return None;
2429 }
2430
2431 let header_block = func.block(header);
2432 if header_block.params.len() != 2 {
2433 return None;
2434 }
2435 // Identify which param is the IV (int type, used as gep index)
2436 // and which is the accumulator. We require the IV to be param 0
2437 // in this MVP — Fortran's lowered form always emits IV first.
2438 let iv_param = header_block.params[0].id;
2439 let acc_param = header_block.params[1].id;
2440 let iv_int_width = match header_block.params[0].ty {
2441 IrType::Int(w) => w,
2442 _ => return None,
2443 };
2444 let acc_ty = header_block.params[1].ty.clone();
2445 let elem_ty = match acc_ty.clone() {
2446 IrType::Int(_) | IrType::Float(_) => acc_ty,
2447 _ => return None,
2448 };
2449 let lanes = lane_count_for(&elem_ty)?;
2450
2451 let preheader = find_preheader(func, lp, preds)?;
2452 let (iv_init, acc_init) = match &func.block(preheader).terminator {
2453 Some(Terminator::Branch(dest, args)) if *dest == header && args.len() == 2 => {
2454 (resolve_const_int(func, args[0])?, args[1])
2455 }
2456 _ => return None,
2457 };
2458
2459 // Header cond_br shape: `iv <= bound` → body, exit.
2460 let (cond_id, true_dest, false_dest, true_args, false_args) = match &header_block.terminator {
2461 Some(Terminator::CondBranch {
2462 cond,
2463 true_dest,
2464 true_args,
2465 false_dest,
2466 false_args,
2467 }) => (*cond, *true_dest, *false_dest, true_args, false_args),
2468 _ => return None,
2469 };
2470 if !true_args.is_empty()
2471 || !false_args.is_empty()
2472 || true_dest != body
2473 || lp.body.contains(&false_dest)
2474 {
2475 return None;
2476 }
2477 let exit = false_dest;
2478 let cond_inst = header_block.insts.iter().find(|inst| inst.id == cond_id)?;
2479 let (iv_bound, bound_const_id) = match cond_inst.kind {
2480 InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => {
2481 (resolve_const_int(func, rhs)?, rhs)
2482 }
2483 InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => {
2484 (resolve_const_int(func, rhs)?.checked_sub(1)?, rhs)
2485 }
2486 _ => return None,
2487 };
2488
2489 let trip = iv_bound.checked_sub(iv_init).and_then(|d| d.checked_add(1))?;
2490 if trip <= 0 {
2491 return None;
2492 }
2493 let head_count = trip - (trip % lanes as i64);
2494 if head_count == 0 {
2495 return None;
2496 }
2497 let tail_count = trip - head_count;
2498
2499 // Body shape: load + iadd(acc, load) + iadd(iv, 1) + branch back.
2500 let body_block = func.block(body);
2501 if body_block
2502 .insts
2503 .iter()
2504 .any(|inst| matches!(inst.kind, InstKind::Call(..) | InstKind::RuntimeCall(..)))
2505 {
2506 return None;
2507 }
2508 let body_term_arg_iv;
2509 let body_term_arg_acc;
2510 match &body_block.terminator {
2511 Some(Terminator::Branch(dest, args)) if *dest == header && args.len() == 2 => {
2512 body_term_arg_iv = args[0];
2513 body_term_arg_acc = args[1];
2514 }
2515 _ => return None,
2516 }
2517
2518 let defs = inst_map(func);
2519 // The acc-update is one of:
2520 // acc' = acc + value (Sum / Dot, IAdd or FAdd)
2521 // acc' = select(icmp_ge_or_gt acc, value, acc, value) (Max)
2522 // acc' = select(icmp_le_or_lt acc, value, acc, value) (Min)
2523 let accumulate_inst = defs.get(&body_term_arg_acc)?;
2524 let (reduce, cmp_id, acc_lhs, acc_rhs) = match (&elem_ty, &accumulate_inst.kind) {
2525 (IrType::Int(_), InstKind::IAdd(l, r)) => (ReductionKind::Sum, None, *l, *r),
2526 (IrType::Float(_), InstKind::FAdd(l, r)) => (ReductionKind::Sum, None, *l, *r),
2527 (_, InstKind::Select(c, t, f)) => {
2528 // Look at the predicate: it must be `icmp <op> acc, load`
2529 // (or `icmp <op> load, acc`). The arms must be `(acc,
2530 // load)` for max/min (so the select picks acc when the
2531 // predicate is true).
2532 let cmp_inst = defs.get(c)?;
2533 let (cmp_op, cmp_a, cmp_b) = match cmp_inst.kind {
2534 InstKind::ICmp(op, a, b) => (op, a, b),
2535 InstKind::FCmp(op, a, b) => (op, a, b),
2536 _ => return None,
2537 };
2538 // Identify which side of the icmp / select is the acc
2539 // and infer Max vs Min.
2540 //
2541 // select(acc >= value, acc, value) → max(acc, value)
2542 // select(acc <= value, acc, value) → min(acc, value)
2543 // select(value >= acc, value, acc) → max(acc, value)
2544 // select(value <= acc, value, acc) → min(acc, value)
2545 //
2546 // For NEON SmaxV4S/SminV4S the operand order doesn't
2547 // matter (commutative).
2548 let (kind, acc_side, value_side) = if cmp_a == acc_param && *t == acc_param {
2549 let kind = match cmp_op {
2550 CmpOp::Ge | CmpOp::Gt => ReductionKind::Max,
2551 CmpOp::Le | CmpOp::Lt => ReductionKind::Min,
2552 _ => return None,
2553 };
2554 (kind, *t, *f)
2555 } else if cmp_b == acc_param && *t != acc_param && *f == acc_param {
2556 let kind = match cmp_op {
2557 CmpOp::Le | CmpOp::Lt => ReductionKind::Max,
2558 CmpOp::Ge | CmpOp::Gt => ReductionKind::Min,
2559 _ => return None,
2560 };
2561 (kind, *f, *t)
2562 } else {
2563 return None;
2564 };
2565 (kind, Some(cmp_inst.id), acc_side, value_side)
2566 }
2567 _ => return None,
2568 };
2569 // For Sum / Dot (IAdd/FAdd), one operand must be acc_param. For
2570 // Max/Min the `acc_lhs` is already the acc and `acc_rhs` is the
2571 // value (set up by the match arm above).
2572 let (accumulate_id, value_v) = if matches!(reduce, ReductionKind::Sum) {
2573 if acc_lhs == acc_param {
2574 (accumulate_inst.id, acc_rhs)
2575 } else if acc_rhs == acc_param {
2576 (accumulate_inst.id, acc_lhs)
2577 } else {
2578 return None;
2579 }
2580 } else {
2581 (accumulate_inst.id, acc_rhs)
2582 };
2583 let value_inst = defs.get(&value_v)?;
2584 // Classify `value_v` as a load (Sum / Min / Max), an
2585 // unary-of-load (Sum / Min / Max — folds to VNeg / VAbs), or
2586 // an imul/fmul of two loads (Sum-only Dot fold). The
2587 // dot-product fold is meaningless under min/max.
2588 if !matches!(reduce, ReductionKind::Sum)
2589 && !matches!(
2590 value_inst.kind,
2591 InstKind::Load(_)
2592 | InstKind::INeg(_)
2593 | InstKind::FNeg(_)
2594 | InstKind::FAbs(_)
2595 )
2596 {
2597 return None;
2598 }
2599 let source = match (&elem_ty, &value_inst.kind) {
2600 (_, InstKind::Load(load_ptr)) => {
2601 let access = classify_array_access(func, *load_ptr, iv_param)?;
2602 if access.elem_ty != elem_ty {
2603 return None;
2604 }
2605 let upper = access
2606 .lower
2607 .checked_add(access.len as i64)
2608 .and_then(|v| v.checked_sub(1))?;
2609 if iv_init != access.lower || iv_bound != upper {
2610 return None;
2611 }
2612 AccumulateSource::Sum {
2613 load_id: value_inst.id,
2614 }
2615 }
2616 // Reductions over `acc + (-load)` / `acc + abs(load)` (Sum)
2617 // or `max/min(acc, abs(load))` (Min/Max) — the unary
2618 // applies per-element and lifts cleanly to VNeg / VAbs.
2619 (IrType::Int(_), InstKind::INeg(load_v))
2620 | (IrType::Float(_), InstKind::FNeg(load_v))
2621 | (IrType::Float(_), InstKind::FAbs(load_v)) => {
2622 let unary_kind = match value_inst.kind {
2623 InstKind::FAbs(_) => UnaryKind::Abs,
2624 _ => UnaryKind::Neg,
2625 };
2626 let load_inst = defs.get(load_v)?;
2627 let load_ptr = match load_inst.kind {
2628 InstKind::Load(p) => p,
2629 _ => return None,
2630 };
2631 let access = classify_array_access(func, load_ptr, iv_param)?;
2632 if access.elem_ty != elem_ty {
2633 return None;
2634 }
2635 let upper = access
2636 .lower
2637 .checked_add(access.len as i64)
2638 .and_then(|v| v.checked_sub(1))?;
2639 if iv_init != access.lower || iv_bound != upper {
2640 return None;
2641 }
2642 AccumulateSource::SumWithUnary {
2643 load_id: load_inst.id,
2644 unary_id: value_inst.id,
2645 kind: unary_kind,
2646 }
2647 }
2648 (IrType::Int(_), InstKind::IMul(la, lb))
2649 | (IrType::Float(_), InstKind::FMul(la, lb)) => {
2650 let load_a_inst = defs.get(la)?;
2651 let load_b_inst = defs.get(lb)?;
2652 let InstKind::Load(ptr_a) = load_a_inst.kind else {
2653 return None;
2654 };
2655 let InstKind::Load(ptr_b) = load_b_inst.kind else {
2656 return None;
2657 };
2658 let acc_a = classify_array_access(func, ptr_a, iv_param)?;
2659 let acc_b = classify_array_access(func, ptr_b, iv_param)?;
2660 if acc_a.elem_ty != elem_ty || acc_b.elem_ty != elem_ty {
2661 return None;
2662 }
2663 let upper_a = acc_a
2664 .lower
2665 .checked_add(acc_a.len as i64)
2666 .and_then(|v| v.checked_sub(1))?;
2667 let upper_b = acc_b
2668 .lower
2669 .checked_add(acc_b.len as i64)
2670 .and_then(|v| v.checked_sub(1))?;
2671 if iv_init != acc_a.lower
2672 || iv_bound != upper_a
2673 || iv_init != acc_b.lower
2674 || iv_bound != upper_b
2675 {
2676 return None;
2677 }
2678 AccumulateSource::Dot {
2679 imul_id: value_inst.id,
2680 load_a: load_a_inst.id,
2681 load_b: load_b_inst.id,
2682 }
2683 }
2684 // `acc + (a(i) - b(i))` — sum of differences. Two loads, one
2685 // sub feeding the accumulator's add.
2686 (IrType::Int(_), InstKind::ISub(la, lb))
2687 | (IrType::Float(_), InstKind::FSub(la, lb)) => {
2688 let load_a_inst = defs.get(la)?;
2689 let load_b_inst = defs.get(lb)?;
2690 let InstKind::Load(ptr_a) = load_a_inst.kind else {
2691 return None;
2692 };
2693 let InstKind::Load(ptr_b) = load_b_inst.kind else {
2694 return None;
2695 };
2696 let acc_a = classify_array_access(func, ptr_a, iv_param)?;
2697 let acc_b = classify_array_access(func, ptr_b, iv_param)?;
2698 if acc_a.elem_ty != elem_ty || acc_b.elem_ty != elem_ty {
2699 return None;
2700 }
2701 let upper_a = acc_a
2702 .lower
2703 .checked_add(acc_a.len as i64)
2704 .and_then(|v| v.checked_sub(1))?;
2705 let upper_b = acc_b
2706 .lower
2707 .checked_add(acc_b.len as i64)
2708 .and_then(|v| v.checked_sub(1))?;
2709 if iv_init != acc_a.lower
2710 || iv_bound != upper_a
2711 || iv_init != acc_b.lower
2712 || iv_bound != upper_b
2713 {
2714 return None;
2715 }
2716 AccumulateSource::SumOfDiff {
2717 sub_id: value_inst.id,
2718 load_a: load_a_inst.id,
2719 load_b: load_b_inst.id,
2720 }
2721 }
2722 _ => return None,
2723 };
2724
2725 // The iv step.
2726 let step_inst = defs.get(&body_term_arg_iv)?;
2727 let (step_lhs, step_rhs) = match step_inst.kind {
2728 InstKind::IAdd(l, r) => (l, r),
2729 _ => return None,
2730 };
2731 let (step_const, _) = if step_lhs == iv_param {
2732 (step_rhs, resolve_const_int(func, step_rhs)?)
2733 } else if step_rhs == iv_param {
2734 (step_lhs, resolve_const_int(func, step_lhs)?)
2735 } else {
2736 return None;
2737 };
2738 let step_iadd = step_inst.id;
2739
2740 // Validate that `acc_param` doesn't have any *other* uses inside
2741 // the loop besides the accumulate inst (and, for Min/Max, the
2742 // companion icmp that we'll discard during rewrite).
2743 let acc_extra_uses: usize = func
2744 .blocks
2745 .iter()
2746 .filter(|b| lp.body.contains(&b.id))
2747 .flat_map(|b| b.insts.iter())
2748 .filter(|inst| inst.id != accumulate_id && Some(inst.id) != cmp_id)
2749 .filter(|inst| inst_uses(&inst.kind).contains(&acc_param))
2750 .count();
2751 if acc_extra_uses != 0 {
2752 return None;
2753 }
2754 // Min/Max codegen is wired for i32 (smaxv/sminv.4s + umov.s),
2755 // f32 (fmaxv/fminv.4s direct to scalar fp reg), and f64 (no
2756 // fmaxv.2d on NEON; fmaxp/fminp.2d gives the across-lane reduce
2757 // for the two f64 lanes).
2758 if !matches!(reduce, ReductionKind::Sum)
2759 && !matches!(
2760 elem_ty,
2761 IrType::Int(IntWidth::I32)
2762 | IrType::Float(FloatWidth::F32)
2763 | IrType::Float(FloatWidth::F64)
2764 )
2765 {
2766 return None;
2767 }
2768
2769 // The accumulate_inst result must not be used inside the loop
2770 // (other than as the body terminator's arg). All in-loop uses
2771 // would conflict with our vector rewrite.
2772 let acc_result_extra_uses: usize = func
2773 .blocks
2774 .iter()
2775 .filter(|b| lp.body.contains(&b.id))
2776 .flat_map(|b| b.insts.iter())
2777 .filter(|inst| inst_uses(&inst.kind).contains(&accumulate_id))
2778 .count();
2779 if acc_result_extra_uses != 0 {
2780 return None;
2781 }
2782
2783 Some(ReductionPlan {
2784 preheader,
2785 header,
2786 body,
2787 exit,
2788 iv_param,
2789 acc_param,
2790 acc_param_idx: 1,
2791 acc_init,
2792 source,
2793 reduce,
2794 accumulate_id,
2795 cmp_id,
2796 step_iadd,
2797 step_const,
2798 iv_int_width,
2799 elem_ty,
2800 lanes,
2801 iv_init,
2802 head_count,
2803 tail_count,
2804 cond_id,
2805 bound_const_id,
2806 span: accumulate_inst.span,
2807 })
2808 }
2809
2810 fn apply_reduction_plan(func: &mut Function, lp: &NaturalLoop, plan: ReductionPlan) {
2811 let v_ty = vector_ty(&plan.elem_ty, plan.lanes);
2812
2813 // 0. Snapshot the body before any in-place mutation. Used by the
2814 // scalar-tail peel below (sum reductions only).
2815 let body_snapshot: Option<Vec<Inst>> = if plan.tail_count > 0 {
2816 Some(func.block(plan.body).insts.clone())
2817 } else {
2818 None
2819 };
2820
2821 // 1. Insert `vacc_init = vbroadcast(acc_init)` at the end of the
2822 // preheader, before its branch terminator.
2823 let vacc_init = func.next_value_id();
2824 func.register_type(vacc_init, v_ty.clone());
2825 let pre_block = func.block_mut(plan.preheader);
2826 let pos = pre_block.insts.len();
2827 pre_block.insts.insert(
2828 pos,
2829 Inst {
2830 id: vacc_init,
2831 kind: InstKind::VBroadcast(plan.acc_init),
2832 ty: v_ty.clone(),
2833 span: plan.span,
2834 },
2835 );
2836 // 1b. Update the preheader branch arg slot for the accumulator.
2837 if let Some(Terminator::Branch(_, args)) = &mut pre_block.terminator {
2838 if let Some(slot) = args.get_mut(plan.acc_param_idx) {
2839 *slot = vacc_init;
2840 }
2841 }
2842
2843 // 2. Update the header's accumulator block param type to the
2844 // vector type.
2845 let header_block = func.block_mut(plan.header);
2846 if let Some(param) = header_block.params.get_mut(plan.acc_param_idx) {
2847 param.ty = v_ty.clone();
2848 }
2849 func.register_type(plan.acc_param, v_ty.clone());
2850
2851 // 3. Insert a fresh ConstInt(V) for the iv step (avoid clobbering
2852 // a shared `1`).
2853 let new_step_const = func.next_value_id();
2854 let step_const_ty = IrType::Int(plan.iv_int_width);
2855 func.register_type(new_step_const, step_const_ty.clone());
2856 let body_block = func.block_mut(plan.body);
2857 body_block.insts.insert(
2858 0,
2859 Inst {
2860 id: new_step_const,
2861 kind: InstKind::ConstInt(plan.lanes as i128, plan.iv_int_width),
2862 ty: step_const_ty,
2863 span: plan.span,
2864 },
2865 );
2866 // Update the iadd to reference the new const.
2867 if let Some(step_inst) = body_block
2868 .insts
2869 .iter_mut()
2870 .find(|inst| inst.id == plan.step_iadd)
2871 {
2872 if let InstKind::IAdd(ref mut l, ref mut r) = step_inst.kind {
2873 if *l == plan.step_const {
2874 *l = new_step_const;
2875 }
2876 if *r == plan.step_const {
2877 *r = new_step_const;
2878 }
2879 }
2880 }
2881
2882 // 4. Rewrite the per-iteration source value:
2883 // - Sum: one Load → VLoad.
2884 // - SumWithUnary: Load → VLoad and unary → VNeg / VAbs.
2885 // - Dot: two Loads → VLoad each, plus IMul/FMul → VMul.
2886 match plan.source.clone() {
2887 AccumulateSource::Sum { load_id } => {
2888 let body_block = func.block_mut(plan.body);
2889 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == load_id) {
2890 if let InstKind::Load(ptr) = inst.kind {
2891 inst.kind = InstKind::VLoad(ptr);
2892 inst.ty = v_ty.clone();
2893 }
2894 }
2895 func.register_type(load_id, v_ty.clone());
2896 }
2897 AccumulateSource::SumWithUnary {
2898 load_id,
2899 unary_id,
2900 kind,
2901 } => {
2902 let body_block = func.block_mut(plan.body);
2903 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == load_id) {
2904 if let InstKind::Load(ptr) = inst.kind {
2905 inst.kind = InstKind::VLoad(ptr);
2906 inst.ty = v_ty.clone();
2907 }
2908 }
2909 func.register_type(load_id, v_ty.clone());
2910 let body_block = func.block_mut(plan.body);
2911 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == unary_id) {
2912 let new_kind = match (inst.kind.clone(), kind) {
2913 (InstKind::INeg(s), UnaryKind::Neg)
2914 | (InstKind::FNeg(s), UnaryKind::Neg) => InstKind::VNeg(s),
2915 (InstKind::FAbs(s), UnaryKind::Abs) => InstKind::VAbs(s),
2916 (other, _) => other,
2917 };
2918 inst.kind = new_kind;
2919 inst.ty = v_ty.clone();
2920 }
2921 func.register_type(unary_id, v_ty.clone());
2922 }
2923 AccumulateSource::Dot {
2924 imul_id,
2925 load_a,
2926 load_b,
2927 } => {
2928 for load_id in [load_a, load_b] {
2929 let body_block = func.block_mut(plan.body);
2930 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == load_id) {
2931 if let InstKind::Load(ptr) = inst.kind {
2932 inst.kind = InstKind::VLoad(ptr);
2933 inst.ty = v_ty.clone();
2934 }
2935 }
2936 func.register_type(load_id, v_ty.clone());
2937 }
2938 let body_block = func.block_mut(plan.body);
2939 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == imul_id) {
2940 let new_kind = match inst.kind.clone() {
2941 InstKind::IMul(l, r) | InstKind::FMul(l, r) => InstKind::VMul(l, r),
2942 other => other,
2943 };
2944 inst.kind = new_kind;
2945 inst.ty = v_ty.clone();
2946 }
2947 func.register_type(imul_id, v_ty.clone());
2948 }
2949 AccumulateSource::SumOfDiff {
2950 sub_id,
2951 load_a,
2952 load_b,
2953 } => {
2954 for load_id in [load_a, load_b] {
2955 let body_block = func.block_mut(plan.body);
2956 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == load_id) {
2957 if let InstKind::Load(ptr) = inst.kind {
2958 inst.kind = InstKind::VLoad(ptr);
2959 inst.ty = v_ty.clone();
2960 }
2961 }
2962 func.register_type(load_id, v_ty.clone());
2963 }
2964 let body_block = func.block_mut(plan.body);
2965 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == sub_id) {
2966 let new_kind = match inst.kind.clone() {
2967 InstKind::ISub(l, r) | InstKind::FSub(l, r) => InstKind::VSub(l, r),
2968 other => other,
2969 };
2970 inst.kind = new_kind;
2971 inst.ty = v_ty.clone();
2972 }
2973 func.register_type(sub_id, v_ty.clone());
2974 }
2975 }
2976
2977 // 5. Rewrite the accumulate inst:
2978 // - Sum: IAdd/FAdd → VAdd
2979 // - Max: select(icmp, acc, value) → VMax(acc, value)
2980 // - Min: select(icmp, acc, value) → VMin(acc, value)
2981 // For Min/Max we also drop the icmp predicate's type since
2982 // its result is no longer used (regalloc will dead-code it).
2983 let body_block = func.block_mut(plan.body);
2984 if let Some(inst) = body_block.insts.iter_mut().find(|i| i.id == plan.accumulate_id) {
2985 let acc_v = plan.acc_param;
2986 let new_kind = match (inst.kind.clone(), plan.reduce) {
2987 (InstKind::IAdd(l, r), ReductionKind::Sum)
2988 | (InstKind::FAdd(l, r), ReductionKind::Sum) => InstKind::VAdd(l, r),
2989 (InstKind::Select(_, t, f), ReductionKind::Max) => {
2990 // The detection guarantees one arm is acc, the
2991 // other is the value (now a vector). Use whichever
2992 // arm equals acc.
2993 if t == acc_v {
2994 InstKind::VMax(t, f)
2995 } else {
2996 InstKind::VMax(f, t)
2997 }
2998 }
2999 (InstKind::Select(_, t, f), ReductionKind::Min) => {
3000 if t == acc_v {
3001 InstKind::VMin(t, f)
3002 } else {
3003 InstKind::VMin(f, t)
3004 }
3005 }
3006 (other, _) => other,
3007 };
3008 inst.kind = new_kind;
3009 inst.ty = v_ty.clone();
3010 }
3011 func.register_type(plan.accumulate_id, v_ty.clone());
3012
3013 // 6. Insert `acc_scalar = vreduce_*(acc_param)` at the top of
3014 // the exit block, then walk every block NOT in the loop and
3015 // rewrite acc_param → acc_scalar.
3016 let acc_scalar = func.next_value_id();
3017 func.register_type(acc_scalar, plan.elem_ty.clone());
3018 let reduce_kind = match plan.reduce {
3019 ReductionKind::Sum => InstKind::VReduceSum(plan.acc_param),
3020 ReductionKind::Max => InstKind::VReduceMax(plan.acc_param),
3021 ReductionKind::Min => InstKind::VReduceMin(plan.acc_param),
3022 };
3023 let exit_block = func.block_mut(plan.exit);
3024 exit_block.insts.insert(
3025 0,
3026 Inst {
3027 id: acc_scalar,
3028 kind: reduce_kind,
3029 ty: plan.elem_ty.clone(),
3030 span: plan.span,
3031 },
3032 );
3033 let lp_body: HashSet<BlockId> = lp.body.iter().copied().collect();
3034 for block in func.blocks.iter_mut() {
3035 if lp_body.contains(&block.id) {
3036 continue;
3037 }
3038 for inst in &mut block.insts {
3039 // Skip the vreduce we just inserted — its sole purpose
3040 // is to consume the (now-vector) acc_param.
3041 if inst.id == acc_scalar {
3042 continue;
3043 }
3044 substitute_in_inst(&mut inst.kind, plan.acc_param, acc_scalar);
3045 }
3046 if let Some(term) = &mut block.terminator {
3047 substitute_in_terminator(term, plan.acc_param, acc_scalar);
3048 }
3049 }
3050
3051 // 7. Reduction scalar tail (sum only). Peel `tail_count` scalar
3052 // iterations into the exit block so they accumulate from
3053 // `acc_scalar` into a chained `final_acc`. Then retarget
3054 // post-tail consumers of `acc_scalar` to `final_acc`.
3055 if plan.tail_count > 0 {
3056 if let Some(snapshot) = body_snapshot {
3057 apply_reduction_scalar_tail(func, &plan, &snapshot, acc_scalar, &lp_body);
3058 }
3059 }
3060 }
3061
3062 /// Peel `plan.tail_count` scalar iterations of the body into the
3063 /// exit block (just after the `vreduce_*`), each iteration chaining
3064 /// from the previous accumulator. The first iteration's seed is
3065 /// `acc_scalar`; the last produces `final_acc`. After peeling, every
3066 /// non-loop, non-peel use of `acc_scalar` is rewritten to
3067 /// `final_acc`.
3068 fn apply_reduction_scalar_tail(
3069 func: &mut Function,
3070 plan: &ReductionPlan,
3071 body_snapshot: &[Inst],
3072 acc_scalar: ValueId,
3073 lp_body: &HashSet<BlockId>,
3074 ) {
3075 let int_ty = IrType::Int(plan.iv_int_width);
3076
3077 // Insert the new head-bound const at the top of the preheader.
3078 let new_bound = plan.iv_init + plan.head_count - 1;
3079 let new_bound_id = func.next_value_id();
3080 func.register_type(new_bound_id, int_ty.clone());
3081 func.block_mut(plan.preheader).insts.insert(
3082 0,
3083 Inst {
3084 id: new_bound_id,
3085 kind: InstKind::ConstInt(new_bound as i128, plan.iv_int_width),
3086 ty: int_ty.clone(),
3087 span: plan.span,
3088 },
3089 );
3090
3091 // Rewrite the original icmp's RHS to point at the new bound.
3092 if let Some(inst) = func
3093 .block_mut(plan.header)
3094 .insts
3095 .iter_mut()
3096 .find(|i| i.id == plan.cond_id)
3097 {
3098 if let InstKind::ICmp(_, _, rhs) = &mut inst.kind {
3099 if *rhs == plan.bound_const_id {
3100 *rhs = new_bound_id;
3101 }
3102 }
3103 }
3104
3105 let step_inst_id = plan.step_iadd;
3106 let mut peeled: Vec<Inst> = Vec::new();
3107 let mut peel_ids: HashSet<ValueId> = HashSet::new();
3108 let mut current_acc = acc_scalar;
3109
3110 for t in 0..plan.tail_count {
3111 let tail_iv = plan.iv_init + plan.head_count + t;
3112 let tail_iv_const_id = func.next_value_id();
3113 func.register_type(tail_iv_const_id, int_ty.clone());
3114 peeled.push(Inst {
3115 id: tail_iv_const_id,
3116 kind: InstKind::ConstInt(tail_iv as i128, plan.iv_int_width),
3117 ty: int_ty.clone(),
3118 span: plan.span,
3119 });
3120 peel_ids.insert(tail_iv_const_id);
3121
3122 let mut val_map: HashMap<ValueId, ValueId> = HashMap::new();
3123 val_map.insert(plan.iv_param, tail_iv_const_id);
3124 val_map.insert(plan.acc_param, current_acc);
3125
3126 for inst in body_snapshot {
3127 // Skip the IV step iadd — peel iterations don't bump iv.
3128 if inst.id == step_inst_id {
3129 continue;
3130 }
3131 let new_id = func.next_value_id();
3132 func.register_type(new_id, inst.ty.clone());
3133 let new_kind = remap_inst_kind(&inst.kind, &val_map);
3134 val_map.insert(inst.id, new_id);
3135 peel_ids.insert(new_id);
3136 peeled.push(Inst {
3137 id: new_id,
3138 kind: new_kind,
3139 ty: inst.ty.clone(),
3140 span: inst.span,
3141 });
3142 }
3143 current_acc = val_map[&plan.accumulate_id];
3144 }
3145 let final_acc = current_acc;
3146
3147 // Splice peeled insts into the exit block, just after `acc_scalar`
3148 // (which is at exit[0]).
3149 let exit = func.block_mut(plan.exit);
3150 let acc_pos = exit
3151 .insts
3152 .iter()
3153 .position(|i| i.id == acc_scalar)
3154 .unwrap_or(0);
3155 let after = acc_pos + 1;
3156 let tail = exit.insts.split_off(after);
3157 exit.insts.extend(peeled);
3158 exit.insts.extend(tail);
3159
3160 // Retarget non-loop, non-peel uses of acc_scalar → final_acc.
3161 if final_acc != acc_scalar {
3162 for block in func.blocks.iter_mut() {
3163 if lp_body.contains(&block.id) {
3164 continue;
3165 }
3166 for inst in &mut block.insts {
3167 if peel_ids.contains(&inst.id) || inst.id == acc_scalar {
3168 continue;
3169 }
3170 substitute_in_inst(&mut inst.kind, acc_scalar, final_acc);
3171 }
3172 if let Some(term) = &mut block.terminator {
3173 substitute_in_terminator(term, acc_scalar, final_acc);
3174 }
3175 }
3176 }
3177 }
3178
3179 fn substitute_in_inst(kind: &mut InstKind, from: ValueId, to: ValueId) {
3180 let replace = |v: &mut ValueId| {
3181 if *v == from {
3182 *v = to;
3183 }
3184 };
3185 match kind {
3186 InstKind::Load(p) => replace(p),
3187 InstKind::Store(v, p) => {
3188 replace(v);
3189 replace(p);
3190 }
3191 InstKind::IAdd(a, b)
3192 | InstKind::ISub(a, b)
3193 | InstKind::IMul(a, b)
3194 | InstKind::IDiv(a, b)
3195 | InstKind::FAdd(a, b)
3196 | InstKind::FSub(a, b)
3197 | InstKind::FMul(a, b)
3198 | InstKind::FDiv(a, b) => {
3199 replace(a);
3200 replace(b);
3201 }
3202 InstKind::Call(_, args) | InstKind::RuntimeCall(_, args) => {
3203 for a in args {
3204 replace(a);
3205 }
3206 }
3207 InstKind::ICmp(_, a, b) | InstKind::FCmp(_, a, b) => {
3208 replace(a);
3209 replace(b);
3210 }
3211 InstKind::IntExtend(v, _, _) | InstKind::IntTrunc(v, _) => replace(v),
3212 InstKind::IntToFloat(v, _) | InstKind::FloatToInt(v, _) => replace(v),
3213 InstKind::FloatExtend(v, _) | InstKind::FloatTrunc(v, _) => replace(v),
3214 InstKind::INeg(v) | InstKind::FNeg(v) | InstKind::FAbs(v) => replace(v),
3215 _ => {
3216 // Conservative fallback: walk inst_uses and replace where
3217 // possible. The exact set varies; for the limited cases
3218 // we hit (post-loop scalar use of `acc_param`), the
3219 // explicit arms above are usually enough.
3220 }
3221 }
3222 }
3223
3224 fn substitute_in_terminator(term: &mut Terminator, from: ValueId, to: ValueId) {
3225 let replace = |v: &mut ValueId| {
3226 if *v == from {
3227 *v = to;
3228 }
3229 };
3230 match term {
3231 Terminator::Return(Some(v)) => replace(v),
3232 Terminator::Return(None) => {}
3233 Terminator::Branch(_, args) => {
3234 for a in args {
3235 replace(a);
3236 }
3237 }
3238 Terminator::CondBranch {
3239 cond,
3240 true_args,
3241 false_args,
3242 ..
3243 } => {
3244 replace(cond);
3245 for a in true_args {
3246 replace(a);
3247 }
3248 for a in false_args {
3249 replace(a);
3250 }
3251 }
3252 _ => {}
3253 }
3254 }
3255
3256 #[cfg(test)]
3257 mod tests {
3258 use super::*;
3259 use crate::ir::types::IrType;
3260 use crate::lexer::{Position, Span};
3261 use crate::opt::pass::Pass;
3262
3263 fn dummy_span() -> Span {
3264 let p = Position { line: 0, col: 0 };
3265 Span {
3266 file_id: 0,
3267 start: p,
3268 end: p,
3269 }
3270 }
3271
3272 fn push_inst(func: &mut Function, block: BlockId, kind: InstKind, ty: IrType) -> ValueId {
3273 let id = func.next_value_id();
3274 func.register_type(id, ty.clone());
3275 func.block_mut(block).insts.push(Inst {
3276 id,
3277 kind,
3278 ty,
3279 span: dummy_span(),
3280 });
3281 id
3282 }
3283
3284 /// Build the canonical `c(i) = a(i) + b(i)` loop over i32 arrays of
3285 /// length 32 (trip count divisible by V=4).
3286 fn build_array_add_loop() -> (Module, BlockId) {
3287 let mut module = Module::new("m".into());
3288 let mut func = Function::new("__prog_vec".into(), vec![], IrType::Void);
3289 let entry = func.entry;
3290 let header = func.create_block("do_check");
3291 let body = func.create_block("do_body");
3292 let exit = func.create_block("do_exit");
3293
3294 let arr_ty = IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 32);
3295 let arr_ptr_ty = IrType::Ptr(Box::new(arr_ty.clone()));
3296 let a = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3297 let b = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3298 let c = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3299
3300 let one_i32 = push_inst(
3301 &mut func,
3302 entry,
3303 InstKind::ConstInt(1, IntWidth::I32),
3304 IrType::Int(IntWidth::I32),
3305 );
3306 let hi_i32 = push_inst(
3307 &mut func,
3308 entry,
3309 InstKind::ConstInt(32, IntWidth::I32),
3310 IrType::Int(IntWidth::I32),
3311 );
3312 let one_i64 = push_inst(
3313 &mut func,
3314 entry,
3315 InstKind::ConstInt(1, IntWidth::I64),
3316 IrType::Int(IntWidth::I64),
3317 );
3318 func.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![one_i32]));
3319
3320 let iv = func.next_value_id();
3321 func.register_type(iv, IrType::Int(IntWidth::I32));
3322 func.block_mut(header).params.push(BlockParam {
3323 id: iv,
3324 ty: IrType::Int(IntWidth::I32),
3325 });
3326 let cmp = push_inst(
3327 &mut func,
3328 header,
3329 InstKind::ICmp(CmpOp::Le, iv, hi_i32),
3330 IrType::Bool,
3331 );
3332 func.block_mut(header).terminator = Some(Terminator::CondBranch {
3333 cond: cmp,
3334 true_dest: body,
3335 true_args: vec![],
3336 false_dest: exit,
3337 false_args: vec![],
3338 });
3339
3340 let idx64 = push_inst(
3341 &mut func,
3342 body,
3343 InstKind::IntExtend(iv, IntWidth::I64, true),
3344 IrType::Int(IntWidth::I64),
3345 );
3346 let offset = push_inst(
3347 &mut func,
3348 body,
3349 InstKind::ISub(idx64, one_i64),
3350 IrType::Int(IntWidth::I64),
3351 );
3352 let elem_ptr_ty = IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)));
3353 let a_ptr = push_inst(
3354 &mut func,
3355 body,
3356 InstKind::GetElementPtr(a, vec![offset]),
3357 elem_ptr_ty.clone(),
3358 );
3359 let a_val = push_inst(
3360 &mut func,
3361 body,
3362 InstKind::Load(a_ptr),
3363 IrType::Int(IntWidth::I32),
3364 );
3365 let b_ptr = push_inst(
3366 &mut func,
3367 body,
3368 InstKind::GetElementPtr(b, vec![offset]),
3369 elem_ptr_ty.clone(),
3370 );
3371 let b_val = push_inst(
3372 &mut func,
3373 body,
3374 InstKind::Load(b_ptr),
3375 IrType::Int(IntWidth::I32),
3376 );
3377 let sum = push_inst(
3378 &mut func,
3379 body,
3380 InstKind::IAdd(a_val, b_val),
3381 IrType::Int(IntWidth::I32),
3382 );
3383 let c_ptr = push_inst(
3384 &mut func,
3385 body,
3386 InstKind::GetElementPtr(c, vec![offset]),
3387 elem_ptr_ty.clone(),
3388 );
3389 push_inst(&mut func, body, InstKind::Store(sum, c_ptr), IrType::Void);
3390 let next = push_inst(
3391 &mut func,
3392 body,
3393 InstKind::IAdd(iv, one_i32),
3394 IrType::Int(IntWidth::I32),
3395 );
3396 func.block_mut(body).terminator = Some(Terminator::Branch(header, vec![next]));
3397 func.block_mut(exit).terminator = Some(Terminator::Return(None));
3398 module.add_function(func);
3399 (module, body)
3400 }
3401
3402 #[test]
3403 fn rewrites_array_add_loop_to_vload_vadd_vstore() {
3404 let (mut module, body) = build_array_add_loop();
3405 let changed = NeonVectorize.run(&mut module);
3406 assert!(changed, "neon_vectorize should fire on a clean array-add loop");
3407
3408 let func = &module.functions[0];
3409 let body_block = func.block(body);
3410
3411 let n_vload = body_block
3412 .insts
3413 .iter()
3414 .filter(|i| matches!(i.kind, InstKind::VLoad(_)))
3415 .count();
3416 assert_eq!(n_vload, 2, "two scalar Loads should become VLoads");
3417
3418 let n_vadd = body_block
3419 .insts
3420 .iter()
3421 .filter(|i| matches!(i.kind, InstKind::VAdd(..)))
3422 .count();
3423 assert_eq!(n_vadd, 1, "the IAdd should become a VAdd");
3424
3425 let n_vstore = body_block
3426 .insts
3427 .iter()
3428 .filter(|i| matches!(i.kind, InstKind::VStore(..)))
3429 .count();
3430 assert_eq!(n_vstore, 1, "the Store should become a VStore");
3431
3432 // The loaded values should now have vector type.
3433 for inst in &body_block.insts {
3434 if let InstKind::VLoad(_) = inst.kind {
3435 assert_eq!(
3436 inst.ty,
3437 IrType::Vector { lanes: 4, elem: Box::new(IrType::Int(IntWidth::I32)) }
3438 );
3439 }
3440 }
3441
3442 // The IV step should now use a ConstInt(4) somewhere in the body.
3443 let has_v_step = body_block.insts.iter().any(|i| {
3444 matches!(i.kind, InstKind::ConstInt(4, IntWidth::I32))
3445 });
3446 assert!(has_v_step, "step should now be ConstInt(4)");
3447 }
3448
3449 /// Build `c(i) = a(i) + scale` over i32(32) where `scale` is a
3450 /// loop-invariant ConstInt defined in the entry/preheader. The
3451 /// vectorizer should classify `scale` as `InvariantScalar`, hoist
3452 /// a `VBroadcast` into the preheader, and rewrite the binop to
3453 /// consume the broadcast vector.
3454 fn build_array_add_scalar_loop() -> (Module, BlockId, BlockId) {
3455 let mut module = Module::new("m".into());
3456 let mut func = Function::new("__prog_vec".into(), vec![], IrType::Void);
3457 let entry = func.entry;
3458 let header = func.create_block("do_check");
3459 let body = func.create_block("do_body");
3460 let exit = func.create_block("do_exit");
3461
3462 let arr_ty = IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 32);
3463 let arr_ptr_ty = IrType::Ptr(Box::new(arr_ty.clone()));
3464 let a = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3465 let c = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3466
3467 let scale = push_inst(
3468 &mut func,
3469 entry,
3470 InstKind::ConstInt(7, IntWidth::I32),
3471 IrType::Int(IntWidth::I32),
3472 );
3473 let one_i32 = push_inst(
3474 &mut func,
3475 entry,
3476 InstKind::ConstInt(1, IntWidth::I32),
3477 IrType::Int(IntWidth::I32),
3478 );
3479 let hi_i32 = push_inst(
3480 &mut func,
3481 entry,
3482 InstKind::ConstInt(32, IntWidth::I32),
3483 IrType::Int(IntWidth::I32),
3484 );
3485 let one_i64 = push_inst(
3486 &mut func,
3487 entry,
3488 InstKind::ConstInt(1, IntWidth::I64),
3489 IrType::Int(IntWidth::I64),
3490 );
3491 func.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![one_i32]));
3492
3493 let iv = func.next_value_id();
3494 func.register_type(iv, IrType::Int(IntWidth::I32));
3495 func.block_mut(header).params.push(BlockParam {
3496 id: iv,
3497 ty: IrType::Int(IntWidth::I32),
3498 });
3499 let cmp = push_inst(
3500 &mut func,
3501 header,
3502 InstKind::ICmp(CmpOp::Le, iv, hi_i32),
3503 IrType::Bool,
3504 );
3505 func.block_mut(header).terminator = Some(Terminator::CondBranch {
3506 cond: cmp,
3507 true_dest: body,
3508 true_args: vec![],
3509 false_dest: exit,
3510 false_args: vec![],
3511 });
3512
3513 let idx64 = push_inst(
3514 &mut func,
3515 body,
3516 InstKind::IntExtend(iv, IntWidth::I64, true),
3517 IrType::Int(IntWidth::I64),
3518 );
3519 let offset = push_inst(
3520 &mut func,
3521 body,
3522 InstKind::ISub(idx64, one_i64),
3523 IrType::Int(IntWidth::I64),
3524 );
3525 let elem_ptr_ty = IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)));
3526 let a_ptr = push_inst(
3527 &mut func,
3528 body,
3529 InstKind::GetElementPtr(a, vec![offset]),
3530 elem_ptr_ty.clone(),
3531 );
3532 let a_val = push_inst(
3533 &mut func,
3534 body,
3535 InstKind::Load(a_ptr),
3536 IrType::Int(IntWidth::I32),
3537 );
3538 let sum = push_inst(
3539 &mut func,
3540 body,
3541 InstKind::IAdd(a_val, scale),
3542 IrType::Int(IntWidth::I32),
3543 );
3544 let c_ptr = push_inst(
3545 &mut func,
3546 body,
3547 InstKind::GetElementPtr(c, vec![offset]),
3548 elem_ptr_ty.clone(),
3549 );
3550 push_inst(&mut func, body, InstKind::Store(sum, c_ptr), IrType::Void);
3551 let next = push_inst(
3552 &mut func,
3553 body,
3554 InstKind::IAdd(iv, one_i32),
3555 IrType::Int(IntWidth::I32),
3556 );
3557 func.block_mut(body).terminator = Some(Terminator::Branch(header, vec![next]));
3558 func.block_mut(exit).terminator = Some(Terminator::Return(None));
3559 module.add_function(func);
3560 (module, entry, body)
3561 }
3562
3563 #[test]
3564 fn broadcasts_invariant_scalar_into_preheader() {
3565 let (mut module, preheader, body) = build_array_add_scalar_loop();
3566 let changed = NeonVectorize.run(&mut module);
3567 assert!(
3568 changed,
3569 "neon_vectorize should fire on a(i) + invariant scalar"
3570 );
3571
3572 let func = &module.functions[0];
3573 let pre_block = func.block(preheader);
3574 let body_block = func.block(body);
3575
3576 let n_vbroadcast = pre_block
3577 .insts
3578 .iter()
3579 .filter(|i| matches!(i.kind, InstKind::VBroadcast(_)))
3580 .count();
3581 assert_eq!(
3582 n_vbroadcast, 1,
3583 "the invariant scalar should be broadcast once in the preheader"
3584 );
3585
3586 let n_vload = body_block
3587 .insts
3588 .iter()
3589 .filter(|i| matches!(i.kind, InstKind::VLoad(_)))
3590 .count();
3591 assert_eq!(n_vload, 1, "only the array operand becomes a VLoad");
3592
3593 let n_vadd = body_block
3594 .insts
3595 .iter()
3596 .filter(|i| matches!(i.kind, InstKind::VAdd(..)))
3597 .count();
3598 assert_eq!(n_vadd, 1, "the IAdd should become a VAdd");
3599
3600 let n_vstore = body_block
3601 .insts
3602 .iter()
3603 .filter(|i| matches!(i.kind, InstKind::VStore(..)))
3604 .count();
3605 assert_eq!(n_vstore, 1, "the Store should become a VStore");
3606 }
3607
3608 /// Build `c(i) = b(i)` over i32(32) — a pure array copy with no
3609 /// arithmetic between the load and the store.
3610 fn build_array_copy_loop() -> (Module, BlockId) {
3611 let mut module = Module::new("m".into());
3612 let mut func = Function::new("__prog_vec".into(), vec![], IrType::Void);
3613 let entry = func.entry;
3614 let header = func.create_block("do_check");
3615 let body = func.create_block("do_body");
3616 let exit = func.create_block("do_exit");
3617
3618 let arr_ty = IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 32);
3619 let arr_ptr_ty = IrType::Ptr(Box::new(arr_ty.clone()));
3620 let b = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3621 let c = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3622
3623 let one_i32 = push_inst(
3624 &mut func,
3625 entry,
3626 InstKind::ConstInt(1, IntWidth::I32),
3627 IrType::Int(IntWidth::I32),
3628 );
3629 let hi_i32 = push_inst(
3630 &mut func,
3631 entry,
3632 InstKind::ConstInt(32, IntWidth::I32),
3633 IrType::Int(IntWidth::I32),
3634 );
3635 let one_i64 = push_inst(
3636 &mut func,
3637 entry,
3638 InstKind::ConstInt(1, IntWidth::I64),
3639 IrType::Int(IntWidth::I64),
3640 );
3641 func.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![one_i32]));
3642
3643 let iv = func.next_value_id();
3644 func.register_type(iv, IrType::Int(IntWidth::I32));
3645 func.block_mut(header).params.push(BlockParam {
3646 id: iv,
3647 ty: IrType::Int(IntWidth::I32),
3648 });
3649 let cmp = push_inst(
3650 &mut func,
3651 header,
3652 InstKind::ICmp(CmpOp::Le, iv, hi_i32),
3653 IrType::Bool,
3654 );
3655 func.block_mut(header).terminator = Some(Terminator::CondBranch {
3656 cond: cmp,
3657 true_dest: body,
3658 true_args: vec![],
3659 false_dest: exit,
3660 false_args: vec![],
3661 });
3662
3663 let idx64 = push_inst(
3664 &mut func,
3665 body,
3666 InstKind::IntExtend(iv, IntWidth::I64, true),
3667 IrType::Int(IntWidth::I64),
3668 );
3669 let offset = push_inst(
3670 &mut func,
3671 body,
3672 InstKind::ISub(idx64, one_i64),
3673 IrType::Int(IntWidth::I64),
3674 );
3675 let elem_ptr_ty = IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)));
3676 let b_ptr = push_inst(
3677 &mut func,
3678 body,
3679 InstKind::GetElementPtr(b, vec![offset]),
3680 elem_ptr_ty.clone(),
3681 );
3682 let b_val = push_inst(
3683 &mut func,
3684 body,
3685 InstKind::Load(b_ptr),
3686 IrType::Int(IntWidth::I32),
3687 );
3688 let c_ptr = push_inst(
3689 &mut func,
3690 body,
3691 InstKind::GetElementPtr(c, vec![offset]),
3692 elem_ptr_ty.clone(),
3693 );
3694 push_inst(&mut func, body, InstKind::Store(b_val, c_ptr), IrType::Void);
3695 let next = push_inst(
3696 &mut func,
3697 body,
3698 InstKind::IAdd(iv, one_i32),
3699 IrType::Int(IntWidth::I32),
3700 );
3701 func.block_mut(body).terminator = Some(Terminator::Branch(header, vec![next]));
3702 func.block_mut(exit).terminator = Some(Terminator::Return(None));
3703 module.add_function(func);
3704 (module, body)
3705 }
3706
3707 #[test]
3708 fn rewrites_pure_array_copy_to_vload_vstore() {
3709 let (mut module, body) = build_array_copy_loop();
3710 let changed = NeonVectorize.run(&mut module);
3711 assert!(changed, "neon_vectorize should fire on a pure copy loop");
3712
3713 let func = &module.functions[0];
3714 let body_block = func.block(body);
3715
3716 let n_vload = body_block
3717 .insts
3718 .iter()
3719 .filter(|i| matches!(i.kind, InstKind::VLoad(_)))
3720 .count();
3721 assert_eq!(n_vload, 1, "the single Load becomes a VLoad");
3722
3723 let n_vstore = body_block
3724 .insts
3725 .iter()
3726 .filter(|i| matches!(i.kind, InstKind::VStore(..)))
3727 .count();
3728 assert_eq!(n_vstore, 1, "the Store becomes a VStore");
3729
3730 // No binop should appear — pure copy has none.
3731 let n_binop = body_block
3732 .insts
3733 .iter()
3734 .filter(|i| {
3735 matches!(
3736 i.kind,
3737 InstKind::VAdd(..) | InstKind::VSub(..) | InstKind::VMul(..)
3738 )
3739 })
3740 .count();
3741 assert_eq!(n_binop, 0, "pure copy must not introduce a v-binop");
3742 }
3743
3744 #[test]
3745 fn peels_scalar_tail_for_non_divisible_trip_count() {
3746 // length 31 → not divisible by V=4. The pass vectorizes 28
3747 // iterations (head_count = 7 × 4) and peels 3 scalar
3748 // iterations into the exit block.
3749 let mut module = Module::new("m".into());
3750 let mut func = Function::new("__prog_vec".into(), vec![], IrType::Void);
3751 let entry = func.entry;
3752 let header = func.create_block("do_check");
3753 let body = func.create_block("do_body");
3754 let exit = func.create_block("do_exit");
3755
3756 let arr_ty = IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 31);
3757 let arr_ptr_ty = IrType::Ptr(Box::new(arr_ty.clone()));
3758 let a = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3759 let b = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3760 let c = push_inst(&mut func, entry, InstKind::Alloca(arr_ty.clone()), arr_ptr_ty.clone());
3761
3762 let one_i32 = push_inst(
3763 &mut func,
3764 entry,
3765 InstKind::ConstInt(1, IntWidth::I32),
3766 IrType::Int(IntWidth::I32),
3767 );
3768 let hi_i32 = push_inst(
3769 &mut func,
3770 entry,
3771 InstKind::ConstInt(31, IntWidth::I32),
3772 IrType::Int(IntWidth::I32),
3773 );
3774 let one_i64 = push_inst(
3775 &mut func,
3776 entry,
3777 InstKind::ConstInt(1, IntWidth::I64),
3778 IrType::Int(IntWidth::I64),
3779 );
3780 func.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![one_i32]));
3781
3782 let iv = func.next_value_id();
3783 func.register_type(iv, IrType::Int(IntWidth::I32));
3784 func.block_mut(header).params.push(BlockParam {
3785 id: iv,
3786 ty: IrType::Int(IntWidth::I32),
3787 });
3788 let cmp = push_inst(
3789 &mut func,
3790 header,
3791 InstKind::ICmp(CmpOp::Le, iv, hi_i32),
3792 IrType::Bool,
3793 );
3794 func.block_mut(header).terminator = Some(Terminator::CondBranch {
3795 cond: cmp,
3796 true_dest: body,
3797 true_args: vec![],
3798 false_dest: exit,
3799 false_args: vec![],
3800 });
3801
3802 let idx64 = push_inst(
3803 &mut func,
3804 body,
3805 InstKind::IntExtend(iv, IntWidth::I64, true),
3806 IrType::Int(IntWidth::I64),
3807 );
3808 let offset = push_inst(
3809 &mut func,
3810 body,
3811 InstKind::ISub(idx64, one_i64),
3812 IrType::Int(IntWidth::I64),
3813 );
3814 let elem_ptr_ty = IrType::Ptr(Box::new(IrType::Int(IntWidth::I32)));
3815 let a_ptr = push_inst(
3816 &mut func,
3817 body,
3818 InstKind::GetElementPtr(a, vec![offset]),
3819 elem_ptr_ty.clone(),
3820 );
3821 let a_val = push_inst(
3822 &mut func,
3823 body,
3824 InstKind::Load(a_ptr),
3825 IrType::Int(IntWidth::I32),
3826 );
3827 let b_ptr = push_inst(
3828 &mut func,
3829 body,
3830 InstKind::GetElementPtr(b, vec![offset]),
3831 elem_ptr_ty.clone(),
3832 );
3833 let b_val = push_inst(
3834 &mut func,
3835 body,
3836 InstKind::Load(b_ptr),
3837 IrType::Int(IntWidth::I32),
3838 );
3839 let sum = push_inst(
3840 &mut func,
3841 body,
3842 InstKind::IAdd(a_val, b_val),
3843 IrType::Int(IntWidth::I32),
3844 );
3845 let c_ptr = push_inst(
3846 &mut func,
3847 body,
3848 InstKind::GetElementPtr(c, vec![offset]),
3849 elem_ptr_ty.clone(),
3850 );
3851 push_inst(&mut func, body, InstKind::Store(sum, c_ptr), IrType::Void);
3852 let next = push_inst(
3853 &mut func,
3854 body,
3855 InstKind::IAdd(iv, one_i32),
3856 IrType::Int(IntWidth::I32),
3857 );
3858 func.block_mut(body).terminator = Some(Terminator::Branch(header, vec![next]));
3859 func.block_mut(exit).terminator = Some(Terminator::Return(None));
3860 module.add_function(func);
3861
3862 let changed = NeonVectorize.run(&mut module);
3863 assert!(changed, "scalar tail should let the head vectorize");
3864
3865 let func = &module.functions[0];
3866 let body_block = func.block(body);
3867
3868 // Body has at least one VLoad and one VStore (vectorized head).
3869 let n_vload = body_block
3870 .insts
3871 .iter()
3872 .filter(|i| matches!(i.kind, InstKind::VLoad(_)))
3873 .count();
3874 assert!(n_vload >= 2, "two array loads should become VLoads");
3875 let n_vstore = body_block
3876 .insts
3877 .iter()
3878 .filter(|i| matches!(i.kind, InstKind::VStore(..)))
3879 .count();
3880 assert_eq!(n_vstore, 1, "the destination store should become a VStore");
3881
3882 // Exit block has 3 peeled scalar Stores (one per tail iter).
3883 let exit_block = func.block(exit);
3884 let exit_stores = exit_block
3885 .insts
3886 .iter()
3887 .filter(|i| matches!(i.kind, InstKind::Store(..)))
3888 .count();
3889 assert_eq!(
3890 exit_stores, 3,
3891 "three scalar stores should be peeled into the exit block"
3892 );
3893 }
3894 }
3895