| 1 |
"""Unit tests for Sprint 42's HF reward-model judge runtime.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import builtins |
| 6 |
from collections import deque |
| 7 |
from types import SimpleNamespace |
| 8 |
from unittest.mock import patch |
| 9 |
|
| 10 |
import pytest |
| 11 |
|
| 12 |
from dlm.preference import ( |
| 13 |
HfRewardModelJudge, |
| 14 |
InvalidJudgeSpecError, |
| 15 |
JudgeInvocationError, |
| 16 |
JudgeUnavailableError, |
| 17 |
) |
| 18 |
from dlm.preference.judge import ( |
| 19 |
_default_reward_loader, |
| 20 |
_encode_reward_input, |
| 21 |
_extract_reward_scalar, |
| 22 |
_move_to_device, |
| 23 |
_resolve_reward_device, |
| 24 |
) |
| 25 |
|
| 26 |
|
| 27 |
class FakeScalar: |
| 28 |
def __init__(self, value: float) -> None: |
| 29 |
self._value = value |
| 30 |
|
| 31 |
def item(self) -> float: |
| 32 |
return self._value |
| 33 |
|
| 34 |
|
| 35 |
class FakeLogits: |
| 36 |
def __init__(self, values: list[float]) -> None: |
| 37 |
self._values = values |
| 38 |
|
| 39 |
def numel(self) -> int: |
| 40 |
return len(self._values) |
| 41 |
|
| 42 |
def reshape(self, *_shape: int) -> FakeLogits: |
| 43 |
return self |
| 44 |
|
| 45 |
def __getitem__(self, idx: int) -> FakeScalar: |
| 46 |
return FakeScalar(self._values[idx]) |
| 47 |
|
| 48 |
|
| 49 |
class FakeBatch(dict[str, object]): |
| 50 |
def to(self, device: str) -> FakeBatch: |
| 51 |
self["__device__"] = device |
| 52 |
return self |
| 53 |
|
| 54 |
|
| 55 |
class FakeTensor: |
| 56 |
def __init__(self) -> None: |
| 57 |
self.device: str | None = None |
| 58 |
|
| 59 |
def to(self, device: str) -> FakeTensor: |
| 60 |
self.device = device |
| 61 |
return self |
| 62 |
|
| 63 |
|
| 64 |
class FakeTokenizer: |
| 65 |
def __init__( |
| 66 |
self, |
| 67 |
*, |
| 68 |
use_chat_template: bool = False, |
| 69 |
template_error: Exception | None = None, |
| 70 |
template_returns_non_string: bool = False, |
| 71 |
) -> None: |
| 72 |
self.calls: list[tuple[str, tuple[object, ...], dict[str, object]]] = [] |
| 73 |
self._template_error = template_error |
| 74 |
self._template_returns_non_string = template_returns_non_string |
| 75 |
if use_chat_template: |
| 76 |
self.chat_template = "fake-template" |
| 77 |
|
| 78 |
def apply_chat_template(self, messages: list[dict[str, str]], **kwargs: object) -> str: |
| 79 |
self.calls.append(("apply_chat_template", (messages,), dict(kwargs))) |
| 80 |
if self._template_error is not None: |
| 81 |
raise self._template_error |
| 82 |
if self._template_returns_non_string: |
| 83 |
return "" # type: ignore[return-value] |
| 84 |
return f"templated::{messages[0]['content']}::{messages[1]['content']}" |
| 85 |
|
| 86 |
def __call__(self, *args: object, **kwargs: object) -> FakeBatch: |
| 87 |
self.calls.append(("tokenizer", args, dict(kwargs))) |
| 88 |
return FakeBatch({"input_ids": object()}) |
| 89 |
|
| 90 |
|
| 91 |
class FakeModel: |
| 92 |
def __init__(self, logits: list[FakeLogits]) -> None: |
| 93 |
self._logits = deque(logits) |
| 94 |
self.calls: list[dict[str, object]] = [] |
| 95 |
|
| 96 |
def __call__(self, **kwargs: object): |
| 97 |
self.calls.append(dict(kwargs)) |
| 98 |
|
| 99 |
class Output: |
| 100 |
def __init__(self, logits: FakeLogits) -> None: |
| 101 |
self.logits = logits |
| 102 |
|
| 103 |
return Output(self._logits.popleft()) |
| 104 |
|
| 105 |
|
| 106 |
class FakeTorchScalarLogits: |
| 107 |
def __init__(self, value: float) -> None: |
| 108 |
self._value = value |
| 109 |
|
| 110 |
def numel(self) -> int: |
| 111 |
return 1 |
| 112 |
|
| 113 |
def item(self) -> float: |
| 114 |
return self._value |
| 115 |
|
| 116 |
|
| 117 |
class FakePretrainedRewardModel: |
| 118 |
def __init__(self) -> None: |
| 119 |
self.device: str | None = None |
| 120 |
self.eval_called = False |
| 121 |
|
| 122 |
def to(self, device: str) -> FakePretrainedRewardModel: |
| 123 |
self.device = device |
| 124 |
return self |
| 125 |
|
| 126 |
def eval(self) -> None: |
| 127 |
self.eval_called = True |
| 128 |
|
| 129 |
|
| 130 |
def _loader_factory(tokenizer: FakeTokenizer, model: FakeModel): |
| 131 |
calls: list[tuple[str, str]] = [] |
| 132 |
|
| 133 |
def _loader(hf_id: str, device: str): |
| 134 |
calls.append((hf_id, device)) |
| 135 |
from dlm.preference.judge import _LoadedRewardJudge |
| 136 |
|
| 137 |
return _LoadedRewardJudge(model=model, tokenizer=tokenizer, device=device) |
| 138 |
|
| 139 |
return calls, _loader |
| 140 |
|
| 141 |
|
| 142 |
class TestHfRewardModelJudge: |
| 143 |
def test_blank_selector_is_rejected(self) -> None: |
| 144 |
with pytest.raises(InvalidJudgeSpecError, match="include a model id"): |
| 145 |
HfRewardModelJudge(" ") |
| 146 |
|
| 147 |
def test_scores_pair_and_caches_loaded_bundle(self) -> None: |
| 148 |
tokenizer = FakeTokenizer() |
| 149 |
model = FakeModel([FakeLogits([0.2]), FakeLogits([0.9])]) |
| 150 |
calls, loader = _loader_factory(tokenizer, model) |
| 151 |
judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader) |
| 152 |
|
| 153 |
score = judge.score_pair("What is DGEMM?", "bad", "good") |
| 154 |
|
| 155 |
assert score.score_a == pytest.approx(0.2) |
| 156 |
assert score.score_b == pytest.approx(0.9) |
| 157 |
assert score.preferred == "b" |
| 158 |
assert calls == [("reward/model", "cpu")] |
| 159 |
assert len(model.calls) == 2 |
| 160 |
first_call = tokenizer.calls[0] |
| 161 |
assert first_call[0] == "tokenizer" |
| 162 |
assert first_call[1] == ("What is DGEMM?",) |
| 163 |
assert first_call[2]["text_pair"] == "bad" |
| 164 |
|
| 165 |
def test_chat_template_path_is_used_when_available(self) -> None: |
| 166 |
tokenizer = FakeTokenizer(use_chat_template=True) |
| 167 |
model = FakeModel([FakeLogits([0.4]), FakeLogits([0.1])]) |
| 168 |
_, loader = _loader_factory(tokenizer, model) |
| 169 |
judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader) |
| 170 |
|
| 171 |
judge.score_pair("prompt", "cand-a", "cand-b") |
| 172 |
|
| 173 |
assert tokenizer.calls[0][0] == "apply_chat_template" |
| 174 |
assert tokenizer.calls[1][0] == "tokenizer" |
| 175 |
assert tokenizer.calls[1][1] == ("templated::prompt::cand-a",) |
| 176 |
|
| 177 |
def test_non_scalar_logits_are_refused(self) -> None: |
| 178 |
tokenizer = FakeTokenizer() |
| 179 |
model = FakeModel([FakeLogits([0.1, 0.2]), FakeLogits([0.3])]) |
| 180 |
_, loader = _loader_factory(tokenizer, model) |
| 181 |
judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader) |
| 182 |
|
| 183 |
with pytest.raises(JudgeInvocationError, match="single scalar logit"): |
| 184 |
judge.score_pair("prompt", "a", "b") |
| 185 |
|
| 186 |
def test_missing_logits_are_refused(self) -> None: |
| 187 |
tokenizer = FakeTokenizer() |
| 188 |
|
| 189 |
class NoLogitsModel: |
| 190 |
def __call__(self, **_kwargs: object): |
| 191 |
class Output: |
| 192 |
pass |
| 193 |
|
| 194 |
return Output() |
| 195 |
|
| 196 |
calls: list[tuple[str, str]] = [] |
| 197 |
|
| 198 |
def loader(hf_id: str, device: str): |
| 199 |
calls.append((hf_id, device)) |
| 200 |
from dlm.preference.judge import _LoadedRewardJudge |
| 201 |
|
| 202 |
return _LoadedRewardJudge(model=NoLogitsModel(), tokenizer=tokenizer, device=device) |
| 203 |
|
| 204 |
judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader) |
| 205 |
|
| 206 |
with pytest.raises(JudgeInvocationError, match="no `.logits`"): |
| 207 |
judge.score_pair("prompt", "a", "b") |
| 208 |
assert calls == [("reward/model", "cpu")] |
| 209 |
|
| 210 |
def test_missing_torch_is_reported(self) -> None: |
| 211 |
tokenizer = FakeTokenizer() |
| 212 |
model = FakeModel([FakeLogits([0.2]), FakeLogits([0.1])]) |
| 213 |
_, loader = _loader_factory(tokenizer, model) |
| 214 |
judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader) |
| 215 |
real_import = builtins.__import__ |
| 216 |
|
| 217 |
def fake_import(name: str, *args: object, **kwargs: object): |
| 218 |
if name == "torch": |
| 219 |
raise ImportError("no torch here") |
| 220 |
return real_import(name, *args, **kwargs) |
| 221 |
|
| 222 |
with ( |
| 223 |
patch("builtins.__import__", side_effect=fake_import), |
| 224 |
pytest.raises(JudgeUnavailableError, match="requires torch"), |
| 225 |
): |
| 226 |
judge.score_pair("prompt", "a", "b") |
| 227 |
|
| 228 |
def test_default_loader_path_is_used_when_no_loader_is_supplied(self) -> None: |
| 229 |
tokenizer = FakeTokenizer() |
| 230 |
model = FakeModel([FakeLogits([0.7]), FakeLogits([0.1])]) |
| 231 |
|
| 232 |
def fake_default_loader(hf_id: str, device: str): |
| 233 |
from dlm.preference.judge import _LoadedRewardJudge |
| 234 |
|
| 235 |
assert hf_id == "reward/model" |
| 236 |
assert device == "cpu" |
| 237 |
return _LoadedRewardJudge(model=model, tokenizer=tokenizer, device=device) |
| 238 |
|
| 239 |
judge = HfRewardModelJudge("reward/model", device="cpu") |
| 240 |
with patch("dlm.preference.judge._default_reward_loader", side_effect=fake_default_loader): |
| 241 |
score = judge.score_pair("prompt", "a", "b") |
| 242 |
|
| 243 |
assert score.preferred == "a" |
| 244 |
|
| 245 |
|
| 246 |
class TestHfRewardHelpers: |
| 247 |
def test_default_reward_loader_requires_transformers(self) -> None: |
| 248 |
real_import = builtins.__import__ |
| 249 |
|
| 250 |
def fake_import(name: str, *args: object, **kwargs: object): |
| 251 |
if name == "transformers": |
| 252 |
raise ImportError("missing transformers") |
| 253 |
return real_import(name, *args, **kwargs) |
| 254 |
|
| 255 |
with ( |
| 256 |
patch("builtins.__import__", side_effect=fake_import), |
| 257 |
pytest.raises(JudgeUnavailableError, match="requires transformers"), |
| 258 |
): |
| 259 |
_default_reward_loader("reward/model", "cpu") |
| 260 |
|
| 261 |
def test_default_reward_loader_moves_model_and_sets_eval(self) -> None: |
| 262 |
model = FakePretrainedRewardModel() |
| 263 |
tokenizer = FakeTokenizer() |
| 264 |
|
| 265 |
class AutoModelForSequenceClassification: |
| 266 |
@staticmethod |
| 267 |
def from_pretrained(hf_id: str) -> FakePretrainedRewardModel: |
| 268 |
assert hf_id == "reward/model" |
| 269 |
return model |
| 270 |
|
| 271 |
class AutoTokenizer: |
| 272 |
@staticmethod |
| 273 |
def from_pretrained(hf_id: str) -> FakeTokenizer: |
| 274 |
assert hf_id == "reward/model" |
| 275 |
return tokenizer |
| 276 |
|
| 277 |
fake_transformers = SimpleNamespace( |
| 278 |
AutoModelForSequenceClassification=AutoModelForSequenceClassification, |
| 279 |
AutoTokenizer=AutoTokenizer, |
| 280 |
) |
| 281 |
|
| 282 |
with patch.dict("sys.modules", {"transformers": fake_transformers}): |
| 283 |
loaded = _default_reward_loader("reward/model", "mps") |
| 284 |
|
| 285 |
assert loaded.model is model |
| 286 |
assert loaded.tokenizer is tokenizer |
| 287 |
assert loaded.device == "mps" |
| 288 |
assert model.device == "mps" |
| 289 |
assert model.eval_called is True |
| 290 |
|
| 291 |
def test_resolve_reward_device_respects_explicit_request(self) -> None: |
| 292 |
assert _resolve_reward_device("cuda:3") == "cuda:3" |
| 293 |
|
| 294 |
def test_resolve_reward_device_returns_cpu_when_torch_is_missing(self) -> None: |
| 295 |
real_import = builtins.__import__ |
| 296 |
|
| 297 |
def fake_import(name: str, *args: object, **kwargs: object): |
| 298 |
if name == "torch": |
| 299 |
raise ImportError("no torch") |
| 300 |
return real_import(name, *args, **kwargs) |
| 301 |
|
| 302 |
with patch("builtins.__import__", side_effect=fake_import): |
| 303 |
assert _resolve_reward_device("auto") == "cpu" |
| 304 |
|
| 305 |
def test_resolve_reward_device_prefers_cuda_then_mps_then_cpu(self) -> None: |
| 306 |
torch_cuda = SimpleNamespace( |
| 307 |
cuda=SimpleNamespace(is_available=lambda: True), |
| 308 |
backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)), |
| 309 |
) |
| 310 |
torch_mps = SimpleNamespace( |
| 311 |
cuda=SimpleNamespace(is_available=lambda: False), |
| 312 |
backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)), |
| 313 |
) |
| 314 |
torch_cpu = SimpleNamespace( |
| 315 |
cuda=SimpleNamespace(is_available=lambda: False), |
| 316 |
backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)), |
| 317 |
) |
| 318 |
|
| 319 |
with patch.dict("sys.modules", {"torch": torch_cuda}): |
| 320 |
assert _resolve_reward_device("auto") == "cuda" |
| 321 |
with patch.dict("sys.modules", {"torch": torch_mps}): |
| 322 |
assert _resolve_reward_device("auto") == "mps" |
| 323 |
with patch.dict("sys.modules", {"torch": torch_cpu}): |
| 324 |
assert _resolve_reward_device("auto") == "cpu" |
| 325 |
|
| 326 |
def test_encode_reward_input_falls_back_when_template_raises(self) -> None: |
| 327 |
tokenizer = FakeTokenizer(use_chat_template=True, template_error=RuntimeError("boom")) |
| 328 |
|
| 329 |
encoded = _encode_reward_input(tokenizer, "prompt", "candidate") |
| 330 |
|
| 331 |
assert isinstance(encoded, FakeBatch) |
| 332 |
assert tokenizer.calls[-1][0] == "tokenizer" |
| 333 |
assert tokenizer.calls[-1][1] == ("prompt",) |
| 334 |
assert tokenizer.calls[-1][2]["text_pair"] == "candidate" |
| 335 |
|
| 336 |
def test_encode_reward_input_falls_back_when_template_returns_non_string(self) -> None: |
| 337 |
tokenizer = FakeTokenizer(use_chat_template=True, template_returns_non_string=True) |
| 338 |
|
| 339 |
encoded = _encode_reward_input(tokenizer, "prompt", "candidate") |
| 340 |
|
| 341 |
assert isinstance(encoded, FakeBatch) |
| 342 |
assert tokenizer.calls[-1][0] == "tokenizer" |
| 343 |
|
| 344 |
def test_move_to_device_moves_mapping_values(self) -> None: |
| 345 |
tensor = FakeTensor() |
| 346 |
payload = {"input_ids": tensor, "meta": "keep"} |
| 347 |
|
| 348 |
moved = _move_to_device(payload, "mps") |
| 349 |
|
| 350 |
assert moved["input_ids"] is tensor |
| 351 |
assert tensor.device == "mps" |
| 352 |
assert moved["meta"] == "keep" |
| 353 |
|
| 354 |
def test_move_to_device_returns_unmodified_non_mapping_values(self) -> None: |
| 355 |
value = object() |
| 356 |
assert _move_to_device(value, "cpu") is value |
| 357 |
|
| 358 |
def test_extract_reward_scalar_uses_item_fallback(self) -> None: |
| 359 |
assert _extract_reward_scalar(FakeTorchScalarLogits(0.75)) == pytest.approx(0.75) |
| 360 |
|
| 361 |
def test_extract_reward_scalar_rejects_unreadable_values(self) -> None: |
| 362 |
class UnreadableLogits: |
| 363 |
def numel(self) -> int: |
| 364 |
return 1 |
| 365 |
|
| 366 |
with pytest.raises(JudgeInvocationError, match="unreadable scalar logit"): |
| 367 |
_extract_reward_scalar(UnreadableLogits()) |