| 1 | """Tests for :mod:`dlm_sway.probes.multi_turn_coherence`.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import math |
| 6 | |
| 7 | import numpy as np |
| 8 | |
| 9 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 10 | from dlm_sway.core.result import Verdict |
| 11 | from dlm_sway.core.scoring import TokenDist |
| 12 | from dlm_sway.probes.base import RunContext, build_probe |
| 13 | from dlm_sway.probes.multi_turn_coherence import ( |
| 14 | _ascii_sparkline, |
| 15 | _chat_template_of, |
| 16 | _fallback_format, |
| 17 | _fit_half_life_turns, |
| 18 | _verdict_from_half_life, |
| 19 | ) |
| 20 | |
| 21 | # --------------------------------------------------------------------------- |
| 22 | # Fake tokenizer + dist-injection helpers |
| 23 | # --------------------------------------------------------------------------- |
| 24 | |
| 25 | |
| 26 | class _FakeTokenizer: |
| 27 | """Minimal stand-in for an HF tokenizer with a chat_template. |
| 28 | |
| 29 | Implements just enough of the surface :mod:`multi_turn_coherence` |
| 30 | consults: the ``chat_template`` attribute and an |
| 31 | ``apply_chat_template`` method that returns a concatenated |
| 32 | role-marker string. Real HF chat templates are Jinja-rendered; |
| 33 | ours is a deterministic string concat so per-turn comparisons |
| 34 | are reproducible. |
| 35 | """ |
| 36 | |
| 37 | chat_template = "fake-jinja-template-string" |
| 38 | |
| 39 | def apply_chat_template( |
| 40 | self, |
| 41 | messages: list[dict[str, str]], |
| 42 | *, |
| 43 | tokenize: bool = False, |
| 44 | add_generation_prompt: bool = False, |
| 45 | ) -> str: |
| 46 | del tokenize # we always return text |
| 47 | parts = [f"{m['role']}::{m['content']}" for m in messages] |
| 48 | if add_generation_prompt: |
| 49 | parts.append("assistant::") |
| 50 | return "|".join(parts) |
| 51 | |
| 52 | |
| 53 | def _attach_tokenizer( |
| 54 | backend: DummyDifferentialBackend, tokenizer: object | None |
| 55 | ) -> DummyDifferentialBackend: |
| 56 | """Slot a tokenizer onto the dummy backend the same way HF does. |
| 57 | |
| 58 | multi_turn_coherence reads ``ctx.backend._tokenizer`` to find the |
| 59 | chat template — mirrors prompt_collapse's _peek_backend_tokenizer. |
| 60 | """ |
| 61 | backend._tokenizer = tokenizer # type: ignore[attr-defined] |
| 62 | return backend |
| 63 | |
| 64 | |
| 65 | def _decay_dist(value: float, *, broad: bool, k: int = 8) -> TokenDist: |
| 66 | """Build a TokenDist whose top-k logprobs encode a known signal. |
| 67 | |
| 68 | ``value`` controls how peaked the distribution is — feeding two |
| 69 | such dists into the divergence helper produces a deterministic |
| 70 | KL we can plant per-turn for the curve-fit tests. |
| 71 | """ |
| 72 | if broad: |
| 73 | # Roughly uniform with tiny perturbation (clears the |
| 74 | # _UNIFORM_LOGPROB_TOL guard). |
| 75 | lp = np.full(k, -math.log(k), dtype=np.float32) |
| 76 | lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32) |
| 77 | else: |
| 78 | # Sharp: most mass on first token, with `value` controlling |
| 79 | # how peaked. Larger value ⇒ sharper. |
| 80 | lp = np.array([-0.01 * value] + [-1.0 - value] * (k - 1), dtype=np.float32) |
| 81 | return TokenDist( |
| 82 | token_ids=np.arange(k, dtype=np.int64), |
| 83 | logprobs=lp, |
| 84 | vocab_size=1000, |
| 85 | tail_logprob=None, |
| 86 | ) |
| 87 | |
| 88 | |
| 89 | # --------------------------------------------------------------------------- |
| 90 | # Skip / error paths |
| 91 | # --------------------------------------------------------------------------- |
| 92 | |
| 93 | |
| 94 | class TestSkipPaths: |
| 95 | def test_skips_when_no_tokenizer(self) -> None: |
| 96 | """Dummy backend has no tokenizer ⇒ probe SKIPs cleanly.""" |
| 97 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 98 | probe, spec = build_probe( |
| 99 | { |
| 100 | "name": "mtc", |
| 101 | "kind": "multi_turn_coherence_decay", |
| 102 | "prompts": ["hello"], |
| 103 | } |
| 104 | ) |
| 105 | result = probe.run(spec, RunContext(backend=backend)) |
| 106 | assert result.verdict == Verdict.SKIP |
| 107 | assert "chat_template" in result.message |
| 108 | |
| 109 | def test_skips_when_tokenizer_lacks_chat_template(self) -> None: |
| 110 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 111 | |
| 112 | class _BareTokenizer: |
| 113 | chat_template = None |
| 114 | |
| 115 | _attach_tokenizer(backend, _BareTokenizer()) |
| 116 | probe, spec = build_probe( |
| 117 | { |
| 118 | "name": "mtc", |
| 119 | "kind": "multi_turn_coherence_decay", |
| 120 | "prompts": ["hello"], |
| 121 | } |
| 122 | ) |
| 123 | result = probe.run(spec, RunContext(backend=backend)) |
| 124 | assert result.verdict == Verdict.SKIP |
| 125 | |
| 126 | def test_errors_when_no_prompts(self) -> None: |
| 127 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 128 | _attach_tokenizer(backend, _FakeTokenizer()) |
| 129 | probe, spec = build_probe({"name": "mtc", "kind": "multi_turn_coherence_decay"}) |
| 130 | result = probe.run(spec, RunContext(backend=backend)) |
| 131 | assert result.verdict == Verdict.ERROR |
| 132 | |
| 133 | |
| 134 | # --------------------------------------------------------------------------- |
| 135 | # Happy path: planted decay → curve fit recovers half-life |
| 136 | # --------------------------------------------------------------------------- |
| 137 | |
| 138 | |
| 139 | class TestHappyPath: |
| 140 | """End-to-end probe runs against the dummy backend. |
| 141 | |
| 142 | Tests verify *wiring* — the probe correctly drives the chat loop, |
| 143 | asks for next-token-dists at the right positions, and produces a |
| 144 | finalized result with the documented evidence shape. The exact |
| 145 | KL values depend on the dummy backend's synthetic-distribution |
| 146 | math; tests assert the curve is monotonic / non-monotonic / flat |
| 147 | as intended rather than pinning specific KL targets. Curve-fit |
| 148 | correctness is covered separately in :class:`TestFitHalfLife`. |
| 149 | """ |
| 150 | |
| 151 | def test_decreasing_curve_yields_finite_half_life(self) -> None: |
| 152 | """Decreasing planted KLs → 'ok' fit status, finite half-life.""" |
| 153 | backend = _backend_with_decreasing_dists( |
| 154 | prompt="hello", max_turns=4, sharpness_per_turn=[3.0, 1.0, 0.3] |
| 155 | ) |
| 156 | probe, spec = build_probe( |
| 157 | { |
| 158 | "name": "mtc", |
| 159 | "kind": "multi_turn_coherence_decay", |
| 160 | "prompts": ["hello"], |
| 161 | "max_turns": 4, |
| 162 | "assert_half_life_turns": 0.5, |
| 163 | } |
| 164 | ) |
| 165 | result = probe.run(spec, RunContext(backend=backend)) |
| 166 | assert result.verdict in {Verdict.PASS, Verdict.FAIL}, result.message |
| 167 | assert result.evidence["fit_status"] in {"ok", "stable"} |
| 168 | per_turn = result.evidence["per_turn_kls"] |
| 169 | assert len(per_turn) == 3 |
| 170 | # Monotonic decrease — verifies the probe wired turns in order. |
| 171 | assert per_turn[0] > per_turn[1] > per_turn[2] |
| 172 | assert result.raw is not None |
| 173 | assert result.raw > 0.0 |
| 174 | |
| 175 | def test_flat_curve_marked_stable(self) -> None: |
| 176 | """Same dists at every turn → fit_status=stable → PASS.""" |
| 177 | backend = _backend_with_decreasing_dists( |
| 178 | prompt="hello", max_turns=4, sharpness_per_turn=[1.0, 1.0, 1.0] |
| 179 | ) |
| 180 | probe, spec = build_probe( |
| 181 | { |
| 182 | "name": "mtc", |
| 183 | "kind": "multi_turn_coherence_decay", |
| 184 | "prompts": ["hello"], |
| 185 | "max_turns": 4, |
| 186 | } |
| 187 | ) |
| 188 | result = probe.run(spec, RunContext(backend=backend)) |
| 189 | assert result.verdict == Verdict.PASS |
| 190 | assert result.evidence["fit_status"] == "stable" |
| 191 | assert result.score == 1.0 |
| 192 | assert "held coherence" in result.message |
| 193 | |
| 194 | def test_growing_curve_warns(self) -> None: |
| 195 | """Increasing planted KLs → fit_status=non_monotonic → WARN.""" |
| 196 | backend = _backend_with_decreasing_dists( |
| 197 | prompt="hello", max_turns=4, sharpness_per_turn=[0.3, 1.0, 3.0] |
| 198 | ) |
| 199 | probe, spec = build_probe( |
| 200 | { |
| 201 | "name": "mtc", |
| 202 | "kind": "multi_turn_coherence_decay", |
| 203 | "prompts": ["hello"], |
| 204 | "max_turns": 4, |
| 205 | } |
| 206 | ) |
| 207 | result = probe.run(spec, RunContext(backend=backend)) |
| 208 | assert result.verdict == Verdict.WARN |
| 209 | assert result.evidence["fit_status"] == "non_monotonic" |
| 210 | |
| 211 | def test_evidence_carries_turns_and_sparkline(self) -> None: |
| 212 | backend = _backend_with_decreasing_dists( |
| 213 | prompt="hello", max_turns=4, sharpness_per_turn=[2.0, 1.0, 0.5] |
| 214 | ) |
| 215 | probe, spec = build_probe( |
| 216 | { |
| 217 | "name": "mtc", |
| 218 | "kind": "multi_turn_coherence_decay", |
| 219 | "prompts": ["hello"], |
| 220 | "max_turns": 4, |
| 221 | } |
| 222 | ) |
| 223 | result = probe.run(spec, RunContext(backend=backend)) |
| 224 | assert result.evidence["turns_axis"] == [2.0, 3.0, 4.0] |
| 225 | assert result.evidence["max_turns"] == 4 |
| 226 | assert result.evidence["num_prompts"] == 1 |
| 227 | assert isinstance(result.evidence["sparkline"], str) |
| 228 | assert len(result.evidence["sparkline"]) == 3 |
| 229 | |
| 230 | |
| 231 | # --------------------------------------------------------------------------- |
| 232 | # Curve fit unit tests (math only, no backend) |
| 233 | # --------------------------------------------------------------------------- |
| 234 | |
| 235 | |
| 236 | class TestFitHalfLife: |
| 237 | def test_clean_exponential_recovers_half_life(self) -> None: |
| 238 | # y = exp(-0.693 * x) ⇒ half-life = 1 turn |
| 239 | turns = np.array([2.0, 3.0, 4.0]) |
| 240 | kls = np.exp(-math.log(2.0) * turns) |
| 241 | h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 242 | assert status == "ok" |
| 243 | assert h is not None |
| 244 | assert math.isclose(h, 1.0, rel_tol=1e-3) |
| 245 | |
| 246 | def test_stable_returns_saturation(self) -> None: |
| 247 | turns = np.array([2.0, 3.0, 4.0]) |
| 248 | kls = np.array([0.5, 0.5, 0.5]) |
| 249 | h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 250 | assert status == "stable" |
| 251 | assert h == 40.0 # max_turns * 10 |
| 252 | |
| 253 | def test_growing_returns_non_monotonic(self) -> None: |
| 254 | turns = np.array([2.0, 3.0, 4.0]) |
| 255 | kls = np.array([0.1, 0.3, 0.9]) |
| 256 | h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 257 | assert status == "non_monotonic" |
| 258 | assert h is None |
| 259 | |
| 260 | def test_all_zero_returns_degenerate(self) -> None: |
| 261 | turns = np.array([2.0, 3.0, 4.0]) |
| 262 | kls = np.array([0.0, 0.0, 0.0]) |
| 263 | h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 264 | assert status == "degenerate" |
| 265 | assert h == 0.0 |
| 266 | |
| 267 | def test_partial_zero_drops_zero_points(self) -> None: |
| 268 | """One zero-KL turn: drop it, fit on the remaining positives.""" |
| 269 | turns = np.array([2.0, 3.0, 4.0]) |
| 270 | kls = np.array([0.4, 0.0, 0.1]) |
| 271 | h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| 272 | assert status == "ok" |
| 273 | assert h is not None |
| 274 | assert h > 0.0 |
| 275 | |
| 276 | |
| 277 | # --------------------------------------------------------------------------- |
| 278 | # Verdict mapping (math-free) |
| 279 | # --------------------------------------------------------------------------- |
| 280 | |
| 281 | |
| 282 | class TestVerdictMapping: |
| 283 | def test_pass_on_half_life_above_target(self) -> None: |
| 284 | v, s, msg = _verdict_from_half_life( |
| 285 | half_life=3.0, |
| 286 | fit_status="ok", |
| 287 | target=2.0, |
| 288 | mean_kls=[0.4, 0.2, 0.1], |
| 289 | turns_axis=[2.0, 3.0, 4.0], |
| 290 | ) |
| 291 | assert v == Verdict.PASS |
| 292 | assert s == 1.0 |
| 293 | assert "half-life=3.00" in msg |
| 294 | |
| 295 | def test_fail_on_half_life_below_target(self) -> None: |
| 296 | v, s, msg = _verdict_from_half_life( |
| 297 | half_life=0.5, |
| 298 | fit_status="ok", |
| 299 | target=2.0, |
| 300 | mean_kls=[0.4, 0.1, 0.02], |
| 301 | turns_axis=[2.0, 3.0, 4.0], |
| 302 | ) |
| 303 | assert v == Verdict.FAIL |
| 304 | assert 0.0 < s < 1.0 |
| 305 | |
| 306 | |
| 307 | # --------------------------------------------------------------------------- |
| 308 | # Sparkline + fallback formatter |
| 309 | # --------------------------------------------------------------------------- |
| 310 | |
| 311 | |
| 312 | class TestSparkline: |
| 313 | def test_renders_one_char_per_value(self) -> None: |
| 314 | out = _ascii_sparkline([0.4, 0.2, 0.1]) |
| 315 | assert len(out) == 3 |
| 316 | # Decreasing input ⇒ first bar should be the tallest |
| 317 | assert out[0] >= out[-1] |
| 318 | |
| 319 | def test_flat_input_renders_uniform_mid(self) -> None: |
| 320 | out = _ascii_sparkline([0.5, 0.5, 0.5]) |
| 321 | assert len(set(out)) == 1 # all the same bar |
| 322 | |
| 323 | def test_empty_input_returns_empty(self) -> None: |
| 324 | assert _ascii_sparkline([]) == "" |
| 325 | |
| 326 | def test_non_finite_drops_and_marks(self) -> None: |
| 327 | out = _ascii_sparkline([0.4, math.inf, 0.1]) |
| 328 | assert out[1] == "?" |
| 329 | |
| 330 | |
| 331 | class TestFallbackFormatter: |
| 332 | def test_concatenates_with_role_markers(self) -> None: |
| 333 | msg = _fallback_format( |
| 334 | [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "yo"}], |
| 335 | add_generation_prompt=True, |
| 336 | ) |
| 337 | assert "USER: hi" in msg |
| 338 | assert "ASSISTANT: yo" in msg |
| 339 | assert msg.endswith("ASSISTANT:") |
| 340 | |
| 341 | |
| 342 | class TestChatTemplateDetection: |
| 343 | def test_returns_none_when_tokenizer_is_none(self) -> None: |
| 344 | assert _chat_template_of(None) is None |
| 345 | |
| 346 | def test_returns_none_when_no_attribute(self) -> None: |
| 347 | class _T: |
| 348 | pass |
| 349 | |
| 350 | assert _chat_template_of(_T()) is None |
| 351 | |
| 352 | def test_returns_template_when_set(self) -> None: |
| 353 | assert _chat_template_of(_FakeTokenizer()) == "fake-jinja-template-string" |
| 354 | |
| 355 | |
| 356 | # --------------------------------------------------------------------------- |
| 357 | # Test infrastructure |
| 358 | # --------------------------------------------------------------------------- |
| 359 | |
| 360 | |
| 361 | def _backend_with_decreasing_dists( |
| 362 | *, |
| 363 | prompt: str, |
| 364 | max_turns: int, |
| 365 | sharpness_per_turn: list[float], |
| 366 | ) -> DummyDifferentialBackend: |
| 367 | """Build a dummy backend whose per-turn ft TokenDists vary in sharpness. |
| 368 | |
| 369 | Base view always returns the same broad (≈uniform) dist. ft view |
| 370 | returns a dist with a controllable sharpness per turn: larger |
| 371 | sharpness ⇒ more peaked ⇒ larger KL from the broad base. The |
| 372 | sharpness sequence drives the curve shape (decreasing, |
| 373 | increasing, flat) without trying to plant exact KL values — that |
| 374 | coupling proved fragile in the first cut. |
| 375 | |
| 376 | Plants the chat strings the probe will see by replaying the |
| 377 | probe's per-prompt loop with the same fake tokenizer + the same |
| 378 | follow-up cycle. |
| 379 | """ |
| 380 | if len(sharpness_per_turn) != max_turns - 1: |
| 381 | raise ValueError(f"need {max_turns - 1} sharpness values, got {len(sharpness_per_turn)}") |
| 382 | |
| 383 | tok = _FakeTokenizer() |
| 384 | follow_ups_default = [ |
| 385 | "Continue.", |
| 386 | "Tell me more.", |
| 387 | "Can you elaborate?", |
| 388 | "What else?", |
| 389 | "Go deeper.", |
| 390 | "Expand on that.", |
| 391 | "And then?", |
| 392 | ] |
| 393 | |
| 394 | base_dists: dict[str, TokenDist] = {} |
| 395 | ft_dists: dict[str, TokenDist] = {} |
| 396 | ft_gens: dict[str, str] = {} |
| 397 | |
| 398 | messages: list[dict[str, str]] = [{"role": "user", "content": prompt}] |
| 399 | t1_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 400 | ft_gens[t1_input] = f"ft-response-1 for {prompt}" |
| 401 | messages.append({"role": "assistant", "content": ft_gens[t1_input]}) |
| 402 | |
| 403 | for turn_idx, sharpness in enumerate(sharpness_per_turn): |
| 404 | follow_up = follow_ups_default[turn_idx % len(follow_ups_default)] |
| 405 | messages.append({"role": "user", "content": follow_up}) |
| 406 | chat_str = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 407 | base_dists[chat_str] = _decay_dist(0.0, broad=True) |
| 408 | ft_dists[chat_str] = _decay_dist(value=sharpness, broad=False) |
| 409 | if turn_idx < len(sharpness_per_turn) - 1: |
| 410 | ft_gens[chat_str] = f"ft-response-{turn_idx + 2} for {prompt}" |
| 411 | messages.append({"role": "assistant", "content": ft_gens[chat_str]}) |
| 412 | |
| 413 | backend = DummyDifferentialBackend( |
| 414 | base=DummyResponses(token_dists=base_dists), |
| 415 | ft=DummyResponses(token_dists=ft_dists, generations=ft_gens), |
| 416 | ) |
| 417 | _attach_tokenizer(backend, tok) |
| 418 | return backend |