Python · 4226 bytes Raw Blame History
1 """Direct coverage for modality dispatch wrapper modules."""
2
3 from __future__ import annotations
4
5 from types import SimpleNamespace
6 from unittest.mock import patch
7
8 import pytest
9
10 from dlm.base_models import BaseModelSpec
11 from dlm.modality.audio import AudioLanguageModality
12 from dlm.modality.errors import UnknownModalityError
13 from dlm.modality.registry import TextModality, _unknown, modality_for
14 from dlm.modality.text import TextModality as ReexportedTextModality
15 from dlm.modality.vl import VisionLanguageModality
16
17
18 def _minimal_text_spec(*, modality: str = "text") -> BaseModelSpec:
19 return BaseModelSpec.model_validate(
20 {
21 "key": "demo-1b",
22 "hf_id": "org/demo-1b",
23 "revision": "0123456789abcdef0123456789abcdef01234567",
24 "architecture": "DemoForCausalLM",
25 "params": 1_000_000_000,
26 "target_modules": ["q_proj", "v_proj"],
27 "template": "chatml",
28 "gguf_arch": "demo",
29 "tokenizer_pre": "demo",
30 "license_spdx": "Apache-2.0",
31 "license_url": None,
32 "requires_acceptance": False,
33 "redistributable": True,
34 "size_gb_fp16": 2.0,
35 "context_length": 4096,
36 "recommended_seq_len": 2048,
37 "modality": modality,
38 }
39 )
40
41
42 def test_text_module_reexports_text_modality() -> None:
43 assert ReexportedTextModality is TextModality
44
45
46 def test_text_dispatch_defaults_are_noops() -> None:
47 dispatch = TextModality()
48
49 assert dispatch.load_processor(_minimal_text_spec()) is None
50 assert (
51 dispatch.dispatch_export(
52 store=object(),
53 spec=_minimal_text_spec(),
54 adapter_name=None,
55 quant=None,
56 merged=False,
57 adapter_mix_raw=None,
58 )
59 is None
60 )
61
62
63 def test_unknown_error_contains_registration_hint() -> None:
64 err = _unknown("mystery")
65 assert isinstance(err, UnknownModalityError)
66 assert "Register a ModalityDispatch subclass" in str(err)
67
68
69 def test_modality_for_unknown_modality_raises() -> None:
70 with pytest.raises(UnknownModalityError, match="mystery"):
71 modality_for(SimpleNamespace(modality="mystery"))
72
73
74 def test_audio_modality_loads_processor_and_dispatches_export() -> None:
75 dispatch = AudioLanguageModality()
76 spec = SimpleNamespace()
77
78 with (
79 patch("dlm.train.loader.load_processor", return_value="processor") as load_processor,
80 patch("dlm.export.dispatch.dispatch_audio_export", return_value="audio-export") as export,
81 ):
82 processor = dispatch.load_processor(spec)
83 result = dispatch.dispatch_export(
84 store="store",
85 spec=spec,
86 adapter_name="adapter",
87 quant="q4_k_m",
88 merged=False,
89 adapter_mix_raw="named",
90 )
91
92 assert processor == "processor"
93 load_processor.assert_called_once_with(spec)
94 assert result == "audio-export"
95 export.assert_called_once_with(
96 store="store",
97 spec=spec,
98 adapter_name="adapter",
99 quant="q4_k_m",
100 merged=False,
101 adapter_mix_raw="named",
102 )
103
104
105 def test_vl_modality_loads_processor_and_dispatches_export() -> None:
106 dispatch = VisionLanguageModality()
107 spec = SimpleNamespace()
108 context = {"emit": "gguf"}
109
110 with (
111 patch("dlm.train.loader.load_processor", return_value="processor") as load_processor,
112 patch("dlm.export.dispatch.dispatch_vl_export", return_value="vl-export") as export,
113 ):
114 processor = dispatch.load_processor(spec)
115 result = dispatch.dispatch_export(
116 store="store",
117 spec=spec,
118 adapter_name="adapter",
119 quant="q8_0",
120 merged=True,
121 adapter_mix_raw=None,
122 gguf_emission_context=context,
123 )
124
125 assert processor == "processor"
126 load_processor.assert_called_once_with(spec)
127 assert result == "vl-export"
128 export.assert_called_once_with(
129 store="store",
130 spec=spec,
131 adapter_name="adapter",
132 quant="q8_0",
133 merged=True,
134 adapter_mix_raw=None,
135 gguf_emission_context=context,
136 )