tenseleyflow/sway / 74d94e2

Browse files

probes: RunContext.backend optional + ctx.backend.as_* → ctx.require_backend.as_* sweep (S25 P4 prep)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
74d94e2d0ce18025ec24be6cd483e4cab092c0d0
Parents
3021f24
Tree
3218183

13 changed files

StatusFile+-
M src/dlm_sway/probes/adapter_ablation.py 7 3
M src/dlm_sway/probes/adapter_revert.py 2 2
M src/dlm_sway/probes/base.py 32 1
M src/dlm_sway/probes/calibration_drift.py 2 2
M src/dlm_sway/probes/cluster_kl.py 2 2
M src/dlm_sway/probes/delta_kl.py 2 2
M src/dlm_sway/probes/external_perplexity.py 2 2
M src/dlm_sway/probes/leakage.py 1 1
M src/dlm_sway/probes/paraphrase_invariance.py 2 2
M src/dlm_sway/probes/preference_flip.py 2 2
M src/dlm_sway/probes/prompt_collapse.py 2 2
M src/dlm_sway/probes/section_internalization.py 2 2
M src/dlm_sway/probes/style_fingerprint.py 2 2
src/dlm_sway/probes/adapter_ablation.pymodified
@@ -89,7 +89,11 @@ class AdapterAblationProbe(Probe):
8989
                 score=None,
9090
                 message="no prompts provided",
9191
             )
92
-        if not isinstance(ctx.backend, ScalableDifferentialBackend):
92
+        # Local binding so mypy keeps the ScalableDifferentialBackend
93
+        # narrowing across the loop below (require_backend's return
94
+        # type is the base DifferentialBackend; we narrow once here).
95
+        scalable = ctx.backend
96
+        if not isinstance(scalable, ScalableDifferentialBackend):
9397
             return ProbeResult(
9498
                 name=spec.name,
9599
                 kind=spec.kind,
@@ -109,9 +113,9 @@ class AdapterAblationProbe(Probe):
109113
         for lam in spec.lambdas:
110114
             divs_for_lam: list[float] = []
111115
             for prompt in spec.prompts:
112
-                with ctx.backend.as_scaled_adapter(lam_zero) as ref:
116
+                with scalable.as_scaled_adapter(lam_zero) as ref:
113117
                     ref_dist = ref.next_token_dist(prompt, top_k=top_k)
114
-                with ctx.backend.as_scaled_adapter(lam) as scaled:
118
+                with scalable.as_scaled_adapter(lam) as scaled:
115119
                     scaled_dist = scaled.next_token_dist(prompt, top_k=top_k)
116120
                 divs_for_lam.append(divergence(ref_dist, scaled_dist, kind=spec.divergence))
117121
             per_lambda.append(float(np.mean(divs_for_lam)))
src/dlm_sway/probes/adapter_revert.pymodified
@@ -104,9 +104,9 @@ class AdapterRevertProbe(Probe):
104104
         for case in spec.cases:
105105
             gold_vec = embed([case.gold])[0]
106106
             for pp in case.paraphrases:
107
-                with ctx.backend.as_base() as bv:
107
+                with ctx.require_backend.as_base() as bv:
108108
                     base_gen = bv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
109
-                with ctx.backend.as_finetuned() as fv:
109
+                with ctx.require_backend.as_finetuned() as fv:
110110
                     ft_gen = fv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
111111
                 vecs = embed([base_gen, ft_gen])
112112
                 base_vec, ft_vec = vecs[0], vecs[1]
src/dlm_sway/probes/base.pymodified
@@ -90,7 +90,18 @@ class RunContext:
9090
         to calibrate per-kind null stats for.
9191
     """
9292
 
93
-    backend: DifferentialBackend
93
+    backend: DifferentialBackend | None = None
94
+    """The model-scoring backend. Required for every probe with
95
+    ``needs_backend=True`` (the default). Pre-run probes
96
+    (``needs_backend=False``, e.g. S25 ``gradient_ghost``) tolerate
97
+    ``None`` here so the runner can skip backend construction
98
+    entirely when only pre-flight probes are scheduled.
99
+
100
+    Existing probes access ``self.require_backend`` instead of
101
+    ``backend`` directly — the property narrows the type for mypy
102
+    and gives a clear runtime error if the runner ever passes
103
+    ``None`` to a probe that needs the backend.
104
+    """
94105
     seed: int = 0
95106
     top_k: int = 256
96107
     sections: tuple[Section, ...] | None = None
@@ -101,6 +112,26 @@ class RunContext:
101112
     )
102113
     downstream_kinds: tuple[str, ...] = field(default_factory=tuple)
103114
 
115
+    @property
116
+    def require_backend(self) -> DifferentialBackend:
117
+        """Return :attr:`backend`, asserting non-None.
118
+
119
+        Probes with ``needs_backend=True`` (default) call this to
120
+        narrow the type from ``DifferentialBackend | None`` to
121
+        ``DifferentialBackend``. The runner contract guarantees
122
+        non-None when scheduling backend-dependent probes; this
123
+        accessor turns a runner bug into a clear error rather than
124
+        a confusing AttributeError on ``None.as_base()``.
125
+        """
126
+        if self.backend is None:
127
+            raise RuntimeError(
128
+                "RunContext.backend is None — probe requires a backend "
129
+                "(needs_backend=True) but the runner did not provide one. "
130
+                "If this is a pre-run probe, set needs_backend=False on "
131
+                "the Probe subclass."
132
+            )
133
+        return self.backend
134
+
104135
 
105136
 _REGISTRY: dict[str, type[Probe]] = {}
106137
 
src/dlm_sway/probes/calibration_drift.pymodified
@@ -109,9 +109,9 @@ class CalibrationDriftProbe(Probe):
109109
 
110110
         for prompt, gold in items:
111111
             tokens = max(_token_estimate(gold), 1)
112
-            with ctx.backend.as_base() as b:
112
+            with ctx.require_backend.as_base() as b:
113113
                 lp_base = b.logprob_of(prompt, gold) / tokens
114
-            with ctx.backend.as_finetuned() as f:
114
+            with ctx.require_backend.as_finetuned() as f:
115115
                 lp_ft = f.logprob_of(prompt, gold) / tokens
116116
             delta = lp_ft - lp_base
117117
             deltas.append(delta)
src/dlm_sway/probes/cluster_kl.pymodified
@@ -192,9 +192,9 @@ class ClusterKLProbe(Probe):
192192
         # S23 — per-prompt divergences, now via one batched forward
193193
         # per view (same math as ``delta_kl``).
194194
         top_k = spec.top_k if spec.top_k is not None else ctx.top_k
195
-        with ctx.backend.as_base() as base_view:
195
+        with ctx.require_backend.as_base() as base_view:
196196
             base_dists = base_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
197
-        with ctx.backend.as_finetuned() as ft_view:
197
+        with ctx.require_backend.as_finetuned() as ft_view:
198198
             ft_dists = ft_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
199199
         divergences: list[float] = [
200200
             divergence(b, f, kind=spec.divergence)
src/dlm_sway/probes/delta_kl.pymodified
@@ -93,9 +93,9 @@ class DeltaKLProbe(Probe):
9393
         # manager entries. next_token_dist_batch falls back to a
9494
         # per-prompt loop on backends without real batching so the
9595
         # result list stays identical to the pre-S23 path.
96
-        with ctx.backend.as_base() as base_view:
96
+        with ctx.require_backend.as_base() as base_view:
9797
             base_dists = base_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
98
-        with ctx.backend.as_finetuned() as ft_view:
98
+        with ctx.require_backend.as_finetuned() as ft_view:
9999
             ft_dists = ft_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
100100
         divergences: list[float] = [
101101
             divergence(b, f, kind=spec.divergence)
src/dlm_sway/probes/external_perplexity.pymodified
@@ -147,9 +147,9 @@ class ExternalPerplexityProbe(Probe):
147147
         total_base_lp = 0.0
148148
         total_ft_lp = 0.0
149149
         for chunk in chunks:
150
-            with ctx.backend.as_base() as b:
150
+            with ctx.require_backend.as_base() as b:
151151
                 base_rl = b.rolling_logprob(chunk)
152
-            with ctx.backend.as_finetuned() as f:
152
+            with ctx.require_backend.as_finetuned() as f:
153153
                 ft_rl = f.rolling_logprob(chunk)
154154
             # Per-token mean logprob for this chunk. ``logprobs.size``
155155
             # is ``num_tokens - 1`` by the RollingLogprob contract.
src/dlm_sway/probes/leakage.pymodified
@@ -120,7 +120,7 @@ class LeakageSusceptibilityProbe(Probe):
120120
         perturbed_recalls: list[float] = []
121121
         per_section: list[dict[str, float | str]] = []
122122
 
123
-        with ctx.backend.as_finetuned() as ft:
123
+        with ctx.require_backend.as_finetuned() as ft:
124124
             for s in prose:
125125
                 prefix = s.content[: spec.prefix_chars]
126126
                 target = s.content[spec.prefix_chars : spec.prefix_chars + spec.continuation_chars]
src/dlm_sway/probes/paraphrase_invariance.pymodified
@@ -115,10 +115,10 @@ class ParaphraseInvarianceProbe(Probe):
115115
 
116116
         for case in spec.cases:
117117
             tokens = max(_token_estimate(case.gold), 1)
118
-            with ctx.backend.as_base() as b:
118
+            with ctx.require_backend.as_base() as b:
119119
                 lp_base_verb = b.logprob_of(case.prompt, case.gold) / tokens
120120
                 lp_base_par = [b.logprob_of(p, case.gold) / tokens for p in case.paraphrases]
121
-            with ctx.backend.as_finetuned() as f:
121
+            with ctx.require_backend.as_finetuned() as f:
122122
                 lp_ft_verb = f.logprob_of(case.prompt, case.gold) / tokens
123123
                 lp_ft_par = [f.logprob_of(p, case.gold) / tokens for p in case.paraphrases]
124124
 
src/dlm_sway/probes/preference_flip.pymodified
@@ -103,11 +103,11 @@ class PreferenceFlipProbe(Probe):
103103
             # batch down. Fence per triple so probes degrade gracefully:
104104
             # drop the offending triple, count it, surface in evidence.
105105
             try:
106
-                with ctx.backend.as_base() as b:
106
+                with ctx.require_backend.as_base() as b:
107107
                     base_margin = b.logprob_of(t.prompt, t.chosen) - b.logprob_of(
108108
                         t.prompt, t.rejected
109109
                     )
110
-                with ctx.backend.as_finetuned() as f:
110
+                with ctx.require_backend.as_finetuned() as f:
111111
                     ft_margin = f.logprob_of(t.prompt, t.chosen) - f.logprob_of(
112112
                         t.prompt, t.rejected
113113
                     )
src/dlm_sway/probes/prompt_collapse.pymodified
@@ -100,9 +100,9 @@ class PromptCollapseProbe(Probe):
100100
             divs: list[float] = []
101101
             for prompt in spec.prompts:
102102
                 full_prompt = prefix + prompt
103
-                with ctx.backend.as_base() as bv:
103
+                with ctx.require_backend.as_base() as bv:
104104
                     base_dist = bv.next_token_dist(full_prompt, top_k=top_k)
105
-                with ctx.backend.as_finetuned() as fv:
105
+                with ctx.require_backend.as_finetuned() as fv:
106106
                     ft_dist = fv.next_token_dist(full_prompt, top_k=top_k)
107107
                 divs.append(divergence(base_dist, ft_dist, kind=spec.divergence))
108108
             mean_divs.append(float(np.mean(divs)))
src/dlm_sway/probes/section_internalization.pymodified
@@ -111,10 +111,10 @@ class SectionInternalizationProbe(Probe):
111111
         # re-running the forward pass for leak-checks.
112112
         base_nll: dict[str, float] = {}
113113
         ft_nll: dict[str, float] = {}
114
-        with ctx.backend.as_base() as base_view:
114
+        with ctx.require_backend.as_base() as base_view:
115115
             for s in eligible:
116116
                 base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars)
117
-        with ctx.backend.as_finetuned() as ft_view:
117
+        with ctx.require_backend.as_finetuned() as ft_view:
118118
             for s in eligible:
119119
                 ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars)
120120
 
src/dlm_sway/probes/style_fingerprint.pymodified
@@ -300,11 +300,11 @@ class StyleFingerprintProbe(Probe):
300300
         base_samples: list[str] = []
301301
         ft_samples: list[str] = []
302302
         for prompt in spec.prompts:
303
-            with ctx.backend.as_base() as b:
303
+            with ctx.require_backend.as_base() as b:
304304
                 base_samples.append(
305305
                     b.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
306306
                 )
307
-            with ctx.backend.as_finetuned() as f:
307
+            with ctx.require_backend.as_finetuned() as f:
308308
                 ft_samples.append(
309309
                     f.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
310310
                 )