Python · 3208 bytes Raw Blame History
1 """S13 prove-the-value (§F7): ``ApiScoringBackend`` against a real Ollama.
2
3 **Opt-in.** Skipped unless ``SWAY_OLLAMA_URL`` is set (typically
4 ``http://localhost:11434``). Also needs ``SWAY_OLLAMA_MODEL`` — the
5 name of a model already pulled via ``ollama pull <name>``. A minimal
6 run::
7
8 ollama pull llama3.2:1b
9 ollama serve &
10 SWAY_OLLAMA_URL=http://localhost:11434 \\
11 SWAY_OLLAMA_MODEL=llama3.2:1b \\
12 uv run pytest tests/integration/test_api_ollama.py -v
13
14 What the test proves:
15
16 1. The backend talks to a real OpenAI-compatible endpoint without
17 crashing on any of its three scoring primitives
18 (``logprob_of``, ``rolling_logprob``, ``next_token_dist``).
19 2. Preflight passes (non-finite logprobs would surface here).
20 3. Wall time per call is in a sane range — documents the latency
21 budget the sprint's "≤3× HF backend, ≤1.5× with concurrent_probes=4"
22 claim rests on.
23
24 This test is the F7 claim's concrete backing: ``sway`` can score
25 hosted-inference endpoints end-to-end, not just local HF loads.
26 """
27
28 from __future__ import annotations
29
30 import math
31 import os
32 import time
33 from collections.abc import Iterator
34
35 import pytest
36
37 _ollama_url = os.environ.get("SWAY_OLLAMA_URL")
38 _ollama_model = os.environ.get("SWAY_OLLAMA_MODEL")
39
40 pytestmark = [
41 pytest.mark.slow,
42 pytest.mark.online,
43 pytest.mark.skipif(
44 not _ollama_url or not _ollama_model,
45 reason="set SWAY_OLLAMA_URL + SWAY_OLLAMA_MODEL to run this test",
46 ),
47 ]
48
49 pytest.importorskip("httpx")
50 pytest.importorskip("tenacity")
51
52 from dlm_sway.backends.api import ApiScoringBackend # noqa: E402
53
54
55 @pytest.fixture(scope="module")
56 def backend() -> Iterator[ApiScoringBackend]:
57 assert _ollama_url is not None # narrowing for type-checker
58 assert _ollama_model is not None
59 be = ApiScoringBackend(
60 base_url=_ollama_url,
61 model_name=_ollama_model,
62 api_key=None, # Ollama doesn't require auth by default
63 max_retries=1,
64 timeout_s=60.0,
65 )
66 yield be
67 be.close()
68
69
70 def test_preflight_passes(backend: ApiScoringBackend) -> None:
71 ok, reason = backend.preflight_finite_check()
72 assert ok, reason
73
74
75 def test_logprob_of_finite(backend: ApiScoringBackend) -> None:
76 t0 = time.perf_counter()
77 lp = backend.logprob_of(
78 prompt="The capital of France is",
79 completion=" Paris.",
80 )
81 wall = time.perf_counter() - t0
82 print(f"\n logprob_of wall: {wall:.2f}s")
83 assert math.isfinite(lp)
84 assert lp < 0.0, "logprobs of any non-empty text are negative"
85
86
87 def test_rolling_logprob_shape(backend: ApiScoringBackend) -> None:
88 r = backend.rolling_logprob("Hello world. This is a sentence.")
89 assert r.num_tokens >= 2
90 assert r.logprobs.size == r.num_tokens - 1
91 assert math.isfinite(r.total_logprob)
92 assert math.isfinite(r.perplexity)
93 assert r.perplexity > 1.0
94
95
96 def test_next_token_dist_shape(backend: ApiScoringBackend) -> None:
97 d = backend.next_token_dist("The quick brown fox jumps over the", top_k=8)
98 import numpy as np
99
100 assert d.logprobs.size <= 8
101 assert np.all(np.isfinite(d.logprobs))
102 # Descending by probability.
103 assert np.all(np.diff(d.logprobs) <= 1e-6)