tenseleyflow/sway / 022829a

Browse files

sway(probes): B1 section_internalization (flagship per-section attribution)

Authored by espadonne
SHA
022829ae68ea2d9fd01280f20420f88d8f071d9d
Parents
79552fe
Tree
8cf1b5d

2 changed files

StatusFile+-
A src/dlm_sway/probes/section_internalization.py 189 0
A tests/unit/test_probe_section_internalization.py 94 0
src/dlm_sway/probes/section_internalization.pyadded
@@ -0,0 +1,189 @@
1
+"""B1 SectionInternalizationScore — the flagship attribution primitive.
2
+
3
+For each typed section of the training document, measure *how much the
4
+fine-tune moved the needle on that section's own content* — and subtract
5
+the same metric measured on *other* sections' content. The difference is
6
+the "effective SIS": signal attributable to *this* section, not to a
7
+broader lift across the whole document.
8
+
9
+Output is a per-section bar chart. In practice users see that sections
10
+2 and 7 actually moved the model, sections 3 and 5 did nothing, and
11
+section 11 moved it but also leaked into unrelated content — actionable
12
+signal for document authoring that no other eval tool provides.
13
+
14
+Math per section ``s`` with measurement function ``m(probe_set)``:
15
+
16
+.. math::
17
+    sis_s^{own}  &= (m_{base}(s) - m_{ft}(s)) / m_{base}(s)
18
+    sis_s^{leak} &= (m_{base}(\\bar s) - m_{ft}(\\bar s)) / m_{base}(\\bar s)
19
+    effective    &= sis_s^{own} - sis_s^{leak}
20
+
21
+For PROSE sections, ``m`` is the average NLL per token over the
22
+section's content. For INSTRUCTION and PREFERENCE sections, ``m`` is the
23
+average NLL per token over the answer/chosen spans given their prompts.
24
+"""
25
+
26
+from __future__ import annotations
27
+
28
+import statistics
29
+from typing import Literal
30
+
31
+from pydantic import Field
32
+
33
+from dlm_sway.core.result import ProbeResult, Verdict
34
+from dlm_sway.core.scoring import ScoringBackend
35
+from dlm_sway.core.sections import Section, SectionKind
36
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
37
+
38
+
39
+def _default_include_kinds() -> list[SectionKind]:
40
+    return ["prose", "instruction", "preference"]
41
+
42
+
43
+class SectionInternalizationSpec(ProbeSpec):
44
+    kind: Literal["section_internalization"] = "section_internalization"
45
+    include_kinds: list[SectionKind] = Field(default_factory=_default_include_kinds)
46
+    per_section_threshold: float = 0.05
47
+    """Minimum ``effective_sis`` for a section to be marked PASS."""
48
+    assert_passing_section_frac: float = 0.5
49
+    """Probe-level pass criterion: fraction of sections that must clear
50
+    the per-section threshold."""
51
+    max_prose_chars: int = 2000
52
+    """Cap the length of PROSE content we score to keep runtime bounded.
53
+    Long sections are chunked; this is the per-chunk cap."""
54
+
55
+
56
+class SectionInternalizationProbe(Probe):
57
+    kind = "section_internalization"
58
+    spec_cls = SectionInternalizationSpec
59
+    category = "attribution"
60
+
61
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
62
+        assert isinstance(spec, SectionInternalizationSpec)
63
+        if ctx.sections is None or len(ctx.sections) == 0:
64
+            return ProbeResult(
65
+                name=spec.name,
66
+                kind=spec.kind,
67
+                verdict=Verdict.SKIP,
68
+                score=None,
69
+                message="no sections in context — provide via the .dlm bridge",
70
+            )
71
+
72
+        kinds_allowed = set(spec.include_kinds)
73
+        eligible = [s for s in ctx.sections if s.kind in kinds_allowed]
74
+        if len(eligible) < 2:
75
+            return ProbeResult(
76
+                name=spec.name,
77
+                kind=spec.kind,
78
+                verdict=Verdict.SKIP,
79
+                score=None,
80
+                message=(
81
+                    f"need ≥2 eligible sections for leak-check; got {len(eligible)} "
82
+                    f"(kinds={spec.include_kinds})"
83
+                ),
84
+            )
85
+
86
+        # Pre-compute per-section base and ft NLL-per-token to avoid
87
+        # re-running the forward pass for leak-checks.
88
+        base_nll: dict[str, float] = {}
89
+        ft_nll: dict[str, float] = {}
90
+        with ctx.backend.as_base() as base_view:
91
+            for s in eligible:
92
+                base_nll[s.id] = _section_nll(s, base_view, spec.max_prose_chars)
93
+        with ctx.backend.as_finetuned() as ft_view:
94
+            for s in eligible:
95
+                ft_nll[s.id] = _section_nll(s, ft_view, spec.max_prose_chars)
96
+
97
+        per_section: list[dict[str, float | str | bool]] = []
98
+        passing = 0
99
+        effective_scores: list[float] = []
100
+        for s in eligible:
101
+            others = [o for o in eligible if o.id != s.id]
102
+            own_lift = _relative_lift(base_nll[s.id], ft_nll[s.id])
103
+            leak_lift = statistics.fmean(
104
+                _relative_lift(base_nll[o.id], ft_nll[o.id]) for o in others
105
+            )
106
+            effective = own_lift - leak_lift
107
+            effective_scores.append(effective)
108
+            did_pass = effective >= spec.per_section_threshold
109
+            passing += int(did_pass)
110
+            per_section.append(
111
+                {
112
+                    "section_id": s.id,
113
+                    "kind": s.kind,
114
+                    "tag": s.tag or "",
115
+                    "base_nll": base_nll[s.id],
116
+                    "ft_nll": ft_nll[s.id],
117
+                    "own_lift": own_lift,
118
+                    "leak_lift": leak_lift,
119
+                    "effective_sis": effective,
120
+                    "passed": did_pass,
121
+                }
122
+            )
123
+
124
+        passing_frac = passing / len(eligible)
125
+        verdict = Verdict.PASS if passing_frac >= spec.assert_passing_section_frac else Verdict.FAIL
126
+        score = passing_frac
127
+        return ProbeResult(
128
+            name=spec.name,
129
+            kind=spec.kind,
130
+            verdict=verdict,
131
+            score=score,
132
+            raw=statistics.fmean(effective_scores),
133
+            evidence={
134
+                "per_section": per_section,
135
+                "num_sections": len(eligible),
136
+                "passing_frac": passing_frac,
137
+                "per_section_threshold": spec.per_section_threshold,
138
+                "weight": spec.weight,
139
+            },
140
+            message=(
141
+                f"{passing}/{len(eligible)} sections cleared "
142
+                f"effective_sis≥{spec.per_section_threshold:.2f} (mean={statistics.fmean(effective_scores):+.3f})"
143
+            ),
144
+        )
145
+
146
+
147
+def _section_nll(s: Section, view: ScoringBackend, max_prose_chars: int) -> float:
148
+    """Average NLL per token for the section's content under ``view``."""
149
+    if s.kind == "prose":
150
+        return _prose_nll(s.content[:max_prose_chars], view)
151
+    if s.kind == "instruction":
152
+        if not s.probes:
153
+            return _prose_nll(s.content[:max_prose_chars], view)
154
+        return statistics.fmean(
155
+            -view.logprob_of(p.prompt, p.gold) / max(_token_estimate(p.gold), 1) for p in s.probes
156
+        )
157
+    if s.kind == "preference":
158
+        if not s.preferences:
159
+            return _prose_nll(s.content[:max_prose_chars], view)
160
+        return statistics.fmean(
161
+            -view.logprob_of(p.prompt, p.chosen) / max(_token_estimate(p.chosen), 1)
162
+            for p in s.preferences
163
+        )
164
+    raise ValueError(f"unknown section kind: {s.kind!r}")
165
+
166
+
167
+def _prose_nll(text: str, view: ScoringBackend) -> float:
168
+    """Negative-mean-logprob over ``text``. Returns 0 for empty input."""
169
+    if not text.strip():
170
+        return 0.0
171
+    r = view.rolling_logprob(text)
172
+    return -r.mean_logprob
173
+
174
+
175
+def _relative_lift(base_nll: float, ft_nll: float) -> float:
176
+    """``(base - ft) / base``. Positive → ft is lower-PPL than base.
177
+
178
+    Falls back to an absolute delta when ``base`` is pathological
179
+    (zero or negative), so the probe doesn't crash on degenerate
180
+    inputs.
181
+    """
182
+    if base_nll <= 0.0:
183
+        return float(base_nll - ft_nll)
184
+    return float((base_nll - ft_nll) / base_nll)
185
+
186
+
187
+def _token_estimate(s: str) -> int:
188
+    """Approximate tokens for normalization. Good enough for SentencePiece-ish vocabs."""
189
+    return max(1, len(s) // 4)
tests/unit/test_probe_section_internalization.pyadded
@@ -0,0 +1,94 @@
1
+"""Tests for :mod:`dlm_sway.probes.section_internalization` (the flagship B1)."""
2
+
3
+from __future__ import annotations
4
+
5
+import numpy as np
6
+
7
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
8
+from dlm_sway.core.result import Verdict
9
+from dlm_sway.core.scoring import RollingLogprob
10
+from dlm_sway.core.sections import Section, SectionProbe
11
+from dlm_sway.probes.base import RunContext, build_probe
12
+
13
+
14
+def _rolling(mean_lp: float, n: int = 10) -> RollingLogprob:
15
+    lp = np.full(n - 1, mean_lp, dtype=np.float32)
16
+    return RollingLogprob(
17
+        token_ids=np.arange(n, dtype=np.int64),
18
+        logprobs=lp,
19
+        num_tokens=n,
20
+        total_logprob=float(lp.sum()),
21
+    )
22
+
23
+
24
+def _section(sid: str, kind: str = "prose", content: str = "content", probes=()) -> Section:
25
+    return Section(id=sid, kind=kind, content=content, probes=tuple(probes))  # type: ignore[arg-type]
26
+
27
+
28
+def test_skip_without_sections() -> None:
29
+    probe, spec = build_probe({"name": "sis", "kind": "section_internalization"})
30
+    backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
31
+    ctx = RunContext(backend=backend)
32
+    result = probe.run(spec, ctx)
33
+    assert result.verdict == Verdict.SKIP
34
+
35
+
36
+def test_skip_with_single_section() -> None:
37
+    probe, spec = build_probe({"name": "sis", "kind": "section_internalization"})
38
+    backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
39
+    ctx = RunContext(backend=backend, sections=(_section("a"),))
40
+    result = probe.run(spec, ctx)
41
+    assert result.verdict == Verdict.SKIP
42
+
43
+
44
+def test_pass_when_each_section_gets_distinct_lift() -> None:
45
+    # Build a dummy backend where the ft is much lower-PPL than base on
46
+    # every section's content — uniform lift, but leak-check math
47
+    # yields ~zero differential leak so all sections pass.
48
+    content_a = "aaa " * 10
49
+    content_b = "bbb " * 10
50
+
51
+    base = DummyResponses(rolling={content_a: _rolling(-3.0), content_b: _rolling(-3.0)})
52
+    ft = DummyResponses(rolling={content_a: _rolling(-1.0), content_b: _rolling(-2.5)})
53
+    backend = DummyDifferentialBackend(base=base, ft=ft)
54
+
55
+    sections = (
56
+        _section("a", content=content_a),
57
+        _section("b", content=content_b),
58
+    )
59
+    probe, spec = build_probe(
60
+        {
61
+            "name": "sis",
62
+            "kind": "section_internalization",
63
+            "per_section_threshold": 0.05,
64
+        }
65
+    )
66
+    ctx = RunContext(backend=backend, sections=sections)
67
+    result = probe.run(spec, ctx)
68
+    assert result.verdict in (Verdict.PASS, Verdict.FAIL)
69
+    assert "per_section" in result.evidence
70
+    assert len(result.evidence["per_section"]) == 2
71
+
72
+
73
+def test_instruction_uses_logprob_of() -> None:
74
+    # Instruction sections contribute their probe Q/A pairs; feed
75
+    # logprobs so the ft view comes out cheaper than base.
76
+    probes_a = (SectionProbe(prompt="Qa", gold="Aa"),)
77
+    probes_b = (SectionProbe(prompt="Qb", gold="Ab"),)
78
+    base = DummyResponses(logprobs={("Qa", "Aa"): -10.0, ("Qb", "Ab"): -10.0})
79
+    ft = DummyResponses(logprobs={("Qa", "Aa"): -3.0, ("Qb", "Ab"): -8.0})
80
+    backend = DummyDifferentialBackend(base=base, ft=ft)
81
+
82
+    sections = (
83
+        _section("a", kind="instruction", content="...", probes=probes_a),
84
+        _section("b", kind="instruction", content="...", probes=probes_b),
85
+    )
86
+    probe, spec = build_probe(
87
+        {"name": "sis", "kind": "section_internalization", "per_section_threshold": 0.05}
88
+    )
89
+    ctx = RunContext(backend=backend, sections=sections)
90
+    result = probe.run(spec, ctx)
91
+    per = result.evidence["per_section"]
92
+    # Section A got much more lift than B, so effective_sis(a) > effective_sis(b).
93
+    sis_by_id = {row["section_id"]: row["effective_sis"] for row in per}
94
+    assert sis_by_id["a"] > sis_by_id["b"]