tenseleyflow/sway / b81426c

Browse files

tests/backend_hf_helpers: unit cover _resolve_dtype + _detect_device

Authored by espadonne
SHA
b81426c31ed47fda14f449ae132ace2fb22e6b96
Parents
6ad2814
Tree
0a6061a

1 changed file

StatusFile+-
A tests/unit/test_backend_hf_helpers.py 57 0
tests/unit/test_backend_hf_helpers.pyadded
@@ -0,0 +1,57 @@
1
+"""Unit tests for the pure helpers inside ``backends/hf.py``.
2
+
3
+The helpers (``_resolve_dtype``, ``_detect_device``) gate the HF
4
+backend's dtype/device choices. They're exercised end-to-end by the
5
+integration suite, but doing direct unit coverage on them keeps the
6
+slow lane focused on the actual backend behavior — and makes the
7
+fast lane catch dtype/device regressions in seconds.
8
+"""
9
+
10
+from __future__ import annotations
11
+
12
+import importlib.util
13
+
14
+import pytest
15
+
16
+# These tests need torch to construct the dtype objects we're asserting
17
+# on; skip cleanly when the [hf] extra isn't installed.
18
+if importlib.util.find_spec("torch") is None:
19
+    pytest.skip(
20
+        "torch not installed — install the [hf] extra to run HF helper tests",
21
+        allow_module_level=True,
22
+    )
23
+
24
+from dlm_sway.backends.hf import _detect_device, _resolve_dtype
25
+
26
+
27
+class TestResolveDtype:
28
+    def test_explicit_fp16(self) -> None:
29
+        import torch
30
+
31
+        assert _resolve_dtype("fp16", "cpu") is torch.float16
32
+
33
+    def test_explicit_bf16(self) -> None:
34
+        import torch
35
+
36
+        assert _resolve_dtype("bf16", "cpu") is torch.bfloat16
37
+
38
+    def test_explicit_fp32(self) -> None:
39
+        import torch
40
+
41
+        assert _resolve_dtype("fp32", "cpu") is torch.float32
42
+
43
+    def test_auto_on_cpu_picks_fp32_for_numerical_stability(self) -> None:
44
+        import torch
45
+
46
+        assert _resolve_dtype("auto", "cpu") is torch.float32
47
+
48
+    def test_auto_on_mps_picks_fp16(self) -> None:
49
+        import torch
50
+
51
+        assert _resolve_dtype("auto", "mps") is torch.float16
52
+
53
+
54
+class TestDetectDevice:
55
+    def test_returns_one_of_supported_devices(self) -> None:
56
+        device = _detect_device()
57
+        assert device in ("cuda", "mps", "cpu")