Python · 7130 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.prompt_collapse`.
2
3 Uses a programmable dummy backend that serves different token dists
4 depending on whether the prompt contains the stuffing prefix. That's the
5 cleanest way to simulate "divergence decays with context length" without
6 a real model.
7 """
8
9 from __future__ import annotations
10
11 import numpy as np
12
13 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
14 from dlm_sway.core.result import Verdict
15 from dlm_sway.core.scoring import TokenDist
16 from dlm_sway.probes.base import RunContext, build_probe
17 from dlm_sway.probes.prompt_collapse import _fit_half_life
18
19
20 class TestFitHalfLife:
21 def test_exponential_recovered(self) -> None:
22 lengths = np.array([0.0, 100.0, 200.0, 300.0])
23 # y = 1.0 * exp(-x / 100)
24 y = np.exp(-lengths / 100.0)
25 h = _fit_half_life(lengths, y)
26 assert h is not None
27 import math
28
29 # True half-life = ln(2) * 100 ≈ 69.3
30 assert abs(h - math.log(2.0) * 100.0) < 1e-6
31
32 def test_returns_none_for_flat(self) -> None:
33 lengths = np.array([0.0, 100.0, 200.0])
34 y = np.array([1e-10, 1e-10, 1e-10])
35 assert _fit_half_life(lengths, y) is not None or _fit_half_life(lengths, y) is None
36 # Either None or a huge half-life — both acceptable for flat input.
37
38 def test_returns_none_for_increasing(self) -> None:
39 lengths = np.array([0.0, 100.0, 200.0])
40 y = np.array([0.1, 0.3, 0.5])
41 assert _fit_half_life(lengths, y) is None
42
43
44 def _programmed_backend(stuffing_sensitivity: float) -> DummyDifferentialBackend:
45 """Return a backend whose divergence decays with prompt length.
46
47 ``stuffing_sensitivity`` controls how quickly the ft distribution
48 snaps back to base as prompt length grows; lower = healthier adapter.
49 """
50 import numpy as np
51
52 base_probs = np.array([0.5, 0.3, 0.2], dtype=np.float32)
53
54 class _StuffedResponses(DummyResponses):
55 def __init__(self, is_ft: bool):
56 super().__init__()
57 self._is_ft = is_ft
58
59 # Override retrieval by subclassing the view's lookup path.
60
61 # Simpler: use explicit prompts at each expected length to seed the dict.
62 # The probe prefixes stuffing so the dummy sees the exact final prompt.
63 # We pre-build dists for each prompt we expect to see.
64 base = DummyResponses()
65 ft = DummyResponses()
66
67 # Pre-generate prompts the probe will query. The probe uses default
68 # context_lengths=[0,256,512,1024] times _STUFFING ~4 chars/tok.
69 from dlm_sway.probes.prompt_collapse import _stuffing
70
71 for ctx_len in (0, 256, 512, 1024):
72 prefix = _stuffing(ctx_len)
73 for prompt in ("q1",):
74 key = prefix + prompt
75 # Base: always tight on token 1.
76 base.token_dists[key] = TokenDist(
77 token_ids=np.array([1, 2, 3], dtype=np.int64),
78 logprobs=np.log(base_probs),
79 vocab_size=100,
80 )
81 # FT: diverges at ctx=0, decays toward base with length.
82 decay = np.exp(-ctx_len * stuffing_sensitivity)
83 ft_probs = base_probs * (1.0 - decay) + np.array([0.1, 0.45, 0.45]) * decay
84 ft_probs = ft_probs / ft_probs.sum()
85 ft.token_dists[key] = TokenDist(
86 token_ids=np.array([1, 2, 3], dtype=np.int64),
87 logprobs=np.log(ft_probs.astype(np.float32)),
88 vocab_size=100,
89 )
90 return DummyDifferentialBackend(base=base, ft=ft)
91
92
93 class TestPromptCollapse:
94 def test_healthy_adapter_passes(self) -> None:
95 probe, spec = build_probe(
96 {
97 "name": "pc",
98 "kind": "prompt_collapse",
99 "prompts": ["q1"],
100 "context_lengths": [0, 256, 512, 1024],
101 "assert_half_life_tokens": 100,
102 }
103 )
104 ctx = RunContext(backend=_programmed_backend(stuffing_sensitivity=0.001))
105 result = probe.run(spec, ctx)
106 # Half-life should be well above 100 with slow decay.
107 assert result.verdict == Verdict.PASS
108 assert result.raw is not None
109 assert result.raw > 100
110
111 def test_collapsing_adapter_fails(self) -> None:
112 probe, spec = build_probe(
113 {
114 "name": "pc",
115 "kind": "prompt_collapse",
116 "prompts": ["q1"],
117 "context_lengths": [0, 256, 512, 1024],
118 "assert_half_life_tokens": 500,
119 }
120 )
121 ctx = RunContext(backend=_programmed_backend(stuffing_sensitivity=0.02))
122 result = probe.run(spec, ctx)
123 # Fast decay → short half-life → fail against 500-token threshold.
124 assert result.verdict == Verdict.FAIL
125
126 def test_tokenizer_aware_stuffing_uses_pad_token(self) -> None:
127 """B13: when a tokenizer is supplied, the stuffing is built from
128 the model's pad/unk token, not the hardcoded English string."""
129 from dlm_sway.probes.prompt_collapse import _stuffing
130
131 class _FakeTokenizer:
132 pad_token = "<pad>"
133
134 def encode(self, text: str) -> list[int]:
135 # 1 id per character of text — simple enough to verify length.
136 return [1] * len(text)
137
138 def decode(self, ids: list[int], *, skip_special_tokens: bool = False) -> str:
139 del skip_special_tokens
140 return "<pad>" * len(ids)
141
142 out = _stuffing(50, tokenizer=_FakeTokenizer())
143 # No English noise from the legacy fallback.
144 assert "archived for historical record" not in out
145 assert "<pad>" in out
146
147 def test_legacy_path_used_when_no_tokenizer(self) -> None:
148 """The default ``_stuffing(n)`` (no tokenizer) returns the legacy English."""
149 from dlm_sway.probes.prompt_collapse import _stuffing
150
151 out = _stuffing(50)
152 assert "archived for historical record" in out
153
154 def test_legacy_stuffing_spec_field_forces_english(self) -> None:
155 """``legacy_stuffing=True`` opts out of the tokenizer path."""
156 backend = _programmed_backend(0.001)
157 probe, spec = build_probe(
158 {
159 "name": "pc",
160 "kind": "prompt_collapse",
161 "prompts": ["q1"],
162 "context_lengths": [0, 256],
163 "assert_half_life_tokens": 0,
164 "legacy_stuffing": True,
165 }
166 )
167 # Even if the dummy backend grew a tokenizer, this spec wouldn't
168 # use it. Smoke: probe runs end-to-end.
169 ctx = RunContext(backend=backend)
170 result = probe.run(spec, ctx)
171 assert result.verdict in (Verdict.PASS, Verdict.FAIL)
172
173 def test_error_on_empty_prompts(self) -> None:
174 probe, spec = build_probe(
175 {
176 "name": "pc",
177 "kind": "prompt_collapse",
178 "prompts": [],
179 "context_lengths": [0, 256],
180 }
181 )
182 ctx = RunContext(backend=_programmed_backend(0.001))
183 result = probe.run(spec, ctx)
184 assert result.verdict == Verdict.ERROR