tenseleyflow/sway / 3890e1e

Browse files

sway(bridge): align resolver with real dlm API (type/parsers/for_dlm/base_models.resolve)

Authored by espadonne
SHA
3890e1e09452818cdf352d5fe678fc445fc15298
Parents
3a27d73
Tree
a5c717a

3 changed files

StatusFile+-
M src/dlm_sway/backends/hf.py 7 3
M src/dlm_sway/integrations/dlm/resolver.py 84 47
M tests/unit/test_dlm_bridge.py 63 26
src/dlm_sway/backends/hf.pymodified
@@ -256,7 +256,9 @@ class HuggingFaceDifferentialBackend:
256256
     def as_base(self) -> Iterator[_HFView]:
257257
         self._enter("base")
258258
         try:
259
-            with self._peft_model.disable_adapter():
259
+            # peft.PeftModel.disable_adapter is a context manager; mypy
260
+            # mis-reads it as a Tensor on this transformers version.
261
+            with self._peft_model.disable_adapter():  # type: ignore[operator]
260262
                 yield self._make_view("base")
261263
         finally:
262264
             self._exit()
@@ -279,7 +281,9 @@ class HuggingFaceDifferentialBackend:
279281
         exception propagates, to keep the model in a sane state.
280282
         """
281283
         self._enter(f"scaled({lam})")
282
-        saved: list[tuple[object, str, float]] = []
284
+        # ``module`` is dynamic (peft LoraLayer subclass) — Any avoids
285
+        # mypy treating its ``.scaling`` as a Tensor when peft is loaded.
286
+        saved: list[tuple[Any, str, float]] = []
283287
         try:
284288
             import peft  # noqa: PLC0415 — already a hard dep of this backend
285289
 
@@ -298,7 +302,7 @@ class HuggingFaceDifferentialBackend:
298302
             yield self._make_view(f"scaled_{lam:.2f}")
299303
         finally:
300304
             for module, key, original in saved:
301
-                module.scaling[key] = original  # type: ignore[attr-defined]
305
+                module.scaling[key] = original
302306
             self._exit()
303307
 
304308
     @contextmanager
src/dlm_sway/integrations/dlm/resolver.pymodified
@@ -69,50 +69,75 @@ def resolve_dlm(dlm_path: Path) -> DlmHandle:
6969
     doc_text = "\n\n".join(s.content for s in sections)
7070
 
7171
     adapter_path = _resolve_adapter_path(fm.dlm_id)
72
+    base_hf_id = _resolve_base_model_to_hf_id(fm.base_model)
7273
 
7374
     return DlmHandle(
7475
         dlm_id=fm.dlm_id,
75
-        base_model=fm.base_model,
76
+        base_model=base_hf_id,
7677
         adapter_path=adapter_path,
7778
         sections=sections,
7879
         doc_text=doc_text,
7980
     )
8081
 
8182
 
83
+def _resolve_base_model_to_hf_id(base_model: str) -> str:
84
+    """Translate dlm's base-model *key* to a HuggingFace repo id.
85
+
86
+    dlm's frontmatter stores registry keys like ``smollm2-135m`` which
87
+    resolve to ``HuggingFaceTB/SmolLM2-135M-Instruct``. sway's backends
88
+    call ``AutoModelForCausalLM.from_pretrained`` directly and need the
89
+    HF id. The ``hf:org/name`` escape hatch passes through unchanged.
90
+    """
91
+    if base_model.startswith("hf:"):
92
+        return base_model[len("hf:") :]
93
+    try:
94
+        from dlm.base_models import resolve as resolve_base
95
+    except ImportError:
96
+        return base_model
97
+    try:
98
+        spec = resolve_base(base_model)
99
+    except Exception:  # noqa: BLE001 — unknown dlm errors
100
+        return base_model
101
+    hf_id = getattr(spec, "hf_id", None)
102
+    return str(hf_id) if hf_id else base_model
103
+
104
+
82105
 def _resolve_adapter_path(dlm_id: str) -> Path | None:
83106
     """Locate the current adapter directory for ``dlm_id``.
84107
 
85
-    Uses dlm's ``StorePath`` helper if available, else falls back to
86
-    the canonical ``~/.dlm/store/<dlm_id>/adapter/current.txt`` pointer.
87
-    Returns ``None`` if no adapter has been trained yet.
108
+    Uses dlm's module-level ``for_dlm`` helper if available, else falls
109
+    back to the canonical ``~/.dlm/store/<dlm_id>/adapter/current.txt``
110
+    pointer. Returns ``None`` if no adapter has been trained yet.
88111
     """
112
+    # Primary path: use dlm's own store-path helpers.
89113
     try:
90
-        from dlm.store.paths import StorePath
91
-
92
-        _store_path_cls: object | None = StorePath
114
+        from dlm.store.paths import for_dlm as _for_dlm
93115
     except ImportError:
94
-        _store_path_cls = None
116
+        _for_dlm = None
95117
 
96
-    if _store_path_cls is not None:
118
+    if _for_dlm is not None:
97119
         try:
98
-            store = _store_path_cls.for_dlm(dlm_id)  # type: ignore[attr-defined]
120
+            store = _for_dlm(dlm_id)
99121
         except Exception:  # noqa: BLE001 — unknown dlm exception shapes
100
-            return None
101
-        try:
102
-            resolved = store.resolve_current_adapter()
103
-        except (AttributeError, FileNotFoundError):
104
-            resolved = None
105
-        if resolved is not None and resolved.exists():
106
-            return Path(resolved)
107
-
108
-    # Manual fallback in case the dlm API evolves.
122
+            store = None
123
+        if store is not None:
124
+            try:
125
+                resolved = store.resolve_current_adapter()
126
+            except (AttributeError, FileNotFoundError):
127
+                resolved = None
128
+            if resolved is not None and Path(resolved).exists():
129
+                return Path(resolved)
130
+
131
+    # Manual fallback. The ``current.txt`` pointer is relative to the
132
+    # **store root**, not to current.txt's parent dir — so go up one level.
109133
     import os
110134
 
111135
     home = Path(os.environ.get("DLM_HOME", "~/.dlm")).expanduser()
112
-    current_file = home / "store" / dlm_id / "adapter" / "current.txt"
136
+    store_root = home / "store" / dlm_id
137
+    current_file = store_root / "adapter" / "current.txt"
113138
     if current_file.exists():
114139
         pointer = current_file.read_text(encoding="utf-8").strip()
115
-        candidate = (current_file.parent / pointer).resolve()
140
+        candidate = (store_root / pointer).resolve()
116141
         if candidate.exists():
117142
             return candidate
118143
     return None
@@ -121,12 +146,14 @@ def _resolve_adapter_path(dlm_id: str) -> Path | None:
121146
 def _translate_section(dlm_section: object) -> Section:
122147
     """Adapt a ``dlm.doc.sections.Section`` to sway's section type.
123148
 
124
-    The shape dlm uses has been stable through the v0.x series but we
125
-    treat field access defensively so a minor dlm refactor can't silently
126
-    misread section content.
149
+    dlm's Section dataclass uses the attribute name ``type`` (not
150
+    ``kind``) and stores instruction/preference content as raw markdown
151
+    — dlm ships dedicated parsers (``parse_instruction_body``,
152
+    ``parse_preference_body``) that we reuse here so any future dlm
153
+    syntax additions land in sway for free.
127154
     """
128
-    kind_raw = getattr(dlm_section, "kind", None)
129
-    # dlm uses the attribute name "kind" on its Section dataclass.
155
+    # dlm's current attribute is ``type``; older revisions used ``kind``.
156
+    kind_raw = getattr(dlm_section, "type", getattr(dlm_section, "kind", None))
130157
     kind = _normalize_kind(kind_raw)
131158
     content = str(getattr(dlm_section, "content", ""))
132159
     section_id = str(
@@ -139,9 +166,9 @@ def _translate_section(dlm_section: object) -> Section:
139166
     probes: tuple[SectionProbe, ...] = ()
140167
     preferences: tuple[SectionPreference, ...] = ()
141168
     if kind == "instruction":
142
-        probes = tuple(_extract_instruction_probes(dlm_section))
169
+        probes = tuple(_parse_instruction(content, section_id=section_id))
143170
     elif kind == "preference":
144
-        preferences = tuple(_extract_preference_triples(dlm_section))
171
+        preferences = tuple(_parse_preference(content, section_id=section_id))
145172
 
146173
     return Section(
147174
         id=section_id,
@@ -168,35 +195,45 @@ def _normalize_kind(raw: object) -> SectionKind:
168195
     return "prose"
169196
 
170197
 
171
-def _extract_instruction_probes(dlm_section: object) -> list[SectionProbe]:
172
-    """Pull (Q, A) pairs out of a dlm INSTRUCTION section.
198
+def _parse_instruction(content: str, *, section_id: str) -> list[SectionProbe]:
199
+    """Pull (Q, A) pairs out of a dlm INSTRUCTION section body.
173200
 
174
-    dlm's Section carries its parsed Q/A as ``probes`` or ``qa`` depending
175
-    on version. We read the first non-empty one and build
176
-    :class:`SectionProbe` records defensively.
201
+    Delegates to dlm's own ``parse_instruction_body`` so syntax additions
202
+    land in sway without code changes here. Falls back to an empty list
203
+    on parse errors — the probe will fail gracefully.
177204
     """
178
-    raw_probes = getattr(dlm_section, "probes", None) or getattr(dlm_section, "qa", None)
179
-    if not raw_probes:
205
+    try:
206
+        from dlm.data.instruction_parser import parse_instruction_body
207
+    except ImportError:
208
+        return []
209
+    try:
210
+        pairs = parse_instruction_body(content, section_id=section_id)
211
+    except Exception:  # noqa: BLE001 — dlm raises InstructionParseError
180212
         return []
181213
     out: list[SectionProbe] = []
182
-    for rp in raw_probes:
183
-        q = str(getattr(rp, "prompt", getattr(rp, "question", "")))
184
-        a = str(getattr(rp, "gold", getattr(rp, "answer", "")))
214
+    for p in pairs:
215
+        q = getattr(p, "question", getattr(p, "prompt", ""))
216
+        a = getattr(p, "answer", getattr(p, "gold", ""))
185217
         if q and a:
186
-            out.append(SectionProbe(prompt=q, gold=a))
218
+            out.append(SectionProbe(prompt=str(q), gold=str(a)))
187219
     return out
188220
 
189221
 
190
-def _extract_preference_triples(dlm_section: object) -> list[SectionPreference]:
191
-    """Pull (prompt, chosen, rejected) triples out of a dlm PREFERENCE section."""
192
-    raw = getattr(dlm_section, "preferences", None) or getattr(dlm_section, "triples", None)
193
-    if not raw:
222
+def _parse_preference(content: str, *, section_id: str) -> list[SectionPreference]:
223
+    """Pull (prompt, chosen, rejected) triples out of a PREFERENCE body."""
224
+    try:
225
+        from dlm.data.preference_parser import parse_preference_body
226
+    except ImportError:
227
+        return []
228
+    try:
229
+        triples = parse_preference_body(content, section_id=section_id)
230
+    except Exception:  # noqa: BLE001 — dlm raises PreferenceParseError
194231
         return []
195232
     out: list[SectionPreference] = []
196
-    for r in raw:
197
-        p = str(getattr(r, "prompt", ""))
198
-        c = str(getattr(r, "chosen", ""))
199
-        rej = str(getattr(r, "rejected", ""))
233
+    for t in triples:
234
+        p = str(getattr(t, "prompt", ""))
235
+        c = str(getattr(t, "chosen", ""))
236
+        rej = str(getattr(t, "rejected", ""))
200237
         if p and c and rej:
201238
             out.append(SectionPreference(prompt=p, chosen=c, rejected=rej))
202239
     return out
tests/unit/test_dlm_bridge.pymodified
@@ -25,26 +25,13 @@ def fake_dlm(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
2525
     @dataclass
2626
     class _Frontmatter:
2727
         dlm_id: str = "01TESTULID"
28
-        base_model: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
29
-
30
-    @dataclass
31
-    class _InstrProbe:
32
-        prompt: str
33
-        gold: str
34
-
35
-    @dataclass
36
-    class _PrefTriple:
37
-        prompt: str
38
-        chosen: str
39
-        rejected: str
28
+        base_model: str = "smollm2-135m"
4029
 
4130
     @dataclass
4231
     class _Section:
4332
         section_id: str
44
-        kind: str
33
+        type: str
4534
         content: str
46
-        probes: tuple[object, ...] = ()
47
-        preferences: tuple[object, ...] = ()
4835
         tag: str | None = None
4936
 
5037
     @dataclass
@@ -58,20 +45,18 @@ def fake_dlm(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
5845
             sections=(
5946
                 _Section(
6047
                     section_id="prose-1",
61
-                    kind="PROSE",
48
+                    type="PROSE",
6249
                     content="This is a prose section with some information. Further detail follows.",
6350
                 ),
6451
                 _Section(
6552
                     section_id="instr-1",
66
-                    kind="INSTRUCTION",
67
-                    content="Q-A pairs",
68
-                    probes=(_InstrProbe("What is X?", "X is a concept"),),
53
+                    type="INSTRUCTION",
54
+                    content="### Q\nWhat is X?\n\n### A\nX is a concept\n",
6955
                 ),
7056
                 _Section(
7157
                     section_id="pref-1",
72
-                    kind="PREFERENCE",
73
-                    content="Prefs",
74
-                    preferences=(_PrefTriple("Which?", "good answer", "bad answer"),),
58
+                    type="PREFERENCE",
59
+                    content="chosen/rejected triple",
7560
                 ),
7661
             ),
7762
         )
@@ -94,20 +79,72 @@ def fake_dlm(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
9479
         def __init__(self, path: Path) -> None:
9580
             self._p = path
9681
 
97
-        @classmethod
98
-        def for_dlm(cls, _dlm_id: str) -> _StorePath:
99
-            return cls(adapter_dir)
100
-
10182
         def resolve_current_adapter(self) -> Path:
10283
             return self._p
10384
 
85
+    def _for_dlm(_dlm_id: str) -> _StorePath:
86
+        return _StorePath(adapter_dir)
87
+
10488
     dlm_store_paths.StorePath = _StorePath  # type: ignore[attr-defined]
89
+    dlm_store_paths.for_dlm = _for_dlm  # type: ignore[attr-defined]
90
+
91
+    # Fake base-model resolver — returns a stub with an ``hf_id`` attribute.
92
+    dlm_base = types.ModuleType("dlm.base_models")
93
+
94
+    @dataclass
95
+    class _BaseSpec:
96
+        hf_id: str
97
+        key: str
98
+
99
+    def _resolve(key: str) -> _BaseSpec:
100
+        return _BaseSpec(hf_id="HuggingFaceTB/SmolLM2-135M-Instruct", key=key)
101
+
102
+    dlm_base.resolve = _resolve  # type: ignore[attr-defined]
103
+
104
+    # Fake instruction / preference parsers.
105
+    dlm_data = types.ModuleType("dlm.data")
106
+    dlm_data_instr = types.ModuleType("dlm.data.instruction_parser")
107
+    dlm_data_pref = types.ModuleType("dlm.data.preference_parser")
108
+
109
+    @dataclass
110
+    class _QAPair:
111
+        question: str
112
+        answer: str
113
+
114
+    @dataclass
115
+    class _Triple:
116
+        prompt: str
117
+        chosen: str
118
+        rejected: str
119
+
120
+    def _parse_instr(body: str, *, section_id: str) -> list[_QAPair]:
121
+        del section_id
122
+        out: list[_QAPair] = []
123
+        parts = body.split("### Q")
124
+        for part in parts[1:]:
125
+            q_block, _, a_block = part.partition("### A")
126
+            q = q_block.strip()
127
+            a = a_block.strip()
128
+            if q and a:
129
+                out.append(_QAPair(question=q, answer=a))
130
+        return out
131
+
132
+    def _parse_pref(body: str, *, section_id: str) -> list[_Triple]:
133
+        del body, section_id
134
+        return [_Triple(prompt="Which?", chosen="good answer", rejected="bad answer")]
135
+
136
+    dlm_data_instr.parse_instruction_body = _parse_instr  # type: ignore[attr-defined]
137
+    dlm_data_pref.parse_preference_body = _parse_pref  # type: ignore[attr-defined]
105138
 
106139
     monkeypatch.setitem(sys.modules, "dlm", dlm_pkg)
107140
     monkeypatch.setitem(sys.modules, "dlm.doc", dlm_doc)
108141
     monkeypatch.setitem(sys.modules, "dlm.doc.parser", dlm_doc_parser)
109142
     monkeypatch.setitem(sys.modules, "dlm.store", dlm_store)
110143
     monkeypatch.setitem(sys.modules, "dlm.store.paths", dlm_store_paths)
144
+    monkeypatch.setitem(sys.modules, "dlm.base_models", dlm_base)
145
+    monkeypatch.setitem(sys.modules, "dlm.data", dlm_data)
146
+    monkeypatch.setitem(sys.modules, "dlm.data.instruction_parser", dlm_data_instr)
147
+    monkeypatch.setitem(sys.modules, "dlm.data.preference_parser", dlm_data_pref)
111148
 
112149
     # Return a path to a fake .dlm file (the parser won't actually read it).
113150
     dlm_file = tmp_path / "doc.dlm"