| 1 |
"""Direct coverage for modality dispatch wrapper modules.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
from types import SimpleNamespace |
| 6 |
from unittest.mock import patch |
| 7 |
|
| 8 |
import pytest |
| 9 |
|
| 10 |
from dlm.base_models import BaseModelSpec |
| 11 |
from dlm.modality.audio import AudioLanguageModality |
| 12 |
from dlm.modality.errors import UnknownModalityError |
| 13 |
from dlm.modality.registry import TextModality, _unknown, modality_for |
| 14 |
from dlm.modality.text import TextModality as ReexportedTextModality |
| 15 |
from dlm.modality.vl import VisionLanguageModality |
| 16 |
|
| 17 |
|
| 18 |
def _minimal_text_spec(*, modality: str = "text") -> BaseModelSpec: |
| 19 |
return BaseModelSpec.model_validate( |
| 20 |
{ |
| 21 |
"key": "demo-1b", |
| 22 |
"hf_id": "org/demo-1b", |
| 23 |
"revision": "0123456789abcdef0123456789abcdef01234567", |
| 24 |
"architecture": "DemoForCausalLM", |
| 25 |
"params": 1_000_000_000, |
| 26 |
"target_modules": ["q_proj", "v_proj"], |
| 27 |
"template": "chatml", |
| 28 |
"gguf_arch": "demo", |
| 29 |
"tokenizer_pre": "demo", |
| 30 |
"license_spdx": "Apache-2.0", |
| 31 |
"license_url": None, |
| 32 |
"requires_acceptance": False, |
| 33 |
"redistributable": True, |
| 34 |
"size_gb_fp16": 2.0, |
| 35 |
"context_length": 4096, |
| 36 |
"recommended_seq_len": 2048, |
| 37 |
"modality": modality, |
| 38 |
} |
| 39 |
) |
| 40 |
|
| 41 |
|
| 42 |
def test_text_module_reexports_text_modality() -> None: |
| 43 |
assert ReexportedTextModality is TextModality |
| 44 |
|
| 45 |
|
| 46 |
def test_text_dispatch_defaults_are_noops() -> None: |
| 47 |
dispatch = TextModality() |
| 48 |
|
| 49 |
assert dispatch.load_processor(_minimal_text_spec()) is None |
| 50 |
assert ( |
| 51 |
dispatch.dispatch_export( |
| 52 |
store=object(), |
| 53 |
spec=_minimal_text_spec(), |
| 54 |
adapter_name=None, |
| 55 |
quant=None, |
| 56 |
merged=False, |
| 57 |
adapter_mix_raw=None, |
| 58 |
) |
| 59 |
is None |
| 60 |
) |
| 61 |
|
| 62 |
|
| 63 |
def test_unknown_error_contains_registration_hint() -> None: |
| 64 |
err = _unknown("mystery") |
| 65 |
assert isinstance(err, UnknownModalityError) |
| 66 |
assert "Register a ModalityDispatch subclass" in str(err) |
| 67 |
|
| 68 |
|
| 69 |
def test_modality_for_unknown_modality_raises() -> None: |
| 70 |
with pytest.raises(UnknownModalityError, match="mystery"): |
| 71 |
modality_for(SimpleNamespace(modality="mystery")) |
| 72 |
|
| 73 |
|
| 74 |
def test_audio_modality_loads_processor_and_dispatches_export() -> None: |
| 75 |
dispatch = AudioLanguageModality() |
| 76 |
spec = SimpleNamespace() |
| 77 |
|
| 78 |
with ( |
| 79 |
patch("dlm.train.loader.load_processor", return_value="processor") as load_processor, |
| 80 |
patch("dlm.export.dispatch.dispatch_audio_export", return_value="audio-export") as export, |
| 81 |
): |
| 82 |
processor = dispatch.load_processor(spec) |
| 83 |
result = dispatch.dispatch_export( |
| 84 |
store="store", |
| 85 |
spec=spec, |
| 86 |
adapter_name="adapter", |
| 87 |
quant="q4_k_m", |
| 88 |
merged=False, |
| 89 |
adapter_mix_raw="named", |
| 90 |
) |
| 91 |
|
| 92 |
assert processor == "processor" |
| 93 |
load_processor.assert_called_once_with(spec) |
| 94 |
assert result == "audio-export" |
| 95 |
export.assert_called_once_with( |
| 96 |
store="store", |
| 97 |
spec=spec, |
| 98 |
adapter_name="adapter", |
| 99 |
quant="q4_k_m", |
| 100 |
merged=False, |
| 101 |
adapter_mix_raw="named", |
| 102 |
) |
| 103 |
|
| 104 |
|
| 105 |
def test_vl_modality_loads_processor_and_dispatches_export() -> None: |
| 106 |
dispatch = VisionLanguageModality() |
| 107 |
spec = SimpleNamespace() |
| 108 |
context = {"emit": "gguf"} |
| 109 |
|
| 110 |
with ( |
| 111 |
patch("dlm.train.loader.load_processor", return_value="processor") as load_processor, |
| 112 |
patch("dlm.export.dispatch.dispatch_vl_export", return_value="vl-export") as export, |
| 113 |
): |
| 114 |
processor = dispatch.load_processor(spec) |
| 115 |
result = dispatch.dispatch_export( |
| 116 |
store="store", |
| 117 |
spec=spec, |
| 118 |
adapter_name="adapter", |
| 119 |
quant="q8_0", |
| 120 |
merged=True, |
| 121 |
adapter_mix_raw=None, |
| 122 |
gguf_emission_context=context, |
| 123 |
) |
| 124 |
|
| 125 |
assert processor == "processor" |
| 126 |
load_processor.assert_called_once_with(spec) |
| 127 |
assert result == "vl-export" |
| 128 |
export.assert_called_once_with( |
| 129 |
store="store", |
| 130 |
spec=spec, |
| 131 |
adapter_name="adapter", |
| 132 |
quant="q8_0", |
| 133 |
merged=True, |
| 134 |
adapter_mix_raw=None, |
| 135 |
gguf_emission_context=context, |
| 136 |
) |