Rust · 35507 bytes Raw Blame History
1 //! Constant folding pass.
2 //!
3 //! Walks each function and replaces every instruction whose operands
4 //! are all compile-time constants with the corresponding `Const*`
5 //! instruction. The `ValueId` of each folded instruction is preserved
6 //! so that downstream uses see the same SSA name and don't need
7 //! rewriting.
8 //!
9 //! We deliberately keep this pass narrow:
10 //!
11 //! * Only operands defined by `ConstInt` / `ConstFloat` / `ConstBool`
12 //! in the same function are eligible.
13 //! * Integer arithmetic uses two's-complement wrapping with the
14 //! instruction's declared width — we mask down to 8/16/32/64 bits
15 //! so that `i8 + i8` overflow matches what codegen would produce.
16 //! * Integer division/modulo by zero is left alone — Fortran's
17 //! behavior on divide-by-zero is implementation-defined, and the
18 //! runtime / hardware should observe it the same way it would
19 //! without optimization.
20 //! * Float operations follow IEEE 754 semantics; division by zero
21 //! yields well-defined ±inf or NaN and is folded.
22 //!
23 //! Constant *propagation* (replacing uses of a value that happens to
24 //! be a constant) is a separate pass; this one only rewrites the
25 //! defining instruction in place. The pair runs to fixpoint via the
26 //! pass manager, so propagation feeds folding which feeds propagation.
27
28 use super::pass::Pass;
29 use crate::ir::inst::*;
30 use crate::ir::types::{FloatWidth, IntWidth, IrType};
31 use std::collections::HashMap;
32
33 /// Compile-time constant value, normalized for folding.
34 #[derive(Debug, Clone, Copy)]
35 enum Const {
36 Int(i128, IntWidth),
37 Float(f64, FloatWidth),
38 Bool(bool),
39 }
40
41 impl Const {
42 /// Build a `Const` from a constant-producing instruction, putting
43 /// the stored value into its **canonical** form for the declared
44 /// width. Audit N-9 (root-cause fix for N-1/M-1/M-2/M-5):
45 ///
46 /// * Integer values are sign-extended at their width. After this
47 /// call, `Const::Int(-1, I8)` and `Const::Int(255, I8)` both
48 /// land at the same stored value (`-1`), so downstream folds
49 /// can destructure `Const::Int(av, _)` without needing to
50 /// re-sext at the source width. This retires 12+ width-drift
51 /// hazards across IAdd/ISub/IMul/IDiv/IMod/BitAnd/... in one
52 /// move.
53 /// * Float values are rounded through their declared precision
54 /// (f32 values become exactly representable in f32 after
55 /// `v as f32 as f64`). This is belt-and-suspenders on top of
56 /// the producer-side guarantee from M-1.
57 fn from_inst(kind: &InstKind) -> Option<Self> {
58 match kind {
59 InstKind::ConstInt(v, w) => Some(Const::Int(sext(*v, w.bits()), *w)),
60 InstKind::ConstFloat(v, w) => Some(Const::Float(round_for_width(*v, *w), *w)),
61 InstKind::ConstBool(b) => Some(Const::Bool(*b)),
62 _ => None,
63 }
64 }
65 }
66
67 /// Sign-extend `v` from `bits` to a full i128. Used so that comparisons
68 /// and arithmetic on narrower integer types match hardware behavior.
69 fn sext(v: i128, bits: u32) -> i128 {
70 if bits >= 128 {
71 v
72 } else {
73 let shift = 128 - bits;
74 (v << shift) >> shift
75 }
76 }
77
78 /// Mask `v` down to `bits` low bits. The IR stores integers as i128 but
79 /// arithmetic must wrap at the declared width.
80 fn mask(v: i128, bits: u32) -> i128 {
81 if bits >= 128 {
82 v
83 } else {
84 v & ((1i128 << bits) - 1)
85 }
86 }
87
88 /// Truncate-then-sign-extend a wide arithmetic result back to its
89 /// declared width, normalized to a sign-extended i128.
90 fn norm(v: i128, w: IntWidth) -> i128 {
91 sext(mask(v, w.bits()), w.bits())
92 }
93
94 fn signed_min(w: IntWidth) -> i128 {
95 match w {
96 IntWidth::I8 => i8::MIN as i128,
97 IntWidth::I16 => i16::MIN as i128,
98 IntWidth::I32 => i32::MIN as i128,
99 IntWidth::I64 => i64::MIN as i128,
100 IntWidth::I128 => i128::MIN,
101 }
102 }
103
104 /// Round an `f64`-stored value to the precision of its declared
105 /// `FloatWidth`. Audit B-3: used as defense in depth in `FCmp` and
106 /// `FloatToInt` so that those folds don't silently miscompute when
107 /// (some hypothetical future) producer breaks the M-1 invariant
108 /// (every `Const::Float(_, F32)` value is already f32-rounded).
109 fn round_for_width(v: f64, w: FloatWidth) -> f64 {
110 match w {
111 FloatWidth::F32 => v as f32 as f64,
112 FloatWidth::F64 => v,
113 }
114 }
115
116 /// Try to evaluate a single instruction given a map of known
117 /// constants. Returns `Some(new_kind)` if the instruction can be
118 /// replaced with a `Const*`.
119 #[allow(clippy::too_many_lines)]
120 fn try_fold(kind: &InstKind, ty: &IrType, consts: &HashMap<ValueId, Const>) -> Option<InstKind> {
121 let get = |id: &ValueId| consts.get(id).copied();
122
123 match kind {
124 // Integer arithmetic ----------------------------------------------
125 InstKind::IAdd(a, b) => {
126 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
127 if let IrType::Int(w) = ty {
128 return Some(InstKind::ConstInt(norm(av.wrapping_add(bv), *w), *w));
129 }
130 }
131 None
132 }
133 InstKind::ISub(a, b) => {
134 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
135 if let IrType::Int(w) = ty {
136 return Some(InstKind::ConstInt(norm(av.wrapping_sub(bv), *w), *w));
137 }
138 }
139 None
140 }
141 InstKind::IMul(a, b) => {
142 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
143 if let IrType::Int(w) = ty {
144 return Some(InstKind::ConstInt(norm(av.wrapping_mul(bv), *w), *w));
145 }
146 }
147 None
148 }
149 InstKind::IDiv(a, b) => {
150 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
151 if let IrType::Int(w) = ty {
152 if bv == 0 {
153 return None;
154 } // leave divide-by-zero to runtime
155 if av == signed_min(*w) && bv == -1 {
156 return None;
157 }
158 return Some(InstKind::ConstInt(norm(av / bv, *w), *w));
159 }
160 }
161 None
162 }
163 InstKind::IMod(a, b) => {
164 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
165 if let IrType::Int(w) = ty {
166 if bv == 0 {
167 return None;
168 }
169 if av == signed_min(*w) && bv == -1 {
170 return None;
171 }
172 return Some(InstKind::ConstInt(norm(av % bv, *w), *w));
173 }
174 }
175 None
176 }
177 InstKind::INeg(a) => {
178 if let Some(Const::Int(av, _)) = get(a) {
179 if let IrType::Int(w) = ty {
180 return Some(InstKind::ConstInt(norm(av.wrapping_neg(), *w), *w));
181 }
182 }
183 None
184 }
185
186 // Float arithmetic ------------------------------------------------
187 InstKind::FAdd(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x + y),
188 InstKind::FSub(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x - y),
189 InstKind::FMul(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x * y),
190 InstKind::FDiv(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x / y),
191 InstKind::FNeg(a) => fold_float_un(get(a), ty, |x| -x),
192 InstKind::FAbs(a) => fold_float_un(get(a), ty, |x| x.abs()),
193 InstKind::FSqrt(a) => fold_float_un(get(a), ty, |x| x.sqrt()),
194 InstKind::FPow(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x.powf(y)),
195
196 // Comparisons -----------------------------------------------------
197 InstKind::ICmp(op, a, b) => {
198 if let (Some(Const::Int(av, aw)), Some(Const::Int(bv, bw))) = (get(a), get(b)) {
199 // Audit M-D: each operand must be sign-extended at
200 // its OWN declared width. The earlier version reused
201 // `a`'s width for both operands, which is correct
202 // only when the verifier guarantees both operands
203 // have the same width — and the verifier currently
204 // only checks `is_int()`, not bit width. After a
205 // future GVN-style pass introduces width drift this
206 // becomes a wrong-result hazard.
207 let av = sext(av, aw.bits());
208 let bv = sext(bv, bw.bits());
209 let r = match op {
210 CmpOp::Eq => av == bv,
211 CmpOp::Ne => av != bv,
212 CmpOp::Lt => av < bv,
213 CmpOp::Le => av <= bv,
214 CmpOp::Gt => av > bv,
215 CmpOp::Ge => av >= bv,
216 };
217 return Some(InstKind::ConstBool(r));
218 }
219 None
220 }
221 InstKind::FCmp(op, a, b) => {
222 if let (Some(Const::Float(av, aw)), Some(Const::Float(bv, bw))) = (get(a), get(b)) {
223 // Audit B-3 (defense in depth): explicitly round each
224 // operand at its declared width before comparing.
225 // Today every Const::Float producer respects the
226 // M-1 invariant that f32-tagged values are already
227 // f32-rounded — but doing the rounding here protects
228 // future folds from a bug that breaks the invariant.
229 let av = round_for_width(av, aw);
230 let bv = round_for_width(bv, bw);
231 let r = match op {
232 CmpOp::Eq => av == bv,
233 CmpOp::Ne => av != bv,
234 CmpOp::Lt => av < bv,
235 CmpOp::Le => av <= bv,
236 CmpOp::Gt => av > bv,
237 CmpOp::Ge => av >= bv,
238 };
239 return Some(InstKind::ConstBool(r));
240 }
241 None
242 }
243
244 // Logic on bools --------------------------------------------------
245 InstKind::And(a, b) => {
246 if let (Some(Const::Bool(av)), Some(Const::Bool(bv))) = (get(a), get(b)) {
247 return Some(InstKind::ConstBool(av && bv));
248 }
249 None
250 }
251 InstKind::Or(a, b) => {
252 if let (Some(Const::Bool(av)), Some(Const::Bool(bv))) = (get(a), get(b)) {
253 return Some(InstKind::ConstBool(av || bv));
254 }
255 None
256 }
257 InstKind::Not(a) => {
258 if let Some(Const::Bool(av)) = get(a) {
259 return Some(InstKind::ConstBool(!av));
260 }
261 None
262 }
263
264 // Bitwise ---------------------------------------------------------
265 InstKind::BitAnd(a, b) => fold_int_bin(get(a), get(b), ty, |x, y| x & y),
266 InstKind::BitOr(a, b) => fold_int_bin(get(a), get(b), ty, |x, y| x | y),
267 InstKind::BitXor(a, b) => fold_int_bin(get(a), get(b), ty, |x, y| x ^ y),
268 InstKind::BitNot(a) => {
269 if let Some(Const::Int(av, _)) = get(a) {
270 if let IrType::Int(w) = ty {
271 return Some(InstKind::ConstInt(norm(!av, *w), *w));
272 }
273 }
274 None
275 }
276 // Audit M-B / M-C: shift counts that are negative or ≥ the
277 // operand width are *implementation-defined* in the IR. We
278 // deliberately bail (return None) so the runtime / codegen
279 // path is the single source of truth. AArch64 LSL/LSR/ASR
280 // mask the count to the low log2(width) bits — emulating
281 // that here is risky because Fortran ISHFT lowering computes
282 // both branches and selects, and a bogus folded value could
283 // leak through CSE / strength_reduce / other passes.
284 InstKind::Shl(a, b) => {
285 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
286 if let IrType::Int(w) = ty {
287 let bits = w.bits() as i128;
288 if (0..bits).contains(&bv) {
289 return Some(InstKind::ConstInt(norm(av.wrapping_shl(bv as u32), *w), *w));
290 }
291 return None;
292 }
293 }
294 None
295 }
296 InstKind::LShr(a, b) => {
297 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
298 if let IrType::Int(w) = ty {
299 let bits = w.bits() as i128;
300 if (0..bits).contains(&bv) {
301 let unsigned = (mask(av, w.bits()) as u128) >> (bv as u32);
302 return Some(InstKind::ConstInt(norm(unsigned as i128, *w), *w));
303 }
304 return None;
305 }
306 }
307 None
308 }
309 InstKind::AShr(a, b) => {
310 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) {
311 if let IrType::Int(w) = ty {
312 let bits = w.bits() as i128;
313 if (0..bits).contains(&bv) {
314 let signed = sext(av, w.bits()) >> (bv as u32);
315 return Some(InstKind::ConstInt(norm(signed, *w), *w));
316 }
317 return None;
318 }
319 }
320 None
321 }
322 // Audit Med-3: the count uses the *source* operand's width
323 // for the bit-mask but emits the result tagged with `inst.ty`'s
324 // width (taken from the IR declaration), not the source's.
325 // Today they always match because the lowerer keeps them in
326 // sync, but a future change that makes popcount return i32
327 // unconditionally (matching C intrinsics) would silently
328 // produce a wrong-width constant if we kept reading from `w`.
329 InstKind::PopCount(a) => {
330 if let (Some(Const::Int(av, src_w)), IrType::Int(out_w)) = (get(a), ty) {
331 let val = mask(av, src_w.bits()) as u128;
332 return Some(InstKind::ConstInt(
333 norm(val.count_ones() as i128, *out_w),
334 *out_w,
335 ));
336 }
337 None
338 }
339 InstKind::CountLeadingZeros(a) => {
340 if let (Some(Const::Int(av, src_w)), IrType::Int(out_w)) = (get(a), ty) {
341 let bits = src_w.bits();
342 let val = mask(av, bits) as u128;
343 let lz = if val == 0 {
344 bits as i128
345 } else {
346 (val.leading_zeros() as i128) - (128 - bits as i128)
347 };
348 return Some(InstKind::ConstInt(norm(lz, *out_w), *out_w));
349 }
350 None
351 }
352 InstKind::CountTrailingZeros(a) => {
353 if let (Some(Const::Int(av, src_w)), IrType::Int(out_w)) = (get(a), ty) {
354 let bits = src_w.bits();
355 let val = mask(av, bits) as u128;
356 let tz = if val == 0 {
357 bits as i128
358 } else {
359 val.trailing_zeros() as i128
360 };
361 return Some(InstKind::ConstInt(norm(tz, *out_w), *out_w));
362 }
363 None
364 }
365
366 // Conversions -----------------------------------------------------
367 InstKind::IntToFloat(a, fw) => {
368 if let Some(Const::Int(av, w)) = get(a) {
369 let signed = sext(av, w.bits());
370 // Round through the destination precision so that
371 // downstream `const_fold` operations see the exact
372 // value the runtime would observe. AArch64 `SCVTF`
373 // produces an f32 directly for f32 destinations; we
374 // approximate that by going through `f32` then back
375 // to `f64` for IR storage. The double-rounding from
376 // `i64 → f64 → f32` differs from a hypothetical
377 // direct `i64 → f32` by at most one ULP near
378 // round-tie boundaries — acceptable for now;
379 // codegen always does the direct conversion at
380 // runtime, so the only divergence is in IR-level
381 // operations chained off of this constant.
382 let v = match fw {
383 FloatWidth::F32 => signed as f32 as f64,
384 FloatWidth::F64 => signed as f64,
385 };
386 return Some(InstKind::ConstFloat(v, *fw));
387 }
388 None
389 }
390 InstKind::FloatToInt(a, w) => {
391 if let Some(Const::Float(av, src_fw)) = get(a) {
392 if av.is_nan() || av.is_infinite() {
393 return None;
394 }
395 // Audit B-3 (defense in depth): re-round at the
396 // source float width before truncating. The M-1
397 // invariant says any f32-tagged Const::Float is
398 // already f32-rounded, but doing it explicitly here
399 // means a future invariant break can't silently
400 // produce a wrong int.
401 let av = round_for_width(av, src_fw);
402 // Fortran INT() truncates toward zero. Match.
403 let truncd = av.trunc();
404 let lo = match w {
405 IntWidth::I8 => i8::MIN as f64,
406 IntWidth::I16 => i16::MIN as f64,
407 IntWidth::I32 => i32::MIN as f64,
408 IntWidth::I64 => i64::MIN as f64,
409 IntWidth::I128 => return None,
410 };
411 let hi = match w {
412 IntWidth::I8 => i8::MAX as f64,
413 IntWidth::I16 => i16::MAX as f64,
414 IntWidth::I32 => i32::MAX as f64,
415 IntWidth::I64 => i64::MAX as f64,
416 IntWidth::I128 => return None,
417 };
418 if truncd < lo || truncd > hi {
419 return None;
420 }
421 return Some(InstKind::ConstInt(norm(truncd as i128, *w), *w));
422 }
423 None
424 }
425 InstKind::FloatExtend(a, fw) | InstKind::FloatTrunc(a, fw) => {
426 if let Some(Const::Float(av, _)) = get(a) {
427 let v = match fw {
428 FloatWidth::F32 => av as f32 as f64, // round through f32
429 FloatWidth::F64 => av,
430 };
431 return Some(InstKind::ConstFloat(v, *fw));
432 }
433 None
434 }
435 InstKind::IntExtend(a, w, signed) => {
436 if let Some(Const::Int(av, src_w)) = get(a) {
437 let v = if *signed {
438 sext(av, src_w.bits())
439 } else {
440 mask(av, src_w.bits())
441 };
442 return Some(InstKind::ConstInt(norm(v, *w), *w));
443 }
444 None
445 }
446 InstKind::IntTrunc(a, w) => {
447 if let Some(Const::Int(av, _)) = get(a) {
448 return Some(InstKind::ConstInt(norm(av, *w), *w));
449 }
450 None
451 }
452
453 // Select on a known condition ------------------------------------
454 InstKind::Select(c, t, f) => {
455 if let Some(Const::Bool(cv)) = get(c) {
456 let chosen = if cv { *t } else { *f };
457 if let Some(k) = consts.get(&chosen) {
458 // Re-emit using the **destination** type carried
459 // by the Select instruction itself, NOT the
460 // chosen branch's source type. After IntExtend /
461 // FloatTrunc / similar fold chains the chosen
462 // operand can carry a narrower width than the
463 // Select declares; reusing it would silently
464 // produce a `kind` whose embedded width disagrees
465 // with `inst.ty`. Audit C-C.
466 return match (ty, *k) {
467 (IrType::Int(w), Const::Int(v, src_w)) => {
468 // Audit B-1: sign-extend the chosen
469 // value at its **source** width before
470 // re-norming at the destination. The
471 // earlier version discarded `src_w` and
472 // treated the raw stored bits as already
473 // sign-extended at the destination width,
474 // silently producing a wrong-signed value
475 // when widths differed. Example:
476 // chosen = ConstInt(255, I8) (i.e. -1
477 // at i8 precision); destination = I64.
478 // Without source-width sext, we'd emit
479 // ConstInt(255, I64) instead of
480 // ConstInt(-1, I64).
481 let signed = sext(v, src_w.bits());
482 Some(InstKind::ConstInt(norm(signed, *w), *w))
483 }
484 (IrType::Float(FloatWidth::F32), Const::Float(v, _)) => {
485 Some(InstKind::ConstFloat(v as f32 as f64, FloatWidth::F32))
486 }
487 (IrType::Float(FloatWidth::F64), Const::Float(v, _)) => {
488 Some(InstKind::ConstFloat(v, FloatWidth::F64))
489 }
490 (IrType::Bool, Const::Bool(b)) => Some(InstKind::ConstBool(b)),
491 // Type categories disagree — bail rather than
492 // type-pun. The verifier doesn't catch the
493 // mismatch (it only checks `is_int`/`is_float`
494 // not specific widths), so being conservative
495 // here is the correctness floor.
496 _ => None,
497 };
498 }
499 }
500 None
501 }
502
503 _ => None,
504 }
505 }
506
507 // fold_float_bin / fold_float_un rely on the N-9 invariant:
508 // `Const::from_inst` normalizes every incoming `Const::Float` via
509 // `round_for_width`, so `av` / `bv` below are already exact at their
510 // declared precision. This lets us operate on the raw f64 storage
511 // without re-rounding per input — the N-3 defense-in-depth concern
512 // raised in the third audit is satisfied at the producer side rather
513 // than repeated at every use site. We DO still round the result for
514 // f32 destinations because the op itself can produce a value
515 // outside f32's representable range.
516 fn fold_float_bin<F>(a: Option<Const>, b: Option<Const>, ty: &IrType, op: F) -> Option<InstKind>
517 where
518 F: FnOnce(f64, f64) -> f64,
519 {
520 if let (Some(Const::Float(av, _)), Some(Const::Float(bv, _))) = (a, b) {
521 if let IrType::Float(w) = ty {
522 let r = op(av, bv);
523 let r = round_for_width(r, *w);
524 return Some(InstKind::ConstFloat(r, *w));
525 }
526 }
527 None
528 }
529
530 fn fold_float_un<F>(a: Option<Const>, ty: &IrType, op: F) -> Option<InstKind>
531 where
532 F: FnOnce(f64) -> f64,
533 {
534 if let Some(Const::Float(av, _)) = a {
535 if let IrType::Float(w) = ty {
536 let r = op(av);
537 let r = round_for_width(r, *w);
538 return Some(InstKind::ConstFloat(r, *w));
539 }
540 }
541 None
542 }
543
544 fn fold_int_bin<F>(a: Option<Const>, b: Option<Const>, ty: &IrType, op: F) -> Option<InstKind>
545 where
546 F: FnOnce(i128, i128) -> i128,
547 {
548 if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (a, b) {
549 if let IrType::Int(w) = ty {
550 return Some(InstKind::ConstInt(norm(op(av, bv), *w), *w));
551 }
552 }
553 None
554 }
555
556 /// The constant folding pass.
557 pub struct ConstFold;
558
559 impl Pass for ConstFold {
560 fn name(&self) -> &'static str {
561 "const-fold"
562 }
563
564 fn run(&self, module: &mut Module) -> bool {
565 let mut changed = false;
566 for func in &mut module.functions {
567 // Audit N-8: we walk `func.blocks` in vec order, which
568 // is NOT guaranteed to be reverse-postorder. If a fold
569 // in an early-vec block needs a constant defined by a
570 // later-vec block (that still dominates it in the CFG),
571 // the earlier version would miss the fold — and the
572 // outer pass-manager fixpoint would NOT rescue us when
573 // no other pass in the pipeline produced a change.
574 //
575 // Fix: two phases.
576 //
577 // 1. Pre-collect every `Const*` instruction's value so
578 // the map is fully populated before any fold runs.
579 // 2. Iterate the fold walk to a local fixpoint. A fold
580 // that creates a new constant is added to the map
581 // on the fly, unlocking downstream folds in the
582 // next iteration. Each iteration is monotonic (the
583 // consts map only grows), so the inner loop is
584 // bounded by O(number of foldable instructions).
585 let mut consts: HashMap<ValueId, Const> = HashMap::new();
586 for block in &func.blocks {
587 for inst in &block.insts {
588 if let Some(c) = Const::from_inst(&inst.kind) {
589 consts.insert(inst.id, c);
590 }
591 }
592 }
593
594 let mut inner_changed = true;
595 while inner_changed {
596 inner_changed = false;
597 for block in &mut func.blocks {
598 for inst in &mut block.insts {
599 if consts.contains_key(&inst.id) {
600 continue;
601 }
602 if let Some(new_kind) = try_fold(&inst.kind, &inst.ty, &consts) {
603 if let Some(c) = Const::from_inst(&new_kind) {
604 consts.insert(inst.id, c);
605 }
606 inst.kind = new_kind;
607 changed = true;
608 inner_changed = true;
609 }
610 }
611 }
612 }
613 }
614 changed
615 }
616 }
617
618 #[cfg(test)]
619 mod tests {
620 use super::*;
621 use crate::ir::types::IrType;
622
623 fn dummy_span() -> crate::lexer::Span {
624 let p = crate::lexer::Position { line: 1, col: 1 };
625 crate::lexer::Span {
626 start: p,
627 end: p,
628 file_id: 0,
629 }
630 }
631
632 fn make_module_with(insts: Vec<(InstKind, IrType)>) -> (Module, Vec<ValueId>) {
633 let mut m = Module::new("t".into());
634 let mut f = Function::new("f".into(), vec![], IrType::Void);
635 let entry = f.entry;
636 let mut ids = Vec::new();
637 for (kind, ty) in insts {
638 let id = f.next_value_id();
639 ids.push(id);
640 f.block_mut(entry).insts.push(Inst {
641 id,
642 kind,
643 ty,
644 span: dummy_span(),
645 });
646 }
647 f.block_mut(entry).terminator = Some(Terminator::Return(None));
648 m.add_function(f);
649 (m, ids)
650 }
651
652 fn first_block_kinds(m: &Module) -> Vec<InstKind> {
653 m.functions[0].blocks[0]
654 .insts
655 .iter()
656 .map(|i| i.kind.clone())
657 .collect()
658 }
659
660 #[test]
661 fn folds_iadd_i32() {
662 let (mut m, ids) = make_module_with(vec![
663 (
664 InstKind::ConstInt(3, IntWidth::I32),
665 IrType::Int(IntWidth::I32),
666 ),
667 (
668 InstKind::ConstInt(4, IntWidth::I32),
669 IrType::Int(IntWidth::I32),
670 ),
671 (
672 InstKind::IAdd(ValueId(0), ValueId(1)),
673 IrType::Int(IntWidth::I32),
674 ),
675 ]);
676 let _ = ids;
677 assert!(ConstFold.run(&mut m));
678 let kinds = first_block_kinds(&m);
679 assert!(matches!(kinds[2], InstKind::ConstInt(7, IntWidth::I32)));
680 }
681
682 #[test]
683 fn integer_overflow_wraps_at_width() {
684 // i8: 100 + 50 = 150 → wraps to -106
685 let (mut m, _) = make_module_with(vec![
686 (
687 InstKind::ConstInt(100, IntWidth::I8),
688 IrType::Int(IntWidth::I8),
689 ),
690 (
691 InstKind::ConstInt(50, IntWidth::I8),
692 IrType::Int(IntWidth::I8),
693 ),
694 (
695 InstKind::IAdd(ValueId(0), ValueId(1)),
696 IrType::Int(IntWidth::I8),
697 ),
698 ]);
699 assert!(ConstFold.run(&mut m));
700 let kinds = first_block_kinds(&m);
701 match kinds[2] {
702 InstKind::ConstInt(v, IntWidth::I8) => assert_eq!(v, -106),
703 _ => panic!("expected ConstInt"),
704 }
705 }
706
707 #[test]
708 fn idiv_by_zero_left_alone() {
709 let (mut m, _) = make_module_with(vec![
710 (
711 InstKind::ConstInt(10, IntWidth::I32),
712 IrType::Int(IntWidth::I32),
713 ),
714 (
715 InstKind::ConstInt(0, IntWidth::I32),
716 IrType::Int(IntWidth::I32),
717 ),
718 (
719 InstKind::IDiv(ValueId(0), ValueId(1)),
720 IrType::Int(IntWidth::I32),
721 ),
722 ]);
723 assert!(!ConstFold.run(&mut m));
724 let kinds = first_block_kinds(&m);
725 assert!(matches!(kinds[2], InstKind::IDiv(..)));
726 }
727
728 #[test]
729 fn fmul_f64() {
730 let (mut m, _) = make_module_with(vec![
731 (
732 InstKind::ConstFloat(2.0, FloatWidth::F64),
733 IrType::Float(FloatWidth::F64),
734 ),
735 (
736 InstKind::ConstFloat(3.14, FloatWidth::F64),
737 IrType::Float(FloatWidth::F64),
738 ),
739 (
740 InstKind::FMul(ValueId(0), ValueId(1)),
741 IrType::Float(FloatWidth::F64),
742 ),
743 ]);
744 assert!(ConstFold.run(&mut m));
745 let kinds = first_block_kinds(&m);
746 match kinds[2] {
747 InstKind::ConstFloat(v, FloatWidth::F64) => assert!((v - 6.28).abs() < 1e-12),
748 _ => panic!("expected ConstFloat"),
749 }
750 }
751
752 #[test]
753 fn icmp_eq_true() {
754 let (mut m, _) = make_module_with(vec![
755 (
756 InstKind::ConstInt(5, IntWidth::I32),
757 IrType::Int(IntWidth::I32),
758 ),
759 (
760 InstKind::ConstInt(5, IntWidth::I32),
761 IrType::Int(IntWidth::I32),
762 ),
763 (
764 InstKind::ICmp(CmpOp::Eq, ValueId(0), ValueId(1)),
765 IrType::Bool,
766 ),
767 ]);
768 assert!(ConstFold.run(&mut m));
769 assert!(matches!(
770 first_block_kinds(&m)[2],
771 InstKind::ConstBool(true)
772 ));
773 }
774
775 #[test]
776 fn icmp_signed_lt_with_i8_negatives() {
777 // i8 representation: -1 stored as 0xff (255). Without sign-extension,
778 // a naive (av < bv) would compare 255 < 1 → false; correct is true.
779 let (mut m, _) = make_module_with(vec![
780 (
781 InstKind::ConstInt(-1, IntWidth::I8),
782 IrType::Int(IntWidth::I8),
783 ),
784 (
785 InstKind::ConstInt(1, IntWidth::I8),
786 IrType::Int(IntWidth::I8),
787 ),
788 (
789 InstKind::ICmp(CmpOp::Lt, ValueId(0), ValueId(1)),
790 IrType::Bool,
791 ),
792 ]);
793 assert!(ConstFold.run(&mut m));
794 assert!(matches!(
795 first_block_kinds(&m)[2],
796 InstKind::ConstBool(true)
797 ));
798 }
799
800 #[test]
801 fn shl_power_of_two() {
802 let (mut m, _) = make_module_with(vec![
803 (
804 InstKind::ConstInt(1, IntWidth::I32),
805 IrType::Int(IntWidth::I32),
806 ),
807 (
808 InstKind::ConstInt(3, IntWidth::I32),
809 IrType::Int(IntWidth::I32),
810 ),
811 (
812 InstKind::Shl(ValueId(0), ValueId(1)),
813 IrType::Int(IntWidth::I32),
814 ),
815 ]);
816 assert!(ConstFold.run(&mut m));
817 match first_block_kinds(&m)[2] {
818 InstKind::ConstInt(v, IntWidth::I32) => assert_eq!(v, 8),
819 _ => panic!(),
820 }
821 }
822
823 #[test]
824 fn int_to_float_chain() {
825 // const(42:i32) → int_to_float → const(42.0:f64)
826 let (mut m, _) = make_module_with(vec![
827 (
828 InstKind::ConstInt(42, IntWidth::I32),
829 IrType::Int(IntWidth::I32),
830 ),
831 (
832 InstKind::IntToFloat(ValueId(0), FloatWidth::F64),
833 IrType::Float(FloatWidth::F64),
834 ),
835 ]);
836 assert!(ConstFold.run(&mut m));
837 match first_block_kinds(&m)[1] {
838 InstKind::ConstFloat(v, FloatWidth::F64) => assert_eq!(v, 42.0),
839 _ => panic!(),
840 }
841 }
842
843 #[test]
844 fn select_on_known_cond_picks_branch() {
845 let (mut m, _) = make_module_with(vec![
846 (InstKind::ConstBool(true), IrType::Bool),
847 (
848 InstKind::ConstInt(10, IntWidth::I32),
849 IrType::Int(IntWidth::I32),
850 ),
851 (
852 InstKind::ConstInt(20, IntWidth::I32),
853 IrType::Int(IntWidth::I32),
854 ),
855 (
856 InstKind::Select(ValueId(0), ValueId(1), ValueId(2)),
857 IrType::Int(IntWidth::I32),
858 ),
859 ]);
860 assert!(ConstFold.run(&mut m));
861 match first_block_kinds(&m)[3] {
862 InstKind::ConstInt(v, IntWidth::I32) => assert_eq!(v, 10),
863 _ => panic!(),
864 }
865 }
866
867 #[test]
868 fn fixpoint_via_chained_constants() {
869 // (3 + 4) * 2 = 14, where the inner add must fold first.
870 let (mut m, _) = make_module_with(vec![
871 (
872 InstKind::ConstInt(3, IntWidth::I32),
873 IrType::Int(IntWidth::I32),
874 ),
875 (
876 InstKind::ConstInt(4, IntWidth::I32),
877 IrType::Int(IntWidth::I32),
878 ),
879 (
880 InstKind::IAdd(ValueId(0), ValueId(1)),
881 IrType::Int(IntWidth::I32),
882 ),
883 (
884 InstKind::ConstInt(2, IntWidth::I32),
885 IrType::Int(IntWidth::I32),
886 ),
887 (
888 InstKind::IMul(ValueId(2), ValueId(3)),
889 IrType::Int(IntWidth::I32),
890 ),
891 ]);
892 // A single pass already handles this because we walk in order.
893 assert!(ConstFold.run(&mut m));
894 let kinds = first_block_kinds(&m);
895 assert!(matches!(kinds[2], InstKind::ConstInt(7, IntWidth::I32)));
896 assert!(matches!(kinds[4], InstKind::ConstInt(14, IntWidth::I32)));
897 }
898
899 #[test]
900 fn nan_float_to_int_left_alone() {
901 let (mut m, _) = make_module_with(vec![
902 (
903 InstKind::ConstFloat(f64::NAN, FloatWidth::F64),
904 IrType::Float(FloatWidth::F64),
905 ),
906 (
907 InstKind::FloatToInt(ValueId(0), IntWidth::I32),
908 IrType::Int(IntWidth::I32),
909 ),
910 ]);
911 assert!(!ConstFold.run(&mut m));
912 assert!(matches!(first_block_kinds(&m)[1], InstKind::FloatToInt(..)));
913 }
914 }
915