| 1 | """Multi-turn coherence decay — does the adapter survive a multi-turn dialogue? |
| 2 | |
| 3 | Every other adherence probe is single-turn: one user message, one |
| 4 | completion, one score. Adapters that pass single-turn probes |
| 5 | frequently "forget their training" by turn 2 or 3 of real dialogue, |
| 6 | where the model's own previous responses enter the context window |
| 7 | and create compounding drift. No other shipped sway probe catches |
| 8 | that failure mode. |
| 9 | |
| 10 | The probe rolls a multi-turn synthetic dialogue per prompt: |
| 11 | |
| 12 | 1. Generate ft's turn-1 response greedily. |
| 13 | 2. Build a turn-2 chat history `[user=prompt, asst=ft_t1, |
| 14 | user=follow_up_2]` and compute `KL(base || ft)` at turn 2. |
| 15 | 3. Extend with ft's turn-2 response, build turn-3 history, score. |
| 16 | 4. Repeat through `max_turns`. |
| 17 | 5. Fit `kl = a · exp(-b · turn)` over turns 2..N; report |
| 18 | `half_life_turns = ln(2) / b`. |
| 19 | |
| 20 | Lower half-life ⇒ adapter influence evaporates faster as dialogue |
| 21 | deepens. A "stable" adapter (near-flat KL across turns) reports a |
| 22 | saturated half-life with a "stable" marker rather than `inf`. |
| 23 | |
| 24 | ## Why no null calibration |
| 25 | |
| 26 | Mirrors :mod:`prompt_collapse`: a null adapter has random-noise |
| 27 | weights with no real signal to decay. Its turn-1 greedy output is |
| 28 | essentially gibberish, fed back through the chat template makes the |
| 29 | per-turn KLs meaningless, and the resulting half-life distribution |
| 30 | is undefined. Fixed-threshold verdicts are the published path. |
| 31 | |
| 32 | ## Chat-template requirement |
| 33 | |
| 34 | Multi-turn requires the base's chat template to format turns. The |
| 35 | probe consults the backend's tokenizer for `chat_template`; bases |
| 36 | without one (raw completion models) SKIP gracefully with a clear |
| 37 | reason via :class:`Verdict.SKIP`. |
| 38 | |
| 39 | Tests that exercise the math path inject a minimal fake tokenizer |
| 40 | on the dummy backend — see ``tests/unit/test_probe_multi_turn_coherence.py``. |
| 41 | """ |
| 42 | |
| 43 | from __future__ import annotations |
| 44 | |
| 45 | import math |
| 46 | from typing import Any, Literal |
| 47 | |
| 48 | import numpy as np |
| 49 | from pydantic import Field |
| 50 | |
| 51 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 52 | from dlm_sway.probes._divergence import Divergence, divergence |
| 53 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 54 | |
| 55 | |
| 56 | class MultiTurnCoherenceSpec(ProbeSpec): |
| 57 | """Spec for ``kind: multi_turn_coherence_decay``.""" |
| 58 | |
| 59 | kind: Literal["multi_turn_coherence_decay"] = "multi_turn_coherence_decay" |
| 60 | prompts: list[str] = Field(default_factory=list, min_length=0) |
| 61 | """Inline turn-1 user messages. Empty list → probe ERRORs (the |
| 62 | .dlm autogen path doesn't yet seed multi-turn cases — when it |
| 63 | does, this path will mirror :mod:`delta_kl`'s prompts_from).""" |
| 64 | max_turns: int = Field(default=4, ge=2, le=8) |
| 65 | """How many dialogue turns to roll. Minimum 2 (otherwise this is |
| 66 | just :mod:`delta_kl`); cap 8 to keep the probe fast on real |
| 67 | backends — each additional turn requires another greedy |
| 68 | generation + a backend toggle pair.""" |
| 69 | max_new_tokens: int = 96 |
| 70 | """Greedy decode budget for ft's per-turn response. 96 is |
| 71 | deliberately conservative so a long-winded adapter doesn't blow |
| 72 | out the context budget on later turns.""" |
| 73 | follow_ups: list[str] = Field( |
| 74 | default_factory=lambda: [ |
| 75 | "Continue.", |
| 76 | "Tell me more.", |
| 77 | "Can you elaborate?", |
| 78 | "What else?", |
| 79 | "Go deeper.", |
| 80 | "Expand on that.", |
| 81 | "And then?", |
| 82 | ] |
| 83 | ) |
| 84 | """Generic per-turn follow-up prompts cycled through to drive the |
| 85 | dialogue forward. Cycled so any ``max_turns`` works without |
| 86 | needing per-prompt customization. Future: prefer doc-author- |
| 87 | written follow-ups when ``ctx.sections`` carries instruction |
| 88 | blocks with explicit multi-turn structure.""" |
| 89 | divergence: Divergence = "kl" |
| 90 | """Per-turn divergence metric. ``kl`` is the convention here |
| 91 | because we want a directional measure (`base || ft`); ``js`` |
| 92 | is symmetric.""" |
| 93 | top_k: int | None = None |
| 94 | assert_half_life_turns: float = 2.0 |
| 95 | """Pass criterion: adapter influence persists past turn 2 by |
| 96 | at least one half-life. Tune upward for adapters that need to |
| 97 | hold over longer conversations.""" |
| 98 | |
| 99 | # No null-calibration fields (no assert_z_gte) — see module docstring. |
| 100 | |
| 101 | |
| 102 | class MultiTurnCoherenceProbe(Probe): |
| 103 | """The "did the adapter forget its training by turn 3?" probe.""" |
| 104 | |
| 105 | kind = "multi_turn_coherence_decay" |
| 106 | spec_cls = MultiTurnCoherenceSpec |
| 107 | category = "adherence" |
| 108 | |
| 109 | # As noted in the module docstring: a null adapter has no |
| 110 | # coherence to decay, so a null distribution of half_life_turns |
| 111 | # is meaningless. Skip the calibration handshake entirely. |
| 112 | |
| 113 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 114 | assert isinstance(spec, MultiTurnCoherenceSpec) |
| 115 | if not spec.prompts: |
| 116 | return ProbeResult( |
| 117 | name=spec.name, |
| 118 | kind=spec.kind, |
| 119 | verdict=Verdict.ERROR, |
| 120 | score=None, |
| 121 | message="no prompts provided (inline 'prompts' was empty)", |
| 122 | ) |
| 123 | |
| 124 | tokenizer = _peek_backend_tokenizer(ctx) |
| 125 | chat_template = _chat_template_of(tokenizer) |
| 126 | if chat_template is None: |
| 127 | return ProbeResult( |
| 128 | name=spec.name, |
| 129 | kind=spec.kind, |
| 130 | verdict=Verdict.SKIP, |
| 131 | score=None, |
| 132 | message="base has no chat_template; multi-turn dialogue requires one", |
| 133 | ) |
| 134 | |
| 135 | top_k = spec.top_k if spec.top_k is not None else ctx.top_k |
| 136 | # Per-turn KLs aggregated across prompts. ``per_turn_kls[t]`` |
| 137 | # is the mean KL at turn t+2 (turn 1 is the seed; turns 2..N |
| 138 | # are the scored positions). |
| 139 | per_turn_kls: list[list[float]] = [[] for _ in range(spec.max_turns - 1)] |
| 140 | |
| 141 | for prompt in spec.prompts: |
| 142 | # Turn 1: ft generates the seed assistant response that |
| 143 | # populates the dialogue history. Generated under ft only — |
| 144 | # base never sees its own turn-1; both views share ft's |
| 145 | # history, which is the load-bearing design choice. |
| 146 | messages: list[dict[str, str]] = [{"role": "user", "content": prompt}] |
| 147 | with ctx.require_backend.as_finetuned() as fv: |
| 148 | ft_t1 = fv.generate( |
| 149 | _format_chat(messages, tokenizer, add_generation_prompt=True), |
| 150 | max_new_tokens=spec.max_new_tokens, |
| 151 | ) |
| 152 | messages.append({"role": "assistant", "content": ft_t1}) |
| 153 | |
| 154 | for turn_idx in range(spec.max_turns - 1): |
| 155 | # Append the per-turn user follow-up (cycled). |
| 156 | follow_up = spec.follow_ups[turn_idx % len(spec.follow_ups)] |
| 157 | messages.append({"role": "user", "content": follow_up}) |
| 158 | # Score next-token dist at the current turn under both views. |
| 159 | chat_str = _format_chat(messages, tokenizer, add_generation_prompt=True) |
| 160 | with ctx.require_backend.as_base() as bv: |
| 161 | base_dist = bv.next_token_dist(chat_str, top_k=top_k) |
| 162 | with ctx.require_backend.as_finetuned() as fv: |
| 163 | ft_dist = fv.next_token_dist(chat_str, top_k=top_k) |
| 164 | per_turn_kls[turn_idx].append(divergence(base_dist, ft_dist, kind=spec.divergence)) |
| 165 | # Extend history with ft's response for the next iteration. |
| 166 | # Skip on the final turn — saves one generation call. |
| 167 | if turn_idx < spec.max_turns - 2: |
| 168 | with ctx.require_backend.as_finetuned() as fv: |
| 169 | ft_response = fv.generate(chat_str, max_new_tokens=spec.max_new_tokens) |
| 170 | messages.append({"role": "assistant", "content": ft_response}) |
| 171 | |
| 172 | # Mean KL per turn (across prompts). |
| 173 | mean_kls: list[float] = [float(np.mean(turn_kls)) for turn_kls in per_turn_kls] |
| 174 | # Turn axis: 2, 3, 4, ..., max_turns. Turn 1 isn't scored |
| 175 | # (it's the seed); turns 2..N are the curve points. |
| 176 | turns_axis = np.asarray(list(range(2, spec.max_turns + 1)), dtype=np.float64) |
| 177 | kls_axis = np.asarray(mean_kls, dtype=np.float64) |
| 178 | |
| 179 | half_life, fit_status = _fit_half_life_turns(turns_axis, kls_axis, max_turns=spec.max_turns) |
| 180 | |
| 181 | verdict, score, message = _verdict_from_half_life( |
| 182 | half_life=half_life, |
| 183 | fit_status=fit_status, |
| 184 | target=spec.assert_half_life_turns, |
| 185 | mean_kls=mean_kls, |
| 186 | turns_axis=turns_axis.tolist(), |
| 187 | ) |
| 188 | return safe_finalize( |
| 189 | name=spec.name, |
| 190 | kind=spec.kind, |
| 191 | verdict=verdict, |
| 192 | score=score, |
| 193 | raw=half_life if half_life is not None and math.isfinite(half_life) else None, |
| 194 | evidence={ |
| 195 | "per_turn_kls": mean_kls, |
| 196 | "turns_axis": turns_axis.tolist(), |
| 197 | "fit_status": fit_status, |
| 198 | "divergence_kind": spec.divergence, |
| 199 | "max_turns": spec.max_turns, |
| 200 | "num_prompts": len(spec.prompts), |
| 201 | "weight": spec.weight, |
| 202 | "sparkline": _ascii_sparkline(mean_kls), |
| 203 | }, |
| 204 | message=message, |
| 205 | ) |
| 206 | |
| 207 | |
| 208 | # --------------------------------------------------------------------------- |
| 209 | # Curve fit + verdict logic |
| 210 | # --------------------------------------------------------------------------- |
| 211 | |
| 212 | |
| 213 | def _fit_half_life_turns( |
| 214 | turns: np.ndarray, kls: np.ndarray, *, max_turns: int |
| 215 | ) -> tuple[float | None, str]: |
| 216 | """Fit ``kl = a * exp(-b * turn)`` via log-space linear regression. |
| 217 | |
| 218 | Returns ``(half_life_turns, status)``. |
| 219 | |
| 220 | Statuses: |
| 221 | - ``"ok"``: a clean exponential fit produced a finite half-life. |
| 222 | - ``"stable"``: KL stayed near-flat across turns. Half-life is |
| 223 | formally infinite; we clip to ``max_turns * 10`` so the report |
| 224 | doesn't print ``inf``. The probe interprets stable as |
| 225 | "adapter held coherence" — passing. |
| 226 | - ``"non_monotonic"``: KL grew with turn count (adapter becoming |
| 227 | *more* distinct as dialogue deepens — physically possible but |
| 228 | atypical). The half-life concept doesn't apply; we surface the |
| 229 | curve as evidence and let the user judge. |
| 230 | - ``"degenerate"``: KL was zero / negative at every turn. The |
| 231 | adapter is producing identical-to-base distributions; we |
| 232 | report a half-life of 0 to flag a likely no-op adapter. |
| 233 | """ |
| 234 | # All-zero / non-positive KLs ⇒ no signal at any turn ⇒ probable no-op. |
| 235 | if not (kls > 0.0).all(): |
| 236 | # Mixed positive + zero: still try the fit on the positives if there |
| 237 | # are at least 2 of them. All-zero / all-negative ⇒ degenerate. |
| 238 | positive_mask = kls > 0.0 |
| 239 | if positive_mask.sum() < 2: |
| 240 | return 0.0, "degenerate" |
| 241 | turns = turns[positive_mask] |
| 242 | kls = kls[positive_mask] |
| 243 | |
| 244 | # Flat-curve detection runs *before* the fit: if the KLs sit |
| 245 | # within a tight relative band of the mean, the slope is ~0 and |
| 246 | # the half-life is formally infinite. Pre-detecting this avoids |
| 247 | # an awkward "slope == 0 vs slope < epsilon" boundary inside the |
| 248 | # regression branch. |
| 249 | kl_mean = float(kls.mean()) |
| 250 | if kl_mean > 0.0: |
| 251 | relative_spread = float((kls.max() - kls.min()) / kl_mean) |
| 252 | if relative_spread < 1e-3: |
| 253 | return float(max_turns) * 10.0, "stable" |
| 254 | |
| 255 | log_y = np.log(kls) |
| 256 | x_mean = float(turns.mean()) |
| 257 | y_mean = float(log_y.mean()) |
| 258 | denom = float(((turns - x_mean) ** 2).sum()) |
| 259 | # All turns identical (impossible for our turn-axis but defensive). |
| 260 | if denom == 0.0: |
| 261 | return None, "non_monotonic" |
| 262 | slope = float(((turns - x_mean) * (log_y - y_mean)).sum()) / denom |
| 263 | if slope > 0.0: |
| 264 | # KL grew with turn — adapter became more distinct. Could |
| 265 | # mean genuine multi-turn personality emergence; leave half- |
| 266 | # life undefined and flag in the status. |
| 267 | return None, "non_monotonic" |
| 268 | |
| 269 | half_life = float(math.log(2.0) * (-1.0 / slope)) |
| 270 | # B11 — a near-stable adapter has a tiny negative slope; the fit |
| 271 | # produces an enormous half-life. Clip + label so the report |
| 272 | # doesn't print ``inf`` or implausible values. |
| 273 | saturation = float(max_turns) * 10.0 |
| 274 | if half_life >= saturation or not math.isfinite(half_life): |
| 275 | return saturation, "stable" |
| 276 | return half_life, "ok" |
| 277 | |
| 278 | |
| 279 | def _verdict_from_half_life( |
| 280 | *, |
| 281 | half_life: float | None, |
| 282 | fit_status: str, |
| 283 | target: float, |
| 284 | mean_kls: list[float], |
| 285 | turns_axis: list[float], |
| 286 | ) -> tuple[Verdict, float, str]: |
| 287 | """Translate the curve-fit result into a (verdict, score, message).""" |
| 288 | sparkline = _ascii_sparkline(mean_kls) |
| 289 | |
| 290 | if fit_status == "non_monotonic": |
| 291 | # KL grew with turn — atypical but not necessarily wrong. |
| 292 | # Surface as WARN with the curve; user judgment from there. |
| 293 | return ( |
| 294 | Verdict.WARN, |
| 295 | 0.5, |
| 296 | f"non-monotonic KL across turns (adapter grew more distinct); curve {sparkline}", |
| 297 | ) |
| 298 | |
| 299 | if fit_status == "degenerate": |
| 300 | return ( |
| 301 | Verdict.FAIL, |
| 302 | 0.0, |
| 303 | f"KL is zero across all turns (probable no-op adapter); curve {sparkline}", |
| 304 | ) |
| 305 | |
| 306 | if half_life is None: |
| 307 | return ( |
| 308 | Verdict.ERROR, |
| 309 | 0.0, |
| 310 | f"could not fit decay curve (status={fit_status}); curve {sparkline}", |
| 311 | ) |
| 312 | |
| 313 | if fit_status == "stable": |
| 314 | return ( |
| 315 | Verdict.PASS, |
| 316 | 1.0, |
| 317 | f"adapter held coherence across all {int(turns_axis[-1])} turns; curve {sparkline}", |
| 318 | ) |
| 319 | |
| 320 | # fit_status == "ok" |
| 321 | passed = half_life >= target |
| 322 | score = float(min(1.0, half_life / max(target, 1e-6))) |
| 323 | return ( |
| 324 | Verdict.PASS if passed else Verdict.FAIL, |
| 325 | score, |
| 326 | f"half-life={half_life:.2f} turns ({'≥' if passed else '<'} {target}); curve {sparkline}", |
| 327 | ) |
| 328 | |
| 329 | |
| 330 | # --------------------------------------------------------------------------- |
| 331 | # Chat-template / tokenizer helpers |
| 332 | # --------------------------------------------------------------------------- |
| 333 | |
| 334 | |
| 335 | def _peek_backend_tokenizer(ctx: RunContext) -> Any | None: |
| 336 | """Same trick :mod:`prompt_collapse` uses: backends store the |
| 337 | tokenizer at ``_tokenizer``. Returns ``None`` for the dummy |
| 338 | backend (which has no tokenizer).""" |
| 339 | return getattr(ctx.backend, "_tokenizer", None) |
| 340 | |
| 341 | |
| 342 | def _chat_template_of(tokenizer: Any | None) -> str | None: |
| 343 | """Return the tokenizer's chat_template string, or ``None``. |
| 344 | |
| 345 | HF tokenizers expose a ``chat_template`` attribute that's either |
| 346 | a Jinja string or ``None``. Some custom tokenizers don't have the |
| 347 | attribute at all — covered by the ``getattr`` default. |
| 348 | """ |
| 349 | if tokenizer is None: |
| 350 | return None |
| 351 | template = getattr(tokenizer, "chat_template", None) |
| 352 | if template is None or not isinstance(template, str) or not template.strip(): |
| 353 | return None |
| 354 | return str(template) |
| 355 | |
| 356 | |
| 357 | def _format_chat( |
| 358 | messages: list[dict[str, str]], |
| 359 | tokenizer: Any, |
| 360 | *, |
| 361 | add_generation_prompt: bool, |
| 362 | ) -> str: |
| 363 | """Render `messages` via the tokenizer's chat template. |
| 364 | |
| 365 | Caller has already verified the tokenizer carries a chat_template |
| 366 | (via :func:`_chat_template_of`); a ``None`` tokenizer here is a |
| 367 | bug, not a runtime fallback path. We still defensively fall |
| 368 | back to a minimal role-marker concatenation so the dummy |
| 369 | backend's tests can drive the probe end-to-end without standing |
| 370 | up a real tokenizer. |
| 371 | """ |
| 372 | if tokenizer is None: |
| 373 | return _fallback_format(messages, add_generation_prompt=add_generation_prompt) |
| 374 | try: |
| 375 | out = tokenizer.apply_chat_template( |
| 376 | messages, tokenize=False, add_generation_prompt=add_generation_prompt |
| 377 | ) |
| 378 | return str(out) |
| 379 | except Exception: # noqa: BLE001 — tokenizer impls vary; never let a probe crash on this |
| 380 | return _fallback_format(messages, add_generation_prompt=add_generation_prompt) |
| 381 | |
| 382 | |
| 383 | def _fallback_format(messages: list[dict[str, str]], *, add_generation_prompt: bool) -> str: |
| 384 | """Minimal role-marker formatter for tokenizers that misbehave + tests.""" |
| 385 | parts: list[str] = [] |
| 386 | for m in messages: |
| 387 | parts.append(f"{m['role'].upper()}: {m['content']}") |
| 388 | if add_generation_prompt: |
| 389 | parts.append("ASSISTANT:") |
| 390 | return "\n".join(parts) |
| 391 | |
| 392 | |
| 393 | # --------------------------------------------------------------------------- |
| 394 | # Report sparkline |
| 395 | # --------------------------------------------------------------------------- |
| 396 | |
| 397 | |
| 398 | _SPARKLINE_BARS = "▁▂▃▄▅▆▇█" |
| 399 | |
| 400 | |
| 401 | def _ascii_sparkline(values: list[float]) -> str: |
| 402 | """Compact unicode sparkline of per-turn KL. |
| 403 | |
| 404 | Surfaces in the report message so a terminal-only reader gets a |
| 405 | sense of curve shape without opening the JSON. Empty/degenerate |
| 406 | inputs render as an empty string. |
| 407 | """ |
| 408 | if not values: |
| 409 | return "" |
| 410 | finite = [v for v in values if math.isfinite(v) and v >= 0.0] |
| 411 | if not finite: |
| 412 | return "" |
| 413 | lo, hi = min(finite), max(finite) |
| 414 | span = hi - lo |
| 415 | if span <= 1e-12: |
| 416 | # Flat line — every bar at mid-height. |
| 417 | return _SPARKLINE_BARS[len(_SPARKLINE_BARS) // 2] * len(values) |
| 418 | out: list[str] = [] |
| 419 | for v in values: |
| 420 | if not math.isfinite(v) or v < 0.0: |
| 421 | out.append("?") |
| 422 | continue |
| 423 | idx = int((v - lo) / span * (len(_SPARKLINE_BARS) - 1)) |
| 424 | out.append(_SPARKLINE_BARS[max(0, min(idx, len(_SPARKLINE_BARS) - 1))]) |
| 425 | return "".join(out) |