backends: PreflightCheckable protocol + finite-check on HF and dummy
- SHA
96aef85c9d2a12363cc3457715336548e422a193- Parents
-
ebba1a3 - Tree
39d3a80
96aef85
96aef85c9d2a12363cc3457715336548e422a193ebba1a3
39d3a80| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm_sway/__init__.py
|
2 | 0 |
| M |
src/dlm_sway/backends/dummy.py
|
26 | 0 |
| M |
src/dlm_sway/backends/hf.py
|
46 | 0 |
| M |
src/dlm_sway/core/scoring.py
|
22 | 0 |
| A |
tests/unit/test_preflight_check.py
|
67 | 0 |
src/dlm_sway/__init__.pymodified@@ -17,6 +17,7 @@ from dlm_sway.core.result import ProbeResult, SuiteResult, SwayScore, Verdict, s | ||
| 17 | 17 | from dlm_sway.core.scoring import ( |
| 18 | 18 | DifferentialBackend, |
| 19 | 19 | NullCalibratedBackend, |
| 20 | + PreflightCheckable, | |
| 20 | 21 | RollingLogprob, |
| 21 | 22 | ScalableDifferentialBackend, |
| 22 | 23 | ScoringBackend, |
@@ -30,6 +31,7 @@ __all__ = [ | ||
| 30 | 31 | "Model", |
| 31 | 32 | "ModelSpec", |
| 32 | 33 | "NullCalibratedBackend", |
| 34 | + "PreflightCheckable", | |
| 33 | 35 | "ProbeError", |
| 34 | 36 | "ProbeResult", |
| 35 | 37 | "RollingLogprob", |
src/dlm_sway/backends/dummy.pymodified@@ -245,6 +245,32 @@ class DummyDifferentialBackend: | ||
| 245 | 245 | finally: |
| 246 | 246 | self._exit() |
| 247 | 247 | |
| 248 | + def preflight_finite_check(self) -> tuple[bool, str]: | |
| 249 | + """Smoke a single forward pass per view; reject non-finite logits. | |
| 250 | + | |
| 251 | + For the dummy backend the canned data is finite by construction | |
| 252 | + unless tests deliberately seed NaN-laden ``TokenDist`` entries — | |
| 253 | + which is exactly what S01 tests do to verify the runner gate. | |
| 254 | + """ | |
| 255 | + prompt = "preflight" | |
| 256 | + try: | |
| 257 | + with self.as_base() as base_view: | |
| 258 | + base_dist = base_view.next_token_dist(prompt, top_k=8) | |
| 259 | + with self.as_finetuned() as ft_view: | |
| 260 | + ft_dist = ft_view.next_token_dist(prompt, top_k=8) | |
| 261 | + except Exception as exc: # noqa: BLE001 | |
| 262 | + return False, f"preflight raised {type(exc).__name__}: {exc}" | |
| 263 | + | |
| 264 | + for label, dist in (("base", base_dist), ("ft", ft_dist)): | |
| 265 | + if not np.all(np.isfinite(dist.logprobs)): | |
| 266 | + n_bad = int((~np.isfinite(dist.logprobs)).sum()) | |
| 267 | + return ( | |
| 268 | + False, | |
| 269 | + f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite " | |
| 270 | + f"logprob(s) on prompt {prompt!r}", | |
| 271 | + ) | |
| 272 | + return True, "" | |
| 273 | + | |
| 248 | 274 | def _enter(self, mode: str) -> None: |
| 249 | 275 | if self._active is not None: |
| 250 | 276 | raise RuntimeError( |
src/dlm_sway/backends/hf.pymodified@@ -350,6 +350,52 @@ class HuggingFaceDifferentialBackend: | ||
| 350 | 350 | if self._torch.cuda.is_available(): |
| 351 | 351 | self._torch.cuda.empty_cache() |
| 352 | 352 | |
| 353 | + # -- PreflightCheckable ------------------------------------------- | |
| 354 | + | |
| 355 | + _PREFLIGHT_PROMPT = "hello" | |
| 356 | + _PREFLIGHT_TOP_K = 8 | |
| 357 | + | |
| 358 | + def preflight_finite_check(self) -> tuple[bool, str]: | |
| 359 | + """One forward pass per view; assert both produce finite logits. | |
| 360 | + | |
| 361 | + Catches the +11639σ class of bug at suite-load time: a NaN-weighted | |
| 362 | + adapter would produce non-finite logprobs here, the runner sees | |
| 363 | + ``ok=False``, and the suite aborts with a single synthetic ERROR | |
| 364 | + probe — never reaching a probe that would pass on garbage. | |
| 365 | + """ | |
| 366 | + import math | |
| 367 | + | |
| 368 | + try: | |
| 369 | + with self.as_base() as base_view: | |
| 370 | + base_dist = base_view.next_token_dist( | |
| 371 | + self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K | |
| 372 | + ) | |
| 373 | + with self.as_finetuned() as ft_view: | |
| 374 | + ft_dist = ft_view.next_token_dist( | |
| 375 | + self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K | |
| 376 | + ) | |
| 377 | + except Exception as exc: # noqa: BLE001 — backend may raise anything | |
| 378 | + return False, f"preflight forward pass raised {type(exc).__name__}: {exc}" | |
| 379 | + | |
| 380 | + for label, dist in (("base", base_dist), ("ft", ft_dist)): | |
| 381 | + n_bad = int((~np.isfinite(dist.logprobs)).sum()) | |
| 382 | + if n_bad > 0: | |
| 383 | + return ( | |
| 384 | + False, | |
| 385 | + f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite " | |
| 386 | + f"logprob(s) on prompt {self._PREFLIGHT_PROMPT!r} — adapter is " | |
| 387 | + f"likely broken (NaN/inf weights). sway refuses to score a model " | |
| 388 | + f"producing non-finite outputs.", | |
| 389 | + ) | |
| 390 | + tail = dist.tail_logprob | |
| 391 | + if not math.isfinite(tail): | |
| 392 | + return ( | |
| 393 | + False, | |
| 394 | + f"{label} view produced non-finite tail_logprob = {tail}", | |
| 395 | + ) | |
| 396 | + | |
| 397 | + return True, "" | |
| 398 | + | |
| 353 | 399 | # -- internals ----------------------------------------------------- |
| 354 | 400 | |
| 355 | 401 | def _make_view(self, mode: str) -> _HFView: |
src/dlm_sway/core/scoring.pymodified@@ -161,6 +161,28 @@ class ScalableDifferentialBackend(DifferentialBackend, Protocol): | ||
| 161 | 161 | def as_scaled_adapter(self, lam: float) -> AbstractContextManager[_ScoringModel]: ... |
| 162 | 162 | |
| 163 | 163 | |
| 164 | +@runtime_checkable | |
| 165 | +class PreflightCheckable(Protocol): | |
| 166 | + """A backend that can validate itself before any probe runs. | |
| 167 | + | |
| 168 | + Returns ``(ok, reason)`` from a single forward pass per view with a | |
| 169 | + fixed sentinel prompt, asserting that both the base and fine-tuned | |
| 170 | + distributions contain finite logits. | |
| 171 | + | |
| 172 | + The runner calls this at suite start; on failure it aborts with a | |
| 173 | + single synthetic ERROR probe explaining the issue, so a NaN-weighted | |
| 174 | + adapter never produces a false PASS verdict (the +11639σ class of | |
| 175 | + bug from Audit 01). | |
| 176 | + | |
| 177 | + This Protocol is **opt-in** — backends that don't implement it run | |
| 178 | + without the check (the runner skips with a NOTE-level log entry). | |
| 179 | + All shipped backends in this version implement it; custom backends | |
| 180 | + are encouraged to. | |
| 181 | + """ | |
| 182 | + | |
| 183 | + def preflight_finite_check(self) -> tuple[bool, str]: ... | |
| 184 | + | |
| 185 | + | |
| 164 | 186 | @runtime_checkable |
| 165 | 187 | class NullCalibratedBackend(DifferentialBackend, Protocol): |
| 166 | 188 | """A differential backend that can produce a "null adapter" view. |
tests/unit/test_preflight_check.pyadded@@ -0,0 +1,67 @@ | ||
| 1 | +"""Tests for ``preflight_finite_check`` on the shipped backends. | |
| 2 | + | |
| 3 | +The HF backend's check is exercised in the integration suite (it needs | |
| 4 | +a real model). Here we verify the dummy backend's contract and the | |
| 5 | +PreflightCheckable Protocol shape. | |
| 6 | +""" | |
| 7 | + | |
| 8 | +from __future__ import annotations | |
| 9 | + | |
| 10 | +import math | |
| 11 | + | |
| 12 | +import numpy as np | |
| 13 | + | |
| 14 | +from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses | |
| 15 | +from dlm_sway.core.scoring import PreflightCheckable, TokenDist | |
| 16 | + | |
| 17 | + | |
| 18 | +class TestProtocolShape: | |
| 19 | + def test_dummy_satisfies_preflight_protocol(self) -> None: | |
| 20 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) | |
| 21 | + assert isinstance(backend, PreflightCheckable) | |
| 22 | + | |
| 23 | + | |
| 24 | +class TestDummyDefaultIsFinite: | |
| 25 | + def test_default_dummy_passes_preflight(self) -> None: | |
| 26 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) | |
| 27 | + ok, reason = backend.preflight_finite_check() | |
| 28 | + assert ok is True | |
| 29 | + assert reason == "" | |
| 30 | + | |
| 31 | + | |
| 32 | +class TestDummyNanDetection: | |
| 33 | + def test_nan_in_ft_token_dist_caught(self) -> None: | |
| 34 | + nan_dist = TokenDist( | |
| 35 | + token_ids=np.array([1, 2, 3], dtype=np.int64), | |
| 36 | + logprobs=np.array([-0.1, math.nan, -2.0], dtype=np.float32), | |
| 37 | + vocab_size=100, | |
| 38 | + ) | |
| 39 | + ft = DummyResponses(token_dists={"preflight": nan_dist}) | |
| 40 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft) | |
| 41 | + ok, reason = backend.preflight_finite_check() | |
| 42 | + assert ok is False | |
| 43 | + assert "ft view" in reason | |
| 44 | + assert "non-finite" in reason | |
| 45 | + | |
| 46 | + def test_nan_in_base_token_dist_caught(self) -> None: | |
| 47 | + nan_dist = TokenDist( | |
| 48 | + token_ids=np.array([1, 2], dtype=np.int64), | |
| 49 | + logprobs=np.array([math.nan, math.nan], dtype=np.float32), | |
| 50 | + vocab_size=100, | |
| 51 | + ) | |
| 52 | + base = DummyResponses(token_dists={"preflight": nan_dist}) | |
| 53 | + backend = DummyDifferentialBackend(base=base, ft=DummyResponses()) | |
| 54 | + ok, reason = backend.preflight_finite_check() | |
| 55 | + assert ok is False | |
| 56 | + assert "base view" in reason | |
| 57 | + | |
| 58 | + def test_inf_caught(self) -> None: | |
| 59 | + inf_dist = TokenDist( | |
| 60 | + token_ids=np.array([1, 2], dtype=np.int64), | |
| 61 | + logprobs=np.array([-0.5, math.inf], dtype=np.float32), | |
| 62 | + vocab_size=100, | |
| 63 | + ) | |
| 64 | + ft = DummyResponses(token_dists={"preflight": inf_dist}) | |
| 65 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft) | |
| 66 | + ok, _reason = backend.preflight_finite_check() | |
| 67 | + assert ok is False | |