| 1 | //! Constant folding pass. |
| 2 | //! |
| 3 | //! Walks each function and replaces every instruction whose operands |
| 4 | //! are all compile-time constants with the corresponding `Const*` |
| 5 | //! instruction. The `ValueId` of each folded instruction is preserved |
| 6 | //! so that downstream uses see the same SSA name and don't need |
| 7 | //! rewriting. |
| 8 | //! |
| 9 | //! We deliberately keep this pass narrow: |
| 10 | //! |
| 11 | //! * Only operands defined by `ConstInt` / `ConstFloat` / `ConstBool` |
| 12 | //! in the same function are eligible. |
| 13 | //! * Integer arithmetic uses two's-complement wrapping with the |
| 14 | //! instruction's declared width — we mask down to 8/16/32/64 bits |
| 15 | //! so that `i8 + i8` overflow matches what codegen would produce. |
| 16 | //! * Integer division/modulo by zero is left alone — Fortran's |
| 17 | //! behavior on divide-by-zero is implementation-defined, and the |
| 18 | //! runtime / hardware should observe it the same way it would |
| 19 | //! without optimization. |
| 20 | //! * Float operations follow IEEE 754 semantics; division by zero |
| 21 | //! yields well-defined ±inf or NaN and is folded. |
| 22 | //! |
| 23 | //! Constant *propagation* (replacing uses of a value that happens to |
| 24 | //! be a constant) is a separate pass; this one only rewrites the |
| 25 | //! defining instruction in place. The pair runs to fixpoint via the |
| 26 | //! pass manager, so propagation feeds folding which feeds propagation. |
| 27 | |
| 28 | use super::pass::Pass; |
| 29 | use crate::ir::inst::*; |
| 30 | use crate::ir::types::{FloatWidth, IntWidth, IrType}; |
| 31 | use std::collections::HashMap; |
| 32 | |
| 33 | /// Compile-time constant value, normalized for folding. |
| 34 | #[derive(Debug, Clone, Copy)] |
| 35 | enum Const { |
| 36 | Int(i128, IntWidth), |
| 37 | Float(f64, FloatWidth), |
| 38 | Bool(bool), |
| 39 | } |
| 40 | |
| 41 | impl Const { |
| 42 | /// Build a `Const` from a constant-producing instruction, putting |
| 43 | /// the stored value into its **canonical** form for the declared |
| 44 | /// width. Audit N-9 (root-cause fix for N-1/M-1/M-2/M-5): |
| 45 | /// |
| 46 | /// * Integer values are sign-extended at their width. After this |
| 47 | /// call, `Const::Int(-1, I8)` and `Const::Int(255, I8)` both |
| 48 | /// land at the same stored value (`-1`), so downstream folds |
| 49 | /// can destructure `Const::Int(av, _)` without needing to |
| 50 | /// re-sext at the source width. This retires 12+ width-drift |
| 51 | /// hazards across IAdd/ISub/IMul/IDiv/IMod/BitAnd/... in one |
| 52 | /// move. |
| 53 | /// * Float values are rounded through their declared precision |
| 54 | /// (f32 values become exactly representable in f32 after |
| 55 | /// `v as f32 as f64`). This is belt-and-suspenders on top of |
| 56 | /// the producer-side guarantee from M-1. |
| 57 | fn from_inst(kind: &InstKind) -> Option<Self> { |
| 58 | match kind { |
| 59 | InstKind::ConstInt(v, w) => Some(Const::Int(sext(*v, w.bits()), *w)), |
| 60 | InstKind::ConstFloat(v, w) => Some(Const::Float(round_for_width(*v, *w), *w)), |
| 61 | InstKind::ConstBool(b) => Some(Const::Bool(*b)), |
| 62 | _ => None, |
| 63 | } |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | /// Sign-extend `v` from `bits` to a full i128. Used so that comparisons |
| 68 | /// and arithmetic on narrower integer types match hardware behavior. |
| 69 | fn sext(v: i128, bits: u32) -> i128 { |
| 70 | if bits >= 128 { |
| 71 | v |
| 72 | } else { |
| 73 | let shift = 128 - bits; |
| 74 | (v << shift) >> shift |
| 75 | } |
| 76 | } |
| 77 | |
| 78 | /// Mask `v` down to `bits` low bits. The IR stores integers as i128 but |
| 79 | /// arithmetic must wrap at the declared width. |
| 80 | fn mask(v: i128, bits: u32) -> i128 { |
| 81 | if bits >= 128 { |
| 82 | v |
| 83 | } else { |
| 84 | v & ((1i128 << bits) - 1) |
| 85 | } |
| 86 | } |
| 87 | |
| 88 | /// Truncate-then-sign-extend a wide arithmetic result back to its |
| 89 | /// declared width, normalized to a sign-extended i128. |
| 90 | fn norm(v: i128, w: IntWidth) -> i128 { |
| 91 | sext(mask(v, w.bits()), w.bits()) |
| 92 | } |
| 93 | |
| 94 | fn signed_min(w: IntWidth) -> i128 { |
| 95 | match w { |
| 96 | IntWidth::I8 => i8::MIN as i128, |
| 97 | IntWidth::I16 => i16::MIN as i128, |
| 98 | IntWidth::I32 => i32::MIN as i128, |
| 99 | IntWidth::I64 => i64::MIN as i128, |
| 100 | IntWidth::I128 => i128::MIN, |
| 101 | } |
| 102 | } |
| 103 | |
| 104 | /// Round an `f64`-stored value to the precision of its declared |
| 105 | /// `FloatWidth`. Audit B-3: used as defense in depth in `FCmp` and |
| 106 | /// `FloatToInt` so that those folds don't silently miscompute when |
| 107 | /// (some hypothetical future) producer breaks the M-1 invariant |
| 108 | /// (every `Const::Float(_, F32)` value is already f32-rounded). |
| 109 | fn round_for_width(v: f64, w: FloatWidth) -> f64 { |
| 110 | match w { |
| 111 | FloatWidth::F32 => v as f32 as f64, |
| 112 | FloatWidth::F64 => v, |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | /// Try to evaluate a single instruction given a map of known |
| 117 | /// constants. Returns `Some(new_kind)` if the instruction can be |
| 118 | /// replaced with a `Const*`. |
| 119 | #[allow(clippy::too_many_lines)] |
| 120 | fn try_fold(kind: &InstKind, ty: &IrType, consts: &HashMap<ValueId, Const>) -> Option<InstKind> { |
| 121 | let get = |id: &ValueId| consts.get(id).copied(); |
| 122 | |
| 123 | match kind { |
| 124 | // Integer arithmetic ---------------------------------------------- |
| 125 | InstKind::IAdd(a, b) => { |
| 126 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 127 | if let IrType::Int(w) = ty { |
| 128 | return Some(InstKind::ConstInt(norm(av.wrapping_add(bv), *w), *w)); |
| 129 | } |
| 130 | } |
| 131 | None |
| 132 | } |
| 133 | InstKind::ISub(a, b) => { |
| 134 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 135 | if let IrType::Int(w) = ty { |
| 136 | return Some(InstKind::ConstInt(norm(av.wrapping_sub(bv), *w), *w)); |
| 137 | } |
| 138 | } |
| 139 | None |
| 140 | } |
| 141 | InstKind::IMul(a, b) => { |
| 142 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 143 | if let IrType::Int(w) = ty { |
| 144 | return Some(InstKind::ConstInt(norm(av.wrapping_mul(bv), *w), *w)); |
| 145 | } |
| 146 | } |
| 147 | None |
| 148 | } |
| 149 | InstKind::IDiv(a, b) => { |
| 150 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 151 | if let IrType::Int(w) = ty { |
| 152 | if bv == 0 { |
| 153 | return None; |
| 154 | } // leave divide-by-zero to runtime |
| 155 | if av == signed_min(*w) && bv == -1 { |
| 156 | return None; |
| 157 | } |
| 158 | return Some(InstKind::ConstInt(norm(av / bv, *w), *w)); |
| 159 | } |
| 160 | } |
| 161 | None |
| 162 | } |
| 163 | InstKind::IMod(a, b) => { |
| 164 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 165 | if let IrType::Int(w) = ty { |
| 166 | if bv == 0 { |
| 167 | return None; |
| 168 | } |
| 169 | if av == signed_min(*w) && bv == -1 { |
| 170 | return None; |
| 171 | } |
| 172 | return Some(InstKind::ConstInt(norm(av % bv, *w), *w)); |
| 173 | } |
| 174 | } |
| 175 | None |
| 176 | } |
| 177 | InstKind::INeg(a) => { |
| 178 | if let Some(Const::Int(av, _)) = get(a) { |
| 179 | if let IrType::Int(w) = ty { |
| 180 | return Some(InstKind::ConstInt(norm(av.wrapping_neg(), *w), *w)); |
| 181 | } |
| 182 | } |
| 183 | None |
| 184 | } |
| 185 | |
| 186 | // Float arithmetic ------------------------------------------------ |
| 187 | InstKind::FAdd(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x + y), |
| 188 | InstKind::FSub(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x - y), |
| 189 | InstKind::FMul(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x * y), |
| 190 | InstKind::FDiv(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x / y), |
| 191 | InstKind::FNeg(a) => fold_float_un(get(a), ty, |x| -x), |
| 192 | InstKind::FAbs(a) => fold_float_un(get(a), ty, |x| x.abs()), |
| 193 | InstKind::FSqrt(a) => fold_float_un(get(a), ty, |x| x.sqrt()), |
| 194 | InstKind::FPow(a, b) => fold_float_bin(get(a), get(b), ty, |x, y| x.powf(y)), |
| 195 | |
| 196 | // Comparisons ----------------------------------------------------- |
| 197 | InstKind::ICmp(op, a, b) => { |
| 198 | if let (Some(Const::Int(av, aw)), Some(Const::Int(bv, bw))) = (get(a), get(b)) { |
| 199 | // Audit M-D: each operand must be sign-extended at |
| 200 | // its OWN declared width. The earlier version reused |
| 201 | // `a`'s width for both operands, which is correct |
| 202 | // only when the verifier guarantees both operands |
| 203 | // have the same width — and the verifier currently |
| 204 | // only checks `is_int()`, not bit width. After a |
| 205 | // future GVN-style pass introduces width drift this |
| 206 | // becomes a wrong-result hazard. |
| 207 | let av = sext(av, aw.bits()); |
| 208 | let bv = sext(bv, bw.bits()); |
| 209 | let r = match op { |
| 210 | CmpOp::Eq => av == bv, |
| 211 | CmpOp::Ne => av != bv, |
| 212 | CmpOp::Lt => av < bv, |
| 213 | CmpOp::Le => av <= bv, |
| 214 | CmpOp::Gt => av > bv, |
| 215 | CmpOp::Ge => av >= bv, |
| 216 | }; |
| 217 | return Some(InstKind::ConstBool(r)); |
| 218 | } |
| 219 | None |
| 220 | } |
| 221 | InstKind::FCmp(op, a, b) => { |
| 222 | if let (Some(Const::Float(av, aw)), Some(Const::Float(bv, bw))) = (get(a), get(b)) { |
| 223 | // Audit B-3 (defense in depth): explicitly round each |
| 224 | // operand at its declared width before comparing. |
| 225 | // Today every Const::Float producer respects the |
| 226 | // M-1 invariant that f32-tagged values are already |
| 227 | // f32-rounded — but doing the rounding here protects |
| 228 | // future folds from a bug that breaks the invariant. |
| 229 | let av = round_for_width(av, aw); |
| 230 | let bv = round_for_width(bv, bw); |
| 231 | let r = match op { |
| 232 | CmpOp::Eq => av == bv, |
| 233 | CmpOp::Ne => av != bv, |
| 234 | CmpOp::Lt => av < bv, |
| 235 | CmpOp::Le => av <= bv, |
| 236 | CmpOp::Gt => av > bv, |
| 237 | CmpOp::Ge => av >= bv, |
| 238 | }; |
| 239 | return Some(InstKind::ConstBool(r)); |
| 240 | } |
| 241 | None |
| 242 | } |
| 243 | |
| 244 | // Logic on bools -------------------------------------------------- |
| 245 | InstKind::And(a, b) => { |
| 246 | if let (Some(Const::Bool(av)), Some(Const::Bool(bv))) = (get(a), get(b)) { |
| 247 | return Some(InstKind::ConstBool(av && bv)); |
| 248 | } |
| 249 | None |
| 250 | } |
| 251 | InstKind::Or(a, b) => { |
| 252 | if let (Some(Const::Bool(av)), Some(Const::Bool(bv))) = (get(a), get(b)) { |
| 253 | return Some(InstKind::ConstBool(av || bv)); |
| 254 | } |
| 255 | None |
| 256 | } |
| 257 | InstKind::Not(a) => { |
| 258 | if let Some(Const::Bool(av)) = get(a) { |
| 259 | return Some(InstKind::ConstBool(!av)); |
| 260 | } |
| 261 | None |
| 262 | } |
| 263 | |
| 264 | // Bitwise --------------------------------------------------------- |
| 265 | InstKind::BitAnd(a, b) => fold_int_bin(get(a), get(b), ty, |x, y| x & y), |
| 266 | InstKind::BitOr(a, b) => fold_int_bin(get(a), get(b), ty, |x, y| x | y), |
| 267 | InstKind::BitXor(a, b) => fold_int_bin(get(a), get(b), ty, |x, y| x ^ y), |
| 268 | InstKind::BitNot(a) => { |
| 269 | if let Some(Const::Int(av, _)) = get(a) { |
| 270 | if let IrType::Int(w) = ty { |
| 271 | return Some(InstKind::ConstInt(norm(!av, *w), *w)); |
| 272 | } |
| 273 | } |
| 274 | None |
| 275 | } |
| 276 | // Audit M-B / M-C: shift counts that are negative or ≥ the |
| 277 | // operand width are *implementation-defined* in the IR. We |
| 278 | // deliberately bail (return None) so the runtime / codegen |
| 279 | // path is the single source of truth. AArch64 LSL/LSR/ASR |
| 280 | // mask the count to the low log2(width) bits — emulating |
| 281 | // that here is risky because Fortran ISHFT lowering computes |
| 282 | // both branches and selects, and a bogus folded value could |
| 283 | // leak through CSE / strength_reduce / other passes. |
| 284 | InstKind::Shl(a, b) => { |
| 285 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 286 | if let IrType::Int(w) = ty { |
| 287 | let bits = w.bits() as i128; |
| 288 | if (0..bits).contains(&bv) { |
| 289 | return Some(InstKind::ConstInt(norm(av.wrapping_shl(bv as u32), *w), *w)); |
| 290 | } |
| 291 | return None; |
| 292 | } |
| 293 | } |
| 294 | None |
| 295 | } |
| 296 | InstKind::LShr(a, b) => { |
| 297 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 298 | if let IrType::Int(w) = ty { |
| 299 | let bits = w.bits() as i128; |
| 300 | if (0..bits).contains(&bv) { |
| 301 | let unsigned = (mask(av, w.bits()) as u128) >> (bv as u32); |
| 302 | return Some(InstKind::ConstInt(norm(unsigned as i128, *w), *w)); |
| 303 | } |
| 304 | return None; |
| 305 | } |
| 306 | } |
| 307 | None |
| 308 | } |
| 309 | InstKind::AShr(a, b) => { |
| 310 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (get(a), get(b)) { |
| 311 | if let IrType::Int(w) = ty { |
| 312 | let bits = w.bits() as i128; |
| 313 | if (0..bits).contains(&bv) { |
| 314 | let signed = sext(av, w.bits()) >> (bv as u32); |
| 315 | return Some(InstKind::ConstInt(norm(signed, *w), *w)); |
| 316 | } |
| 317 | return None; |
| 318 | } |
| 319 | } |
| 320 | None |
| 321 | } |
| 322 | // Audit Med-3: the count uses the *source* operand's width |
| 323 | // for the bit-mask but emits the result tagged with `inst.ty`'s |
| 324 | // width (taken from the IR declaration), not the source's. |
| 325 | // Today they always match because the lowerer keeps them in |
| 326 | // sync, but a future change that makes popcount return i32 |
| 327 | // unconditionally (matching C intrinsics) would silently |
| 328 | // produce a wrong-width constant if we kept reading from `w`. |
| 329 | InstKind::PopCount(a) => { |
| 330 | if let (Some(Const::Int(av, src_w)), IrType::Int(out_w)) = (get(a), ty) { |
| 331 | let val = mask(av, src_w.bits()) as u128; |
| 332 | return Some(InstKind::ConstInt( |
| 333 | norm(val.count_ones() as i128, *out_w), |
| 334 | *out_w, |
| 335 | )); |
| 336 | } |
| 337 | None |
| 338 | } |
| 339 | InstKind::CountLeadingZeros(a) => { |
| 340 | if let (Some(Const::Int(av, src_w)), IrType::Int(out_w)) = (get(a), ty) { |
| 341 | let bits = src_w.bits(); |
| 342 | let val = mask(av, bits) as u128; |
| 343 | let lz = if val == 0 { |
| 344 | bits as i128 |
| 345 | } else { |
| 346 | (val.leading_zeros() as i128) - (128 - bits as i128) |
| 347 | }; |
| 348 | return Some(InstKind::ConstInt(norm(lz, *out_w), *out_w)); |
| 349 | } |
| 350 | None |
| 351 | } |
| 352 | InstKind::CountTrailingZeros(a) => { |
| 353 | if let (Some(Const::Int(av, src_w)), IrType::Int(out_w)) = (get(a), ty) { |
| 354 | let bits = src_w.bits(); |
| 355 | let val = mask(av, bits) as u128; |
| 356 | let tz = if val == 0 { |
| 357 | bits as i128 |
| 358 | } else { |
| 359 | val.trailing_zeros() as i128 |
| 360 | }; |
| 361 | return Some(InstKind::ConstInt(norm(tz, *out_w), *out_w)); |
| 362 | } |
| 363 | None |
| 364 | } |
| 365 | |
| 366 | // Conversions ----------------------------------------------------- |
| 367 | InstKind::IntToFloat(a, fw) => { |
| 368 | if let Some(Const::Int(av, w)) = get(a) { |
| 369 | let signed = sext(av, w.bits()); |
| 370 | // Round through the destination precision so that |
| 371 | // downstream `const_fold` operations see the exact |
| 372 | // value the runtime would observe. AArch64 `SCVTF` |
| 373 | // produces an f32 directly for f32 destinations; we |
| 374 | // approximate that by going through `f32` then back |
| 375 | // to `f64` for IR storage. The double-rounding from |
| 376 | // `i64 → f64 → f32` differs from a hypothetical |
| 377 | // direct `i64 → f32` by at most one ULP near |
| 378 | // round-tie boundaries — acceptable for now; |
| 379 | // codegen always does the direct conversion at |
| 380 | // runtime, so the only divergence is in IR-level |
| 381 | // operations chained off of this constant. |
| 382 | let v = match fw { |
| 383 | FloatWidth::F32 => signed as f32 as f64, |
| 384 | FloatWidth::F64 => signed as f64, |
| 385 | }; |
| 386 | return Some(InstKind::ConstFloat(v, *fw)); |
| 387 | } |
| 388 | None |
| 389 | } |
| 390 | InstKind::FloatToInt(a, w) => { |
| 391 | if let Some(Const::Float(av, src_fw)) = get(a) { |
| 392 | if av.is_nan() || av.is_infinite() { |
| 393 | return None; |
| 394 | } |
| 395 | // Audit B-3 (defense in depth): re-round at the |
| 396 | // source float width before truncating. The M-1 |
| 397 | // invariant says any f32-tagged Const::Float is |
| 398 | // already f32-rounded, but doing it explicitly here |
| 399 | // means a future invariant break can't silently |
| 400 | // produce a wrong int. |
| 401 | let av = round_for_width(av, src_fw); |
| 402 | // Fortran INT() truncates toward zero. Match. |
| 403 | let truncd = av.trunc(); |
| 404 | let lo = match w { |
| 405 | IntWidth::I8 => i8::MIN as f64, |
| 406 | IntWidth::I16 => i16::MIN as f64, |
| 407 | IntWidth::I32 => i32::MIN as f64, |
| 408 | IntWidth::I64 => i64::MIN as f64, |
| 409 | IntWidth::I128 => return None, |
| 410 | }; |
| 411 | let hi = match w { |
| 412 | IntWidth::I8 => i8::MAX as f64, |
| 413 | IntWidth::I16 => i16::MAX as f64, |
| 414 | IntWidth::I32 => i32::MAX as f64, |
| 415 | IntWidth::I64 => i64::MAX as f64, |
| 416 | IntWidth::I128 => return None, |
| 417 | }; |
| 418 | if truncd < lo || truncd > hi { |
| 419 | return None; |
| 420 | } |
| 421 | return Some(InstKind::ConstInt(norm(truncd as i128, *w), *w)); |
| 422 | } |
| 423 | None |
| 424 | } |
| 425 | InstKind::FloatExtend(a, fw) | InstKind::FloatTrunc(a, fw) => { |
| 426 | if let Some(Const::Float(av, _)) = get(a) { |
| 427 | let v = match fw { |
| 428 | FloatWidth::F32 => av as f32 as f64, // round through f32 |
| 429 | FloatWidth::F64 => av, |
| 430 | }; |
| 431 | return Some(InstKind::ConstFloat(v, *fw)); |
| 432 | } |
| 433 | None |
| 434 | } |
| 435 | InstKind::IntExtend(a, w, signed) => { |
| 436 | if let Some(Const::Int(av, src_w)) = get(a) { |
| 437 | let v = if *signed { |
| 438 | sext(av, src_w.bits()) |
| 439 | } else { |
| 440 | mask(av, src_w.bits()) |
| 441 | }; |
| 442 | return Some(InstKind::ConstInt(norm(v, *w), *w)); |
| 443 | } |
| 444 | None |
| 445 | } |
| 446 | InstKind::IntTrunc(a, w) => { |
| 447 | if let Some(Const::Int(av, _)) = get(a) { |
| 448 | return Some(InstKind::ConstInt(norm(av, *w), *w)); |
| 449 | } |
| 450 | None |
| 451 | } |
| 452 | |
| 453 | // Select on a known condition ------------------------------------ |
| 454 | InstKind::Select(c, t, f) => { |
| 455 | if let Some(Const::Bool(cv)) = get(c) { |
| 456 | let chosen = if cv { *t } else { *f }; |
| 457 | if let Some(k) = consts.get(&chosen) { |
| 458 | // Re-emit using the **destination** type carried |
| 459 | // by the Select instruction itself, NOT the |
| 460 | // chosen branch's source type. After IntExtend / |
| 461 | // FloatTrunc / similar fold chains the chosen |
| 462 | // operand can carry a narrower width than the |
| 463 | // Select declares; reusing it would silently |
| 464 | // produce a `kind` whose embedded width disagrees |
| 465 | // with `inst.ty`. Audit C-C. |
| 466 | return match (ty, *k) { |
| 467 | (IrType::Int(w), Const::Int(v, src_w)) => { |
| 468 | // Audit B-1: sign-extend the chosen |
| 469 | // value at its **source** width before |
| 470 | // re-norming at the destination. The |
| 471 | // earlier version discarded `src_w` and |
| 472 | // treated the raw stored bits as already |
| 473 | // sign-extended at the destination width, |
| 474 | // silently producing a wrong-signed value |
| 475 | // when widths differed. Example: |
| 476 | // chosen = ConstInt(255, I8) (i.e. -1 |
| 477 | // at i8 precision); destination = I64. |
| 478 | // Without source-width sext, we'd emit |
| 479 | // ConstInt(255, I64) instead of |
| 480 | // ConstInt(-1, I64). |
| 481 | let signed = sext(v, src_w.bits()); |
| 482 | Some(InstKind::ConstInt(norm(signed, *w), *w)) |
| 483 | } |
| 484 | (IrType::Float(FloatWidth::F32), Const::Float(v, _)) => { |
| 485 | Some(InstKind::ConstFloat(v as f32 as f64, FloatWidth::F32)) |
| 486 | } |
| 487 | (IrType::Float(FloatWidth::F64), Const::Float(v, _)) => { |
| 488 | Some(InstKind::ConstFloat(v, FloatWidth::F64)) |
| 489 | } |
| 490 | (IrType::Bool, Const::Bool(b)) => Some(InstKind::ConstBool(b)), |
| 491 | // Type categories disagree — bail rather than |
| 492 | // type-pun. The verifier doesn't catch the |
| 493 | // mismatch (it only checks `is_int`/`is_float` |
| 494 | // not specific widths), so being conservative |
| 495 | // here is the correctness floor. |
| 496 | _ => None, |
| 497 | }; |
| 498 | } |
| 499 | } |
| 500 | None |
| 501 | } |
| 502 | |
| 503 | _ => None, |
| 504 | } |
| 505 | } |
| 506 | |
| 507 | // fold_float_bin / fold_float_un rely on the N-9 invariant: |
| 508 | // `Const::from_inst` normalizes every incoming `Const::Float` via |
| 509 | // `round_for_width`, so `av` / `bv` below are already exact at their |
| 510 | // declared precision. This lets us operate on the raw f64 storage |
| 511 | // without re-rounding per input — the N-3 defense-in-depth concern |
| 512 | // raised in the third audit is satisfied at the producer side rather |
| 513 | // than repeated at every use site. We DO still round the result for |
| 514 | // f32 destinations because the op itself can produce a value |
| 515 | // outside f32's representable range. |
| 516 | fn fold_float_bin<F>(a: Option<Const>, b: Option<Const>, ty: &IrType, op: F) -> Option<InstKind> |
| 517 | where |
| 518 | F: FnOnce(f64, f64) -> f64, |
| 519 | { |
| 520 | if let (Some(Const::Float(av, _)), Some(Const::Float(bv, _))) = (a, b) { |
| 521 | if let IrType::Float(w) = ty { |
| 522 | let r = op(av, bv); |
| 523 | let r = round_for_width(r, *w); |
| 524 | return Some(InstKind::ConstFloat(r, *w)); |
| 525 | } |
| 526 | } |
| 527 | None |
| 528 | } |
| 529 | |
| 530 | fn fold_float_un<F>(a: Option<Const>, ty: &IrType, op: F) -> Option<InstKind> |
| 531 | where |
| 532 | F: FnOnce(f64) -> f64, |
| 533 | { |
| 534 | if let Some(Const::Float(av, _)) = a { |
| 535 | if let IrType::Float(w) = ty { |
| 536 | let r = op(av); |
| 537 | let r = round_for_width(r, *w); |
| 538 | return Some(InstKind::ConstFloat(r, *w)); |
| 539 | } |
| 540 | } |
| 541 | None |
| 542 | } |
| 543 | |
| 544 | fn fold_int_bin<F>(a: Option<Const>, b: Option<Const>, ty: &IrType, op: F) -> Option<InstKind> |
| 545 | where |
| 546 | F: FnOnce(i128, i128) -> i128, |
| 547 | { |
| 548 | if let (Some(Const::Int(av, _)), Some(Const::Int(bv, _))) = (a, b) { |
| 549 | if let IrType::Int(w) = ty { |
| 550 | return Some(InstKind::ConstInt(norm(op(av, bv), *w), *w)); |
| 551 | } |
| 552 | } |
| 553 | None |
| 554 | } |
| 555 | |
| 556 | /// The constant folding pass. |
| 557 | pub struct ConstFold; |
| 558 | |
| 559 | impl Pass for ConstFold { |
| 560 | fn name(&self) -> &'static str { |
| 561 | "const-fold" |
| 562 | } |
| 563 | |
| 564 | fn run(&self, module: &mut Module) -> bool { |
| 565 | let mut changed = false; |
| 566 | for func in &mut module.functions { |
| 567 | // Audit N-8: we walk `func.blocks` in vec order, which |
| 568 | // is NOT guaranteed to be reverse-postorder. If a fold |
| 569 | // in an early-vec block needs a constant defined by a |
| 570 | // later-vec block (that still dominates it in the CFG), |
| 571 | // the earlier version would miss the fold — and the |
| 572 | // outer pass-manager fixpoint would NOT rescue us when |
| 573 | // no other pass in the pipeline produced a change. |
| 574 | // |
| 575 | // Fix: two phases. |
| 576 | // |
| 577 | // 1. Pre-collect every `Const*` instruction's value so |
| 578 | // the map is fully populated before any fold runs. |
| 579 | // 2. Iterate the fold walk to a local fixpoint. A fold |
| 580 | // that creates a new constant is added to the map |
| 581 | // on the fly, unlocking downstream folds in the |
| 582 | // next iteration. Each iteration is monotonic (the |
| 583 | // consts map only grows), so the inner loop is |
| 584 | // bounded by O(number of foldable instructions). |
| 585 | let mut consts: HashMap<ValueId, Const> = HashMap::new(); |
| 586 | for block in &func.blocks { |
| 587 | for inst in &block.insts { |
| 588 | if let Some(c) = Const::from_inst(&inst.kind) { |
| 589 | consts.insert(inst.id, c); |
| 590 | } |
| 591 | } |
| 592 | } |
| 593 | |
| 594 | let mut inner_changed = true; |
| 595 | while inner_changed { |
| 596 | inner_changed = false; |
| 597 | for block in &mut func.blocks { |
| 598 | for inst in &mut block.insts { |
| 599 | if consts.contains_key(&inst.id) { |
| 600 | continue; |
| 601 | } |
| 602 | if let Some(new_kind) = try_fold(&inst.kind, &inst.ty, &consts) { |
| 603 | if let Some(c) = Const::from_inst(&new_kind) { |
| 604 | consts.insert(inst.id, c); |
| 605 | } |
| 606 | inst.kind = new_kind; |
| 607 | changed = true; |
| 608 | inner_changed = true; |
| 609 | } |
| 610 | } |
| 611 | } |
| 612 | } |
| 613 | } |
| 614 | changed |
| 615 | } |
| 616 | } |
| 617 | |
| 618 | #[cfg(test)] |
| 619 | mod tests { |
| 620 | use super::*; |
| 621 | use crate::ir::types::IrType; |
| 622 | |
| 623 | fn dummy_span() -> crate::lexer::Span { |
| 624 | let p = crate::lexer::Position { line: 1, col: 1 }; |
| 625 | crate::lexer::Span { |
| 626 | start: p, |
| 627 | end: p, |
| 628 | file_id: 0, |
| 629 | } |
| 630 | } |
| 631 | |
| 632 | fn make_module_with(insts: Vec<(InstKind, IrType)>) -> (Module, Vec<ValueId>) { |
| 633 | let mut m = Module::new("t".into()); |
| 634 | let mut f = Function::new("f".into(), vec![], IrType::Void); |
| 635 | let entry = f.entry; |
| 636 | let mut ids = Vec::new(); |
| 637 | for (kind, ty) in insts { |
| 638 | let id = f.next_value_id(); |
| 639 | ids.push(id); |
| 640 | f.block_mut(entry).insts.push(Inst { |
| 641 | id, |
| 642 | kind, |
| 643 | ty, |
| 644 | span: dummy_span(), |
| 645 | }); |
| 646 | } |
| 647 | f.block_mut(entry).terminator = Some(Terminator::Return(None)); |
| 648 | m.add_function(f); |
| 649 | (m, ids) |
| 650 | } |
| 651 | |
| 652 | fn first_block_kinds(m: &Module) -> Vec<InstKind> { |
| 653 | m.functions[0].blocks[0] |
| 654 | .insts |
| 655 | .iter() |
| 656 | .map(|i| i.kind.clone()) |
| 657 | .collect() |
| 658 | } |
| 659 | |
| 660 | #[test] |
| 661 | fn folds_iadd_i32() { |
| 662 | let (mut m, ids) = make_module_with(vec![ |
| 663 | ( |
| 664 | InstKind::ConstInt(3, IntWidth::I32), |
| 665 | IrType::Int(IntWidth::I32), |
| 666 | ), |
| 667 | ( |
| 668 | InstKind::ConstInt(4, IntWidth::I32), |
| 669 | IrType::Int(IntWidth::I32), |
| 670 | ), |
| 671 | ( |
| 672 | InstKind::IAdd(ValueId(0), ValueId(1)), |
| 673 | IrType::Int(IntWidth::I32), |
| 674 | ), |
| 675 | ]); |
| 676 | let _ = ids; |
| 677 | assert!(ConstFold.run(&mut m)); |
| 678 | let kinds = first_block_kinds(&m); |
| 679 | assert!(matches!(kinds[2], InstKind::ConstInt(7, IntWidth::I32))); |
| 680 | } |
| 681 | |
| 682 | #[test] |
| 683 | fn integer_overflow_wraps_at_width() { |
| 684 | // i8: 100 + 50 = 150 → wraps to -106 |
| 685 | let (mut m, _) = make_module_with(vec![ |
| 686 | ( |
| 687 | InstKind::ConstInt(100, IntWidth::I8), |
| 688 | IrType::Int(IntWidth::I8), |
| 689 | ), |
| 690 | ( |
| 691 | InstKind::ConstInt(50, IntWidth::I8), |
| 692 | IrType::Int(IntWidth::I8), |
| 693 | ), |
| 694 | ( |
| 695 | InstKind::IAdd(ValueId(0), ValueId(1)), |
| 696 | IrType::Int(IntWidth::I8), |
| 697 | ), |
| 698 | ]); |
| 699 | assert!(ConstFold.run(&mut m)); |
| 700 | let kinds = first_block_kinds(&m); |
| 701 | match kinds[2] { |
| 702 | InstKind::ConstInt(v, IntWidth::I8) => assert_eq!(v, -106), |
| 703 | _ => panic!("expected ConstInt"), |
| 704 | } |
| 705 | } |
| 706 | |
| 707 | #[test] |
| 708 | fn idiv_by_zero_left_alone() { |
| 709 | let (mut m, _) = make_module_with(vec![ |
| 710 | ( |
| 711 | InstKind::ConstInt(10, IntWidth::I32), |
| 712 | IrType::Int(IntWidth::I32), |
| 713 | ), |
| 714 | ( |
| 715 | InstKind::ConstInt(0, IntWidth::I32), |
| 716 | IrType::Int(IntWidth::I32), |
| 717 | ), |
| 718 | ( |
| 719 | InstKind::IDiv(ValueId(0), ValueId(1)), |
| 720 | IrType::Int(IntWidth::I32), |
| 721 | ), |
| 722 | ]); |
| 723 | assert!(!ConstFold.run(&mut m)); |
| 724 | let kinds = first_block_kinds(&m); |
| 725 | assert!(matches!(kinds[2], InstKind::IDiv(..))); |
| 726 | } |
| 727 | |
| 728 | #[test] |
| 729 | fn fmul_f64() { |
| 730 | let (mut m, _) = make_module_with(vec![ |
| 731 | ( |
| 732 | InstKind::ConstFloat(2.0, FloatWidth::F64), |
| 733 | IrType::Float(FloatWidth::F64), |
| 734 | ), |
| 735 | ( |
| 736 | InstKind::ConstFloat(3.14, FloatWidth::F64), |
| 737 | IrType::Float(FloatWidth::F64), |
| 738 | ), |
| 739 | ( |
| 740 | InstKind::FMul(ValueId(0), ValueId(1)), |
| 741 | IrType::Float(FloatWidth::F64), |
| 742 | ), |
| 743 | ]); |
| 744 | assert!(ConstFold.run(&mut m)); |
| 745 | let kinds = first_block_kinds(&m); |
| 746 | match kinds[2] { |
| 747 | InstKind::ConstFloat(v, FloatWidth::F64) => assert!((v - 6.28).abs() < 1e-12), |
| 748 | _ => panic!("expected ConstFloat"), |
| 749 | } |
| 750 | } |
| 751 | |
| 752 | #[test] |
| 753 | fn icmp_eq_true() { |
| 754 | let (mut m, _) = make_module_with(vec![ |
| 755 | ( |
| 756 | InstKind::ConstInt(5, IntWidth::I32), |
| 757 | IrType::Int(IntWidth::I32), |
| 758 | ), |
| 759 | ( |
| 760 | InstKind::ConstInt(5, IntWidth::I32), |
| 761 | IrType::Int(IntWidth::I32), |
| 762 | ), |
| 763 | ( |
| 764 | InstKind::ICmp(CmpOp::Eq, ValueId(0), ValueId(1)), |
| 765 | IrType::Bool, |
| 766 | ), |
| 767 | ]); |
| 768 | assert!(ConstFold.run(&mut m)); |
| 769 | assert!(matches!( |
| 770 | first_block_kinds(&m)[2], |
| 771 | InstKind::ConstBool(true) |
| 772 | )); |
| 773 | } |
| 774 | |
| 775 | #[test] |
| 776 | fn icmp_signed_lt_with_i8_negatives() { |
| 777 | // i8 representation: -1 stored as 0xff (255). Without sign-extension, |
| 778 | // a naive (av < bv) would compare 255 < 1 → false; correct is true. |
| 779 | let (mut m, _) = make_module_with(vec![ |
| 780 | ( |
| 781 | InstKind::ConstInt(-1, IntWidth::I8), |
| 782 | IrType::Int(IntWidth::I8), |
| 783 | ), |
| 784 | ( |
| 785 | InstKind::ConstInt(1, IntWidth::I8), |
| 786 | IrType::Int(IntWidth::I8), |
| 787 | ), |
| 788 | ( |
| 789 | InstKind::ICmp(CmpOp::Lt, ValueId(0), ValueId(1)), |
| 790 | IrType::Bool, |
| 791 | ), |
| 792 | ]); |
| 793 | assert!(ConstFold.run(&mut m)); |
| 794 | assert!(matches!( |
| 795 | first_block_kinds(&m)[2], |
| 796 | InstKind::ConstBool(true) |
| 797 | )); |
| 798 | } |
| 799 | |
| 800 | #[test] |
| 801 | fn shl_power_of_two() { |
| 802 | let (mut m, _) = make_module_with(vec![ |
| 803 | ( |
| 804 | InstKind::ConstInt(1, IntWidth::I32), |
| 805 | IrType::Int(IntWidth::I32), |
| 806 | ), |
| 807 | ( |
| 808 | InstKind::ConstInt(3, IntWidth::I32), |
| 809 | IrType::Int(IntWidth::I32), |
| 810 | ), |
| 811 | ( |
| 812 | InstKind::Shl(ValueId(0), ValueId(1)), |
| 813 | IrType::Int(IntWidth::I32), |
| 814 | ), |
| 815 | ]); |
| 816 | assert!(ConstFold.run(&mut m)); |
| 817 | match first_block_kinds(&m)[2] { |
| 818 | InstKind::ConstInt(v, IntWidth::I32) => assert_eq!(v, 8), |
| 819 | _ => panic!(), |
| 820 | } |
| 821 | } |
| 822 | |
| 823 | #[test] |
| 824 | fn int_to_float_chain() { |
| 825 | // const(42:i32) → int_to_float → const(42.0:f64) |
| 826 | let (mut m, _) = make_module_with(vec![ |
| 827 | ( |
| 828 | InstKind::ConstInt(42, IntWidth::I32), |
| 829 | IrType::Int(IntWidth::I32), |
| 830 | ), |
| 831 | ( |
| 832 | InstKind::IntToFloat(ValueId(0), FloatWidth::F64), |
| 833 | IrType::Float(FloatWidth::F64), |
| 834 | ), |
| 835 | ]); |
| 836 | assert!(ConstFold.run(&mut m)); |
| 837 | match first_block_kinds(&m)[1] { |
| 838 | InstKind::ConstFloat(v, FloatWidth::F64) => assert_eq!(v, 42.0), |
| 839 | _ => panic!(), |
| 840 | } |
| 841 | } |
| 842 | |
| 843 | #[test] |
| 844 | fn select_on_known_cond_picks_branch() { |
| 845 | let (mut m, _) = make_module_with(vec![ |
| 846 | (InstKind::ConstBool(true), IrType::Bool), |
| 847 | ( |
| 848 | InstKind::ConstInt(10, IntWidth::I32), |
| 849 | IrType::Int(IntWidth::I32), |
| 850 | ), |
| 851 | ( |
| 852 | InstKind::ConstInt(20, IntWidth::I32), |
| 853 | IrType::Int(IntWidth::I32), |
| 854 | ), |
| 855 | ( |
| 856 | InstKind::Select(ValueId(0), ValueId(1), ValueId(2)), |
| 857 | IrType::Int(IntWidth::I32), |
| 858 | ), |
| 859 | ]); |
| 860 | assert!(ConstFold.run(&mut m)); |
| 861 | match first_block_kinds(&m)[3] { |
| 862 | InstKind::ConstInt(v, IntWidth::I32) => assert_eq!(v, 10), |
| 863 | _ => panic!(), |
| 864 | } |
| 865 | } |
| 866 | |
| 867 | #[test] |
| 868 | fn fixpoint_via_chained_constants() { |
| 869 | // (3 + 4) * 2 = 14, where the inner add must fold first. |
| 870 | let (mut m, _) = make_module_with(vec![ |
| 871 | ( |
| 872 | InstKind::ConstInt(3, IntWidth::I32), |
| 873 | IrType::Int(IntWidth::I32), |
| 874 | ), |
| 875 | ( |
| 876 | InstKind::ConstInt(4, IntWidth::I32), |
| 877 | IrType::Int(IntWidth::I32), |
| 878 | ), |
| 879 | ( |
| 880 | InstKind::IAdd(ValueId(0), ValueId(1)), |
| 881 | IrType::Int(IntWidth::I32), |
| 882 | ), |
| 883 | ( |
| 884 | InstKind::ConstInt(2, IntWidth::I32), |
| 885 | IrType::Int(IntWidth::I32), |
| 886 | ), |
| 887 | ( |
| 888 | InstKind::IMul(ValueId(2), ValueId(3)), |
| 889 | IrType::Int(IntWidth::I32), |
| 890 | ), |
| 891 | ]); |
| 892 | // A single pass already handles this because we walk in order. |
| 893 | assert!(ConstFold.run(&mut m)); |
| 894 | let kinds = first_block_kinds(&m); |
| 895 | assert!(matches!(kinds[2], InstKind::ConstInt(7, IntWidth::I32))); |
| 896 | assert!(matches!(kinds[4], InstKind::ConstInt(14, IntWidth::I32))); |
| 897 | } |
| 898 | |
| 899 | #[test] |
| 900 | fn nan_float_to_int_left_alone() { |
| 901 | let (mut m, _) = make_module_with(vec![ |
| 902 | ( |
| 903 | InstKind::ConstFloat(f64::NAN, FloatWidth::F64), |
| 904 | IrType::Float(FloatWidth::F64), |
| 905 | ), |
| 906 | ( |
| 907 | InstKind::FloatToInt(ValueId(0), IntWidth::I32), |
| 908 | IrType::Int(IntWidth::I32), |
| 909 | ), |
| 910 | ]); |
| 911 | assert!(!ConstFold.run(&mut m)); |
| 912 | assert!(matches!(first_block_kinds(&m)[1], InstKind::FloatToInt(..))); |
| 913 | } |
| 914 | } |
| 915 |