tenseleyflow/documentlanguagemodel / 8cb8a8c

Browse files

Refresh platform-sensitive runtime tests

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
8cb8a8c73bf575b5ea8e6b0d1d0513fc43a86bff
Parents
4151ac0
Tree
397b77b

3 changed files

StatusFile+-
M src/dlm/inference/backends/mlx_backend.py 2 2
M tests/unit/hardware/test_plan.py 7 1
M tests/unit/synth/test_teachers.py 4 1
src/dlm/inference/backends/mlx_backend.pymodified
@@ -164,14 +164,14 @@ class MlxBackend(InferenceBackend):
164164
         *,
165165
         adapter_name: str | None = None,
166166
     ) -> None:
167
-        from mlx_lm import load
168
-
169167
         from dlm.inference.loader import resolve_adapter_path
170168
 
171169
         adapter_path = resolve_adapter_path(store, adapter_name=adapter_name)
172170
         if not adapter_path.exists():
173171
             raise AdapterNotFoundError(f"mlx backend: adapter dir {adapter_path} does not exist")
174172
 
173
+        from mlx_lm import load
174
+
175175
         # Stage both tensors + adapter_config.json into a scratch dir.
176176
         # `stage_mlx_adapter_dir` performs the preflight PEFT-shape
177177
         # check + translates PEFT config into mlx-lm's schema + writes
tests/unit/hardware/test_plan.pymodified
@@ -45,7 +45,13 @@ class TestPrecisionPicker:
4545
         with force_mps():
4646
             caps = probe()
4747
         with caplog.at_level(logging.WARNING, logger="dlm.hardware.plan"):  # type: ignore[attr-defined]
48
-            plan = resolve(_cfg(precision="fp16"), caps, base_params=8_000_000_000, seq_len=2048)
48
+            plan = resolve(
49
+                _cfg(precision="fp16"),
50
+                caps,
51
+                base_params=8_000_000_000,
52
+                seq_len=2048,
53
+                force=True,
54
+            )
4955
         assert plan.precision == "fp16"
5056
         # The caller must see the risk explicitly — silent fp16 on MPS
5157
         # is what caused the original bug.
tests/unit/synth/test_teachers.pymodified
@@ -128,7 +128,10 @@ class TestHfTeacher:
128128
             "system", "user", max_new_tokens=21, temperature=0.5, top_p=0.8, seed=11
129129
         )
130130
         assert out == "hf output"
131
-        assert seen["loader"] == ("Qwen/Qwen2.5-1.5B-Instruct", "cpu")
131
+        assert seen["loader"] == (
132
+            "Qwen/Qwen2.5-1.5B-Instruct",
133
+            teachers_mod._resolve_generation_device("auto"),
134
+        )
132135
         assert seen["runner"][3:] == (21, 0.5, 0.8, 11)
133136
 
134137