tenseleyflow/sway / 468a436

Browse files

tests/integration: end-to-end PEFT→MLX convert+load+score (S24 prove-the-value)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
468a436ecfaf567a0a6db765a32c1a91306ea908
Parents
7390b0a
Tree
b960d56

1 changed file

StatusFile+-
A tests/integration/test_mlx_converter_e2e.py 184 0
tests/integration/test_mlx_converter_e2e.pyadded
@@ -0,0 +1,184 @@
1
+"""S24 — end-to-end PEFT → MLX adapter conversion (darwin-arm64-only).
2
+
3
+Closes the F01 audit gap: ``dlm train`` writes a PEFT-shaped adapter,
4
+``MLXDifferentialBackend`` is pointed at it, the backend auto-converts
5
+into the user's cache, ``mlx_lm.load`` consumes the result, scoring
6
+returns finite logprobs.
7
+
8
+This is the **prove-the-value** test the sprint file calls out — every
9
+other layer of testing (synthetic-input unit tests, CLI smoke tests)
10
+is upstream of this. If this passes locally on darwin-arm64, the
11
+headline ``.dlm → MLX`` flow works.
12
+
13
+Skips cleanly on:
14
+- non-darwin (mlx is Apple Silicon only)
15
+- non-arm64
16
+- ``mlx_lm`` not installed (the ``[mlx]`` extra is optional)
17
+- ``peft`` / ``transformers`` not installed (the ``[hf]`` extra needed
18
+  to *build* the source PEFT adapter)
19
+"""
20
+
21
+from __future__ import annotations
22
+
23
+import math
24
+import platform
25
+import sys
26
+from pathlib import Path
27
+
28
+import numpy as np
29
+import pytest
30
+
31
+pytestmark = [pytest.mark.slow, pytest.mark.online]
32
+
33
+
34
+# Default to the unquantized MLX repo because the 4-bit variant has
35
+# slipped into a gated/auth state on HF Hub. Either repo's adapter
36
+# slot works for the converter — the test only cares that mlx-lm
37
+# loads our converted ``adapters.safetensors``.
38
+_MODEL_ID = "mlx-community/SmolLM2-135M-Instruct"
39
+
40
+
41
+def _platform_supports_mlx() -> bool:
42
+    return sys.platform == "darwin" and platform.machine() == "arm64"
43
+
44
+
45
+def _build_random_peft_lora(base_dir: Path, out_dir: Path) -> None:
46
+    """Same deterministic LoRA the HF integration tests use, shipped
47
+    here because we don't want to import from another test file."""
48
+    import torch
49
+    from peft import LoraConfig, get_peft_model
50
+    from transformers import AutoModelForCausalLM, AutoTokenizer
51
+
52
+    torch.manual_seed(0)
53
+    tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
54
+    if tokenizer.pad_token_id is None:
55
+        tokenizer.pad_token = tokenizer.eos_token
56
+    base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
57
+    cfg = LoraConfig(
58
+        r=8,
59
+        lora_alpha=16,
60
+        target_modules=["q_proj", "v_proj"],
61
+        lora_dropout=0.0,
62
+        bias="none",
63
+        task_type="CAUSAL_LM",
64
+    )
65
+    peft_model = get_peft_model(base, cfg)
66
+    with torch.no_grad():
67
+        for name, param in peft_model.named_parameters():
68
+            if "lora_B" in name:
69
+                param.copy_(torch.randn_like(param) * 0.05)
70
+    peft_model.save_pretrained(str(out_dir))
71
+    tokenizer.save_pretrained(str(out_dir))
72
+
73
+
74
+@pytest.fixture(scope="module")
75
+def peft_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
76
+    if not _platform_supports_mlx():
77
+        pytest.skip("MLX requires darwin-arm64")
78
+    pytest.importorskip("peft", reason="needs the [hf] extra to build a PEFT adapter")
79
+    out = tmp_path_factory.mktemp("peft-for-mlx-convert")
80
+    _build_random_peft_lora(tiny_model_dir, out)
81
+    return out
82
+
83
+
84
+@pytest.fixture(scope="module")
85
+def mlx_backend(peft_adapter: Path, tmp_path_factory: pytest.TempPathFactory):
86
+    """Point the MLX backend at a PEFT-shaped adapter dir; the backend
87
+    auto-converts into a tmp cache (XDG_CACHE_HOME redirected so we
88
+    don't pollute the user's real cache)."""
89
+    pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests")
90
+
91
+    # Redirect the cache so this test doesn't write to the user's
92
+    # ~/.cache/dlm-sway/. Each fixture invocation gets a fresh dir.
93
+    import os
94
+
95
+    cache_root = tmp_path_factory.mktemp("mlx-convert-cache")
96
+    prev = os.environ.get("XDG_CACHE_HOME")
97
+    os.environ["XDG_CACHE_HOME"] = str(cache_root)
98
+    try:
99
+        from dlm_sway.backends.mlx import MLXDifferentialBackend
100
+        from dlm_sway.core.model import ModelSpec
101
+
102
+        backend = MLXDifferentialBackend(
103
+            base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
104
+            adapter_path=peft_adapter,
105
+        )
106
+        yield backend, cache_root
107
+        backend.close()
108
+    finally:
109
+        if prev is None:
110
+            os.environ.pop("XDG_CACHE_HOME", None)
111
+        else:
112
+            os.environ["XDG_CACHE_HOME"] = prev
113
+
114
+
115
+def test_auto_conversion_writes_to_xdg_cache(mlx_backend) -> None:
116
+    """The backend's __init__ must have populated the cache dir with
117
+    an MLX-format adapter — proves the auto-convert path fired."""
118
+    _backend, cache_root = mlx_backend
119
+    converted = list((cache_root / "dlm-sway" / "mlx-converted").glob("*"))
120
+    assert len(converted) == 1, f"expected exactly one cached MLX adapter dir, got {converted}"
121
+    cache_dir = converted[0]
122
+    assert (cache_dir / "adapters.safetensors").exists()
123
+    assert (cache_dir / "adapter_config.json").exists()
124
+
125
+
126
+def test_next_token_dist_returns_finite_topk_via_converted_adapter(mlx_backend) -> None:
127
+    """The converted adapter, loaded via mlx_lm + scored via the MLX
128
+    backend, must produce finite, well-ordered top-k logprobs."""
129
+    backend, _ = mlx_backend
130
+    with backend.as_finetuned() as ft:
131
+        d = ft.next_token_dist("The capital of France is", top_k=32)
132
+    assert d.token_ids.shape == (32,)
133
+    assert d.logprobs.shape == (32,)
134
+    assert np.all(np.isfinite(d.logprobs))
135
+    assert np.all(np.diff(d.logprobs) <= 1e-7)  # descending
136
+
137
+
138
+def test_logprob_of_finite_via_converted_adapter(mlx_backend) -> None:
139
+    backend, _ = mlx_backend
140
+    with backend.as_finetuned() as ft:
141
+        lp = ft.logprob_of("The capital of France is", " Paris")
142
+    assert math.isfinite(lp)
143
+    assert lp < 0.0
144
+
145
+
146
+def test_repeat_load_skips_reconvert(
147
+    peft_adapter: Path, tmp_path_factory: pytest.TempPathFactory
148
+) -> None:
149
+    """Second backend instance against the same PEFT adapter must
150
+    short-circuit on the cache and NOT rewrite the converted file."""
151
+    pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests")
152
+
153
+    import os
154
+
155
+    cache_root = tmp_path_factory.mktemp("mlx-convert-cache-2")
156
+    prev = os.environ.get("XDG_CACHE_HOME")
157
+    os.environ["XDG_CACHE_HOME"] = str(cache_root)
158
+    try:
159
+        from dlm_sway.backends.mlx import MLXDifferentialBackend
160
+        from dlm_sway.core.model import ModelSpec
161
+
162
+        b1 = MLXDifferentialBackend(
163
+            base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
164
+            adapter_path=peft_adapter,
165
+        )
166
+        cache_dir = next((cache_root / "dlm-sway" / "mlx-converted").glob("*"))
167
+        first_mtime = (cache_dir / "adapters.safetensors").stat().st_mtime_ns
168
+        b1.close()
169
+
170
+        b2 = MLXDifferentialBackend(
171
+            base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
172
+            adapter_path=peft_adapter,
173
+        )
174
+        second_mtime = (cache_dir / "adapters.safetensors").stat().st_mtime_ns
175
+        b2.close()
176
+
177
+        assert second_mtime == first_mtime, (
178
+            "second backend init re-wrote the cached MLX adapter — cache short-circuit is broken"
179
+        )
180
+    finally:
181
+        if prev is None:
182
+            os.environ.pop("XDG_CACHE_HOME", None)
183
+        else:
184
+            os.environ["XDG_CACHE_HOME"] = prev