@@ -0,0 +1,418 @@ |
| 1 | +"""Tests for :mod:`dlm_sway.probes.multi_turn_coherence`.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import math |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 10 | +from dlm_sway.core.result import Verdict |
| 11 | +from dlm_sway.core.scoring import TokenDist |
| 12 | +from dlm_sway.probes.base import RunContext, build_probe |
| 13 | +from dlm_sway.probes.multi_turn_coherence import ( |
| 14 | + _ascii_sparkline, |
| 15 | + _chat_template_of, |
| 16 | + _fallback_format, |
| 17 | + _fit_half_life_turns, |
| 18 | + _verdict_from_half_life, |
| 19 | +) |
| 20 | + |
| 21 | +# --------------------------------------------------------------------------- |
| 22 | +# Fake tokenizer + dist-injection helpers |
| 23 | +# --------------------------------------------------------------------------- |
| 24 | + |
| 25 | + |
| 26 | +class _FakeTokenizer: |
| 27 | + """Minimal stand-in for an HF tokenizer with a chat_template. |
| 28 | + |
| 29 | + Implements just enough of the surface :mod:`multi_turn_coherence` |
| 30 | + consults: the ``chat_template`` attribute and an |
| 31 | + ``apply_chat_template`` method that returns a concatenated |
| 32 | + role-marker string. Real HF chat templates are Jinja-rendered; |
| 33 | + ours is a deterministic string concat so per-turn comparisons |
| 34 | + are reproducible. |
| 35 | + """ |
| 36 | + |
| 37 | + chat_template = "fake-jinja-template-string" |
| 38 | + |
| 39 | + def apply_chat_template( |
| 40 | + self, |
| 41 | + messages: list[dict[str, str]], |
| 42 | + *, |
| 43 | + tokenize: bool = False, |
| 44 | + add_generation_prompt: bool = False, |
| 45 | + ) -> str: |
| 46 | + del tokenize # we always return text |
| 47 | + parts = [f"{m['role']}::{m['content']}" for m in messages] |
| 48 | + if add_generation_prompt: |
| 49 | + parts.append("assistant::") |
| 50 | + return "|".join(parts) |
| 51 | + |
| 52 | + |
| 53 | +def _attach_tokenizer( |
| 54 | + backend: DummyDifferentialBackend, tokenizer: object | None |
| 55 | +) -> DummyDifferentialBackend: |
| 56 | + """Slot a tokenizer onto the dummy backend the same way HF does. |
| 57 | + |
| 58 | + multi_turn_coherence reads ``ctx.backend._tokenizer`` to find the |
| 59 | + chat template — mirrors prompt_collapse's _peek_backend_tokenizer. |
| 60 | + """ |
| 61 | + backend._tokenizer = tokenizer # type: ignore[attr-defined] |
| 62 | + return backend |
| 63 | + |
| 64 | + |
| 65 | +def _decay_dist(value: float, *, broad: bool, k: int = 8) -> TokenDist: |
| 66 | + """Build a TokenDist whose top-k logprobs encode a known signal. |
| 67 | + |
| 68 | + ``value`` controls how peaked the distribution is — feeding two |
| 69 | + such dists into the divergence helper produces a deterministic |
| 70 | + KL we can plant per-turn for the curve-fit tests. |
| 71 | + """ |
| 72 | + if broad: |
| 73 | + # Roughly uniform with tiny perturbation (clears the |
| 74 | + # _UNIFORM_LOGPROB_TOL guard). |
| 75 | + lp = np.full(k, -math.log(k), dtype=np.float32) |
| 76 | + lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32) |
| 77 | + else: |
| 78 | + # Sharp: most mass on first token, with `value` controlling |
| 79 | + # how peaked. Larger value ⇒ sharper. |
| 80 | + lp = np.array([-0.01 * value] + [-1.0 - value] * (k - 1), dtype=np.float32) |
| 81 | + return TokenDist( |
| 82 | + token_ids=np.arange(k, dtype=np.int64), |
| 83 | + logprobs=lp, |
| 84 | + vocab_size=1000, |
| 85 | + tail_logprob=None, |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +# --------------------------------------------------------------------------- |
| 90 | +# Skip / error paths |
| 91 | +# --------------------------------------------------------------------------- |
| 92 | + |
| 93 | + |
| 94 | +class TestSkipPaths: |
| 95 | + def test_skips_when_no_tokenizer(self) -> None: |
| 96 | + """Dummy backend has no tokenizer ⇒ probe SKIPs cleanly.""" |
| 97 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 98 | + probe, spec = build_probe( |
| 99 | + { |
| 100 | + "name": "mtc", |
| 101 | + "kind": "multi_turn_coherence_decay", |
| 102 | + "prompts": ["hello"], |
| 103 | + } |
| 104 | + ) |
| 105 | + result = probe.run(spec, RunContext(backend=backend)) |
| 106 | + assert result.verdict == Verdict.SKIP |
| 107 | + assert "chat_template" in result.message |
| 108 | + |
| 109 | + def test_skips_when_tokenizer_lacks_chat_template(self) -> None: |
| 110 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 111 | + |
| 112 | + class _BareTokenizer: |
| 113 | + chat_template = None |
| 114 | + |
| 115 | + _attach_tokenizer(backend, _BareTokenizer()) |
| 116 | + probe, spec = build_probe( |
| 117 | + { |
| 118 | + "name": "mtc", |
| 119 | + "kind": "multi_turn_coherence_decay", |
| 120 | + "prompts": ["hello"], |
| 121 | + } |
| 122 | + ) |
| 123 | + result = probe.run(spec, RunContext(backend=backend)) |
| 124 | + assert result.verdict == Verdict.SKIP |
| 125 | + |
| 126 | + def test_errors_when_no_prompts(self) -> None: |
| 127 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 128 | + _attach_tokenizer(backend, _FakeTokenizer()) |
| 129 | + probe, spec = build_probe({"name": "mtc", "kind": "multi_turn_coherence_decay"}) |
| 130 | + result = probe.run(spec, RunContext(backend=backend)) |
| 131 | + assert result.verdict == Verdict.ERROR |
| 132 | + |
| 133 | + |
| 134 | +# --------------------------------------------------------------------------- |
| 135 | +# Happy path: planted decay → curve fit recovers half-life |
| 136 | +# --------------------------------------------------------------------------- |
| 137 | + |
| 138 | + |
| 139 | +class TestHappyPath: |
| 140 | + """End-to-end probe runs against the dummy backend. |
| 141 | + |
| 142 | + Tests verify *wiring* — the probe correctly drives the chat loop, |
| 143 | + asks for next-token-dists at the right positions, and produces a |
| 144 | + finalized result with the documented evidence shape. The exact |
| 145 | + KL values depend on the dummy backend's synthetic-distribution |
| 146 | + math; tests assert the curve is monotonic / non-monotonic / flat |
| 147 | + as intended rather than pinning specific KL targets. Curve-fit |
| 148 | + correctness is covered separately in :class:`TestFitHalfLife`. |
| 149 | + """ |
| 150 | + |
| 151 | + def test_decreasing_curve_yields_finite_half_life(self) -> None: |
| 152 | + """Decreasing planted KLs → 'ok' fit status, finite half-life.""" |
| 153 | + backend = _backend_with_decreasing_dists( |
| 154 | + prompt="hello", max_turns=4, sharpness_per_turn=[3.0, 1.0, 0.3] |
| 155 | + ) |
| 156 | + probe, spec = build_probe( |
| 157 | + { |
| 158 | + "name": "mtc", |
| 159 | + "kind": "multi_turn_coherence_decay", |
| 160 | + "prompts": ["hello"], |
| 161 | + "max_turns": 4, |
| 162 | + "assert_half_life_turns": 0.5, |
| 163 | + } |
| 164 | + ) |
| 165 | + result = probe.run(spec, RunContext(backend=backend)) |
| 166 | + assert result.verdict in {Verdict.PASS, Verdict.FAIL}, result.message |
| 167 | + assert result.evidence["fit_status"] in {"ok", "stable"} |
| 168 | + per_turn = result.evidence["per_turn_kls"] |
| 169 | + assert len(per_turn) == 3 |
| 170 | + # Monotonic decrease — verifies the probe wired turns in order. |
| 171 | + assert per_turn[0] > per_turn[1] > per_turn[2] |
| 172 | + assert result.raw is not None |
| 173 | + assert result.raw > 0.0 |
| 174 | + |
| 175 | + def test_flat_curve_marked_stable(self) -> None: |
| 176 | + """Same dists at every turn → fit_status=stable → PASS.""" |
| 177 | + backend = _backend_with_decreasing_dists( |
| 178 | + prompt="hello", max_turns=4, sharpness_per_turn=[1.0, 1.0, 1.0] |
| 179 | + ) |
| 180 | + probe, spec = build_probe( |
| 181 | + { |
| 182 | + "name": "mtc", |
| 183 | + "kind": "multi_turn_coherence_decay", |
| 184 | + "prompts": ["hello"], |
| 185 | + "max_turns": 4, |
| 186 | + } |
| 187 | + ) |
| 188 | + result = probe.run(spec, RunContext(backend=backend)) |
| 189 | + assert result.verdict == Verdict.PASS |
| 190 | + assert result.evidence["fit_status"] == "stable" |
| 191 | + assert result.score == 1.0 |
| 192 | + assert "held coherence" in result.message |
| 193 | + |
| 194 | + def test_growing_curve_warns(self) -> None: |
| 195 | + """Increasing planted KLs → fit_status=non_monotonic → WARN.""" |
| 196 | + backend = _backend_with_decreasing_dists( |
| 197 | + prompt="hello", max_turns=4, sharpness_per_turn=[0.3, 1.0, 3.0] |
| 198 | + ) |
| 199 | + probe, spec = build_probe( |
| 200 | + { |
| 201 | + "name": "mtc", |
| 202 | + "kind": "multi_turn_coherence_decay", |
| 203 | + "prompts": ["hello"], |
| 204 | + "max_turns": 4, |
| 205 | + } |
| 206 | + ) |
| 207 | + result = probe.run(spec, RunContext(backend=backend)) |
| 208 | + assert result.verdict == Verdict.WARN |
| 209 | + assert result.evidence["fit_status"] == "non_monotonic" |
| 210 | + |
| 211 | + def test_evidence_carries_turns_and_sparkline(self) -> None: |
| 212 | + backend = _backend_with_decreasing_dists( |
| 213 | + prompt="hello", max_turns=4, sharpness_per_turn=[2.0, 1.0, 0.5] |
| 214 | + ) |
| 215 | + probe, spec = build_probe( |
| 216 | + { |
| 217 | + "name": "mtc", |
| 218 | + "kind": "multi_turn_coherence_decay", |
| 219 | + "prompts": ["hello"], |
| 220 | + "max_turns": 4, |
| 221 | + } |
| 222 | + ) |
| 223 | + result = probe.run(spec, RunContext(backend=backend)) |
| 224 | + assert result.evidence["turns_axis"] == [2.0, 3.0, 4.0] |
| 225 | + assert result.evidence["max_turns"] == 4 |
| 226 | + assert result.evidence["num_prompts"] == 1 |
| 227 | + assert isinstance(result.evidence["sparkline"], str) |
| 228 | + assert len(result.evidence["sparkline"]) == 3 |
| 229 | + |
| 230 | + |
| 231 | +# --------------------------------------------------------------------------- |
| 232 | +# Curve fit unit tests (math only, no backend) |
| 233 | +# --------------------------------------------------------------------------- |
| 234 | + |
| 235 | + |
| 236 | +class TestFitHalfLife: |
| 237 | + def test_clean_exponential_recovers_half_life(self) -> None: |
| 238 | + # y = exp(-0.693 * x) ⇒ half-life = 1 turn |
| 239 | + turns = np.array([2.0, 3.0, 4.0]) |
| 240 | + kls = np.exp(-math.log(2.0) * turns) |
| 241 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 242 | + assert status == "ok" |
| 243 | + assert h is not None |
| 244 | + assert math.isclose(h, 1.0, rel_tol=1e-3) |
| 245 | + |
| 246 | + def test_stable_returns_saturation(self) -> None: |
| 247 | + turns = np.array([2.0, 3.0, 4.0]) |
| 248 | + kls = np.array([0.5, 0.5, 0.5]) |
| 249 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 250 | + assert status == "stable" |
| 251 | + assert h == 40.0 # max_turns * 10 |
| 252 | + |
| 253 | + def test_growing_returns_non_monotonic(self) -> None: |
| 254 | + turns = np.array([2.0, 3.0, 4.0]) |
| 255 | + kls = np.array([0.1, 0.3, 0.9]) |
| 256 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 257 | + assert status == "non_monotonic" |
| 258 | + assert h is None |
| 259 | + |
| 260 | + def test_all_zero_returns_degenerate(self) -> None: |
| 261 | + turns = np.array([2.0, 3.0, 4.0]) |
| 262 | + kls = np.array([0.0, 0.0, 0.0]) |
| 263 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 264 | + assert status == "degenerate" |
| 265 | + assert h == 0.0 |
| 266 | + |
| 267 | + def test_partial_zero_drops_zero_points(self) -> None: |
| 268 | + """One zero-KL turn: drop it, fit on the remaining positives.""" |
| 269 | + turns = np.array([2.0, 3.0, 4.0]) |
| 270 | + kls = np.array([0.4, 0.0, 0.1]) |
| 271 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 272 | + assert status == "ok" |
| 273 | + assert h is not None |
| 274 | + assert h > 0.0 |
| 275 | + |
| 276 | + |
| 277 | +# --------------------------------------------------------------------------- |
| 278 | +# Verdict mapping (math-free) |
| 279 | +# --------------------------------------------------------------------------- |
| 280 | + |
| 281 | + |
| 282 | +class TestVerdictMapping: |
| 283 | + def test_pass_on_half_life_above_target(self) -> None: |
| 284 | + v, s, msg = _verdict_from_half_life( |
| 285 | + half_life=3.0, |
| 286 | + fit_status="ok", |
| 287 | + target=2.0, |
| 288 | + mean_kls=[0.4, 0.2, 0.1], |
| 289 | + turns_axis=[2.0, 3.0, 4.0], |
| 290 | + ) |
| 291 | + assert v == Verdict.PASS |
| 292 | + assert s == 1.0 |
| 293 | + assert "half-life=3.00" in msg |
| 294 | + |
| 295 | + def test_fail_on_half_life_below_target(self) -> None: |
| 296 | + v, s, msg = _verdict_from_half_life( |
| 297 | + half_life=0.5, |
| 298 | + fit_status="ok", |
| 299 | + target=2.0, |
| 300 | + mean_kls=[0.4, 0.1, 0.02], |
| 301 | + turns_axis=[2.0, 3.0, 4.0], |
| 302 | + ) |
| 303 | + assert v == Verdict.FAIL |
| 304 | + assert 0.0 < s < 1.0 |
| 305 | + |
| 306 | + |
| 307 | +# --------------------------------------------------------------------------- |
| 308 | +# Sparkline + fallback formatter |
| 309 | +# --------------------------------------------------------------------------- |
| 310 | + |
| 311 | + |
| 312 | +class TestSparkline: |
| 313 | + def test_renders_one_char_per_value(self) -> None: |
| 314 | + out = _ascii_sparkline([0.4, 0.2, 0.1]) |
| 315 | + assert len(out) == 3 |
| 316 | + # Decreasing input ⇒ first bar should be the tallest |
| 317 | + assert out[0] >= out[-1] |
| 318 | + |
| 319 | + def test_flat_input_renders_uniform_mid(self) -> None: |
| 320 | + out = _ascii_sparkline([0.5, 0.5, 0.5]) |
| 321 | + assert len(set(out)) == 1 # all the same bar |
| 322 | + |
| 323 | + def test_empty_input_returns_empty(self) -> None: |
| 324 | + assert _ascii_sparkline([]) == "" |
| 325 | + |
| 326 | + def test_non_finite_drops_and_marks(self) -> None: |
| 327 | + out = _ascii_sparkline([0.4, math.inf, 0.1]) |
| 328 | + assert out[1] == "?" |
| 329 | + |
| 330 | + |
| 331 | +class TestFallbackFormatter: |
| 332 | + def test_concatenates_with_role_markers(self) -> None: |
| 333 | + msg = _fallback_format( |
| 334 | + [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "yo"}], |
| 335 | + add_generation_prompt=True, |
| 336 | + ) |
| 337 | + assert "USER: hi" in msg |
| 338 | + assert "ASSISTANT: yo" in msg |
| 339 | + assert msg.endswith("ASSISTANT:") |
| 340 | + |
| 341 | + |
| 342 | +class TestChatTemplateDetection: |
| 343 | + def test_returns_none_when_tokenizer_is_none(self) -> None: |
| 344 | + assert _chat_template_of(None) is None |
| 345 | + |
| 346 | + def test_returns_none_when_no_attribute(self) -> None: |
| 347 | + class _T: |
| 348 | + pass |
| 349 | + |
| 350 | + assert _chat_template_of(_T()) is None |
| 351 | + |
| 352 | + def test_returns_template_when_set(self) -> None: |
| 353 | + assert _chat_template_of(_FakeTokenizer()) == "fake-jinja-template-string" |
| 354 | + |
| 355 | + |
| 356 | +# --------------------------------------------------------------------------- |
| 357 | +# Test infrastructure |
| 358 | +# --------------------------------------------------------------------------- |
| 359 | + |
| 360 | + |
| 361 | +def _backend_with_decreasing_dists( |
| 362 | + *, |
| 363 | + prompt: str, |
| 364 | + max_turns: int, |
| 365 | + sharpness_per_turn: list[float], |
| 366 | +) -> DummyDifferentialBackend: |
| 367 | + """Build a dummy backend whose per-turn ft TokenDists vary in sharpness. |
| 368 | + |
| 369 | + Base view always returns the same broad (≈uniform) dist. ft view |
| 370 | + returns a dist with a controllable sharpness per turn: larger |
| 371 | + sharpness ⇒ more peaked ⇒ larger KL from the broad base. The |
| 372 | + sharpness sequence drives the curve shape (decreasing, |
| 373 | + increasing, flat) without trying to plant exact KL values — that |
| 374 | + coupling proved fragile in the first cut. |
| 375 | + |
| 376 | + Plants the chat strings the probe will see by replaying the |
| 377 | + probe's per-prompt loop with the same fake tokenizer + the same |
| 378 | + follow-up cycle. |
| 379 | + """ |
| 380 | + if len(sharpness_per_turn) != max_turns - 1: |
| 381 | + raise ValueError(f"need {max_turns - 1} sharpness values, got {len(sharpness_per_turn)}") |
| 382 | + |
| 383 | + tok = _FakeTokenizer() |
| 384 | + follow_ups_default = [ |
| 385 | + "Continue.", |
| 386 | + "Tell me more.", |
| 387 | + "Can you elaborate?", |
| 388 | + "What else?", |
| 389 | + "Go deeper.", |
| 390 | + "Expand on that.", |
| 391 | + "And then?", |
| 392 | + ] |
| 393 | + |
| 394 | + base_dists: dict[str, TokenDist] = {} |
| 395 | + ft_dists: dict[str, TokenDist] = {} |
| 396 | + ft_gens: dict[str, str] = {} |
| 397 | + |
| 398 | + messages: list[dict[str, str]] = [{"role": "user", "content": prompt}] |
| 399 | + t1_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 400 | + ft_gens[t1_input] = f"ft-response-1 for {prompt}" |
| 401 | + messages.append({"role": "assistant", "content": ft_gens[t1_input]}) |
| 402 | + |
| 403 | + for turn_idx, sharpness in enumerate(sharpness_per_turn): |
| 404 | + follow_up = follow_ups_default[turn_idx % len(follow_ups_default)] |
| 405 | + messages.append({"role": "user", "content": follow_up}) |
| 406 | + chat_str = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 407 | + base_dists[chat_str] = _decay_dist(0.0, broad=True) |
| 408 | + ft_dists[chat_str] = _decay_dist(value=sharpness, broad=False) |
| 409 | + if turn_idx < len(sharpness_per_turn) - 1: |
| 410 | + ft_gens[chat_str] = f"ft-response-{turn_idx + 2} for {prompt}" |
| 411 | + messages.append({"role": "assistant", "content": ft_gens[chat_str]}) |
| 412 | + |
| 413 | + backend = DummyDifferentialBackend( |
| 414 | + base=DummyResponses(token_dists=base_dists), |
| 415 | + ft=DummyResponses(token_dists=ft_dists, generations=ft_gens), |
| 416 | + ) |
| 417 | + _attach_tokenizer(backend, tok) |
| 418 | + return backend |