tenseleyflow/sway / cb92c9f

Browse files

sway(probes): C2 calibration_drift + 30-item built-in general-knowledge pack

Authored by espadonne
SHA
cb92c9f865859e1a2dbb6a0686ad8db75aa52e41
Parents
f0b5d50
Tree
be99664

3 changed files

StatusFile+-
A src/dlm_sway/probes/_calibration_pack.py 63 0
A src/dlm_sway/probes/calibration_drift.py 135 0
A tests/unit/test_probe_calibration_drift.py 57 0
src/dlm_sway/probes/_calibration_pack.pyadded
@@ -0,0 +1,63 @@
1
+"""A small, built-in general-knowledge probe pack for C2.
2
+
3
+Each item is a ``(prompt, gold)`` pair where ``gold`` is the next few
4
+tokens a competent base model should assign high probability to. The
5
+items are deliberately *factually trivial* — the point isn't "does the
6
+model know this?" but "did the fine-tune forget this?" — so the pack
7
+skews toward grade-school geography, chemistry, arithmetic, and
8
+high-frequency idiom.
9
+
10
+A real v1.0 will ship a 200-item pack sliced from TriviaQA + SQuAD +
11
+OpenBookQA. This 30-item seed lets the probe ship today and catches the
12
+most egregious over-fit cases.
13
+"""
14
+
15
+from __future__ import annotations
16
+
17
+from typing import Final
18
+
19
+CalibrationItem = tuple[str, str]
20
+
21
+BUILT_IN_PACK: Final[tuple[CalibrationItem, ...]] = (
22
+    # Geography
23
+    ("The capital of France is", " Paris"),
24
+    ("The capital of Japan is", " Tokyo"),
25
+    ("The largest ocean on Earth is the", " Pacific"),
26
+    ("Mount Everest is located on the border of Nepal and", " China"),
27
+    ("The longest river in South America is the", " Amazon"),
28
+    # Natural sciences
29
+    ("Water freezes at zero degrees", " Celsius"),
30
+    ("The chemical symbol for gold is", " Au"),
31
+    ("Light travels faster than", " sound"),
32
+    ("Plants convert sunlight into energy through", " photosynthesis"),
33
+    ("The Earth orbits around the", " Sun"),
34
+    # Arithmetic
35
+    ("Two plus two equals", " four"),
36
+    ("Ten times ten equals", " one hundred"),
37
+    ("Half of one hundred is", " fifty"),
38
+    ("A dozen means", " twelve"),
39
+    # Language and idiom
40
+    ("A rose by any other name would smell as", " sweet"),
41
+    ("To be or not to be, that is the", " question"),
42
+    ("The early bird catches the", " worm"),
43
+    ("Actions speak louder than", " words"),
44
+    ("A picture is worth a thousand", " words"),
45
+    # History
46
+    ("World War II ended in the year", " 1945"),
47
+    ("The first president of the United States was", " George Washington"),
48
+    ("The Berlin Wall fell in", " 1989"),
49
+    # Biology
50
+    ("Humans have twenty", " fingers and toes"),
51
+    ("The human body has two", " lungs"),
52
+    ("Blood is pumped through the body by the", " heart"),
53
+    # Technology
54
+    ("HTML stands for HyperText", " Markup Language"),
55
+    ("The World Wide Web was invented by Tim", " Berners-Lee"),
56
+    # Miscellaneous
57
+    ("One year has", " 365 days"),
58
+    ("A week has seven", " days"),
59
+    ("There are seven colors in a", " rainbow"),
60
+)
61
+"""30 items covering geography, science, arithmetic, language, history,
62
+biology, and technology. Pulled from public-domain grade-school facts so
63
+there's no licensing concern about shipping with the wheel."""
src/dlm_sway/probes/calibration_drift.pyadded
@@ -0,0 +1,135 @@
1
+"""C2 CalibrationDrift — did we break general knowledge while fitting the doc?
2
+
3
+The classic small-doc fine-tune failure mode: the adapter learned the
4
+document so well that it forgot the world. C2 catches this by scoring
5
+base and ft on a packaged set of general-knowledge completions (the
6
+``BUILT_IN_PACK`` — a 30-item seed of public-domain grade-school facts)
7
+and flagging items whose per-token logprob regressed significantly.
8
+
9
+A healthy fine-tune: some items drift slightly (mild confidence shift,
10
+normal), but essentially none regress below a nat of slack. An over-fit
11
+fine-tune: 20%+ of items regress, the adapter has torched its ability
12
+to answer anything outside the document.
13
+
14
+Pass when ``fraction_regressed < assert_fraction_regressed_lt`` AND
15
+``mean_delta_nats >= assert_mean_delta_gte``. Both thresholds default
16
+to values that trigger on genuine damage but tolerate normal drift.
17
+"""
18
+
19
+from __future__ import annotations
20
+
21
+import statistics
22
+from typing import Literal
23
+
24
+from pydantic import Field
25
+
26
+from dlm_sway.core.result import ProbeResult, Verdict
27
+from dlm_sway.probes._calibration_pack import BUILT_IN_PACK
28
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
29
+
30
+
31
+class CalibrationItemSpec(ProbeSpec):
32
+    """Not used directly — documents the shape of an item override."""
33
+
34
+    kind: Literal["__calibration_item"] = "__calibration_item"
35
+    prompt: str = ""
36
+    gold: str = ""
37
+
38
+
39
+class CalibrationDriftSpec(ProbeSpec):
40
+    kind: Literal["calibration_drift"] = "calibration_drift"
41
+    pack: Literal["builtin"] = "builtin"
42
+    """Source of items. ``"builtin"`` uses :data:`BUILT_IN_PACK`. Custom
43
+    packs will ship via a file reference in a later milestone."""
44
+    items_limit: int | None = None
45
+    """If set, truncate the pack to this many items (for fast runs)."""
46
+    assert_fraction_regressed_lt: float = 0.15
47
+    assert_mean_delta_gte: float = -0.5
48
+    """Mean per-token logprob delta (ft − base) across the pack. Slightly
49
+    negative is tolerable; deeply negative is not."""
50
+    regression_nats: float = 1.0
51
+    """How many nats worse an item must get to count as regressed."""
52
+    items: list[tuple[str, str]] = Field(default_factory=list)
53
+    """Optional inline override of the packaged items."""
54
+
55
+
56
+class CalibrationDriftProbe(Probe):
57
+    kind = "calibration_drift"
58
+    spec_cls = CalibrationDriftSpec
59
+    category = "calibration"
60
+
61
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
62
+        assert isinstance(spec, CalibrationDriftSpec)
63
+        items = list(spec.items) if spec.items else list(BUILT_IN_PACK)
64
+        if spec.items_limit is not None:
65
+            items = items[: spec.items_limit]
66
+        if not items:
67
+            return ProbeResult(
68
+                name=spec.name,
69
+                kind=spec.kind,
70
+                verdict=Verdict.ERROR,
71
+                score=None,
72
+                message="no calibration items",
73
+            )
74
+
75
+        deltas: list[float] = []
76
+        regressed = 0
77
+        worst: list[dict[str, float | str]] = []
78
+
79
+        for prompt, gold in items:
80
+            tokens = max(_token_estimate(gold), 1)
81
+            with ctx.backend.as_base() as b:
82
+                lp_base = b.logprob_of(prompt, gold) / tokens
83
+            with ctx.backend.as_finetuned() as f:
84
+                lp_ft = f.logprob_of(prompt, gold) / tokens
85
+            delta = lp_ft - lp_base
86
+            deltas.append(delta)
87
+            if delta < -spec.regression_nats:
88
+                regressed += 1
89
+                worst.append({"prompt": prompt, "gold": gold, "delta": delta})
90
+
91
+        # Surface the worst offenders — up to 5.
92
+        worst.sort(key=lambda d: float(d["delta"]))
93
+        worst = worst[:5]
94
+
95
+        frac_regressed = regressed / len(items)
96
+        mean_delta = statistics.fmean(deltas)
97
+
98
+        passed = (
99
+            frac_regressed < spec.assert_fraction_regressed_lt
100
+            and mean_delta >= spec.assert_mean_delta_gte
101
+        )
102
+        verdict = Verdict.PASS if passed else Verdict.FAIL
103
+        # Score: 1.0 at zero regression + zero drift, declining with either.
104
+        regress_component = max(
105
+            0.0, 1.0 - frac_regressed / max(spec.assert_fraction_regressed_lt, 1e-6)
106
+        )
107
+        drift_component = max(0.0, min(1.0, (mean_delta + 1.0) / 1.5))
108
+        score = 0.6 * regress_component + 0.4 * drift_component
109
+
110
+        return ProbeResult(
111
+            name=spec.name,
112
+            kind=spec.kind,
113
+            verdict=verdict,
114
+            score=score,
115
+            raw=frac_regressed,
116
+            base_value=None,
117
+            ft_value=mean_delta,
118
+            evidence={
119
+                "fraction_regressed": frac_regressed,
120
+                "mean_delta_nats": mean_delta,
121
+                "regressed_count": regressed,
122
+                "total_items": len(items),
123
+                "worst_offenders": worst,
124
+                "regression_nats_threshold": spec.regression_nats,
125
+                "weight": spec.weight,
126
+            },
127
+            message=(
128
+                f"{regressed}/{len(items)} items regressed >{spec.regression_nats:.1f} nats "
129
+                f"(frac={frac_regressed:.1%}), mean_delta={mean_delta:+.3f} nats/tok"
130
+            ),
131
+        )
132
+
133
+
134
+def _token_estimate(s: str) -> int:
135
+    return max(1, len(s) // 4)
tests/unit/test_probe_calibration_drift.pyadded
@@ -0,0 +1,57 @@
1
+"""Tests for :mod:`dlm_sway.probes.calibration_drift`."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6
+from dlm_sway.core.result import Verdict
7
+from dlm_sway.probes._calibration_pack import BUILT_IN_PACK
8
+from dlm_sway.probes.base import RunContext, build_probe
9
+
10
+
11
+def _backend(delta_per_token: float) -> DummyDifferentialBackend:
12
+    """Apply a uniform per-token logprob delta across every item."""
13
+    base_lp: dict[tuple[str, str], float] = {}
14
+    ft_lp: dict[tuple[str, str], float] = {}
15
+    for prompt, gold in BUILT_IN_PACK:
16
+        base_lp[(prompt, gold)] = -5.0 * max(len(gold) // 4, 1)
17
+        ft_lp[(prompt, gold)] = base_lp[(prompt, gold)] + delta_per_token * max(len(gold) // 4, 1)
18
+    return DummyDifferentialBackend(
19
+        base=DummyResponses(logprobs=base_lp),
20
+        ft=DummyResponses(logprobs=ft_lp),
21
+    )
22
+
23
+
24
+class TestCalibrationDrift:
25
+    def test_healthy_when_no_regression(self) -> None:
26
+        backend = _backend(delta_per_token=0.0)  # no drift
27
+        probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"})
28
+        ctx = RunContext(backend=backend)
29
+        result = probe.run(spec, ctx)
30
+        assert result.verdict == Verdict.PASS
31
+        assert result.raw == 0.0  # zero fraction regressed
32
+
33
+    def test_fail_on_uniform_large_regression(self) -> None:
34
+        backend = _backend(delta_per_token=-2.0)  # every item regresses
35
+        probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"})
36
+        ctx = RunContext(backend=backend)
37
+        result = probe.run(spec, ctx)
38
+        assert result.verdict == Verdict.FAIL
39
+        assert result.raw == 1.0
40
+
41
+    def test_respects_items_limit(self) -> None:
42
+        backend = _backend(delta_per_token=0.0)
43
+        probe, spec = build_probe({"name": "c2", "kind": "calibration_drift", "items_limit": 5})
44
+        ctx = RunContext(backend=backend)
45
+        result = probe.run(spec, ctx)
46
+        assert result.evidence["total_items"] == 5
47
+
48
+    def test_worst_offenders_reported(self) -> None:
49
+        backend = _backend(delta_per_token=-2.0)
50
+        probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"})
51
+        ctx = RunContext(backend=backend)
52
+        result = probe.run(spec, ctx)
53
+        worst = result.evidence["worst_offenders"]
54
+        assert len(worst) <= 5
55
+        # Each worst-offender record carries prompt/gold/delta fields.
56
+        if worst:
57
+            assert {"prompt", "gold", "delta"} <= set(worst[0].keys())