Python · 10832 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.external_perplexity`."""
2
3 from __future__ import annotations
4
5 import numpy as np
6 import pytest
7
8 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
9 from dlm_sway.core.result import Verdict
10 from dlm_sway.core.scoring import RollingLogprob
11 from dlm_sway.probes._external_corpus import (
12 available_corpora,
13 chunk_corpus,
14 load_corpus,
15 )
16 from dlm_sway.probes.base import RunContext, build_probe
17 from dlm_sway.probes.external_perplexity import ExternalPerplexitySpec
18
19
20 def _rolling(text: str, per_tok: float) -> RollingLogprob:
21 """Build a uniform rolling logprob whose per-token mean is ``per_tok``."""
22 tokens = text.split()
23 n = max(len(tokens), 1)
24 lp = np.full(max(n - 1, 0), per_tok, dtype=np.float32)
25 return RollingLogprob(
26 token_ids=np.arange(n, dtype=np.int64),
27 logprobs=lp,
28 num_tokens=n,
29 total_logprob=float(per_tok * max(n - 1, 0)),
30 )
31
32
33 def _backend_with_delta(base_per_tok: float, ft_per_tok: float) -> DummyDifferentialBackend:
34 """Return a dummy backend where ft - base = (ft_per_tok - base_per_tok) per token.
35
36 The canned rolling map is keyed by raw chunk text, so the probe's
37 first-64-chunk slice of the corpus doesn't matter — every chunk
38 falls back to the synthesized default (``_compute_rolling_logprob``
39 uses mode defaults of -2.0 / -1.5). We override both maps with one
40 entry per chunk by preloading the corpus and pre-computing chunk
41 texts.
42 """
43 corpus = load_corpus("public_domain_en")
44 chunks = chunk_corpus(corpus, chunk_chars=2048, max_chunks=16)
45 base_map = {c: _rolling(c, base_per_tok) for c in chunks}
46 ft_map = {c: _rolling(c, ft_per_tok) for c in chunks}
47 return DummyDifferentialBackend(
48 base=DummyResponses(rolling=base_map),
49 ft=DummyResponses(rolling=ft_map),
50 )
51
52
53 class TestCorpusLoader:
54 def test_public_domain_en_is_available(self) -> None:
55 assert "public_domain_en" in available_corpora()
56
57 def test_load_corpus_strips_comments(self) -> None:
58 text = load_corpus("public_domain_en")
59 # The raw file has `# -- source:` provenance lines; those must
60 # not survive into the probe-facing string.
61 assert "# --" not in text
62 assert text.strip(), "loaded corpus should not be empty"
63
64 def test_load_corpus_unknown_raises(self) -> None:
65 with pytest.raises(KeyError):
66 load_corpus("not_a_real_corpus")
67
68 def test_chunk_corpus_respects_caps(self) -> None:
69 text = "A" * 10_000
70 chunks = chunk_corpus(text, chunk_chars=1024, max_chunks=4)
71 assert len(chunks) == 4
72 assert all(len(c) == 1024 for c in chunks)
73
74 def test_chunk_corpus_drops_short_tail(self) -> None:
75 # 2100 chars at chunk_chars=1024 → two full chunks + 52-char tail
76 # (below the 64-char floor, so it's dropped).
77 text = "A" * 2100
78 chunks = chunk_corpus(text, chunk_chars=1024, max_chunks=16)
79 assert len(chunks) == 2
80
81 def test_chunk_corpus_keeps_long_tail(self) -> None:
82 # 2200 chars at chunk_chars=1024 → two full chunks + 152-char tail
83 # (above the 64-char floor, kept).
84 text = "A" * 2200
85 chunks = chunk_corpus(text, chunk_chars=1024, max_chunks=16)
86 assert len(chunks) == 3
87
88 @pytest.mark.parametrize("bad", [0, -1])
89 def test_chunk_corpus_rejects_nonpositive_chunk_chars(self, bad: int) -> None:
90 with pytest.raises(ValueError):
91 chunk_corpus("hello world", chunk_chars=bad, max_chunks=4)
92
93 @pytest.mark.parametrize("bad", [0, -1])
94 def test_chunk_corpus_rejects_nonpositive_max_chunks(self, bad: int) -> None:
95 with pytest.raises(ValueError):
96 chunk_corpus("hello world", chunk_chars=1024, max_chunks=bad)
97
98
99 class TestExternalPerplexityProbe:
100 def test_pass_when_ft_matches_base(self) -> None:
101 """No perplexity shift → mean_delta ≈ 0 → fixed-threshold PASS."""
102 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0)
103 probe, spec = build_probe(
104 {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4}
105 )
106 ctx = RunContext(backend=backend)
107 result = probe.run(spec, ctx)
108 assert result.verdict == Verdict.PASS
109 assert result.raw is not None
110 assert abs(result.raw) < 1e-6
111 # No null stats yet — message should carry the no-calibration note.
112 assert "no calibration" in (result.message or "").lower()
113
114 def test_pass_when_ft_improves_base(self) -> None:
115 """ft assigns higher logprobs → mean_delta > 0 → PASS."""
116 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-1.5)
117 probe, spec = build_probe(
118 {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4}
119 )
120 ctx = RunContext(backend=backend)
121 result = probe.run(spec, ctx)
122 assert result.verdict == Verdict.PASS
123 assert result.raw is not None
124 assert result.raw > 0
125
126 def test_fail_on_large_regression(self) -> None:
127 """ft raised perplexity by >0.1 nats/tok → fixed-threshold FAIL."""
128 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.5)
129 probe, spec = build_probe(
130 {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4}
131 )
132 ctx = RunContext(backend=backend)
133 result = probe.run(spec, ctx)
134 assert result.verdict == Verdict.FAIL
135 assert result.raw is not None
136 assert result.raw < -0.1
137
138 def test_evidence_carries_per_chunk_deltas(self) -> None:
139 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-1.9)
140 probe, spec = build_probe(
141 {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 3}
142 )
143 ctx = RunContext(backend=backend)
144 result = probe.run(spec, ctx)
145 ev = result.evidence
146 assert ev["corpus"] == "public_domain_en"
147 assert ev["num_chunks"] == 3
148 assert len(ev["per_chunk_delta"]) == 3
149 # Every chunk carries the same 0.1 nats/tok improvement by
150 # construction of the canned rolling maps.
151 for d in ev["per_chunk_delta"]:
152 assert abs(d - 0.1) < 1e-5
153
154 def test_unknown_corpus_errors(self) -> None:
155 """Spec validation accepts only declared corpora; bypassing it
156 via direct construction yields a clean ERROR."""
157 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0)
158 spec = ExternalPerplexitySpec.model_construct(
159 name="ext",
160 kind="external_perplexity",
161 corpus="does_not_exist", # type: ignore[arg-type]
162 )
163 probe, _ = build_probe({"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4})
164 ctx = RunContext(backend=backend)
165 result = probe.run(spec, ctx)
166 assert result.verdict == Verdict.ERROR
167 assert "unknown corpus" in (result.message or "")
168
169 def test_respects_max_chunks(self) -> None:
170 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0)
171 probe, spec = build_probe(
172 {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 2}
173 )
174 ctx = RunContext(backend=backend)
175 result = probe.run(spec, ctx)
176 assert result.evidence["num_chunks"] == 2
177
178
179 class TestCalibrateSpec:
180 def test_returns_non_none_spec(self) -> None:
181 from dlm_sway.probes.external_perplexity import ExternalPerplexityProbe
182
183 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0)
184 ctx = RunContext(backend=backend)
185 cal_spec = ExternalPerplexityProbe.calibrate_spec(ctx)
186 assert cal_spec is not None
187 assert cal_spec.kind == "external_perplexity"
188 # Cheap calibration uses ≤4 chunks so null calibration stays fast.
189 assert cal_spec.max_chunks <= 4
190
191
192 class TestNullCalibrationEndToEnd:
193 def test_runner_threads_null_stats_to_external_perplexity(self) -> None:
194 """null_adapter → external_perplexity — the runner threads the
195 per-kind null stats into the downstream probe's context.
196
197 Post-F02 (Audit 03), the dummy backend's ``rolling_logprob``
198 isn't seed-sensitive, so null-calibration runs of
199 ``external_perplexity`` produce identical raws across seeds.
200 That's a legitimately-degenerate null (``std == 0``), and the
201 F02 fix now surfaces that as ``degenerate: 1.0`` in the stats
202 dict. Downstream ``z_score`` correctly refuses to divide by a
203 lifted std, and the probe takes the fixed-threshold fallback.
204
205 What THIS test pins is the runner threading contract: regardless
206 of whether the z-score fires, ``null_stats`` must reach the
207 suite result + the probe's message must surface the
208 ``(no calibration for ...)`` tag when the null is degenerate.
209 A regression in threading would drop ``null_stats`` from the
210 suite result entirely.
211 """
212 from dlm_sway.suite.runner import run as run_suite
213 from dlm_sway.suite.spec import SwaySpec
214
215 backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-1.9)
216 raw_spec = SwaySpec.model_validate(
217 {
218 "version": 1,
219 "models": {
220 "base": {"base": "b"},
221 "ft": {"base": "b", "adapter": "/tmp/a"},
222 },
223 "suite": [
224 {"name": "null", "kind": "null_adapter", "runs": 2, "cache": False},
225 {
226 "name": "ext",
227 "kind": "external_perplexity",
228 "max_chunks": 3,
229 "assert_mean_delta_gte": -100.0, # permissive fixed threshold
230 },
231 ],
232 }
233 )
234 result = run_suite(raw_spec, backend)
235 assert len(result.probes) == 2
236 null_result = result.probes[0]
237 ext_result = result.probes[1]
238 assert null_result.verdict == Verdict.PASS
239
240 # F02 — the runner threads null_stats into suite.null_stats
241 # even when the null is degenerate. The probe sees the stats
242 # dict (with ``degenerate: 1.0``) and chooses the fallback
243 # path itself.
244 assert "external_perplexity" in result.null_stats
245 ext_null = result.null_stats["external_perplexity"]
246 assert ext_null.get("degenerate", 0.0) >= 0.5
247 # Probe fell back to fixed thresholds; z_score is None, and
248 # the message carries the (no calibration) tag the S02 report
249 # layer uses to annotate the row.
250 assert ext_result.z_score is None
251 assert "no calibration for external_perplexity" in (ext_result.message or "")