Python · 7463 bytes Raw Blame History
1 """B1 SectionInternalizationScore — the flagship attribution primitive.
2
3 For each typed section of the training document, measure *how much the
4 fine-tune moved the needle on that section's own content* — and subtract
5 the same metric measured on *other* sections' content. The difference is
6 the "effective SIS": signal attributable to *this* section, not to a
7 broader lift across the whole document.
8
9 Output is a per-section bar chart. In practice users see that sections
10 2 and 7 actually moved the model, sections 3 and 5 did nothing, and
11 section 11 moved it but also leaked into unrelated content — actionable
12 signal for document authoring that no other eval tool provides.
13
14 Math per section ``s`` with measurement function ``m(probe_set)``:
15
16 .. math::
17 sis_s^{own} &= (m_{base}(s) - m_{ft}(s)) / m_{base}(s)
18 sis_s^{leak} &= (m_{base}(\\bar s) - m_{ft}(\\bar s)) / m_{base}(\\bar s)
19 effective &= sis_s^{own} - sis_s^{leak}
20
21 For PROSE sections, ``m`` is the average NLL per token over the
22 section's content. For INSTRUCTION and PREFERENCE sections, ``m`` is the
23 average NLL per token over the answer/chosen spans given their prompts.
24 """
25
26 from __future__ import annotations
27
28 import statistics
29 from typing import Literal
30
31 from pydantic import Field
32
33 from dlm_sway.core.result import ProbeResult, Verdict
34 from dlm_sway.core.scoring import ScoringBackend
35 from dlm_sway.core.sections import Section, SectionKind
36 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
37
38
39 def _default_include_kinds() -> list[SectionKind]:
40 return ["prose", "instruction", "preference"]
41
42
43 class SectionInternalizationSpec(ProbeSpec):
44 kind: Literal["section_internalization"] = "section_internalization"
45 include_kinds: list[SectionKind] = Field(default_factory=_default_include_kinds)
46 per_section_threshold: float = 0.05
47 """Minimum ``effective_sis`` for a section to be marked PASS."""
48 assert_passing_section_frac: float = 0.5
49 """Probe-level pass criterion: fraction of sections that must clear
50 the per-section threshold."""
51 max_prose_chars: int = 2000
52 """Cap the length of PROSE content we score to keep runtime bounded.
53 Long sections are chunked; this is the per-chunk cap."""
54
55
56 class SectionInternalizationProbe(Probe):
57 kind = "section_internalization"
58 spec_cls = SectionInternalizationSpec
59 category = "attribution"
60
61 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
62 assert isinstance(spec, SectionInternalizationSpec)
63 if ctx.sections is None or len(ctx.sections) == 0:
64 return ProbeResult(
65 name=spec.name,
66 kind=spec.kind,
67 verdict=Verdict.SKIP,
68 score=None,
69 message="no sections in context — provide via the .dlm bridge",
70 )
71
72 kinds_allowed = set(spec.include_kinds)
73 eligible = [s for s in ctx.sections if s.kind in kinds_allowed]
74 if len(eligible) < 2:
75 return ProbeResult(
76 name=spec.name,
77 kind=spec.kind,
78 verdict=Verdict.SKIP,
79 score=None,
80 message=(
81 f"need ≥2 eligible sections for leak-check; got {len(eligible)} "
82 f"(kinds={spec.include_kinds})"
83 ),
84 )
85
86 # Pre-compute per-section base and ft NLL-per-token to avoid
87 # re-running the forward pass for leak-checks.
88 base_nll: dict[str, float] = {}
89 ft_nll: dict[str, float] = {}
90 with ctx.backend.as_base() as base_view:
91 for s in eligible:
92 base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars)
93 with ctx.backend.as_finetuned() as ft_view:
94 for s in eligible:
95 ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars)
96
97 per_section: list[dict[str, float | str | bool]] = []
98 passing = 0
99 effective_scores: list[float] = []
100 for s in eligible:
101 others = [o for o in eligible if o.id != s.id]
102 own_lift = _relative_lift(base_nll[s.id], ft_nll[s.id])
103 leak_lift = statistics.fmean(
104 _relative_lift(base_nll[o.id], ft_nll[o.id]) for o in others
105 )
106 effective = own_lift - leak_lift
107 effective_scores.append(effective)
108 did_pass = effective >= spec.per_section_threshold
109 passing += int(did_pass)
110 per_section.append(
111 {
112 "section_id": s.id,
113 "kind": s.kind,
114 "tag": s.tag or "",
115 "base_nll": base_nll[s.id],
116 "ft_nll": ft_nll[s.id],
117 "own_lift": own_lift,
118 "leak_lift": leak_lift,
119 "effective_sis": effective,
120 "passed": did_pass,
121 }
122 )
123
124 passing_frac = passing / len(eligible)
125 verdict = Verdict.PASS if passing_frac >= spec.assert_passing_section_frac else Verdict.FAIL
126 score = passing_frac
127 return ProbeResult(
128 name=spec.name,
129 kind=spec.kind,
130 verdict=verdict,
131 score=score,
132 raw=statistics.fmean(effective_scores),
133 evidence={
134 "per_section": per_section,
135 "num_sections": len(eligible),
136 "passing_frac": passing_frac,
137 "per_section_threshold": spec.per_section_threshold,
138 "weight": spec.weight,
139 },
140 message=(
141 f"{passing}/{len(eligible)} sections cleared "
142 f"effective_sis≥{spec.per_section_threshold:.2f} (mean={statistics.fmean(effective_scores):+.3f})"
143 ),
144 )
145
146
147 def _section_nll(s: Section, view: ScoringBackend, max_prose_chars: int) -> float:
148 """Average NLL per token for the section's content under ``view``."""
149 if s.kind == "prose":
150 return _prose_nll(s.content[:max_prose_chars], view)
151 if s.kind == "instruction":
152 if not s.probes:
153 return _prose_nll(s.content[:max_prose_chars], view)
154 return statistics.fmean(
155 -view.logprob_of(p.prompt, p.gold) / max(_token_estimate(p.gold), 1) for p in s.probes
156 )
157 if s.kind == "preference":
158 if not s.preferences:
159 return _prose_nll(s.content[:max_prose_chars], view)
160 return statistics.fmean(
161 -view.logprob_of(p.prompt, p.chosen) / max(_token_estimate(p.chosen), 1)
162 for p in s.preferences
163 )
164 raise ValueError(f"unknown section kind: {s.kind!r}")
165
166
167 def _prose_nll(text: str, view: ScoringBackend) -> float:
168 """Negative-mean-logprob over ``text``. Returns 0 for empty input."""
169 if not text.strip():
170 return 0.0
171 r = view.rolling_logprob(text)
172 return -r.mean_logprob
173
174
175 def _relative_lift(base_nll: float, ft_nll: float) -> float:
176 """``(base - ft) / base``. Positive → ft is lower-PPL than base.
177
178 Falls back to an absolute delta when ``base`` is pathological
179 (zero or negative), so the probe doesn't crash on degenerate
180 inputs.
181 """
182 if base_nll <= 0.0:
183 return float(base_nll - ft_nll)
184 return float((base_nll - ft_nll) / base_nll)
185
186
187 def _token_estimate(s: str) -> int:
188 """Approximate tokens for normalization. Good enough for SentencePiece-ish vocabs."""
189 return max(1, len(s) // 4)