Python · 5051 bytes Raw Blame History
1 """S14 F9 prove-the-value: bootstrap CI width shrinks as N grows.
2
3 The audit's F9 pitch is "every numeric probe should publish a CI so
4 downstream claims are honest about sampling noise." The narrow-vs-
5 wide behavior is the concrete evidence that the CI is informative
6 rather than a fixed-width decoration.
7
8 Test construction: build a dummy backend that produces *per-prompt
9 varying* divergences by seeding each prompt's ft token distribution
10 with its hash. Run ``delta_kl`` at N=4 and N=32 on that backend,
11 assert the N=32 CI is strictly narrower than the N=4 CI — the F9
12 claim in concrete form.
13
14 The stock `DummyDifferentialBackend.as_finetuned()` returns the same
15 synthesized distribution for every prompt, which produces identical
16 per-prompt divergences and a zero-width CI at any N. This test's
17 fixture subclasses the dummy to inject per-prompt variation so the
18 bootstrap has actual dispersion to measure.
19 """
20
21 from __future__ import annotations
22
23 import math
24 from collections.abc import Iterator
25 from contextlib import contextmanager
26
27 import numpy as np
28
29 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses, _DummyView
30 from dlm_sway.core.scoring import TokenDist
31 from dlm_sway.probes.base import RunContext, build_probe
32
33
34 class _VariableFtView(_DummyView):
35 """A dummy view whose ``next_token_dist`` varies by prompt.
36
37 Each prompt gets a deterministically-seeded small perturbation of
38 the default ft distribution — enough to produce per-prompt JS
39 differences in the 0.001–0.05 range, which is where the bootstrap
40 CI narrowing is visible.
41 """
42
43 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
44 base_dist = super().next_token_dist(prompt, top_k=top_k)
45 # Use a stable hash (hashlib) instead of Python's built-in
46 # ``hash()``, which salts per-process via PYTHONHASHSEED and
47 # would make per-prompt dispersion vary across pytest runs.
48 import hashlib
49
50 seed = int(hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8], 16)
51 rng = np.random.default_rng(seed)
52 noise = rng.normal(0.0, 0.5, size=base_dist.logprobs.shape).astype(np.float32)
53 perturbed = base_dist.logprobs + noise
54 # Renormalize (within the top-k slice).
55 max_lp = perturbed.max()
56 probs = np.exp(perturbed - max_lp)
57 probs /= probs.sum()
58 return TokenDist(
59 token_ids=base_dist.token_ids,
60 logprobs=np.log(probs).astype(np.float32),
61 vocab_size=base_dist.vocab_size,
62 tail_logprob=base_dist.tail_logprob,
63 )
64
65
66 class _VariableFtBackend(DummyDifferentialBackend):
67 """Dummy backend whose ft view perturbs per-prompt."""
68
69 @contextmanager
70 def as_finetuned(self) -> Iterator[_DummyView]:
71 self._enter("ft")
72 try:
73 view = _VariableFtView("ft", self._ft_r, inst=self._inst)
74 yield view
75 finally:
76 self._exit()
77
78
79 def _run_delta_kl(n_prompts: int) -> tuple[float, float, float]:
80 """Run delta_kl with ``n_prompts`` synthesized prompts. Returns
81 ``(raw, ci_lo, ci_hi)``.
82 """
83 backend = _VariableFtBackend(base=DummyResponses(), ft=DummyResponses())
84 prompts = [f"prompt-{i:03d}" for i in range(n_prompts)]
85 probe, spec = build_probe({"name": f"dk_{n_prompts}", "kind": "delta_kl", "prompts": prompts})
86 ctx = RunContext(backend=backend)
87 result = probe.run(spec, ctx)
88 assert result.raw is not None, "dummy backend delta_kl should produce a raw value"
89 assert result.ci_95 is not None, "bootstrap_ci should land on delta_kl output"
90 lo, hi = result.ci_95
91 return result.raw, lo, hi
92
93
94 def test_ci_width_shrinks_with_more_prompts() -> None:
95 """The F9 claim: `delta_kl = 0.05 [0.01, 0.11]` at N=4 narrows to
96 something tighter at N=32."""
97 raw_4, lo_4, hi_4 = _run_delta_kl(n_prompts=4)
98 raw_32, lo_32, hi_32 = _run_delta_kl(n_prompts=32)
99
100 width_4 = hi_4 - lo_4
101 width_32 = hi_32 - lo_32
102
103 # Both raws are positive divergences, live in the same order of
104 # magnitude, and bracket their own raw value.
105 assert raw_4 > 0
106 assert raw_32 > 0
107 assert lo_4 <= raw_4 <= hi_4
108 assert lo_32 <= raw_32 <= hi_32
109
110 # The N=32 CI is strictly tighter than N=4. Theory predicts the
111 # CI half-width scales as 1/sqrt(N), so N=32 should be roughly
112 # sqrt(8) ≈ 2.8× narrower than N=4.
113 assert width_32 < width_4, (
114 f"expected width_32 < width_4; got {width_32:.4f} >= {width_4:.4f} "
115 f"(CIs: N=4 {[lo_4, hi_4]}, N=32 {[lo_32, hi_32]})"
116 )
117 # Loose additional check: the narrowing factor is meaningfully
118 # bigger than 1.0 — the CI isn't just slightly tighter from
119 # RNG noise.
120 assert width_4 / max(width_32, 1e-9) > 1.5, (
121 f"expected N=4 width at least 1.5× N=32 width; got ratio "
122 f"{width_4 / max(width_32, 1e-9):.2f}"
123 )
124 # Sanity on magnitudes — widths are positive and finite.
125 for w in (width_4, width_32):
126 assert w > 0
127 assert math.isfinite(w)