tenseleyflow/sway / b7d17f4

Browse files

tests/integration: HF-backend batched next_token_dist end-to-end (S23 followup)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
b7d17f4c88a585d75e58514032db56a540131edb
Parents
2f53720
Tree
d92fe86

1 changed file

StatusFile+-
A tests/integration/test_hf_batched_s23.py 181 0
tests/integration/test_hf_batched_s23.pyadded
@@ -0,0 +1,181 @@
1
+"""S23 follow-up — HF-backend batched next_token_dist end-to-end.
2
+
3
+Proves the S23 batched-forward path works against a real
4
+HuggingFace + PEFT stack (not just the dummy backend's loop
5
+fallback). The unit tests in ``tests/unit/test_batched_backend_s23``
6
+pin the Protocol and probe-level contracts on the dummy backend;
7
+this one rides the same SmolLM2-135M fixture every other slow+online
8
+test uses and exercises the left-padded ``model.forward`` batched
9
+code path in ``_HFView.next_token_dist_batch``.
10
+
11
+What this test locks down:
12
+
13
+1. Batched output is numerically equivalent to the serial per-prompt
14
+   output on the same prompts (within fp32 batch-reorder tolerance).
15
+2. The instrumentation counters (``batches_sent``, ``batched_prompts``,
16
+   ``max_batch_size``) reflect at least one real batched forward.
17
+3. The cache short-circuits per-prompt when the same prompt re-enters
18
+   a batch, so the batched counter doesn't inflate on repeat runs.
19
+
20
+What this test *doesn't* do (deferred):
21
+
22
+- The fortran-spec wall-time benchmark (≤ 60s vs 155s baseline) —
23
+  needs a 1.5B adapter and real GPU; this fixture is 135M on CPU.
24
+- Real MLX batched forward — MLX backend still loops; real
25
+  ``mx.array`` padded forward is a separate follow-up.
26
+
27
+Marked ``slow + online``.
28
+"""
29
+
30
+from __future__ import annotations
31
+
32
+from pathlib import Path
33
+
34
+import numpy as np
35
+import pytest
36
+
37
+from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
38
+from dlm_sway.core.model import ModelSpec
39
+
40
+pytestmark = [pytest.mark.slow, pytest.mark.online]
41
+
42
+
43
+# Same deterministic-LoRA build the other integration tests use.
44
+def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
45
+    import torch
46
+    from peft import LoraConfig, get_peft_model
47
+    from transformers import AutoModelForCausalLM, AutoTokenizer
48
+
49
+    torch.manual_seed(0)
50
+    tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
51
+    if tokenizer.pad_token_id is None:
52
+        tokenizer.pad_token = tokenizer.eos_token
53
+    base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
54
+    cfg = LoraConfig(
55
+        r=8,
56
+        lora_alpha=16,
57
+        target_modules=["q_proj", "v_proj"],
58
+        lora_dropout=0.0,
59
+        bias="none",
60
+        task_type="CAUSAL_LM",
61
+    )
62
+    peft_model = get_peft_model(base, cfg)
63
+    with torch.no_grad():
64
+        for name, param in peft_model.named_parameters():
65
+            if "lora_B" in name:
66
+                param.copy_(torch.randn_like(param) * 0.05)
67
+    peft_model.save_pretrained(str(out_dir))
68
+    tokenizer.save_pretrained(str(out_dir))
69
+
70
+
71
+@pytest.fixture(scope="module")
72
+def batched_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
73
+    adapter_dir = tmp_path_factory.mktemp("batched-s23-adapter")
74
+    _build_random_lora_adapter(tiny_model_dir, adapter_dir)
75
+    return adapter_dir
76
+
77
+
78
+@pytest.fixture(scope="module")
79
+def hf_backend(tiny_model_dir: Path, batched_adapter: Path) -> HuggingFaceDifferentialBackend:
80
+    backend = HuggingFaceDifferentialBackend(
81
+        base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
82
+        adapter_path=batched_adapter,
83
+    )
84
+    yield backend
85
+    backend.close()
86
+
87
+
88
+# Varied-length prompts so the left-padding path genuinely matters
89
+# (a single length would let a misimplementation slip through).
90
+_PROMPTS = [
91
+    "The capital of France is",
92
+    "Two plus two equals",
93
+    "The quick brown fox jumps over the",
94
+    "Paris",
95
+]
96
+
97
+
98
+def test_batched_output_matches_serial_on_real_model(
99
+    hf_backend: HuggingFaceDifferentialBackend,
100
+) -> None:
101
+    """The batched forward's top-k logprobs must match the per-prompt
102
+    serial forward's on the same prompts, within a tight fp32
103
+    reorder tolerance.
104
+
105
+    Rationale: left-padded batches reorder the underlying attention
106
+    accumulations vs a single-prompt forward. We accept ~1e-4
107
+    divergence — same bar S18's determinism golden uses on CPU.
108
+    """
109
+    # Fresh views to avoid the cache serving identical results from
110
+    # the first call and hiding any real divergence.
111
+    with hf_backend.as_base() as base_view:
112
+        batched_base = base_view.next_token_dist_batch(_PROMPTS, top_k=32)
113
+
114
+    # Clear the cache so the serial calls actually re-forward.
115
+    hf_backend._inst.cache.clear()  # noqa: SLF001
116
+
117
+    with hf_backend.as_base() as base_view:
118
+        serial_base = [base_view.next_token_dist(p, top_k=32) for p in _PROMPTS]
119
+
120
+    for i, (b, s) in enumerate(zip(batched_base, serial_base, strict=True)):
121
+        # Token-id sets should be identical in the top-k slice
122
+        # (ordering can swap on exact-tie logprobs, compare as sets).
123
+        assert set(b.token_ids.tolist()) == set(s.token_ids.tolist()), (
124
+            f"prompt[{i}]={_PROMPTS[i]!r}: top-k token sets differ "
125
+            f"(batched {b.token_ids.tolist()}, serial {s.token_ids.tolist()})"
126
+        )
127
+        # Top-1 logprob should match within the fp32 reorder tol.
128
+        np.testing.assert_allclose(
129
+            sorted(b.logprobs.tolist(), reverse=True)[:5],
130
+            sorted(s.logprobs.tolist(), reverse=True)[:5],
131
+            atol=1e-4,
132
+            rtol=1e-3,
133
+            err_msg=f"prompt[{i}]={_PROMPTS[i]!r}: top-5 logprobs diverged",
134
+        )
135
+
136
+
137
+def test_batched_forward_fires_instrumentation(hf_backend: HuggingFaceDifferentialBackend) -> None:
138
+    """A batched call on the HF backend must increment
139
+    ``batches_sent`` + ``batched_prompts`` + ``max_batch_size``. This
140
+    is how the report footer knows to print the ``batches: N (avg=K)``
141
+    segment."""
142
+    hf_backend._inst.cache.clear()  # noqa: SLF001
143
+    stats = hf_backend._inst.stats  # noqa: SLF001
144
+    before = (stats.batches_sent, stats.batched_prompts, stats.max_batch_size)
145
+
146
+    with hf_backend.as_base() as base_view:
147
+        out = base_view.next_token_dist_batch(_PROMPTS, top_k=16)
148
+
149
+    assert len(out) == len(_PROMPTS)
150
+    after = (stats.batches_sent, stats.batched_prompts, stats.max_batch_size)
151
+    assert after[0] == before[0] + 1, f"expected one new batch, got {after[0] - before[0]}"
152
+    assert after[1] == before[1] + len(_PROMPTS), (
153
+        f"expected +{len(_PROMPTS)} batched prompts, got {after[1] - before[1]}"
154
+    )
155
+    assert after[2] >= len(_PROMPTS)
156
+
157
+
158
+def test_batched_cache_short_circuits_repeat_prompts(
159
+    hf_backend: HuggingFaceDifferentialBackend,
160
+) -> None:
161
+    """Second batched call with identical prompts hits the cache
162
+    per-prompt. ``batches_sent`` must NOT increment a second time
163
+    because no prompts missed."""
164
+    hf_backend._inst.cache.clear()  # noqa: SLF001
165
+    with hf_backend.as_base() as base_view:
166
+        base_view.next_token_dist_batch(_PROMPTS, top_k=16)
167
+    before_batches = hf_backend._inst.stats.batches_sent  # noqa: SLF001
168
+    before_hits = hf_backend._inst.stats.cache_hits  # noqa: SLF001
169
+
170
+    with hf_backend.as_base() as base_view:
171
+        base_view.next_token_dist_batch(_PROMPTS, top_k=16)
172
+
173
+    after_batches = hf_backend._inst.stats.batches_sent  # noqa: SLF001
174
+    after_hits = hf_backend._inst.stats.cache_hits  # noqa: SLF001
175
+    # No new batch — everything came from the cache.
176
+    assert after_batches == before_batches, (
177
+        f"second all-cache-hit call spuriously fired a batch ({before_batches} → {after_batches})"
178
+    )
179
+    assert after_hits - before_hits == len(_PROMPTS), (
180
+        f"expected {len(_PROMPTS)} fresh cache hits, got {after_hits - before_hits}"
181
+    )