sway(probes): C2 calibration_drift + 30-item built-in general-knowledge pack
- SHA
cb92c9f865859e1a2dbb6a0686ad8db75aa52e41- Parents
-
f0b5d50 - Tree
be99664
cb92c9f
cb92c9f865859e1a2dbb6a0686ad8db75aa52e41f0b5d50
be99664| Status | File | + | - |
|---|---|---|---|
| A |
src/dlm_sway/probes/_calibration_pack.py
|
63 | 0 |
| A |
src/dlm_sway/probes/calibration_drift.py
|
135 | 0 |
| A |
tests/unit/test_probe_calibration_drift.py
|
57 | 0 |
src/dlm_sway/probes/_calibration_pack.pyadded@@ -0,0 +1,63 @@ | ||
| 1 | +"""A small, built-in general-knowledge probe pack for C2. | |
| 2 | + | |
| 3 | +Each item is a ``(prompt, gold)`` pair where ``gold`` is the next few | |
| 4 | +tokens a competent base model should assign high probability to. The | |
| 5 | +items are deliberately *factually trivial* — the point isn't "does the | |
| 6 | +model know this?" but "did the fine-tune forget this?" — so the pack | |
| 7 | +skews toward grade-school geography, chemistry, arithmetic, and | |
| 8 | +high-frequency idiom. | |
| 9 | + | |
| 10 | +A real v1.0 will ship a 200-item pack sliced from TriviaQA + SQuAD + | |
| 11 | +OpenBookQA. This 30-item seed lets the probe ship today and catches the | |
| 12 | +most egregious over-fit cases. | |
| 13 | +""" | |
| 14 | + | |
| 15 | +from __future__ import annotations | |
| 16 | + | |
| 17 | +from typing import Final | |
| 18 | + | |
| 19 | +CalibrationItem = tuple[str, str] | |
| 20 | + | |
| 21 | +BUILT_IN_PACK: Final[tuple[CalibrationItem, ...]] = ( | |
| 22 | + # Geography | |
| 23 | + ("The capital of France is", " Paris"), | |
| 24 | + ("The capital of Japan is", " Tokyo"), | |
| 25 | + ("The largest ocean on Earth is the", " Pacific"), | |
| 26 | + ("Mount Everest is located on the border of Nepal and", " China"), | |
| 27 | + ("The longest river in South America is the", " Amazon"), | |
| 28 | + # Natural sciences | |
| 29 | + ("Water freezes at zero degrees", " Celsius"), | |
| 30 | + ("The chemical symbol for gold is", " Au"), | |
| 31 | + ("Light travels faster than", " sound"), | |
| 32 | + ("Plants convert sunlight into energy through", " photosynthesis"), | |
| 33 | + ("The Earth orbits around the", " Sun"), | |
| 34 | + # Arithmetic | |
| 35 | + ("Two plus two equals", " four"), | |
| 36 | + ("Ten times ten equals", " one hundred"), | |
| 37 | + ("Half of one hundred is", " fifty"), | |
| 38 | + ("A dozen means", " twelve"), | |
| 39 | + # Language and idiom | |
| 40 | + ("A rose by any other name would smell as", " sweet"), | |
| 41 | + ("To be or not to be, that is the", " question"), | |
| 42 | + ("The early bird catches the", " worm"), | |
| 43 | + ("Actions speak louder than", " words"), | |
| 44 | + ("A picture is worth a thousand", " words"), | |
| 45 | + # History | |
| 46 | + ("World War II ended in the year", " 1945"), | |
| 47 | + ("The first president of the United States was", " George Washington"), | |
| 48 | + ("The Berlin Wall fell in", " 1989"), | |
| 49 | + # Biology | |
| 50 | + ("Humans have twenty", " fingers and toes"), | |
| 51 | + ("The human body has two", " lungs"), | |
| 52 | + ("Blood is pumped through the body by the", " heart"), | |
| 53 | + # Technology | |
| 54 | + ("HTML stands for HyperText", " Markup Language"), | |
| 55 | + ("The World Wide Web was invented by Tim", " Berners-Lee"), | |
| 56 | + # Miscellaneous | |
| 57 | + ("One year has", " 365 days"), | |
| 58 | + ("A week has seven", " days"), | |
| 59 | + ("There are seven colors in a", " rainbow"), | |
| 60 | +) | |
| 61 | +"""30 items covering geography, science, arithmetic, language, history, | |
| 62 | +biology, and technology. Pulled from public-domain grade-school facts so | |
| 63 | +there's no licensing concern about shipping with the wheel.""" | |
src/dlm_sway/probes/calibration_drift.pyadded@@ -0,0 +1,135 @@ | ||
| 1 | +"""C2 CalibrationDrift — did we break general knowledge while fitting the doc? | |
| 2 | + | |
| 3 | +The classic small-doc fine-tune failure mode: the adapter learned the | |
| 4 | +document so well that it forgot the world. C2 catches this by scoring | |
| 5 | +base and ft on a packaged set of general-knowledge completions (the | |
| 6 | +``BUILT_IN_PACK`` — a 30-item seed of public-domain grade-school facts) | |
| 7 | +and flagging items whose per-token logprob regressed significantly. | |
| 8 | + | |
| 9 | +A healthy fine-tune: some items drift slightly (mild confidence shift, | |
| 10 | +normal), but essentially none regress below a nat of slack. An over-fit | |
| 11 | +fine-tune: 20%+ of items regress, the adapter has torched its ability | |
| 12 | +to answer anything outside the document. | |
| 13 | + | |
| 14 | +Pass when ``fraction_regressed < assert_fraction_regressed_lt`` AND | |
| 15 | +``mean_delta_nats >= assert_mean_delta_gte``. Both thresholds default | |
| 16 | +to values that trigger on genuine damage but tolerate normal drift. | |
| 17 | +""" | |
| 18 | + | |
| 19 | +from __future__ import annotations | |
| 20 | + | |
| 21 | +import statistics | |
| 22 | +from typing import Literal | |
| 23 | + | |
| 24 | +from pydantic import Field | |
| 25 | + | |
| 26 | +from dlm_sway.core.result import ProbeResult, Verdict | |
| 27 | +from dlm_sway.probes._calibration_pack import BUILT_IN_PACK | |
| 28 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext | |
| 29 | + | |
| 30 | + | |
| 31 | +class CalibrationItemSpec(ProbeSpec): | |
| 32 | + """Not used directly — documents the shape of an item override.""" | |
| 33 | + | |
| 34 | + kind: Literal["__calibration_item"] = "__calibration_item" | |
| 35 | + prompt: str = "" | |
| 36 | + gold: str = "" | |
| 37 | + | |
| 38 | + | |
| 39 | +class CalibrationDriftSpec(ProbeSpec): | |
| 40 | + kind: Literal["calibration_drift"] = "calibration_drift" | |
| 41 | + pack: Literal["builtin"] = "builtin" | |
| 42 | + """Source of items. ``"builtin"`` uses :data:`BUILT_IN_PACK`. Custom | |
| 43 | + packs will ship via a file reference in a later milestone.""" | |
| 44 | + items_limit: int | None = None | |
| 45 | + """If set, truncate the pack to this many items (for fast runs).""" | |
| 46 | + assert_fraction_regressed_lt: float = 0.15 | |
| 47 | + assert_mean_delta_gte: float = -0.5 | |
| 48 | + """Mean per-token logprob delta (ft − base) across the pack. Slightly | |
| 49 | + negative is tolerable; deeply negative is not.""" | |
| 50 | + regression_nats: float = 1.0 | |
| 51 | + """How many nats worse an item must get to count as regressed.""" | |
| 52 | + items: list[tuple[str, str]] = Field(default_factory=list) | |
| 53 | + """Optional inline override of the packaged items.""" | |
| 54 | + | |
| 55 | + | |
| 56 | +class CalibrationDriftProbe(Probe): | |
| 57 | + kind = "calibration_drift" | |
| 58 | + spec_cls = CalibrationDriftSpec | |
| 59 | + category = "calibration" | |
| 60 | + | |
| 61 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: | |
| 62 | + assert isinstance(spec, CalibrationDriftSpec) | |
| 63 | + items = list(spec.items) if spec.items else list(BUILT_IN_PACK) | |
| 64 | + if spec.items_limit is not None: | |
| 65 | + items = items[: spec.items_limit] | |
| 66 | + if not items: | |
| 67 | + return ProbeResult( | |
| 68 | + name=spec.name, | |
| 69 | + kind=spec.kind, | |
| 70 | + verdict=Verdict.ERROR, | |
| 71 | + score=None, | |
| 72 | + message="no calibration items", | |
| 73 | + ) | |
| 74 | + | |
| 75 | + deltas: list[float] = [] | |
| 76 | + regressed = 0 | |
| 77 | + worst: list[dict[str, float | str]] = [] | |
| 78 | + | |
| 79 | + for prompt, gold in items: | |
| 80 | + tokens = max(_token_estimate(gold), 1) | |
| 81 | + with ctx.backend.as_base() as b: | |
| 82 | + lp_base = b.logprob_of(prompt, gold) / tokens | |
| 83 | + with ctx.backend.as_finetuned() as f: | |
| 84 | + lp_ft = f.logprob_of(prompt, gold) / tokens | |
| 85 | + delta = lp_ft - lp_base | |
| 86 | + deltas.append(delta) | |
| 87 | + if delta < -spec.regression_nats: | |
| 88 | + regressed += 1 | |
| 89 | + worst.append({"prompt": prompt, "gold": gold, "delta": delta}) | |
| 90 | + | |
| 91 | + # Surface the worst offenders — up to 5. | |
| 92 | + worst.sort(key=lambda d: float(d["delta"])) | |
| 93 | + worst = worst[:5] | |
| 94 | + | |
| 95 | + frac_regressed = regressed / len(items) | |
| 96 | + mean_delta = statistics.fmean(deltas) | |
| 97 | + | |
| 98 | + passed = ( | |
| 99 | + frac_regressed < spec.assert_fraction_regressed_lt | |
| 100 | + and mean_delta >= spec.assert_mean_delta_gte | |
| 101 | + ) | |
| 102 | + verdict = Verdict.PASS if passed else Verdict.FAIL | |
| 103 | + # Score: 1.0 at zero regression + zero drift, declining with either. | |
| 104 | + regress_component = max( | |
| 105 | + 0.0, 1.0 - frac_regressed / max(spec.assert_fraction_regressed_lt, 1e-6) | |
| 106 | + ) | |
| 107 | + drift_component = max(0.0, min(1.0, (mean_delta + 1.0) / 1.5)) | |
| 108 | + score = 0.6 * regress_component + 0.4 * drift_component | |
| 109 | + | |
| 110 | + return ProbeResult( | |
| 111 | + name=spec.name, | |
| 112 | + kind=spec.kind, | |
| 113 | + verdict=verdict, | |
| 114 | + score=score, | |
| 115 | + raw=frac_regressed, | |
| 116 | + base_value=None, | |
| 117 | + ft_value=mean_delta, | |
| 118 | + evidence={ | |
| 119 | + "fraction_regressed": frac_regressed, | |
| 120 | + "mean_delta_nats": mean_delta, | |
| 121 | + "regressed_count": regressed, | |
| 122 | + "total_items": len(items), | |
| 123 | + "worst_offenders": worst, | |
| 124 | + "regression_nats_threshold": spec.regression_nats, | |
| 125 | + "weight": spec.weight, | |
| 126 | + }, | |
| 127 | + message=( | |
| 128 | + f"{regressed}/{len(items)} items regressed >{spec.regression_nats:.1f} nats " | |
| 129 | + f"(frac={frac_regressed:.1%}), mean_delta={mean_delta:+.3f} nats/tok" | |
| 130 | + ), | |
| 131 | + ) | |
| 132 | + | |
| 133 | + | |
| 134 | +def _token_estimate(s: str) -> int: | |
| 135 | + return max(1, len(s) // 4) | |
tests/unit/test_probe_calibration_drift.pyadded@@ -0,0 +1,57 @@ | ||
| 1 | +"""Tests for :mod:`dlm_sway.probes.calibration_drift`.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses | |
| 6 | +from dlm_sway.core.result import Verdict | |
| 7 | +from dlm_sway.probes._calibration_pack import BUILT_IN_PACK | |
| 8 | +from dlm_sway.probes.base import RunContext, build_probe | |
| 9 | + | |
| 10 | + | |
| 11 | +def _backend(delta_per_token: float) -> DummyDifferentialBackend: | |
| 12 | + """Apply a uniform per-token logprob delta across every item.""" | |
| 13 | + base_lp: dict[tuple[str, str], float] = {} | |
| 14 | + ft_lp: dict[tuple[str, str], float] = {} | |
| 15 | + for prompt, gold in BUILT_IN_PACK: | |
| 16 | + base_lp[(prompt, gold)] = -5.0 * max(len(gold) // 4, 1) | |
| 17 | + ft_lp[(prompt, gold)] = base_lp[(prompt, gold)] + delta_per_token * max(len(gold) // 4, 1) | |
| 18 | + return DummyDifferentialBackend( | |
| 19 | + base=DummyResponses(logprobs=base_lp), | |
| 20 | + ft=DummyResponses(logprobs=ft_lp), | |
| 21 | + ) | |
| 22 | + | |
| 23 | + | |
| 24 | +class TestCalibrationDrift: | |
| 25 | + def test_healthy_when_no_regression(self) -> None: | |
| 26 | + backend = _backend(delta_per_token=0.0) # no drift | |
| 27 | + probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"}) | |
| 28 | + ctx = RunContext(backend=backend) | |
| 29 | + result = probe.run(spec, ctx) | |
| 30 | + assert result.verdict == Verdict.PASS | |
| 31 | + assert result.raw == 0.0 # zero fraction regressed | |
| 32 | + | |
| 33 | + def test_fail_on_uniform_large_regression(self) -> None: | |
| 34 | + backend = _backend(delta_per_token=-2.0) # every item regresses | |
| 35 | + probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"}) | |
| 36 | + ctx = RunContext(backend=backend) | |
| 37 | + result = probe.run(spec, ctx) | |
| 38 | + assert result.verdict == Verdict.FAIL | |
| 39 | + assert result.raw == 1.0 | |
| 40 | + | |
| 41 | + def test_respects_items_limit(self) -> None: | |
| 42 | + backend = _backend(delta_per_token=0.0) | |
| 43 | + probe, spec = build_probe({"name": "c2", "kind": "calibration_drift", "items_limit": 5}) | |
| 44 | + ctx = RunContext(backend=backend) | |
| 45 | + result = probe.run(spec, ctx) | |
| 46 | + assert result.evidence["total_items"] == 5 | |
| 47 | + | |
| 48 | + def test_worst_offenders_reported(self) -> None: | |
| 49 | + backend = _backend(delta_per_token=-2.0) | |
| 50 | + probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"}) | |
| 51 | + ctx = RunContext(backend=backend) | |
| 52 | + result = probe.run(spec, ctx) | |
| 53 | + worst = result.evidence["worst_offenders"] | |
| 54 | + assert len(worst) <= 5 | |
| 55 | + # Each worst-offender record carries prompt/gold/delta fields. | |
| 56 | + if worst: | |
| 57 | + assert {"prompt", "gold", "delta"} <= set(worst[0].keys()) | |