Python · 13461 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.backends.api` (S13 / F7).
2
3 The backend hits real HTTP in production. Here we use httpx's
4 ``MockTransport`` to intercept every request and hand back canned
5 OpenAI-shaped responses. That lets us exercise the full request-build
6 + response-parse roundtrip without running a server.
7 """
8
9 from __future__ import annotations
10
11 import json as _json
12 from collections.abc import Callable
13 from typing import Any
14
15 import numpy as np
16 import pytest
17
18 pytest.importorskip("httpx")
19 pytest.importorskip("tenacity")
20
21 import httpx # noqa: E402
22
23 from dlm_sway.backends.api import ( # noqa: E402
24 ApiScoringBackend,
25 _split_echo_at_char,
26 )
27 from dlm_sway.core.errors import ProbeError # noqa: E402
28
29
30 def _echo_response(tokens: list[str], logprobs: list[float | None]) -> dict[str, Any]:
31 """Shape an ``echo=True`` /v1/completions response."""
32 return {
33 "choices": [
34 {
35 "text": "".join(tokens),
36 "logprobs": {
37 "tokens": tokens,
38 "token_logprobs": logprobs,
39 "top_logprobs": [],
40 },
41 "finish_reason": "length",
42 }
43 ],
44 "usage": {
45 "prompt_tokens": len(tokens),
46 "completion_tokens": 0,
47 "total_tokens": len(tokens),
48 },
49 }
50
51
52 def _top_logprobs_response(top: dict[str, float]) -> dict[str, Any]:
53 """Shape a ``max_tokens=1, logprobs=K`` response."""
54 return {
55 "choices": [
56 {
57 "text": next(iter(top)),
58 "logprobs": {
59 "tokens": [next(iter(top))],
60 "token_logprobs": [next(iter(top.values()))],
61 "top_logprobs": [top],
62 },
63 "finish_reason": "length",
64 }
65 ],
66 "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
67 }
68
69
70 def _make_backend(
71 handler: Callable[[httpx.Request], httpx.Response], **kwargs: Any
72 ) -> ApiScoringBackend:
73 """Construct an ApiScoringBackend whose httpx client uses a MockTransport."""
74 backend = ApiScoringBackend(
75 base_url="https://mock.example/",
76 model_name="mock-model",
77 api_key=None,
78 max_retries=0,
79 **kwargs,
80 )
81 # Swap the client with one wired to a MockTransport. We reuse the
82 # backend's own lazy-init by forcing the client before any method
83 # calls it.
84 backend._client = httpx.Client(
85 base_url="https://mock.example",
86 transport=httpx.MockTransport(handler),
87 headers={"content-type": "application/json"},
88 )
89 return backend
90
91
92 @pytest.fixture(autouse=True)
93 def _clear_api_key_env(monkeypatch: pytest.MonkeyPatch) -> None:
94 """Keep tests hermetic — ignore any API keys the dev machine has set."""
95 monkeypatch.delenv("SWAY_API_KEY", raising=False)
96 monkeypatch.delenv("OPENAI_API_KEY", raising=False)
97
98
99 class TestLogprobOf:
100 def test_sums_completion_tokens(self) -> None:
101 """prompt/completion that split cleanly on a token boundary.
102
103 With ``prompt="The cat"`` + ``completion=" sat."`` the split
104 lands after the second token, so the completion's logprobs
105 are exactly the last two tokens' values.
106 """
107
108 def handler(request: httpx.Request) -> httpx.Response:
109 body = _json.loads(request.content)
110 assert body["echo"] is True
111 assert body["max_tokens"] == 0
112 # Echo tokens for "The cat sat." — 4 tokens; first is
113 # None (no left context), the rest have logprobs.
114 return httpx.Response(
115 200,
116 json=_echo_response(
117 tokens=["The", " cat", " sat", "."],
118 logprobs=[None, -2.1, -1.7, -0.9],
119 ),
120 )
121
122 backend = _make_backend(handler)
123 # prompt_chars=7 ("The cat"): after "The" total=3, after " cat"
124 # total=7 (>=7) → split index 2. Completion = [" sat", "."].
125 lp = backend.logprob_of(prompt="The cat", completion=" sat.")
126 assert lp == pytest.approx(-1.7 + -0.9)
127
128 def test_mid_token_prompt_leans_conservative(self) -> None:
129 """When the prompt boundary falls inside a token, that token
130 lands on the prompt side — over-counts the prompt, under-
131 attributes to the completion. Documented in ``_split_echo_at_char``.
132 """
133
134 def handler(_: httpx.Request) -> httpx.Response:
135 return httpx.Response(
136 200,
137 json=_echo_response(
138 tokens=["The", " cat", " sat", "."],
139 logprobs=[None, -2.1, -1.7, -0.9],
140 ),
141 )
142
143 backend = _make_backend(handler)
144 # prompt_chars=4 ("The ") — boundary is inside " cat".
145 # Conservative: " cat" joins the prompt side; completion = [" sat", "."].
146 lp = backend.logprob_of(prompt="The ", completion="cat sat.")
147 assert lp == pytest.approx(-1.7 + -0.9)
148
149 def test_empty_completion_raises(self) -> None:
150 backend = _make_backend(lambda _: httpx.Response(200, json={}))
151 with pytest.raises(ProbeError, match="tokenized to zero"):
152 backend.logprob_of(prompt="hi", completion="")
153
154 def test_caches_repeated_calls(self) -> None:
155 """S07 cache dedupes identical (prompt, completion) pairs."""
156 calls = {"count": 0}
157
158 def handler(_: httpx.Request) -> httpx.Response:
159 calls["count"] += 1
160 return httpx.Response(
161 200,
162 json=_echo_response(
163 tokens=["Hi", "!"],
164 logprobs=[None, -3.0],
165 ),
166 )
167
168 backend = _make_backend(handler)
169 backend.logprob_of("Hi", "!")
170 backend.logprob_of("Hi", "!")
171 assert calls["count"] == 1
172
173
174 class TestRollingLogprob:
175 def test_returns_numeric_array(self) -> None:
176 def handler(_: httpx.Request) -> httpx.Response:
177 return httpx.Response(
178 200,
179 json=_echo_response(
180 tokens=["Hello", " world", "!"],
181 logprobs=[None, -2.5, -1.5],
182 ),
183 )
184
185 backend = _make_backend(handler)
186 r = backend.rolling_logprob("Hello world!")
187 assert r.num_tokens == 3
188 assert r.logprobs.dtype == np.float32
189 # First None is dropped; the rest form the rolling array.
190 np.testing.assert_allclose(r.logprobs, [-2.5, -1.5])
191 assert r.total_logprob == pytest.approx(-4.0)
192
193 def test_perplexity_finite(self) -> None:
194 def handler(_: httpx.Request) -> httpx.Response:
195 return httpx.Response(
196 200,
197 json=_echo_response(
198 tokens=["a", "b"],
199 logprobs=[None, -1.0],
200 ),
201 )
202
203 backend = _make_backend(handler)
204 r = backend.rolling_logprob("ab")
205 assert r.perplexity > 1.0
206 assert np.isfinite(r.perplexity)
207
208
209 class TestNextTokenDist:
210 def test_parses_top_logprobs_descending(self) -> None:
211 def handler(request: httpx.Request) -> httpx.Response:
212 body = _json.loads(request.content)
213 assert body["max_tokens"] == 1
214 assert body["echo"] is False
215 assert body["logprobs"] >= 1
216 return httpx.Response(
217 200,
218 json=_top_logprobs_response(
219 # Logprobs that sum to < 1 in prob-space so the
220 # residual is a real positive tail mass (~0.22).
221 {" cat": -0.5, " dog": -2.0, " fish": -3.5, " bird": -5.0},
222 ),
223 )
224
225 backend = _make_backend(handler, vocab_size=50_000)
226 d = backend.next_token_dist("The quick brown fox chased the", top_k=4)
227 assert d.token_ids.shape == (4,)
228 assert d.logprobs.shape == (4,)
229 # Sorted descending by logprob.
230 assert np.all(np.diff(d.logprobs) <= 0)
231 assert d.logprobs[0] == pytest.approx(-0.5)
232 # vocab_size threads through; tail_logprob is a finite residual.
233 assert d.vocab_size == 50_000
234 assert d.tail_logprob is not None
235 assert d.tail_logprob < 0
236
237 def test_unknown_vocab_produces_none_tail(self) -> None:
238 def handler(_: httpx.Request) -> httpx.Response:
239 return httpx.Response(200, json=_top_logprobs_response({"a": -0.5, "b": -1.5}))
240
241 backend = _make_backend(handler) # no vocab_size
242 d = backend.next_token_dist("x", top_k=2)
243 assert d.tail_logprob is None
244
245
246 class TestPreflight:
247 def test_passes_on_finite_dist(self) -> None:
248 def handler(_: httpx.Request) -> httpx.Response:
249 return httpx.Response(200, json=_top_logprobs_response({"ok": -0.1}))
250
251 backend = _make_backend(handler, vocab_size=100)
252 ok, reason = backend.preflight_finite_check()
253 assert ok, reason
254
255 def test_fails_on_nan_logprob(self) -> None:
256 # JSON doesn't spec NaN; httpx's json kwarg rejects it. Hand-
257 # craft the body with Python's default json module (which does
258 # emit ``NaN``) so the fixture behaves like a lenient upstream.
259 body = _json.dumps(_top_logprobs_response({"nope": float("nan")}))
260
261 def handler(_: httpx.Request) -> httpx.Response:
262 return httpx.Response(
263 200,
264 content=body,
265 headers={"content-type": "application/json"},
266 )
267
268 backend = _make_backend(handler, vocab_size=100)
269 ok, reason = backend.preflight_finite_check()
270 assert not ok
271 assert "non-finite" in reason.lower() or "nan" in reason.lower()
272
273 def test_fails_on_http_error(self) -> None:
274 def handler(_: httpx.Request) -> httpx.Response:
275 return httpx.Response(500, text="internal server error")
276
277 backend = _make_backend(handler, vocab_size=100)
278 ok, reason = backend.preflight_finite_check()
279 assert not ok
280 assert "500" in reason or "server" in reason.lower() or "error" in reason.lower()
281
282
283 class TestHttpErrorPaths:
284 def test_4xx_raises_probe_error(self) -> None:
285 def handler(_: httpx.Request) -> httpx.Response:
286 return httpx.Response(401, text="bad api key")
287
288 backend = _make_backend(handler)
289 with pytest.raises(ProbeError, match="401"):
290 backend.logprob_of("hi", "there")
291
292 def test_retry_path_eventually_succeeds(self) -> None:
293 """With retries enabled, a 503 followed by a 200 returns 200's body."""
294 sequence = iter(
295 [
296 httpx.Response(503, text="overloaded"),
297 httpx.Response(
298 200,
299 json=_echo_response(
300 tokens=["hi", " there"],
301 logprobs=[None, -2.0],
302 ),
303 ),
304 ]
305 )
306
307 def handler(_: httpx.Request) -> httpx.Response:
308 return next(sequence)
309
310 backend = ApiScoringBackend(
311 base_url="https://mock.example",
312 model_name="mock-model",
313 api_key=None,
314 max_retries=2,
315 )
316 # Force the client with a MockTransport + short (mocked) wait;
317 # tenacity's wait_exponential is fine at this tiny retry count.
318 backend._client = httpx.Client(
319 base_url="https://mock.example",
320 transport=httpx.MockTransport(handler),
321 )
322 lp = backend.logprob_of("hi", " there")
323 assert lp == pytest.approx(-2.0)
324
325
326 class TestSplitEchoAtChar:
327 def test_prompt_at_token_boundary(self) -> None:
328 assert _split_echo_at_char(["The", " cat"], prompt_chars=3) == 1
329
330 def test_prompt_mid_token_includes_token_in_prompt(self) -> None:
331 # "Th" is 2 chars of "The" (3 chars) — prompt extends into the
332 # token; conservative: include the whole token on the prompt side.
333 assert _split_echo_at_char(["The", " cat"], prompt_chars=2) == 1
334
335 def test_prompt_past_all_tokens(self) -> None:
336 assert _split_echo_at_char(["a", "b"], prompt_chars=10) == 2
337
338 def test_zero_prompt(self) -> None:
339 assert _split_echo_at_char(["x"], prompt_chars=0) == 0
340
341
342 class TestConstructor:
343 def test_api_key_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
344 monkeypatch.setenv("SWAY_API_KEY", "env-token")
345 backend = ApiScoringBackend(base_url="http://x", model_name="m", max_retries=0)
346 assert backend._api_key == "env-token"
347
348 def test_openai_api_key_fallback(self, monkeypatch: pytest.MonkeyPatch) -> None:
349 monkeypatch.delenv("SWAY_API_KEY", raising=False)
350 monkeypatch.setenv("OPENAI_API_KEY", "openai-token")
351 backend = ApiScoringBackend(base_url="http://x", model_name="m", max_retries=0)
352 assert backend._api_key == "openai-token"
353
354 def test_safe_for_concurrent_views_is_true(self) -> None:
355 """Flag is class-level; no instantiation needed."""
356 assert ApiScoringBackend.safe_for_concurrent_views is True
357
358 def test_cache_identity_stable(self) -> None:
359 backend = ApiScoringBackend(
360 base_url="http://host.example:8000",
361 model_name="llama-3-8b",
362 max_retries=0,
363 )
364 assert backend.cache_identity() == "api:http://host.example:8000:llama-3-8b"