tenseleyflow/sway / 5bec484

Browse files

backends/mlx: auto-convert PEFT adapter on load + content-hash cache (F01)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
5bec484b49e09aa88d6865fa413583a0e17d7292
Parents
922ce3e
Tree
bee5aa6

2 changed files

StatusFile+-
M src/dlm_sway/backends/mlx.py 67 1
M tests/unit/test_mlx_convert.py 68 0
src/dlm_sway/backends/mlx.pymodified
@@ -212,7 +212,12 @@ class MLXDifferentialBackend:
212212
         mx, mlx_lm = _require_mlx()
213213
         self._mx = mx
214214
         self._spec = base_spec
215
-        self._adapter_path = Path(adapter_path).expanduser().resolve()
215
+        raw_path = Path(adapter_path).expanduser().resolve()
216
+        # S24: when the user points us at a PEFT adapter (typical
217
+        # `dlm export` output), auto-convert into the user's cache
218
+        # so the headline `.dlm → sway` flow on MLX just works.
219
+        # Cached by content hash so repeated runs skip the convert.
220
+        self._adapter_path = _ensure_mlx_adapter(raw_path)
216221
 
217222
         # Load bare base (no adapter).
218223
         self._base_model, self._tokenizer = mlx_lm.load(base_spec.base)
@@ -267,6 +272,67 @@ class MLXDifferentialBackend:
267272
         self._active = None
268273
 
269274
 
275
+def _ensure_mlx_adapter(adapter_path: Path) -> Path:
276
+    """Auto-convert PEFT adapters to MLX-LM format on first load (S24).
277
+
278
+    Detection is structural: if ``adapter_path/adapter_model.safetensors``
279
+    exists, we treat it as PEFT and run the converter. If it already
280
+    contains ``adapters.safetensors`` (mlx-lm's filename), we leave it
281
+    alone — assumes the user converted manually or the dir is already
282
+    MLX-shaped.
283
+
284
+    Cached at ``${XDG_CACHE_HOME:-$HOME/.cache}/dlm-sway/mlx-converted/<sha>/``
285
+    keyed on a hash of the source ``adapter_model.safetensors`` bytes.
286
+    Repeated runs on the same adapter version skip conversion entirely
287
+    (~10 ms hash + dir lookup).
288
+    """
289
+    if (adapter_path / "adapters.safetensors").exists():
290
+        # Already in MLX format — pass through unchanged.
291
+        return adapter_path
292
+    if not (adapter_path / "adapter_model.safetensors").exists():
293
+        # Neither MLX nor PEFT shape; let mlx_lm.load surface its own error.
294
+        return adapter_path
295
+
296
+    # Compute a content hash of the source PEFT safetensors. blake2b
297
+    # in 16-byte digest mode is overkill on file IO but unambiguous —
298
+    # different adapter versions never collide.
299
+    import hashlib
300
+
301
+    src_st = adapter_path / "adapter_model.safetensors"
302
+    h = hashlib.blake2b(digest_size=16)
303
+    with src_st.open("rb") as fh:
304
+        for chunk in iter(lambda: fh.read(1024 * 1024), b""):
305
+            h.update(chunk)
306
+    sha = h.hexdigest()
307
+
308
+    cache_root = _mlx_cache_root() / sha
309
+    if (cache_root / "adapters.safetensors").exists() and (
310
+        cache_root / "adapter_config.json"
311
+    ).exists():
312
+        return cache_root
313
+
314
+    # First-run conversion. Import here to keep the cycle off the
315
+    # import path of users who never touch MLX.
316
+    from dlm_sway.backends._mlx_convert import convert_peft_to_mlx
317
+
318
+    cache_root.mkdir(parents=True, exist_ok=True)
319
+    convert_peft_to_mlx(adapter_path, cache_root, overwrite=True)
320
+    return cache_root
321
+
322
+
323
+def _mlx_cache_root() -> Path:
324
+    """``$XDG_CACHE_HOME/dlm-sway/mlx-converted/`` (or ``~/.cache/...``).
325
+
326
+    Honors XDG so Linux users get their conventional cache location;
327
+    macOS users get ``~/.cache/...`` (XDG isn't standard on darwin
328
+    but uv + many Python tools follow this convention there too).
329
+    """
330
+    import os
331
+
332
+    base = os.environ.get("XDG_CACHE_HOME") or str(Path.home() / ".cache")
333
+    return Path(base) / "dlm-sway" / "mlx-converted"
334
+
335
+
270336
 def _log_softmax(x: np.ndarray, *, axis: int) -> np.ndarray:
271337
     x_max = np.max(x, axis=axis, keepdims=True)
272338
     y = x - x_max
tests/unit/test_mlx_convert.pymodified
@@ -234,6 +234,74 @@ class TestConvertPeftToMlxErrors:
234234
             convert_peft_to_mlx(src, tmp_path / "mlx")
235235
 
236236
 
237
+class TestEnsureMlxAdapterAutoConvert:
238
+    """``MLXDifferentialBackend.__init__`` calls ``_ensure_mlx_adapter``
239
+    to upgrade PEFT-shaped adapter dirs to MLX format on the fly. The
240
+    function lives in ``backends/mlx.py`` so it doesn't pull mlx-lm
241
+    when the path is already MLX-shaped."""
242
+
243
+    def test_passes_through_when_dir_is_already_mlx_shape(self, tmp_path: Path) -> None:
244
+        """Existing ``adapters.safetensors`` → no conversion, return
245
+        the same path unchanged. (Manual conversions / pre-built MLX
246
+        adapters from other tools must not be re-converted.)"""
247
+        from dlm_sway.backends.mlx import _ensure_mlx_adapter
248
+
249
+        mlx_dir = tmp_path / "mlx"
250
+        mlx_dir.mkdir()
251
+        save_file({}, str(mlx_dir / "adapters.safetensors"))
252
+        (mlx_dir / "adapter_config.json").write_text('{"fine_tune_type":"lora"}')
253
+        out = _ensure_mlx_adapter(mlx_dir)
254
+        assert out == mlx_dir
255
+
256
+    def test_auto_converts_peft_dir_into_cache(
257
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
258
+    ) -> None:
259
+        """A PEFT-shaped dir gets converted into XDG_CACHE_HOME on
260
+        first call; the returned path is the cache dir, not the source."""
261
+        from dlm_sway.backends.mlx import _ensure_mlx_adapter
262
+
263
+        cache_root = tmp_path / "cache"
264
+        monkeypatch.setenv("XDG_CACHE_HOME", str(cache_root))
265
+
266
+        peft_dir = tmp_path / "peft"
267
+        _write_synthetic_peft_adapter(peft_dir)
268
+        out = _ensure_mlx_adapter(peft_dir)
269
+
270
+        assert out != peft_dir
271
+        assert (out / "adapters.safetensors").exists()
272
+        assert (out / "adapter_config.json").exists()
273
+        assert str(out).startswith(str(cache_root))
274
+
275
+    def test_repeated_calls_short_circuit_on_cache_hit(
276
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
277
+    ) -> None:
278
+        """Same PEFT bytes → same cache hash → second call returns the
279
+        cached dir without re-converting (touch mtime to detect)."""
280
+        from dlm_sway.backends.mlx import _ensure_mlx_adapter
281
+
282
+        monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "cache"))
283
+        peft_dir = tmp_path / "peft"
284
+        _write_synthetic_peft_adapter(peft_dir)
285
+
286
+        first = _ensure_mlx_adapter(peft_dir)
287
+        first_mtime = (first / "adapters.safetensors").stat().st_mtime_ns
288
+
289
+        # Second call — should NOT rewrite the file.
290
+        second = _ensure_mlx_adapter(peft_dir)
291
+        assert second == first
292
+        assert (second / "adapters.safetensors").stat().st_mtime_ns == first_mtime
293
+
294
+    def test_passes_through_unrecognized_dir(self, tmp_path: Path) -> None:
295
+        """A directory with neither shape — let mlx_lm.load surface
296
+        its own error rather than this helper second-guessing."""
297
+        from dlm_sway.backends.mlx import _ensure_mlx_adapter
298
+
299
+        empty = tmp_path / "empty"
300
+        empty.mkdir()
301
+        out = _ensure_mlx_adapter(empty)
302
+        assert out == empty
303
+
304
+
237305
 class TestModulesToSave:
238306
     """``modules_to_save`` (e.g. embed_tokens, lm_head) must be skipped
239307
     cleanly with a report entry, not crash the converter."""