Rust · 6074 bytes Raw Blame History
1 use std::collections::BTreeSet;
2 use std::path::PathBuf;
3
4 use armfortas::driver::OptLevel;
5 use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, RunCapture, Stage};
6
7 fn fixture(name: &str) -> PathBuf {
8 let path = PathBuf::from("test_programs").join(name);
9 assert!(path.exists(), "missing test fixture {}", path.display());
10 path
11 }
12
13 fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14 let result = capture_from_path(&request).expect("capture should succeed");
15 match result.get(stage) {
16 Some(CapturedStage::Text(text)) => text.clone(),
17 Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18 None => panic!("missing requested stage {}", stage.as_str()),
19 }
20 }
21
22 fn capture_run(request: CaptureRequest) -> RunCapture {
23 let result = capture_from_path(&request).expect("capture should succeed");
24 match result.get(Stage::Run) {
25 Some(CapturedStage::Run(run)) => run.clone(),
26 Some(CapturedStage::Text(_)) => panic!("expected run stage"),
27 None => panic!("missing requested stage {}", Stage::Run.as_str()),
28 }
29 }
30
31 fn function_section<'a>(ir: &'a str, name: &str) -> &'a str {
32 let header = format!(" func @{}", name);
33 let start = ir
34 .find(&header)
35 .unwrap_or_else(|| panic!("missing function section for {}", name));
36 let rest = &ir[start..];
37 let end = rest
38 .find("\n }\n")
39 .unwrap_or_else(|| panic!("unterminated function section for {}", name));
40 &rest[..end + "\n }".len()]
41 }
42
43 fn function_sections(ir: &str) -> Vec<&str> {
44 ir.match_indices(" func @")
45 .map(|(idx, _)| {
46 let rest = &ir[idx..];
47 let end = rest
48 .find("\n }\n")
49 .unwrap_or_else(|| panic!("unterminated function section in:\n{}", rest));
50 &rest[..end + "\n }".len()]
51 })
52 .collect()
53 }
54
55 fn function_name<'a>(func_section: &'a str) -> &'a str {
56 let header = func_section.lines().next().expect("function header").trim();
57 let rest = header
58 .strip_prefix("func @")
59 .expect("function header prefix");
60 let end = rest
61 .find(|ch: char| ch == ' ' || ch == '(')
62 .unwrap_or(rest.len());
63 &rest[..end]
64 }
65
66 #[test]
67 fn o2_binomial_blend_scalarizes_taps_and_removes_safe_stencil_checks() {
68 let source = fixture("realworld_binomial_blend.f90");
69
70 let raw_ir = capture_text(
71 CaptureRequest {
72 input: source.clone(),
73 requested: BTreeSet::from([Stage::Ir]),
74 opt_level: OptLevel::O0,
75 },
76 Stage::Ir,
77 );
78 let opt_ir = capture_text(
79 CaptureRequest {
80 input: source,
81 requested: BTreeSet::from([Stage::OptIr]),
82 opt_level: OptLevel::O2,
83 },
84 Stage::OptIr,
85 );
86 let raw_sections = function_sections(&raw_ir);
87 assert_eq!(
88 raw_sections.len(),
89 2,
90 "raw IR should include the program body plus one contained blend helper:\n{}",
91 raw_ir
92 );
93 let helper_name = function_name(raw_sections[1]);
94
95 let raw_blend = function_section(&raw_ir, helper_name);
96 let opt_blend = function_section(&opt_ir, helper_name);
97
98 assert!(
99 raw_blend.contains("alloca [i32 x 4]"),
100 "raw IR should still materialize taps(4) as an aggregate before SROA:\n{}",
101 raw_blend
102 );
103 assert!(
104 raw_blend.contains("rt_call @__afs_check_bounds"),
105 "raw IR should still contain stencil bounds checks before BCE:\n{}",
106 raw_blend
107 );
108 assert!(
109 !opt_blend.contains("alloca [i32 x 4]"),
110 "O2 optimized IR should scalarize/remove taps(4):\n{}",
111 opt_blend
112 );
113 assert!(
114 !opt_blend.contains("rt_call @__afs_check_bounds"),
115 "O2 optimized IR should eliminate safe stencil bounds checks:\n{}",
116 opt_blend
117 );
118 }
119
120 #[test]
121 fn realworld_shape_guard_uses_runtime_shape_queries_and_stays_deterministic() {
122 let source = fixture("realworld_shape_guard.f90");
123
124 let raw_ir = capture_text(
125 CaptureRequest {
126 input: source.clone(),
127 requested: BTreeSet::from([Stage::Ir]),
128 opt_level: OptLevel::O0,
129 },
130 Stage::Ir,
131 );
132
133 assert!(
134 raw_ir.contains("call @afs_array_size("),
135 "raw IR should route SIZE(work) through the runtime shape query:\n{}",
136 raw_ir
137 );
138 assert!(
139 raw_ir.contains("call @afs_array_lbound("),
140 "raw IR should route LBOUND(work, 1) through the runtime shape query:\n{}",
141 raw_ir
142 );
143 assert!(
144 raw_ir.contains("call @afs_array_ubound("),
145 "raw IR should route UBOUND(work, 1) through the runtime shape query:\n{}",
146 raw_ir
147 );
148
149 for level in [
150 OptLevel::O0,
151 OptLevel::O1,
152 OptLevel::O2,
153 OptLevel::O3,
154 OptLevel::Os,
155 OptLevel::Ofast,
156 ] {
157 let run = capture_run(CaptureRequest {
158 input: source.clone(),
159 requested: BTreeSet::from([Stage::Run]),
160 opt_level: level,
161 });
162 assert_eq!(
163 run.exit_code, 0,
164 "real-world runtime-shape guard should run successfully at {:?}:\n{:#?}",
165 level, run
166 );
167 assert!(
168 run.stdout.contains("6 0 5 12 36"),
169 "runtime-shape guard should preserve the descriptor-backed query results at {:?}:\n{:#?}",
170 level, run
171 );
172 }
173
174 let obj_a = capture_text(
175 CaptureRequest {
176 input: source.clone(),
177 requested: BTreeSet::from([Stage::Obj]),
178 opt_level: OptLevel::O2,
179 },
180 Stage::Obj,
181 );
182 let obj_b = capture_text(
183 CaptureRequest {
184 input: source,
185 requested: BTreeSet::from([Stage::Obj]),
186 opt_level: OptLevel::O2,
187 },
188 Stage::Obj,
189 );
190 assert_eq!(
191 obj_a, obj_b,
192 "descriptor-backed runtime-shape guard should have a deterministic O2 object snapshot"
193 );
194 }
195