| 1 | """Tests for :mod:`dlm_sway.probes.external_perplexity`.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import numpy as np |
| 6 | import pytest |
| 7 | |
| 8 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 9 | from dlm_sway.core.result import Verdict |
| 10 | from dlm_sway.core.scoring import RollingLogprob |
| 11 | from dlm_sway.probes._external_corpus import ( |
| 12 | available_corpora, |
| 13 | chunk_corpus, |
| 14 | load_corpus, |
| 15 | ) |
| 16 | from dlm_sway.probes.base import RunContext, build_probe |
| 17 | from dlm_sway.probes.external_perplexity import ExternalPerplexitySpec |
| 18 | |
| 19 | |
| 20 | def _rolling(text: str, per_tok: float) -> RollingLogprob: |
| 21 | """Build a uniform rolling logprob whose per-token mean is ``per_tok``.""" |
| 22 | tokens = text.split() |
| 23 | n = max(len(tokens), 1) |
| 24 | lp = np.full(max(n - 1, 0), per_tok, dtype=np.float32) |
| 25 | return RollingLogprob( |
| 26 | token_ids=np.arange(n, dtype=np.int64), |
| 27 | logprobs=lp, |
| 28 | num_tokens=n, |
| 29 | total_logprob=float(per_tok * max(n - 1, 0)), |
| 30 | ) |
| 31 | |
| 32 | |
| 33 | def _backend_with_delta(base_per_tok: float, ft_per_tok: float) -> DummyDifferentialBackend: |
| 34 | """Return a dummy backend where ft - base = (ft_per_tok - base_per_tok) per token. |
| 35 | |
| 36 | The canned rolling map is keyed by raw chunk text, so the probe's |
| 37 | first-64-chunk slice of the corpus doesn't matter — every chunk |
| 38 | falls back to the synthesized default (``_compute_rolling_logprob`` |
| 39 | uses mode defaults of -2.0 / -1.5). We override both maps with one |
| 40 | entry per chunk by preloading the corpus and pre-computing chunk |
| 41 | texts. |
| 42 | """ |
| 43 | corpus = load_corpus("public_domain_en") |
| 44 | chunks = chunk_corpus(corpus, chunk_chars=2048, max_chunks=16) |
| 45 | base_map = {c: _rolling(c, base_per_tok) for c in chunks} |
| 46 | ft_map = {c: _rolling(c, ft_per_tok) for c in chunks} |
| 47 | return DummyDifferentialBackend( |
| 48 | base=DummyResponses(rolling=base_map), |
| 49 | ft=DummyResponses(rolling=ft_map), |
| 50 | ) |
| 51 | |
| 52 | |
| 53 | class TestCorpusLoader: |
| 54 | def test_public_domain_en_is_available(self) -> None: |
| 55 | assert "public_domain_en" in available_corpora() |
| 56 | |
| 57 | def test_load_corpus_strips_comments(self) -> None: |
| 58 | text = load_corpus("public_domain_en") |
| 59 | # The raw file has `# -- source:` provenance lines; those must |
| 60 | # not survive into the probe-facing string. |
| 61 | assert "# --" not in text |
| 62 | assert text.strip(), "loaded corpus should not be empty" |
| 63 | |
| 64 | def test_load_corpus_unknown_raises(self) -> None: |
| 65 | with pytest.raises(KeyError): |
| 66 | load_corpus("not_a_real_corpus") |
| 67 | |
| 68 | def test_chunk_corpus_respects_caps(self) -> None: |
| 69 | text = "A" * 10_000 |
| 70 | chunks = chunk_corpus(text, chunk_chars=1024, max_chunks=4) |
| 71 | assert len(chunks) == 4 |
| 72 | assert all(len(c) == 1024 for c in chunks) |
| 73 | |
| 74 | def test_chunk_corpus_drops_short_tail(self) -> None: |
| 75 | # 2100 chars at chunk_chars=1024 → two full chunks + 52-char tail |
| 76 | # (below the 64-char floor, so it's dropped). |
| 77 | text = "A" * 2100 |
| 78 | chunks = chunk_corpus(text, chunk_chars=1024, max_chunks=16) |
| 79 | assert len(chunks) == 2 |
| 80 | |
| 81 | def test_chunk_corpus_keeps_long_tail(self) -> None: |
| 82 | # 2200 chars at chunk_chars=1024 → two full chunks + 152-char tail |
| 83 | # (above the 64-char floor, kept). |
| 84 | text = "A" * 2200 |
| 85 | chunks = chunk_corpus(text, chunk_chars=1024, max_chunks=16) |
| 86 | assert len(chunks) == 3 |
| 87 | |
| 88 | @pytest.mark.parametrize("bad", [0, -1]) |
| 89 | def test_chunk_corpus_rejects_nonpositive_chunk_chars(self, bad: int) -> None: |
| 90 | with pytest.raises(ValueError): |
| 91 | chunk_corpus("hello world", chunk_chars=bad, max_chunks=4) |
| 92 | |
| 93 | @pytest.mark.parametrize("bad", [0, -1]) |
| 94 | def test_chunk_corpus_rejects_nonpositive_max_chunks(self, bad: int) -> None: |
| 95 | with pytest.raises(ValueError): |
| 96 | chunk_corpus("hello world", chunk_chars=1024, max_chunks=bad) |
| 97 | |
| 98 | |
| 99 | class TestExternalPerplexityProbe: |
| 100 | def test_pass_when_ft_matches_base(self) -> None: |
| 101 | """No perplexity shift → mean_delta ≈ 0 → fixed-threshold PASS.""" |
| 102 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0) |
| 103 | probe, spec = build_probe( |
| 104 | {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4} |
| 105 | ) |
| 106 | ctx = RunContext(backend=backend) |
| 107 | result = probe.run(spec, ctx) |
| 108 | assert result.verdict == Verdict.PASS |
| 109 | assert result.raw is not None |
| 110 | assert abs(result.raw) < 1e-6 |
| 111 | # No null stats yet — message should carry the no-calibration note. |
| 112 | assert "no calibration" in (result.message or "").lower() |
| 113 | |
| 114 | def test_pass_when_ft_improves_base(self) -> None: |
| 115 | """ft assigns higher logprobs → mean_delta > 0 → PASS.""" |
| 116 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-1.5) |
| 117 | probe, spec = build_probe( |
| 118 | {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4} |
| 119 | ) |
| 120 | ctx = RunContext(backend=backend) |
| 121 | result = probe.run(spec, ctx) |
| 122 | assert result.verdict == Verdict.PASS |
| 123 | assert result.raw is not None |
| 124 | assert result.raw > 0 |
| 125 | |
| 126 | def test_fail_on_large_regression(self) -> None: |
| 127 | """ft raised perplexity by >0.1 nats/tok → fixed-threshold FAIL.""" |
| 128 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.5) |
| 129 | probe, spec = build_probe( |
| 130 | {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4} |
| 131 | ) |
| 132 | ctx = RunContext(backend=backend) |
| 133 | result = probe.run(spec, ctx) |
| 134 | assert result.verdict == Verdict.FAIL |
| 135 | assert result.raw is not None |
| 136 | assert result.raw < -0.1 |
| 137 | |
| 138 | def test_evidence_carries_per_chunk_deltas(self) -> None: |
| 139 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-1.9) |
| 140 | probe, spec = build_probe( |
| 141 | {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 3} |
| 142 | ) |
| 143 | ctx = RunContext(backend=backend) |
| 144 | result = probe.run(spec, ctx) |
| 145 | ev = result.evidence |
| 146 | assert ev["corpus"] == "public_domain_en" |
| 147 | assert ev["num_chunks"] == 3 |
| 148 | assert len(ev["per_chunk_delta"]) == 3 |
| 149 | # Every chunk carries the same 0.1 nats/tok improvement by |
| 150 | # construction of the canned rolling maps. |
| 151 | for d in ev["per_chunk_delta"]: |
| 152 | assert abs(d - 0.1) < 1e-5 |
| 153 | |
| 154 | def test_unknown_corpus_errors(self) -> None: |
| 155 | """Spec validation accepts only declared corpora; bypassing it |
| 156 | via direct construction yields a clean ERROR.""" |
| 157 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0) |
| 158 | spec = ExternalPerplexitySpec.model_construct( |
| 159 | name="ext", |
| 160 | kind="external_perplexity", |
| 161 | corpus="does_not_exist", # type: ignore[arg-type] |
| 162 | ) |
| 163 | probe, _ = build_probe({"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 4}) |
| 164 | ctx = RunContext(backend=backend) |
| 165 | result = probe.run(spec, ctx) |
| 166 | assert result.verdict == Verdict.ERROR |
| 167 | assert "unknown corpus" in (result.message or "") |
| 168 | |
| 169 | def test_respects_max_chunks(self) -> None: |
| 170 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0) |
| 171 | probe, spec = build_probe( |
| 172 | {"name": "ext_ppl", "kind": "external_perplexity", "max_chunks": 2} |
| 173 | ) |
| 174 | ctx = RunContext(backend=backend) |
| 175 | result = probe.run(spec, ctx) |
| 176 | assert result.evidence["num_chunks"] == 2 |
| 177 | |
| 178 | |
| 179 | class TestCalibrateSpec: |
| 180 | def test_returns_non_none_spec(self) -> None: |
| 181 | from dlm_sway.probes.external_perplexity import ExternalPerplexityProbe |
| 182 | |
| 183 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-2.0) |
| 184 | ctx = RunContext(backend=backend) |
| 185 | cal_spec = ExternalPerplexityProbe.calibrate_spec(ctx) |
| 186 | assert cal_spec is not None |
| 187 | assert cal_spec.kind == "external_perplexity" |
| 188 | # Cheap calibration uses ≤4 chunks so null calibration stays fast. |
| 189 | assert cal_spec.max_chunks <= 4 |
| 190 | |
| 191 | |
| 192 | class TestNullCalibrationEndToEnd: |
| 193 | def test_runner_threads_null_stats_to_external_perplexity(self) -> None: |
| 194 | """null_adapter → external_perplexity — the runner threads the |
| 195 | per-kind null stats into the downstream probe's context. |
| 196 | |
| 197 | Post-F02 (Audit 03), the dummy backend's ``rolling_logprob`` |
| 198 | isn't seed-sensitive, so null-calibration runs of |
| 199 | ``external_perplexity`` produce identical raws across seeds. |
| 200 | That's a legitimately-degenerate null (``std == 0``), and the |
| 201 | F02 fix now surfaces that as ``degenerate: 1.0`` in the stats |
| 202 | dict. Downstream ``z_score`` correctly refuses to divide by a |
| 203 | lifted std, and the probe takes the fixed-threshold fallback. |
| 204 | |
| 205 | What THIS test pins is the runner threading contract: regardless |
| 206 | of whether the z-score fires, ``null_stats`` must reach the |
| 207 | suite result + the probe's message must surface the |
| 208 | ``(no calibration for ...)`` tag when the null is degenerate. |
| 209 | A regression in threading would drop ``null_stats`` from the |
| 210 | suite result entirely. |
| 211 | """ |
| 212 | from dlm_sway.suite.runner import run as run_suite |
| 213 | from dlm_sway.suite.spec import SwaySpec |
| 214 | |
| 215 | backend = _backend_with_delta(base_per_tok=-2.0, ft_per_tok=-1.9) |
| 216 | raw_spec = SwaySpec.model_validate( |
| 217 | { |
| 218 | "version": 1, |
| 219 | "models": { |
| 220 | "base": {"base": "b"}, |
| 221 | "ft": {"base": "b", "adapter": "/tmp/a"}, |
| 222 | }, |
| 223 | "suite": [ |
| 224 | {"name": "null", "kind": "null_adapter", "runs": 2, "cache": False}, |
| 225 | { |
| 226 | "name": "ext", |
| 227 | "kind": "external_perplexity", |
| 228 | "max_chunks": 3, |
| 229 | "assert_mean_delta_gte": -100.0, # permissive fixed threshold |
| 230 | }, |
| 231 | ], |
| 232 | } |
| 233 | ) |
| 234 | result = run_suite(raw_spec, backend) |
| 235 | assert len(result.probes) == 2 |
| 236 | null_result = result.probes[0] |
| 237 | ext_result = result.probes[1] |
| 238 | assert null_result.verdict == Verdict.PASS |
| 239 | |
| 240 | # F02 — the runner threads null_stats into suite.null_stats |
| 241 | # even when the null is degenerate. The probe sees the stats |
| 242 | # dict (with ``degenerate: 1.0``) and chooses the fallback |
| 243 | # path itself. |
| 244 | assert "external_perplexity" in result.null_stats |
| 245 | ext_null = result.null_stats["external_perplexity"] |
| 246 | assert ext_null.get("degenerate", 0.0) >= 0.5 |
| 247 | # Probe fell back to fixed thresholds; z_score is None, and |
| 248 | # the message carries the (no calibration) tag the S02 report |
| 249 | # layer uses to annotate the row. |
| 250 | assert ext_result.z_score is None |
| 251 | assert "no calibration for external_perplexity" in (ext_result.message or "") |