tenseleyflow/sway / 96aef85

Browse files

backends: PreflightCheckable protocol + finite-check on HF and dummy

Authored by espadonne
SHA
96aef85c9d2a12363cc3457715336548e422a193
Parents
ebba1a3
Tree
39d3a80

5 changed files

StatusFile+-
M src/dlm_sway/__init__.py 2 0
M src/dlm_sway/backends/dummy.py 26 0
M src/dlm_sway/backends/hf.py 46 0
M src/dlm_sway/core/scoring.py 22 0
A tests/unit/test_preflight_check.py 67 0
src/dlm_sway/__init__.pymodified
@@ -17,6 +17,7 @@ from dlm_sway.core.result import ProbeResult, SuiteResult, SwayScore, Verdict, s
1717
 from dlm_sway.core.scoring import (
1818
     DifferentialBackend,
1919
     NullCalibratedBackend,
20
+    PreflightCheckable,
2021
     RollingLogprob,
2122
     ScalableDifferentialBackend,
2223
     ScoringBackend,
@@ -30,6 +31,7 @@ __all__ = [
3031
     "Model",
3132
     "ModelSpec",
3233
     "NullCalibratedBackend",
34
+    "PreflightCheckable",
3335
     "ProbeError",
3436
     "ProbeResult",
3537
     "RollingLogprob",
src/dlm_sway/backends/dummy.pymodified
@@ -245,6 +245,32 @@ class DummyDifferentialBackend:
245245
         finally:
246246
             self._exit()
247247
 
248
+    def preflight_finite_check(self) -> tuple[bool, str]:
249
+        """Smoke a single forward pass per view; reject non-finite logits.
250
+
251
+        For the dummy backend the canned data is finite by construction
252
+        unless tests deliberately seed NaN-laden ``TokenDist`` entries —
253
+        which is exactly what S01 tests do to verify the runner gate.
254
+        """
255
+        prompt = "preflight"
256
+        try:
257
+            with self.as_base() as base_view:
258
+                base_dist = base_view.next_token_dist(prompt, top_k=8)
259
+            with self.as_finetuned() as ft_view:
260
+                ft_dist = ft_view.next_token_dist(prompt, top_k=8)
261
+        except Exception as exc:  # noqa: BLE001
262
+            return False, f"preflight raised {type(exc).__name__}: {exc}"
263
+
264
+        for label, dist in (("base", base_dist), ("ft", ft_dist)):
265
+            if not np.all(np.isfinite(dist.logprobs)):
266
+                n_bad = int((~np.isfinite(dist.logprobs)).sum())
267
+                return (
268
+                    False,
269
+                    f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite "
270
+                    f"logprob(s) on prompt {prompt!r}",
271
+                )
272
+        return True, ""
273
+
248274
     def _enter(self, mode: str) -> None:
249275
         if self._active is not None:
250276
             raise RuntimeError(
src/dlm_sway/backends/hf.pymodified
@@ -350,6 +350,52 @@ class HuggingFaceDifferentialBackend:
350350
         if self._torch.cuda.is_available():
351351
             self._torch.cuda.empty_cache()
352352
 
353
+    # -- PreflightCheckable -------------------------------------------
354
+
355
+    _PREFLIGHT_PROMPT = "hello"
356
+    _PREFLIGHT_TOP_K = 8
357
+
358
+    def preflight_finite_check(self) -> tuple[bool, str]:
359
+        """One forward pass per view; assert both produce finite logits.
360
+
361
+        Catches the +11639σ class of bug at suite-load time: a NaN-weighted
362
+        adapter would produce non-finite logprobs here, the runner sees
363
+        ``ok=False``, and the suite aborts with a single synthetic ERROR
364
+        probe — never reaching a probe that would pass on garbage.
365
+        """
366
+        import math
367
+
368
+        try:
369
+            with self.as_base() as base_view:
370
+                base_dist = base_view.next_token_dist(
371
+                    self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K
372
+                )
373
+            with self.as_finetuned() as ft_view:
374
+                ft_dist = ft_view.next_token_dist(
375
+                    self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K
376
+                )
377
+        except Exception as exc:  # noqa: BLE001 — backend may raise anything
378
+            return False, f"preflight forward pass raised {type(exc).__name__}: {exc}"
379
+
380
+        for label, dist in (("base", base_dist), ("ft", ft_dist)):
381
+            n_bad = int((~np.isfinite(dist.logprobs)).sum())
382
+            if n_bad > 0:
383
+                return (
384
+                    False,
385
+                    f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite "
386
+                    f"logprob(s) on prompt {self._PREFLIGHT_PROMPT!r} — adapter is "
387
+                    f"likely broken (NaN/inf weights). sway refuses to score a model "
388
+                    f"producing non-finite outputs.",
389
+                )
390
+            tail = dist.tail_logprob
391
+            if not math.isfinite(tail):
392
+                return (
393
+                    False,
394
+                    f"{label} view produced non-finite tail_logprob = {tail}",
395
+                )
396
+
397
+        return True, ""
398
+
353399
     # -- internals -----------------------------------------------------
354400
 
355401
     def _make_view(self, mode: str) -> _HFView:
src/dlm_sway/core/scoring.pymodified
@@ -161,6 +161,28 @@ class ScalableDifferentialBackend(DifferentialBackend, Protocol):
161161
     def as_scaled_adapter(self, lam: float) -> AbstractContextManager[_ScoringModel]: ...
162162
 
163163
 
164
+@runtime_checkable
165
+class PreflightCheckable(Protocol):
166
+    """A backend that can validate itself before any probe runs.
167
+
168
+    Returns ``(ok, reason)`` from a single forward pass per view with a
169
+    fixed sentinel prompt, asserting that both the base and fine-tuned
170
+    distributions contain finite logits.
171
+
172
+    The runner calls this at suite start; on failure it aborts with a
173
+    single synthetic ERROR probe explaining the issue, so a NaN-weighted
174
+    adapter never produces a false PASS verdict (the +11639σ class of
175
+    bug from Audit 01).
176
+
177
+    This Protocol is **opt-in** — backends that don't implement it run
178
+    without the check (the runner skips with a NOTE-level log entry).
179
+    All shipped backends in this version implement it; custom backends
180
+    are encouraged to.
181
+    """
182
+
183
+    def preflight_finite_check(self) -> tuple[bool, str]: ...
184
+
185
+
164186
 @runtime_checkable
165187
 class NullCalibratedBackend(DifferentialBackend, Protocol):
166188
     """A differential backend that can produce a "null adapter" view.
tests/unit/test_preflight_check.pyadded
@@ -0,0 +1,67 @@
1
+"""Tests for ``preflight_finite_check`` on the shipped backends.
2
+
3
+The HF backend's check is exercised in the integration suite (it needs
4
+a real model). Here we verify the dummy backend's contract and the
5
+PreflightCheckable Protocol shape.
6
+"""
7
+
8
+from __future__ import annotations
9
+
10
+import math
11
+
12
+import numpy as np
13
+
14
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
15
+from dlm_sway.core.scoring import PreflightCheckable, TokenDist
16
+
17
+
18
+class TestProtocolShape:
19
+    def test_dummy_satisfies_preflight_protocol(self) -> None:
20
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
21
+        assert isinstance(backend, PreflightCheckable)
22
+
23
+
24
+class TestDummyDefaultIsFinite:
25
+    def test_default_dummy_passes_preflight(self) -> None:
26
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
27
+        ok, reason = backend.preflight_finite_check()
28
+        assert ok is True
29
+        assert reason == ""
30
+
31
+
32
+class TestDummyNanDetection:
33
+    def test_nan_in_ft_token_dist_caught(self) -> None:
34
+        nan_dist = TokenDist(
35
+            token_ids=np.array([1, 2, 3], dtype=np.int64),
36
+            logprobs=np.array([-0.1, math.nan, -2.0], dtype=np.float32),
37
+            vocab_size=100,
38
+        )
39
+        ft = DummyResponses(token_dists={"preflight": nan_dist})
40
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft)
41
+        ok, reason = backend.preflight_finite_check()
42
+        assert ok is False
43
+        assert "ft view" in reason
44
+        assert "non-finite" in reason
45
+
46
+    def test_nan_in_base_token_dist_caught(self) -> None:
47
+        nan_dist = TokenDist(
48
+            token_ids=np.array([1, 2], dtype=np.int64),
49
+            logprobs=np.array([math.nan, math.nan], dtype=np.float32),
50
+            vocab_size=100,
51
+        )
52
+        base = DummyResponses(token_dists={"preflight": nan_dist})
53
+        backend = DummyDifferentialBackend(base=base, ft=DummyResponses())
54
+        ok, reason = backend.preflight_finite_check()
55
+        assert ok is False
56
+        assert "base view" in reason
57
+
58
+    def test_inf_caught(self) -> None:
59
+        inf_dist = TokenDist(
60
+            token_ids=np.array([1, 2], dtype=np.int64),
61
+            logprobs=np.array([-0.5, math.inf], dtype=np.float32),
62
+            vocab_size=100,
63
+        )
64
+        ft = DummyResponses(token_dists={"preflight": inf_dist})
65
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft)
66
+        ok, _reason = backend.preflight_finite_check()
67
+        assert ok is False