tenseleyflow/sway / c5e9de2

Browse files

tests/unit: multi-rank null calibration — rank_scale semantics, z-profile emission, prove-the-value rank saturation

Authored by espadonne
SHA
c5e9de2b267260a0f32dac931b6f68ccd2c6cf2d
Parents
6b27ee1
Tree
898b0c4

1 changed file

StatusFile+-
A tests/unit/test_null_multi_rank.py 286 0
tests/unit/test_null_multi_rank.pyadded
@@ -0,0 +1,286 @@
1
+"""Tests for the multi-rank null-adapter calibration path (S10 / F4)."""
2
+
3
+from __future__ import annotations
4
+
5
+import math
6
+
7
+import numpy as np
8
+import pytest
9
+
10
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
11
+from dlm_sway.core.result import Verdict
12
+from dlm_sway.probes._zscore import format_z_profile, z_scores_by_rank
13
+from dlm_sway.probes.base import RunContext, build_probe
14
+from dlm_sway.suite.runner import run as run_suite
15
+from dlm_sway.suite.spec import SwaySpec
16
+
17
+
18
+def _backend() -> DummyDifferentialBackend:
19
+    return DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
20
+
21
+
22
+class TestDummyBackendRankScale:
23
+    def test_rank_scale_scales_noise_std(self) -> None:
24
+        """sqrt(rank_scale) scales the null-view perturbation std."""
25
+        backend = _backend()
26
+        prompt = "hello"
27
+
28
+        # Collect 40 samples at each rank_scale to estimate std.
29
+        def _std_at(rank_scale: float) -> float:
30
+            lps = []
31
+            for seed in range(40):
32
+                with backend.as_null_adapter(seed=seed, rank_scale=rank_scale) as view:
33
+                    d = view.next_token_dist(prompt, top_k=8)
34
+                lps.append(d.logprobs)
35
+            arr = np.asarray(lps)
36
+            # Variance across seeds at each position, averaged.
37
+            return float(np.mean(np.std(arr, axis=0)))
38
+
39
+        std_1 = _std_at(1.0)
40
+        std_half = _std_at(0.5)
41
+        std_2 = _std_at(2.0)
42
+        # std ∝ sqrt(rank_scale). Tolerance loose because of the
43
+        # top-k renorm and seed discretion.
44
+        assert std_half < std_1 < std_2
45
+        # Ratio should be roughly sqrt(2) with 20% tolerance.
46
+        ratio_up = std_2 / std_1
47
+        ratio_down = std_1 / std_half
48
+        assert 1.15 < ratio_up < 1.75, f"2x ratio={ratio_up}"
49
+        assert 1.15 < ratio_down < 1.75, f"0.5x ratio={ratio_down}"
50
+
51
+    @pytest.mark.parametrize("bad", [0.0, -1.0, float("nan"), float("inf")])
52
+    def test_rejects_non_positive_rank_scale(self, bad: float) -> None:
53
+        backend = _backend()
54
+        with pytest.raises(ValueError, match="rank_scale"):
55
+            with backend.as_null_adapter(seed=0, rank_scale=bad):
56
+                pass
57
+
58
+    def test_rank_scale_1_preserves_pre_s10_behavior(self) -> None:
59
+        """rank_scale=1.0 → identical output to calling without the kwarg."""
60
+        backend = _backend()
61
+        with backend.as_null_adapter(seed=7) as v1:
62
+            d1 = v1.next_token_dist("hello", top_k=8)
63
+        with backend.as_null_adapter(seed=7, rank_scale=1.0) as v2:
64
+            d2 = v2.next_token_dist("hello", top_k=8)
65
+        np.testing.assert_array_equal(d1.logprobs, d2.logprobs)
66
+
67
+
68
+class TestNullProbeMultiRank:
69
+    def test_single_rank_default_matches_pre_s10(self) -> None:
70
+        """Default rank_multipliers=[1.0] produces the same shape of
71
+        evidence as pre-S10 + a null_stats_by_rank with one entry."""
72
+        backend = _backend()
73
+        probe, spec = build_probe(
74
+            {
75
+                "name": "null",
76
+                "kind": "null_adapter",
77
+                "runs": 3,
78
+                "calibrate_kinds": ["delta_kl"],
79
+            }
80
+        )
81
+        ctx = RunContext(backend=backend)
82
+        result = probe.run(spec, ctx)
83
+        assert result.verdict == Verdict.PASS
84
+        stats = result.evidence["null_stats"]
85
+        by_rank = result.evidence["null_stats_by_rank"]
86
+        assert "delta_kl" in stats
87
+        assert set(by_rank) == {"rank_1.00"}
88
+        assert by_rank["rank_1.00"] == stats
89
+
90
+    def test_three_ranks_produce_three_groups(self) -> None:
91
+        backend = _backend()
92
+        probe, spec = build_probe(
93
+            {
94
+                "name": "null",
95
+                "kind": "null_adapter",
96
+                "runs": 3,
97
+                "rank_multipliers": [0.5, 1.0, 2.0],
98
+                "calibrate_kinds": ["delta_kl"],
99
+            }
100
+        )
101
+        ctx = RunContext(backend=backend)
102
+        result = probe.run(spec, ctx)
103
+        assert result.verdict == Verdict.PASS
104
+        by_rank = result.evidence["null_stats_by_rank"]
105
+        assert set(by_rank) == {"rank_0.50", "rank_1.00", "rank_2.00"}
106
+        for rkey, kind_stats in by_rank.items():
107
+            assert "delta_kl" in kind_stats, f"{rkey} missing delta_kl"
108
+            assert kind_stats["delta_kl"]["std"] > 0.0
109
+
110
+    def test_rank_0_and_negative_rejected(self) -> None:
111
+        backend = _backend()
112
+        probe, spec = build_probe(
113
+            {
114
+                "name": "null",
115
+                "kind": "null_adapter",
116
+                "runs": 2,
117
+                "rank_multipliers": [1.0, -0.5],
118
+                "calibrate_kinds": ["delta_kl"],
119
+            }
120
+        )
121
+        ctx = RunContext(backend=backend)
122
+        result = probe.run(spec, ctx)
123
+        assert result.verdict == Verdict.ERROR
124
+        assert "rank_multipliers" in (result.message or "")
125
+
126
+    def test_higher_rank_has_larger_null_std(self) -> None:
127
+        """A 2x rank null should show more delta_kl variance than a 0.5x one."""
128
+        backend = _backend()
129
+        probe, spec = build_probe(
130
+            {
131
+                "name": "null",
132
+                "kind": "null_adapter",
133
+                "runs": 5,
134
+                "rank_multipliers": [0.5, 2.0],
135
+                "calibrate_kinds": ["delta_kl"],
136
+                "cache": False,
137
+            }
138
+        )
139
+        ctx = RunContext(backend=backend)
140
+        result = probe.run(spec, ctx)
141
+        by_rank = result.evidence["null_stats_by_rank"]
142
+        std_half = by_rank["rank_0.50"]["delta_kl"]["std"]
143
+        std_2 = by_rank["rank_2.00"]["delta_kl"]["std"]
144
+        assert std_2 > std_half, f"2x std={std_2} not > 0.5x std={std_half}"
145
+
146
+
147
+class TestRunnerThreadsNullStatsByRank:
148
+    def test_delta_kl_emits_z_by_rank(self) -> None:
149
+        """null_adapter → delta_kl: evidence carries z_by_rank with three entries."""
150
+        backend = _backend()
151
+        raw_spec = SwaySpec.model_validate(
152
+            {
153
+                "version": 1,
154
+                "models": {
155
+                    "base": {"base": "b"},
156
+                    "ft": {"base": "b", "adapter": "/tmp/a"},
157
+                },
158
+                "suite": [
159
+                    {
160
+                        "name": "null",
161
+                        "kind": "null_adapter",
162
+                        "runs": 3,
163
+                        "rank_multipliers": [0.5, 1.0, 2.0],
164
+                        "cache": False,
165
+                    },
166
+                    {
167
+                        "name": "dk",
168
+                        "kind": "delta_kl",
169
+                        "prompts": ["p1", "p2"],
170
+                        "assert_z_gte": -100.0,  # permissive
171
+                    },
172
+                ],
173
+            }
174
+        )
175
+        result = run_suite(raw_spec, backend)
176
+        assert len(result.probes) == 2
177
+        dk = result.probes[1]
178
+        z_by_rank = dk.evidence.get("z_by_rank")
179
+        assert z_by_rank is not None
180
+        assert set(z_by_rank) == {"rank_0.50", "rank_1.00", "rank_2.00"}
181
+        # Each z is finite.
182
+        for z in z_by_rank.values():
183
+            assert math.isfinite(z)
184
+
185
+
186
+class TestZScoreHelpers:
187
+    def test_z_scores_by_rank_positive_sign(self) -> None:
188
+        raw = 1.0
189
+        stats_by_rank = {
190
+            "rank_1.00": {"mean": 0.5, "std": 0.1, "n": 3.0},
191
+            "rank_0.50": {"mean": 0.3, "std": 0.1, "n": 3.0},
192
+        }
193
+        z = z_scores_by_rank(raw, stats_by_rank, sign=+1)
194
+        assert z is not None
195
+        assert abs(z["rank_1.00"] - 5.0) < 1e-9
196
+        assert abs(z["rank_0.50"] - 7.0) < 1e-9
197
+
198
+    def test_z_scores_by_rank_negative_sign(self) -> None:
199
+        """Lower-is-better probes invert the sign."""
200
+        raw = 0.1
201
+        stats_by_rank = {"rank_1.00": {"mean": 0.5, "std": 0.1, "n": 3.0}}
202
+        z = z_scores_by_rank(raw, stats_by_rank, sign=-1)
203
+        assert z is not None
204
+        assert abs(z["rank_1.00"] - 4.0) < 1e-9  # -((0.1 - 0.5)/0.1) = 4
205
+
206
+    def test_z_scores_by_rank_none_on_empty(self) -> None:
207
+        assert z_scores_by_rank(0.0, None) is None
208
+        assert z_scores_by_rank(0.0, {}) is None
209
+
210
+    def test_z_scores_by_rank_drops_degenerate_ranks(self) -> None:
211
+        """Ranks with std < MIN_STD silently drop out."""
212
+        stats_by_rank = {
213
+            "rank_1.00": {"mean": 0.0, "std": 0.1, "n": 3.0},
214
+            "rank_0.50": {"mean": 0.0, "std": 1e-9, "n": 3.0},  # degenerate
215
+        }
216
+        z = z_scores_by_rank(1.0, stats_by_rank, sign=+1)
217
+        assert z is not None
218
+        assert set(z) == {"rank_1.00"}
219
+
220
+    def test_format_z_profile_readable_labels(self) -> None:
221
+        s = format_z_profile(
222
+            {"rank_1.00": 4.2, "rank_0.50": 6.8, "rank_2.00": 2.1},
223
+        )
224
+        assert "+4.20σ @ 1x" in s
225
+        assert "+6.80σ @ 0.5x" in s
226
+        assert "+2.10σ @ 2x" in s
227
+        assert " / " in s
228
+
229
+    def test_format_z_profile_empty(self) -> None:
230
+        assert format_z_profile(None) == ""
231
+        assert format_z_profile({}) == ""
232
+
233
+
234
+class TestProveTheValueRankSaturation:
235
+    """S10 prove-the-value (§F4): rank profile reveals adapter saturation.
236
+
237
+    The dummy backend's null view injects noise scaled by
238
+    ``sqrt(rank_scale)`` into ``next_token_dist``. That scales the
239
+    null distribution of ``delta_kl``'s raw metric (mean JS divergence
240
+    across prompts) so that smaller ``rank_scale`` → tighter null →
241
+    larger z at the same adapter divergence.
242
+
243
+    Test: hold the adapter fixed (ft responses that produce a known
244
+    divergence from base), vary rank_scale across {0.5, 1.0, 2.0}, and
245
+    assert z_0.5 > z_1 > z_2 — exactly the signature of a rank-sized
246
+    adapter: stronger signal vs a smaller-rank null, weaker signal vs
247
+    a larger-rank null.
248
+    """
249
+
250
+    def test_rank_profile_monotone_in_inverse_rank(self) -> None:
251
+        backend = _backend()
252
+        raw_spec = SwaySpec.model_validate(
253
+            {
254
+                "version": 1,
255
+                "models": {
256
+                    "base": {"base": "b"},
257
+                    "ft": {"base": "b", "adapter": "/tmp/a"},
258
+                },
259
+                "suite": [
260
+                    {
261
+                        "name": "null",
262
+                        "kind": "null_adapter",
263
+                        "runs": 5,
264
+                        "rank_multipliers": [0.5, 1.0, 2.0],
265
+                        "cache": False,
266
+                    },
267
+                    {
268
+                        "name": "dk",
269
+                        "kind": "delta_kl",
270
+                        "prompts": ["p1", "p2", "p3", "p4"],
271
+                        "assert_z_gte": -100.0,  # permissive
272
+                    },
273
+                ],
274
+            }
275
+        )
276
+        result = run_suite(raw_spec, backend)
277
+        dk = result.probes[1]
278
+        z_by_rank = dk.evidence["z_by_rank"]
279
+        # Smaller rank → tighter null → larger (more positive) z.
280
+        z_half = z_by_rank["rank_0.50"]
281
+        z_1 = z_by_rank["rank_1.00"]
282
+        z_2 = z_by_rank["rank_2.00"]
283
+        assert z_half > z_1 > z_2, (
284
+            f"expected z monotone-decreasing in rank; got "
285
+            f"0.5x={z_half:.2f}, 1x={z_1:.2f}, 2x={z_2:.2f}"
286
+        )