Python · 9414 bytes Raw Blame History
1 """B1 SectionInternalizationScore — the flagship attribution primitive.
2
3 For each typed section of the training document, measure *how much the
4 fine-tune moved the needle on that section's own content* — and subtract
5 the same metric measured on *other* sections' content. The difference is
6 the "effective SIS": signal attributable to *this* section, not to a
7 broader lift across the whole document.
8
9 Output is a per-section bar chart. In practice users see that sections
10 2 and 7 actually moved the model, sections 3 and 5 did nothing, and
11 section 11 moved it but also leaked into unrelated content — actionable
12 signal for document authoring that no other eval tool provides.
13
14 Math per section ``s`` with measurement function ``m(probe_set)``:
15
16 .. math::
17 sis_s^{own} &= (m_{base}(s) - m_{ft}(s)) / m_{base}(s)
18 sis_s^{leak} &= (m_{base}(\\bar s) - m_{ft}(\\bar s)) / m_{base}(\\bar s)
19 effective &= sis_s^{own} - sis_s^{leak}
20
21 For PROSE sections, ``m`` is the average NLL per token over the
22 section's content. For INSTRUCTION and PREFERENCE sections, ``m`` is the
23 average NLL per token over the answer/chosen spans given their prompts.
24 """
25
26 from __future__ import annotations
27
28 import statistics
29 from typing import Literal
30
31 from pydantic import Field
32
33 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
34 from dlm_sway.core.scoring import ScoringBackend
35 from dlm_sway.core.sections import Section, SectionKind
36 from dlm_sway.core.stats import bootstrap_ci
37 from dlm_sway.probes._zscore import (
38 no_calibration_note,
39 score_from_z,
40 verdict_from_z,
41 z_score,
42 z_scores_by_rank,
43 )
44 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
45 from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank
46
47
48 def _default_include_kinds() -> list[SectionKind]:
49 return ["prose", "instruction", "preference"]
50
51
52 class SectionInternalizationSpec(ProbeSpec):
53 kind: Literal["section_internalization"] = "section_internalization"
54 include_kinds: list[SectionKind] = Field(default_factory=_default_include_kinds)
55 per_section_threshold: float = 0.05
56 """Minimum ``effective_sis`` for a section to be marked PASS."""
57 assert_passing_section_frac: float = 0.5
58 """Probe-level pass criterion: fraction of sections that must clear
59 the per-section threshold."""
60 assert_z_gte: float = 3.0
61 """Z-score pass criterion against the null-adapter baseline, when it
62 exists. Preferred over the raw threshold. The statistic z-scored is
63 the mean ``effective_sis`` across sections."""
64 max_prose_chars: int = 2000
65 """Cap the length of PROSE content we score to keep runtime bounded.
66 Long sections are chunked; this is the per-chunk cap."""
67
68
69 class SectionInternalizationProbe(Probe):
70 kind = "section_internalization"
71 spec_cls = SectionInternalizationSpec
72 category = "attribution"
73
74 @classmethod
75 def calibrate_spec(cls, ctx: RunContext) -> SectionInternalizationSpec | None:
76 # Needs sections; if the bridge didn't populate them, opt out.
77 if ctx.sections is None or len(ctx.sections) < 2:
78 return None
79 return SectionInternalizationSpec(
80 name="_calibration",
81 kind="section_internalization",
82 per_section_threshold=0.05,
83 )
84
85 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
86 assert isinstance(spec, SectionInternalizationSpec)
87 if ctx.sections is None or len(ctx.sections) == 0:
88 return ProbeResult(
89 name=spec.name,
90 kind=spec.kind,
91 verdict=Verdict.SKIP,
92 score=None,
93 message="no sections in context — provide via the .dlm bridge",
94 )
95
96 kinds_allowed = set(spec.include_kinds)
97 eligible = [s for s in ctx.sections if s.kind in kinds_allowed]
98 if len(eligible) < 2:
99 return ProbeResult(
100 name=spec.name,
101 kind=spec.kind,
102 verdict=Verdict.SKIP,
103 score=None,
104 message=(
105 f"need ≥2 eligible sections for leak-check; got {len(eligible)} "
106 f"(kinds={spec.include_kinds})"
107 ),
108 )
109
110 # Pre-compute per-section base and ft NLL-per-token to avoid
111 # re-running the forward pass for leak-checks.
112 base_nll: dict[str, float] = {}
113 ft_nll: dict[str, float] = {}
114 with ctx.require_backend.as_base() as base_view:
115 for s in eligible:
116 base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars)
117 with ctx.require_backend.as_finetuned() as ft_view:
118 for s in eligible:
119 ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars)
120
121 per_section: list[dict[str, float | str | bool]] = []
122 passing = 0
123 effective_scores: list[float] = []
124 for s in eligible:
125 others = [o for o in eligible if o.id != s.id]
126 own_lift = _relative_lift(base_nll[s.id], ft_nll[s.id])
127 leak_lift = statistics.fmean(
128 _relative_lift(base_nll[o.id], ft_nll[o.id]) for o in others
129 )
130 effective = own_lift - leak_lift
131 effective_scores.append(effective)
132 did_pass = effective >= spec.per_section_threshold
133 passing += int(did_pass)
134 per_section.append(
135 {
136 "section_id": s.id,
137 "kind": s.kind,
138 "tag": s.tag or "",
139 "base_nll": base_nll[s.id],
140 "ft_nll": ft_nll[s.id],
141 "own_lift": own_lift,
142 "leak_lift": leak_lift,
143 "effective_sis": effective,
144 "passed": did_pass,
145 }
146 )
147
148 passing_frac = passing / len(eligible)
149 raw_mean = statistics.fmean(effective_scores)
150 ci_95 = bootstrap_ci(effective_scores, seed=ctx.seed)
151
152 # Null-adapter calibration wins when available.
153 stats = get_null_stats(ctx, spec.kind)
154 z = z_score(raw_mean, stats)
155 z_by_rank = z_scores_by_rank(raw_mean, get_null_stats_by_rank(ctx, spec.kind), sign=+1)
156 verdict_z = verdict_from_z(z, spec.assert_z_gte)
157 if verdict_z is not None:
158 verdict = verdict_z
159 score_val = score_from_z(z)
160 score = score_val if score_val is not None else 0.0
161 message = (
162 f"{passing}/{len(eligible)} sections cleared; "
163 f"mean effective_sis={raw_mean:+.3f}, z={z:+.2f}σ vs null"
164 )
165 else:
166 verdict = (
167 Verdict.PASS if passing_frac >= spec.assert_passing_section_frac else Verdict.FAIL
168 )
169 score = passing_frac
170 message = (
171 f"{passing}/{len(eligible)} sections cleared "
172 f"effective_sis≥{spec.per_section_threshold:.2f} "
173 f"(mean={raw_mean:+.3f}) {no_calibration_note(spec.kind)}"
174 )
175
176 return safe_finalize(
177 name=spec.name,
178 kind=spec.kind,
179 verdict=verdict,
180 score=score,
181 raw=raw_mean,
182 z_score=z,
183 evidence={
184 "per_section": per_section,
185 "num_sections": len(eligible),
186 "passing_frac": passing_frac,
187 "per_section_threshold": spec.per_section_threshold,
188 "weight": spec.weight,
189 "z_by_rank": z_by_rank,
190 "raw_ci_95": list(ci_95) if ci_95 is not None else None,
191 },
192 message=message,
193 ci_95=ci_95,
194 )
195
196
197 def _section_nll(s: Section, view: ScoringBackend, max_prose_chars: int) -> float:
198 """Average NLL per token for the section's content under ``view``."""
199 if s.kind == "prose":
200 return _prose_nll(s.content[:max_prose_chars], view)
201 if s.kind == "instruction":
202 if not s.probes:
203 return _prose_nll(s.content[:max_prose_chars], view)
204 return statistics.fmean(
205 -view.logprob_of(p.prompt, p.gold) / max(_token_estimate(p.gold), 1) for p in s.probes
206 )
207 if s.kind == "preference":
208 if not s.preferences:
209 return _prose_nll(s.content[:max_prose_chars], view)
210 return statistics.fmean(
211 -view.logprob_of(p.prompt, p.chosen) / max(_token_estimate(p.chosen), 1)
212 for p in s.preferences
213 )
214 raise ValueError(f"unknown section kind: {s.kind!r}")
215
216
217 def _prose_nll(text: str, view: ScoringBackend) -> float:
218 """Negative-mean-logprob over ``text``. Returns 0 for empty input."""
219 if not text.strip():
220 return 0.0
221 r = view.rolling_logprob(text)
222 return -r.mean_logprob
223
224
225 def _relative_lift(base_nll: float, ft_nll: float) -> float:
226 """``(base - ft) / base``. Positive → ft is lower-PPL than base.
227
228 Falls back to an absolute delta when ``base`` is pathological
229 (zero or negative), so the probe doesn't crash on degenerate
230 inputs.
231 """
232 if base_nll <= 0.0:
233 return float(base_nll - ft_nll)
234 return float((base_nll - ft_nll) / base_nll)
235
236
237 def _token_estimate(s: str) -> int:
238 """Approximate tokens for normalization. Good enough for SentencePiece-ish vocabs."""
239 return max(1, len(s) // 4)