Rust · 8431 bytes Raw Blame History
1 use std::collections::BTreeSet;
2 use std::path::PathBuf;
3
4 use armfortas::driver::OptLevel;
5 use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, Stage};
6
7 fn fixture(name: &str) -> PathBuf {
8 let path = PathBuf::from("test_programs").join(name);
9 assert!(path.exists(), "missing test fixture {}", path.display());
10 path
11 }
12
13 fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14 let result = capture_from_path(&request).expect("capture should succeed");
15 match result.get(stage) {
16 Some(CapturedStage::Text(text)) => text.clone(),
17 Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18 None => panic!("missing requested stage {}", stage.as_str()),
19 }
20 }
21
22 fn function_section<'a>(ir: &'a str, name: &str) -> &'a str {
23 let header = format!("func @{}", name);
24 let start = ir
25 .find(&header)
26 .unwrap_or_else(|| panic!("missing function section for {}", name));
27 let rest = &ir[start..];
28 let end = rest
29 .find("\n }\n")
30 .unwrap_or_else(|| panic!("unterminated function section for {}", name));
31 &rest[..end + "\n }".len()]
32 }
33
34 fn function_sections(ir: &str) -> Vec<&str> {
35 ir.match_indices("func @")
36 .map(|(idx, _)| {
37 let rest = &ir[idx..];
38 let end = rest
39 .find("\n }\n")
40 .unwrap_or_else(|| panic!("unterminated function section in:\n{}", rest));
41 &rest[..end + "\n }".len()]
42 })
43 .collect()
44 }
45
46 fn function_name<'a>(func_section: &'a str) -> &'a str {
47 let header = func_section.lines().next().expect("function header").trim();
48 let rest = header
49 .strip_prefix("func @")
50 .expect("function header prefix");
51 let end = rest
52 .find(|ch: char| ch == ' ' || ch == '(')
53 .unwrap_or(rest.len());
54 &rest[..end]
55 }
56
57 fn block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str {
58 let mut start = None;
59 let mut end = None;
60
61 for (idx, _line) in func_section.match_indices('\n') {
62 let line_start = idx + 1;
63 let tail = &func_section[line_start..];
64 let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail);
65
66 if start.is_none() {
67 if line_text.starts_with(" ")
68 && !line_text.starts_with(" ")
69 && line_text[4..].starts_with(prefix)
70 {
71 start = Some(line_start);
72 }
73 continue;
74 }
75
76 if line_text.starts_with(" ") && !line_text.starts_with(" ") {
77 end = Some(idx);
78 break;
79 }
80 if line_text == " }" {
81 end = Some(idx);
82 break;
83 }
84 }
85
86 let start = start.unwrap_or_else(|| panic!("missing block with prefix {}", prefix));
87 let end = end.unwrap_or(func_section.len());
88 &func_section[start..end]
89 }
90
91 fn last_block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str {
92 let mut starts = Vec::new();
93
94 for (idx, _line) in func_section.match_indices('\n') {
95 let line_start = idx + 1;
96 let tail = &func_section[line_start..];
97 let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail);
98 if line_text.starts_with(" ")
99 && !line_text.starts_with(" ")
100 && line_text[4..].starts_with(prefix)
101 {
102 starts.push(line_start);
103 }
104 }
105
106 let start = *starts
107 .last()
108 .unwrap_or_else(|| panic!("missing block with prefix {}", prefix));
109 let tail = &func_section[start..];
110 let end_rel = tail
111 .match_indices('\n')
112 .find_map(|(idx, _)| {
113 let line_start = idx + 1;
114 let line_tail = &tail[line_start..];
115 let line_text = line_tail
116 .split_once('\n')
117 .map(|(line, _)| line)
118 .unwrap_or(line_tail);
119 if (line_text.starts_with(" ") && !line_text.starts_with(" "))
120 || line_text == " }"
121 {
122 Some(idx)
123 } else {
124 None
125 }
126 })
127 .unwrap_or(tail.len());
128 &tail[..end_rel]
129 }
130
131 #[test]
132 fn o2_reuses_branch_join_affine_expression() {
133 let source = fixture("realworld_join_bias_sum.f90");
134
135 let raw_ir = capture_text(
136 CaptureRequest {
137 input: source.clone(),
138 requested: BTreeSet::from([Stage::Ir]),
139 opt_level: OptLevel::O0,
140 },
141 Stage::Ir,
142 );
143 let opt_ir = capture_text(
144 CaptureRequest {
145 input: source,
146 requested: BTreeSet::from([Stage::OptIr]),
147 opt_level: OptLevel::O2,
148 },
149 Stage::OptIr,
150 );
151 let raw_sections = function_sections(&raw_ir);
152 assert_eq!(
153 raw_sections.len(),
154 3,
155 "raw IR should include the program body plus a pure helper and a tally worker:\n{}",
156 raw_ir
157 );
158 let helper_name = function_name(raw_sections[1]);
159 let worker_name = function_name(raw_sections[2]);
160
161 let raw_tally = function_section(&raw_ir, worker_name);
162 let raw_join = last_block_section(raw_tally, "if_end_");
163
164 assert!(
165 raw_join.matches(&format!("call @{}", helper_name)).count() >= 2,
166 "raw join block should still recompute the repeated branch-join PURE helper call:\n{}",
167 raw_join
168 );
169 assert!(
170 opt_ir.matches(&format!("call @{}", helper_name)).count()
171 < raw_ir.matches(&format!("call @{}", helper_name)).count(),
172 "O2 should reduce duplicated branch-join PURE helper calls:\n{}",
173 opt_ir
174 );
175 }
176
177 #[test]
178 fn o2_removes_dead_seed_store_across_noalias_call() {
179 let source = fixture("realworld_seed_overwrite.f90");
180
181 let raw_ir = capture_text(
182 CaptureRequest {
183 input: source.clone(),
184 requested: BTreeSet::from([Stage::Ir]),
185 opt_level: OptLevel::O0,
186 },
187 Stage::Ir,
188 );
189 let opt_ir = capture_text(
190 CaptureRequest {
191 input: source,
192 requested: BTreeSet::from([Stage::OptIr]),
193 opt_level: OptLevel::O2,
194 },
195 Stage::OptIr,
196 );
197 let raw_sections = function_sections(&raw_ir);
198 assert_eq!(
199 raw_sections.len(),
200 3,
201 "raw IR should include the program body plus a helper and a fill worker:\n{}",
202 raw_ir
203 );
204 let worker_name = function_name(raw_sections[2]);
205
206 let raw_fill = function_section(&raw_ir, worker_name);
207 let opt_fill = function_section(&opt_ir, worker_name);
208 let raw_body = block_section(raw_fill, "do_body_");
209 let opt_body = block_section(opt_fill, "do_body_");
210
211 assert!(
212 raw_body.matches("store ").count() >= 2,
213 "raw loop body should still contain the seed store and the real fill store:\n{}",
214 raw_body
215 );
216 // Detect "seed-zero" pattern: a store whose value is a `const_int 0`.
217 // This invariant is robust to loop transformations like partial
218 // unrolling that multiply store counts — DSE removing the dead
219 // seed eliminates the const-zero feeding any store.
220 fn count_store_of_zero(func_section: &str) -> usize {
221 use std::collections::HashSet;
222 // Collect every value id defined as `const_int 0`.
223 let zeros: HashSet<&str> = func_section
224 .lines()
225 .filter_map(|l| {
226 let t = l.trim();
227 let after_eq = t.strip_prefix("%")?;
228 let (id, rest) = after_eq.split_once(" = const_int 0 ")?;
229 let _ = rest;
230 Some(id)
231 })
232 .collect();
233 // Count `store %ID, ...` where %ID is one of those.
234 func_section
235 .lines()
236 .filter_map(|l| l.trim().strip_prefix("store %"))
237 .filter(|rest| {
238 let id = rest.split(',').next().unwrap_or("").trim();
239 zeros.contains(id)
240 })
241 .count()
242 }
243 let raw_zeros = count_store_of_zero(raw_fill);
244 let opt_zeros = count_store_of_zero(opt_fill);
245 assert!(
246 raw_zeros >= 1,
247 "raw seed_and_fill should have at least one store-of-zero (the dead seed):\n{}",
248 raw_fill
249 );
250 assert_eq!(
251 opt_zeros, 0,
252 "O2 DSE should remove the dead store-of-zero (raw had {}):\n{}",
253 raw_zeros, opt_fill
254 );
255 }
256