tenseleyflow/sway / e4fdbf1

Browse files

core: safe_finalize helper — non-finite critical fields → Verdict.ERROR

Authored by espadonne
SHA
e4fdbf1bd877b23c6efa11c70ff92e1abcc4ab7a
Parents
d21e208
Tree
a4a0ec6

3 changed files

StatusFile+-
M src/dlm_sway/__init__.py 2 1
M src/dlm_sway/core/result.py 100 0
A tests/unit/test_safe_finalize.py 204 0
src/dlm_sway/__init__.pymodified
@@ -13,7 +13,7 @@ from dlm_sway.core.errors import (
1313
     SwayError,
1414
 )
1515
 from dlm_sway.core.model import LoadedModel, Model, ModelSpec
16
-from dlm_sway.core.result import ProbeResult, SuiteResult, SwayScore, Verdict
16
+from dlm_sway.core.result import ProbeResult, SuiteResult, SwayScore, Verdict, safe_finalize
1717
 from dlm_sway.core.scoring import (
1818
     DifferentialBackend,
1919
     NullCalibratedBackend,
@@ -41,6 +41,7 @@ __all__ = [
4141
     "SwayScore",
4242
     "TokenDist",
4343
     "Verdict",
44
+    "safe_finalize",
4445
 ]
4546
 
4647
 __version__ = "0.1.0.dev0"
src/dlm_sway/core/result.pymodified
@@ -13,6 +13,7 @@ cheap probes.
1313
 
1414
 from __future__ import annotations
1515
 
16
+import math
1617
 from dataclasses import dataclass, field
1718
 from datetime import UTC, datetime
1819
 from enum import StrEnum
@@ -137,3 +138,102 @@ class SwayScore:
137138
 def utcnow() -> datetime:
138139
     """Timezone-aware UTC timestamp (used by the runner)."""
139140
     return datetime.now(UTC)
141
+
142
+
143
+def safe_finalize(
144
+    *,
145
+    name: str,
146
+    kind: str,
147
+    verdict: Verdict,
148
+    score: float | None = None,
149
+    raw: float | None = None,
150
+    z_score: float | None = None,
151
+    base_value: float | None = None,
152
+    ft_value: float | None = None,
153
+    evidence: dict[str, Any] | None = None,
154
+    message: str = "",
155
+    duration_s: float = 0.0,
156
+    critical_fields: tuple[str, ...] = ("raw",),
157
+) -> ProbeResult:
158
+    """Build a :class:`ProbeResult` with defense against non-finite metrics.
159
+
160
+    Probes hand their candidate result kwargs here instead of constructing
161
+    a :class:`ProbeResult` directly. The helper inspects every numeric
162
+    field and classifies it:
163
+
164
+    - **Critical field non-finite** (any field named in ``critical_fields``
165
+      whose value is ``NaN`` or ``±inf``): the whole probe result is
166
+      converted to :attr:`Verdict.ERROR` with all scalar fields nulled out,
167
+      the offending values are preserved under
168
+      ``evidence["non_finite_inputs"]``, and the message explains which
169
+      field(s) were non-finite.
170
+    - **Non-critical field non-finite**: nulled out silently (set to
171
+      ``None``), and the field name appended to
172
+      ``evidence["defensively_nulled"]`` so a report reader can see what
173
+      happened.
174
+    - **Everything finite**: passthrough, no change.
175
+
176
+    The default ``critical_fields = ("raw",)`` reflects the design stance:
177
+    ``raw`` is the probe's ground-truth metric; a non-finite ``raw`` means
178
+    the probe cannot make a meaningful statement. Probes that care about
179
+    other fields (e.g., probes whose ``z_score`` is load-bearing) pass a
180
+    broader tuple.
181
+
182
+    This helper is the single shared guardrail sprint 01 installs against
183
+    the +11639σ class of bug, where NaN logprobs flowed silently through
184
+    to a PASS verdict. Every numeric probe is expected to finalize through
185
+    this function.
186
+    """
187
+    numeric_kwargs: dict[str, float | None] = {
188
+        "score": score,
189
+        "raw": raw,
190
+        "z_score": z_score,
191
+        "base_value": base_value,
192
+        "ft_value": ft_value,
193
+    }
194
+
195
+    non_finite: dict[str, float] = {}
196
+    for fname, v in numeric_kwargs.items():
197
+        if isinstance(v, int | float) and not isinstance(v, bool) and not math.isfinite(float(v)):
198
+            non_finite[fname] = float(v)
199
+
200
+    ev: dict[str, Any] = dict(evidence) if evidence is not None else {}
201
+
202
+    critical_non_finite = {k: v for k, v in non_finite.items() if k in critical_fields}
203
+    if critical_non_finite:
204
+        ev["non_finite_inputs"] = non_finite
205
+        return ProbeResult(
206
+            name=name,
207
+            kind=kind,
208
+            verdict=Verdict.ERROR,
209
+            score=None,
210
+            raw=None,
211
+            z_score=None,
212
+            base_value=None,
213
+            ft_value=None,
214
+            evidence=ev,
215
+            message=(
216
+                f"non-finite critical field(s): {', '.join(sorted(critical_non_finite))} "
217
+                f"— probe cannot produce a meaningful result"
218
+            ),
219
+            duration_s=duration_s,
220
+        )
221
+
222
+    if non_finite:
223
+        ev.setdefault("defensively_nulled", []).extend(sorted(non_finite))
224
+        for fname in non_finite:
225
+            numeric_kwargs[fname] = None
226
+
227
+    return ProbeResult(
228
+        name=name,
229
+        kind=kind,
230
+        verdict=verdict,
231
+        score=numeric_kwargs["score"],
232
+        raw=numeric_kwargs["raw"],
233
+        z_score=numeric_kwargs["z_score"],
234
+        base_value=numeric_kwargs["base_value"],
235
+        ft_value=numeric_kwargs["ft_value"],
236
+        evidence=ev,
237
+        message=message,
238
+        duration_s=duration_s,
239
+    )
tests/unit/test_safe_finalize.pyadded
@@ -0,0 +1,204 @@
1
+"""Tests for :func:`dlm_sway.core.result.safe_finalize`.
2
+
3
+This helper is the shared guardrail S01 installs against NaN-flows-through
4
+bugs. It must:
5
+
6
+- Route critical non-finite fields to :attr:`Verdict.ERROR` with score nulled
7
+- Defensively null non-critical non-finite fields without changing the verdict
8
+- Leave all-finite inputs untouched
9
+- Preserve the original non-finite values in evidence for postmortem
10
+"""
11
+
12
+from __future__ import annotations
13
+
14
+import math
15
+
16
+from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
17
+
18
+
19
+class TestAllFinite:
20
+    def test_passthrough_preserves_all_fields(self) -> None:
21
+        r = safe_finalize(
22
+            name="p1",
23
+            kind="delta_kl",
24
+            verdict=Verdict.PASS,
25
+            score=0.75,
26
+            raw=0.08,
27
+            z_score=3.2,
28
+            base_value=0.0,
29
+            ft_value=0.08,
30
+            evidence={"num_prompts": 4},
31
+            message="looks fine",
32
+            duration_s=1.2,
33
+        )
34
+        assert r.verdict == Verdict.PASS
35
+        assert r.score == 0.75
36
+        assert r.raw == 0.08
37
+        assert r.z_score == 3.2
38
+        assert r.base_value == 0.0
39
+        assert r.ft_value == 0.08
40
+        assert r.message == "looks fine"
41
+        assert r.duration_s == 1.2
42
+        assert r.evidence == {"num_prompts": 4}
43
+
44
+    def test_defaults(self) -> None:
45
+        r = safe_finalize(name="p", kind="k", verdict=Verdict.PASS, score=1.0)
46
+        assert r.raw is None
47
+        assert r.z_score is None
48
+        assert r.evidence == {}
49
+        assert r.duration_s == 0.0
50
+
51
+
52
+class TestCriticalNonFinite:
53
+    def test_nan_raw_routes_to_error(self) -> None:
54
+        r = safe_finalize(
55
+            name="p",
56
+            kind="delta_kl",
57
+            verdict=Verdict.PASS,
58
+            score=1.0,
59
+            raw=math.nan,
60
+            z_score=3.0,
61
+        )
62
+        assert r.verdict == Verdict.ERROR
63
+        assert r.score is None
64
+        assert r.raw is None
65
+        assert r.z_score is None
66
+        assert "non-finite critical" in r.message
67
+        assert "raw" in r.message
68
+        assert "raw" in r.evidence["non_finite_inputs"]
69
+        assert math.isnan(r.evidence["non_finite_inputs"]["raw"])
70
+
71
+    def test_inf_raw_routes_to_error(self) -> None:
72
+        r = safe_finalize(
73
+            name="p",
74
+            kind="delta_kl",
75
+            verdict=Verdict.PASS,
76
+            score=1.0,
77
+            raw=math.inf,
78
+        )
79
+        assert r.verdict == Verdict.ERROR
80
+        assert r.evidence["non_finite_inputs"]["raw"] == math.inf
81
+
82
+    def test_negative_inf_raw_routes_to_error(self) -> None:
83
+        r = safe_finalize(
84
+            name="p",
85
+            kind="delta_kl",
86
+            verdict=Verdict.PASS,
87
+            score=1.0,
88
+            raw=-math.inf,
89
+        )
90
+        assert r.verdict == Verdict.ERROR
91
+
92
+    def test_error_capture_includes_all_non_finite_fields(self) -> None:
93
+        """Even non-critical fields that are non-finite are recorded in evidence."""
94
+        r = safe_finalize(
95
+            name="p",
96
+            kind="delta_kl",
97
+            verdict=Verdict.PASS,
98
+            score=1.0,
99
+            raw=math.nan,
100
+            z_score=math.inf,
101
+            base_value=math.nan,
102
+        )
103
+        assert r.verdict == Verdict.ERROR
104
+        captured = r.evidence["non_finite_inputs"]
105
+        assert set(captured) == {"raw", "z_score", "base_value"}
106
+
107
+    def test_error_preserves_caller_evidence_keys(self) -> None:
108
+        r = safe_finalize(
109
+            name="p",
110
+            kind="delta_kl",
111
+            verdict=Verdict.PASS,
112
+            score=1.0,
113
+            raw=math.nan,
114
+            evidence={"per_prompt": [1, 2, 3], "num_prompts": 3},
115
+        )
116
+        assert r.verdict == Verdict.ERROR
117
+        assert r.evidence["per_prompt"] == [1, 2, 3]
118
+        assert r.evidence["num_prompts"] == 3
119
+        assert "non_finite_inputs" in r.evidence
120
+
121
+
122
+class TestNonCriticalNonFinite:
123
+    def test_nan_z_score_is_nulled_silently(self) -> None:
124
+        r = safe_finalize(
125
+            name="p",
126
+            kind="delta_kl",
127
+            verdict=Verdict.PASS,
128
+            score=0.7,
129
+            raw=0.05,
130
+            z_score=math.nan,
131
+        )
132
+        assert r.verdict == Verdict.PASS
133
+        assert r.score == 0.7
134
+        assert r.raw == 0.05
135
+        assert r.z_score is None
136
+        assert "z_score" in r.evidence["defensively_nulled"]
137
+
138
+    def test_nan_base_and_ft_nulled_preserves_passing_score(self) -> None:
139
+        r = safe_finalize(
140
+            name="p",
141
+            kind="delta_kl",
142
+            verdict=Verdict.PASS,
143
+            score=0.9,
144
+            raw=0.1,
145
+            base_value=math.nan,
146
+            ft_value=math.inf,
147
+        )
148
+        assert r.verdict == Verdict.PASS
149
+        assert r.base_value is None
150
+        assert r.ft_value is None
151
+        assert sorted(r.evidence["defensively_nulled"]) == ["base_value", "ft_value"]
152
+
153
+
154
+class TestCriticalFieldsOverride:
155
+    def test_z_score_critical_triggers_error_on_nan(self) -> None:
156
+        r = safe_finalize(
157
+            name="p",
158
+            kind="adapter_ablation",
159
+            verdict=Verdict.PASS,
160
+            score=1.0,
161
+            raw=0.9,
162
+            z_score=math.nan,
163
+            critical_fields=("raw", "z_score"),
164
+        )
165
+        assert r.verdict == Verdict.ERROR
166
+        assert "z_score" in r.message
167
+
168
+    def test_critical_fields_empty_allows_all_through(self) -> None:
169
+        """When no field is critical, even NaN raw only gets defensively nulled."""
170
+        r = safe_finalize(
171
+            name="p",
172
+            kind="delta_kl",
173
+            verdict=Verdict.PASS,
174
+            score=1.0,
175
+            raw=math.nan,
176
+            critical_fields=(),
177
+        )
178
+        assert r.verdict == Verdict.PASS
179
+        assert r.raw is None
180
+        assert "raw" in r.evidence["defensively_nulled"]
181
+
182
+
183
+class TestBoolFieldsNotMistakenForFloat:
184
+    """Pyantic sometimes wraps bools as ints; isinstance(True, int) is True.
185
+    We don't want booleans to be treated as numeric checks.
186
+    """
187
+
188
+    def test_true_in_a_numeric_slot_is_not_non_finite(self) -> None:
189
+        # This test pins behavior: even if a caller passes True, we don't
190
+        # crash. We also don't treat True as non-finite.
191
+        r = safe_finalize(
192
+            name="p",
193
+            kind="test",
194
+            verdict=Verdict.PASS,
195
+            score=1.0,
196
+            raw=True,  # type: ignore[arg-type]
197
+        )
198
+        assert r.verdict == Verdict.PASS  # bool is finite
199
+
200
+
201
+class TestResultTypeReturned:
202
+    def test_returns_probe_result(self) -> None:
203
+        r = safe_finalize(name="p", kind="k", verdict=Verdict.PASS, score=1.0)
204
+        assert isinstance(r, ProbeResult)