tenseleyflow/sway / 051144b

Browse files

probes/multi_turn_coherence: KL decay across dialogue turns + exp half-life fit

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
051144b1cf3131ce326d2cb1d60321313836b2ed
Parents
a86a866
Tree
cbe9fd7

1 changed file

StatusFile+-
A src/dlm_sway/probes/multi_turn_coherence.py 425 0
src/dlm_sway/probes/multi_turn_coherence.pyadded
@@ -0,0 +1,425 @@
1
+"""Multi-turn coherence decay — does the adapter survive a multi-turn dialogue?
2
+
3
+Every other adherence probe is single-turn: one user message, one
4
+completion, one score. Adapters that pass single-turn probes
5
+frequently "forget their training" by turn 2 or 3 of real dialogue,
6
+where the model's own previous responses enter the context window
7
+and create compounding drift. No other shipped sway probe catches
8
+that failure mode.
9
+
10
+The probe rolls a multi-turn synthetic dialogue per prompt:
11
+
12
+1. Generate ft's turn-1 response greedily.
13
+2. Build a turn-2 chat history `[user=prompt, asst=ft_t1,
14
+   user=follow_up_2]` and compute `KL(base || ft)` at turn 2.
15
+3. Extend with ft's turn-2 response, build turn-3 history, score.
16
+4. Repeat through `max_turns`.
17
+5. Fit `kl = a · exp(-b · turn)` over turns 2..N; report
18
+   `half_life_turns = ln(2) / b`.
19
+
20
+Lower half-life ⇒ adapter influence evaporates faster as dialogue
21
+deepens. A "stable" adapter (near-flat KL across turns) reports a
22
+saturated half-life with a "stable" marker rather than `inf`.
23
+
24
+## Why no null calibration
25
+
26
+Mirrors :mod:`prompt_collapse`: a null adapter has random-noise
27
+weights with no real signal to decay. Its turn-1 greedy output is
28
+essentially gibberish, fed back through the chat template makes the
29
+per-turn KLs meaningless, and the resulting half-life distribution
30
+is undefined. Fixed-threshold verdicts are the published path.
31
+
32
+## Chat-template requirement
33
+
34
+Multi-turn requires the base's chat template to format turns. The
35
+probe consults the backend's tokenizer for `chat_template`; bases
36
+without one (raw completion models) SKIP gracefully with a clear
37
+reason via :class:`Verdict.SKIP`.
38
+
39
+Tests that exercise the math path inject a minimal fake tokenizer
40
+on the dummy backend — see ``tests/unit/test_probe_multi_turn_coherence.py``.
41
+"""
42
+
43
+from __future__ import annotations
44
+
45
+import math
46
+from typing import Any, Literal
47
+
48
+import numpy as np
49
+from pydantic import Field
50
+
51
+from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
52
+from dlm_sway.probes._divergence import Divergence, divergence
53
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
54
+
55
+
56
+class MultiTurnCoherenceSpec(ProbeSpec):
57
+    """Spec for ``kind: multi_turn_coherence_decay``."""
58
+
59
+    kind: Literal["multi_turn_coherence_decay"] = "multi_turn_coherence_decay"
60
+    prompts: list[str] = Field(default_factory=list, min_length=0)
61
+    """Inline turn-1 user messages. Empty list → probe ERRORs (the
62
+    .dlm autogen path doesn't yet seed multi-turn cases — when it
63
+    does, this path will mirror :mod:`delta_kl`'s prompts_from)."""
64
+    max_turns: int = Field(default=4, ge=2, le=8)
65
+    """How many dialogue turns to roll. Minimum 2 (otherwise this is
66
+    just :mod:`delta_kl`); cap 8 to keep the probe fast on real
67
+    backends — each additional turn requires another greedy
68
+    generation + a backend toggle pair."""
69
+    max_new_tokens: int = 96
70
+    """Greedy decode budget for ft's per-turn response. 96 is
71
+    deliberately conservative so a long-winded adapter doesn't blow
72
+    out the context budget on later turns."""
73
+    follow_ups: list[str] = Field(
74
+        default_factory=lambda: [
75
+            "Continue.",
76
+            "Tell me more.",
77
+            "Can you elaborate?",
78
+            "What else?",
79
+            "Go deeper.",
80
+            "Expand on that.",
81
+            "And then?",
82
+        ]
83
+    )
84
+    """Generic per-turn follow-up prompts cycled through to drive the
85
+    dialogue forward. Cycled so any ``max_turns`` works without
86
+    needing per-prompt customization. Future: prefer doc-author-
87
+    written follow-ups when ``ctx.sections`` carries instruction
88
+    blocks with explicit multi-turn structure."""
89
+    divergence: Divergence = "kl"
90
+    """Per-turn divergence metric. ``kl`` is the convention here
91
+    because we want a directional measure (`base || ft`); ``js``
92
+    is symmetric."""
93
+    top_k: int | None = None
94
+    assert_half_life_turns: float = 2.0
95
+    """Pass criterion: adapter influence persists past turn 2 by
96
+    at least one half-life. Tune upward for adapters that need to
97
+    hold over longer conversations."""
98
+
99
+    # No null-calibration fields (no assert_z_gte) — see module docstring.
100
+
101
+
102
+class MultiTurnCoherenceProbe(Probe):
103
+    """The "did the adapter forget its training by turn 3?" probe."""
104
+
105
+    kind = "multi_turn_coherence_decay"
106
+    spec_cls = MultiTurnCoherenceSpec
107
+    category = "adherence"
108
+
109
+    # As noted in the module docstring: a null adapter has no
110
+    # coherence to decay, so a null distribution of half_life_turns
111
+    # is meaningless. Skip the calibration handshake entirely.
112
+
113
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
114
+        assert isinstance(spec, MultiTurnCoherenceSpec)
115
+        if not spec.prompts:
116
+            return ProbeResult(
117
+                name=spec.name,
118
+                kind=spec.kind,
119
+                verdict=Verdict.ERROR,
120
+                score=None,
121
+                message="no prompts provided (inline 'prompts' was empty)",
122
+            )
123
+
124
+        tokenizer = _peek_backend_tokenizer(ctx)
125
+        chat_template = _chat_template_of(tokenizer)
126
+        if chat_template is None:
127
+            return ProbeResult(
128
+                name=spec.name,
129
+                kind=spec.kind,
130
+                verdict=Verdict.SKIP,
131
+                score=None,
132
+                message="base has no chat_template; multi-turn dialogue requires one",
133
+            )
134
+
135
+        top_k = spec.top_k if spec.top_k is not None else ctx.top_k
136
+        # Per-turn KLs aggregated across prompts. ``per_turn_kls[t]``
137
+        # is the mean KL at turn t+2 (turn 1 is the seed; turns 2..N
138
+        # are the scored positions).
139
+        per_turn_kls: list[list[float]] = [[] for _ in range(spec.max_turns - 1)]
140
+
141
+        for prompt in spec.prompts:
142
+            # Turn 1: ft generates the seed assistant response that
143
+            # populates the dialogue history. Generated under ft only —
144
+            # base never sees its own turn-1; both views share ft's
145
+            # history, which is the load-bearing design choice.
146
+            messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
147
+            with ctx.require_backend.as_finetuned() as fv:
148
+                ft_t1 = fv.generate(
149
+                    _format_chat(messages, tokenizer, add_generation_prompt=True),
150
+                    max_new_tokens=spec.max_new_tokens,
151
+                )
152
+            messages.append({"role": "assistant", "content": ft_t1})
153
+
154
+            for turn_idx in range(spec.max_turns - 1):
155
+                # Append the per-turn user follow-up (cycled).
156
+                follow_up = spec.follow_ups[turn_idx % len(spec.follow_ups)]
157
+                messages.append({"role": "user", "content": follow_up})
158
+                # Score next-token dist at the current turn under both views.
159
+                chat_str = _format_chat(messages, tokenizer, add_generation_prompt=True)
160
+                with ctx.require_backend.as_base() as bv:
161
+                    base_dist = bv.next_token_dist(chat_str, top_k=top_k)
162
+                with ctx.require_backend.as_finetuned() as fv:
163
+                    ft_dist = fv.next_token_dist(chat_str, top_k=top_k)
164
+                per_turn_kls[turn_idx].append(divergence(base_dist, ft_dist, kind=spec.divergence))
165
+                # Extend history with ft's response for the next iteration.
166
+                # Skip on the final turn — saves one generation call.
167
+                if turn_idx < spec.max_turns - 2:
168
+                    with ctx.require_backend.as_finetuned() as fv:
169
+                        ft_response = fv.generate(chat_str, max_new_tokens=spec.max_new_tokens)
170
+                    messages.append({"role": "assistant", "content": ft_response})
171
+
172
+        # Mean KL per turn (across prompts).
173
+        mean_kls: list[float] = [float(np.mean(turn_kls)) for turn_kls in per_turn_kls]
174
+        # Turn axis: 2, 3, 4, ..., max_turns. Turn 1 isn't scored
175
+        # (it's the seed); turns 2..N are the curve points.
176
+        turns_axis = np.asarray(list(range(2, spec.max_turns + 1)), dtype=np.float64)
177
+        kls_axis = np.asarray(mean_kls, dtype=np.float64)
178
+
179
+        half_life, fit_status = _fit_half_life_turns(turns_axis, kls_axis, max_turns=spec.max_turns)
180
+
181
+        verdict, score, message = _verdict_from_half_life(
182
+            half_life=half_life,
183
+            fit_status=fit_status,
184
+            target=spec.assert_half_life_turns,
185
+            mean_kls=mean_kls,
186
+            turns_axis=turns_axis.tolist(),
187
+        )
188
+        return safe_finalize(
189
+            name=spec.name,
190
+            kind=spec.kind,
191
+            verdict=verdict,
192
+            score=score,
193
+            raw=half_life if half_life is not None and math.isfinite(half_life) else None,
194
+            evidence={
195
+                "per_turn_kls": mean_kls,
196
+                "turns_axis": turns_axis.tolist(),
197
+                "fit_status": fit_status,
198
+                "divergence_kind": spec.divergence,
199
+                "max_turns": spec.max_turns,
200
+                "num_prompts": len(spec.prompts),
201
+                "weight": spec.weight,
202
+                "sparkline": _ascii_sparkline(mean_kls),
203
+            },
204
+            message=message,
205
+        )
206
+
207
+
208
+# ---------------------------------------------------------------------------
209
+# Curve fit + verdict logic
210
+# ---------------------------------------------------------------------------
211
+
212
+
213
+def _fit_half_life_turns(
214
+    turns: np.ndarray, kls: np.ndarray, *, max_turns: int
215
+) -> tuple[float | None, str]:
216
+    """Fit ``kl = a * exp(-b * turn)`` via log-space linear regression.
217
+
218
+    Returns ``(half_life_turns, status)``.
219
+
220
+    Statuses:
221
+    - ``"ok"``: a clean exponential fit produced a finite half-life.
222
+    - ``"stable"``: KL stayed near-flat across turns. Half-life is
223
+      formally infinite; we clip to ``max_turns * 10`` so the report
224
+      doesn't print ``inf``. The probe interprets stable as
225
+      "adapter held coherence" — passing.
226
+    - ``"non_monotonic"``: KL grew with turn count (adapter becoming
227
+      *more* distinct as dialogue deepens — physically possible but
228
+      atypical). The half-life concept doesn't apply; we surface the
229
+      curve as evidence and let the user judge.
230
+    - ``"degenerate"``: KL was zero / negative at every turn. The
231
+      adapter is producing identical-to-base distributions; we
232
+      report a half-life of 0 to flag a likely no-op adapter.
233
+    """
234
+    # All-zero / non-positive KLs ⇒ no signal at any turn ⇒ probable no-op.
235
+    if not (kls > 0.0).all():
236
+        # Mixed positive + zero: still try the fit on the positives if there
237
+        # are at least 2 of them. All-zero / all-negative ⇒ degenerate.
238
+        positive_mask = kls > 0.0
239
+        if positive_mask.sum() < 2:
240
+            return 0.0, "degenerate"
241
+        turns = turns[positive_mask]
242
+        kls = kls[positive_mask]
243
+
244
+    # Flat-curve detection runs *before* the fit: if the KLs sit
245
+    # within a tight relative band of the mean, the slope is ~0 and
246
+    # the half-life is formally infinite. Pre-detecting this avoids
247
+    # an awkward "slope == 0 vs slope < epsilon" boundary inside the
248
+    # regression branch.
249
+    kl_mean = float(kls.mean())
250
+    if kl_mean > 0.0:
251
+        relative_spread = float((kls.max() - kls.min()) / kl_mean)
252
+        if relative_spread < 1e-3:
253
+            return float(max_turns) * 10.0, "stable"
254
+
255
+    log_y = np.log(kls)
256
+    x_mean = float(turns.mean())
257
+    y_mean = float(log_y.mean())
258
+    denom = float(((turns - x_mean) ** 2).sum())
259
+    # All turns identical (impossible for our turn-axis but defensive).
260
+    if denom == 0.0:
261
+        return None, "non_monotonic"
262
+    slope = float(((turns - x_mean) * (log_y - y_mean)).sum()) / denom
263
+    if slope > 0.0:
264
+        # KL grew with turn — adapter became more distinct. Could
265
+        # mean genuine multi-turn personality emergence; leave half-
266
+        # life undefined and flag in the status.
267
+        return None, "non_monotonic"
268
+
269
+    half_life = float(math.log(2.0) * (-1.0 / slope))
270
+    # B11 — a near-stable adapter has a tiny negative slope; the fit
271
+    # produces an enormous half-life. Clip + label so the report
272
+    # doesn't print ``inf`` or implausible values.
273
+    saturation = float(max_turns) * 10.0
274
+    if half_life >= saturation or not math.isfinite(half_life):
275
+        return saturation, "stable"
276
+    return half_life, "ok"
277
+
278
+
279
+def _verdict_from_half_life(
280
+    *,
281
+    half_life: float | None,
282
+    fit_status: str,
283
+    target: float,
284
+    mean_kls: list[float],
285
+    turns_axis: list[float],
286
+) -> tuple[Verdict, float, str]:
287
+    """Translate the curve-fit result into a (verdict, score, message)."""
288
+    sparkline = _ascii_sparkline(mean_kls)
289
+
290
+    if fit_status == "non_monotonic":
291
+        # KL grew with turn — atypical but not necessarily wrong.
292
+        # Surface as WARN with the curve; user judgment from there.
293
+        return (
294
+            Verdict.WARN,
295
+            0.5,
296
+            f"non-monotonic KL across turns (adapter grew more distinct); curve {sparkline}",
297
+        )
298
+
299
+    if fit_status == "degenerate":
300
+        return (
301
+            Verdict.FAIL,
302
+            0.0,
303
+            f"KL is zero across all turns (probable no-op adapter); curve {sparkline}",
304
+        )
305
+
306
+    if half_life is None:
307
+        return (
308
+            Verdict.ERROR,
309
+            0.0,
310
+            f"could not fit decay curve (status={fit_status}); curve {sparkline}",
311
+        )
312
+
313
+    if fit_status == "stable":
314
+        return (
315
+            Verdict.PASS,
316
+            1.0,
317
+            f"adapter held coherence across all {int(turns_axis[-1])} turns; curve {sparkline}",
318
+        )
319
+
320
+    # fit_status == "ok"
321
+    passed = half_life >= target
322
+    score = float(min(1.0, half_life / max(target, 1e-6)))
323
+    return (
324
+        Verdict.PASS if passed else Verdict.FAIL,
325
+        score,
326
+        f"half-life={half_life:.2f} turns ({'≥' if passed else '<'} {target}); curve {sparkline}",
327
+    )
328
+
329
+
330
+# ---------------------------------------------------------------------------
331
+# Chat-template / tokenizer helpers
332
+# ---------------------------------------------------------------------------
333
+
334
+
335
+def _peek_backend_tokenizer(ctx: RunContext) -> Any | None:
336
+    """Same trick :mod:`prompt_collapse` uses: backends store the
337
+    tokenizer at ``_tokenizer``. Returns ``None`` for the dummy
338
+    backend (which has no tokenizer)."""
339
+    return getattr(ctx.backend, "_tokenizer", None)
340
+
341
+
342
+def _chat_template_of(tokenizer: Any | None) -> str | None:
343
+    """Return the tokenizer's chat_template string, or ``None``.
344
+
345
+    HF tokenizers expose a ``chat_template`` attribute that's either
346
+    a Jinja string or ``None``. Some custom tokenizers don't have the
347
+    attribute at all — covered by the ``getattr`` default.
348
+    """
349
+    if tokenizer is None:
350
+        return None
351
+    template = getattr(tokenizer, "chat_template", None)
352
+    if template is None or not isinstance(template, str) or not template.strip():
353
+        return None
354
+    return str(template)
355
+
356
+
357
+def _format_chat(
358
+    messages: list[dict[str, str]],
359
+    tokenizer: Any,
360
+    *,
361
+    add_generation_prompt: bool,
362
+) -> str:
363
+    """Render `messages` via the tokenizer's chat template.
364
+
365
+    Caller has already verified the tokenizer carries a chat_template
366
+    (via :func:`_chat_template_of`); a ``None`` tokenizer here is a
367
+    bug, not a runtime fallback path. We still defensively fall
368
+    back to a minimal role-marker concatenation so the dummy
369
+    backend's tests can drive the probe end-to-end without standing
370
+    up a real tokenizer.
371
+    """
372
+    if tokenizer is None:
373
+        return _fallback_format(messages, add_generation_prompt=add_generation_prompt)
374
+    try:
375
+        out = tokenizer.apply_chat_template(
376
+            messages, tokenize=False, add_generation_prompt=add_generation_prompt
377
+        )
378
+        return str(out)
379
+    except Exception:  # noqa: BLE001 — tokenizer impls vary; never let a probe crash on this
380
+        return _fallback_format(messages, add_generation_prompt=add_generation_prompt)
381
+
382
+
383
+def _fallback_format(messages: list[dict[str, str]], *, add_generation_prompt: bool) -> str:
384
+    """Minimal role-marker formatter for tokenizers that misbehave + tests."""
385
+    parts: list[str] = []
386
+    for m in messages:
387
+        parts.append(f"{m['role'].upper()}: {m['content']}")
388
+    if add_generation_prompt:
389
+        parts.append("ASSISTANT:")
390
+    return "\n".join(parts)
391
+
392
+
393
+# ---------------------------------------------------------------------------
394
+# Report sparkline
395
+# ---------------------------------------------------------------------------
396
+
397
+
398
+_SPARKLINE_BARS = "▁▂▃▄▅▆▇█"
399
+
400
+
401
+def _ascii_sparkline(values: list[float]) -> str:
402
+    """Compact unicode sparkline of per-turn KL.
403
+
404
+    Surfaces in the report message so a terminal-only reader gets a
405
+    sense of curve shape without opening the JSON. Empty/degenerate
406
+    inputs render as an empty string.
407
+    """
408
+    if not values:
409
+        return ""
410
+    finite = [v for v in values if math.isfinite(v) and v >= 0.0]
411
+    if not finite:
412
+        return ""
413
+    lo, hi = min(finite), max(finite)
414
+    span = hi - lo
415
+    if span <= 1e-12:
416
+        # Flat line — every bar at mid-height.
417
+        return _SPARKLINE_BARS[len(_SPARKLINE_BARS) // 2] * len(values)
418
+    out: list[str] = []
419
+    for v in values:
420
+        if not math.isfinite(v) or v < 0.0:
421
+            out.append("?")
422
+            continue
423
+        idx = int((v - lo) / span * (len(_SPARKLINE_BARS) - 1))
424
+        out.append(_SPARKLINE_BARS[max(0, min(idx, len(_SPARKLINE_BARS) - 1))])
425
+    return "".join(out)