| 1 | use std::collections::BTreeSet; |
| 2 | use std::path::PathBuf; |
| 3 | |
| 4 | use armfortas::driver::OptLevel; |
| 5 | use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, Stage}; |
| 6 | |
| 7 | fn fixture(name: &str) -> PathBuf { |
| 8 | let path = PathBuf::from("test_programs").join(name); |
| 9 | assert!(path.exists(), "missing test fixture {}", path.display()); |
| 10 | path |
| 11 | } |
| 12 | |
| 13 | fn capture_text(request: CaptureRequest, stage: Stage) -> String { |
| 14 | let result = capture_from_path(&request).expect("capture should succeed"); |
| 15 | match result.get(stage) { |
| 16 | Some(CapturedStage::Text(text)) => text.clone(), |
| 17 | Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()), |
| 18 | None => panic!("missing requested stage {}", stage.as_str()), |
| 19 | } |
| 20 | } |
| 21 | |
| 22 | fn function_section<'a>(ir: &'a str, name: &str) -> &'a str { |
| 23 | let header = format!("func @{}", name); |
| 24 | let start = ir |
| 25 | .find(&header) |
| 26 | .unwrap_or_else(|| panic!("missing function section for {}", name)); |
| 27 | let rest = &ir[start..]; |
| 28 | let end = rest |
| 29 | .find("\n }\n") |
| 30 | .unwrap_or_else(|| panic!("unterminated function section for {}", name)); |
| 31 | &rest[..end + "\n }".len()] |
| 32 | } |
| 33 | |
| 34 | fn function_sections(ir: &str) -> Vec<&str> { |
| 35 | ir.match_indices("func @") |
| 36 | .map(|(idx, _)| { |
| 37 | let rest = &ir[idx..]; |
| 38 | let end = rest |
| 39 | .find("\n }\n") |
| 40 | .unwrap_or_else(|| panic!("unterminated function section in:\n{}", rest)); |
| 41 | &rest[..end + "\n }".len()] |
| 42 | }) |
| 43 | .collect() |
| 44 | } |
| 45 | |
| 46 | fn function_name<'a>(func_section: &'a str) -> &'a str { |
| 47 | let header = func_section.lines().next().expect("function header").trim(); |
| 48 | let rest = header |
| 49 | .strip_prefix("func @") |
| 50 | .expect("function header prefix"); |
| 51 | let end = rest |
| 52 | .find(|ch: char| ch == ' ' || ch == '(') |
| 53 | .unwrap_or(rest.len()); |
| 54 | &rest[..end] |
| 55 | } |
| 56 | |
| 57 | fn block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str { |
| 58 | let mut start = None; |
| 59 | let mut end = None; |
| 60 | |
| 61 | for (idx, _line) in func_section.match_indices('\n') { |
| 62 | let line_start = idx + 1; |
| 63 | let tail = &func_section[line_start..]; |
| 64 | let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail); |
| 65 | |
| 66 | if start.is_none() { |
| 67 | if line_text.starts_with(" ") |
| 68 | && !line_text.starts_with(" ") |
| 69 | && line_text[4..].starts_with(prefix) |
| 70 | { |
| 71 | start = Some(line_start); |
| 72 | } |
| 73 | continue; |
| 74 | } |
| 75 | |
| 76 | if line_text.starts_with(" ") && !line_text.starts_with(" ") { |
| 77 | end = Some(idx); |
| 78 | break; |
| 79 | } |
| 80 | if line_text == " }" { |
| 81 | end = Some(idx); |
| 82 | break; |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | let start = start.unwrap_or_else(|| panic!("missing block with prefix {}", prefix)); |
| 87 | let end = end.unwrap_or(func_section.len()); |
| 88 | &func_section[start..end] |
| 89 | } |
| 90 | |
| 91 | fn last_block_section<'a>(func_section: &'a str, prefix: &str) -> &'a str { |
| 92 | let mut starts = Vec::new(); |
| 93 | |
| 94 | for (idx, _line) in func_section.match_indices('\n') { |
| 95 | let line_start = idx + 1; |
| 96 | let tail = &func_section[line_start..]; |
| 97 | let line_text = tail.split_once('\n').map(|(line, _)| line).unwrap_or(tail); |
| 98 | if line_text.starts_with(" ") |
| 99 | && !line_text.starts_with(" ") |
| 100 | && line_text[4..].starts_with(prefix) |
| 101 | { |
| 102 | starts.push(line_start); |
| 103 | } |
| 104 | } |
| 105 | |
| 106 | let start = *starts |
| 107 | .last() |
| 108 | .unwrap_or_else(|| panic!("missing block with prefix {}", prefix)); |
| 109 | let tail = &func_section[start..]; |
| 110 | let end_rel = tail |
| 111 | .match_indices('\n') |
| 112 | .find_map(|(idx, _)| { |
| 113 | let line_start = idx + 1; |
| 114 | let line_tail = &tail[line_start..]; |
| 115 | let line_text = line_tail |
| 116 | .split_once('\n') |
| 117 | .map(|(line, _)| line) |
| 118 | .unwrap_or(line_tail); |
| 119 | if (line_text.starts_with(" ") && !line_text.starts_with(" ")) |
| 120 | || line_text == " }" |
| 121 | { |
| 122 | Some(idx) |
| 123 | } else { |
| 124 | None |
| 125 | } |
| 126 | }) |
| 127 | .unwrap_or(tail.len()); |
| 128 | &tail[..end_rel] |
| 129 | } |
| 130 | |
| 131 | #[test] |
| 132 | fn o2_reuses_branch_join_affine_expression() { |
| 133 | let source = fixture("realworld_join_bias_sum.f90"); |
| 134 | |
| 135 | let raw_ir = capture_text( |
| 136 | CaptureRequest { |
| 137 | input: source.clone(), |
| 138 | requested: BTreeSet::from([Stage::Ir]), |
| 139 | opt_level: OptLevel::O0, |
| 140 | }, |
| 141 | Stage::Ir, |
| 142 | ); |
| 143 | let opt_ir = capture_text( |
| 144 | CaptureRequest { |
| 145 | input: source, |
| 146 | requested: BTreeSet::from([Stage::OptIr]), |
| 147 | opt_level: OptLevel::O2, |
| 148 | }, |
| 149 | Stage::OptIr, |
| 150 | ); |
| 151 | let raw_sections = function_sections(&raw_ir); |
| 152 | assert_eq!( |
| 153 | raw_sections.len(), |
| 154 | 3, |
| 155 | "raw IR should include the program body plus a pure helper and a tally worker:\n{}", |
| 156 | raw_ir |
| 157 | ); |
| 158 | let helper_name = function_name(raw_sections[1]); |
| 159 | let worker_name = function_name(raw_sections[2]); |
| 160 | |
| 161 | let raw_tally = function_section(&raw_ir, worker_name); |
| 162 | let raw_join = last_block_section(raw_tally, "if_end_"); |
| 163 | |
| 164 | assert!( |
| 165 | raw_join.matches(&format!("call @{}", helper_name)).count() >= 2, |
| 166 | "raw join block should still recompute the repeated branch-join PURE helper call:\n{}", |
| 167 | raw_join |
| 168 | ); |
| 169 | assert!( |
| 170 | opt_ir.matches(&format!("call @{}", helper_name)).count() |
| 171 | < raw_ir.matches(&format!("call @{}", helper_name)).count(), |
| 172 | "O2 should reduce duplicated branch-join PURE helper calls:\n{}", |
| 173 | opt_ir |
| 174 | ); |
| 175 | } |
| 176 | |
| 177 | #[test] |
| 178 | fn o2_removes_dead_seed_store_across_noalias_call() { |
| 179 | let source = fixture("realworld_seed_overwrite.f90"); |
| 180 | |
| 181 | let raw_ir = capture_text( |
| 182 | CaptureRequest { |
| 183 | input: source.clone(), |
| 184 | requested: BTreeSet::from([Stage::Ir]), |
| 185 | opt_level: OptLevel::O0, |
| 186 | }, |
| 187 | Stage::Ir, |
| 188 | ); |
| 189 | let opt_ir = capture_text( |
| 190 | CaptureRequest { |
| 191 | input: source, |
| 192 | requested: BTreeSet::from([Stage::OptIr]), |
| 193 | opt_level: OptLevel::O2, |
| 194 | }, |
| 195 | Stage::OptIr, |
| 196 | ); |
| 197 | let raw_sections = function_sections(&raw_ir); |
| 198 | assert_eq!( |
| 199 | raw_sections.len(), |
| 200 | 3, |
| 201 | "raw IR should include the program body plus a helper and a fill worker:\n{}", |
| 202 | raw_ir |
| 203 | ); |
| 204 | let worker_name = function_name(raw_sections[2]); |
| 205 | |
| 206 | let raw_fill = function_section(&raw_ir, worker_name); |
| 207 | let opt_fill = function_section(&opt_ir, worker_name); |
| 208 | let raw_body = block_section(raw_fill, "do_body_"); |
| 209 | let opt_body = block_section(opt_fill, "do_body_"); |
| 210 | |
| 211 | assert!( |
| 212 | raw_body.matches("store ").count() >= 2, |
| 213 | "raw loop body should still contain the seed store and the real fill store:\n{}", |
| 214 | raw_body |
| 215 | ); |
| 216 | // Detect "seed-zero" pattern: a store whose value is a `const_int 0`. |
| 217 | // This invariant is robust to loop transformations like partial |
| 218 | // unrolling that multiply store counts — DSE removing the dead |
| 219 | // seed eliminates the const-zero feeding any store. |
| 220 | fn count_store_of_zero(func_section: &str) -> usize { |
| 221 | use std::collections::HashSet; |
| 222 | // Collect every value id defined as `const_int 0`. |
| 223 | let zeros: HashSet<&str> = func_section |
| 224 | .lines() |
| 225 | .filter_map(|l| { |
| 226 | let t = l.trim(); |
| 227 | let after_eq = t.strip_prefix("%")?; |
| 228 | let (id, rest) = after_eq.split_once(" = const_int 0 ")?; |
| 229 | let _ = rest; |
| 230 | Some(id) |
| 231 | }) |
| 232 | .collect(); |
| 233 | // Count `store %ID, ...` where %ID is one of those. |
| 234 | func_section |
| 235 | .lines() |
| 236 | .filter_map(|l| l.trim().strip_prefix("store %")) |
| 237 | .filter(|rest| { |
| 238 | let id = rest.split(',').next().unwrap_or("").trim(); |
| 239 | zeros.contains(id) |
| 240 | }) |
| 241 | .count() |
| 242 | } |
| 243 | let raw_zeros = count_store_of_zero(raw_fill); |
| 244 | let opt_zeros = count_store_of_zero(opt_fill); |
| 245 | assert!( |
| 246 | raw_zeros >= 1, |
| 247 | "raw seed_and_fill should have at least one store-of-zero (the dead seed):\n{}", |
| 248 | raw_fill |
| 249 | ); |
| 250 | assert_eq!( |
| 251 | opt_zeros, 0, |
| 252 | "O2 DSE should remove the dead store-of-zero (raw had {}):\n{}", |
| 253 | raw_zeros, opt_fill |
| 254 | ); |
| 255 | } |
| 256 |