@@ -0,0 +1,418 @@ |
| | 1 | +"""Tests for :mod:`dlm_sway.probes.multi_turn_coherence`.""" |
| | 2 | + |
| | 3 | +from __future__ import annotations |
| | 4 | + |
| | 5 | +import math |
| | 6 | + |
| | 7 | +import numpy as np |
| | 8 | + |
| | 9 | +from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| | 10 | +from dlm_sway.core.result import Verdict |
| | 11 | +from dlm_sway.core.scoring import TokenDist |
| | 12 | +from dlm_sway.probes.base import RunContext, build_probe |
| | 13 | +from dlm_sway.probes.multi_turn_coherence import ( |
| | 14 | + _ascii_sparkline, |
| | 15 | + _chat_template_of, |
| | 16 | + _fallback_format, |
| | 17 | + _fit_half_life_turns, |
| | 18 | + _verdict_from_half_life, |
| | 19 | +) |
| | 20 | + |
| | 21 | +# --------------------------------------------------------------------------- |
| | 22 | +# Fake tokenizer + dist-injection helpers |
| | 23 | +# --------------------------------------------------------------------------- |
| | 24 | + |
| | 25 | + |
| | 26 | +class _FakeTokenizer: |
| | 27 | + """Minimal stand-in for an HF tokenizer with a chat_template. |
| | 28 | + |
| | 29 | + Implements just enough of the surface :mod:`multi_turn_coherence` |
| | 30 | + consults: the ``chat_template`` attribute and an |
| | 31 | + ``apply_chat_template`` method that returns a concatenated |
| | 32 | + role-marker string. Real HF chat templates are Jinja-rendered; |
| | 33 | + ours is a deterministic string concat so per-turn comparisons |
| | 34 | + are reproducible. |
| | 35 | + """ |
| | 36 | + |
| | 37 | + chat_template = "fake-jinja-template-string" |
| | 38 | + |
| | 39 | + def apply_chat_template( |
| | 40 | + self, |
| | 41 | + messages: list[dict[str, str]], |
| | 42 | + *, |
| | 43 | + tokenize: bool = False, |
| | 44 | + add_generation_prompt: bool = False, |
| | 45 | + ) -> str: |
| | 46 | + del tokenize # we always return text |
| | 47 | + parts = [f"{m['role']}::{m['content']}" for m in messages] |
| | 48 | + if add_generation_prompt: |
| | 49 | + parts.append("assistant::") |
| | 50 | + return "|".join(parts) |
| | 51 | + |
| | 52 | + |
| | 53 | +def _attach_tokenizer( |
| | 54 | + backend: DummyDifferentialBackend, tokenizer: object | None |
| | 55 | +) -> DummyDifferentialBackend: |
| | 56 | + """Slot a tokenizer onto the dummy backend the same way HF does. |
| | 57 | + |
| | 58 | + multi_turn_coherence reads ``ctx.backend._tokenizer`` to find the |
| | 59 | + chat template — mirrors prompt_collapse's _peek_backend_tokenizer. |
| | 60 | + """ |
| | 61 | + backend._tokenizer = tokenizer # type: ignore[attr-defined] |
| | 62 | + return backend |
| | 63 | + |
| | 64 | + |
| | 65 | +def _decay_dist(value: float, *, broad: bool, k: int = 8) -> TokenDist: |
| | 66 | + """Build a TokenDist whose top-k logprobs encode a known signal. |
| | 67 | + |
| | 68 | + ``value`` controls how peaked the distribution is — feeding two |
| | 69 | + such dists into the divergence helper produces a deterministic |
| | 70 | + KL we can plant per-turn for the curve-fit tests. |
| | 71 | + """ |
| | 72 | + if broad: |
| | 73 | + # Roughly uniform with tiny perturbation (clears the |
| | 74 | + # _UNIFORM_LOGPROB_TOL guard). |
| | 75 | + lp = np.full(k, -math.log(k), dtype=np.float32) |
| | 76 | + lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32) |
| | 77 | + else: |
| | 78 | + # Sharp: most mass on first token, with `value` controlling |
| | 79 | + # how peaked. Larger value ⇒ sharper. |
| | 80 | + lp = np.array([-0.01 * value] + [-1.0 - value] * (k - 1), dtype=np.float32) |
| | 81 | + return TokenDist( |
| | 82 | + token_ids=np.arange(k, dtype=np.int64), |
| | 83 | + logprobs=lp, |
| | 84 | + vocab_size=1000, |
| | 85 | + tail_logprob=None, |
| | 86 | + ) |
| | 87 | + |
| | 88 | + |
| | 89 | +# --------------------------------------------------------------------------- |
| | 90 | +# Skip / error paths |
| | 91 | +# --------------------------------------------------------------------------- |
| | 92 | + |
| | 93 | + |
| | 94 | +class TestSkipPaths: |
| | 95 | + def test_skips_when_no_tokenizer(self) -> None: |
| | 96 | + """Dummy backend has no tokenizer ⇒ probe SKIPs cleanly.""" |
| | 97 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| | 98 | + probe, spec = build_probe( |
| | 99 | + { |
| | 100 | + "name": "mtc", |
| | 101 | + "kind": "multi_turn_coherence_decay", |
| | 102 | + "prompts": ["hello"], |
| | 103 | + } |
| | 104 | + ) |
| | 105 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 106 | + assert result.verdict == Verdict.SKIP |
| | 107 | + assert "chat_template" in result.message |
| | 108 | + |
| | 109 | + def test_skips_when_tokenizer_lacks_chat_template(self) -> None: |
| | 110 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| | 111 | + |
| | 112 | + class _BareTokenizer: |
| | 113 | + chat_template = None |
| | 114 | + |
| | 115 | + _attach_tokenizer(backend, _BareTokenizer()) |
| | 116 | + probe, spec = build_probe( |
| | 117 | + { |
| | 118 | + "name": "mtc", |
| | 119 | + "kind": "multi_turn_coherence_decay", |
| | 120 | + "prompts": ["hello"], |
| | 121 | + } |
| | 122 | + ) |
| | 123 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 124 | + assert result.verdict == Verdict.SKIP |
| | 125 | + |
| | 126 | + def test_errors_when_no_prompts(self) -> None: |
| | 127 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| | 128 | + _attach_tokenizer(backend, _FakeTokenizer()) |
| | 129 | + probe, spec = build_probe({"name": "mtc", "kind": "multi_turn_coherence_decay"}) |
| | 130 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 131 | + assert result.verdict == Verdict.ERROR |
| | 132 | + |
| | 133 | + |
| | 134 | +# --------------------------------------------------------------------------- |
| | 135 | +# Happy path: planted decay → curve fit recovers half-life |
| | 136 | +# --------------------------------------------------------------------------- |
| | 137 | + |
| | 138 | + |
| | 139 | +class TestHappyPath: |
| | 140 | + """End-to-end probe runs against the dummy backend. |
| | 141 | + |
| | 142 | + Tests verify *wiring* — the probe correctly drives the chat loop, |
| | 143 | + asks for next-token-dists at the right positions, and produces a |
| | 144 | + finalized result with the documented evidence shape. The exact |
| | 145 | + KL values depend on the dummy backend's synthetic-distribution |
| | 146 | + math; tests assert the curve is monotonic / non-monotonic / flat |
| | 147 | + as intended rather than pinning specific KL targets. Curve-fit |
| | 148 | + correctness is covered separately in :class:`TestFitHalfLife`. |
| | 149 | + """ |
| | 150 | + |
| | 151 | + def test_decreasing_curve_yields_finite_half_life(self) -> None: |
| | 152 | + """Decreasing planted KLs → 'ok' fit status, finite half-life.""" |
| | 153 | + backend = _backend_with_decreasing_dists( |
| | 154 | + prompt="hello", max_turns=4, sharpness_per_turn=[3.0, 1.0, 0.3] |
| | 155 | + ) |
| | 156 | + probe, spec = build_probe( |
| | 157 | + { |
| | 158 | + "name": "mtc", |
| | 159 | + "kind": "multi_turn_coherence_decay", |
| | 160 | + "prompts": ["hello"], |
| | 161 | + "max_turns": 4, |
| | 162 | + "assert_half_life_turns": 0.5, |
| | 163 | + } |
| | 164 | + ) |
| | 165 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 166 | + assert result.verdict in {Verdict.PASS, Verdict.FAIL}, result.message |
| | 167 | + assert result.evidence["fit_status"] in {"ok", "stable"} |
| | 168 | + per_turn = result.evidence["per_turn_kls"] |
| | 169 | + assert len(per_turn) == 3 |
| | 170 | + # Monotonic decrease — verifies the probe wired turns in order. |
| | 171 | + assert per_turn[0] > per_turn[1] > per_turn[2] |
| | 172 | + assert result.raw is not None |
| | 173 | + assert result.raw > 0.0 |
| | 174 | + |
| | 175 | + def test_flat_curve_marked_stable(self) -> None: |
| | 176 | + """Same dists at every turn → fit_status=stable → PASS.""" |
| | 177 | + backend = _backend_with_decreasing_dists( |
| | 178 | + prompt="hello", max_turns=4, sharpness_per_turn=[1.0, 1.0, 1.0] |
| | 179 | + ) |
| | 180 | + probe, spec = build_probe( |
| | 181 | + { |
| | 182 | + "name": "mtc", |
| | 183 | + "kind": "multi_turn_coherence_decay", |
| | 184 | + "prompts": ["hello"], |
| | 185 | + "max_turns": 4, |
| | 186 | + } |
| | 187 | + ) |
| | 188 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 189 | + assert result.verdict == Verdict.PASS |
| | 190 | + assert result.evidence["fit_status"] == "stable" |
| | 191 | + assert result.score == 1.0 |
| | 192 | + assert "held coherence" in result.message |
| | 193 | + |
| | 194 | + def test_growing_curve_warns(self) -> None: |
| | 195 | + """Increasing planted KLs → fit_status=non_monotonic → WARN.""" |
| | 196 | + backend = _backend_with_decreasing_dists( |
| | 197 | + prompt="hello", max_turns=4, sharpness_per_turn=[0.3, 1.0, 3.0] |
| | 198 | + ) |
| | 199 | + probe, spec = build_probe( |
| | 200 | + { |
| | 201 | + "name": "mtc", |
| | 202 | + "kind": "multi_turn_coherence_decay", |
| | 203 | + "prompts": ["hello"], |
| | 204 | + "max_turns": 4, |
| | 205 | + } |
| | 206 | + ) |
| | 207 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 208 | + assert result.verdict == Verdict.WARN |
| | 209 | + assert result.evidence["fit_status"] == "non_monotonic" |
| | 210 | + |
| | 211 | + def test_evidence_carries_turns_and_sparkline(self) -> None: |
| | 212 | + backend = _backend_with_decreasing_dists( |
| | 213 | + prompt="hello", max_turns=4, sharpness_per_turn=[2.0, 1.0, 0.5] |
| | 214 | + ) |
| | 215 | + probe, spec = build_probe( |
| | 216 | + { |
| | 217 | + "name": "mtc", |
| | 218 | + "kind": "multi_turn_coherence_decay", |
| | 219 | + "prompts": ["hello"], |
| | 220 | + "max_turns": 4, |
| | 221 | + } |
| | 222 | + ) |
| | 223 | + result = probe.run(spec, RunContext(backend=backend)) |
| | 224 | + assert result.evidence["turns_axis"] == [2.0, 3.0, 4.0] |
| | 225 | + assert result.evidence["max_turns"] == 4 |
| | 226 | + assert result.evidence["num_prompts"] == 1 |
| | 227 | + assert isinstance(result.evidence["sparkline"], str) |
| | 228 | + assert len(result.evidence["sparkline"]) == 3 |
| | 229 | + |
| | 230 | + |
| | 231 | +# --------------------------------------------------------------------------- |
| | 232 | +# Curve fit unit tests (math only, no backend) |
| | 233 | +# --------------------------------------------------------------------------- |
| | 234 | + |
| | 235 | + |
| | 236 | +class TestFitHalfLife: |
| | 237 | + def test_clean_exponential_recovers_half_life(self) -> None: |
| | 238 | + # y = exp(-0.693 * x) ⇒ half-life = 1 turn |
| | 239 | + turns = np.array([2.0, 3.0, 4.0]) |
| | 240 | + kls = np.exp(-math.log(2.0) * turns) |
| | 241 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| | 242 | + assert status == "ok" |
| | 243 | + assert h is not None |
| | 244 | + assert math.isclose(h, 1.0, rel_tol=1e-3) |
| | 245 | + |
| | 246 | + def test_stable_returns_saturation(self) -> None: |
| | 247 | + turns = np.array([2.0, 3.0, 4.0]) |
| | 248 | + kls = np.array([0.5, 0.5, 0.5]) |
| | 249 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| | 250 | + assert status == "stable" |
| | 251 | + assert h == 40.0 # max_turns * 10 |
| | 252 | + |
| | 253 | + def test_growing_returns_non_monotonic(self) -> None: |
| | 254 | + turns = np.array([2.0, 3.0, 4.0]) |
| | 255 | + kls = np.array([0.1, 0.3, 0.9]) |
| | 256 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| | 257 | + assert status == "non_monotonic" |
| | 258 | + assert h is None |
| | 259 | + |
| | 260 | + def test_all_zero_returns_degenerate(self) -> None: |
| | 261 | + turns = np.array([2.0, 3.0, 4.0]) |
| | 262 | + kls = np.array([0.0, 0.0, 0.0]) |
| | 263 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| | 264 | + assert status == "degenerate" |
| | 265 | + assert h == 0.0 |
| | 266 | + |
| | 267 | + def test_partial_zero_drops_zero_points(self) -> None: |
| | 268 | + """One zero-KL turn: drop it, fit on the remaining positives.""" |
| | 269 | + turns = np.array([2.0, 3.0, 4.0]) |
| | 270 | + kls = np.array([0.4, 0.0, 0.1]) |
| | 271 | + h, status = _fit_half_life_turns(turns, kls, max_turns=4) |
| | 272 | + assert status == "ok" |
| | 273 | + assert h is not None |
| | 274 | + assert h > 0.0 |
| | 275 | + |
| | 276 | + |
| | 277 | +# --------------------------------------------------------------------------- |
| | 278 | +# Verdict mapping (math-free) |
| | 279 | +# --------------------------------------------------------------------------- |
| | 280 | + |
| | 281 | + |
| | 282 | +class TestVerdictMapping: |
| | 283 | + def test_pass_on_half_life_above_target(self) -> None: |
| | 284 | + v, s, msg = _verdict_from_half_life( |
| | 285 | + half_life=3.0, |
| | 286 | + fit_status="ok", |
| | 287 | + target=2.0, |
| | 288 | + mean_kls=[0.4, 0.2, 0.1], |
| | 289 | + turns_axis=[2.0, 3.0, 4.0], |
| | 290 | + ) |
| | 291 | + assert v == Verdict.PASS |
| | 292 | + assert s == 1.0 |
| | 293 | + assert "half-life=3.00" in msg |
| | 294 | + |
| | 295 | + def test_fail_on_half_life_below_target(self) -> None: |
| | 296 | + v, s, msg = _verdict_from_half_life( |
| | 297 | + half_life=0.5, |
| | 298 | + fit_status="ok", |
| | 299 | + target=2.0, |
| | 300 | + mean_kls=[0.4, 0.1, 0.02], |
| | 301 | + turns_axis=[2.0, 3.0, 4.0], |
| | 302 | + ) |
| | 303 | + assert v == Verdict.FAIL |
| | 304 | + assert 0.0 < s < 1.0 |
| | 305 | + |
| | 306 | + |
| | 307 | +# --------------------------------------------------------------------------- |
| | 308 | +# Sparkline + fallback formatter |
| | 309 | +# --------------------------------------------------------------------------- |
| | 310 | + |
| | 311 | + |
| | 312 | +class TestSparkline: |
| | 313 | + def test_renders_one_char_per_value(self) -> None: |
| | 314 | + out = _ascii_sparkline([0.4, 0.2, 0.1]) |
| | 315 | + assert len(out) == 3 |
| | 316 | + # Decreasing input ⇒ first bar should be the tallest |
| | 317 | + assert out[0] >= out[-1] |
| | 318 | + |
| | 319 | + def test_flat_input_renders_uniform_mid(self) -> None: |
| | 320 | + out = _ascii_sparkline([0.5, 0.5, 0.5]) |
| | 321 | + assert len(set(out)) == 1 # all the same bar |
| | 322 | + |
| | 323 | + def test_empty_input_returns_empty(self) -> None: |
| | 324 | + assert _ascii_sparkline([]) == "" |
| | 325 | + |
| | 326 | + def test_non_finite_drops_and_marks(self) -> None: |
| | 327 | + out = _ascii_sparkline([0.4, math.inf, 0.1]) |
| | 328 | + assert out[1] == "?" |
| | 329 | + |
| | 330 | + |
| | 331 | +class TestFallbackFormatter: |
| | 332 | + def test_concatenates_with_role_markers(self) -> None: |
| | 333 | + msg = _fallback_format( |
| | 334 | + [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "yo"}], |
| | 335 | + add_generation_prompt=True, |
| | 336 | + ) |
| | 337 | + assert "USER: hi" in msg |
| | 338 | + assert "ASSISTANT: yo" in msg |
| | 339 | + assert msg.endswith("ASSISTANT:") |
| | 340 | + |
| | 341 | + |
| | 342 | +class TestChatTemplateDetection: |
| | 343 | + def test_returns_none_when_tokenizer_is_none(self) -> None: |
| | 344 | + assert _chat_template_of(None) is None |
| | 345 | + |
| | 346 | + def test_returns_none_when_no_attribute(self) -> None: |
| | 347 | + class _T: |
| | 348 | + pass |
| | 349 | + |
| | 350 | + assert _chat_template_of(_T()) is None |
| | 351 | + |
| | 352 | + def test_returns_template_when_set(self) -> None: |
| | 353 | + assert _chat_template_of(_FakeTokenizer()) == "fake-jinja-template-string" |
| | 354 | + |
| | 355 | + |
| | 356 | +# --------------------------------------------------------------------------- |
| | 357 | +# Test infrastructure |
| | 358 | +# --------------------------------------------------------------------------- |
| | 359 | + |
| | 360 | + |
| | 361 | +def _backend_with_decreasing_dists( |
| | 362 | + *, |
| | 363 | + prompt: str, |
| | 364 | + max_turns: int, |
| | 365 | + sharpness_per_turn: list[float], |
| | 366 | +) -> DummyDifferentialBackend: |
| | 367 | + """Build a dummy backend whose per-turn ft TokenDists vary in sharpness. |
| | 368 | + |
| | 369 | + Base view always returns the same broad (≈uniform) dist. ft view |
| | 370 | + returns a dist with a controllable sharpness per turn: larger |
| | 371 | + sharpness ⇒ more peaked ⇒ larger KL from the broad base. The |
| | 372 | + sharpness sequence drives the curve shape (decreasing, |
| | 373 | + increasing, flat) without trying to plant exact KL values — that |
| | 374 | + coupling proved fragile in the first cut. |
| | 375 | + |
| | 376 | + Plants the chat strings the probe will see by replaying the |
| | 377 | + probe's per-prompt loop with the same fake tokenizer + the same |
| | 378 | + follow-up cycle. |
| | 379 | + """ |
| | 380 | + if len(sharpness_per_turn) != max_turns - 1: |
| | 381 | + raise ValueError(f"need {max_turns - 1} sharpness values, got {len(sharpness_per_turn)}") |
| | 382 | + |
| | 383 | + tok = _FakeTokenizer() |
| | 384 | + follow_ups_default = [ |
| | 385 | + "Continue.", |
| | 386 | + "Tell me more.", |
| | 387 | + "Can you elaborate?", |
| | 388 | + "What else?", |
| | 389 | + "Go deeper.", |
| | 390 | + "Expand on that.", |
| | 391 | + "And then?", |
| | 392 | + ] |
| | 393 | + |
| | 394 | + base_dists: dict[str, TokenDist] = {} |
| | 395 | + ft_dists: dict[str, TokenDist] = {} |
| | 396 | + ft_gens: dict[str, str] = {} |
| | 397 | + |
| | 398 | + messages: list[dict[str, str]] = [{"role": "user", "content": prompt}] |
| | 399 | + t1_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | 400 | + ft_gens[t1_input] = f"ft-response-1 for {prompt}" |
| | 401 | + messages.append({"role": "assistant", "content": ft_gens[t1_input]}) |
| | 402 | + |
| | 403 | + for turn_idx, sharpness in enumerate(sharpness_per_turn): |
| | 404 | + follow_up = follow_ups_default[turn_idx % len(follow_ups_default)] |
| | 405 | + messages.append({"role": "user", "content": follow_up}) |
| | 406 | + chat_str = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | 407 | + base_dists[chat_str] = _decay_dist(0.0, broad=True) |
| | 408 | + ft_dists[chat_str] = _decay_dist(value=sharpness, broad=False) |
| | 409 | + if turn_idx < len(sharpness_per_turn) - 1: |
| | 410 | + ft_gens[chat_str] = f"ft-response-{turn_idx + 2} for {prompt}" |
| | 411 | + messages.append({"role": "assistant", "content": ft_gens[chat_str]}) |
| | 412 | + |
| | 413 | + backend = DummyDifferentialBackend( |
| | 414 | + base=DummyResponses(token_dists=base_dists), |
| | 415 | + ft=DummyResponses(token_dists=ft_dists, generations=ft_gens), |
| | 416 | + ) |
| | 417 | + _attach_tokenizer(backend, tok) |
| | 418 | + return backend |