Python · 12964 bytes Raw Blame History
1 """Suite runner.
2
3 Iterates the probe list, materializes each into a ``(Probe, Spec)`` via
4 the registry, executes it with a :class:`~dlm_sway.probes.base.RunContext`,
5 and assembles a :class:`~dlm_sway.core.result.SuiteResult`.
6
7 Runtime contract:
8
9 - Probes are executed in declaration order (not sorted, not parallelized).
10 The null-adapter baseline has to run before any probe that needs z-scores,
11 so authoring order is load-bearing.
12 - A probe that raises is recorded as
13 :attr:`~dlm_sway.core.result.Verdict.ERROR` and the suite continues —
14 one broken probe doesn't torch the whole report.
15 - The backend is the caller's responsibility: the runner does not build
16 or close it, so callers can reuse a backend across multiple suites.
17 """
18
19 from __future__ import annotations
20
21 import time
22 from pathlib import Path
23 from types import MappingProxyType
24
25 from dlm_sway import __version__
26 from dlm_sway.core.determinism import seed_everything
27 from dlm_sway.core.errors import BackendNotAvailableError, ProbeError
28 from dlm_sway.core.result import DeterminismReport, ProbeResult, SuiteResult, Verdict, utcnow
29 from dlm_sway.core.scoring import DifferentialBackend, PreflightCheckable
30 from dlm_sway.core.sections import Section
31 from dlm_sway.probes.base import RunContext, build_probe
32 from dlm_sway.probes.null_adapter import NullAdapterSpec, get_null_stats
33 from dlm_sway.suite.spec import SwaySpec
34
35
36 def run(
37 spec: SwaySpec,
38 backend: DifferentialBackend | None,
39 *,
40 spec_path: str = "<memory>",
41 doc_text: str | None = None,
42 sections: tuple[Section, ...] | None = None,
43 skip_preflight: bool = False,
44 trace_path: Path | None = None,
45 ) -> SuiteResult:
46 """Execute every probe in ``spec`` against ``backend``.
47
48 Before any probe runs, the runner asks the backend to preflight-check
49 itself if it implements :class:`PreflightCheckable`. On failure
50 (e.g., a NaN-weighted adapter), the runner aborts the suite with a
51 single synthetic ERROR probe and skips every configured probe — so
52 a broken model never produces a false PASS verdict (the +11639σ
53 class of bug from Audit 01).
54
55 Set ``skip_preflight=True`` to disable the gate (e.g., for sub-second
56 test suites where the cost matters); the default is to run it.
57
58 **S25 — backend-optional path.** ``backend`` may be ``None`` when
59 every scheduled probe declares ``needs_backend=False`` (e.g. a
60 suite composed entirely of pre-run diagnostic probes like
61 ``gradient_ghost``). In that case the runner skips backend-related
62 work (preflight, trace writer, backend-stats snapshot) and the
63 caller can avoid building a model entirely. Passing ``None`` while
64 any probe needs a backend raises ``BackendNotAvailableError``
65 listing the offending probe kinds.
66 """
67 started = utcnow()
68
69 # S25 — backend-optional gate. Resolve every probe class up front
70 # so we can decide whether the backend is actually needed before
71 # any expensive work fires. (build_probe is cheap — pydantic
72 # validation only.)
73 _scheduled_probes = [build_probe(raw) for raw in spec.suite]
74 _needs_backend_kinds = sorted(
75 {ps.kind for probe, ps in _scheduled_probes if probe.__class__.needs_backend and ps.enabled}
76 )
77 if backend is None and _needs_backend_kinds:
78 raise BackendNotAvailableError(
79 f"<runner: spec={spec_path}>",
80 extra="hf",
81 hint=(
82 "The runner was called with backend=None but the suite "
83 f"includes probes that require a backend: "
84 f"{_needs_backend_kinds}. Build a backend before calling "
85 "run(), or remove those probes from the suite."
86 ),
87 )
88
89 # Sprint 07: ``concurrent_probes`` is scaffolding. The HF and MLX
90 # backends declare ``safe_for_concurrent_views = False`` so the
91 # runner stays sequential below. The spec field is still validated
92 # + accepted (so a future pool can light up without a schema bump)
93 # but does nothing until the pool lands; see
94 # ``.docs/design/backend-concurrency.md``.
95 #
96 # DC4 (Audit 02) — the previous stderr warning fired every run
97 # with ``concurrent_probes>1`` even though execution was
98 # unchanged, which was noise for no benefit. The warning is
99 # reinstated by the sprint that wires the pool.
100
101 # Seed every RNG sway's probes touch before any backend work runs.
102 # ``strict=True`` asks torch for deterministic algorithms and sets
103 # CUBLAS_WORKSPACE_CONFIG; this is a no-op when torch is absent, so
104 # callers without the ``hf`` extra still get a seeded python/numpy
105 # state and a ``best_effort`` classification in the report.
106 det_summary = seed_everything(spec.defaults.seed, strict=True)
107 determinism = DeterminismReport(
108 class_=det_summary.class_,
109 seed=det_summary.seed,
110 notes=det_summary.notes,
111 )
112
113 # Sprint 07: attach a trace writer to the backend's
114 # instrumentation if the caller asked for one. Silent no-op for
115 # backends without ``_inst`` (custom backends) and for the default
116 # ``trace_path=None`` case. S25 — also skipped when backend is None.
117 if backend is not None:
118 _install_trace_writer(backend, trace_path)
119
120 ctx = RunContext(
121 backend=backend,
122 seed=spec.defaults.seed,
123 top_k=spec.defaults.top_k,
124 sections=sections,
125 doc_text=doc_text,
126 )
127
128 results: list[ProbeResult] = []
129 null_stats: dict[str, dict[str, float]] = {}
130 null_stats_by_rank: dict[str, dict[str, dict[str, float]]] = {}
131
132 # Preflight gate: if the backend can self-check, do so before any
133 # probe runs. A failing preflight aborts the suite. S25 — when
134 # backend is None (only pre-run probes scheduled), preflight is
135 # also skipped since there's no model state to check.
136 if not skip_preflight and backend is not None and isinstance(backend, PreflightCheckable):
137 t0 = time.perf_counter()
138 ok, reason = backend.preflight_finite_check()
139 duration = time.perf_counter() - t0
140 if not ok:
141 results.append(
142 ProbeResult(
143 name="__preflight__",
144 kind="preflight",
145 verdict=Verdict.ERROR,
146 score=None,
147 message=(
148 f"backend preflight failed — suite aborted before any probe ran. {reason}"
149 ),
150 duration_s=duration,
151 evidence={"preflight_reason": reason},
152 )
153 )
154 finished = utcnow()
155 return SuiteResult(
156 spec_path=spec_path,
157 started_at=started,
158 finished_at=finished,
159 base_model_id=spec.models.base.base,
160 adapter_id=str(spec.models.ft.adapter) if spec.models.ft.adapter else "",
161 sway_version=__version__,
162 probes=tuple(results),
163 null_stats={},
164 determinism=determinism,
165 backend_stats=_snapshot_backend_stats(backend),
166 )
167
168 # Pre-extract suite kinds so each probe sees only what's *after* it.
169 suite_kinds: list[str] = [str(entry.get("kind", "")) for entry in spec.suite]
170
171 for idx, (probe, probe_spec) in enumerate(_scheduled_probes):
172 if not probe_spec.enabled:
173 results.append(
174 ProbeResult(
175 name=probe_spec.name,
176 kind=probe_spec.kind,
177 verdict=Verdict.SKIP,
178 score=None,
179 message="disabled in spec",
180 )
181 )
182 continue
183
184 # Refresh ctx.downstream_kinds for this probe. The current
185 # null_adapter probe wants to know which probe kinds it's
186 # calibrating *for* — that's the kinds in the suite after itself.
187 downstream_kinds = tuple(k for k in suite_kinds[idx + 1 :] if k)
188 ctx = RunContext(
189 backend=ctx.backend,
190 seed=ctx.seed,
191 top_k=ctx.top_k,
192 sections=ctx.sections,
193 doc_text=ctx.doc_text,
194 null_stats=ctx.null_stats,
195 null_stats_by_rank=ctx.null_stats_by_rank,
196 downstream_kinds=downstream_kinds,
197 )
198
199 # Label trace events with the currently-running probe so the
200 # JSONL is filterable by probe name. Skipped when backend is
201 # None (S25 pre-run-only suite path).
202 if backend is not None:
203 _set_backend_probe_label(backend, probe_spec.name)
204
205 t0 = time.perf_counter()
206 try:
207 result = probe.run(probe_spec, ctx)
208 except ProbeError as exc:
209 result = ProbeResult(
210 name=probe_spec.name,
211 kind=probe_spec.kind,
212 verdict=Verdict.ERROR,
213 score=None,
214 message=str(exc),
215 )
216 except Exception as exc: # noqa: BLE001 — probe impls may raise anything
217 result = ProbeResult(
218 name=probe_spec.name,
219 kind=probe_spec.kind,
220 verdict=Verdict.ERROR,
221 score=None,
222 message=f"{type(exc).__name__}: {exc}",
223 )
224 duration = time.perf_counter() - t0
225 # Re-stamp duration (probes don't know their own wall time).
226 result = _with_duration(result, duration)
227 results.append(result)
228
229 # Null-adapter result seeds ctx.null_stats for subsequent probes.
230 if isinstance(probe_spec, NullAdapterSpec) and result.evidence.get("null_stats"):
231 null_stats.update(result.evidence["null_stats"])
232 null_stats_by_rank.update(result.evidence.get("null_stats_by_rank") or {})
233 # The dataclass is frozen, but the dict was previously
234 # passed by reference — a probe could have mutated stats
235 # other probes consume. Wrap in MappingProxyType so the
236 # contract matches the docstring (B21).
237 ctx = RunContext(
238 backend=ctx.backend,
239 seed=ctx.seed,
240 top_k=ctx.top_k,
241 sections=ctx.sections,
242 doc_text=ctx.doc_text,
243 null_stats=MappingProxyType(null_stats),
244 null_stats_by_rank=MappingProxyType(null_stats_by_rank),
245 )
246
247 if backend is not None:
248 _set_backend_probe_label(backend, None)
249 finished = utcnow()
250 return SuiteResult(
251 spec_path=spec_path,
252 started_at=started,
253 finished_at=finished,
254 base_model_id=spec.models.base.base,
255 adapter_id=str(spec.models.ft.adapter) if spec.models.ft.adapter else "",
256 sway_version=__version__,
257 probes=tuple(results),
258 null_stats=null_stats,
259 determinism=determinism,
260 backend_stats=_snapshot_backend_stats(backend),
261 )
262
263
264 def _install_trace_writer(backend: DifferentialBackend, trace_path: Path | None) -> None:
265 """Attach a :class:`TraceWriter` to the backend's instrumentation.
266
267 Silent no-op when the backend doesn't expose ``_inst`` (custom
268 backends). When ``trace_path`` is ``None``, the existing no-op
269 tracer stays in place.
270 """
271 if trace_path is None:
272 return
273 inst = getattr(backend, "_inst", None)
274 if inst is None:
275 return
276 from dlm_sway.backends._instrumentation import TraceWriter
277
278 inst.trace = TraceWriter(trace_path)
279
280
281 def _set_backend_probe_label(backend: DifferentialBackend, name: str | None) -> None:
282 inst = getattr(backend, "_inst", None)
283 if inst is None:
284 return
285 inst.set_current_probe(name)
286
287
288 def _snapshot_backend_stats(backend: DifferentialBackend | None) -> dict[str, float | int]:
289 """Copy the backend's counters into a plain dict for ``SuiteResult``.
290
291 Empty dict when backend is None (S25 pre-run-only suite path).
292 """
293 if backend is None:
294 return {}
295 inst = getattr(backend, "_inst", None)
296 if inst is None:
297 return {}
298 return dict(inst.stats.to_dict())
299
300
301 def _with_duration(result: ProbeResult, duration: float) -> ProbeResult:
302 """Return a copy of ``result`` with :attr:`ProbeResult.duration_s` set.
303
304 Every ``ProbeResult`` field must be forwarded explicitly — a dropped
305 field here silently returns ``None`` for the whole runner path, which
306 is how F01 (bootstrap ``ci_95`` stripped from every probe) landed
307 before the audit caught it.
308 """
309 return ProbeResult(
310 name=result.name,
311 kind=result.kind,
312 verdict=result.verdict,
313 score=result.score,
314 raw=result.raw,
315 z_score=result.z_score,
316 base_value=result.base_value,
317 ft_value=result.ft_value,
318 evidence=result.evidence,
319 message=result.message,
320 duration_s=duration,
321 ci_95=result.ci_95,
322 )
323
324
325 __all__ = ["get_null_stats", "run"]