| 1 | """Tolerance-aware JSON comparator for cross-platform golden tests (S18). |
| 2 | |
| 3 | A plain ``actual == expected`` on the suite's JSON payload fails across |
| 4 | platforms for reasons that aren't real drift — BLAS implementation |
| 5 | differences produce last-ULP noise on ``raw`` values, wall-time and |
| 6 | timestamps are by definition non-deterministic, and ``sway_version`` |
| 7 | bumps every release. |
| 8 | |
| 9 | This module encodes the comparison rules the golden test actually |
| 10 | wants: |
| 11 | |
| 12 | - **Variable fields are masked** before comparison. Timestamps, |
| 13 | per-probe ``duration_s``, suite ``wall_seconds``, and the running |
| 14 | ``sway_version`` are all stripped — these are not load-bearing on |
| 15 | the determinism claim. |
| 16 | - **Numeric drift is bounded by tolerance.** ``logprob_tol`` (default |
| 17 | 1e-6) covers raw metrics; ``score_tol`` (default 1e-4) covers |
| 18 | composite scores and probe ``score`` fields. Differences beyond |
| 19 | those tolerances surface as explicit drift reports with path, |
| 20 | values, and delta. |
| 21 | |
| 22 | No torch / HF dependency — the module is usable from the fast lane |
| 23 | for the comparator's own unit tests. |
| 24 | """ |
| 25 | |
| 26 | from __future__ import annotations |
| 27 | |
| 28 | import copy |
| 29 | import math |
| 30 | from dataclasses import dataclass |
| 31 | from typing import Any |
| 32 | |
| 33 | #: Default field paths stripped before comparison. Every entry is a |
| 34 | #: dotted key path (``probes.duration_s``) or a plain field name |
| 35 | #: (``sway_version``). Path components are matched anywhere in the |
| 36 | #: nested structure — so ``duration_s`` masks both |
| 37 | #: ``probes[i].duration_s`` and any future top-level ``duration_s``. |
| 38 | DEFAULT_VARIABLE_FIELDS: frozenset[str] = frozenset( |
| 39 | { |
| 40 | "started_at", |
| 41 | "finished_at", |
| 42 | "wall_seconds", |
| 43 | "duration_s", |
| 44 | "sway_version", |
| 45 | # ``backend_stats`` records per-run wall times + cache counters |
| 46 | # that vary with load and cold/warm cache — not part of the |
| 47 | # determinism contract. |
| 48 | "backend_stats", |
| 49 | # ``adapter_id`` and ``base_model_id`` are absolute-path |
| 50 | # identifiers the spec loader resolves against cwd. Different |
| 51 | # cwds on different platforms (``/Users/.../`` on darwin vs |
| 52 | # ``/home/runner/...`` on ubuntu) surface as drift without |
| 53 | # any real numeric change. The numeric fields (``raw``, |
| 54 | # ``score``, etc.) are what the determinism contract covers. |
| 55 | "adapter_id", |
| 56 | "base_model_id", |
| 57 | } |
| 58 | ) |
| 59 | |
| 60 | #: Score-level floats: composite scores + per-probe scores. Applies |
| 61 | #: the looser ``score_tol`` since these are derived metrics. |
| 62 | _SCORE_FIELD_NAMES: frozenset[str] = frozenset({"score", "overall"}) |
| 63 | |
| 64 | |
| 65 | @dataclass(frozen=True, slots=True) |
| 66 | class Diff: |
| 67 | """One tolerance-exceeding drift between actual and expected.""" |
| 68 | |
| 69 | path: str |
| 70 | actual: Any |
| 71 | expected: Any |
| 72 | reason: str |
| 73 | |
| 74 | def __str__(self) -> str: |
| 75 | return f"{self.path}: {self.reason} (actual={self.actual!r}, expected={self.expected!r})" |
| 76 | |
| 77 | |
| 78 | def mask_variable_fields(payload: Any, *, fields: frozenset[str] = DEFAULT_VARIABLE_FIELDS) -> Any: |
| 79 | """Return a deep copy of ``payload`` with variable fields removed. |
| 80 | |
| 81 | Walks nested dicts + lists; drops any dict key whose name is in |
| 82 | ``fields``. Lists preserve order; scalars pass through unchanged. |
| 83 | """ |
| 84 | if isinstance(payload, dict): |
| 85 | return { |
| 86 | k: mask_variable_fields(v, fields=fields) for k, v in payload.items() if k not in fields |
| 87 | } |
| 88 | if isinstance(payload, list): |
| 89 | return [mask_variable_fields(item, fields=fields) for item in payload] |
| 90 | return copy.copy(payload) |
| 91 | |
| 92 | |
| 93 | def compare_goldens( |
| 94 | actual: Any, |
| 95 | expected: Any, |
| 96 | *, |
| 97 | logprob_tol: float = 1e-4, |
| 98 | score_tol: float = 1e-4, |
| 99 | ) -> list[Diff]: |
| 100 | """Compare two masked JSON payloads; return tolerance-exceeding diffs. |
| 101 | |
| 102 | Empty list = payloads match within tolerance. Non-empty list = one |
| 103 | or more fields drifted beyond tolerance. |
| 104 | |
| 105 | Walks the two structures in parallel. Floats are compared against |
| 106 | the appropriate tolerance (``score_tol`` for score-like fields, |
| 107 | ``logprob_tol`` elsewhere). Missing keys, length mismatches, or |
| 108 | type changes surface as structural diffs regardless of tolerance. |
| 109 | |
| 110 | **Tolerance rationale.** S18's first CI observation showed |
| 111 | intra-platform BLAS drift in the 1e-5–1e-6 band on ubuntu-latest |
| 112 | runners (heterogeneous Intel/AMD hardware + variable OpenBLAS |
| 113 | builds), which put 1e-6 below the natural noise floor. A real |
| 114 | algorithm change — e.g. flipping ``top_k=256`` → 128 in |
| 115 | :mod:`delta_kl` — shifts probe raws by 1e-2 to 1e-1, three orders |
| 116 | of magnitude above the current ``1e-4`` tolerance. Tuning room to |
| 117 | revisit once we have more CI history; see the sprint's risks |
| 118 | section for the "too tight vs too loose" tradeoff. |
| 119 | """ |
| 120 | diffs: list[Diff] = [] |
| 121 | _walk(actual, expected, path="$", diffs=diffs, logprob_tol=logprob_tol, score_tol=score_tol) |
| 122 | return diffs |
| 123 | |
| 124 | |
| 125 | def _walk( |
| 126 | actual: Any, |
| 127 | expected: Any, |
| 128 | *, |
| 129 | path: str, |
| 130 | diffs: list[Diff], |
| 131 | logprob_tol: float, |
| 132 | score_tol: float, |
| 133 | ) -> None: |
| 134 | # Type mismatch — catches e.g. scalar-to-dict transitions between |
| 135 | # schema versions. Treat int/float as comparable — an expected |
| 136 | # int 0 vs actual 0.0 is not drift. |
| 137 | if type(actual) is not type(expected) and not ( |
| 138 | isinstance(actual, int | float) and isinstance(expected, int | float) |
| 139 | ): |
| 140 | diffs.append(Diff(path=path, actual=actual, expected=expected, reason="type mismatch")) |
| 141 | return |
| 142 | |
| 143 | if isinstance(expected, dict): |
| 144 | assert isinstance(actual, dict) |
| 145 | extra_actual = set(actual) - set(expected) |
| 146 | missing_actual = set(expected) - set(actual) |
| 147 | for key in sorted(extra_actual): |
| 148 | diffs.append( |
| 149 | Diff( |
| 150 | path=f"{path}.{key}", |
| 151 | actual=actual[key], |
| 152 | expected=None, |
| 153 | reason="unexpected key in actual", |
| 154 | ) |
| 155 | ) |
| 156 | for key in sorted(missing_actual): |
| 157 | diffs.append( |
| 158 | Diff( |
| 159 | path=f"{path}.{key}", |
| 160 | actual=None, |
| 161 | expected=expected[key], |
| 162 | reason="missing key in actual", |
| 163 | ) |
| 164 | ) |
| 165 | for key in sorted(set(actual) & set(expected)): |
| 166 | _walk( |
| 167 | actual[key], |
| 168 | expected[key], |
| 169 | path=f"{path}.{key}", |
| 170 | diffs=diffs, |
| 171 | logprob_tol=logprob_tol, |
| 172 | score_tol=score_tol, |
| 173 | ) |
| 174 | return |
| 175 | |
| 176 | if isinstance(expected, list): |
| 177 | assert isinstance(actual, list) |
| 178 | if len(actual) != len(expected): |
| 179 | diffs.append( |
| 180 | Diff( |
| 181 | path=path, |
| 182 | actual=f"len={len(actual)}", |
| 183 | expected=f"len={len(expected)}", |
| 184 | reason="list length mismatch", |
| 185 | ) |
| 186 | ) |
| 187 | return |
| 188 | for i, (a_item, e_item) in enumerate(zip(actual, expected, strict=True)): |
| 189 | _walk( |
| 190 | a_item, |
| 191 | e_item, |
| 192 | path=f"{path}[{i}]", |
| 193 | diffs=diffs, |
| 194 | logprob_tol=logprob_tol, |
| 195 | score_tol=score_tol, |
| 196 | ) |
| 197 | return |
| 198 | |
| 199 | if isinstance(expected, int | float): |
| 200 | assert isinstance(actual, int | float) |
| 201 | # Use the looser score_tol for fields whose last path segment |
| 202 | # is a known score name. Everything else (raw, z_score, |
| 203 | # base_value, ft_value, logprob entries, component values) |
| 204 | # gets logprob_tol. |
| 205 | tol = score_tol if _is_score_path(path) else logprob_tol |
| 206 | if not _within_tol(float(actual), float(expected), tol): |
| 207 | diff = float(actual) - float(expected) |
| 208 | diffs.append( |
| 209 | Diff( |
| 210 | path=path, |
| 211 | actual=actual, |
| 212 | expected=expected, |
| 213 | reason=f"|Δ|={abs(diff):.3e} > tol={tol:.3e}", |
| 214 | ) |
| 215 | ) |
| 216 | return |
| 217 | |
| 218 | # Strings, None, bool — exact equality required. |
| 219 | if actual != expected: |
| 220 | diffs.append(Diff(path=path, actual=actual, expected=expected, reason="value mismatch")) |
| 221 | |
| 222 | |
| 223 | def _within_tol(a: float, b: float, tol: float) -> bool: |
| 224 | """Absolute-tolerance float equality, treating NaN/inf carefully. |
| 225 | |
| 226 | Both non-finite on the same side (both +inf, both -inf, both NaN) |
| 227 | counts as equal. Otherwise the diff must be within ``tol``. |
| 228 | """ |
| 229 | if math.isnan(a) and math.isnan(b): |
| 230 | return True |
| 231 | if math.isnan(a) or math.isnan(b): |
| 232 | return False |
| 233 | if math.isinf(a) or math.isinf(b): |
| 234 | return a == b # same-signed infinities compare equal |
| 235 | return abs(a - b) <= tol |
| 236 | |
| 237 | |
| 238 | def _is_score_path(path: str) -> bool: |
| 239 | """Last path segment matches a score-like field name.""" |
| 240 | # Strip trailing ``[idx]`` if any — scores live as scalars inside |
| 241 | # their parent dict, never indexed. |
| 242 | key = path.rsplit(".", 1)[-1] |
| 243 | return key in _SCORE_FIELD_NAMES |
| 244 | |
| 245 | |
| 246 | __all__ = ["DEFAULT_VARIABLE_FIELDS", "Diff", "compare_goldens", "mask_variable_fields"] |