| 1 | """B1 SectionInternalizationScore — the flagship attribution primitive. |
| 2 | |
| 3 | For each typed section of the training document, measure *how much the |
| 4 | fine-tune moved the needle on that section's own content* — and subtract |
| 5 | the same metric measured on *other* sections' content. The difference is |
| 6 | the "effective SIS": signal attributable to *this* section, not to a |
| 7 | broader lift across the whole document. |
| 8 | |
| 9 | Output is a per-section bar chart. In practice users see that sections |
| 10 | 2 and 7 actually moved the model, sections 3 and 5 did nothing, and |
| 11 | section 11 moved it but also leaked into unrelated content — actionable |
| 12 | signal for document authoring that no other eval tool provides. |
| 13 | |
| 14 | Math per section ``s`` with measurement function ``m(probe_set)``: |
| 15 | |
| 16 | .. math:: |
| 17 | sis_s^{own} &= (m_{base}(s) - m_{ft}(s)) / m_{base}(s) |
| 18 | sis_s^{leak} &= (m_{base}(\\bar s) - m_{ft}(\\bar s)) / m_{base}(\\bar s) |
| 19 | effective &= sis_s^{own} - sis_s^{leak} |
| 20 | |
| 21 | For PROSE sections, ``m`` is the average NLL per token over the |
| 22 | section's content. For INSTRUCTION and PREFERENCE sections, ``m`` is the |
| 23 | average NLL per token over the answer/chosen spans given their prompts. |
| 24 | """ |
| 25 | |
| 26 | from __future__ import annotations |
| 27 | |
| 28 | import statistics |
| 29 | from typing import Literal |
| 30 | |
| 31 | from pydantic import Field |
| 32 | |
| 33 | from dlm_sway.core.result import ProbeResult, Verdict |
| 34 | from dlm_sway.core.scoring import ScoringBackend |
| 35 | from dlm_sway.core.sections import Section, SectionKind |
| 36 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 37 | |
| 38 | |
| 39 | def _default_include_kinds() -> list[SectionKind]: |
| 40 | return ["prose", "instruction", "preference"] |
| 41 | |
| 42 | |
| 43 | class SectionInternalizationSpec(ProbeSpec): |
| 44 | kind: Literal["section_internalization"] = "section_internalization" |
| 45 | include_kinds: list[SectionKind] = Field(default_factory=_default_include_kinds) |
| 46 | per_section_threshold: float = 0.05 |
| 47 | """Minimum ``effective_sis`` for a section to be marked PASS.""" |
| 48 | assert_passing_section_frac: float = 0.5 |
| 49 | """Probe-level pass criterion: fraction of sections that must clear |
| 50 | the per-section threshold.""" |
| 51 | max_prose_chars: int = 2000 |
| 52 | """Cap the length of PROSE content we score to keep runtime bounded. |
| 53 | Long sections are chunked; this is the per-chunk cap.""" |
| 54 | |
| 55 | |
| 56 | class SectionInternalizationProbe(Probe): |
| 57 | kind = "section_internalization" |
| 58 | spec_cls = SectionInternalizationSpec |
| 59 | category = "attribution" |
| 60 | |
| 61 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 62 | assert isinstance(spec, SectionInternalizationSpec) |
| 63 | if ctx.sections is None or len(ctx.sections) == 0: |
| 64 | return ProbeResult( |
| 65 | name=spec.name, |
| 66 | kind=spec.kind, |
| 67 | verdict=Verdict.SKIP, |
| 68 | score=None, |
| 69 | message="no sections in context — provide via the .dlm bridge", |
| 70 | ) |
| 71 | |
| 72 | kinds_allowed = set(spec.include_kinds) |
| 73 | eligible = [s for s in ctx.sections if s.kind in kinds_allowed] |
| 74 | if len(eligible) < 2: |
| 75 | return ProbeResult( |
| 76 | name=spec.name, |
| 77 | kind=spec.kind, |
| 78 | verdict=Verdict.SKIP, |
| 79 | score=None, |
| 80 | message=( |
| 81 | f"need ≥2 eligible sections for leak-check; got {len(eligible)} " |
| 82 | f"(kinds={spec.include_kinds})" |
| 83 | ), |
| 84 | ) |
| 85 | |
| 86 | # Pre-compute per-section base and ft NLL-per-token to avoid |
| 87 | # re-running the forward pass for leak-checks. |
| 88 | base_nll: dict[str, float] = {} |
| 89 | ft_nll: dict[str, float] = {} |
| 90 | with ctx.backend.as_base() as base_view: |
| 91 | for s in eligible: |
| 92 | base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars) |
| 93 | with ctx.backend.as_finetuned() as ft_view: |
| 94 | for s in eligible: |
| 95 | ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars) |
| 96 | |
| 97 | per_section: list[dict[str, float | str | bool]] = [] |
| 98 | passing = 0 |
| 99 | effective_scores: list[float] = [] |
| 100 | for s in eligible: |
| 101 | others = [o for o in eligible if o.id != s.id] |
| 102 | own_lift = _relative_lift(base_nll[s.id], ft_nll[s.id]) |
| 103 | leak_lift = statistics.fmean( |
| 104 | _relative_lift(base_nll[o.id], ft_nll[o.id]) for o in others |
| 105 | ) |
| 106 | effective = own_lift - leak_lift |
| 107 | effective_scores.append(effective) |
| 108 | did_pass = effective >= spec.per_section_threshold |
| 109 | passing += int(did_pass) |
| 110 | per_section.append( |
| 111 | { |
| 112 | "section_id": s.id, |
| 113 | "kind": s.kind, |
| 114 | "tag": s.tag or "", |
| 115 | "base_nll": base_nll[s.id], |
| 116 | "ft_nll": ft_nll[s.id], |
| 117 | "own_lift": own_lift, |
| 118 | "leak_lift": leak_lift, |
| 119 | "effective_sis": effective, |
| 120 | "passed": did_pass, |
| 121 | } |
| 122 | ) |
| 123 | |
| 124 | passing_frac = passing / len(eligible) |
| 125 | verdict = Verdict.PASS if passing_frac >= spec.assert_passing_section_frac else Verdict.FAIL |
| 126 | score = passing_frac |
| 127 | return ProbeResult( |
| 128 | name=spec.name, |
| 129 | kind=spec.kind, |
| 130 | verdict=verdict, |
| 131 | score=score, |
| 132 | raw=statistics.fmean(effective_scores), |
| 133 | evidence={ |
| 134 | "per_section": per_section, |
| 135 | "num_sections": len(eligible), |
| 136 | "passing_frac": passing_frac, |
| 137 | "per_section_threshold": spec.per_section_threshold, |
| 138 | "weight": spec.weight, |
| 139 | }, |
| 140 | message=( |
| 141 | f"{passing}/{len(eligible)} sections cleared " |
| 142 | f"effective_sis≥{spec.per_section_threshold:.2f} (mean={statistics.fmean(effective_scores):+.3f})" |
| 143 | ), |
| 144 | ) |
| 145 | |
| 146 | |
| 147 | def _section_nll(s: Section, view: ScoringBackend, max_prose_chars: int) -> float: |
| 148 | """Average NLL per token for the section's content under ``view``.""" |
| 149 | if s.kind == "prose": |
| 150 | return _prose_nll(s.content[:max_prose_chars], view) |
| 151 | if s.kind == "instruction": |
| 152 | if not s.probes: |
| 153 | return _prose_nll(s.content[:max_prose_chars], view) |
| 154 | return statistics.fmean( |
| 155 | -view.logprob_of(p.prompt, p.gold) / max(_token_estimate(p.gold), 1) for p in s.probes |
| 156 | ) |
| 157 | if s.kind == "preference": |
| 158 | if not s.preferences: |
| 159 | return _prose_nll(s.content[:max_prose_chars], view) |
| 160 | return statistics.fmean( |
| 161 | -view.logprob_of(p.prompt, p.chosen) / max(_token_estimate(p.chosen), 1) |
| 162 | for p in s.preferences |
| 163 | ) |
| 164 | raise ValueError(f"unknown section kind: {s.kind!r}") |
| 165 | |
| 166 | |
| 167 | def _prose_nll(text: str, view: ScoringBackend) -> float: |
| 168 | """Negative-mean-logprob over ``text``. Returns 0 for empty input.""" |
| 169 | if not text.strip(): |
| 170 | return 0.0 |
| 171 | r = view.rolling_logprob(text) |
| 172 | return -r.mean_logprob |
| 173 | |
| 174 | |
| 175 | def _relative_lift(base_nll: float, ft_nll: float) -> float: |
| 176 | """``(base - ft) / base``. Positive → ft is lower-PPL than base. |
| 177 | |
| 178 | Falls back to an absolute delta when ``base`` is pathological |
| 179 | (zero or negative), so the probe doesn't crash on degenerate |
| 180 | inputs. |
| 181 | """ |
| 182 | if base_nll <= 0.0: |
| 183 | return float(base_nll - ft_nll) |
| 184 | return float((base_nll - ft_nll) / base_nll) |
| 185 | |
| 186 | |
| 187 | def _token_estimate(s: str) -> int: |
| 188 | """Approximate tokens for normalization. Good enough for SentencePiece-ish vocabs.""" |
| 189 | return max(1, len(s) // 4) |