Python · 7361 bytes Raw Blame History
1 """S23 follow-up — HF-backend batched next_token_dist end-to-end.
2
3 Proves the S23 batched-forward path works against a real
4 HuggingFace + PEFT stack (not just the dummy backend's loop
5 fallback). The unit tests in ``tests/unit/test_batched_backend_s23``
6 pin the Protocol and probe-level contracts on the dummy backend;
7 this one rides the same SmolLM2-135M fixture every other slow+online
8 test uses and exercises the left-padded ``model.forward`` batched
9 code path in ``_HFView.next_token_dist_batch``.
10
11 What this test locks down:
12
13 1. Batched output is numerically equivalent to the serial per-prompt
14 output on the same prompts (within fp32 batch-reorder tolerance).
15 2. The instrumentation counters (``batches_sent``, ``batched_prompts``,
16 ``max_batch_size``) reflect at least one real batched forward.
17 3. The cache short-circuits per-prompt when the same prompt re-enters
18 a batch, so the batched counter doesn't inflate on repeat runs.
19
20 What this test *doesn't* do (deferred):
21
22 - The fortran-spec wall-time benchmark (≤ 60s vs 155s baseline) —
23 needs a 1.5B adapter and real GPU; this fixture is 135M on CPU.
24 - Real MLX batched forward — MLX backend still loops; real
25 ``mx.array`` padded forward is a separate follow-up.
26
27 Marked ``slow + online``.
28 """
29
30 from __future__ import annotations
31
32 from pathlib import Path
33
34 import numpy as np
35 import pytest
36
37 from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
38 from dlm_sway.core.model import ModelSpec
39
40 pytestmark = [pytest.mark.slow, pytest.mark.online]
41
42
43 # Same deterministic-LoRA build the other integration tests use.
44 def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
45 import torch
46 from peft import LoraConfig, get_peft_model
47 from transformers import AutoModelForCausalLM, AutoTokenizer
48
49 torch.manual_seed(0)
50 tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
51 if tokenizer.pad_token_id is None:
52 tokenizer.pad_token = tokenizer.eos_token
53 base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
54 cfg = LoraConfig(
55 r=8,
56 lora_alpha=16,
57 target_modules=["q_proj", "v_proj"],
58 lora_dropout=0.0,
59 bias="none",
60 task_type="CAUSAL_LM",
61 )
62 peft_model = get_peft_model(base, cfg)
63 with torch.no_grad():
64 for name, param in peft_model.named_parameters():
65 if "lora_B" in name:
66 param.copy_(torch.randn_like(param) * 0.05)
67 peft_model.save_pretrained(str(out_dir))
68 tokenizer.save_pretrained(str(out_dir))
69
70
71 @pytest.fixture(scope="module")
72 def batched_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
73 adapter_dir = tmp_path_factory.mktemp("batched-s23-adapter")
74 _build_random_lora_adapter(tiny_model_dir, adapter_dir)
75 return adapter_dir
76
77
78 @pytest.fixture(scope="module")
79 def hf_backend(tiny_model_dir: Path, batched_adapter: Path) -> HuggingFaceDifferentialBackend:
80 backend = HuggingFaceDifferentialBackend(
81 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
82 adapter_path=batched_adapter,
83 )
84 yield backend
85 backend.close()
86
87
88 # Varied-length prompts so the left-padding path genuinely matters
89 # (a single length would let a misimplementation slip through).
90 _PROMPTS = [
91 "The capital of France is",
92 "Two plus two equals",
93 "The quick brown fox jumps over the",
94 "Paris",
95 ]
96
97
98 def test_batched_output_matches_serial_on_real_model(
99 hf_backend: HuggingFaceDifferentialBackend,
100 ) -> None:
101 """The batched forward's top-k logprobs must match the per-prompt
102 serial forward's on the same prompts, within a tight fp32
103 reorder tolerance.
104
105 Rationale: left-padded batches reorder the underlying attention
106 accumulations vs a single-prompt forward. We accept ~1e-4
107 divergence — same bar S18's determinism golden uses on CPU.
108 """
109 # Fresh views to avoid the cache serving identical results from
110 # the first call and hiding any real divergence.
111 with hf_backend.as_base() as base_view:
112 batched_base = base_view.next_token_dist_batch(_PROMPTS, top_k=32)
113
114 # Clear the cache so the serial calls actually re-forward.
115 hf_backend._inst.cache.clear() # noqa: SLF001
116
117 with hf_backend.as_base() as base_view:
118 serial_base = [base_view.next_token_dist(p, top_k=32) for p in _PROMPTS]
119
120 for i, (b, s) in enumerate(zip(batched_base, serial_base, strict=True)):
121 # Token-id sets should be identical in the top-k slice
122 # (ordering can swap on exact-tie logprobs, compare as sets).
123 assert set(b.token_ids.tolist()) == set(s.token_ids.tolist()), (
124 f"prompt[{i}]={_PROMPTS[i]!r}: top-k token sets differ "
125 f"(batched {b.token_ids.tolist()}, serial {s.token_ids.tolist()})"
126 )
127 # Top-1 logprob should match within the fp32 reorder tol.
128 np.testing.assert_allclose(
129 sorted(b.logprobs.tolist(), reverse=True)[:5],
130 sorted(s.logprobs.tolist(), reverse=True)[:5],
131 atol=1e-4,
132 rtol=1e-3,
133 err_msg=f"prompt[{i}]={_PROMPTS[i]!r}: top-5 logprobs diverged",
134 )
135
136
137 def test_batched_forward_fires_instrumentation(hf_backend: HuggingFaceDifferentialBackend) -> None:
138 """A batched call on the HF backend must increment
139 ``batches_sent`` + ``batched_prompts`` + ``max_batch_size``. This
140 is how the report footer knows to print the ``batches: N (avg=K)``
141 segment."""
142 hf_backend._inst.cache.clear() # noqa: SLF001
143 stats = hf_backend._inst.stats # noqa: SLF001
144 before = (stats.batches_sent, stats.batched_prompts, stats.max_batch_size)
145
146 with hf_backend.as_base() as base_view:
147 out = base_view.next_token_dist_batch(_PROMPTS, top_k=16)
148
149 assert len(out) == len(_PROMPTS)
150 after = (stats.batches_sent, stats.batched_prompts, stats.max_batch_size)
151 assert after[0] == before[0] + 1, f"expected one new batch, got {after[0] - before[0]}"
152 assert after[1] == before[1] + len(_PROMPTS), (
153 f"expected +{len(_PROMPTS)} batched prompts, got {after[1] - before[1]}"
154 )
155 assert after[2] >= len(_PROMPTS)
156
157
158 def test_batched_cache_short_circuits_repeat_prompts(
159 hf_backend: HuggingFaceDifferentialBackend,
160 ) -> None:
161 """Second batched call with identical prompts hits the cache
162 per-prompt. ``batches_sent`` must NOT increment a second time
163 because no prompts missed."""
164 hf_backend._inst.cache.clear() # noqa: SLF001
165 with hf_backend.as_base() as base_view:
166 base_view.next_token_dist_batch(_PROMPTS, top_k=16)
167 before_batches = hf_backend._inst.stats.batches_sent # noqa: SLF001
168 before_hits = hf_backend._inst.stats.cache_hits # noqa: SLF001
169
170 with hf_backend.as_base() as base_view:
171 base_view.next_token_dist_batch(_PROMPTS, top_k=16)
172
173 after_batches = hf_backend._inst.stats.batches_sent # noqa: SLF001
174 after_hits = hf_backend._inst.stats.cache_hits # noqa: SLF001
175 # No new batch — everything came from the cache.
176 assert after_batches == before_batches, (
177 f"second all-cache-hit call spuriously fired a batch ({before_batches}{after_batches})"
178 )
179 assert after_hits - before_hits == len(_PROMPTS), (
180 f"expected {len(_PROMPTS)} fresh cache hits, got {after_hits - before_hits}"
181 )