Python · 13148 bytes Raw Blame History
1 """Unit tests for Sprint 42's HF reward-model judge runtime."""
2
3 from __future__ import annotations
4
5 import builtins
6 from collections import deque
7 from types import SimpleNamespace
8 from unittest.mock import patch
9
10 import pytest
11
12 from dlm.preference import (
13 HfRewardModelJudge,
14 InvalidJudgeSpecError,
15 JudgeInvocationError,
16 JudgeUnavailableError,
17 )
18 from dlm.preference.judge import (
19 _default_reward_loader,
20 _encode_reward_input,
21 _extract_reward_scalar,
22 _move_to_device,
23 _resolve_reward_device,
24 )
25
26
27 class FakeScalar:
28 def __init__(self, value: float) -> None:
29 self._value = value
30
31 def item(self) -> float:
32 return self._value
33
34
35 class FakeLogits:
36 def __init__(self, values: list[float]) -> None:
37 self._values = values
38
39 def numel(self) -> int:
40 return len(self._values)
41
42 def reshape(self, *_shape: int) -> FakeLogits:
43 return self
44
45 def __getitem__(self, idx: int) -> FakeScalar:
46 return FakeScalar(self._values[idx])
47
48
49 class FakeBatch(dict[str, object]):
50 def to(self, device: str) -> FakeBatch:
51 self["__device__"] = device
52 return self
53
54
55 class FakeTensor:
56 def __init__(self) -> None:
57 self.device: str | None = None
58
59 def to(self, device: str) -> FakeTensor:
60 self.device = device
61 return self
62
63
64 class FakeTokenizer:
65 def __init__(
66 self,
67 *,
68 use_chat_template: bool = False,
69 template_error: Exception | None = None,
70 template_returns_non_string: bool = False,
71 ) -> None:
72 self.calls: list[tuple[str, tuple[object, ...], dict[str, object]]] = []
73 self._template_error = template_error
74 self._template_returns_non_string = template_returns_non_string
75 if use_chat_template:
76 self.chat_template = "fake-template"
77
78 def apply_chat_template(self, messages: list[dict[str, str]], **kwargs: object) -> str:
79 self.calls.append(("apply_chat_template", (messages,), dict(kwargs)))
80 if self._template_error is not None:
81 raise self._template_error
82 if self._template_returns_non_string:
83 return "" # type: ignore[return-value]
84 return f"templated::{messages[0]['content']}::{messages[1]['content']}"
85
86 def __call__(self, *args: object, **kwargs: object) -> FakeBatch:
87 self.calls.append(("tokenizer", args, dict(kwargs)))
88 return FakeBatch({"input_ids": object()})
89
90
91 class FakeModel:
92 def __init__(self, logits: list[FakeLogits]) -> None:
93 self._logits = deque(logits)
94 self.calls: list[dict[str, object]] = []
95
96 def __call__(self, **kwargs: object):
97 self.calls.append(dict(kwargs))
98
99 class Output:
100 def __init__(self, logits: FakeLogits) -> None:
101 self.logits = logits
102
103 return Output(self._logits.popleft())
104
105
106 class FakeTorchScalarLogits:
107 def __init__(self, value: float) -> None:
108 self._value = value
109
110 def numel(self) -> int:
111 return 1
112
113 def item(self) -> float:
114 return self._value
115
116
117 class FakePretrainedRewardModel:
118 def __init__(self) -> None:
119 self.device: str | None = None
120 self.eval_called = False
121
122 def to(self, device: str) -> FakePretrainedRewardModel:
123 self.device = device
124 return self
125
126 def eval(self) -> None:
127 self.eval_called = True
128
129
130 def _loader_factory(tokenizer: FakeTokenizer, model: FakeModel):
131 calls: list[tuple[str, str]] = []
132
133 def _loader(hf_id: str, device: str):
134 calls.append((hf_id, device))
135 from dlm.preference.judge import _LoadedRewardJudge
136
137 return _LoadedRewardJudge(model=model, tokenizer=tokenizer, device=device)
138
139 return calls, _loader
140
141
142 class TestHfRewardModelJudge:
143 def test_blank_selector_is_rejected(self) -> None:
144 with pytest.raises(InvalidJudgeSpecError, match="include a model id"):
145 HfRewardModelJudge(" ")
146
147 def test_scores_pair_and_caches_loaded_bundle(self) -> None:
148 tokenizer = FakeTokenizer()
149 model = FakeModel([FakeLogits([0.2]), FakeLogits([0.9])])
150 calls, loader = _loader_factory(tokenizer, model)
151 judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader)
152
153 score = judge.score_pair("What is DGEMM?", "bad", "good")
154
155 assert score.score_a == pytest.approx(0.2)
156 assert score.score_b == pytest.approx(0.9)
157 assert score.preferred == "b"
158 assert calls == [("reward/model", "cpu")]
159 assert len(model.calls) == 2
160 first_call = tokenizer.calls[0]
161 assert first_call[0] == "tokenizer"
162 assert first_call[1] == ("What is DGEMM?",)
163 assert first_call[2]["text_pair"] == "bad"
164
165 def test_chat_template_path_is_used_when_available(self) -> None:
166 tokenizer = FakeTokenizer(use_chat_template=True)
167 model = FakeModel([FakeLogits([0.4]), FakeLogits([0.1])])
168 _, loader = _loader_factory(tokenizer, model)
169 judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader)
170
171 judge.score_pair("prompt", "cand-a", "cand-b")
172
173 assert tokenizer.calls[0][0] == "apply_chat_template"
174 assert tokenizer.calls[1][0] == "tokenizer"
175 assert tokenizer.calls[1][1] == ("templated::prompt::cand-a",)
176
177 def test_non_scalar_logits_are_refused(self) -> None:
178 tokenizer = FakeTokenizer()
179 model = FakeModel([FakeLogits([0.1, 0.2]), FakeLogits([0.3])])
180 _, loader = _loader_factory(tokenizer, model)
181 judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader)
182
183 with pytest.raises(JudgeInvocationError, match="single scalar logit"):
184 judge.score_pair("prompt", "a", "b")
185
186 def test_missing_logits_are_refused(self) -> None:
187 tokenizer = FakeTokenizer()
188
189 class NoLogitsModel:
190 def __call__(self, **_kwargs: object):
191 class Output:
192 pass
193
194 return Output()
195
196 calls: list[tuple[str, str]] = []
197
198 def loader(hf_id: str, device: str):
199 calls.append((hf_id, device))
200 from dlm.preference.judge import _LoadedRewardJudge
201
202 return _LoadedRewardJudge(model=NoLogitsModel(), tokenizer=tokenizer, device=device)
203
204 judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader)
205
206 with pytest.raises(JudgeInvocationError, match="no `.logits`"):
207 judge.score_pair("prompt", "a", "b")
208 assert calls == [("reward/model", "cpu")]
209
210 def test_missing_torch_is_reported(self) -> None:
211 tokenizer = FakeTokenizer()
212 model = FakeModel([FakeLogits([0.2]), FakeLogits([0.1])])
213 _, loader = _loader_factory(tokenizer, model)
214 judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader)
215 real_import = builtins.__import__
216
217 def fake_import(name: str, *args: object, **kwargs: object):
218 if name == "torch":
219 raise ImportError("no torch here")
220 return real_import(name, *args, **kwargs)
221
222 with (
223 patch("builtins.__import__", side_effect=fake_import),
224 pytest.raises(JudgeUnavailableError, match="requires torch"),
225 ):
226 judge.score_pair("prompt", "a", "b")
227
228 def test_default_loader_path_is_used_when_no_loader_is_supplied(self) -> None:
229 tokenizer = FakeTokenizer()
230 model = FakeModel([FakeLogits([0.7]), FakeLogits([0.1])])
231
232 def fake_default_loader(hf_id: str, device: str):
233 from dlm.preference.judge import _LoadedRewardJudge
234
235 assert hf_id == "reward/model"
236 assert device == "cpu"
237 return _LoadedRewardJudge(model=model, tokenizer=tokenizer, device=device)
238
239 judge = HfRewardModelJudge("reward/model", device="cpu")
240 with patch("dlm.preference.judge._default_reward_loader", side_effect=fake_default_loader):
241 score = judge.score_pair("prompt", "a", "b")
242
243 assert score.preferred == "a"
244
245
246 class TestHfRewardHelpers:
247 def test_default_reward_loader_requires_transformers(self) -> None:
248 real_import = builtins.__import__
249
250 def fake_import(name: str, *args: object, **kwargs: object):
251 if name == "transformers":
252 raise ImportError("missing transformers")
253 return real_import(name, *args, **kwargs)
254
255 with (
256 patch("builtins.__import__", side_effect=fake_import),
257 pytest.raises(JudgeUnavailableError, match="requires transformers"),
258 ):
259 _default_reward_loader("reward/model", "cpu")
260
261 def test_default_reward_loader_moves_model_and_sets_eval(self) -> None:
262 model = FakePretrainedRewardModel()
263 tokenizer = FakeTokenizer()
264
265 class AutoModelForSequenceClassification:
266 @staticmethod
267 def from_pretrained(hf_id: str) -> FakePretrainedRewardModel:
268 assert hf_id == "reward/model"
269 return model
270
271 class AutoTokenizer:
272 @staticmethod
273 def from_pretrained(hf_id: str) -> FakeTokenizer:
274 assert hf_id == "reward/model"
275 return tokenizer
276
277 fake_transformers = SimpleNamespace(
278 AutoModelForSequenceClassification=AutoModelForSequenceClassification,
279 AutoTokenizer=AutoTokenizer,
280 )
281
282 with patch.dict("sys.modules", {"transformers": fake_transformers}):
283 loaded = _default_reward_loader("reward/model", "mps")
284
285 assert loaded.model is model
286 assert loaded.tokenizer is tokenizer
287 assert loaded.device == "mps"
288 assert model.device == "mps"
289 assert model.eval_called is True
290
291 def test_resolve_reward_device_respects_explicit_request(self) -> None:
292 assert _resolve_reward_device("cuda:3") == "cuda:3"
293
294 def test_resolve_reward_device_returns_cpu_when_torch_is_missing(self) -> None:
295 real_import = builtins.__import__
296
297 def fake_import(name: str, *args: object, **kwargs: object):
298 if name == "torch":
299 raise ImportError("no torch")
300 return real_import(name, *args, **kwargs)
301
302 with patch("builtins.__import__", side_effect=fake_import):
303 assert _resolve_reward_device("auto") == "cpu"
304
305 def test_resolve_reward_device_prefers_cuda_then_mps_then_cpu(self) -> None:
306 torch_cuda = SimpleNamespace(
307 cuda=SimpleNamespace(is_available=lambda: True),
308 backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)),
309 )
310 torch_mps = SimpleNamespace(
311 cuda=SimpleNamespace(is_available=lambda: False),
312 backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)),
313 )
314 torch_cpu = SimpleNamespace(
315 cuda=SimpleNamespace(is_available=lambda: False),
316 backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)),
317 )
318
319 with patch.dict("sys.modules", {"torch": torch_cuda}):
320 assert _resolve_reward_device("auto") == "cuda"
321 with patch.dict("sys.modules", {"torch": torch_mps}):
322 assert _resolve_reward_device("auto") == "mps"
323 with patch.dict("sys.modules", {"torch": torch_cpu}):
324 assert _resolve_reward_device("auto") == "cpu"
325
326 def test_encode_reward_input_falls_back_when_template_raises(self) -> None:
327 tokenizer = FakeTokenizer(use_chat_template=True, template_error=RuntimeError("boom"))
328
329 encoded = _encode_reward_input(tokenizer, "prompt", "candidate")
330
331 assert isinstance(encoded, FakeBatch)
332 assert tokenizer.calls[-1][0] == "tokenizer"
333 assert tokenizer.calls[-1][1] == ("prompt",)
334 assert tokenizer.calls[-1][2]["text_pair"] == "candidate"
335
336 def test_encode_reward_input_falls_back_when_template_returns_non_string(self) -> None:
337 tokenizer = FakeTokenizer(use_chat_template=True, template_returns_non_string=True)
338
339 encoded = _encode_reward_input(tokenizer, "prompt", "candidate")
340
341 assert isinstance(encoded, FakeBatch)
342 assert tokenizer.calls[-1][0] == "tokenizer"
343
344 def test_move_to_device_moves_mapping_values(self) -> None:
345 tensor = FakeTensor()
346 payload = {"input_ids": tensor, "meta": "keep"}
347
348 moved = _move_to_device(payload, "mps")
349
350 assert moved["input_ids"] is tensor
351 assert tensor.device == "mps"
352 assert moved["meta"] == "keep"
353
354 def test_move_to_device_returns_unmodified_non_mapping_values(self) -> None:
355 value = object()
356 assert _move_to_device(value, "cpu") is value
357
358 def test_extract_reward_scalar_uses_item_fallback(self) -> None:
359 assert _extract_reward_scalar(FakeTorchScalarLogits(0.75)) == pytest.approx(0.75)
360
361 def test_extract_reward_scalar_rejects_unreadable_values(self) -> None:
362 class UnreadableLogits:
363 def numel(self) -> int:
364 return 1
365
366 with pytest.raises(JudgeInvocationError, match="unreadable scalar logit"):
367 _extract_reward_scalar(UnreadableLogits())