Python · 13963 bytes Raw Blame History
1 """Audio HF-snapshot export (Sprint 35.2 T10) — manifest shape + layout.
2
3 Mirrors `test_vl_snapshot.py`. Covers:
4
5 - `run_audio_snapshot_export` refuses non-audio specs.
6 - Writes adapter + manifest + README to the right paths.
7 - Manifest carries export_target="hf_snapshot" + the base's
8 AudioPreprocessorPlan params.
9 - `verify_artifacts` round-trips.
10 - `load_audio_snapshot_manifest` deserializes what `_save_manifest` wrote.
11 """
12
13 from __future__ import annotations
14
15 import json
16 from pathlib import Path
17
18 import pytest
19
20 from dlm.base_models.schema import AudioPreprocessorPlan, BaseModelSpec
21 from dlm.export.audio_snapshot import (
22 AUDIO_SNAPSHOT_SUBDIR,
23 SNAPSHOT_MANIFEST_FILENAME,
24 SNAPSHOT_README_FILENAME,
25 AudioSnapshotManifest,
26 load_audio_snapshot_manifest,
27 run_audio_snapshot_export,
28 verify_artifacts,
29 )
30 from dlm.export.errors import ExportError, ExportManifestError
31 from dlm.store.paths import for_dlm
32
33 _VALID_ULID = "01KPMGSTNGSTTSTTSTTSTTSTVA"
34
35
36 def _audio_spec(**overrides: object) -> BaseModelSpec:
37 base_kwargs: dict[str, object] = {
38 "key": "qwen2-audio-test",
39 "hf_id": "Qwen/Qwen2-Audio-test",
40 "revision": "c" * 40,
41 "architecture": "Qwen2AudioForConditionalGeneration",
42 "params": 8_400_000_000,
43 "target_modules": ["q_proj"],
44 "template": "qwen2-audio",
45 "gguf_arch": "qwen2-audio",
46 "tokenizer_pre": "qwen2",
47 "license_spdx": "Apache-2.0",
48 "redistributable": False,
49 "size_gb_fp16": 15.0,
50 "context_length": 8192,
51 "recommended_seq_len": 2048,
52 "modality": "audio-language",
53 "audio_preprocessor_plan": AudioPreprocessorPlan(
54 sample_rate=16_000,
55 max_length_seconds=30.0,
56 audio_token="<|AUDIO|>",
57 num_audio_tokens=750,
58 ),
59 }
60 base_kwargs.update(overrides)
61 return BaseModelSpec(**base_kwargs) # type: ignore[arg-type]
62
63
64 def _text_spec() -> BaseModelSpec:
65 return BaseModelSpec(
66 key="text-base",
67 hf_id="org/text-base",
68 revision="a" * 40,
69 architecture="LlamaForCausalLM",
70 params=1_000_000,
71 target_modules=["q_proj"],
72 template="chatml",
73 gguf_arch="llama",
74 tokenizer_pre="llama-bpe",
75 license_spdx="Apache-2.0",
76 redistributable=True,
77 size_gb_fp16=0.5,
78 context_length=4096,
79 recommended_seq_len=1024,
80 )
81
82
83 @pytest.fixture
84 def populated_store(tmp_path: Path):
85 """StorePath with a fake adapter at `adapter/versions/v0001/`."""
86 store = for_dlm(_VALID_ULID, home=tmp_path)
87 store.ensure_layout()
88 v1 = store.adapter_version(1)
89 v1.mkdir(parents=True, exist_ok=True)
90 (v1 / "adapter_config.json").write_text('{"r": 16}', encoding="utf-8")
91 (v1 / "adapter_model.safetensors").write_bytes(b"fake audio adapter bytes")
92 store.set_current_adapter(v1)
93 return store
94
95
96 class TestRefusals:
97 def test_text_spec_refused(self, populated_store) -> None:
98 with pytest.raises(ExportError, match="only audio-language bases"):
99 run_audio_snapshot_export(populated_store, _text_spec())
100
101 def test_missing_audio_preprocessor_plan_refused(self, populated_store) -> None:
102 spec = _audio_spec()
103 object.__setattr__(spec, "audio_preprocessor_plan", None)
104 with pytest.raises(ExportError, match="no audio_preprocessor_plan"):
105 run_audio_snapshot_export(populated_store, spec)
106
107 def test_missing_adapter_refused(self, tmp_path: Path) -> None:
108 store = for_dlm(_VALID_ULID, home=tmp_path)
109 store.ensure_layout()
110 with pytest.raises(ExportError, match="no current adapter"):
111 run_audio_snapshot_export(store, _audio_spec())
112
113
114 class TestSnapshotLayout:
115 def test_export_dir_under_exports_hf_audio_snapshot(self, populated_store) -> None:
116 result = run_audio_snapshot_export(populated_store, _audio_spec())
117 assert result.export_dir.name == AUDIO_SNAPSHOT_SUBDIR
118 assert result.export_dir.parent == populated_store.exports
119
120 def test_adapter_files_copied(self, populated_store) -> None:
121 result = run_audio_snapshot_export(populated_store, _audio_spec())
122 assert (result.adapter_dir / "adapter_config.json").exists()
123 assert (
124 result.adapter_dir / "adapter_model.safetensors"
125 ).read_bytes() == b"fake audio adapter bytes"
126
127 def test_manifest_and_readme_written(self, populated_store) -> None:
128 result = run_audio_snapshot_export(populated_store, _audio_spec())
129 assert result.manifest_path.name == SNAPSHOT_MANIFEST_FILENAME
130 assert result.readme_path.name == SNAPSHOT_README_FILENAME
131 assert result.manifest_path.exists()
132 assert result.readme_path.exists()
133
134 def test_repeat_export_overwrites_adapter(self, populated_store) -> None:
135 run_audio_snapshot_export(populated_store, _audio_spec())
136 v1 = populated_store.adapter_version(1)
137 (v1 / "adapter_model.safetensors").write_bytes(b"new bytes")
138 result = run_audio_snapshot_export(populated_store, _audio_spec())
139 assert (result.adapter_dir / "adapter_model.safetensors").read_bytes() == b"new bytes"
140
141 def test_vl_and_audio_snapshots_disjoint_subdirs(self, populated_store) -> None:
142 """Audio + VL snapshots live in different subdirectories.
143
144 A store could in principle hold exports from two separate
145 training runs at different times against different bases; the
146 subdirectory split keeps them from clobbering each other.
147 """
148 result = run_audio_snapshot_export(populated_store, _audio_spec())
149 assert result.export_dir.name != "hf-snapshot"
150 assert result.export_dir.name == AUDIO_SNAPSHOT_SUBDIR
151
152 def test_named_adapter_export_uses_named_current_pointer(self, populated_store) -> None:
153 named = populated_store.adapter_version_for("podcast", 5)
154 named.mkdir(parents=True, exist_ok=True)
155 (named / "adapter_config.json").write_text('{"r": 32}', encoding="utf-8")
156 (named / "adapter_model.safetensors").write_bytes(b"named audio bytes")
157 populated_store.set_current_adapter_for("podcast", named)
158
159 result = run_audio_snapshot_export(
160 populated_store,
161 _audio_spec(),
162 adapter_name="podcast",
163 )
164
165 assert (
166 result.adapter_dir / "adapter_model.safetensors"
167 ).read_bytes() == b"named audio bytes"
168 manifest = load_audio_snapshot_manifest(result.export_dir)
169 assert manifest.adapter_version == 5
170 assert manifest.adapter_name == "podcast"
171
172 def test_adapter_override_uses_provided_dir(self, populated_store, tmp_path: Path) -> None:
173 override = tmp_path / "merged-adapter"
174 override.mkdir()
175 (override / "adapter_model.safetensors").write_bytes(b"override audio bytes")
176
177 result = run_audio_snapshot_export(
178 populated_store,
179 _audio_spec(),
180 adapter_path_override=override,
181 )
182
183 assert (
184 result.adapter_dir / "adapter_model.safetensors"
185 ).read_bytes() == b"override audio bytes"
186 manifest = load_audio_snapshot_manifest(result.export_dir)
187 assert manifest.adapter_version == 1
188
189 def test_missing_adapter_override_refused(self, populated_store, tmp_path: Path) -> None:
190 with pytest.raises(ExportError, match="adapter_path_override .* does not exist"):
191 run_audio_snapshot_export(
192 populated_store,
193 _audio_spec(),
194 adapter_path_override=tmp_path / "missing",
195 )
196
197 def test_processor_save_pretrained_writes_processor_artifact(self, populated_store) -> None:
198 class _Processor:
199 def save_pretrained(self, out_dir: str) -> None:
200 Path(out_dir, "processor_config.json").write_text("{}", encoding="utf-8")
201
202 result = run_audio_snapshot_export(populated_store, _audio_spec(), processor=_Processor())
203
204 assert (result.processor_dir / "processor_config.json").exists()
205 manifest = load_audio_snapshot_manifest(result.export_dir)
206 paths = {entry.path for entry in manifest.artifacts}
207 assert "processor/processor_config.json" in paths
208
209 def test_noncallable_processor_save_is_ignored(self, populated_store) -> None:
210 class _Processor:
211 save_pretrained = "not-callable"
212
213 result = run_audio_snapshot_export(populated_store, _audio_spec(), processor=_Processor())
214
215 assert result.processor_dir.exists()
216 assert not any(result.processor_dir.iterdir())
217
218
219 class TestManifestContent:
220 def test_export_target_is_hf_snapshot(self, populated_store) -> None:
221 run_audio_snapshot_export(populated_store, _audio_spec())
222 manifest = load_audio_snapshot_manifest(populated_store.exports / AUDIO_SNAPSHOT_SUBDIR)
223 assert manifest.export_target == "hf_snapshot"
224 assert manifest.modality == "audio-language"
225
226 def test_base_pinned_in_manifest(self, populated_store) -> None:
227 run_audio_snapshot_export(populated_store, _audio_spec())
228 manifest = load_audio_snapshot_manifest(populated_store.exports / AUDIO_SNAPSHOT_SUBDIR)
229 assert manifest.base_model_hf_id == "Qwen/Qwen2-Audio-test"
230 assert manifest.base_model_revision == "c" * 40
231 assert manifest.base_model_architecture == "Qwen2AudioForConditionalGeneration"
232
233 def test_preprocessor_params_pinned(self, populated_store) -> None:
234 run_audio_snapshot_export(populated_store, _audio_spec())
235 manifest = load_audio_snapshot_manifest(populated_store.exports / AUDIO_SNAPSHOT_SUBDIR)
236 assert manifest.audio_token == "<|AUDIO|>"
237 assert manifest.num_audio_tokens == 750
238 assert manifest.sample_rate == 16_000
239 assert manifest.max_length_seconds == 30.0
240
241 def test_adapter_version_recorded(self, populated_store) -> None:
242 run_audio_snapshot_export(populated_store, _audio_spec())
243 manifest = load_audio_snapshot_manifest(populated_store.exports / AUDIO_SNAPSHOT_SUBDIR)
244 assert manifest.adapter_version == 1
245
246 def test_adapter_artifacts_listed(self, populated_store) -> None:
247 run_audio_snapshot_export(populated_store, _audio_spec())
248 manifest = load_audio_snapshot_manifest(populated_store.exports / AUDIO_SNAPSHOT_SUBDIR)
249 paths = {entry.path for entry in manifest.artifacts}
250 assert "adapter/adapter_config.json" in paths
251 assert "adapter/adapter_model.safetensors" in paths
252
253
254 class TestVerifyArtifacts:
255 def test_pristine_snapshot_verifies(self, populated_store) -> None:
256 run_audio_snapshot_export(populated_store, _audio_spec())
257 export_dir = populated_store.exports / AUDIO_SNAPSHOT_SUBDIR
258 manifest = load_audio_snapshot_manifest(export_dir)
259 verify_artifacts(export_dir, manifest) # no raise
260
261 def test_tampered_artifact_detected(self, populated_store) -> None:
262 run_audio_snapshot_export(populated_store, _audio_spec())
263 export_dir = populated_store.exports / AUDIO_SNAPSHOT_SUBDIR
264 manifest = load_audio_snapshot_manifest(export_dir)
265 target = export_dir / manifest.artifacts[0].path
266 target.write_bytes(b"tampered")
267 with pytest.raises(ExportManifestError, match="sha256 mismatch"):
268 verify_artifacts(export_dir, manifest)
269
270 def test_missing_artifact_detected(self, populated_store) -> None:
271 run_audio_snapshot_export(populated_store, _audio_spec())
272 export_dir = populated_store.exports / AUDIO_SNAPSHOT_SUBDIR
273 manifest = load_audio_snapshot_manifest(export_dir)
274 (export_dir / manifest.artifacts[0].path).unlink()
275 with pytest.raises(ExportManifestError, match="missing declared artifact"):
276 verify_artifacts(export_dir, manifest)
277
278
279 class TestManifestLoadFailures:
280 def test_missing_manifest_raises(self, tmp_path: Path) -> None:
281 with pytest.raises(ExportManifestError, match="missing"):
282 load_audio_snapshot_manifest(tmp_path)
283
284 def test_malformed_json_raises(self, tmp_path: Path) -> None:
285 (tmp_path / SNAPSHOT_MANIFEST_FILENAME).write_text("not json", encoding="utf-8")
286 with pytest.raises(ExportManifestError, match="cannot parse"):
287 load_audio_snapshot_manifest(tmp_path)
288
289 def test_invalid_shape_raises(self, tmp_path: Path) -> None:
290 (tmp_path / SNAPSHOT_MANIFEST_FILENAME).write_text(
291 json.dumps({"created_by": "dlm-test"}),
292 encoding="utf-8",
293 )
294 with pytest.raises(ExportManifestError, match="invalid shape"):
295 load_audio_snapshot_manifest(tmp_path)
296
297
298 class TestManifestModelDirect:
299 def test_frozen(self) -> None:
300 from datetime import UTC, datetime
301
302 manifest = AudioSnapshotManifest(
303 created_at=datetime.now(UTC).replace(tzinfo=None),
304 created_by="dlm-test",
305 base_model_hf_id="x/y",
306 base_model_revision="a" * 40,
307 base_model_architecture="X",
308 audio_token="<|AUDIO|>",
309 num_audio_tokens=750,
310 sample_rate=16_000,
311 max_length_seconds=30.0,
312 adapter_version=1,
313 )
314 from pydantic import ValidationError
315
316 with pytest.raises(ValidationError):
317 manifest.adapter_version = 2 # type: ignore[misc]
318
319
320 class TestReadmeContent:
321 def test_mentions_architecture_class(self, populated_store) -> None:
322 """README should name the architecture class for load snippet."""
323 result = run_audio_snapshot_export(populated_store, _audio_spec())
324 body = result.readme_path.read_text(encoding="utf-8")
325 assert "Qwen2AudioForConditionalGeneration" in body
326 # Sample-rate + placeholder token surface the runtime contract.
327 assert "16000 Hz" in body
328 assert "<|AUDIO|>" in body