Python · 11695 bytes Raw Blame History
1 """Probe abstract base + per-kind registry.
2
3 The registry is the extension point. Adding a new probe means:
4
5 1. Subclass :class:`ProbeSpec` with a unique ``kind`` field (Literal).
6 2. Subclass :class:`Probe` setting ``kind`` and ``spec_cls``.
7 3. Importing the probe module at least once (its subclass hook registers
8 itself).
9
10 The runner uses :func:`build_probe` to map each raw spec dict to a
11 ``(Probe, ProbeSpec)`` pair. Validation errors are turned into
12 :class:`~dlm_sway.core.errors.SpecValidationError` with the probe name
13 as the source so error messages localize to the offending entry.
14 """
15
16 from __future__ import annotations
17
18 from abc import ABC, abstractmethod
19 from collections.abc import Mapping
20 from dataclasses import dataclass, field
21 from typing import Any, ClassVar
22
23 from pydantic import BaseModel, ConfigDict, ValidationError
24
25 from dlm_sway.core.errors import SpecValidationError
26 from dlm_sway.core.result import ProbeResult
27 from dlm_sway.core.scoring import DifferentialBackend
28 from dlm_sway.core.sections import Section
29
30
31 class ProbeSpec(BaseModel):
32 """Common fields for every probe's spec entry in ``sway.yaml``."""
33
34 model_config = ConfigDict(extra="forbid", frozen=True)
35
36 name: str
37 """Unique within a suite; surfaces in the report."""
38 kind: str
39 """Discriminator — must match a registered :class:`Probe` subclass."""
40 enabled: bool = True
41 """If ``False`` the runner records a :class:`~dlm_sway.core.result.Verdict.SKIP`."""
42 weight: float = 1.0
43 """Weight inside the probe's component (adherence / attribution / …)."""
44
45
46 @dataclass(frozen=True, slots=True)
47 class RunContext:
48 """What a probe can read beyond its own spec.
49
50 Probes should receive exactly what they need and nothing more; fat
51 contexts encourage coupling between unrelated probes.
52
53 Attributes
54 ----------
55 backend:
56 The differential backend holding base + fine-tuned views.
57 seed:
58 Seed for deterministic probe RNGs (paraphrase sampling, etc).
59 top_k:
60 Default truncation for next-token distributions.
61 sections:
62 Optional list of typed sections (populated by the .dlm bridge;
63 ``None`` when sway is invoked against bare HF+PEFT).
64 doc_text:
65 Raw document text, if available.
66 null_stats:
67 Null-adapter baseline stats for z-score calibration, keyed by
68 probe *kind*. Populated by the runner after it's executed the
69 ``null_adapter`` probe (if configured). Typed as a read-only
70 :class:`~collections.abc.Mapping` and constructed from a
71 ``MappingProxyType`` so probes can't accidentally mutate the
72 stats other probes will consume.
73
74 When ``null_adapter`` was run with multiple ``rank_multipliers``,
75 this field carries the 1.0x group (or the first multiplier when
76 1.0 isn't present) for back-compat with probes that consume a
77 single calibration level. For the full rank profile, consult
78 :attr:`null_stats_by_rank`.
79 null_stats_by_rank:
80 Null-adapter stats across every rank multiplier the
81 ``null_adapter`` probe calibrated. Outer key is the canonical
82 rank label (``"rank_1.00"``, ``"rank_0.50"``, …); inner key is
83 the probe kind. Empty when a single-rank calibration ran —
84 probes that render a rank profile should check length > 1
85 before surfacing it.
86 downstream_kinds:
87 Tuple of probe kinds that appear *after* the current probe in
88 the suite. Populated by the runner before each probe runs;
89 ``NullAdapterProbe`` consults it to decide which probe kinds
90 to calibrate per-kind null stats for.
91 """
92
93 backend: DifferentialBackend | None = None
94 """The model-scoring backend. Required for every probe with
95 ``needs_backend=True`` (the default). Pre-run probes
96 (``needs_backend=False``, e.g. S25 ``gradient_ghost``) tolerate
97 ``None`` here so the runner can skip backend construction
98 entirely when only pre-flight probes are scheduled.
99
100 Existing probes access ``self.require_backend`` instead of
101 ``backend`` directly — the property narrows the type for mypy
102 and gives a clear runtime error if the runner ever passes
103 ``None`` to a probe that needs the backend.
104 """
105 seed: int = 0
106 top_k: int = 256
107 sections: tuple[Section, ...] | None = None
108 doc_text: str | None = None
109 null_stats: Mapping[str, Mapping[str, float]] = field(default_factory=dict)
110 null_stats_by_rank: Mapping[str, Mapping[str, Mapping[str, float]]] = field(
111 default_factory=dict
112 )
113 downstream_kinds: tuple[str, ...] = field(default_factory=tuple)
114
115 @property
116 def require_backend(self) -> DifferentialBackend:
117 """Return :attr:`backend`, asserting non-None.
118
119 Probes with ``needs_backend=True`` (default) call this to
120 narrow the type from ``DifferentialBackend | None`` to
121 ``DifferentialBackend``. The runner contract guarantees
122 non-None when scheduling backend-dependent probes; this
123 accessor turns a runner bug into a clear error rather than
124 a confusing AttributeError on ``None.as_base()``.
125 """
126 if self.backend is None:
127 raise RuntimeError(
128 "RunContext.backend is None — probe requires a backend "
129 "(needs_backend=True) but the runner did not provide one. "
130 "If this is a pre-run probe, set needs_backend=False on "
131 "the Probe subclass."
132 )
133 return self.backend
134
135
136 _REGISTRY: dict[str, type[Probe]] = {}
137
138
139 #: Generic LM-agnostic prompts used as sentinel inputs by per-probe
140 #: ``calibrate_spec()`` overrides. Each is short, content-neutral, and
141 #: should produce a finite logprob from any sane base model.
142 SENTINEL_PROMPTS: tuple[str, ...] = (
143 "The capital of",
144 "Once upon a time",
145 "An interesting fact is",
146 "The next step is to",
147 )
148
149
150 #: A fixed reference paragraph for stylistic-fingerprint calibration.
151 #: Mid-length, mid-complexity prose so the fingerprint vector exercises
152 #: every dimension non-degenerately.
153 SENTINEL_DOC: str = (
154 "This is a brief reference paragraph. It contains several short "
155 "sentences. The vocabulary is plain and the punctuation density "
156 "is moderate. A second paragraph follows the first.\n\n"
157 "Each clause is concise. Together they sample the fingerprint "
158 "dimensions a probe needs to compute a meaningful style shift."
159 )
160
161
162 class Probe(ABC):
163 """Concrete probe. One instance per probe spec in the suite."""
164
165 kind: ClassVar[str]
166 """The string used in ``sway.yaml``'s ``kind`` field."""
167 spec_cls: ClassVar[type[ProbeSpec]]
168 """The pydantic model class that validates this probe's spec."""
169 category: ClassVar[str] = "adherence"
170 """One of: ``adherence``, ``attribution``, ``calibration``,
171 ``ablation``, ``baseline``. Drives composite scoring."""
172 needs_backend: ClassVar[bool] = True
173 """Does this probe ever call ``ctx.backend``?
174
175 Default ``True`` — every probe shipped through Sprint 24 reads
176 next-token distributions or logprobs from a backend view. The
177 runner therefore builds the backend before any probe runs.
178
179 Set ``False`` for **pre-run diagnostic probes** that consume
180 artifacts on disk (e.g. S25's ``gradient_ghost`` which loads
181 ``training_state.pt``). When *every* scheduled probe has
182 ``needs_backend=False``, the runner skips backend construction
183 entirely — no model load, no GPU memory, no cold-start
184 latency. Pre-flight verdicts can short-circuit a slow suite
185 when the adapter is obviously broken.
186 """
187 batch_score: ClassVar[bool] = False
188 """Does this probe score its prompts via the backend's batched
189 :meth:`~dlm_sway.core.scoring.ScoringBackend.next_token_dist_batch`
190 path (S23)?
191
192 When ``True`` the probe's ``run()`` uses a single batched call per
193 view (typically inside a ``with backend.as_base()`` block) instead
194 of looping one prompt at a time. Backends with real batching (HF)
195 amortize kernel-launch overhead; backends without (dummy, MLX for
196 now) see identical behavior via the Protocol's default loop.
197
198 Opt-in. Set to ``True`` only when every prompt flows through a
199 uniform ``next_token_dist`` call (delta_kl, cluster_kl). Probes
200 that parameterize per-prompt context (prompt_collapse) or
201 interleave backend toggles per prompt (adapter_ablation) keep
202 ``False`` — those need bespoke batching logic deferred to a
203 follow-up sprint.
204 """
205
206 def __init_subclass__(cls, **kwargs: Any) -> None:
207 super().__init_subclass__(**kwargs)
208 # The abstract class itself has no `kind`; skip registration.
209 if "kind" not in cls.__dict__:
210 return
211 kind = cls.kind
212 if kind in _REGISTRY:
213 raise ValueError(f"duplicate probe kind {kind!r}: {_REGISTRY[kind]!r} vs {cls!r}")
214 _REGISTRY[kind] = cls
215
216 @abstractmethod
217 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: ...
218
219 @classmethod
220 def calibrate_spec(cls, ctx: RunContext) -> ProbeSpec | None:
221 """Return a small spec for null-adapter calibration of this kind.
222
223 ``NullAdapterProbe`` calls this once per kind in
224 ``ctx.downstream_kinds`` to harvest the per-kind null
225 distribution. Returning ``None`` opts the kind out of
226 calibration — downstream probes of that kind fall back to
227 their fixed-threshold paths and surface ``(no calibration)``
228 in the report.
229
230 Default returns ``None``. Probes that *can* be calibrated
231 override to return a small spec (typically 4 sentinel prompts,
232 minimal compute) suitable for running against the
233 :class:`~dlm_sway.probes._null_proxy.NullCalibrationBackendProxy`.
234 """
235 del ctx
236 return None
237
238
239 def registry() -> dict[str, type[Probe]]:
240 """Read-only view of registered probes."""
241 return dict(_REGISTRY)
242
243
244 def build_probe(raw: dict[str, Any]) -> tuple[Probe, ProbeSpec]:
245 """Validate a raw YAML probe entry and return (Probe instance, spec)."""
246 kind = raw.get("kind")
247 if not isinstance(kind, str):
248 raise SpecValidationError(
249 "probe entry missing string 'kind' field",
250 source=str(raw.get("name", "<unknown>")),
251 )
252 if kind not in _REGISTRY:
253 known = ", ".join(sorted(_REGISTRY))
254 raise SpecValidationError(
255 f"unknown probe kind {kind!r} (registered: {known})",
256 source=str(raw.get("name", "<unknown>")),
257 )
258 probe_cls = _REGISTRY[kind]
259 try:
260 spec = probe_cls.spec_cls.model_validate(raw)
261 except ValidationError as exc:
262 raise SpecValidationError(str(exc), source=str(raw.get("name", "<unknown>"))) from exc
263 return probe_cls(), spec
264
265
266 def validate_all_probes(suite: list[dict[str, Any]]) -> None:
267 """Run :func:`build_probe` against every entry; collect all errors.
268
269 The runner / CLI calls this before constructing the backend (which
270 is the slow, network-touching step). A typo in any ``kind:`` field
271 therefore surfaces before any model is loaded — and *every* typo
272 in the spec surfaces in a single error message instead of
273 forcing the user to fix one, re-run, fix the next, re-run (B7).
274 """
275 errors: list[str] = []
276 for idx, raw in enumerate(suite):
277 try:
278 build_probe(raw)
279 except SpecValidationError as exc:
280 label = raw.get("name") or f"entry #{idx}"
281 errors.append(f" - {label}: {exc}")
282 if errors:
283 raise SpecValidationError("spec contains invalid probe entries:\n" + "\n".join(errors))