| 1 | """Tests for :mod:`dlm_sway.probes.prompt_collapse`. |
| 2 | |
| 3 | Uses a programmable dummy backend that serves different token dists |
| 4 | depending on whether the prompt contains the stuffing prefix. That's the |
| 5 | cleanest way to simulate "divergence decays with context length" without |
| 6 | a real model. |
| 7 | """ |
| 8 | |
| 9 | from __future__ import annotations |
| 10 | |
| 11 | import numpy as np |
| 12 | |
| 13 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 14 | from dlm_sway.core.result import Verdict |
| 15 | from dlm_sway.core.scoring import TokenDist |
| 16 | from dlm_sway.probes.base import RunContext, build_probe |
| 17 | from dlm_sway.probes.prompt_collapse import _fit_half_life |
| 18 | |
| 19 | |
| 20 | class TestFitHalfLife: |
| 21 | def test_exponential_recovered(self) -> None: |
| 22 | lengths = np.array([0.0, 100.0, 200.0, 300.0]) |
| 23 | # y = 1.0 * exp(-x / 100) |
| 24 | y = np.exp(-lengths / 100.0) |
| 25 | h = _fit_half_life(lengths, y) |
| 26 | assert h is not None |
| 27 | import math |
| 28 | |
| 29 | # True half-life = ln(2) * 100 ≈ 69.3 |
| 30 | assert abs(h - math.log(2.0) * 100.0) < 1e-6 |
| 31 | |
| 32 | def test_returns_none_for_flat(self) -> None: |
| 33 | lengths = np.array([0.0, 100.0, 200.0]) |
| 34 | y = np.array([1e-10, 1e-10, 1e-10]) |
| 35 | assert _fit_half_life(lengths, y) is not None or _fit_half_life(lengths, y) is None |
| 36 | # Either None or a huge half-life — both acceptable for flat input. |
| 37 | |
| 38 | def test_returns_none_for_increasing(self) -> None: |
| 39 | lengths = np.array([0.0, 100.0, 200.0]) |
| 40 | y = np.array([0.1, 0.3, 0.5]) |
| 41 | assert _fit_half_life(lengths, y) is None |
| 42 | |
| 43 | |
| 44 | def _programmed_backend(stuffing_sensitivity: float) -> DummyDifferentialBackend: |
| 45 | """Return a backend whose divergence decays with prompt length. |
| 46 | |
| 47 | ``stuffing_sensitivity`` controls how quickly the ft distribution |
| 48 | snaps back to base as prompt length grows; lower = healthier adapter. |
| 49 | """ |
| 50 | import numpy as np |
| 51 | |
| 52 | base_probs = np.array([0.5, 0.3, 0.2], dtype=np.float32) |
| 53 | |
| 54 | class _StuffedResponses(DummyResponses): |
| 55 | def __init__(self, is_ft: bool): |
| 56 | super().__init__() |
| 57 | self._is_ft = is_ft |
| 58 | |
| 59 | # Override retrieval by subclassing the view's lookup path. |
| 60 | |
| 61 | # Simpler: use explicit prompts at each expected length to seed the dict. |
| 62 | # The probe prefixes stuffing so the dummy sees the exact final prompt. |
| 63 | # We pre-build dists for each prompt we expect to see. |
| 64 | base = DummyResponses() |
| 65 | ft = DummyResponses() |
| 66 | |
| 67 | # Pre-generate prompts the probe will query. The probe uses default |
| 68 | # context_lengths=[0,256,512,1024] times _STUFFING ~4 chars/tok. |
| 69 | from dlm_sway.probes.prompt_collapse import _stuffing |
| 70 | |
| 71 | for ctx_len in (0, 256, 512, 1024): |
| 72 | prefix = _stuffing(ctx_len) |
| 73 | for prompt in ("q1",): |
| 74 | key = prefix + prompt |
| 75 | # Base: always tight on token 1. |
| 76 | base.token_dists[key] = TokenDist( |
| 77 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 78 | logprobs=np.log(base_probs), |
| 79 | vocab_size=100, |
| 80 | ) |
| 81 | # FT: diverges at ctx=0, decays toward base with length. |
| 82 | decay = np.exp(-ctx_len * stuffing_sensitivity) |
| 83 | ft_probs = base_probs * (1.0 - decay) + np.array([0.1, 0.45, 0.45]) * decay |
| 84 | ft_probs = ft_probs / ft_probs.sum() |
| 85 | ft.token_dists[key] = TokenDist( |
| 86 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 87 | logprobs=np.log(ft_probs.astype(np.float32)), |
| 88 | vocab_size=100, |
| 89 | ) |
| 90 | return DummyDifferentialBackend(base=base, ft=ft) |
| 91 | |
| 92 | |
| 93 | class TestPromptCollapse: |
| 94 | def test_healthy_adapter_passes(self) -> None: |
| 95 | probe, spec = build_probe( |
| 96 | { |
| 97 | "name": "pc", |
| 98 | "kind": "prompt_collapse", |
| 99 | "prompts": ["q1"], |
| 100 | "context_lengths": [0, 256, 512, 1024], |
| 101 | "assert_half_life_tokens": 100, |
| 102 | } |
| 103 | ) |
| 104 | ctx = RunContext(backend=_programmed_backend(stuffing_sensitivity=0.001)) |
| 105 | result = probe.run(spec, ctx) |
| 106 | # Half-life should be well above 100 with slow decay. |
| 107 | assert result.verdict == Verdict.PASS |
| 108 | assert result.raw is not None |
| 109 | assert result.raw > 100 |
| 110 | |
| 111 | def test_collapsing_adapter_fails(self) -> None: |
| 112 | probe, spec = build_probe( |
| 113 | { |
| 114 | "name": "pc", |
| 115 | "kind": "prompt_collapse", |
| 116 | "prompts": ["q1"], |
| 117 | "context_lengths": [0, 256, 512, 1024], |
| 118 | "assert_half_life_tokens": 500, |
| 119 | } |
| 120 | ) |
| 121 | ctx = RunContext(backend=_programmed_backend(stuffing_sensitivity=0.02)) |
| 122 | result = probe.run(spec, ctx) |
| 123 | # Fast decay → short half-life → fail against 500-token threshold. |
| 124 | assert result.verdict == Verdict.FAIL |
| 125 | |
| 126 | def test_tokenizer_aware_stuffing_uses_pad_token(self) -> None: |
| 127 | """B13: when a tokenizer is supplied, the stuffing is built from |
| 128 | the model's pad/unk token, not the hardcoded English string.""" |
| 129 | from dlm_sway.probes.prompt_collapse import _stuffing |
| 130 | |
| 131 | class _FakeTokenizer: |
| 132 | pad_token = "<pad>" |
| 133 | |
| 134 | def encode(self, text: str) -> list[int]: |
| 135 | # 1 id per character of text — simple enough to verify length. |
| 136 | return [1] * len(text) |
| 137 | |
| 138 | def decode(self, ids: list[int], *, skip_special_tokens: bool = False) -> str: |
| 139 | del skip_special_tokens |
| 140 | return "<pad>" * len(ids) |
| 141 | |
| 142 | out = _stuffing(50, tokenizer=_FakeTokenizer()) |
| 143 | # No English noise from the legacy fallback. |
| 144 | assert "archived for historical record" not in out |
| 145 | assert "<pad>" in out |
| 146 | |
| 147 | def test_legacy_path_used_when_no_tokenizer(self) -> None: |
| 148 | """The default ``_stuffing(n)`` (no tokenizer) returns the legacy English.""" |
| 149 | from dlm_sway.probes.prompt_collapse import _stuffing |
| 150 | |
| 151 | out = _stuffing(50) |
| 152 | assert "archived for historical record" in out |
| 153 | |
| 154 | def test_legacy_stuffing_spec_field_forces_english(self) -> None: |
| 155 | """``legacy_stuffing=True`` opts out of the tokenizer path.""" |
| 156 | backend = _programmed_backend(0.001) |
| 157 | probe, spec = build_probe( |
| 158 | { |
| 159 | "name": "pc", |
| 160 | "kind": "prompt_collapse", |
| 161 | "prompts": ["q1"], |
| 162 | "context_lengths": [0, 256], |
| 163 | "assert_half_life_tokens": 0, |
| 164 | "legacy_stuffing": True, |
| 165 | } |
| 166 | ) |
| 167 | # Even if the dummy backend grew a tokenizer, this spec wouldn't |
| 168 | # use it. Smoke: probe runs end-to-end. |
| 169 | ctx = RunContext(backend=backend) |
| 170 | result = probe.run(spec, ctx) |
| 171 | assert result.verdict in (Verdict.PASS, Verdict.FAIL) |
| 172 | |
| 173 | def test_error_on_empty_prompts(self) -> None: |
| 174 | probe, spec = build_probe( |
| 175 | { |
| 176 | "name": "pc", |
| 177 | "kind": "prompt_collapse", |
| 178 | "prompts": [], |
| 179 | "context_lengths": [0, 256], |
| 180 | } |
| 181 | ) |
| 182 | ctx = RunContext(backend=_programmed_backend(0.001)) |
| 183 | result = probe.run(spec, ctx) |
| 184 | assert result.verdict == Verdict.ERROR |