Rust · 4165 bytes Raw Blame History
1 use std::collections::BTreeSet;
2 use std::path::PathBuf;
3
4 use armfortas::driver::OptLevel;
5 use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, Stage};
6
7 fn fixture(name: &str) -> PathBuf {
8 let path = PathBuf::from("test_programs").join(name);
9 assert!(path.exists(), "missing test fixture {}", path.display());
10 path
11 }
12
13 fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14 let result = capture_from_path(&request).expect("capture should succeed");
15 match result.get(stage) {
16 Some(CapturedStage::Text(text)) => text.clone(),
17 Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18 None => panic!("missing requested stage {}", stage.as_str()),
19 }
20 }
21
22 fn function_section<'a>(ir: &'a str, name: &str) -> &'a str {
23 let header = format!(" func @{}", name);
24 let start = ir
25 .find(&header)
26 .unwrap_or_else(|| panic!("missing function section for {}", name));
27 let rest = &ir[start..];
28 let end = rest
29 .find("\n }\n")
30 .unwrap_or_else(|| panic!("unterminated function section for {}", name));
31 &rest[..end + "\n }".len()]
32 }
33
34 fn function_sections(ir: &str) -> Vec<&str> {
35 ir.match_indices(" func @")
36 .map(|(idx, _)| {
37 let rest = &ir[idx..];
38 let end = rest
39 .find("\n }\n")
40 .unwrap_or_else(|| panic!("unterminated function section in:\n{}", rest));
41 &rest[..end + "\n }".len()]
42 })
43 .collect()
44 }
45
46 fn function_name<'a>(func_section: &'a str) -> &'a str {
47 let header = func_section.lines().next().expect("function header").trim();
48 let rest = header
49 .strip_prefix("func @")
50 .expect("function header prefix");
51 let end = rest
52 .find(|ch: char| ch == ' ' || ch == '(')
53 .unwrap_or(rest.len());
54 &rest[..end]
55 }
56
57 fn non_program_function_names(ir: &str) -> Vec<&str> {
58 function_sections(ir)
59 .into_iter()
60 .map(function_name)
61 .filter(|name| !name.starts_with("__prog_"))
62 .collect()
63 }
64
65 fn block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str {
66 let mut start = None;
67 let mut end = None;
68
69 for (idx, _line) in func_section.match_indices('\n') {
70 let line_start = idx + 1;
71 let tail = &func_section[line_start..];
72 let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail);
73
74 if start.is_none() {
75 if line_text.starts_with(" ")
76 && !line_text.starts_with(" ")
77 && line_text[4..].starts_with(prefix)
78 {
79 start = Some(line_start);
80 }
81 continue;
82 }
83
84 if line_text.starts_with(" ") && !line_text.starts_with(" ") {
85 end = Some(idx);
86 break;
87 }
88 if line_text == " }" {
89 end = Some(idx);
90 break;
91 }
92 }
93
94 let start = start.unwrap_or_else(|| panic!("missing block with prefix {}", prefix));
95 let end = end.unwrap_or(func_section.len());
96 &func_section[start..end]
97 }
98
99 #[test]
100 fn o2_hoists_noalias_dummy_load_into_loop_preheader() {
101 let source = fixture("licm_noalias_dummy_load.f90");
102
103 let opt_ir = capture_text(
104 CaptureRequest {
105 input: source,
106 requested: BTreeSet::from([Stage::OptIr]),
107 opt_level: OptLevel::O2,
108 },
109 Stage::OptIr,
110 );
111
112 let helper_names = non_program_function_names(&opt_ir);
113 assert_eq!(
114 helper_names.len(),
115 1,
116 "optimized IR should include exactly one contained kernel helper:\n{}",
117 opt_ir
118 );
119 let kernel = function_section(&opt_ir, helper_names[0]);
120 let preheader = block_section(kernel, "if_end_");
121 let loop_body = block_section(kernel, "do_body_");
122
123 assert!(
124 preheader.contains("load %0"),
125 "O2 kernel preheader should preload the non-aliasing dummy arg:\n{}",
126 kernel
127 );
128 assert!(
129 !loop_body.contains("load %0"),
130 "O2 kernel loop body should reuse the hoisted dummy-arg load:\n{}",
131 kernel
132 );
133 }
134