Rust · 2415 bytes Raw Blame History
1 use std::collections::BTreeSet;
2 use std::path::PathBuf;
3
4 use armfortas::driver::OptLevel;
5 use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, Stage};
6
7 fn fixture(name: &str) -> PathBuf {
8 let path = PathBuf::from("test_programs").join(name);
9 assert!(path.exists(), "missing test fixture {}", path.display());
10 path
11 }
12
13 fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14 let result = capture_from_path(&request).expect("capture should succeed");
15 match result.get(stage) {
16 Some(CapturedStage::Text(text)) => text.clone(),
17 Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18 None => panic!("missing requested stage {}", stage.as_str()),
19 }
20 }
21
22 fn capture_run_stdout(request: CaptureRequest) -> String {
23 let result = capture_from_path(&request).expect("capture should succeed");
24 match result.get(Stage::Run) {
25 Some(CapturedStage::Run(run)) => run.stdout.clone(),
26 _ => panic!("missing run stage"),
27 }
28 }
29
30 #[test]
31 fn o3_vectorizes_manual_dot_product_loop() {
32 let source = fixture("do_loop_vectorize_dot.f90");
33
34 let o3_ir = capture_text(
35 CaptureRequest {
36 input: source.clone(),
37 requested: BTreeSet::from([Stage::OptIr]),
38 opt_level: OptLevel::O3,
39 },
40 Stage::OptIr,
41 );
42 // The dot-product path produces two vloads, a vmul, a vadd into
43 // the vector accumulator, and a final vreduce_sum.
44 let n_vload = o3_ir.matches("vload").count();
45 assert!(
46 n_vload >= 2,
47 "dot product needs at least 2 VLoads in the body, got {}:\n{}",
48 n_vload,
49 o3_ir
50 );
51 assert!(
52 o3_ir.contains("vmul")
53 && o3_ir.contains("vadd")
54 && o3_ir.contains("vreduce_sum"),
55 "expected dot-product shape (vmul + vadd + vreduce_sum):\n{}",
56 o3_ir
57 );
58
59 // Runtime: sum(i*i for i = 1..32) = 32*33*65/6 = 11440.
60 let stdout = capture_run_stdout(CaptureRequest {
61 input: source,
62 requested: BTreeSet::from([Stage::Run]),
63 opt_level: OptLevel::O3,
64 });
65 let trimmed: Vec<&str> = stdout
66 .lines()
67 .map(|l| l.trim())
68 .filter(|l| !l.is_empty())
69 .collect();
70 assert_eq!(
71 trimmed,
72 vec!["11440"],
73 "vectorized dot product should produce 11440:\n{}",
74 stdout
75 );
76 }
77