tenseleyflow/sway / fab3c83

Browse files

tests/unit: backend_api — MockTransport coverage across all three scoring methods, retries, preflight

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
fab3c832bf577f17fde8957884f71ce79ab51ab8
Parents
74c011c
Tree
95af811

1 changed file

StatusFile+-
A tests/unit/test_backend_api.py 364 0
tests/unit/test_backend_api.pyadded
@@ -0,0 +1,364 @@
1
+"""Tests for :mod:`dlm_sway.backends.api` (S13 / F7).
2
+
3
+The backend hits real HTTP in production. Here we use httpx's
4
+``MockTransport`` to intercept every request and hand back canned
5
+OpenAI-shaped responses. That lets us exercise the full request-build
6
++ response-parse roundtrip without running a server.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+import json as _json
12
+from collections.abc import Callable
13
+from typing import Any
14
+
15
+import numpy as np
16
+import pytest
17
+
18
+pytest.importorskip("httpx")
19
+pytest.importorskip("tenacity")
20
+
21
+import httpx  # noqa: E402
22
+
23
+from dlm_sway.backends.api import (  # noqa: E402
24
+    ApiScoringBackend,
25
+    _split_echo_at_char,
26
+)
27
+from dlm_sway.core.errors import ProbeError  # noqa: E402
28
+
29
+
30
+def _echo_response(tokens: list[str], logprobs: list[float | None]) -> dict[str, Any]:
31
+    """Shape an ``echo=True`` /v1/completions response."""
32
+    return {
33
+        "choices": [
34
+            {
35
+                "text": "".join(tokens),
36
+                "logprobs": {
37
+                    "tokens": tokens,
38
+                    "token_logprobs": logprobs,
39
+                    "top_logprobs": [],
40
+                },
41
+                "finish_reason": "length",
42
+            }
43
+        ],
44
+        "usage": {
45
+            "prompt_tokens": len(tokens),
46
+            "completion_tokens": 0,
47
+            "total_tokens": len(tokens),
48
+        },
49
+    }
50
+
51
+
52
+def _top_logprobs_response(top: dict[str, float]) -> dict[str, Any]:
53
+    """Shape a ``max_tokens=1, logprobs=K`` response."""
54
+    return {
55
+        "choices": [
56
+            {
57
+                "text": next(iter(top)),
58
+                "logprobs": {
59
+                    "tokens": [next(iter(top))],
60
+                    "token_logprobs": [next(iter(top.values()))],
61
+                    "top_logprobs": [top],
62
+                },
63
+                "finish_reason": "length",
64
+            }
65
+        ],
66
+        "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
67
+    }
68
+
69
+
70
+def _make_backend(
71
+    handler: Callable[[httpx.Request], httpx.Response], **kwargs: Any
72
+) -> ApiScoringBackend:
73
+    """Construct an ApiScoringBackend whose httpx client uses a MockTransport."""
74
+    backend = ApiScoringBackend(
75
+        base_url="https://mock.example/",
76
+        model_name="mock-model",
77
+        api_key=None,
78
+        max_retries=0,
79
+        **kwargs,
80
+    )
81
+    # Swap the client with one wired to a MockTransport. We reuse the
82
+    # backend's own lazy-init by forcing the client before any method
83
+    # calls it.
84
+    backend._client = httpx.Client(
85
+        base_url="https://mock.example",
86
+        transport=httpx.MockTransport(handler),
87
+        headers={"content-type": "application/json"},
88
+    )
89
+    return backend
90
+
91
+
92
+@pytest.fixture(autouse=True)
93
+def _clear_api_key_env(monkeypatch: pytest.MonkeyPatch) -> None:
94
+    """Keep tests hermetic — ignore any API keys the dev machine has set."""
95
+    monkeypatch.delenv("SWAY_API_KEY", raising=False)
96
+    monkeypatch.delenv("OPENAI_API_KEY", raising=False)
97
+
98
+
99
+class TestLogprobOf:
100
+    def test_sums_completion_tokens(self) -> None:
101
+        """prompt/completion that split cleanly on a token boundary.
102
+
103
+        With ``prompt="The cat"`` + ``completion=" sat."`` the split
104
+        lands after the second token, so the completion's logprobs
105
+        are exactly the last two tokens' values.
106
+        """
107
+
108
+        def handler(request: httpx.Request) -> httpx.Response:
109
+            body = _json.loads(request.content)
110
+            assert body["echo"] is True
111
+            assert body["max_tokens"] == 0
112
+            # Echo tokens for "The cat sat." — 4 tokens; first is
113
+            # None (no left context), the rest have logprobs.
114
+            return httpx.Response(
115
+                200,
116
+                json=_echo_response(
117
+                    tokens=["The", " cat", " sat", "."],
118
+                    logprobs=[None, -2.1, -1.7, -0.9],
119
+                ),
120
+            )
121
+
122
+        backend = _make_backend(handler)
123
+        # prompt_chars=7 ("The cat"): after "The" total=3, after " cat"
124
+        # total=7 (>=7) → split index 2. Completion = [" sat", "."].
125
+        lp = backend.logprob_of(prompt="The cat", completion=" sat.")
126
+        assert lp == pytest.approx(-1.7 + -0.9)
127
+
128
+    def test_mid_token_prompt_leans_conservative(self) -> None:
129
+        """When the prompt boundary falls inside a token, that token
130
+        lands on the prompt side — over-counts the prompt, under-
131
+        attributes to the completion. Documented in ``_split_echo_at_char``.
132
+        """
133
+
134
+        def handler(_: httpx.Request) -> httpx.Response:
135
+            return httpx.Response(
136
+                200,
137
+                json=_echo_response(
138
+                    tokens=["The", " cat", " sat", "."],
139
+                    logprobs=[None, -2.1, -1.7, -0.9],
140
+                ),
141
+            )
142
+
143
+        backend = _make_backend(handler)
144
+        # prompt_chars=4 ("The ") — boundary is inside " cat".
145
+        # Conservative: " cat" joins the prompt side; completion = [" sat", "."].
146
+        lp = backend.logprob_of(prompt="The ", completion="cat sat.")
147
+        assert lp == pytest.approx(-1.7 + -0.9)
148
+
149
+    def test_empty_completion_raises(self) -> None:
150
+        backend = _make_backend(lambda _: httpx.Response(200, json={}))
151
+        with pytest.raises(ProbeError, match="tokenized to zero"):
152
+            backend.logprob_of(prompt="hi", completion="")
153
+
154
+    def test_caches_repeated_calls(self) -> None:
155
+        """S07 cache dedupes identical (prompt, completion) pairs."""
156
+        calls = {"count": 0}
157
+
158
+        def handler(_: httpx.Request) -> httpx.Response:
159
+            calls["count"] += 1
160
+            return httpx.Response(
161
+                200,
162
+                json=_echo_response(
163
+                    tokens=["Hi", "!"],
164
+                    logprobs=[None, -3.0],
165
+                ),
166
+            )
167
+
168
+        backend = _make_backend(handler)
169
+        backend.logprob_of("Hi", "!")
170
+        backend.logprob_of("Hi", "!")
171
+        assert calls["count"] == 1
172
+
173
+
174
+class TestRollingLogprob:
175
+    def test_returns_numeric_array(self) -> None:
176
+        def handler(_: httpx.Request) -> httpx.Response:
177
+            return httpx.Response(
178
+                200,
179
+                json=_echo_response(
180
+                    tokens=["Hello", " world", "!"],
181
+                    logprobs=[None, -2.5, -1.5],
182
+                ),
183
+            )
184
+
185
+        backend = _make_backend(handler)
186
+        r = backend.rolling_logprob("Hello world!")
187
+        assert r.num_tokens == 3
188
+        assert r.logprobs.dtype == np.float32
189
+        # First None is dropped; the rest form the rolling array.
190
+        np.testing.assert_allclose(r.logprobs, [-2.5, -1.5])
191
+        assert r.total_logprob == pytest.approx(-4.0)
192
+
193
+    def test_perplexity_finite(self) -> None:
194
+        def handler(_: httpx.Request) -> httpx.Response:
195
+            return httpx.Response(
196
+                200,
197
+                json=_echo_response(
198
+                    tokens=["a", "b"],
199
+                    logprobs=[None, -1.0],
200
+                ),
201
+            )
202
+
203
+        backend = _make_backend(handler)
204
+        r = backend.rolling_logprob("ab")
205
+        assert r.perplexity > 1.0
206
+        assert np.isfinite(r.perplexity)
207
+
208
+
209
+class TestNextTokenDist:
210
+    def test_parses_top_logprobs_descending(self) -> None:
211
+        def handler(request: httpx.Request) -> httpx.Response:
212
+            body = _json.loads(request.content)
213
+            assert body["max_tokens"] == 1
214
+            assert body["echo"] is False
215
+            assert body["logprobs"] >= 1
216
+            return httpx.Response(
217
+                200,
218
+                json=_top_logprobs_response(
219
+                    # Logprobs that sum to < 1 in prob-space so the
220
+                    # residual is a real positive tail mass (~0.22).
221
+                    {" cat": -0.5, " dog": -2.0, " fish": -3.5, " bird": -5.0},
222
+                ),
223
+            )
224
+
225
+        backend = _make_backend(handler, vocab_size=50_000)
226
+        d = backend.next_token_dist("The quick brown fox chased the", top_k=4)
227
+        assert d.token_ids.shape == (4,)
228
+        assert d.logprobs.shape == (4,)
229
+        # Sorted descending by logprob.
230
+        assert np.all(np.diff(d.logprobs) <= 0)
231
+        assert d.logprobs[0] == pytest.approx(-0.5)
232
+        # vocab_size threads through; tail_logprob is a finite residual.
233
+        assert d.vocab_size == 50_000
234
+        assert d.tail_logprob is not None
235
+        assert d.tail_logprob < 0
236
+
237
+    def test_unknown_vocab_produces_none_tail(self) -> None:
238
+        def handler(_: httpx.Request) -> httpx.Response:
239
+            return httpx.Response(200, json=_top_logprobs_response({"a": -0.5, "b": -1.5}))
240
+
241
+        backend = _make_backend(handler)  # no vocab_size
242
+        d = backend.next_token_dist("x", top_k=2)
243
+        assert d.tail_logprob is None
244
+
245
+
246
+class TestPreflight:
247
+    def test_passes_on_finite_dist(self) -> None:
248
+        def handler(_: httpx.Request) -> httpx.Response:
249
+            return httpx.Response(200, json=_top_logprobs_response({"ok": -0.1}))
250
+
251
+        backend = _make_backend(handler, vocab_size=100)
252
+        ok, reason = backend.preflight_finite_check()
253
+        assert ok, reason
254
+
255
+    def test_fails_on_nan_logprob(self) -> None:
256
+        # JSON doesn't spec NaN; httpx's json kwarg rejects it. Hand-
257
+        # craft the body with Python's default json module (which does
258
+        # emit ``NaN``) so the fixture behaves like a lenient upstream.
259
+        body = _json.dumps(_top_logprobs_response({"nope": float("nan")}))
260
+
261
+        def handler(_: httpx.Request) -> httpx.Response:
262
+            return httpx.Response(
263
+                200,
264
+                content=body,
265
+                headers={"content-type": "application/json"},
266
+            )
267
+
268
+        backend = _make_backend(handler, vocab_size=100)
269
+        ok, reason = backend.preflight_finite_check()
270
+        assert not ok
271
+        assert "non-finite" in reason.lower() or "nan" in reason.lower()
272
+
273
+    def test_fails_on_http_error(self) -> None:
274
+        def handler(_: httpx.Request) -> httpx.Response:
275
+            return httpx.Response(500, text="internal server error")
276
+
277
+        backend = _make_backend(handler, vocab_size=100)
278
+        ok, reason = backend.preflight_finite_check()
279
+        assert not ok
280
+        assert "500" in reason or "server" in reason.lower() or "error" in reason.lower()
281
+
282
+
283
+class TestHttpErrorPaths:
284
+    def test_4xx_raises_probe_error(self) -> None:
285
+        def handler(_: httpx.Request) -> httpx.Response:
286
+            return httpx.Response(401, text="bad api key")
287
+
288
+        backend = _make_backend(handler)
289
+        with pytest.raises(ProbeError, match="401"):
290
+            backend.logprob_of("hi", "there")
291
+
292
+    def test_retry_path_eventually_succeeds(self) -> None:
293
+        """With retries enabled, a 503 followed by a 200 returns 200's body."""
294
+        sequence = iter(
295
+            [
296
+                httpx.Response(503, text="overloaded"),
297
+                httpx.Response(
298
+                    200,
299
+                    json=_echo_response(
300
+                        tokens=["hi", " there"],
301
+                        logprobs=[None, -2.0],
302
+                    ),
303
+                ),
304
+            ]
305
+        )
306
+
307
+        def handler(_: httpx.Request) -> httpx.Response:
308
+            return next(sequence)
309
+
310
+        backend = ApiScoringBackend(
311
+            base_url="https://mock.example",
312
+            model_name="mock-model",
313
+            api_key=None,
314
+            max_retries=2,
315
+        )
316
+        # Force the client with a MockTransport + short (mocked) wait;
317
+        # tenacity's wait_exponential is fine at this tiny retry count.
318
+        backend._client = httpx.Client(
319
+            base_url="https://mock.example",
320
+            transport=httpx.MockTransport(handler),
321
+        )
322
+        lp = backend.logprob_of("hi", " there")
323
+        assert lp == pytest.approx(-2.0)
324
+
325
+
326
+class TestSplitEchoAtChar:
327
+    def test_prompt_at_token_boundary(self) -> None:
328
+        assert _split_echo_at_char(["The", " cat"], prompt_chars=3) == 1
329
+
330
+    def test_prompt_mid_token_includes_token_in_prompt(self) -> None:
331
+        # "Th" is 2 chars of "The" (3 chars) — prompt extends into the
332
+        # token; conservative: include the whole token on the prompt side.
333
+        assert _split_echo_at_char(["The", " cat"], prompt_chars=2) == 1
334
+
335
+    def test_prompt_past_all_tokens(self) -> None:
336
+        assert _split_echo_at_char(["a", "b"], prompt_chars=10) == 2
337
+
338
+    def test_zero_prompt(self) -> None:
339
+        assert _split_echo_at_char(["x"], prompt_chars=0) == 0
340
+
341
+
342
+class TestConstructor:
343
+    def test_api_key_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
344
+        monkeypatch.setenv("SWAY_API_KEY", "env-token")
345
+        backend = ApiScoringBackend(base_url="http://x", model_name="m", max_retries=0)
346
+        assert backend._api_key == "env-token"
347
+
348
+    def test_openai_api_key_fallback(self, monkeypatch: pytest.MonkeyPatch) -> None:
349
+        monkeypatch.delenv("SWAY_API_KEY", raising=False)
350
+        monkeypatch.setenv("OPENAI_API_KEY", "openai-token")
351
+        backend = ApiScoringBackend(base_url="http://x", model_name="m", max_retries=0)
352
+        assert backend._api_key == "openai-token"
353
+
354
+    def test_safe_for_concurrent_views_is_true(self) -> None:
355
+        """Flag is class-level; no instantiation needed."""
356
+        assert ApiScoringBackend.safe_for_concurrent_views is True
357
+
358
+    def test_cache_identity_stable(self) -> None:
359
+        backend = ApiScoringBackend(
360
+            base_url="http://host.example:8000",
361
+            model_name="llama-3-8b",
362
+            max_retries=0,
363
+        )
364
+        assert backend.cache_identity() == "api:http://host.example:8000:llama-3-8b"