tenseleyflow/sway / d955dfd

Browse files

tests/unit: 22 tests for multi_turn_coherence probe + curve-fit math

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
d955dfd4241120abd9ba029f0d815330a3a1ed64
Parents
ecca902
Tree
cbe3f97

1 changed file

StatusFile+-
A tests/unit/test_probe_multi_turn_coherence.py 418 0
tests/unit/test_probe_multi_turn_coherence.pyadded
@@ -0,0 +1,418 @@
1
+"""Tests for :mod:`dlm_sway.probes.multi_turn_coherence`."""
2
+
3
+from __future__ import annotations
4
+
5
+import math
6
+
7
+import numpy as np
8
+
9
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
10
+from dlm_sway.core.result import Verdict
11
+from dlm_sway.core.scoring import TokenDist
12
+from dlm_sway.probes.base import RunContext, build_probe
13
+from dlm_sway.probes.multi_turn_coherence import (
14
+    _ascii_sparkline,
15
+    _chat_template_of,
16
+    _fallback_format,
17
+    _fit_half_life_turns,
18
+    _verdict_from_half_life,
19
+)
20
+
21
+# ---------------------------------------------------------------------------
22
+# Fake tokenizer + dist-injection helpers
23
+# ---------------------------------------------------------------------------
24
+
25
+
26
+class _FakeTokenizer:
27
+    """Minimal stand-in for an HF tokenizer with a chat_template.
28
+
29
+    Implements just enough of the surface :mod:`multi_turn_coherence`
30
+    consults: the ``chat_template`` attribute and an
31
+    ``apply_chat_template`` method that returns a concatenated
32
+    role-marker string. Real HF chat templates are Jinja-rendered;
33
+    ours is a deterministic string concat so per-turn comparisons
34
+    are reproducible.
35
+    """
36
+
37
+    chat_template = "fake-jinja-template-string"
38
+
39
+    def apply_chat_template(
40
+        self,
41
+        messages: list[dict[str, str]],
42
+        *,
43
+        tokenize: bool = False,
44
+        add_generation_prompt: bool = False,
45
+    ) -> str:
46
+        del tokenize  # we always return text
47
+        parts = [f"{m['role']}::{m['content']}" for m in messages]
48
+        if add_generation_prompt:
49
+            parts.append("assistant::")
50
+        return "|".join(parts)
51
+
52
+
53
+def _attach_tokenizer(
54
+    backend: DummyDifferentialBackend, tokenizer: object | None
55
+) -> DummyDifferentialBackend:
56
+    """Slot a tokenizer onto the dummy backend the same way HF does.
57
+
58
+    multi_turn_coherence reads ``ctx.backend._tokenizer`` to find the
59
+    chat template — mirrors prompt_collapse's _peek_backend_tokenizer.
60
+    """
61
+    backend._tokenizer = tokenizer  # type: ignore[attr-defined]
62
+    return backend
63
+
64
+
65
+def _decay_dist(value: float, *, broad: bool, k: int = 8) -> TokenDist:
66
+    """Build a TokenDist whose top-k logprobs encode a known signal.
67
+
68
+    ``value`` controls how peaked the distribution is — feeding two
69
+    such dists into the divergence helper produces a deterministic
70
+    KL we can plant per-turn for the curve-fit tests.
71
+    """
72
+    if broad:
73
+        # Roughly uniform with tiny perturbation (clears the
74
+        # _UNIFORM_LOGPROB_TOL guard).
75
+        lp = np.full(k, -math.log(k), dtype=np.float32)
76
+        lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32)
77
+    else:
78
+        # Sharp: most mass on first token, with `value` controlling
79
+        # how peaked. Larger value ⇒ sharper.
80
+        lp = np.array([-0.01 * value] + [-1.0 - value] * (k - 1), dtype=np.float32)
81
+    return TokenDist(
82
+        token_ids=np.arange(k, dtype=np.int64),
83
+        logprobs=lp,
84
+        vocab_size=1000,
85
+        tail_logprob=None,
86
+    )
87
+
88
+
89
+# ---------------------------------------------------------------------------
90
+# Skip / error paths
91
+# ---------------------------------------------------------------------------
92
+
93
+
94
+class TestSkipPaths:
95
+    def test_skips_when_no_tokenizer(self) -> None:
96
+        """Dummy backend has no tokenizer ⇒ probe SKIPs cleanly."""
97
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
98
+        probe, spec = build_probe(
99
+            {
100
+                "name": "mtc",
101
+                "kind": "multi_turn_coherence_decay",
102
+                "prompts": ["hello"],
103
+            }
104
+        )
105
+        result = probe.run(spec, RunContext(backend=backend))
106
+        assert result.verdict == Verdict.SKIP
107
+        assert "chat_template" in result.message
108
+
109
+    def test_skips_when_tokenizer_lacks_chat_template(self) -> None:
110
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
111
+
112
+        class _BareTokenizer:
113
+            chat_template = None
114
+
115
+        _attach_tokenizer(backend, _BareTokenizer())
116
+        probe, spec = build_probe(
117
+            {
118
+                "name": "mtc",
119
+                "kind": "multi_turn_coherence_decay",
120
+                "prompts": ["hello"],
121
+            }
122
+        )
123
+        result = probe.run(spec, RunContext(backend=backend))
124
+        assert result.verdict == Verdict.SKIP
125
+
126
+    def test_errors_when_no_prompts(self) -> None:
127
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
128
+        _attach_tokenizer(backend, _FakeTokenizer())
129
+        probe, spec = build_probe({"name": "mtc", "kind": "multi_turn_coherence_decay"})
130
+        result = probe.run(spec, RunContext(backend=backend))
131
+        assert result.verdict == Verdict.ERROR
132
+
133
+
134
+# ---------------------------------------------------------------------------
135
+# Happy path: planted decay → curve fit recovers half-life
136
+# ---------------------------------------------------------------------------
137
+
138
+
139
+class TestHappyPath:
140
+    """End-to-end probe runs against the dummy backend.
141
+
142
+    Tests verify *wiring* — the probe correctly drives the chat loop,
143
+    asks for next-token-dists at the right positions, and produces a
144
+    finalized result with the documented evidence shape. The exact
145
+    KL values depend on the dummy backend's synthetic-distribution
146
+    math; tests assert the curve is monotonic / non-monotonic / flat
147
+    as intended rather than pinning specific KL targets. Curve-fit
148
+    correctness is covered separately in :class:`TestFitHalfLife`.
149
+    """
150
+
151
+    def test_decreasing_curve_yields_finite_half_life(self) -> None:
152
+        """Decreasing planted KLs → 'ok' fit status, finite half-life."""
153
+        backend = _backend_with_decreasing_dists(
154
+            prompt="hello", max_turns=4, sharpness_per_turn=[3.0, 1.0, 0.3]
155
+        )
156
+        probe, spec = build_probe(
157
+            {
158
+                "name": "mtc",
159
+                "kind": "multi_turn_coherence_decay",
160
+                "prompts": ["hello"],
161
+                "max_turns": 4,
162
+                "assert_half_life_turns": 0.5,
163
+            }
164
+        )
165
+        result = probe.run(spec, RunContext(backend=backend))
166
+        assert result.verdict in {Verdict.PASS, Verdict.FAIL}, result.message
167
+        assert result.evidence["fit_status"] in {"ok", "stable"}
168
+        per_turn = result.evidence["per_turn_kls"]
169
+        assert len(per_turn) == 3
170
+        # Monotonic decrease — verifies the probe wired turns in order.
171
+        assert per_turn[0] > per_turn[1] > per_turn[2]
172
+        assert result.raw is not None
173
+        assert result.raw > 0.0
174
+
175
+    def test_flat_curve_marked_stable(self) -> None:
176
+        """Same dists at every turn → fit_status=stable → PASS."""
177
+        backend = _backend_with_decreasing_dists(
178
+            prompt="hello", max_turns=4, sharpness_per_turn=[1.0, 1.0, 1.0]
179
+        )
180
+        probe, spec = build_probe(
181
+            {
182
+                "name": "mtc",
183
+                "kind": "multi_turn_coherence_decay",
184
+                "prompts": ["hello"],
185
+                "max_turns": 4,
186
+            }
187
+        )
188
+        result = probe.run(spec, RunContext(backend=backend))
189
+        assert result.verdict == Verdict.PASS
190
+        assert result.evidence["fit_status"] == "stable"
191
+        assert result.score == 1.0
192
+        assert "held coherence" in result.message
193
+
194
+    def test_growing_curve_warns(self) -> None:
195
+        """Increasing planted KLs → fit_status=non_monotonic → WARN."""
196
+        backend = _backend_with_decreasing_dists(
197
+            prompt="hello", max_turns=4, sharpness_per_turn=[0.3, 1.0, 3.0]
198
+        )
199
+        probe, spec = build_probe(
200
+            {
201
+                "name": "mtc",
202
+                "kind": "multi_turn_coherence_decay",
203
+                "prompts": ["hello"],
204
+                "max_turns": 4,
205
+            }
206
+        )
207
+        result = probe.run(spec, RunContext(backend=backend))
208
+        assert result.verdict == Verdict.WARN
209
+        assert result.evidence["fit_status"] == "non_monotonic"
210
+
211
+    def test_evidence_carries_turns_and_sparkline(self) -> None:
212
+        backend = _backend_with_decreasing_dists(
213
+            prompt="hello", max_turns=4, sharpness_per_turn=[2.0, 1.0, 0.5]
214
+        )
215
+        probe, spec = build_probe(
216
+            {
217
+                "name": "mtc",
218
+                "kind": "multi_turn_coherence_decay",
219
+                "prompts": ["hello"],
220
+                "max_turns": 4,
221
+            }
222
+        )
223
+        result = probe.run(spec, RunContext(backend=backend))
224
+        assert result.evidence["turns_axis"] == [2.0, 3.0, 4.0]
225
+        assert result.evidence["max_turns"] == 4
226
+        assert result.evidence["num_prompts"] == 1
227
+        assert isinstance(result.evidence["sparkline"], str)
228
+        assert len(result.evidence["sparkline"]) == 3
229
+
230
+
231
+# ---------------------------------------------------------------------------
232
+# Curve fit unit tests (math only, no backend)
233
+# ---------------------------------------------------------------------------
234
+
235
+
236
+class TestFitHalfLife:
237
+    def test_clean_exponential_recovers_half_life(self) -> None:
238
+        # y = exp(-0.693 * x) ⇒ half-life = 1 turn
239
+        turns = np.array([2.0, 3.0, 4.0])
240
+        kls = np.exp(-math.log(2.0) * turns)
241
+        h, status = _fit_half_life_turns(turns, kls, max_turns=4)
242
+        assert status == "ok"
243
+        assert h is not None
244
+        assert math.isclose(h, 1.0, rel_tol=1e-3)
245
+
246
+    def test_stable_returns_saturation(self) -> None:
247
+        turns = np.array([2.0, 3.0, 4.0])
248
+        kls = np.array([0.5, 0.5, 0.5])
249
+        h, status = _fit_half_life_turns(turns, kls, max_turns=4)
250
+        assert status == "stable"
251
+        assert h == 40.0  # max_turns * 10
252
+
253
+    def test_growing_returns_non_monotonic(self) -> None:
254
+        turns = np.array([2.0, 3.0, 4.0])
255
+        kls = np.array([0.1, 0.3, 0.9])
256
+        h, status = _fit_half_life_turns(turns, kls, max_turns=4)
257
+        assert status == "non_monotonic"
258
+        assert h is None
259
+
260
+    def test_all_zero_returns_degenerate(self) -> None:
261
+        turns = np.array([2.0, 3.0, 4.0])
262
+        kls = np.array([0.0, 0.0, 0.0])
263
+        h, status = _fit_half_life_turns(turns, kls, max_turns=4)
264
+        assert status == "degenerate"
265
+        assert h == 0.0
266
+
267
+    def test_partial_zero_drops_zero_points(self) -> None:
268
+        """One zero-KL turn: drop it, fit on the remaining positives."""
269
+        turns = np.array([2.0, 3.0, 4.0])
270
+        kls = np.array([0.4, 0.0, 0.1])
271
+        h, status = _fit_half_life_turns(turns, kls, max_turns=4)
272
+        assert status == "ok"
273
+        assert h is not None
274
+        assert h > 0.0
275
+
276
+
277
+# ---------------------------------------------------------------------------
278
+# Verdict mapping (math-free)
279
+# ---------------------------------------------------------------------------
280
+
281
+
282
+class TestVerdictMapping:
283
+    def test_pass_on_half_life_above_target(self) -> None:
284
+        v, s, msg = _verdict_from_half_life(
285
+            half_life=3.0,
286
+            fit_status="ok",
287
+            target=2.0,
288
+            mean_kls=[0.4, 0.2, 0.1],
289
+            turns_axis=[2.0, 3.0, 4.0],
290
+        )
291
+        assert v == Verdict.PASS
292
+        assert s == 1.0
293
+        assert "half-life=3.00" in msg
294
+
295
+    def test_fail_on_half_life_below_target(self) -> None:
296
+        v, s, msg = _verdict_from_half_life(
297
+            half_life=0.5,
298
+            fit_status="ok",
299
+            target=2.0,
300
+            mean_kls=[0.4, 0.1, 0.02],
301
+            turns_axis=[2.0, 3.0, 4.0],
302
+        )
303
+        assert v == Verdict.FAIL
304
+        assert 0.0 < s < 1.0
305
+
306
+
307
+# ---------------------------------------------------------------------------
308
+# Sparkline + fallback formatter
309
+# ---------------------------------------------------------------------------
310
+
311
+
312
+class TestSparkline:
313
+    def test_renders_one_char_per_value(self) -> None:
314
+        out = _ascii_sparkline([0.4, 0.2, 0.1])
315
+        assert len(out) == 3
316
+        # Decreasing input ⇒ first bar should be the tallest
317
+        assert out[0] >= out[-1]
318
+
319
+    def test_flat_input_renders_uniform_mid(self) -> None:
320
+        out = _ascii_sparkline([0.5, 0.5, 0.5])
321
+        assert len(set(out)) == 1  # all the same bar
322
+
323
+    def test_empty_input_returns_empty(self) -> None:
324
+        assert _ascii_sparkline([]) == ""
325
+
326
+    def test_non_finite_drops_and_marks(self) -> None:
327
+        out = _ascii_sparkline([0.4, math.inf, 0.1])
328
+        assert out[1] == "?"
329
+
330
+
331
+class TestFallbackFormatter:
332
+    def test_concatenates_with_role_markers(self) -> None:
333
+        msg = _fallback_format(
334
+            [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "yo"}],
335
+            add_generation_prompt=True,
336
+        )
337
+        assert "USER: hi" in msg
338
+        assert "ASSISTANT: yo" in msg
339
+        assert msg.endswith("ASSISTANT:")
340
+
341
+
342
+class TestChatTemplateDetection:
343
+    def test_returns_none_when_tokenizer_is_none(self) -> None:
344
+        assert _chat_template_of(None) is None
345
+
346
+    def test_returns_none_when_no_attribute(self) -> None:
347
+        class _T:
348
+            pass
349
+
350
+        assert _chat_template_of(_T()) is None
351
+
352
+    def test_returns_template_when_set(self) -> None:
353
+        assert _chat_template_of(_FakeTokenizer()) == "fake-jinja-template-string"
354
+
355
+
356
+# ---------------------------------------------------------------------------
357
+# Test infrastructure
358
+# ---------------------------------------------------------------------------
359
+
360
+
361
+def _backend_with_decreasing_dists(
362
+    *,
363
+    prompt: str,
364
+    max_turns: int,
365
+    sharpness_per_turn: list[float],
366
+) -> DummyDifferentialBackend:
367
+    """Build a dummy backend whose per-turn ft TokenDists vary in sharpness.
368
+
369
+    Base view always returns the same broad (≈uniform) dist. ft view
370
+    returns a dist with a controllable sharpness per turn: larger
371
+    sharpness ⇒ more peaked ⇒ larger KL from the broad base. The
372
+    sharpness sequence drives the curve shape (decreasing,
373
+    increasing, flat) without trying to plant exact KL values — that
374
+    coupling proved fragile in the first cut.
375
+
376
+    Plants the chat strings the probe will see by replaying the
377
+    probe's per-prompt loop with the same fake tokenizer + the same
378
+    follow-up cycle.
379
+    """
380
+    if len(sharpness_per_turn) != max_turns - 1:
381
+        raise ValueError(f"need {max_turns - 1} sharpness values, got {len(sharpness_per_turn)}")
382
+
383
+    tok = _FakeTokenizer()
384
+    follow_ups_default = [
385
+        "Continue.",
386
+        "Tell me more.",
387
+        "Can you elaborate?",
388
+        "What else?",
389
+        "Go deeper.",
390
+        "Expand on that.",
391
+        "And then?",
392
+    ]
393
+
394
+    base_dists: dict[str, TokenDist] = {}
395
+    ft_dists: dict[str, TokenDist] = {}
396
+    ft_gens: dict[str, str] = {}
397
+
398
+    messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
399
+    t1_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
400
+    ft_gens[t1_input] = f"ft-response-1 for {prompt}"
401
+    messages.append({"role": "assistant", "content": ft_gens[t1_input]})
402
+
403
+    for turn_idx, sharpness in enumerate(sharpness_per_turn):
404
+        follow_up = follow_ups_default[turn_idx % len(follow_ups_default)]
405
+        messages.append({"role": "user", "content": follow_up})
406
+        chat_str = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
407
+        base_dists[chat_str] = _decay_dist(0.0, broad=True)
408
+        ft_dists[chat_str] = _decay_dist(value=sharpness, broad=False)
409
+        if turn_idx < len(sharpness_per_turn) - 1:
410
+            ft_gens[chat_str] = f"ft-response-{turn_idx + 2} for {prompt}"
411
+            messages.append({"role": "assistant", "content": ft_gens[chat_str]})
412
+
413
+    backend = DummyDifferentialBackend(
414
+        base=DummyResponses(token_dists=base_dists),
415
+        ft=DummyResponses(token_dists=ft_dists, generations=ft_gens),
416
+    )
417
+    _attach_tokenizer(backend, tok)
418
+    return backend