Rust · 16384 bytes Raw Blame History
1 //! Ofast-only fast-math reassociation.
2 //!
3 //! Reassociates floating add/sub chains that consist of one non-constant
4 //! base value plus finite constant terms:
5 //!
6 //! - `(x + c1) + c2` -> `x + (c1 + c2)`
7 //! - `(x + c1) - c2` -> `x + (c1 - c2)`
8 //! - `(x - c1) + c2` -> `x + (c2 - c1)`
9 //! - `(x - c1) - c2` -> `x + (-(c1 + c2))`
10 //!
11 //! Under strict IEEE semantics this can change results because it changes
12 //! rounding and signed-zero behavior. We therefore gate it at `-Ofast`
13 //! only, where fast-math relaxation is explicitly enabled.
14
15 use std::collections::{HashMap, HashSet};
16
17 use crate::ir::inst::*;
18 use crate::ir::types::{FloatWidth, IrType};
19 use crate::lexer::Span;
20
21 use super::pass::Pass;
22 use super::util::substitute_uses;
23
24 pub struct FastMathReassoc;
25
26 impl Pass for FastMathReassoc {
27 fn name(&self) -> &'static str {
28 "fast-math-reassoc"
29 }
30
31 fn run(&self, module: &mut Module) -> bool {
32 let mut changed = false;
33 for func in &mut module.functions {
34 if rewrite_function(func) {
35 changed = true;
36 }
37 }
38 changed
39 }
40 }
41
42 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43 struct FloatConst {
44 width: FloatWidth,
45 bits: u64,
46 }
47
48 impl FloatConst {
49 fn from_value(value: f64, width: FloatWidth) -> Option<Self> {
50 let rounded = round_for_width(value, width);
51 if !rounded.is_finite() {
52 return None;
53 }
54 let bits = match width {
55 FloatWidth::F32 => (rounded as f32).to_bits() as u64,
56 FloatWidth::F64 => rounded.to_bits(),
57 };
58 Some(Self { width, bits })
59 }
60
61 fn as_f64(self) -> f64 {
62 match self.width {
63 FloatWidth::F32 => f32::from_bits(self.bits as u32) as f64,
64 FloatWidth::F64 => f64::from_bits(self.bits),
65 }
66 }
67
68 fn kind(self) -> InstKind {
69 InstKind::ConstFloat(self.as_f64(), self.width)
70 }
71 }
72
73 #[derive(Debug, Clone)]
74 struct AddChain {
75 base: Option<ValueId>,
76 terms: Vec<SignedFloatConst>,
77 additive_nodes: usize,
78 }
79
80 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81 struct SignedFloatConst {
82 constant: FloatConst,
83 negative: bool,
84 }
85
86 impl SignedFloatConst {
87 fn positive(constant: FloatConst) -> Self {
88 Self {
89 constant,
90 negative: false,
91 }
92 }
93
94 fn negated(self) -> Self {
95 Self {
96 constant: self.constant,
97 negative: !self.negative,
98 }
99 }
100 }
101
102 #[derive(Debug, Clone, Copy)]
103 struct RewritePlan {
104 inst_id: ValueId,
105 block_id: BlockId,
106 base: ValueId,
107 constant: FloatConst,
108 span: Span,
109 }
110
111 fn rewrite_function(func: &mut Function) -> bool {
112 let defs = inst_map(func);
113 let mut plans = Vec::new();
114 let mut rewrites = HashMap::new();
115 let mut remove_ids = HashSet::new();
116
117 for block in &func.blocks {
118 for inst in &block.insts {
119 let width = match inst.ty {
120 IrType::Float(width) => width,
121 _ => continue,
122 };
123 if !matches!(inst.kind, InstKind::FAdd(..) | InstKind::FSub(..)) {
124 continue;
125 }
126
127 let Some(chain) = collect_chain(&defs, inst.id, width) else {
128 continue;
129 };
130 if chain.additive_nodes <= 1 {
131 continue;
132 }
133 let Some(base) = chain.base else {
134 continue;
135 };
136
137 let rounded = combine_terms(&chain.terms, width);
138 if is_effective_zero(rounded) {
139 rewrites.insert(inst.id, base);
140 remove_ids.insert(inst.id);
141 continue;
142 }
143
144 let Some(constant) = FloatConst::from_value(rounded, width) else {
145 continue;
146 };
147
148 let already_canonical = match inst.kind {
149 InstKind::FAdd(lhs, rhs) => {
150 lhs == base && const_float_of(&defs, rhs, width) == Some(constant)
151 }
152 _ => false,
153 };
154 if already_canonical {
155 continue;
156 }
157
158 plans.push(RewritePlan {
159 inst_id: inst.id,
160 block_id: block.id,
161 base,
162 constant,
163 span: inst.span,
164 });
165 }
166 }
167
168 if plans.is_empty() && rewrites.is_empty() {
169 return false;
170 }
171
172 let mut const_cache = HashMap::new();
173 for plan in plans {
174 let const_id = ensure_const_in_entry(func, &mut const_cache, plan.constant, plan.span);
175 let block = func.block_mut(plan.block_id);
176 let Some(inst) = block.insts.iter_mut().find(|inst| inst.id == plan.inst_id) else {
177 continue;
178 };
179 inst.kind = InstKind::FAdd(plan.base, const_id);
180 }
181
182 if !rewrites.is_empty() {
183 let keys: Vec<ValueId> = rewrites.keys().copied().collect();
184 for key in keys {
185 let mut cur = rewrites[&key];
186 let mut hops = 0usize;
187 while let Some(&next) = rewrites.get(&cur) {
188 if next == cur || hops > rewrites.len() {
189 break;
190 }
191 cur = next;
192 hops += 1;
193 }
194 rewrites.insert(key, cur);
195 }
196
197 for (old, new) in &rewrites {
198 substitute_uses(func, *old, *new);
199 }
200 for block in &mut func.blocks {
201 block.insts.retain(|inst| !remove_ids.contains(&inst.id));
202 }
203 }
204
205 true
206 }
207
208 fn inst_map(func: &Function) -> HashMap<ValueId, &Inst> {
209 func.blocks
210 .iter()
211 .flat_map(|block| block.insts.iter())
212 .map(|inst| (inst.id, inst))
213 .collect()
214 }
215
216 fn collect_chain(
217 defs: &HashMap<ValueId, &Inst>,
218 value: ValueId,
219 width: FloatWidth,
220 ) -> Option<AddChain> {
221 if let Some(constant) = const_float_of(defs, value, width) {
222 return Some(AddChain {
223 base: None,
224 terms: vec![SignedFloatConst::positive(constant)],
225 additive_nodes: 0,
226 });
227 }
228
229 let Some(inst) = defs.get(&value) else {
230 return Some(AddChain {
231 base: Some(value),
232 terms: Vec::new(),
233 additive_nodes: 0,
234 });
235 };
236 match inst.kind {
237 InstKind::FAdd(lhs, rhs) if inst.ty == IrType::Float(width) => {
238 let lhs_chain = collect_chain(defs, lhs, width)?;
239 let rhs_chain = collect_chain(defs, rhs, width)?;
240 let base = merge_bases(lhs_chain.base, rhs_chain.base)?;
241 let mut terms = lhs_chain.terms;
242 terms.extend(rhs_chain.terms);
243 Some(AddChain {
244 base,
245 terms,
246 additive_nodes: lhs_chain.additive_nodes + rhs_chain.additive_nodes + 1,
247 })
248 }
249 InstKind::FSub(lhs, rhs) if inst.ty == IrType::Float(width) => {
250 let lhs_chain = collect_chain(defs, lhs, width)?;
251 let rhs_chain = collect_chain(defs, rhs, width)?;
252 if rhs_chain.base.is_some() {
253 return None;
254 }
255 let mut terms = lhs_chain.terms;
256 terms.extend(rhs_chain.terms.into_iter().map(SignedFloatConst::negated));
257 Some(AddChain {
258 base: lhs_chain.base,
259 terms,
260 additive_nodes: lhs_chain.additive_nodes + rhs_chain.additive_nodes + 1,
261 })
262 }
263 _ => Some(AddChain {
264 base: Some(value),
265 terms: Vec::new(),
266 additive_nodes: 0,
267 }),
268 }
269 }
270
271 fn merge_bases(lhs: Option<ValueId>, rhs: Option<ValueId>) -> Option<Option<ValueId>> {
272 match (lhs, rhs) {
273 (Some(_), Some(_)) => None,
274 (Some(base), None) | (None, Some(base)) => Some(Some(base)),
275 (None, None) => Some(None),
276 }
277 }
278
279 fn const_float_of(
280 defs: &HashMap<ValueId, &Inst>,
281 value: ValueId,
282 width: FloatWidth,
283 ) -> Option<FloatConst> {
284 let inst = defs.get(&value)?;
285 match inst.kind {
286 InstKind::ConstFloat(value, inst_width) if inst_width == width => {
287 FloatConst::from_value(value, width)
288 }
289 _ => None,
290 }
291 }
292
293 fn combine_terms(terms: &[SignedFloatConst], width: FloatWidth) -> f64 {
294 let mut counts: HashMap<FloatConst, i32> = HashMap::new();
295 for term in terms {
296 *counts.entry(term.constant).or_insert(0) += if term.negative { -1 } else { 1 };
297 }
298
299 let mut surviving: Vec<(FloatConst, i32)> = counts
300 .into_iter()
301 .filter(|(_, count)| *count != 0)
302 .collect();
303 surviving.sort_by(|(lhs_const, lhs_count), (rhs_const, rhs_count)| {
304 lhs_const
305 .as_f64()
306 .abs()
307 .total_cmp(&rhs_const.as_f64().abs())
308 .then_with(|| lhs_const.as_f64().total_cmp(&rhs_const.as_f64()))
309 .then_with(|| lhs_count.cmp(rhs_count))
310 });
311
312 let mut sum = 0.0;
313 for (constant, count) in surviving {
314 for _ in 0..count.unsigned_abs() {
315 if count > 0 {
316 sum += constant.as_f64();
317 } else {
318 sum -= constant.as_f64();
319 }
320 }
321 }
322 round_for_width(sum, width)
323 }
324
325 fn ensure_const_in_entry(
326 func: &mut Function,
327 cache: &mut HashMap<FloatConst, ValueId>,
328 value: FloatConst,
329 span: Span,
330 ) -> ValueId {
331 if let Some(&id) = cache.get(&value) {
332 return id;
333 }
334
335 let id = func.next_value_id();
336 let ty = IrType::Float(value.width);
337 func.register_type(id, ty.clone());
338 let inst = Inst {
339 id,
340 kind: value.kind(),
341 ty,
342 span,
343 };
344 let entry = func.entry;
345 func.block_mut(entry).insts.insert(0, inst);
346 cache.insert(value, id);
347 id
348 }
349
350 fn round_for_width(value: f64, width: FloatWidth) -> f64 {
351 match width {
352 FloatWidth::F32 => value as f32 as f64,
353 FloatWidth::F64 => value,
354 }
355 }
356
357 fn is_effective_zero(value: f64) -> bool {
358 value == 0.0
359 }
360
361 #[cfg(test)]
362 mod tests {
363 use super::*;
364 use crate::ir::types::IrType;
365 use crate::lexer::Position;
366
367 fn span() -> Span {
368 let pos = Position { line: 0, col: 0 };
369 Span {
370 file_id: 0,
371 start: pos,
372 end: pos,
373 }
374 }
375
376 fn push(f: &mut Function, kind: InstKind, ty: IrType) -> ValueId {
377 let id = f.next_value_id();
378 let entry = f.entry;
379 f.block_mut(entry).insts.push(Inst {
380 id,
381 kind,
382 ty: ty.clone(),
383 span: span(),
384 });
385 f.register_type(id, ty);
386 id
387 }
388
389 #[test]
390 fn collapses_nested_float_constant_chain() {
391 let mut module = Module::new("t".into());
392 let params = vec![Param {
393 name: "x".into(),
394 ty: IrType::Float(FloatWidth::F64),
395 id: ValueId(0),
396 fortran_noalias: false,
397 }];
398 let mut func = Function::new("f".into(), params, IrType::Float(FloatWidth::F64));
399 let c1 = push(
400 &mut func,
401 InstKind::ConstFloat(2.0, FloatWidth::F64),
402 IrType::Float(FloatWidth::F64),
403 );
404 let c2 = push(
405 &mut func,
406 InstKind::ConstFloat(3.0, FloatWidth::F64),
407 IrType::Float(FloatWidth::F64),
408 );
409 let add1 = push(
410 &mut func,
411 InstKind::FAdd(ValueId(0), c1),
412 IrType::Float(FloatWidth::F64),
413 );
414 let add2 = push(
415 &mut func,
416 InstKind::FAdd(add1, c2),
417 IrType::Float(FloatWidth::F64),
418 );
419 let entry = func.entry;
420 func.block_mut(entry).terminator = Some(Terminator::Return(Some(add2)));
421 module.add_function(func);
422
423 assert!(FastMathReassoc.run(&mut module));
424 let func = &module.functions[0];
425 let insts = &func.blocks[0].insts;
426 assert!(
427 insts.iter().any(
428 |inst| matches!(inst.kind, InstKind::ConstFloat(v, FloatWidth::F64) if v == 5.0)
429 ),
430 "reassociated chain should materialize a combined constant:\n{:?}",
431 insts
432 );
433 assert!(
434 matches!(func.blocks[0].terminator, Some(Terminator::Return(Some(v))) if v == add2),
435 "outer value should stay the return root"
436 );
437 let outer = insts
438 .iter()
439 .find(|inst| inst.id == add2)
440 .expect("outer add should remain");
441 assert!(
442 matches!(outer.kind, InstKind::FAdd(ValueId(0), _)),
443 "outer add should become x + const, got {:?}",
444 outer.kind
445 );
446 }
447
448 #[test]
449 fn cancels_rounding_sensitive_add_sub_chain_to_base() {
450 let mut module = Module::new("t".into());
451 let params = vec![Param {
452 name: "x".into(),
453 ty: IrType::Float(FloatWidth::F64),
454 id: ValueId(0),
455 fortran_noalias: false,
456 }];
457 let mut func = Function::new("f".into(), params, IrType::Float(FloatWidth::F64));
458 let big = push(
459 &mut func,
460 InstKind::ConstFloat(1.0e16, FloatWidth::F64),
461 IrType::Float(FloatWidth::F64),
462 );
463 let add = push(
464 &mut func,
465 InstKind::FAdd(ValueId(0), big),
466 IrType::Float(FloatWidth::F64),
467 );
468 let sub = push(
469 &mut func,
470 InstKind::FSub(add, big),
471 IrType::Float(FloatWidth::F64),
472 );
473 let entry = func.entry;
474 func.block_mut(entry).terminator = Some(Terminator::Return(Some(sub)));
475 module.add_function(func);
476
477 assert!(FastMathReassoc.run(&mut module));
478 let func = &module.functions[0];
479 assert!(
480 matches!(func.blocks[0].terminator, Some(Terminator::Return(Some(v))) if v == ValueId(0)),
481 "cancelled chain should return the original base directly"
482 );
483 assert!(
484 !func.blocks[0].insts.iter().any(|inst| inst.id == sub),
485 "cancelled outer subtraction should be removed"
486 );
487 }
488
489 #[test]
490 fn preserves_small_constant_when_large_terms_cancel_under_ofast() {
491 let mut module = Module::new("t".into());
492 let params = vec![Param {
493 name: "x".into(),
494 ty: IrType::Float(FloatWidth::F64),
495 id: ValueId(0),
496 fortran_noalias: false,
497 }];
498 let mut func = Function::new("f".into(), params, IrType::Float(FloatWidth::F64));
499 let one = push(
500 &mut func,
501 InstKind::ConstFloat(1.0, FloatWidth::F64),
502 IrType::Float(FloatWidth::F64),
503 );
504 let big = push(
505 &mut func,
506 InstKind::ConstFloat(1.0e16, FloatWidth::F64),
507 IrType::Float(FloatWidth::F64),
508 );
509 let add_small = push(
510 &mut func,
511 InstKind::FAdd(ValueId(0), one),
512 IrType::Float(FloatWidth::F64),
513 );
514 let add_big = push(
515 &mut func,
516 InstKind::FAdd(add_small, big),
517 IrType::Float(FloatWidth::F64),
518 );
519 let sub_big = push(
520 &mut func,
521 InstKind::FSub(add_big, big),
522 IrType::Float(FloatWidth::F64),
523 );
524 let entry = func.entry;
525 func.block_mut(entry).terminator = Some(Terminator::Return(Some(sub_big)));
526 module.add_function(func);
527
528 assert!(FastMathReassoc.run(&mut module));
529 let func = &module.functions[0];
530 let outer = func.blocks[0]
531 .insts
532 .iter()
533 .find(|inst| inst.id == sub_big)
534 .expect("outer value should remain after reassociation");
535 let const_id = match outer.kind {
536 InstKind::FAdd(ValueId(0), const_id) => const_id,
537 ref other => panic!("expected x + const after reassociation, got {:?}", other),
538 };
539 let const_inst = func.blocks[0]
540 .insts
541 .iter()
542 .find(|inst| inst.id == const_id)
543 .expect("combined constant should be materialized");
544 assert!(
545 matches!(const_inst.kind, InstKind::ConstFloat(v, FloatWidth::F64) if v == 1.0),
546 "large cancelling terms should preserve the remaining +1 constant, got {:?}",
547 const_inst.kind
548 );
549 }
550 }
551