Python · 9865 bytes Raw Blame History
1 """Probe and suite result types.
2
3 Every numeric probe ultimately returns a :class:`ProbeResult`. The suite
4 runner collects them into a :class:`SuiteResult` and the scorer folds
5 that into a single :class:`SwayScore` with transparent per-component
6 weights.
7
8 These dataclasses are deliberately plain — no pydantic — because they
9 cross probe/backend boundaries hundreds of times per run and a free
10 ``model_validate`` on every construction would dominate the runtime of
11 cheap probes.
12 """
13
14 from __future__ import annotations
15
16 import math
17 from dataclasses import dataclass, field
18 from datetime import UTC, datetime
19 from enum import StrEnum
20 from typing import Any
21
22
23 class Verdict(StrEnum):
24 """Outcome of a single probe against its assertion."""
25
26 PASS = "pass"
27 FAIL = "fail"
28 WARN = "warn"
29 SKIP = "skip"
30 ERROR = "error"
31
32
33 @dataclass(frozen=True, slots=True)
34 class ProbeResult:
35 """The result of running one probe.
36
37 Attributes
38 ----------
39 name:
40 User-facing name from the spec (unique within a suite).
41 kind:
42 Probe discriminator (``delta_kl``, ``section_internalization`` …).
43 verdict:
44 Pass / fail / warn / skip / error.
45 score:
46 Normalized [0, 1] score. ``sigmoid(z_vs_null / 3)`` for numeric
47 probes; 1.0 / 0.0 for binary ones. ``None`` for :attr:`Verdict.SKIP`.
48 raw:
49 The raw metric value (e.g. KL=0.083). Probe-specific units.
50 z_score:
51 Standard deviations above the null-adapter baseline. ``None``
52 when no null calibration was run.
53 base_value:
54 The metric evaluated on the base model, when meaningful.
55 ft_value:
56 The metric evaluated on the fine-tuned model, when meaningful.
57 evidence:
58 Small structured payload for the report — prompts, example
59 completions, per-section breakdowns. Kept bounded (<10 KB) so
60 suite JSON stays under a megabyte.
61 message:
62 One-line diagnostic. Surfaces in the terminal report.
63 duration_s:
64 Wall time to execute.
65 ci_95:
66 95% percentile-bootstrap confidence interval on :attr:`raw`.
67 Populated by aggregating probes (``delta_kl`` over N prompts,
68 ``calibration_drift`` over pack items, etc.) via
69 :func:`dlm_sway.core.stats.bootstrap_ci`. ``None`` when the
70 probe doesn't aggregate (``adapter_ablation``,
71 ``style_fingerprint``), when the sample count is too low to
72 bootstrap meaningfully, or when the raw value isn't finite.
73 Report surfaces it inline as ``raw [lo–hi]``.
74 """
75
76 name: str
77 kind: str
78 verdict: Verdict
79 score: float | None
80 raw: float | None = None
81 z_score: float | None = None
82 base_value: float | None = None
83 ft_value: float | None = None
84 evidence: dict[str, Any] = field(default_factory=dict)
85 message: str = ""
86 duration_s: float = 0.0
87 ci_95: tuple[float, float] | None = None
88
89
90 @dataclass(frozen=True, slots=True)
91 class DeterminismReport:
92 """Serializable view of what seeding the runner accomplished.
93
94 Mirrors :class:`dlm_sway.core.determinism.DeterminismSummary` but
95 lives here so :class:`SuiteResult` doesn't pull ``determinism`` as
96 an import-time dependency of ``core.result``.
97 """
98
99 class_: str
100 seed: int
101 notes: tuple[str, ...] = ()
102
103
104 @dataclass(frozen=True, slots=True)
105 class SuiteResult:
106 """A full run of a sway.yaml suite."""
107
108 spec_path: str
109 started_at: datetime
110 finished_at: datetime
111 base_model_id: str
112 adapter_id: str
113 sway_version: str
114 probes: tuple[ProbeResult, ...] = ()
115 null_stats: dict[str, dict[str, float]] = field(default_factory=dict)
116 """Per-primitive null-adapter baseline stats (mean, std, runs). Used
117 to turn raw metrics into z-scores when rendering the report."""
118 determinism: DeterminismReport | None = None
119 """Classification of the determinism regime the suite ran under, from
120 :func:`dlm_sway.core.determinism.seed_everything`. ``None`` when the
121 caller bypassed seeding (e.g., unit tests constructing a
122 ``SuiteResult`` directly)."""
123 backend_stats: dict[str, float | int] = field(default_factory=dict)
124 """Forward-pass + cache counters from the backend
125 (:class:`dlm_sway.backends._instrumentation.BackendStats.to_dict`).
126 Populated by the runner at suite-end; ``{}`` when the backend
127 doesn't expose instrumentation (custom backends, pre-S07 snapshots)."""
128
129 @property
130 def wall_seconds(self) -> float:
131 return (self.finished_at - self.started_at).total_seconds()
132
133
134 # Component weights for the composite score. Overridable in sway.yaml.
135 # ``baseline`` is listed with weight 0.0 so the null-calibration row
136 # appears in the report for transparency but contributes nothing to the
137 # composite — it's an informational category, not a judgment one.
138 DEFAULT_COMPONENT_WEIGHTS: dict[str, float] = {
139 "adherence": 0.30,
140 "attribution": 0.35,
141 "calibration": 0.20,
142 "ablation": 0.15,
143 "baseline": 0.0,
144 }
145
146
147 @dataclass(frozen=True, slots=True)
148 class SwayScore:
149 """Composite score with a transparent per-component breakdown."""
150
151 overall: float
152 components: dict[str, float]
153 weights: dict[str, float] = field(default_factory=lambda: dict(DEFAULT_COMPONENT_WEIGHTS))
154 band: str = ""
155 findings: tuple[str, ...] = ()
156
157 @staticmethod
158 def band_for(overall: float) -> str:
159 """Map a score to a human-readable band.
160
161 Bands (from the plan):
162 - <0.3 : indistinguishable from noise
163 - 0.3–0.6 : partial fit
164 - 0.6–0.85: healthy
165 - >0.85 : suspiciously good (possible overfit / memorization)
166 """
167 if overall < 0.3:
168 return "noise"
169 if overall < 0.6:
170 return "partial"
171 if overall <= 0.85:
172 return "healthy"
173 return "suspicious"
174
175
176 def utcnow() -> datetime:
177 """Timezone-aware UTC timestamp (used by the runner)."""
178 return datetime.now(UTC)
179
180
181 def safe_finalize(
182 *,
183 name: str,
184 kind: str,
185 verdict: Verdict,
186 score: float | None = None,
187 raw: float | None = None,
188 z_score: float | None = None,
189 base_value: float | None = None,
190 ft_value: float | None = None,
191 evidence: dict[str, Any] | None = None,
192 message: str = "",
193 duration_s: float = 0.0,
194 ci_95: tuple[float, float] | None = None,
195 critical_fields: tuple[str, ...] = ("raw",),
196 ) -> ProbeResult:
197 """Build a :class:`ProbeResult` with defense against non-finite metrics.
198
199 Probes hand their candidate result kwargs here instead of constructing
200 a :class:`ProbeResult` directly. The helper inspects every numeric
201 field and classifies it:
202
203 - **Critical field non-finite** (any field named in ``critical_fields``
204 whose value is ``NaN`` or ``±inf``): the whole probe result is
205 converted to :attr:`Verdict.ERROR` with all scalar fields nulled out,
206 the offending values are preserved under
207 ``evidence["non_finite_inputs"]``, and the message explains which
208 field(s) were non-finite.
209 - **Non-critical field non-finite**: nulled out silently (set to
210 ``None``), and the field name appended to
211 ``evidence["defensively_nulled"]`` so a report reader can see what
212 happened.
213 - **Everything finite**: passthrough, no change.
214
215 The default ``critical_fields = ("raw",)`` reflects the design stance:
216 ``raw`` is the probe's ground-truth metric; a non-finite ``raw`` means
217 the probe cannot make a meaningful statement. Probes that care about
218 other fields (e.g., probes whose ``z_score`` is load-bearing) pass a
219 broader tuple.
220
221 This helper is the single shared guardrail sprint 01 installs against
222 the +11639σ class of bug, where NaN logprobs flowed silently through
223 to a PASS verdict. Every numeric probe is expected to finalize through
224 this function.
225 """
226 numeric_kwargs: dict[str, float | None] = {
227 "score": score,
228 "raw": raw,
229 "z_score": z_score,
230 "base_value": base_value,
231 "ft_value": ft_value,
232 }
233
234 non_finite: dict[str, float] = {}
235 for fname, v in numeric_kwargs.items():
236 if isinstance(v, int | float) and not isinstance(v, bool) and not math.isfinite(float(v)):
237 non_finite[fname] = float(v)
238
239 ev: dict[str, Any] = dict(evidence) if evidence is not None else {}
240
241 critical_non_finite = {k: v for k, v in non_finite.items() if k in critical_fields}
242 if critical_non_finite:
243 ev["non_finite_inputs"] = non_finite
244 return ProbeResult(
245 name=name,
246 kind=kind,
247 verdict=Verdict.ERROR,
248 score=None,
249 raw=None,
250 z_score=None,
251 base_value=None,
252 ft_value=None,
253 evidence=ev,
254 message=(
255 f"non-finite critical field(s): {', '.join(sorted(critical_non_finite))} "
256 f"— probe cannot produce a meaningful result"
257 ),
258 duration_s=duration_s,
259 )
260
261 if non_finite:
262 ev.setdefault("defensively_nulled", []).extend(sorted(non_finite))
263 for fname in non_finite:
264 numeric_kwargs[fname] = None
265
266 # ``ci_95`` is only attached when ``raw`` survived the
267 # defensive-null sweep — a CI bracketing a nulled-out point
268 # estimate would mislead more than it informs.
269 final_ci_95 = ci_95 if numeric_kwargs["raw"] is not None else None
270
271 return ProbeResult(
272 name=name,
273 kind=kind,
274 verdict=verdict,
275 score=numeric_kwargs["score"],
276 raw=numeric_kwargs["raw"],
277 z_score=numeric_kwargs["z_score"],
278 base_value=numeric_kwargs["base_value"],
279 ft_value=numeric_kwargs["ft_value"],
280 evidence=ev,
281 message=message,
282 duration_s=duration_s,
283 ci_95=final_ci_95,
284 )