Rust · 9254 bytes Raw Blame History
1 use std::collections::BTreeSet;
2 use std::path::PathBuf;
3
4 use armfortas::driver::OptLevel;
5 use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, Stage};
6
7 fn fixture(name: &str) -> PathBuf {
8 let path = PathBuf::from("test_programs").join(name);
9 assert!(path.exists(), "missing test fixture {}", path.display());
10 path
11 }
12
13 fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14 let result = capture_from_path(&request).expect("capture should succeed");
15 match result.get(stage) {
16 Some(CapturedStage::Text(text)) => text.clone(),
17 Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18 None => panic!("missing requested stage {}", stage.as_str()),
19 }
20 }
21
22 fn function_section<'a>(ir: &'a str, name: &str) -> &'a str {
23 let header = format!(" func @{}", name);
24 let start = ir
25 .find(&header)
26 .unwrap_or_else(|| panic!("missing function section for {}", name));
27 let rest = &ir[start..];
28 let end = rest
29 .find("\n }\n")
30 .unwrap_or_else(|| panic!("unterminated function section for {}", name));
31 &rest[..end + "\n }".len()]
32 }
33
34 fn function_sections(ir: &str) -> Vec<&str> {
35 ir.match_indices(" func @")
36 .map(|(idx, _)| {
37 let rest = &ir[idx..];
38 let end = rest
39 .find("\n }\n")
40 .unwrap_or_else(|| panic!("unterminated function section in:\n{}", rest));
41 &rest[..end + "\n }".len()]
42 })
43 .collect()
44 }
45
46 fn function_name<'a>(func_section: &'a str) -> &'a str {
47 let header = func_section.lines().next().expect("function header").trim();
48 let rest = header
49 .strip_prefix("func @")
50 .expect("function header prefix");
51 let end = rest
52 .find(|ch: char| ch == ' ' || ch == '(')
53 .unwrap_or(rest.len());
54 &rest[..end]
55 }
56
57 fn non_program_function_names(ir: &str) -> Vec<&str> {
58 function_sections(ir)
59 .into_iter()
60 .map(function_name)
61 .filter(|name| !name.starts_with("__prog_"))
62 .collect()
63 }
64
65 fn block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str {
66 let mut start = None;
67 let mut end = None;
68
69 for (idx, _line) in func_section.match_indices('\n') {
70 let line_start = idx + 1;
71 let tail = &func_section[line_start..];
72 let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail);
73
74 if start.is_none() {
75 if line_text.starts_with(" ")
76 && !line_text.starts_with(" ")
77 && line_text[4..].starts_with(prefix)
78 {
79 start = Some(line_start);
80 }
81 continue;
82 }
83
84 if line_text.starts_with(" ") && !line_text.starts_with(" ") {
85 end = Some(idx);
86 break;
87 }
88 if line_text == " }" {
89 end = Some(idx);
90 break;
91 }
92 }
93
94 let start = start.unwrap_or_else(|| panic!("missing block with prefix {}", prefix));
95 let end = end.unwrap_or(func_section.len());
96 &func_section[start..end]
97 }
98
99 fn last_block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str {
100 let mut starts = Vec::new();
101
102 for (idx, _line) in func_section.match_indices('\n') {
103 let line_start = idx + 1;
104 let tail = &func_section[line_start..];
105 let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail);
106 if line_text.starts_with(" ")
107 && !line_text.starts_with(" ")
108 && line_text[4..].starts_with(prefix)
109 {
110 starts.push(line_start);
111 }
112 }
113
114 let start = *starts
115 .last()
116 .unwrap_or_else(|| panic!("missing block with prefix {}", prefix));
117 let tail = &func_section[start..];
118 let end = tail
119 .match_indices('\n')
120 .find_map(|(idx, _)| {
121 let line_start = idx + 1;
122 let rest = &tail[line_start..];
123 let line_text = rest.split_once('\n').map(|(line, _)| line).unwrap_or(rest);
124 ((line_text.starts_with(" ") && !line_text.starts_with(" "))
125 || line_text == " }")
126 .then_some(idx)
127 })
128 .unwrap_or(tail.len());
129 &tail[..end]
130 }
131
132 fn tail_after<'a>(text: &'a str, needle: &str) -> &'a str {
133 let start = text
134 .find(needle)
135 .unwrap_or_else(|| panic!("missing '{}' in:\n{}", needle, text));
136 &text[start + needle.len()..]
137 }
138
139 #[test]
140 fn o2_hoists_affine_dummy_loads_out_of_loop() {
141 let source = fixture("realworld_affine_shift.f90");
142
143 let raw_ir = capture_text(
144 CaptureRequest {
145 input: source.clone(),
146 requested: BTreeSet::from([Stage::Ir]),
147 opt_level: OptLevel::O0,
148 },
149 Stage::Ir,
150 );
151 let opt_ir = capture_text(
152 CaptureRequest {
153 input: source,
154 requested: BTreeSet::from([Stage::OptIr]),
155 opt_level: OptLevel::O2,
156 },
157 Stage::OptIr,
158 );
159 let helper_names = non_program_function_names(&raw_ir);
160 assert_eq!(
161 helper_names.len(),
162 1,
163 "raw IR should include exactly one contained affine helper:\n{}",
164 raw_ir
165 );
166
167 let raw_apply = function_section(&raw_ir, helper_names[0]);
168 let opt_apply = function_section(&opt_ir, helper_names[0]);
169 let raw_body = block_section(raw_apply, "do_body_");
170 let opt_preheader = block_section(opt_apply, "if_end_");
171 let opt_body = block_section(opt_apply, "do_body_");
172
173 assert!(
174 raw_body.contains("load %5 : ptr<i32>") && raw_body.contains("load %7 : ptr<i32>"),
175 "raw LICM kernel should still chase scalar dummy wrappers inside the loop before hoisting:\n{}",
176 raw_body
177 );
178 assert!(
179 opt_preheader.contains("load %0 : i32") && opt_preheader.contains("load %1 : i32"),
180 "O2 LICM should hoist invariant dummy loads into the loop preheader:\n{}",
181 opt_apply
182 );
183 assert!(
184 !opt_body.contains("load %0 : i32") && !opt_body.contains("load %1 : i32"),
185 "O2 loop body should reuse the hoisted dummy loads:\n{}",
186 opt_apply
187 );
188 }
189
190 #[test]
191 fn o2_forwards_local_store_reuse_across_noalias_call() {
192 let source = fixture("realworld_noalias_reuse.f90");
193
194 let raw_ir = capture_text(
195 CaptureRequest {
196 input: source.clone(),
197 requested: BTreeSet::from([Stage::Ir]),
198 opt_level: OptLevel::O0,
199 },
200 Stage::Ir,
201 );
202 let opt_ir = capture_text(
203 CaptureRequest {
204 input: source,
205 requested: BTreeSet::from([Stage::OptIr]),
206 opt_level: OptLevel::O2,
207 },
208 Stage::OptIr,
209 );
210 let helper_names = non_program_function_names(&raw_ir);
211 assert_eq!(
212 helper_names.len(),
213 3,
214 "raw IR should include a side-effect helper plus two contained noalias workers:\n{}",
215 raw_ir
216 );
217
218 let raw_local = function_section(&raw_ir, helper_names[1]);
219 let opt_local = function_section(&opt_ir, helper_names[1]);
220 let raw_body = block_section(raw_local, "do_body_");
221 let opt_body = block_section(opt_local, "do_body_");
222
223 let side_effect_call = format!("call @{}", helper_names[0]);
224 let raw_after_call = tail_after(raw_body, &side_effect_call);
225 let opt_after_call = tail_after(opt_body, &side_effect_call);
226
227 assert!(
228 raw_after_call.contains("gep") && raw_after_call.contains("load"),
229 "raw local LSF kernel should still recompute and reload y(i) after the helper call:\n{}",
230 raw_body
231 );
232 assert!(
233 !opt_after_call.contains("gep"),
234 "O2 local LSF should reuse the stored y(i) value directly after the noalias helper call:\n{}",
235 opt_body
236 );
237 }
238
239 #[test]
240 fn o2_forwards_branch_join_reuse_across_noalias_side_call() {
241 let source = fixture("realworld_noalias_reuse.f90");
242
243 let raw_ir = capture_text(
244 CaptureRequest {
245 input: source.clone(),
246 requested: BTreeSet::from([Stage::Ir]),
247 opt_level: OptLevel::O0,
248 },
249 Stage::Ir,
250 );
251 let opt_ir = capture_text(
252 CaptureRequest {
253 input: source,
254 requested: BTreeSet::from([Stage::OptIr]),
255 opt_level: OptLevel::O2,
256 },
257 Stage::OptIr,
258 );
259 let helper_names = non_program_function_names(&raw_ir);
260 assert_eq!(
261 helper_names.len(),
262 3,
263 "raw IR should include a side-effect helper plus two contained noalias workers:\n{}",
264 raw_ir
265 );
266
267 let raw_branchy = function_section(&raw_ir, helper_names[2]);
268 let opt_branchy = function_section(&opt_ir, helper_names[2]);
269 let raw_join = last_block_section(raw_branchy, "if_end_");
270 let opt_join = last_block_section(opt_branchy, "if_end_");
271
272 assert!(
273 raw_join.contains("gep") && raw_join.contains("load"),
274 "raw branch-join kernel should still reload y(i) after the side-path helper call:\n{}",
275 raw_join
276 );
277 assert!(
278 !opt_join.contains("gep"),
279 "O2 global LSF should remove the join-block y(i) reload across the noalias side-path call:\n{}",
280 opt_join
281 );
282 }
283