Rust · 23933 bytes Raw Blame History
1 //! Simple loop vectorization onto bulk runtime kernels.
2 //!
3 //! This pass recognizes a narrow but high-value class of ordinary
4 //! counted `DO` loops over fixed-size local arrays and rewrites them to
5 //! the existing SIMD-backed runtime bulk kernels used by whole-array
6 //! lowering. That makes `-O3` / `-Ofast`'s vectorization claim honest
7 //! even when the source used an explicit scalar loop instead of a whole-
8 //! array assignment or `DO CONCURRENT`.
9
10 use std::collections::{HashMap, HashSet};
11
12 use crate::ir::inst::*;
13 use crate::ir::types::{FloatWidth, IntWidth, IrType};
14 use crate::ir::walk::prune_unreachable;
15
16 use super::loop_utils::{find_preheader, loop_defined_values, resolve_const_int};
17 use super::pass::Pass;
18 use super::util::{find_natural_loops, inst_uses, predecessors, terminator_uses, NaturalLoop};
19
20 pub struct Vectorize;
21
22 impl Pass for Vectorize {
23 fn name(&self) -> &'static str {
24 "vectorize"
25 }
26
27 fn run(&self, module: &mut Module) -> bool {
28 let mut changed = false;
29 for func in &mut module.functions {
30 if vectorize_function(func) {
31 changed = true;
32 }
33 }
34 changed
35 }
36 }
37
38 #[derive(Debug, Clone, Copy)]
39 struct CountedLoop {
40 preheader: BlockId,
41 header: BlockId,
42 body: BlockId,
43 exit: BlockId,
44 iv_param: ValueId,
45 iv_init: i64,
46 iv_bound: i64,
47 }
48
49 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
50 enum BinaryKind {
51 Add,
52 Sub,
53 Mul,
54 }
55
56 #[derive(Debug, Clone)]
57 struct ArrayAccess {
58 base: ValueId,
59 elem_ty: IrType,
60 len: u64,
61 lower: i64,
62 }
63
64 #[derive(Debug, Clone)]
65 enum OperandClass {
66 Array(ArrayAccess),
67 Scalar(ValueId),
68 }
69
70 #[derive(Debug, Clone, Copy)]
71 enum KernelPlan {
72 Fill {
73 kernel: &'static str,
74 dest: ValueId,
75 len: u64,
76 scalar: ValueId,
77 },
78 ArrayBinary {
79 kernel: &'static str,
80 dest: ValueId,
81 lhs: ValueId,
82 rhs: ValueId,
83 len: u64,
84 },
85 ArrayScalar {
86 kernel: &'static str,
87 dest: ValueId,
88 src: ValueId,
89 scalar: ValueId,
90 len: u64,
91 },
92 ScalarArray {
93 kernel: &'static str,
94 dest: ValueId,
95 scalar: ValueId,
96 src: ValueId,
97 len: u64,
98 },
99 }
100
101 fn vectorize_function(func: &mut Function) -> bool {
102 let loops = find_natural_loops(func);
103 if loops.is_empty() {
104 return false;
105 }
106 let preds = predecessors(func);
107
108 for lp in &loops {
109 let Some(shape) = detect_counted_loop(func, lp, &preds) else {
110 continue;
111 };
112 let loop_defs = loop_defined_values(func, lp);
113 if loop_values_escape(func, lp, &loop_defs) {
114 continue;
115 }
116 let Some((plan, span)) = build_kernel_plan(func, lp, &shape, &loop_defs) else {
117 continue;
118 };
119 apply_kernel_plan(func, shape, plan, span);
120 prune_unreachable(func);
121 return true;
122 }
123
124 false
125 }
126
127 fn detect_counted_loop(
128 func: &Function,
129 lp: &NaturalLoop,
130 preds: &HashMap<BlockId, Vec<BlockId>>,
131 ) -> Option<CountedLoop> {
132 if lp.latches.len() != 1 || lp.body.len() != 2 {
133 return None;
134 }
135
136 let header = lp.header;
137 let body = lp.latches[0];
138 if body == header {
139 return None;
140 }
141
142 let header_block = func.block(header);
143 if header_block.params.len() != 1 {
144 return None;
145 }
146 let iv_param = header_block.params[0].id;
147 if !matches!(header_block.params[0].ty, IrType::Int(_)) {
148 return None;
149 }
150
151 let preheader = find_preheader(func, lp, preds)?;
152 let iv_init = match &func.block(preheader).terminator {
153 Some(Terminator::Branch(dest, args)) if *dest == header && args.len() == 1 => {
154 resolve_const_int(func, args[0])?
155 }
156 _ => return None,
157 };
158
159 let (cond_id, true_dest, false_dest, true_args, false_args) = match &header_block.terminator {
160 Some(Terminator::CondBranch {
161 cond,
162 true_dest,
163 true_args,
164 false_dest,
165 false_args,
166 }) => (*cond, *true_dest, *false_dest, true_args, false_args),
167 _ => return None,
168 };
169 if !true_args.is_empty()
170 || !false_args.is_empty()
171 || true_dest != body
172 || lp.body.contains(&false_dest)
173 {
174 return None;
175 }
176
177 let cond_inst = header_block.insts.iter().find(|inst| inst.id == cond_id)?;
178 let iv_bound = match cond_inst.kind {
179 InstKind::ICmp(CmpOp::Le, lhs, rhs) if lhs == iv_param => resolve_const_int(func, rhs)?,
180 InstKind::ICmp(CmpOp::Lt, lhs, rhs) if lhs == iv_param => {
181 resolve_const_int(func, rhs)?.checked_sub(1)?
182 }
183 _ => return None,
184 };
185
186 let body_block = func.block(body);
187 let next = match &body_block.terminator {
188 Some(Terminator::Branch(dest, args)) if *dest == header && args.len() == 1 => args[0],
189 _ => return None,
190 };
191 if step_size(func, iv_param, next)? != 1 {
192 return None;
193 }
194
195 Some(CountedLoop {
196 preheader,
197 header,
198 body,
199 exit: false_dest,
200 iv_param,
201 iv_init,
202 iv_bound,
203 })
204 }
205
206 fn step_size(func: &Function, iv_param: ValueId, next: ValueId) -> Option<i64> {
207 if next == iv_param {
208 return Some(0);
209 }
210 let defs = inst_map(func);
211 let inst = defs.get(&next)?;
212 match inst.kind {
213 InstKind::IAdd(lhs, rhs) => {
214 if lhs == iv_param {
215 resolve_const_int(func, rhs)
216 } else if rhs == iv_param {
217 resolve_const_int(func, lhs)
218 } else {
219 None
220 }
221 }
222 InstKind::ISub(lhs, rhs) if lhs == iv_param => resolve_const_int(func, rhs).map(|v| -v),
223 _ => None,
224 }
225 }
226
227 fn build_kernel_plan(
228 func: &Function,
229 lp: &NaturalLoop,
230 shape: &CountedLoop,
231 loop_defs: &HashSet<ValueId>,
232 ) -> Option<(KernelPlan, crate::lexer::Span)> {
233 let body = func.block(shape.body);
234 if body
235 .insts
236 .iter()
237 .any(|inst| matches!(inst.kind, InstKind::Call(..) | InstKind::RuntimeCall(..)))
238 {
239 return None;
240 }
241
242 let mut stores = body.insts.iter().filter_map(|inst| match inst.kind {
243 InstKind::Store(value, ptr) => Some((inst.span, value, ptr)),
244 _ => None,
245 });
246 let (span, stored_value, dest_ptr) = stores.next()?;
247 if stores.next().is_some() {
248 return None;
249 }
250
251 let dest = classify_array_access(func, dest_ptr, shape.iv_param)?;
252 if !covers_full_array(shape, &dest) {
253 return None;
254 }
255
256 let plan = if let Some(scalar) =
257 classify_invariant_scalar(func, loop_defs, stored_value, &dest.elem_ty)
258 {
259 let kernel = fill_kernel_name(&dest.elem_ty)?;
260 KernelPlan::Fill {
261 kernel,
262 dest: dest.base,
263 len: dest.len,
264 scalar,
265 }
266 } else {
267 classify_store_value(func, lp, shape, loop_defs, stored_value, dest)?
268 };
269
270 Some((plan, span))
271 }
272
273 fn covers_full_array(shape: &CountedLoop, access: &ArrayAccess) -> bool {
274 if access.len == 0 {
275 return false;
276 }
277 let Some(upper) = access
278 .lower
279 .checked_add(access.len as i64)
280 .and_then(|value| value.checked_sub(1))
281 else {
282 return false;
283 };
284 shape.iv_init == access.lower && shape.iv_bound == upper
285 }
286
287 fn classify_store_value(
288 func: &Function,
289 _lp: &NaturalLoop,
290 shape: &CountedLoop,
291 loop_defs: &HashSet<ValueId>,
292 value: ValueId,
293 dest: ArrayAccess,
294 ) -> Option<KernelPlan> {
295 let defs = inst_map(func);
296 let inst = defs.get(&value)?;
297 let (op, lhs, rhs) = match inst.kind {
298 InstKind::IAdd(lhs, rhs) => (BinaryKind::Add, lhs, rhs),
299 InstKind::ISub(lhs, rhs) => (BinaryKind::Sub, lhs, rhs),
300 InstKind::IMul(lhs, rhs) => (BinaryKind::Mul, lhs, rhs),
301 InstKind::FAdd(lhs, rhs) => (BinaryKind::Add, lhs, rhs),
302 InstKind::FSub(lhs, rhs) => (BinaryKind::Sub, lhs, rhs),
303 InstKind::FMul(lhs, rhs) => (BinaryKind::Mul, lhs, rhs),
304 _ => return None,
305 };
306
307 let lhs = classify_operand(func, loop_defs, lhs, &dest.elem_ty, shape.iv_param)?;
308 let rhs = classify_operand(func, loop_defs, rhs, &dest.elem_ty, shape.iv_param)?;
309
310 match (lhs, rhs) {
311 (OperandClass::Array(lhs), OperandClass::Array(rhs))
312 if arrays_compatible(&dest, &lhs) && arrays_compatible(&dest, &rhs) =>
313 {
314 let kernel = array_binary_kernel_name(op, &dest.elem_ty)?;
315 Some(KernelPlan::ArrayBinary {
316 kernel,
317 dest: dest.base,
318 lhs: lhs.base,
319 rhs: rhs.base,
320 len: dest.len,
321 })
322 }
323 (OperandClass::Array(src), OperandClass::Scalar(scalar))
324 if arrays_compatible(&dest, &src) =>
325 {
326 let kernel = array_scalar_kernel_name(op, &dest.elem_ty)?;
327 Some(KernelPlan::ArrayScalar {
328 kernel,
329 dest: dest.base,
330 src: src.base,
331 scalar,
332 len: dest.len,
333 })
334 }
335 (OperandClass::Scalar(scalar), OperandClass::Array(src))
336 if arrays_compatible(&dest, &src) =>
337 {
338 match op {
339 BinaryKind::Add | BinaryKind::Mul => {
340 let kernel = array_scalar_kernel_name(op, &dest.elem_ty)?;
341 Some(KernelPlan::ArrayScalar {
342 kernel,
343 dest: dest.base,
344 src: src.base,
345 scalar,
346 len: dest.len,
347 })
348 }
349 BinaryKind::Sub => {
350 let kernel = scalar_array_kernel_name(op, &dest.elem_ty)?;
351 Some(KernelPlan::ScalarArray {
352 kernel,
353 dest: dest.base,
354 scalar,
355 src: src.base,
356 len: dest.len,
357 })
358 }
359 }
360 }
361 _ => None,
362 }
363 }
364
365 fn classify_operand(
366 func: &Function,
367 loop_defs: &HashSet<ValueId>,
368 value: ValueId,
369 elem_ty: &IrType,
370 iv_param: ValueId,
371 ) -> Option<OperandClass> {
372 if let Some(access) = classify_loaded_array(func, value, iv_param) {
373 if access.elem_ty == *elem_ty {
374 return Some(OperandClass::Array(access));
375 }
376 }
377 classify_invariant_scalar(func, loop_defs, value, elem_ty).map(OperandClass::Scalar)
378 }
379
380 fn classify_invariant_scalar(
381 func: &Function,
382 loop_defs: &HashSet<ValueId>,
383 value: ValueId,
384 elem_ty: &IrType,
385 ) -> Option<ValueId> {
386 if loop_defs.contains(&value) {
387 return None;
388 }
389 (func.value_type(value).as_ref() == Some(elem_ty)).then_some(value)
390 }
391
392 fn classify_loaded_array(
393 func: &Function,
394 value: ValueId,
395 iv_param: ValueId,
396 ) -> Option<ArrayAccess> {
397 let defs = inst_map(func);
398 let inst = defs.get(&value)?;
399 let InstKind::Load(ptr) = inst.kind else {
400 return None;
401 };
402 classify_array_access(func, ptr, iv_param)
403 }
404
405 fn classify_array_access(func: &Function, ptr: ValueId, iv_param: ValueId) -> Option<ArrayAccess> {
406 let defs = inst_map(func);
407 let inst = defs.get(&ptr)?;
408 let InstKind::GetElementPtr(base, ref indices) = inst.kind else {
409 return None;
410 };
411 if indices.len() != 1 {
412 return None;
413 }
414 let IrType::Ptr(inner) = func.value_type(base)? else {
415 return None;
416 };
417 let IrType::Array(elem, len) = inner.as_ref() else {
418 return None;
419 };
420 let lower = normalized_index_lower(func, indices[0], iv_param)?;
421 Some(ArrayAccess {
422 base,
423 elem_ty: elem.as_ref().clone(),
424 len: *len,
425 lower,
426 })
427 }
428
429 fn normalized_index_lower(func: &Function, value: ValueId, iv_param: ValueId) -> Option<i64> {
430 if value == iv_param {
431 return Some(0);
432 }
433
434 let defs = inst_map(func);
435 let inst = defs.get(&value)?;
436 match inst.kind {
437 InstKind::IntExtend(src, IntWidth::I64, _) if src == iv_param => Some(0),
438 InstKind::ISub(lhs, rhs) => {
439 let lhs_lower = normalized_index_lower(func, lhs, iv_param)?;
440 let rhs_const = resolve_const_int(func, rhs)?;
441 lhs_lower.checked_add(rhs_const)
442 }
443 _ => None,
444 }
445 }
446
447 fn arrays_compatible(dest: &ArrayAccess, other: &ArrayAccess) -> bool {
448 dest.elem_ty == other.elem_ty && dest.len == other.len && dest.lower == other.lower
449 }
450
451 fn loop_values_escape(func: &Function, lp: &NaturalLoop, loop_defs: &HashSet<ValueId>) -> bool {
452 for block in &func.blocks {
453 if lp.body.contains(&block.id) {
454 continue;
455 }
456 if block.insts.iter().any(|inst| {
457 inst_uses(&inst.kind)
458 .into_iter()
459 .any(|value| loop_defs.contains(&value))
460 }) {
461 return true;
462 }
463 if block.terminator.as_ref().is_some_and(|term| {
464 terminator_uses(term)
465 .into_iter()
466 .any(|value| loop_defs.contains(&value))
467 }) {
468 return true;
469 }
470 }
471 false
472 }
473
474 fn apply_kernel_plan(
475 func: &mut Function,
476 shape: CountedLoop,
477 plan: KernelPlan,
478 span: crate::lexer::Span,
479 ) {
480 let n_id = ensure_i64_const(func, shape.preheader, kernel_len(plan) as i64, span);
481 let (kernel, args) = match plan {
482 KernelPlan::Fill {
483 kernel,
484 dest,
485 scalar,
486 ..
487 } => (kernel, vec![dest, n_id, scalar]),
488 KernelPlan::ArrayBinary {
489 kernel,
490 dest,
491 lhs,
492 rhs,
493 ..
494 } => (kernel, vec![dest, lhs, rhs, n_id]),
495 KernelPlan::ArrayScalar {
496 kernel,
497 dest,
498 src,
499 scalar,
500 ..
501 } => (kernel, vec![dest, src, scalar, n_id]),
502 KernelPlan::ScalarArray {
503 kernel,
504 dest,
505 scalar,
506 src,
507 ..
508 } => (kernel, vec![dest, scalar, src, n_id]),
509 };
510
511 let id = func.next_value_id();
512 func.register_type(id, IrType::Void);
513 let inst = Inst {
514 id,
515 kind: InstKind::Call(FuncRef::External(kernel.into()), args),
516 ty: IrType::Void,
517 span,
518 };
519
520 let preheader = func.block_mut(shape.preheader);
521 preheader.insts.push(inst);
522 preheader.terminator = Some(Terminator::Branch(shape.exit, vec![]));
523 }
524
525 fn kernel_len(plan: KernelPlan) -> u64 {
526 match plan {
527 KernelPlan::Fill { len, .. }
528 | KernelPlan::ArrayBinary { len, .. }
529 | KernelPlan::ArrayScalar { len, .. }
530 | KernelPlan::ScalarArray { len, .. } => len,
531 }
532 }
533
534 fn ensure_i64_const(
535 func: &mut Function,
536 block_id: BlockId,
537 value: i64,
538 span: crate::lexer::Span,
539 ) -> ValueId {
540 let id = func.next_value_id();
541 let ty = IrType::Int(IntWidth::I64);
542 func.register_type(id, ty.clone());
543 func.block_mut(block_id).insts.push(Inst {
544 id,
545 kind: InstKind::ConstInt(value as i128, IntWidth::I64),
546 ty,
547 span,
548 });
549 id
550 }
551
552 fn inst_map(func: &Function) -> HashMap<ValueId, &Inst> {
553 func.blocks
554 .iter()
555 .flat_map(|block| block.insts.iter())
556 .map(|inst| (inst.id, inst))
557 .collect()
558 }
559
560 fn fill_kernel_name(ty: &IrType) -> Option<&'static str> {
561 match ty {
562 IrType::Int(IntWidth::I32) => Some("afs_fill_i32"),
563 IrType::Float(FloatWidth::F32) => Some("afs_fill_f32"),
564 IrType::Float(FloatWidth::F64) => Some("afs_fill_f64"),
565 _ => None,
566 }
567 }
568
569 fn array_binary_kernel_name(kind: BinaryKind, ty: &IrType) -> Option<&'static str> {
570 match (kind, ty) {
571 (BinaryKind::Add, IrType::Int(IntWidth::I32)) => Some("afs_array_add_i32"),
572 (BinaryKind::Add, IrType::Float(FloatWidth::F32)) => Some("afs_array_add_f32"),
573 (BinaryKind::Add, IrType::Float(FloatWidth::F64)) => Some("afs_array_add_f64"),
574 (BinaryKind::Sub, IrType::Int(IntWidth::I32)) => Some("afs_array_sub_i32"),
575 (BinaryKind::Sub, IrType::Float(FloatWidth::F32)) => Some("afs_array_sub_f32"),
576 (BinaryKind::Sub, IrType::Float(FloatWidth::F64)) => Some("afs_array_sub_f64"),
577 (BinaryKind::Mul, IrType::Int(IntWidth::I32)) => Some("afs_array_mul_i32"),
578 (BinaryKind::Mul, IrType::Float(FloatWidth::F32)) => Some("afs_array_mul_f32"),
579 (BinaryKind::Mul, IrType::Float(FloatWidth::F64)) => Some("afs_array_mul_f64"),
580 _ => None,
581 }
582 }
583
584 fn array_scalar_kernel_name(kind: BinaryKind, ty: &IrType) -> Option<&'static str> {
585 match (kind, ty) {
586 (BinaryKind::Add, IrType::Int(IntWidth::I32)) => Some("afs_array_add_scalar_i32"),
587 (BinaryKind::Add, IrType::Float(FloatWidth::F32)) => Some("afs_array_add_scalar_f32"),
588 (BinaryKind::Add, IrType::Float(FloatWidth::F64)) => Some("afs_array_add_scalar_f64"),
589 (BinaryKind::Sub, IrType::Int(IntWidth::I32)) => Some("afs_array_sub_scalar_i32"),
590 (BinaryKind::Sub, IrType::Float(FloatWidth::F32)) => Some("afs_array_sub_scalar_f32"),
591 (BinaryKind::Sub, IrType::Float(FloatWidth::F64)) => Some("afs_array_sub_scalar_f64"),
592 (BinaryKind::Mul, IrType::Int(IntWidth::I32)) => Some("afs_array_mul_scalar_i32"),
593 (BinaryKind::Mul, IrType::Float(FloatWidth::F32)) => Some("afs_array_mul_scalar_f32"),
594 (BinaryKind::Mul, IrType::Float(FloatWidth::F64)) => Some("afs_array_mul_scalar_f64"),
595 _ => None,
596 }
597 }
598
599 fn scalar_array_kernel_name(kind: BinaryKind, ty: &IrType) -> Option<&'static str> {
600 match (kind, ty) {
601 (BinaryKind::Sub, IrType::Int(IntWidth::I32)) => Some("afs_scalar_sub_array_i32"),
602 (BinaryKind::Sub, IrType::Float(FloatWidth::F32)) => Some("afs_scalar_sub_array_f32"),
603 (BinaryKind::Sub, IrType::Float(FloatWidth::F64)) => Some("afs_scalar_sub_array_f64"),
604 _ => None,
605 }
606 }
607
608 #[cfg(test)]
609 mod tests {
610 use super::*;
611 use crate::ir::types::IrType;
612 use crate::lexer::{Position, Span};
613 use crate::opt::pass::Pass;
614
615 fn dummy_span() -> Span {
616 let p = Position { line: 0, col: 0 };
617 Span {
618 file_id: 0,
619 start: p,
620 end: p,
621 }
622 }
623
624 fn push_inst(func: &mut Function, block: BlockId, kind: InstKind, ty: IrType) -> ValueId {
625 let id = func.next_value_id();
626 func.register_type(id, ty.clone());
627 func.block_mut(block).insts.push(Inst {
628 id,
629 kind,
630 ty,
631 span: dummy_span(),
632 });
633 id
634 }
635
636 #[test]
637 fn vectorizes_simple_array_add_loop() {
638 let mut module = Module::new("m".into());
639 let mut func = Function::new("__prog_vec".into(), vec![], IrType::Void);
640
641 let entry = func.entry;
642 let header = func.create_block("do_check");
643 let body = func.create_block("do_body");
644 let exit = func.create_block("do_exit");
645
646 let a = push_inst(
647 &mut func,
648 entry,
649 InstKind::Alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 32)),
650 IrType::Ptr(Box::new(IrType::Array(
651 Box::new(IrType::Int(IntWidth::I32)),
652 32,
653 ))),
654 );
655 let b = push_inst(
656 &mut func,
657 entry,
658 InstKind::Alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 32)),
659 IrType::Ptr(Box::new(IrType::Array(
660 Box::new(IrType::Int(IntWidth::I32)),
661 32,
662 ))),
663 );
664 let c = push_inst(
665 &mut func,
666 entry,
667 InstKind::Alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I32)), 32)),
668 IrType::Ptr(Box::new(IrType::Array(
669 Box::new(IrType::Int(IntWidth::I32)),
670 32,
671 ))),
672 );
673 let one_i32 = push_inst(
674 &mut func,
675 entry,
676 InstKind::ConstInt(1, IntWidth::I32),
677 IrType::Int(IntWidth::I32),
678 );
679 let hi_i32 = push_inst(
680 &mut func,
681 entry,
682 InstKind::ConstInt(32, IntWidth::I32),
683 IrType::Int(IntWidth::I32),
684 );
685 let one_i64 = push_inst(
686 &mut func,
687 entry,
688 InstKind::ConstInt(1, IntWidth::I64),
689 IrType::Int(IntWidth::I64),
690 );
691 func.block_mut(entry).terminator = Some(Terminator::Branch(header, vec![one_i32]));
692
693 let iv = func.next_value_id();
694 func.register_type(iv, IrType::Int(IntWidth::I32));
695 func.block_mut(header).params.push(BlockParam {
696 id: iv,
697 ty: IrType::Int(IntWidth::I32),
698 });
699 let cmp = push_inst(
700 &mut func,
701 header,
702 InstKind::ICmp(CmpOp::Le, iv, hi_i32),
703 IrType::Bool,
704 );
705 func.block_mut(header).terminator = Some(Terminator::CondBranch {
706 cond: cmp,
707 true_dest: body,
708 true_args: vec![],
709 false_dest: exit,
710 false_args: vec![],
711 });
712
713 let idx64 = push_inst(
714 &mut func,
715 body,
716 InstKind::IntExtend(iv, IntWidth::I64, true),
717 IrType::Int(IntWidth::I64),
718 );
719 let offset = push_inst(
720 &mut func,
721 body,
722 InstKind::ISub(idx64, one_i64),
723 IrType::Int(IntWidth::I64),
724 );
725 let a_ptr = push_inst(
726 &mut func,
727 body,
728 InstKind::GetElementPtr(a, vec![offset]),
729 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
730 );
731 let a_val = push_inst(
732 &mut func,
733 body,
734 InstKind::Load(a_ptr),
735 IrType::Int(IntWidth::I32),
736 );
737 let b_ptr = push_inst(
738 &mut func,
739 body,
740 InstKind::GetElementPtr(b, vec![offset]),
741 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
742 );
743 let b_val = push_inst(
744 &mut func,
745 body,
746 InstKind::Load(b_ptr),
747 IrType::Int(IntWidth::I32),
748 );
749 let sum = push_inst(
750 &mut func,
751 body,
752 InstKind::IAdd(a_val, b_val),
753 IrType::Int(IntWidth::I32),
754 );
755 let c_ptr = push_inst(
756 &mut func,
757 body,
758 InstKind::GetElementPtr(c, vec![offset]),
759 IrType::Ptr(Box::new(IrType::Int(IntWidth::I32))),
760 );
761 push_inst(&mut func, body, InstKind::Store(sum, c_ptr), IrType::Void);
762 let next = push_inst(
763 &mut func,
764 body,
765 InstKind::IAdd(iv, one_i32),
766 IrType::Int(IntWidth::I32),
767 );
768 func.block_mut(body).terminator = Some(Terminator::Branch(header, vec![next]));
769 func.block_mut(exit).terminator = Some(Terminator::Return(None));
770
771 module.add_function(func);
772
773 let changed = Vectorize.run(&mut module);
774 assert!(
775 changed,
776 "vectorize should rewrite the counted array-add loop"
777 );
778
779 let func = &module.functions[0];
780 let entry_block = func.block(entry);
781 assert!(
782 entry_block.insts.iter().any(|inst| matches!(
783 &inst.kind,
784 InstKind::Call(FuncRef::External(name), _args) if name == "afs_array_add_i32"
785 )),
786 "entry should now contain a bulk add kernel call: {:?}",
787 entry_block.insts
788 );
789 assert!(
790 func.try_block(header).is_none() || !func.blocks.iter().any(|block| block.id == header),
791 "loop header should be pruned after vectorization"
792 );
793 }
794 }
795