| 1 | """B1 SectionInternalizationScore — the flagship attribution primitive. |
| 2 | |
| 3 | For each typed section of the training document, measure *how much the |
| 4 | fine-tune moved the needle on that section's own content* — and subtract |
| 5 | the same metric measured on *other* sections' content. The difference is |
| 6 | the "effective SIS": signal attributable to *this* section, not to a |
| 7 | broader lift across the whole document. |
| 8 | |
| 9 | Output is a per-section bar chart. In practice users see that sections |
| 10 | 2 and 7 actually moved the model, sections 3 and 5 did nothing, and |
| 11 | section 11 moved it but also leaked into unrelated content — actionable |
| 12 | signal for document authoring that no other eval tool provides. |
| 13 | |
| 14 | Math per section ``s`` with measurement function ``m(probe_set)``: |
| 15 | |
| 16 | .. math:: |
| 17 | sis_s^{own} &= (m_{base}(s) - m_{ft}(s)) / m_{base}(s) |
| 18 | sis_s^{leak} &= (m_{base}(\\bar s) - m_{ft}(\\bar s)) / m_{base}(\\bar s) |
| 19 | effective &= sis_s^{own} - sis_s^{leak} |
| 20 | |
| 21 | For PROSE sections, ``m`` is the average NLL per token over the |
| 22 | section's content. For INSTRUCTION and PREFERENCE sections, ``m`` is the |
| 23 | average NLL per token over the answer/chosen spans given their prompts. |
| 24 | """ |
| 25 | |
| 26 | from __future__ import annotations |
| 27 | |
| 28 | import statistics |
| 29 | from typing import Literal |
| 30 | |
| 31 | from pydantic import Field |
| 32 | |
| 33 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 34 | from dlm_sway.core.scoring import ScoringBackend |
| 35 | from dlm_sway.core.sections import Section, SectionKind |
| 36 | from dlm_sway.core.stats import bootstrap_ci |
| 37 | from dlm_sway.probes._zscore import ( |
| 38 | no_calibration_note, |
| 39 | score_from_z, |
| 40 | verdict_from_z, |
| 41 | z_score, |
| 42 | z_scores_by_rank, |
| 43 | ) |
| 44 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 45 | from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank |
| 46 | |
| 47 | |
| 48 | def _default_include_kinds() -> list[SectionKind]: |
| 49 | return ["prose", "instruction", "preference"] |
| 50 | |
| 51 | |
| 52 | class SectionInternalizationSpec(ProbeSpec): |
| 53 | kind: Literal["section_internalization"] = "section_internalization" |
| 54 | include_kinds: list[SectionKind] = Field(default_factory=_default_include_kinds) |
| 55 | per_section_threshold: float = 0.05 |
| 56 | """Minimum ``effective_sis`` for a section to be marked PASS.""" |
| 57 | assert_passing_section_frac: float = 0.5 |
| 58 | """Probe-level pass criterion: fraction of sections that must clear |
| 59 | the per-section threshold.""" |
| 60 | assert_z_gte: float = 3.0 |
| 61 | """Z-score pass criterion against the null-adapter baseline, when it |
| 62 | exists. Preferred over the raw threshold. The statistic z-scored is |
| 63 | the mean ``effective_sis`` across sections.""" |
| 64 | max_prose_chars: int = 2000 |
| 65 | """Cap the length of PROSE content we score to keep runtime bounded. |
| 66 | Long sections are chunked; this is the per-chunk cap.""" |
| 67 | |
| 68 | |
| 69 | class SectionInternalizationProbe(Probe): |
| 70 | kind = "section_internalization" |
| 71 | spec_cls = SectionInternalizationSpec |
| 72 | category = "attribution" |
| 73 | |
| 74 | @classmethod |
| 75 | def calibrate_spec(cls, ctx: RunContext) -> SectionInternalizationSpec | None: |
| 76 | # Needs sections; if the bridge didn't populate them, opt out. |
| 77 | if ctx.sections is None or len(ctx.sections) < 2: |
| 78 | return None |
| 79 | return SectionInternalizationSpec( |
| 80 | name="_calibration", |
| 81 | kind="section_internalization", |
| 82 | per_section_threshold=0.05, |
| 83 | ) |
| 84 | |
| 85 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 86 | assert isinstance(spec, SectionInternalizationSpec) |
| 87 | if ctx.sections is None or len(ctx.sections) == 0: |
| 88 | return ProbeResult( |
| 89 | name=spec.name, |
| 90 | kind=spec.kind, |
| 91 | verdict=Verdict.SKIP, |
| 92 | score=None, |
| 93 | message="no sections in context — provide via the .dlm bridge", |
| 94 | ) |
| 95 | |
| 96 | kinds_allowed = set(spec.include_kinds) |
| 97 | eligible = [s for s in ctx.sections if s.kind in kinds_allowed] |
| 98 | if len(eligible) < 2: |
| 99 | return ProbeResult( |
| 100 | name=spec.name, |
| 101 | kind=spec.kind, |
| 102 | verdict=Verdict.SKIP, |
| 103 | score=None, |
| 104 | message=( |
| 105 | f"need ≥2 eligible sections for leak-check; got {len(eligible)} " |
| 106 | f"(kinds={spec.include_kinds})" |
| 107 | ), |
| 108 | ) |
| 109 | |
| 110 | # Pre-compute per-section base and ft NLL-per-token to avoid |
| 111 | # re-running the forward pass for leak-checks. |
| 112 | base_nll: dict[str, float] = {} |
| 113 | ft_nll: dict[str, float] = {} |
| 114 | with ctx.require_backend.as_base() as base_view: |
| 115 | for s in eligible: |
| 116 | base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars) |
| 117 | with ctx.require_backend.as_finetuned() as ft_view: |
| 118 | for s in eligible: |
| 119 | ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars) |
| 120 | |
| 121 | per_section: list[dict[str, float | str | bool]] = [] |
| 122 | passing = 0 |
| 123 | effective_scores: list[float] = [] |
| 124 | for s in eligible: |
| 125 | others = [o for o in eligible if o.id != s.id] |
| 126 | own_lift = _relative_lift(base_nll[s.id], ft_nll[s.id]) |
| 127 | leak_lift = statistics.fmean( |
| 128 | _relative_lift(base_nll[o.id], ft_nll[o.id]) for o in others |
| 129 | ) |
| 130 | effective = own_lift - leak_lift |
| 131 | effective_scores.append(effective) |
| 132 | did_pass = effective >= spec.per_section_threshold |
| 133 | passing += int(did_pass) |
| 134 | per_section.append( |
| 135 | { |
| 136 | "section_id": s.id, |
| 137 | "kind": s.kind, |
| 138 | "tag": s.tag or "", |
| 139 | "base_nll": base_nll[s.id], |
| 140 | "ft_nll": ft_nll[s.id], |
| 141 | "own_lift": own_lift, |
| 142 | "leak_lift": leak_lift, |
| 143 | "effective_sis": effective, |
| 144 | "passed": did_pass, |
| 145 | } |
| 146 | ) |
| 147 | |
| 148 | passing_frac = passing / len(eligible) |
| 149 | raw_mean = statistics.fmean(effective_scores) |
| 150 | ci_95 = bootstrap_ci(effective_scores, seed=ctx.seed) |
| 151 | |
| 152 | # Null-adapter calibration wins when available. |
| 153 | stats = get_null_stats(ctx, spec.kind) |
| 154 | z = z_score(raw_mean, stats) |
| 155 | z_by_rank = z_scores_by_rank(raw_mean, get_null_stats_by_rank(ctx, spec.kind), sign=+1) |
| 156 | verdict_z = verdict_from_z(z, spec.assert_z_gte) |
| 157 | if verdict_z is not None: |
| 158 | verdict = verdict_z |
| 159 | score_val = score_from_z(z) |
| 160 | score = score_val if score_val is not None else 0.0 |
| 161 | message = ( |
| 162 | f"{passing}/{len(eligible)} sections cleared; " |
| 163 | f"mean effective_sis={raw_mean:+.3f}, z={z:+.2f}σ vs null" |
| 164 | ) |
| 165 | else: |
| 166 | verdict = ( |
| 167 | Verdict.PASS if passing_frac >= spec.assert_passing_section_frac else Verdict.FAIL |
| 168 | ) |
| 169 | score = passing_frac |
| 170 | message = ( |
| 171 | f"{passing}/{len(eligible)} sections cleared " |
| 172 | f"effective_sis≥{spec.per_section_threshold:.2f} " |
| 173 | f"(mean={raw_mean:+.3f}) {no_calibration_note(spec.kind)}" |
| 174 | ) |
| 175 | |
| 176 | return safe_finalize( |
| 177 | name=spec.name, |
| 178 | kind=spec.kind, |
| 179 | verdict=verdict, |
| 180 | score=score, |
| 181 | raw=raw_mean, |
| 182 | z_score=z, |
| 183 | evidence={ |
| 184 | "per_section": per_section, |
| 185 | "num_sections": len(eligible), |
| 186 | "passing_frac": passing_frac, |
| 187 | "per_section_threshold": spec.per_section_threshold, |
| 188 | "weight": spec.weight, |
| 189 | "z_by_rank": z_by_rank, |
| 190 | "raw_ci_95": list(ci_95) if ci_95 is not None else None, |
| 191 | }, |
| 192 | message=message, |
| 193 | ci_95=ci_95, |
| 194 | ) |
| 195 | |
| 196 | |
| 197 | def _section_nll(s: Section, view: ScoringBackend, max_prose_chars: int) -> float: |
| 198 | """Average NLL per token for the section's content under ``view``.""" |
| 199 | if s.kind == "prose": |
| 200 | return _prose_nll(s.content[:max_prose_chars], view) |
| 201 | if s.kind == "instruction": |
| 202 | if not s.probes: |
| 203 | return _prose_nll(s.content[:max_prose_chars], view) |
| 204 | return statistics.fmean( |
| 205 | -view.logprob_of(p.prompt, p.gold) / max(_token_estimate(p.gold), 1) for p in s.probes |
| 206 | ) |
| 207 | if s.kind == "preference": |
| 208 | if not s.preferences: |
| 209 | return _prose_nll(s.content[:max_prose_chars], view) |
| 210 | return statistics.fmean( |
| 211 | -view.logprob_of(p.prompt, p.chosen) / max(_token_estimate(p.chosen), 1) |
| 212 | for p in s.preferences |
| 213 | ) |
| 214 | raise ValueError(f"unknown section kind: {s.kind!r}") |
| 215 | |
| 216 | |
| 217 | def _prose_nll(text: str, view: ScoringBackend) -> float: |
| 218 | """Negative-mean-logprob over ``text``. Returns 0 for empty input.""" |
| 219 | if not text.strip(): |
| 220 | return 0.0 |
| 221 | r = view.rolling_logprob(text) |
| 222 | return -r.mean_logprob |
| 223 | |
| 224 | |
| 225 | def _relative_lift(base_nll: float, ft_nll: float) -> float: |
| 226 | """``(base - ft) / base``. Positive → ft is lower-PPL than base. |
| 227 | |
| 228 | Falls back to an absolute delta when ``base`` is pathological |
| 229 | (zero or negative), so the probe doesn't crash on degenerate |
| 230 | inputs. |
| 231 | """ |
| 232 | if base_nll <= 0.0: |
| 233 | return float(base_nll - ft_nll) |
| 234 | return float((base_nll - ft_nll) / base_nll) |
| 235 | |
| 236 | |
| 237 | def _token_estimate(s: str) -> int: |
| 238 | """Approximate tokens for normalization. Good enough for SentencePiece-ish vocabs.""" |
| 239 | return max(1, len(s) // 4) |