tenseleyflow/documentlanguagemodel / 42c551a

Browse files

refactor(modality): ModalityDispatch registry + fold spec.modality == scatter

src/dlm/modality/ gains errors.py, registry.py (base class + predicate
flags + dispatch hooks), vl.py, audio.py, text.py, plus the MODALITIES
dict + modality_for() helper. Predicate flags (requires_processor,
accepts_images, accepts_audio) replace string comparisons; subclasses
override dispatch_export + load_processor where they have real work.

Folds nine call sites off spec.modality == "..." onto predicate reads
or modality_for(spec).dispatch_export() — cli/commands.py (two export
branches + two prompt-guardrail branches), train/trainer.py (is_vl /
is_audio / is_media trio), train/loader.py (VL / audio base-model
class picker), base_models/probes.py (token-probe dispatch +
is_media).

Per-file-ignore ARG002 carves out the polymorphic signature — each
subclass uses a different subset of the shared kwargs.
Authored by espadonne
SHA
42c551a67de0aa1c85fe8208e0e37b7b848cbf6d
Parents
c00a243
Tree
9dc0fc4

11 changed files

StatusFile+-
M pyproject.toml 6 0
M src/dlm/base_models/probes.py 6 3
M src/dlm/cli/commands.py 17 15
A src/dlm/modality/__init__.py 43 0
A src/dlm/modality/audio.py 46 0
A src/dlm/modality/errors.py 11 0
A src/dlm/modality/registry.py 108 0
A src/dlm/modality/text.py 7 0
A src/dlm/modality/vl.py 47 0
M src/dlm/train/loader.py 5 2
M src/dlm/train/trainer.py 6 3
pyproject.tomlmodified
@@ -152,6 +152,12 @@ ignore = [
152152
 # positionally even when the implementation only reads some of them —
153153
 # HF dispatches them by position. ARG002 for these wrappers is noise.
154154
 "src/dlm/train/cpt/embed_warmup.py" = ["ARG002"]
155
+# Modality dispatch uses a polymorphic interface — each subclass uses
156
+# a different subset of the keyword args (text.dispatch_export reads
157
+# none, VL reads gguf_emission_context, audio ignores it). ARG002
158
+# flags the unused ones in each branch; the shared signature is the
159
+# point of the abstraction.
160
+"src/dlm/modality/*.py" = ["ARG002"]
155161
 
156162
 [tool.ruff.format]
157163
 quote-style = "double"
src/dlm/base_models/probes.pymodified
@@ -562,10 +562,13 @@ def run_all(spec: BaseModelSpec, *, skip_export_probes: bool = False) -> ProbeRe
562562
     the vendored copy catches up. VL bases auto-opt-out of export
563563
     probes — GGUF conversion for VL archs is tracked in Sprint 35.4.
564564
     """
565
+    from dlm.modality import modality_for
566
+
567
+    dispatch = modality_for(spec)
565568
     core: tuple[ProbeResult, ...] = (probe_architecture(spec),)
566
-    if spec.modality == "vision-language":
569
+    if dispatch.accepts_images:
567570
         core = (*core, probe_vl_image_token(spec))
568
-    elif spec.modality == "audio-language":
571
+    elif dispatch.accepts_audio:
569572
         core = (*core, probe_audio_token(spec))
570573
     else:
571574
         core = (*core, probe_chat_template(spec))
@@ -574,7 +577,7 @@ def run_all(spec: BaseModelSpec, *, skip_export_probes: bool = False) -> ProbeRe
574577
     # converter support for VL archs is Sprint 35.4's scope, and audio
575578
     # archs are not on any llama.cpp roadmap yet. The export path
576579
     # refuses GGUF cleanly for both and emits an HF snapshot instead.
577
-    is_media = spec.modality in ("vision-language", "audio-language")
580
+    is_media = dispatch.requires_processor
578581
     if skip_export_probes or is_media:
579582
         return ProbeReport(hf_id=spec.hf_id, results=core)
580583
     results = (
src/dlm/cli/commands.pymodified
@@ -1253,20 +1253,22 @@ def prompt_cmd(
12531253
     # The VL branch has its own model / processor / adapter loader and
12541254
     # its own generate function. `--image` and vision-language bases
12551255
     # must appear together; each alone is a usage error.
1256
-    is_vl_spec = spec.modality == "vision-language"
1257
-    if image_paths and not is_vl_spec:
1256
+    from dlm.modality import modality_for
1257
+
1258
+    dispatch = modality_for(spec)
1259
+    if image_paths and not dispatch.accepts_images:
12581260
         console.print(
12591261
             f"[red]prompt:[/red] --image is only valid with vision-language bases; "
12601262
             f"base {spec.key!r} is modality='{spec.modality}'."
12611263
         )
12621264
         raise typer.Exit(code=2)
1263
-    if is_vl_spec and not image_paths:
1265
+    if dispatch.accepts_images and not image_paths:
12641266
         console.print(
12651267
             f"[red]prompt:[/red] base {spec.key!r} is vision-language; "
12661268
             "pass at least one --image PATH to prompt it."
12671269
         )
12681270
         raise typer.Exit(code=2)
1269
-    if is_vl_spec:
1271
+    if dispatch.accepts_images:
12701272
         _dispatch_vl_prompt(
12711273
             console=console,
12721274
             spec=spec,
@@ -1283,20 +1285,19 @@ def prompt_cmd(
12831285
         return
12841286
 
12851287
     # --- Audio path (Sprint 35.2) -------------------------------------
1286
-    is_audio_spec = spec.modality == "audio-language"
1287
-    if audio_paths and not is_audio_spec:
1288
+    if audio_paths and not dispatch.accepts_audio:
12881289
         console.print(
12891290
             f"[red]prompt:[/red] --audio is only valid with audio-language bases; "
12901291
             f"base {spec.key!r} is modality='{spec.modality}'."
12911292
         )
12921293
         raise typer.Exit(code=2)
1293
-    if is_audio_spec and not audio_paths:
1294
+    if dispatch.accepts_audio and not audio_paths:
12941295
         console.print(
12951296
             f"[red]prompt:[/red] base {spec.key!r} is audio-language; "
12961297
             "pass at least one --audio PATH to prompt it."
12971298
         )
12981299
         raise typer.Exit(code=2)
1299
-    if is_audio_spec:
1300
+    if dispatch.accepts_audio:
13001301
         _dispatch_audio_prompt(
13011302
             console=console,
13021303
             spec=spec,
@@ -1702,11 +1703,12 @@ def export_cmd(
17021703
     # Audio bases take HF-snapshot unconditionally — llama.cpp has no
17031704
     # audio-arch roadmap at our pinned tag — so branch early without
17041705
     # resolving a GGUF plan.
1705
-    if spec.modality == "audio-language":
1706
-        from dlm.export.dispatch import dispatch_audio_export
1706
+    from dlm.modality import modality_for
17071707
 
1708
+    export_dispatch = modality_for(spec)
1709
+    if export_dispatch.accepts_audio:
17081710
         try:
1709
-            dispatch_result = dispatch_audio_export(
1711
+            dispatch_result = export_dispatch.dispatch_export(
17101712
                 store=store,
17111713
                 spec=spec,
17121714
                 adapter_name=adapter,
@@ -1717,6 +1719,7 @@ def export_cmd(
17171719
         except ExportError as exc:
17181720
             console.print(f"[red]export:[/red] {exc}")
17191721
             raise typer.Exit(code=1) from exc
1722
+        assert dispatch_result is not None  # audio modality always returns a result
17201723
         for line in dispatch_result.banner_lines:
17211724
             console.print(line)
17221725
         return
@@ -1742,9 +1745,7 @@ def export_cmd(
17421745
     # still need the resolved plan + cached base dir for the GGUF
17431746
     # path, so resolve those first, then let the dispatcher decide
17441747
     # whether to use them.
1745
-    if spec.modality == "vision-language":
1746
-        from dlm.export.dispatch import dispatch_vl_export
1747
-
1748
+    if export_dispatch.accepts_images:
17481749
         try:
17491750
             cached_vl = download_spec(spec, local_files_only=True)
17501751
         except RuntimeError as exc:
@@ -1754,7 +1755,7 @@ def export_cmd(
17541755
             )
17551756
             raise typer.Exit(code=1) from exc
17561757
         try:
1757
-            dispatch_result = dispatch_vl_export(
1758
+            dispatch_result = export_dispatch.dispatch_export(
17581759
                 store=store,
17591760
                 spec=spec,
17601761
                 adapter_name=adapter,
@@ -1772,6 +1773,7 @@ def export_cmd(
17721773
         except ExportError as exc:
17731774
             console.print(f"[red]export:[/red] {exc}")
17741775
             raise typer.Exit(code=1) from exc
1776
+        assert dispatch_result is not None  # VL modality always returns a result
17751777
         for line in dispatch_result.banner_lines:
17761778
             console.print(line)
17771779
         return
src/dlm/modality/__init__.pyadded
@@ -0,0 +1,43 @@
1
+"""Modality dispatch package — replaces scattered ``spec.modality ==`` branches.
2
+
3
+Public surface:
4
+
5
+- :class:`ModalityDispatch` — base class with predicate flags +
6
+  dispatch hooks (``dispatch_export``, ``load_processor``).
7
+- :data:`MODALITIES` — string → instance registry.
8
+- :func:`modality_for` — resolve a spec to its dispatcher.
9
+- :class:`UnknownModalityError` — raised when a spec's modality
10
+  string has no registered dispatcher.
11
+
12
+Callers that previously wrote ``if spec.modality == "vision-language"``
13
+now read ``modality_for(spec).accepts_images`` (or one of the other
14
+predicate flags) or call a dispatch method directly. A pregate
15
+grep-gate refuses new scatter — see ``scripts/pregate.sh``.
16
+"""
17
+
18
+from __future__ import annotations
19
+
20
+from dlm.modality.audio import AudioLanguageModality
21
+from dlm.modality.errors import ModalityError, UnknownModalityError
22
+from dlm.modality.registry import ModalityDispatch, TextModality, modality_for
23
+from dlm.modality.vl import VisionLanguageModality
24
+
25
+MODALITIES: dict[str, ModalityDispatch] = {
26
+    "text": TextModality(),
27
+    "vision-language": VisionLanguageModality(),
28
+    "audio-language": AudioLanguageModality(),
29
+}
30
+"""Registry: modality string → dispatcher instance. Ordered by
31
+registration history — future modalities append here and land a
32
+corresponding class under ``dlm.modality``."""
33
+
34
+__all__ = [
35
+    "MODALITIES",
36
+    "AudioLanguageModality",
37
+    "ModalityDispatch",
38
+    "ModalityError",
39
+    "TextModality",
40
+    "UnknownModalityError",
41
+    "VisionLanguageModality",
42
+    "modality_for",
43
+]
src/dlm/modality/audio.pyadded
@@ -0,0 +1,46 @@
1
+"""Audio-language modality dispatch."""
2
+
3
+from __future__ import annotations
4
+
5
+from typing import TYPE_CHECKING, Any
6
+
7
+from dlm.modality.registry import ModalityDispatch
8
+
9
+if TYPE_CHECKING:
10
+    from dlm.base_models import BaseModelSpec
11
+    from dlm.export.dispatch import DispatchResult
12
+
13
+
14
+class AudioLanguageModality(ModalityDispatch):
15
+    """Audio-language base — audio accepted, processor required, HF-snapshot export."""
16
+
17
+    modality = "audio-language"
18
+    requires_processor = True
19
+    accepts_audio = True
20
+
21
+    def load_processor(self, spec: BaseModelSpec) -> Any:
22
+        from dlm.train.loader import load_processor as _load
23
+
24
+        return _load(spec)
25
+
26
+    def dispatch_export(
27
+        self,
28
+        *,
29
+        store: Any,
30
+        spec: BaseModelSpec,
31
+        adapter_name: str | None,
32
+        quant: str | None,
33
+        merged: bool,
34
+        adapter_mix_raw: str | None,
35
+        gguf_emission_context: dict[str, Any] | None = None,
36
+    ) -> DispatchResult:
37
+        from dlm.export.dispatch import dispatch_audio_export
38
+
39
+        return dispatch_audio_export(
40
+            store=store,
41
+            spec=spec,
42
+            adapter_name=adapter_name,
43
+            quant=quant,
44
+            merged=merged,
45
+            adapter_mix_raw=adapter_mix_raw,
46
+        )
src/dlm/modality/errors.pyadded
@@ -0,0 +1,11 @@
1
+"""Typed errors for modality dispatch."""
2
+
3
+from __future__ import annotations
4
+
5
+
6
+class ModalityError(Exception):
7
+    """Base for `dlm.modality` errors."""
8
+
9
+
10
+class UnknownModalityError(ModalityError):
11
+    """Spec declares a modality string the registry doesn't know."""
src/dlm/modality/registry.pyadded
@@ -0,0 +1,108 @@
1
+"""Modality dispatch base class — predicate flags + method hooks.
2
+
3
+Callers that used to branch on ``spec.modality == "vision-language"``
4
+or ``"audio-language"`` now read from a registered
5
+:class:`ModalityDispatch` instance. Three concrete subclasses live
6
+under the ``dlm.modality`` package — one per supported modality —
7
+registered in :data:`MODALITIES` and resolved via
8
+:func:`modality_for`. The split keeps the "does this spec accept
9
+images?" predicate next to the "route the export through the VL
10
+path" method: both are modality-specific concerns.
11
+
12
+Each instance carries:
13
+
14
+- ``modality`` (string tag — the only place a `"vision-language"`
15
+  string literal appears outside the base-model schema);
16
+- predicate flags (``requires_processor``, ``accepts_images``,
17
+  ``accepts_audio``) callers read instead of comparing the tag;
18
+- dispatch hooks (``dispatch_export``, ``dispatch_prompt``) that
19
+  forward to the modality-specific pipeline.
20
+
21
+A pregate grep-gate refuses new ``spec.modality ==`` comparisons
22
+outside this package so next-modality work lands here rather than
23
+scattering another set of branches.
24
+"""
25
+
26
+from __future__ import annotations
27
+
28
+from typing import TYPE_CHECKING, Any
29
+
30
+from dlm.modality.errors import UnknownModalityError
31
+
32
+if TYPE_CHECKING:
33
+    from dlm.base_models import BaseModelSpec
34
+    from dlm.export.dispatch import DispatchResult
35
+
36
+
37
+class ModalityDispatch:
38
+    """Base class — subclasses override per-modality predicates + hooks.
39
+
40
+    The base implementation defaults to the text-path semantics
41
+    (nothing to probe, nothing to dispatch). Subclasses narrow the
42
+    predicates and override the dispatch hooks.
43
+    """
44
+
45
+    modality: str = "text"
46
+    """The modality tag. The only place modality string literals
47
+    should appear outside this package."""
48
+
49
+    requires_processor: bool = False
50
+    """True for media modalities that ship a feature extractor /
51
+    processor alongside the tokenizer. Text-only bases set this
52
+    False — the trainer skips the BlobStore + preprocess pass."""
53
+
54
+    accepts_images: bool = False
55
+    """True for vision-language bases. Drives the ``dlm prompt
56
+    --image`` guardrail."""
57
+
58
+    accepts_audio: bool = False
59
+    """True for audio-language bases. Drives the ``dlm prompt
60
+    --audio`` guardrail."""
61
+
62
+    def load_processor(self, spec: BaseModelSpec) -> Any | None:
63
+        """Load the HF processor if this modality needs one. Text → None."""
64
+        return None
65
+
66
+    def dispatch_export(
67
+        self,
68
+        *,
69
+        store: Any,
70
+        spec: BaseModelSpec,
71
+        adapter_name: str | None,
72
+        quant: str | None,
73
+        merged: bool,
74
+        adapter_mix_raw: str | None,
75
+        gguf_emission_context: dict[str, Any] | None = None,
76
+    ) -> DispatchResult | None:
77
+        """Route an export through the modality-specific path.
78
+
79
+        Returns ``None`` on the text path — the caller falls back to
80
+        the GGUF `run_export` pipeline, which has a different result
81
+        shape (`run_export` returns `RunResult`, not `DispatchResult`,
82
+        and the text path prints its own banner inline).
83
+        """
84
+        return None
85
+
86
+
87
+class TextModality(ModalityDispatch):
88
+    """Text-only base — defaults carry the whole contract."""
89
+
90
+    modality = "text"
91
+
92
+
93
+def _unknown(mod: str) -> UnknownModalityError:
94
+    return UnknownModalityError(
95
+        f"modality={mod!r} has no registered dispatcher. "
96
+        "Register a ModalityDispatch subclass in dlm.modality and "
97
+        "add it to MODALITIES."
98
+    )
99
+
100
+
101
+def modality_for(spec: BaseModelSpec) -> ModalityDispatch:
102
+    """Resolve a spec's ``ModalityDispatch``, raising if unregistered."""
103
+    from dlm.modality import MODALITIES  # late import to avoid cycle
104
+
105
+    try:
106
+        return MODALITIES[spec.modality]
107
+    except KeyError as exc:
108
+        raise _unknown(spec.modality) from exc
src/dlm/modality/text.pyadded
@@ -0,0 +1,7 @@
1
+"""Text modality dispatch — thin re-export of the base defaults."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm.modality.registry import TextModality
6
+
7
+__all__ = ["TextModality"]
src/dlm/modality/vl.pyadded
@@ -0,0 +1,47 @@
1
+"""Vision-language modality dispatch."""
2
+
3
+from __future__ import annotations
4
+
5
+from typing import TYPE_CHECKING, Any
6
+
7
+from dlm.modality.registry import ModalityDispatch
8
+
9
+if TYPE_CHECKING:
10
+    from dlm.base_models import BaseModelSpec
11
+    from dlm.export.dispatch import DispatchResult
12
+
13
+
14
+class VisionLanguageModality(ModalityDispatch):
15
+    """VL base — images accepted, processor required, GGUF-then-snapshot export."""
16
+
17
+    modality = "vision-language"
18
+    requires_processor = True
19
+    accepts_images = True
20
+
21
+    def load_processor(self, spec: BaseModelSpec) -> Any:
22
+        from dlm.train.loader import load_processor as _load
23
+
24
+        return _load(spec)
25
+
26
+    def dispatch_export(
27
+        self,
28
+        *,
29
+        store: Any,
30
+        spec: BaseModelSpec,
31
+        adapter_name: str | None,
32
+        quant: str | None,
33
+        merged: bool,
34
+        adapter_mix_raw: str | None,
35
+        gguf_emission_context: dict[str, Any] | None = None,
36
+    ) -> DispatchResult:
37
+        from dlm.export.dispatch import dispatch_vl_export
38
+
39
+        return dispatch_vl_export(
40
+            store=store,
41
+            spec=spec,
42
+            adapter_name=adapter_name,
43
+            quant=quant,
44
+            merged=merged,
45
+            adapter_mix_raw=adapter_mix_raw,
46
+            gguf_emission_context=gguf_emission_context,
47
+        )
src/dlm/train/loader.pymodified
@@ -57,7 +57,10 @@ def load_base_model(spec: BaseModelSpec, plan: TrainingPlan) -> Any: # pragma:
5757
     if plan.use_qlora:
5858
         kwargs["quantization_config"] = _build_bnb_config(plan)
5959
 
60
-    if spec.modality == "vision-language":
60
+    from dlm.modality import modality_for
61
+
62
+    dispatch = modality_for(spec)
63
+    if dispatch.accepts_images:
6164
         # Bases with `trust_remote_code=True` often aren't registered
6265
         # with AutoModelForImageTextToText (that's the whole reason —
6366
         # their class lives in the repo, not transformers). Fall back
@@ -69,7 +72,7 @@ def load_base_model(spec: BaseModelSpec, plan: TrainingPlan) -> Any: # pragma:
6972
             return AutoModel.from_pretrained(spec.hf_id, **kwargs)
7073
         return AutoModelForImageTextToText.from_pretrained(spec.hf_id, **kwargs)
7174
 
72
-    if spec.modality == "audio-language":
75
+    if dispatch.accepts_audio:
7376
         # No AutoModelForAudioTextToText in transformers 5.x; resolve
7477
         # the class name from `spec.architecture` so adding a new audio
7578
         # base is a registry edit, not a loader patch.
src/dlm/train/trainer.pymodified
@@ -701,9 +701,12 @@ def _build_real_trainer( # pragma: no cover
701701
     # our downstream helpers and TRL's VL collator understand. Audio
702702
     # bases carry a processor too but TRL has no auto-dispatch, so the
703703
     # audio branch hands the SFTTrainer a custom `AudioLmCollator`.
704
-    is_vl = spec.modality == "vision-language"
705
-    is_audio = spec.modality == "audio-language"
706
-    is_media = is_vl or is_audio
704
+    from dlm.modality import modality_for
705
+
706
+    modality_dispatch = modality_for(spec)
707
+    is_vl = modality_dispatch.accepts_images
708
+    is_audio = modality_dispatch.accepts_audio
709
+    is_media = modality_dispatch.requires_processor
707710
     media_processor: Any | None = None
708711
     blob_store: BlobStore | None = None
709712
     image_token = "<image>"