Python · 15665 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.multi_turn_coherence`."""
2
3 from __future__ import annotations
4
5 import math
6
7 import numpy as np
8
9 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
10 from dlm_sway.core.result import Verdict
11 from dlm_sway.core.scoring import TokenDist
12 from dlm_sway.probes.base import RunContext, build_probe
13 from dlm_sway.probes.multi_turn_coherence import (
14 _ascii_sparkline,
15 _chat_template_of,
16 _fallback_format,
17 _fit_half_life_turns,
18 _verdict_from_half_life,
19 )
20
21 # ---------------------------------------------------------------------------
22 # Fake tokenizer + dist-injection helpers
23 # ---------------------------------------------------------------------------
24
25
26 class _FakeTokenizer:
27 """Minimal stand-in for an HF tokenizer with a chat_template.
28
29 Implements just enough of the surface :mod:`multi_turn_coherence`
30 consults: the ``chat_template`` attribute and an
31 ``apply_chat_template`` method that returns a concatenated
32 role-marker string. Real HF chat templates are Jinja-rendered;
33 ours is a deterministic string concat so per-turn comparisons
34 are reproducible.
35 """
36
37 chat_template = "fake-jinja-template-string"
38
39 def apply_chat_template(
40 self,
41 messages: list[dict[str, str]],
42 *,
43 tokenize: bool = False,
44 add_generation_prompt: bool = False,
45 ) -> str:
46 del tokenize # we always return text
47 parts = [f"{m['role']}::{m['content']}" for m in messages]
48 if add_generation_prompt:
49 parts.append("assistant::")
50 return "|".join(parts)
51
52
53 def _attach_tokenizer(
54 backend: DummyDifferentialBackend, tokenizer: object | None
55 ) -> DummyDifferentialBackend:
56 """Slot a tokenizer onto the dummy backend the same way HF does.
57
58 multi_turn_coherence reads ``ctx.backend._tokenizer`` to find the
59 chat template — mirrors prompt_collapse's _peek_backend_tokenizer.
60 """
61 backend._tokenizer = tokenizer # type: ignore[attr-defined]
62 return backend
63
64
65 def _decay_dist(value: float, *, broad: bool, k: int = 8) -> TokenDist:
66 """Build a TokenDist whose top-k logprobs encode a known signal.
67
68 ``value`` controls how peaked the distribution is — feeding two
69 such dists into the divergence helper produces a deterministic
70 KL we can plant per-turn for the curve-fit tests.
71 """
72 if broad:
73 # Roughly uniform with tiny perturbation (clears the
74 # _UNIFORM_LOGPROB_TOL guard).
75 lp = np.full(k, -math.log(k), dtype=np.float32)
76 lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32)
77 else:
78 # Sharp: most mass on first token, with `value` controlling
79 # how peaked. Larger value ⇒ sharper.
80 lp = np.array([-0.01 * value] + [-1.0 - value] * (k - 1), dtype=np.float32)
81 return TokenDist(
82 token_ids=np.arange(k, dtype=np.int64),
83 logprobs=lp,
84 vocab_size=1000,
85 tail_logprob=None,
86 )
87
88
89 # ---------------------------------------------------------------------------
90 # Skip / error paths
91 # ---------------------------------------------------------------------------
92
93
94 class TestSkipPaths:
95 def test_skips_when_no_tokenizer(self) -> None:
96 """Dummy backend has no tokenizer ⇒ probe SKIPs cleanly."""
97 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
98 probe, spec = build_probe(
99 {
100 "name": "mtc",
101 "kind": "multi_turn_coherence_decay",
102 "prompts": ["hello"],
103 }
104 )
105 result = probe.run(spec, RunContext(backend=backend))
106 assert result.verdict == Verdict.SKIP
107 assert "chat_template" in result.message
108
109 def test_skips_when_tokenizer_lacks_chat_template(self) -> None:
110 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
111
112 class _BareTokenizer:
113 chat_template = None
114
115 _attach_tokenizer(backend, _BareTokenizer())
116 probe, spec = build_probe(
117 {
118 "name": "mtc",
119 "kind": "multi_turn_coherence_decay",
120 "prompts": ["hello"],
121 }
122 )
123 result = probe.run(spec, RunContext(backend=backend))
124 assert result.verdict == Verdict.SKIP
125
126 def test_errors_when_no_prompts(self) -> None:
127 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
128 _attach_tokenizer(backend, _FakeTokenizer())
129 probe, spec = build_probe({"name": "mtc", "kind": "multi_turn_coherence_decay"})
130 result = probe.run(spec, RunContext(backend=backend))
131 assert result.verdict == Verdict.ERROR
132
133
134 # ---------------------------------------------------------------------------
135 # Happy path: planted decay → curve fit recovers half-life
136 # ---------------------------------------------------------------------------
137
138
139 class TestHappyPath:
140 """End-to-end probe runs against the dummy backend.
141
142 Tests verify *wiring* — the probe correctly drives the chat loop,
143 asks for next-token-dists at the right positions, and produces a
144 finalized result with the documented evidence shape. The exact
145 KL values depend on the dummy backend's synthetic-distribution
146 math; tests assert the curve is monotonic / non-monotonic / flat
147 as intended rather than pinning specific KL targets. Curve-fit
148 correctness is covered separately in :class:`TestFitHalfLife`.
149 """
150
151 def test_decreasing_curve_yields_finite_half_life(self) -> None:
152 """Decreasing planted KLs → 'ok' fit status, finite half-life."""
153 backend = _backend_with_decreasing_dists(
154 prompt="hello", max_turns=4, sharpness_per_turn=[3.0, 1.0, 0.3]
155 )
156 probe, spec = build_probe(
157 {
158 "name": "mtc",
159 "kind": "multi_turn_coherence_decay",
160 "prompts": ["hello"],
161 "max_turns": 4,
162 "assert_half_life_turns": 0.5,
163 }
164 )
165 result = probe.run(spec, RunContext(backend=backend))
166 assert result.verdict in {Verdict.PASS, Verdict.FAIL}, result.message
167 assert result.evidence["fit_status"] in {"ok", "stable"}
168 per_turn = result.evidence["per_turn_kls"]
169 assert len(per_turn) == 3
170 # Monotonic decrease — verifies the probe wired turns in order.
171 assert per_turn[0] > per_turn[1] > per_turn[2]
172 assert result.raw is not None
173 assert result.raw > 0.0
174
175 def test_flat_curve_marked_stable(self) -> None:
176 """Same dists at every turn → fit_status=stable → PASS."""
177 backend = _backend_with_decreasing_dists(
178 prompt="hello", max_turns=4, sharpness_per_turn=[1.0, 1.0, 1.0]
179 )
180 probe, spec = build_probe(
181 {
182 "name": "mtc",
183 "kind": "multi_turn_coherence_decay",
184 "prompts": ["hello"],
185 "max_turns": 4,
186 }
187 )
188 result = probe.run(spec, RunContext(backend=backend))
189 assert result.verdict == Verdict.PASS
190 assert result.evidence["fit_status"] == "stable"
191 assert result.score == 1.0
192 assert "held coherence" in result.message
193
194 def test_growing_curve_warns(self) -> None:
195 """Increasing planted KLs → fit_status=non_monotonic → WARN."""
196 backend = _backend_with_decreasing_dists(
197 prompt="hello", max_turns=4, sharpness_per_turn=[0.3, 1.0, 3.0]
198 )
199 probe, spec = build_probe(
200 {
201 "name": "mtc",
202 "kind": "multi_turn_coherence_decay",
203 "prompts": ["hello"],
204 "max_turns": 4,
205 }
206 )
207 result = probe.run(spec, RunContext(backend=backend))
208 assert result.verdict == Verdict.WARN
209 assert result.evidence["fit_status"] == "non_monotonic"
210
211 def test_evidence_carries_turns_and_sparkline(self) -> None:
212 backend = _backend_with_decreasing_dists(
213 prompt="hello", max_turns=4, sharpness_per_turn=[2.0, 1.0, 0.5]
214 )
215 probe, spec = build_probe(
216 {
217 "name": "mtc",
218 "kind": "multi_turn_coherence_decay",
219 "prompts": ["hello"],
220 "max_turns": 4,
221 }
222 )
223 result = probe.run(spec, RunContext(backend=backend))
224 assert result.evidence["turns_axis"] == [2.0, 3.0, 4.0]
225 assert result.evidence["max_turns"] == 4
226 assert result.evidence["num_prompts"] == 1
227 assert isinstance(result.evidence["sparkline"], str)
228 assert len(result.evidence["sparkline"]) == 3
229
230
231 # ---------------------------------------------------------------------------
232 # Curve fit unit tests (math only, no backend)
233 # ---------------------------------------------------------------------------
234
235
236 class TestFitHalfLife:
237 def test_clean_exponential_recovers_half_life(self) -> None:
238 # y = exp(-0.693 * x) ⇒ half-life = 1 turn
239 turns = np.array([2.0, 3.0, 4.0])
240 kls = np.exp(-math.log(2.0) * turns)
241 h, status = _fit_half_life_turns(turns, kls, max_turns=4)
242 assert status == "ok"
243 assert h is not None
244 assert math.isclose(h, 1.0, rel_tol=1e-3)
245
246 def test_stable_returns_saturation(self) -> None:
247 turns = np.array([2.0, 3.0, 4.0])
248 kls = np.array([0.5, 0.5, 0.5])
249 h, status = _fit_half_life_turns(turns, kls, max_turns=4)
250 assert status == "stable"
251 assert h == 40.0 # max_turns * 10
252
253 def test_growing_returns_non_monotonic(self) -> None:
254 turns = np.array([2.0, 3.0, 4.0])
255 kls = np.array([0.1, 0.3, 0.9])
256 h, status = _fit_half_life_turns(turns, kls, max_turns=4)
257 assert status == "non_monotonic"
258 assert h is None
259
260 def test_all_zero_returns_degenerate(self) -> None:
261 turns = np.array([2.0, 3.0, 4.0])
262 kls = np.array([0.0, 0.0, 0.0])
263 h, status = _fit_half_life_turns(turns, kls, max_turns=4)
264 assert status == "degenerate"
265 assert h == 0.0
266
267 def test_partial_zero_drops_zero_points(self) -> None:
268 """One zero-KL turn: drop it, fit on the remaining positives."""
269 turns = np.array([2.0, 3.0, 4.0])
270 kls = np.array([0.4, 0.0, 0.1])
271 h, status = _fit_half_life_turns(turns, kls, max_turns=4)
272 assert status == "ok"
273 assert h is not None
274 assert h > 0.0
275
276
277 # ---------------------------------------------------------------------------
278 # Verdict mapping (math-free)
279 # ---------------------------------------------------------------------------
280
281
282 class TestVerdictMapping:
283 def test_pass_on_half_life_above_target(self) -> None:
284 v, s, msg = _verdict_from_half_life(
285 half_life=3.0,
286 fit_status="ok",
287 target=2.0,
288 mean_kls=[0.4, 0.2, 0.1],
289 turns_axis=[2.0, 3.0, 4.0],
290 )
291 assert v == Verdict.PASS
292 assert s == 1.0
293 assert "half-life=3.00" in msg
294
295 def test_fail_on_half_life_below_target(self) -> None:
296 v, s, msg = _verdict_from_half_life(
297 half_life=0.5,
298 fit_status="ok",
299 target=2.0,
300 mean_kls=[0.4, 0.1, 0.02],
301 turns_axis=[2.0, 3.0, 4.0],
302 )
303 assert v == Verdict.FAIL
304 assert 0.0 < s < 1.0
305
306
307 # ---------------------------------------------------------------------------
308 # Sparkline + fallback formatter
309 # ---------------------------------------------------------------------------
310
311
312 class TestSparkline:
313 def test_renders_one_char_per_value(self) -> None:
314 out = _ascii_sparkline([0.4, 0.2, 0.1])
315 assert len(out) == 3
316 # Decreasing input ⇒ first bar should be the tallest
317 assert out[0] >= out[-1]
318
319 def test_flat_input_renders_uniform_mid(self) -> None:
320 out = _ascii_sparkline([0.5, 0.5, 0.5])
321 assert len(set(out)) == 1 # all the same bar
322
323 def test_empty_input_returns_empty(self) -> None:
324 assert _ascii_sparkline([]) == ""
325
326 def test_non_finite_drops_and_marks(self) -> None:
327 out = _ascii_sparkline([0.4, math.inf, 0.1])
328 assert out[1] == "?"
329
330
331 class TestFallbackFormatter:
332 def test_concatenates_with_role_markers(self) -> None:
333 msg = _fallback_format(
334 [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "yo"}],
335 add_generation_prompt=True,
336 )
337 assert "USER: hi" in msg
338 assert "ASSISTANT: yo" in msg
339 assert msg.endswith("ASSISTANT:")
340
341
342 class TestChatTemplateDetection:
343 def test_returns_none_when_tokenizer_is_none(self) -> None:
344 assert _chat_template_of(None) is None
345
346 def test_returns_none_when_no_attribute(self) -> None:
347 class _T:
348 pass
349
350 assert _chat_template_of(_T()) is None
351
352 def test_returns_template_when_set(self) -> None:
353 assert _chat_template_of(_FakeTokenizer()) == "fake-jinja-template-string"
354
355
356 # ---------------------------------------------------------------------------
357 # Test infrastructure
358 # ---------------------------------------------------------------------------
359
360
361 def _backend_with_decreasing_dists(
362 *,
363 prompt: str,
364 max_turns: int,
365 sharpness_per_turn: list[float],
366 ) -> DummyDifferentialBackend:
367 """Build a dummy backend whose per-turn ft TokenDists vary in sharpness.
368
369 Base view always returns the same broad (≈uniform) dist. ft view
370 returns a dist with a controllable sharpness per turn: larger
371 sharpness ⇒ more peaked ⇒ larger KL from the broad base. The
372 sharpness sequence drives the curve shape (decreasing,
373 increasing, flat) without trying to plant exact KL values — that
374 coupling proved fragile in the first cut.
375
376 Plants the chat strings the probe will see by replaying the
377 probe's per-prompt loop with the same fake tokenizer + the same
378 follow-up cycle.
379 """
380 if len(sharpness_per_turn) != max_turns - 1:
381 raise ValueError(f"need {max_turns - 1} sharpness values, got {len(sharpness_per_turn)}")
382
383 tok = _FakeTokenizer()
384 follow_ups_default = [
385 "Continue.",
386 "Tell me more.",
387 "Can you elaborate?",
388 "What else?",
389 "Go deeper.",
390 "Expand on that.",
391 "And then?",
392 ]
393
394 base_dists: dict[str, TokenDist] = {}
395 ft_dists: dict[str, TokenDist] = {}
396 ft_gens: dict[str, str] = {}
397
398 messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
399 t1_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
400 ft_gens[t1_input] = f"ft-response-1 for {prompt}"
401 messages.append({"role": "assistant", "content": ft_gens[t1_input]})
402
403 for turn_idx, sharpness in enumerate(sharpness_per_turn):
404 follow_up = follow_ups_default[turn_idx % len(follow_ups_default)]
405 messages.append({"role": "user", "content": follow_up})
406 chat_str = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
407 base_dists[chat_str] = _decay_dist(0.0, broad=True)
408 ft_dists[chat_str] = _decay_dist(value=sharpness, broad=False)
409 if turn_idx < len(sharpness_per_turn) - 1:
410 ft_gens[chat_str] = f"ft-response-{turn_idx + 2} for {prompt}"
411 messages.append({"role": "assistant", "content": ft_gens[chat_str]})
412
413 backend = DummyDifferentialBackend(
414 base=DummyResponses(token_dists=base_dists),
415 ft=DummyResponses(token_dists=ft_dists, generations=ft_gens),
416 )
417 _attach_tokenizer(backend, tok)
418 return backend