Python · 9059 bytes Raw Blame History
1 """Tolerance-aware JSON comparator for cross-platform golden tests (S18).
2
3 A plain ``actual == expected`` on the suite's JSON payload fails across
4 platforms for reasons that aren't real drift — BLAS implementation
5 differences produce last-ULP noise on ``raw`` values, wall-time and
6 timestamps are by definition non-deterministic, and ``sway_version``
7 bumps every release.
8
9 This module encodes the comparison rules the golden test actually
10 wants:
11
12 - **Variable fields are masked** before comparison. Timestamps,
13 per-probe ``duration_s``, suite ``wall_seconds``, and the running
14 ``sway_version`` are all stripped — these are not load-bearing on
15 the determinism claim.
16 - **Numeric drift is bounded by tolerance.** ``logprob_tol`` (default
17 1e-6) covers raw metrics; ``score_tol`` (default 1e-4) covers
18 composite scores and probe ``score`` fields. Differences beyond
19 those tolerances surface as explicit drift reports with path,
20 values, and delta.
21
22 No torch / HF dependency — the module is usable from the fast lane
23 for the comparator's own unit tests.
24 """
25
26 from __future__ import annotations
27
28 import copy
29 import math
30 from dataclasses import dataclass
31 from typing import Any
32
33 #: Default field paths stripped before comparison. Every entry is a
34 #: dotted key path (``probes.duration_s``) or a plain field name
35 #: (``sway_version``). Path components are matched anywhere in the
36 #: nested structure — so ``duration_s`` masks both
37 #: ``probes[i].duration_s`` and any future top-level ``duration_s``.
38 DEFAULT_VARIABLE_FIELDS: frozenset[str] = frozenset(
39 {
40 "started_at",
41 "finished_at",
42 "wall_seconds",
43 "duration_s",
44 "sway_version",
45 # ``backend_stats`` records per-run wall times + cache counters
46 # that vary with load and cold/warm cache — not part of the
47 # determinism contract.
48 "backend_stats",
49 # ``adapter_id`` and ``base_model_id`` are absolute-path
50 # identifiers the spec loader resolves against cwd. Different
51 # cwds on different platforms (``/Users/.../`` on darwin vs
52 # ``/home/runner/...`` on ubuntu) surface as drift without
53 # any real numeric change. The numeric fields (``raw``,
54 # ``score``, etc.) are what the determinism contract covers.
55 "adapter_id",
56 "base_model_id",
57 }
58 )
59
60 #: Score-level floats: composite scores + per-probe scores. Applies
61 #: the looser ``score_tol`` since these are derived metrics.
62 _SCORE_FIELD_NAMES: frozenset[str] = frozenset({"score", "overall"})
63
64
65 @dataclass(frozen=True, slots=True)
66 class Diff:
67 """One tolerance-exceeding drift between actual and expected."""
68
69 path: str
70 actual: Any
71 expected: Any
72 reason: str
73
74 def __str__(self) -> str:
75 return f"{self.path}: {self.reason} (actual={self.actual!r}, expected={self.expected!r})"
76
77
78 def mask_variable_fields(payload: Any, *, fields: frozenset[str] = DEFAULT_VARIABLE_FIELDS) -> Any:
79 """Return a deep copy of ``payload`` with variable fields removed.
80
81 Walks nested dicts + lists; drops any dict key whose name is in
82 ``fields``. Lists preserve order; scalars pass through unchanged.
83 """
84 if isinstance(payload, dict):
85 return {
86 k: mask_variable_fields(v, fields=fields) for k, v in payload.items() if k not in fields
87 }
88 if isinstance(payload, list):
89 return [mask_variable_fields(item, fields=fields) for item in payload]
90 return copy.copy(payload)
91
92
93 def compare_goldens(
94 actual: Any,
95 expected: Any,
96 *,
97 logprob_tol: float = 1e-4,
98 score_tol: float = 1e-4,
99 ) -> list[Diff]:
100 """Compare two masked JSON payloads; return tolerance-exceeding diffs.
101
102 Empty list = payloads match within tolerance. Non-empty list = one
103 or more fields drifted beyond tolerance.
104
105 Walks the two structures in parallel. Floats are compared against
106 the appropriate tolerance (``score_tol`` for score-like fields,
107 ``logprob_tol`` elsewhere). Missing keys, length mismatches, or
108 type changes surface as structural diffs regardless of tolerance.
109
110 **Tolerance rationale.** S18's first CI observation showed
111 intra-platform BLAS drift in the 1e-5–1e-6 band on ubuntu-latest
112 runners (heterogeneous Intel/AMD hardware + variable OpenBLAS
113 builds), which put 1e-6 below the natural noise floor. A real
114 algorithm change — e.g. flipping ``top_k=256`` → 128 in
115 :mod:`delta_kl` — shifts probe raws by 1e-2 to 1e-1, three orders
116 of magnitude above the current ``1e-4`` tolerance. Tuning room to
117 revisit once we have more CI history; see the sprint's risks
118 section for the "too tight vs too loose" tradeoff.
119 """
120 diffs: list[Diff] = []
121 _walk(actual, expected, path="$", diffs=diffs, logprob_tol=logprob_tol, score_tol=score_tol)
122 return diffs
123
124
125 def _walk(
126 actual: Any,
127 expected: Any,
128 *,
129 path: str,
130 diffs: list[Diff],
131 logprob_tol: float,
132 score_tol: float,
133 ) -> None:
134 # Type mismatch — catches e.g. scalar-to-dict transitions between
135 # schema versions. Treat int/float as comparable — an expected
136 # int 0 vs actual 0.0 is not drift.
137 if type(actual) is not type(expected) and not (
138 isinstance(actual, int | float) and isinstance(expected, int | float)
139 ):
140 diffs.append(Diff(path=path, actual=actual, expected=expected, reason="type mismatch"))
141 return
142
143 if isinstance(expected, dict):
144 assert isinstance(actual, dict)
145 extra_actual = set(actual) - set(expected)
146 missing_actual = set(expected) - set(actual)
147 for key in sorted(extra_actual):
148 diffs.append(
149 Diff(
150 path=f"{path}.{key}",
151 actual=actual[key],
152 expected=None,
153 reason="unexpected key in actual",
154 )
155 )
156 for key in sorted(missing_actual):
157 diffs.append(
158 Diff(
159 path=f"{path}.{key}",
160 actual=None,
161 expected=expected[key],
162 reason="missing key in actual",
163 )
164 )
165 for key in sorted(set(actual) & set(expected)):
166 _walk(
167 actual[key],
168 expected[key],
169 path=f"{path}.{key}",
170 diffs=diffs,
171 logprob_tol=logprob_tol,
172 score_tol=score_tol,
173 )
174 return
175
176 if isinstance(expected, list):
177 assert isinstance(actual, list)
178 if len(actual) != len(expected):
179 diffs.append(
180 Diff(
181 path=path,
182 actual=f"len={len(actual)}",
183 expected=f"len={len(expected)}",
184 reason="list length mismatch",
185 )
186 )
187 return
188 for i, (a_item, e_item) in enumerate(zip(actual, expected, strict=True)):
189 _walk(
190 a_item,
191 e_item,
192 path=f"{path}[{i}]",
193 diffs=diffs,
194 logprob_tol=logprob_tol,
195 score_tol=score_tol,
196 )
197 return
198
199 if isinstance(expected, int | float):
200 assert isinstance(actual, int | float)
201 # Use the looser score_tol for fields whose last path segment
202 # is a known score name. Everything else (raw, z_score,
203 # base_value, ft_value, logprob entries, component values)
204 # gets logprob_tol.
205 tol = score_tol if _is_score_path(path) else logprob_tol
206 if not _within_tol(float(actual), float(expected), tol):
207 diff = float(actual) - float(expected)
208 diffs.append(
209 Diff(
210 path=path,
211 actual=actual,
212 expected=expected,
213 reason=f"|Δ|={abs(diff):.3e} > tol={tol:.3e}",
214 )
215 )
216 return
217
218 # Strings, None, bool — exact equality required.
219 if actual != expected:
220 diffs.append(Diff(path=path, actual=actual, expected=expected, reason="value mismatch"))
221
222
223 def _within_tol(a: float, b: float, tol: float) -> bool:
224 """Absolute-tolerance float equality, treating NaN/inf carefully.
225
226 Both non-finite on the same side (both +inf, both -inf, both NaN)
227 counts as equal. Otherwise the diff must be within ``tol``.
228 """
229 if math.isnan(a) and math.isnan(b):
230 return True
231 if math.isnan(a) or math.isnan(b):
232 return False
233 if math.isinf(a) or math.isinf(b):
234 return a == b # same-signed infinities compare equal
235 return abs(a - b) <= tol
236
237
238 def _is_score_path(path: str) -> bool:
239 """Last path segment matches a score-like field name."""
240 # Strip trailing ``[idx]`` if any — scores live as scalars inside
241 # their parent dict, never indexed.
242 key = path.rsplit(".", 1)[-1]
243 return key in _SCORE_FIELD_NAMES
244
245
246 __all__ = ["DEFAULT_VARIABLE_FIELDS", "Diff", "compare_goldens", "mask_variable_fields"]