Python · 3674 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.calibration_drift`."""
2
3 from __future__ import annotations
4
5 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6 from dlm_sway.core.result import Verdict
7 from dlm_sway.probes._calibration_pack import BUILT_IN_PACK
8 from dlm_sway.probes.base import RunContext, build_probe
9
10
11 def _backend(delta_per_token: float) -> DummyDifferentialBackend:
12 """Apply a uniform per-token logprob delta across every item."""
13 base_lp: dict[tuple[str, str], float] = {}
14 ft_lp: dict[tuple[str, str], float] = {}
15 for prompt, gold in BUILT_IN_PACK:
16 base_lp[(prompt, gold)] = -5.0 * max(len(gold) // 4, 1)
17 ft_lp[(prompt, gold)] = base_lp[(prompt, gold)] + delta_per_token * max(len(gold) // 4, 1)
18 return DummyDifferentialBackend(
19 base=DummyResponses(logprobs=base_lp),
20 ft=DummyResponses(logprobs=ft_lp),
21 )
22
23
24 class TestCalibrationDrift:
25 def test_healthy_when_no_regression(self) -> None:
26 backend = _backend(delta_per_token=0.0) # no drift
27 probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"})
28 ctx = RunContext(backend=backend)
29 result = probe.run(spec, ctx)
30 assert result.verdict == Verdict.PASS
31 assert result.raw == 0.0 # zero fraction regressed
32
33 def test_fail_on_uniform_large_regression(self) -> None:
34 backend = _backend(delta_per_token=-2.0) # every item regresses
35 probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"})
36 ctx = RunContext(backend=backend)
37 result = probe.run(spec, ctx)
38 assert result.verdict == Verdict.FAIL
39 assert result.raw == 1.0
40
41 def test_respects_items_limit(self) -> None:
42 backend = _backend(delta_per_token=0.0)
43 probe, spec = build_probe({"name": "c2", "kind": "calibration_drift", "items_limit": 5})
44 ctx = RunContext(backend=backend)
45 result = probe.run(spec, ctx)
46 assert result.evidence["total_items"] == 5
47
48 def test_worst_offenders_reported(self) -> None:
49 backend = _backend(delta_per_token=-2.0)
50 probe, spec = build_probe({"name": "c2", "kind": "calibration_drift"})
51 ctx = RunContext(backend=backend)
52 result = probe.run(spec, ctx)
53 worst = result.evidence["worst_offenders"]
54 assert len(worst) <= 5
55 # Each worst-offender record carries prompt/gold/delta fields.
56 if worst:
57 assert {"prompt", "gold", "delta"} <= set(worst[0].keys())
58
59
60 class TestPackContract:
61 """B12: pack must be 200 items and well-formed."""
62
63 def test_pack_size_is_200(self) -> None:
64 assert len(BUILT_IN_PACK) == 200
65
66 def test_pack_items_are_well_formed(self) -> None:
67 """Each item is a (str, str) pair with non-empty gold."""
68 for idx, item in enumerate(BUILT_IN_PACK):
69 assert isinstance(item, tuple), f"item {idx}: not a tuple"
70 assert len(item) == 2, f"item {idx}: not a pair"
71 prompt, gold = item
72 assert isinstance(prompt, str), f"item {idx}: prompt not str"
73 assert prompt, f"item {idx}: empty prompt"
74 assert isinstance(gold, str), f"item {idx}: gold not str"
75 assert gold, f"item {idx}: empty gold"
76
77 def test_items_limit_subsets_pack(self) -> None:
78 """``items_limit`` truncates to the first N items deterministically."""
79 backend = _backend(delta_per_token=0.0)
80 probe, spec = build_probe({"name": "c2", "kind": "calibration_drift", "items_limit": 50})
81 ctx = RunContext(backend=backend)
82 result = probe.run(spec, ctx)
83 assert result.evidence["total_items"] == 50