Python · 1711 bytes Raw Blame History
1 """Unit tests for the pure helpers inside ``backends/hf.py``.
2
3 The helpers (``_resolve_dtype``, ``_detect_device``) gate the HF
4 backend's dtype/device choices. They're exercised end-to-end by the
5 integration suite, but doing direct unit coverage on them keeps the
6 slow lane focused on the actual backend behavior — and makes the
7 fast lane catch dtype/device regressions in seconds.
8 """
9
10 from __future__ import annotations
11
12 import importlib.util
13
14 import pytest
15
16 # These tests need torch to construct the dtype objects we're asserting
17 # on; skip cleanly when the [hf] extra isn't installed.
18 if importlib.util.find_spec("torch") is None:
19 pytest.skip(
20 "torch not installed — install the [hf] extra to run HF helper tests",
21 allow_module_level=True,
22 )
23
24 from dlm_sway.backends.hf import _detect_device, _resolve_dtype
25
26
27 class TestResolveDtype:
28 def test_explicit_fp16(self) -> None:
29 import torch
30
31 assert _resolve_dtype("fp16", "cpu") is torch.float16
32
33 def test_explicit_bf16(self) -> None:
34 import torch
35
36 assert _resolve_dtype("bf16", "cpu") is torch.bfloat16
37
38 def test_explicit_fp32(self) -> None:
39 import torch
40
41 assert _resolve_dtype("fp32", "cpu") is torch.float32
42
43 def test_auto_on_cpu_picks_fp32_for_numerical_stability(self) -> None:
44 import torch
45
46 assert _resolve_dtype("auto", "cpu") is torch.float32
47
48 def test_auto_on_mps_picks_fp16(self) -> None:
49 import torch
50
51 assert _resolve_dtype("auto", "mps") is torch.float16
52
53
54 class TestDetectDevice:
55 def test_returns_one_of_supported_devices(self) -> None:
56 device = _detect_device()
57 assert device in ("cuda", "mps", "cpu")