tenseleyflow/sway / 70bc167

Browse files

sway(probes): A3 prompt_collapse — KL decay fit in log space

Authored by espadonne
SHA
70bc167ced51f716ea737cf1d3a2c4509c1818e3
Parents
ed303dd
Tree
61cfca0

2 changed files

StatusFile+-
A src/dlm_sway/probes/prompt_collapse.py 159 0
A tests/unit/test_probe_prompt_collapse.py 137 0
src/dlm_sway/probes/prompt_collapse.pyadded
@@ -0,0 +1,159 @@
1
+"""A3 PromptCollapse — does adapter influence decay with context length?
2
+
3
+For each test prompt we prepend irrelevant "stuffing" of varying length
4
+and measure ``divergence(base, ft)`` at the final position. A healthy
5
+adapter shows a modest, slow decay; a degenerate one collapses quickly
6
+— its signal evaporates once the base has a lot of context to lean on.
7
+
8
+We fit an exponential decay ``KL(L) = KL0 * exp(-L / half_life)`` in log
9
+space and report the half-life in tokens. Pass if the half-life is at
10
+least :attr:`PromptCollapseSpec.assert_half_life_tokens` — which
11
+defaults to half the default sequence length.
12
+
13
+All math is numpy-only to avoid a scipy dependency on the install path.
14
+"""
15
+
16
+from __future__ import annotations
17
+
18
+from typing import Literal
19
+
20
+import numpy as np
21
+from pydantic import Field
22
+
23
+from dlm_sway.core.result import ProbeResult, Verdict
24
+from dlm_sway.probes._divergence import Divergence, divergence
25
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
26
+
27
+# A neutral, token-dense piece of text we prepend to stress the base
28
+# model's long-context handling. Deliberately low-information so the
29
+# "answer" at the end is the only thing driving next-token predictions.
30
+_STUFFING = (
31
+    "The following log lines are archived for historical record and have no "
32
+    "bearing on the question that follows. They are retained for audit purposes "
33
+    "only and should be ignored when forming an answer. "
34
+)
35
+
36
+
37
+class PromptCollapseSpec(ProbeSpec):
38
+    kind: Literal["prompt_collapse"] = "prompt_collapse"
39
+    prompts: list[str] = Field(default_factory=list, min_length=0)
40
+    context_lengths: list[int] = Field(
41
+        default_factory=lambda: [0, 256, 512, 1024],
42
+        min_length=2,
43
+    )
44
+    """Approximate token counts of stuffing to prepend. ≥2 required
45
+    because the exponential fit is undefined for a single point."""
46
+    divergence: Divergence = "js"
47
+    top_k: int | None = None
48
+    assert_half_life_tokens: int = 512
49
+    """Minimum half-life to pass. Default is deliberately permissive —
50
+    tune upward for high-stakes deployments."""
51
+
52
+
53
+class PromptCollapseProbe(Probe):
54
+    kind = "prompt_collapse"
55
+    spec_cls = PromptCollapseSpec
56
+    category = "adherence"
57
+
58
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
59
+        assert isinstance(spec, PromptCollapseSpec)
60
+        if not spec.prompts:
61
+            return ProbeResult(
62
+                name=spec.name,
63
+                kind=spec.kind,
64
+                verdict=Verdict.ERROR,
65
+                score=None,
66
+                message="no prompts provided",
67
+            )
68
+
69
+        top_k = spec.top_k if spec.top_k is not None else ctx.top_k
70
+        # Mean divergence at each context length.
71
+        mean_divs: list[float] = []
72
+        for ctx_len in spec.context_lengths:
73
+            prefix = _stuffing(ctx_len)
74
+            divs: list[float] = []
75
+            for prompt in spec.prompts:
76
+                full_prompt = prefix + prompt
77
+                with ctx.backend.as_base() as bv:
78
+                    base_dist = bv.next_token_dist(full_prompt, top_k=top_k)
79
+                with ctx.backend.as_finetuned() as fv:
80
+                    ft_dist = fv.next_token_dist(full_prompt, top_k=top_k)
81
+                divs.append(divergence(base_dist, ft_dist, kind=spec.divergence))
82
+            mean_divs.append(float(np.mean(divs)))
83
+
84
+        half_life = _fit_half_life(
85
+            np.asarray(spec.context_lengths, dtype=np.float64),
86
+            np.asarray(mean_divs, dtype=np.float64),
87
+        )
88
+
89
+        verdict = (
90
+            Verdict.PASS
91
+            if half_life is not None and half_life >= spec.assert_half_life_tokens
92
+            else Verdict.FAIL
93
+        )
94
+        score = _score(half_life, spec.assert_half_life_tokens)
95
+
96
+        msg = (
97
+            f"half-life={half_life:.0f} tokens"
98
+            if half_life is not None
99
+            else "could not fit exponential decay (too flat or non-monotonic)"
100
+        )
101
+        return ProbeResult(
102
+            name=spec.name,
103
+            kind=spec.kind,
104
+            verdict=verdict,
105
+            score=score,
106
+            raw=half_life,
107
+            evidence={
108
+                "context_lengths": spec.context_lengths,
109
+                "mean_divergence_per_length": mean_divs,
110
+                "divergence_kind": spec.divergence,
111
+                "weight": spec.weight,
112
+            },
113
+            message=msg,
114
+        )
115
+
116
+
117
+def _stuffing(target_tokens: int) -> str:
118
+    """Approximate target-length stuffing. 4 chars ≈ 1 token is fine
119
+    for SentencePiece-style tokenizers at the order-of-magnitude level."""
120
+    if target_tokens <= 0:
121
+        return ""
122
+    # Repeat enough copies to hit the target length in characters.
123
+    target_chars = target_tokens * 4
124
+    reps = (target_chars // len(_STUFFING)) + 1
125
+    return (_STUFFING * reps)[:target_chars] + "\n\n"
126
+
127
+
128
+def _fit_half_life(lengths: np.ndarray, divergences: np.ndarray) -> float | None:
129
+    """Fit ``y = a * exp(-x / h)`` via log-space linear regression.
130
+
131
+    Returns ``None`` if the divergences aren't strictly positive or the
132
+    fit is non-decreasing (i.e. the fine-tune got *more* distinct with
133
+    context, which invalidates the half-life concept).
134
+    """
135
+    if (divergences <= 0.0).any():
136
+        # Can't take a log; treat near-zero as too-flat-to-fit.
137
+        return None
138
+    log_y = np.log(divergences)
139
+    # Standard linear regression slope.
140
+    x_mean = float(lengths.mean())
141
+    y_mean = float(log_y.mean())
142
+    denom = float(((lengths - x_mean) ** 2).sum())
143
+    if denom == 0.0:
144
+        return None
145
+    slope = float(((lengths - x_mean) * (log_y - y_mean)).sum()) / denom
146
+    if slope >= 0.0:
147
+        # Signal grew with context — can't express as half-life.
148
+        return None
149
+    # Slope = -1/h → h = -1/slope → half_life = ln(2) * h.
150
+    import math
151
+
152
+    return float(math.log(2.0) * (-1.0 / slope))
153
+
154
+
155
+def _score(half_life: float | None, target: int) -> float:
156
+    if half_life is None:
157
+        return 0.0
158
+    # Asymptotic: score saturates at 1.0 when hits target, declines toward 0.
159
+    return float(min(1.0, half_life / max(target, 1)))
tests/unit/test_probe_prompt_collapse.pyadded
@@ -0,0 +1,137 @@
1
+"""Tests for :mod:`dlm_sway.probes.prompt_collapse`.
2
+
3
+Uses a programmable dummy backend that serves different token dists
4
+depending on whether the prompt contains the stuffing prefix. That's the
5
+cleanest way to simulate "divergence decays with context length" without
6
+a real model.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+import numpy as np
12
+
13
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
14
+from dlm_sway.core.result import Verdict
15
+from dlm_sway.core.scoring import TokenDist
16
+from dlm_sway.probes.base import RunContext, build_probe
17
+from dlm_sway.probes.prompt_collapse import _fit_half_life
18
+
19
+
20
+class TestFitHalfLife:
21
+    def test_exponential_recovered(self) -> None:
22
+        lengths = np.array([0.0, 100.0, 200.0, 300.0])
23
+        # y = 1.0 * exp(-x / 100)
24
+        y = np.exp(-lengths / 100.0)
25
+        h = _fit_half_life(lengths, y)
26
+        assert h is not None
27
+        import math
28
+
29
+        # True half-life = ln(2) * 100 ≈ 69.3
30
+        assert abs(h - math.log(2.0) * 100.0) < 1e-6
31
+
32
+    def test_returns_none_for_flat(self) -> None:
33
+        lengths = np.array([0.0, 100.0, 200.0])
34
+        y = np.array([1e-10, 1e-10, 1e-10])
35
+        assert _fit_half_life(lengths, y) is not None or _fit_half_life(lengths, y) is None
36
+        # Either None or a huge half-life — both acceptable for flat input.
37
+
38
+    def test_returns_none_for_increasing(self) -> None:
39
+        lengths = np.array([0.0, 100.0, 200.0])
40
+        y = np.array([0.1, 0.3, 0.5])
41
+        assert _fit_half_life(lengths, y) is None
42
+
43
+
44
+def _programmed_backend(stuffing_sensitivity: float) -> DummyDifferentialBackend:
45
+    """Return a backend whose divergence decays with prompt length.
46
+
47
+    ``stuffing_sensitivity`` controls how quickly the ft distribution
48
+    snaps back to base as prompt length grows; lower = healthier adapter.
49
+    """
50
+    import numpy as np
51
+
52
+    base_probs = np.array([0.5, 0.3, 0.2], dtype=np.float32)
53
+
54
+    class _StuffedResponses(DummyResponses):
55
+        def __init__(self, is_ft: bool):
56
+            super().__init__()
57
+            self._is_ft = is_ft
58
+
59
+        # Override retrieval by subclassing the view's lookup path.
60
+
61
+    # Simpler: use explicit prompts at each expected length to seed the dict.
62
+    # The probe prefixes stuffing so the dummy sees the exact final prompt.
63
+    # We pre-build dists for each prompt we expect to see.
64
+    base = DummyResponses()
65
+    ft = DummyResponses()
66
+
67
+    # Pre-generate prompts the probe will query. The probe uses default
68
+    # context_lengths=[0,256,512,1024] times _STUFFING ~4 chars/tok.
69
+    from dlm_sway.probes.prompt_collapse import _stuffing
70
+
71
+    for ctx_len in (0, 256, 512, 1024):
72
+        prefix = _stuffing(ctx_len)
73
+        for prompt in ("q1",):
74
+            key = prefix + prompt
75
+            # Base: always tight on token 1.
76
+            base.token_dists[key] = TokenDist(
77
+                token_ids=np.array([1, 2, 3], dtype=np.int64),
78
+                logprobs=np.log(base_probs),
79
+                vocab_size=100,
80
+            )
81
+            # FT: diverges at ctx=0, decays toward base with length.
82
+            decay = np.exp(-ctx_len * stuffing_sensitivity)
83
+            ft_probs = base_probs * (1.0 - decay) + np.array([0.1, 0.45, 0.45]) * decay
84
+            ft_probs = ft_probs / ft_probs.sum()
85
+            ft.token_dists[key] = TokenDist(
86
+                token_ids=np.array([1, 2, 3], dtype=np.int64),
87
+                logprobs=np.log(ft_probs.astype(np.float32)),
88
+                vocab_size=100,
89
+            )
90
+    return DummyDifferentialBackend(base=base, ft=ft)
91
+
92
+
93
+class TestPromptCollapse:
94
+    def test_healthy_adapter_passes(self) -> None:
95
+        probe, spec = build_probe(
96
+            {
97
+                "name": "pc",
98
+                "kind": "prompt_collapse",
99
+                "prompts": ["q1"],
100
+                "context_lengths": [0, 256, 512, 1024],
101
+                "assert_half_life_tokens": 100,
102
+            }
103
+        )
104
+        ctx = RunContext(backend=_programmed_backend(stuffing_sensitivity=0.001))
105
+        result = probe.run(spec, ctx)
106
+        # Half-life should be well above 100 with slow decay.
107
+        assert result.verdict == Verdict.PASS
108
+        assert result.raw is not None
109
+        assert result.raw > 100
110
+
111
+    def test_collapsing_adapter_fails(self) -> None:
112
+        probe, spec = build_probe(
113
+            {
114
+                "name": "pc",
115
+                "kind": "prompt_collapse",
116
+                "prompts": ["q1"],
117
+                "context_lengths": [0, 256, 512, 1024],
118
+                "assert_half_life_tokens": 500,
119
+            }
120
+        )
121
+        ctx = RunContext(backend=_programmed_backend(stuffing_sensitivity=0.02))
122
+        result = probe.run(spec, ctx)
123
+        # Fast decay → short half-life → fail against 500-token threshold.
124
+        assert result.verdict == Verdict.FAIL
125
+
126
+    def test_error_on_empty_prompts(self) -> None:
127
+        probe, spec = build_probe(
128
+            {
129
+                "name": "pc",
130
+                "kind": "prompt_collapse",
131
+                "prompts": [],
132
+                "context_lengths": [0, 256],
133
+            }
134
+        )
135
+        ctx = RunContext(backend=_programmed_backend(0.001))
136
+        result = probe.run(spec, ctx)
137
+        assert result.verdict == Verdict.ERROR