Python · 9659 bytes Raw Blame History
1 """Unit tests for :mod:`dlm_sway.core.golden`.
2
3 Pins the comparator's tolerance math and the variable-field mask so
4 the cross-platform golden test (S18) has a reliable backbone. No HF
5 or torch dependency — the comparator is pure-Python and runs in the
6 fast lane.
7 """
8
9 from __future__ import annotations
10
11 import math
12
13 from dlm_sway.core.golden import (
14 DEFAULT_VARIABLE_FIELDS,
15 Diff,
16 compare_goldens,
17 mask_variable_fields,
18 )
19
20
21 class TestMaskVariableFields:
22 def test_strips_top_level_fields(self) -> None:
23 payload = {
24 "sway_version": "0.1.0",
25 "wall_seconds": 1.23,
26 "probes": [],
27 }
28 masked = mask_variable_fields(payload)
29 assert "sway_version" not in masked
30 assert "wall_seconds" not in masked
31 assert "probes" in masked
32
33 def test_strips_nested_duration_s(self) -> None:
34 payload = {
35 "probes": [
36 {"name": "p1", "raw": 0.5, "duration_s": 0.01},
37 {"name": "p2", "raw": 0.8, "duration_s": 0.02},
38 ],
39 }
40 masked = mask_variable_fields(payload)
41 for probe in masked["probes"]:
42 assert "duration_s" not in probe
43 assert "raw" in probe
44
45 def test_strips_started_and_finished(self) -> None:
46 payload = {"started_at": "2026-01-01T00:00:00Z", "finished_at": "2026-01-01T00:00:05Z"}
47 masked = mask_variable_fields(payload)
48 assert masked == {}
49
50 def test_strips_backend_stats(self) -> None:
51 payload = {
52 "backend_stats": {"cache_hits": 42, "wall_ms": 1230.0},
53 "overall": 0.8,
54 }
55 masked = mask_variable_fields(payload)
56 assert "backend_stats" not in masked
57 assert masked["overall"] == 0.8
58
59 def test_preserves_scalars(self) -> None:
60 assert mask_variable_fields(42) == 42
61 assert mask_variable_fields("hello") == "hello"
62 assert mask_variable_fields(None) is None
63
64 def test_default_variable_fields_has_expected_members(self) -> None:
65 """Lock the default mask set — accidentally dropping a field
66 from the mask would make the golden test newly flaky."""
67 expected_members = {
68 "started_at",
69 "finished_at",
70 "wall_seconds",
71 "duration_s",
72 "sway_version",
73 "backend_stats",
74 # Platform-dependent path identifiers.
75 "adapter_id",
76 "base_model_id",
77 }
78 assert expected_members <= DEFAULT_VARIABLE_FIELDS
79
80
81 class TestCompareGoldensIdentical:
82 def test_identical_payload_no_diffs(self) -> None:
83 payload = {"overall": 0.85, "probes": [{"raw": 0.123, "score": 0.9}]}
84 assert compare_goldens(payload, payload) == []
85
86 def test_empty_payload_no_diffs(self) -> None:
87 assert compare_goldens({}, {}) == []
88
89
90 class TestCompareGoldensTolerance:
91 def test_floats_within_logprob_tol_pass(self) -> None:
92 actual = {"probes": [{"raw": 0.12345}]}
93 expected = {"probes": [{"raw": 0.12345 + 5e-5}]} # well under 1e-4
94 assert compare_goldens(actual, expected) == []
95
96 def test_floats_just_above_logprob_tol_fail(self) -> None:
97 actual = {"probes": [{"raw": 0.12345}]}
98 expected = {"probes": [{"raw": 0.12345 + 2e-4}]} # double the tol
99 diffs = compare_goldens(actual, expected)
100 assert len(diffs) == 1
101 assert "raw" in diffs[0].path
102 assert "Δ" in diffs[0].reason
103
104 def test_scores_match_logprob_tol_default(self) -> None:
105 """Score fields use ``score_tol`` (1e-4) — same as ``logprob_tol``
106 after S18's first-week tuning. A 5e-5 drift passes on both."""
107 actual = {"overall": 0.85}
108 expected = {"overall": 0.85 + 5e-5}
109 assert compare_goldens(actual, expected) == []
110
111 def test_score_field_drift_above_score_tol_fails(self) -> None:
112 actual = {"overall": 0.85}
113 expected = {"overall": 0.85 + 2e-4} # double the score tol
114 diffs = compare_goldens(actual, expected)
115 assert len(diffs) == 1
116 assert diffs[0].path == "$.overall"
117
118 def test_custom_tolerances_respected(self) -> None:
119 """Callers can tighten or loosen both tolerances."""
120 actual = {"probes": [{"raw": 0.1}]}
121 expected = {"probes": [{"raw": 0.1 + 5e-4}]}
122 # Default tol (1e-4) → fail.
123 assert compare_goldens(actual, expected) != []
124 # Loosened to 1e-3 → pass.
125 assert compare_goldens(actual, expected, logprob_tol=1e-3) == []
126 # Tightened to 1e-6 → same fail, but also a regression guard
127 # if we ever tighten the default back.
128 assert compare_goldens(actual, expected, logprob_tol=1e-6) != []
129
130 def test_nan_vs_nan_treated_equal(self) -> None:
131 actual = {"z_score": float("nan")}
132 expected = {"z_score": float("nan")}
133 assert compare_goldens(actual, expected) == []
134
135 def test_nan_vs_finite_is_drift(self) -> None:
136 actual = {"z_score": float("nan")}
137 expected = {"z_score": 3.0}
138 diffs = compare_goldens(actual, expected)
139 assert len(diffs) == 1
140 assert diffs[0].path == "$.z_score"
141
142 def test_inf_comparison(self) -> None:
143 """Same-signed infinities compare equal; opposite signs drift."""
144 actual = {"raw": float("inf")}
145 expected = {"raw": float("inf")}
146 assert compare_goldens(actual, expected) == []
147 diffs = compare_goldens({"raw": float("inf")}, {"raw": float("-inf")})
148 assert diffs
149 # IEEE compares same-sign as equal but opposite as distinct;
150 # the comparator bails on non-finite diffs without a tolerance.
151
152 def test_int_vs_float_not_type_mismatch(self) -> None:
153 """``raw: 0`` (int) vs ``raw: 0.0`` (float) is not drift."""
154 assert compare_goldens({"raw": 0}, {"raw": 0.0}) == []
155
156
157 class TestCompareGoldensStructural:
158 def test_missing_key_flagged(self) -> None:
159 actual = {"overall": 0.8}
160 expected = {"overall": 0.8, "band": "healthy"}
161 diffs = compare_goldens(actual, expected)
162 assert any(d.reason == "missing key in actual" for d in diffs)
163
164 def test_extra_key_flagged(self) -> None:
165 actual = {"overall": 0.8, "new_field": 42}
166 expected = {"overall": 0.8}
167 diffs = compare_goldens(actual, expected)
168 assert any(d.reason == "unexpected key in actual" for d in diffs)
169
170 def test_list_length_mismatch_flagged(self) -> None:
171 actual = {"probes": [{"raw": 0.1}]}
172 expected = {"probes": [{"raw": 0.1}, {"raw": 0.2}]}
173 diffs = compare_goldens(actual, expected)
174 assert len(diffs) == 1
175 assert "list length mismatch" in diffs[0].reason
176
177 def test_type_mismatch_flagged(self) -> None:
178 actual = {"band": "healthy"}
179 expected = {"band": {"name": "healthy", "level": 3}}
180 diffs = compare_goldens(actual, expected)
181 assert any(d.reason == "type mismatch" for d in diffs)
182
183 def test_string_mismatch_flagged(self) -> None:
184 actual = {"band": "noise"}
185 expected = {"band": "healthy"}
186 diffs = compare_goldens(actual, expected)
187 assert len(diffs) == 1
188 assert diffs[0].reason == "value mismatch"
189
190
191 class TestDiffRepr:
192 def test_str_includes_path_and_reason(self) -> None:
193 d = Diff(path="$.foo", actual=1.0, expected=2.0, reason="drift")
194 s = str(d)
195 assert "$.foo" in s
196 assert "drift" in s
197 assert "1.0" in s
198 assert "2.0" in s
199
200
201 class TestRealisticPayload:
202 def test_two_masked_payloads_match(self) -> None:
203 """End-to-end sanity: mask timestamps + duration, compare the
204 rest, drift-free."""
205 actual = {
206 "schema_version": 1,
207 "sway_version": "0.1.0",
208 "started_at": "2026-04-01T00:00:00Z",
209 "finished_at": "2026-04-01T00:00:05Z",
210 "wall_seconds": 5.123,
211 "overall": 0.82,
212 "probes": [
213 {
214 "name": "dk",
215 "raw": 0.4561,
216 "score": 0.87,
217 "duration_s": 0.123,
218 },
219 ],
220 }
221 expected = {
222 "schema_version": 1,
223 "sway_version": "0.0.9", # version bumped
224 "started_at": "2026-03-15T12:00:00Z",
225 "finished_at": "2026-03-15T12:00:03Z",
226 "wall_seconds": 3.456, # different wall
227 "overall": 0.82 + 5e-5, # within score_tol
228 "probes": [
229 {
230 "name": "dk",
231 "raw": 0.4561 + 5e-5, # within logprob_tol (1e-4)
232 "score": 0.87,
233 "duration_s": 0.789, # different duration
234 },
235 ],
236 }
237 masked_actual = mask_variable_fields(actual)
238 masked_expected = mask_variable_fields(expected)
239 assert compare_goldens(masked_actual, masked_expected) == []
240
241 def test_simulated_silent_algorithm_change_is_caught(self) -> None:
242 """Prove-the-value sanity: a 1e-2 drift on a probe's raw is
243 flagged — well above the 1e-4 default tolerance. Real
244 algorithm changes (e.g. flipping ``top_k=256`` → 128) shift
245 raws by this order of magnitude."""
246 expected = {"probes": [{"raw": 0.4561}]}
247 actual = {"probes": [{"raw": 0.4561 + 1e-2}]}
248 diffs = compare_goldens(actual, expected)
249 assert len(diffs) == 1
250 assert "raw" in diffs[0].path
251 assert math.isclose(
252 abs(float(diffs[0].actual) - float(diffs[0].expected)), 1e-2, abs_tol=1e-9
253 )