Python · 17393 bytes Raw Blame History
1 """Multi-turn coherence decay — does the adapter survive a multi-turn dialogue?
2
3 Every other adherence probe is single-turn: one user message, one
4 completion, one score. Adapters that pass single-turn probes
5 frequently "forget their training" by turn 2 or 3 of real dialogue,
6 where the model's own previous responses enter the context window
7 and create compounding drift. No other shipped sway probe catches
8 that failure mode.
9
10 The probe rolls a multi-turn synthetic dialogue per prompt:
11
12 1. Generate ft's turn-1 response greedily.
13 2. Build a turn-2 chat history `[user=prompt, asst=ft_t1,
14 user=follow_up_2]` and compute `KL(base || ft)` at turn 2.
15 3. Extend with ft's turn-2 response, build turn-3 history, score.
16 4. Repeat through `max_turns`.
17 5. Fit `kl = a · exp(-b · turn)` over turns 2..N; report
18 `half_life_turns = ln(2) / b`.
19
20 Lower half-life ⇒ adapter influence evaporates faster as dialogue
21 deepens. A "stable" adapter (near-flat KL across turns) reports a
22 saturated half-life with a "stable" marker rather than `inf`.
23
24 ## Why no null calibration
25
26 Mirrors :mod:`prompt_collapse`: a null adapter has random-noise
27 weights with no real signal to decay. Its turn-1 greedy output is
28 essentially gibberish, fed back through the chat template makes the
29 per-turn KLs meaningless, and the resulting half-life distribution
30 is undefined. Fixed-threshold verdicts are the published path.
31
32 ## Chat-template requirement
33
34 Multi-turn requires the base's chat template to format turns. The
35 probe consults the backend's tokenizer for `chat_template`; bases
36 without one (raw completion models) SKIP gracefully with a clear
37 reason via :class:`Verdict.SKIP`.
38
39 Tests that exercise the math path inject a minimal fake tokenizer
40 on the dummy backend — see ``tests/unit/test_probe_multi_turn_coherence.py``.
41 """
42
43 from __future__ import annotations
44
45 import math
46 from typing import Any, Literal
47
48 import numpy as np
49 from pydantic import Field
50
51 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
52 from dlm_sway.probes._divergence import Divergence, divergence
53 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
54
55
56 class MultiTurnCoherenceSpec(ProbeSpec):
57 """Spec for ``kind: multi_turn_coherence_decay``."""
58
59 kind: Literal["multi_turn_coherence_decay"] = "multi_turn_coherence_decay"
60 prompts: list[str] = Field(default_factory=list, min_length=0)
61 """Inline turn-1 user messages. Empty list → probe ERRORs (the
62 .dlm autogen path doesn't yet seed multi-turn cases — when it
63 does, this path will mirror :mod:`delta_kl`'s prompts_from)."""
64 max_turns: int = Field(default=4, ge=2, le=8)
65 """How many dialogue turns to roll. Minimum 2 (otherwise this is
66 just :mod:`delta_kl`); cap 8 to keep the probe fast on real
67 backends — each additional turn requires another greedy
68 generation + a backend toggle pair."""
69 max_new_tokens: int = 96
70 """Greedy decode budget for ft's per-turn response. 96 is
71 deliberately conservative so a long-winded adapter doesn't blow
72 out the context budget on later turns."""
73 follow_ups: list[str] = Field(
74 default_factory=lambda: [
75 "Continue.",
76 "Tell me more.",
77 "Can you elaborate?",
78 "What else?",
79 "Go deeper.",
80 "Expand on that.",
81 "And then?",
82 ]
83 )
84 """Generic per-turn follow-up prompts cycled through to drive the
85 dialogue forward. Cycled so any ``max_turns`` works without
86 needing per-prompt customization. Future: prefer doc-author-
87 written follow-ups when ``ctx.sections`` carries instruction
88 blocks with explicit multi-turn structure."""
89 divergence: Divergence = "kl"
90 """Per-turn divergence metric. ``kl`` is the convention here
91 because we want a directional measure (`base || ft`); ``js``
92 is symmetric."""
93 top_k: int | None = None
94 assert_half_life_turns: float = 2.0
95 """Pass criterion: adapter influence persists past turn 2 by
96 at least one half-life. Tune upward for adapters that need to
97 hold over longer conversations."""
98
99 # No null-calibration fields (no assert_z_gte) — see module docstring.
100
101
102 class MultiTurnCoherenceProbe(Probe):
103 """The "did the adapter forget its training by turn 3?" probe."""
104
105 kind = "multi_turn_coherence_decay"
106 spec_cls = MultiTurnCoherenceSpec
107 category = "adherence"
108
109 # As noted in the module docstring: a null adapter has no
110 # coherence to decay, so a null distribution of half_life_turns
111 # is meaningless. Skip the calibration handshake entirely.
112
113 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
114 assert isinstance(spec, MultiTurnCoherenceSpec)
115 if not spec.prompts:
116 return ProbeResult(
117 name=spec.name,
118 kind=spec.kind,
119 verdict=Verdict.ERROR,
120 score=None,
121 message="no prompts provided (inline 'prompts' was empty)",
122 )
123
124 tokenizer = _peek_backend_tokenizer(ctx)
125 chat_template = _chat_template_of(tokenizer)
126 if chat_template is None:
127 return ProbeResult(
128 name=spec.name,
129 kind=spec.kind,
130 verdict=Verdict.SKIP,
131 score=None,
132 message="base has no chat_template; multi-turn dialogue requires one",
133 )
134
135 top_k = spec.top_k if spec.top_k is not None else ctx.top_k
136 # Per-turn KLs aggregated across prompts. ``per_turn_kls[t]``
137 # is the mean KL at turn t+2 (turn 1 is the seed; turns 2..N
138 # are the scored positions).
139 per_turn_kls: list[list[float]] = [[] for _ in range(spec.max_turns - 1)]
140
141 for prompt in spec.prompts:
142 # Turn 1: ft generates the seed assistant response that
143 # populates the dialogue history. Generated under ft only —
144 # base never sees its own turn-1; both views share ft's
145 # history, which is the load-bearing design choice.
146 messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
147 with ctx.require_backend.as_finetuned() as fv:
148 ft_t1 = fv.generate(
149 _format_chat(messages, tokenizer, add_generation_prompt=True),
150 max_new_tokens=spec.max_new_tokens,
151 )
152 messages.append({"role": "assistant", "content": ft_t1})
153
154 for turn_idx in range(spec.max_turns - 1):
155 # Append the per-turn user follow-up (cycled).
156 follow_up = spec.follow_ups[turn_idx % len(spec.follow_ups)]
157 messages.append({"role": "user", "content": follow_up})
158 # Score next-token dist at the current turn under both views.
159 chat_str = _format_chat(messages, tokenizer, add_generation_prompt=True)
160 with ctx.require_backend.as_base() as bv:
161 base_dist = bv.next_token_dist(chat_str, top_k=top_k)
162 with ctx.require_backend.as_finetuned() as fv:
163 ft_dist = fv.next_token_dist(chat_str, top_k=top_k)
164 per_turn_kls[turn_idx].append(divergence(base_dist, ft_dist, kind=spec.divergence))
165 # Extend history with ft's response for the next iteration.
166 # Skip on the final turn — saves one generation call.
167 if turn_idx < spec.max_turns - 2:
168 with ctx.require_backend.as_finetuned() as fv:
169 ft_response = fv.generate(chat_str, max_new_tokens=spec.max_new_tokens)
170 messages.append({"role": "assistant", "content": ft_response})
171
172 # Mean KL per turn (across prompts).
173 mean_kls: list[float] = [float(np.mean(turn_kls)) for turn_kls in per_turn_kls]
174 # Turn axis: 2, 3, 4, ..., max_turns. Turn 1 isn't scored
175 # (it's the seed); turns 2..N are the curve points.
176 turns_axis = np.asarray(list(range(2, spec.max_turns + 1)), dtype=np.float64)
177 kls_axis = np.asarray(mean_kls, dtype=np.float64)
178
179 half_life, fit_status = _fit_half_life_turns(turns_axis, kls_axis, max_turns=spec.max_turns)
180
181 verdict, score, message = _verdict_from_half_life(
182 half_life=half_life,
183 fit_status=fit_status,
184 target=spec.assert_half_life_turns,
185 mean_kls=mean_kls,
186 turns_axis=turns_axis.tolist(),
187 )
188 return safe_finalize(
189 name=spec.name,
190 kind=spec.kind,
191 verdict=verdict,
192 score=score,
193 raw=half_life if half_life is not None and math.isfinite(half_life) else None,
194 evidence={
195 "per_turn_kls": mean_kls,
196 "turns_axis": turns_axis.tolist(),
197 "fit_status": fit_status,
198 "divergence_kind": spec.divergence,
199 "max_turns": spec.max_turns,
200 "num_prompts": len(spec.prompts),
201 "weight": spec.weight,
202 "sparkline": _ascii_sparkline(mean_kls),
203 },
204 message=message,
205 )
206
207
208 # ---------------------------------------------------------------------------
209 # Curve fit + verdict logic
210 # ---------------------------------------------------------------------------
211
212
213 def _fit_half_life_turns(
214 turns: np.ndarray, kls: np.ndarray, *, max_turns: int
215 ) -> tuple[float | None, str]:
216 """Fit ``kl = a * exp(-b * turn)`` via log-space linear regression.
217
218 Returns ``(half_life_turns, status)``.
219
220 Statuses:
221 - ``"ok"``: a clean exponential fit produced a finite half-life.
222 - ``"stable"``: KL stayed near-flat across turns. Half-life is
223 formally infinite; we clip to ``max_turns * 10`` so the report
224 doesn't print ``inf``. The probe interprets stable as
225 "adapter held coherence" — passing.
226 - ``"non_monotonic"``: KL grew with turn count (adapter becoming
227 *more* distinct as dialogue deepens — physically possible but
228 atypical). The half-life concept doesn't apply; we surface the
229 curve as evidence and let the user judge.
230 - ``"degenerate"``: KL was zero / negative at every turn. The
231 adapter is producing identical-to-base distributions; we
232 report a half-life of 0 to flag a likely no-op adapter.
233 """
234 # All-zero / non-positive KLs ⇒ no signal at any turn ⇒ probable no-op.
235 if not (kls > 0.0).all():
236 # Mixed positive + zero: still try the fit on the positives if there
237 # are at least 2 of them. All-zero / all-negative ⇒ degenerate.
238 positive_mask = kls > 0.0
239 if positive_mask.sum() < 2:
240 return 0.0, "degenerate"
241 turns = turns[positive_mask]
242 kls = kls[positive_mask]
243
244 # Flat-curve detection runs *before* the fit: if the KLs sit
245 # within a tight relative band of the mean, the slope is ~0 and
246 # the half-life is formally infinite. Pre-detecting this avoids
247 # an awkward "slope == 0 vs slope < epsilon" boundary inside the
248 # regression branch.
249 kl_mean = float(kls.mean())
250 if kl_mean > 0.0:
251 relative_spread = float((kls.max() - kls.min()) / kl_mean)
252 if relative_spread < 1e-3:
253 return float(max_turns) * 10.0, "stable"
254
255 log_y = np.log(kls)
256 x_mean = float(turns.mean())
257 y_mean = float(log_y.mean())
258 denom = float(((turns - x_mean) ** 2).sum())
259 # All turns identical (impossible for our turn-axis but defensive).
260 if denom == 0.0:
261 return None, "non_monotonic"
262 slope = float(((turns - x_mean) * (log_y - y_mean)).sum()) / denom
263 if slope > 0.0:
264 # KL grew with turn — adapter became more distinct. Could
265 # mean genuine multi-turn personality emergence; leave half-
266 # life undefined and flag in the status.
267 return None, "non_monotonic"
268
269 half_life = float(math.log(2.0) * (-1.0 / slope))
270 # B11 — a near-stable adapter has a tiny negative slope; the fit
271 # produces an enormous half-life. Clip + label so the report
272 # doesn't print ``inf`` or implausible values.
273 saturation = float(max_turns) * 10.0
274 if half_life >= saturation or not math.isfinite(half_life):
275 return saturation, "stable"
276 return half_life, "ok"
277
278
279 def _verdict_from_half_life(
280 *,
281 half_life: float | None,
282 fit_status: str,
283 target: float,
284 mean_kls: list[float],
285 turns_axis: list[float],
286 ) -> tuple[Verdict, float, str]:
287 """Translate the curve-fit result into a (verdict, score, message)."""
288 sparkline = _ascii_sparkline(mean_kls)
289
290 if fit_status == "non_monotonic":
291 # KL grew with turn — atypical but not necessarily wrong.
292 # Surface as WARN with the curve; user judgment from there.
293 return (
294 Verdict.WARN,
295 0.5,
296 f"non-monotonic KL across turns (adapter grew more distinct); curve {sparkline}",
297 )
298
299 if fit_status == "degenerate":
300 return (
301 Verdict.FAIL,
302 0.0,
303 f"KL is zero across all turns (probable no-op adapter); curve {sparkline}",
304 )
305
306 if half_life is None:
307 return (
308 Verdict.ERROR,
309 0.0,
310 f"could not fit decay curve (status={fit_status}); curve {sparkline}",
311 )
312
313 if fit_status == "stable":
314 return (
315 Verdict.PASS,
316 1.0,
317 f"adapter held coherence across all {int(turns_axis[-1])} turns; curve {sparkline}",
318 )
319
320 # fit_status == "ok"
321 passed = half_life >= target
322 score = float(min(1.0, half_life / max(target, 1e-6)))
323 return (
324 Verdict.PASS if passed else Verdict.FAIL,
325 score,
326 f"half-life={half_life:.2f} turns ({'≥' if passed else '<'} {target}); curve {sparkline}",
327 )
328
329
330 # ---------------------------------------------------------------------------
331 # Chat-template / tokenizer helpers
332 # ---------------------------------------------------------------------------
333
334
335 def _peek_backend_tokenizer(ctx: RunContext) -> Any | None:
336 """Same trick :mod:`prompt_collapse` uses: backends store the
337 tokenizer at ``_tokenizer``. Returns ``None`` for the dummy
338 backend (which has no tokenizer)."""
339 return getattr(ctx.backend, "_tokenizer", None)
340
341
342 def _chat_template_of(tokenizer: Any | None) -> str | None:
343 """Return the tokenizer's chat_template string, or ``None``.
344
345 HF tokenizers expose a ``chat_template`` attribute that's either
346 a Jinja string or ``None``. Some custom tokenizers don't have the
347 attribute at all — covered by the ``getattr`` default.
348 """
349 if tokenizer is None:
350 return None
351 template = getattr(tokenizer, "chat_template", None)
352 if template is None or not isinstance(template, str) or not template.strip():
353 return None
354 return str(template)
355
356
357 def _format_chat(
358 messages: list[dict[str, str]],
359 tokenizer: Any,
360 *,
361 add_generation_prompt: bool,
362 ) -> str:
363 """Render `messages` via the tokenizer's chat template.
364
365 Caller has already verified the tokenizer carries a chat_template
366 (via :func:`_chat_template_of`); a ``None`` tokenizer here is a
367 bug, not a runtime fallback path. We still defensively fall
368 back to a minimal role-marker concatenation so the dummy
369 backend's tests can drive the probe end-to-end without standing
370 up a real tokenizer.
371 """
372 if tokenizer is None:
373 return _fallback_format(messages, add_generation_prompt=add_generation_prompt)
374 try:
375 out = tokenizer.apply_chat_template(
376 messages, tokenize=False, add_generation_prompt=add_generation_prompt
377 )
378 return str(out)
379 except Exception: # noqa: BLE001 — tokenizer impls vary; never let a probe crash on this
380 return _fallback_format(messages, add_generation_prompt=add_generation_prompt)
381
382
383 def _fallback_format(messages: list[dict[str, str]], *, add_generation_prompt: bool) -> str:
384 """Minimal role-marker formatter for tokenizers that misbehave + tests."""
385 parts: list[str] = []
386 for m in messages:
387 parts.append(f"{m['role'].upper()}: {m['content']}")
388 if add_generation_prompt:
389 parts.append("ASSISTANT:")
390 return "\n".join(parts)
391
392
393 # ---------------------------------------------------------------------------
394 # Report sparkline
395 # ---------------------------------------------------------------------------
396
397
398 _SPARKLINE_BARS = "▁▂▃▄▅▆▇█"
399
400
401 def _ascii_sparkline(values: list[float]) -> str:
402 """Compact unicode sparkline of per-turn KL.
403
404 Surfaces in the report message so a terminal-only reader gets a
405 sense of curve shape without opening the JSON. Empty/degenerate
406 inputs render as an empty string.
407 """
408 if not values:
409 return ""
410 finite = [v for v in values if math.isfinite(v) and v >= 0.0]
411 if not finite:
412 return ""
413 lo, hi = min(finite), max(finite)
414 span = hi - lo
415 if span <= 1e-12:
416 # Flat line — every bar at mid-height.
417 return _SPARKLINE_BARS[len(_SPARKLINE_BARS) // 2] * len(values)
418 out: list[str] = []
419 for v in values:
420 if not math.isfinite(v) or v < 0.0:
421 out.append("?")
422 continue
423 idx = int((v - lo) / span * (len(_SPARKLINE_BARS) - 1))
424 out.append(_SPARKLINE_BARS[max(0, min(idx, len(_SPARKLINE_BARS) - 1))])
425 return "".join(out)