@@ -0,0 +1,425 @@ |
| 1 | +"""Multi-turn coherence decay — does the adapter survive a multi-turn dialogue? |
| 2 | + |
| 3 | +Every other adherence probe is single-turn: one user message, one |
| 4 | +completion, one score. Adapters that pass single-turn probes |
| 5 | +frequently "forget their training" by turn 2 or 3 of real dialogue, |
| 6 | +where the model's own previous responses enter the context window |
| 7 | +and create compounding drift. No other shipped sway probe catches |
| 8 | +that failure mode. |
| 9 | + |
| 10 | +The probe rolls a multi-turn synthetic dialogue per prompt: |
| 11 | + |
| 12 | +1. Generate ft's turn-1 response greedily. |
| 13 | +2. Build a turn-2 chat history `[user=prompt, asst=ft_t1, |
| 14 | + user=follow_up_2]` and compute `KL(base || ft)` at turn 2. |
| 15 | +3. Extend with ft's turn-2 response, build turn-3 history, score. |
| 16 | +4. Repeat through `max_turns`. |
| 17 | +5. Fit `kl = a · exp(-b · turn)` over turns 2..N; report |
| 18 | + `half_life_turns = ln(2) / b`. |
| 19 | + |
| 20 | +Lower half-life ⇒ adapter influence evaporates faster as dialogue |
| 21 | +deepens. A "stable" adapter (near-flat KL across turns) reports a |
| 22 | +saturated half-life with a "stable" marker rather than `inf`. |
| 23 | + |
| 24 | +## Why no null calibration |
| 25 | + |
| 26 | +Mirrors :mod:`prompt_collapse`: a null adapter has random-noise |
| 27 | +weights with no real signal to decay. Its turn-1 greedy output is |
| 28 | +essentially gibberish, fed back through the chat template makes the |
| 29 | +per-turn KLs meaningless, and the resulting half-life distribution |
| 30 | +is undefined. Fixed-threshold verdicts are the published path. |
| 31 | + |
| 32 | +## Chat-template requirement |
| 33 | + |
| 34 | +Multi-turn requires the base's chat template to format turns. The |
| 35 | +probe consults the backend's tokenizer for `chat_template`; bases |
| 36 | +without one (raw completion models) SKIP gracefully with a clear |
| 37 | +reason via :class:`Verdict.SKIP`. |
| 38 | + |
| 39 | +Tests that exercise the math path inject a minimal fake tokenizer |
| 40 | +on the dummy backend — see ``tests/unit/test_probe_multi_turn_coherence.py``. |
| 41 | +""" |
| 42 | + |
| 43 | +from __future__ import annotations |
| 44 | + |
| 45 | +import math |
| 46 | +from typing import Any, Literal |
| 47 | + |
| 48 | +import numpy as np |
| 49 | +from pydantic import Field |
| 50 | + |
| 51 | +from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 52 | +from dlm_sway.probes._divergence import Divergence, divergence |
| 53 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 54 | + |
| 55 | + |
| 56 | +class MultiTurnCoherenceSpec(ProbeSpec): |
| 57 | + """Spec for ``kind: multi_turn_coherence_decay``.""" |
| 58 | + |
| 59 | + kind: Literal["multi_turn_coherence_decay"] = "multi_turn_coherence_decay" |
| 60 | + prompts: list[str] = Field(default_factory=list, min_length=0) |
| 61 | + """Inline turn-1 user messages. Empty list → probe ERRORs (the |
| 62 | + .dlm autogen path doesn't yet seed multi-turn cases — when it |
| 63 | + does, this path will mirror :mod:`delta_kl`'s prompts_from).""" |
| 64 | + max_turns: int = Field(default=4, ge=2, le=8) |
| 65 | + """How many dialogue turns to roll. Minimum 2 (otherwise this is |
| 66 | + just :mod:`delta_kl`); cap 8 to keep the probe fast on real |
| 67 | + backends — each additional turn requires another greedy |
| 68 | + generation + a backend toggle pair.""" |
| 69 | + max_new_tokens: int = 96 |
| 70 | + """Greedy decode budget for ft's per-turn response. 96 is |
| 71 | + deliberately conservative so a long-winded adapter doesn't blow |
| 72 | + out the context budget on later turns.""" |
| 73 | + follow_ups: list[str] = Field( |
| 74 | + default_factory=lambda: [ |
| 75 | + "Continue.", |
| 76 | + "Tell me more.", |
| 77 | + "Can you elaborate?", |
| 78 | + "What else?", |
| 79 | + "Go deeper.", |
| 80 | + "Expand on that.", |
| 81 | + "And then?", |
| 82 | + ] |
| 83 | + ) |
| 84 | + """Generic per-turn follow-up prompts cycled through to drive the |
| 85 | + dialogue forward. Cycled so any ``max_turns`` works without |
| 86 | + needing per-prompt customization. Future: prefer doc-author- |
| 87 | + written follow-ups when ``ctx.sections`` carries instruction |
| 88 | + blocks with explicit multi-turn structure.""" |
| 89 | + divergence: Divergence = "kl" |
| 90 | + """Per-turn divergence metric. ``kl`` is the convention here |
| 91 | + because we want a directional measure (`base || ft`); ``js`` |
| 92 | + is symmetric.""" |
| 93 | + top_k: int | None = None |
| 94 | + assert_half_life_turns: float = 2.0 |
| 95 | + """Pass criterion: adapter influence persists past turn 2 by |
| 96 | + at least one half-life. Tune upward for adapters that need to |
| 97 | + hold over longer conversations.""" |
| 98 | + |
| 99 | + # No null-calibration fields (no assert_z_gte) — see module docstring. |
| 100 | + |
| 101 | + |
| 102 | +class MultiTurnCoherenceProbe(Probe): |
| 103 | + """The "did the adapter forget its training by turn 3?" probe.""" |
| 104 | + |
| 105 | + kind = "multi_turn_coherence_decay" |
| 106 | + spec_cls = MultiTurnCoherenceSpec |
| 107 | + category = "adherence" |
| 108 | + |
| 109 | + # As noted in the module docstring: a null adapter has no |
| 110 | + # coherence to decay, so a null distribution of half_life_turns |
| 111 | + # is meaningless. Skip the calibration handshake entirely. |
| 112 | + |
| 113 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 114 | + assert isinstance(spec, MultiTurnCoherenceSpec) |
| 115 | + if not spec.prompts: |
| 116 | + return ProbeResult( |
| 117 | + name=spec.name, |
| 118 | + kind=spec.kind, |
| 119 | + verdict=Verdict.ERROR, |
| 120 | + score=None, |
| 121 | + message="no prompts provided (inline 'prompts' was empty)", |
| 122 | + ) |
| 123 | + |
| 124 | + tokenizer = _peek_backend_tokenizer(ctx) |
| 125 | + chat_template = _chat_template_of(tokenizer) |
| 126 | + if chat_template is None: |
| 127 | + return ProbeResult( |
| 128 | + name=spec.name, |
| 129 | + kind=spec.kind, |
| 130 | + verdict=Verdict.SKIP, |
| 131 | + score=None, |
| 132 | + message="base has no chat_template; multi-turn dialogue requires one", |
| 133 | + ) |
| 134 | + |
| 135 | + top_k = spec.top_k if spec.top_k is not None else ctx.top_k |
| 136 | + # Per-turn KLs aggregated across prompts. ``per_turn_kls[t]`` |
| 137 | + # is the mean KL at turn t+2 (turn 1 is the seed; turns 2..N |
| 138 | + # are the scored positions). |
| 139 | + per_turn_kls: list[list[float]] = [[] for _ in range(spec.max_turns - 1)] |
| 140 | + |
| 141 | + for prompt in spec.prompts: |
| 142 | + # Turn 1: ft generates the seed assistant response that |
| 143 | + # populates the dialogue history. Generated under ft only — |
| 144 | + # base never sees its own turn-1; both views share ft's |
| 145 | + # history, which is the load-bearing design choice. |
| 146 | + messages: list[dict[str, str]] = [{"role": "user", "content": prompt}] |
| 147 | + with ctx.require_backend.as_finetuned() as fv: |
| 148 | + ft_t1 = fv.generate( |
| 149 | + _format_chat(messages, tokenizer, add_generation_prompt=True), |
| 150 | + max_new_tokens=spec.max_new_tokens, |
| 151 | + ) |
| 152 | + messages.append({"role": "assistant", "content": ft_t1}) |
| 153 | + |
| 154 | + for turn_idx in range(spec.max_turns - 1): |
| 155 | + # Append the per-turn user follow-up (cycled). |
| 156 | + follow_up = spec.follow_ups[turn_idx % len(spec.follow_ups)] |
| 157 | + messages.append({"role": "user", "content": follow_up}) |
| 158 | + # Score next-token dist at the current turn under both views. |
| 159 | + chat_str = _format_chat(messages, tokenizer, add_generation_prompt=True) |
| 160 | + with ctx.require_backend.as_base() as bv: |
| 161 | + base_dist = bv.next_token_dist(chat_str, top_k=top_k) |
| 162 | + with ctx.require_backend.as_finetuned() as fv: |
| 163 | + ft_dist = fv.next_token_dist(chat_str, top_k=top_k) |
| 164 | + per_turn_kls[turn_idx].append(divergence(base_dist, ft_dist, kind=spec.divergence)) |
| 165 | + # Extend history with ft's response for the next iteration. |
| 166 | + # Skip on the final turn — saves one generation call. |
| 167 | + if turn_idx < spec.max_turns - 2: |
| 168 | + with ctx.require_backend.as_finetuned() as fv: |
| 169 | + ft_response = fv.generate(chat_str, max_new_tokens=spec.max_new_tokens) |
| 170 | + messages.append({"role": "assistant", "content": ft_response}) |
| 171 | + |
| 172 | + # Mean KL per turn (across prompts). |
| 173 | + mean_kls: list[float] = [float(np.mean(turn_kls)) for turn_kls in per_turn_kls] |
| 174 | + # Turn axis: 2, 3, 4, ..., max_turns. Turn 1 isn't scored |
| 175 | + # (it's the seed); turns 2..N are the curve points. |
| 176 | + turns_axis = np.asarray(list(range(2, spec.max_turns + 1)), dtype=np.float64) |
| 177 | + kls_axis = np.asarray(mean_kls, dtype=np.float64) |
| 178 | + |
| 179 | + half_life, fit_status = _fit_half_life_turns(turns_axis, kls_axis, max_turns=spec.max_turns) |
| 180 | + |
| 181 | + verdict, score, message = _verdict_from_half_life( |
| 182 | + half_life=half_life, |
| 183 | + fit_status=fit_status, |
| 184 | + target=spec.assert_half_life_turns, |
| 185 | + mean_kls=mean_kls, |
| 186 | + turns_axis=turns_axis.tolist(), |
| 187 | + ) |
| 188 | + return safe_finalize( |
| 189 | + name=spec.name, |
| 190 | + kind=spec.kind, |
| 191 | + verdict=verdict, |
| 192 | + score=score, |
| 193 | + raw=half_life if half_life is not None and math.isfinite(half_life) else None, |
| 194 | + evidence={ |
| 195 | + "per_turn_kls": mean_kls, |
| 196 | + "turns_axis": turns_axis.tolist(), |
| 197 | + "fit_status": fit_status, |
| 198 | + "divergence_kind": spec.divergence, |
| 199 | + "max_turns": spec.max_turns, |
| 200 | + "num_prompts": len(spec.prompts), |
| 201 | + "weight": spec.weight, |
| 202 | + "sparkline": _ascii_sparkline(mean_kls), |
| 203 | + }, |
| 204 | + message=message, |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +# --------------------------------------------------------------------------- |
| 209 | +# Curve fit + verdict logic |
| 210 | +# --------------------------------------------------------------------------- |
| 211 | + |
| 212 | + |
| 213 | +def _fit_half_life_turns( |
| 214 | + turns: np.ndarray, kls: np.ndarray, *, max_turns: int |
| 215 | +) -> tuple[float | None, str]: |
| 216 | + """Fit ``kl = a * exp(-b * turn)`` via log-space linear regression. |
| 217 | + |
| 218 | + Returns ``(half_life_turns, status)``. |
| 219 | + |
| 220 | + Statuses: |
| 221 | + - ``"ok"``: a clean exponential fit produced a finite half-life. |
| 222 | + - ``"stable"``: KL stayed near-flat across turns. Half-life is |
| 223 | + formally infinite; we clip to ``max_turns * 10`` so the report |
| 224 | + doesn't print ``inf``. The probe interprets stable as |
| 225 | + "adapter held coherence" — passing. |
| 226 | + - ``"non_monotonic"``: KL grew with turn count (adapter becoming |
| 227 | + *more* distinct as dialogue deepens — physically possible but |
| 228 | + atypical). The half-life concept doesn't apply; we surface the |
| 229 | + curve as evidence and let the user judge. |
| 230 | + - ``"degenerate"``: KL was zero / negative at every turn. The |
| 231 | + adapter is producing identical-to-base distributions; we |
| 232 | + report a half-life of 0 to flag a likely no-op adapter. |
| 233 | + """ |
| 234 | + # All-zero / non-positive KLs ⇒ no signal at any turn ⇒ probable no-op. |
| 235 | + if not (kls > 0.0).all(): |
| 236 | + # Mixed positive + zero: still try the fit on the positives if there |
| 237 | + # are at least 2 of them. All-zero / all-negative ⇒ degenerate. |
| 238 | + positive_mask = kls > 0.0 |
| 239 | + if positive_mask.sum() < 2: |
| 240 | + return 0.0, "degenerate" |
| 241 | + turns = turns[positive_mask] |
| 242 | + kls = kls[positive_mask] |
| 243 | + |
| 244 | + # Flat-curve detection runs *before* the fit: if the KLs sit |
| 245 | + # within a tight relative band of the mean, the slope is ~0 and |
| 246 | + # the half-life is formally infinite. Pre-detecting this avoids |
| 247 | + # an awkward "slope == 0 vs slope < epsilon" boundary inside the |
| 248 | + # regression branch. |
| 249 | + kl_mean = float(kls.mean()) |
| 250 | + if kl_mean > 0.0: |
| 251 | + relative_spread = float((kls.max() - kls.min()) / kl_mean) |
| 252 | + if relative_spread < 1e-3: |
| 253 | + return float(max_turns) * 10.0, "stable" |
| 254 | + |
| 255 | + log_y = np.log(kls) |
| 256 | + x_mean = float(turns.mean()) |
| 257 | + y_mean = float(log_y.mean()) |
| 258 | + denom = float(((turns - x_mean) ** 2).sum()) |
| 259 | + # All turns identical (impossible for our turn-axis but defensive). |
| 260 | + if denom == 0.0: |
| 261 | + return None, "non_monotonic" |
| 262 | + slope = float(((turns - x_mean) * (log_y - y_mean)).sum()) / denom |
| 263 | + if slope > 0.0: |
| 264 | + # KL grew with turn — adapter became more distinct. Could |
| 265 | + # mean genuine multi-turn personality emergence; leave half- |
| 266 | + # life undefined and flag in the status. |
| 267 | + return None, "non_monotonic" |
| 268 | + |
| 269 | + half_life = float(math.log(2.0) * (-1.0 / slope)) |
| 270 | + # B11 — a near-stable adapter has a tiny negative slope; the fit |
| 271 | + # produces an enormous half-life. Clip + label so the report |
| 272 | + # doesn't print ``inf`` or implausible values. |
| 273 | + saturation = float(max_turns) * 10.0 |
| 274 | + if half_life >= saturation or not math.isfinite(half_life): |
| 275 | + return saturation, "stable" |
| 276 | + return half_life, "ok" |
| 277 | + |
| 278 | + |
| 279 | +def _verdict_from_half_life( |
| 280 | + *, |
| 281 | + half_life: float | None, |
| 282 | + fit_status: str, |
| 283 | + target: float, |
| 284 | + mean_kls: list[float], |
| 285 | + turns_axis: list[float], |
| 286 | +) -> tuple[Verdict, float, str]: |
| 287 | + """Translate the curve-fit result into a (verdict, score, message).""" |
| 288 | + sparkline = _ascii_sparkline(mean_kls) |
| 289 | + |
| 290 | + if fit_status == "non_monotonic": |
| 291 | + # KL grew with turn — atypical but not necessarily wrong. |
| 292 | + # Surface as WARN with the curve; user judgment from there. |
| 293 | + return ( |
| 294 | + Verdict.WARN, |
| 295 | + 0.5, |
| 296 | + f"non-monotonic KL across turns (adapter grew more distinct); curve {sparkline}", |
| 297 | + ) |
| 298 | + |
| 299 | + if fit_status == "degenerate": |
| 300 | + return ( |
| 301 | + Verdict.FAIL, |
| 302 | + 0.0, |
| 303 | + f"KL is zero across all turns (probable no-op adapter); curve {sparkline}", |
| 304 | + ) |
| 305 | + |
| 306 | + if half_life is None: |
| 307 | + return ( |
| 308 | + Verdict.ERROR, |
| 309 | + 0.0, |
| 310 | + f"could not fit decay curve (status={fit_status}); curve {sparkline}", |
| 311 | + ) |
| 312 | + |
| 313 | + if fit_status == "stable": |
| 314 | + return ( |
| 315 | + Verdict.PASS, |
| 316 | + 1.0, |
| 317 | + f"adapter held coherence across all {int(turns_axis[-1])} turns; curve {sparkline}", |
| 318 | + ) |
| 319 | + |
| 320 | + # fit_status == "ok" |
| 321 | + passed = half_life >= target |
| 322 | + score = float(min(1.0, half_life / max(target, 1e-6))) |
| 323 | + return ( |
| 324 | + Verdict.PASS if passed else Verdict.FAIL, |
| 325 | + score, |
| 326 | + f"half-life={half_life:.2f} turns ({'≥' if passed else '<'} {target}); curve {sparkline}", |
| 327 | + ) |
| 328 | + |
| 329 | + |
| 330 | +# --------------------------------------------------------------------------- |
| 331 | +# Chat-template / tokenizer helpers |
| 332 | +# --------------------------------------------------------------------------- |
| 333 | + |
| 334 | + |
| 335 | +def _peek_backend_tokenizer(ctx: RunContext) -> Any | None: |
| 336 | + """Same trick :mod:`prompt_collapse` uses: backends store the |
| 337 | + tokenizer at ``_tokenizer``. Returns ``None`` for the dummy |
| 338 | + backend (which has no tokenizer).""" |
| 339 | + return getattr(ctx.backend, "_tokenizer", None) |
| 340 | + |
| 341 | + |
| 342 | +def _chat_template_of(tokenizer: Any | None) -> str | None: |
| 343 | + """Return the tokenizer's chat_template string, or ``None``. |
| 344 | + |
| 345 | + HF tokenizers expose a ``chat_template`` attribute that's either |
| 346 | + a Jinja string or ``None``. Some custom tokenizers don't have the |
| 347 | + attribute at all — covered by the ``getattr`` default. |
| 348 | + """ |
| 349 | + if tokenizer is None: |
| 350 | + return None |
| 351 | + template = getattr(tokenizer, "chat_template", None) |
| 352 | + if template is None or not isinstance(template, str) or not template.strip(): |
| 353 | + return None |
| 354 | + return str(template) |
| 355 | + |
| 356 | + |
| 357 | +def _format_chat( |
| 358 | + messages: list[dict[str, str]], |
| 359 | + tokenizer: Any, |
| 360 | + *, |
| 361 | + add_generation_prompt: bool, |
| 362 | +) -> str: |
| 363 | + """Render `messages` via the tokenizer's chat template. |
| 364 | + |
| 365 | + Caller has already verified the tokenizer carries a chat_template |
| 366 | + (via :func:`_chat_template_of`); a ``None`` tokenizer here is a |
| 367 | + bug, not a runtime fallback path. We still defensively fall |
| 368 | + back to a minimal role-marker concatenation so the dummy |
| 369 | + backend's tests can drive the probe end-to-end without standing |
| 370 | + up a real tokenizer. |
| 371 | + """ |
| 372 | + if tokenizer is None: |
| 373 | + return _fallback_format(messages, add_generation_prompt=add_generation_prompt) |
| 374 | + try: |
| 375 | + out = tokenizer.apply_chat_template( |
| 376 | + messages, tokenize=False, add_generation_prompt=add_generation_prompt |
| 377 | + ) |
| 378 | + return str(out) |
| 379 | + except Exception: # noqa: BLE001 — tokenizer impls vary; never let a probe crash on this |
| 380 | + return _fallback_format(messages, add_generation_prompt=add_generation_prompt) |
| 381 | + |
| 382 | + |
| 383 | +def _fallback_format(messages: list[dict[str, str]], *, add_generation_prompt: bool) -> str: |
| 384 | + """Minimal role-marker formatter for tokenizers that misbehave + tests.""" |
| 385 | + parts: list[str] = [] |
| 386 | + for m in messages: |
| 387 | + parts.append(f"{m['role'].upper()}: {m['content']}") |
| 388 | + if add_generation_prompt: |
| 389 | + parts.append("ASSISTANT:") |
| 390 | + return "\n".join(parts) |
| 391 | + |
| 392 | + |
| 393 | +# --------------------------------------------------------------------------- |
| 394 | +# Report sparkline |
| 395 | +# --------------------------------------------------------------------------- |
| 396 | + |
| 397 | + |
| 398 | +_SPARKLINE_BARS = "▁▂▃▄▅▆▇█" |
| 399 | + |
| 400 | + |
| 401 | +def _ascii_sparkline(values: list[float]) -> str: |
| 402 | + """Compact unicode sparkline of per-turn KL. |
| 403 | + |
| 404 | + Surfaces in the report message so a terminal-only reader gets a |
| 405 | + sense of curve shape without opening the JSON. Empty/degenerate |
| 406 | + inputs render as an empty string. |
| 407 | + """ |
| 408 | + if not values: |
| 409 | + return "" |
| 410 | + finite = [v for v in values if math.isfinite(v) and v >= 0.0] |
| 411 | + if not finite: |
| 412 | + return "" |
| 413 | + lo, hi = min(finite), max(finite) |
| 414 | + span = hi - lo |
| 415 | + if span <= 1e-12: |
| 416 | + # Flat line — every bar at mid-height. |
| 417 | + return _SPARKLINE_BARS[len(_SPARKLINE_BARS) // 2] * len(values) |
| 418 | + out: list[str] = [] |
| 419 | + for v in values: |
| 420 | + if not math.isfinite(v) or v < 0.0: |
| 421 | + out.append("?") |
| 422 | + continue |
| 423 | + idx = int((v - lo) / span * (len(_SPARKLINE_BARS) - 1)) |
| 424 | + out.append(_SPARKLINE_BARS[max(0, min(idx, len(_SPARKLINE_BARS) - 1))]) |
| 425 | + return "".join(out) |