Rust · 2767 bytes Raw Blame History
1 use std::collections::BTreeSet;
2 use std::path::PathBuf;
3
4 use armfortas::driver::OptLevel;
5 use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, Stage};
6
7 fn fixture(name: &str) -> PathBuf {
8 let path = PathBuf::from("test_programs").join(name);
9 assert!(path.exists(), "missing test fixture {}", path.display());
10 path
11 }
12
13 fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14 let result = capture_from_path(&request).expect("capture should succeed");
15 match result.get(stage) {
16 Some(CapturedStage::Text(text)) => text.clone(),
17 Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18 None => panic!("missing requested stage {}", stage.as_str()),
19 }
20 }
21
22 fn capture_run_stdout(request: CaptureRequest) -> String {
23 let result = capture_from_path(&request).expect("capture should succeed");
24 match result.get(Stage::Run) {
25 Some(CapturedStage::Run(run)) => run.stdout.clone(),
26 _ => panic!("missing run stage"),
27 }
28 }
29
30 #[test]
31 fn o3_vectorizes_sum_reductions_with_scalar_tail() {
32 let source = fixture("do_loop_vectorize_reduce_sum_tail.f90");
33
34 let o3_ir = capture_text(
35 CaptureRequest {
36 input: source.clone(),
37 requested: BTreeSet::from([Stage::OptIr]),
38 opt_level: OptLevel::O3,
39 },
40 Stage::OptIr,
41 );
42 // All four reductions should leave a vreduce_sum at exit
43 // followed by peeled scalar iadd/fadd ops chaining from the
44 // reduce result.
45 assert_eq!(
46 o3_ir.matches("vreduce_sum").count(),
47 4,
48 "expected four vreduce_sum (i32, i64, f32, f64):\n{}",
49 o3_ir
50 );
51 assert!(
52 o3_ir.contains("<4 x i32>")
53 && o3_ir.contains("<2 x i64>")
54 && o3_ir.contains("<4 x f32>")
55 && o3_ir.contains("<2 x f64>"),
56 "expected i32/i64/f32/f64 vector accumulators in IR:\n{}",
57 o3_ir
58 );
59
60 let stdout = capture_run_stdout(CaptureRequest {
61 input: source,
62 requested: BTreeSet::from([Stage::Run]),
63 opt_level: OptLevel::O3,
64 });
65 let trimmed: Vec<&str> = stdout
66 .lines()
67 .map(|l| l.trim())
68 .filter(|l| !l.is_empty())
69 .collect();
70 assert_eq!(trimmed.len(), 4, "expected four output lines:\n{}", stdout);
71 // 1 + 2 + ... + 31 = 31 * 32 / 2 = 496.
72 assert_eq!(trimmed[0], "496", "i32 sum wrong: got {:?}", trimmed[0]);
73 assert_eq!(trimmed[1], "496", "i64 sum wrong: got {:?}", trimmed[1]);
74 assert!(
75 trimmed[2].starts_with("4.96"),
76 "f32 sum wrong: got {:?}",
77 trimmed[2]
78 );
79 assert!(
80 trimmed[3].starts_with("4.96"),
81 "f64 sum wrong: got {:?}",
82 trimmed[3]
83 );
84 }
85