@@ -0,0 +1,189 @@ |
| 1 | +"""B1 SectionInternalizationScore — the flagship attribution primitive. |
| 2 | + |
| 3 | +For each typed section of the training document, measure *how much the |
| 4 | +fine-tune moved the needle on that section's own content* — and subtract |
| 5 | +the same metric measured on *other* sections' content. The difference is |
| 6 | +the "effective SIS": signal attributable to *this* section, not to a |
| 7 | +broader lift across the whole document. |
| 8 | + |
| 9 | +Output is a per-section bar chart. In practice users see that sections |
| 10 | +2 and 7 actually moved the model, sections 3 and 5 did nothing, and |
| 11 | +section 11 moved it but also leaked into unrelated content — actionable |
| 12 | +signal for document authoring that no other eval tool provides. |
| 13 | + |
| 14 | +Math per section ``s`` with measurement function ``m(probe_set)``: |
| 15 | + |
| 16 | +.. math:: |
| 17 | + sis_s^{own} &= (m_{base}(s) - m_{ft}(s)) / m_{base}(s) |
| 18 | + sis_s^{leak} &= (m_{base}(\\bar s) - m_{ft}(\\bar s)) / m_{base}(\\bar s) |
| 19 | + effective &= sis_s^{own} - sis_s^{leak} |
| 20 | + |
| 21 | +For PROSE sections, ``m`` is the average NLL per token over the |
| 22 | +section's content. For INSTRUCTION and PREFERENCE sections, ``m`` is the |
| 23 | +average NLL per token over the answer/chosen spans given their prompts. |
| 24 | +""" |
| 25 | + |
| 26 | +from __future__ import annotations |
| 27 | + |
| 28 | +import statistics |
| 29 | +from typing import Literal |
| 30 | + |
| 31 | +from pydantic import Field |
| 32 | + |
| 33 | +from dlm_sway.core.result import ProbeResult, Verdict |
| 34 | +from dlm_sway.core.scoring import ScoringBackend |
| 35 | +from dlm_sway.core.sections import Section, SectionKind |
| 36 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 37 | + |
| 38 | + |
| 39 | +def _default_include_kinds() -> list[SectionKind]: |
| 40 | + return ["prose", "instruction", "preference"] |
| 41 | + |
| 42 | + |
| 43 | +class SectionInternalizationSpec(ProbeSpec): |
| 44 | + kind: Literal["section_internalization"] = "section_internalization" |
| 45 | + include_kinds: list[SectionKind] = Field(default_factory=_default_include_kinds) |
| 46 | + per_section_threshold: float = 0.05 |
| 47 | + """Minimum ``effective_sis`` for a section to be marked PASS.""" |
| 48 | + assert_passing_section_frac: float = 0.5 |
| 49 | + """Probe-level pass criterion: fraction of sections that must clear |
| 50 | + the per-section threshold.""" |
| 51 | + max_prose_chars: int = 2000 |
| 52 | + """Cap the length of PROSE content we score to keep runtime bounded. |
| 53 | + Long sections are chunked; this is the per-chunk cap.""" |
| 54 | + |
| 55 | + |
| 56 | +class SectionInternalizationProbe(Probe): |
| 57 | + kind = "section_internalization" |
| 58 | + spec_cls = SectionInternalizationSpec |
| 59 | + category = "attribution" |
| 60 | + |
| 61 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 62 | + assert isinstance(spec, SectionInternalizationSpec) |
| 63 | + if ctx.sections is None or len(ctx.sections) == 0: |
| 64 | + return ProbeResult( |
| 65 | + name=spec.name, |
| 66 | + kind=spec.kind, |
| 67 | + verdict=Verdict.SKIP, |
| 68 | + score=None, |
| 69 | + message="no sections in context — provide via the .dlm bridge", |
| 70 | + ) |
| 71 | + |
| 72 | + kinds_allowed = set(spec.include_kinds) |
| 73 | + eligible = [s for s in ctx.sections if s.kind in kinds_allowed] |
| 74 | + if len(eligible) < 2: |
| 75 | + return ProbeResult( |
| 76 | + name=spec.name, |
| 77 | + kind=spec.kind, |
| 78 | + verdict=Verdict.SKIP, |
| 79 | + score=None, |
| 80 | + message=( |
| 81 | + f"need ≥2 eligible sections for leak-check; got {len(eligible)} " |
| 82 | + f"(kinds={spec.include_kinds})" |
| 83 | + ), |
| 84 | + ) |
| 85 | + |
| 86 | + # Pre-compute per-section base and ft NLL-per-token to avoid |
| 87 | + # re-running the forward pass for leak-checks. |
| 88 | + base_nll: dict[str, float] = {} |
| 89 | + ft_nll: dict[str, float] = {} |
| 90 | + with ctx.backend.as_base() as base_view: |
| 91 | + for s in eligible: |
| 92 | + base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars) |
| 93 | + with ctx.backend.as_finetuned() as ft_view: |
| 94 | + for s in eligible: |
| 95 | + ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars) |
| 96 | + |
| 97 | + per_section: list[dict[str, float | str | bool]] = [] |
| 98 | + passing = 0 |
| 99 | + effective_scores: list[float] = [] |
| 100 | + for s in eligible: |
| 101 | + others = [o for o in eligible if o.id != s.id] |
| 102 | + own_lift = _relative_lift(base_nll[s.id], ft_nll[s.id]) |
| 103 | + leak_lift = statistics.fmean( |
| 104 | + _relative_lift(base_nll[o.id], ft_nll[o.id]) for o in others |
| 105 | + ) |
| 106 | + effective = own_lift - leak_lift |
| 107 | + effective_scores.append(effective) |
| 108 | + did_pass = effective >= spec.per_section_threshold |
| 109 | + passing += int(did_pass) |
| 110 | + per_section.append( |
| 111 | + { |
| 112 | + "section_id": s.id, |
| 113 | + "kind": s.kind, |
| 114 | + "tag": s.tag or "", |
| 115 | + "base_nll": base_nll[s.id], |
| 116 | + "ft_nll": ft_nll[s.id], |
| 117 | + "own_lift": own_lift, |
| 118 | + "leak_lift": leak_lift, |
| 119 | + "effective_sis": effective, |
| 120 | + "passed": did_pass, |
| 121 | + } |
| 122 | + ) |
| 123 | + |
| 124 | + passing_frac = passing / len(eligible) |
| 125 | + verdict = Verdict.PASS if passing_frac >= spec.assert_passing_section_frac else Verdict.FAIL |
| 126 | + score = passing_frac |
| 127 | + return ProbeResult( |
| 128 | + name=spec.name, |
| 129 | + kind=spec.kind, |
| 130 | + verdict=verdict, |
| 131 | + score=score, |
| 132 | + raw=statistics.fmean(effective_scores), |
| 133 | + evidence={ |
| 134 | + "per_section": per_section, |
| 135 | + "num_sections": len(eligible), |
| 136 | + "passing_frac": passing_frac, |
| 137 | + "per_section_threshold": spec.per_section_threshold, |
| 138 | + "weight": spec.weight, |
| 139 | + }, |
| 140 | + message=( |
| 141 | + f"{passing}/{len(eligible)} sections cleared " |
| 142 | + f"effective_sis≥{spec.per_section_threshold:.2f} (mean={statistics.fmean(effective_scores):+.3f})" |
| 143 | + ), |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +def _section_nll(s: Section, view: ScoringBackend, max_prose_chars: int) -> float: |
| 148 | + """Average NLL per token for the section's content under ``view``.""" |
| 149 | + if s.kind == "prose": |
| 150 | + return _prose_nll(s.content[:max_prose_chars], view) |
| 151 | + if s.kind == "instruction": |
| 152 | + if not s.probes: |
| 153 | + return _prose_nll(s.content[:max_prose_chars], view) |
| 154 | + return statistics.fmean( |
| 155 | + -view.logprob_of(p.prompt, p.gold) / max(_token_estimate(p.gold), 1) for p in s.probes |
| 156 | + ) |
| 157 | + if s.kind == "preference": |
| 158 | + if not s.preferences: |
| 159 | + return _prose_nll(s.content[:max_prose_chars], view) |
| 160 | + return statistics.fmean( |
| 161 | + -view.logprob_of(p.prompt, p.chosen) / max(_token_estimate(p.chosen), 1) |
| 162 | + for p in s.preferences |
| 163 | + ) |
| 164 | + raise ValueError(f"unknown section kind: {s.kind!r}") |
| 165 | + |
| 166 | + |
| 167 | +def _prose_nll(text: str, view: ScoringBackend) -> float: |
| 168 | + """Negative-mean-logprob over ``text``. Returns 0 for empty input.""" |
| 169 | + if not text.strip(): |
| 170 | + return 0.0 |
| 171 | + r = view.rolling_logprob(text) |
| 172 | + return -r.mean_logprob |
| 173 | + |
| 174 | + |
| 175 | +def _relative_lift(base_nll: float, ft_nll: float) -> float: |
| 176 | + """``(base - ft) / base``. Positive → ft is lower-PPL than base. |
| 177 | + |
| 178 | + Falls back to an absolute delta when ``base`` is pathological |
| 179 | + (zero or negative), so the probe doesn't crash on degenerate |
| 180 | + inputs. |
| 181 | + """ |
| 182 | + if base_nll <= 0.0: |
| 183 | + return float(base_nll - ft_nll) |
| 184 | + return float((base_nll - ft_nll) / base_nll) |
| 185 | + |
| 186 | + |
| 187 | +def _token_estimate(s: str) -> int: |
| 188 | + """Approximate tokens for normalization. Good enough for SentencePiece-ish vocabs.""" |
| 189 | + return max(1, len(s) // 4) |