Python · 11204 bytes Raw Blame History
1 """HF-snapshot export for audio-language bases.
2
3 Parallel to `vl_snapshot.py`. GGUF conversion for audio architectures
4 isn't on llama.cpp's roadmap, so this path emits an HF snapshot: a
5 self-contained directory a recipient can load with
6 `AutoProcessor.from_pretrained` + architecture-class `from_pretrained`
7 + `PeftModel.from_pretrained`.
8
9 Layout under `exports/hf-audio-snapshot/`:
10
11 adapter/ # PEFT adapter dir (copy of the current store adapter)
12 processor/ # processor config + tokenizer + feature-extractor files
13 snapshot_manifest.json # see AudioSnapshotManifest
14 README.md # load instructions for the recipient
15
16 The base weights are NOT copied — pinned by `hf_id` + `revision` in
17 the manifest. Qwen2-Audio-7B is ~15 GB fp16 and marked
18 `redistributable=False` for our pack path; the same policy applies to
19 the snapshot so users don't accidentally ship base weights through
20 the pack system downstream.
21
22 The architectural symmetry to VL is intentional: both paths write a
23 `snapshot_manifest.json` with overlapping fields (`created_at`,
24 `adapter_version`, `artifacts`). A future refactor can extract a
25 shared base once a third media modality lands.
26 """
27
28 from __future__ import annotations
29
30 import json
31 import shutil
32 from dataclasses import dataclass
33 from datetime import UTC, datetime
34 from pathlib import Path
35 from typing import TYPE_CHECKING, Literal
36
37 from pydantic import BaseModel, ConfigDict, Field
38
39 from dlm.export.errors import ExportError, ExportManifestError
40 from dlm.export.manifest import ExportArtifact, build_artifact, compute_sha256
41 from dlm.io.atomic import write_text
42
43 if TYPE_CHECKING:
44 from dlm.base_models import BaseModelSpec
45 from dlm.store.paths import StorePath
46
47 AUDIO_SNAPSHOT_SUBDIR = "hf-audio-snapshot"
48 SNAPSHOT_MANIFEST_FILENAME = "snapshot_manifest.json"
49 SNAPSHOT_README_FILENAME = "README.md"
50
51
52 class AudioSnapshotManifest(BaseModel):
53 """Self-describing record of one audio HF-snapshot export.
54
55 Parallel to `VlSnapshotManifest` but scoped to the audio path.
56 `export_target="hf_snapshot"` matches the VL manifest so downstream
57 tooling can discriminate on `modality` alone when iterating over
58 all exports in a store.
59 """
60
61 model_config = ConfigDict(extra="forbid", frozen=True)
62
63 export_target: Literal["hf_snapshot"] = "hf_snapshot"
64 created_at: datetime
65 created_by: str = Field(..., description="dlm version that wrote this manifest.")
66 base_model_hf_id: str
67 base_model_revision: str
68 base_model_architecture: str
69 modality: Literal["audio-language"] = "audio-language"
70 audio_token: str
71 num_audio_tokens: int
72 sample_rate: int
73 max_length_seconds: float
74 adapter_version: int = Field(..., ge=1)
75 adapter_name: str | None = None
76 rationale: str = Field(
77 default=(
78 "Audio-language architectures are not on the llama.cpp "
79 "roadmap; this build emits an HF-snapshot fallback so "
80 "users can share trained adapters without waiting for "
81 "upstream GGUF support."
82 ),
83 )
84 artifacts: list[ExportArtifact] = Field(default_factory=list)
85
86
87 @dataclass(frozen=True)
88 class AudioSnapshotResult:
89 """Return value of `run_audio_snapshot_export` — what the CLI prints."""
90
91 export_dir: Path
92 manifest_path: Path
93 readme_path: Path
94 adapter_dir: Path
95 processor_dir: Path
96 artifacts: list[Path]
97
98
99 def run_audio_snapshot_export(
100 store: StorePath,
101 spec: BaseModelSpec,
102 *,
103 adapter_name: str | None = None,
104 adapter_path_override: Path | None = None,
105 dlm_version: str = "dlm-0",
106 processor: object | None = None,
107 ) -> AudioSnapshotResult:
108 """Emit an audio HF-snapshot export under `exports/hf-audio-snapshot/`.
109
110 Resolves the adapter dir, copies it into the export directory,
111 saves the processor (if supplied) under `processor/`, writes the
112 manifest + README, and returns the layout paths.
113
114 `processor=None` lets callers skip the processor save (tests, dry
115 runs). Production paths pass an `AutoProcessor` loaded via
116 `dlm.train.loader.load_processor`.
117 """
118 if spec.modality != "audio-language":
119 raise ExportError(
120 f"run_audio_snapshot_export: {spec.key!r} is modality={spec.modality!r}; "
121 "only audio-language bases go through the audio HF-snapshot path"
122 )
123 if spec.audio_preprocessor_plan is None:
124 raise ExportError(
125 f"run_audio_snapshot_export: {spec.key!r} has modality='audio-language' "
126 "but no audio_preprocessor_plan (this is a schema bug — file an issue)"
127 )
128
129 adapter_path, adapter_version = _resolve_adapter_for_export(
130 store=store,
131 adapter_name=adapter_name,
132 adapter_path_override=adapter_path_override,
133 )
134
135 export_dir = store.exports / AUDIO_SNAPSHOT_SUBDIR
136 export_dir.mkdir(parents=True, exist_ok=True)
137
138 adapter_out = export_dir / "adapter"
139 _copy_adapter_dir(adapter_path, adapter_out)
140
141 processor_out = export_dir / "processor"
142 if processor is not None:
143 processor_out.mkdir(parents=True, exist_ok=True)
144 save = getattr(processor, "save_pretrained", None)
145 if callable(save):
146 save(str(processor_out))
147
148 artifacts: list[Path] = []
149 for path in sorted(export_dir.rglob("*")):
150 if path.is_file() and path.name not in (
151 SNAPSHOT_MANIFEST_FILENAME,
152 SNAPSHOT_README_FILENAME,
153 ):
154 artifacts.append(path)
155
156 plan = spec.audio_preprocessor_plan
157 manifest = AudioSnapshotManifest(
158 created_at=_utc_now(),
159 created_by=dlm_version,
160 base_model_hf_id=spec.hf_id,
161 base_model_revision=spec.revision,
162 base_model_architecture=spec.architecture,
163 audio_token=plan.audio_token,
164 num_audio_tokens=plan.num_audio_tokens,
165 sample_rate=plan.sample_rate,
166 max_length_seconds=plan.max_length_seconds,
167 adapter_version=adapter_version,
168 adapter_name=adapter_name,
169 artifacts=[build_artifact(export_dir, p) for p in artifacts],
170 )
171 manifest_path = _save_manifest(export_dir, manifest)
172 readme_path = _write_readme(export_dir, spec=spec, manifest=manifest)
173
174 return AudioSnapshotResult(
175 export_dir=export_dir,
176 manifest_path=manifest_path,
177 readme_path=readme_path,
178 adapter_dir=adapter_out,
179 processor_dir=processor_out,
180 artifacts=artifacts,
181 )
182
183
184 # --- internals ---------------------------------------------------------------
185
186
187 def _resolve_adapter_for_export(
188 *,
189 store: StorePath,
190 adapter_name: str | None,
191 adapter_path_override: Path | None,
192 ) -> tuple[Path, int]:
193 """Return (adapter_dir, version) for the export."""
194 if adapter_path_override is not None:
195 if not adapter_path_override.exists():
196 raise ExportError(f"adapter_path_override {adapter_path_override} does not exist")
197 return adapter_path_override, _version_from_dir_name(adapter_path_override)
198
199 if adapter_name is None:
200 resolved = store.resolve_current_adapter()
201 pointer = store.adapter_current_pointer
202 else:
203 resolved = store.resolve_current_adapter_for(adapter_name)
204 pointer = store.adapter_current_pointer_for(adapter_name)
205
206 if resolved is None or not resolved.exists():
207 raise ExportError(f"no current adapter under {pointer}; run `dlm train` before exporting.")
208 return resolved, _version_from_dir_name(resolved)
209
210
211 def _version_from_dir_name(path: Path) -> int:
212 stem = path.name
213 if not stem.startswith("v") or not stem[1:].isdigit():
214 return 1
215 return int(stem[1:])
216
217
218 def _copy_adapter_dir(src: Path, dst: Path) -> None:
219 if dst.exists():
220 shutil.rmtree(dst)
221 shutil.copytree(src, dst)
222
223
224 def _save_manifest(export_dir: Path, manifest: AudioSnapshotManifest) -> Path:
225 path = export_dir / SNAPSHOT_MANIFEST_FILENAME
226 payload = manifest.model_dump(mode="json")
227 blob = json.dumps(payload, sort_keys=True, indent=2) + "\n"
228 write_text(path, blob)
229 return path
230
231
232 def _write_readme(
233 export_dir: Path,
234 *,
235 spec: BaseModelSpec,
236 manifest: AudioSnapshotManifest,
237 ) -> Path:
238 """Write a human-readable load-instruction file for the audio snapshot."""
239 path = export_dir / SNAPSHOT_README_FILENAME
240 body = (
241 f"# HF-audio-snapshot export\n"
242 f"\n"
243 f"Target: **{spec.key}** ({spec.hf_id} @ {spec.revision[:12]}…)\n"
244 f"Adapter version: v{manifest.adapter_version:04d}"
245 f"{f' ({manifest.adapter_name})' if manifest.adapter_name else ''}\n"
246 f"\n"
247 f"## Load this snapshot\n"
248 f"\n"
249 f"```python\n"
250 f"from transformers import AutoProcessor, {spec.architecture}\n"
251 f"from peft import PeftModel\n"
252 f"\n"
253 f"base = {spec.architecture}.from_pretrained(\n"
254 f' "{spec.hf_id}", revision="{spec.revision}",\n'
255 f")\n"
256 f'model = PeftModel.from_pretrained(base, "./adapter")\n'
257 f'processor = AutoProcessor.from_pretrained("./processor")\n'
258 f"```\n"
259 f"\n"
260 f"## Audio input shape\n"
261 f"\n"
262 f"- Sample rate: {manifest.sample_rate} Hz (required; mismatches refused)\n"
263 f"- Max length: {manifest.max_length_seconds} s (longer clips truncated)\n"
264 f"- Placeholder token: `{manifest.audio_token}` "
265 f"(expands to {manifest.num_audio_tokens} tokens per clip)\n"
266 f"\n"
267 f"## Why HF snapshot (not GGUF)\n"
268 f"\n"
269 f"Audio-language architectures are not currently supported by\n"
270 f"`llama.cpp`. The HF-snapshot path gives you a portable adapter\n"
271 f"directory that loads on any PyTorch + transformers install.\n"
272 )
273 write_text(path, body)
274 return path
275
276
277 def _utc_now() -> datetime:
278 return datetime.now(UTC).replace(tzinfo=None, microsecond=0)
279
280
281 def load_audio_snapshot_manifest(export_dir: Path) -> AudioSnapshotManifest:
282 """Read + validate `<export_dir>/snapshot_manifest.json`."""
283 path = export_dir / SNAPSHOT_MANIFEST_FILENAME
284 if not path.exists():
285 raise ExportManifestError(f"missing {path}")
286 try:
287 data = json.loads(path.read_text(encoding="utf-8"))
288 except (OSError, json.JSONDecodeError) as exc:
289 raise ExportManifestError(f"cannot parse {path}: {exc}") from exc
290 try:
291 return AudioSnapshotManifest.model_validate(data)
292 except Exception as exc:
293 raise ExportManifestError(f"{path} has invalid shape: {exc}") from exc
294
295
296 def verify_artifacts(export_dir: Path, manifest: AudioSnapshotManifest) -> None:
297 """Re-hash each declared artifact and raise on mismatch."""
298 for entry in manifest.artifacts:
299 on_disk = export_dir / entry.path
300 if not on_disk.exists():
301 raise ExportManifestError(f"missing declared artifact: {on_disk}")
302 actual = compute_sha256(on_disk)
303 if actual != entry.sha256:
304 raise ExportManifestError(
305 f"sha256 mismatch for {entry.path}: "
306 f"manifest={entry.sha256[:12]}… disk={actual[:12]}…"
307 )